├── .gitignore
├── COPYING
├── README.md
├── pytorch
├── README.md
├── arguments_test_eigen.txt
├── arguments_test_nyu.txt
├── arguments_train_eigen.txt
├── arguments_train_nyu.txt
├── bts.py
├── bts_dataloader.py
├── bts_eval.py
├── bts_live_3d.py
├── bts_main.py
├── bts_test.py
├── distributed_sampler_no_evenly_divisible.py
├── run_bts_eval_schedule.py
└── run_bts_live_3d.sh
├── tensorflow
├── Dockerfile
├── README.md
├── arguments_test_eigen.txt
├── arguments_test_nyu.txt
├── arguments_train_eigen.txt
├── arguments_train_nyu.txt
├── average_gradients.py
├── bts.py
├── bts_dataloader.py
├── bts_eval.py
├── bts_live_3d.py
├── bts_main.py
├── bts_sequence.py
├── bts_test.py
├── custom_layer
│ ├── CMakeLists.txt
│ ├── _local_planar_guidance_grad.py
│ ├── local_planar_guidance.cc
│ ├── local_planar_guidance.cu
│ └── local_planar_guidance.h
├── notebooks
│ └── example_nyu_v2.py.ipynb
├── requirements.txt
├── resnet_v1.py
└── run_bts_eval_schedule.py
├── train_test_inputs
├── eigen_test_files_with_gt.txt
├── eigen_train_files_with_gt.txt
├── nyudepthv2_test_files_with_gt.txt
└── nyudepthv2_train_files_with_gt.txt
└── utils
├── download_from_gdrive.py
├── eval_with_pngs.py
├── extract_official_train_test_set_from_mat.py
├── kitti_archives_to_download.txt
├── nyudepthv2_archives_to_download.txt
├── splits.mat
├── sync_project_frames_multi_threads.m
└── train_scenes.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.gitignore.io/api/c++,python
2 | # Edit at https://www.gitignore.io/?templates=c++,python
3 |
4 | # Notebook checkpoints
5 | .ipynb_checkpoints
6 |
7 | # python cache files
8 | *.pyc
9 |
10 | custom_layer/build
11 | .idea
12 | models
13 | result*
14 | .gitignore
15 | *.mat
16 | *.zip
17 | !splits.mat
18 | utils/toolbox_nyu_depth_v2
19 |
20 | ### C++ ###
21 | # Prerequisites
22 | *.d
23 |
24 | # Compiled Object files
25 | *.slo
26 | *.lo
27 | *.o
28 | *.obj
29 |
30 | # Precompiled Headers
31 | *.gch
32 | *.pch
33 |
34 | # Compiled Dynamic libraries
35 | *.so
36 | *.dylib
37 | *.dll
38 |
39 | # Fortran module files
40 | *.mod
41 | *.smod
42 |
43 | # Compiled Static libraries
44 | *.lai
45 | *.la
46 | *.a
47 | *.lib
48 |
49 | # Executables
50 | *.exe
51 | *.out
52 | *.app
53 |
54 | ### Python ###
55 | # Byte-compiled / optimized / DLL files
56 | __pycache__/
57 | *.py[cod]
58 | *$py.class
59 |
60 | # C extensions
61 |
62 | # Distribution / packaging
63 | .Python
64 | build/
65 | develop-eggs/
66 | dist/
67 | downloads/
68 | eggs/
69 | .eggs/
70 | lib/
71 | lib64/
72 | parts/
73 | sdist/
74 | var/
75 | wheels/
76 | pip-wheel-metadata/
77 | share/python-wheels/
78 | *.egg-info/
79 | .installed.cfg
80 | *.egg
81 | MANIFEST
82 |
83 | # PyInstaller
84 | # Usually these files are written by a python script from a template
85 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
86 | *.manifest
87 | *.spec
88 |
89 | # Installer logs
90 | pip-log.txt
91 | pip-delete-this-directory.txt
92 |
93 | # Unit test / coverage reports
94 | htmlcov/
95 | .tox/
96 | .nox/
97 | .coverage
98 | .coverage.*
99 | .cache
100 | nosetests.xml
101 | coverage.xml
102 | *.cover
103 | .hypothesis/
104 | .pytest_cache/
105 |
106 | # Translations
107 | *.mo
108 | *.pot
109 |
110 | # Scrapy stuff:
111 | .scrapy
112 |
113 | # Sphinx documentation
114 | docs/_build/
115 |
116 | # PyBuilder
117 | target/
118 |
119 | # pyenv
120 | .python-version
121 |
122 | # pipenv
123 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
124 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
125 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
126 | # install all needed dependencies.
127 | #Pipfile.lock
128 |
129 | # celery beat schedule file
130 | celerybeat-schedule
131 |
132 | # SageMath parsed files
133 | *.sage.py
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # Mr Developer
143 | .mr.developer.cfg
144 | .project
145 | .pydevproject
146 |
147 | # mkdocs documentation
148 | /site
149 |
150 | # mypy
151 | .mypy_cache/
152 | .dmypy.json
153 | dmypy.json
154 |
155 | # Pyre type checker
156 | .pyre/
157 |
158 | # End of https://www.gitignore.io/api/c++,python
159 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BTS
2 |
3 | From Big to Small: Multi-Scale Local Planar Guidance for Monocular Depth Estimation
4 | [arXiv](https://arxiv.org/abs/1907.10326)
5 | [Supplementary material](https://arxiv.org/src/1907.10326v4/anc/bts_sm.pdf)
6 |
7 | ## Video Demo 1
8 | [](https://www.youtube.com/watch?v=2fPdZYzx9Cg)
9 | ## Video Demo 2
10 | [](https://www.youtube.com/watch?v=1J-GSb0fROw)
11 |
12 | ## Note
13 | This repository contains TensorFlow and PyTorch implementations of BTS.
14 | ## Preparation for all implementations
15 | ```shell
16 | $ cd ~
17 | $ mkdir workspace
18 | $ cd workspace
19 | ### Make a folder for datasets
20 | $ mkdir dataset
21 | ### Clone this repo
22 | $ git clone https://github.com/cleinc/bts
23 | ```
24 | ## Prepare [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) test set
25 | ```shell
26 | $ cd ~/workspace/bts/utils
27 | ### Get official NYU Depth V2 split file
28 | $ wget http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat
29 | ### Convert mat file to image files
30 | $ python extract_official_train_test_set_from_mat.py nyu_depth_v2_labeled.mat splits.mat ../../dataset/nyu_depth_v2/official_splits/
31 | ```
32 | ## Prepare [KITTI](http://www.cvlibs.net/download.php?file=data_depth_annotated.zip) official ground truth depth maps
33 | Download the ground truth depthmaps from this link [KITTI](http://www.cvlibs.net/download.php?file=data_depth_annotated.zip).\
34 | Then,
35 | ```
36 | $ cd ~/workspace/dataset
37 | $ mkdir kitti_dataset && cd kitti_dataset
38 | $ mv ~/Downloads/data_depth_annotated.zip .
39 | $ unzip data_depth_annotated.zip
40 | ```
41 |
42 | Follow instructions from one of the below implementations with your choice.
43 |
44 | ## TensorFlow Implementation
45 | [[./tensorflow/]](./tensorflow/)
46 | ## PyTorch Implementation
47 | [[./pytorch/]](./pytorch/)
48 |
49 | ## Model Zoo
50 | ### KITTI Eigen Split
51 |
52 | | Base Network | cap | d1 | d2 | d3 | AbsRel | SqRel | RMSE | RMSElog | SILog | log10 | #Params | Model Download |
53 | |:------------:|:-----:|:-----:|:-----:|:-----:|:------:|:-----:|:-----:|:-------:|:-----:|:-----:|:-------:|:--------------------------------:|
54 | | ResNet50 | 0-80m | 0.954 | 0.992 | 0.998 | 0.061 | 0.250 | 2.803 | 0.098 | 9.030 | 0.027 | 49.5M | [bts_eigen_v2_pytorch_resnet50](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_resnet50.zip) |
55 | | ResNet101 | 0-80m | 0.954 | 0.992 | 0.998 | 0.061 | 0.261 | 2.834 | 0.099 | 9.075 | 0.027 | 68.5M | [bts_eigen_v2_pytorch_resnet101](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_resnet101.zip) |
56 | | ResNext50 | 0-80m | 0.954 | 0.993 | 0.998 | 0.061 | 0.245 | 2.774 | 0.098 | 9.014 | 0.027 | 49.0M | [bts_eigen_v2_pytorch_resnext50](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_resnext50.zip) |
57 | | ResNext101 | 0-80m | 0.956 | 0.993 | 0.998 | 0.059 | 0.241 | 2.756 | 0.096 | 8.781 | 0.026 | 112.8M | [bts_eigen_v2_pytorch_resnext101](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_resnext101.zip) |
58 | | DenseNet121 | 0-80m | 0.951 | 0.993 | 0.998 | 0.063 | 0.256 | 2.850 | 0.100 | 9.221 | 0.028 | 21.2M | [bts_eigen_v2_pytorch_densenet121](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_densenet121.zip) |
59 | | DenseNet161 | 0-80m | 0.955 | 0.993 | 0.998 | 0.060 | 0.249 | 2.798 | 0.096 | 8.933 | 0.027 | 47.0M | [bts_eigen_v2_pytorch_densenet161](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_densenet161.zip) |
60 |
61 | ### NYU Depth V2
62 |
63 | | Base Network | d1 | d2 | d3 | AbsRel | SqRel | RMSE | RMSElog | SILog | log10 | #Params | Model Download |
64 | |:------------:|:-----:|:-----:|:-----:|:------:|:-----:|:-----:|:-------:|:------:|:-----:|:-------:|:------------------------------:|
65 | | ResNet50 | 0.865 | 0.975 | 0.993 | 0.119 | 0.075 | 0.419 | 0.152 | 12.368 | 0.051 | 49.5M | [bts_nyu_v2_pytorch_resnet50](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_resnet50.zip) |
66 | | ResNet101 | 0.871 | 0.977 | 0.995 | 0.113 | 0.068 | 0.407 | 0.148 | 11.886 | 0.049 | 68.5M | [bts_nyu_v2_pytorch_resnet101](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_resnet101.zip) |
67 | | ResNext50 | 0.867 | 0.977 | 0.995 | 0.116 | 0.070 | 0.414 | 0.150 | 12.186 | 0.050 | 49.0M | [bts_nyu_v2_pytorch_resnext50](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_resnext50.zip) |
68 | | ResNext101 | 0.880 | 0.977 | 0.994 | 0.111 | 0.069 | 0.399 | 0.145 | 11.680 | 0.048 | 112.8M | [bts_nyu_v2_pytorch_resnext101](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_resnext101.zip) |
69 | | DenseNet121 | 0.871 | 0.977 | 0.993 | 0.118 | 0.072 | 0.410 | 0.149 | 12.028 | 0.050 | 21.2M | [bts_nyu_v2_pytorch_densenet121](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_densenet121.zip) |
70 | | DenseNet161 | 0.885 | 0.978 | 0.994 | 0.110 | 0.066 | 0.392 | 0.142 | 11.533 | 0.047 | 47.0M | [bts_nyu_v2_pytorch_densenet161](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_densenet161.zip) |
71 | | MobileNetV2 | TBA | TBA | TBA | TBA | TBA | TBA | TBA | TBA | TBA | 16.3M | [bts_nyu_v2_pytorch_mobilenetv2](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_mobilenetv2.zip) |
72 |
73 | Note: Modify arguments '--encoder', '--model_name', '--checkpoint_path' and '--pred_path' accordingly.
74 |
75 | ## Live Demo
76 | Finally, we attach live 3d demo implementations for both of TensorFlow and Pytorch. \
77 | For best performance, get correct intrinsic values for your webcam and put them in bts_live_3d.py. \
78 | Sample usage for PyTorch:
79 | ```
80 | $ cd ~/workspace/bts/pytorch
81 | $ python bts_live_3d.py --model_name bts_nyu_v2_pytorch_densenet161 \
82 | --encoder densenet161_bts \
83 | --checkpoint_path ./models/bts_nyu_v2_pytorch_densenet161/model \
84 | --max_depth 10 \
85 | --input_height 480 \
86 | --input_width 640
87 | ```
88 |
89 | ## Citation
90 | If you find this work useful for your research, please consider citing our paper:
91 | ```
92 | @article{lee2019big,
93 | title={From big to small: Multi-scale local planar guidance for monocular depth estimation},
94 | author={Lee, Jin Han and Han, Myung-Kyu and Ko, Dong Wook and Suh, Il Hong},
95 | journal={arXiv preprint arXiv:1907.10326},
96 | year={2019}
97 | }
98 | ```
99 |
100 | ## License
101 | Copyright (C) 2019 Jin Han Lee, Myung-Kyu Han, Dong Wook Ko and Il Hong Suh \
102 | This Software is licensed under GPL-3.0-or-later.
103 |
--------------------------------------------------------------------------------
/pytorch/README.md:
--------------------------------------------------------------------------------
1 | # BTS
2 | From Big to Small: Multi-Scale Local Planar Guidance for Monocular Depth Estimation
3 | [arXiv](https://arxiv.org/abs/1907.10326)
4 | [Supplementary material](https://arxiv.org/src/1907.10326v4/anc/bts_sm.pdf)
5 |
6 | ## Note
7 | This folder contains a PyTorch implementation of BTS.\
8 | We tested this code under python 3.6, PyTorch 1.2.0, CUDA 10.0 on Ubuntu 18.04.
9 |
10 | ## Testing with [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)
11 | First make sure that you have prepared the test set using instructions in README.md at root of this repo.
12 | ```shell
13 | $ cd ~/workspace/bts/pytorch
14 | $ mkdir models
15 | ### Get BTS model trained with NYU Depth V2
16 | $ cd models
17 | $ wget https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_densenet161.zip
18 | $ unzip bts_nyu_v2_pytorch_densenet161.zip
19 | ```
20 | Once the preparation steps completed, you can test BTS using following commands.
21 | ```
22 | $ cd ~/workspace/bts/pytorch
23 | $ python bts_test.py arguments_test_nyu.txt
24 | ```
25 | This will save results to ./result_bts_nyu_v2_pytorch_densenet161. With a single RTX 2080 Ti it takes about 41 seconds for processing 654 testing images.
26 |
27 | ## Evaluation
28 | Following command will evaluate the prediction results for NYU Depvh V2.
29 | ```
30 | $ cd ~/workspace/bts/pytorch
31 | $ python ../utils/eval_with_pngs.py --pred_path result_bts_nyu_v2_pytorch_densenet161/raw/ --gt_path ../../dataset/nyu_depth_v2/official_splits/test/ --dataset nyu --min_depth_eval 1e-3 --max_depth_eval 10 --eigen_crop
32 | ```
33 |
34 | You should see outputs like this:
35 | ```
36 | Raw png files reading done
37 | Evaluating 654 files
38 | GT files reading done
39 | 0 GT files missing
40 | Computing errors
41 | d1, d2, d3, AbsRel, SqRel, RMSE, RMSElog, SILog, log10
42 | 0.885, 0.978, 0.994, 0.110, 0.066, 0.392, 0.142, 11.533, 0.047
43 | Done.
44 | ```
45 |
46 | ## Preparation for Training
47 | ### NYU Depvh V2
48 | Download the dataset we used in this work.
49 | ```
50 | $ cd ~/workspace/bts
51 | $ python utils/download_from_gdrive.py 1AysroWpfISmm-yRFGBgFTrLy6FjQwvwP ../dataset/nyu_depth_v2/sync.zip
52 | $ unzip sync.zip
53 | ```
54 | Also, you can download it from following link:
55 | https://drive.google.com/file/d/1AysroWpfISmm-yRFGBgFTrLy6FjQwvwP/view?usp=sharing
56 | Please make sure to locate the downloaded file to ~/workspace/bts/dataset/nyu_depth_v2/sync.zip
57 |
58 | Once the dataset is ready, you can train the network using following command.
59 | ```
60 | $ cd ~/workspace/bts/pytorch
61 | $ python bts_main.py arguments_train_nyu.txt
62 | ```
63 | You can check the training using tensorboard:
64 | ```
65 | $ tensorboard --logdir ./models/bts_nyu_test/ --port 6006
66 | ```
67 | Open localhost:6006 with your favorite browser to see the progress of training.
68 |
69 | ### KITTI
70 | You can also train BTS with KITTI dataset by following procedures.
71 | First, make sure that you have prepared the ground truth depthmaps from [KITTI](http://www.cvlibs.net/download.php?file=data_depth_annotated.zip).
72 | If you have not, please follow instructions on README.md at root of this repo.
73 | Then, download and unzip the raw dataset using following commands.
74 | ```
75 | $ cd ~/workspace/dataset/kitti_dataset
76 | $ aria2c -x 16 -i ../../bts/utils/kitti_archives_to_download.txt
77 | $ parallel unzip ::: *.zip
78 | ```
79 | Finally, we can train our network with
80 | ```
81 | $ cd ~/workspace/bts/pytorch
82 | $ python bts_main.py arguments_train_eigen.txt
83 | ```
84 |
85 | ## Testing and Evaluation with [KITTI](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_prediction)
86 | Once you have KITTI dataset and official ground truth depthmaps, you can test and evaluate our model with following commands.
87 | ```
88 | ### Get model trained with KITTI Eigen split
89 | $ cd ~/workspace/bts/pytorch/models
90 | $ wget https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_densenet161.zip
91 | $ cd unzip bts_eigen_v2_pytorch_densenet161.zip
92 | ```
93 | Test and save results.
94 | ```
95 | $ cd ~/workspace/bts/pytorch
96 | $ python bts_test.py arguments_test_eigen.txt
97 | ```
98 | This will save results to ./result_bts_eigen_v2_pytorch_densenet161.
99 | Finally, we can evaluate the prediction results with
100 | ```
101 | $ cd ~/workspace/bts/pytorch
102 | $ python ../utils/eval_with_pngs.py --pred_path result_bts_eigen_v2_pytorch_densenet161/raw/ --gt_path ../../dataset/kitti_dataset/data_depth_annotated/ --dataset kitti --min_depth_eval 1e-3 --max_depth_eval 80 --do_kb_crop --garg_crop
103 | ```
104 | You should see outputs like this:
105 | ```
106 | GT files reading done
107 | 45 GT files missing
108 | Computing errors
109 | d1, d2, d3, AbsRel, SqRel, RMSE, RMSElog, SILog, log10
110 | 0.955, 0.993, 0.998, 0.060, 0.249, 2.798, 0.096, 8.933, 0.027
111 | Done.
112 | ```
113 |
114 | Also, in this pytorch implementation, you can use various base networks with pretrained weights as the encoder for bts.\
115 | Available options are: resnet50_bts, resnet101_bts, resnext50_bts, resnext101_bts, densenet121_bts and densenet161_bts\
116 | Simply change the argument '--encoder' in arguments_train_*.txt with your choice.
117 |
118 | ## Model Zoo
119 | ### KITTI Eigen Split
120 |
121 | | Base Network | d1 | d2 | d3 | AbsRel | SqRel | RMSE | RMSElog | SILog | log10 | #Params | Model Download |
122 | |:------------:|:-----:|:-----:|:-----:|:------:|:-----:|:-----:|:-------:|:-----:|:-----:|:-------:|:--------------------------------:|
123 | | ResNet50 | 0.954 | 0.992 | 0.998 | 0.061 | 0.250 | 2.803 | 0.098 | 9.030 | 0.027 | 49.5M | [bts_eigen_v2_pytorch_resnet50](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_resnet50.zip) |
124 | | ResNet101 | 0.954 | 0.992 | 0.998 | 0.061 | 0.261 | 2.834 | 0.099 | 9.075 | 0.027 | 68.5M | [bts_eigen_v2_pytorch_resnet101](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_resnet101.zip) |
125 | | ResNext50 | 0.954 | 0.993 | 0.998 | 0.061 | 0.245 | 2.774 | 0.098 | 9.014 | 0.027 | 49.0M | [bts_eigen_v2_pytorch_resnext50](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_resnext50.zip) |
126 | | ResNext101 | 0.956 | 0.993 | 0.998 | 0.059 | 0.241 | 2.756 | 0.096 | 8.781 | 0.026 | 112.8M | [bts_eigen_v2_pytorch_resnext101](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_resnext101.zip) |
127 | | DenseNet121 | 0.951 | 0.993 | 0.998 | 0.063 | 0.256 | 2.850 | 0.100 | 9.221 | 0.028 | 21.2M | [bts_eigen_v2_pytorch_densenet121](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_densenet121.zip) |
128 | | DenseNet161 | 0.955 | 0.993 | 0.998 | 0.060 | 0.249 | 2.798 | 0.096 | 8.933 | 0.027 | 47.0M | [bts_eigen_v2_pytorch_densenet161](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_eigen_v2_pytorch_densenet161.zip) |
129 |
130 | ### NYU Depth V2
131 |
132 | | Base Network | d1 | d2 | d3 | AbsRel | SqRel | RMSE | RMSElog | SILog | log10 | #Params | Model Download |
133 | |:------------:|:-----:|:-----:|:-----:|:------:|:-----:|:-----:|:-------:|:------:|:-----:|:-------:|:------------------------------:|
134 | | ResNet50 | 0.865 | 0.975 | 0.993 | 0.119 | 0.075 | 0.419 | 0.152 | 12.368 | 0.051 | 49.5M | [bts_nyu_v2_pytorch_resnet50](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_resnet50.zip) |
135 | | ResNet101 | 0.871 | 0.977 | 0.995 | 0.113 | 0.068 | 0.407 | 0.148 | 11.886 | 0.049 | 68.5M | [bts_nyu_v2_pytorch_resnet101](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_resnet101.zip) |
136 | | ResNext50 | 0.867 | 0.977 | 0.995 | 0.116 | 0.070 | 0.414 | 0.150 | 12.186 | 0.050 | 49.0M | [bts_nyu_v2_pytorch_resnext50](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_resnext50.zip) |
137 | | ResNext101 | 0.880 | 0.977 | 0.994 | 0.111 | 0.069 | 0.399 | 0.145 | 11.680 | 0.048 | 112.8M | [bts_nyu_v2_pytorch_resnext101](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_resnext101.zip) |
138 | | DenseNet121 | 0.871 | 0.977 | 0.993 | 0.118 | 0.072 | 0.410 | 0.149 | 12.028 | 0.050 | 21.2M | [bts_nyu_v2_pytorch_densenet121](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_densenet121.zip) |
139 | | DenseNet161 | 0.885 | 0.978 | 0.994 | 0.110 | 0.066 | 0.392 | 0.142 | 11.533 | 0.047 | 47.0M | [bts_nyu_v2_pytorch_densenet161](https://cogaplex-bts.s3.ap-northeast-2.amazonaws.com/bts_nyu_v2_pytorch_densenet161.zip) |
140 |
141 | Note: Modify arguments '--encoder', '--model_name', '--checkpoint_path' and '--pred_path' accordingly.
142 |
143 | ## License
144 | Copyright (C) 2019 Jin Han Lee, Myung-Kyu Han, Dong Wook Ko and Il Hong Suh \
145 | This Software is licensed under GPL-3.0-or-later.
146 |
--------------------------------------------------------------------------------
/pytorch/arguments_test_eigen.txt:
--------------------------------------------------------------------------------
1 | --encoder densenet161_bts
2 | --data_path ../../dataset/kitti_dataset/
3 | --dataset kitti
4 | --filenames_file ../train_test_inputs/eigen_test_files_with_gt.txt
5 | --model_name bts_eigen_v2_pytorch_densenet161
6 | --checkpoint_path ./models/bts_eigen_v2_pytorch_densenet161/model
7 | --input_height 352
8 | --input_width 1216
9 | --max_depth 80
10 | --do_kb_crop
11 |
--------------------------------------------------------------------------------
/pytorch/arguments_test_nyu.txt:
--------------------------------------------------------------------------------
1 | --encoder densenet161_bts
2 | --data_path ../../dataset/nyu_depth_v2/official_splits/test/
3 | --dataset nyu
4 | --filenames_file ../train_test_inputs/nyudepthv2_test_files_with_gt.txt
5 | --model_name bts_nyu_v2_pytorch_densenet161
6 | --checkpoint_path ./models/bts_nyu_v2_pytorch_densenet161/model
7 | --input_height 480
8 | --input_width 640
9 | --max_depth 10
10 |
--------------------------------------------------------------------------------
/pytorch/arguments_train_eigen.txt:
--------------------------------------------------------------------------------
1 | --mode train
2 | --model_name bts_eigen_v2_pytorch_test
3 | --encoder densenet161_bts
4 | --dataset kitti
5 | --data_path ../../dataset/kitti_dataset/
6 | --gt_path ../../dataset/kitti_dataset/data_depth_annotated/
7 | --filenames_file ../train_test_inputs/eigen_train_files_with_gt.txt
8 | --batch_size 4
9 | --num_epochs 50
10 | --learning_rate 1e-4
11 | --weight_decay 1e-2
12 | --adam_eps 1e-3
13 | --num_threads 1
14 | --input_height 352
15 | --input_width 704
16 | --max_depth 80
17 | --do_kb_crop
18 | --do_random_rotate
19 | --degree 1.0
20 | --log_directory ./models/
21 | --multiprocessing_distributed
22 | --dist_url tcp://127.0.0.1:2345
23 |
24 | --log_freq 100
25 | --do_online_eval
26 | --eval_freq 500
27 | --data_path_eval ../../dataset/kitti_dataset/
28 | --gt_path_eval ../../dataset/kitti_dataset/data_depth_annotated/
29 | --filenames_file_eval ../train_test_inputs/eigen_test_files_with_gt.txt
30 | --min_depth_eval 1e-3
31 | --max_depth_eval 80
32 | --eval_summary_directory ./models/eval/
33 | --garg_crop
--------------------------------------------------------------------------------
/pytorch/arguments_train_nyu.txt:
--------------------------------------------------------------------------------
1 | --mode train
2 | --model_name bts_nyu_v2_pytorch_test
3 | --encoder densenet161_bts
4 | --dataset nyu
5 | --data_path ../../dataset/nyu_depth_v2/sync/
6 | --gt_path ../../dataset/nyu_depth_v2/sync/
7 | --filenames_file ../train_test_inputs/nyudepthv2_train_files_with_gt.txt
8 | --batch_size 4
9 | --num_epochs 50
10 | --learning_rate 1e-4
11 | --weight_decay 1e-2
12 | --adam_eps 1e-3
13 | --num_threads 1
14 | --input_height 416
15 | --input_width 544
16 | --max_depth 10
17 | --do_random_rotate
18 | --degree 2.5
19 | --log_directory ./models/
20 | --multiprocessing_distributed
21 | --dist_url tcp://127.0.0.1:2345
22 |
23 | --log_freq 100
24 | --do_online_eval
25 | --eval_freq 500
26 | --data_path_eval ../../dataset/nyu_depth_v2/official_splits/test/
27 | --gt_path_eval ../../dataset/nyu_depth_v2/official_splits/test/
28 | --filenames_file_eval ../train_test_inputs/nyudepthv2_test_files_with_gt.txt
29 | --min_depth_eval 1e-3
30 | --max_depth_eval 10
31 | --eval_summary_directory ./models/eval/
32 | --eigen_crop
--------------------------------------------------------------------------------
/pytorch/bts.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as torch_nn_func
20 | import math
21 |
22 | from collections import namedtuple
23 |
24 |
25 | # This sets the batch norm layers in pytorch as if {'is_training': False, 'scale': True} in tensorflow
26 | def bn_init_as_tf(m):
27 | if isinstance(m, nn.BatchNorm2d):
28 | m.track_running_stats = True # These two lines enable using stats (moving mean and var) loaded from pretrained model
29 | m.eval() # or zero mean and variance of one if the batch norm layer has no pretrained values
30 | m.affine = True
31 | m.requires_grad = True
32 |
33 |
34 | def weights_init_xavier(m):
35 | if isinstance(m, nn.Conv2d):
36 | torch.nn.init.xavier_uniform_(m.weight)
37 | if m.bias is not None:
38 | torch.nn.init.zeros_(m.bias)
39 |
40 |
41 | class silog_loss(nn.Module):
42 | def __init__(self, variance_focus):
43 | super(silog_loss, self).__init__()
44 | self.variance_focus = variance_focus
45 |
46 | def forward(self, depth_est, depth_gt, mask):
47 | d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask])
48 | return torch.sqrt((d ** 2).mean() - self.variance_focus * (d.mean() ** 2)) * 10.0
49 |
50 |
51 | class atrous_conv(nn.Sequential):
52 | def __init__(self, in_channels, out_channels, dilation, apply_bn_first=True):
53 | super(atrous_conv, self).__init__()
54 | self.atrous_conv = torch.nn.Sequential()
55 | if apply_bn_first:
56 | self.atrous_conv.add_module('first_bn', nn.BatchNorm2d(in_channels, momentum=0.01, affine=True, track_running_stats=True, eps=1.1e-5))
57 |
58 | self.atrous_conv.add_module('aconv_sequence', nn.Sequential(nn.ReLU(),
59 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, bias=False, kernel_size=1, stride=1, padding=0),
60 | nn.BatchNorm2d(out_channels*2, momentum=0.01, affine=True, track_running_stats=True),
61 | nn.ReLU(),
62 | nn.Conv2d(in_channels=out_channels * 2, out_channels=out_channels, bias=False, kernel_size=3, stride=1,
63 | padding=(dilation, dilation), dilation=dilation)))
64 |
65 | def forward(self, x):
66 | return self.atrous_conv.forward(x)
67 |
68 |
69 | class upconv(nn.Module):
70 | def __init__(self, in_channels, out_channels, ratio=2):
71 | super(upconv, self).__init__()
72 | self.elu = nn.ELU()
73 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, bias=False, kernel_size=3, stride=1, padding=1)
74 | self.ratio = ratio
75 |
76 | def forward(self, x):
77 | up_x = torch_nn_func.interpolate(x, scale_factor=self.ratio, mode='nearest')
78 | out = self.conv(up_x)
79 | out = self.elu(out)
80 | return out
81 |
82 |
83 | class reduction_1x1(nn.Sequential):
84 | def __init__(self, num_in_filters, num_out_filters, max_depth, is_final=False):
85 | super(reduction_1x1, self).__init__()
86 | self.max_depth = max_depth
87 | self.is_final = is_final
88 | self.sigmoid = nn.Sigmoid()
89 | self.reduc = torch.nn.Sequential()
90 |
91 | while num_out_filters >= 4:
92 | if num_out_filters < 8:
93 | if self.is_final:
94 | self.reduc.add_module('final', torch.nn.Sequential(nn.Conv2d(num_in_filters, out_channels=1, bias=False,
95 | kernel_size=1, stride=1, padding=0),
96 | nn.Sigmoid()))
97 | else:
98 | self.reduc.add_module('plane_params', torch.nn.Conv2d(num_in_filters, out_channels=3, bias=False,
99 | kernel_size=1, stride=1, padding=0))
100 | break
101 | else:
102 | self.reduc.add_module('inter_{}_{}'.format(num_in_filters, num_out_filters),
103 | torch.nn.Sequential(nn.Conv2d(in_channels=num_in_filters, out_channels=num_out_filters,
104 | bias=False, kernel_size=1, stride=1, padding=0),
105 | nn.ELU()))
106 |
107 | num_in_filters = num_out_filters
108 | num_out_filters = num_out_filters // 2
109 |
110 | def forward(self, net):
111 | net = self.reduc.forward(net)
112 | if not self.is_final:
113 | theta = self.sigmoid(net[:, 0, :, :]) * math.pi / 3
114 | phi = self.sigmoid(net[:, 1, :, :]) * math.pi * 2
115 | dist = self.sigmoid(net[:, 2, :, :]) * self.max_depth
116 | n1 = torch.mul(torch.sin(theta), torch.cos(phi)).unsqueeze(1)
117 | n2 = torch.mul(torch.sin(theta), torch.sin(phi)).unsqueeze(1)
118 | n3 = torch.cos(theta).unsqueeze(1)
119 | n4 = dist.unsqueeze(1)
120 | net = torch.cat([n1, n2, n3, n4], dim=1)
121 |
122 | return net
123 |
124 | class local_planar_guidance(nn.Module):
125 | def __init__(self, upratio):
126 | super(local_planar_guidance, self).__init__()
127 | self.upratio = upratio
128 | self.u = torch.arange(self.upratio).reshape([1, 1, self.upratio]).float()
129 | self.v = torch.arange(int(self.upratio)).reshape([1, self.upratio, 1]).float()
130 | self.upratio = float(upratio)
131 |
132 | def forward(self, plane_eq, focal):
133 | plane_eq_expanded = torch.repeat_interleave(plane_eq, int(self.upratio), 2)
134 | plane_eq_expanded = torch.repeat_interleave(plane_eq_expanded, int(self.upratio), 3)
135 | n1 = plane_eq_expanded[:, 0, :, :]
136 | n2 = plane_eq_expanded[:, 1, :, :]
137 | n3 = plane_eq_expanded[:, 2, :, :]
138 | n4 = plane_eq_expanded[:, 3, :, :]
139 |
140 | u = self.u.repeat(plane_eq.size(0), plane_eq.size(2) * int(self.upratio), plane_eq.size(3)).cuda()
141 | u = (u - (self.upratio - 1) * 0.5) / self.upratio
142 |
143 | v = self.v.repeat(plane_eq.size(0), plane_eq.size(2), plane_eq.size(3) * int(self.upratio)).cuda()
144 | v = (v - (self.upratio - 1) * 0.5) / self.upratio
145 |
146 | return n4 / (n1 * u + n2 * v + n3)
147 |
148 | class bts(nn.Module):
149 | def __init__(self, params, feat_out_channels, num_features=512):
150 | super(bts, self).__init__()
151 | self.params = params
152 |
153 | self.upconv5 = upconv(feat_out_channels[4], num_features)
154 | self.bn5 = nn.BatchNorm2d(num_features, momentum=0.01, affine=True, eps=1.1e-5)
155 |
156 | self.conv5 = torch.nn.Sequential(nn.Conv2d(num_features + feat_out_channels[3], num_features, 3, 1, 1, bias=False),
157 | nn.ELU())
158 | self.upconv4 = upconv(num_features, num_features // 2)
159 | self.bn4 = nn.BatchNorm2d(num_features // 2, momentum=0.01, affine=True, eps=1.1e-5)
160 | self.conv4 = torch.nn.Sequential(nn.Conv2d(num_features // 2 + feat_out_channels[2], num_features // 2, 3, 1, 1, bias=False),
161 | nn.ELU())
162 | self.bn4_2 = nn.BatchNorm2d(num_features // 2, momentum=0.01, affine=True, eps=1.1e-5)
163 |
164 | self.daspp_3 = atrous_conv(num_features // 2, num_features // 4, 3, apply_bn_first=False)
165 | self.daspp_6 = atrous_conv(num_features // 2 + num_features // 4 + feat_out_channels[2], num_features // 4, 6)
166 | self.daspp_12 = atrous_conv(num_features + feat_out_channels[2], num_features // 4, 12)
167 | self.daspp_18 = atrous_conv(num_features + num_features // 4 + feat_out_channels[2], num_features // 4, 18)
168 | self.daspp_24 = atrous_conv(num_features + num_features // 2 + feat_out_channels[2], num_features // 4, 24)
169 | self.daspp_conv = torch.nn.Sequential(nn.Conv2d(num_features + num_features // 2 + num_features // 4, num_features // 4, 3, 1, 1, bias=False),
170 | nn.ELU())
171 | self.reduc8x8 = reduction_1x1(num_features // 4, num_features // 4, self.params.max_depth)
172 | self.lpg8x8 = local_planar_guidance(8)
173 |
174 | self.upconv3 = upconv(num_features // 4, num_features // 4)
175 | self.bn3 = nn.BatchNorm2d(num_features // 4, momentum=0.01, affine=True, eps=1.1e-5)
176 | self.conv3 = torch.nn.Sequential(nn.Conv2d(num_features // 4 + feat_out_channels[1] + 1, num_features // 4, 3, 1, 1, bias=False),
177 | nn.ELU())
178 | self.reduc4x4 = reduction_1x1(num_features // 4, num_features // 8, self.params.max_depth)
179 | self.lpg4x4 = local_planar_guidance(4)
180 |
181 | self.upconv2 = upconv(num_features // 4, num_features // 8)
182 | self.bn2 = nn.BatchNorm2d(num_features // 8, momentum=0.01, affine=True, eps=1.1e-5)
183 | self.conv2 = torch.nn.Sequential(nn.Conv2d(num_features // 8 + feat_out_channels[0] + 1, num_features // 8, 3, 1, 1, bias=False),
184 | nn.ELU())
185 |
186 | self.reduc2x2 = reduction_1x1(num_features // 8, num_features // 16, self.params.max_depth)
187 | self.lpg2x2 = local_planar_guidance(2)
188 |
189 | self.upconv1 = upconv(num_features // 8, num_features // 16)
190 | self.reduc1x1 = reduction_1x1(num_features // 16, num_features // 32, self.params.max_depth, is_final=True)
191 | self.conv1 = torch.nn.Sequential(nn.Conv2d(num_features // 16 + 4, num_features // 16, 3, 1, 1, bias=False),
192 | nn.ELU())
193 | self.get_depth = torch.nn.Sequential(nn.Conv2d(num_features // 16, 1, 3, 1, 1, bias=False),
194 | nn.Sigmoid())
195 |
196 | def forward(self, features, focal):
197 | skip0, skip1, skip2, skip3 = features[0], features[1], features[2], features[3]
198 | dense_features = torch.nn.ReLU()(features[4])
199 | upconv5 = self.upconv5(dense_features) # H/16
200 | upconv5 = self.bn5(upconv5)
201 | concat5 = torch.cat([upconv5, skip3], dim=1)
202 | iconv5 = self.conv5(concat5)
203 |
204 | upconv4 = self.upconv4(iconv5) # H/8
205 | upconv4 = self.bn4(upconv4)
206 | concat4 = torch.cat([upconv4, skip2], dim=1)
207 | iconv4 = self.conv4(concat4)
208 | iconv4 = self.bn4_2(iconv4)
209 |
210 | daspp_3 = self.daspp_3(iconv4)
211 | concat4_2 = torch.cat([concat4, daspp_3], dim=1)
212 | daspp_6 = self.daspp_6(concat4_2)
213 | concat4_3 = torch.cat([concat4_2, daspp_6], dim=1)
214 | daspp_12 = self.daspp_12(concat4_3)
215 | concat4_4 = torch.cat([concat4_3, daspp_12], dim=1)
216 | daspp_18 = self.daspp_18(concat4_4)
217 | concat4_5 = torch.cat([concat4_4, daspp_18], dim=1)
218 | daspp_24 = self.daspp_24(concat4_5)
219 | concat4_daspp = torch.cat([iconv4, daspp_3, daspp_6, daspp_12, daspp_18, daspp_24], dim=1)
220 | daspp_feat = self.daspp_conv(concat4_daspp)
221 |
222 | reduc8x8 = self.reduc8x8(daspp_feat)
223 | plane_normal_8x8 = reduc8x8[:, :3, :, :]
224 | plane_normal_8x8 = torch_nn_func.normalize(plane_normal_8x8, 2, 1)
225 | plane_dist_8x8 = reduc8x8[:, 3, :, :]
226 | plane_eq_8x8 = torch.cat([plane_normal_8x8, plane_dist_8x8.unsqueeze(1)], 1)
227 | depth_8x8 = self.lpg8x8(plane_eq_8x8, focal)
228 | depth_8x8_scaled = depth_8x8.unsqueeze(1) / self.params.max_depth
229 | depth_8x8_scaled_ds = torch_nn_func.interpolate(depth_8x8_scaled, scale_factor=0.25, mode='nearest')
230 |
231 | upconv3 = self.upconv3(daspp_feat) # H/4
232 | upconv3 = self.bn3(upconv3)
233 | concat3 = torch.cat([upconv3, skip1, depth_8x8_scaled_ds], dim=1)
234 | iconv3 = self.conv3(concat3)
235 |
236 | reduc4x4 = self.reduc4x4(iconv3)
237 | plane_normal_4x4 = reduc4x4[:, :3, :, :]
238 | plane_normal_4x4 = torch_nn_func.normalize(plane_normal_4x4, 2, 1)
239 | plane_dist_4x4 = reduc4x4[:, 3, :, :]
240 | plane_eq_4x4 = torch.cat([plane_normal_4x4, plane_dist_4x4.unsqueeze(1)], 1)
241 | depth_4x4 = self.lpg4x4(plane_eq_4x4, focal)
242 | depth_4x4_scaled = depth_4x4.unsqueeze(1) / self.params.max_depth
243 | depth_4x4_scaled_ds = torch_nn_func.interpolate(depth_4x4_scaled, scale_factor=0.5, mode='nearest')
244 |
245 | upconv2 = self.upconv2(iconv3) # H/2
246 | upconv2 = self.bn2(upconv2)
247 | concat2 = torch.cat([upconv2, skip0, depth_4x4_scaled_ds], dim=1)
248 | iconv2 = self.conv2(concat2)
249 |
250 | reduc2x2 = self.reduc2x2(iconv2)
251 | plane_normal_2x2 = reduc2x2[:, :3, :, :]
252 | plane_normal_2x2 = torch_nn_func.normalize(plane_normal_2x2, 2, 1)
253 | plane_dist_2x2 = reduc2x2[:, 3, :, :]
254 | plane_eq_2x2 = torch.cat([plane_normal_2x2, plane_dist_2x2.unsqueeze(1)], 1)
255 | depth_2x2 = self.lpg2x2(plane_eq_2x2, focal)
256 | depth_2x2_scaled = depth_2x2.unsqueeze(1) / self.params.max_depth
257 |
258 | upconv1 = self.upconv1(iconv2)
259 | reduc1x1 = self.reduc1x1(upconv1)
260 | concat1 = torch.cat([upconv1, reduc1x1, depth_2x2_scaled, depth_4x4_scaled, depth_8x8_scaled], dim=1)
261 | iconv1 = self.conv1(concat1)
262 | final_depth = self.params.max_depth * self.get_depth(iconv1)
263 | if self.params.dataset == 'kitti':
264 | final_depth = final_depth * focal.view(-1, 1, 1, 1).float() / 715.0873
265 |
266 | return depth_8x8_scaled, depth_4x4_scaled, depth_2x2_scaled, reduc1x1, final_depth
267 |
268 | class encoder(nn.Module):
269 | def __init__(self, params):
270 | super(encoder, self).__init__()
271 | self.params = params
272 | import torchvision.models as models
273 | if params.encoder == 'densenet121_bts':
274 | self.base_model = models.densenet121(pretrained=True).features
275 | self.feat_names = ['relu0', 'pool0', 'transition1', 'transition2', 'norm5']
276 | self.feat_out_channels = [64, 64, 128, 256, 1024]
277 | elif params.encoder == 'densenet161_bts':
278 | self.base_model = models.densenet161(pretrained=True).features
279 | self.feat_names = ['relu0', 'pool0', 'transition1', 'transition2', 'norm5']
280 | self.feat_out_channels = [96, 96, 192, 384, 2208]
281 | elif params.encoder == 'resnet50_bts':
282 | self.base_model = models.resnet50(pretrained=True)
283 | self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
284 | self.feat_out_channels = [64, 256, 512, 1024, 2048]
285 | elif params.encoder == 'resnet101_bts':
286 | self.base_model = models.resnet101(pretrained=True)
287 | self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
288 | self.feat_out_channels = [64, 256, 512, 1024, 2048]
289 | elif params.encoder == 'resnext50_bts':
290 | self.base_model = models.resnext50_32x4d(pretrained=True)
291 | self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
292 | self.feat_out_channels = [64, 256, 512, 1024, 2048]
293 | elif params.encoder == 'resnext101_bts':
294 | self.base_model = models.resnext101_32x8d(pretrained=True)
295 | self.feat_names = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
296 | self.feat_out_channels = [64, 256, 512, 1024, 2048]
297 | elif params.encoder == 'mobilenetv2_bts':
298 | self.base_model = models.mobilenet_v2(pretrained=True).features
299 | self.feat_inds = [2, 4, 7, 11, 19]
300 | self.feat_out_channels = [16, 24, 32, 64, 1280]
301 | self.feat_names = []
302 | else:
303 | print('Not supported encoder: {}'.format(params.encoder))
304 |
305 | def forward(self, x):
306 | feature = x
307 | skip_feat = []
308 | i = 1
309 | for k, v in self.base_model._modules.items():
310 | if 'fc' in k or 'avgpool' in k:
311 | continue
312 | feature = v(feature)
313 | if self.params.encoder == 'mobilenetv2_bts':
314 | if i == 2 or i == 4 or i == 7 or i == 11 or i == 19:
315 | skip_feat.append(feature)
316 | else:
317 | if any(x in k for x in self.feat_names):
318 | skip_feat.append(feature)
319 | i = i + 1
320 | return skip_feat
321 |
322 |
323 | class BtsModel(nn.Module):
324 | def __init__(self, params):
325 | super(BtsModel, self).__init__()
326 | self.encoder = encoder(params)
327 | self.decoder = bts(params, self.encoder.feat_out_channels, params.bts_size)
328 |
329 | def forward(self, x, focal):
330 | skip_feat = self.encoder(x)
331 | return self.decoder(skip_feat, focal)
332 |
--------------------------------------------------------------------------------
/pytorch/bts_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | import numpy as np
18 | import torch
19 | from torch.utils.data import Dataset, DataLoader
20 | import torch.utils.data.distributed
21 | from torchvision import transforms
22 | from PIL import Image
23 | import os
24 | import random
25 |
26 | from distributed_sampler_no_evenly_divisible import *
27 |
28 |
29 | def _is_pil_image(img):
30 | return isinstance(img, Image.Image)
31 |
32 |
33 | def _is_numpy_image(img):
34 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
35 |
36 |
37 | def preprocessing_transforms(mode):
38 | return transforms.Compose([
39 | ToTensor(mode=mode)
40 | ])
41 |
42 |
43 | class BtsDataLoader(object):
44 | def __init__(self, args, mode):
45 | if mode == 'train':
46 | self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
47 | if args.distributed:
48 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples)
49 | else:
50 | self.train_sampler = None
51 |
52 | self.data = DataLoader(self.training_samples, args.batch_size,
53 | shuffle=(self.train_sampler is None),
54 | num_workers=args.num_threads,
55 | pin_memory=True,
56 | sampler=self.train_sampler)
57 |
58 | elif mode == 'online_eval':
59 | self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
60 | if args.distributed:
61 | # self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False)
62 | self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False)
63 | else:
64 | self.eval_sampler = None
65 | self.data = DataLoader(self.testing_samples, 1,
66 | shuffle=False,
67 | num_workers=1,
68 | pin_memory=True,
69 | sampler=self.eval_sampler)
70 |
71 | elif mode == 'test':
72 | self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
73 | self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1)
74 |
75 | else:
76 | print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
77 |
78 |
79 | class DataLoadPreprocess(Dataset):
80 | def __init__(self, args, mode, transform=None, is_for_online_eval=False):
81 | self.args = args
82 | if mode == 'online_eval':
83 | with open(args.filenames_file_eval, 'r') as f:
84 | self.filenames = f.readlines()
85 | else:
86 | with open(args.filenames_file, 'r') as f:
87 | self.filenames = f.readlines()
88 |
89 | self.mode = mode
90 | self.transform = transform
91 | self.to_tensor = ToTensor
92 | self.is_for_online_eval = is_for_online_eval
93 |
94 | def __getitem__(self, idx):
95 | sample_path = self.filenames[idx]
96 | focal = float(sample_path.split()[2])
97 |
98 | if self.mode == 'train':
99 | if self.args.dataset == 'kitti' and self.args.use_right is True and random.random() > 0.5:
100 | image_path = os.path.join(self.args.data_path, "./" + sample_path.split()[3])
101 | depth_path = os.path.join(self.args.gt_path, "./" + sample_path.split()[4])
102 | else:
103 | image_path = os.path.join(self.args.data_path, "./" + sample_path.split()[0])
104 | depth_path = os.path.join(self.args.gt_path, "./" + sample_path.split()[1])
105 |
106 | image = Image.open(image_path)
107 | depth_gt = Image.open(depth_path)
108 |
109 | if self.args.do_kb_crop is True:
110 | height = image.height
111 | width = image.width
112 | top_margin = int(height - 352)
113 | left_margin = int((width - 1216) / 2)
114 | depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
115 | image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
116 |
117 | # To avoid blank boundaries due to pixel registration
118 | if self.args.dataset == 'nyu':
119 | depth_gt = depth_gt.crop((43, 45, 608, 472))
120 | image = image.crop((43, 45, 608, 472))
121 |
122 | if self.args.do_random_rotate is True:
123 | random_angle = (random.random() - 0.5) * 2 * self.args.degree
124 | image = self.rotate_image(image, random_angle)
125 | depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST)
126 |
127 | image = np.asarray(image, dtype=np.float32) / 255.0
128 | depth_gt = np.asarray(depth_gt, dtype=np.float32)
129 | depth_gt = np.expand_dims(depth_gt, axis=2)
130 |
131 | if self.args.dataset == 'nyu':
132 | depth_gt = depth_gt / 1000.0
133 | else:
134 | depth_gt = depth_gt / 256.0
135 |
136 | image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width)
137 | image, depth_gt = self.train_preprocess(image, depth_gt)
138 | sample = {'image': image, 'depth': depth_gt, 'focal': focal}
139 |
140 | else:
141 | if self.mode == 'online_eval':
142 | data_path = self.args.data_path_eval
143 | else:
144 | data_path = self.args.data_path
145 |
146 | image_path = os.path.join(data_path, "./" + sample_path.split()[0])
147 | image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
148 |
149 | if self.mode == 'online_eval':
150 | gt_path = self.args.gt_path_eval
151 | depth_path = os.path.join(gt_path, "./" + sample_path.split()[1])
152 | has_valid_depth = False
153 | try:
154 | depth_gt = Image.open(depth_path)
155 | has_valid_depth = True
156 | except IOError:
157 | depth_gt = False
158 | # print('Missing gt for {}'.format(image_path))
159 |
160 | if has_valid_depth:
161 | depth_gt = np.asarray(depth_gt, dtype=np.float32)
162 | depth_gt = np.expand_dims(depth_gt, axis=2)
163 | if self.args.dataset == 'nyu':
164 | depth_gt = depth_gt / 1000.0
165 | else:
166 | depth_gt = depth_gt / 256.0
167 |
168 | if self.args.do_kb_crop is True:
169 | height = image.shape[0]
170 | width = image.shape[1]
171 | top_margin = int(height - 352)
172 | left_margin = int((width - 1216) / 2)
173 | image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
174 | if self.mode == 'online_eval' and has_valid_depth:
175 | depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
176 |
177 | if self.mode == 'online_eval':
178 | sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth}
179 | else:
180 | sample = {'image': image, 'focal': focal}
181 |
182 | if self.transform:
183 | sample = self.transform(sample)
184 |
185 | return sample
186 |
187 | def rotate_image(self, image, angle, flag=Image.BILINEAR):
188 | result = image.rotate(angle, resample=flag)
189 | return result
190 |
191 | def random_crop(self, img, depth, height, width):
192 | assert img.shape[0] >= height
193 | assert img.shape[1] >= width
194 | assert img.shape[0] == depth.shape[0]
195 | assert img.shape[1] == depth.shape[1]
196 | x = random.randint(0, img.shape[1] - width)
197 | y = random.randint(0, img.shape[0] - height)
198 | img = img[y:y + height, x:x + width, :]
199 | depth = depth[y:y + height, x:x + width, :]
200 | return img, depth
201 |
202 | def train_preprocess(self, image, depth_gt):
203 | # Random flipping
204 | do_flip = random.random()
205 | if do_flip > 0.5:
206 | image = (image[:, ::-1, :]).copy()
207 | depth_gt = (depth_gt[:, ::-1, :]).copy()
208 |
209 | # Random gamma, brightness, color augmentation
210 | do_augment = random.random()
211 | if do_augment > 0.5:
212 | image = self.augment_image(image)
213 |
214 | return image, depth_gt
215 |
216 | def augment_image(self, image):
217 | # gamma augmentation
218 | gamma = random.uniform(0.9, 1.1)
219 | image_aug = image ** gamma
220 |
221 | # brightness augmentation
222 | if self.args.dataset == 'nyu':
223 | brightness = random.uniform(0.75, 1.25)
224 | else:
225 | brightness = random.uniform(0.9, 1.1)
226 | image_aug = image_aug * brightness
227 |
228 | # color augmentation
229 | colors = np.random.uniform(0.9, 1.1, size=3)
230 | white = np.ones((image.shape[0], image.shape[1]))
231 | color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
232 | image_aug *= color_image
233 | image_aug = np.clip(image_aug, 0, 1)
234 |
235 | return image_aug
236 |
237 | def __len__(self):
238 | return len(self.filenames)
239 |
240 |
241 | class ToTensor(object):
242 | def __init__(self, mode):
243 | self.mode = mode
244 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
245 |
246 | def __call__(self, sample):
247 | image, focal = sample['image'], sample['focal']
248 | image = self.to_tensor(image)
249 | image = self.normalize(image)
250 |
251 | if self.mode == 'test':
252 | return {'image': image, 'focal': focal}
253 |
254 | depth = sample['depth']
255 | if self.mode == 'train':
256 | depth = self.to_tensor(depth)
257 | return {'image': image, 'depth': depth, 'focal': focal}
258 | else:
259 | has_valid_depth = sample['has_valid_depth']
260 | return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth}
261 |
262 | def to_tensor(self, pic):
263 | if not (_is_pil_image(pic) or _is_numpy_image(pic)):
264 | raise TypeError(
265 | 'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
266 |
267 | if isinstance(pic, np.ndarray):
268 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
269 | return img
270 |
271 | # handle PIL Image
272 | if pic.mode == 'I':
273 | img = torch.from_numpy(np.array(pic, np.int32, copy=False))
274 | elif pic.mode == 'I;16':
275 | img = torch.from_numpy(np.array(pic, np.int16, copy=False))
276 | else:
277 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
278 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
279 | if pic.mode == 'YCbCr':
280 | nchannel = 3
281 | elif pic.mode == 'I;16':
282 | nchannel = 1
283 | else:
284 | nchannel = len(pic.mode)
285 | img = img.view(pic.size[1], pic.size[0], nchannel)
286 |
287 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
288 | if isinstance(img, torch.ByteTensor):
289 | return img.float()
290 | else:
291 | return img
292 |
--------------------------------------------------------------------------------
/pytorch/bts_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import os
20 | import argparse
21 | import time
22 | import numpy as np
23 | import cv2
24 | import sys
25 |
26 | import torch
27 | import torch.nn as nn
28 | import torch.nn.utils as utils
29 | import torchvision.utils as vutils
30 | import torch.backends.cudnn as cudnn
31 | from torch.autograd import Variable
32 | from tensorboardX import SummaryWriter
33 | from bts_dataloader import *
34 |
35 | def convert_arg_line_to_args(arg_line):
36 | for arg in arg_line.split():
37 | if not arg.strip():
38 | continue
39 | yield arg
40 |
41 |
42 | parser = argparse.ArgumentParser(description='BTS PyTorch implementation.', fromfile_prefix_chars='@')
43 | parser.convert_arg_line_to_args = convert_arg_line_to_args
44 |
45 | parser.add_argument('--model_name', type=str, help='model name', default='bts_v0_0_1')
46 | parser.add_argument('--encoder', type=str, help='type of encoder, desenet121_bts or densenet161_bts',
47 | default='densenet161_bts')
48 | parser.add_argument('--data_path', type=str, help='path to the data', required=True)
49 | parser.add_argument('--gt_path', type=str, help='path to the groundtruth data', required=False)
50 | parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
51 | parser.add_argument('--input_height', type=int, help='input height', default=480)
52 | parser.add_argument('--input_width', type=int, help='input width', default=640)
53 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=80)
54 | parser.add_argument('--output_directory', type=str,
55 | help='output directory for summary, if empty outputs to checkpoint folder', default='')
56 | parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='')
57 | parser.add_argument('--dataset', type=str, help='dataset to train on, make3d or nyudepthv2', default='nyu')
58 | parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
59 | parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
60 |
61 | parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
62 | parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
63 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
64 | parser.add_argument('--bts_size', type=int, help='initial num_filters in bts', default=512)
65 |
66 | if sys.argv.__len__() == 2:
67 | arg_filename_with_prefix = '@' + sys.argv[1]
68 | args = parser.parse_args([arg_filename_with_prefix])
69 | else:
70 | args = parser.parse_args()
71 |
72 | model_dir = os.path.dirname(args.checkpoint_path)
73 | sys.path.append(model_dir)
74 |
75 | for key, val in vars(__import__(args.model_name)).items():
76 | if key.startswith('__') and key.endswith('__'):
77 | continue
78 | vars()[key] = val
79 |
80 |
81 | def compute_errors(gt, pred):
82 | thresh = np.maximum((gt / pred), (pred / gt))
83 | d1 = (thresh < 1.25).mean()
84 | d2 = (thresh < 1.25 ** 2).mean()
85 | d3 = (thresh < 1.25 ** 3).mean()
86 |
87 | rmse = (gt - pred) ** 2
88 | rmse = np.sqrt(rmse.mean())
89 |
90 | rmse_log = (np.log(gt) - np.log(pred)) ** 2
91 | rmse_log = np.sqrt(rmse_log.mean())
92 |
93 | abs_rel = np.mean(np.abs(gt - pred) / gt)
94 | sq_rel = np.mean(((gt - pred) ** 2) / gt)
95 |
96 | err = np.log(pred) - np.log(gt)
97 | silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
98 |
99 | err = np.abs(np.log10(pred) - np.log10(gt))
100 | log10 = np.mean(err)
101 |
102 | return silog, log10, abs_rel, sq_rel, rmse, rmse_log, d1, d2, d3
103 |
104 |
105 | def get_num_lines(file_path):
106 | f = open(file_path, 'r')
107 | lines = f.readlines()
108 | f.close()
109 | return len(lines)
110 |
111 |
112 | def test(params):
113 | global gt_depths, is_missing, missing_ids
114 | gt_depths = []
115 | is_missing = []
116 | missing_ids = set()
117 | write_summary = False
118 | steps = set()
119 |
120 | if os.path.isdir(args.checkpoint_path):
121 | import glob
122 | models = [f for f in glob.glob(args.checkpoint_path + "/model*")]
123 |
124 | for model in models:
125 | step = model.split('-')[-1]
126 | steps.add('{:06d}'.format(int(step)))
127 |
128 | lines = []
129 | if os.path.exists(args.checkpoint_path + '/evaluated_checkpoints'):
130 | with open(args.checkpoint_path + '/evaluated_checkpoints') as file:
131 | lines = file.readlines()
132 |
133 | for line in lines:
134 | if line.rstrip() in steps:
135 | steps.remove(line.rstrip())
136 |
137 | steps = sorted(steps)
138 | if args.output_directory != '':
139 | summary_path = os.path.join(args.output_directory, args.model_name)
140 | else:
141 | summary_path = os.path.join(args.checkpoint_path, 'eval')
142 |
143 | write_summary = True
144 | else:
145 | steps.add('{:06d}'.format(int(args.checkpoint_path.split('-')[-1])))
146 |
147 |
148 | if len(steps) == 0:
149 | print('No new model to evaluate. Abort.')
150 | return
151 |
152 | args.mode = 'test'
153 | dataloader = BtsDataLoader(args, 'eval')
154 |
155 | model = BtsModel(params=params)
156 | model = torch.nn.DataParallel(model)
157 |
158 | cudnn.benchmark = True
159 |
160 | if write_summary:
161 | summary_writer = SummaryWriter(summary_path, flush_secs=30)
162 |
163 | for step in steps:
164 | if os.path.isdir(args.checkpoint_path):
165 | checkpoint = torch.load(os.path.join(args.checkpoint_path, 'model-' + str(int(step))))
166 | model.load_state_dict(checkpoint['model'])
167 | else:
168 | checkpoint = torch.load(args.checkpoint_path)
169 | model.load_state_dict(checkpoint['model'])
170 |
171 | model.eval()
172 | model.cuda()
173 |
174 | num_test_samples = get_num_lines(args.filenames_file)
175 |
176 | with open(args.filenames_file) as f:
177 | lines = f.readlines()
178 |
179 | print('now testing {} files for step {}'.format(num_test_samples, step))
180 |
181 | pred_depths = []
182 |
183 | start_time = time.time()
184 | with torch.no_grad():
185 | for _, sample in enumerate(dataloader.data):
186 | image = Variable(sample['image'].cuda())
187 | focal = Variable(sample['focal'].cuda())
188 | # image = Variable(sample['image'])
189 | # focal = Variable(sample['focal'])
190 | # Predict
191 | lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)
192 | pred_depths.append(depth_est.cpu().numpy().squeeze())
193 |
194 | elapsed_time = time.time() - start_time
195 | print('Elapesed time: %s' % str(elapsed_time))
196 | print('Done.')
197 |
198 | if len(gt_depths) == 0:
199 | for t_id in range(num_test_samples):
200 | gt_depth_path = os.path.join(args.gt_path, lines[t_id].split()[1])
201 | depth = cv2.imread(gt_depth_path, -1)
202 | if depth is None:
203 | print('Missing: %s ' % gt_depth_path)
204 | missing_ids.add(t_id)
205 | continue
206 |
207 | if args.dataset == 'nyu':
208 | depth = depth.astype(np.float32) / 1000.0
209 | else:
210 | depth = depth.astype(np.float32) / 256.0
211 |
212 | gt_depths.append(depth)
213 |
214 | print('Computing errors')
215 | silog, log10, abs_rel, sq_rel, rms, log_rms, d1, d2, d3 = eval(pred_depths, int(step))
216 |
217 | if write_summary:
218 | summary_writer.add_scalar('silog', silog.mean(), int(step))
219 | summary_writer.add_scalar('abs_rel', abs_rel.mean(), int(step))
220 | summary_writer.add_scalar('log10', log10.mean(), int(step))
221 | summary_writer.add_scalar('sq_rel', sq_rel.mean(), int(step))
222 | summary_writer.add_scalar('rms', rms.mean(), int(step))
223 | summary_writer.add_scalar('log_rms', log_rms.mean(), int(step))
224 | summary_writer.add_scalar('d1', d1.mean(), int(step))
225 | summary_writer.add_scalar('d2', d2.mean(), int(step))
226 | summary_writer.add_scalar('d3', d3.mean(), int(step))
227 | summary_writer.flush()
228 |
229 | with open(os.path.dirname(args.checkpoint_path) + '/evaluated_checkpoints', 'a') as file:
230 | file.write(step + '\n')
231 |
232 | print('Evaluation done')
233 |
234 |
235 | def eval(pred_depths, step):
236 | num_samples = get_num_lines(args.filenames_file)
237 | pred_depths_valid = []
238 |
239 | for t_id in range(num_samples):
240 | if t_id in missing_ids:
241 | continue
242 |
243 | pred_depths_valid.append(pred_depths[t_id])
244 |
245 | num_samples = num_samples - len(missing_ids)
246 |
247 | silog = np.zeros(num_samples, np.float32)
248 | log10 = np.zeros(num_samples, np.float32)
249 | rms = np.zeros(num_samples, np.float32)
250 | log_rms = np.zeros(num_samples, np.float32)
251 | abs_rel = np.zeros(num_samples, np.float32)
252 | sq_rel = np.zeros(num_samples, np.float32)
253 | d1 = np.zeros(num_samples, np.float32)
254 | d2 = np.zeros(num_samples, np.float32)
255 | d3 = np.zeros(num_samples, np.float32)
256 |
257 | for i in range(num_samples):
258 | gt_depth = gt_depths[i]
259 | pred_depth = pred_depths_valid[i]
260 |
261 | if args.do_kb_crop:
262 | height, width = gt_depth.shape
263 | top_margin = int(height - 352)
264 | left_margin = int((width - 1216) / 2)
265 | pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
266 | pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
267 | pred_depth = pred_depth_uncropped
268 |
269 | pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
270 | pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
271 | pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
272 | pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
273 |
274 | valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
275 |
276 | if args.garg_crop or args.eigen_crop:
277 | gt_height, gt_width = gt_depth.shape
278 | eval_mask = np.zeros(valid_mask.shape)
279 |
280 | if args.garg_crop:
281 | eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
282 |
283 | elif args.eigen_crop:
284 | if args.dataset == 'kitti':
285 | eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
286 | else:
287 | eval_mask[45:471, 41:601] = 1
288 |
289 | valid_mask = np.logical_and(valid_mask, eval_mask)
290 |
291 | silog[i], log10[i], abs_rel[i], sq_rel[i], rms[i], log_rms[i], d1[i], d2[i], d3[i] = compute_errors(
292 | gt_depth[valid_mask], pred_depth[valid_mask])
293 |
294 | print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
295 | 'sq_rel', 'log_rms', 'd1', 'd2', 'd3'))
296 | print("{:7.4f}, {:7.4f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}".format(
297 | silog.mean(), abs_rel.mean(), log10.mean(), rms.mean(), sq_rel.mean(), log_rms.mean(), d1.mean(), d2.mean(),
298 | d3.mean()))
299 |
300 | return silog, log10, abs_rel, sq_rel, rms, log_rms, d1, d2, d3
301 |
302 |
303 | if __name__ == '__main__':
304 | test(args)
--------------------------------------------------------------------------------
/pytorch/bts_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import os
20 | import argparse
21 | import time
22 | import numpy as np
23 | import cv2
24 | import sys
25 |
26 | import torch
27 | import torch.nn as nn
28 | from torch.autograd import Variable
29 | from bts_dataloader import *
30 |
31 | import errno
32 | import matplotlib.pyplot as plt
33 | from tqdm import tqdm
34 |
35 | from bts_dataloader import *
36 |
37 |
38 | def convert_arg_line_to_args(arg_line):
39 | for arg in arg_line.split():
40 | if not arg.strip():
41 | continue
42 | yield arg
43 |
44 |
45 | parser = argparse.ArgumentParser(description='BTS PyTorch implementation.', fromfile_prefix_chars='@')
46 | parser.convert_arg_line_to_args = convert_arg_line_to_args
47 |
48 | parser.add_argument('--model_name', type=str, help='model name', default='bts_nyu_v2')
49 | parser.add_argument('--encoder', type=str, help='type of encoder, vgg or desenet121_bts or densenet161_bts',
50 | default='densenet161_bts')
51 | parser.add_argument('--data_path', type=str, help='path to the data', required=True)
52 | parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
53 | parser.add_argument('--input_height', type=int, help='input height', default=480)
54 | parser.add_argument('--input_width', type=int, help='input width', default=640)
55 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=80)
56 | parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='')
57 | parser.add_argument('--dataset', type=str, help='dataset to train on, make3d or nyudepthv2', default='nyu')
58 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
59 | parser.add_argument('--save_lpg', help='if set, save outputs from lpg layers', action='store_true')
60 | parser.add_argument('--bts_size', type=int, help='initial num_filters in bts', default=512)
61 |
62 | if sys.argv.__len__() == 2:
63 | arg_filename_with_prefix = '@' + sys.argv[1]
64 | args = parser.parse_args([arg_filename_with_prefix])
65 | else:
66 | args = parser.parse_args()
67 |
68 | model_dir = os.path.dirname(args.checkpoint_path)
69 | sys.path.append(model_dir)
70 |
71 | for key, val in vars(__import__(args.model_name)).items():
72 | if key.startswith('__') and key.endswith('__'):
73 | continue
74 | vars()[key] = val
75 |
76 |
77 | def get_num_lines(file_path):
78 | f = open(file_path, 'r')
79 | lines = f.readlines()
80 | f.close()
81 | return len(lines)
82 |
83 |
84 | def test(params):
85 | """Test function."""
86 | args.mode = 'test'
87 | dataloader = BtsDataLoader(args, 'test')
88 |
89 | model = BtsModel(params=args)
90 | model = torch.nn.DataParallel(model)
91 |
92 | checkpoint = torch.load(args.checkpoint_path)
93 | model.load_state_dict(checkpoint['model'])
94 | model.eval()
95 | model.cuda()
96 |
97 | num_params = sum([np.prod(p.size()) for p in model.parameters()])
98 | print("Total number of parameters: {}".format(num_params))
99 |
100 | num_test_samples = get_num_lines(args.filenames_file)
101 |
102 | with open(args.filenames_file) as f:
103 | lines = f.readlines()
104 |
105 | print('now testing {} files with {}'.format(num_test_samples, args.checkpoint_path))
106 |
107 | pred_depths = []
108 | pred_8x8s = []
109 | pred_4x4s = []
110 | pred_2x2s = []
111 | pred_1x1s = []
112 |
113 | start_time = time.time()
114 | with torch.no_grad():
115 | for _, sample in enumerate(tqdm(dataloader.data)):
116 | image = Variable(sample['image'].cuda())
117 | focal = Variable(sample['focal'].cuda())
118 | # Predict
119 | lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)
120 | pred_depths.append(depth_est.cpu().numpy().squeeze())
121 | pred_8x8s.append(lpg8x8[0].cpu().numpy().squeeze())
122 | pred_4x4s.append(lpg4x4[0].cpu().numpy().squeeze())
123 | pred_2x2s.append(lpg2x2[0].cpu().numpy().squeeze())
124 | pred_1x1s.append(reduc1x1[0].cpu().numpy().squeeze())
125 |
126 | elapsed_time = time.time() - start_time
127 | print('Elapesed time: %s' % str(elapsed_time))
128 | print('Done.')
129 |
130 | save_name = 'result_' + args.model_name
131 |
132 | print('Saving result pngs..')
133 | if not os.path.exists(os.path.dirname(save_name)):
134 | try:
135 | os.mkdir(save_name)
136 | os.mkdir(save_name + '/raw')
137 | os.mkdir(save_name + '/cmap')
138 | os.mkdir(save_name + '/rgb')
139 | os.mkdir(save_name + '/gt')
140 | except OSError as e:
141 | if e.errno != errno.EEXIST:
142 | raise
143 |
144 | for s in tqdm(range(num_test_samples)):
145 | if args.dataset == 'kitti':
146 | date_drive = lines[s].split('/')[1]
147 | filename_pred_png = save_name + '/raw/' + date_drive + '_' + lines[s].split()[0].split('/')[-1].replace(
148 | '.jpg', '.png')
149 | filename_cmap_png = save_name + '/cmap/' + date_drive + '_' + lines[s].split()[0].split('/')[
150 | -1].replace('.jpg', '.png')
151 | filename_image_png = save_name + '/rgb/' + date_drive + '_' + lines[s].split()[0].split('/')[-1]
152 | elif args.dataset == 'kitti_benchmark':
153 | filename_pred_png = save_name + '/raw/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
154 | filename_cmap_png = save_name + '/cmap/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
155 | filename_image_png = save_name + '/rgb/' + lines[s].split()[0].split('/')[-1]
156 | else:
157 | scene_name = lines[s].split()[0].split('/')[0]
158 | filename_pred_png = save_name + '/raw/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace(
159 | '.jpg', '.png')
160 | filename_cmap_png = save_name + '/cmap/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace(
161 | '.jpg', '.png')
162 | filename_gt_png = save_name + '/gt/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace(
163 | '.jpg', '.png')
164 | filename_image_png = save_name + '/rgb/' + scene_name + '_' + lines[s].split()[0].split('/')[1]
165 |
166 | rgb_path = os.path.join(args.data_path, './' + lines[s].split()[0])
167 | image = cv2.imread(rgb_path)
168 | if args.dataset == 'nyu':
169 | gt_path = os.path.join(args.data_path, './' + lines[s].split()[1])
170 | gt = cv2.imread(gt_path, -1).astype(np.float32) / 1000.0 # Visualization purpose only
171 | gt[gt == 0] = np.amax(gt)
172 |
173 | pred_depth = pred_depths[s]
174 | pred_8x8 = pred_8x8s[s]
175 | pred_4x4 = pred_4x4s[s]
176 | pred_2x2 = pred_2x2s[s]
177 | pred_1x1 = pred_1x1s[s]
178 |
179 | if args.dataset == 'kitti' or args.dataset == 'kitti_benchmark':
180 | pred_depth_scaled = pred_depth * 256.0
181 | else:
182 | pred_depth_scaled = pred_depth * 1000.0
183 |
184 | pred_depth_scaled = pred_depth_scaled.astype(np.uint16)
185 | cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0])
186 |
187 | if args.save_lpg:
188 | cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :])
189 | if args.dataset == 'nyu':
190 | plt.imsave(filename_gt_png, np.log10(gt[10:-1 - 9, 10:-1 - 9]), cmap='Greys')
191 | pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9]
192 | plt.imsave(filename_cmap_png, np.log10(pred_depth_cropped), cmap='Greys')
193 | pred_8x8_cropped = pred_8x8[10:-1 - 9, 10:-1 - 9]
194 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_8x8.png')
195 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_8x8_cropped), cmap='Greys')
196 | pred_4x4_cropped = pred_4x4[10:-1 - 9, 10:-1 - 9]
197 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_4x4.png')
198 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4_cropped), cmap='Greys')
199 | pred_2x2_cropped = pred_2x2[10:-1 - 9, 10:-1 - 9]
200 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_2x2.png')
201 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_2x2_cropped), cmap='Greys')
202 | pred_1x1_cropped = pred_1x1[10:-1 - 9, 10:-1 - 9]
203 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_1x1.png')
204 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_1x1_cropped), cmap='Greys')
205 | else:
206 | plt.imsave(filename_cmap_png, np.log10(pred_depth), cmap='Greys')
207 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_8x8.png')
208 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_8x8), cmap='Greys')
209 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_4x4.png')
210 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4), cmap='Greys')
211 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_2x2.png')
212 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_2x2), cmap='Greys')
213 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_1x1.png')
214 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_1x1), cmap='Greys')
215 |
216 | return
217 |
218 |
219 | if __name__ == '__main__':
220 | test(args)
221 |
--------------------------------------------------------------------------------
/pytorch/distributed_sampler_no_evenly_divisible.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.utils.data import Sampler
4 | import torch.distributed as dist
5 |
6 |
7 | class DistributedSamplerNoEvenlyDivisible(Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset.
9 |
10 | It is especially useful in conjunction with
11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12 | process can pass a DistributedSampler instance as a DataLoader sampler,
13 | and load a subset of the original dataset that is exclusive to it.
14 |
15 | .. note::
16 | Dataset is assumed to be of constant size.
17 |
18 | Arguments:
19 | dataset: Dataset used for sampling.
20 | num_replicas (optional): Number of processes participating in
21 | distributed training.
22 | rank (optional): Rank of the current process within num_replicas.
23 | shuffle (optional): If true (default), sampler will shuffle the indices
24 | """
25 |
26 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
27 | if num_replicas is None:
28 | if not dist.is_available():
29 | raise RuntimeError("Requires distributed package to be available")
30 | num_replicas = dist.get_world_size()
31 | if rank is None:
32 | if not dist.is_available():
33 | raise RuntimeError("Requires distributed package to be available")
34 | rank = dist.get_rank()
35 | self.dataset = dataset
36 | self.num_replicas = num_replicas
37 | self.rank = rank
38 | self.epoch = 0
39 | num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas))
40 | rest = len(self.dataset) - num_samples * self.num_replicas
41 | if self.rank < rest:
42 | num_samples += 1
43 | self.num_samples = num_samples
44 | self.total_size = len(dataset)
45 | # self.total_size = self.num_samples * self.num_replicas
46 | self.shuffle = shuffle
47 |
48 | def __iter__(self):
49 | # deterministically shuffle based on epoch
50 | g = torch.Generator()
51 | g.manual_seed(self.epoch)
52 | if self.shuffle:
53 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
54 | else:
55 | indices = list(range(len(self.dataset)))
56 |
57 | # add extra samples to make it evenly divisible
58 | # indices += indices[:(self.total_size - len(indices))]
59 | # assert len(indices) == self.total_size
60 |
61 | # subsample
62 | indices = indices[self.rank:self.total_size:self.num_replicas]
63 | self.num_samples = len(indices)
64 | # assert len(indices) == self.num_samples
65 |
66 | return iter(indices)
67 |
68 | def __len__(self):
69 | return self.num_samples
70 |
71 | def set_epoch(self, epoch):
72 | self.epoch = epoch
73 |
--------------------------------------------------------------------------------
/pytorch/run_bts_eval_schedule.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | import os
18 | import datetime
19 | from apscheduler.schedulers.blocking import BlockingScheduler
20 | scheduler = BlockingScheduler()
21 |
22 | @scheduler.scheduled_job('interval', minutes=1, start_date=datetime.datetime.now() + datetime.timedelta(0,3))
23 | def run_eval():
24 | command = 'export CUDA_VISIBLE_DEVICES=0; ' \
25 | '/usr/bin/python ' \
26 | 'bts_eval.py ' \
27 | '--encoder densenet161_bts ' \
28 | '--dataset kitti ' \
29 | '--data_path ../../dataset/kitti_dataset/ ' \
30 | '--gt_path ../../dataset/kitti_dataset/data_depth_annotated/ ' \
31 | '--filenames_file ../train_test_inputs/eigen_test_files_with_gt.txt ' \
32 | '--input_height 352 ' \
33 | '--input_width 1216 ' \
34 | '--garg_crop ' \
35 | '--max_depth 80 ' \
36 | '--max_depth_eval 80 ' \
37 | '--output_directory ./models/eval-eigen/ ' \
38 | '--model_name bts_eigen_v0_0_1 ' \
39 | '--checkpoint_path ./models/bts_eigen_v0_0_1/ ' \
40 | '--do_kb_crop '
41 |
42 | print('Executing: %s' % command)
43 | os.system(command)
44 | print('Finished: %s' % datetime.datetime.now())
45 |
46 | scheduler.configure()
47 | scheduler.start()
--------------------------------------------------------------------------------
/pytorch/run_bts_live_3d.sh:
--------------------------------------------------------------------------------
1 | python3 bts_live_3d.py --model_name bts_nyu_v2_pytorch_densenet161 --encoder densenet161_bts --checkpoint_path ./models/bts_nyu_v2_pytorch_densenet161/model --max_depth 10 --input_height 480 --input_width 640
2 |
--------------------------------------------------------------------------------
/tensorflow/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:1.13.2-gpu-jupyter
2 |
3 | # libcuda.so.1 is not available by default so we add what are probably stubs.
4 | # See https://github.com/tensorflow/tensorflow/issues/25865
5 | # If we leave the stubs linked later, then we get a weird error about CUDA
6 | # versions not matching, so we have to remove it later.
7 | ENV LD_LIBRARY_PATH_OLD="${LD_LIBRARY_PATH}"
8 | ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda-10.0/compat"
9 |
10 | # Load everything we need to build the custom layer and stuff required by opencv.
11 | RUN apt-get update && apt-get install -y \
12 | build-essential \
13 | cmake \
14 | g++ \
15 | libsm6 \
16 | libxext6 \
17 | libxrender-dev \
18 | && rm -rf /var/lib/apt/lists/*
19 |
20 | # Setup our build paths
21 | RUN mkdir -p /build
22 | COPY custom_layer /build/custom_layer
23 | RUN mkdir -p /build/custom_layer/build
24 |
25 | # Compile the new layer
26 | WORKDIR /build/custom_layer/build
27 | RUN cmake -D CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda ..
28 | RUN make -j
29 |
30 | # Install the python requirements.
31 | COPY requirements.txt /
32 | RUN pip install -r /requirements.txt
33 |
34 | # Copy in the full repo.
35 | COPY . /bts
36 | WORKDIR /bts
37 |
38 | # Put the new layer we built into /bts/custom_layer
39 | RUN cp -r /build/custom_layer/build custom_layer/.
40 |
41 | # Download the model locally.
42 | RUN mkdir -p models \
43 | && python utils/download_from_gdrive.py 1ipme-fkV4pIx87sOs31R9CD_Qg-85__h models/bts_nyu.zip \
44 | && cd models \
45 | && unzip bts_nyu.zip
46 |
47 | # Set the path back to avoid error (see above).
48 | ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH_OLD}"
49 |
50 | # Add relevant paths to the PYTHONPATH so they can be imported from anywhere.
51 | ENV PYTHONPATH=/bts:/bts/models/bts_nyu
52 |
--------------------------------------------------------------------------------
/tensorflow/README.md:
--------------------------------------------------------------------------------
1 | # BTS
2 | From Big to Small: Multi-Scale Local Planar Guidance for Monocular Depth Estimation
3 | [arXiv](https://arxiv.org/abs/1907.10326)
4 | [Supplementary material](https://arxiv.org/src/1907.10326v4/anc/bts_sm.pdf)
5 |
6 | ## Note
7 | This folder contains a Tensorflow implementation of BTS.\
8 | We tested this code under python 2.7 and 3.6, Tensorflow 1.14, CUDA 10.0 on Ubuntu 18.04. \
9 |
10 | If you use TensorFlow built from source, it is okay with v1.14. \
11 | If you use TensorFlow installed using pip, it is okay up to v1.13.2. \
12 | Currently, if we use TensorFlow v1.14.0 installed using pip, we get segmentation fault.
13 |
14 |
15 | ## Preparation
16 | ```shell
17 | $ cd ~/workspace/bts/tensorflow/custom_layer
18 | $ mkdir build && cd build
19 | $ cmake -D CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda ..
20 | $ make -j
21 | ```
22 | If you encounter an error "fatal error: third_party/gpus/cuda/include/cuda_fp16.h: No such file or directory",
23 | open "tensorflow/include/tensorflow/core/util/gpu_kernel_helper.h" and edit a line from
24 | ```
25 | #include "third_party/gpus/cuda/include/cuda_fp16.h"
26 | ```
27 | to
28 | ```
29 | #include "cuda_fp16.h"
30 | ```
31 | Also, you will need to edit lines in "tensorflow/include/tensorflow/core/util/gpu_device_functions.h" from
32 | ```
33 | #include "third_party/gpus/cuda/include/cuComplex.h"
34 | #include "third_party/gpus/cuda/include/cuda.h"
35 | ```
36 | to
37 | ```
38 | #include "cuComplex.h"
39 | #include "cuda.h"
40 | ```
41 |
42 | If you are testing with Tensorflow version lower than 1.14, please edit a line in "compute_depth.cu" from
43 | ```
44 | #include "tensorflow/include/tensorflow/core/util/gpu_kernel_helper.h"
45 | ```
46 | to
47 | ```
48 | #include "tensorflow/include/tensorflow/core/util/cuda_kernel_helper.h"
49 | ```
50 |
51 | Then issue the make commands again.
52 | ```shell
53 | $ cmake ..
54 | $ make -j
55 | ```
56 |
57 | ## Testing with [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)
58 | First make sure that you have prepared the test set using instructions in README.md at root of this repo.
59 | ```shell
60 | $ cd ~/workspace/bts/tensorflow
61 | $ mkdir models
62 | ### Get BTS model trained with NYU Depth V2
63 | $ python ../utils/download_from_gdrive.py 1goRL8aZw8bwZ8cZmne_cJTBnBOT6ii0S models/bts_nyu_v2.zip
64 | $ cd models
65 | $ unzip bts_nyu_v2.zip
66 | ```
67 | Once the preparation steps completed, you can test BTS using following commands.
68 | ```
69 | $ cd ~/workspace/bts/tensorflow
70 | $ python bts_test.py arguments_test_nyu.txt
71 | ```
72 | This will save results to ./result_bts_nyu_v2. With a single RTX 2080 Ti it takes about 34 seconds for processing 654 testing images.
73 |
74 | ## Evaluation
75 | Following command will evaluate the prediction results for NYU Depvh V2.
76 | ```
77 | $ cd ~/workspace/bts
78 | $ python utils/eval_with_pngs.py --pred_path ./tensorflow/result_bts_nyu_v2/raw/ --gt_path ../dataset/nyu_depth_v2/official_splits/test/ --dataset nyu --min_depth_eval 1e-3 --max_depth_eval 10 --eigen_crop
79 | ```
80 |
81 | You should see outputs like this:
82 | ```
83 | Raw png files reading done
84 | Evaluating 654 files
85 | GT files reading done
86 | 0 GT files missing
87 | Computing errors
88 | d1, d2, d3, AbsRel, SqRel, RMSE, RMSElog, SILog, log10
89 | 0.886, 0.981, 0.995, 0.110, 0.059, 0.350, 0.138, 11.076, 0.046
90 | Done.
91 | ```
92 |
93 | ## Preparation for Training
94 | ### NYU Depvh V2
95 | First, you need to download DenseNet-161 model pretrained with ImageNet.
96 | ```
97 | # Get DenseNet-161 model pretrained with ImageNet
98 | $ cd ~/workspace/bts
99 | $ python utils/download_from_gdrive.py 1rn7xBF5eSISFKL2bIa8o3d8dNnsrlWfJ tensorflow/models/densenet161_imagenet.zip
100 | $ cd tensorflow/models && unzip densenet161_imagenet.zip
101 | ```
102 | Then, download the dataset we used in this work.
103 | ```
104 | $ cd ~/workspace/bts
105 | $ python utils/download_from_gdrive.py 1AysroWpfISmm-yRFGBgFTrLy6FjQwvwP ../dataset/nyu_depth_v2/sync.zip
106 | $ unzip sync.zip
107 | ```
108 |
109 | Also, you can download it from following link: https://drive.google.com/file/d/1AysroWpfISmm-yRFGBgFTrLy6FjQwvwP/view?usp=sharing Please make sure to locate the downloaded file to ~/workspace/bts/dataset/nyu_depth_v2/sync.zip
110 |
111 | Or, using a MATLAB script, you can prepare the dataset by yourself using original files from official site [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html).
112 | There are two options for downloading original files: Single file downloading and Segmented-files downloading.
113 |
114 | Single file downloading:
115 | ```
116 | $ cd ~/workspace/dataset/nyu_depth_v2
117 | $ mkdir raw && cd raw
118 | $ wget http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_raw.zip
119 | $ unzip nyu_depth_v2_raw.zip
120 | ```
121 | Segmented-files downloading:
122 | ```
123 | $ cd ~/workspace/dataset/nyu_depth_v2
124 | $ mkdir raw && cd raw
125 | $ aria2c -x 16 -i ../../../bts/utils/nyudepthv2_archives_to_download.txt
126 | $ cd ~/workspace/bts
127 | $ python utils/download_from_gdrive.py 1xBwO6qU8UCS69POJJ0-9luaG_1pS1khW ../dataset/nyu_depth_v2/raw/bathroom_0039.zip
128 | $ python utils/download_from_gdrive.py 1IFoci9kns6vOV833S7osV6c5HmGxZsBp ../dataset/nyu_depth_v2/raw/bedroom_0076a.zip
129 | $ python utils/download_from_gdrive.py 1ysSeyiOiOI1EKr1yhmKy4jcYiXdgLP4f ../dataset/nyu_depth_v2/raw/living_room_0018.zip
130 | $ python utils/download_from_gdrive.py 1QkHkK46VuKBPszB-mb6ysFp7VO92UgfB ../dataset/nyu_depth_v2/raw/living_room_0019.zip
131 | $ python utils/download_from_gdrive.py 1g1Xc3urlI_nIcgWk8I-UaFXJHiKGzK6w ../dataset/nyu_depth_v2/raw/living_room_0020.zip
132 | $ parallel unzip ::: *.zip
133 | ```
134 | Get the official MATLAB toolbox for rgb and depth synchronization.
135 | ```
136 | $ cd ~/workspace/bts/utils
137 | $ wget http://cs.nyu.edu/~silberman/code/toolbox_nyu_depth_v2.zip
138 | $ unzip toolbox_nyu_depth_v2.zip
139 | $ cd toolbox_nyu_depth_v2
140 | $ mv ../sync_project_frames_multi_threads.m .
141 | $ mv ../train_scenes.txt .
142 | ```
143 | Run script "sync_project_frames_multi_threads.m" using MATLAB to get synchronized RGB and depth images.
144 | This will save rgb-depth pairs in "~/workspace/dataset/nyu_depth_v2/sync".
145 |
146 | Once the dataset is ready, you can train the network using following command.
147 | ```
148 | $ cd ~/workspace/bts/tensorflow
149 | $ python bts_main.py arguments_train_nyu.txt
150 | ```
151 | You can check the training using tensorboard:
152 | ```
153 | $ tensorboard --logdir ./models/bts_nyu_test/ --port 6006
154 | ```
155 | Open localhost:6006 with your favorite browser to see the progress of training.
156 |
157 | ### KITTI
158 | You can also train BTS with KITTI dataset by following procedures.
159 | First, make sure that you have prepared the ground truth depthmaps from [KITTI](http://www.cvlibs.net/download.php?file=data_depth_annotated.zip).
160 | If you have not, please follow instructions on README.md at root of this repo.
161 | Then, download and unzip the raw dataset using following commands.
162 | ```
163 | $ cd ~/workspace/dataset/kitti_dataset
164 | $ aria2c -x 16 -i ../../bts/utils/kitti_archives_to_download.txt
165 | $ parallel unzip ::: *.zip
166 | ```
167 | Finally, we can train our network with
168 | ```
169 | $ cd ~/workspace/bts/tensorflow
170 | $ python bts_main.py arguments_train_eigen.txt
171 | ```
172 |
173 | ## Testing and Evaluation with [KITTI](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_prediction)
174 | Once you have KITTI dataset and official ground truth depthmaps, you can test and evaluate our model with following commands.
175 | ```
176 | # Get KITTI model trained with KITTI Eigen split
177 | $ cd ~/workspace/bts
178 | $ python utils/download_from_gdrive.py 1nhukEgl3YdTBKVzcjxUp6ZFMsKKM3xfg tensorflow/models/bts_eigen_v2.zip
179 | $ cd tensorflow/models && unzip bts_eigen_v2.zip
180 | ```
181 | Test and save results.
182 | ```
183 | $ cd ~/workspace/bts/tensorflow
184 | $ python bts_test.py arguments_test_eigen.txt
185 | ```
186 | This will save results to ./result_bts_eigen_v2.
187 | Finally, we can evaluate the prediction results with
188 | ```
189 | $ cd ~/workspace/bts
190 | $ python utils/eval_with_pngs.py --pred_path ./tensorflow/result_bts_eigen_v2/raw/ --gt_path ../dataset/kitti_dataset/data_depth_annotated/ --dataset kitti --min_depth_eval 1e-3 --max_depth_eval 80 --do_kb_crop --garg_crop
191 | ```
192 | You should see outputs like this:
193 | ```
194 | GT files reading done
195 | 45 GT files missing
196 | Computing errors
197 | d1, d2, d3, AbsRel, SqRel, RMSE, RMSElog, SILog, log10
198 | 0.952, 0.993, 0.998, 0.063, 0.257, 2.791, 0.099, 9.168, 0.028
199 | Done.
200 | ```
201 |
202 | ## License
203 | Copyright (C) 2019 Jin Han Lee, Myung-Kyu Han, Dong Wook Ko and Il Hong Suh \
204 | This Software is licensed under GPL-3.0-or-later.
205 |
--------------------------------------------------------------------------------
/tensorflow/arguments_test_eigen.txt:
--------------------------------------------------------------------------------
1 | --encoder densenet161_bts
2 | --data_path ../dataset/kitti_dataset/
3 | --dataset kitti
4 | --filenames_file ./train_test_inputs/eigen_test_files_with_gt.txt
5 | --model_name bts_eigen_v2
6 | --checkpoint_path ./models/bts_eigen_v2/model
7 | --input_height 352
8 | --input_width 1216
9 | --max_depth 80
10 | --do_kb_crop
--------------------------------------------------------------------------------
/tensorflow/arguments_test_nyu.txt:
--------------------------------------------------------------------------------
1 | --encoder densenet161_bts
2 | --data_path ../dataset/nyu_depth_v2/official_splits/test/
3 | --dataset nyu
4 | --filenames_file ./train_test_inputs/nyudepthv2_test_files_with_gt.txt
5 | --model_name bts_nyu_v2
6 | --checkpoint_path ./models/bts_nyu_v2/model
7 | --input_height 480
8 | --input_width 640
9 | --max_depth 10
10 |
--------------------------------------------------------------------------------
/tensorflow/arguments_train_eigen.txt:
--------------------------------------------------------------------------------
1 | --mode train
2 | --model_name bts_eigen_test
3 | --encoder densenet161_bts
4 | --dataset kitti
5 | --data_path ../dataset/kitti_dataset/
6 | --gt_path ../dataset/kitti_dataset/data_depth_annotated/
7 | --filenames_file ./train_test_inputs/eigen_train_files_with_gt.txt
8 | --batch_size 4
9 | --num_epochs 50
10 | --learning_rate 1e-4
11 | --num_gpus 1
12 | --num_threads 1
13 | --input_height 352
14 | --input_width 704
15 | --max_depth 80
16 | --do_kb_crop
17 | --do_random_rotate
18 | --degree 1.0
19 | --log_directory ./models/
20 | --pretrained_model ./models/densenet161_imagenet/model
21 | --fix_first_conv_blocks
22 |
--------------------------------------------------------------------------------
/tensorflow/arguments_train_nyu.txt:
--------------------------------------------------------------------------------
1 | --mode train
2 | --model_name bts_nyu_test
3 | --encoder densenet161_bts
4 | --dataset nyu
5 | --data_path ../dataset/nyu_depth_v2/sync/
6 | --gt_path ../dataset/nyu_depth_v2/sync/
7 | --filenames_file ./train_test_inputs/nyudepthv2_train_files_with_gt.txt
8 | --batch_size 4
9 | --num_epochs 50
10 | --learning_rate 1e-4
11 | --num_gpus 1
12 | --num_threads 1
13 | --input_height 416
14 | --input_width 544
15 | --max_depth 10
16 | --do_random_rotate
17 | --degree 2.5
18 | --log_directory ./models/
19 | --pretrained_model ./models/densenet161_imagenet/model
20 | --fix_first_conv_blocks
21 |
--------------------------------------------------------------------------------
/tensorflow/average_gradients.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import, division, print_function
16 | import tensorflow as tf
17 |
18 | def average_gradients(tower_grads):
19 |
20 | average_grads = []
21 | for grad_and_vars in zip(*tower_grads):
22 | # Note that each grad_and_vars looks like the following:
23 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
24 | grads = []
25 | for g, _ in grad_and_vars:
26 | # Add 0 dimension to the gradients to represent the tower.
27 | expanded_g = tf.expand_dims(g, 0)
28 |
29 | # Append on a 'tower' dimension which we will average over below.
30 | grads.append(expanded_g)
31 |
32 | # Average over the 'tower' dimension.
33 | grad = tf.concat(axis=0, values=grads)
34 | grad = tf.reduce_mean(grad, 0)
35 |
36 | # Keep in mind that the Variables are redundant because they are shared
37 | # across towers. So .. we will just return the first tower's pointer to
38 | # the Variable.
39 | v = grad_and_vars[0][1]
40 | grad_and_var = (grad, v)
41 | average_grads.append(grad_and_var)
42 |
43 | return average_grads
44 |
--------------------------------------------------------------------------------
/tensorflow/bts_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | from __future__ import absolute_import, division, print_function
18 | import tensorflow as tf
19 | from tensorflow.python.ops import array_ops
20 |
21 |
22 | class BtsDataloader(object):
23 | """bts dataloader"""
24 |
25 | def __init__(self, data_path, gt_path, filenames_file, params, mode,
26 | do_rotate=False, degree=5.0, do_kb_crop=False):
27 |
28 | self.data_path = data_path
29 | self.gt_path = gt_path
30 | self.params = params
31 | self.mode = mode
32 |
33 | self.do_rotate = do_rotate
34 | self.degree = degree
35 |
36 | self.do_kb_crop = do_kb_crop
37 |
38 | with open(filenames_file, 'r') as f:
39 | filenames = f.readlines()
40 |
41 | if mode == 'train':
42 | assert not self.params.batch_size % self.params.num_gpus
43 | mini_batch_size = int(self.params.batch_size / self.params.num_gpus)
44 |
45 | self.loader = tf.data.Dataset.from_tensor_slices(filenames)
46 | self.loader = self.loader.apply(tf.contrib.data.shuffle_and_repeat(len(filenames)))
47 | self.loader = self.loader.map(self.parse_function_train, num_parallel_calls=params.num_threads)
48 | self.loader = self.loader.map(self.train_preprocess, num_parallel_calls=params.num_threads)
49 | self.loader = self.loader.batch(mini_batch_size)
50 | self.loader = self.loader.prefetch(mini_batch_size)
51 |
52 | else:
53 | self.loader = tf.data.Dataset.from_tensor_slices(filenames)
54 | self.loader = self.loader.map(self.parse_function_test, num_parallel_calls=1)
55 | self.loader = self.loader.map(self.test_preprocess, num_parallel_calls=1)
56 | self.loader = self.loader.batch(1)
57 | self.loader = self.loader.prefetch(1)
58 |
59 | def parse_function_test(self, line):
60 | split_line = tf.string_split([line]).values
61 | image_path = tf.string_join([self.data_path, split_line[0]])
62 |
63 | if self.params.dataset == 'nyu':
64 | image = tf.image.decode_jpeg(tf.read_file(image_path))
65 | else:
66 | image = tf.image.decode_png(tf.read_file(image_path))
67 |
68 | image = tf.image.convert_image_dtype(image, tf.float32)
69 | focal = tf.string_to_number(split_line[2])
70 |
71 | if self.do_kb_crop is True:
72 | height = tf.shape(image)[0]
73 | width = tf.shape(image)[1]
74 | top_margin = tf.to_int32(height - 352)
75 | left_margin = tf.to_int32((width - 1216) / 2)
76 | image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
77 |
78 | return image, focal
79 |
80 | def test_preprocess(self, image, focal):
81 |
82 | image.set_shape([None, None, 3])
83 |
84 | image *= 255.0
85 | image = self.mean_image_subtraction(image, [123.68, 116.78, 103.94])
86 |
87 | if self.params.encoder == 'densenet161_bts' or self.params.encoder == 'densenet121_bts':
88 | image *= 0.017
89 |
90 | return image, focal
91 |
92 | def parse_function_train(self, line):
93 | split_line = tf.string_split([line]).values
94 | image_path = tf.string_join([self.data_path, split_line[0]])
95 | depth_gt_path = tf.string_join([self.gt_path, tf.string_strip(split_line[1])])
96 |
97 | if self.params.dataset == 'nyu':
98 | image = tf.image.decode_jpeg(tf.read_file(image_path))
99 | else:
100 | image = tf.image.decode_png(tf.read_file(image_path))
101 |
102 | depth_gt = tf.image.decode_png(tf.read_file(depth_gt_path), channels=0, dtype=tf.uint16)
103 |
104 | if self.params.dataset == 'nyu':
105 | depth_gt = tf.cast(depth_gt, tf.float32) / 1000.0
106 | else:
107 | depth_gt = tf.cast(depth_gt, tf.float32) / 256.0
108 |
109 | image = tf.image.convert_image_dtype(image, tf.float32)
110 | focal = tf.string_to_number(split_line[2])
111 |
112 | # To avoid blank boundaries due to pixel registration
113 | if self.params.dataset == 'nyu':
114 | depth_gt = depth_gt[45:472, 43:608, :]
115 | image = image[45:472, 43:608, :]
116 |
117 | if self.do_kb_crop is True:
118 | print('Cropping training images as kitti benchmark images')
119 | height = tf.shape(image)[0]
120 | width = tf.shape(image)[1]
121 | top_margin = tf.to_int32(height - 352)
122 | left_margin = tf.to_int32((width - 1216) / 2)
123 | depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
124 | image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
125 |
126 | if self.do_rotate is True:
127 | random_angle = tf.random_uniform([], - self.degree * 3.141592 / 180, self.degree * 3.141592 / 180)
128 | image = tf.contrib.image.rotate(image, random_angle, interpolation='BILINEAR')
129 | depth_gt = tf.contrib.image.rotate(depth_gt, random_angle, interpolation='NEAREST')
130 |
131 | print('Do random cropping from fixed size input')
132 | image, depth_gt = self.random_crop_fixed_size(image, depth_gt)
133 |
134 | return image, depth_gt, focal
135 |
136 | def train_preprocess(self, image, depth_gt, focal):
137 | # Random flipping
138 | do_flip = tf.random_uniform([], 0, 1)
139 | image = tf.cond(do_flip > 0.5, lambda: tf.image.flip_left_right(image), lambda: image)
140 | depth_gt = tf.cond(do_flip > 0.5, lambda: tf.image.flip_left_right(depth_gt), lambda: depth_gt)
141 |
142 | # Random gamma, brightness, color augmentation
143 | do_augment = tf.random_uniform([], 0, 1)
144 | image = tf.cond(do_augment > 0.5, lambda: self.augment_image(image), lambda: image)
145 |
146 | image.set_shape([self.params.height, self.params.width, 3])
147 | depth_gt.set_shape([self.params.height, self.params.width, 1])
148 |
149 | image *= 255.0
150 | image = self.mean_image_subtraction(image, [123.68, 116.78, 103.94])
151 |
152 | if self.params.encoder == 'densenet161_bts' or self.params.encoder == 'densenet121_bts':
153 | image *= 0.017
154 |
155 | return image, depth_gt, focal
156 |
157 | def random_crop_fixed_size(self, image, depth_gt):
158 | image_depth = tf.concat([image, depth_gt], 2)
159 | image_depth_cropped = tf.random_crop(image_depth, [self.params.height, self.params.width, 4])
160 |
161 | image_cropped = image_depth_cropped[:, :, 0:3]
162 | depth_gt_cropped = tf.expand_dims(image_depth_cropped[:, :, 3], 2)
163 |
164 | return image_cropped, depth_gt_cropped
165 |
166 | def augment_image(self, image):
167 | # gamma augmentation
168 | gamma = tf.random_uniform([], 0.9, 1.1)
169 | image_aug = image ** gamma
170 |
171 | # brightness augmentation
172 | if self.params.dataset == 'nyu':
173 | brightness = tf.random_uniform([], 0.75, 1.25)
174 | else:
175 | brightness = tf.random_uniform([], 0.9, 1.1)
176 | image_aug = image_aug * brightness
177 |
178 | # color augmentation
179 | colors = tf.random_uniform([3], 0.9, 1.1)
180 | white = tf.ones([tf.shape(image)[0], tf.shape(image)[1]])
181 | color_image = tf.stack([white * colors[i] for i in range(3)], axis=2)
182 | image_aug *= color_image
183 |
184 | # clip
185 | if self.params.encoder == 'densenet161_bts' or self.params.encoder == 'densenet121_bts':
186 | image_aug = tf.clip_by_value(image_aug, 0, 1)
187 | else:
188 | image_aug = tf.clip_by_value(image_aug, 0, 255)
189 |
190 | return image_aug
191 |
192 | @staticmethod
193 | def mean_image_subtraction(image, means):
194 | """Subtracts the given means from each image channel.
195 | For example:
196 | means = [123.68, 116.779, 103.939]
197 | image = mean_image_subtraction(image, means)
198 | Note that the rank of `image` must be known.
199 | Args:
200 | image: a tensor of size [height, width, C].
201 | means: a C-vector of values to subtract from each channel.
202 | Returns:
203 | the centered image.
204 | Raises:
205 | ValueError: If the rank of `image` is unknown, if `image` has a rank other
206 | than three or if the number of channels in `image` doesn't match the
207 | number of values in `means`.
208 | """
209 |
210 | if image.get_shape().ndims != 3:
211 | raise ValueError('Input must be of size [height, width, C>0]')
212 | num_channels = image.get_shape().as_list()[-1]
213 | if len(means) != num_channels:
214 | raise ValueError('len(means) must match the number of channels')
215 |
216 | channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image)
217 | for i in range(num_channels):
218 | channels[i] -= means[i]
219 | return tf.concat(axis=2, values=channels)
220 |
--------------------------------------------------------------------------------
/tensorflow/bts_eval.py:
--------------------------------------------------------------------------------
1 | # This file is a part of BTS.
2 | # This program is free software: you can redistribute it and/or modify
3 | # it under the terms of the GNU General Public License as published by
4 | # the Free Software Foundation, either version 3 of the License, or
5 | # (at your option) any later version.
6 | #
7 | # This program is distributed in the hope that it will be useful,
8 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 | # GNU General Public License for more details.
11 | #
12 | # You should have received a copy of the GNU General Public License
13 | # along with this program. If not, see
14 |
15 | from __future__ import absolute_import, division, print_function
16 |
17 | import os
18 | import argparse
19 | import time
20 | import numpy as np
21 | import cv2
22 | import sys
23 |
24 | from bts_dataloader import *
25 |
26 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
27 |
28 |
29 | def convert_arg_line_to_args(arg_line):
30 | for arg in arg_line.split():
31 | if not arg.strip():
32 | continue
33 | yield arg
34 |
35 |
36 | parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.', fromfile_prefix_chars='@')
37 | parser.convert_arg_line_to_args = convert_arg_line_to_args
38 |
39 | parser.add_argument('--model_name', type=str, help='model name', default='bts_v0_0_1')
40 | parser.add_argument('--encoder', type=str, help='type of encoder, desenet121_bts or densenet161_bts', default='densenet161_bts')
41 | parser.add_argument('--data_path', type=str, help='path to the data', required=True)
42 | parser.add_argument('--gt_path', type=str, help='path to the groundtruth data', required=False)
43 | parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
44 | parser.add_argument('--input_height', type=int, help='input height', default=480)
45 | parser.add_argument('--input_width', type=int, help='input width', default=640)
46 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=80)
47 | parser.add_argument('--output_directory', type=str, help='output directory for summary, if empty outputs to checkpoint folder', default='')
48 | parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='')
49 | parser.add_argument('--dataset', type=str, help='dataset to train on, make3d or nyudepthv2', default='nyu')
50 | parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
51 | parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
52 |
53 | parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
54 | parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
55 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
56 |
57 |
58 | if sys.argv.__len__() == 2:
59 | arg_filename_with_prefix = '@' + sys.argv[1]
60 | args = parser.parse_args([arg_filename_with_prefix])
61 | else:
62 | args = parser.parse_args()
63 |
64 | model_dir = os.path.dirname(args.checkpoint_path)
65 | sys.path.append(model_dir)
66 |
67 | for key, val in vars(__import__(args.model_name)).items():
68 | if key.startswith('__') and key.endswith('__'):
69 | continue
70 | vars()[key] = val
71 |
72 |
73 | def compute_errors(gt, pred):
74 | thresh = np.maximum((gt / pred), (pred / gt))
75 | d1 = (thresh < 1.25).mean()
76 | d2 = (thresh < 1.25 ** 2).mean()
77 | d3 = (thresh < 1.25 ** 3).mean()
78 |
79 | rmse = (gt - pred) ** 2
80 | rmse = np.sqrt(rmse.mean())
81 |
82 | rmse_log = (np.log(gt) - np.log(pred)) ** 2
83 | rmse_log = np.sqrt(rmse_log.mean())
84 |
85 | abs_rel = np.mean(np.abs(gt - pred) / gt)
86 | sq_rel = np.mean(((gt - pred)**2) / gt)
87 |
88 | err = np.log(pred) - np.log(gt)
89 | silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
90 |
91 | err = np.abs(np.log10(pred) - np.log10(gt))
92 | log10 = np.mean(err)
93 |
94 | return silog, log10, abs_rel, sq_rel, rmse, rmse_log, d1, d2, d3
95 |
96 |
97 | def get_num_lines(file_path):
98 | f = open(file_path, 'r')
99 | lines = f.readlines()
100 | f.close()
101 | return len(lines)
102 |
103 |
104 | def test(params):
105 | global gt_depths, is_missing, missing_ids
106 | gt_depths = []
107 | is_missing = []
108 | missing_ids = set()
109 |
110 | write_summary = False
111 |
112 | if os.path.exists(args.checkpoint_path + '.meta'):
113 | steps = [str(args.checkpoint_path).split('/')[-1].split('-')[-1]]
114 | else:
115 | with open(args.checkpoint_path + '/checkpoint') as file:
116 | lines = file.readlines()[1:]
117 |
118 | steps = set()
119 | for line in lines:
120 | step = line.split()[1].split('/')[-1].split('-')[-1].replace('\"', '')
121 | steps.add('{:06d}'.format(int(step)))
122 |
123 | lines = []
124 | if os.path.exists(args.checkpoint_path + '/evaluated_checkpoints'):
125 | with open(args.checkpoint_path + '/evaluated_checkpoints') as file:
126 | lines = file.readlines()
127 |
128 | for line in lines:
129 | if line.rstrip() in steps:
130 | steps.remove(line.rstrip())
131 |
132 | steps = sorted(steps)
133 | if args.output_directory != '':
134 | summary_path = os.path.join(args.output_directory, args.model_name)
135 | else:
136 | summary_path = os.path.join(args.checkpoint_path, 'eval')
137 | write_summary = True
138 |
139 | if len(steps) == 0:
140 | print('No new model to evaluate. Abort.')
141 | return
142 |
143 | time_modified = os.path.getmtime(args.checkpoint_path + 'checkpoint')
144 | time_diff = time.time() - time_modified
145 | if time_diff < 60:
146 | print('Model file might not be mature due to short time_diff: %s' % str(time_diff))
147 | print('Aborting')
148 | return
149 | else:
150 | print('time_diff: %s' % str(time_diff))
151 |
152 | dataloader = BtsDataloader(args.data_path, args.gt_path, args.filenames_file, params, 'test',
153 | do_kb_crop=args.do_kb_crop)
154 |
155 | dataloader_iter = dataloader.loader.make_initializable_iterator()
156 | iter_init_op = dataloader_iter.initializer
157 | image, focal = dataloader_iter.get_next()
158 |
159 | model = BtsModel(params, 'test', image, None, focal=focal, bn_training=False)
160 |
161 | if write_summary:
162 | summary_writer = tf.summary.FileWriter(summary_path)
163 |
164 | # SESSION
165 | config = tf.ConfigProto(allow_soft_placement=True)
166 | sess = tf.Session(config=config)
167 |
168 | # INIT
169 | sess.run(tf.global_variables_initializer())
170 | sess.run(tf.local_variables_initializer())
171 | coordinator = tf.train.Coordinator()
172 | threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)
173 |
174 | # SAVER
175 | train_saver = tf.train.Saver()
176 |
177 | with tf.device('/cpu:0'):
178 | for step in steps:
179 |
180 | if os.path.exists(args.checkpoint_path + '.meta'):
181 | restore_path = args.checkpoint_path
182 | else:
183 | restore_path = os.path.join(args.checkpoint_path, 'model-' + str(int(step)))
184 |
185 | # RESTORE
186 | train_saver.restore(sess, restore_path)
187 |
188 | num_test_samples = get_num_lines(args.filenames_file)
189 |
190 | with open(args.filenames_file) as f:
191 | lines = f.readlines()
192 |
193 | print('now testing {} files for step {}'.format(num_test_samples, step))
194 | sess.run(iter_init_op)
195 |
196 | pred_depths = []
197 |
198 | start_time = time.time()
199 | for s in range(num_test_samples):
200 | depth = sess.run([model.depth_est])
201 | pred_depths.append(depth[0].squeeze())
202 |
203 | elapsed_time = time.time() - start_time
204 | print('Elapesed time: %s' % str(elapsed_time))
205 | print('Done.')
206 |
207 | if len(gt_depths) == 0:
208 | for t_id in range(num_test_samples):
209 | gt_depth_path = os.path.join(args.gt_path, lines[t_id].split()[1])
210 | depth = cv2.imread(gt_depth_path, -1)
211 | if depth is None:
212 | print('Missing: %s ' % gt_depth_path)
213 | missing_ids.add(t_id)
214 | continue
215 |
216 | if args.dataset == 'nyu':
217 | depth = depth.astype(np.float32) / 1000.0
218 | else:
219 | depth = depth.astype(np.float32) / 256.0
220 |
221 | gt_depths.append(depth)
222 |
223 | print('Computing errors')
224 | silog, log10, abs_rel, sq_rel, rms, log_rms, d1, d2, d3 = eval(pred_depths, int(step))
225 |
226 | if write_summary:
227 | summary = tf.Summary()
228 | summary.value.add(tag='silog', simple_value=silog.mean())
229 | summary.value.add(tag='abs_rel', simple_value=abs_rel.mean())
230 | summary.value.add(tag='log10', simple_value=log10.mean())
231 | summary.value.add(tag='sq_rel', simple_value=sq_rel.mean())
232 | summary.value.add(tag='rms', simple_value=rms.mean())
233 | summary.value.add(tag='log_rms', simple_value=log_rms.mean())
234 | summary.value.add(tag='d1', simple_value=d1.mean())
235 | summary.value.add(tag='d2', simple_value=d2.mean())
236 | summary.value.add(tag='d3', simple_value=d3.mean())
237 |
238 | summary_writer.add_summary(summary, global_step=step)
239 | summary_writer.flush()
240 |
241 | with open(os.path.dirname(args.checkpoint_path) + '/evaluated_checkpoints', 'a') as file:
242 | file.write(step + '\n')
243 |
244 | print('Evaluation done')
245 |
246 |
247 | def eval(pred_depths, step):
248 |
249 | num_samples = get_num_lines(args.filenames_file)
250 | pred_depths_valid = []
251 |
252 | for t_id in range(num_samples):
253 | if t_id in missing_ids:
254 | continue
255 |
256 | pred_depths_valid.append(pred_depths[t_id])
257 |
258 | num_samples = num_samples - len(missing_ids)
259 |
260 | silog = np.zeros(num_samples, np.float32)
261 | log10 = np.zeros(num_samples, np.float32)
262 | rms = np.zeros(num_samples, np.float32)
263 | log_rms = np.zeros(num_samples, np.float32)
264 | abs_rel = np.zeros(num_samples, np.float32)
265 | sq_rel = np.zeros(num_samples, np.float32)
266 | d1 = np.zeros(num_samples, np.float32)
267 | d2 = np.zeros(num_samples, np.float32)
268 | d3 = np.zeros(num_samples, np.float32)
269 |
270 | for i in range(num_samples):
271 |
272 | gt_depth = gt_depths[i]
273 | pred_depth = pred_depths_valid[i]
274 |
275 | if args.do_kb_crop:
276 | height, width = gt_depth.shape
277 | top_margin = int(height - 352)
278 | left_margin = int((width - 1216) / 2)
279 | pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
280 | pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
281 | pred_depth = pred_depth_uncropped
282 |
283 | pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
284 | pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
285 | pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
286 |
287 | valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
288 |
289 | if args.garg_crop or args.eigen_crop:
290 | gt_height, gt_width = gt_depth.shape
291 | eval_mask = np.zeros(valid_mask.shape)
292 |
293 | if args.garg_crop:
294 | eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
295 |
296 |
297 | elif args.eigen_crop:
298 | if args.dataset == 'kitti':
299 | eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
300 |
301 | else:
302 | eval_mask[45:471, 41:601] = 1
303 |
304 | valid_mask = np.logical_and(valid_mask, eval_mask)
305 |
306 | silog[i], log10[i], abs_rel[i], sq_rel[i], rms[i], log_rms[i], d1[i], d2[i], d3[i] = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
307 |
308 | print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3'))
309 | print("{:7.4f}, {:7.4f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}".format(
310 | silog.mean(), abs_rel.mean(), log10.mean(), rms.mean(), sq_rel.mean(), log_rms.mean(), d1.mean(), d2.mean(), d3.mean()))
311 |
312 | return silog, log10, abs_rel, sq_rel, rms, log_rms, d1, d2, d3
313 |
314 |
315 | def main(_):
316 |
317 | params = bts_parameters(
318 | encoder=args.encoder,
319 | height=args.input_height,
320 | width=args.input_width,
321 | batch_size=None,
322 | dataset=args.dataset,
323 | max_depth=args.max_depth,
324 | num_gpus=None,
325 | num_threads=None,
326 | num_epochs=None)
327 |
328 | test(params)
329 |
330 |
331 | if __name__ == '__main__':
332 | tf.app.run()
333 |
334 |
335 |
336 |
--------------------------------------------------------------------------------
/tensorflow/bts_main.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import os
20 | import argparse
21 | import time
22 | import datetime
23 | import sys
24 |
25 | from average_gradients import *
26 | from tensorflow.python import pywrap_tensorflow
27 | from bts_dataloader import *
28 |
29 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
30 |
31 |
32 | def convert_arg_line_to_args(arg_line):
33 | for arg in arg_line.split():
34 | if not arg.strip():
35 | continue
36 | yield arg
37 |
38 |
39 | parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.', fromfile_prefix_chars='@')
40 | parser.convert_arg_line_to_args = convert_arg_line_to_args
41 |
42 | parser.add_argument('--mode', type=str, help='train or test', default='train')
43 | parser.add_argument('--model_name', type=str, help='model name', default='bts_eigen_v2')
44 | parser.add_argument('--encoder', type=str, help='type of encoder, desenet121_bts, densenet161_bts, resnet101_bts or resnet50_bts', default='densenet161_bts')
45 | parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
46 | parser.add_argument('--data_path', type=str, help='path to the data', required=False)
47 | parser.add_argument('--gt_path', type=str, help='path to the groundtruth data', required=False)
48 | parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=False)
49 | parser.add_argument('--input_height', type=int, help='input height', default=480)
50 | parser.add_argument('--input_width', type=int, help='input width', default=640)
51 | parser.add_argument('--batch_size', type=int, help='batch size', default=4)
52 | parser.add_argument('--num_epochs', type=int, help='number of epochs', default=50)
53 | parser.add_argument('--learning_rate', type=float, help='initial learning rate', default=1e-4)
54 | parser.add_argument('--end_learning_rate', type=float, help='end learning rate', default=-1)
55 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
56 | parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
57 | parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
58 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
59 | parser.add_argument('--num_gpus', type=int, help='number of GPUs to use for training', default=1)
60 | parser.add_argument('--num_threads', type=int, help='number of threads to use for data loading', default=1)
61 | parser.add_argument('--log_directory', type=str, help='directory to save checkpoints and summaries', default='')
62 | parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
63 | parser.add_argument('--pretrained_model', type=str, help='path to a pretrained model checkpoint to load', default='')
64 | parser.add_argument('--retrain', help='if used with checkpoint_path, will restart training from step zero', action='store_true')
65 | parser.add_argument('--fix_first_conv_blocks', help='if set, will fix the first two conv blocks', action='store_true')
66 | parser.add_argument('--fix_first_conv_block', help='if set, will fix the first conv block', action='store_true')
67 |
68 | if sys.argv.__len__() == 2:
69 | arg_filename_with_prefix = '@' + sys.argv[1]
70 | args = parser.parse_args([arg_filename_with_prefix])
71 | else:
72 | args = parser.parse_args()
73 |
74 | if args.mode == 'train' and not args.checkpoint_path:
75 | from bts import *
76 |
77 | elif args.mode == 'train' and args.checkpoint_path:
78 | model_dir = os.path.dirname(args.checkpoint_path)
79 | model_name = os.path.basename(model_dir)
80 | import sys
81 | sys.path.append(model_dir)
82 | for key, val in vars(__import__(model_name)).items():
83 | if key.startswith('__') and key.endswith('__'):
84 | continue
85 | vars()[key] = val
86 |
87 |
88 | def get_num_lines(file_path):
89 | f = open(file_path, 'r')
90 | lines = f.readlines()
91 | f.close()
92 | return len(lines)
93 |
94 |
95 | def get_tensors_in_checkpoint_file(file_name, all_tensors=True, tensor_name=None):
96 | varlist = []
97 | var_value = []
98 | reader = pywrap_tensorflow.NewCheckpointReader(file_name)
99 | if all_tensors:
100 | var_to_shape_map = reader.get_variable_to_shape_map()
101 | for key in sorted(var_to_shape_map):
102 | varlist.append(key)
103 | var_value.append(reader.get_tensor(key))
104 | else:
105 | varlist.append(tensor_name)
106 | var_value.append(reader.get_tensor(tensor_name))
107 | return (varlist, var_value)
108 |
109 |
110 | def build_tensors_in_checkpoint_file(loaded_tensors):
111 | full_var_list = list()
112 | var_check = set()
113 | # Loop all loaded tensors
114 | for i, tensor_name in enumerate(loaded_tensors[0]):
115 | # Extract tensor
116 | try:
117 | tensor_aux = tf.get_default_graph().get_tensor_by_name(tensor_name+":0")
118 | except:
119 | print(tensor_name + ' is in pretrained model but not in current training model')
120 | if tensor_aux not in var_check:
121 | full_var_list.append(tensor_aux)
122 | var_check.add(tensor_aux)
123 | return full_var_list
124 |
125 |
126 | def train(params):
127 |
128 | with tf.Graph().as_default(), tf.device('/cpu:0'):
129 |
130 | global_step = tf.Variable(0, trainable=False)
131 |
132 | num_training_samples = get_num_lines(args.filenames_file)
133 |
134 | steps_per_epoch = np.ceil(num_training_samples / params.batch_size).astype(np.int32)
135 | num_total_steps = params.num_epochs * steps_per_epoch
136 | start_learning_rate = args.learning_rate
137 |
138 | end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else start_learning_rate * 0.1
139 | learning_rate = tf.train.polynomial_decay(start_learning_rate, global_step, num_total_steps, end_learning_rate, 0.9)
140 |
141 | opt_step = tf.train.AdamOptimizer(learning_rate, epsilon=1e-8)
142 |
143 | print("Total number of samples: {}".format(num_training_samples))
144 | print("Total number of steps: {}".format(num_total_steps))
145 |
146 | if args.fix_first_conv_blocks or args.fix_first_conv_block:
147 | if args.fix_first_conv_blocks:
148 | print('Fixing first two conv blocks')
149 | else:
150 | print('Fixing first conv block')
151 |
152 | dataloader = BtsDataloader(args.data_path, args.gt_path, args.filenames_file, params, args.mode,
153 | do_rotate=args.do_random_rotate, degree=args.degree,
154 | do_kb_crop=args.do_kb_crop)
155 |
156 | dataloader_iter = dataloader.loader.make_initializable_iterator()
157 | iter_init_op = dataloader_iter.initializer
158 |
159 | tower_grads = []
160 | tower_losses = []
161 | reuse_variables = None
162 |
163 | with tf.variable_scope(tf.get_variable_scope()):
164 | for i in range(args.num_gpus):
165 | with tf.device('/gpu:%d' % i):
166 | image, depth_gt, focal = dataloader_iter.get_next()
167 | model = BtsModel(params, args.mode, image, depth_gt, focal=focal,
168 | reuse_variables=reuse_variables, model_index=i, bn_training=False)
169 |
170 | loss = model.total_loss
171 | tower_losses.append(loss)
172 |
173 | reuse_variables = True
174 |
175 | if args.fix_first_conv_blocks or args.fix_first_conv_block:
176 | trainable_vars = tf.trainable_variables()
177 | if args.encoder == 'resnet101_bts' or args.encoder == 'resnet50_bts':
178 | first_conv_name = args.encoder.replace('_bts', '') + '/conv1'
179 | if args.fix_first_conv_blocks:
180 | g_vars = [var for var in
181 | trainable_vars if (first_conv_name or 'block1' or 'block2') not in var.name]
182 | else:
183 | g_vars = [var for var in
184 | trainable_vars if (first_conv_name or 'block1') not in var.name]
185 | else:
186 | if args.fix_first_conv_blocks:
187 | g_vars = [var for var in
188 | trainable_vars if ('conv1' or 'dense_block1' or 'dense_block2' or 'transition_block1' or 'transition_block2') not in var.name]
189 | else:
190 | g_vars = [var for var in
191 | trainable_vars if ('dense_block1' or 'transition_block1') not in var.name]
192 | else:
193 | g_vars = None
194 |
195 | grads = opt_step.compute_gradients(loss, var_list=g_vars)
196 |
197 | tower_grads.append(grads)
198 |
199 | with tf.variable_scope(tf.get_variable_scope()):
200 | with tf.device('/gpu:%d' % (args.num_gpus - 1)):
201 | grads = average_gradients(tower_grads)
202 | apply_gradient_op = opt_step.apply_gradients(grads, global_step=global_step)
203 | total_loss = tf.reduce_mean(tower_losses)
204 |
205 | tf.summary.scalar('learning_rate', learning_rate, ['model_0'])
206 | tf.summary.scalar('total_loss', total_loss, ['model_0'])
207 | summary_op = tf.summary.merge_all('model_0')
208 |
209 | config = tf.ConfigProto(allow_soft_placement=True)
210 | config.gpu_options.allow_growth = True
211 | sess = tf.Session(config=config)
212 |
213 | summary_writer = tf.summary.FileWriter(args.log_directory + '/' + args.model_name, sess.graph)
214 | train_saver = tf.train.Saver(max_to_keep=200)
215 |
216 | total_num_parameters = 0
217 | for variable in tf.trainable_variables():
218 | total_num_parameters += np.array(variable.get_shape().as_list()).prod()
219 |
220 | print("Total number of trainable parameters: {}".format(total_num_parameters))
221 |
222 | sess.run(tf.global_variables_initializer())
223 | sess.run(tf.local_variables_initializer())
224 |
225 | coordinator = tf.train.Coordinator()
226 | threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)
227 |
228 | if args.pretrained_model != '':
229 | vars_to_restore = get_tensors_in_checkpoint_file(file_name=args.pretrained_model)
230 | tensors_to_load = build_tensors_in_checkpoint_file(vars_to_restore)
231 | loader = tf.train.Saver(tensors_to_load)
232 | loader.restore(sess, args.pretrained_model)
233 |
234 | # Load checkpoint if set
235 | if args.checkpoint_path != '':
236 | restore_path = args.checkpoint_path
237 | train_saver.restore(sess, restore_path)
238 |
239 | if args.retrain:
240 | sess.run(global_step.assign(0))
241 |
242 | start_step = global_step.eval(session=sess)
243 | start_time = time.time()
244 | duration = 0
245 | should_init_iter_op = False
246 | if args.mode == 'train':
247 | should_init_iter_op = True
248 | for step in range(start_step, num_total_steps):
249 | before_op_time = time.time()
250 | if step % steps_per_epoch == 0 or should_init_iter_op is True:
251 | sess.run(iter_init_op)
252 | should_init_iter_op = False
253 |
254 | _, lr, loss_value = sess.run([apply_gradient_op, learning_rate, total_loss])
255 |
256 | print('step: {}/{}, lr: {:.12f}, loss: {:.12f}'.format(step, num_total_steps, lr, loss_value))
257 |
258 | duration += time.time() - before_op_time
259 | if step and step % 100 == 0:
260 | examples_per_sec = params.batch_size / duration * 100
261 | duration = 0
262 | time_sofar = (time.time() - start_time) / 3600
263 | training_time_left = (num_total_steps / step - 1.0) * time_sofar
264 | print('%s:' % args.model_name)
265 | print_string = 'examples/s: {:4.2f} | loss: {:.5f} | time elapsed: {:.2f}h | time left: {:.2f}h'
266 | print(print_string.format(examples_per_sec, loss_value, time_sofar, training_time_left))
267 | summary_str = sess.run(summary_op)
268 | summary_writer.add_summary(summary_str, global_step=step)
269 | summary_writer.flush()
270 |
271 | if step and step % 500 == 0:
272 | train_saver.save(sess, args.log_directory + '/' + args.model_name + '/model', global_step=step)
273 |
274 | train_saver.save(sess, args.log_directory + '/' + args.model_name + '/model', global_step=num_total_steps)
275 | print('%s training finished' % args.model_name)
276 | print(datetime.datetime.now())
277 |
278 |
279 | def main(_):
280 |
281 | params = bts_parameters(
282 | encoder=args.encoder,
283 | height=args.input_height,
284 | width=args.input_width,
285 | batch_size=args.batch_size,
286 | dataset=args.dataset,
287 | max_depth=args.max_depth,
288 | num_gpus=args.num_gpus,
289 | num_threads=args.num_threads,
290 | num_epochs=args.num_epochs)
291 |
292 | if args.mode == 'train':
293 | model_filename = args.model_name + '.py'
294 | command = 'mkdir ' + args.log_directory + '/' + args.model_name
295 | os.system(command)
296 |
297 | custom_layer_path = args.log_directory + '/' + args.model_name + '/' + 'custom_layer'
298 | command = 'mkdir ' + custom_layer_path
299 | os.system(command)
300 |
301 | command = 'cp ' + './custom_layer/* ' + custom_layer_path + '/.'
302 | os.system(command)
303 |
304 | args_out_path = args.log_directory + '/' + args.model_name + '/' + sys.argv[1]
305 | command = 'cp ' + sys.argv[1] + ' ' + args_out_path
306 | os.system(command)
307 |
308 | if args.checkpoint_path == '':
309 | model_out_path = args.log_directory + '/' + args.model_name + '/' + model_filename
310 | command = 'cp bts.py ' + model_out_path
311 | os.system(command)
312 | else:
313 | loaded_model_dir = os.path.dirname(args.checkpoint_path)
314 | loaded_model_name = os.path.basename(loaded_model_dir)
315 | loaded_model_filename = loaded_model_name + '.py'
316 |
317 | model_out_path = args.log_directory + '/' + args.model_name + '/' + model_filename
318 | command = 'cp ' + loaded_model_dir + '/' + loaded_model_filename + ' ' + model_out_path
319 | os.system(command)
320 |
321 | train(params)
322 |
323 | elif args.mode == 'test':
324 | print('This script does not support testing. Use bts_test.py instead.')
325 |
326 |
327 | if __name__ == '__main__':
328 | tf.app.run()
329 |
--------------------------------------------------------------------------------
/tensorflow/bts_sequence.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import os
20 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
21 |
22 | import numpy as np
23 | import argparse
24 | import time
25 | import glob
26 | import cv2
27 | import errno
28 | import matplotlib.pyplot as plt
29 | import sys
30 | import tensorflow as tf
31 | import tqdm
32 |
33 | from bts_dataloader import *
34 |
35 | parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.')
36 |
37 | parser.add_argument('--model_name', type=str, help='model name', default='bts_v0_0_1')
38 | parser.add_argument('--encoder', type=str, help='type of encoder, densenet121_bts or densenet161_bts', default='densenet161_bts')
39 | parser.add_argument('--dataset', type=str, help='dataset to test, kitti or nyu', default='')
40 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=80)
41 | parser.add_argument('--focal', type=float, help='focal length in pixels', default=-1)
42 | parser.add_argument('--image_path', type=str, help='image sequence path', required=True)
43 | parser.add_argument('--out_path', type=str, help='output path', required=True)
44 | parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', required=True)
45 | parser.add_argument('--input_height', type=int, help='input height', default=480)
46 | parser.add_argument('--input_width', type=int, help='input width', default=640)
47 |
48 | args = parser.parse_args()
49 |
50 | model_dir = os.path.dirname(args.checkpoint_path)
51 |
52 | sys.path.append(model_dir)
53 | for key, val in vars(__import__(args.model_name)).items():
54 | if key.startswith('__') and key.endswith('__'):
55 | continue
56 | vars()[key] = val
57 |
58 |
59 | def test_sequence(params):
60 | image_files = []
61 |
62 | for filename in glob.glob(os.path.join(args.image_path, '*.png')):
63 | image_files.append(filename)
64 |
65 | image_files.sort()
66 |
67 | num_test_samples = len(image_files)
68 | if num_test_samples == 0:
69 | print("No images found! Program abort.")
70 | return
71 |
72 | if args.dataset == 'nyu':
73 | focal = 518.8579
74 | elif args.dataset == 'kitti':
75 | focal = 718.856 # Visualize purpose only
76 | elif args.dataset == '' and args.focal == -1:
77 | print('Custom dataset needs to specify focal length with --focal')
78 | return
79 |
80 | image = tf.placeholder(tf.float32, [1, args.input_height, args.input_width, 3])
81 | focals = tf.constant([focal])
82 |
83 | model = BtsModel(params, 'test', image, None, focal=focals, bn_training=False)
84 |
85 | # SESSION
86 | config = tf.ConfigProto(allow_soft_placement=True)
87 | sess = tf.Session(config=config)
88 |
89 | # INIT
90 | sess.run(tf.global_variables_initializer())
91 | sess.run(tf.local_variables_initializer())
92 | coordinator = tf.train.Coordinator()
93 | threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)
94 |
95 | # SAVER
96 | train_saver = tf.train.Saver()
97 |
98 | with tf.device('/cpu:0'):
99 | restore_path = args.checkpoint_path
100 |
101 | # RESTORE
102 | train_saver.restore(sess, restore_path)
103 |
104 | print('now testing {} files for model {}'.format(num_test_samples, args.checkpoint_path))
105 |
106 | print('Saving result pngs')
107 | if not os.path.exists(os.path.dirname(args.out_path)):
108 | try:
109 | os.mkdir(args.out_path)
110 | os.mkdir(args.out_path + '/depth')
111 | os.mkdir(args.out_path + '/reduc1x1')
112 | os.mkdir(args.out_path + '/lpg2x2')
113 | os.mkdir(args.out_path + '/lpg4x4')
114 | os.mkdir(args.out_path + '/lpg8x8')
115 | os.mkdir(args.out_path + '/rgb')
116 | except OSError as e:
117 | if e.errno != errno.EEXIST:
118 | raise
119 |
120 | start_time = time.time()
121 | for s in tqdm(range(num_test_samples)):
122 | input_image = cv2.imread(image_files[s])
123 |
124 | if args.dataset == 'kitti':
125 | height, width, ch = input_image.shape
126 | top_margin = int(height - 352)
127 | left_margin = int((width - 1216) / 2)
128 | input_image = input_image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
129 |
130 | input_image_original = input_image
131 | input_image = input_image.astype(np.float32)
132 |
133 | # Normalize image
134 | input_image[:, :, 0] = (input_image[:, :, 0] - 103.939) * 0.017
135 | input_image[:, :, 1] = (input_image[:, :, 1] - 116.779) * 0.017
136 | input_image[:, :, 2] = (input_image[:, :, 2] - 123.68) * 0.017
137 |
138 | input_images = np.reshape(input_image, (1, args.input_height, args.input_width, 3))
139 |
140 | depth, pred_8x8, pred_4x4, pred_2x2, pred_1x1 = sess.run(
141 | [model.depth_est, model.lpg8x8, model.lpg4x4, model.lpg2x2, model.reduc1x1], feed_dict={image: input_images})
142 |
143 | pred_depth = depth.squeeze()
144 | pred_8x8 = pred_8x8.squeeze()
145 | pred_4x4 = pred_4x4.squeeze()
146 | pred_2x2 = pred_2x2.squeeze()
147 | pred_1x1 = pred_1x1.squeeze()
148 |
149 | save_path = os.path.join(args.out_path, 'depth', image_files[s].split('/')[-1])
150 | plt.imsave(save_path, np.log10(pred_depth), cmap='Greys')
151 |
152 | save_path = os.path.join(args.out_path, 'rgb', image_files[s].split('/')[-1])
153 | cv2.imwrite(save_path, input_image_original)
154 |
155 | save_path = os.path.join(args.out_path, 'reduc1x1', image_files[s].split('/')[-1])
156 | plt.imsave(save_path, np.log10(pred_1x1), cmap='Greys')
157 |
158 | save_path = os.path.join(args.out_path, 'lpg2x2', image_files[s].split('/')[-1])
159 | plt.imsave(save_path, np.log10(pred_2x2), cmap='Greys')
160 |
161 | save_path = os.path.join(args.out_path, 'lpg4x4', image_files[s].split('/')[-1])
162 | plt.imsave(save_path, np.log10(pred_4x4), cmap='Greys')
163 |
164 | save_path = os.path.join(args.out_path, 'lpg8x8', image_files[s].split('/')[-1])
165 | plt.imsave(save_path, np.log10(pred_8x8), cmap='Greys')
166 |
167 | print('{}/{}'.format(s, num_test_samples))
168 |
169 | elapsed_time = time.time() - start_time
170 | print('Elapesed time: %s' % str(elapsed_time))
171 | print('done.')
172 |
173 | def main(_):
174 |
175 | params = bts_parameters(
176 | encoder=args.encoder,
177 | height=args.input_height,
178 | width=args.input_width,
179 | batch_size=None,
180 | dataset=None,
181 | max_depth=args.max_depth,
182 | num_gpus=None,
183 | num_threads=None,
184 | num_epochs=None,
185 | )
186 |
187 | test_sequence(params)
188 |
189 | if __name__ == '__main__':
190 | tf.app.run()
191 |
--------------------------------------------------------------------------------
/tensorflow/bts_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import os
20 | import numpy as np
21 | import argparse
22 | import time
23 | import tensorflow as tf
24 | import errno
25 | import matplotlib.pyplot as plt
26 | import cv2
27 | import sys
28 | import tqdm
29 |
30 | from bts_dataloader import *
31 |
32 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
33 |
34 |
35 | def convert_arg_line_to_args(arg_line):
36 | for arg in arg_line.split():
37 | if not arg.strip():
38 | continue
39 | yield arg
40 |
41 |
42 | parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.', fromfile_prefix_chars='@')
43 | parser.convert_arg_line_to_args = convert_arg_line_to_args
44 |
45 | parser.add_argument('--model_name', type=str, help='model name', default='bts_nyu_v2')
46 | parser.add_argument('--encoder', type=str, help='type of encoder, vgg or desenet121_bts or densenet161_bts', default='densenet161_bts')
47 | parser.add_argument('--data_path', type=str, help='path to the data', required=True)
48 | parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
49 | parser.add_argument('--input_height', type=int, help='input height', default=480)
50 | parser.add_argument('--input_width', type=int, help='input width', default=640)
51 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=80)
52 | parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='')
53 | parser.add_argument('--dataset', type=str, help='dataset to train on, make3d or nyudepthv2', default='nyu')
54 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
55 | parser.add_argument('--save_lpg', help='if set, save outputs from lpg layers', action='store_true')
56 |
57 | if sys.argv.__len__() == 2:
58 | arg_filename_with_prefix = '@' + sys.argv[1]
59 | args = parser.parse_args([arg_filename_with_prefix])
60 | else:
61 | args = parser.parse_args()
62 |
63 | model_dir = os.path.dirname(args.checkpoint_path)
64 | sys.path.append(model_dir)
65 |
66 | for key, val in vars(__import__(args.model_name)).items():
67 | if key.startswith('__') and key.endswith('__'):
68 | continue
69 | vars()[key] = val
70 |
71 |
72 | def get_num_lines(file_path):
73 | f = open(file_path, 'r')
74 | lines = f.readlines()
75 | f.close()
76 | return len(lines)
77 |
78 |
79 | def test(params):
80 | """Test function."""
81 |
82 | dataloader = BtsDataloader(args.data_path, None, args.filenames_file, params, 'test', do_kb_crop=args.do_kb_crop)
83 |
84 | dataloader_iter = dataloader.loader.make_initializable_iterator()
85 | iter_init_op = dataloader_iter.initializer
86 | image, focal = dataloader_iter.get_next()
87 |
88 | model = BtsModel(params, 'test', image, None, focal=focal, bn_training=False)
89 |
90 | # SESSION
91 | config = tf.ConfigProto(allow_soft_placement=True)
92 | sess = tf.Session(config=config)
93 |
94 | # INIT
95 | sess.run(tf.global_variables_initializer())
96 | sess.run(tf.local_variables_initializer())
97 |
98 | # SAVER
99 | train_saver = tf.train.Saver()
100 |
101 | with tf.device('/cpu:0'):
102 | restore_path = args.checkpoint_path
103 |
104 | # RESTORE
105 | train_saver.restore(sess, restore_path)
106 |
107 | num_test_samples = get_num_lines(args.filenames_file)
108 |
109 | with open(args.filenames_file) as f:
110 | lines = f.readlines()
111 |
112 | print('Now testing {} files with {}'.format(num_test_samples, args.checkpoint_path))
113 | sess.run(iter_init_op)
114 |
115 | pred_depths = []
116 | pred_8x8s = []
117 | pred_4x4s = []
118 | pred_2x2s = []
119 |
120 | start_time = time.time()
121 | print('Processing images..')
122 | for s in tqdm(range(num_test_samples)):
123 | depth, pred_8x8, pred_4x4, pred_2x2 = sess.run([model.depth_est, model.lpg8x8, model.lpg4x4, model.lpg2x2])
124 | pred_depths.append(depth[0].squeeze())
125 |
126 | pred_8x8s.append(pred_8x8[0].squeeze())
127 | pred_4x4s.append(pred_4x4[0].squeeze())
128 | pred_2x2s.append(pred_2x2[0].squeeze())
129 |
130 | print('Done.')
131 |
132 | save_name = 'result_' + args.model_name
133 |
134 | print('Saving result pngs..')
135 | if not os.path.exists(os.path.dirname(save_name)):
136 | try:
137 | os.mkdir(save_name)
138 | os.mkdir(save_name + '/raw')
139 | os.mkdir(save_name + '/cmap')
140 | os.mkdir(save_name + '/rgb')
141 | os.mkdir(save_name + '/gt')
142 | except OSError as e:
143 | if e.errno != errno.EEXIST:
144 | raise
145 |
146 | for s in tqdm(range(num_test_samples)):
147 | if args.dataset == 'kitti':
148 | date_drive = lines[s].split('/')[1]
149 | filename_pred_png = save_name + '/raw/' + date_drive + '_' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
150 | filename_cmap_png = save_name + '/cmap/' + date_drive + '_' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
151 | filename_image_png = save_name + '/rgb/' + date_drive + '_' + lines[s].split()[0].split('/')[-1]
152 | elif args.dataset == 'kitti_benchmark':
153 | filename_pred_png = save_name + '/raw/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
154 | filename_cmap_png = save_name + '/cmap/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
155 | filename_image_png = save_name + '/rgb/' + lines[s].split()[0].split('/')[-1]
156 | else:
157 | scene_name = lines[s].split()[0].split('/')[0]
158 | filename_pred_png = save_name + '/raw/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace('.jpg', '.png')
159 | filename_cmap_png = save_name + '/cmap/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace('.jpg', '.png')
160 | filename_gt_png = save_name + '/gt/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace('.jpg', '.png')
161 | filename_image_png = save_name + '/rgb/' + scene_name + '_' + lines[s].split()[0].split('/')[1]
162 |
163 | rgb_path = os.path.join(args.data_path, lines[s].split()[0])
164 | image = cv2.imread(rgb_path)
165 | if args.dataset == 'nyu':
166 | gt_path = os.path.join(args.data_path, lines[s].split()[1])
167 | gt = cv2.imread(gt_path, -1).astype(np.float32) / 1000.0 # Visualization purpose only
168 | gt[gt == 0] = np.amax(gt)
169 |
170 | pred_depth = pred_depths[s]
171 | pred_8x8 = pred_8x8s[s]
172 | pred_4x4 = pred_4x4s[s]
173 | pred_2x2 = pred_2x2s[s]
174 |
175 | if args.dataset == 'kitti' or args.dataset == 'kitti_benchmark':
176 | pred_depth_scaled = pred_depth * 256.0
177 | else:
178 | pred_depth_scaled = pred_depth * 1000.0
179 |
180 | pred_depth_scaled = pred_depth_scaled.astype(np.uint16)
181 | cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0])
182 |
183 | if args.save_lpg:
184 | cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :])
185 | if args.dataset == 'nyu':
186 | plt.imsave(filename_gt_png, np.log10(gt[10:-1 - 9, 10:-1 - 9]), cmap='Greys')
187 | pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9]
188 | plt.imsave(filename_cmap_png, np.log10(pred_depth_cropped), cmap='Greys')
189 | pred_8x8_cropped = pred_8x8[10:-1 - 9, 10:-1 - 9]
190 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_8x8.png')
191 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_8x8_cropped), cmap='Greys')
192 | pred_4x4_cropped = pred_4x4[10:-1 - 9, 10:-1 - 9]
193 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_4x4.png')
194 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4_cropped), cmap='Greys')
195 | pred_2x2_cropped = pred_2x2[10:-1 - 9, 10:-1 - 9]
196 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_2x2.png')
197 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_2x2_cropped), cmap='Greys')
198 | else:
199 | plt.imsave(filename_cmap_png, np.log10(pred_depth), cmap='Greys')
200 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_8x8.png')
201 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_8x8), cmap='Greys')
202 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_4x4.png')
203 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4), cmap='Greys')
204 | filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_2x2.png')
205 | plt.imsave(filename_lpg_cmap_png, np.log10(pred_2x2), cmap='Greys')
206 |
207 | return
208 |
209 |
210 | def main(_):
211 |
212 | params = bts_parameters(
213 | encoder=args.encoder,
214 | height=args.input_height,
215 | width=args.input_width,
216 | batch_size=None,
217 | dataset=args.dataset,
218 | max_depth=args.max_depth,
219 | num_gpus=None,
220 | num_threads=None,
221 | num_epochs=None)
222 |
223 | test(params)
224 |
225 |
226 | if __name__ == '__main__':
227 | tf.app.run()
228 |
229 |
230 |
231 |
--------------------------------------------------------------------------------
/tensorflow/custom_layer/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | cmake_minimum_required(VERSION 2.8)
18 |
19 | # get tensorflow include dirs, see https://www.tensorflow.org/how_tos/adding_an_op/
20 | execute_process(COMMAND python -c "import tensorflow; print(tensorflow.sysconfig.get_include())" OUTPUT_VARIABLE TF_INC)
21 | execute_process(COMMAND python -c "import tensorflow; print(tensorflow.sysconfig.get_lib())" OUTPUT_VARIABLE TF_LIB)
22 | execute_process(COMMAND python -c "import tensorflow; print(tensorflow.sysconfig.get_compile_flags())" OUTPUT_VARIABLE TF_COMPILE_FLAGS)
23 | execute_process(COMMAND python -c "import tensorflow; print(tensorflow.sysconfig.get_link_flags())" OUTPUT_VARIABLE TF_LINK_FLAGS)
24 |
25 | if(WIN32)
26 | string(REPLACE "\r\n" "" TF_INC "${TF_INC}")
27 | string(REPLACE "\r\n" "" TF_LIB "${TF_LIB}")
28 | string(REPLACE "\r\n" "" TF_COMPILE_FLAGS "${TF_COMPILE_FLAGS}")
29 | string(REPLACE "\r\n" "" TF_LINK_FLAGS "${TF_LINK_FLAGS}")
30 | else(WIN32)
31 | string(REPLACE "\n" "" TF_INC "${TF_INC}")
32 | string(REPLACE "\n" "" TF_LIB "${TF_LIB}")
33 | string(REPLACE "\n" "" TF_COMPILE_FLAGS "${TF_COMPILE_FLAGS}")
34 | string(REPLACE "\n" "" TF_LINK_FLAGS "${TF_LINK_FLAGS}")
35 | endif(WIN32)
36 |
37 | string(REPLACE "[" "" TF_COMPILE_FLAGS "${TF_COMPILE_FLAGS}")
38 | string(REPLACE "]" "" TF_COMPILE_FLAGS "${TF_COMPILE_FLAGS}")
39 | string(REPLACE "'" "" TF_COMPILE_FLAGS "${TF_COMPILE_FLAGS}")
40 | string(REPLACE "," "" TF_COMPILE_FLAGS "${TF_COMPILE_FLAGS}")
41 |
42 | string(REPLACE "[" "" TF_LINK_FLAGS "${TF_LINK_FLAGS}")
43 | string(REPLACE "]" "" TF_LINK_FLAGS "${TF_LINK_FLAGS}")
44 | string(REPLACE "'" "" TF_LINK_FLAGS "${TF_LINK_FLAGS}")
45 | string(REPLACE "," "" TF_LINK_FLAGS "${TF_LINK_FLAGS}")
46 |
47 | message("TF_INC: ${TF_INC}")
48 | message("TF_LIB: ${TF_LIB}")
49 | message("TF_COMPILE_FLAGS: ${TF_COMPILE_FLAGS}")
50 | message("TF_LINK_FLAGS: ${TF_LINK_FLAGS}")
51 |
52 | find_package(CUDA)
53 |
54 | set(TF_NSYNC_INC "${TF_INC}/external/nsync/public/")
55 | message(${TF_NSYNC_INC})
56 | include_directories(${TF_NSYNC_INC})
57 |
58 | # C++11 required for tensorflow
59 | set(CMAKE_CXX_FLAGS "-std=c++11 -fPIC ${CMAKE_CXX_FLAGS} ${TF_COMPILE_FLAGS} -DGOOGLE_CUDA")
60 | message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
61 |
62 | #pass flags to c++ compiler
63 | set(CUDA_PROPAGATE_HOST_FLAGS OFF)
64 | list(APPEND CUDA_NVCC_FLAGS "-std=c++11 -DNDEBUG")
65 |
66 | # build the actual operation which can be used directory
67 | include_directories(${TF_INC})
68 | link_directories(${TF_LIB})
69 | include_directories("/usr/local/")
70 |
71 | cuda_add_library(lpg SHARED local_planar_guidance.cu local_planar_guidance.cc
72 | OPTIONS -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr)
73 | target_link_libraries(lpg ${TF_LINK_FLAGS})
74 |
--------------------------------------------------------------------------------
/tensorflow/custom_layer/_local_planar_guidance_grad.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | import tensorflow as tf
18 | from tensorflow.python.framework import ops
19 |
20 | lpg = tf.load_op_library('custom_layer/build/liblpg.so')
21 |
22 | @ops.RegisterGradient("LocalPlanarGuidance")
23 | def _local_planar_guidance_grad_cc(op, depth_grad):
24 | """
25 | The gradient for `local_planar_guidance` using the operation implemented in C++.
26 |
27 | :param op: `local_planar_guidance` `Operation` that we are differentiating, which we can use
28 | to find the inputs and outputs of the original op.
29 | :param grad: gradient with respect to the output of the `local_planar_guidance` op.
30 | :return: gradients with respect to the input of `local_planar_guidance`.
31 | """
32 |
33 | return lpg.local_planar_guidance_grad(depth_grad, op.inputs[0], op.inputs[1])
34 |
--------------------------------------------------------------------------------
/tensorflow/custom_layer/local_planar_guidance.cu:
--------------------------------------------------------------------------------
1 | /**********************************************************************
2 | Copyright (C) 2019 Jin Han Lee
3 |
4 | This file is a part of BTS.
5 | This program is free software: you can redistribute it and/or modify
6 | it under the terms of the GNU General Public License as published by
7 | the Free Software Foundation, either version 3 of the License, or
8 | (at your option) any later version.
9 |
10 | This program is distributed in the hope that it will be useful,
11 | but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | GNU General Public License for more details.
14 |
15 | You should have received a copy of the GNU General Public License
16 | along with this program. If not, see
17 | ***********************************************************************/
18 |
19 | #ifdef GOOGLE_CUDA
20 | #define EIGEN_USE_GPU
21 | #include "local_planar_guidance.h"
22 | #include "tensorflow/core/util/cuda_kernel_helper.h" // tf <= 1.13.2
23 | // #include "tensorflow/core/util/gpu_kernel_helper.h" // tf >= 1.14.0
24 |
25 | using namespace tensorflow;
26 |
27 | using GPUDevice = Eigen::GpuDevice;
28 |
29 | template struct functor::LocalPlanarGuidanceKernel;
30 |
31 | template struct functor::LocalPlanarGuidanceGradKernel;
32 |
33 | __global__ void LocalPlanarGuidanceFunctor(const int nthreads,
34 | const int input_height,
35 | const int input_width,
36 | const int depth_height,
37 | const int depth_width,
38 | const float* input,
39 | const float* focal,
40 | float* depth)
41 | {
42 | CUDA_1D_KERNEL_LOOP(index, nthreads)
43 | {
44 | const int num_threads_row = depth_height / input_height;
45 | const int num_threads_col = depth_width / input_width;
46 |
47 | int batch = index;
48 | const int col = (batch % depth_width);
49 | batch /= depth_width;
50 | const int row = (batch % depth_height);
51 | batch /= depth_height;
52 |
53 | const int input_row = row / num_threads_row;
54 | const int input_col = col / num_threads_col;
55 |
56 | float fo = focal[batch];
57 |
58 | float v = ((float)(row % num_threads_row) - (float)(num_threads_row - 1.0f) / 2.0f) / (float)num_threads_row;
59 | float u = ((float)(col % num_threads_col) - (float)(num_threads_col - 1.0f) / 2.0f) / (float)num_threads_col;
60 |
61 | unsigned int input_index = batch*input_height*input_width*4 + input_row*input_width*4 + input_col*4;
62 |
63 | float n1 = input[input_index+0];
64 | float n2 = input[input_index+1];
65 | float n3 = input[input_index+2];
66 | float n4 = input[input_index+3];
67 |
68 | float numerator = n4;
69 | float denominator = (n1*u + n2*v + n3);
70 | depth[index] = numerator / denominator;
71 | }
72 | }
73 |
74 | namespace functor {
75 | template
76 | void LocalPlanarGuidanceKernel::operator()(const GPUDevice& d,
77 | const int batch_size,
78 | const int input_height,
79 | const int input_width,
80 | const int depth_height,
81 | const int depth_width,
82 | const float* input,
83 | const float* focal,
84 | float* depth)
85 | {
86 | const int kThreadsPerBlock = 1024;
87 | const int output_size = batch_size*depth_height*depth_width;
88 | LocalPlanarGuidanceFunctor
89 | <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, d.stream()>>>
90 | (output_size, input_height, input_width, depth_height, depth_width, input, focal, depth);
91 | d.synchronize();
92 | }
93 | }
94 |
95 | __global__ void LocalPlanarGuidanceGradFunctor(const int nthreads,
96 | const int input_height,
97 | const int input_width,
98 | const int depth_height,
99 | const int depth_width,
100 | const float* depth_grad,
101 | const float* input,
102 | const float* focal,
103 | float* grad_input)
104 | {
105 | CUDA_1D_KERNEL_LOOP(index, nthreads)
106 | {
107 | unsigned int num_threads_row = depth_height / input_height;
108 | unsigned int num_threads_col = depth_width / input_width;
109 |
110 | int batch = index;
111 | const int input_col = (batch % input_width);
112 | batch /= input_width;
113 | const int input_row = (batch % input_height);
114 | batch /= input_height;
115 |
116 | grad_input[index * 4 + 0] = 0.0f;
117 | grad_input[index * 4 + 1] = 0.0f;
118 | grad_input[index * 4 + 2] = 0.0f;
119 | grad_input[index * 4 + 3] = 0.0f;
120 |
121 | float n1 = input[index * 4 + 0];
122 | float n2 = input[index * 4 + 1];
123 | float n3 = input[index * 4 + 2];
124 | float n4 = input[index * 4 + 3];
125 |
126 | float fo = focal[batch];
127 |
128 | for(unsigned int r = 0; r < num_threads_row; ++r)
129 | {
130 | for(unsigned int c = 0; c < num_threads_col; ++c)
131 | {
132 | unsigned int col = input_col * num_threads_col + c;
133 | unsigned int row = input_row * num_threads_row + r;
134 |
135 | float v = ((float)(row % num_threads_row) - (float)(num_threads_row - 1.0f) / 2.0f) / (float)num_threads_row;
136 | float u = ((float)(col % num_threads_col) - (float)(num_threads_col - 1.0f) / 2.0f) / (float)num_threads_col;
137 |
138 | unsigned int depth_index = batch*depth_height*depth_width + row*depth_width + col;
139 |
140 | float denominator = n1*u + n2*v + n3;
141 | float denominator_sq = denominator*denominator;
142 |
143 | grad_input[index * 4 + 0] += depth_grad[depth_index] * (-1.0f * u) / denominator_sq;
144 | grad_input[index * 4 + 1] += depth_grad[depth_index] * (-1.0f * v) / denominator_sq;
145 | grad_input[index * 4 + 2] += depth_grad[depth_index] * (-1.0f) / denominator_sq;
146 | grad_input[index * 4 + 3] += depth_grad[depth_index] / denominator;
147 | }
148 | }
149 | }
150 | }
151 |
152 | namespace functor {
153 | template
154 | void LocalPlanarGuidanceGradKernel::operator()(const GPUDevice& d,
155 | const int batch_size,
156 | const int input_height,
157 | const int input_width,
158 | const int depth_height,
159 | const int depth_width,
160 | const float* depth_grad,
161 | const float* input,
162 | const float* focal,
163 | float* depth)
164 | {
165 | const int kThreadsPerBlock = 1024;
166 | const int output_size = batch_size*input_height*input_width;
167 | LocalPlanarGuidanceGradFunctor
168 | <<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, d.stream()>>>
169 | (output_size, input_height, input_width, depth_height, depth_width, depth_grad, input, focal, depth);
170 | d.synchronize();
171 | }
172 | }
173 |
174 | #endif // GOOGLE_CUDA
175 |
--------------------------------------------------------------------------------
/tensorflow/custom_layer/local_planar_guidance.h:
--------------------------------------------------------------------------------
1 | /**********************************************************************
2 | Copyright (C) 2019 Jin Han Lee
3 |
4 | This file is a part of BTS.
5 | This program is free software: you can redistribute it and/or modify
6 | it under the terms of the GNU General Public License as published by
7 | the Free Software Foundation, either version 3 of the License, or
8 | (at your option) any later version.
9 |
10 | This program is distributed in the hope that it will be useful,
11 | but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | GNU General Public License for more details.
14 |
15 | You should have received a copy of the GNU General Public License
16 | along with this program. If not, see
17 | ***********************************************************************/
18 |
19 | #ifndef COMPUTE_DEPTH_H_
20 | #define COMPUTE_DEPTH_H_
21 |
22 | template
23 | struct LocalPlanarGuidanceKernel
24 | {
25 | void operator()(const Device& d,
26 | const int batch_size,
27 | const int input_height,
28 | const int input_width,
29 | const int depth_height,
30 | const int depth_width,
31 | const float* input,
32 | const float* focal,
33 | float* depth);
34 | };
35 |
36 | template
37 | struct LocalPlanarGuidanceGradKernel
38 | {
39 | void operator()(const Device& d,
40 | const int batch_size,
41 | const int input_height,
42 | const int input_width,
43 | const int depth_height,
44 | const int depth_width,
45 | const float* depth_grad,
46 | const float* input,
47 | const float* focal,
48 | float* grad_input);
49 | };
50 |
51 | #if GOOGLE_CUDA
52 | namespace functor { // Trick for GPU implementation forward decralation
53 | template
54 | struct LocalPlanarGuidanceKernel
55 | {
56 | void operator()(const Device& d,
57 | const int batch_size,
58 | const int input_height,
59 | const int input_width,
60 | const int depth_height,
61 | const int depth_width,
62 | const float* input,
63 | const float* focal,
64 | float* depth);
65 | };
66 |
67 | template
68 | struct LocalPlanarGuidanceGradKernel
69 | {
70 | void operator()(const Device& d,
71 | const int batch_size,
72 | const int input_height,
73 | const int input_width,
74 | const int depth_height,
75 | const int depth_width,
76 | const float* depth_grad,
77 | const float* input,
78 | const float* focal,
79 | float* grad_input);
80 | };
81 | }
82 | #endif
83 |
84 | #endif // COMPUTE_DEPTH_H_
--------------------------------------------------------------------------------
/tensorflow/notebooks/example_nyu_v2.py.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Wed Sep 4 03:46:01 2019 \r\n",
13 | "+-----------------------------------------------------------------------------+\r\n",
14 | "| NVIDIA-SMI 418.56 Driver Version: 418.56 CUDA Version: 10.1 |\r\n",
15 | "|-------------------------------+----------------------+----------------------+\r\n",
16 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n",
17 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\r\n",
18 | "|===============================+======================+======================|\r\n",
19 | "| 0 GeForce GTX TIT... Off | 00000000:01:00.0 On | N/A |\r\n",
20 | "| 22% 37C P8 15W / 250W | 413MiB / 12211MiB | 0% Default |\r\n",
21 | "+-------------------------------+----------------------+----------------------+\r\n",
22 | " \r\n",
23 | "+-----------------------------------------------------------------------------+\r\n",
24 | "| Processes: GPU Memory |\r\n",
25 | "| GPU PID Type Process name Usage |\r\n",
26 | "|=============================================================================|\r\n",
27 | "+-----------------------------------------------------------------------------+\r\n"
28 | ]
29 | }
30 | ],
31 | "source": [
32 | "# Check that the GPU exists.\n",
33 | "! nvidia-smi"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 2,
39 | "metadata": {},
40 | "outputs": [
41 | {
42 | "name": "stdout",
43 | "output_type": "stream",
44 | "text": [
45 | "Check: is GPU Available?\n",
46 | "True\n",
47 | "Note that this will still run if it's not, but it will run a lot slower.\n"
48 | ]
49 | }
50 | ],
51 | "source": [
52 | "import tensorflow as tf\n",
53 | "print(\"Check: is GPU Available?\")\n",
54 | "print(tf.test.is_gpu_available())\n",
55 | "print(\"Note that this will still run if it's not, but it will run a lot slower.\")"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 4,
61 | "metadata": {},
62 | "outputs": [
63 | {
64 | "name": "stdout",
65 | "output_type": "stream",
66 | "text": [
67 | "WARNING:tensorflow:From /bts/bts_dataloader.py:68: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
68 | "Instructions for updating:\n",
69 | "Use tf.cast instead.\n",
70 | "WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py:1419: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
71 | "Instructions for updating:\n",
72 | "Colocations handled automatically by placer.\n",
73 | "==================================\n",
74 | " upconv5 in/out: 2208 / 512\n",
75 | " iconv5 in/out: 896 / 512\n",
76 | " upconv4 in/out: 512 / 256\n",
77 | " iconv4 in/out: 448 / 256\n",
78 | " aspp in/out: 896 / 128\n",
79 | "reduc8x8 in/out: 128 / 4\n",
80 | " lpg8x8 in/out: 4 / 1\n",
81 | " upconv3 in/out: 128 / 128\n",
82 | " iconv3 in/out: 225 / 128\n",
83 | "reduc4x4 in/out: 128 / 4\n",
84 | " lpg4x4 in/out: 4 / 1\n",
85 | " upconv2 in/out: 128 / 64\n",
86 | " iconv2 in/out: 161 / 64\n",
87 | "reduc2x2 in/out: 64 / 4\n",
88 | " lpg2x2 in/out: 4 / 1\n",
89 | " upconv1 in/out: 64 / 32\n",
90 | " iconv1 in/out: 35 / 32\n",
91 | " depth in/out: 32 / 1\n",
92 | "==================================\n",
93 | "WARNING:tensorflow:From bts_test.py:97: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
94 | "Instructions for updating:\n",
95 | "To construct input pipelines, use the `tf.data` module.\n",
96 | "WARNING:tensorflow:`tf.train.start_queue_runners()` was called when no queue runners were defined. You can safely remove the call to this deprecated function.\n",
97 | "WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
98 | "Instructions for updating:\n",
99 | "Use standard file APIs to check for files with this prefix.\n",
100 | "Now testing 654 files with ./models/bts_nyu/model\n",
101 | "Processing images..\n",
102 | "100%|██████████| 654/654 [01:05<00:00, 9.94it/s]\n",
103 | "Done.\n",
104 | "Saving result pngs..\n",
105 | " 38%|███▊ | 251/654 [00:56<01:30, 4.43it/s]bts_test.py:189: RuntimeWarning: divide by zero encountered in log10\n",
106 | " plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4_cropped), cmap='Greys')\n",
107 | "100%|██████████| 654/654 [02:25<00:00, 4.51it/s]\n"
108 | ]
109 | }
110 | ],
111 | "source": [
112 | "! cd /bts && python bts_test.py \\\n",
113 | "--encoder densenet161_bts \\\n",
114 | "--data_path /data/nyu_depth_v2/official_splits/test/ \\\n",
115 | "--dataset nyu \\\n",
116 | "--filenames_file ./train_test_inputs/nyudepthv2_test_files_with_gt.txt \\\n",
117 | "--model_name bts_nyu \\\n",
118 | "--checkpoint_path ./models/bts_nyu/model \\\n",
119 | "--input_height 480 \\\n",
120 | "--input_width 640 \\\n",
121 | "--max_depth 10"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": []
130 | }
131 | ],
132 | "metadata": {
133 | "kernelspec": {
134 | "display_name": "Python 2",
135 | "language": "python",
136 | "name": "python2"
137 | },
138 | "language_info": {
139 | "codemirror_mode": {
140 | "name": "ipython",
141 | "version": 2
142 | },
143 | "file_extension": ".py",
144 | "mimetype": "text/x-python",
145 | "name": "python",
146 | "nbconvert_exporter": "python",
147 | "pygments_lexer": "ipython2",
148 | "version": "2.7.15+"
149 | }
150 | },
151 | "nbformat": 4,
152 | "nbformat_minor": 2
153 | }
154 |
--------------------------------------------------------------------------------
/tensorflow/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | scipy
3 | tqdm
4 | requests
5 |
--------------------------------------------------------------------------------
/tensorflow/run_bts_eval_schedule.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | import os
18 | import datetime
19 | from apscheduler.schedulers.blocking import BlockingScheduler
20 | scheduler = BlockingScheduler()
21 |
22 | @scheduler.scheduled_job('interval', minutes=1, start_date=datetime.datetime.now() + datetime.timedelta(0,3))
23 | def run_eval():
24 | command = 'export CUDA_VISIBLE_DEVICES=0; ' \
25 | '/usr/bin/python ' \
26 | 'bts_eval.py ' \
27 | '--encoder densenet161_bts ' \
28 | '--dataset kitti ' \
29 | '--data_path ../../dataset/kitti_dataset/ ' \
30 | '--gt_path ../../dataset/kitti_dataset/data_depth_annotated/ ' \
31 | '--filenames_file ../train_test_inputs/eigen_test_files_with_gt.txt ' \
32 | '--input_height 352 ' \
33 | '--input_width 1216 ' \
34 | '--garg_crop ' \
35 | '--max_depth 80 ' \
36 | '--max_depth_eval 80 ' \
37 | '--output_directory ./models/eval-eigen/ ' \
38 | '--model_name bts_eigen_v0_0_1 ' \
39 | '--checkpoint_path ./models/bts_eigen_v0_0_1/ ' \
40 | '--do_kb_crop '
41 |
42 | print('Executing: %s' % command)
43 | os.system(command)
44 | print('Finished: %s' % datetime.datetime.now())
45 |
46 | scheduler.configure()
47 | scheduler.start()
--------------------------------------------------------------------------------
/utils/download_from_gdrive.py:
--------------------------------------------------------------------------------
1 | # Source: https://stackoverflow.com/a/39225039
2 |
3 | import requests
4 |
5 |
6 | def download_file_from_google_drive(id, destination):
7 | def get_confirm_token(response):
8 | for key, value in response.cookies.items():
9 | if key.startswith('download_warning'):
10 | return value
11 |
12 | return None
13 |
14 | def save_response_content(response, destination):
15 | CHUNK_SIZE = 32768
16 |
17 | with open(destination, "wb") as f:
18 | for chunk in response.iter_content(CHUNK_SIZE):
19 | if chunk: # filter out keep-alive new chunks
20 | f.write(chunk)
21 |
22 | URL = "https://docs.google.com/uc?export=download"
23 |
24 | session = requests.Session()
25 |
26 | response = session.get(URL, params = { 'id' : id }, stream = True)
27 | token = get_confirm_token(response)
28 |
29 | if token:
30 | params = { 'id' : id, 'confirm' : token }
31 | response = session.get(URL, params = params, stream = True)
32 |
33 | save_response_content(response, destination)
34 |
35 |
36 | if __name__ == "__main__":
37 | import sys
38 | if len(sys.argv) is not 3:
39 | print("Usage: python google_drive.py drive_file_id destination_file_path")
40 | else:
41 | # TAKE ID FROM SHAREABLE LINK
42 | file_id = sys.argv[1]
43 | # DESTINATION FILE ON YOUR DISK
44 | destination = sys.argv[2]
45 | download_file_from_google_drive(file_id, destination)
46 |
--------------------------------------------------------------------------------
/utils/eval_with_pngs.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2019 Jin Han Lee
2 | #
3 | # This file is a part of BTS.
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see
16 |
17 | from __future__ import absolute_import, division, print_function
18 |
19 | import os
20 | import argparse
21 | import fnmatch
22 | import cv2
23 | import numpy as np
24 |
25 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
26 |
27 |
28 | def convert_arg_line_to_args(arg_line):
29 | for arg in arg_line.split():
30 | if not arg.strip():
31 | continue
32 | yield arg
33 |
34 |
35 | parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.', fromfile_prefix_chars='@')
36 | parser.convert_arg_line_to_args = convert_arg_line_to_args
37 |
38 | parser.add_argument('--pred_path', type=str, help='path to the prediction results in png', required=True)
39 | parser.add_argument('--gt_path', type=str, help='root path to the groundtruth data', required=False)
40 | parser.add_argument('--dataset', type=str, help='dataset to test on, nyu or kitti', default='nyu')
41 | parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
42 | parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
43 | parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
44 | parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
45 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
46 |
47 | args = parser.parse_args()
48 |
49 |
50 | def compute_errors(gt, pred):
51 | thresh = np.maximum((gt / pred), (pred / gt))
52 | d1 = (thresh < 1.25).mean()
53 | d2 = (thresh < 1.25 ** 2).mean()
54 | d3 = (thresh < 1.25 ** 3).mean()
55 |
56 | rmse = (gt - pred) ** 2
57 | rmse = np.sqrt(rmse.mean())
58 |
59 | rmse_log = (np.log(gt) - np.log(pred)) ** 2
60 | rmse_log = np.sqrt(rmse_log.mean())
61 |
62 | abs_rel = np.mean(np.abs(gt - pred) / gt)
63 | sq_rel = np.mean(((gt - pred)**2) / gt)
64 |
65 | err = np.log(pred) - np.log(gt)
66 | silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
67 |
68 | err = np.abs(np.log10(pred) - np.log10(gt))
69 | log10 = np.mean(err)
70 |
71 | return silog, log10, abs_rel, sq_rel, rmse, rmse_log, d1, d2, d3
72 |
73 |
74 | def test():
75 | global gt_depths, missing_ids, pred_filenames
76 | gt_depths = []
77 | missing_ids = set()
78 | pred_filenames = []
79 |
80 | for root, dirnames, filenames in os.walk(args.pred_path):
81 | for pred_filename in fnmatch.filter(filenames, '*.png'):
82 | if 'cmap' in pred_filename or 'gt' in pred_filename:
83 | continue
84 | dirname = root.replace(args.pred_path, '')
85 | pred_filenames.append(os.path.join(dirname, pred_filename))
86 |
87 | num_test_samples = len(pred_filenames)
88 |
89 | pred_depths = []
90 |
91 | for i in range(num_test_samples):
92 | pred_depth_path = os.path.join(args.pred_path, pred_filenames[i])
93 | pred_depth = cv2.imread(pred_depth_path, -1)
94 | if pred_depth is None:
95 | print('Missing: %s ' % pred_depth_path)
96 | missing_ids.add(i)
97 | continue
98 |
99 | if args.dataset == 'nyu':
100 | pred_depth = pred_depth.astype(np.float32) / 1000.0
101 | else:
102 | pred_depth = pred_depth.astype(np.float32) / 256.0
103 |
104 | pred_depths.append(pred_depth)
105 |
106 | print('Raw png files reading done')
107 | print('Evaluating {} files'.format(len(pred_depths)))
108 |
109 | if args.dataset == 'kitti':
110 | for t_id in range(num_test_samples):
111 | file_dir = pred_filenames[t_id].split('.')[0]
112 | filename = file_dir.split('_')[-1]
113 | directory = file_dir.replace('_' + filename, '')
114 | gt_depth_path = os.path.join(args.gt_path, directory, 'proj_depth/groundtruth/image_02', filename + '.png')
115 | depth = cv2.imread(gt_depth_path, -1)
116 | if depth is None:
117 | print('Missing: %s ' % gt_depth_path)
118 | missing_ids.add(t_id)
119 | continue
120 |
121 | depth = depth.astype(np.float32) / 256.0
122 | gt_depths.append(depth)
123 |
124 | elif args.dataset == 'nyu':
125 | for t_id in range(num_test_samples):
126 | file_dir = pred_filenames[t_id].split('.')[0]
127 | filename = file_dir.split('_')[-1]
128 | directory = file_dir.replace('_rgb_'+file_dir.split('_')[-1], '')
129 | gt_depth_path = os.path.join(args.gt_path, directory, 'sync_depth_' + filename + '.png')
130 | depth = cv2.imread(gt_depth_path, -1)
131 | if depth is None:
132 | print('Missing: %s ' % gt_depth_path)
133 | missing_ids.add(t_id)
134 | continue
135 |
136 | depth = depth.astype(np.float32) / 1000.0
137 | gt_depths.append(depth)
138 |
139 | print('GT files reading done')
140 | print('{} GT files missing'.format(len(missing_ids)))
141 |
142 | print('Computing errors')
143 | eval(pred_depths)
144 |
145 | print('Done.')
146 |
147 |
148 | def eval(pred_depths):
149 |
150 | num_samples = len(pred_depths)
151 | pred_depths_valid = []
152 |
153 | i = 0
154 | for t_id in range(num_samples):
155 | if t_id in missing_ids:
156 | continue
157 |
158 | pred_depths_valid.append(pred_depths[t_id])
159 |
160 | num_samples = num_samples - len(missing_ids)
161 |
162 | silog = np.zeros(num_samples, np.float32)
163 | log10 = np.zeros(num_samples, np.float32)
164 | rms = np.zeros(num_samples, np.float32)
165 | log_rms = np.zeros(num_samples, np.float32)
166 | abs_rel = np.zeros(num_samples, np.float32)
167 | sq_rel = np.zeros(num_samples, np.float32)
168 | d1 = np.zeros(num_samples, np.float32)
169 | d2 = np.zeros(num_samples, np.float32)
170 | d3 = np.zeros(num_samples, np.float32)
171 |
172 | for i in range(num_samples):
173 |
174 | gt_depth = gt_depths[i]
175 | pred_depth = pred_depths_valid[i]
176 |
177 | pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
178 | pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
179 | pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
180 |
181 | gt_depth[np.isinf(gt_depth)] = 0
182 | gt_depth[np.isnan(gt_depth)] = 0
183 |
184 | valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
185 |
186 | if args.do_kb_crop:
187 | height, width = gt_depth.shape
188 | top_margin = int(height - 352)
189 | left_margin = int((width - 1216) / 2)
190 | pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
191 | pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
192 | pred_depth = pred_depth_uncropped
193 |
194 | if args.garg_crop or args.eigen_crop:
195 | gt_height, gt_width = gt_depth.shape
196 | eval_mask = np.zeros(valid_mask.shape)
197 |
198 | if args.garg_crop:
199 | eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
200 |
201 | elif args.eigen_crop:
202 | if args.dataset == 'kitti':
203 | eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
204 | else:
205 | eval_mask[45:471, 41:601] = 1
206 |
207 | valid_mask = np.logical_and(valid_mask, eval_mask)
208 |
209 | silog[i], log10[i], abs_rel[i], sq_rel[i], rms[i], log_rms[i], d1[i], d2[i], d3[i] = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
210 |
211 | print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format(
212 | 'd1', 'd2', 'd3', 'AbsRel', 'SqRel', 'RMSE', 'RMSElog', 'SILog', 'log10'))
213 | print("{:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}, {:7.3f}".format(
214 | d1.mean(), d2.mean(), d3.mean(),
215 | abs_rel.mean(), sq_rel.mean(), rms.mean(), log_rms.mean(), silog.mean(), log10.mean()))
216 |
217 | return silog, log10, abs_rel, sq_rel, rms, log_rms, d1, d2, d3
218 |
219 |
220 | def main():
221 | test()
222 |
223 |
224 | if __name__ == '__main__':
225 | main()
226 |
227 |
228 |
229 |
--------------------------------------------------------------------------------
/utils/extract_official_train_test_set_from_mat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #######################################################################################
4 | # The MIT License
5 |
6 | # Copyright (c) 2014 Hannes Schulz, University of Bonn
7 | # Copyright (c) 2013 Benedikt Waldvogel, University of Bonn
8 | # Copyright (c) 2008-2009 Sebastian Nowozin
9 |
10 | # Permission is hereby granted, free of charge, to any person obtaining a copy
11 | # of this software and associated documentation files (the "Software"), to deal
12 | # in the Software without restriction, including without limitation the rights
13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14 | # copies of the Software, and to permit persons to whom the Software is
15 | # furnished to do so, subject to the following conditions:
16 | #
17 | # The above copyright notice and this permission notice shall be included in all
18 | # copies or substantial portions of the Software.
19 | #
20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26 | # SOFTWARE.
27 | #######################################################################################
28 | #
29 | # Helper script to convert the NYU Depth v2 dataset Matlab file into a set of
30 | # PNG and JPEG images.
31 | #
32 | # See https://github.com/deeplearningais/curfil/wiki/Training-and-Prediction-with-the-NYU-Depth-v2-Dataset
33 |
34 | from __future__ import print_function
35 |
36 | import h5py
37 | import numpy as np
38 | import os
39 | import scipy.io
40 | import sys
41 | import cv2
42 |
43 |
44 | def convert_image(i, scene, depth_raw, image):
45 |
46 | idx = int(i) + 1
47 | if idx in train_images:
48 | train_test = "train"
49 | else:
50 | assert idx in test_images, "index %d neither found in training set nor in test set" % idx
51 | train_test = "test"
52 |
53 | folder = "%s/%s/%s" % (out_folder, train_test, scene)
54 | if not os.path.exists(folder):
55 | os.makedirs(folder)
56 |
57 | img_depth = depth_raw * 1000.0
58 | img_depth_uint16 = img_depth.astype(np.uint16)
59 | cv2.imwrite("%s/sync_depth_%05d.png" % (folder, i), img_depth_uint16)
60 | image = image[:, :, ::-1]
61 | image_black_boundary = np.zeros((480, 640, 3), dtype=np.uint8)
62 | image_black_boundary[7:474, 7:632, :] = image[7:474, 7:632, :]
63 | cv2.imwrite("%s/rgb_%05d.jpg" % (folder, i), image_black_boundary)
64 |
65 |
66 | if __name__ == "__main__":
67 |
68 | if len(sys.argv) < 4:
69 | print("usage: %s " % sys.argv[0], file=sys.stderr)
70 | sys.exit(0)
71 |
72 | h5_file = h5py.File(sys.argv[1], "r")
73 | # h5py is not able to open that file. but scipy is
74 | train_test = scipy.io.loadmat(sys.argv[2])
75 | out_folder = sys.argv[3]
76 |
77 | test_images = set([int(x) for x in train_test["testNdxs"]])
78 | train_images = set([int(x) for x in train_test["trainNdxs"]])
79 | print("%d training images" % len(train_images))
80 | print("%d test images" % len(test_images))
81 |
82 | depth_raw = h5_file['rawDepths']
83 |
84 | print("reading", sys.argv[1])
85 |
86 | images = h5_file['images']
87 | scenes = [u''.join(chr(c) for c in h5_file[obj_ref]) for obj_ref in h5_file['sceneTypes'][0]]
88 |
89 | print("processing images")
90 | for i, image in enumerate(images):
91 | print("image", i + 1, "/", len(images))
92 | convert_image(i, scenes[i], depth_raw[i, :, :].T, image.T)
93 |
94 | print("Finished")
--------------------------------------------------------------------------------
/utils/kitti_archives_to_download.txt:
--------------------------------------------------------------------------------
1 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_calib.zip
2 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0001/2011_09_26_drive_0001_sync.zip
3 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0002/2011_09_26_drive_0002_sync.zip
4 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0005/2011_09_26_drive_0005_sync.zip
5 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0009/2011_09_26_drive_0009_sync.zip
6 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0011/2011_09_26_drive_0011_sync.zip
7 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0013/2011_09_26_drive_0013_sync.zip
8 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0014/2011_09_26_drive_0014_sync.zip
9 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0015/2011_09_26_drive_0015_sync.zip
10 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0017/2011_09_26_drive_0017_sync.zip
11 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0018/2011_09_26_drive_0018_sync.zip
12 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0019/2011_09_26_drive_0019_sync.zip
13 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0020/2011_09_26_drive_0020_sync.zip
14 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0022/2011_09_26_drive_0022_sync.zip
15 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0023/2011_09_26_drive_0023_sync.zip
16 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0027/2011_09_26_drive_0027_sync.zip
17 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0028/2011_09_26_drive_0028_sync.zip
18 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0029/2011_09_26_drive_0029_sync.zip
19 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0032/2011_09_26_drive_0032_sync.zip
20 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0035/2011_09_26_drive_0035_sync.zip
21 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0036/2011_09_26_drive_0036_sync.zip
22 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0039/2011_09_26_drive_0039_sync.zip
23 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0046/2011_09_26_drive_0046_sync.zip
24 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0048/2011_09_26_drive_0048_sync.zip
25 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0051/2011_09_26_drive_0051_sync.zip
26 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0052/2011_09_26_drive_0052_sync.zip
27 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0056/2011_09_26_drive_0056_sync.zip
28 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0057/2011_09_26_drive_0057_sync.zip
29 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0059/2011_09_26_drive_0059_sync.zip
30 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0060/2011_09_26_drive_0060_sync.zip
31 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0061/2011_09_26_drive_0061_sync.zip
32 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0064/2011_09_26_drive_0064_sync.zip
33 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0070/2011_09_26_drive_0070_sync.zip
34 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0079/2011_09_26_drive_0079_sync.zip
35 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0084/2011_09_26_drive_0084_sync.zip
36 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0086/2011_09_26_drive_0086_sync.zip
37 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0087/2011_09_26_drive_0087_sync.zip
38 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0091/2011_09_26_drive_0091_sync.zip
39 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0093/2011_09_26_drive_0093_sync.zip
40 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0095/2011_09_26_drive_0095_sync.zip
41 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0096/2011_09_26_drive_0096_sync.zip
42 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0101/2011_09_26_drive_0101_sync.zip
43 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0104/2011_09_26_drive_0104_sync.zip
44 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0106/2011_09_26_drive_0106_sync.zip
45 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0113/2011_09_26_drive_0113_sync.zip
46 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0117/2011_09_26_drive_0117_sync.zip
47 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_28_calib.zip
48 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_28_drive_0001/2011_09_28_drive_0001_sync.zip
49 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_28_drive_0002/2011_09_28_drive_0002_sync.zip
50 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_calib.zip
51 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_drive_0004/2011_09_29_drive_0004_sync.zip
52 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_drive_0026/2011_09_29_drive_0026_sync.zip
53 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_drive_0071/2011_09_29_drive_0071_sync.zip
54 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_calib.zip
55 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0016/2011_09_30_drive_0016_sync.zip
56 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0018/2011_09_30_drive_0018_sync.zip
57 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0020/2011_09_30_drive_0020_sync.zip
58 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0027/2011_09_30_drive_0027_sync.zip
59 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0028/2011_09_30_drive_0028_sync.zip
60 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0033/2011_09_30_drive_0033_sync.zip
61 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0034/2011_09_30_drive_0034_sync.zip
62 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_calib.zip
63 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_drive_0027/2011_10_03_drive_0027_sync.zip
64 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_drive_0034/2011_10_03_drive_0034_sync.zip
65 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_drive_0042/2011_10_03_drive_0042_sync.zip
66 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_drive_0047/2011_10_03_drive_0047_sync.zip
67 |
--------------------------------------------------------------------------------
/utils/nyudepthv2_archives_to_download.txt:
--------------------------------------------------------------------------------
1 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/basements.zip
2 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bathrooms_part1.zip
3 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bathrooms_part2.zip
4 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bathrooms_part3.zip
5 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bathrooms_part4.zip
6 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part1.zip
7 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part2.zip
8 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part3.zip
9 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part4.zip
10 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part5.zip
11 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part6.zip
12 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part7.zip
13 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bookstore_part1.zip
14 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bookstore_part2.zip
15 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bookstore_part3.zip
16 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/cafe.zip
17 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/classrooms.zip
18 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/dining_rooms_part1.zip
19 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/dining_rooms_part2.zip
20 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/furniture_stores.zip
21 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/home_offices.zip
22 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/kitchens_part1.zip
23 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/kitchens_part2.zip
24 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/kitchens_part3.zip
25 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/libraries.zip
26 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/living_rooms_part1.zip
27 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/living_rooms_part2.zip
28 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/living_rooms_part3.zip
29 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/living_rooms_part4.zip
30 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/misc_part1.zip
31 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/misc_part2.zip
32 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/offices_part1.zip
33 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/offices_part2.zip
34 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/office_kitchens.zip
35 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/playrooms.zip
36 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/reception_rooms.zip
37 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/studies.zip
38 | http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/study_rooms.zip
--------------------------------------------------------------------------------
/utils/splits.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cleinc/bts/dd62221bc50ff3cbe4559a832c94776830247e2e/utils/splits.mat
--------------------------------------------------------------------------------
/utils/sync_project_frames_multi_threads.m:
--------------------------------------------------------------------------------
1 | % The directory where you extracted the raw dataset.
2 | datasetDir = '../../../dataset/nyu_depth_v2/raw';
3 | % The directory where you want save the synced images and depths.
4 | dstDir = '../../../dataset/nyu_depth_v2/sync';
5 |
6 | fid = fopen('train_scenes.txt', 'r');
7 | tline = fgetl(fid);
8 | sceneNames = cell(0, 1);
9 | while ischar(tline)
10 | sceneNames{end+1, 1} = tline;
11 | tline = fgetl(fid);
12 | end
13 | fclose(fid);
14 |
15 | num_threads = 6;
16 | sample_step = 7;
17 |
18 | parpool(num_threads);
19 |
20 | for aa = 1 : num_threads : numel(sceneNames)
21 | actual_num_threads = min(numel(sceneNames) - aa + 1, num_threads);
22 | sceneNames_batch = sceneNames(aa:aa+actual_num_threads-1);
23 | parfor i = 1:actual_num_threads
24 | sceneName = sceneNames_batch{i};
25 | % The absolute directory of the
26 | sceneDir = sprintf('%s/%s', datasetDir, sceneName);
27 |
28 | % Reads the list of frames.
29 | frameList = get_synched_frames(sceneDir);
30 | saveDir = sprintf('%s/%s', dstDir, sceneName)
31 | if ~exist(saveDir, 'dir')
32 | % Folder does not exist so create it.
33 | mkdir(saveDir);
34 | end
35 |
36 | ind = 0;
37 |
38 | % Displays each pair of synchronized RGB and Depth frames.
39 | for ii = 1 : sample_step : numel(frameList)
40 | imgRgb = imread([sceneDir '/' frameList(ii).rawRgbFilename]);
41 | if frameList(ii).rawDepthFilename == "d-1315166703.129542-2466101449.pgm" % Faulty image
42 | continue;
43 | end
44 | imgDepthRaw = swapbytes(imread([sceneDir '/' frameList(ii).rawDepthFilename]));
45 | [imgDepthProj, imgRgbUd] = project_depth_map(imgDepthRaw, imgRgb);
46 |
47 | rgb_dst = sprintf('%s/rgb_%05d.jpg', saveDir, ind);
48 | imwrite(imgRgbUd, rgb_dst);
49 |
50 | imgDepthProj = uint16(imgDepthProj * 1000.0);
51 | sync_depth_dst = sprintf('%s/sync_depth_%05d.png', saveDir, ind);
52 | imwrite(imgDepthProj, sync_depth_dst);
53 |
54 | ind = ind + 1;
55 | fprintf('%d/%d done\n', ii, numel(frameList));
56 | end
57 | fprintf('%s done', sceneName);
58 | end
59 | end
--------------------------------------------------------------------------------
/utils/train_scenes.txt:
--------------------------------------------------------------------------------
1 | basement_0001a
2 | basement_0001b
3 | bathroom_0001
4 | bathroom_0002
5 | bathroom_0005
6 | bathroom_0006
7 | bathroom_0007
8 | bathroom_0010
9 | bathroom_0011
10 | bathroom_0013
11 | bathroom_0014a
12 | bathroom_0016
13 | bathroom_0019
14 | bathroom_0023
15 | bathroom_0024
16 | bathroom_0028
17 | bathroom_0030
18 | bathroom_0033
19 | bathroom_0034
20 | bathroom_0035
21 | bathroom_0039
22 | bathroom_0041
23 | bathroom_0042
24 | bathroom_0045a
25 | bathroom_0048
26 | bathroom_0049
27 | bathroom_0050
28 | bathroom_0051
29 | bathroom_0053
30 | bathroom_0054
31 | bathroom_0055
32 | bathroom_0056
33 | bathroom_0057
34 | bedroom_0004
35 | bedroom_0010
36 | bedroom_0012
37 | bedroom_0014
38 | bedroom_0015
39 | bedroom_0016
40 | bedroom_0017
41 | bedroom_0019
42 | bedroom_0020
43 | bedroom_0021
44 | bedroom_0025
45 | bedroom_0026
46 | bedroom_0028
47 | bedroom_0029
48 | bedroom_0031
49 | bedroom_0033
50 | bedroom_0034
51 | bedroom_0035
52 | bedroom_0036
53 | bedroom_0038
54 | bedroom_0039
55 | bedroom_0040
56 | bedroom_0041
57 | bedroom_0042
58 | bedroom_0045
59 | bedroom_0047
60 | bedroom_0050
61 | bedroom_0051
62 | bedroom_0052
63 | bedroom_0053
64 | bedroom_0056a
65 | bedroom_0056b
66 | bedroom_0057
67 | bedroom_0059
68 | bedroom_0060
69 | bedroom_0062
70 | bedroom_0063
71 | bedroom_0065
72 | bedroom_0066
73 | bedroom_0067a
74 | bedroom_0067b
75 | bedroom_0069
76 | bedroom_0071
77 | bedroom_0072
78 | bedroom_0074
79 | bedroom_0076a
80 | bedroom_0078
81 | bedroom_0079
82 | bedroom_0080
83 | bedroom_0081
84 | bedroom_0082
85 | bedroom_0086
86 | bedroom_0090
87 | bedroom_0094
88 | bedroom_0096
89 | bedroom_0097
90 | bedroom_0098
91 | bedroom_0100
92 | bedroom_0104
93 | bedroom_0106
94 | bedroom_0107
95 | bedroom_0113
96 | bedroom_0116
97 | bedroom_0118
98 | bedroom_0120
99 | bedroom_0124
100 | bedroom_0125a
101 | bedroom_0125b
102 | bedroom_0126
103 | bedroom_0129
104 | bedroom_0130
105 | bedroom_0132
106 | bedroom_0136
107 | bedroom_0138
108 | bedroom_0140
109 | bookstore_0001d
110 | bookstore_0001e
111 | bookstore_0001f
112 | bookstore_0001g
113 | bookstore_0001h
114 | bookstore_0001i
115 | bookstore_0001j
116 | cafe_0001a
117 | cafe_0001b
118 | cafe_0001c
119 | classroom_0003
120 | classroom_0004
121 | classroom_0005
122 | classroom_0006
123 | classroom_0010
124 | classroom_0011
125 | classroom_0012
126 | classroom_0016
127 | classroom_0018
128 | classroom_0022
129 | computer_lab_0002
130 | conference_room_0001
131 | conference_room_0002
132 | dinette_0001
133 | dining_room_0001b
134 | dining_room_0002
135 | dining_room_0004
136 | dining_room_0007
137 | dining_room_0008
138 | dining_room_0010
139 | dining_room_0012
140 | dining_room_0013
141 | dining_room_0014
142 | dining_room_0015
143 | dining_room_0016
144 | dining_room_0019
145 | dining_room_0023
146 | dining_room_0024
147 | dining_room_0028
148 | dining_room_0029
149 | dining_room_0031
150 | dining_room_0033
151 | dining_room_0034
152 | dining_room_0037
153 | excercise_room_0001
154 | foyer_0002
155 | furniture_store_0001a
156 | furniture_store_0001b
157 | furniture_store_0001c
158 | furniture_store_0001d
159 | furniture_store_0001e
160 | furniture_store_0001f
161 | furniture_store_0002a
162 | furniture_store_0002b
163 | furniture_store_0002c
164 | furniture_store_0002d
165 | home_office_0004
166 | home_office_0005
167 | home_office_0006
168 | home_office_0007
169 | home_office_0008
170 | home_office_0011
171 | home_office_0013
172 | home_storage_0001
173 | indoor_balcony_0001
174 | kitchen_0003
175 | kitchen_0006
176 | kitchen_0008
177 | kitchen_0010
178 | kitchen_0011a
179 | kitchen_0011b
180 | kitchen_0016
181 | kitchen_0017
182 | kitchen_0019a
183 | kitchen_0019b
184 | kitchen_0028a
185 | kitchen_0028b
186 | kitchen_0029a
187 | kitchen_0029b
188 | kitchen_0029c
189 | kitchen_0031
190 | kitchen_0033
191 | kitchen_0035a
192 | kitchen_0035b
193 | kitchen_0037
194 | kitchen_0041
195 | kitchen_0043
196 | kitchen_0045a
197 | kitchen_0045b
198 | kitchen_0047
199 | kitchen_0048
200 | kitchen_0049
201 | kitchen_0050
202 | kitchen_0051
203 | kitchen_0052
204 | kitchen_0053
205 | kitchen_0059
206 | kitchen_0060
207 | laundry_room_0001
208 | living_room_0004
209 | living_room_0005
210 | living_room_0006
211 | living_room_0010
212 | living_room_0011
213 | living_room_0012
214 | living_room_0018
215 | living_room_0019
216 | living_room_0020
217 | living_room_0022
218 | living_room_0029
219 | living_room_0032
220 | living_room_0033
221 | living_room_0035
222 | living_room_0037
223 | living_room_0038
224 | living_room_0039
225 | living_room_0040
226 | living_room_0042a
227 | living_room_0042b
228 | living_room_0046a
229 | living_room_0046b
230 | living_room_0047a
231 | living_room_0047b
232 | living_room_0050
233 | living_room_0055
234 | living_room_0058
235 | living_room_0062
236 | living_room_0063
237 | living_room_0067
238 | living_room_0068
239 | living_room_0069a
240 | living_room_0069b
241 | living_room_0070
242 | living_room_0071
243 | living_room_0078
244 | living_room_0082
245 | living_room_0083
246 | living_room_0085
247 | living_room_0086a
248 | living_room_0086b
249 | nyu_office_0
250 | nyu_office_1
251 | office_0003
252 | office_0004
253 | office_0006
254 | office_0009
255 | office_0011
256 | office_0012
257 | office_0018
258 | office_0019
259 | office_0021
260 | office_0023
261 | office_0024
262 | office_0025
263 | office_0026
264 | office_kitchen_0001a
265 | office_kitchen_0001b
266 | office_kitchen_0003
267 | playroom_0002
268 | playroom_0003
269 | playroom_0004
270 | playroom_0006
271 | printer_room_0001
272 | reception_room_0001a
273 | reception_room_0001b
274 | reception_room_0002
275 | reception_room_0004
276 | student_lounge_0001
277 | study_0003
278 | study_0004
279 | study_0005
280 | study_0006
281 | study_0008
282 | study_room_0004
283 | study_room_0005a
284 | study_room_0005b
--------------------------------------------------------------------------------