Skip to content

This repo contains code to run DecTrain, an algorithm that decides when to train a monocular depth DNN.

License

Notifications You must be signed in to change notification settings

mit-lean/DecTrain

Repository files navigation

DecTrain: Deciding When to Train a Monocular Depth DNN Online

Zih-Sing Fu*, Soumya Sudhakar*, Sertac Karaman, Vivienne Sze
*Equal Contribution
MIT LEAN

📄 [Paper] | 🌐 [Project]

dectrain

This repo contains code to run DecTrain, an algorithm that decides when to train a monocular depth DNN, published in our paper.

Abstract

Deep neural networks (DNNs) can deteriorate in accuracy when deployment data differs from training data. While performing online training at all timesteps can improve accuracy, it is computationally expensive. We propose DecTrain, a new algorithm that decides when to train a monocular depth DNN online using self-supervision with low overhead. To make the decision at each timestep, DecTrain compares the cost of training with the predicted accuracy gain. We evaluate DecTrain on out-of-distribution data, and find DecTrain maintains accuracy compared to online training at all timesteps, while training only 44% of the time on average. We also compare the recovery of a low inference cost DNN using DecTrain and a more generalizable high inference cost DNN on various sequences. DecTrain recovers the majority (97%) of the accuracy gain of online training at all timesteps while reducing computation compared to the high inference cost DNN which recovers only 66%. With an even smaller DNN, we achieve 89% recovery while reducing computation by 56%. DecTrain enables low-cost online training for a smaller DNN to have competitive accuracy with a larger, more generalizable DNN at a lower overall computational cost.

Installation

Clone this repo and copy over submodules to a specific path (e.g. dectrain) using

git clone https://github.com/soumya-ss/on_the_fly_learning.git dectrain
cd dectrain
git checkout code-release
git submodule update --init --recursive

This code was tested with Python 3.9 and PyTorch 2.0.0 on Ubuntu 20.04. To install the dependencies, create and activate a new virtual environment with Python 3.9, install PyTorch from source (https://pytorch.org/get-started/locally/), and install using requirements.txt for the rest of the dependencies into the virtual environment.

python3 -m venv <path/to/venv> && source <path/to/venv>/bin/activate                             # create and activate virtual environment
pip3 install -r requirements.txt                                                                 # install required packages
pip3 install torch==2.0.0 torchvision==0.15.0 --index-url https://download.pytorch.org/whl/cu118 # install PyTorch 2.0.0

Last, apply patches to submodules for interfacing with our code:

bash scripts/apply_submodule_patches.sh

Download pretrained models

Please find the pretrained monocular depth DNNs here, and the pretrained decision DNNs here. Download and unzip the files to dectrain/models/.

mkdir models
unzip <path/to/depth_models.zip> -d models
unzip <path/to/decision_models.zip> -d models

The models should be seen in dectrain/models/depth and dectrain/models/decision. The depth DNNs are pretrained on NYUDepthV2, and the decision DNNs are pretrained on our pre-collected statistics here.

Download datasets

For online depth training, we use ScanNet, SUN3D, and KITTI-360 for our experimetns. Please download the datasets and create symbolic links in this repo.

mkdir -p datasets/depth
ln -s <path/to/scannet> datasets/depth/scannet
ln -s <path/to/sun3d> datasets/depth/sun3d
ln -s <path/to/kitti360> datasets/depth/kitti-360

For pretraining decision DNN, please download our pre-collected training data here. Follow the steps to save the files to datasets/decision

mkdir -p datasets
unzip <path/to/decision_dataset.zip> -d datasets

This dataset is also used for source replaying when running online decision DNN training. Note: we are still working on cleaning up the code for bridging the raw depth dataset format to our framework's interface.

Setup configurations

We provide the config files used for our experiments in configs/. You can also download them from here. configs/decision-dnn is for offline training the decision DNN, the other folders are all for online depth DNN training.

Online depth DNN training

Please make sure to activate the virtual environment for this project. We provide an example script (scripts/online_train_depth_dnn) for running the online depth DNN training:

bash scripts/online_train_depth_dnn.sh

The script will run the online depth DNN training with the provided experiment configs, and run a compute estimation at the end of the training. The training results will be stored to outputs/.

Offline decision DNN training

Please make sure to activate the virtual environment for this project. We provide a script (scripts/offline_train_decision_dnn) as example of training our decision DNN:

bash scripts/offline_train_decision_dnn.sh

The script train the decision DNN with our pre-collected online training statistics. To run your own online training statistics collection with a given config, please make sure the following configs are enabled:

acquisition_type: all
record_policy_training_data: true
record_policy_training_data_pattern: <list of 1/0, 1 is train, 0 is not train>

Acknowledgement

This work is implemented based on many open-source projects, including CoDEPS, DinoV2, TUM-RGBD tools and UfM. We greatly appreciate their awesome works!

Citation

If you reference this work, please consider citing the following:

@article{ral2025dectrain,
  title={DecTrain: Deciding When to Train a Monocular Depth DNN Online},
  author={Fu, Zih-Sing and Sudhakar, Soumya and Karaman, Sertac and Sze, Vivienne},
  journal={IEEE Robotics and Automation Letters},
  year={2025},
  publisher={IEEE}
}

About

This repo contains code to run DecTrain, an algorithm that decides when to train a monocular depth DNN.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •