├── .gitignore
├── LICENSE
├── README.md
├── configs
├── arguments_eval_kittieigen.txt
├── arguments_eval_nyu.txt
├── arguments_train_kittieigen.txt
└── arguments_train_nyu.txt
├── data_splits
├── eigen_test_files_with_gt.txt
├── eigen_train_files_with_gt.txt
├── kitti_depth_prediction_train.txt
├── kitti_official_test.txt
├── kitti_official_valid.txt
├── nyudepthv2_test_files_with_gt.txt
└── nyudepthv2_train_files_with_gt_dense.txt
├── files
├── intro.png
├── office_00633.jpg
├── office_00633_depth.jpg
├── office_00633_pcd.jpg
├── output_nyu1_compressed.gif
└── output_nyu2_compressed.gif
└── newcrfs
├── dataloaders
├── __init__.py
├── dataloader.py
└── dataloader_kittipred.py
├── demo.py
├── eval.py
├── networks
├── NewCRFDepth.py
├── __init__.py
├── newcrf_layers.py
├── newcrf_utils.py
├── swin_transformer.py
└── uper_crf_head.py
├── test.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | __pycache__
3 | model_zoo
4 | models
5 | datasets
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Alibaba Cloud
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## NeW CRFs: Neural Window Fully-connected CRFs for Monocular Depth Estimation
2 |
3 | This is the official PyTorch implementation code for NeWCRFs. For technical details, please refer to:
4 |
5 | **NeW CRFs: Neural Window Fully-connected CRFs for Monocular Depth Estimation**
6 | Weihao Yuan, Xiaodong Gu, Zuozhuo Dai, Siyu Zhu, Ping Tan
7 | **CVPR 2022**
8 | **[[Project Page](https://weihaosky.github.io/newcrfs/)]** |
9 | **[[Paper](https://arxiv.org/abs/2203.01502)]**
10 |
11 |
12 |
13 |
14 |
15 |
20 |
21 | 
22 |
23 | ## Bibtex
24 | If you find this code useful in your research, please cite:
25 |
26 | ```
27 | @inproceedings{yuan2022newcrfs,
28 | title={NeWCRFs: Neural Window Fully-connected CRFs for Monocular Depth Estimation},
29 | author={Yuan, Weihao and Gu, Xiaodong and Dai, Zuozhuo and Zhu, Siyu and Tan, Ping},
30 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
31 | pages={},
32 | year={2022}
33 | }
34 | ```
35 |
36 | ## Contents
37 | 1. [Installation](#installation)
38 | 2. [Datasets](#datasets)
39 | 3. [Training](#training)
40 | 4. [Evaluation](#evaluation)
41 | 5. [Models](#models)
42 | 6. [Demo](#demo)
43 |
44 | ## Installation
45 | ```
46 | conda create -n newcrfs python=3.8
47 | conda activate newcrfs
48 | conda install pytorch=1.10.0 torchvision cudatoolkit=11.1
49 | pip install matplotlib, tqdm, tensorboardX, timm, mmcv
50 | ```
51 |
52 |
53 | ## Datasets
54 | You can prepare the datasets KITTI and NYUv2 according to [here](https://github.com/cleinc/bts), and then modify the data path in the config files to your dataset locations.
55 |
56 | Or you can download the NYUv2 data from [here](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/newcrfs/datasets/nyu/sync.zip) and download the KITTI data from [here](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_prediction).
57 |
58 |
59 | ## Training
60 | First download the pretrained encoder backbone from [here](https://github.com/microsoft/Swin-Transformer), and then modify the pretrain path in the config files.
61 |
62 | Training the NYUv2 model:
63 | ```
64 | python newcrfs/train.py configs/arguments_train_nyu.txt
65 | ```
66 |
67 | Training the KITTI model:
68 | ```
69 | python newcrfs/train.py configs/arguments_train_kittieigen.txt
70 | ```
71 |
72 |
73 | ## Evaluation
74 | Evaluate the NYUv2 model:
75 | ```
76 | python newcrfs/eval.py configs/arguments_eval_nyu.txt
77 | ```
78 |
79 | Evaluate the KITTI model:
80 | ```
81 | python newcrfs/eval.py configs/arguments_eval_kittieigen.txt
82 | ```
83 |
84 | ## Models
85 | | Model | Abs.Rel. | Sqr.Rel | RMSE | RMSElog | a1 | a2 | a3| SILog|
86 | | :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
87 | |[NYUv2](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/newcrfs/models/model_nyu.ckpt) | 0.0952 | 0.0443 | 0.3310 | 0.1185 | 0.923 | 0.992 | 0.998 | 9.1023 |
88 | |[KITTI_Eigen](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/newcrfs/models/model_kittieigen.ckpt) | 0.0520 | 0.1482 | 2.0716 | 0.0780 | 0.975 | 0.997 | 0.999 | 6.9859 |
89 |
90 |
91 | ## Demo
92 | Test images with the indoor model:
93 | ```
94 | python newcrfs/test.py --data_path datasets/test_data --dataset nyu --filenames_file data_splits/test_list.txt --checkpoint_path model_nyu.ckpt --max_depth 10 --save_viz
95 | ```
96 |
97 | Play with the live demo from a video or your webcam:
98 | ```
99 | python newcrfs/demo.py --dataset nyu --checkpoint_path model_zoo/model_nyu.ckpt --max_depth 10 --video video.mp4
100 | ```
101 |
102 | 
103 |
104 | [Demo video1](https://www.youtube.com/watch?v=RrWQIpXoP2Y)
105 |
106 | [Demo video2](https://www.youtube.com/watch?v=fD3sWH_54cg)
107 |
108 | [Demo video3](https://www.youtube.com/watch?v=IztmOYZNirM)
109 |
110 | ## Acknowledgements
111 | Thanks to Jin Han Lee for opening source of the excellent work [BTS](https://github.com/cleinc/bts).
112 | Thanks to Microsoft Research Asia for opening source of the excellent work [Swin Transformer](https://github.com/microsoft/Swin-Transformer).
--------------------------------------------------------------------------------
/configs/arguments_eval_kittieigen.txt:
--------------------------------------------------------------------------------
1 | --model_name newcrfs_kittieigen
2 | --encoder large07
3 | --dataset kitti
4 | --input_height 352
5 | --input_width 1216
6 | --max_depth 80
7 | --do_kb_crop
8 |
9 | --data_path_eval datasets/kitti/
10 | --gt_path_eval datasets/kitti/
11 | --filenames_file_eval data_splits/eigen_test_files_with_gt.txt
12 | --min_depth_eval 1e-3
13 | --max_depth_eval 80
14 | --garg_crop
15 |
16 | --checkpoint_path model_zoo/model_kittieigen.ckpt
--------------------------------------------------------------------------------
/configs/arguments_eval_nyu.txt:
--------------------------------------------------------------------------------
1 | --model_name newscrf_nyu
2 | --encoder large07
3 | --dataset nyu
4 | --input_height 480
5 | --input_width 640
6 | --max_depth 10
7 |
8 | --data_path_eval datasets/nyu/official_splits/test/
9 | --gt_path_eval datasets/nyu/official_splits/test/
10 | --filenames_file_eval data_splits/nyudepthv2_test_files_with_gt.txt
11 | --min_depth_eval 1e-3
12 | --max_depth_eval 10
13 | --eigen_crop
14 |
15 | --checkpoint_path model_zoo/model_nyu.ckpt
--------------------------------------------------------------------------------
/configs/arguments_train_kittieigen.txt:
--------------------------------------------------------------------------------
1 | --mode train
2 | --model_name newcrfs_kittieigen
3 | --encoder large07
4 | --pretrain model_zoo/swin_transformer/swin_large_patch4_window7_224_22k.pth
5 | --dataset kitti
6 | --data_path datasets/kitti/
7 | --gt_path datasets/kitti/
8 | --filenames_file data_splits/eigen_train_files_with_gt.txt
9 | --batch_size 8
10 | --num_epochs 50
11 | --learning_rate 2e-5
12 | --weight_decay 1e-2
13 | --adam_eps 1e-3
14 | --num_threads 1
15 | --input_height 352
16 | --input_width 1120
17 | --max_depth 80
18 | --do_kb_crop
19 | --do_random_rotate
20 | --degree 1.0
21 | --log_directory ./models/
22 | --multiprocessing_distributed
23 | --dist_url tcp://127.0.0.1:2345
24 |
25 | --log_freq 100
26 | --do_online_eval
27 | --eval_freq 1000
28 | --data_path_eval datasets/kitti/
29 | --gt_path_eval datasets/kitti/
30 | --filenames_file_eval data_splits/eigen_test_files_with_gt.txt
31 | --min_depth_eval 1e-3
32 | --max_depth_eval 80
33 | --garg_crop
34 |
--------------------------------------------------------------------------------
/configs/arguments_train_nyu.txt:
--------------------------------------------------------------------------------
1 | --mode train
2 | --model_name newcrfs_nyu
3 | --encoder large07
4 | --pretrain model_zoo/swin_transformer/swin_large_patch4_window7_224_22k.pth
5 | --dataset nyu
6 | --data_path datasets/nyu/sync/
7 | --gt_path datasets/nyu/sync/
8 | --filenames_file data_splits/nyudepthv2_train_files_with_gt_dense.txt
9 | --batch_size 8
10 | --num_epochs 50
11 | --learning_rate 2e-5
12 | --weight_decay 1e-2
13 | --adam_eps 1e-3
14 | --num_threads 1
15 | --input_height 480
16 | --input_width 640
17 | --max_depth 10
18 | --do_random_rotate
19 | --degree 2.5
20 | --log_directory ./models/
21 | --multiprocessing_distributed
22 | --dist_url tcp://127.0.0.1:2345
23 |
24 | --log_freq 100
25 | --do_online_eval
26 | --eval_freq 1000
27 | --data_path_eval datasets/nyu/official_splits/test/
28 | --gt_path_eval datasets/nyu/official_splits/test/
29 | --filenames_file_eval data_splits/nyudepthv2_test_files_with_gt.txt
30 | --min_depth_eval 1e-3
31 | --max_depth_eval 10
32 | --eigen_crop
33 |
--------------------------------------------------------------------------------
/data_splits/kitti_official_test.txt:
--------------------------------------------------------------------------------
1 | depth_selection/test_depth_prediction_anonymous/image/0000000000.png
2 | depth_selection/test_depth_prediction_anonymous/image/0000000001.png
3 | depth_selection/test_depth_prediction_anonymous/image/0000000002.png
4 | depth_selection/test_depth_prediction_anonymous/image/0000000003.png
5 | depth_selection/test_depth_prediction_anonymous/image/0000000004.png
6 | depth_selection/test_depth_prediction_anonymous/image/0000000005.png
7 | depth_selection/test_depth_prediction_anonymous/image/0000000006.png
8 | depth_selection/test_depth_prediction_anonymous/image/0000000007.png
9 | depth_selection/test_depth_prediction_anonymous/image/0000000008.png
10 | depth_selection/test_depth_prediction_anonymous/image/0000000009.png
11 | depth_selection/test_depth_prediction_anonymous/image/0000000010.png
12 | depth_selection/test_depth_prediction_anonymous/image/0000000011.png
13 | depth_selection/test_depth_prediction_anonymous/image/0000000012.png
14 | depth_selection/test_depth_prediction_anonymous/image/0000000013.png
15 | depth_selection/test_depth_prediction_anonymous/image/0000000014.png
16 | depth_selection/test_depth_prediction_anonymous/image/0000000015.png
17 | depth_selection/test_depth_prediction_anonymous/image/0000000016.png
18 | depth_selection/test_depth_prediction_anonymous/image/0000000017.png
19 | depth_selection/test_depth_prediction_anonymous/image/0000000018.png
20 | depth_selection/test_depth_prediction_anonymous/image/0000000019.png
21 | depth_selection/test_depth_prediction_anonymous/image/0000000020.png
22 | depth_selection/test_depth_prediction_anonymous/image/0000000021.png
23 | depth_selection/test_depth_prediction_anonymous/image/0000000022.png
24 | depth_selection/test_depth_prediction_anonymous/image/0000000023.png
25 | depth_selection/test_depth_prediction_anonymous/image/0000000024.png
26 | depth_selection/test_depth_prediction_anonymous/image/0000000025.png
27 | depth_selection/test_depth_prediction_anonymous/image/0000000026.png
28 | depth_selection/test_depth_prediction_anonymous/image/0000000027.png
29 | depth_selection/test_depth_prediction_anonymous/image/0000000028.png
30 | depth_selection/test_depth_prediction_anonymous/image/0000000029.png
31 | depth_selection/test_depth_prediction_anonymous/image/0000000030.png
32 | depth_selection/test_depth_prediction_anonymous/image/0000000031.png
33 | depth_selection/test_depth_prediction_anonymous/image/0000000032.png
34 | depth_selection/test_depth_prediction_anonymous/image/0000000033.png
35 | depth_selection/test_depth_prediction_anonymous/image/0000000034.png
36 | depth_selection/test_depth_prediction_anonymous/image/0000000035.png
37 | depth_selection/test_depth_prediction_anonymous/image/0000000036.png
38 | depth_selection/test_depth_prediction_anonymous/image/0000000037.png
39 | depth_selection/test_depth_prediction_anonymous/image/0000000038.png
40 | depth_selection/test_depth_prediction_anonymous/image/0000000039.png
41 | depth_selection/test_depth_prediction_anonymous/image/0000000040.png
42 | depth_selection/test_depth_prediction_anonymous/image/0000000041.png
43 | depth_selection/test_depth_prediction_anonymous/image/0000000042.png
44 | depth_selection/test_depth_prediction_anonymous/image/0000000043.png
45 | depth_selection/test_depth_prediction_anonymous/image/0000000044.png
46 | depth_selection/test_depth_prediction_anonymous/image/0000000045.png
47 | depth_selection/test_depth_prediction_anonymous/image/0000000046.png
48 | depth_selection/test_depth_prediction_anonymous/image/0000000047.png
49 | depth_selection/test_depth_prediction_anonymous/image/0000000048.png
50 | depth_selection/test_depth_prediction_anonymous/image/0000000049.png
51 | depth_selection/test_depth_prediction_anonymous/image/0000000050.png
52 | depth_selection/test_depth_prediction_anonymous/image/0000000051.png
53 | depth_selection/test_depth_prediction_anonymous/image/0000000052.png
54 | depth_selection/test_depth_prediction_anonymous/image/0000000053.png
55 | depth_selection/test_depth_prediction_anonymous/image/0000000054.png
56 | depth_selection/test_depth_prediction_anonymous/image/0000000055.png
57 | depth_selection/test_depth_prediction_anonymous/image/0000000056.png
58 | depth_selection/test_depth_prediction_anonymous/image/0000000057.png
59 | depth_selection/test_depth_prediction_anonymous/image/0000000058.png
60 | depth_selection/test_depth_prediction_anonymous/image/0000000059.png
61 | depth_selection/test_depth_prediction_anonymous/image/0000000060.png
62 | depth_selection/test_depth_prediction_anonymous/image/0000000061.png
63 | depth_selection/test_depth_prediction_anonymous/image/0000000062.png
64 | depth_selection/test_depth_prediction_anonymous/image/0000000063.png
65 | depth_selection/test_depth_prediction_anonymous/image/0000000064.png
66 | depth_selection/test_depth_prediction_anonymous/image/0000000065.png
67 | depth_selection/test_depth_prediction_anonymous/image/0000000066.png
68 | depth_selection/test_depth_prediction_anonymous/image/0000000067.png
69 | depth_selection/test_depth_prediction_anonymous/image/0000000068.png
70 | depth_selection/test_depth_prediction_anonymous/image/0000000069.png
71 | depth_selection/test_depth_prediction_anonymous/image/0000000070.png
72 | depth_selection/test_depth_prediction_anonymous/image/0000000071.png
73 | depth_selection/test_depth_prediction_anonymous/image/0000000072.png
74 | depth_selection/test_depth_prediction_anonymous/image/0000000073.png
75 | depth_selection/test_depth_prediction_anonymous/image/0000000074.png
76 | depth_selection/test_depth_prediction_anonymous/image/0000000075.png
77 | depth_selection/test_depth_prediction_anonymous/image/0000000076.png
78 | depth_selection/test_depth_prediction_anonymous/image/0000000077.png
79 | depth_selection/test_depth_prediction_anonymous/image/0000000078.png
80 | depth_selection/test_depth_prediction_anonymous/image/0000000079.png
81 | depth_selection/test_depth_prediction_anonymous/image/0000000080.png
82 | depth_selection/test_depth_prediction_anonymous/image/0000000081.png
83 | depth_selection/test_depth_prediction_anonymous/image/0000000082.png
84 | depth_selection/test_depth_prediction_anonymous/image/0000000083.png
85 | depth_selection/test_depth_prediction_anonymous/image/0000000084.png
86 | depth_selection/test_depth_prediction_anonymous/image/0000000085.png
87 | depth_selection/test_depth_prediction_anonymous/image/0000000086.png
88 | depth_selection/test_depth_prediction_anonymous/image/0000000087.png
89 | depth_selection/test_depth_prediction_anonymous/image/0000000088.png
90 | depth_selection/test_depth_prediction_anonymous/image/0000000089.png
91 | depth_selection/test_depth_prediction_anonymous/image/0000000090.png
92 | depth_selection/test_depth_prediction_anonymous/image/0000000091.png
93 | depth_selection/test_depth_prediction_anonymous/image/0000000092.png
94 | depth_selection/test_depth_prediction_anonymous/image/0000000093.png
95 | depth_selection/test_depth_prediction_anonymous/image/0000000094.png
96 | depth_selection/test_depth_prediction_anonymous/image/0000000095.png
97 | depth_selection/test_depth_prediction_anonymous/image/0000000096.png
98 | depth_selection/test_depth_prediction_anonymous/image/0000000097.png
99 | depth_selection/test_depth_prediction_anonymous/image/0000000098.png
100 | depth_selection/test_depth_prediction_anonymous/image/0000000099.png
101 | depth_selection/test_depth_prediction_anonymous/image/0000000100.png
102 | depth_selection/test_depth_prediction_anonymous/image/0000000101.png
103 | depth_selection/test_depth_prediction_anonymous/image/0000000102.png
104 | depth_selection/test_depth_prediction_anonymous/image/0000000103.png
105 | depth_selection/test_depth_prediction_anonymous/image/0000000104.png
106 | depth_selection/test_depth_prediction_anonymous/image/0000000105.png
107 | depth_selection/test_depth_prediction_anonymous/image/0000000106.png
108 | depth_selection/test_depth_prediction_anonymous/image/0000000107.png
109 | depth_selection/test_depth_prediction_anonymous/image/0000000108.png
110 | depth_selection/test_depth_prediction_anonymous/image/0000000109.png
111 | depth_selection/test_depth_prediction_anonymous/image/0000000110.png
112 | depth_selection/test_depth_prediction_anonymous/image/0000000111.png
113 | depth_selection/test_depth_prediction_anonymous/image/0000000112.png
114 | depth_selection/test_depth_prediction_anonymous/image/0000000113.png
115 | depth_selection/test_depth_prediction_anonymous/image/0000000114.png
116 | depth_selection/test_depth_prediction_anonymous/image/0000000115.png
117 | depth_selection/test_depth_prediction_anonymous/image/0000000116.png
118 | depth_selection/test_depth_prediction_anonymous/image/0000000117.png
119 | depth_selection/test_depth_prediction_anonymous/image/0000000118.png
120 | depth_selection/test_depth_prediction_anonymous/image/0000000119.png
121 | depth_selection/test_depth_prediction_anonymous/image/0000000120.png
122 | depth_selection/test_depth_prediction_anonymous/image/0000000121.png
123 | depth_selection/test_depth_prediction_anonymous/image/0000000122.png
124 | depth_selection/test_depth_prediction_anonymous/image/0000000123.png
125 | depth_selection/test_depth_prediction_anonymous/image/0000000124.png
126 | depth_selection/test_depth_prediction_anonymous/image/0000000125.png
127 | depth_selection/test_depth_prediction_anonymous/image/0000000126.png
128 | depth_selection/test_depth_prediction_anonymous/image/0000000127.png
129 | depth_selection/test_depth_prediction_anonymous/image/0000000128.png
130 | depth_selection/test_depth_prediction_anonymous/image/0000000129.png
131 | depth_selection/test_depth_prediction_anonymous/image/0000000130.png
132 | depth_selection/test_depth_prediction_anonymous/image/0000000131.png
133 | depth_selection/test_depth_prediction_anonymous/image/0000000132.png
134 | depth_selection/test_depth_prediction_anonymous/image/0000000133.png
135 | depth_selection/test_depth_prediction_anonymous/image/0000000134.png
136 | depth_selection/test_depth_prediction_anonymous/image/0000000135.png
137 | depth_selection/test_depth_prediction_anonymous/image/0000000136.png
138 | depth_selection/test_depth_prediction_anonymous/image/0000000137.png
139 | depth_selection/test_depth_prediction_anonymous/image/0000000138.png
140 | depth_selection/test_depth_prediction_anonymous/image/0000000139.png
141 | depth_selection/test_depth_prediction_anonymous/image/0000000140.png
142 | depth_selection/test_depth_prediction_anonymous/image/0000000141.png
143 | depth_selection/test_depth_prediction_anonymous/image/0000000142.png
144 | depth_selection/test_depth_prediction_anonymous/image/0000000143.png
145 | depth_selection/test_depth_prediction_anonymous/image/0000000144.png
146 | depth_selection/test_depth_prediction_anonymous/image/0000000145.png
147 | depth_selection/test_depth_prediction_anonymous/image/0000000146.png
148 | depth_selection/test_depth_prediction_anonymous/image/0000000147.png
149 | depth_selection/test_depth_prediction_anonymous/image/0000000148.png
150 | depth_selection/test_depth_prediction_anonymous/image/0000000149.png
151 | depth_selection/test_depth_prediction_anonymous/image/0000000150.png
152 | depth_selection/test_depth_prediction_anonymous/image/0000000151.png
153 | depth_selection/test_depth_prediction_anonymous/image/0000000152.png
154 | depth_selection/test_depth_prediction_anonymous/image/0000000153.png
155 | depth_selection/test_depth_prediction_anonymous/image/0000000154.png
156 | depth_selection/test_depth_prediction_anonymous/image/0000000155.png
157 | depth_selection/test_depth_prediction_anonymous/image/0000000156.png
158 | depth_selection/test_depth_prediction_anonymous/image/0000000157.png
159 | depth_selection/test_depth_prediction_anonymous/image/0000000158.png
160 | depth_selection/test_depth_prediction_anonymous/image/0000000159.png
161 | depth_selection/test_depth_prediction_anonymous/image/0000000160.png
162 | depth_selection/test_depth_prediction_anonymous/image/0000000161.png
163 | depth_selection/test_depth_prediction_anonymous/image/0000000162.png
164 | depth_selection/test_depth_prediction_anonymous/image/0000000163.png
165 | depth_selection/test_depth_prediction_anonymous/image/0000000164.png
166 | depth_selection/test_depth_prediction_anonymous/image/0000000165.png
167 | depth_selection/test_depth_prediction_anonymous/image/0000000166.png
168 | depth_selection/test_depth_prediction_anonymous/image/0000000167.png
169 | depth_selection/test_depth_prediction_anonymous/image/0000000168.png
170 | depth_selection/test_depth_prediction_anonymous/image/0000000169.png
171 | depth_selection/test_depth_prediction_anonymous/image/0000000170.png
172 | depth_selection/test_depth_prediction_anonymous/image/0000000171.png
173 | depth_selection/test_depth_prediction_anonymous/image/0000000172.png
174 | depth_selection/test_depth_prediction_anonymous/image/0000000173.png
175 | depth_selection/test_depth_prediction_anonymous/image/0000000174.png
176 | depth_selection/test_depth_prediction_anonymous/image/0000000175.png
177 | depth_selection/test_depth_prediction_anonymous/image/0000000176.png
178 | depth_selection/test_depth_prediction_anonymous/image/0000000177.png
179 | depth_selection/test_depth_prediction_anonymous/image/0000000178.png
180 | depth_selection/test_depth_prediction_anonymous/image/0000000179.png
181 | depth_selection/test_depth_prediction_anonymous/image/0000000180.png
182 | depth_selection/test_depth_prediction_anonymous/image/0000000181.png
183 | depth_selection/test_depth_prediction_anonymous/image/0000000182.png
184 | depth_selection/test_depth_prediction_anonymous/image/0000000183.png
185 | depth_selection/test_depth_prediction_anonymous/image/0000000184.png
186 | depth_selection/test_depth_prediction_anonymous/image/0000000185.png
187 | depth_selection/test_depth_prediction_anonymous/image/0000000186.png
188 | depth_selection/test_depth_prediction_anonymous/image/0000000187.png
189 | depth_selection/test_depth_prediction_anonymous/image/0000000188.png
190 | depth_selection/test_depth_prediction_anonymous/image/0000000189.png
191 | depth_selection/test_depth_prediction_anonymous/image/0000000190.png
192 | depth_selection/test_depth_prediction_anonymous/image/0000000191.png
193 | depth_selection/test_depth_prediction_anonymous/image/0000000192.png
194 | depth_selection/test_depth_prediction_anonymous/image/0000000193.png
195 | depth_selection/test_depth_prediction_anonymous/image/0000000194.png
196 | depth_selection/test_depth_prediction_anonymous/image/0000000195.png
197 | depth_selection/test_depth_prediction_anonymous/image/0000000196.png
198 | depth_selection/test_depth_prediction_anonymous/image/0000000197.png
199 | depth_selection/test_depth_prediction_anonymous/image/0000000198.png
200 | depth_selection/test_depth_prediction_anonymous/image/0000000199.png
201 | depth_selection/test_depth_prediction_anonymous/image/0000000200.png
202 | depth_selection/test_depth_prediction_anonymous/image/0000000201.png
203 | depth_selection/test_depth_prediction_anonymous/image/0000000202.png
204 | depth_selection/test_depth_prediction_anonymous/image/0000000203.png
205 | depth_selection/test_depth_prediction_anonymous/image/0000000204.png
206 | depth_selection/test_depth_prediction_anonymous/image/0000000205.png
207 | depth_selection/test_depth_prediction_anonymous/image/0000000206.png
208 | depth_selection/test_depth_prediction_anonymous/image/0000000207.png
209 | depth_selection/test_depth_prediction_anonymous/image/0000000208.png
210 | depth_selection/test_depth_prediction_anonymous/image/0000000209.png
211 | depth_selection/test_depth_prediction_anonymous/image/0000000210.png
212 | depth_selection/test_depth_prediction_anonymous/image/0000000211.png
213 | depth_selection/test_depth_prediction_anonymous/image/0000000212.png
214 | depth_selection/test_depth_prediction_anonymous/image/0000000213.png
215 | depth_selection/test_depth_prediction_anonymous/image/0000000214.png
216 | depth_selection/test_depth_prediction_anonymous/image/0000000215.png
217 | depth_selection/test_depth_prediction_anonymous/image/0000000216.png
218 | depth_selection/test_depth_prediction_anonymous/image/0000000217.png
219 | depth_selection/test_depth_prediction_anonymous/image/0000000218.png
220 | depth_selection/test_depth_prediction_anonymous/image/0000000219.png
221 | depth_selection/test_depth_prediction_anonymous/image/0000000220.png
222 | depth_selection/test_depth_prediction_anonymous/image/0000000221.png
223 | depth_selection/test_depth_prediction_anonymous/image/0000000222.png
224 | depth_selection/test_depth_prediction_anonymous/image/0000000223.png
225 | depth_selection/test_depth_prediction_anonymous/image/0000000224.png
226 | depth_selection/test_depth_prediction_anonymous/image/0000000225.png
227 | depth_selection/test_depth_prediction_anonymous/image/0000000226.png
228 | depth_selection/test_depth_prediction_anonymous/image/0000000227.png
229 | depth_selection/test_depth_prediction_anonymous/image/0000000228.png
230 | depth_selection/test_depth_prediction_anonymous/image/0000000229.png
231 | depth_selection/test_depth_prediction_anonymous/image/0000000230.png
232 | depth_selection/test_depth_prediction_anonymous/image/0000000231.png
233 | depth_selection/test_depth_prediction_anonymous/image/0000000232.png
234 | depth_selection/test_depth_prediction_anonymous/image/0000000233.png
235 | depth_selection/test_depth_prediction_anonymous/image/0000000234.png
236 | depth_selection/test_depth_prediction_anonymous/image/0000000235.png
237 | depth_selection/test_depth_prediction_anonymous/image/0000000236.png
238 | depth_selection/test_depth_prediction_anonymous/image/0000000237.png
239 | depth_selection/test_depth_prediction_anonymous/image/0000000238.png
240 | depth_selection/test_depth_prediction_anonymous/image/0000000239.png
241 | depth_selection/test_depth_prediction_anonymous/image/0000000240.png
242 | depth_selection/test_depth_prediction_anonymous/image/0000000241.png
243 | depth_selection/test_depth_prediction_anonymous/image/0000000242.png
244 | depth_selection/test_depth_prediction_anonymous/image/0000000243.png
245 | depth_selection/test_depth_prediction_anonymous/image/0000000244.png
246 | depth_selection/test_depth_prediction_anonymous/image/0000000245.png
247 | depth_selection/test_depth_prediction_anonymous/image/0000000246.png
248 | depth_selection/test_depth_prediction_anonymous/image/0000000247.png
249 | depth_selection/test_depth_prediction_anonymous/image/0000000248.png
250 | depth_selection/test_depth_prediction_anonymous/image/0000000249.png
251 | depth_selection/test_depth_prediction_anonymous/image/0000000250.png
252 | depth_selection/test_depth_prediction_anonymous/image/0000000251.png
253 | depth_selection/test_depth_prediction_anonymous/image/0000000252.png
254 | depth_selection/test_depth_prediction_anonymous/image/0000000253.png
255 | depth_selection/test_depth_prediction_anonymous/image/0000000254.png
256 | depth_selection/test_depth_prediction_anonymous/image/0000000255.png
257 | depth_selection/test_depth_prediction_anonymous/image/0000000256.png
258 | depth_selection/test_depth_prediction_anonymous/image/0000000257.png
259 | depth_selection/test_depth_prediction_anonymous/image/0000000258.png
260 | depth_selection/test_depth_prediction_anonymous/image/0000000259.png
261 | depth_selection/test_depth_prediction_anonymous/image/0000000260.png
262 | depth_selection/test_depth_prediction_anonymous/image/0000000261.png
263 | depth_selection/test_depth_prediction_anonymous/image/0000000262.png
264 | depth_selection/test_depth_prediction_anonymous/image/0000000263.png
265 | depth_selection/test_depth_prediction_anonymous/image/0000000264.png
266 | depth_selection/test_depth_prediction_anonymous/image/0000000265.png
267 | depth_selection/test_depth_prediction_anonymous/image/0000000266.png
268 | depth_selection/test_depth_prediction_anonymous/image/0000000267.png
269 | depth_selection/test_depth_prediction_anonymous/image/0000000268.png
270 | depth_selection/test_depth_prediction_anonymous/image/0000000269.png
271 | depth_selection/test_depth_prediction_anonymous/image/0000000270.png
272 | depth_selection/test_depth_prediction_anonymous/image/0000000271.png
273 | depth_selection/test_depth_prediction_anonymous/image/0000000272.png
274 | depth_selection/test_depth_prediction_anonymous/image/0000000273.png
275 | depth_selection/test_depth_prediction_anonymous/image/0000000274.png
276 | depth_selection/test_depth_prediction_anonymous/image/0000000275.png
277 | depth_selection/test_depth_prediction_anonymous/image/0000000276.png
278 | depth_selection/test_depth_prediction_anonymous/image/0000000277.png
279 | depth_selection/test_depth_prediction_anonymous/image/0000000278.png
280 | depth_selection/test_depth_prediction_anonymous/image/0000000279.png
281 | depth_selection/test_depth_prediction_anonymous/image/0000000280.png
282 | depth_selection/test_depth_prediction_anonymous/image/0000000281.png
283 | depth_selection/test_depth_prediction_anonymous/image/0000000282.png
284 | depth_selection/test_depth_prediction_anonymous/image/0000000283.png
285 | depth_selection/test_depth_prediction_anonymous/image/0000000284.png
286 | depth_selection/test_depth_prediction_anonymous/image/0000000285.png
287 | depth_selection/test_depth_prediction_anonymous/image/0000000286.png
288 | depth_selection/test_depth_prediction_anonymous/image/0000000287.png
289 | depth_selection/test_depth_prediction_anonymous/image/0000000288.png
290 | depth_selection/test_depth_prediction_anonymous/image/0000000289.png
291 | depth_selection/test_depth_prediction_anonymous/image/0000000290.png
292 | depth_selection/test_depth_prediction_anonymous/image/0000000291.png
293 | depth_selection/test_depth_prediction_anonymous/image/0000000292.png
294 | depth_selection/test_depth_prediction_anonymous/image/0000000293.png
295 | depth_selection/test_depth_prediction_anonymous/image/0000000294.png
296 | depth_selection/test_depth_prediction_anonymous/image/0000000295.png
297 | depth_selection/test_depth_prediction_anonymous/image/0000000296.png
298 | depth_selection/test_depth_prediction_anonymous/image/0000000297.png
299 | depth_selection/test_depth_prediction_anonymous/image/0000000298.png
300 | depth_selection/test_depth_prediction_anonymous/image/0000000299.png
301 | depth_selection/test_depth_prediction_anonymous/image/0000000300.png
302 | depth_selection/test_depth_prediction_anonymous/image/0000000301.png
303 | depth_selection/test_depth_prediction_anonymous/image/0000000302.png
304 | depth_selection/test_depth_prediction_anonymous/image/0000000303.png
305 | depth_selection/test_depth_prediction_anonymous/image/0000000304.png
306 | depth_selection/test_depth_prediction_anonymous/image/0000000305.png
307 | depth_selection/test_depth_prediction_anonymous/image/0000000306.png
308 | depth_selection/test_depth_prediction_anonymous/image/0000000307.png
309 | depth_selection/test_depth_prediction_anonymous/image/0000000308.png
310 | depth_selection/test_depth_prediction_anonymous/image/0000000309.png
311 | depth_selection/test_depth_prediction_anonymous/image/0000000310.png
312 | depth_selection/test_depth_prediction_anonymous/image/0000000311.png
313 | depth_selection/test_depth_prediction_anonymous/image/0000000312.png
314 | depth_selection/test_depth_prediction_anonymous/image/0000000313.png
315 | depth_selection/test_depth_prediction_anonymous/image/0000000314.png
316 | depth_selection/test_depth_prediction_anonymous/image/0000000315.png
317 | depth_selection/test_depth_prediction_anonymous/image/0000000316.png
318 | depth_selection/test_depth_prediction_anonymous/image/0000000317.png
319 | depth_selection/test_depth_prediction_anonymous/image/0000000318.png
320 | depth_selection/test_depth_prediction_anonymous/image/0000000319.png
321 | depth_selection/test_depth_prediction_anonymous/image/0000000320.png
322 | depth_selection/test_depth_prediction_anonymous/image/0000000321.png
323 | depth_selection/test_depth_prediction_anonymous/image/0000000322.png
324 | depth_selection/test_depth_prediction_anonymous/image/0000000323.png
325 | depth_selection/test_depth_prediction_anonymous/image/0000000324.png
326 | depth_selection/test_depth_prediction_anonymous/image/0000000325.png
327 | depth_selection/test_depth_prediction_anonymous/image/0000000326.png
328 | depth_selection/test_depth_prediction_anonymous/image/0000000327.png
329 | depth_selection/test_depth_prediction_anonymous/image/0000000328.png
330 | depth_selection/test_depth_prediction_anonymous/image/0000000329.png
331 | depth_selection/test_depth_prediction_anonymous/image/0000000330.png
332 | depth_selection/test_depth_prediction_anonymous/image/0000000331.png
333 | depth_selection/test_depth_prediction_anonymous/image/0000000332.png
334 | depth_selection/test_depth_prediction_anonymous/image/0000000333.png
335 | depth_selection/test_depth_prediction_anonymous/image/0000000334.png
336 | depth_selection/test_depth_prediction_anonymous/image/0000000335.png
337 | depth_selection/test_depth_prediction_anonymous/image/0000000336.png
338 | depth_selection/test_depth_prediction_anonymous/image/0000000337.png
339 | depth_selection/test_depth_prediction_anonymous/image/0000000338.png
340 | depth_selection/test_depth_prediction_anonymous/image/0000000339.png
341 | depth_selection/test_depth_prediction_anonymous/image/0000000340.png
342 | depth_selection/test_depth_prediction_anonymous/image/0000000341.png
343 | depth_selection/test_depth_prediction_anonymous/image/0000000342.png
344 | depth_selection/test_depth_prediction_anonymous/image/0000000343.png
345 | depth_selection/test_depth_prediction_anonymous/image/0000000344.png
346 | depth_selection/test_depth_prediction_anonymous/image/0000000345.png
347 | depth_selection/test_depth_prediction_anonymous/image/0000000346.png
348 | depth_selection/test_depth_prediction_anonymous/image/0000000347.png
349 | depth_selection/test_depth_prediction_anonymous/image/0000000348.png
350 | depth_selection/test_depth_prediction_anonymous/image/0000000349.png
351 | depth_selection/test_depth_prediction_anonymous/image/0000000350.png
352 | depth_selection/test_depth_prediction_anonymous/image/0000000351.png
353 | depth_selection/test_depth_prediction_anonymous/image/0000000352.png
354 | depth_selection/test_depth_prediction_anonymous/image/0000000353.png
355 | depth_selection/test_depth_prediction_anonymous/image/0000000354.png
356 | depth_selection/test_depth_prediction_anonymous/image/0000000355.png
357 | depth_selection/test_depth_prediction_anonymous/image/0000000356.png
358 | depth_selection/test_depth_prediction_anonymous/image/0000000357.png
359 | depth_selection/test_depth_prediction_anonymous/image/0000000358.png
360 | depth_selection/test_depth_prediction_anonymous/image/0000000359.png
361 | depth_selection/test_depth_prediction_anonymous/image/0000000360.png
362 | depth_selection/test_depth_prediction_anonymous/image/0000000361.png
363 | depth_selection/test_depth_prediction_anonymous/image/0000000362.png
364 | depth_selection/test_depth_prediction_anonymous/image/0000000363.png
365 | depth_selection/test_depth_prediction_anonymous/image/0000000364.png
366 | depth_selection/test_depth_prediction_anonymous/image/0000000365.png
367 | depth_selection/test_depth_prediction_anonymous/image/0000000366.png
368 | depth_selection/test_depth_prediction_anonymous/image/0000000367.png
369 | depth_selection/test_depth_prediction_anonymous/image/0000000368.png
370 | depth_selection/test_depth_prediction_anonymous/image/0000000369.png
371 | depth_selection/test_depth_prediction_anonymous/image/0000000370.png
372 | depth_selection/test_depth_prediction_anonymous/image/0000000371.png
373 | depth_selection/test_depth_prediction_anonymous/image/0000000372.png
374 | depth_selection/test_depth_prediction_anonymous/image/0000000373.png
375 | depth_selection/test_depth_prediction_anonymous/image/0000000374.png
376 | depth_selection/test_depth_prediction_anonymous/image/0000000375.png
377 | depth_selection/test_depth_prediction_anonymous/image/0000000376.png
378 | depth_selection/test_depth_prediction_anonymous/image/0000000377.png
379 | depth_selection/test_depth_prediction_anonymous/image/0000000378.png
380 | depth_selection/test_depth_prediction_anonymous/image/0000000379.png
381 | depth_selection/test_depth_prediction_anonymous/image/0000000380.png
382 | depth_selection/test_depth_prediction_anonymous/image/0000000381.png
383 | depth_selection/test_depth_prediction_anonymous/image/0000000382.png
384 | depth_selection/test_depth_prediction_anonymous/image/0000000383.png
385 | depth_selection/test_depth_prediction_anonymous/image/0000000384.png
386 | depth_selection/test_depth_prediction_anonymous/image/0000000385.png
387 | depth_selection/test_depth_prediction_anonymous/image/0000000386.png
388 | depth_selection/test_depth_prediction_anonymous/image/0000000387.png
389 | depth_selection/test_depth_prediction_anonymous/image/0000000388.png
390 | depth_selection/test_depth_prediction_anonymous/image/0000000389.png
391 | depth_selection/test_depth_prediction_anonymous/image/0000000390.png
392 | depth_selection/test_depth_prediction_anonymous/image/0000000391.png
393 | depth_selection/test_depth_prediction_anonymous/image/0000000392.png
394 | depth_selection/test_depth_prediction_anonymous/image/0000000393.png
395 | depth_selection/test_depth_prediction_anonymous/image/0000000394.png
396 | depth_selection/test_depth_prediction_anonymous/image/0000000395.png
397 | depth_selection/test_depth_prediction_anonymous/image/0000000396.png
398 | depth_selection/test_depth_prediction_anonymous/image/0000000397.png
399 | depth_selection/test_depth_prediction_anonymous/image/0000000398.png
400 | depth_selection/test_depth_prediction_anonymous/image/0000000399.png
401 | depth_selection/test_depth_prediction_anonymous/image/0000000400.png
402 | depth_selection/test_depth_prediction_anonymous/image/0000000401.png
403 | depth_selection/test_depth_prediction_anonymous/image/0000000402.png
404 | depth_selection/test_depth_prediction_anonymous/image/0000000403.png
405 | depth_selection/test_depth_prediction_anonymous/image/0000000404.png
406 | depth_selection/test_depth_prediction_anonymous/image/0000000405.png
407 | depth_selection/test_depth_prediction_anonymous/image/0000000406.png
408 | depth_selection/test_depth_prediction_anonymous/image/0000000407.png
409 | depth_selection/test_depth_prediction_anonymous/image/0000000408.png
410 | depth_selection/test_depth_prediction_anonymous/image/0000000409.png
411 | depth_selection/test_depth_prediction_anonymous/image/0000000410.png
412 | depth_selection/test_depth_prediction_anonymous/image/0000000411.png
413 | depth_selection/test_depth_prediction_anonymous/image/0000000412.png
414 | depth_selection/test_depth_prediction_anonymous/image/0000000413.png
415 | depth_selection/test_depth_prediction_anonymous/image/0000000414.png
416 | depth_selection/test_depth_prediction_anonymous/image/0000000415.png
417 | depth_selection/test_depth_prediction_anonymous/image/0000000416.png
418 | depth_selection/test_depth_prediction_anonymous/image/0000000417.png
419 | depth_selection/test_depth_prediction_anonymous/image/0000000418.png
420 | depth_selection/test_depth_prediction_anonymous/image/0000000419.png
421 | depth_selection/test_depth_prediction_anonymous/image/0000000420.png
422 | depth_selection/test_depth_prediction_anonymous/image/0000000421.png
423 | depth_selection/test_depth_prediction_anonymous/image/0000000422.png
424 | depth_selection/test_depth_prediction_anonymous/image/0000000423.png
425 | depth_selection/test_depth_prediction_anonymous/image/0000000424.png
426 | depth_selection/test_depth_prediction_anonymous/image/0000000425.png
427 | depth_selection/test_depth_prediction_anonymous/image/0000000426.png
428 | depth_selection/test_depth_prediction_anonymous/image/0000000427.png
429 | depth_selection/test_depth_prediction_anonymous/image/0000000428.png
430 | depth_selection/test_depth_prediction_anonymous/image/0000000429.png
431 | depth_selection/test_depth_prediction_anonymous/image/0000000430.png
432 | depth_selection/test_depth_prediction_anonymous/image/0000000431.png
433 | depth_selection/test_depth_prediction_anonymous/image/0000000432.png
434 | depth_selection/test_depth_prediction_anonymous/image/0000000433.png
435 | depth_selection/test_depth_prediction_anonymous/image/0000000434.png
436 | depth_selection/test_depth_prediction_anonymous/image/0000000435.png
437 | depth_selection/test_depth_prediction_anonymous/image/0000000436.png
438 | depth_selection/test_depth_prediction_anonymous/image/0000000437.png
439 | depth_selection/test_depth_prediction_anonymous/image/0000000438.png
440 | depth_selection/test_depth_prediction_anonymous/image/0000000439.png
441 | depth_selection/test_depth_prediction_anonymous/image/0000000440.png
442 | depth_selection/test_depth_prediction_anonymous/image/0000000441.png
443 | depth_selection/test_depth_prediction_anonymous/image/0000000442.png
444 | depth_selection/test_depth_prediction_anonymous/image/0000000443.png
445 | depth_selection/test_depth_prediction_anonymous/image/0000000444.png
446 | depth_selection/test_depth_prediction_anonymous/image/0000000445.png
447 | depth_selection/test_depth_prediction_anonymous/image/0000000446.png
448 | depth_selection/test_depth_prediction_anonymous/image/0000000447.png
449 | depth_selection/test_depth_prediction_anonymous/image/0000000448.png
450 | depth_selection/test_depth_prediction_anonymous/image/0000000449.png
451 | depth_selection/test_depth_prediction_anonymous/image/0000000450.png
452 | depth_selection/test_depth_prediction_anonymous/image/0000000451.png
453 | depth_selection/test_depth_prediction_anonymous/image/0000000452.png
454 | depth_selection/test_depth_prediction_anonymous/image/0000000453.png
455 | depth_selection/test_depth_prediction_anonymous/image/0000000454.png
456 | depth_selection/test_depth_prediction_anonymous/image/0000000455.png
457 | depth_selection/test_depth_prediction_anonymous/image/0000000456.png
458 | depth_selection/test_depth_prediction_anonymous/image/0000000457.png
459 | depth_selection/test_depth_prediction_anonymous/image/0000000458.png
460 | depth_selection/test_depth_prediction_anonymous/image/0000000459.png
461 | depth_selection/test_depth_prediction_anonymous/image/0000000460.png
462 | depth_selection/test_depth_prediction_anonymous/image/0000000461.png
463 | depth_selection/test_depth_prediction_anonymous/image/0000000462.png
464 | depth_selection/test_depth_prediction_anonymous/image/0000000463.png
465 | depth_selection/test_depth_prediction_anonymous/image/0000000464.png
466 | depth_selection/test_depth_prediction_anonymous/image/0000000465.png
467 | depth_selection/test_depth_prediction_anonymous/image/0000000466.png
468 | depth_selection/test_depth_prediction_anonymous/image/0000000467.png
469 | depth_selection/test_depth_prediction_anonymous/image/0000000468.png
470 | depth_selection/test_depth_prediction_anonymous/image/0000000469.png
471 | depth_selection/test_depth_prediction_anonymous/image/0000000470.png
472 | depth_selection/test_depth_prediction_anonymous/image/0000000471.png
473 | depth_selection/test_depth_prediction_anonymous/image/0000000472.png
474 | depth_selection/test_depth_prediction_anonymous/image/0000000473.png
475 | depth_selection/test_depth_prediction_anonymous/image/0000000474.png
476 | depth_selection/test_depth_prediction_anonymous/image/0000000475.png
477 | depth_selection/test_depth_prediction_anonymous/image/0000000476.png
478 | depth_selection/test_depth_prediction_anonymous/image/0000000477.png
479 | depth_selection/test_depth_prediction_anonymous/image/0000000478.png
480 | depth_selection/test_depth_prediction_anonymous/image/0000000479.png
481 | depth_selection/test_depth_prediction_anonymous/image/0000000480.png
482 | depth_selection/test_depth_prediction_anonymous/image/0000000481.png
483 | depth_selection/test_depth_prediction_anonymous/image/0000000482.png
484 | depth_selection/test_depth_prediction_anonymous/image/0000000483.png
485 | depth_selection/test_depth_prediction_anonymous/image/0000000484.png
486 | depth_selection/test_depth_prediction_anonymous/image/0000000485.png
487 | depth_selection/test_depth_prediction_anonymous/image/0000000486.png
488 | depth_selection/test_depth_prediction_anonymous/image/0000000487.png
489 | depth_selection/test_depth_prediction_anonymous/image/0000000488.png
490 | depth_selection/test_depth_prediction_anonymous/image/0000000489.png
491 | depth_selection/test_depth_prediction_anonymous/image/0000000490.png
492 | depth_selection/test_depth_prediction_anonymous/image/0000000491.png
493 | depth_selection/test_depth_prediction_anonymous/image/0000000492.png
494 | depth_selection/test_depth_prediction_anonymous/image/0000000493.png
495 | depth_selection/test_depth_prediction_anonymous/image/0000000494.png
496 | depth_selection/test_depth_prediction_anonymous/image/0000000495.png
497 | depth_selection/test_depth_prediction_anonymous/image/0000000496.png
498 | depth_selection/test_depth_prediction_anonymous/image/0000000497.png
499 | depth_selection/test_depth_prediction_anonymous/image/0000000498.png
500 | depth_selection/test_depth_prediction_anonymous/image/0000000499.png
501 |
--------------------------------------------------------------------------------
/files/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/intro.png
--------------------------------------------------------------------------------
/files/office_00633.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/office_00633.jpg
--------------------------------------------------------------------------------
/files/office_00633_depth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/office_00633_depth.jpg
--------------------------------------------------------------------------------
/files/office_00633_pcd.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/office_00633_pcd.jpg
--------------------------------------------------------------------------------
/files/output_nyu1_compressed.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/output_nyu1_compressed.gif
--------------------------------------------------------------------------------
/files/output_nyu2_compressed.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/output_nyu2_compressed.gif
--------------------------------------------------------------------------------
/newcrfs/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/newcrfs/dataloaders/__init__.py
--------------------------------------------------------------------------------
/newcrfs/dataloaders/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import torch.utils.data.distributed
4 | from torchvision import transforms
5 |
6 | import numpy as np
7 | from PIL import Image
8 | import os
9 | import random
10 |
11 | from utils import DistributedSamplerNoEvenlyDivisible
12 |
13 |
14 | def _is_pil_image(img):
15 | return isinstance(img, Image.Image)
16 |
17 |
18 | def _is_numpy_image(img):
19 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
20 |
21 |
22 | def preprocessing_transforms(mode):
23 | return transforms.Compose([
24 | ToTensor(mode=mode)
25 | ])
26 |
27 |
28 | class NewDataLoader(object):
29 | def __init__(self, args, mode):
30 | if mode == 'train':
31 | self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
32 | if args.distributed:
33 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples)
34 | else:
35 | self.train_sampler = None
36 |
37 | self.data = DataLoader(self.training_samples, args.batch_size,
38 | shuffle=(self.train_sampler is None),
39 | num_workers=args.num_threads,
40 | pin_memory=True,
41 | sampler=self.train_sampler)
42 |
43 | elif mode == 'online_eval':
44 | self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
45 | if args.distributed:
46 | # self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False)
47 | self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False)
48 | else:
49 | self.eval_sampler = None
50 | self.data = DataLoader(self.testing_samples, 1,
51 | shuffle=False,
52 | num_workers=1,
53 | pin_memory=True,
54 | sampler=self.eval_sampler)
55 |
56 | elif mode == 'test':
57 | self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
58 | self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1)
59 |
60 | else:
61 | print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
62 |
63 |
64 | class DataLoadPreprocess(Dataset):
65 | def __init__(self, args, mode, transform=None, is_for_online_eval=False):
66 | self.args = args
67 | if mode == 'online_eval':
68 | with open(args.filenames_file_eval, 'r') as f:
69 | self.filenames = f.readlines()
70 | else:
71 | with open(args.filenames_file, 'r') as f:
72 | self.filenames = f.readlines()
73 |
74 | self.mode = mode
75 | self.transform = transform
76 | self.to_tensor = ToTensor
77 | self.is_for_online_eval = is_for_online_eval
78 |
79 | def __getitem__(self, idx):
80 | sample_path = self.filenames[idx]
81 | # focal = float(sample_path.split()[2])
82 | focal = 518.8579
83 |
84 | if self.mode == 'train':
85 | if self.args.dataset == 'kitti':
86 | rgb_file = sample_path.split()[0]
87 | depth_file = os.path.join(sample_path.split()[0].split('/')[0], sample_path.split()[1])
88 | if self.args.use_right is True and random.random() > 0.5:
89 | rgb_file.replace('image_02', 'image_03')
90 | depth_file.replace('image_02', 'image_03')
91 | else:
92 | rgb_file = sample_path.split()[0]
93 | depth_file = sample_path.split()[1]
94 |
95 | image_path = os.path.join(self.args.data_path, rgb_file)
96 | depth_path = os.path.join(self.args.gt_path, depth_file)
97 |
98 | image = Image.open(image_path)
99 | depth_gt = Image.open(depth_path)
100 |
101 | if self.args.do_kb_crop is True:
102 | height = image.height
103 | width = image.width
104 | top_margin = int(height - 352)
105 | left_margin = int((width - 1216) / 2)
106 | depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
107 | image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
108 |
109 | # To avoid blank boundaries due to pixel registration
110 | if self.args.dataset == 'nyu':
111 | if self.args.input_height == 480:
112 | depth_gt = np.array(depth_gt)
113 | valid_mask = np.zeros_like(depth_gt)
114 | valid_mask[45:472, 43:608] = 1
115 | depth_gt[valid_mask==0] = 0
116 | depth_gt = Image.fromarray(depth_gt)
117 | else:
118 | depth_gt = depth_gt.crop((43, 45, 608, 472))
119 | image = image.crop((43, 45, 608, 472))
120 |
121 | if self.args.do_random_rotate is True:
122 | random_angle = (random.random() - 0.5) * 2 * self.args.degree
123 | image = self.rotate_image(image, random_angle)
124 | depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST)
125 |
126 | image = np.asarray(image, dtype=np.float32) / 255.0
127 | depth_gt = np.asarray(depth_gt, dtype=np.float32)
128 | depth_gt = np.expand_dims(depth_gt, axis=2)
129 |
130 | if self.args.dataset == 'nyu':
131 | depth_gt = depth_gt / 1000.0
132 | else:
133 | depth_gt = depth_gt / 256.0
134 |
135 | if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width:
136 | image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width)
137 | image, depth_gt = self.train_preprocess(image, depth_gt)
138 | sample = {'image': image, 'depth': depth_gt, 'focal': focal}
139 |
140 | else:
141 | if self.mode == 'online_eval':
142 | data_path = self.args.data_path_eval
143 | else:
144 | data_path = self.args.data_path
145 |
146 | image_path = os.path.join(data_path, "./" + sample_path.split()[0])
147 | image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
148 |
149 | if self.mode == 'online_eval':
150 | gt_path = self.args.gt_path_eval
151 | depth_path = os.path.join(gt_path, "./" + sample_path.split()[1])
152 | if self.args.dataset == 'kitti':
153 | depth_path = os.path.join(gt_path, sample_path.split()[0].split('/')[0], sample_path.split()[1])
154 | has_valid_depth = False
155 | try:
156 | depth_gt = Image.open(depth_path)
157 | has_valid_depth = True
158 | except IOError:
159 | depth_gt = False
160 | # print('Missing gt for {}'.format(image_path))
161 |
162 | if has_valid_depth:
163 | depth_gt = np.asarray(depth_gt, dtype=np.float32)
164 | depth_gt = np.expand_dims(depth_gt, axis=2)
165 | if self.args.dataset == 'nyu':
166 | depth_gt = depth_gt / 1000.0
167 | else:
168 | depth_gt = depth_gt / 256.0
169 |
170 | if self.args.do_kb_crop is True:
171 | height = image.shape[0]
172 | width = image.shape[1]
173 | top_margin = int(height - 352)
174 | left_margin = int((width - 1216) / 2)
175 | image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
176 | if self.mode == 'online_eval' and has_valid_depth:
177 | depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
178 |
179 | if self.mode == 'online_eval':
180 | sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth}
181 | else:
182 | sample = {'image': image, 'focal': focal}
183 |
184 | if self.transform:
185 | sample = self.transform(sample)
186 |
187 | return sample
188 |
189 | def rotate_image(self, image, angle, flag=Image.BILINEAR):
190 | result = image.rotate(angle, resample=flag)
191 | return result
192 |
193 | def random_crop(self, img, depth, height, width):
194 | assert img.shape[0] >= height
195 | assert img.shape[1] >= width
196 | assert img.shape[0] == depth.shape[0]
197 | assert img.shape[1] == depth.shape[1]
198 | x = random.randint(0, img.shape[1] - width)
199 | y = random.randint(0, img.shape[0] - height)
200 | img = img[y:y + height, x:x + width, :]
201 | depth = depth[y:y + height, x:x + width, :]
202 | return img, depth
203 |
204 | def train_preprocess(self, image, depth_gt):
205 | # Random flipping
206 | do_flip = random.random()
207 | if do_flip > 0.5:
208 | image = (image[:, ::-1, :]).copy()
209 | depth_gt = (depth_gt[:, ::-1, :]).copy()
210 |
211 | # Random gamma, brightness, color augmentation
212 | do_augment = random.random()
213 | if do_augment > 0.5:
214 | image = self.augment_image(image)
215 |
216 | return image, depth_gt
217 |
218 | def augment_image(self, image):
219 | # gamma augmentation
220 | gamma = random.uniform(0.9, 1.1)
221 | image_aug = image ** gamma
222 |
223 | # brightness augmentation
224 | if self.args.dataset == 'nyu':
225 | brightness = random.uniform(0.75, 1.25)
226 | else:
227 | brightness = random.uniform(0.9, 1.1)
228 | image_aug = image_aug * brightness
229 |
230 | # color augmentation
231 | colors = np.random.uniform(0.9, 1.1, size=3)
232 | white = np.ones((image.shape[0], image.shape[1]))
233 | color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
234 | image_aug *= color_image
235 | image_aug = np.clip(image_aug, 0, 1)
236 |
237 | return image_aug
238 |
239 | def __len__(self):
240 | return len(self.filenames)
241 |
242 |
243 | class ToTensor(object):
244 | def __init__(self, mode):
245 | self.mode = mode
246 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
247 |
248 | def __call__(self, sample):
249 | image, focal = sample['image'], sample['focal']
250 | image = self.to_tensor(image)
251 | image = self.normalize(image)
252 |
253 | if self.mode == 'test':
254 | return {'image': image, 'focal': focal}
255 |
256 | depth = sample['depth']
257 | if self.mode == 'train':
258 | depth = self.to_tensor(depth)
259 | return {'image': image, 'depth': depth, 'focal': focal}
260 | else:
261 | has_valid_depth = sample['has_valid_depth']
262 | return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth}
263 |
264 | def to_tensor(self, pic):
265 | if not (_is_pil_image(pic) or _is_numpy_image(pic)):
266 | raise TypeError(
267 | 'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
268 |
269 | if isinstance(pic, np.ndarray):
270 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
271 | return img
272 |
273 | # handle PIL Image
274 | if pic.mode == 'I':
275 | img = torch.from_numpy(np.array(pic, np.int32, copy=False))
276 | elif pic.mode == 'I;16':
277 | img = torch.from_numpy(np.array(pic, np.int16, copy=False))
278 | else:
279 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
280 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
281 | if pic.mode == 'YCbCr':
282 | nchannel = 3
283 | elif pic.mode == 'I;16':
284 | nchannel = 1
285 | else:
286 | nchannel = len(pic.mode)
287 | img = img.view(pic.size[1], pic.size[0], nchannel)
288 |
289 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
290 | if isinstance(img, torch.ByteTensor):
291 | return img.float()
292 | else:
293 | return img
294 |
--------------------------------------------------------------------------------
/newcrfs/dataloaders/dataloader_kittipred.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import torch.utils.data.distributed
4 | from torchvision import transforms
5 |
6 | import numpy as np
7 | from PIL import Image
8 | import os
9 | import random
10 |
11 | from utils import DistributedSamplerNoEvenlyDivisible
12 |
13 |
14 | def _is_pil_image(img):
15 | return isinstance(img, Image.Image)
16 |
17 |
18 | def _is_numpy_image(img):
19 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
20 |
21 |
22 | def preprocessing_transforms(mode):
23 | return transforms.Compose([
24 | ToTensor(mode=mode)
25 | ])
26 |
27 |
28 | class NewDataLoader(object):
29 | def __init__(self, args, mode):
30 | if mode == 'train':
31 | self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
32 | if args.distributed:
33 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples)
34 | else:
35 | self.train_sampler = None
36 |
37 | self.data = DataLoader(self.training_samples, args.batch_size,
38 | shuffle=(self.train_sampler is None),
39 | num_workers=args.num_threads,
40 | pin_memory=True,
41 | sampler=self.train_sampler)
42 |
43 | elif mode == 'online_eval':
44 | self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
45 | if args.distributed:
46 | # self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False)
47 | self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False)
48 | else:
49 | self.eval_sampler = None
50 | self.data = DataLoader(self.testing_samples, 1,
51 | shuffle=False,
52 | num_workers=1,
53 | pin_memory=True,
54 | sampler=self.eval_sampler)
55 |
56 | elif mode == 'test':
57 | self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
58 | self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1)
59 |
60 | else:
61 | print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
62 |
63 |
64 | class DataLoadPreprocess(Dataset):
65 | def __init__(self, args, mode, transform=None, is_for_online_eval=False):
66 | self.args = args
67 | if mode == 'online_eval':
68 | with open(args.filenames_file_eval, 'r') as f:
69 | self.filenames = f.readlines()
70 | else:
71 | with open(args.filenames_file, 'r') as f:
72 | self.filenames = f.readlines()
73 |
74 | self.mode = mode
75 | self.transform = transform
76 | self.to_tensor = ToTensor
77 | self.is_for_online_eval = is_for_online_eval
78 |
79 | def __getitem__(self, idx):
80 | sample_path = self.filenames[idx]
81 | # focal = float(sample_path.split()[2])
82 | focal = 518.8579
83 |
84 | if self.mode == 'train':
85 | rgb_file = sample_path.split()[0]
86 | depth_file = rgb_file.replace('/image_02/data/', '/proj_depth/groundtruth/image_02/')
87 | if self.args.use_right is True and random.random() > 0.5:
88 | rgb_file.replace('image_02', 'image_03')
89 | depth_file.replace('image_02', 'image_03')
90 |
91 | image_path = os.path.join(self.args.data_path, rgb_file)
92 | depth_path = os.path.join(self.args.gt_path, depth_file)
93 |
94 | image = Image.open(image_path)
95 | depth_gt = Image.open(depth_path)
96 |
97 | if self.args.do_kb_crop is True:
98 | height = image.height
99 | width = image.width
100 | top_margin = int(height - 352)
101 | left_margin = int((width - 1216) / 2)
102 | depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
103 | image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
104 |
105 | if self.args.do_random_rotate is True:
106 | random_angle = (random.random() - 0.5) * 2 * self.args.degree
107 | image = self.rotate_image(image, random_angle)
108 | depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST)
109 |
110 | image = np.asarray(image, dtype=np.float32) / 255.0
111 | depth_gt = np.asarray(depth_gt, dtype=np.float32)
112 | depth_gt = np.expand_dims(depth_gt, axis=2)
113 |
114 | depth_gt = depth_gt / 256.0
115 |
116 | if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width:
117 | image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width)
118 | image, depth_gt = self.train_preprocess(image, depth_gt)
119 | sample = {'image': image, 'depth': depth_gt, 'focal': focal}
120 |
121 | else:
122 | if self.mode == 'online_eval':
123 | data_path = self.args.data_path_eval
124 | else:
125 | data_path = self.args.data_path
126 |
127 | image_path = os.path.join(data_path, "./" + sample_path.split()[0])
128 | image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
129 |
130 | if self.mode == 'online_eval':
131 | gt_path = self.args.gt_path_eval
132 | depth_path = image_path.replace('/image/', '/groundtruth_depth/').replace('sync_image', 'sync_groundtruth_depth')
133 | has_valid_depth = False
134 | try:
135 | depth_gt = Image.open(depth_path)
136 | has_valid_depth = True
137 | except IOError:
138 | depth_gt = False
139 | # print('Missing gt for {}'.format(image_path))
140 |
141 | if has_valid_depth:
142 | depth_gt = np.asarray(depth_gt, dtype=np.float32)
143 | depth_gt = np.expand_dims(depth_gt, axis=2)
144 | if self.args.dataset == 'nyu':
145 | depth_gt = depth_gt / 1000.0
146 | else:
147 | depth_gt = depth_gt / 256.0
148 |
149 | if self.args.do_kb_crop is True:
150 | height = image.shape[0]
151 | width = image.shape[1]
152 | top_margin = int(height - 352)
153 | left_margin = int((width - 1216) / 2)
154 | image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
155 | if self.mode == 'online_eval' and has_valid_depth:
156 | depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
157 |
158 | if self.mode == 'online_eval':
159 | sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth}
160 | else:
161 | sample = {'image': image, 'focal': focal}
162 |
163 | if self.transform:
164 | sample = self.transform(sample)
165 |
166 | return sample
167 |
168 | def rotate_image(self, image, angle, flag=Image.BILINEAR):
169 | result = image.rotate(angle, resample=flag)
170 | return result
171 |
172 | def random_crop(self, img, depth, height, width):
173 | assert img.shape[0] >= height
174 | assert img.shape[1] >= width
175 | assert img.shape[0] == depth.shape[0]
176 | assert img.shape[1] == depth.shape[1]
177 | x = random.randint(0, img.shape[1] - width)
178 | y = random.randint(0, img.shape[0] - height)
179 | img = img[y:y + height, x:x + width, :]
180 | depth = depth[y:y + height, x:x + width, :]
181 | return img, depth
182 |
183 | def train_preprocess(self, image, depth_gt):
184 | # Random flipping
185 | do_flip = random.random()
186 | if do_flip > 0.5:
187 | image = (image[:, ::-1, :]).copy()
188 | depth_gt = (depth_gt[:, ::-1, :]).copy()
189 |
190 | # Random gamma, brightness, color augmentation
191 | do_augment = random.random()
192 | if do_augment > 0.5:
193 | image = self.augment_image(image)
194 |
195 | return image, depth_gt
196 |
197 | def augment_image(self, image):
198 | # gamma augmentation
199 | gamma = random.uniform(0.9, 1.1)
200 | image_aug = image ** gamma
201 |
202 | # brightness augmentation
203 | if self.args.dataset == 'nyu':
204 | brightness = random.uniform(0.75, 1.25)
205 | else:
206 | brightness = random.uniform(0.9, 1.1)
207 | image_aug = image_aug * brightness
208 |
209 | # color augmentation
210 | colors = np.random.uniform(0.9, 1.1, size=3)
211 | white = np.ones((image.shape[0], image.shape[1]))
212 | color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
213 | image_aug *= color_image
214 | image_aug = np.clip(image_aug, 0, 1)
215 |
216 | return image_aug
217 |
218 | def __len__(self):
219 | return len(self.filenames)
220 |
221 |
222 | class ToTensor(object):
223 | def __init__(self, mode):
224 | self.mode = mode
225 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
226 |
227 | def __call__(self, sample):
228 | image, focal = sample['image'], sample['focal']
229 | image = self.to_tensor(image)
230 | image = self.normalize(image)
231 |
232 | if self.mode == 'test':
233 | return {'image': image, 'focal': focal}
234 |
235 | depth = sample['depth']
236 | if self.mode == 'train':
237 | depth = self.to_tensor(depth)
238 | return {'image': image, 'depth': depth, 'focal': focal}
239 | else:
240 | has_valid_depth = sample['has_valid_depth']
241 | return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth}
242 |
243 | def to_tensor(self, pic):
244 | if not (_is_pil_image(pic) or _is_numpy_image(pic)):
245 | raise TypeError(
246 | 'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
247 |
248 | if isinstance(pic, np.ndarray):
249 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
250 | return img
251 |
252 | # handle PIL Image
253 | if pic.mode == 'I':
254 | img = torch.from_numpy(np.array(pic, np.int32, copy=False))
255 | elif pic.mode == 'I;16':
256 | img = torch.from_numpy(np.array(pic, np.int16, copy=False))
257 | else:
258 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
259 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
260 | if pic.mode == 'YCbCr':
261 | nchannel = 3
262 | elif pic.mode == 'I;16':
263 | nchannel = 1
264 | else:
265 | nchannel = len(pic.mode)
266 | img = img.view(pic.size[1], pic.size[0], nchannel)
267 |
268 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
269 | if isinstance(img, torch.ByteTensor):
270 | return img.float()
271 | else:
272 | return img
273 |
--------------------------------------------------------------------------------
/newcrfs/demo.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import torch
3 | import torch.nn as nn
4 | import torch.backends.cudnn as cudnn
5 | from torch.autograd import Variable
6 |
7 | import os
8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
9 | import sys
10 | import time
11 | import argparse
12 | import numpy as np
13 |
14 | import cv2
15 | from scipy import ndimage
16 | from skimage.transform import resize
17 | import matplotlib.pyplot as plt
18 |
19 | plasma = plt.get_cmap('plasma')
20 | greys = plt.get_cmap('Greys')
21 |
22 | # UI and OpenGL
23 | from PySide2 import QtCore, QtGui, QtWidgets, QtOpenGL
24 | from OpenGL import GL, GLU
25 | from OpenGL.arrays import vbo
26 | from OpenGL.GL import shaders
27 | import glm
28 |
29 | from utils import post_process_depth, flip_lr
30 | from networks.NewCRFDepth import NewCRFDepth
31 |
32 |
33 | # Argument Parser
34 | parser = argparse.ArgumentParser(description='NeWCRFs Live 3D')
35 | parser.add_argument('--model_name', type=str, help='model name', default='newcrfs')
36 | parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07')
37 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
38 | parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', required=True)
39 | parser.add_argument('--input_height', type=int, help='input height', default=480)
40 | parser.add_argument('--input_width', type=int, help='input width', default=640)
41 | parser.add_argument('--dataset', type=str, help='dataset this model trained on', default='nyu')
42 | parser.add_argument('--crop', type=str, help='crop: kbcrop, edge, non', default='non')
43 | parser.add_argument('--video', type=str, help='video path', default='')
44 |
45 | args = parser.parse_args()
46 |
47 | # Image shapes
48 | height_rgb, width_rgb = args.input_height, args.input_width
49 | height_depth, width_depth = height_rgb, width_rgb
50 |
51 |
52 | # =============== Intrinsics rectify ==================
53 | # Open this if you have the real intrinsics
54 | Use_intrs_remap = False
55 | # Intrinsic parameters for your own webcam camera
56 | camera_matrix = np.zeros(shape=(3, 3))
57 | camera_matrix[0, 0] = 5.4765313594010649e+02
58 | camera_matrix[0, 2] = 3.2516069906172453e+02
59 | camera_matrix[1, 1] = 5.4801781476172562e+02
60 | camera_matrix[1, 2] = 2.4794113960783835e+02
61 | camera_matrix[2, 2] = 1
62 | dist_coeffs = np.array([ 3.7230261423972011e-02, -1.6171708069773008e-01, -3.5260752900266357e-04, 1.7161234226767313e-04, 1.0192711400840315e-01 ])
63 | # Parameters for a model trained on NYU Depth V2
64 | new_camera_matrix = np.zeros(shape=(3, 3))
65 | new_camera_matrix[0, 0] = 518.8579
66 | new_camera_matrix[0, 2] = 320
67 | new_camera_matrix[1, 1] = 518.8579
68 | new_camera_matrix[1, 2] = 240
69 | new_camera_matrix[2, 2] = 1
70 |
71 | R = np.identity(3, dtype=np.float)
72 | map1, map2 = cv2.initUndistortRectifyMap(camera_matrix, dist_coeffs, R, new_camera_matrix, (width_rgb, height_rgb), cv2.CV_32FC1)
73 |
74 |
75 | def load_model():
76 | args.mode = 'test'
77 | model = NewCRFDepth(version='large07', inv_depth=False, max_depth=args.max_depth)
78 | model = torch.nn.DataParallel(model)
79 |
80 | checkpoint = torch.load(args.checkpoint_path)
81 | model.load_state_dict(checkpoint['model'])
82 | model.eval()
83 | model.cuda()
84 |
85 | return model
86 |
87 | # Function timing
88 | ticTime = time.time()
89 |
90 |
91 | def tic():
92 | global ticTime;
93 | ticTime = time.time()
94 |
95 |
96 | def toc():
97 | print('{0} seconds.'.format(time.time() - ticTime))
98 |
99 |
100 | # Conversion from Numpy to QImage and back
101 | def np_to_qimage(a):
102 | im = a.copy()
103 | return QtGui.QImage(im.data, im.shape[1], im.shape[0], im.strides[0], QtGui.QImage.Format_RGB888).copy()
104 |
105 |
106 | def qimage_to_np(img):
107 | img = img.convertToFormat(QtGui.QImage.Format.Format_ARGB32)
108 | return np.array(img.constBits()).reshape(img.height(), img.width(), 4)
109 |
110 |
111 | # Compute edge magnitudes
112 | def edges(d):
113 | dx = ndimage.sobel(d, 0) # horizontal derivative
114 | dy = ndimage.sobel(d, 1) # vertical derivative
115 | return np.abs(dx) + np.abs(dy)
116 |
117 |
118 | # Main window
119 | class Window(QtWidgets.QWidget):
120 | updateInput = QtCore.Signal()
121 |
122 | def __init__(self, parent=None):
123 | QtWidgets.QWidget.__init__(self, parent)
124 | self.model = None
125 | self.capture = None
126 | self.glWidget = GLWidget()
127 |
128 | mainLayout = QtWidgets.QVBoxLayout()
129 |
130 | # Input / output views
131 | viewsLayout = QtWidgets.QGridLayout()
132 | self.inputViewer = QtWidgets.QLabel("[Click to start]")
133 | self.inputViewer.setPixmap(QtGui.QPixmap(width_rgb, height_rgb))
134 | self.outputViewer = QtWidgets.QLabel("[Click to start]")
135 | self.outputViewer.setPixmap(QtGui.QPixmap(width_rgb, height_rgb))
136 |
137 | imgsFrame = QtWidgets.QFrame()
138 | inputsLayout = QtWidgets.QVBoxLayout()
139 | imgsFrame.setLayout(inputsLayout)
140 | inputsLayout.addWidget(self.inputViewer)
141 | inputsLayout.addWidget(self.outputViewer)
142 |
143 | viewsLayout.addWidget(imgsFrame, 0, 0)
144 | viewsLayout.addWidget(self.glWidget, 0, 1)
145 | viewsLayout.setColumnStretch(1, 10)
146 | mainLayout.addLayout(viewsLayout)
147 |
148 | # Load depth estimation model
149 | toolsLayout = QtWidgets.QHBoxLayout()
150 |
151 | self.button2 = QtWidgets.QPushButton("Webcam")
152 | self.button2.clicked.connect(self.loadCamera)
153 | toolsLayout.addWidget(self.button2)
154 |
155 | self.button3 = QtWidgets.QPushButton("Video")
156 | self.button3.clicked.connect(self.loadVideoFile)
157 | toolsLayout.addWidget(self.button3)
158 |
159 | self.button4 = QtWidgets.QPushButton("Pause")
160 | self.button4.clicked.connect(self.loadImage)
161 | toolsLayout.addWidget(self.button4)
162 |
163 | self.button6 = QtWidgets.QPushButton("Refresh")
164 | self.button6.clicked.connect(self.updateCloud)
165 | toolsLayout.addWidget(self.button6)
166 |
167 | mainLayout.addLayout(toolsLayout)
168 |
169 | self.setLayout(mainLayout)
170 | self.setWindowTitle(self.tr("NeWCRFs Live"))
171 |
172 | # Signals
173 | self.updateInput.connect(self.update_input)
174 |
175 | # Default example
176 | if self.glWidget.rgb.any() and self.glWidget.depth.any():
177 | img = (self.glWidget.rgb * 255).astype('uint8')
178 | self.inputViewer.setPixmap(QtGui.QPixmap.fromImage(np_to_qimage(img)))
179 | coloredDepth = (plasma(self.glWidget.depth[:, :, 0])[:, :, :3] * 255).astype('uint8')
180 | self.outputViewer.setPixmap(QtGui.QPixmap.fromImage(np_to_qimage(coloredDepth)))
181 |
182 | def loadModel(self):
183 | print('== loadModel')
184 | QtGui.QGuiApplication.setOverrideCursor(QtCore.Qt.WaitCursor)
185 | tic()
186 | self.model = load_model()
187 | print('Model loaded.')
188 | toc()
189 | self.updateCloud()
190 | QtGui.QGuiApplication.restoreOverrideCursor()
191 |
192 | def loadCamera(self):
193 | print('== loadCamera')
194 | tic()
195 | self.model = load_model()
196 | print('Model loaded.')
197 | toc()
198 | self.capture = cv2.VideoCapture(0)
199 | self.updateInput.emit()
200 |
201 | def loadVideoFile(self):
202 | print('== loadVideoFile')
203 | self.model = load_model()
204 | self.capture = cv2.VideoCapture(args.video)
205 | self.updateInput.emit()
206 |
207 | def loadImage(self):
208 | print('== loadImage')
209 | self.capture = None
210 | img = (self.glWidget.rgb * 255).astype('uint8')
211 | self.inputViewer.setPixmap(QtGui.QPixmap.fromImage(np_to_qimage(img)))
212 | self.updateCloud()
213 |
214 | def loadImageFile(self):
215 | print('== loadImageFile')
216 | self.capture = None
217 | filename = \
218 | QtWidgets.QFileDialog.getOpenFileName(None, 'Select image', '', self.tr('Image files (*.jpg *.png)'))[0]
219 | img = QtGui.QImage(filename).scaledToHeight(height_rgb)
220 | xstart = 0
221 | if img.width() > width_rgb: xstart = (img.width() - width_rgb) // 2
222 | img = img.copy(xstart, 0, xstart + width_rgb, height_rgb)
223 | self.inputViewer.setPixmap(QtGui.QPixmap.fromImage(img))
224 | print('== loadImageFile')
225 | self.updateCloud()
226 |
227 | def update_input(self):
228 | print('== update_input')
229 | # Don't update anymore if no capture device is set
230 | if self.capture == None:
231 | return
232 |
233 | # Capture a frame
234 | ret, frame = self.capture.read()
235 |
236 | # Loop video playback if current stream is video file
237 | if not ret:
238 | self.capture.set(cv2.CAP_PROP_POS_FRAMES, 0)
239 | ret, frame = self.capture.read()
240 |
241 | # Prepare image and show in UI
242 | if Use_intrs_remap:
243 | frame_ud = cv2.remap(frame, map1, map2, interpolation=cv2.INTER_LINEAR)
244 | else:
245 | frame_ud = cv2.resize(frame, (width_rgb, height_rgb), interpolation=cv2.INTER_LINEAR)
246 | frame = cv2.cvtColor(frame_ud, cv2.COLOR_BGR2RGB)
247 | image = np_to_qimage(frame)
248 | self.inputViewer.setPixmap(QtGui.QPixmap.fromImage(image))
249 |
250 | # Update the point cloud
251 | self.updateCloud()
252 |
253 | def updateCloud(self):
254 | print('== updateCloud')
255 | rgb8 = qimage_to_np(self.inputViewer.pixmap().toImage())
256 | self.glWidget.rgb = (rgb8[:, :, :3] / 255)[:, :, ::-1]
257 |
258 | if self.model:
259 | input_image = rgb8[:, :, :3].astype(np.float32)
260 |
261 | # Normalize image
262 | input_image[:, :, 0] = (input_image[:, :, 0] - 123.68) * 0.017
263 | input_image[:, :, 1] = (input_image[:, :, 1] - 116.78) * 0.017
264 | input_image[:, :, 2] = (input_image[:, :, 2] - 103.94) * 0.017
265 |
266 | H, W, _ = input_image.shape
267 | if args.crop == 'kbcrop':
268 | top_margin = int(H - 352)
269 | left_margin = int((W - 1216) / 2)
270 | input_image_cropped = input_image[top_margin:top_margin + 352,
271 | left_margin:left_margin + 1216]
272 | elif args.crop == 'edge':
273 | input_image_cropped = input_image[32:-32, 32:-32, :]
274 | else:
275 | input_image_cropped = input_image
276 |
277 | input_images = np.expand_dims(input_image_cropped, axis=0)
278 | input_images = np.transpose(input_images, (0, 3, 1, 2))
279 |
280 | with torch.no_grad():
281 | image = Variable(torch.from_numpy(input_images)).cuda()
282 | # Predict
283 | depth_est = self.model(image)
284 | post_process = True
285 | if post_process:
286 | image_flipped = flip_lr(image)
287 | depth_est_flipped = self.model(image_flipped)
288 | depth_cropped = post_process_depth(depth_est, depth_est_flipped)
289 |
290 | depth = np.zeros((height_depth, width_depth), dtype=np.float32)
291 | if args.crop == 'kbcrop':
292 | depth[top_margin:top_margin + 352, left_margin:left_margin + 1216] = \
293 | depth_cropped[0].cpu().squeeze() / args.max_depth
294 | elif args.crop == 'edge':
295 | depth[32:-32, 32:-32] = depth_cropped[0].cpu().squeeze() / args.max_depth
296 | else:
297 | depth[:, :] = depth_cropped[0].cpu().squeeze() / args.max_depth
298 |
299 | coloredDepth = (greys(np.log10(depth * args.max_depth))[:, :, :3] * 255).astype('uint8')
300 | self.outputViewer.setPixmap(QtGui.QPixmap.fromImage(np_to_qimage(coloredDepth)))
301 | self.glWidget.depth = depth
302 |
303 | else:
304 | self.glWidget.depth = 0.5 + np.zeros((height_rgb // 2, width_rgb // 2, 1))
305 |
306 | self.glWidget.updateRGBD()
307 | self.glWidget.updateGL()
308 |
309 | # Update to next frame if we are live
310 | QtCore.QTimer.singleShot(10, self.updateInput)
311 |
312 |
313 | class GLWidget(QtOpenGL.QGLWidget):
314 | def __init__(self, parent=None):
315 | QtOpenGL.QGLWidget.__init__(self, parent)
316 |
317 | self.object = 0
318 | self.xRot = 5040
319 | self.yRot = 40
320 | self.zRot = 0
321 | self.zoomLevel = 9
322 |
323 | self.lastPos = QtCore.QPoint()
324 |
325 | self.green = QtGui.QColor.fromCmykF(0.0, 0.0, 0.0, 1.0)
326 | self.black = QtGui.QColor.fromCmykF(0.0, 0.0, 0.0, 1.0)
327 |
328 | # Precompute for world coordinates
329 | self.xx, self.yy = self.worldCoords(width=width_rgb, height=height_rgb)
330 |
331 | self.rgb = np.zeros((height_rgb, width_rgb, 3), dtype=np.uint8)
332 | self.depth = np.zeros((height_depth, height_depth), dtype=np.float32)
333 |
334 | self.col_vbo = None
335 | self.pos_vbo = None
336 | if self.rgb.any() and self.depth.any():
337 | self.updateRGBD()
338 |
339 | def xRotation(self):
340 | return self.xRot
341 |
342 | def yRotation(self):
343 | return self.yRot
344 |
345 | def zRotation(self):
346 | return self.zRot
347 |
348 | def minimumSizeHint(self):
349 | return QtCore.QSize(height_rgb, width_rgb)
350 |
351 | def sizeHint(self):
352 | return QtCore.QSize(height_rgb, width_rgb)
353 |
354 | def setXRotation(self, angle):
355 | if angle != self.xRot:
356 | self.xRot = angle
357 | self.emit(QtCore.SIGNAL("xRotationChanged(int)"), angle)
358 | self.updateGL()
359 |
360 | def setYRotation(self, angle):
361 | if angle != self.yRot:
362 | self.yRot = angle
363 | self.emit(QtCore.SIGNAL("yRotationChanged(int)"), angle)
364 | self.updateGL()
365 |
366 | def setZRotation(self, angle):
367 | if angle != self.zRot:
368 | self.zRot = angle
369 | self.emit(QtCore.SIGNAL("zRotationChanged(int)"), angle)
370 | self.updateGL()
371 |
372 | def resizeGL(self, width, height):
373 | GL.glViewport(0, 0, width, height)
374 |
375 | def mousePressEvent(self, event):
376 | self.lastPos = QtCore.QPoint(event.pos())
377 |
378 | def mouseMoveEvent(self, event):
379 | dx = -(event.x() - self.lastPos.x())
380 | dy = (event.y() - self.lastPos.y())
381 |
382 | if event.buttons() & QtCore.Qt.LeftButton:
383 | self.setXRotation(self.xRot + dy)
384 | self.setYRotation(self.yRot + dx)
385 | elif event.buttons() & QtCore.Qt.RightButton:
386 | self.setXRotation(self.xRot + dy)
387 | self.setZRotation(self.zRot + dx)
388 |
389 | self.lastPos = QtCore.QPoint(event.pos())
390 |
391 | def wheelEvent(self, event):
392 | numDegrees = event.delta() / 8
393 | numSteps = numDegrees / 15
394 | self.zoomLevel = self.zoomLevel + numSteps
395 | event.accept()
396 | self.updateGL()
397 |
398 | def initializeGL(self):
399 | self.qglClearColor(self.black.darker())
400 | GL.glShadeModel(GL.GL_FLAT)
401 | GL.glEnable(GL.GL_DEPTH_TEST)
402 | GL.glEnable(GL.GL_CULL_FACE)
403 |
404 | VERTEX_SHADER = shaders.compileShader("""#version 330
405 | layout(location = 0) in vec3 position;
406 | layout(location = 1) in vec3 color;
407 | uniform mat4 mvp; out vec4 frag_color;
408 | void main() {gl_Position = mvp * vec4(position, 1.0);frag_color = vec4(color, 1.0);}""", GL.GL_VERTEX_SHADER)
409 |
410 | FRAGMENT_SHADER = shaders.compileShader("""#version 330
411 | in vec4 frag_color; out vec4 out_color;
412 | void main() {out_color = frag_color;}""", GL.GL_FRAGMENT_SHADER)
413 |
414 | self.shaderProgram = shaders.compileProgram(VERTEX_SHADER, FRAGMENT_SHADER)
415 |
416 | self.UNIFORM_LOCATIONS = {
417 | 'position': GL.glGetAttribLocation(self.shaderProgram, 'position'),
418 | 'color': GL.glGetAttribLocation(self.shaderProgram, 'color'),
419 | 'mvp': GL.glGetUniformLocation(self.shaderProgram, 'mvp'),
420 | }
421 |
422 | shaders.glUseProgram(self.shaderProgram)
423 |
424 | def paintGL(self):
425 | if self.rgb.any() and self.depth.any():
426 | GL.glClear(GL.GL_COLOR_BUFFER_BIT | GL.GL_DEPTH_BUFFER_BIT)
427 | self.drawObject()
428 |
429 | def worldCoords(self, width, height):
430 | cx, cy = width / 2, height / 2
431 | fx = 518.8579
432 | fy = 518.8579
433 | xx, yy = np.tile(range(width), height), np.repeat(range(height), width)
434 | xx = (xx - cx) / fx
435 | yy = (yy - cy) / fy
436 | return xx, yy
437 |
438 | def posFromDepth(self, depth):
439 | length = depth.shape[0] * depth.shape[1]
440 |
441 | depth[edges(depth) > 0.3] = 1e6 # Hide depth edges
442 | z = depth.reshape(length)
443 |
444 | return np.dstack((self.xx * z, self.yy * z, z)).reshape((length, 3))
445 |
446 | def createPointCloudVBOfromRGBD(self):
447 | # Create position and color VBOs
448 | self.pos_vbo = vbo.VBO(data=self.pos, usage=GL.GL_DYNAMIC_DRAW, target=GL.GL_ARRAY_BUFFER)
449 | self.col_vbo = vbo.VBO(data=self.col, usage=GL.GL_DYNAMIC_DRAW, target=GL.GL_ARRAY_BUFFER)
450 |
451 | def updateRGBD(self):
452 | # RGBD dimensions
453 | width, height = self.depth.shape[1], self.depth.shape[0]
454 |
455 | # Reshape
456 | points = self.posFromDepth(self.depth.copy())
457 | colors = resize(self.rgb, (height, width)).reshape((height * width, 3))
458 |
459 | # Flatten and convert to float32
460 | self.pos = points.astype('float32')
461 | self.col = colors.reshape(height * width, 3).astype('float32')
462 |
463 | # Move center of scene
464 | self.pos = self.pos + glm.vec3(0, -0.06, -0.3)
465 |
466 | # Create VBOs
467 | if not self.col_vbo:
468 | self.createPointCloudVBOfromRGBD()
469 |
470 | def drawObject(self):
471 | # Update camera
472 | model, view, proj = glm.mat4(1), glm.mat4(1), glm.perspective(45, self.width() / self.height(), 0.01, 100)
473 | center, up, eye = glm.vec3(0, -0.075, 0), glm.vec3(0, -1, 0), glm.vec3(0, 0, -0.4 * (self.zoomLevel / 10))
474 | view = glm.lookAt(eye, center, up)
475 | model = glm.rotate(model, self.xRot / 160.0, glm.vec3(1, 0, 0))
476 | model = glm.rotate(model, self.yRot / 160.0, glm.vec3(0, 1, 0))
477 | model = glm.rotate(model, self.zRot / 160.0, glm.vec3(0, 0, 1))
478 | mvp = proj * view * model
479 | GL.glUniformMatrix4fv(self.UNIFORM_LOCATIONS['mvp'], 1, False, glm.value_ptr(mvp))
480 |
481 | # Update data
482 | self.pos_vbo.set_array(self.pos)
483 | self.col_vbo.set_array(self.col)
484 |
485 | # Point size
486 | GL.glPointSize(2)
487 |
488 | # Position
489 | self.pos_vbo.bind()
490 | GL.glEnableVertexAttribArray(0)
491 | GL.glVertexAttribPointer(0, 3, GL.GL_FLOAT, GL.GL_FALSE, 0, None)
492 |
493 | # Color
494 | self.col_vbo.bind()
495 | GL.glEnableVertexAttribArray(1)
496 | GL.glVertexAttribPointer(1, 3, GL.GL_FLOAT, GL.GL_FALSE, 0, None)
497 |
498 | # Draw
499 | GL.glDrawArrays(GL.GL_POINTS, 0, self.pos.shape[0])
500 |
501 | if __name__ == '__main__':
502 | app = QtWidgets.QApplication(sys.argv)
503 | window = Window()
504 | window.show()
505 | res = app.exec_()
--------------------------------------------------------------------------------
/newcrfs/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.backends.cudnn as cudnn
3 |
4 | import os, sys
5 | import argparse
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from utils import post_process_depth, flip_lr, compute_errors
10 | from networks.NewCRFDepth import NewCRFDepth
11 |
12 |
13 | def convert_arg_line_to_args(arg_line):
14 | for arg in arg_line.split():
15 | if not arg.strip():
16 | continue
17 | yield arg
18 |
19 |
20 | parser = argparse.ArgumentParser(description='NeWCRFs PyTorch implementation.', fromfile_prefix_chars='@')
21 | parser.convert_arg_line_to_args = convert_arg_line_to_args
22 |
23 | parser.add_argument('--model_name', type=str, help='model name', default='newcrfs')
24 | parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07')
25 | parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
26 |
27 | # Dataset
28 | parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
29 | parser.add_argument('--input_height', type=int, help='input height', default=480)
30 | parser.add_argument('--input_width', type=int, help='input width', default=640)
31 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
32 |
33 | # Preprocessing
34 | parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
35 | parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
36 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
37 | parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
38 |
39 | # Eval
40 | parser.add_argument('--data_path_eval', type=str, help='path to the data for evaluation', required=False)
41 | parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for evaluation', required=False)
42 | parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for evaluation', required=False)
43 | parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
44 | parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
45 | parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
46 | parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
47 |
48 |
49 | if sys.argv.__len__() == 2:
50 | arg_filename_with_prefix = '@' + sys.argv[1]
51 | args = parser.parse_args([arg_filename_with_prefix])
52 | else:
53 | args = parser.parse_args()
54 |
55 | if args.dataset == 'kitti' or args.dataset == 'nyu':
56 | from dataloaders.dataloader import NewDataLoader
57 | elif args.dataset == 'kittipred':
58 | from dataloaders.dataloader_kittipred import NewDataLoader
59 |
60 |
61 | def eval(model, dataloader_eval, post_process=False):
62 | eval_measures = torch.zeros(10).cuda()
63 | for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
64 | with torch.no_grad():
65 | image = torch.autograd.Variable(eval_sample_batched['image'].cuda())
66 | gt_depth = eval_sample_batched['depth']
67 | has_valid_depth = eval_sample_batched['has_valid_depth']
68 | if not has_valid_depth:
69 | # print('Invalid depth. continue.')
70 | continue
71 |
72 | pred_depth = model(image)
73 | if post_process:
74 | image_flipped = flip_lr(image)
75 | pred_depth_flipped = model(image_flipped)
76 | pred_depth = post_process_depth(pred_depth, pred_depth_flipped)
77 |
78 | pred_depth = pred_depth.cpu().numpy().squeeze()
79 | gt_depth = gt_depth.cpu().numpy().squeeze()
80 |
81 | if args.do_kb_crop:
82 | height, width = gt_depth.shape
83 | top_margin = int(height - 352)
84 | left_margin = int((width - 1216) / 2)
85 | pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
86 | pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
87 | pred_depth = pred_depth_uncropped
88 |
89 | pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
90 | pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
91 | pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
92 | pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
93 |
94 | valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
95 |
96 | if args.garg_crop or args.eigen_crop:
97 | gt_height, gt_width = gt_depth.shape
98 | eval_mask = np.zeros(valid_mask.shape)
99 |
100 | if args.garg_crop:
101 | eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
102 |
103 | elif args.eigen_crop:
104 | if args.dataset == 'kitti':
105 | eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
106 | elif args.dataset == 'nyu':
107 | eval_mask[45:471, 41:601] = 1
108 |
109 | valid_mask = np.logical_and(valid_mask, eval_mask)
110 |
111 | measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
112 |
113 | eval_measures[:9] += torch.tensor(measures).cuda()
114 | eval_measures[9] += 1
115 |
116 | eval_measures_cpu = eval_measures.cpu()
117 | cnt = eval_measures_cpu[9].item()
118 | eval_measures_cpu /= cnt
119 | print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
120 | print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
121 | 'sq_rel', 'log_rms', 'd1', 'd2',
122 | 'd3'))
123 | for i in range(8):
124 | print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
125 | print('{:7.4f}'.format(eval_measures_cpu[8]))
126 | return eval_measures_cpu
127 |
128 |
129 | def main_worker(args):
130 |
131 | # CRF model
132 | model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
133 | model.train()
134 |
135 | num_params = sum([np.prod(p.size()) for p in model.parameters()])
136 | print("== Total number of parameters: {}".format(num_params))
137 |
138 | num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
139 | print("== Total number of learning parameters: {}".format(num_params_update))
140 |
141 | model = torch.nn.DataParallel(model)
142 | model.cuda()
143 |
144 | print("== Model Initialized")
145 |
146 | if args.checkpoint_path != '':
147 | if os.path.isfile(args.checkpoint_path):
148 | print("== Loading checkpoint '{}'".format(args.checkpoint_path))
149 | checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
150 | model.load_state_dict(checkpoint['model'])
151 | print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
152 | del checkpoint
153 | else:
154 | print("== No checkpoint found at '{}'".format(args.checkpoint_path))
155 |
156 | cudnn.benchmark = True
157 |
158 | dataloader_eval = NewDataLoader(args, 'online_eval')
159 |
160 | # ===== Evaluation ======
161 | model.eval()
162 | with torch.no_grad():
163 | eval_measures = eval(model, dataloader_eval, post_process=True)
164 |
165 |
166 | def main():
167 | torch.cuda.empty_cache()
168 | args.distributed = False
169 | ngpus_per_node = torch.cuda.device_count()
170 | if ngpus_per_node > 1:
171 | print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
172 | return -1
173 |
174 | main_worker(args)
175 |
176 |
177 | if __name__ == '__main__':
178 | main()
179 |
--------------------------------------------------------------------------------
/newcrfs/networks/NewCRFDepth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .swin_transformer import SwinTransformer
6 | from .newcrf_layers import NewCRF
7 | from .uper_crf_head import PSP
8 | ########################################################################################################################
9 |
10 |
11 | class NewCRFDepth(nn.Module):
12 | """
13 | Depth network based on neural window FC-CRFs architecture.
14 | """
15 | def __init__(self, version=None, inv_depth=False, pretrained=None,
16 | frozen_stages=-1, min_depth=0.1, max_depth=100.0, **kwargs):
17 | super().__init__()
18 |
19 | self.inv_depth = inv_depth
20 | self.with_auxiliary_head = False
21 | self.with_neck = False
22 |
23 | norm_cfg = dict(type='BN', requires_grad=True)
24 | # norm_cfg = dict(type='GN', requires_grad=True, num_groups=8)
25 |
26 | window_size = int(version[-2:])
27 |
28 | if version[:-2] == 'base':
29 | embed_dim = 128
30 | depths = [2, 2, 18, 2]
31 | num_heads = [4, 8, 16, 32]
32 | in_channels = [128, 256, 512, 1024]
33 | elif version[:-2] == 'large':
34 | embed_dim = 192
35 | depths = [2, 2, 18, 2]
36 | num_heads = [6, 12, 24, 48]
37 | in_channels = [192, 384, 768, 1536]
38 | elif version[:-2] == 'tiny':
39 | embed_dim = 96
40 | depths = [2, 2, 6, 2]
41 | num_heads = [3, 6, 12, 24]
42 | in_channels = [96, 192, 384, 768]
43 |
44 | backbone_cfg = dict(
45 | embed_dim=embed_dim,
46 | depths=depths,
47 | num_heads=num_heads,
48 | window_size=window_size,
49 | ape=False,
50 | drop_path_rate=0.3,
51 | patch_norm=True,
52 | use_checkpoint=False,
53 | frozen_stages=frozen_stages
54 | )
55 |
56 | embed_dim = 512
57 | decoder_cfg = dict(
58 | in_channels=in_channels,
59 | in_index=[0, 1, 2, 3],
60 | pool_scales=(1, 2, 3, 6),
61 | channels=embed_dim,
62 | dropout_ratio=0.0,
63 | num_classes=32,
64 | norm_cfg=norm_cfg,
65 | align_corners=False
66 | )
67 |
68 | self.backbone = SwinTransformer(**backbone_cfg)
69 | v_dim = decoder_cfg['num_classes']*4
70 | win = 7
71 | crf_dims = [128, 256, 512, 1024]
72 | v_dims = [64, 128, 256, embed_dim]
73 | self.crf3 = NewCRF(input_dim=in_channels[3], embed_dim=crf_dims[3], window_size=win, v_dim=v_dims[3], num_heads=32)
74 | self.crf2 = NewCRF(input_dim=in_channels[2], embed_dim=crf_dims[2], window_size=win, v_dim=v_dims[2], num_heads=16)
75 | self.crf1 = NewCRF(input_dim=in_channels[1], embed_dim=crf_dims[1], window_size=win, v_dim=v_dims[1], num_heads=8)
76 | self.crf0 = NewCRF(input_dim=in_channels[0], embed_dim=crf_dims[0], window_size=win, v_dim=v_dims[0], num_heads=4)
77 |
78 | self.decoder = PSP(**decoder_cfg)
79 | self.disp_head1 = DispHead(input_dim=crf_dims[0])
80 |
81 | self.up_mode = 'bilinear'
82 | if self.up_mode == 'mask':
83 | self.mask_head = nn.Sequential(
84 | nn.Conv2d(crf_dims[0], 64, 3, padding=1),
85 | nn.ReLU(inplace=True),
86 | nn.Conv2d(64, 16*9, 1, padding=0))
87 |
88 | self.min_depth = min_depth
89 | self.max_depth = max_depth
90 |
91 | self.init_weights(pretrained=pretrained)
92 |
93 | def init_weights(self, pretrained=None):
94 | """Initialize the weights in backbone and heads.
95 |
96 | Args:
97 | pretrained (str, optional): Path to pre-trained weights.
98 | Defaults to None.
99 | """
100 | print(f'== Load encoder backbone from: {pretrained}')
101 | self.backbone.init_weights(pretrained=pretrained)
102 | self.decoder.init_weights()
103 | if self.with_auxiliary_head:
104 | if isinstance(self.auxiliary_head, nn.ModuleList):
105 | for aux_head in self.auxiliary_head:
106 | aux_head.init_weights()
107 | else:
108 | self.auxiliary_head.init_weights()
109 |
110 | def upsample_mask(self, disp, mask):
111 | """ Upsample disp [H/4, W/4, 1] -> [H, W, 1] using convex combination """
112 | N, _, H, W = disp.shape
113 | mask = mask.view(N, 1, 9, 4, 4, H, W)
114 | mask = torch.softmax(mask, dim=2)
115 |
116 | up_disp = F.unfold(disp, kernel_size=3, padding=1)
117 | up_disp = up_disp.view(N, 1, 9, 1, 1, H, W)
118 |
119 | up_disp = torch.sum(mask * up_disp, dim=2)
120 | up_disp = up_disp.permute(0, 1, 4, 2, 5, 3)
121 | return up_disp.reshape(N, 1, 4*H, 4*W)
122 |
123 | def forward(self, imgs):
124 |
125 | feats = self.backbone(imgs)
126 | if self.with_neck:
127 | feats = self.neck(feats)
128 |
129 | ppm_out = self.decoder(feats)
130 |
131 | e3 = self.crf3(feats[3], ppm_out)
132 | e3 = nn.PixelShuffle(2)(e3)
133 | e2 = self.crf2(feats[2], e3)
134 | e2 = nn.PixelShuffle(2)(e2)
135 | e1 = self.crf1(feats[1], e2)
136 | e1 = nn.PixelShuffle(2)(e1)
137 | e0 = self.crf0(feats[0], e1)
138 |
139 | if self.up_mode == 'mask':
140 | mask = self.mask_head(e0)
141 | d1 = self.disp_head1(e0, 1)
142 | d1 = self.upsample_mask(d1, mask)
143 | else:
144 | d1 = self.disp_head1(e0, 4)
145 |
146 | depth = d1 * self.max_depth
147 |
148 | return depth
149 |
150 |
151 | class DispHead(nn.Module):
152 | def __init__(self, input_dim=100):
153 | super(DispHead, self).__init__()
154 | # self.norm1 = nn.BatchNorm2d(input_dim)
155 | self.conv1 = nn.Conv2d(input_dim, 1, 3, padding=1)
156 | # self.relu = nn.ReLU(inplace=True)
157 | self.sigmoid = nn.Sigmoid()
158 |
159 | def forward(self, x, scale):
160 | # x = self.relu(self.norm1(x))
161 | x = self.sigmoid(self.conv1(x))
162 | if scale > 1:
163 | x = upsample(x, scale_factor=scale)
164 | return x
165 |
166 |
167 | class DispUnpack(nn.Module):
168 | def __init__(self, input_dim=100, hidden_dim=128):
169 | super(DispUnpack, self).__init__()
170 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
171 | self.conv2 = nn.Conv2d(hidden_dim, 16, 3, padding=1)
172 | self.relu = nn.ReLU(inplace=True)
173 | self.sigmoid = nn.Sigmoid()
174 | self.pixel_shuffle = nn.PixelShuffle(4)
175 |
176 | def forward(self, x, output_size):
177 | x = self.relu(self.conv1(x))
178 | x = self.sigmoid(self.conv2(x)) # [b, 16, h/4, w/4]
179 | # x = torch.reshape(x, [x.shape[0], 1, x.shape[2]*4, x.shape[3]*4])
180 | x = self.pixel_shuffle(x)
181 |
182 | return x
183 |
184 |
185 | def upsample(x, scale_factor=2, mode="bilinear", align_corners=False):
186 | """Upsample input tensor by a factor of 2
187 | """
188 | return F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
--------------------------------------------------------------------------------
/newcrfs/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/newcrfs/networks/__init__.py
--------------------------------------------------------------------------------
/newcrfs/networks/newcrf_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.checkpoint as checkpoint
5 | import numpy as np
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 |
8 |
9 | class Mlp(nn.Module):
10 | """ Multilayer perceptron."""
11 |
12 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
13 | super().__init__()
14 | out_features = out_features or in_features
15 | hidden_features = hidden_features or in_features
16 | self.fc1 = nn.Linear(in_features, hidden_features)
17 | self.act = act_layer()
18 | self.fc2 = nn.Linear(hidden_features, out_features)
19 | self.drop = nn.Dropout(drop)
20 |
21 | def forward(self, x):
22 | x = self.fc1(x)
23 | x = self.act(x)
24 | x = self.drop(x)
25 | x = self.fc2(x)
26 | x = self.drop(x)
27 | return x
28 |
29 |
30 | def window_partition(x, window_size):
31 | """
32 | Args:
33 | x: (B, H, W, C)
34 | window_size (int): window size
35 |
36 | Returns:
37 | windows: (num_windows*B, window_size, window_size, C)
38 | """
39 | B, H, W, C = x.shape
40 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
41 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
42 | return windows
43 |
44 |
45 | def window_reverse(windows, window_size, H, W):
46 | """
47 | Args:
48 | windows: (num_windows*B, window_size, window_size, C)
49 | window_size (int): Window size
50 | H (int): Height of image
51 | W (int): Width of image
52 |
53 | Returns:
54 | x: (B, H, W, C)
55 | """
56 | B = int(windows.shape[0] / (H * W / window_size / window_size))
57 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
59 | return x
60 |
61 |
62 | class WindowAttention(nn.Module):
63 | """ Window based multi-head self attention (W-MSA) module with relative position bias.
64 | It supports both of shifted and non-shifted window.
65 |
66 | Args:
67 | dim (int): Number of input channels.
68 | window_size (tuple[int]): The height and width of the window.
69 | num_heads (int): Number of attention heads.
70 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
71 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
72 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
73 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
74 | """
75 |
76 | def __init__(self, dim, window_size, num_heads, v_dim, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
77 |
78 | super().__init__()
79 | self.dim = dim
80 | self.window_size = window_size # Wh, Ww
81 | self.num_heads = num_heads
82 | head_dim = dim // num_heads
83 | self.scale = qk_scale or head_dim ** -0.5
84 |
85 | # define a parameter table of relative position bias
86 | self.relative_position_bias_table = nn.Parameter(
87 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
88 |
89 | # get pair-wise relative position index for each token inside the window
90 | coords_h = torch.arange(self.window_size[0])
91 | coords_w = torch.arange(self.window_size[1])
92 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
93 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
94 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
95 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
96 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
97 | relative_coords[:, :, 1] += self.window_size[1] - 1
98 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
99 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
100 | self.register_buffer("relative_position_index", relative_position_index)
101 |
102 | self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
103 | self.attn_drop = nn.Dropout(attn_drop)
104 | self.proj = nn.Linear(v_dim, v_dim)
105 | self.proj_drop = nn.Dropout(proj_drop)
106 |
107 | trunc_normal_(self.relative_position_bias_table, std=.02)
108 | self.softmax = nn.Softmax(dim=-1)
109 |
110 | def forward(self, x, v, mask=None):
111 | """ Forward function.
112 |
113 | Args:
114 | x: input features with shape of (num_windows*B, N, C)
115 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
116 | """
117 | B_, N, C = x.shape
118 | qk = self.qk(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
119 | q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple)
120 |
121 | q = q * self.scale
122 | attn = (q @ k.transpose(-2, -1))
123 |
124 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
125 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
126 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
127 | attn = attn + relative_position_bias.unsqueeze(0)
128 |
129 | if mask is not None:
130 | nW = mask.shape[0]
131 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
132 | attn = attn.view(-1, self.num_heads, N, N)
133 | attn = self.softmax(attn)
134 | else:
135 | attn = self.softmax(attn)
136 |
137 | attn = self.attn_drop(attn)
138 |
139 | # assert self.dim % v.shape[-1] == 0, "self.dim % v.shape[-1] != 0"
140 | # repeat_num = self.dim // v.shape[-1]
141 | # v = v.view(B_, N, self.num_heads // repeat_num, -1).transpose(1, 2).repeat(1, repeat_num, 1, 1)
142 |
143 | assert self.dim == v.shape[-1], "self.dim != v.shape[-1]"
144 | v = v.view(B_, N, self.num_heads, -1).transpose(1, 2)
145 |
146 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
147 | x = self.proj(x)
148 | x = self.proj_drop(x)
149 | return x
150 |
151 |
152 | class CRFBlock(nn.Module):
153 | """ CRF Block.
154 |
155 | Args:
156 | dim (int): Number of input channels.
157 | num_heads (int): Number of attention heads.
158 | window_size (int): Window size.
159 | shift_size (int): Shift size for SW-MSA.
160 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
161 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
162 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
163 | drop (float, optional): Dropout rate. Default: 0.0
164 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
165 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
166 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
167 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
168 | """
169 |
170 | def __init__(self, dim, num_heads, v_dim, window_size=7, shift_size=0,
171 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
172 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
173 | super().__init__()
174 | self.dim = dim
175 | self.num_heads = num_heads
176 | self.v_dim = v_dim
177 | self.window_size = window_size
178 | self.shift_size = shift_size
179 | self.mlp_ratio = mlp_ratio
180 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
181 |
182 | self.norm1 = norm_layer(dim)
183 | self.attn = WindowAttention(
184 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, v_dim=v_dim,
185 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
186 |
187 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
188 | self.norm2 = norm_layer(v_dim)
189 | mlp_hidden_dim = int(v_dim * mlp_ratio)
190 | self.mlp = Mlp(in_features=v_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
191 |
192 | self.H = None
193 | self.W = None
194 |
195 | def forward(self, x, v, mask_matrix):
196 | """ Forward function.
197 |
198 | Args:
199 | x: Input feature, tensor size (B, H*W, C).
200 | H, W: Spatial resolution of the input feature.
201 | mask_matrix: Attention mask for cyclic shift.
202 | """
203 | B, L, C = x.shape
204 | H, W = self.H, self.W
205 | assert L == H * W, "input feature has wrong size"
206 |
207 | shortcut = x
208 | x = self.norm1(x)
209 | x = x.view(B, H, W, C)
210 |
211 | # pad feature maps to multiples of window size
212 | pad_l = pad_t = 0
213 | pad_r = (self.window_size - W % self.window_size) % self.window_size
214 | pad_b = (self.window_size - H % self.window_size) % self.window_size
215 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
216 | v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b))
217 | _, Hp, Wp, _ = x.shape
218 |
219 | # cyclic shift
220 | if self.shift_size > 0:
221 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
222 | shifted_v = torch.roll(v, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
223 | attn_mask = mask_matrix
224 | else:
225 | shifted_x = x
226 | shifted_v = v
227 | attn_mask = None
228 |
229 | # partition windows
230 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
231 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
232 | v_windows = window_partition(shifted_v, self.window_size) # nW*B, window_size, window_size, C
233 | v_windows = v_windows.view(-1, self.window_size * self.window_size, v_windows.shape[-1]) # nW*B, window_size*window_size, C
234 |
235 | # W-MSA/SW-MSA
236 | attn_windows = self.attn(x_windows, v_windows, mask=attn_mask) # nW*B, window_size*window_size, C
237 |
238 | # merge windows
239 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.v_dim)
240 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
241 |
242 | # reverse cyclic shift
243 | if self.shift_size > 0:
244 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
245 | else:
246 | x = shifted_x
247 |
248 | if pad_r > 0 or pad_b > 0:
249 | x = x[:, :H, :W, :].contiguous()
250 |
251 | x = x.view(B, H * W, self.v_dim)
252 |
253 | # FFN
254 | x = shortcut + self.drop_path(x)
255 | x = x + self.drop_path(self.mlp(self.norm2(x)))
256 |
257 | return x
258 |
259 |
260 | class BasicCRFLayer(nn.Module):
261 | """ A basic NeWCRFs layer for one stage.
262 |
263 | Args:
264 | dim (int): Number of feature channels
265 | depth (int): Depths of this stage.
266 | num_heads (int): Number of attention head.
267 | window_size (int): Local window size. Default: 7.
268 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
269 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
270 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
271 | drop (float, optional): Dropout rate. Default: 0.0
272 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
273 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
274 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
275 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
276 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
277 | """
278 |
279 | def __init__(self,
280 | dim,
281 | depth,
282 | num_heads,
283 | v_dim,
284 | window_size=7,
285 | mlp_ratio=4.,
286 | qkv_bias=True,
287 | qk_scale=None,
288 | drop=0.,
289 | attn_drop=0.,
290 | drop_path=0.,
291 | norm_layer=nn.LayerNorm,
292 | downsample=None,
293 | use_checkpoint=False):
294 | super().__init__()
295 | self.window_size = window_size
296 | self.shift_size = window_size // 2
297 | self.depth = depth
298 | self.use_checkpoint = use_checkpoint
299 |
300 | # build blocks
301 | self.blocks = nn.ModuleList([
302 | CRFBlock(
303 | dim=dim,
304 | num_heads=num_heads,
305 | v_dim=v_dim,
306 | window_size=window_size,
307 | shift_size=0 if (i % 2 == 0) else window_size // 2,
308 | mlp_ratio=mlp_ratio,
309 | qkv_bias=qkv_bias,
310 | qk_scale=qk_scale,
311 | drop=drop,
312 | attn_drop=attn_drop,
313 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
314 | norm_layer=norm_layer)
315 | for i in range(depth)])
316 |
317 | # patch merging layer
318 | if downsample is not None:
319 | self.downsample = downsample(dim=dim, norm_layer=norm_layer)
320 | else:
321 | self.downsample = None
322 |
323 | def forward(self, x, v, H, W):
324 | """ Forward function.
325 |
326 | Args:
327 | x: Input feature, tensor size (B, H*W, C).
328 | H, W: Spatial resolution of the input feature.
329 | """
330 |
331 | # calculate attention mask for SW-MSA
332 | Hp = int(np.ceil(H / self.window_size)) * self.window_size
333 | Wp = int(np.ceil(W / self.window_size)) * self.window_size
334 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
335 | h_slices = (slice(0, -self.window_size),
336 | slice(-self.window_size, -self.shift_size),
337 | slice(-self.shift_size, None))
338 | w_slices = (slice(0, -self.window_size),
339 | slice(-self.window_size, -self.shift_size),
340 | slice(-self.shift_size, None))
341 | cnt = 0
342 | for h in h_slices:
343 | for w in w_slices:
344 | img_mask[:, h, w, :] = cnt
345 | cnt += 1
346 |
347 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
348 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
349 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
350 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
351 |
352 | for blk in self.blocks:
353 | blk.H, blk.W = H, W
354 | if self.use_checkpoint:
355 | x = checkpoint.checkpoint(blk, x, attn_mask)
356 | else:
357 | x = blk(x, v, attn_mask)
358 | if self.downsample is not None:
359 | x_down = self.downsample(x, H, W)
360 | Wh, Ww = (H + 1) // 2, (W + 1) // 2
361 | return x, H, W, x_down, Wh, Ww
362 | else:
363 | return x, H, W, x, H, W
364 |
365 |
366 | class NewCRF(nn.Module):
367 | def __init__(self,
368 | input_dim=96,
369 | embed_dim=96,
370 | v_dim=64,
371 | window_size=7,
372 | num_heads=4,
373 | depth=2,
374 | patch_size=4,
375 | in_chans=3,
376 | norm_layer=nn.LayerNorm,
377 | patch_norm=True):
378 | super().__init__()
379 |
380 | self.embed_dim = embed_dim
381 | self.patch_norm = patch_norm
382 |
383 | if input_dim != embed_dim:
384 | self.proj_x = nn.Conv2d(input_dim, embed_dim, 3, padding=1)
385 | else:
386 | self.proj_x = None
387 |
388 | if v_dim != embed_dim:
389 | self.proj_v = nn.Conv2d(v_dim, embed_dim, 3, padding=1)
390 | elif embed_dim % v_dim == 0:
391 | self.proj_v = None
392 |
393 | # For now, v_dim need to be equal to embed_dim, because the output of window-attn is the input of shift-window-attn
394 | v_dim = embed_dim
395 | assert v_dim == embed_dim
396 |
397 | self.crf_layer = BasicCRFLayer(
398 | dim=embed_dim,
399 | depth=depth,
400 | num_heads=num_heads,
401 | v_dim=v_dim,
402 | window_size=window_size,
403 | mlp_ratio=4.,
404 | qkv_bias=True,
405 | qk_scale=None,
406 | drop=0.,
407 | attn_drop=0.,
408 | drop_path=0.,
409 | norm_layer=norm_layer,
410 | downsample=None,
411 | use_checkpoint=False)
412 |
413 | layer = norm_layer(embed_dim)
414 | layer_name = 'norm_crf'
415 | self.add_module(layer_name, layer)
416 |
417 |
418 | def forward(self, x, v):
419 | if self.proj_x is not None:
420 | x = self.proj_x(x)
421 | if self.proj_v is not None:
422 | v = self.proj_v(v)
423 |
424 | Wh, Ww = x.size(2), x.size(3)
425 | x = x.flatten(2).transpose(1, 2)
426 | v = v.transpose(1, 2).transpose(2, 3)
427 |
428 | x_out, H, W, x, Wh, Ww = self.crf_layer(x, v, Wh, Ww)
429 | norm_layer = getattr(self, f'norm_crf')
430 | x_out = norm_layer(x_out)
431 | out = x_out.view(-1, H, W, self.embed_dim).permute(0, 3, 1, 2).contiguous()
432 |
433 | return out
--------------------------------------------------------------------------------
/newcrfs/networks/newcrf_utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import os
3 | import os.path as osp
4 | import pkgutil
5 | import warnings
6 | from collections import OrderedDict
7 | from importlib import import_module
8 |
9 | import torch
10 | import torchvision
11 | import torch.nn as nn
12 | from torch.utils import model_zoo
13 | from torch.nn import functional as F
14 | from torch.nn.parallel import DataParallel, DistributedDataParallel
15 | from torch import distributed as dist
16 |
17 | TORCH_VERSION = torch.__version__
18 |
19 |
20 | def resize(input,
21 | size=None,
22 | scale_factor=None,
23 | mode='nearest',
24 | align_corners=None,
25 | warning=True):
26 | if warning:
27 | if size is not None and align_corners:
28 | input_h, input_w = tuple(int(x) for x in input.shape[2:])
29 | output_h, output_w = tuple(int(x) for x in size)
30 | if output_h > input_h or output_w > output_h:
31 | if ((output_h > 1 and output_w > 1 and input_h > 1
32 | and input_w > 1) and (output_h - 1) % (input_h - 1)
33 | and (output_w - 1) % (input_w - 1)):
34 | warnings.warn(
35 | f'When align_corners={align_corners}, '
36 | 'the output would more aligned if '
37 | f'input size {(input_h, input_w)} is `x+1` and '
38 | f'out size {(output_h, output_w)} is `nx+1`')
39 | if isinstance(size, torch.Size):
40 | size = tuple(int(x) for x in size)
41 | return F.interpolate(input, size, scale_factor, mode, align_corners)
42 |
43 |
44 | def normal_init(module, mean=0, std=1, bias=0):
45 | if hasattr(module, 'weight') and module.weight is not None:
46 | nn.init.normal_(module.weight, mean, std)
47 | if hasattr(module, 'bias') and module.bias is not None:
48 | nn.init.constant_(module.bias, bias)
49 |
50 |
51 | def is_module_wrapper(module):
52 | module_wrappers = (DataParallel, DistributedDataParallel)
53 | return isinstance(module, module_wrappers)
54 |
55 |
56 | def get_dist_info():
57 | if TORCH_VERSION < '1.0':
58 | initialized = dist._initialized
59 | else:
60 | if dist.is_available():
61 | initialized = dist.is_initialized()
62 | else:
63 | initialized = False
64 | if initialized:
65 | rank = dist.get_rank()
66 | world_size = dist.get_world_size()
67 | else:
68 | rank = 0
69 | world_size = 1
70 | return rank, world_size
71 |
72 |
73 | def load_state_dict(module, state_dict, strict=False, logger=None):
74 | """Load state_dict to a module.
75 |
76 | This method is modified from :meth:`torch.nn.Module.load_state_dict`.
77 | Default value for ``strict`` is set to ``False`` and the message for
78 | param mismatch will be shown even if strict is False.
79 |
80 | Args:
81 | module (Module): Module that receives the state_dict.
82 | state_dict (OrderedDict): Weights.
83 | strict (bool): whether to strictly enforce that the keys
84 | in :attr:`state_dict` match the keys returned by this module's
85 | :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
86 | logger (:obj:`logging.Logger`, optional): Logger to log the error
87 | message. If not specified, print function will be used.
88 | """
89 | unexpected_keys = []
90 | all_missing_keys = []
91 | err_msg = []
92 |
93 | metadata = getattr(state_dict, '_metadata', None)
94 | state_dict = state_dict.copy()
95 | if metadata is not None:
96 | state_dict._metadata = metadata
97 |
98 | # use _load_from_state_dict to enable checkpoint version control
99 | def load(module, prefix=''):
100 | # recursively check parallel module in case that the model has a
101 | # complicated structure, e.g., nn.Module(nn.Module(DDP))
102 | if is_module_wrapper(module):
103 | module = module.module
104 | local_metadata = {} if metadata is None else metadata.get(
105 | prefix[:-1], {})
106 | module._load_from_state_dict(state_dict, prefix, local_metadata, True,
107 | all_missing_keys, unexpected_keys,
108 | err_msg)
109 | for name, child in module._modules.items():
110 | if child is not None:
111 | load(child, prefix + name + '.')
112 |
113 | load(module)
114 | load = None # break load->load reference cycle
115 |
116 | # ignore "num_batches_tracked" of BN layers
117 | missing_keys = [
118 | key for key in all_missing_keys if 'num_batches_tracked' not in key
119 | ]
120 |
121 | if unexpected_keys:
122 | err_msg.append('unexpected key in source '
123 | f'state_dict: {", ".join(unexpected_keys)}\n')
124 | if missing_keys:
125 | err_msg.append(
126 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
127 |
128 | rank, _ = get_dist_info()
129 | if len(err_msg) > 0 and rank == 0:
130 | err_msg.insert(
131 | 0, 'The model and loaded state dict do not match exactly\n')
132 | err_msg = '\n'.join(err_msg)
133 | if strict:
134 | raise RuntimeError(err_msg)
135 | elif logger is not None:
136 | logger.warning(err_msg)
137 | else:
138 | print(err_msg)
139 |
140 |
141 | def load_url_dist(url, model_dir=None):
142 | """In distributed setting, this function only download checkpoint at local
143 | rank 0."""
144 | rank, world_size = get_dist_info()
145 | rank = int(os.environ.get('LOCAL_RANK', rank))
146 | if rank == 0:
147 | checkpoint = model_zoo.load_url(url, model_dir=model_dir)
148 | if world_size > 1:
149 | torch.distributed.barrier()
150 | if rank > 0:
151 | checkpoint = model_zoo.load_url(url, model_dir=model_dir)
152 | return checkpoint
153 |
154 |
155 | def get_torchvision_models():
156 | model_urls = dict()
157 | for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
158 | if ispkg:
159 | continue
160 | _zoo = import_module(f'torchvision.models.{name}')
161 | if hasattr(_zoo, 'model_urls'):
162 | _urls = getattr(_zoo, 'model_urls')
163 | model_urls.update(_urls)
164 | return model_urls
165 |
166 |
167 | def _load_checkpoint(filename, map_location=None):
168 | """Load checkpoint from somewhere (modelzoo, file, url).
169 |
170 | Args:
171 | filename (str): Accept local filepath, URL, ``torchvision://xxx``,
172 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
173 | details.
174 | map_location (str | None): Same as :func:`torch.load`. Default: None.
175 |
176 | Returns:
177 | dict | OrderedDict: The loaded checkpoint. It can be either an
178 | OrderedDict storing model weights or a dict containing other
179 | information, which depends on the checkpoint.
180 | """
181 | if filename.startswith('modelzoo://'):
182 | warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
183 | 'use "torchvision://" instead')
184 | model_urls = get_torchvision_models()
185 | model_name = filename[11:]
186 | checkpoint = load_url_dist(model_urls[model_name])
187 | else:
188 | if not osp.isfile(filename):
189 | raise IOError(f'{filename} is not a checkpoint file')
190 | checkpoint = torch.load(filename, map_location=map_location)
191 | return checkpoint
192 |
193 |
194 | def load_checkpoint(model,
195 | filename,
196 | map_location='cpu',
197 | strict=False,
198 | logger=None):
199 | """Load checkpoint from a file or URI.
200 |
201 | Args:
202 | model (Module): Module to load checkpoint.
203 | filename (str): Accept local filepath, URL, ``torchvision://xxx``,
204 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
205 | details.
206 | map_location (str): Same as :func:`torch.load`.
207 | strict (bool): Whether to allow different params for the model and
208 | checkpoint.
209 | logger (:mod:`logging.Logger` or None): The logger for error message.
210 |
211 | Returns:
212 | dict or OrderedDict: The loaded checkpoint.
213 | """
214 | checkpoint = _load_checkpoint(filename, map_location)
215 | # OrderedDict is a subclass of dict
216 | if not isinstance(checkpoint, dict):
217 | raise RuntimeError(
218 | f'No state_dict found in checkpoint file {filename}')
219 | # get state_dict from checkpoint
220 | if 'state_dict' in checkpoint:
221 | state_dict = checkpoint['state_dict']
222 | elif 'model' in checkpoint:
223 | state_dict = checkpoint['model']
224 | else:
225 | state_dict = checkpoint
226 | # strip prefix of state_dict
227 | if list(state_dict.keys())[0].startswith('module.'):
228 | state_dict = {k[7:]: v for k, v in state_dict.items()}
229 |
230 | # for MoBY, load model of online branch
231 | if sorted(list(state_dict.keys()))[0].startswith('encoder'):
232 | state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
233 |
234 | # reshape absolute position embedding
235 | if state_dict.get('absolute_pos_embed') is not None:
236 | absolute_pos_embed = state_dict['absolute_pos_embed']
237 | N1, L, C1 = absolute_pos_embed.size()
238 | N2, C2, H, W = model.absolute_pos_embed.size()
239 | if N1 != N2 or C1 != C2 or L != H*W:
240 | logger.warning("Error in loading absolute_pos_embed, pass")
241 | else:
242 | state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
243 |
244 | # interpolate position bias table if needed
245 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
246 | for table_key in relative_position_bias_table_keys:
247 | table_pretrained = state_dict[table_key]
248 | table_current = model.state_dict()[table_key]
249 | L1, nH1 = table_pretrained.size()
250 | L2, nH2 = table_current.size()
251 | if nH1 != nH2:
252 | logger.warning(f"Error in loading {table_key}, pass")
253 | else:
254 | if L1 != L2:
255 | S1 = int(L1 ** 0.5)
256 | S2 = int(L2 ** 0.5)
257 | table_pretrained_resized = F.interpolate(
258 | table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
259 | size=(S2, S2), mode='bicubic')
260 | state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
261 |
262 | # load state_dict
263 | load_state_dict(model, state_dict, strict, logger)
264 | return checkpoint
--------------------------------------------------------------------------------
/newcrfs/networks/swin_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.checkpoint as checkpoint
5 | import numpy as np
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 |
8 | from .newcrf_utils import load_checkpoint
9 |
10 |
11 | class Mlp(nn.Module):
12 | """ Multilayer perceptron."""
13 |
14 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
15 | super().__init__()
16 | out_features = out_features or in_features
17 | hidden_features = hidden_features or in_features
18 | self.fc1 = nn.Linear(in_features, hidden_features)
19 | self.act = act_layer()
20 | self.fc2 = nn.Linear(hidden_features, out_features)
21 | self.drop = nn.Dropout(drop)
22 |
23 | def forward(self, x):
24 | x = self.fc1(x)
25 | x = self.act(x)
26 | x = self.drop(x)
27 | x = self.fc2(x)
28 | x = self.drop(x)
29 | return x
30 |
31 |
32 | def window_partition(x, window_size):
33 | """
34 | Args:
35 | x: (B, H, W, C)
36 | window_size (int): window size
37 |
38 | Returns:
39 | windows: (num_windows*B, window_size, window_size, C)
40 | """
41 | B, H, W, C = x.shape
42 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
43 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
44 | return windows
45 |
46 |
47 | def window_reverse(windows, window_size, H, W):
48 | """
49 | Args:
50 | windows: (num_windows*B, window_size, window_size, C)
51 | window_size (int): Window size
52 | H (int): Height of image
53 | W (int): Width of image
54 |
55 | Returns:
56 | x: (B, H, W, C)
57 | """
58 | B = int(windows.shape[0] / (H * W / window_size / window_size))
59 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
60 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
61 | return x
62 |
63 |
64 | class WindowAttention(nn.Module):
65 | """ Window based multi-head self attention (W-MSA) module with relative position bias.
66 | It supports both of shifted and non-shifted window.
67 |
68 | Args:
69 | dim (int): Number of input channels.
70 | window_size (tuple[int]): The height and width of the window.
71 | num_heads (int): Number of attention heads.
72 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
73 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
74 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
75 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
76 | """
77 |
78 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
79 |
80 | super().__init__()
81 | self.dim = dim
82 | self.window_size = window_size # Wh, Ww
83 | self.num_heads = num_heads
84 | head_dim = dim // num_heads
85 | self.scale = qk_scale or head_dim ** -0.5
86 |
87 | # define a parameter table of relative position bias
88 | self.relative_position_bias_table = nn.Parameter(
89 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
90 |
91 | # get pair-wise relative position index for each token inside the window
92 | coords_h = torch.arange(self.window_size[0])
93 | coords_w = torch.arange(self.window_size[1])
94 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
95 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
96 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
97 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
98 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
99 | relative_coords[:, :, 1] += self.window_size[1] - 1
100 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
101 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
102 | self.register_buffer("relative_position_index", relative_position_index)
103 |
104 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
105 | self.attn_drop = nn.Dropout(attn_drop)
106 | self.proj = nn.Linear(dim, dim)
107 | self.proj_drop = nn.Dropout(proj_drop)
108 |
109 | trunc_normal_(self.relative_position_bias_table, std=.02)
110 | self.softmax = nn.Softmax(dim=-1)
111 |
112 | def forward(self, x, mask=None):
113 | """ Forward function.
114 |
115 | Args:
116 | x: input features with shape of (num_windows*B, N, C)
117 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
118 | """
119 | B_, N, C = x.shape
120 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
121 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
122 |
123 | q = q * self.scale
124 | attn = (q @ k.transpose(-2, -1))
125 |
126 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
127 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
128 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
129 | attn = attn + relative_position_bias.unsqueeze(0)
130 |
131 | if mask is not None:
132 | nW = mask.shape[0]
133 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
134 | attn = attn.view(-1, self.num_heads, N, N)
135 | attn = self.softmax(attn)
136 | else:
137 | attn = self.softmax(attn)
138 |
139 | attn = self.attn_drop(attn)
140 |
141 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
142 | x = self.proj(x)
143 | x = self.proj_drop(x)
144 | return x
145 |
146 |
147 | class SwinTransformerBlock(nn.Module):
148 | """ Swin Transformer Block.
149 |
150 | Args:
151 | dim (int): Number of input channels.
152 | num_heads (int): Number of attention heads.
153 | window_size (int): Window size.
154 | shift_size (int): Shift size for SW-MSA.
155 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
156 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
157 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
158 | drop (float, optional): Dropout rate. Default: 0.0
159 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
160 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
161 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
162 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
163 | """
164 |
165 | def __init__(self, dim, num_heads, window_size=7, shift_size=0,
166 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
167 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
168 | super().__init__()
169 | self.dim = dim
170 | self.num_heads = num_heads
171 | self.window_size = window_size
172 | self.shift_size = shift_size
173 | self.mlp_ratio = mlp_ratio
174 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
175 |
176 | self.norm1 = norm_layer(dim)
177 | self.attn = WindowAttention(
178 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
179 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
180 |
181 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
182 | self.norm2 = norm_layer(dim)
183 | mlp_hidden_dim = int(dim * mlp_ratio)
184 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
185 |
186 | self.H = None
187 | self.W = None
188 |
189 | def forward(self, x, mask_matrix):
190 | """ Forward function.
191 |
192 | Args:
193 | x: Input feature, tensor size (B, H*W, C).
194 | H, W: Spatial resolution of the input feature.
195 | mask_matrix: Attention mask for cyclic shift.
196 | """
197 | B, L, C = x.shape
198 | H, W = self.H, self.W
199 | assert L == H * W, "input feature has wrong size"
200 |
201 | shortcut = x
202 | x = self.norm1(x)
203 | x = x.view(B, H, W, C)
204 |
205 | # pad feature maps to multiples of window size
206 | pad_l = pad_t = 0
207 | pad_r = (self.window_size - W % self.window_size) % self.window_size
208 | pad_b = (self.window_size - H % self.window_size) % self.window_size
209 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
210 | _, Hp, Wp, _ = x.shape
211 |
212 | # cyclic shift
213 | if self.shift_size > 0:
214 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
215 | attn_mask = mask_matrix
216 | else:
217 | shifted_x = x
218 | attn_mask = None
219 |
220 | # partition windows
221 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
222 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
223 |
224 | # W-MSA/SW-MSA
225 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
226 |
227 | # merge windows
228 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
229 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
230 |
231 | # reverse cyclic shift
232 | if self.shift_size > 0:
233 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
234 | else:
235 | x = shifted_x
236 |
237 | if pad_r > 0 or pad_b > 0:
238 | x = x[:, :H, :W, :].contiguous()
239 |
240 | x = x.view(B, H * W, C)
241 |
242 | # FFN
243 | x = shortcut + self.drop_path(x)
244 | x = x + self.drop_path(self.mlp(self.norm2(x)))
245 |
246 | return x
247 |
248 |
249 | class PatchMerging(nn.Module):
250 | """ Patch Merging Layer
251 |
252 | Args:
253 | dim (int): Number of input channels.
254 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
255 | """
256 | def __init__(self, dim, norm_layer=nn.LayerNorm):
257 | super().__init__()
258 | self.dim = dim
259 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
260 | self.norm = norm_layer(4 * dim)
261 |
262 | def forward(self, x, H, W):
263 | """ Forward function.
264 |
265 | Args:
266 | x: Input feature, tensor size (B, H*W, C).
267 | H, W: Spatial resolution of the input feature.
268 | """
269 | B, L, C = x.shape
270 | assert L == H * W, "input feature has wrong size"
271 |
272 | x = x.view(B, H, W, C)
273 |
274 | # padding
275 | pad_input = (H % 2 == 1) or (W % 2 == 1)
276 | if pad_input:
277 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
278 |
279 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
280 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
281 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
282 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
283 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
284 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
285 |
286 | x = self.norm(x)
287 | x = self.reduction(x)
288 |
289 | return x
290 |
291 |
292 | class BasicLayer(nn.Module):
293 | """ A basic Swin Transformer layer for one stage.
294 |
295 | Args:
296 | dim (int): Number of feature channels
297 | depth (int): Depths of this stage.
298 | num_heads (int): Number of attention head.
299 | window_size (int): Local window size. Default: 7.
300 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
301 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
302 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
303 | drop (float, optional): Dropout rate. Default: 0.0
304 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
305 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
306 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
307 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
308 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
309 | """
310 |
311 | def __init__(self,
312 | dim,
313 | depth,
314 | num_heads,
315 | window_size=7,
316 | mlp_ratio=4.,
317 | qkv_bias=True,
318 | qk_scale=None,
319 | drop=0.,
320 | attn_drop=0.,
321 | drop_path=0.,
322 | norm_layer=nn.LayerNorm,
323 | downsample=None,
324 | use_checkpoint=False):
325 | super().__init__()
326 | self.window_size = window_size
327 | self.shift_size = window_size // 2
328 | self.depth = depth
329 | self.use_checkpoint = use_checkpoint
330 |
331 | # build blocks
332 | self.blocks = nn.ModuleList([
333 | SwinTransformerBlock(
334 | dim=dim,
335 | num_heads=num_heads,
336 | window_size=window_size,
337 | shift_size=0 if (i % 2 == 0) else window_size // 2,
338 | mlp_ratio=mlp_ratio,
339 | qkv_bias=qkv_bias,
340 | qk_scale=qk_scale,
341 | drop=drop,
342 | attn_drop=attn_drop,
343 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
344 | norm_layer=norm_layer)
345 | for i in range(depth)])
346 |
347 | # patch merging layer
348 | if downsample is not None:
349 | self.downsample = downsample(dim=dim, norm_layer=norm_layer)
350 | else:
351 | self.downsample = None
352 |
353 | def forward(self, x, H, W):
354 | """ Forward function.
355 |
356 | Args:
357 | x: Input feature, tensor size (B, H*W, C).
358 | H, W: Spatial resolution of the input feature.
359 | """
360 |
361 | # calculate attention mask for SW-MSA
362 | Hp = int(np.ceil(H / self.window_size)) * self.window_size
363 | Wp = int(np.ceil(W / self.window_size)) * self.window_size
364 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
365 | h_slices = (slice(0, -self.window_size),
366 | slice(-self.window_size, -self.shift_size),
367 | slice(-self.shift_size, None))
368 | w_slices = (slice(0, -self.window_size),
369 | slice(-self.window_size, -self.shift_size),
370 | slice(-self.shift_size, None))
371 | cnt = 0
372 | for h in h_slices:
373 | for w in w_slices:
374 | img_mask[:, h, w, :] = cnt
375 | cnt += 1
376 |
377 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
378 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
379 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
380 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
381 |
382 | for blk in self.blocks:
383 | blk.H, blk.W = H, W
384 | if self.use_checkpoint:
385 | x = checkpoint.checkpoint(blk, x, attn_mask)
386 | else:
387 | x = blk(x, attn_mask)
388 | if self.downsample is not None:
389 | x_down = self.downsample(x, H, W)
390 | Wh, Ww = (H + 1) // 2, (W + 1) // 2
391 | return x, H, W, x_down, Wh, Ww
392 | else:
393 | return x, H, W, x, H, W
394 |
395 |
396 | class PatchEmbed(nn.Module):
397 | """ Image to Patch Embedding
398 |
399 | Args:
400 | patch_size (int): Patch token size. Default: 4.
401 | in_chans (int): Number of input image channels. Default: 3.
402 | embed_dim (int): Number of linear projection output channels. Default: 96.
403 | norm_layer (nn.Module, optional): Normalization layer. Default: None
404 | """
405 |
406 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
407 | super().__init__()
408 | patch_size = to_2tuple(patch_size)
409 | self.patch_size = patch_size
410 |
411 | self.in_chans = in_chans
412 | self.embed_dim = embed_dim
413 |
414 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
415 | if norm_layer is not None:
416 | self.norm = norm_layer(embed_dim)
417 | else:
418 | self.norm = None
419 |
420 | def forward(self, x):
421 | """Forward function."""
422 | # padding
423 | _, _, H, W = x.size()
424 | if W % self.patch_size[1] != 0:
425 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
426 | if H % self.patch_size[0] != 0:
427 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
428 |
429 | x = self.proj(x) # B C Wh Ww
430 | if self.norm is not None:
431 | Wh, Ww = x.size(2), x.size(3)
432 | x = x.flatten(2).transpose(1, 2)
433 | x = self.norm(x)
434 | x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
435 |
436 | return x
437 |
438 |
439 | class SwinTransformer(nn.Module):
440 | """ Swin Transformer backbone.
441 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
442 | https://arxiv.org/pdf/2103.14030
443 |
444 | Args:
445 | pretrain_img_size (int): Input image size for training the pretrained model,
446 | used in absolute postion embedding. Default 224.
447 | patch_size (int | tuple(int)): Patch size. Default: 4.
448 | in_chans (int): Number of input image channels. Default: 3.
449 | embed_dim (int): Number of linear projection output channels. Default: 96.
450 | depths (tuple[int]): Depths of each Swin Transformer stage.
451 | num_heads (tuple[int]): Number of attention head of each stage.
452 | window_size (int): Window size. Default: 7.
453 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
454 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
455 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
456 | drop_rate (float): Dropout rate.
457 | attn_drop_rate (float): Attention dropout rate. Default: 0.
458 | drop_path_rate (float): Stochastic depth rate. Default: 0.2.
459 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
460 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
461 | patch_norm (bool): If True, add normalization after patch embedding. Default: True.
462 | out_indices (Sequence[int]): Output from which stages.
463 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
464 | -1 means not freezing any parameters.
465 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
466 | """
467 |
468 | def __init__(self,
469 | pretrain_img_size=224,
470 | patch_size=4,
471 | in_chans=3,
472 | embed_dim=96,
473 | depths=[2, 2, 6, 2],
474 | num_heads=[3, 6, 12, 24],
475 | window_size=7,
476 | mlp_ratio=4.,
477 | qkv_bias=True,
478 | qk_scale=None,
479 | drop_rate=0.,
480 | attn_drop_rate=0.,
481 | drop_path_rate=0.2,
482 | norm_layer=nn.LayerNorm,
483 | ape=False,
484 | patch_norm=True,
485 | out_indices=(0, 1, 2, 3),
486 | frozen_stages=-1,
487 | use_checkpoint=False):
488 | super().__init__()
489 |
490 | self.pretrain_img_size = pretrain_img_size
491 | self.num_layers = len(depths)
492 | self.embed_dim = embed_dim
493 | self.ape = ape
494 | self.patch_norm = patch_norm
495 | self.out_indices = out_indices
496 | self.frozen_stages = frozen_stages
497 |
498 | # split image into non-overlapping patches
499 | self.patch_embed = PatchEmbed(
500 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
501 | norm_layer=norm_layer if self.patch_norm else None)
502 |
503 | # absolute position embedding
504 | if self.ape:
505 | pretrain_img_size = to_2tuple(pretrain_img_size)
506 | patch_size = to_2tuple(patch_size)
507 | patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
508 |
509 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
510 | trunc_normal_(self.absolute_pos_embed, std=.02)
511 |
512 | self.pos_drop = nn.Dropout(p=drop_rate)
513 |
514 | # stochastic depth
515 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
516 |
517 | # build layers
518 | self.layers = nn.ModuleList()
519 | for i_layer in range(self.num_layers):
520 | layer = BasicLayer(
521 | dim=int(embed_dim * 2 ** i_layer),
522 | depth=depths[i_layer],
523 | num_heads=num_heads[i_layer],
524 | window_size=window_size,
525 | mlp_ratio=mlp_ratio,
526 | qkv_bias=qkv_bias,
527 | qk_scale=qk_scale,
528 | drop=drop_rate,
529 | attn_drop=attn_drop_rate,
530 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
531 | norm_layer=norm_layer,
532 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
533 | use_checkpoint=use_checkpoint)
534 | self.layers.append(layer)
535 |
536 | num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
537 | self.num_features = num_features
538 |
539 | # add a norm layer for each output
540 | for i_layer in out_indices:
541 | layer = norm_layer(num_features[i_layer])
542 | layer_name = f'norm{i_layer}'
543 | self.add_module(layer_name, layer)
544 |
545 | self._freeze_stages()
546 |
547 | def _freeze_stages(self):
548 | if self.frozen_stages >= 0:
549 | self.patch_embed.eval()
550 | for param in self.patch_embed.parameters():
551 | param.requires_grad = False
552 |
553 | if self.frozen_stages >= 1 and self.ape:
554 | self.absolute_pos_embed.requires_grad = False
555 |
556 | if self.frozen_stages >= 2:
557 | self.pos_drop.eval()
558 | for i in range(0, self.frozen_stages - 1):
559 | m = self.layers[i]
560 | m.eval()
561 | for param in m.parameters():
562 | param.requires_grad = False
563 |
564 | def init_weights(self, pretrained=None):
565 | """Initialize the weights in backbone.
566 |
567 | Args:
568 | pretrained (str, optional): Path to pre-trained weights.
569 | Defaults to None.
570 | """
571 |
572 | def _init_weights(m):
573 | if isinstance(m, nn.Linear):
574 | trunc_normal_(m.weight, std=.02)
575 | if isinstance(m, nn.Linear) and m.bias is not None:
576 | nn.init.constant_(m.bias, 0)
577 | elif isinstance(m, nn.LayerNorm):
578 | nn.init.constant_(m.bias, 0)
579 | nn.init.constant_(m.weight, 1.0)
580 |
581 | if isinstance(pretrained, str):
582 | self.apply(_init_weights)
583 | # logger = get_root_logger()
584 | load_checkpoint(self, pretrained, strict=False)
585 | elif pretrained is None:
586 | self.apply(_init_weights)
587 | else:
588 | raise TypeError('pretrained must be a str or None')
589 |
590 | def forward(self, x):
591 | """Forward function."""
592 | x = self.patch_embed(x)
593 |
594 | Wh, Ww = x.size(2), x.size(3)
595 | if self.ape:
596 | # interpolate the position embedding to the corresponding size
597 | absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
598 | x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
599 | else:
600 | x = x.flatten(2).transpose(1, 2)
601 | x = self.pos_drop(x)
602 |
603 | outs = []
604 | for i in range(self.num_layers):
605 | layer = self.layers[i]
606 | x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
607 |
608 | if i in self.out_indices:
609 | norm_layer = getattr(self, f'norm{i}')
610 | x_out = norm_layer(x_out)
611 |
612 | out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
613 | outs.append(out)
614 |
615 | return tuple(outs)
616 |
617 | def train(self, mode=True):
618 | """Convert the model into training mode while keep layers freezed."""
619 | super(SwinTransformer, self).train(mode)
620 | self._freeze_stages()
621 |
--------------------------------------------------------------------------------
/newcrfs/networks/uper_crf_head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from mmcv.cnn import ConvModule
6 | from .newcrf_utils import resize, normal_init
7 |
8 |
9 | class PPM(nn.ModuleList):
10 | """Pooling Pyramid Module used in PSPNet.
11 |
12 | Args:
13 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
14 | Module.
15 | in_channels (int): Input channels.
16 | channels (int): Channels after modules, before conv_seg.
17 | conv_cfg (dict|None): Config of conv layers.
18 | norm_cfg (dict|None): Config of norm layers.
19 | act_cfg (dict): Config of activation layers.
20 | align_corners (bool): align_corners argument of F.interpolate.
21 | """
22 |
23 | def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
24 | act_cfg, align_corners):
25 | super(PPM, self).__init__()
26 | self.pool_scales = pool_scales
27 | self.align_corners = align_corners
28 | self.in_channels = in_channels
29 | self.channels = channels
30 | self.conv_cfg = conv_cfg
31 | self.norm_cfg = norm_cfg
32 | self.act_cfg = act_cfg
33 | for pool_scale in pool_scales:
34 | # == if batch size = 1, BN is not supported, change to GN
35 | if pool_scale == 1: norm_cfg = dict(type='GN', requires_grad=True, num_groups=256)
36 | self.append(
37 | nn.Sequential(
38 | nn.AdaptiveAvgPool2d(pool_scale),
39 | ConvModule(
40 | self.in_channels,
41 | self.channels,
42 | 1,
43 | conv_cfg=self.conv_cfg,
44 | norm_cfg=norm_cfg,
45 | act_cfg=self.act_cfg)))
46 |
47 | def forward(self, x):
48 | """Forward function."""
49 | ppm_outs = []
50 | for ppm in self:
51 | ppm_out = ppm(x)
52 | upsampled_ppm_out = resize(
53 | ppm_out,
54 | size=x.size()[2:],
55 | mode='bilinear',
56 | align_corners=self.align_corners)
57 | ppm_outs.append(upsampled_ppm_out)
58 | return ppm_outs
59 |
60 |
61 | class BaseDecodeHead(nn.Module):
62 | """Base class for BaseDecodeHead.
63 |
64 | Args:
65 | in_channels (int|Sequence[int]): Input channels.
66 | channels (int): Channels after modules, before conv_seg.
67 | num_classes (int): Number of classes.
68 | dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
69 | conv_cfg (dict|None): Config of conv layers. Default: None.
70 | norm_cfg (dict|None): Config of norm layers. Default: None.
71 | act_cfg (dict): Config of activation layers.
72 | Default: dict(type='ReLU')
73 | in_index (int|Sequence[int]): Input feature index. Default: -1
74 | input_transform (str|None): Transformation type of input features.
75 | Options: 'resize_concat', 'multiple_select', None.
76 | 'resize_concat': Multiple feature maps will be resize to the
77 | same size as first one and than concat together.
78 | Usually used in FCN head of HRNet.
79 | 'multiple_select': Multiple feature maps will be bundle into
80 | a list and passed into decode head.
81 | None: Only one select feature map is allowed.
82 | Default: None.
83 | loss_decode (dict): Config of decode loss.
84 | Default: dict(type='CrossEntropyLoss').
85 | ignore_index (int | None): The label index to be ignored. When using
86 | masked BCE loss, ignore_index should be set to None. Default: 255
87 | sampler (dict|None): The config of segmentation map sampler.
88 | Default: None.
89 | align_corners (bool): align_corners argument of F.interpolate.
90 | Default: False.
91 | """
92 |
93 | def __init__(self,
94 | in_channels,
95 | channels,
96 | *,
97 | num_classes,
98 | dropout_ratio=0.1,
99 | conv_cfg=None,
100 | norm_cfg=None,
101 | act_cfg=dict(type='ReLU'),
102 | in_index=-1,
103 | input_transform=None,
104 | loss_decode=dict(
105 | type='CrossEntropyLoss',
106 | use_sigmoid=False,
107 | loss_weight=1.0),
108 | ignore_index=255,
109 | sampler=None,
110 | align_corners=False):
111 | super(BaseDecodeHead, self).__init__()
112 | self._init_inputs(in_channels, in_index, input_transform)
113 | self.channels = channels
114 | self.num_classes = num_classes
115 | self.dropout_ratio = dropout_ratio
116 | self.conv_cfg = conv_cfg
117 | self.norm_cfg = norm_cfg
118 | self.act_cfg = act_cfg
119 | self.in_index = in_index
120 | # self.loss_decode = build_loss(loss_decode)
121 | self.ignore_index = ignore_index
122 | self.align_corners = align_corners
123 | # if sampler is not None:
124 | # self.sampler = build_pixel_sampler(sampler, context=self)
125 | # else:
126 | # self.sampler = None
127 |
128 | # self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
129 | # self.conv1 = nn.Conv2d(channels, num_classes, 3, padding=1)
130 | if dropout_ratio > 0:
131 | self.dropout = nn.Dropout2d(dropout_ratio)
132 | else:
133 | self.dropout = None
134 | self.fp16_enabled = False
135 |
136 | def extra_repr(self):
137 | """Extra repr."""
138 | s = f'input_transform={self.input_transform}, ' \
139 | f'ignore_index={self.ignore_index}, ' \
140 | f'align_corners={self.align_corners}'
141 | return s
142 |
143 | def _init_inputs(self, in_channels, in_index, input_transform):
144 | """Check and initialize input transforms.
145 |
146 | The in_channels, in_index and input_transform must match.
147 | Specifically, when input_transform is None, only single feature map
148 | will be selected. So in_channels and in_index must be of type int.
149 | When input_transform
150 |
151 | Args:
152 | in_channels (int|Sequence[int]): Input channels.
153 | in_index (int|Sequence[int]): Input feature index.
154 | input_transform (str|None): Transformation type of input features.
155 | Options: 'resize_concat', 'multiple_select', None.
156 | 'resize_concat': Multiple feature maps will be resize to the
157 | same size as first one and than concat together.
158 | Usually used in FCN head of HRNet.
159 | 'multiple_select': Multiple feature maps will be bundle into
160 | a list and passed into decode head.
161 | None: Only one select feature map is allowed.
162 | """
163 |
164 | if input_transform is not None:
165 | assert input_transform in ['resize_concat', 'multiple_select']
166 | self.input_transform = input_transform
167 | self.in_index = in_index
168 | if input_transform is not None:
169 | assert isinstance(in_channels, (list, tuple))
170 | assert isinstance(in_index, (list, tuple))
171 | assert len(in_channels) == len(in_index)
172 | if input_transform == 'resize_concat':
173 | self.in_channels = sum(in_channels)
174 | else:
175 | self.in_channels = in_channels
176 | else:
177 | assert isinstance(in_channels, int)
178 | assert isinstance(in_index, int)
179 | self.in_channels = in_channels
180 |
181 | def init_weights(self):
182 | """Initialize weights of classification layer."""
183 | # normal_init(self.conv_seg, mean=0, std=0.01)
184 | # normal_init(self.conv1, mean=0, std=0.01)
185 |
186 | def _transform_inputs(self, inputs):
187 | """Transform inputs for decoder.
188 |
189 | Args:
190 | inputs (list[Tensor]): List of multi-level img features.
191 |
192 | Returns:
193 | Tensor: The transformed inputs
194 | """
195 |
196 | if self.input_transform == 'resize_concat':
197 | inputs = [inputs[i] for i in self.in_index]
198 | upsampled_inputs = [
199 | resize(
200 | input=x,
201 | size=inputs[0].shape[2:],
202 | mode='bilinear',
203 | align_corners=self.align_corners) for x in inputs
204 | ]
205 | inputs = torch.cat(upsampled_inputs, dim=1)
206 | elif self.input_transform == 'multiple_select':
207 | inputs = [inputs[i] for i in self.in_index]
208 | else:
209 | inputs = inputs[self.in_index]
210 |
211 | return inputs
212 |
213 | def forward(self, inputs):
214 | """Placeholder of forward function."""
215 | pass
216 |
217 | def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
218 | """Forward function for training.
219 | Args:
220 | inputs (list[Tensor]): List of multi-level img features.
221 | img_metas (list[dict]): List of image info dict where each dict
222 | has: 'img_shape', 'scale_factor', 'flip', and may also contain
223 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
224 | For details on the values of these keys see
225 | `mmseg/datasets/pipelines/formatting.py:Collect`.
226 | gt_semantic_seg (Tensor): Semantic segmentation masks
227 | used if the architecture supports semantic segmentation task.
228 | train_cfg (dict): The training config.
229 |
230 | Returns:
231 | dict[str, Tensor]: a dictionary of loss components
232 | """
233 | seg_logits = self.forward(inputs)
234 | losses = self.losses(seg_logits, gt_semantic_seg)
235 | return losses
236 |
237 | def forward_test(self, inputs, img_metas, test_cfg):
238 | """Forward function for testing.
239 |
240 | Args:
241 | inputs (list[Tensor]): List of multi-level img features.
242 | img_metas (list[dict]): List of image info dict where each dict
243 | has: 'img_shape', 'scale_factor', 'flip', and may also contain
244 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
245 | For details on the values of these keys see
246 | `mmseg/datasets/pipelines/formatting.py:Collect`.
247 | test_cfg (dict): The testing config.
248 |
249 | Returns:
250 | Tensor: Output segmentation map.
251 | """
252 | return self.forward(inputs)
253 |
254 |
255 | class UPerHead(BaseDecodeHead):
256 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
257 | super(UPerHead, self).__init__(
258 | input_transform='multiple_select', **kwargs)
259 | # FPN Module
260 | self.lateral_convs = nn.ModuleList()
261 | self.fpn_convs = nn.ModuleList()
262 | for in_channels in self.in_channels: # skip the top layer
263 | l_conv = ConvModule(
264 | in_channels,
265 | self.channels,
266 | 1,
267 | conv_cfg=self.conv_cfg,
268 | norm_cfg=self.norm_cfg,
269 | act_cfg=self.act_cfg,
270 | inplace=True)
271 | fpn_conv = ConvModule(
272 | self.channels,
273 | self.channels,
274 | 3,
275 | padding=1,
276 | conv_cfg=self.conv_cfg,
277 | norm_cfg=self.norm_cfg,
278 | act_cfg=self.act_cfg,
279 | inplace=True)
280 | self.lateral_convs.append(l_conv)
281 | self.fpn_convs.append(fpn_conv)
282 |
283 | def forward(self, inputs):
284 | """Forward function."""
285 |
286 | inputs = self._transform_inputs(inputs)
287 |
288 | # build laterals
289 | laterals = [
290 | lateral_conv(inputs[i])
291 | for i, lateral_conv in enumerate(self.lateral_convs)
292 | ]
293 |
294 | # laterals.append(self.psp_forward(inputs))
295 |
296 | # build top-down path
297 | used_backbone_levels = len(laterals)
298 | for i in range(used_backbone_levels - 1, 0, -1):
299 | prev_shape = laterals[i - 1].shape[2:]
300 | laterals[i - 1] += resize(
301 | laterals[i],
302 | size=prev_shape,
303 | mode='bilinear',
304 | align_corners=self.align_corners)
305 |
306 | # build outputs
307 | fpn_outs = [
308 | self.fpn_convs[i](laterals[i])
309 | for i in range(used_backbone_levels - 1)
310 | ]
311 | # append psp feature
312 | fpn_outs.append(laterals[-1])
313 |
314 | return fpn_outs[0]
315 |
316 |
317 |
318 | class PSP(BaseDecodeHead):
319 | """Unified Perceptual Parsing for Scene Understanding.
320 |
321 | This head is the implementation of `UPerNet
322 | `_.
323 |
324 | Args:
325 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
326 | Module applied on the last feature. Default: (1, 2, 3, 6).
327 | """
328 |
329 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
330 | super(PSP, self).__init__(
331 | input_transform='multiple_select', **kwargs)
332 | # PSP Module
333 | self.psp_modules = PPM(
334 | pool_scales,
335 | self.in_channels[-1],
336 | self.channels,
337 | conv_cfg=self.conv_cfg,
338 | norm_cfg=self.norm_cfg,
339 | act_cfg=self.act_cfg,
340 | align_corners=self.align_corners)
341 | self.bottleneck = ConvModule(
342 | self.in_channels[-1] + len(pool_scales) * self.channels,
343 | self.channels,
344 | 3,
345 | padding=1,
346 | conv_cfg=self.conv_cfg,
347 | norm_cfg=self.norm_cfg,
348 | act_cfg=self.act_cfg)
349 |
350 | def psp_forward(self, inputs):
351 | """Forward function of PSP module."""
352 | x = inputs[-1]
353 | psp_outs = [x]
354 | psp_outs.extend(self.psp_modules(x))
355 | psp_outs = torch.cat(psp_outs, dim=1)
356 | output = self.bottleneck(psp_outs)
357 |
358 | return output
359 |
360 | def forward(self, inputs):
361 | """Forward function."""
362 | inputs = self._transform_inputs(inputs)
363 |
364 | return self.psp_forward(inputs)
365 |
--------------------------------------------------------------------------------
/newcrfs/test.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Variable
6 |
7 | import os, sys, errno
8 | import argparse
9 | import time
10 | import numpy as np
11 | import cv2
12 | import matplotlib.pyplot as plt
13 | from tqdm import tqdm
14 |
15 | from utils import post_process_depth, flip_lr
16 | from networks.NewCRFDepth import NewCRFDepth
17 |
18 |
19 | def convert_arg_line_to_args(arg_line):
20 | for arg in arg_line.split():
21 | if not arg.strip():
22 | continue
23 | yield arg
24 |
25 |
26 | parser = argparse.ArgumentParser(description='NeWCRFs PyTorch implementation.', fromfile_prefix_chars='@')
27 | parser.convert_arg_line_to_args = convert_arg_line_to_args
28 |
29 | parser.add_argument('--model_name', type=str, help='model name', default='newcrfs')
30 | parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07')
31 | parser.add_argument('--data_path', type=str, help='path to the data', required=True)
32 | parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
33 | parser.add_argument('--input_height', type=int, help='input height', default=480)
34 | parser.add_argument('--input_width', type=int, help='input width', default=640)
35 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
36 | parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='')
37 | parser.add_argument('--dataset', type=str, help='dataset to train on', default='nyu')
38 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
39 | parser.add_argument('--save_viz', help='if set, save visulization of the outputs', action='store_true')
40 |
41 | if sys.argv.__len__() == 2:
42 | arg_filename_with_prefix = '@' + sys.argv[1]
43 | args = parser.parse_args([arg_filename_with_prefix])
44 | else:
45 | args = parser.parse_args()
46 |
47 | if args.dataset == 'kitti' or args.dataset == 'nyu':
48 | from dataloaders.dataloader import NewDataLoader
49 | elif args.dataset == 'kittipred':
50 | from dataloaders.dataloader_kittipred import NewDataLoader
51 |
52 | model_dir = os.path.dirname(args.checkpoint_path)
53 | sys.path.append(model_dir)
54 |
55 |
56 | def get_num_lines(file_path):
57 | f = open(file_path, 'r')
58 | lines = f.readlines()
59 | f.close()
60 | return len(lines)
61 |
62 |
63 | def test(params):
64 | """Test function."""
65 | args.mode = 'test'
66 | dataloader = NewDataLoader(args, 'test')
67 |
68 | model = NewCRFDepth(version='large07', inv_depth=False, max_depth=args.max_depth)
69 | model = torch.nn.DataParallel(model)
70 |
71 | checkpoint = torch.load(args.checkpoint_path)
72 | model.load_state_dict(checkpoint['model'])
73 | model.eval()
74 | model.cuda()
75 |
76 | num_params = sum([np.prod(p.size()) for p in model.parameters()])
77 | print("Total number of parameters: {}".format(num_params))
78 |
79 | num_test_samples = get_num_lines(args.filenames_file)
80 |
81 | with open(args.filenames_file) as f:
82 | lines = f.readlines()
83 |
84 | print('now testing {} files with {}'.format(num_test_samples, args.checkpoint_path))
85 |
86 | pred_depths = []
87 | start_time = time.time()
88 | with torch.no_grad():
89 | for _, sample in enumerate(tqdm(dataloader.data)):
90 | image = Variable(sample['image'].cuda())
91 | # Predict
92 | depth_est = model(image)
93 | post_process = True
94 | if post_process:
95 | image_flipped = flip_lr(image)
96 | depth_est_flipped = model(image_flipped)
97 | depth_est = post_process_depth(depth_est, depth_est_flipped)
98 |
99 | pred_depth = depth_est.cpu().numpy().squeeze()
100 |
101 | if args.do_kb_crop:
102 | height, width = 352, 1216
103 | top_margin = int(height - 352)
104 | left_margin = int((width - 1216) / 2)
105 | pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
106 | pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
107 | pred_depth = pred_depth_uncropped
108 |
109 | pred_depths.append(pred_depth)
110 |
111 | elapsed_time = time.time() - start_time
112 | print('Elapesed time: %s' % str(elapsed_time))
113 | print('Done.')
114 |
115 | save_name = 'models/result_' + args.model_name
116 |
117 | print('Saving result pngs..')
118 | if not os.path.exists(save_name):
119 | try:
120 | os.mkdir(save_name)
121 | os.mkdir(save_name + '/raw')
122 | os.mkdir(save_name + '/cmap')
123 | os.mkdir(save_name + '/rgb')
124 | os.mkdir(save_name + '/gt')
125 | except OSError as e:
126 | if e.errno != errno.EEXIST:
127 | raise
128 |
129 | for s in tqdm(range(num_test_samples)):
130 | if args.dataset == 'kitti':
131 | date_drive = lines[s].split('/')[1]
132 | filename_pred_png = save_name + '/raw/' + date_drive + '_' + lines[s].split()[0].split('/')[-1].replace(
133 | '.jpg', '.png')
134 | filename_cmap_png = save_name + '/cmap/' + date_drive + '_' + lines[s].split()[0].split('/')[
135 | -1].replace('.jpg', '.png')
136 | filename_image_png = save_name + '/rgb/' + date_drive + '_' + lines[s].split()[0].split('/')[-1]
137 | elif args.dataset == 'kittipred':
138 | filename_pred_png = save_name + '/raw/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
139 | filename_cmap_png = save_name + '/cmap/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
140 | filename_image_png = save_name + '/rgb/' + lines[s].split()[0].split('/')[-1]
141 | else:
142 | scene_name = lines[s].split()[0].split('/')[0]
143 | filename_pred_png = save_name + '/raw/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace(
144 | '.jpg', '.png')
145 | filename_cmap_png = save_name + '/cmap/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1].replace(
146 | '.jpg', '.png')
147 | filename_gt_png = save_name + '/gt/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1].replace(
148 | '.jpg', '_gt.png')
149 | filename_image_png = save_name + '/rgb/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1]
150 |
151 | rgb_path = os.path.join(args.data_path, './' + lines[s].split()[0])
152 | image = cv2.imread(rgb_path)
153 | if args.dataset == 'nyu':
154 | gt_path = os.path.join(args.data_path, './' + lines[s].split()[1])
155 | gt = cv2.imread(gt_path, -1).astype(np.float32) / 1000.0 # Visualization purpose only
156 | gt[gt == 0] = np.amax(gt)
157 |
158 | pred_depth = pred_depths[s]
159 |
160 | if args.dataset == 'kitti' or args.dataset == 'kittipred':
161 | pred_depth_scaled = pred_depth * 256.0
162 | else:
163 | pred_depth_scaled = pred_depth * 1000.0
164 |
165 | pred_depth_scaled = pred_depth_scaled.astype(np.uint16)
166 | cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0])
167 |
168 | if args.save_viz:
169 | cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :])
170 | if args.dataset == 'nyu':
171 | plt.imsave(filename_gt_png, (10 - gt) / 10, cmap='plasma')
172 | pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9]
173 | plt.imsave(filename_cmap_png, (10 - pred_depth) / 10, cmap='plasma')
174 | else:
175 | plt.imsave(filename_cmap_png, np.log10(pred_depth), cmap='Greys')
176 |
177 | return
178 |
179 |
180 | if __name__ == '__main__':
181 | test(args)
182 |
--------------------------------------------------------------------------------
/newcrfs/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.utils as utils
4 | import torch.backends.cudnn as cudnn
5 | import torch.distributed as dist
6 | import torch.multiprocessing as mp
7 |
8 | import os, sys, time
9 | from telnetlib import IP
10 | import argparse
11 | import numpy as np
12 | from tqdm import tqdm
13 |
14 | from tensorboardX import SummaryWriter
15 |
16 | from utils import post_process_depth, flip_lr, silog_loss, compute_errors, eval_metrics, \
17 | block_print, enable_print, normalize_result, inv_normalize, convert_arg_line_to_args
18 | from networks.NewCRFDepth import NewCRFDepth
19 |
20 |
21 | parser = argparse.ArgumentParser(description='NeWCRFs PyTorch implementation.', fromfile_prefix_chars='@')
22 | parser.convert_arg_line_to_args = convert_arg_line_to_args
23 |
24 | parser.add_argument('--mode', type=str, help='train or test', default='train')
25 | parser.add_argument('--model_name', type=str, help='model name', default='newcrfs')
26 | parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07')
27 | parser.add_argument('--pretrain', type=str, help='path of pretrained encoder', default=None)
28 |
29 | # Dataset
30 | parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
31 | parser.add_argument('--data_path', type=str, help='path to the data', required=True)
32 | parser.add_argument('--gt_path', type=str, help='path to the groundtruth data', required=True)
33 | parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
34 | parser.add_argument('--input_height', type=int, help='input height', default=480)
35 | parser.add_argument('--input_width', type=int, help='input width', default=640)
36 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
37 |
38 | # Log and save
39 | parser.add_argument('--log_directory', type=str, help='directory to save checkpoints and summaries', default='')
40 | parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
41 | parser.add_argument('--log_freq', type=int, help='Logging frequency in global steps', default=100)
42 | parser.add_argument('--save_freq', type=int, help='Checkpoint saving frequency in global steps', default=5000)
43 |
44 | # Training
45 | parser.add_argument('--weight_decay', type=float, help='weight decay factor for optimization', default=1e-2)
46 | parser.add_argument('--retrain', help='if used with checkpoint_path, will restart training from step zero', action='store_true')
47 | parser.add_argument('--adam_eps', type=float, help='epsilon in Adam optimizer', default=1e-6)
48 | parser.add_argument('--batch_size', type=int, help='batch size', default=4)
49 | parser.add_argument('--num_epochs', type=int, help='number of epochs', default=50)
50 | parser.add_argument('--learning_rate', type=float, help='initial learning rate', default=1e-4)
51 | parser.add_argument('--end_learning_rate', type=float, help='end learning rate', default=-1)
52 | parser.add_argument('--variance_focus', type=float, help='lambda in paper: [0, 1], higher value more focus on minimizing variance of error', default=0.85)
53 |
54 | # Preprocessing
55 | parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
56 | parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
57 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
58 | parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
59 |
60 | # Multi-gpu training
61 | parser.add_argument('--num_threads', type=int, help='number of threads to use for data loading', default=1)
62 | parser.add_argument('--world_size', type=int, help='number of nodes for distributed training', default=1)
63 | parser.add_argument('--rank', type=int, help='node rank for distributed training', default=0)
64 | parser.add_argument('--dist_url', type=str, help='url used to set up distributed training', default='tcp://127.0.0.1:1234')
65 | parser.add_argument('--dist_backend', type=str, help='distributed backend', default='nccl')
66 | parser.add_argument('--gpu', type=int, help='GPU id to use.', default=None)
67 | parser.add_argument('--multiprocessing_distributed', help='Use multi-processing distributed training to launch '
68 | 'N processes per node, which has N GPUs. This is the '
69 | 'fastest way to use PyTorch for either single node or '
70 | 'multi node data parallel training', action='store_true',)
71 | # Online eval
72 | parser.add_argument('--do_online_eval', help='if set, perform online eval in every eval_freq steps', action='store_true')
73 | parser.add_argument('--data_path_eval', type=str, help='path to the data for online evaluation', required=False)
74 | parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for online evaluation', required=False)
75 | parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for online evaluation', required=False)
76 | parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
77 | parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
78 | parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
79 | parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
80 | parser.add_argument('--eval_freq', type=int, help='Online evaluation frequency in global steps', default=500)
81 | parser.add_argument('--eval_summary_directory', type=str, help='output directory for eval summary,'
82 | 'if empty outputs to checkpoint folder', default='')
83 |
84 | if sys.argv.__len__() == 2:
85 | arg_filename_with_prefix = '@' + sys.argv[1]
86 | args = parser.parse_args([arg_filename_with_prefix])
87 | else:
88 | args = parser.parse_args()
89 |
90 | if args.dataset == 'kitti' or args.dataset == 'nyu':
91 | from dataloaders.dataloader import NewDataLoader
92 | elif args.dataset == 'kittipred':
93 | from dataloaders.dataloader_kittipred import NewDataLoader
94 |
95 |
96 | def online_eval(model, dataloader_eval, gpu, ngpus, post_process=False):
97 | eval_measures = torch.zeros(10).cuda(device=gpu)
98 | for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
99 | with torch.no_grad():
100 | image = torch.autograd.Variable(eval_sample_batched['image'].cuda(gpu, non_blocking=True))
101 | gt_depth = eval_sample_batched['depth']
102 | has_valid_depth = eval_sample_batched['has_valid_depth']
103 | if not has_valid_depth:
104 | # print('Invalid depth. continue.')
105 | continue
106 |
107 | pred_depth = model(image)
108 | if post_process:
109 | image_flipped = flip_lr(image)
110 | pred_depth_flipped = model(image_flipped)
111 | pred_depth = post_process_depth(pred_depth, pred_depth_flipped)
112 |
113 | pred_depth = pred_depth.cpu().numpy().squeeze()
114 | gt_depth = gt_depth.cpu().numpy().squeeze()
115 |
116 | if args.do_kb_crop:
117 | height, width = gt_depth.shape
118 | top_margin = int(height - 352)
119 | left_margin = int((width - 1216) / 2)
120 | pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
121 | pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
122 | pred_depth = pred_depth_uncropped
123 |
124 | pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
125 | pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
126 | pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
127 | pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
128 |
129 | valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
130 |
131 | if args.garg_crop or args.eigen_crop:
132 | gt_height, gt_width = gt_depth.shape
133 | eval_mask = np.zeros(valid_mask.shape)
134 |
135 | if args.garg_crop:
136 | eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
137 |
138 | elif args.eigen_crop:
139 | if args.dataset == 'kitti':
140 | eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
141 | elif args.dataset == 'nyu':
142 | eval_mask[45:471, 41:601] = 1
143 |
144 | valid_mask = np.logical_and(valid_mask, eval_mask)
145 |
146 | measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
147 |
148 | eval_measures[:9] += torch.tensor(measures).cuda(device=gpu)
149 | eval_measures[9] += 1
150 |
151 | if args.multiprocessing_distributed:
152 | group = dist.new_group([i for i in range(ngpus)])
153 | dist.all_reduce(tensor=eval_measures, op=dist.ReduceOp.SUM, group=group)
154 |
155 | if not args.multiprocessing_distributed or gpu == 0:
156 | eval_measures_cpu = eval_measures.cpu()
157 | cnt = eval_measures_cpu[9].item()
158 | eval_measures_cpu /= cnt
159 | print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
160 | print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
161 | 'sq_rel', 'log_rms', 'd1', 'd2',
162 | 'd3'))
163 | for i in range(8):
164 | print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
165 | print('{:7.4f}'.format(eval_measures_cpu[8]))
166 | return eval_measures_cpu
167 |
168 | return None
169 |
170 |
171 | def main_worker(gpu, ngpus_per_node, args):
172 | args.gpu = gpu
173 |
174 | if args.gpu is not None:
175 | print("== Use GPU: {} for training".format(args.gpu))
176 |
177 | if args.distributed:
178 | if args.dist_url == "env://" and args.rank == -1:
179 | args.rank = int(os.environ["RANK"])
180 | if args.multiprocessing_distributed:
181 | args.rank = args.rank * ngpus_per_node + gpu
182 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
183 |
184 | # NeWCRFs model
185 | model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=args.pretrain)
186 | model.train()
187 |
188 | num_params = sum([np.prod(p.size()) for p in model.parameters()])
189 | print("== Total number of parameters: {}".format(num_params))
190 |
191 | num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
192 | print("== Total number of learning parameters: {}".format(num_params_update))
193 |
194 | if args.distributed:
195 | if args.gpu is not None:
196 | torch.cuda.set_device(args.gpu)
197 | model.cuda(args.gpu)
198 | args.batch_size = int(args.batch_size / ngpus_per_node)
199 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
200 | else:
201 | model.cuda()
202 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
203 | else:
204 | model = torch.nn.DataParallel(model)
205 | model.cuda()
206 |
207 | if args.distributed:
208 | print("== Model Initialized on GPU: {}".format(args.gpu))
209 | else:
210 | print("== Model Initialized")
211 |
212 | global_step = 0
213 | best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3
214 | best_eval_measures_higher_better = torch.zeros(3).cpu()
215 | best_eval_steps = np.zeros(9, dtype=np.int32)
216 |
217 | # Training parameters
218 | optimizer = torch.optim.Adam([{'params': model.module.parameters()}],
219 | lr=args.learning_rate)
220 |
221 | model_just_loaded = False
222 | if args.checkpoint_path != '':
223 | if os.path.isfile(args.checkpoint_path):
224 | print("== Loading checkpoint '{}'".format(args.checkpoint_path))
225 | if args.gpu is None:
226 | checkpoint = torch.load(args.checkpoint_path)
227 | else:
228 | loc = 'cuda:{}'.format(args.gpu)
229 | checkpoint = torch.load(args.checkpoint_path, map_location=loc)
230 | model.load_state_dict(checkpoint['model'])
231 | optimizer.load_state_dict(checkpoint['optimizer'])
232 | if not args.retrain:
233 | try:
234 | global_step = checkpoint['global_step']
235 | best_eval_measures_higher_better = checkpoint['best_eval_measures_higher_better'].cpu()
236 | best_eval_measures_lower_better = checkpoint['best_eval_measures_lower_better'].cpu()
237 | best_eval_steps = checkpoint['best_eval_steps']
238 | except KeyError:
239 | print("Could not load values for online evaluation")
240 |
241 | print("== Loaded checkpoint '{}' (global_step {})".format(args.checkpoint_path, checkpoint['global_step']))
242 | else:
243 | print("== No checkpoint found at '{}'".format(args.checkpoint_path))
244 | model_just_loaded = True
245 | del checkpoint
246 |
247 | cudnn.benchmark = True
248 |
249 | dataloader = NewDataLoader(args, 'train')
250 | dataloader_eval = NewDataLoader(args, 'online_eval')
251 |
252 | # ===== Evaluation before training ======
253 | # model.eval()
254 | # with torch.no_grad():
255 | # eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node, post_process=True)
256 |
257 | # Logging
258 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
259 | writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30)
260 | if args.do_online_eval:
261 | if args.eval_summary_directory != '':
262 | eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name)
263 | else:
264 | eval_summary_path = os.path.join(args.log_directory, args.model_name, 'eval')
265 | eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30)
266 |
267 | silog_criterion = silog_loss(variance_focus=args.variance_focus)
268 |
269 | start_time = time.time()
270 | duration = 0
271 |
272 | num_log_images = args.batch_size
273 | end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate
274 |
275 | var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
276 | var_cnt = len(var_sum)
277 | var_sum = np.sum(var_sum)
278 |
279 | print("== Initial variables' sum: {:.3f}, avg: {:.3f}".format(var_sum, var_sum/var_cnt))
280 |
281 | steps_per_epoch = len(dataloader.data)
282 | num_total_steps = args.num_epochs * steps_per_epoch
283 | epoch = global_step // steps_per_epoch
284 |
285 | while epoch < args.num_epochs:
286 | if args.distributed:
287 | dataloader.train_sampler.set_epoch(epoch)
288 |
289 | for step, sample_batched in enumerate(dataloader.data):
290 | optimizer.zero_grad()
291 | before_op_time = time.time()
292 |
293 | image = torch.autograd.Variable(sample_batched['image'].cuda(args.gpu, non_blocking=True))
294 | depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(args.gpu, non_blocking=True))
295 |
296 | depth_est = model(image)
297 |
298 | if args.dataset == 'nyu':
299 | mask = depth_gt > 0.1
300 | else:
301 | mask = depth_gt > 1.0
302 |
303 | loss = silog_criterion.forward(depth_est, depth_gt, mask.to(torch.bool))
304 | loss.backward()
305 | for param_group in optimizer.param_groups:
306 | current_lr = (args.learning_rate - end_learning_rate) * (1 - global_step / num_total_steps) ** 0.9 + end_learning_rate
307 | param_group['lr'] = current_lr
308 |
309 | optimizer.step()
310 |
311 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
312 | print('[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'.format(epoch, step, steps_per_epoch, global_step, current_lr, loss))
313 | if np.isnan(loss.cpu().item()):
314 | print('NaN in loss occurred. Aborting training.')
315 | return -1
316 |
317 | duration += time.time() - before_op_time
318 | if global_step and global_step % args.log_freq == 0 and not model_just_loaded:
319 | var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
320 | var_cnt = len(var_sum)
321 | var_sum = np.sum(var_sum)
322 | examples_per_sec = args.batch_size / duration * args.log_freq
323 | duration = 0
324 | time_sofar = (time.time() - start_time) / 3600
325 | training_time_left = (num_total_steps / global_step - 1.0) * time_sofar
326 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
327 | print("{}".format(args.model_name))
328 | print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
329 | print(print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item()/var_cnt, time_sofar, training_time_left))
330 |
331 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
332 | and args.rank % ngpus_per_node == 0):
333 | writer.add_scalar('silog_loss', loss, global_step)
334 | writer.add_scalar('learning_rate', current_lr, global_step)
335 | writer.add_scalar('var average', var_sum.item()/var_cnt, global_step)
336 | depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3, depth_gt)
337 | for i in range(num_log_images):
338 | writer.add_image('depth_gt/image/{}'.format(i), normalize_result(1/depth_gt[i, :, :, :].data), global_step)
339 | writer.add_image('depth_est/image/{}'.format(i), normalize_result(1/depth_est[i, :, :, :].data), global_step)
340 | writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step)
341 | writer.flush()
342 |
343 | if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded:
344 | time.sleep(0.1)
345 | model.eval()
346 | with torch.no_grad():
347 | eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node, post_process=True)
348 | if eval_measures is not None:
349 | for i in range(9):
350 | eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step))
351 | measure = eval_measures[i]
352 | is_best = False
353 | if i < 6 and measure < best_eval_measures_lower_better[i]:
354 | old_best = best_eval_measures_lower_better[i].item()
355 | best_eval_measures_lower_better[i] = measure.item()
356 | is_best = True
357 | elif i >= 6 and measure > best_eval_measures_higher_better[i-6]:
358 | old_best = best_eval_measures_higher_better[i-6].item()
359 | best_eval_measures_higher_better[i-6] = measure.item()
360 | is_best = True
361 | if is_best:
362 | old_best_step = best_eval_steps[i]
363 | old_best_name = '/model-{}-best_{}_{:.5f}'.format(old_best_step, eval_metrics[i], old_best)
364 | model_path = args.log_directory + '/' + args.model_name + old_best_name
365 | if os.path.exists(model_path):
366 | command = 'rm {}'.format(model_path)
367 | os.system(command)
368 | best_eval_steps[i] = global_step
369 | model_save_name = '/model-{}-best_{}_{:.5f}'.format(global_step, eval_metrics[i], measure)
370 | print('New best for {}. Saving model: {}'.format(eval_metrics[i], model_save_name))
371 | checkpoint = {'global_step': global_step,
372 | 'model': model.state_dict(),
373 | 'optimizer': optimizer.state_dict(),
374 | 'best_eval_measures_higher_better': best_eval_measures_higher_better,
375 | 'best_eval_measures_lower_better': best_eval_measures_lower_better,
376 | 'best_eval_steps': best_eval_steps
377 | }
378 | torch.save(checkpoint, args.log_directory + '/' + args.model_name + model_save_name)
379 | eval_summary_writer.flush()
380 | model.train()
381 | block_print()
382 | enable_print()
383 |
384 | model_just_loaded = False
385 | global_step += 1
386 |
387 | epoch += 1
388 |
389 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
390 | writer.close()
391 | if args.do_online_eval:
392 | eval_summary_writer.close()
393 |
394 |
395 | def main():
396 | if args.mode != 'train':
397 | print('train.py is only for training.')
398 | return -1
399 |
400 | command = 'mkdir ' + os.path.join(args.log_directory, args.model_name)
401 | os.system(command)
402 |
403 | args_out_path = os.path.join(args.log_directory, args.model_name)
404 | command = 'cp ' + sys.argv[1] + ' ' + args_out_path
405 | os.system(command)
406 |
407 | save_files = True
408 | if save_files:
409 | aux_out_path = os.path.join(args.log_directory, args.model_name)
410 | networks_savepath = os.path.join(aux_out_path, 'networks')
411 | dataloaders_savepath = os.path.join(aux_out_path, 'dataloaders')
412 | command = 'cp newcrfs/train.py ' + aux_out_path
413 | os.system(command)
414 | command = 'mkdir -p ' + networks_savepath + ' && cp newcrfs/networks/*.py ' + networks_savepath
415 | os.system(command)
416 | command = 'mkdir -p ' + dataloaders_savepath + ' && cp newcrfs/dataloaders/*.py ' + dataloaders_savepath
417 | os.system(command)
418 |
419 | torch.cuda.empty_cache()
420 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
421 |
422 | ngpus_per_node = torch.cuda.device_count()
423 | if ngpus_per_node > 1 and not args.multiprocessing_distributed:
424 | print("This machine has more than 1 gpu. Please specify --multiprocessing_distributed, or set \'CUDA_VISIBLE_DEVICES=0\'")
425 | return -1
426 |
427 | if args.do_online_eval:
428 | print("You have specified --do_online_eval.")
429 | print("This will evaluate the model every eval_freq {} steps and save best models for individual eval metrics."
430 | .format(args.eval_freq))
431 |
432 | if args.multiprocessing_distributed:
433 | args.world_size = ngpus_per_node * args.world_size
434 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
435 | else:
436 | main_worker(args.gpu, ngpus_per_node, args)
437 |
438 |
439 | if __name__ == '__main__':
440 | main()
441 |
--------------------------------------------------------------------------------
/newcrfs/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.distributed as dist
4 | from torch.utils.data import Sampler
5 | from torchvision import transforms
6 |
7 | import os, sys
8 | import numpy as np
9 | import math
10 | import torch
11 |
12 |
13 | def convert_arg_line_to_args(arg_line):
14 | for arg in arg_line.split():
15 | if not arg.strip():
16 | continue
17 | yield arg
18 |
19 |
20 | def block_print():
21 | sys.stdout = open(os.devnull, 'w')
22 |
23 |
24 | def enable_print():
25 | sys.stdout = sys.__stdout__
26 |
27 |
28 | def get_num_lines(file_path):
29 | f = open(file_path, 'r')
30 | lines = f.readlines()
31 | f.close()
32 | return len(lines)
33 |
34 |
35 | def colorize(value, vmin=None, vmax=None, cmap='Greys'):
36 | value = value.cpu().numpy()[:, :, :]
37 | value = np.log10(value)
38 |
39 | vmin = value.min() if vmin is None else vmin
40 | vmax = value.max() if vmax is None else vmax
41 |
42 | if vmin != vmax:
43 | value = (value - vmin) / (vmax - vmin)
44 | else:
45 | value = value*0.
46 |
47 | cmapper = matplotlib.cm.get_cmap(cmap)
48 | value = cmapper(value, bytes=True)
49 |
50 | img = value[:, :, :3]
51 |
52 | return img.transpose((2, 0, 1))
53 |
54 |
55 | def normalize_result(value, vmin=None, vmax=None):
56 | value = value.cpu().numpy()[0, :, :]
57 |
58 | vmin = value.min() if vmin is None else vmin
59 | vmax = value.max() if vmax is None else vmax
60 |
61 | if vmin != vmax:
62 | value = (value - vmin) / (vmax - vmin)
63 | else:
64 | value = value * 0.
65 |
66 | return np.expand_dims(value, 0)
67 |
68 |
69 | inv_normalize = transforms.Normalize(
70 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
71 | std=[1/0.229, 1/0.224, 1/0.225]
72 | )
73 |
74 |
75 | eval_metrics = ['silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3']
76 |
77 |
78 | def compute_errors(gt, pred):
79 | thresh = np.maximum((gt / pred), (pred / gt))
80 | d1 = (thresh < 1.25).mean()
81 | d2 = (thresh < 1.25 ** 2).mean()
82 | d3 = (thresh < 1.25 ** 3).mean()
83 |
84 | rms = (gt - pred) ** 2
85 | rms = np.sqrt(rms.mean())
86 |
87 | log_rms = (np.log(gt) - np.log(pred)) ** 2
88 | log_rms = np.sqrt(log_rms.mean())
89 |
90 | abs_rel = np.mean(np.abs(gt - pred) / gt)
91 | sq_rel = np.mean(((gt - pred) ** 2) / gt)
92 |
93 | err = np.log(pred) - np.log(gt)
94 | silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
95 |
96 | err = np.abs(np.log10(pred) - np.log10(gt))
97 | log10 = np.mean(err)
98 |
99 | return [silog, abs_rel, log10, rms, sq_rel, log_rms, d1, d2, d3]
100 |
101 |
102 | class silog_loss(nn.Module):
103 | def __init__(self, variance_focus):
104 | super(silog_loss, self).__init__()
105 | self.variance_focus = variance_focus
106 |
107 | def forward(self, depth_est, depth_gt, mask):
108 | d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask])
109 | return torch.sqrt((d ** 2).mean() - self.variance_focus * (d.mean() ** 2)) * 10.0
110 |
111 |
112 | def flip_lr(image):
113 | """
114 | Flip image horizontally
115 |
116 | Parameters
117 | ----------
118 | image : torch.Tensor [B,3,H,W]
119 | Image to be flipped
120 |
121 | Returns
122 | -------
123 | image_flipped : torch.Tensor [B,3,H,W]
124 | Flipped image
125 | """
126 | assert image.dim() == 4, 'You need to provide a [B,C,H,W] image to flip'
127 | return torch.flip(image, [3])
128 |
129 |
130 | def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'):
131 | """
132 | Fuse inverse depth and flipped inverse depth maps
133 |
134 | Parameters
135 | ----------
136 | inv_depth : torch.Tensor [B,1,H,W]
137 | Inverse depth map
138 | inv_depth_hat : torch.Tensor [B,1,H,W]
139 | Flipped inverse depth map produced from a flipped image
140 | method : str
141 | Method that will be used to fuse the inverse depth maps
142 |
143 | Returns
144 | -------
145 | fused_inv_depth : torch.Tensor [B,1,H,W]
146 | Fused inverse depth map
147 | """
148 | if method == 'mean':
149 | return 0.5 * (inv_depth + inv_depth_hat)
150 | elif method == 'max':
151 | return torch.max(inv_depth, inv_depth_hat)
152 | elif method == 'min':
153 | return torch.min(inv_depth, inv_depth_hat)
154 | else:
155 | raise ValueError('Unknown post-process method {}'.format(method))
156 |
157 |
158 | def post_process_depth(depth, depth_flipped, method='mean'):
159 | """
160 | Post-process an inverse and flipped inverse depth map
161 |
162 | Parameters
163 | ----------
164 | inv_depth : torch.Tensor [B,1,H,W]
165 | Inverse depth map
166 | inv_depth_flipped : torch.Tensor [B,1,H,W]
167 | Inverse depth map produced from a flipped image
168 | method : str
169 | Method that will be used to fuse the inverse depth maps
170 |
171 | Returns
172 | -------
173 | inv_depth_pp : torch.Tensor [B,1,H,W]
174 | Post-processed inverse depth map
175 | """
176 | B, C, H, W = depth.shape
177 | inv_depth_hat = flip_lr(depth_flipped)
178 | inv_depth_fused = fuse_inv_depth(depth, inv_depth_hat, method=method)
179 | xs = torch.linspace(0., 1., W, device=depth.device,
180 | dtype=depth.dtype).repeat(B, C, H, 1)
181 | mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.)
182 | mask_hat = flip_lr(mask)
183 | return mask_hat * depth + mask * inv_depth_hat + \
184 | (1.0 - mask - mask_hat) * inv_depth_fused
185 |
186 |
187 | class DistributedSamplerNoEvenlyDivisible(Sampler):
188 | """Sampler that restricts data loading to a subset of the dataset.
189 |
190 | It is especially useful in conjunction with
191 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
192 | process can pass a DistributedSampler instance as a DataLoader sampler,
193 | and load a subset of the original dataset that is exclusive to it.
194 |
195 | .. note::
196 | Dataset is assumed to be of constant size.
197 |
198 | Arguments:
199 | dataset: Dataset used for sampling.
200 | num_replicas (optional): Number of processes participating in
201 | distributed training.
202 | rank (optional): Rank of the current process within num_replicas.
203 | shuffle (optional): If true (default), sampler will shuffle the indices
204 | """
205 |
206 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
207 | if num_replicas is None:
208 | if not dist.is_available():
209 | raise RuntimeError("Requires distributed package to be available")
210 | num_replicas = dist.get_world_size()
211 | if rank is None:
212 | if not dist.is_available():
213 | raise RuntimeError("Requires distributed package to be available")
214 | rank = dist.get_rank()
215 | self.dataset = dataset
216 | self.num_replicas = num_replicas
217 | self.rank = rank
218 | self.epoch = 0
219 | num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas))
220 | rest = len(self.dataset) - num_samples * self.num_replicas
221 | if self.rank < rest:
222 | num_samples += 1
223 | self.num_samples = num_samples
224 | self.total_size = len(dataset)
225 | # self.total_size = self.num_samples * self.num_replicas
226 | self.shuffle = shuffle
227 |
228 | def __iter__(self):
229 | # deterministically shuffle based on epoch
230 | g = torch.Generator()
231 | g.manual_seed(self.epoch)
232 | if self.shuffle:
233 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
234 | else:
235 | indices = list(range(len(self.dataset)))
236 |
237 | # add extra samples to make it evenly divisible
238 | # indices += indices[:(self.total_size - len(indices))]
239 | # assert len(indices) == self.total_size
240 |
241 | # subsample
242 | indices = indices[self.rank:self.total_size:self.num_replicas]
243 | self.num_samples = len(indices)
244 | # assert len(indices) == self.num_samples
245 |
246 | return iter(indices)
247 |
248 | def __len__(self):
249 | return self.num_samples
250 |
251 | def set_epoch(self, epoch):
252 | self.epoch = epoch
--------------------------------------------------------------------------------