├── LICENSE
├── README.md
├── assets
├── pointcloud2.png
└── visual_compare.png
├── configs
├── DDAD.conf
├── base.conf
└── kitti.conf
├── data_split
├── kitti_eigen_test.txt
├── kitti_eigen_train.txt
└── kitti_eigen_val.txt
├── datasets
├── DDAD.py
├── DDAD_crop.py
├── DDAD_forward.py
├── __pycache__
│ ├── DDAD.cpython-37.pyc
│ ├── DDAD_crop.cpython-37.pyc
│ ├── DDAD_forward.cpython-37.pyc
│ ├── kitti.cpython-37.pyc
│ ├── kitti.cpython-39.pyc
│ ├── kitti_odometry.cpython-37.pyc
│ └── kitti_odometry.cpython-39.pyc
├── kitti.py
└── kitti_odometry.py
├── eval_ddad.py
├── eval_kitti.py
├── generate_dynamic_mask.py
├── hybrid_evaluate_depth.py
├── networks
├── AFNet.py
├── __init__.py
├── __pycache__
│ ├── AFNet.cpython-37.pyc
│ ├── AFNet.cpython-39.pyc
│ ├── AFNet_efficient.cpython-37.pyc
│ ├── AFNet_main.cpython-37.pyc
│ ├── AFNet_mobile.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-39.pyc
│ ├── module.cpython-37.pyc
│ ├── module.cpython-39.pyc
│ └── mvs2d.cpython-37.pyc
├── mobilenet.py
└── module.py
├── options.py
├── requirements.txt
├── scripts
├── test.sh
└── train.sh
├── train_af.py
├── train_kitti.py
├── trainer_base_af.py
├── trainer_base_kitti.py
├── utils.py
└── visual_ddad.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 cjd24-coder
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
AFNet: Adaptive Fusion of Single-View and Multi-View Depth for Autonomous Driving
3 | **CVPR 2024**
4 |
5 | [Paper](https://arxiv.org/pdf/2403.07535.pdf)
6 |
7 |
8 | This work presents AFNet, a new multi-view and singleview depth fusion network AFNet for alleviating the defects of the existing multi-view methods, which will fail under noisy poses in real-world autonomous driving scenarios.
9 |
10 | 
11 |
12 |
13 | ## ✏️ Changelog
14 | ### Mar. 20 2024
15 | * Initial release. Due to the confidentiality agreement, the accuracy of the current reproduced model on KITTI is very slightly different from that in the paper. We release an initial version first, and the final version will be released soon.
16 |
17 | * In addition, the models trained under noise pose will soon be released.
18 |
19 |
20 | ## ⚙️ Installation
21 |
22 | The code is tested with CUDA11.7. Please use the following commands to install dependencies:
23 |
24 | ```
25 | conda create --name AFNet python=3.7
26 | conda activate AFNet
27 | pip install -r requirements.txt
28 | ```
29 |
30 | ## 🎬 Demo
31 | 
32 |
33 |
34 | ## ⏳ Training & Testing
35 |
36 | We use 4 Nvidia 3090 GPU for training. You may need to modify 'CUDA_VISIBLE_DEVICES' and batch size to accommodate your GPU resources.
37 |
38 | #### Training
39 | First download and extract DDAD and KITTI data and split. You should download and process DDAD dataset follow [DDAD🔗](https://github.com/TRI-ML/DDAD).
40 | #### Download
41 | [__split__ 🔗](https://1drv.ms/u/s!AtFfCZ2Ckf3DghYrBvQ-DWCQR1Nd?e=Q6qz8d) (You need to move this json file in split to the data_split path)
42 | [ models 🔗](https://1drv.ms/u/s!AtFfCZ2Ckf3DghVXXZY611mqxa8B?e=nZ7taR) (models for testing)
43 |
44 | Then run the following command to train our model.
45 | ```
46 | bash scripts/train.sh
47 | ```
48 |
49 | #### Testing
50 | First download and extract data, split and pretrained models.
51 |
52 | ### DDAD:
53 | run:
54 | ```
55 | python eval_ddad.py --cfg "./configs/DDAD.conf"
56 | ```
57 |
58 | You should get something like these:
59 |
60 | | abs_rel | sq_rel | log10 | rmse | rmse_log | a1 | a2 | a3 | abs_diff |
61 | |---------|--------|-------|-------|----------|-------|-------|-------|----------|
62 | | 0.088 | 0.979 | 0.035 | 4.60 | 0.154 | 0.917 | 0.972 | 0.987 | 2.042 |
63 |
64 | ### KITTI:
65 | run:
66 | ```
67 | python eval_kitti.py --cfg "./configs/kitti.conf"
68 | ```
69 | You should get something like these:
70 |
71 | | abs_rel | sq_rel | log10 | rmse | rmse_log | a1 | a2 | a3 | abs_diff |
72 | |---------|--------|-------|-------|----------|-------|-------|-------|----------|
73 | | 0.044 | 0.132 | 0.019 | 1.712 | 0.069 | 0.980 | 0.997 | 0.999 | 0.804 |
74 |
75 |
76 | #### Acknowledgement
77 | Thanks to Zhenpei Yang for opening source of his excellent works [MVS2D](https://github.com/zhenpeiyang/MVS2D?tab=readme-ov-file#nov-27-2021)
78 |
79 | ## Citation
80 |
81 | If you find this project useful, please consider citing:
82 |
83 | ```bibtex
84 | @misc{cheng2024adaptive,
85 | title={Adaptive Fusion of Single-View and Multi-View Depth for Autonomous Driving},
86 | author={JunDa Cheng and Wei Yin and Kaixuan Wang and Xiaozhi Chen and Shijie Wang and Xin Yang},
87 | year={2024},
88 | eprint={2403.07535},
89 | archivePrefix={arXiv},
90 | primaryClass={cs.CV}
91 | }
92 | ```
93 |
94 |
95 |
--------------------------------------------------------------------------------
/assets/pointcloud2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/assets/pointcloud2.png
--------------------------------------------------------------------------------
/assets/visual_compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/assets/visual_compare.png
--------------------------------------------------------------------------------
/configs/DDAD.conf:
--------------------------------------------------------------------------------
1 | include "./base.conf"
2 | dataset = DDAD
3 | data_path = /data/cjd/ddad/my_ddad/
4 | height = 448
5 | width = 768
6 | eval_height = 1216
7 | eval_width = 1920
8 | min_depth = 1.0
9 | max_depth = 80.0
10 | EVAL_MIN_DEPTH = 1.0
11 | EVAL_MAX_DEPTH = 80.0
12 | use_test = 0
13 | num_frame = 3
14 | num_frame_test = 3
15 | perturb_pose = 0
16 | LR = 1e-4
17 | fullsize_eval = 0
18 | DECAY_STEP_LIST = [10, 20]
19 | pred_conf = 1
20 | use_skip = 1
21 | num_epochs = 30
22 | loss_d = L1
23 | batch_size = 4
24 | nlabel = 96
25 | max_conf = 0.06
26 | min_conf = 2e-4
27 | multiprocessing_distributed = 1
28 | monitor_key = [abs_rel, thre1]
29 | monitor_goal = [minimize, maximize]
30 | num_workers = 16
31 | log_dir = /data/cjd/AFnet/
32 | model_name = AFNet_mobile
--------------------------------------------------------------------------------
/configs/base.conf:
--------------------------------------------------------------------------------
1 | data_path = ./data/ScanNet
2 | log_dir = experiments/test/
3 | model_name = test
4 | overwrite = True
5 | note = ""
6 | mode = train
7 |
8 | ## TRAINING OPTIONS
9 | num_workers = 4
10 | DECAY_STEP_LIST = [10, 20]
11 | LR = 2e-4
12 | LR_DECAY = 0.1
13 | LR_CLIP = 1e-6
14 | WEIGHT_DECAY = 0
15 | MOMENTUM = 0.9
16 | GRAD_NORM_CLIP = 10.0
17 | loss_d = L1
18 | launcher = pytorch
19 | tcp_port = 18891
20 | multiprocessing_distributed = 0
21 | dist_backend = nccl
22 | dist_url = tcp://127.0.0.1:1234
23 | num_epochs = 15
24 | monitor_key = ""
25 | optimizer = adam
26 | epoch_size = -1
27 | log_frequency = 10
28 | val_frequency = 100
29 | save_frequency = 1
30 | save_prediction = 0
31 | eval_vis = 0
32 | num_epoch = 30
33 | batch_size = 2
34 |
35 | ## MODEL OPTIONS
36 | pred_conf = 0
37 | use_skip = 1
38 | nlabel = 64
39 | mono = 0
40 | multi_view_agg = 1
41 | depth_embedding = learned
42 | robust = False
43 | unet_channel_mode = v0
44 | use_unet = True
45 | num_depth_regressor_anchor = 512
46 | att_rate = 4
47 | inv_depth = 1
48 | resnet_layers = [2,2,2,2]
49 | max_conf = -1
50 | clip_conf = True
51 | model_size = default
52 | fast = 0
53 | depth_head = default
54 | feat_dim = 32
55 | nhead = 1
56 |
57 | ## DATASET OPTIONS
58 | height = 640
59 | width = 480
60 | dataset = ScanNet
61 | input_scale = 0
62 | output_scale = 2
63 | min_depth = 0.1
64 | max_depth = 10.0
65 | seed = 0
66 | use_test = 0
67 | num_frame = 3
68 | perturb_pose = 0
69 | sceneID = None
70 | filter = None
71 | fullsize_eval = 0
72 | random_lighting = 0
73 | disable_color_aug = 0
74 |
75 |
76 | ## EVAL OPTIONS
77 | conf_thres = 3.5
78 | num_consistent = 1
79 | geo_pixel_thres = 0.75
80 | geo_depth_thres = 0.0001
81 | att_thres = 0.25
82 | EVAL_MIN_DEPTH = 0.5
83 | EVAL_MAX_DEPTH = 10.0
84 | disable_median_scaling = 1
85 | num_frame_test = 5
86 | val_epoch_size = -1
87 |
88 | benchmark_fps = False
89 |
--------------------------------------------------------------------------------
/configs/kitti.conf:
--------------------------------------------------------------------------------
1 | include "./base.conf"
2 | dataset = kitti
3 | data_path = /data/cjd/kitti_raw/
4 | height = 352
5 | width = 1216
6 | eval_height = 352
7 | eval_width = 1216
8 | min_depth = 1.0
9 | max_depth = 60.0
10 | EVAL_MIN_DEPTH = 1.0
11 | EVAL_MAX_DEPTH = 60.0
12 | use_test = 0
13 | num_frame = 3
14 | num_frame_test = 3
15 | perturb_pose = 0
16 | LR = 1e-4
17 | fullsize_eval = 0
18 | DECAY_STEP_LIST = [10, 20]
19 | pred_conf = 1
20 | use_skip = 1
21 | num_epochs = 50
22 | loss_d = L1
23 | batch_size = 3
24 | nlabel = 96
25 | max_conf = 0.06
26 | min_conf = 2e-4
27 | multiprocessing_distributed = 1
28 | monitor_key = [abs_rel, thre1]
29 | monitor_goal = [minimize, maximize]
30 | num_workers = 16
31 | log_dir = /data/cjd/AFnet/
32 | model_name = AF_kitti3
--------------------------------------------------------------------------------
/data_split/kitti_eigen_test.txt:
--------------------------------------------------------------------------------
1 | 2011_09_26 0002 val 69
2 | 2011_09_26 0002 val 54
3 | 2011_09_26 0002 val 42
4 | 2011_09_26 0002 val 57
5 | 2011_09_26 0002 val 30
6 | 2011_09_26 0002 val 27
7 | 2011_09_26 0002 val 12
8 | 2011_09_26 0002 val 36
9 | 2011_09_26 0002 val 33
10 | 2011_09_26 0002 val 15
11 | 2011_09_26 0002 val 39
12 | 2011_09_26 0002 val 9
13 | 2011_09_26 0002 val 51
14 | 2011_09_26 0002 val 60
15 | 2011_09_26 0002 val 21
16 | 2011_09_26 0002 val 24
17 | 2011_09_26 0002 val 45
18 | 2011_09_26 0002 val 18
19 | 2011_09_26 0002 val 48
20 | 2011_09_26 0002 val 6
21 | 2011_09_26 0002 val 63
22 | 2011_09_26 0009 train 16
23 | 2011_09_26 0009 train 32
24 | 2011_09_26 0009 train 48
25 | 2011_09_26 0009 train 64
26 | 2011_09_26 0009 train 80
27 | 2011_09_26 0009 train 96
28 | 2011_09_26 0009 train 112
29 | 2011_09_26 0009 train 128
30 | 2011_09_26 0009 train 144
31 | 2011_09_26 0009 train 160
32 | 2011_09_26 0009 train 176
33 | 2011_09_26 0009 train 196
34 | 2011_09_26 0009 train 212
35 | 2011_09_26 0009 train 228
36 | 2011_09_26 0009 train 244
37 | 2011_09_26 0009 train 260
38 | 2011_09_26 0009 train 276
39 | 2011_09_26 0009 train 292
40 | 2011_09_26 0009 train 308
41 | 2011_09_26 0009 train 324
42 | 2011_09_26 0009 train 340
43 | 2011_09_26 0009 train 356
44 | 2011_09_26 0009 train 372
45 | 2011_09_26 0009 train 388
46 | 2011_09_26 0013 val 90
47 | 2011_09_26 0013 val 50
48 | 2011_09_26 0013 val 110
49 | 2011_09_26 0013 val 115
50 | 2011_09_26 0013 val 60
51 | 2011_09_26 0013 val 105
52 | 2011_09_26 0013 val 125
53 | 2011_09_26 0013 val 20
54 | 2011_09_26 0013 val 85
55 | 2011_09_26 0013 val 70
56 | 2011_09_26 0013 val 80
57 | 2011_09_26 0013 val 65
58 | 2011_09_26 0013 val 95
59 | 2011_09_26 0013 val 130
60 | 2011_09_26 0013 val 100
61 | 2011_09_26 0013 val 10
62 | 2011_09_26 0013 val 30
63 | 2011_09_26 0013 val 135
64 | 2011_09_26 0013 val 40
65 | 2011_09_26 0013 val 5
66 | 2011_09_26 0013 val 120
67 | 2011_09_26 0013 val 45
68 | 2011_09_26 0013 val 35
69 | 2011_09_26 0020 val 69
70 | 2011_09_26 0020 val 57
71 | 2011_09_26 0020 val 12
72 | 2011_09_26 0020 val 72
73 | 2011_09_26 0020 val 18
74 | 2011_09_26 0020 val 63
75 | 2011_09_26 0020 val 15
76 | 2011_09_26 0020 val 66
77 | 2011_09_26 0020 val 6
78 | 2011_09_26 0020 val 48
79 | 2011_09_26 0020 val 60
80 | 2011_09_26 0020 val 9
81 | 2011_09_26 0020 val 33
82 | 2011_09_26 0020 val 21
83 | 2011_09_26 0020 val 75
84 | 2011_09_26 0020 val 27
85 | 2011_09_26 0020 val 45
86 | 2011_09_26 0020 val 78
87 | 2011_09_26 0020 val 36
88 | 2011_09_26 0020 val 51
89 | 2011_09_26 0020 val 54
90 | 2011_09_26 0020 val 42
91 | 2011_09_26 0023 val 18
92 | 2011_09_26 0023 val 90
93 | 2011_09_26 0023 val 126
94 | 2011_09_26 0023 val 378
95 | 2011_09_26 0023 val 36
96 | 2011_09_26 0023 val 288
97 | 2011_09_26 0023 val 198
98 | 2011_09_26 0023 val 450
99 | 2011_09_26 0023 val 144
100 | 2011_09_26 0023 val 72
101 | 2011_09_26 0023 val 252
102 | 2011_09_26 0023 val 180
103 | 2011_09_26 0023 val 432
104 | 2011_09_26 0023 val 396
105 | 2011_09_26 0023 val 54
106 | 2011_09_26 0023 val 468
107 | 2011_09_26 0023 val 306
108 | 2011_09_26 0023 val 108
109 | 2011_09_26 0023 val 162
110 | 2011_09_26 0023 val 342
111 | 2011_09_26 0023 val 270
112 | 2011_09_26 0023 val 414
113 | 2011_09_26 0023 val 216
114 | 2011_09_26 0023 val 360
115 | 2011_09_26 0023 val 324
116 | 2011_09_26 0027 train 77
117 | 2011_09_26 0027 train 35
118 | 2011_09_26 0027 train 91
119 | 2011_09_26 0027 train 112
120 | 2011_09_26 0027 train 7
121 | 2011_09_26 0027 train 175
122 | 2011_09_26 0027 train 42
123 | 2011_09_26 0027 train 98
124 | 2011_09_26 0027 train 133
125 | 2011_09_26 0027 train 161
126 | 2011_09_26 0027 train 14
127 | 2011_09_26 0027 train 126
128 | 2011_09_26 0027 train 168
129 | 2011_09_26 0027 train 70
130 | 2011_09_26 0027 train 84
131 | 2011_09_26 0027 train 140
132 | 2011_09_26 0027 train 49
133 | 2011_09_26 0027 train 182
134 | 2011_09_26 0027 train 147
135 | 2011_09_26 0027 train 56
136 | 2011_09_26 0027 train 63
137 | 2011_09_26 0027 train 21
138 | 2011_09_26 0027 train 119
139 | 2011_09_26 0027 train 28
140 | 2011_09_26 0029 train 380
141 | 2011_09_26 0029 train 394
142 | 2011_09_26 0029 train 324
143 | 2011_09_26 0029 train 268
144 | 2011_09_26 0029 train 366
145 | 2011_09_26 0029 train 296
146 | 2011_09_26 0029 train 14
147 | 2011_09_26 0029 train 28
148 | 2011_09_26 0029 train 182
149 | 2011_09_26 0029 train 168
150 | 2011_09_26 0029 train 196
151 | 2011_09_26 0029 train 140
152 | 2011_09_26 0029 train 84
153 | 2011_09_26 0029 train 56
154 | 2011_09_26 0029 train 112
155 | 2011_09_26 0029 train 352
156 | 2011_09_26 0029 train 126
157 | 2011_09_26 0029 train 70
158 | 2011_09_26 0029 train 310
159 | 2011_09_26 0029 train 154
160 | 2011_09_26 0029 train 98
161 | 2011_09_26 0029 train 408
162 | 2011_09_26 0029 train 42
163 | 2011_09_26 0029 train 338
164 | 2011_09_26 0036 val 128
165 | 2011_09_26 0036 val 192
166 | 2011_09_26 0036 val 32
167 | 2011_09_26 0036 val 352
168 | 2011_09_26 0036 val 608
169 | 2011_09_26 0036 val 224
170 | 2011_09_26 0036 val 576
171 | 2011_09_26 0036 val 672
172 | 2011_09_26 0036 val 64
173 | 2011_09_26 0036 val 448
174 | 2011_09_26 0036 val 704
175 | 2011_09_26 0036 val 640
176 | 2011_09_26 0036 val 512
177 | 2011_09_26 0036 val 768
178 | 2011_09_26 0036 val 160
179 | 2011_09_26 0036 val 416
180 | 2011_09_26 0036 val 480
181 | 2011_09_26 0036 val 288
182 | 2011_09_26 0036 val 544
183 | 2011_09_26 0036 val 96
184 | 2011_09_26 0036 val 384
185 | 2011_09_26 0036 val 256
186 | 2011_09_26 0036 val 320
187 | 2011_09_26 0046 train 5
188 | 2011_09_26 0046 train 10
189 | 2011_09_26 0046 train 15
190 | 2011_09_26 0046 train 20
191 | 2011_09_26 0046 train 25
192 | 2011_09_26 0046 train 30
193 | 2011_09_26 0046 train 35
194 | 2011_09_26 0046 train 40
195 | 2011_09_26 0046 train 45
196 | 2011_09_26 0046 train 50
197 | 2011_09_26 0046 train 55
198 | 2011_09_26 0046 train 60
199 | 2011_09_26 0046 train 65
200 | 2011_09_26 0046 train 70
201 | 2011_09_26 0046 train 75
202 | 2011_09_26 0046 train 80
203 | 2011_09_26 0046 train 85
204 | 2011_09_26 0046 train 90
205 | 2011_09_26 0046 train 95
206 | 2011_09_26 0046 train 100
207 | 2011_09_26 0046 train 105
208 | 2011_09_26 0046 train 110
209 | 2011_09_26 0046 train 115
210 | 2011_09_26 0048 train 5
211 | 2011_09_26 0048 train 6
212 | 2011_09_26 0048 train 7
213 | 2011_09_26 0048 train 8
214 | 2011_09_26 0048 train 9
215 | 2011_09_26 0048 train 10
216 | 2011_09_26 0048 train 11
217 | 2011_09_26 0048 train 12
218 | 2011_09_26 0048 train 13
219 | 2011_09_26 0048 train 14
220 | 2011_09_26 0048 train 15
221 | 2011_09_26 0048 train 16
222 | 2011_09_26 0052 train 46
223 | 2011_09_26 0052 train 14
224 | 2011_09_26 0052 train 36
225 | 2011_09_26 0052 train 28
226 | 2011_09_26 0052 train 26
227 | 2011_09_26 0052 train 50
228 | 2011_09_26 0052 train 40
229 | 2011_09_26 0052 train 8
230 | 2011_09_26 0052 train 16
231 | 2011_09_26 0052 train 44
232 | 2011_09_26 0052 train 18
233 | 2011_09_26 0052 train 32
234 | 2011_09_26 0052 train 42
235 | 2011_09_26 0052 train 10
236 | 2011_09_26 0052 train 20
237 | 2011_09_26 0052 train 48
238 | 2011_09_26 0052 train 52
239 | 2011_09_26 0052 train 6
240 | 2011_09_26 0052 train 30
241 | 2011_09_26 0052 train 12
242 | 2011_09_26 0052 train 38
243 | 2011_09_26 0052 train 22
244 | 2011_09_26 0056 train 11
245 | 2011_09_26 0056 train 33
246 | 2011_09_26 0056 train 242
247 | 2011_09_26 0056 train 253
248 | 2011_09_26 0056 train 286
249 | 2011_09_26 0056 train 154
250 | 2011_09_26 0056 train 99
251 | 2011_09_26 0056 train 220
252 | 2011_09_26 0056 train 22
253 | 2011_09_26 0056 train 77
254 | 2011_09_26 0056 train 187
255 | 2011_09_26 0056 train 143
256 | 2011_09_26 0056 train 66
257 | 2011_09_26 0056 train 176
258 | 2011_09_26 0056 train 110
259 | 2011_09_26 0056 train 275
260 | 2011_09_26 0056 train 264
261 | 2011_09_26 0056 train 198
262 | 2011_09_26 0056 train 55
263 | 2011_09_26 0056 train 88
264 | 2011_09_26 0056 train 121
265 | 2011_09_26 0056 train 209
266 | 2011_09_26 0056 train 165
267 | 2011_09_26 0056 train 231
268 | 2011_09_26 0056 train 44
269 | 2011_09_26 0059 train 56
270 | 2011_09_26 0059 train 344
271 | 2011_09_26 0059 train 358
272 | 2011_09_26 0059 train 316
273 | 2011_09_26 0059 train 238
274 | 2011_09_26 0059 train 98
275 | 2011_09_26 0059 train 112
276 | 2011_09_26 0059 train 28
277 | 2011_09_26 0059 train 14
278 | 2011_09_26 0059 train 330
279 | 2011_09_26 0059 train 154
280 | 2011_09_26 0059 train 42
281 | 2011_09_26 0059 train 302
282 | 2011_09_26 0059 train 182
283 | 2011_09_26 0059 train 288
284 | 2011_09_26 0059 train 140
285 | 2011_09_26 0059 train 274
286 | 2011_09_26 0059 train 224
287 | 2011_09_26 0059 train 196
288 | 2011_09_26 0059 train 126
289 | 2011_09_26 0059 train 84
290 | 2011_09_26 0059 train 210
291 | 2011_09_26 0059 train 70
292 | 2011_09_26 0064 train 528
293 | 2011_09_26 0064 train 308
294 | 2011_09_26 0064 train 44
295 | 2011_09_26 0064 train 352
296 | 2011_09_26 0064 train 66
297 | 2011_09_26 0064 train 506
298 | 2011_09_26 0064 train 176
299 | 2011_09_26 0064 train 22
300 | 2011_09_26 0064 train 242
301 | 2011_09_26 0064 train 462
302 | 2011_09_26 0064 train 418
303 | 2011_09_26 0064 train 110
304 | 2011_09_26 0064 train 440
305 | 2011_09_26 0064 train 396
306 | 2011_09_26 0064 train 154
307 | 2011_09_26 0064 train 374
308 | 2011_09_26 0064 train 88
309 | 2011_09_26 0064 train 286
310 | 2011_09_26 0064 train 550
311 | 2011_09_26 0064 train 264
312 | 2011_09_26 0064 train 220
313 | 2011_09_26 0064 train 330
314 | 2011_09_26 0064 train 484
315 | 2011_09_26 0064 train 198
316 | 2011_09_26 0084 train 283
317 | 2011_09_26 0084 train 361
318 | 2011_09_26 0084 train 270
319 | 2011_09_26 0084 train 127
320 | 2011_09_26 0084 train 205
321 | 2011_09_26 0084 train 218
322 | 2011_09_26 0084 train 153
323 | 2011_09_26 0084 train 335
324 | 2011_09_26 0084 train 192
325 | 2011_09_26 0084 train 348
326 | 2011_09_26 0084 train 101
327 | 2011_09_26 0084 train 49
328 | 2011_09_26 0084 train 179
329 | 2011_09_26 0084 train 140
330 | 2011_09_26 0084 train 374
331 | 2011_09_26 0084 train 322
332 | 2011_09_26 0084 train 309
333 | 2011_09_26 0084 train 244
334 | 2011_09_26 0084 train 62
335 | 2011_09_26 0084 train 257
336 | 2011_09_26 0084 train 88
337 | 2011_09_26 0084 train 114
338 | 2011_09_26 0084 train 75
339 | 2011_09_26 0084 train 296
340 | 2011_09_26 0084 train 231
341 | 2011_09_26 0086 train 7
342 | 2011_09_26 0086 train 196
343 | 2011_09_26 0086 train 439
344 | 2011_09_26 0086 train 169
345 | 2011_09_26 0086 train 115
346 | 2011_09_26 0086 train 34
347 | 2011_09_26 0086 train 304
348 | 2011_09_26 0086 train 331
349 | 2011_09_26 0086 train 277
350 | 2011_09_26 0086 train 520
351 | 2011_09_26 0086 train 682
352 | 2011_09_26 0086 train 628
353 | 2011_09_26 0086 train 88
354 | 2011_09_26 0086 train 601
355 | 2011_09_26 0086 train 574
356 | 2011_09_26 0086 train 223
357 | 2011_09_26 0086 train 655
358 | 2011_09_26 0086 train 358
359 | 2011_09_26 0086 train 412
360 | 2011_09_26 0086 train 142
361 | 2011_09_26 0086 train 385
362 | 2011_09_26 0086 train 61
363 | 2011_09_26 0086 train 493
364 | 2011_09_26 0086 train 466
365 | 2011_09_26 0086 train 250
366 | 2011_09_26 0093 train 16
367 | 2011_09_26 0093 train 32
368 | 2011_09_26 0093 train 48
369 | 2011_09_26 0093 train 64
370 | 2011_09_26 0093 train 80
371 | 2011_09_26 0093 train 96
372 | 2011_09_26 0093 train 112
373 | 2011_09_26 0093 train 128
374 | 2011_09_26 0093 train 144
375 | 2011_09_26 0093 train 160
376 | 2011_09_26 0093 train 176
377 | 2011_09_26 0093 train 192
378 | 2011_09_26 0093 train 208
379 | 2011_09_26 0093 train 224
380 | 2011_09_26 0093 train 240
381 | 2011_09_26 0093 train 256
382 | 2011_09_26 0093 train 305
383 | 2011_09_26 0093 train 321
384 | 2011_09_26 0093 train 337
385 | 2011_09_26 0093 train 353
386 | 2011_09_26 0093 train 369
387 | 2011_09_26 0093 train 385
388 | 2011_09_26 0093 train 401
389 | 2011_09_26 0093 train 417
390 | 2011_09_26 0096 train 19
391 | 2011_09_26 0096 train 38
392 | 2011_09_26 0096 train 57
393 | 2011_09_26 0096 train 76
394 | 2011_09_26 0096 train 95
395 | 2011_09_26 0096 train 114
396 | 2011_09_26 0096 train 133
397 | 2011_09_26 0096 train 152
398 | 2011_09_26 0096 train 171
399 | 2011_09_26 0096 train 190
400 | 2011_09_26 0096 train 209
401 | 2011_09_26 0096 train 228
402 | 2011_09_26 0096 train 247
403 | 2011_09_26 0096 train 266
404 | 2011_09_26 0096 train 285
405 | 2011_09_26 0096 train 304
406 | 2011_09_26 0096 train 323
407 | 2011_09_26 0096 train 342
408 | 2011_09_26 0096 train 361
409 | 2011_09_26 0096 train 380
410 | 2011_09_26 0096 train 399
411 | 2011_09_26 0096 train 418
412 | 2011_09_26 0096 train 437
413 | 2011_09_26 0096 train 456
414 | 2011_09_26 0101 train 692
415 | 2011_09_26 0101 train 930
416 | 2011_09_26 0101 train 760
417 | 2011_09_26 0101 train 896
418 | 2011_09_26 0101 train 284
419 | 2011_09_26 0101 train 148
420 | 2011_09_26 0101 train 522
421 | 2011_09_26 0101 train 794
422 | 2011_09_26 0101 train 624
423 | 2011_09_26 0101 train 726
424 | 2011_09_26 0101 train 216
425 | 2011_09_26 0101 train 318
426 | 2011_09_26 0101 train 488
427 | 2011_09_26 0101 train 590
428 | 2011_09_26 0101 train 454
429 | 2011_09_26 0101 train 862
430 | 2011_09_26 0101 train 386
431 | 2011_09_26 0101 train 352
432 | 2011_09_26 0101 train 420
433 | 2011_09_26 0101 train 658
434 | 2011_09_26 0101 train 828
435 | 2011_09_26 0101 train 556
436 | 2011_09_26 0101 train 114
437 | 2011_09_26 0101 train 182
438 | 2011_09_26 0101 train 80
439 | 2011_09_26 0106 train 15
440 | 2011_09_26 0106 train 35
441 | 2011_09_26 0106 train 43
442 | 2011_09_26 0106 train 51
443 | 2011_09_26 0106 train 59
444 | 2011_09_26 0106 train 67
445 | 2011_09_26 0106 train 75
446 | 2011_09_26 0106 train 83
447 | 2011_09_26 0106 train 91
448 | 2011_09_26 0106 train 99
449 | 2011_09_26 0106 train 107
450 | 2011_09_26 0106 train 115
451 | 2011_09_26 0106 train 123
452 | 2011_09_26 0106 train 131
453 | 2011_09_26 0106 train 139
454 | 2011_09_26 0106 train 147
455 | 2011_09_26 0106 train 155
456 | 2011_09_26 0106 train 163
457 | 2011_09_26 0106 train 171
458 | 2011_09_26 0106 train 179
459 | 2011_09_26 0106 train 187
460 | 2011_09_26 0106 train 195
461 | 2011_09_26 0106 train 203
462 | 2011_09_26 0106 train 211
463 | 2011_09_26 0106 train 219
464 | 2011_09_26 0117 train 312
465 | 2011_09_26 0117 train 494
466 | 2011_09_26 0117 train 104
467 | 2011_09_26 0117 train 130
468 | 2011_09_26 0117 train 156
469 | 2011_09_26 0117 train 182
470 | 2011_09_26 0117 train 598
471 | 2011_09_26 0117 train 416
472 | 2011_09_26 0117 train 364
473 | 2011_09_26 0117 train 26
474 | 2011_09_26 0117 train 78
475 | 2011_09_26 0117 train 572
476 | 2011_09_26 0117 train 468
477 | 2011_09_26 0117 train 260
478 | 2011_09_26 0117 train 624
479 | 2011_09_26 0117 train 234
480 | 2011_09_26 0117 train 442
481 | 2011_09_26 0117 train 390
482 | 2011_09_26 0117 train 546
483 | 2011_09_26 0117 train 286
484 | 2011_09_26 0117 train 338
485 | 2011_09_26 0117 train 208
486 | 2011_09_26 0117 train 650
487 | 2011_09_26 0117 train 52
488 | 2011_09_28 0002 train 24
489 | 2011_09_28 0002 train 21
490 | 2011_09_28 0002 train 36
491 | 2011_09_28 0002 train 51
492 | 2011_09_28 0002 train 18
493 | 2011_09_28 0002 train 33
494 | 2011_09_28 0002 train 90
495 | 2011_09_28 0002 train 45
496 | 2011_09_28 0002 train 54
497 | 2011_09_28 0002 train 12
498 | 2011_09_28 0002 train 39
499 | 2011_09_28 0002 train 9
500 | 2011_09_28 0002 train 30
501 | 2011_09_28 0002 train 78
502 | 2011_09_28 0002 train 60
503 | 2011_09_28 0002 train 48
504 | 2011_09_28 0002 train 84
505 | 2011_09_28 0002 train 81
506 | 2011_09_28 0002 train 6
507 | 2011_09_28 0002 train 57
508 | 2011_09_28 0002 train 72
509 | 2011_09_28 0002 train 87
510 | 2011_09_28 0002 train 63
511 | 2011_09_29 0071 train 252
512 | 2011_09_29 0071 train 540
513 | 2011_09_29 0071 train 36
514 | 2011_09_29 0071 train 360
515 | 2011_09_29 0071 train 807
516 | 2011_09_29 0071 train 879
517 | 2011_09_29 0071 train 288
518 | 2011_09_29 0071 train 771
519 | 2011_09_29 0071 train 216
520 | 2011_09_29 0071 train 951
521 | 2011_09_29 0071 train 324
522 | 2011_09_29 0071 train 432
523 | 2011_09_29 0071 train 504
524 | 2011_09_29 0071 train 576
525 | 2011_09_29 0071 train 108
526 | 2011_09_29 0071 train 180
527 | 2011_09_29 0071 train 72
528 | 2011_09_29 0071 train 612
529 | 2011_09_29 0071 train 915
530 | 2011_09_29 0071 train 735
531 | 2011_09_29 0071 train 144
532 | 2011_09_29 0071 train 396
533 | 2011_09_29 0071 train 468
534 | 2011_09_30 0016 val 132
535 | 2011_09_30 0016 val 11
536 | 2011_09_30 0016 val 154
537 | 2011_09_30 0016 val 22
538 | 2011_09_30 0016 val 242
539 | 2011_09_30 0016 val 198
540 | 2011_09_30 0016 val 176
541 | 2011_09_30 0016 val 231
542 | 2011_09_30 0016 val 220
543 | 2011_09_30 0016 val 88
544 | 2011_09_30 0016 val 143
545 | 2011_09_30 0016 val 55
546 | 2011_09_30 0016 val 33
547 | 2011_09_30 0016 val 187
548 | 2011_09_30 0016 val 110
549 | 2011_09_30 0016 val 44
550 | 2011_09_30 0016 val 77
551 | 2011_09_30 0016 val 66
552 | 2011_09_30 0016 val 165
553 | 2011_09_30 0016 val 264
554 | 2011_09_30 0016 val 253
555 | 2011_09_30 0016 val 209
556 | 2011_09_30 0016 val 121
557 | 2011_09_30 0018 train 107
558 | 2011_09_30 0018 train 2247
559 | 2011_09_30 0018 train 1391
560 | 2011_09_30 0018 train 535
561 | 2011_09_30 0018 train 1819
562 | 2011_09_30 0018 train 1177
563 | 2011_09_30 0018 train 428
564 | 2011_09_30 0018 train 1926
565 | 2011_09_30 0018 train 749
566 | 2011_09_30 0018 train 1284
567 | 2011_09_30 0018 train 2140
568 | 2011_09_30 0018 train 1605
569 | 2011_09_30 0018 train 1498
570 | 2011_09_30 0018 train 642
571 | 2011_09_30 0018 train 2740
572 | 2011_09_30 0018 train 2419
573 | 2011_09_30 0018 train 856
574 | 2011_09_30 0018 train 2526
575 | 2011_09_30 0018 train 1712
576 | 2011_09_30 0018 train 1070
577 | 2011_09_30 0018 train 2033
578 | 2011_09_30 0018 train 214
579 | 2011_09_30 0018 train 963
580 | 2011_09_30 0018 train 2633
581 | 2011_09_30 0027 train 533
582 | 2011_09_30 0027 train 1040
583 | 2011_09_30 0027 train 82
584 | 2011_09_30 0027 train 205
585 | 2011_09_30 0027 train 835
586 | 2011_09_30 0027 train 451
587 | 2011_09_30 0027 train 164
588 | 2011_09_30 0027 train 794
589 | 2011_09_30 0027 train 328
590 | 2011_09_30 0027 train 615
591 | 2011_09_30 0027 train 917
592 | 2011_09_30 0027 train 369
593 | 2011_09_30 0027 train 287
594 | 2011_09_30 0027 train 123
595 | 2011_09_30 0027 train 876
596 | 2011_09_30 0027 train 410
597 | 2011_09_30 0027 train 492
598 | 2011_09_30 0027 train 958
599 | 2011_09_30 0027 train 656
600 | 2011_09_30 0027 train 753
601 | 2011_09_30 0027 train 574
602 | 2011_09_30 0027 train 1081
603 | 2011_09_30 0027 train 41
604 | 2011_09_30 0027 train 246
605 | 2011_10_03 0027 train 2906
606 | 2011_10_03 0027 train 2544
607 | 2011_10_03 0027 train 362
608 | 2011_10_03 0027 train 4535
609 | 2011_10_03 0027 train 734
610 | 2011_10_03 0027 train 1096
611 | 2011_10_03 0027 train 4173
612 | 2011_10_03 0027 train 543
613 | 2011_10_03 0027 train 1277
614 | 2011_10_03 0027 train 4354
615 | 2011_10_03 0027 train 1458
616 | 2011_10_03 0027 train 1820
617 | 2011_10_03 0027 train 3449
618 | 2011_10_03 0027 train 3268
619 | 2011_10_03 0027 train 915
620 | 2011_10_03 0027 train 2363
621 | 2011_10_03 0027 train 2725
622 | 2011_10_03 0027 train 181
623 | 2011_10_03 0027 train 1639
624 | 2011_10_03 0027 train 3992
625 | 2011_10_03 0027 train 3087
626 | 2011_10_03 0027 train 2001
627 | 2011_10_03 0027 train 3811
628 | 2011_10_03 0027 train 3630
629 | 2011_10_03 0047 val 96
630 | 2011_10_03 0047 val 800
631 | 2011_10_03 0047 val 320
632 | 2011_10_03 0047 val 576
633 | 2011_10_03 0047 val 480
634 | 2011_10_03 0047 val 640
635 | 2011_10_03 0047 val 32
636 | 2011_10_03 0047 val 384
637 | 2011_10_03 0047 val 160
638 | 2011_10_03 0047 val 704
639 | 2011_10_03 0047 val 736
640 | 2011_10_03 0047 val 672
641 | 2011_10_03 0047 val 64
642 | 2011_10_03 0047 val 288
643 | 2011_10_03 0047 val 352
644 | 2011_10_03 0047 val 512
645 | 2011_10_03 0047 val 544
646 | 2011_10_03 0047 val 608
647 | 2011_10_03 0047 val 128
648 | 2011_10_03 0047 val 224
649 | 2011_10_03 0047 val 416
650 | 2011_10_03 0047 val 192
651 | 2011_10_03 0047 val 448
652 | 2011_10_03 0047 val 768
--------------------------------------------------------------------------------
/datasets/DDAD.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import os
3 | import random
4 | import numpy as np
5 | import copy
6 | from PIL import Image # using pillow-simd for increased speed
7 | from time import time
8 | import torch
9 | import torch.utils.data as data
10 | from torchvision import transforms
11 | import cv2
12 |
13 | cv2.setNumThreads(0)
14 | import glob
15 | import utils
16 | import torch.nn.functional as F
17 | from utils import npy
18 | import json
19 |
20 |
21 |
22 | class DDAD(data.Dataset):
23 | def __init__(self, opt, is_train):
24 | super(DDAD, self).__init__()
25 | self.opt = opt
26 | self.json_path = "data_split/DDAD_video.json"
27 | self.data_path_root = self.opt.data_path
28 | self.is_train = is_train
29 | if self.is_train:
30 | self.data_path = os.path.join(self.data_path_root, 'train/')
31 | else:
32 | self.data_path = os.path.join(self.data_path_root, 'val/')
33 |
34 | f = open(self.json_path, 'r')
35 | content_all = f.read()
36 | json_list_all = json.loads(content_all)
37 | f.close()
38 |
39 | if self.is_train:
40 | self.file_names = json_list_all["train"]
41 | print('train', len(self.file_names))
42 | else:
43 | self.file_names = json_list_all["val"]
44 | print('val', len(self.file_names))
45 |
46 | print('filter_pre', len(self.file_names))
47 | self.file_names = [x for x in self.file_names if 'timestamp' in x.keys() and 'timestamp_back' in x.keys() and 'timestamp_forward' in x.keys()]
48 | print('filter_after', len(self.file_names))
49 |
50 |
51 |
52 | def get_k_ori_randomcrop(self, k_raw, inputs, x1, y1):
53 |
54 | fx_ori = k_raw[0,0]
55 | fy_ori = k_raw[1,1]
56 | fx_virtual = 1060.0
57 | fx_scale = fx_ori / fx_virtual
58 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale
59 |
60 | pose_cur = inputs[("pose", 0)]
61 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale
62 | inputs[("pose", 0)] = pose_cur
63 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)])
64 |
65 | pose_pre = inputs[("pose", 1)]
66 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale
67 | inputs[("pose", 1)] = pose_pre
68 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)])
69 |
70 | pose_next = inputs[("pose", 2)]
71 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale
72 | inputs[("pose", 2)] = pose_next
73 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)])
74 |
75 | K = np.zeros((3,3), dtype = float)
76 | K[0,0] = fx_ori
77 | K[1,1] = fy_ori
78 | K[2,2] = 1.0
79 | K[0,2] = k_raw[0,2]
80 | K[1,2] = k_raw[1,2]
81 |
82 | h_crop = y1 - self.opt.height
83 | w_crop = x1
84 |
85 | K[0,2] = K[0,2] - w_crop
86 | K[1,2] = K[1,2] - h_crop
87 |
88 | return K, inputs
89 |
90 | def get_k_ori_centercrop(self, k_raw, inputs):
91 |
92 | fx_ori = k_raw[0,0]
93 | fy_ori = k_raw[1,1]
94 | fx_virtual = 1060.0
95 | fx_scale = fx_ori / fx_virtual
96 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale
97 |
98 | pose_cur = inputs[("pose", 0)]
99 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale
100 | inputs[("pose", 0)] = pose_cur
101 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)])
102 |
103 | pose_pre = inputs[("pose", 1)]
104 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale
105 | inputs[("pose", 1)] = pose_pre
106 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)])
107 |
108 | pose_next = inputs[("pose", 2)]
109 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale
110 | inputs[("pose", 2)] = pose_next
111 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)])
112 |
113 | K = np.zeros((3,3), dtype = float)
114 | K[0,0] = fx_ori
115 | K[1,1] = fy_ori
116 | K[2,2] = 1.0
117 | K[0,2] = k_raw[0,2]
118 | K[1,2] = k_raw[1,2]
119 |
120 | h_crop = 0.0
121 | w_crop = 8.0
122 |
123 | K[0,2] = K[0,2] - w_crop
124 | K[1,2] = K[1,2] - h_crop
125 |
126 | return K, inputs
127 |
128 |
129 | def __len__(self):
130 | return len(self.file_names)
131 |
132 |
133 | def __getitem__(self, index):
134 | inputs = {}
135 | cur_npz_path = self.data_path + str(self.file_names[index]['timestamp']) + '_' + self.file_names[index]['Camera'] + '.npz'
136 | pre_npz_path = self.data_path + str(self.file_names[index]['timestamp_back']) + '_' + self.file_names[index]['Camera'] + '.npz'
137 | next_npz_path = self.data_path + str(self.file_names[index]['timestamp_forward']) + '_' + self.file_names[index]['Camera'] + '.npz'
138 |
139 | file_cur = np.load(cur_npz_path)
140 | file_pre = np.load(pre_npz_path)
141 | file_next = np.load(next_npz_path)
142 |
143 | depth_cur_gt = file_cur['depth']
144 | depth_cur_gt = np.array(depth_cur_gt).astype(np.float32)
145 |
146 |
147 | inputs[("depth_gt", 0, 0)] = depth_cur_gt
148 |
149 |
150 | if self.is_train:
151 | if random.randint(0, 10) < 8:
152 | y_center = int((1216 + self.opt.height)/2)
153 | y1 = random.randint(int(y_center - 70), int(y_center + 50))
154 | else:
155 | y1 = random.randint(self.opt.height, 1216)
156 | x1 = random.randint(0, 1930 - self.opt.width)
157 | else:
158 | y1 = int((1216 + self.opt.height)/2)
159 | x1 = int((1936 - self.opt.width)/2)
160 |
161 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)][y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)][None, :, :]
162 |
163 | rgb_cur = file_cur['rgb']
164 | rgb_cur = rgb_cur[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)]
165 | rgb_cur = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB)
166 | rgb_cur = torch.from_numpy(rgb_cur).permute(2, 0, 1) / 255.
167 | inputs[("color", 0, 0)] = rgb_cur
168 |
169 |
170 |
171 | rgb_pre = file_pre['rgb']
172 | rgb_pre = rgb_pre[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)]
173 | rgb_pre = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB)
174 | rgb_pre = torch.from_numpy(rgb_pre).permute(2, 0, 1) / 255.
175 | inputs[("color", 1, 0)] = rgb_pre
176 |
177 |
178 | rgb_next = file_next['rgb']
179 | rgb_next = rgb_next[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)]
180 | rgb_next = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB)
181 | rgb_next = torch.from_numpy(rgb_next).permute(2, 0, 1) / 255.
182 | inputs[("color", 2, 0)] = rgb_next
183 |
184 |
185 | pose_cur = file_cur['pose']
186 | pose_cur = np.linalg.inv(pose_cur).astype('float32')
187 | inputs[("pose", 0)] = pose_cur
188 |
189 | pose_pre = file_pre['pose']
190 | pose_pre = np.linalg.inv(pose_pre).astype('float32')
191 | inputs[("pose", 1)] = pose_pre
192 |
193 | pose_next = file_next['pose']
194 | pose_next = np.linalg.inv(pose_next).astype('float32')
195 | inputs[("pose", 2)] = pose_next
196 |
197 | k_raw = file_cur['intrinsics']
198 | k_crop, inputs = self.get_k_ori_randomcrop(k_raw, inputs, x1, y1)
199 | inputs = self.get_K(k_crop, inputs)
200 |
201 | inputs = self.compute_projection_matrix(inputs)
202 |
203 |
204 | inputs['num_frame'] = 3
205 |
206 | return inputs
207 |
208 |
209 |
210 | def get_K(self, K, inputs):
211 | inv_K = np.linalg.inv(K)
212 | K_pool = {}
213 | ho, wo = self.opt.height, self.opt.width
214 | for i in range(6):
215 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32')
216 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i
217 |
218 | inputs['K_pool'] = K_pool
219 |
220 | inputs[("inv_K_pool", 0)] = {}
221 | for k, v in K_pool.items():
222 | K44 = np.eye(4)
223 | K44[:3, :3] = v
224 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32')
225 |
226 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32'))
227 |
228 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32'))
229 |
230 | return inputs
231 |
232 | def get_K_test(self, K, inputs):
233 | inv_K = np.linalg.inv(K)
234 | K_pool = {}
235 | ho, wo = self.opt.eval_height, self.opt.eval_width
236 | for i in range(6):
237 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32')
238 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i
239 |
240 | inputs['K_pool'] = K_pool
241 |
242 | inputs[("inv_K_pool", 0)] = {}
243 | for k, v in K_pool.items():
244 | K44 = np.eye(4)
245 | K44[:3, :3] = v
246 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32')
247 |
248 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32'))
249 |
250 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32'))
251 |
252 | return inputs
253 |
254 | def compute_projection_matrix(self, inputs):
255 | for i in range(self.opt.num_frame):
256 | inputs[("proj", i)] = {}
257 | for k, v in inputs['K_pool'].items():
258 | K44 = np.eye(4)
259 | K44[:3, :3] = v
260 | inputs[("proj",
261 | i)][k] = np.matmul(K44, inputs[("pose",
262 | i)]).astype('float32')
263 | return inputs
264 |
--------------------------------------------------------------------------------
/datasets/DDAD_crop.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import os
3 | import random
4 | import numpy as np
5 | import copy
6 | from PIL import Image # using pillow-simd for increased speed
7 | from time import time
8 | import torch
9 | import torch.utils.data as data
10 | from torchvision import transforms
11 | import cv2
12 |
13 | cv2.setNumThreads(0)
14 | import glob
15 | import utils
16 | import torch.nn.functional as F
17 | from utils import npy
18 | import json
19 |
20 |
21 |
22 | class DDAD(data.Dataset):
23 | def __init__(self, opt, is_train):
24 | super(DDAD, self).__init__()
25 | self.opt = opt
26 | self.json_path = "/home/cjd/tmp/DDAD_video.json"
27 | self.data_path_root = '/data/cjd/ddad/my_ddad/'
28 | self.is_train = is_train
29 | if self.is_train:
30 | self.data_path = os.path.join(self.data_path_root, 'train/')
31 | else:
32 | self.data_path = os.path.join(self.data_path_root, 'val/')
33 |
34 | f = open(self.json_path, 'r')
35 | content_all = f.read()
36 | json_list_all = json.loads(content_all)
37 | f.close()
38 |
39 | if self.is_train:
40 | self.file_names = json_list_all["train"]
41 | # self.file_names = self.file_names[:300]
42 | print('train', len(self.file_names))
43 | else:
44 | self.file_names = json_list_all["val"]
45 | # self.file_names = self.file_names[:300]
46 |
47 | print('val', len(self.file_names))
48 |
49 | print('filter_pre', len(self.file_names))
50 | self.file_names = [x for x in self.file_names if 'timestamp' in x.keys() and 'timestamp_back' in x.keys() and 'timestamp_forward' in x.keys() and x['Camera'] == 'CAMERA_01']
51 | print('filter_after', len(self.file_names))
52 | self.file_names = self.file_names[:50]
53 |
54 |
55 |
56 | def get_k_ori_randomcrop(self, k_raw, inputs, x1, y1):
57 |
58 | fx_ori = k_raw[0,0]
59 | fy_ori = k_raw[1,1]
60 | fx_virtual = 1060.0
61 | fx_scale = fx_ori / fx_virtual
62 | # print('fx_scale', fx_scale)
63 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale
64 |
65 | pose_cur = inputs[("pose", 0)]
66 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale
67 | inputs[("pose", 0)] = pose_cur
68 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)])
69 |
70 | pose_pre = inputs[("pose", 1)]
71 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale
72 | inputs[("pose", 1)] = pose_pre
73 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)])
74 |
75 | pose_next = inputs[("pose", 2)]
76 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale
77 | inputs[("pose", 2)] = pose_next
78 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)])
79 |
80 | # inputs['focal_scale'] = float(fx_scale)
81 |
82 | K = np.zeros((3,3), dtype = float)
83 | K[0,0] = fx_ori
84 | K[1,1] = fy_ori
85 | K[2,2] = 1.0
86 | K[0,2] = k_raw[0,2]
87 | K[1,2] = k_raw[1,2]
88 |
89 | h_crop = y1 - self.opt.height
90 | w_crop = x1
91 |
92 | K[0,2] = K[0,2] - w_crop
93 | K[1,2] = K[1,2] - h_crop
94 |
95 | return K, inputs
96 |
97 | def get_k_ori_centercrop(self, k_raw, inputs, x_center, y_center):
98 |
99 | fx_ori = k_raw[0,0]
100 | fy_ori = k_raw[1,1]
101 | fx_virtual = 1060.0
102 | fx_scale = fx_ori / fx_virtual
103 | # print('fx_scale', fx_scale)
104 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale
105 |
106 | pose_cur = inputs[("pose", 0)]
107 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale
108 | inputs[("pose", 0)] = pose_cur
109 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)])
110 |
111 | pose_pre = inputs[("pose", 1)]
112 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale
113 | inputs[("pose", 1)] = pose_pre
114 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)])
115 |
116 | pose_next = inputs[("pose", 2)]
117 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale
118 | inputs[("pose", 2)] = pose_next
119 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)])
120 |
121 | # inputs['focal_scale'] = float(fx_scale)
122 |
123 | K = np.zeros((3,3), dtype = float)
124 | K[0,0] = fx_ori
125 | K[1,1] = fy_ori
126 | K[2,2] = 1.0
127 | K[0,2] = k_raw[0,2]
128 | K[1,2] = k_raw[1,2]
129 |
130 | h_crop = int(y_center - 0.5*self.opt.height)
131 | w_crop = int(x_center - 0.5*self.opt.width)
132 |
133 | K[0,2] = K[0,2] - w_crop
134 | K[1,2] = K[1,2] - h_crop
135 |
136 | return K, inputs
137 |
138 |
139 | def __len__(self):
140 | return len(self.file_names)
141 |
142 |
143 | def __getitem__(self, index):
144 | inputs = {}
145 | cur_npz_path = self.data_path + str(self.file_names[index]['timestamp']) + '_' + self.file_names[index]['Camera'] + '.npz'
146 | pre_npz_path = self.data_path + str(self.file_names[index]['timestamp_back']) + '_' + self.file_names[index]['Camera'] + '.npz'
147 | next_npz_path = self.data_path + str(self.file_names[index]['timestamp_forward']) + '_' + self.file_names[index]['Camera'] + '.npz'
148 |
149 | cur_mask_path = cur_npz_path.replace('.npz', '_dynamic.npz')
150 | inputs['dynamic_mask'] = cur_mask_path
151 | file_cur = np.load(cur_npz_path)
152 | file_pre = np.load(pre_npz_path)
153 | file_next = np.load(next_npz_path)
154 |
155 | depth_cur_gt = file_cur['depth']
156 | depth_cur_gt = np.array(depth_cur_gt).astype(np.float32)
157 |
158 |
159 | inputs[("depth_gt", 0, 0)] = depth_cur_gt
160 |
161 |
162 | if self.is_train:
163 | #h_ori=1216, w_ori=1936
164 | if random.randint(0, 10) < 8:
165 | y_center = int((1216 + self.opt.height)/2)
166 | y1 = random.randint(int(y_center - 70), int(y_center + 50))
167 | else:
168 | y1 = random.randint(self.opt.height, 1216)
169 | x1 = random.randint(0, 1930 - self.opt.width)
170 | # else:
171 | # y1 = int((1216 + self.opt.height)/2)
172 | # x1 = int((1936 - self.opt.width)/2)
173 |
174 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)][y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)][None, :, :]
175 |
176 | rgb_cur = file_cur['rgb']
177 | rgb_cur = rgb_cur[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)]
178 | rgb_cur = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB)
179 | rgb_cur = torch.from_numpy(rgb_cur).permute(2, 0, 1) / 255.
180 | inputs[("color", 0, 0)] = rgb_cur
181 |
182 |
183 |
184 | rgb_pre = file_pre['rgb']
185 | rgb_pre = rgb_pre[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)]
186 | rgb_pre = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB)
187 | rgb_pre = torch.from_numpy(rgb_pre).permute(2, 0, 1) / 255.
188 | inputs[("color", 1, 0)] = rgb_pre
189 |
190 |
191 | rgb_next = file_next['rgb']
192 | rgb_next = rgb_next[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)]
193 | rgb_next = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB)
194 | rgb_next = torch.from_numpy(rgb_next).permute(2, 0, 1) / 255.
195 | inputs[("color", 2, 0)] = rgb_next
196 |
197 |
198 | pose_cur = file_cur['pose']
199 | pose_cur = np.linalg.inv(pose_cur).astype('float32')
200 | inputs[("pose", 0)] = pose_cur
201 |
202 | pose_pre = file_pre['pose']
203 | pose_pre = np.linalg.inv(pose_pre).astype('float32')
204 | inputs[("pose", 1)] = pose_pre
205 |
206 | pose_next = file_next['pose']
207 | pose_next = np.linalg.inv(pose_next).astype('float32')
208 | inputs[("pose", 2)] = pose_next
209 |
210 | k_raw = file_cur['intrinsics']
211 | # print('k_raw',k_raw)
212 | k_crop, inputs = self.get_k_ori_randomcrop(k_raw, inputs, x1, y1)
213 | inputs = self.get_K(k_crop, inputs)
214 |
215 |
216 | else:
217 | x_center = int((1936 + self.opt.width)/2)
218 | y_center = int((1216 + self.opt.height)/2)
219 |
220 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)][int(y_center - 0.5*self.opt.height):int(y_center + 0.5*self.opt.height), int(x_center - 0.5*self.opt.width):int(x_center + 0.5*self.opt.width)][None, :, :]
221 | print(inputs[("depth_gt", 0, 0)].shape)
222 | rgb_cur = file_cur['rgb']
223 | rgb_cur = rgb_cur[int(y_center - 0.5*self.opt.height):int(y_center + 0.5*self.opt.height), int(x_center - 0.5*self.opt.width):int(x_center + 0.5*self.opt.width)]
224 | rgb_cur = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB)
225 | rgb_cur = torch.from_numpy(rgb_cur).permute(2, 0, 1) / 255.
226 | inputs[("color", 0, 0)] = rgb_cur
227 |
228 |
229 |
230 | rgb_pre = file_pre['rgb']
231 | rgb_pre = rgb_pre[int(y_center - 0.5*self.opt.height):int(y_center + 0.5*self.opt.height), int(x_center - 0.5*self.opt.width):int(x_center + 0.5*self.opt.width)]
232 | rgb_pre = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB)
233 | rgb_pre = torch.from_numpy(rgb_pre).permute(2, 0, 1) / 255.
234 | inputs[("color", 1, 0)] = rgb_pre
235 |
236 |
237 | rgb_next = file_next['rgb']
238 | rgb_next = rgb_next[int(y_center - 0.5*self.opt.height):int(y_center + 0.5*self.opt.height), int(x_center - 0.5*self.opt.width):int(x_center + 0.5*self.opt.width)]
239 | rgb_next = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB)
240 | rgb_next = torch.from_numpy(rgb_next).permute(2, 0, 1) / 255.
241 | inputs[("color", 2, 0)] = rgb_next
242 |
243 |
244 | pose_cur = file_cur['pose']
245 | pose_cur = np.linalg.inv(pose_cur).astype('float32')
246 | inputs[("pose", 0)] = pose_cur
247 |
248 | pose_pre = file_pre['pose']
249 | pose_pre = np.linalg.inv(pose_pre).astype('float32')
250 | inputs[("pose", 1)] = pose_pre
251 |
252 | pose_next = file_next['pose']
253 | pose_next = np.linalg.inv(pose_next).astype('float32')
254 | inputs[("pose", 2)] = pose_next
255 |
256 | k_raw = file_cur['intrinsics']
257 |
258 | k_crop, inputs = self.get_k_ori_centercrop(k_raw, inputs, x_center, y_center)
259 |
260 | inputs = self.get_K_test(k_crop, inputs)
261 |
262 | inputs = self.compute_projection_matrix(inputs)
263 |
264 |
265 | inputs['num_frame'] = 3
266 |
267 | # for key, value in inputs.items():
268 | # print(key, value.dtype)
269 |
270 | return inputs
271 |
272 |
273 |
274 | def get_K(self, K, inputs):
275 | inv_K = np.linalg.inv(K)
276 | K_pool = {}
277 | ho, wo = self.opt.height, self.opt.width
278 | for i in range(6):
279 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32')
280 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i
281 |
282 | inputs['K_pool'] = K_pool
283 |
284 | inputs[("inv_K_pool", 0)] = {}
285 | for k, v in K_pool.items():
286 | K44 = np.eye(4)
287 | K44[:3, :3] = v
288 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32')
289 |
290 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32'))
291 |
292 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32'))
293 |
294 | return inputs
295 |
296 | def get_K_test(self, K, inputs):
297 | inv_K = np.linalg.inv(K)
298 | K_pool = {}
299 | ho, wo = self.opt.eval_height, self.opt.eval_width
300 | for i in range(6):
301 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32')
302 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i
303 |
304 | inputs['K_pool'] = K_pool
305 |
306 | inputs[("inv_K_pool", 0)] = {}
307 | for k, v in K_pool.items():
308 | K44 = np.eye(4)
309 | K44[:3, :3] = v
310 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32')
311 |
312 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32'))
313 |
314 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32'))
315 |
316 | return inputs
317 |
318 | def compute_projection_matrix(self, inputs):
319 | for i in range(self.opt.num_frame):
320 | inputs[("proj", i)] = {}
321 | for k, v in inputs['K_pool'].items():
322 | K44 = np.eye(4)
323 | K44[:3, :3] = v
324 | inputs[("proj",
325 | i)][k] = np.matmul(K44, inputs[("pose",
326 | i)]).astype('float32')
327 | return inputs
328 |
--------------------------------------------------------------------------------
/datasets/DDAD_forward.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import os
3 | import random
4 | import numpy as np
5 | import copy
6 | from PIL import Image # using pillow-simd for increased speed
7 | from time import time
8 | import torch
9 | import torch.utils.data as data
10 | from torchvision import transforms
11 | import cv2
12 |
13 | cv2.setNumThreads(0)
14 | import glob
15 | import utils
16 | import torch.nn.functional as F
17 | from utils import npy
18 | import json
19 |
20 |
21 |
22 | class DDAD(data.Dataset):
23 | def __init__(self, opt, is_train):
24 | super(DDAD, self).__init__()
25 | self.opt = opt
26 | self.json_path = "/home/cjd/tmp/DDAD_video.json"
27 | self.data_path_root = '/data/cjd/ddad/my_ddad/'
28 | self.is_train = is_train
29 | if self.is_train:
30 | self.data_path = os.path.join(self.data_path_root, 'train/')
31 | else:
32 | self.data_path = os.path.join(self.data_path_root, 'val/')
33 |
34 | f = open(self.json_path, 'r')
35 | content_all = f.read()
36 | json_list_all = json.loads(content_all)
37 | f.close()
38 |
39 | if self.is_train:
40 | self.file_names = json_list_all["train"]
41 | # self.file_names = self.file_names[:300]
42 | print('train', len(self.file_names))
43 | else:
44 | self.file_names = json_list_all["val"]
45 | # self.file_names = self.file_names[:300]
46 |
47 | print('val', len(self.file_names))
48 |
49 | print('filter_pre', len(self.file_names))
50 | self.file_names = [x for x in self.file_names if 'timestamp' in x.keys() and 'timestamp_back' in x.keys() and 'timestamp_forward' in x.keys() and x['Camera'] == 'CAMERA_01']
51 | print('filter_after', len(self.file_names))
52 | self.file_names = self.file_names[:50]
53 |
54 |
55 |
56 | def get_k_ori_randomcrop(self, k_raw, inputs, x1, y1):
57 |
58 | fx_ori = k_raw[0,0]
59 | fy_ori = k_raw[1,1]
60 | fx_virtual = 1060.0
61 | fx_scale = fx_ori / fx_virtual
62 | # print('fx_scale', fx_scale)
63 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale
64 |
65 | pose_cur = inputs[("pose", 0)]
66 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale
67 | inputs[("pose", 0)] = pose_cur
68 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)])
69 |
70 | pose_pre = inputs[("pose", 1)]
71 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale
72 | inputs[("pose", 1)] = pose_pre
73 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)])
74 |
75 | pose_next = inputs[("pose", 2)]
76 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale
77 | inputs[("pose", 2)] = pose_next
78 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)])
79 |
80 | # inputs['focal_scale'] = float(fx_scale)
81 |
82 | K = np.zeros((3,3), dtype = float)
83 | K[0,0] = fx_ori
84 | K[1,1] = fy_ori
85 | K[2,2] = 1.0
86 | K[0,2] = k_raw[0,2]
87 | K[1,2] = k_raw[1,2]
88 |
89 | h_crop = y1 - self.opt.height
90 | w_crop = x1
91 |
92 | K[0,2] = K[0,2] - w_crop
93 | K[1,2] = K[1,2] - h_crop
94 |
95 | return K, inputs
96 |
97 | def get_k_ori_centercrop(self, k_raw, inputs):
98 |
99 | fx_ori = k_raw[0,0]
100 | fy_ori = k_raw[1,1]
101 | fx_virtual = 1060.0
102 | fx_scale = fx_ori / fx_virtual
103 | # print('fx_scale', fx_scale)
104 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale
105 |
106 | pose_cur = inputs[("pose", 0)]
107 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale
108 | inputs[("pose", 0)] = pose_cur
109 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)])
110 |
111 | pose_pre = inputs[("pose", 1)]
112 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale
113 | inputs[("pose", 1)] = pose_pre
114 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)])
115 |
116 | pose_next = inputs[("pose", 2)]
117 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale
118 | inputs[("pose", 2)] = pose_next
119 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)])
120 |
121 | # inputs['focal_scale'] = float(fx_scale)
122 |
123 | K = np.zeros((3,3), dtype = float)
124 | K[0,0] = fx_ori
125 | K[1,1] = fy_ori
126 | K[2,2] = 1.0
127 | K[0,2] = k_raw[0,2]
128 | K[1,2] = k_raw[1,2]
129 |
130 | h_crop = 0.0
131 | w_crop = 8.0
132 |
133 | K[0,2] = K[0,2] - w_crop
134 | K[1,2] = K[1,2] - h_crop
135 |
136 | return K, inputs
137 |
138 |
139 | def __len__(self):
140 | return len(self.file_names)
141 |
142 |
143 | def __getitem__(self, index):
144 | inputs = {}
145 | cur_npz_path = self.data_path + str(self.file_names[index]['timestamp']) + '_' + self.file_names[index]['Camera'] + '.npz'
146 | pre_npz_path = self.data_path + str(self.file_names[index]['timestamp_back']) + '_' + self.file_names[index]['Camera'] + '.npz'
147 | next_npz_path = self.data_path + str(self.file_names[index]['timestamp_forward']) + '_' + self.file_names[index]['Camera'] + '.npz'
148 |
149 | cur_mask_path = cur_npz_path.replace('.npz', '_dynamic.npz')
150 | inputs['dynamic_mask'] = cur_mask_path
151 | file_cur = np.load(cur_npz_path)
152 | file_pre = np.load(pre_npz_path)
153 | file_next = np.load(next_npz_path)
154 |
155 | depth_cur_gt = file_cur['depth']
156 | depth_cur_gt = np.array(depth_cur_gt).astype(np.float32)
157 |
158 |
159 | inputs[("depth_gt", 0, 0)] = depth_cur_gt
160 |
161 |
162 | if self.is_train:
163 | #h_ori=1216, w_ori=1936
164 | if random.randint(0, 10) < 8:
165 | y_center = int((1216 + self.opt.height)/2)
166 | y1 = random.randint(int(y_center - 70), int(y_center + 50))
167 | else:
168 | y1 = random.randint(self.opt.height, 1216)
169 | x1 = random.randint(0, 1930 - self.opt.width)
170 | # else:
171 | # y1 = int((1216 + self.opt.height)/2)
172 | # x1 = int((1936 - self.opt.width)/2)
173 |
174 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)][y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)][None, :, :]
175 |
176 | rgb_cur = file_cur['rgb']
177 | rgb_cur = rgb_cur[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)]
178 | rgb_cur = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB)
179 | rgb_cur = torch.from_numpy(rgb_cur).permute(2, 0, 1) / 255.
180 | inputs[("color", 0, 0)] = rgb_cur
181 |
182 |
183 |
184 | rgb_pre = file_pre['rgb']
185 | rgb_pre = rgb_pre[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)]
186 | rgb_pre = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB)
187 | rgb_pre = torch.from_numpy(rgb_pre).permute(2, 0, 1) / 255.
188 | inputs[("color", 1, 0)] = rgb_pre
189 |
190 |
191 | rgb_next = file_next['rgb']
192 | rgb_next = rgb_next[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)]
193 | rgb_next = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB)
194 | rgb_next = torch.from_numpy(rgb_next).permute(2, 0, 1) / 255.
195 | inputs[("color", 2, 0)] = rgb_next
196 |
197 |
198 | pose_cur = file_cur['pose']
199 | pose_cur = np.linalg.inv(pose_cur).astype('float32')
200 | inputs[("pose", 0)] = pose_cur
201 |
202 | pose_pre = file_pre['pose']
203 | pose_pre = np.linalg.inv(pose_pre).astype('float32')
204 | inputs[("pose", 1)] = pose_pre
205 |
206 | pose_next = file_next['pose']
207 | pose_next = np.linalg.inv(pose_next).astype('float32')
208 | inputs[("pose", 2)] = pose_next
209 |
210 | k_raw = file_cur['intrinsics']
211 | # print('k_raw',k_raw)
212 | k_crop, inputs = self.get_k_ori_randomcrop(k_raw, inputs, x1, y1)
213 | inputs = self.get_K(k_crop, inputs)
214 |
215 |
216 | else:
217 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)][:, 8:1928][None, :, :]
218 |
219 | rgb_cur = file_cur['rgb']
220 | rgb_cur = rgb_cur[:, 8:1928]
221 | rgb_cur = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB)
222 | rgb_cur = torch.from_numpy(rgb_cur).permute(2, 0, 1) / 255.
223 | inputs[("color", 0, 0)] = rgb_cur
224 |
225 |
226 |
227 | rgb_pre = file_pre['rgb']
228 | rgb_pre = rgb_pre[:, 8:1928]
229 | rgb_pre = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB)
230 | rgb_pre = torch.from_numpy(rgb_pre).permute(2, 0, 1) / 255.
231 | inputs[("color", 1, 0)] = rgb_pre
232 |
233 |
234 | rgb_next = file_next['rgb']
235 | rgb_next = rgb_next[:, 8:1928]
236 | rgb_next = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB)
237 | rgb_next = torch.from_numpy(rgb_next).permute(2, 0, 1) / 255.
238 | inputs[("color", 2, 0)] = rgb_next
239 |
240 |
241 | pose_cur = file_cur['pose']
242 | pose_cur = np.linalg.inv(pose_cur).astype('float32')
243 | inputs[("pose", 0)] = pose_cur
244 |
245 | pose_pre = file_pre['pose']
246 | pose_pre = np.linalg.inv(pose_pre).astype('float32')
247 | inputs[("pose", 1)] = pose_pre
248 |
249 | pose_next = file_next['pose']
250 | pose_next = np.linalg.inv(pose_next).astype('float32')
251 | inputs[("pose", 2)] = pose_next
252 |
253 | k_raw = file_cur['intrinsics']
254 |
255 | k_crop, inputs = self.get_k_ori_centercrop(k_raw, inputs)
256 |
257 | inputs = self.get_K_test(k_crop, inputs)
258 |
259 | inputs = self.compute_projection_matrix(inputs)
260 |
261 |
262 | inputs['num_frame'] = 3
263 |
264 | # for key, value in inputs.items():
265 | # print(key, value.dtype)
266 |
267 | return inputs
268 |
269 |
270 |
271 | def get_K(self, K, inputs):
272 | inv_K = np.linalg.inv(K)
273 | K_pool = {}
274 | ho, wo = self.opt.height, self.opt.width
275 | for i in range(6):
276 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32')
277 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i
278 |
279 | inputs['K_pool'] = K_pool
280 |
281 | inputs[("inv_K_pool", 0)] = {}
282 | for k, v in K_pool.items():
283 | K44 = np.eye(4)
284 | K44[:3, :3] = v
285 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32')
286 |
287 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32'))
288 |
289 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32'))
290 |
291 | return inputs
292 |
293 | def get_K_test(self, K, inputs):
294 | inv_K = np.linalg.inv(K)
295 | K_pool = {}
296 | ho, wo = self.opt.eval_height, self.opt.eval_width
297 | for i in range(6):
298 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32')
299 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i
300 |
301 | inputs['K_pool'] = K_pool
302 |
303 | inputs[("inv_K_pool", 0)] = {}
304 | for k, v in K_pool.items():
305 | K44 = np.eye(4)
306 | K44[:3, :3] = v
307 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32')
308 |
309 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32'))
310 |
311 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32'))
312 |
313 | return inputs
314 |
315 | def compute_projection_matrix(self, inputs):
316 | for i in range(self.opt.num_frame):
317 | inputs[("proj", i)] = {}
318 | for k, v in inputs['K_pool'].items():
319 | K44 = np.eye(4)
320 | K44[:3, :3] = v
321 | inputs[("proj",
322 | i)][k] = np.matmul(K44, inputs[("pose",
323 | i)]).astype('float32')
324 | return inputs
325 |
--------------------------------------------------------------------------------
/datasets/__pycache__/DDAD.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/DDAD.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/DDAD_crop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/DDAD_crop.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/DDAD_forward.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/DDAD_forward.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/kitti.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/kitti.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/kitti.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/kitti.cpython-39.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/kitti_odometry.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/kitti_odometry.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/kitti_odometry.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/kitti_odometry.cpython-39.pyc
--------------------------------------------------------------------------------
/datasets/kitti.py:
--------------------------------------------------------------------------------
1 | # dataloader for KITTI / when training & testing F-Net and MaGNet
2 | import os
3 | import random
4 | import glob
5 |
6 | import numpy as np
7 | import torch
8 | import torch.utils.data as data
9 | import torch.utils.data.distributed
10 | from PIL import Image
11 |
12 | from torch.utils.data import Dataset, DataLoader
13 | from torchvision import transforms
14 |
15 | import pykitti
16 |
17 | import cv2
18 |
19 | import utils
20 | import torch.nn.functional as F
21 | from utils import npy
22 | import json
23 |
24 |
25 |
26 | class DDAD_kitti(data.Dataset):
27 | def __init__(self, opt, is_train):
28 | super(DDAD_kitti, self).__init__()
29 | self.opt = opt
30 | self.is_train = is_train
31 | if self.is_train:
32 | with open("./data_split/kitti_eigen_train.txt", 'r') as f:
33 | self.filenames = f.readlines()
34 | else:
35 | with open("./data_split/kitti_eigen_test.txt", 'r') as f:
36 | self.filenames = f.readlines()
37 |
38 | # self.mode = mode
39 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
40 | self.dataset_path = self.opt.data_path
41 |
42 | # local window
43 | self.window_radius = int(2)
44 | self.n_views = int(2)
45 | self.frame_interval = self.window_radius // (self.n_views // 2)
46 | self.img_idx_center = self.n_views // 2
47 |
48 | # window_idx_list
49 | self.window_idx_list = list(range(-self.n_views // 2, (self.n_views // 2) + 1))
50 | self.window_idx_list = [i * self.frame_interval for i in self.window_idx_list]
51 |
52 | # image resolution
53 | self.img_H = self.opt.height # 352
54 | self.img_W = self.opt.width # 1216
55 |
56 | def __len__(self):
57 | return len(self.filenames)
58 |
59 | # get camera intrinscs
60 | def get_cam_intrinsics(self, p_data):
61 | raw_img_size = p_data.get_cam2(0).size
62 | raw_W = int(raw_img_size[0])
63 | raw_H = int(raw_img_size[1])
64 |
65 | top_margin = int(raw_H - 352)
66 | left_margin = int((raw_W - 1216) / 2)
67 |
68 | # original intrinsic matrix (4X4)
69 | IntM_ = p_data.calib.K_cam2
70 |
71 | # updated intrinsic matrix
72 | IntM = np.zeros((3, 3))
73 | IntM[2, 2] = 1.
74 | IntM[0, 0] = IntM_[0, 0]
75 | IntM[1, 1] = IntM_[1, 1]
76 | IntM[0, 2] = (IntM_[0, 2] - left_margin)
77 | IntM[1, 2] = (IntM_[1, 2] - top_margin)
78 |
79 | IntM = IntM.astype(np.float32)
80 | return IntM
81 |
82 | def __getitem__(self, idx):
83 | inputs = {}
84 | date, drive, mode, img_idx = self.filenames[idx].split(' ')
85 | img_idx = int(img_idx)
86 | scene_name = '%s_drive_%s_sync' % (date, drive)
87 |
88 | # identify the neighbor views
89 | img_idx_list = [img_idx + i for i in self.window_idx_list]
90 | p_data = pykitti.raw(self.dataset_path + '/rawdata', date, drive, frames=img_idx_list)
91 |
92 | # cam intrinsics
93 | cam_intrins = self.get_cam_intrinsics(p_data)
94 |
95 | # color augmentation
96 | color_aug = False
97 | if self.is_train:
98 | if random.random() > 0.5:
99 | color_aug = True
100 | aug_gamma = random.uniform(0.9, 1.1)
101 | aug_brightness = random.uniform(0.9, 1.1)
102 | aug_colors = np.random.uniform(0.9, 1.1, size=3)
103 |
104 | # data array
105 | data_array = []
106 | for i in range(self.n_views + 1):
107 | cur_idx = img_idx_list[i]
108 |
109 | # read img
110 | img_name = '%010d.png' % cur_idx
111 | img_path = self.dataset_path + '/rawdata/{}/{}/image_02/data/{}'.format(date, scene_name, img_name)
112 | img = Image.open(img_path).convert("RGB")
113 |
114 | # kitti benchmark crop
115 | height = img.height
116 | width = img.width
117 | top_margin = int(height - 352)
118 | left_margin = int((width - 1216) / 2)
119 | img = img.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
120 |
121 | # to tensor
122 | img = np.array(img).astype(np.float32) / 255.0 # (H, W, 3)
123 | img_ori = torch.from_numpy(img).permute(2, 0, 1)
124 | if color_aug:
125 | img = self.augment_image(img, aug_gamma, aug_brightness, aug_colors)
126 | img = torch.from_numpy(img).permute(2, 0, 1) # (3, H, W)
127 | img = self.normalize(img)
128 |
129 | # read dmap (only for the ref img)
130 | if i == self.img_idx_center:
131 | dmap_path = self.dataset_path + '/{}/{}/proj_depth/groundtruth/image_02/{}'.format(mode, scene_name,
132 | img_name)
133 | gt_dmap = Image.open(dmap_path).crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
134 | gt_dmap = np.array(gt_dmap)[:, :, np.newaxis].astype(np.float32) # (H, W, 1)
135 | gt_dmap = gt_dmap / 256.0
136 | gt_dmap = torch.from_numpy(gt_dmap).permute(2, 0, 1) # (1, H, W)
137 | else:
138 | gt_dmap = 0.0
139 |
140 | # read extM
141 | pose = p_data.oxts[i].T_w_imu
142 | M_imu2cam = p_data.calib.T_cam2_imu
143 | extM = np.matmul(M_imu2cam, np.linalg.inv(pose))
144 | extM = extM.astype('float32')
145 |
146 | data_dict = {
147 | 'img_ori': img_ori,
148 | 'img': img,
149 | 'gt_dmap': gt_dmap,
150 | 'extM': extM,
151 | 'scene_name': scene_name,
152 | 'img_idx': str(img_idx),
153 | }
154 | data_array.append(data_dict)
155 |
156 | inputs[("color", 0, 0)] = data_array[1]['img']
157 | inputs[("img_ori", 0, 0)] = data_array[1]['img_ori']
158 |
159 | inputs[("depth_gt", 0, 0)] = data_array[1]['gt_dmap']
160 | inputs[("pose", 0)] = data_array[1]['extM']
161 |
162 | inputs[("color", 1, 0)] = data_array[0]['img']
163 | inputs[("pose", 1)] = data_array[0]['extM']
164 | inputs[("img_ori", 1, 0)] = data_array[0]['img_ori']
165 |
166 | inputs[("color", 2, 0)] = data_array[2]['img']
167 | inputs[("pose", 2)] = data_array[2]['extM']
168 | inputs[("img_ori", 2, 0)] = data_array[2]['img_ori']
169 |
170 |
171 | inputs = self.get_K(cam_intrins, inputs)
172 |
173 | inputs = self.compute_projection_matrix(inputs)
174 |
175 | inputs['num_frame'] = 3
176 |
177 | return inputs
178 |
179 | def augment_image(self, image, gamma, brightness, colors):
180 | # gamma augmentation
181 | image_aug = image ** gamma
182 |
183 | # brightness augmentation
184 | image_aug = image_aug * brightness
185 |
186 | # color augmentation
187 | white = np.ones((image.shape[0], image.shape[1]))
188 | color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
189 | image_aug *= color_image
190 | image_aug = np.clip(image_aug, 0, 1)
191 |
192 | return image_aug
193 |
194 | def get_K(self, K, inputs):
195 | inv_K = np.linalg.inv(K)
196 | K_pool = {}
197 | ho, wo = self.opt.height, self.opt.width
198 | for i in range(6):
199 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32')
200 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i
201 |
202 | inputs['K_pool'] = K_pool
203 |
204 | inputs[("inv_K_pool", 0)] = {}
205 | for k, v in K_pool.items():
206 | K44 = np.eye(4)
207 | K44[:3, :3] = v
208 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32')
209 |
210 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32'))
211 |
212 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32'))
213 |
214 | return inputs
215 |
216 |
217 | def compute_projection_matrix(self, inputs):
218 | for i in range(3):
219 | inputs[("proj", i)] = {}
220 | for k, v in inputs['K_pool'].items():
221 | K44 = np.eye(4)
222 | K44[:3, :3] = v
223 | inputs[("proj",
224 | i)][k] = np.matmul(K44, inputs[("pose",
225 | i)]).astype('float32')
226 | return inputs
--------------------------------------------------------------------------------
/eval_ddad.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import sys
5 | import cv2
6 | import torch
7 | import torch.nn as nn
8 | from options import MVS2DOptions, EvalCfg
9 | import networks
10 | from torch.utils.data import DataLoader
11 | from datasets.DDAD import DDAD
12 | import torch.nn.functional as F
13 | from utils import *
14 | from hybrid_evaluate_depth import evaluate_depth_maps, compute_errors,compute_errors1,compute_errors_perimage
15 | import os
16 | os.environ['CUDA_VISIBLE_DEVICES'] = '3'
17 |
18 |
19 | def to_gpu(inputs, keys=None):
20 | if keys == None:
21 | keys = inputs.keys()
22 | for key in keys:
23 | if key not in inputs:
24 | continue
25 | ipt = inputs[key]
26 | if type(ipt) == torch.Tensor:
27 | inputs[key] = ipt.cuda()
28 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor:
29 | inputs[key] = [
30 | x.cuda() for x in ipt
31 | ]
32 | elif type(ipt) == dict:
33 | for k in ipt.keys():
34 | if type(ipt[k]) == torch.Tensor:
35 | ipt[k] = ipt[k].cuda()
36 |
37 |
38 | options = MVS2DOptions()
39 | opts = options.parse()
40 |
41 | opts.cfg = "/configs/DDAD.conf"
42 | dataset = DDAD(opts, False)
43 | data_loader = DataLoader(dataset,
44 | 1,
45 | shuffle=False,
46 | num_workers=4,
47 | pin_memory=True,
48 | drop_last=False,
49 | sampler=None)
50 | model = networks.MVS2D(opt=opts).cuda()
51 | pretrained_dict = torch.load("pretrained_model/DDAD/model_DDAD.pth")
52 |
53 | model.load_state_dict(pretrained_dict)
54 | model.eval()
55 |
56 | index = 0
57 |
58 | total_result_sum = {}
59 | total_result_count = {}
60 | min_depth = opts.EVAL_MIN_DEPTH
61 | max_depth = opts.EVAL_MAX_DEPTH
62 |
63 | with torch.no_grad():
64 | for batch_idx, inputs in enumerate(data_loader):
65 | to_gpu(inputs)
66 |
67 | imgs, proj_mats, pose_mats = [], [], []
68 | for i in range(inputs['num_frame'][0].item()):
69 | imgs.append(inputs[('color', i, 0)])
70 | proj_mats.append(inputs[('proj', i)])
71 | pose_mats.append(inputs[('pose', i)])
72 |
73 | depth_gt = inputs[("depth_gt", 0, 0)]
74 | depth_gt_np = depth_gt.cpu().detach().numpy().squeeze()
75 |
76 | mask = (depth_gt_np>min_depth) & (depth_gt_np < max_depth)
77 | outputs = model(imgs[0], imgs[1:], pose_mats[0], pose_mats[1:],
78 | inputs[('inv_K_pool', 0)])
79 | depth_pred_1 = outputs[('depth_pred', 0)]
80 | depth_pred_2 = outputs[('depth_pred_2', 0)]
81 |
82 | depth_pred_2_np = depth_pred_2.cpu().detach().numpy().squeeze()
83 | depth_pred_1_np = depth_pred_1.cpu().detach().numpy().squeeze()
84 |
85 | error_temp = compute_errors_perimage(depth_gt_np[mask], depth_pred_1_np[mask], min_depth, max_depth)
86 | error_temp_2_ = compute_errors_perimage(depth_gt_np[mask], depth_pred_2_np[mask], min_depth, max_depth)
87 | print('cur',index, error_temp)
88 | index = index + 1
89 | error_temp_2 = {}
90 | for k,v in error_temp_2_.items():
91 | new_k = k + '_2'
92 | error_temp_2[new_k] = error_temp_2_[k]
93 |
94 | error_temp_all = {}
95 | error_temp_all.update(error_temp)
96 | error_temp_all.update(error_temp_2)
97 |
98 | for k,v in error_temp_all.items():
99 | if not isinstance(v,float):
100 | v=v.items()
101 | if k in total_result_sum:
102 | total_result_sum[k] = total_result_sum[k] + v
103 | else:
104 | total_result_sum[k] = v
105 |
106 | for k in total_result_sum.keys():
107 | total_result_count[k] = total_result_sum['valid_number']
108 |
109 | print('final####################################')
110 |
111 | for k in total_result_sum.keys():
112 | this_tensor = torch.tensor([total_result_sum[k], total_result_count[k]])
113 | this_list = [this_tensor]
114 | this_tensor = this_list[0].detach().cpu().numpy()
115 | reduce_sum = this_tensor[0].item()
116 | reduce_count = this_tensor[1].item()
117 | reduce_mean = reduce_sum / reduce_count
118 | print(k, reduce_mean)
119 |
120 |
--------------------------------------------------------------------------------
/eval_kitti.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import sys
5 | import cv2
6 | import torch
7 | import torch.nn as nn
8 | from options import MVS2DOptions, EvalCfg
9 | import networks
10 | from torch.utils.data import DataLoader
11 | from datasets.kitti import DDAD_kitti
12 | from hybrid_evaluate_depth import evaluate_depth_maps, compute_errors,compute_errors1,compute_errors_perimage
13 | import torch.nn.functional as F
14 | import os
15 | os.environ['CUDA_VISIBLE_DEVICES'] = '3'
16 |
17 | def to_gpu(inputs, keys=None):
18 | if keys == None:
19 | keys = inputs.keys()
20 | for key in keys:
21 | if key not in inputs:
22 | continue
23 | ipt = inputs[key]
24 | if type(ipt) == torch.Tensor:
25 | inputs[key] = ipt.cuda()
26 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor:
27 | inputs[key] = [
28 | x.cuda() for x in ipt
29 | ]
30 | elif type(ipt) == dict:
31 | for k in ipt.keys():
32 | if type(ipt[k]) == torch.Tensor:
33 | ipt[k] = ipt[k].cuda()
34 |
35 |
36 | options = MVS2DOptions()
37 | opts = options.parse()
38 |
39 | # opts.cfg = "./configs/kitti.conf"
40 | dataset = DDAD_kitti(opts, False)
41 | data_loader = DataLoader(dataset,
42 | 1,
43 | shuffle=False,
44 | num_workers=4,
45 | pin_memory=True,
46 | drop_last=False,
47 | sampler=None)
48 | model = networks.MVS2D(opt=opts).cuda()
49 | pretrained_dict = torch.load("pretrained_model/kitti/model_kitti.pth")
50 |
51 | model.load_state_dict(pretrained_dict)
52 | model.eval()
53 |
54 | min_depth = opts.EVAL_MIN_DEPTH
55 | max_depth = opts.EVAL_MAX_DEPTH
56 |
57 | index = 0
58 | total_result_sum = {}
59 | total_result_count = {}
60 | with torch.no_grad():
61 | for batch_idx, inputs in enumerate(data_loader):
62 | to_gpu(inputs)
63 |
64 | imgs, proj_mats, pose_mats = [], [], []
65 | for i in range(inputs['num_frame'][0].item()):
66 | imgs.append(inputs[('color', i, 0)])
67 | proj_mats.append(inputs[('proj', i)])
68 | pose_mats.append(inputs[('pose', i)])
69 |
70 | depth_gt = inputs[("depth_gt", 0, 0)]
71 | depth_gt_np = depth_gt.cpu().detach().numpy().squeeze()
72 | mask = (depth_gt_np>min_depth) & (depth_gt_np < max_depth)
73 |
74 | if np.sum(mask.astype(np.float32)) > 5:
75 |
76 | outputs = model(imgs[0], imgs[1:], pose_mats[0], pose_mats[1:],
77 | inputs[('inv_K_pool', 0)])
78 | depth_pred_1_tensor = outputs[('depth_pred', 0)]
79 | depth_pred_2_tensor = outputs[('depth_pred_2', 0)]
80 |
81 | depth_pred_2 = depth_pred_2_tensor.cpu().detach().numpy().squeeze()
82 | depth_pred_1 = depth_pred_1_tensor.cpu().detach().numpy().squeeze()
83 |
84 | error_temp = compute_errors_perimage(depth_gt_np[mask], depth_pred_1[mask], min_depth, max_depth)
85 | error_temp_2_ = compute_errors_perimage(depth_gt_np[mask], depth_pred_2[mask], min_depth, max_depth)
86 | print('cur',index, error_temp)
87 | index = index + 1
88 | error_temp_2 = {}
89 | for k,v in error_temp_2_.items():
90 | new_k = k + '_2'
91 | error_temp_2[new_k] = error_temp_2_[k]
92 |
93 | error_temp_all = {}
94 | error_temp_all.update(error_temp)
95 | error_temp_all.update(error_temp_2)
96 |
97 | for k,v in error_temp_all.items():
98 | if not isinstance(v,float):
99 | v=v.items()
100 | if k in total_result_sum:
101 | total_result_sum[k] = total_result_sum[k] + v
102 | else:
103 | total_result_sum[k] = v
104 |
105 | for k in total_result_sum.keys():
106 | total_result_count[k] = total_result_sum['valid_number']
107 |
108 | print('final####################################')
109 |
110 | for k in total_result_sum.keys():
111 | this_tensor = torch.tensor([total_result_sum[k], total_result_count[k]])
112 | this_list = [this_tensor]
113 | this_tensor = this_list[0].detach().cpu().numpy()
114 | reduce_sum = this_tensor[0].item()
115 | reduce_count = this_tensor[1].item()
116 | reduce_mean = reduce_sum / reduce_count
117 | print(k, reduce_mean)
118 |
119 |
--------------------------------------------------------------------------------
/generate_dynamic_mask.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import os
3 | import random
4 | import numpy as np
5 | import copy
6 | from PIL import Image # using pillow-simd for increased speed
7 | from time import time
8 | import torch
9 | import torch.utils.data as data
10 | from torchvision import transforms
11 | import cv2
12 |
13 | cv2.setNumThreads(0)
14 | import glob
15 | import utils
16 | import torch.nn.functional as F
17 | from utils import npy
18 | import json
19 | from mmdet.apis import init_detector, inference_detector
20 | import mmcv
21 | from skimage.metrics import structural_similarity
22 |
23 | import os
24 | # os.environ['CUDA_VISIBLE_DEVICES'] = '3'
25 |
26 | def to_gpu(inputs, keys=None):
27 | if keys == None:
28 | keys = inputs.keys()
29 | for key in keys:
30 | if key not in inputs:
31 | continue
32 | ipt = inputs[key]
33 | if type(ipt) == torch.Tensor:
34 | inputs[key] = ipt.cuda()
35 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor:
36 | inputs[key] = [
37 | x.cuda() for x in ipt
38 | ]
39 | elif type(ipt) == dict:
40 | for k in ipt.keys():
41 | if type(ipt[k]) == torch.Tensor:
42 | ipt[k] = ipt[k].cuda()
43 |
44 | def homo_warping_depth(src_fea, src_proj, ref_proj, depth_values):
45 | # src_fea: [B, C, H, W]
46 | # src_proj: [B, 4, 4]
47 | # ref_proj: [B, 4, 4]
48 | # depth_values: [B, Ndepth, H, W]
49 | # out: [B, C, Ndepth, H, W]
50 | batch, channels = src_fea.shape[0], src_fea.shape[1]
51 | num_depth = depth_values.shape[1]
52 | #height, width = src_fea.shape[2], src_fea.shape[3]
53 | h_src, w_src = src_fea.shape[2], src_fea.shape[3]
54 | h_ref, w_ref = depth_values.shape[2], depth_values.shape[3]
55 |
56 | with torch.no_grad():
57 |
58 | proj = torch.matmul(src_proj, torch.inverse(ref_proj))
59 | rot = proj[:, :3, :3] # [B,3,3]
60 | trans = proj[:, :3, 3:4] # [B,3,1]
61 |
62 |
63 | y, x = torch.meshgrid([torch.arange(0, h_ref, dtype=torch.float32, device=src_fea.device),
64 | torch.arange(0, w_ref, dtype=torch.float32, device=src_fea.device)])
65 | y, x = y.contiguous(), x.contiguous()
66 | y, x = y.view(h_ref * w_ref), x.view(h_ref * w_ref)
67 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
68 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
69 |
70 | rot_xyz = torch.matmul(rot, xyz)
71 | print(rot_xyz.shape)
72 | print(depth_values.shape)
73 | rot_depth_xyz = rot_xyz * depth_values.view(batch, 1, -1)
74 |
75 | proj_xyz = rot_depth_xyz + trans.view(batch,3,1)
76 |
77 | proj_xy = proj_xyz[:, :2, :] / proj_xyz[:, 2:3, :] # [B, 2, Ndepth, H*W]
78 | z = proj_xyz[:, 2:3, :].view(batch, h_ref, w_ref)
79 | proj_x_normalized = proj_xy[:, 0, :] / ((w_src - 1) / 2.0) - 1
80 | proj_y_normalized = proj_xy[:, 1, :] / ((h_src - 1) / 2.0) - 1
81 | X_mask = ((proj_x_normalized > 1)+(proj_x_normalized < -1)).detach()
82 | proj_x_normalized[X_mask] = 2 # make sure that no point in warped image is a combinaison of im and gray
83 | Y_mask = ((proj_y_normalized > 1)+(proj_y_normalized < -1)).detach()
84 | proj_y_normalized[Y_mask] = 2
85 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=2) # [B, Ndepth, H*W, 2]
86 | grid = proj_xy
87 | proj_mask = ((X_mask + Y_mask) > 0).view(batch, num_depth, h_ref, w_ref)
88 | proj_mask = (proj_mask + (z <= 0)) > 0
89 |
90 | warped_src_fea = F.grid_sample(src_fea, grid.view(batch, h_ref, w_ref, 2), mode='bilinear',
91 | padding_mode='zeros', align_corners=True)
92 |
93 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, h_ref, w_ref)
94 |
95 | #return warped_src_fea , proj_mask
96 | return warped_src_fea
97 |
98 | def main():
99 | json_path = "/home/cjd/tmp/DDAD_video.json"
100 | data_path_ori = '/data/cjd/ddad/ddad_train_val/'
101 | data_path_root = '/data/cjd/ddad/my_ddad/'
102 | data_path = os.path.join(data_path_root, 'val/')
103 | f = open(json_path, 'r')
104 | content_all = f.read()
105 | json_list_all = json.loads(content_all)
106 | f.close()
107 | file_names = json_list_all["val"]
108 | file_names = [x for x in file_names if 'timestamp' in x.keys() and 'timestamp_back' in x.keys() and 'timestamp_forward' in x.keys() and x['Camera'] == 'CAMERA_01']
109 |
110 | model = init_detector("/home/cjd/mmdetection3d-master/configs/nuimages/htc_x101_64x4d_fpn_dconv_c3-c5_coco-20e_16x1_20e_nuim.py", "/home/cjd/MVS2D/htc_x101_64x4d_fpn_dconv_c3-c5_coco-20e_16x1_20e_nuim_20201008_211222-0b16ac4b.pth", device = 'cuda:3')
111 | # print(model.CLASSES)
112 | class_all = model.CLASSES
113 | print(class_all)
114 | lenth = len(file_names)
115 | print(lenth)
116 | for index in range(lenth):
117 | inputs = {}
118 | # cur_img_path = data_path_ori + str(file_names[index]['video_num']) + '/rgb/' + file_names[index]['Camera'] +'/'+ str(file_names[index]['timestamp']) + '.png'
119 | cur_npz_path = data_path + str(file_names[index]['timestamp']) + '_' + file_names[index]['Camera'] + '.npz'
120 | pre_npz_path = data_path + str(file_names[index]['timestamp_back']) + '_' + file_names[index]['Camera'] + '.npz'
121 | next_npz_path = data_path + str(file_names[index]['timestamp_forward']) + '_' + file_names[index]['Camera'] + '.npz'
122 |
123 | file_cur = np.load(cur_npz_path)
124 | file_pre = np.load(pre_npz_path)
125 | file_next = np.load(next_npz_path)
126 |
127 |
128 | depth_cur_gt = file_cur['depth']
129 | depth_cur_gt = np.array(depth_cur_gt).astype(np.float32)
130 |
131 | inputs[("depth_gt", 0, 0)] = torch.from_numpy(depth_cur_gt)
132 |
133 | rgb_cur = file_cur['rgb']
134 | # print(rgb_cur.shape)
135 | rgb_cur_input = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB)
136 | rgb_cur_input = torch.from_numpy(rgb_cur_input).permute(2, 0, 1) / 255.
137 | inputs[("color", 0, 0)] = rgb_cur_input
138 | # cv2.imwrite('img.png', rgb_cur)
139 | pose_cur = file_cur['pose']
140 | pose_cur = np.linalg.inv(pose_cur).astype('float32')
141 | inputs[("pose", 0)] = pose_cur
142 | rgb_pre = file_pre['rgb']
143 | rgb_pre_input = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB)
144 | rgb_pre_input = torch.from_numpy(rgb_pre_input).permute(2, 0, 1) / 255.
145 | inputs[("color", 1, 0)] = rgb_pre_input
146 | pose_pre = file_pre['pose']
147 | pose_pre = np.linalg.inv(pose_pre).astype('float32')
148 | inputs[("pose", 1)] = pose_pre
149 | rgb_next = file_next['rgb']
150 | rgb_next_input = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB)
151 | rgb_next_input = torch.from_numpy(rgb_next_input).permute(2, 0, 1) / 255.
152 | inputs[("color", 2, 0)] = rgb_next_input
153 | pose_next = file_next['pose']
154 | pose_next = np.linalg.inv(pose_next).astype('float32')
155 | inputs[("pose", 2)] = pose_next
156 | K = file_cur['intrinsics']
157 |
158 | inv_K = np.linalg.inv(K)
159 |
160 | K_pool = {}
161 | ho, wo, _ = rgb_cur.shape
162 | for i in range(6):
163 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32')
164 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i
165 |
166 | inputs['K_pool'] = K_pool
167 |
168 | inputs[("inv_K_pool", 0)] = {}
169 | for k, v in K_pool.items():
170 | K44 = np.eye(4)
171 | K44[:3, :3] = v
172 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32')
173 |
174 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32'))
175 |
176 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32'))
177 |
178 | for i in range(3):
179 | inputs[("proj", i)] = {}
180 | for k, v in inputs['K_pool'].items():
181 | K44 = np.eye(4)
182 | K44[:3, :3] = v
183 | inputs[("proj",
184 | i)][k] = torch.from_numpy(np.matmul(K44, inputs[("pose",
185 | i)]).astype('float32'))
186 | to_gpu(inputs)
187 | h, w, _ = rgb_cur.shape
188 | imgs, proj_mats, pose_mats = [], [], []
189 | for i in range(3):
190 | imgs.append(inputs[('color', i, 0)])
191 | proj_mats.append(inputs[('proj', i)])
192 | pose_mats.append(inputs[('pose', i)])
193 |
194 | depth_gt = inputs[("depth_gt", 0, 0)][None,None,:,:]
195 | img0 = imgs[0][None,:,:,:]
196 | img1 = imgs[1][None,:,:,:]
197 |
198 | proj_mats_0 = proj_mats[0][(h, w)][None,:,:]
199 |
200 | proj_mats_1 = proj_mats[1][(h, w)][None,:,:]
201 | # print(img1.shape)
202 | # print(proj_mats_0.shape)
203 | # print(depth_gt.shape)
204 | warped_img0 = homo_warping_depth(img1, proj_mats_1, proj_mats_0, depth_gt)
205 | img0_np = img0[0].cpu().detach().numpy().squeeze().transpose(1,2,0)
206 | warped_img0_np = warped_img0[0].cpu().detach().numpy().squeeze().transpose(1,2,0)
207 | depth_gt_np = depth_gt.cpu().detach().numpy().squeeze()
208 | # img0_np = (img0_np / img0_np.max() * 255).astype(np.uint8)
209 | # cv2.imwrite('img0.png', img0_np)
210 |
211 | # warped_img0_np = (warped_img0_np / warped_img0_np.max() * 255).astype(np.uint8)
212 | # cv2.imwrite('warped_img.png', warped_img0_np)
213 | # print(rgb_cur.shape)
214 | # img = mmcv.imread(cur_img_path)
215 | result = inference_detector(model, rgb_cur)
216 |
217 | index_list = [0,1,2,3,4,5,6,7]
218 |
219 | mask_all = np.zeros_like(rgb_cur[:,:,0], dtype = bool)
220 | for index_ in index_list:
221 | object_number = len(result[1][index_])
222 | for index_object in range(object_number):
223 | mask_now = result[1][index_][index_object] & (depth_gt_np > 0)
224 | # diff_now = warped_img0_np[mask_now] - img0_np[mask_now]
225 | if np.sum(mask_now.astype(float)) > 50:
226 | ssim = structural_similarity(img0_np[mask_now], warped_img0_np[mask_now], multichannel = True)
227 | else:
228 | ssim = 0.3
229 | print(ssim)
230 | if result[0][index_][index_object][4] > 0.5 and ssim < 0.75:
231 | # print(result[0][index_][index_object])
232 | mask_all = mask_all + result[1][index_][index_object]
233 |
234 | # # car_number = len(result[1][0])
235 |
236 | # # for car_index in range(car_number):
237 | # # car_mask = car_mask + result[1][0][car_index]
238 | mask_all_vis = mask_all.astype(np.float)
239 | mask_all_vis = (mask_all_vis*255).astype(np.uint8)
240 |
241 | save_path = cur_npz_path.replace('.npz', '_dynamic.npz')
242 | print(save_path)
243 |
244 | np.savez(save_path, mask_all)
245 |
246 | # cv2.imwrite('seg_mask.png',mask_all_vis)
247 | # a = input('print something')
248 | # print(a)
249 |
250 |
251 | if __name__ == "__main__":
252 | main()
--------------------------------------------------------------------------------
/hybrid_evaluate_depth.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | from open3d import *
3 | import os
4 | import cv2
5 | import numpy as np
6 | import torch
7 | from utils import write_ply, backproject_depth, v, npy, Thres_metrics_np
8 |
9 | cv2.setNumThreads(
10 | 0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1)
11 |
12 | def compute_errors_perimage(gt, pred, min_depth, max_depth):
13 | valid_mask = (gt > min_depth) & (gt < max_depth)
14 | epe = np.mean(np.abs(gt[valid_mask] - pred[valid_mask]))
15 | abs_rel = np.mean(np.abs(gt[valid_mask] - pred[valid_mask]) / gt[valid_mask])
16 | sq_rel = np.mean(((gt[valid_mask] - pred[valid_mask])**2) / gt[valid_mask])
17 |
18 | rmse = (gt - pred)**2
19 | rmse = np.sqrt(rmse.mean())
20 |
21 | rmse_log = (np.log(gt) - np.log(pred))**2
22 | rmse_log = np.sqrt(rmse_log.mean())
23 |
24 | thresh = np.maximum((gt / pred), (pred / gt))
25 | a1 = (thresh < 1.25).mean()
26 | a2 = (thresh < 1.25**2).mean()
27 | a3 = (thresh < 1.25**3).mean()
28 | log10 = np.mean(np.abs(np.log10(pred) - np.log10(gt)))
29 |
30 | return {
31 | 'abs_rel':abs_rel.item(),
32 | 'sq_rel':sq_rel.item(),
33 | 'rmse':rmse.item(),
34 | 'rmse_log':rmse_log.item(),
35 | 'a1':a1.item(),
36 | 'a2':a2.item(),
37 | 'a3':a3.item(),
38 | 'log10':log10.item(),
39 | 'valid_number':1.0,
40 | 'abs_diff':epe.item()
41 | }
42 |
43 | def compute_errors(gt, pred, disable_median_scaling, min_depth, max_depth,
44 | interval):
45 | """Computation of error metrics between predicted and ground truth depths
46 | """
47 | # if not disable_median_scaling:
48 | # ratio = np.median(gt) / np.median(pred)
49 | # pred *= ratio
50 |
51 | # pred[pred < min_depth] = min_depth
52 | # pred[pred > max_depth] = max_depth
53 | mask = np.logical_and(gt > min_depth, gt < max_depth)
54 |
55 | thresh = np.maximum((gt / pred), (pred / gt))
56 | a1 = (thresh < 1.25).mean()
57 | a2 = (thresh < 1.25**2).mean()
58 | a3 = (thresh < 1.25**3).mean()
59 |
60 | rmse = (gt - pred)**2
61 | rmse = np.sqrt(rmse.mean())
62 |
63 | rmse_log = (np.log(gt) - np.log(pred))**2
64 | rmse_log = np.sqrt(rmse_log.mean())
65 |
66 | abs_rel = np.mean(np.abs(gt[mask] - pred[mask]) / gt[mask])
67 | print('1',abs_rel)
68 | abs_rel_2 = np.sum(np.abs(gt[mask] - pred[mask]) / gt[mask])/np.sum(mask.astype(np.float32))
69 |
70 | print('2', abs_rel_2)
71 | abs_diff = np.mean(np.abs(gt - pred))
72 | # abs_diff_median = np.median(np.abs(gt - pred))
73 |
74 | sq_rel = np.mean(((gt - pred)**2) / gt)
75 | log10 = np.mean(np.abs(np.log10(pred) - np.log10(gt)))
76 | # mask = np.ones_like(pred)
77 | # thre1 = Thres_metrics_np(pred, gt, mask, 1.0, 0.2)
78 | # thre3 = Thres_metrics_np(pred, gt, mask, 1.0, 0.5)
79 | # thre5 = Thres_metrics_np(pred, gt, mask, 1.0, 1.0)
80 |
81 | result = {}
82 | result['abs_rel'] = abs_rel
83 | result['sq_rel'] = sq_rel
84 | result['rmse'] = rmse
85 | result['rmse_log'] = rmse_log
86 | result['log10'] = log10
87 | result['a1'] = a1
88 | result['a2'] = a2
89 | result['a3'] = a3
90 | result['abs_diff'] = abs_diff
91 | result['total_count'] = 1.0
92 |
93 | return result
94 |
95 | def compute_errors1(gt, pred, disable_median_scaling, min_depth, max_depth,
96 | interval):
97 | """Computation of error metrics between predicted and ground truth depths
98 | """
99 | # if not disable_median_scaling:
100 | # ratio = np.median(gt) / np.median(pred)
101 | # pred *= ratio
102 |
103 | # pred[pred < min_depth] = min_depth
104 | # pred[pred > max_depth] = max_depth
105 |
106 | thresh = np.maximum((gt / pred), (pred / gt))
107 | a1 = (thresh < 1.25).mean()
108 | a2 = (thresh < 1.25**2).mean()
109 | a3 = (thresh < 1.25**3).mean()
110 |
111 | rmse = (gt - pred)**2
112 | rmse = np.sqrt(rmse.mean())
113 |
114 | rmse_log = (np.log(gt) - np.log(pred))**2
115 | rmse_log = np.sqrt(rmse_log.mean())
116 |
117 | abs_rel = np.mean(np.abs(gt - pred) / gt)
118 | abs_diff = np.mean(np.abs(gt - pred))
119 | # abs_diff_median = np.median(np.abs(gt - pred))
120 |
121 | sq_rel = np.mean(((gt - pred)**2) / gt)
122 | log10 = np.mean(np.abs(np.log10(pred) - np.log10(gt)))
123 | # mask = np.ones_like(pred)
124 | # thre1 = Thres_metrics_np(pred, gt, mask, 1.0, 0.2)
125 | # thre3 = Thres_metrics_np(pred, gt, mask, 1.0, 0.5)
126 | # thre5 = Thres_metrics_np(pred, gt, mask, 1.0, 1.0)
127 |
128 | result = {}
129 | result['abs_rel'] = abs_rel
130 | result['sq_rel'] = sq_rel
131 | result['rmse'] = rmse
132 | result['rmse_log'] = rmse_log
133 | result['log10'] = log10
134 | result['a1'] = a1
135 | result['a2'] = a2
136 | result['a3'] = a3
137 | result['abs_diff'] = abs_diff
138 | result['total_count'] = 1.0
139 |
140 | return result
141 |
142 |
143 | # return abs_rel, sq_rel, log10, rmse, rmse_log, a1, a2, a3, abs_diff, abs_diff_median, thre1, thre3, thre5
144 |
145 |
146 | def evaluate_depth_maps(results, config, do_print=False):
147 | errors = []
148 |
149 | if not os.path.exists(config.save_dir):
150 | os.makedirs(config.save_dir)
151 |
152 | print('eval against gt depth map of size: %sx%d' %
153 | (results[0][1].shape[0], results[0][1].shape[1]))
154 | for i in range(len(results)):
155 | if i % 100 == 0:
156 | print('evaluation : %d/%d' % (i, len(results)))
157 |
158 | gt_depth = results[i][1]
159 | gt_height, gt_width = gt_depth.shape[:2]
160 | pred_depth = results[i][0]
161 | filename = results[i][2]
162 | inv_K = results[i][3]
163 | if gt_width != pred_depth.shape[1] or gt_height != pred_depth.shape[0]:
164 | pred_depth = cv2.resize(pred_depth, (gt_width, gt_height),
165 | interpolation=cv2.INTER_NEAREST)
166 | mask = np.logical_and(gt_depth > config.MIN_DEPTH,
167 | gt_depth < config.MAX_DEPTH)
168 | if not mask.sum():
169 | continue
170 |
171 | ind = np.where(mask.flatten())[0]
172 | if config.vis:
173 | cam_points = backproject_depth(pred_depth, inv_K, mask=False)
174 | cam_points_gt = backproject_depth(gt_depth, inv_K, mask=False)
175 | write_ply('%s/%s_pred.ply' % (config.save_dir, filename),
176 | cam_points[ind])
177 | write_ply('%s/%s_gt.ply' % (config.save_dir, filename),
178 | cam_points_gt[ind])
179 |
180 | dataset = filename.split('_')[0]
181 | interval = (935 - 425) / (128 - 1) # Interval value used by MVSNet
182 | errors.append(
183 | (compute_errors(gt_depth[mask], pred_depth[mask],
184 | config.disable_median_scaling, config.MIN_DEPTH,
185 | config.MAX_DEPTH, interval), dataset, filename))
186 |
187 | with open('%s/errors.txt' % (config.save_dir), 'w') as f:
188 | for x, _, fID in errors:
189 | tex = fID + ' ' + ' '.join(['%.3f' % y for y in x])
190 | f.write(tex + '\n')
191 |
192 | np.save('%s/error.npy' % config.save_dir, errors)
193 | results = {}
194 | all_errors = [x[0] for x in errors]
195 |
196 | print(f"total example evaluated: {len(all_errors)}")
197 | all_mean_errors = np.array(all_errors).mean(0)
198 | if do_print:
199 | print("\n all")
200 | print("\n " +
201 | ("{:>8} | " *
202 | 13).format("abs_rel", "sq_rel", "log10", "rmse", "rmse_log",
203 | "a1", "a2", "a3", "abs_diff", "abs_diff_median"))
204 | print(("&{: 8.3f} " * 13).format(*all_mean_errors.tolist()) + "\\\\")
205 |
206 | error_names = [
207 | "abs_rel", "sq_rel", "log10", "rmse", "rmse_log", "a1", "a2", "a3",
208 | "abs_diff", "abs_diff_median", "thre1", "thre3", "thre5"
209 | ]
210 | results['depth'] = {'error_names': error_names, 'errors': all_mean_errors}
211 |
212 | errors_per_dataset = {}
213 | for x in errors:
214 | key = x[1]
215 | if key not in errors_per_dataset:
216 | errors_per_dataset[key] = [x[0]]
217 | else:
218 | errors_per_dataset[key].append(x[0])
219 | if config.print_per_dataset_stats:
220 | for key in errors_per_dataset.keys():
221 | errors_ = errors_per_dataset[key]
222 | mean_errors = np.array(errors_).mean(0)
223 |
224 | print("\n dataset %s: %d" % (key, len(errors_)))
225 | print("\n " +
226 | ("{:>8} | " *
227 | 13).format("abs_rel", "sq_rel", "log10", "rmse", "rmse_log",
228 | "a1", "a2", "a3", "abs_diff", "abs_diff_median",
229 | "thre1", "thre3", "thre5"))
230 | print(("&{: 8.3f} " * 13).format(*mean_errors.tolist()) + "\\\\")
231 |
232 | print("\n-> Done!")
233 | return results
234 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
1 | # from .mvs2d import MVS2D
2 | from .AFNet import MVS2D
3 | # from .AFNet_main import MVS2D
4 | # from .AFNet_mobile import MVS2D
5 | # from .AFNet_efficient import MVS2D
--------------------------------------------------------------------------------
/networks/__pycache__/AFNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/AFNet.cpython-37.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/AFNet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/AFNet.cpython-39.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/AFNet_efficient.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/AFNet_efficient.cpython-37.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/AFNet_main.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/AFNet_main.cpython-37.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/AFNet_mobile.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/AFNet_mobile.cpython-37.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/module.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/module.cpython-37.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/module.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/module.cpython-39.pyc
--------------------------------------------------------------------------------
/networks/__pycache__/mvs2d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/mvs2d.cpython-37.pyc
--------------------------------------------------------------------------------
/networks/mobilenet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.hub import load_state_dict_from_url
3 |
4 | model_urls = {
5 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
6 | }
7 |
8 |
9 | def _make_divisible(v, divisor, min_value=None):
10 | """
11 | This function is taken from the original tf repo.
12 | It ensures that all layers have a channel number that is divisible by 8
13 | It can be seen here:
14 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
15 | :param v:
16 | :param divisor:
17 | :param min_value:
18 | :return:
19 | """
20 | if min_value is None:
21 | min_value = divisor
22 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
23 | # Make sure that round down does not go down by more than 10%.
24 | if new_v < 0.9 * v:
25 | new_v += divisor
26 | return new_v
27 |
28 | class ConvBNReLU(nn.Sequential):
29 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
30 | padding = (kernel_size - 1) // 2
31 | if norm_layer is None:
32 | norm_layer = nn.BatchNorm2d
33 | super(ConvBNReLU, self).__init__(
34 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
35 | norm_layer(out_planes),
36 | nn.ReLU6(inplace=True)
37 | )
38 |
39 |
40 | class InvertedResidual(nn.Module):
41 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
42 | super(InvertedResidual, self).__init__()
43 | self.stride = stride
44 | assert stride in [1, 2]
45 |
46 | if norm_layer is None:
47 | norm_layer = nn.BatchNorm2d
48 |
49 | hidden_dim = int(round(inp * expand_ratio))
50 | self.use_res_connect = self.stride == 1 and inp == oup
51 |
52 | layers = []
53 | if expand_ratio != 1:
54 | # pw
55 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
56 | layers.extend([
57 | # dw
58 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
59 | # pw-linear
60 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
61 | norm_layer(oup),
62 | ])
63 | self.conv = nn.Sequential(*layers)
64 |
65 | def forward(self, x):
66 | if self.use_res_connect:
67 | return x + self.conv(x)
68 | else:
69 | return self.conv(x)
70 |
71 |
72 | class MobileNetV2(nn.Module):
73 | def __init__(self,
74 | num_classes=1000,
75 | width_mult=1.0,
76 | inverted_residual_setting=None,
77 | round_nearest=8,
78 | block=None,
79 | norm_layer=None):
80 | """
81 | MobileNet V2 main class
82 | Args:
83 | num_classes (int): Number of classes
84 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
85 | inverted_residual_setting: Network structure
86 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
87 | Set to 1 to turn off rounding
88 | block: Module specifying inverted residual building block for mobilenet
89 | norm_layer: Module specifying the normalization layer to use
90 | """
91 | super(MobileNetV2, self).__init__()
92 |
93 | if block is None:
94 | block = InvertedResidual
95 |
96 | if norm_layer is None:
97 | norm_layer = nn.BatchNorm2d
98 |
99 | input_channel = 32
100 | last_channel = 1280
101 |
102 | if inverted_residual_setting is None:
103 | inverted_residual_setting = [
104 | # t, c, n, s
105 | [1, 16, 1, 1],
106 | [6, 24, 2, 2],
107 | [6, 32, 3, 2],
108 | [6, 64, 4, 2],
109 | [6, 96, 3, 1],
110 | [6, 160, 3, 2],
111 | [6, 320, 1, 1],
112 | ]
113 |
114 | # only check the first element, assuming user knows t,c,n,s are required
115 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
116 | raise ValueError("inverted_residual_setting should be non-empty "
117 | "or a 4-element list, got {}".format(inverted_residual_setting))
118 |
119 | # building first layer
120 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
121 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
122 | features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
123 | # building inverted residual blocks
124 | for t, c, n, s in inverted_residual_setting:
125 | output_channel = _make_divisible(c * width_mult, round_nearest)
126 | for i in range(n):
127 | stride = s if i == 0 else 1
128 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
129 | input_channel = output_channel
130 | # building last several layers
131 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
132 | # make it nn.Sequential
133 | self.features = nn.Sequential(*features)
134 |
135 | # building classifier
136 | self.classifier = nn.Sequential(
137 | nn.Dropout(0.2),
138 | nn.Linear(self.last_channel, num_classes),
139 | )
140 |
141 | # weight initialization
142 | for m in self.modules():
143 | if isinstance(m, nn.Conv2d):
144 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
145 | if m.bias is not None:
146 | nn.init.zeros_(m.bias)
147 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
148 | nn.init.ones_(m.weight)
149 | nn.init.zeros_(m.bias)
150 | elif isinstance(m, nn.Linear):
151 | nn.init.normal_(m.weight, 0, 0.01)
152 | nn.init.zeros_(m.bias)
153 |
154 | def _forward_impl(self, x):
155 | # This exists since TorchScript doesn't support inheritance, so the superclass method
156 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
157 | x = self.features(x)
158 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
159 | # x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
160 | # x = self.classifier(x)
161 | return x
162 |
163 | def forward(self, x):
164 | return self._forward_impl(x)
165 |
166 |
167 | def mobilenet_v2(pretrained=False, progress=True, **kwargs):
168 | """
169 | Constructs a MobileNetV2 architecture from
170 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
171 | Args:
172 | pretrained (bool): If True, returns a model pre-trained on ImageNet
173 | progress (bool): If True, displays a progress bar of the download to stderr
174 | """
175 | model = MobileNetV2(**kwargs)
176 | if pretrained:
177 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
178 | progress=progress)
179 | model.load_state_dict(state_dict)
180 | return model
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import os
3 | import argparse
4 | from pyhocon import ConfigFactory
5 |
6 |
7 | class MVS2DOptions:
8 | def __init__(self):
9 | self.parser = argparse.ArgumentParser(description="MVS2D options")
10 |
11 | # PATH OPTIONS
12 | self.parser.add_argument("--log_dir", type=str, help="", default=None)
13 | self.parser.add_argument("--model_name",
14 | type=str,
15 | help="",
16 | default=None)
17 | self.parser.add_argument("--num_epochs",
18 | type=int,
19 | default=None,
20 | help="")
21 | self.parser.add_argument("--overwrite",
22 | type=int,
23 | default=None,
24 | help="")
25 | self.parser.add_argument('--note', type=str, help="", default=None)
26 | self.parser.add_argument("--DECAY_STEP_LIST",
27 | nargs="+",
28 | type=int,
29 | help="",
30 | default=None)
31 | # DATA OPTIONS
32 | self.parser.add_argument(
33 | "--mode",
34 | type=str,
35 | default=None,
36 | choices=["train", "test", "train+test", "full_test", "recon"])
37 | self.parser.add_argument("--num_workers",
38 | type=int,
39 | help="",
40 | default=None)
41 | self.parser.add_argument("--use_test", type=int, default=None, help="")
42 | self.parser.add_argument("--robust", type=int, default=None, help="")
43 | self.parser.add_argument("--perturb_pose",
44 | type=int,
45 | default=None,
46 | help="")
47 | self.parser.add_argument('--num_frame',
48 | type=int,
49 | help="",
50 | default=None)
51 | self.parser.add_argument('--fullsize_eval',
52 | type=int,
53 | help="",
54 | default=None)
55 | self.parser.add_argument('--filter', nargs="+", type=str, default=None)
56 | # MODEL OPTIONS
57 | self.parser.add_argument("--load_weights_folder",
58 | type=str,
59 | help="",
60 | default=None)
61 |
62 | # TRAINING OPTIONS
63 |
64 | # OPTIMIZATION OPTIONS
65 |
66 | # MULTI-GPU OPTIONS
67 | self.parser.add_argument("--world_size", type=int, default=1, help="")
68 | self.parser.add_argument("--multiprocessing_distributed",
69 | type=int,
70 | default=None,
71 | help="")
72 | self.parser.add_argument('--rank', type=int, help="", default=0)
73 | self.parser.add_argument('--gpu', type=int, help="", default=None)
74 | self.parser.add_argument('--local_rank', type=int, help="", default=0)
75 | self.parser.add_argument('--tcp_port', type=int, default=None, help="")
76 |
77 | # OTHERS
78 | self.parser.add_argument('--save_prediction',
79 | type=int,
80 | help="",
81 | default=None)
82 | self.parser.add_argument("--debug", help="", action="store_true")
83 | # self.parser.add_argument('--cfg', type=str, default="./configs/DDAD_kitti.conf")
84 | self.parser.add_argument('--cfg', type=str)
85 |
86 | self.parser.add_argument('--epoch_size', type=int, default=None)
87 | self.parser.add_argument('--val_epoch_size', type=int, default=None)
88 |
89 | def parse(self):
90 | self.options = self.parser.parse_args()
91 | cfg = ConfigFactory.parse_file(self.options.cfg)
92 | for k in cfg.keys():
93 | if k not in self.options:
94 | setattr(self.options, k, cfg[k])
95 | else:
96 | if getattr(self.options, k) is None:
97 | setattr(self.options, k, cfg[k])
98 |
99 | return self.options
100 |
101 |
102 | class EvalCfg(object):
103 | def __init__(self,
104 | save_dir,
105 | min_depth=1e-3,
106 | max_depth=10.0,
107 | vis=False,
108 | disable_median_scaling=True,
109 | eigen_crop=True,
110 | print_per_dataset_stats=False,
111 | garg_crop=False):
112 | self.save_dir = save_dir
113 | self.vis = vis
114 | self.MIN_DEPTH = min_depth
115 | self.MAX_DEPTH = max_depth
116 | self.eigen_crop = eigen_crop
117 | self.garg_crop = garg_crop
118 | self.disable_median_scaling = disable_median_scaling
119 | self.print_per_dataset_stats = print_per_dataset_stats
120 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | imageio==2.9.0
2 | ipdb==0.12.3
3 | joblib==1.0.1
4 | lz4==3.1.3
5 | matplotlib==3.1.3
6 | numpy==1.20.3
7 | open3d==0.10.0.0
8 | opencv_contrib_python==3.4.2.17
9 | path.py==12.5.0
10 | Pillow==8.4.0
11 | pyhocon==0.3.58
12 | pykdtree==1.3.4
13 | requests==2.26.0
14 | scipy==1.7.1
15 | tensorboardX==2.4.1
16 |
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | BASE_DIR='./'
3 | EXP_DIR="${BASE_DIR}/experiments/release/ScanNet/exp0"
4 | cfg=./configs/scannet/release.conf
5 |
6 | ## test
7 | CUDA_VISIBLE_DEVICES=0 python train.py --model_name=config0_test --mode=test --cfg $cfg --load_weights_folder=./pretrained_model/scannet/MVS2D --use_test=1 --fullsize_eval=1
8 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # cfg=./configs/DDAD.conf
3 | cfg=./configs/kitti.conf
4 |
5 |
6 | CUDA_VISIBLE_DEVICES=0,1,2 OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=3 train_kitti.py --num_epochs=60 --DECAY_STEP_LIST 30 40 --cfg $cfg --load_weights_folder=/pretrained_model/kitti/ --fullsize_eval=1 --use_test=0
7 |
--------------------------------------------------------------------------------
/train_af.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | from open3d import *
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | # from tensorboardX import SummaryWriter
7 | from torch.utils.tensorboard import SummaryWriter
8 | import json
9 | from utils import *
10 | import networks
11 | import os
12 | import glob
13 | import random
14 | import torch.optim as optim
15 | from options import MVS2DOptions, EvalCfg
16 | from trainer_base_af import BaseTrainer
17 | from hybrid_evaluate_depth import evaluate_depth_maps, compute_errors,compute_errors1,compute_errors_perimage
18 | from dtu_pyeval import dtu_pyeval
19 | import pprint
20 | import torch.distributed as dist
21 |
22 |
23 | class Trainer(BaseTrainer):
24 | def __init__(self, options):
25 | super(Trainer, self).__init__(options)
26 |
27 | def build_model(self):
28 | self.parameters_to_train = []
29 | self.model = networks.MVS2D(opt=self.opt).cuda()
30 | self.parameters_to_train += list(self.model.parameters())
31 | parameters_count(self.model, 'MVS2D')
32 |
33 | # def build_optimizer(self):
34 | # if self.opt.optimizer.lower() == 'adam':
35 | # self.model_optimizer = optim.Adam(
36 | # self.model.parameters(),
37 | # lr=self.opt.LR,
38 | # weight_decay=self.opt.WEIGHT_DECAY)
39 | # elif self.opt.optimizer.lower() == 'sgd':
40 | # self.model_optimizer = optim.SGD(
41 | # self.model.parameters(),
42 | # lr=self.opt.LR,
43 | # weight_decay=self.opt.WEIGHT_DECAY)
44 |
45 | # def val_epoch(self):
46 | # print("Validation")
47 | # writer = self.writers['val']
48 | # self.set_eval()
49 | # results_depth = []
50 | # val_loss = []
51 | # config = EvalCfg(
52 | # eigen_crop=False,
53 | # garg_crop=False,
54 | # min_depth=self.opt.EVAL_MIN_DEPTH,
55 | # max_depth=self.opt.EVAL_MAX_DEPTH,
56 | # vis=self.epoch % 10 == 0 and self.opt.eval_vis,
57 | # disable_median_scaling=self.opt.disable_median_scaling,
58 | # print_per_dataset_stats=self.opt.dataset == 'DeMoN',
59 | # save_dir=os.path.join(self.log_path, 'eval_%03d' % self.epoch))
60 | # if not os.path.exists(config.save_dir):
61 | # os.makedirs(config.save_dir)
62 | # print('evaluation results save to folder %s' % config.save_dir)
63 | # times = []
64 | # val_stats = defaultdict(list)
65 | # dict_pred = {}
66 | # dict_pred_2 = {}
67 | # total_result_count = {}
68 | # total_result_count_2 = {}
69 |
70 | # with torch.no_grad():
71 | # for batch_idx, inputs in enumerate(self.val_loader):
72 | # if self.opt.val_epoch_size != -1 and batch_idx >= self.opt.val_epoch_size:
73 | # break
74 | # if batch_idx % 100 == 0:
75 | # print(batch_idx, len(self.val_loader))
76 | # # filenames = inputs["filenames"]
77 | # losses, outputs = self.process_batch(inputs, 'val')
78 | # # b = len(inputs["filenames"])
79 |
80 | # s = 0
81 | # pred_depth = npy(outputs[('depth_pred', s)])
82 | # pred_depth_2 = npy(outputs[('depth_pred_2', s)])
83 | # depth_gt = npy(inputs[('depth_gt', 0, s)])
84 | # mask = np.logical_and(depth_gt > config.MIN_DEPTH,
85 | # depth_gt < config.MAX_DEPTH)
86 | # interval = (935 - 425) / (128 - 1) # Interval value used by MVSNet
87 | # # dict_pred_temp = compute_errors(depth_gt[mask], pred_depth[mask], config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval)
88 | # dict_pred_temp = compute_errors(depth_gt, pred_depth, config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval)
89 |
90 | # dict_pred_temp_2 = compute_errors1(depth_gt[mask], pred_depth_2[mask], config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval)
91 |
92 | # for k, v in dict_pred_temp.items():
93 | # # print(k,v)
94 | # if k in dict_pred:
95 | # dict_pred[k] = dict_pred[k] + v
96 | # # dict_pred['total_count'] = dict_pred['total_count'] + 1.0
97 | # else:
98 | # dict_pred[k] = v
99 | # # dict_pred['total_count'] = 1.0
100 |
101 | # for k, v in dict_pred_temp_2.items():
102 | # k = k + '_2'
103 | # if k in dict_pred_2:
104 | # dict_pred_2[k] = dict_pred_2[k] + v
105 | # # dict_pred_2['total_count_2'] = dict_pred_2['total_count_2'] + 1.0
106 | # else:
107 | # dict_pred_2[k] = v
108 | # # dict_pred_2['total_count_2'] = 1.0
109 | # if batch_idx % 80 == 0:
110 | # writer.add_image('image0', inputs[("color", 0, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
111 | # writer.add_image('image1', inputs[("color", 1, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
112 | # writer.add_image('image2', inputs[("color", 2, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
113 | # depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0])
114 | # writer.add_image('depth_gt', depth_gt, global_step=self.step, walltime=None, dataformats='HWC')
115 | # depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0])
116 | # writer.add_image('depth_pred', depth_pred, global_step=self.step, walltime=None, dataformats='HWC')
117 |
118 | # # print('total_count', dict_pred['total_count'])
119 | # # print('total_count_2', dict_pred_2['total_count_2'])
120 | # # print('abs_rel', dict_pred['abs_rel'])
121 |
122 | # for k in dict_pred.keys():
123 | # total_result_count[k] = dict_pred['total_count']
124 |
125 | # for k in dict_pred_2.keys():
126 | # total_result_count_2[k] = dict_pred_2['total_count_2']
127 |
128 |
129 | # for k in dict_pred.keys():
130 | # this_tensor = torch.tensor([dict_pred[k], total_result_count[k]]).to(self.device)
131 | # this_list = [this_tensor]
132 | # torch.distributed.all_reduce_multigpu(this_list)
133 | # this_tensor = this_list[0].detach().cpu().numpy()
134 | # reduce_sum = this_tensor[0].item()
135 | # reduce_count = this_tensor[1].item()
136 | # reduce_mean = reduce_sum / reduce_count
137 | # if self.is_master:
138 | # writer.add_scalar(k, reduce_mean, self.step)
139 |
140 | # for k in dict_pred_2.keys():
141 | # this_tensor = torch.tensor([dict_pred_2[k], total_result_count_2[k]]).to(self.device)
142 | # this_list = [this_tensor]
143 | # torch.distributed.all_reduce_multigpu(this_list)
144 | # this_tensor = this_list[0].detach().cpu().numpy()
145 | # reduce_sum = this_tensor[0].item()
146 | # reduce_count = this_tensor[1].item()
147 | # reduce_mean = reduce_sum / reduce_count
148 | # if self.is_master:
149 | # writer.add_scalar(k, reduce_mean, self.step)
150 |
151 | # self.set_train()
152 |
153 | def val_epoch(self):
154 | print("Validation")
155 | writer = self.writers['val']
156 | self.set_eval()
157 | results_depth = []
158 | val_loss = []
159 | config = EvalCfg(
160 | eigen_crop=False,
161 | garg_crop=False,
162 | min_depth=self.opt.EVAL_MIN_DEPTH,
163 | max_depth=self.opt.EVAL_MAX_DEPTH,
164 | vis=self.epoch % 10 == 0 and self.opt.eval_vis,
165 | disable_median_scaling=self.opt.disable_median_scaling,
166 | print_per_dataset_stats=self.opt.dataset == 'DeMoN',
167 | save_dir=os.path.join(self.log_path, 'eval_%03d' % self.epoch))
168 | if not os.path.exists(config.save_dir) and self.is_master:
169 | os.makedirs(config.save_dir)
170 | print('evaluation results save to folder %s' % config.save_dir)
171 | times = []
172 | val_stats = defaultdict(list)
173 | total_result_sum = {}
174 | total_result_count = {}
175 |
176 | with torch.no_grad():
177 | for batch_idx, inputs in enumerate(self.val_loader):
178 | if self.opt.val_epoch_size != -1 and batch_idx >= self.opt.val_epoch_size:
179 | break
180 | if batch_idx % 100 == 0:
181 | print(batch_idx, len(self.val_loader))
182 | # filenames = inputs["filenames"]
183 | losses, outputs = self.process_batch(inputs, 'val')
184 | # b = len(inputs["filenames"])
185 |
186 | s = 0
187 | pred_depth = npy(outputs[('depth_pred', s)])
188 | pred_depth_2 = npy(outputs[('depth_pred_2', s)])
189 | depth_gt = npy(inputs[('depth_gt', 0, s)])
190 | mask = np.logical_and(depth_gt > config.MIN_DEPTH,
191 | depth_gt < config.MAX_DEPTH)
192 | error_temp = compute_errors_perimage(depth_gt[mask], pred_depth[mask], config.MIN_DEPTH, config.MAX_DEPTH)
193 | error_temp_2_ = compute_errors_perimage(depth_gt[mask], pred_depth_2[mask], config.MIN_DEPTH, config.MAX_DEPTH)
194 |
195 | error_temp_2 = {}
196 | for k,v in error_temp_2_.items():
197 | new_k = k + '_2'
198 | error_temp_2[new_k] = error_temp_2_[k]
199 |
200 | error_temp_all = {}
201 | error_temp_all.update(error_temp)
202 | error_temp_all.update(error_temp_2)
203 |
204 | for k,v in error_temp_all.items():
205 | if not isinstance(v,float):
206 | v=v.items()
207 | if k in total_result_sum:
208 | total_result_sum[k] = total_result_sum[k] + v
209 | else:
210 | total_result_sum[k] = v
211 | self.eval_step += 1
212 |
213 | if self.eval_step % 80 == 0 and self.is_master:
214 | writer.add_image('image0', inputs[("color", 0, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW')
215 | writer.add_image('image1', inputs[("color", 1, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW')
216 | writer.add_image('image2', inputs[("color", 2, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW')
217 | depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0])
218 | writer.add_image('depth_gt', depth_gt, global_step=self.eval_step, walltime=None, dataformats='HWC')
219 | depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0])
220 | writer.add_image('depth_pred', depth_pred, global_step=self.eval_step, walltime=None, dataformats='HWC')
221 | depth_pred_2 = gray_2_colormap_np(outputs[('depth_pred_2', 0)][0][0])
222 | writer.add_image('depth_pred_2', depth_pred_2, global_step=self.eval_step, walltime=None, dataformats='HWC')
223 |
224 | # print('total_count', dict_pred['total_count'])
225 | # print('total_count_2', dict_pred_2['total_count_2'])
226 | # print('abs_rel', dict_pred['abs_rel'])
227 |
228 | for k in total_result_sum.keys():
229 | total_result_count[k] = total_result_sum['valid_number']
230 |
231 |
232 | for k in total_result_sum.keys():
233 | this_tensor = torch.tensor([total_result_sum[k], total_result_count[k]]).to(self.device)
234 | this_list = [this_tensor]
235 | torch.distributed.all_reduce_multigpu(this_list)
236 | torch.distributed.barrier()
237 | this_tensor = this_list[0].detach().cpu().numpy()
238 | reduce_sum = this_tensor[0].item()
239 | reduce_count = this_tensor[1].item()
240 | reduce_mean = reduce_sum / reduce_count
241 | if self.is_master:
242 | writer.add_scalar(k, reduce_mean, self.eval_step)
243 |
244 | self.set_train()
245 |
246 | def process_batch(self, inputs, mode):
247 | self.to_gpu(inputs)
248 |
249 | imgs, proj_mats, pose_mats = [], [], []
250 | for i in range(inputs['num_frame'][0].item()):
251 | imgs.append(inputs[('color', i, self.opt.input_scale)])
252 | proj_mats.append(inputs[('proj', i)])
253 | pose_mats.append(inputs[('pose', i)])
254 |
255 | # outputs = self.model(imgs[0], imgs[1:], proj_mats[0], proj_mats[1:],
256 | # inputs[('inv_K_pool', 0)])
257 | outputs = self.model(imgs[0], imgs[1:], pose_mats[0], pose_mats[1:],
258 | inputs[('inv_K_pool', 0)])
259 | losses = self.compute_losses(inputs, outputs)
260 | return losses, outputs
261 |
262 | def compute_losses(self, inputs, outputs):
263 | losses, loss, s = {}, 0, 0
264 | depth_pred = outputs[('depth_pred', s)]
265 | depth_pred_2 = outputs[('depth_pred_2', s)]
266 | depth_gt = inputs[('depth_gt', 0, s)]
267 |
268 |
269 | valid_depth = (depth_gt > 0)
270 |
271 | # if self.opt.pred_conf:
272 | # log_conf_pred = outputs[('log_conf_pred', s)]
273 | # conf_pred = torch.exp(log_conf_pred)
274 | # min_conf = self.opt.min_conf
275 | # max_conf = self.opt.max_conf if self.opt.max_conf != -1 else None
276 | # conf_pred = conf_pred.clamp(min_conf, max_conf)
277 | # loss_depth = ((depth_pred - depth_gt).abs() / conf_pred +
278 | # log_conf_pred)[valid_depth].mean()
279 | # else:
280 | loss_depth = (depth_pred[valid_depth] - depth_gt[valid_depth]).abs().mean()
281 |
282 | loss_depth_2 = (depth_pred_2[valid_depth] - depth_gt[valid_depth]).abs().mean()
283 |
284 | losses["depth"] = loss_depth
285 | losses["depth_2"] = loss_depth_2
286 |
287 | loss += loss_depth + loss_depth_2
288 | losses["loss"] = loss
289 |
290 | return losses
291 |
292 |
293 | def run_fusion(dense_folder, out_folder, opts):
294 | cmd = f"python patchmatch_fusion.py \
295 | --dense_folder {dense_folder} \
296 | --outdir {out_folder} \
297 | --n_proc 4 \
298 | --conf_thres {opts.conf_thres} \
299 | --att_thres {opts.att_thres} \
300 | --use_conf_thres {opts.pred_conf} \
301 | --geo_depth_thres {opts.geo_depth_thres} \
302 | --geo_pixel_thres {opts.geo_pixel_thres} \
303 | --num_consistent {opts.num_consistent} \
304 | "
305 |
306 | os.system(cmd)
307 |
308 |
309 | if __name__ == "__main__":
310 | options = MVS2DOptions()
311 | opts = options.parse()
312 |
313 | set_random_seed(666)
314 |
315 | if torch.cuda.device_count() > 1 and not opts.multiprocessing_distributed:
316 | raise Exception(
317 | "Detected more than 1 GPU. Please set multiprocessing_distributed=1 or set CUDA_VISIBLE_DEVICES"
318 | )
319 |
320 | opts.distributed = opts.world_size > 1 or opts.multiprocessing_distributed
321 | if opts.multiprocessing_distributed:
322 | # total_gpus, opts.rank = init_dist_pytorch(opts.tcp_port,
323 | # opts.local_rank,
324 | # backend='nccl')
325 | print('opts.local_rank', opts.local_rank)
326 | torch.cuda.set_device(opts.local_rank)
327 | dist.init_process_group("nccl", rank=opts.local_rank, world_size=3)
328 | opts.ngpus_per_node = 3
329 | opts.gpu = opts.local_rank
330 | print("Use GPU: {}/{} for training".format(opts.gpu,
331 | opts.ngpus_per_node))
332 | else:
333 | opts.gpu = 0
334 |
335 | if opts.mode == 'train':
336 | trainer = Trainer(opts)
337 | trainer.train()
338 |
339 | elif opts.mode == 'test':
340 | trainer = Trainer(opts)
341 | trainer.val()
342 |
343 | elif opts.mode == 'full_test':
344 | ## save depth prediction
345 | opts.mode = 'test'
346 | trainer = Trainer(opts)
347 | trainer.val()
348 |
349 | ## fuse dense prediction into final point cloud
350 | dense_folder = f"{opts.log_dir}/{opts.model_name}/eval_000/prediction"
351 | out_folder = f"{opts.log_dir}/{opts.model_name}/recon"
352 | run_fusion(dense_folder, out_folder, opts)
353 |
354 | ## eval point cloud
355 | MeanData, MeanStl, MeanAvg = dtu_pyeval(
356 | f"{out_folder}",
357 | gt_dir='./data/SampleSet/MVS Data/',
358 | voxel_down_sample=False,
359 | fn=f"{out_folder}/result.txt")
360 |
--------------------------------------------------------------------------------
/train_kitti.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | from open3d import *
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | # from tensorboardX import SummaryWriter
7 | from torch.utils.tensorboard import SummaryWriter
8 | import json
9 | from utils import *
10 | import networks
11 | import os
12 | import glob
13 | import random
14 | import torch.optim as optim
15 | from options import MVS2DOptions, EvalCfg
16 | from trainer_base_kitti import BaseTrainer
17 | from hybrid_evaluate_depth import evaluate_depth_maps, compute_errors,compute_errors1,compute_errors_perimage
18 | from dtu_pyeval import dtu_pyeval
19 | import pprint
20 | import torch.distributed as dist
21 |
22 |
23 | class Trainer(BaseTrainer):
24 | def __init__(self, options):
25 | super(Trainer, self).__init__(options)
26 |
27 | def build_model(self):
28 | self.parameters_to_train = []
29 | self.model = networks.MVS2D(opt=self.opt).cuda()
30 | self.parameters_to_train += list(self.model.parameters())
31 | parameters_count(self.model, 'MVS2D')
32 |
33 | # def build_optimizer(self):
34 | # if self.opt.optimizer.lower() == 'adam':
35 | # self.model_optimizer = optim.Adam(
36 | # self.model.parameters(),
37 | # lr=self.opt.LR,
38 | # weight_decay=self.opt.WEIGHT_DECAY)
39 | # elif self.opt.optimizer.lower() == 'sgd':
40 | # self.model_optimizer = optim.SGD(
41 | # self.model.parameters(),
42 | # lr=self.opt.LR,
43 | # weight_decay=self.opt.WEIGHT_DECAY)
44 |
45 | # def val_epoch(self):
46 | # print("Validation")
47 | # writer = self.writers['val']
48 | # self.set_eval()
49 | # results_depth = []
50 | # val_loss = []
51 | # config = EvalCfg(
52 | # eigen_crop=False,
53 | # garg_crop=False,
54 | # min_depth=self.opt.EVAL_MIN_DEPTH,
55 | # max_depth=self.opt.EVAL_MAX_DEPTH,
56 | # vis=self.epoch % 10 == 0 and self.opt.eval_vis,
57 | # disable_median_scaling=self.opt.disable_median_scaling,
58 | # print_per_dataset_stats=self.opt.dataset == 'DeMoN',
59 | # save_dir=os.path.join(self.log_path, 'eval_%03d' % self.epoch))
60 | # if not os.path.exists(config.save_dir):
61 | # os.makedirs(config.save_dir)
62 | # print('evaluation results save to folder %s' % config.save_dir)
63 | # times = []
64 | # val_stats = defaultdict(list)
65 | # dict_pred = {}
66 | # dict_pred_2 = {}
67 | # total_result_count = {}
68 | # total_result_count_2 = {}
69 |
70 | # with torch.no_grad():
71 | # for batch_idx, inputs in enumerate(self.val_loader):
72 | # if self.opt.val_epoch_size != -1 and batch_idx >= self.opt.val_epoch_size:
73 | # break
74 | # if batch_idx % 100 == 0:
75 | # print(batch_idx, len(self.val_loader))
76 | # # filenames = inputs["filenames"]
77 | # losses, outputs = self.process_batch(inputs, 'val')
78 | # # b = len(inputs["filenames"])
79 |
80 | # s = 0
81 | # pred_depth = npy(outputs[('depth_pred', s)])
82 | # pred_depth_2 = npy(outputs[('depth_pred_2', s)])
83 | # depth_gt = npy(inputs[('depth_gt', 0, s)])
84 | # mask = np.logical_and(depth_gt > config.MIN_DEPTH,
85 | # depth_gt < config.MAX_DEPTH)
86 | # interval = (935 - 425) / (128 - 1) # Interval value used by MVSNet
87 | # # dict_pred_temp = compute_errors(depth_gt[mask], pred_depth[mask], config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval)
88 | # dict_pred_temp = compute_errors(depth_gt, pred_depth, config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval)
89 |
90 | # dict_pred_temp_2 = compute_errors1(depth_gt[mask], pred_depth_2[mask], config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval)
91 |
92 | # for k, v in dict_pred_temp.items():
93 | # # print(k,v)
94 | # if k in dict_pred:
95 | # dict_pred[k] = dict_pred[k] + v
96 | # # dict_pred['total_count'] = dict_pred['total_count'] + 1.0
97 | # else:
98 | # dict_pred[k] = v
99 | # # dict_pred['total_count'] = 1.0
100 |
101 | # for k, v in dict_pred_temp_2.items():
102 | # k = k + '_2'
103 | # if k in dict_pred_2:
104 | # dict_pred_2[k] = dict_pred_2[k] + v
105 | # # dict_pred_2['total_count_2'] = dict_pred_2['total_count_2'] + 1.0
106 | # else:
107 | # dict_pred_2[k] = v
108 | # # dict_pred_2['total_count_2'] = 1.0
109 | # if batch_idx % 80 == 0:
110 | # writer.add_image('image0', inputs[("color", 0, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
111 | # writer.add_image('image1', inputs[("color", 1, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
112 | # writer.add_image('image2', inputs[("color", 2, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
113 | # depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0])
114 | # writer.add_image('depth_gt', depth_gt, global_step=self.step, walltime=None, dataformats='HWC')
115 | # depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0])
116 | # writer.add_image('depth_pred', depth_pred, global_step=self.step, walltime=None, dataformats='HWC')
117 |
118 | # # print('total_count', dict_pred['total_count'])
119 | # # print('total_count_2', dict_pred_2['total_count_2'])
120 | # # print('abs_rel', dict_pred['abs_rel'])
121 |
122 | # for k in dict_pred.keys():
123 | # total_result_count[k] = dict_pred['total_count']
124 |
125 | # for k in dict_pred_2.keys():
126 | # total_result_count_2[k] = dict_pred_2['total_count_2']
127 |
128 |
129 | # for k in dict_pred.keys():
130 | # this_tensor = torch.tensor([dict_pred[k], total_result_count[k]]).to(self.device)
131 | # this_list = [this_tensor]
132 | # torch.distributed.all_reduce_multigpu(this_list)
133 | # this_tensor = this_list[0].detach().cpu().numpy()
134 | # reduce_sum = this_tensor[0].item()
135 | # reduce_count = this_tensor[1].item()
136 | # reduce_mean = reduce_sum / reduce_count
137 | # if self.is_master:
138 | # writer.add_scalar(k, reduce_mean, self.step)
139 |
140 | # for k in dict_pred_2.keys():
141 | # this_tensor = torch.tensor([dict_pred_2[k], total_result_count_2[k]]).to(self.device)
142 | # this_list = [this_tensor]
143 | # torch.distributed.all_reduce_multigpu(this_list)
144 | # this_tensor = this_list[0].detach().cpu().numpy()
145 | # reduce_sum = this_tensor[0].item()
146 | # reduce_count = this_tensor[1].item()
147 | # reduce_mean = reduce_sum / reduce_count
148 | # if self.is_master:
149 | # writer.add_scalar(k, reduce_mean, self.step)
150 |
151 | # self.set_train()
152 |
153 | def val_epoch(self):
154 | print("Validation")
155 | writer = self.writers['val']
156 | self.set_eval()
157 | results_depth = []
158 | val_loss = []
159 | config = EvalCfg(
160 | eigen_crop=False,
161 | garg_crop=False,
162 | min_depth=self.opt.EVAL_MIN_DEPTH,
163 | max_depth=self.opt.EVAL_MAX_DEPTH,
164 | vis=self.epoch % 10 == 0 and self.opt.eval_vis,
165 | disable_median_scaling=self.opt.disable_median_scaling,
166 | print_per_dataset_stats=self.opt.dataset == 'DeMoN',
167 | save_dir=os.path.join(self.log_path, 'eval_%03d' % self.epoch))
168 | if not os.path.exists(config.save_dir) and self.is_master:
169 | os.makedirs(config.save_dir)
170 | print('evaluation results save to folder %s' % config.save_dir)
171 | times = []
172 | val_stats = defaultdict(list)
173 | total_result_sum = {}
174 | total_result_count = {}
175 |
176 | with torch.no_grad():
177 | for batch_idx, inputs in enumerate(self.val_loader):
178 | if self.opt.val_epoch_size != -1 and batch_idx >= self.opt.val_epoch_size:
179 | break
180 | if batch_idx % 100 == 0:
181 | print(batch_idx, len(self.val_loader))
182 | # filenames = inputs["filenames"]
183 | losses, outputs = self.process_batch(inputs, 'val')
184 | # b = len(inputs["filenames"])
185 |
186 | s = 0
187 | pred_depth = npy(outputs[('depth_pred', s)])
188 | pred_depth_2 = npy(outputs[('depth_pred_2', s)])
189 | depth_gt = npy(inputs[('depth_gt', 0, s)])
190 | mask = np.logical_and(depth_gt > config.MIN_DEPTH,
191 | depth_gt < config.MAX_DEPTH)
192 | error_temp = compute_errors_perimage(depth_gt[mask], pred_depth[mask], config.MIN_DEPTH, config.MAX_DEPTH)
193 | error_temp_2_ = compute_errors_perimage(depth_gt[mask], pred_depth_2[mask], config.MIN_DEPTH, config.MAX_DEPTH)
194 |
195 | error_temp_2 = {}
196 | for k,v in error_temp_2_.items():
197 | new_k = k + '_2'
198 | error_temp_2[new_k] = error_temp_2_[k]
199 |
200 | error_temp_all = {}
201 | error_temp_all.update(error_temp)
202 | error_temp_all.update(error_temp_2)
203 |
204 | for k,v in error_temp_all.items():
205 | if not isinstance(v,float):
206 | v=v.items()
207 | if k in total_result_sum:
208 | total_result_sum[k] = total_result_sum[k] + v
209 | else:
210 | total_result_sum[k] = v
211 | self.eval_step += 1
212 |
213 | if self.eval_step % 80 == 0 and self.is_master:
214 | writer.add_image('image0', inputs[("img_ori", 0, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW')
215 | writer.add_image('image1', inputs[("img_ori", 1, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW')
216 | writer.add_image('image2', inputs[("img_ori", 2, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW')
217 | depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0])
218 | writer.add_image('depth_gt', depth_gt, global_step=self.eval_step, walltime=None, dataformats='HWC')
219 | depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0])
220 | writer.add_image('depth_pred', depth_pred, global_step=self.eval_step, walltime=None, dataformats='HWC')
221 | depth_pred_2 = gray_2_colormap_np(outputs[('depth_pred_2', 0)][0][0])
222 | writer.add_image('depth_pred_2', depth_pred_2, global_step=self.eval_step, walltime=None, dataformats='HWC')
223 |
224 | # print('total_count', dict_pred['total_count'])
225 | # print('total_count_2', dict_pred_2['total_count_2'])
226 | # print('abs_rel', dict_pred['abs_rel'])
227 |
228 | for k in total_result_sum.keys():
229 | total_result_count[k] = total_result_sum['valid_number']
230 |
231 |
232 | for k in total_result_sum.keys():
233 | this_tensor = torch.tensor([total_result_sum[k], total_result_count[k]]).to(self.device)
234 | this_list = [this_tensor]
235 | torch.distributed.all_reduce_multigpu(this_list)
236 | torch.distributed.barrier()
237 | this_tensor = this_list[0].detach().cpu().numpy()
238 | reduce_sum = this_tensor[0].item()
239 | reduce_count = this_tensor[1].item()
240 | reduce_mean = reduce_sum / reduce_count
241 | if self.is_master:
242 | writer.add_scalar(k, reduce_mean, self.eval_step)
243 |
244 | self.set_train()
245 |
246 | def process_batch(self, inputs, mode):
247 | self.to_gpu(inputs)
248 |
249 | imgs, proj_mats, pose_mats = [], [], []
250 | for i in range(inputs['num_frame'][0].item()):
251 | imgs.append(inputs[('color', i, self.opt.input_scale)])
252 | proj_mats.append(inputs[('proj', i)])
253 | pose_mats.append(inputs[('pose', i)])
254 |
255 | # outputs = self.model(imgs[0], imgs[1:], proj_mats[0], proj_mats[1:],
256 | # inputs[('inv_K_pool', 0)])
257 | outputs = self.model(imgs[0], imgs[1:], pose_mats[0], pose_mats[1:],
258 | inputs[('inv_K_pool', 0)])
259 | losses = self.compute_losses(inputs, outputs)
260 | return losses, outputs
261 |
262 | def compute_losses(self, inputs, outputs):
263 | losses, loss, s = {}, 0, 0
264 | depth_pred = outputs[('depth_pred', s)]
265 | depth_pred_2 = outputs[('depth_pred_2', s)]
266 | depth_gt = inputs[('depth_gt', 0, s)]
267 |
268 |
269 | valid_depth = (depth_gt > 0)
270 |
271 | # if self.opt.pred_conf:
272 | # log_conf_pred = outputs[('log_conf_pred', s)]
273 | # conf_pred = torch.exp(log_conf_pred)
274 | # min_conf = self.opt.min_conf
275 | # max_conf = self.opt.max_conf if self.opt.max_conf != -1 else None
276 | # conf_pred = conf_pred.clamp(min_conf, max_conf)
277 | # loss_depth = ((depth_pred - depth_gt).abs() / conf_pred +
278 | # log_conf_pred)[valid_depth].mean()
279 | # else:
280 | loss_depth = (depth_pred[valid_depth] - depth_gt[valid_depth]).abs().mean()
281 |
282 | loss_depth_2 = (depth_pred_2[valid_depth] - depth_gt[valid_depth]).abs().mean()
283 |
284 | losses["depth"] = loss_depth
285 | losses["depth_2"] = loss_depth_2
286 |
287 | loss += loss_depth + loss_depth_2
288 | losses["loss"] = loss
289 |
290 | return losses
291 |
292 |
293 | def run_fusion(dense_folder, out_folder, opts):
294 | cmd = f"python patchmatch_fusion.py \
295 | --dense_folder {dense_folder} \
296 | --outdir {out_folder} \
297 | --n_proc 4 \
298 | --conf_thres {opts.conf_thres} \
299 | --att_thres {opts.att_thres} \
300 | --use_conf_thres {opts.pred_conf} \
301 | --geo_depth_thres {opts.geo_depth_thres} \
302 | --geo_pixel_thres {opts.geo_pixel_thres} \
303 | --num_consistent {opts.num_consistent} \
304 | "
305 |
306 | os.system(cmd)
307 |
308 |
309 | if __name__ == "__main__":
310 | options = MVS2DOptions()
311 | opts = options.parse()
312 |
313 | set_random_seed(666)
314 |
315 | if torch.cuda.device_count() > 1 and not opts.multiprocessing_distributed:
316 | raise Exception(
317 | "Detected more than 1 GPU. Please set multiprocessing_distributed=1 or set CUDA_VISIBLE_DEVICES"
318 | )
319 |
320 | opts.distributed = opts.world_size > 1 or opts.multiprocessing_distributed
321 | if opts.multiprocessing_distributed:
322 | # total_gpus, opts.rank = init_dist_pytorch(opts.tcp_port,
323 | # opts.local_rank,
324 | # backend='nccl')
325 | print('opts.local_rank', opts.local_rank)
326 | torch.cuda.set_device(opts.local_rank)
327 | dist.init_process_group("nccl", rank=opts.local_rank, world_size=3)
328 | opts.ngpus_per_node = 3
329 | opts.gpu = opts.local_rank
330 | print("Use GPU: {}/{} for training".format(opts.gpu,
331 | opts.ngpus_per_node))
332 | else:
333 | opts.gpu = 0
334 |
335 | if opts.mode == 'train':
336 | trainer = Trainer(opts)
337 | trainer.train()
338 |
339 | elif opts.mode == 'test':
340 | trainer = Trainer(opts)
341 | trainer.val()
342 |
343 | elif opts.mode == 'full_test':
344 | ## save depth prediction
345 | opts.mode = 'test'
346 | trainer = Trainer(opts)
347 | trainer.val()
348 |
349 | ## fuse dense prediction into final point cloud
350 | dense_folder = f"{opts.log_dir}/{opts.model_name}/eval_000/prediction"
351 | out_folder = f"{opts.log_dir}/{opts.model_name}/recon"
352 | run_fusion(dense_folder, out_folder, opts)
353 |
354 | ## eval point cloud
355 | MeanData, MeanStl, MeanAvg = dtu_pyeval(
356 | f"{out_folder}",
357 | gt_dir='./data/SampleSet/MVS Data/',
358 | voxel_down_sample=False,
359 | fn=f"{out_folder}/result.txt")
360 |
--------------------------------------------------------------------------------
/trainer_base_af.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | from open3d import *
3 | import numpy as np
4 | import time
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | from torch.utils.data import DataLoader
10 | # from tensorboardX import SummaryWriter
11 | from torch.utils.tensorboard import SummaryWriter
12 | import json
13 | from utils import *
14 | import os
15 | import stat
16 | import glob
17 | import shutil
18 | from torch.autograd import Variable
19 | import torch.optim.lr_scheduler as lr_sched
20 | from options import MVS2DOptions
21 | import torch.backends.cudnn as cudnn
22 |
23 | cudnn.benchmark = True
24 |
25 | g = torch.Generator()
26 | g.manual_seed(0)
27 |
28 |
29 | def worker_init_fn(worker_id):
30 | seed = np.random.get_state()[1][0] + worker_id
31 | np.random.seed(seed)
32 | import random
33 | random.seed(seed)
34 |
35 |
36 | def file_remove_readonly(func, path, execinfo):
37 | os.chmod(path, stat.S_IWUSR)#修改文件权限
38 | func(path)
39 |
40 |
41 | class BaseTrainer(object):
42 | def __init__(self, options):
43 |
44 | self.is_best = {}
45 | self.epoch = 0
46 | self.step = 0
47 | self.eval_step = 0
48 | self.opt = options
49 | self.is_master = self.opt.gpu == 0
50 | self.opt.is_master = self.opt.gpu == 0
51 | self.device = self.opt.gpu
52 |
53 |
54 | # base_dir = '.'
55 | # self.opt.log_dir = os.path.join(base_dir, self.opt.log_dir)
56 | self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)
57 | if self.opt.is_master:
58 | if os.path.exists(self.log_path) and self.opt.overwrite:
59 | try:
60 | shutil.rmtree(self.log_path)
61 | # shutil.rmtree(self.log_path, onerror=file_remove_readonly)
62 | except:
63 | print('overwrite folder failed')
64 | self.log_file = os.path.join(self.log_path, "log.txt")
65 |
66 | self.writers = {}
67 | for mode in ["train", "val"]:
68 | self.writers[mode] = SummaryWriter(
69 | os.path.join(self.log_path, mode))
70 | if self.opt.is_master:
71 | with open(self.log_file, 'w') as f:
72 | f.write(self.opt.note + '\n')
73 |
74 | self.save_opts()
75 |
76 | self.build_dataset()
77 |
78 | self.build_model()
79 |
80 | # self.build_optimizer()
81 | self.fetch_optimizer()
82 |
83 | if self.opt.load_weights_folder is not None:
84 | self.load_model()
85 |
86 | if self.opt.distributed:
87 | if self.opt.gpu is not None:
88 | print(
89 | f"batch size on GPU: {self.opt.gpu}: {self.opt.batch_size}"
90 | )
91 |
92 | self.model = torch.nn.parallel.DistributedDataParallel(
93 | self.model,
94 | device_ids=[self.opt.gpu],
95 | find_unused_parameters=True)
96 | else:
97 | model = torch.nn.parallel.DistributedDataParallel(
98 | self.model, find_unused_parameters=True)
99 |
100 | # self.build_scheduler()
101 |
102 | self.total_data_time = 0
103 | self.total_op_time = 0
104 | if self.opt.epoch_size == -1:
105 | self.opt.epoch_size = len(self.train_loader)
106 |
107 | if self.opt.is_master:
108 | print("Training model named:\n ", self.opt.model_name)
109 | print("Models and tensorboard events files are saved to:\n ",
110 | self.opt.log_dir)
111 |
112 | self.num_total_steps = len(self.train_loader) * self.opt.num_epochs
113 | print("There are {:d} training items and {:d} validation items\n".
114 | format(
115 | len(self.train_loader) * self.opt.batch_size,
116 | len(self.val_loader) * 1))
117 |
118 | # def build_optimizer(self):
119 | # optimizer = optim.Adam(self.model.parameters(),
120 | # lr=self.opt.LR,
121 | # weight_decay=self.opt.WEIGHT_DECAY)
122 |
123 | # self.model_optimizer = optimizer
124 |
125 |
126 | # def build_scheduler(self):
127 | # total_iters_each_epoch = len(self.train_loader)
128 | # decay_steps = [
129 | # x * total_iters_each_epoch for x in self.opt.DECAY_STEP_LIST
130 | # ]
131 | # total_steps = total_iters_each_epoch * self.opt.num_epochs
132 |
133 | # def lr_lbmd(cur_epoch):
134 | # cur_decay = 1
135 | # for decay_step in decay_steps:
136 | # if cur_epoch >= decay_step:
137 | # cur_decay = cur_decay * self.opt.LR_DECAY
138 | # return max(cur_decay, self.opt.LR_CLIP / self.opt.LR)
139 |
140 | # self.model_lr_scheduler = lr_sched.LambdaLR(self.model_optimizer,
141 | # lr_lbmd,
142 | # last_epoch=-1)
143 |
144 | def fetch_optimizer(self):
145 | """ Create the optimizer and learning rate scheduler """
146 | total_iters_each_epoch = len(self.train_loader)
147 | total_steps = total_iters_each_epoch * self.opt.num_epochs
148 | optimizer = optim.AdamW(self.model.parameters(), lr=self.opt.LR, weight_decay=.00001, eps=1e-8)
149 |
150 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, self.opt.LR, total_steps+100,
151 | pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
152 | self.model_lr_scheduler = scheduler
153 | self.model_optimizer = optimizer
154 |
155 |
156 | def to_gpu(self, inputs, keys=None):
157 | if keys == None:
158 | keys = inputs.keys()
159 | for key in keys:
160 | if key not in inputs:
161 | continue
162 | ipt = inputs[key]
163 | if type(ipt) == torch.Tensor:
164 | inputs[key] = ipt.cuda(self.opt.gpu, non_blocking=True)
165 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor:
166 | inputs[key] = [
167 | x.cuda(self.opt.gpu, non_blocking=True) for x in ipt
168 | ]
169 | elif type(ipt) == dict:
170 | for k in ipt.keys():
171 | if type(ipt[k]) == torch.Tensor:
172 | ipt[k] = ipt[k].cuda(self.opt.gpu, non_blocking=True)
173 |
174 | def build_dataset(self):
175 | # if self.opt.dataset == 'ScanNet':
176 | # from datasets.ScanNet import ScanNet as Dataset
177 | # elif self.opt.dataset == 'DeMoN':
178 | # from datasets.DeMoN import DeMoN as Dataset
179 | # elif self.opt.dataset == 'DTU':
180 | # from datasets.DTU import DTU as Dataset
181 | # else:
182 | # raise Exception("Unknown Dataset")
183 | from datasets.DDAD import DDAD
184 |
185 | train_dataset = DDAD(self.opt, True)
186 | if self.opt.distributed:
187 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(
188 | train_dataset)
189 | else:
190 | self.train_sampler = None
191 | self.train_loader = DataLoader(train_dataset,
192 | self.opt.batch_size,
193 | shuffle=(self.train_sampler is None),
194 | num_workers=self.opt.num_workers,
195 | pin_memory=True,
196 | worker_init_fn=worker_init_fn,
197 | drop_last=True,
198 | sampler=self.train_sampler)
199 |
200 | val_dataset = DDAD(self.opt, False)
201 | if self.opt.distributed:
202 | self.val_sampler = torch.utils.data.distributed.DistributedSampler(
203 | val_dataset)
204 | else:
205 | self.val_sampler = None
206 | self.val_loader = DataLoader(val_dataset,
207 | self.opt.batch_size,
208 | shuffle=False,
209 | num_workers=self.opt.num_workers,
210 | pin_memory=True,
211 | worker_init_fn=worker_init_fn,
212 | drop_last=False,
213 | sampler=self.val_sampler)
214 |
215 | # val_dataset = DDAD(self.opt, False)
216 | # self.val_sampler = None
217 | # self.val_loader = DataLoader(val_dataset,
218 | # 1,
219 | # shuffle=False,
220 | # num_workers=self.opt.num_workers,
221 | # pin_memory=True,
222 | # drop_last=False,
223 | # sampler=self.val_sampler)
224 |
225 | def log_time(self, batch_idx, op_time, step_time, loss):
226 | """Print a logging statement to the terminal
227 | """
228 | if self.opt.distributed:
229 | ops_per_sec = self.opt.ngpus_per_node * self.opt.batch_size / op_time
230 | steps_per_sec = self.opt.ngpus_per_node * self.opt.batch_size / step_time
231 | else:
232 | ops_per_sec = self.opt.batch_size / op_time
233 | steps_per_sec = self.opt.batch_size / step_time
234 | time_sofar = time.time() - self.start_time
235 |
236 | training_time_left = (self.num_total_steps / self.step -
237 | 1.0) * time_sofar if self.step > 0 else 0
238 | print_string = "epoch {:>3} | batch {:>6}/{:>6} | ops/s: {:5.1f} | steps/s: {:5.1f} | t_data/t_op: {:5.1f} " + \
239 | " | loss: {:.5f} | time elapsed: {} | time left: {} | lr: {:.7f}"
240 | self.log_string(
241 | print_string.format(self.epoch, batch_idx, len(self.train_loader),
242 | ops_per_sec, steps_per_sec,
243 | self.total_data_time / self.total_op_time,
244 | loss, sec_to_hm_str(time_sofar),
245 | sec_to_hm_str(training_time_left),
246 | self.model_optimizer.param_groups[0]['lr']))
247 |
248 | def train_epoch(self):
249 | if self.opt.is_master:
250 | print("Training")
251 | self.writers['train'].add_scalar(
252 | "lr", self.model_optimizer.param_groups[0]['lr'], self.step)
253 | self.set_train()
254 | before_data_loader_time = time.time()
255 | time_last_step = time.time()
256 |
257 | if self.opt.epoch_size == 0:
258 | return
259 |
260 | for batch_idx, inputs in enumerate(self.train_loader):
261 | if batch_idx >= self.opt.epoch_size:
262 | break
263 | after_data_loader_time = time.time()
264 | duration_data = after_data_loader_time - before_data_loader_time
265 | self.total_data_time += duration_data
266 | before_op_time = time.time()
267 |
268 | self.model_lr_scheduler.step(self.step)
269 |
270 | if self.opt.is_master:
271 | try:
272 | cur_lr = float(self.model_optimizer.lr)
273 | except:
274 | cur_lr = self.model_optimizer.param_groups[0]['lr']
275 |
276 | self.writers['train'].add_scalar('meta_data/learning_rate',
277 | cur_lr, self.step)
278 |
279 | self.model_optimizer.zero_grad()
280 | losses, outputs = self.process_batch(inputs, 'train')
281 | losses['loss'].backward()
282 |
283 | torch.nn.utils.clip_grad_norm_(self.parameters_to_train,
284 | self.opt.GRAD_NORM_CLIP)
285 |
286 | contain_nan = False
287 | for weight in self.parameters_to_train:
288 | if weight.grad is not None:
289 | if torch.any(torch.isnan(weight.grad)):
290 | print('skip parameters update because of nan in grad')
291 | contain_nan = True
292 | if not contain_nan:
293 | self.model_optimizer.step()
294 |
295 | duration = time.time() - before_op_time
296 | self.total_op_time += duration
297 |
298 | if self.opt.is_master and batch_idx % self.opt.log_frequency == 0:
299 | duration_step = time.time() - time_last_step
300 | self.log_time(batch_idx, duration, duration_step,
301 | losses["loss"].cpu().data)
302 | self.log("train", inputs, losses, batch_idx, outputs)
303 | self.step += 1
304 | before_data_loader_time = time.time()
305 | time_last_step = time.time()
306 |
307 | def update_monitor_key(self, metrics, keys, goals):
308 | if len(keys):
309 | if type(keys) != list:
310 | keys = [keys]
311 | for key, goal in zip(keys, goals):
312 | val = metrics[key]
313 | if not hasattr(self, key):
314 | setattr(self, key, val)
315 | self.is_best[key] = True
316 | else:
317 | if goal == 'minimize':
318 | if val < getattr(self, key):
319 | self.is_best[key] = True
320 | setattr(self, key, val)
321 | else:
322 | self.is_best[key] = False
323 | elif goal == 'maximize':
324 | if val > getattr(self, key):
325 | self.is_best[key] = True
326 | setattr(self, key, val)
327 | else:
328 | self.is_best[key] = False
329 |
330 | def set_train(self):
331 | self.model.train()
332 |
333 | def set_eval(self):
334 | self.model.eval()
335 |
336 | def train(self):
337 | self.start_time = time.time()
338 | if self.opt.is_master:
339 | print("Total epoch: %d " % self.opt.num_epochs)
340 | print("train loader size: %d " % len(self.train_loader))
341 | print("val loader size: %d " % len(self.val_loader))
342 | print("log_frequency: %d " % self.opt.log_frequency)
343 | for self.epoch in range(self.opt.num_epochs):
344 | if self.opt.distributed:
345 | self.train_sampler.set_epoch(self.epoch)
346 | self.train_epoch()
347 | # if self.opt.is_master:
348 | self.val_epoch()
349 | torch.distributed.barrier()
350 | torch.cuda.empty_cache()
351 | if self.opt.is_master:
352 | self.save_model(monitor_key=self.opt.monitor_key)
353 |
354 | def val(self):
355 | self.val_epoch()
356 |
357 | def process_batch(self, inputs, mode):
358 | raise Exception("Need to implement process_batch")
359 |
360 | def compute_losses(self, inputs, outputs):
361 | raise Exception("Need to implement compute_losses")
362 |
363 | def log_string(self, content):
364 | with open(self.log_file, 'a') as f:
365 | f.write(content + '\n')
366 | print(content, flush=True)
367 |
368 | def log(self, mode, inputs, losses, batch_idx, outputs):
369 | """Write an event to the tensorboard events file
370 | """
371 | writer = self.writers[mode]
372 | for l, v in losses.items():
373 | if type(losses[l]) == dict:
374 | writer.add_scalars("{}".format(l), v, self.step)
375 | else:
376 | writer.add_scalar("{}".format(l), v, self.step)
377 |
378 | if batch_idx % 150 == 0:
379 | writer.add_image('image0', inputs[("color", 0, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
380 | writer.add_image('image1', inputs[("color", 1, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
381 | writer.add_image('image2', inputs[("color", 2, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
382 | depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0])
383 | writer.add_image('depth_gt', depth_gt, global_step=self.step, walltime=None, dataformats='HWC')
384 | depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0])
385 | writer.add_image('depth_pred', depth_pred, global_step=self.step, walltime=None, dataformats='HWC')
386 | depth_pred_2 = gray_2_colormap_np(outputs[('depth_pred_2', 0)][0][0])
387 | writer.add_image('depth_pred_2', depth_pred_2, global_step=self.step, walltime=None, dataformats='HWC')
388 |
389 |
390 | def save_opts(self):
391 | """Save options to disk so we know what we ran this experiment with
392 | """
393 | models_dir = os.path.join(self.log_path, "models")
394 | if not os.path.exists(models_dir):
395 | os.makedirs(models_dir)
396 | to_save = self.opt.__dict__.copy()
397 |
398 | with open(os.path.join(models_dir, 'opt.json'), 'w') as f:
399 | json.dump(to_save, f, indent=2, sort_keys=True)
400 |
401 | def clean_models(self, keep_ids):
402 | models = glob.glob(os.path.join(self.log_path, "models", "weights_*"))
403 | models = sorted(models,
404 | key=lambda x: int(x.split('/')[-1].split('_')[-1]))
405 | for i in range(len(models) - 1):
406 | epoch = int(models[i].split('/')[-1].split('_')[-1])
407 | if epoch not in keep_ids:
408 | shutil.rmtree(models[i])
409 |
410 | def save_model(self, monitor_key=""):
411 | """Save model weights to disk
412 | """
413 | save_folder = os.path.join(self.log_path, "models", "weights_latest")
414 | if not os.path.exists(save_folder):
415 | os.makedirs(save_folder)
416 | print("save model to folder %s" % save_folder)
417 |
418 | save_path = os.path.join(save_folder, "model_{}.pth".format(self.epoch))
419 | if self.opt.distributed:
420 | to_save = self.model.module.state_dict()
421 | else:
422 | to_save = self.model.state_dict()
423 |
424 | torch.save(to_save, save_path)
425 | save_path_opt = os.path.join(save_folder, "{}.pth".format("adam"))
426 | torch.save(self.model_optimizer.state_dict(), save_path_opt)
427 | # if len(monitor_key):
428 | # if type(monitor_key) != list:
429 | # monitor_key = [monitor_key]
430 | # for key in monitor_key:
431 | # if not self.is_best[key]:
432 | # continue
433 | # save_folder = os.path.join(self.log_path, "models",
434 | # f"weights_best_{key}")
435 | # os.makedirs(save_folder, exist_ok=True)
436 | # cmd = f"cp {save_path} {save_folder}/model.pth"
437 | # os.system(cmd)
438 | # cmd = f"cp {save_path_opt} {save_folder}/adam.pth"
439 | # os.system(cmd)
440 | # with open(f"{save_folder}/key.txt", "w") as f:
441 | # val = getattr(self, key)
442 | # f.write(f"{key} {val}\n")
443 |
444 | def load_model(self):
445 | self.opt.load_weights_folder = os.path.expanduser(
446 | self.opt.load_weights_folder)
447 | assert os.path.isdir(self.opt.load_weights_folder), \
448 | "Cannot find folder {}".format(self.opt.load_weights_folder)
449 | print("loading model from folder {}".format(
450 | self.opt.load_weights_folder))
451 |
452 | try:
453 | self.epoch = int(
454 | self.opt.load_weights_folder.split('/')[-2].split('_')[1])
455 | except:
456 | self.epoch = 0
457 |
458 | try:
459 | path = os.path.join(self.opt.load_weights_folder,
460 | "{}.pth".format("model"))
461 | model_dict = self.model.state_dict()
462 | pretrained_dict = torch.load(path, 'cpu')
463 | for k, v in pretrained_dict.items():
464 | if k not in model_dict:
465 | print('model dict missing ', k, v.shape)
466 | for k, v in model_dict.items():
467 | if k not in pretrained_dict:
468 | print('pretrained_dict missing ', k, v.shape)
469 | pretrained_dict = {
470 | k: v
471 | for k, v in pretrained_dict.items() if k in model_dict
472 | }
473 | model_dict.update(pretrained_dict)
474 | self.model.load_state_dict(model_dict)
475 | except Exception as e:
476 | print(e)
477 | print("Fail loading {}".format("model"))
478 |
479 | # loading optimizer state
480 | optimizer_load_path = os.path.join(self.opt.load_weights_folder,
481 | "adam.pth")
482 | if os.path.isfile(optimizer_load_path):
483 | print("Loading optimizer weights")
484 | optimizer_dict = torch.load(optimizer_load_path, 'cpu')
485 | try:
486 | self.model_optimizer.load_state_dict(optimizer_dict)
487 | except Exception as e:
488 | print(e)
489 | print("Fail loading optimizer weights")
490 | else:
491 | print("Cannot find optimizer weights so optimizer is randomly initialized")
492 |
--------------------------------------------------------------------------------
/trainer_base_kitti.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | from open3d import *
3 | import numpy as np
4 | import time
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | from torch.utils.data import DataLoader
10 | # from tensorboardX import SummaryWriter
11 | from torch.utils.tensorboard import SummaryWriter
12 | import json
13 | from utils import *
14 | import os
15 | import stat
16 | import glob
17 | import shutil
18 | from torch.autograd import Variable
19 | import torch.optim.lr_scheduler as lr_sched
20 | from options import MVS2DOptions
21 | import torch.backends.cudnn as cudnn
22 |
23 | cudnn.benchmark = True
24 |
25 | g = torch.Generator()
26 | g.manual_seed(0)
27 |
28 |
29 | def worker_init_fn(worker_id):
30 | seed = np.random.get_state()[1][0] + worker_id
31 | np.random.seed(seed)
32 | import random
33 | random.seed(seed)
34 |
35 |
36 | def file_remove_readonly(func, path, execinfo):
37 | os.chmod(path, stat.S_IWUSR)#修改文件权限
38 | func(path)
39 |
40 |
41 | class BaseTrainer(object):
42 | def __init__(self, options):
43 |
44 | self.is_best = {}
45 | self.epoch = 0
46 | self.step = 0
47 | self.eval_step = 0
48 | self.opt = options
49 | self.is_master = self.opt.gpu == 0
50 | self.opt.is_master = self.opt.gpu == 0
51 | self.device = self.opt.gpu
52 |
53 |
54 | # base_dir = '.'
55 | # self.opt.log_dir = os.path.join(base_dir, self.opt.log_dir)
56 | self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)
57 | if self.opt.is_master:
58 | if os.path.exists(self.log_path) and self.opt.overwrite:
59 | try:
60 | shutil.rmtree(self.log_path)
61 | # shutil.rmtree(self.log_path, onerror=file_remove_readonly)
62 | except:
63 | print('overwrite folder failed')
64 | self.log_file = os.path.join(self.log_path, "log.txt")
65 |
66 | self.writers = {}
67 | for mode in ["train", "val"]:
68 | self.writers[mode] = SummaryWriter(
69 | os.path.join(self.log_path, mode))
70 | if self.opt.is_master:
71 | with open(self.log_file, 'w') as f:
72 | f.write(self.opt.note + '\n')
73 |
74 | self.save_opts()
75 |
76 | self.build_dataset()
77 |
78 | self.build_model()
79 |
80 | # self.build_optimizer()
81 | self.fetch_optimizer()
82 |
83 | if self.opt.load_weights_folder is not None:
84 | self.load_model()
85 |
86 | if self.opt.distributed:
87 | if self.opt.gpu is not None:
88 | print(
89 | f"batch size on GPU: {self.opt.gpu}: {self.opt.batch_size}"
90 | )
91 |
92 | self.model = torch.nn.parallel.DistributedDataParallel(
93 | self.model,
94 | device_ids=[self.opt.gpu],
95 | find_unused_parameters=True)
96 | else:
97 | model = torch.nn.parallel.DistributedDataParallel(
98 | self.model, find_unused_parameters=True)
99 |
100 | # self.build_scheduler()
101 |
102 | self.total_data_time = 0
103 | self.total_op_time = 0
104 | if self.opt.epoch_size == -1:
105 | self.opt.epoch_size = len(self.train_loader)
106 |
107 | if self.opt.is_master:
108 | print("Training model named:\n ", self.opt.model_name)
109 | print("Models and tensorboard events files are saved to:\n ",
110 | self.opt.log_dir)
111 |
112 | self.num_total_steps = len(self.train_loader) * self.opt.num_epochs
113 | print("There are {:d} training items and {:d} validation items\n".
114 | format(
115 | len(self.train_loader) * self.opt.batch_size,
116 | len(self.val_loader) * 1))
117 |
118 | # def build_optimizer(self):
119 | # optimizer = optim.Adam(self.model.parameters(),
120 | # lr=self.opt.LR,
121 | # weight_decay=self.opt.WEIGHT_DECAY)
122 |
123 | # self.model_optimizer = optimizer
124 |
125 |
126 | # def build_scheduler(self):
127 | # total_iters_each_epoch = len(self.train_loader)
128 | # decay_steps = [
129 | # x * total_iters_each_epoch for x in self.opt.DECAY_STEP_LIST
130 | # ]
131 | # total_steps = total_iters_each_epoch * self.opt.num_epochs
132 |
133 | # def lr_lbmd(cur_epoch):
134 | # cur_decay = 1
135 | # for decay_step in decay_steps:
136 | # if cur_epoch >= decay_step:
137 | # cur_decay = cur_decay * self.opt.LR_DECAY
138 | # return max(cur_decay, self.opt.LR_CLIP / self.opt.LR)
139 |
140 | # self.model_lr_scheduler = lr_sched.LambdaLR(self.model_optimizer,
141 | # lr_lbmd,
142 | # last_epoch=-1)
143 |
144 | def fetch_optimizer(self):
145 | """ Create the optimizer and learning rate scheduler """
146 | total_iters_each_epoch = len(self.train_loader)
147 | total_steps = total_iters_each_epoch * self.opt.num_epochs
148 | optimizer = optim.AdamW(self.model.parameters(), lr=self.opt.LR, weight_decay=.00001, eps=1e-8)
149 |
150 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, self.opt.LR, total_steps+100,
151 | pct_start=0.01, cycle_momentum=False, anneal_strategy='linear')
152 | self.model_lr_scheduler = scheduler
153 | self.model_optimizer = optimizer
154 |
155 |
156 | def to_gpu(self, inputs, keys=None):
157 | if keys == None:
158 | keys = inputs.keys()
159 | for key in keys:
160 | if key not in inputs:
161 | continue
162 | ipt = inputs[key]
163 | if type(ipt) == torch.Tensor:
164 | inputs[key] = ipt.cuda(self.opt.gpu, non_blocking=True)
165 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor:
166 | inputs[key] = [
167 | x.cuda(self.opt.gpu, non_blocking=True) for x in ipt
168 | ]
169 | elif type(ipt) == dict:
170 | for k in ipt.keys():
171 | if type(ipt[k]) == torch.Tensor:
172 | ipt[k] = ipt[k].cuda(self.opt.gpu, non_blocking=True)
173 |
174 | def build_dataset(self):
175 | # if self.opt.dataset == 'ScanNet':
176 | # from datasets.ScanNet import ScanNet as Dataset
177 | # elif self.opt.dataset == 'DeMoN':
178 | # from datasets.DeMoN import DeMoN as Dataset
179 | # elif self.opt.dataset == 'DTU':
180 | # from datasets.DTU import DTU as Dataset
181 | # else:
182 | # raise Exception("Unknown Dataset")
183 | from datasets.kitti import DDAD_kitti
184 |
185 | train_dataset = DDAD_kitti(self.opt, True)
186 | if self.opt.distributed:
187 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(
188 | train_dataset)
189 | else:
190 | self.train_sampler = None
191 | self.train_loader = DataLoader(train_dataset,
192 | self.opt.batch_size,
193 | shuffle=(self.train_sampler is None),
194 | num_workers=self.opt.num_workers,
195 | pin_memory=True,
196 | worker_init_fn=worker_init_fn,
197 | drop_last=True,
198 | sampler=self.train_sampler)
199 |
200 | val_dataset = DDAD_kitti(self.opt, False)
201 | if self.opt.distributed:
202 | self.val_sampler = torch.utils.data.distributed.DistributedSampler(
203 | val_dataset)
204 | else:
205 | self.val_sampler = None
206 | self.val_loader = DataLoader(val_dataset,
207 | self.opt.batch_size,
208 | shuffle=False,
209 | num_workers=self.opt.num_workers,
210 | pin_memory=True,
211 | worker_init_fn=worker_init_fn,
212 | drop_last=False,
213 | sampler=self.val_sampler)
214 |
215 | # val_dataset = DDAD(self.opt, False)
216 | # self.val_sampler = None
217 | # self.val_loader = DataLoader(val_dataset,
218 | # 1,
219 | # shuffle=False,
220 | # num_workers=self.opt.num_workers,
221 | # pin_memory=True,
222 | # drop_last=False,
223 | # sampler=self.val_sampler)
224 |
225 | def log_time(self, batch_idx, op_time, step_time, loss):
226 | """Print a logging statement to the terminal
227 | """
228 | if self.opt.distributed:
229 | ops_per_sec = self.opt.ngpus_per_node * self.opt.batch_size / op_time
230 | steps_per_sec = self.opt.ngpus_per_node * self.opt.batch_size / step_time
231 | else:
232 | ops_per_sec = self.opt.batch_size / op_time
233 | steps_per_sec = self.opt.batch_size / step_time
234 | time_sofar = time.time() - self.start_time
235 |
236 | training_time_left = (self.num_total_steps / self.step -
237 | 1.0) * time_sofar if self.step > 0 else 0
238 | print_string = "epoch {:>3} | batch {:>6}/{:>6} | ops/s: {:5.1f} | steps/s: {:5.1f} | t_data/t_op: {:5.1f} " + \
239 | " | loss: {:.5f} | time elapsed: {} | time left: {} | lr: {:.7f}"
240 | self.log_string(
241 | print_string.format(self.epoch, batch_idx, len(self.train_loader),
242 | ops_per_sec, steps_per_sec,
243 | self.total_data_time / self.total_op_time,
244 | loss, sec_to_hm_str(time_sofar),
245 | sec_to_hm_str(training_time_left),
246 | self.model_optimizer.param_groups[0]['lr']))
247 |
248 | def train_epoch(self):
249 | if self.opt.is_master:
250 | print("Training")
251 | self.writers['train'].add_scalar(
252 | "lr", self.model_optimizer.param_groups[0]['lr'], self.step)
253 | self.set_train()
254 | before_data_loader_time = time.time()
255 | time_last_step = time.time()
256 |
257 | if self.opt.epoch_size == 0:
258 | return
259 |
260 | for batch_idx, inputs in enumerate(self.train_loader):
261 | if batch_idx >= self.opt.epoch_size:
262 | break
263 | after_data_loader_time = time.time()
264 | duration_data = after_data_loader_time - before_data_loader_time
265 | self.total_data_time += duration_data
266 | before_op_time = time.time()
267 |
268 | self.model_lr_scheduler.step(self.step)
269 |
270 | if self.opt.is_master:
271 | try:
272 | cur_lr = float(self.model_optimizer.lr)
273 | except:
274 | cur_lr = self.model_optimizer.param_groups[0]['lr']
275 |
276 | self.writers['train'].add_scalar('meta_data/learning_rate',
277 | cur_lr, self.step)
278 |
279 | self.model_optimizer.zero_grad()
280 | losses, outputs = self.process_batch(inputs, 'train')
281 | losses['loss'].backward()
282 |
283 | torch.nn.utils.clip_grad_norm_(self.parameters_to_train,
284 | self.opt.GRAD_NORM_CLIP)
285 |
286 | contain_nan = False
287 | for weight in self.parameters_to_train:
288 | if weight.grad is not None:
289 | if torch.any(torch.isnan(weight.grad)):
290 | print('skip parameters update because of nan in grad')
291 | contain_nan = True
292 | if not contain_nan:
293 | self.model_optimizer.step()
294 |
295 | duration = time.time() - before_op_time
296 | self.total_op_time += duration
297 |
298 | if self.opt.is_master and batch_idx % self.opt.log_frequency == 0:
299 | duration_step = time.time() - time_last_step
300 | self.log_time(batch_idx, duration, duration_step,
301 | losses["loss"].cpu().data)
302 | self.log("train", inputs, losses, batch_idx, outputs)
303 | self.step += 1
304 | before_data_loader_time = time.time()
305 | time_last_step = time.time()
306 |
307 | def update_monitor_key(self, metrics, keys, goals):
308 | if len(keys):
309 | if type(keys) != list:
310 | keys = [keys]
311 | for key, goal in zip(keys, goals):
312 | val = metrics[key]
313 | if not hasattr(self, key):
314 | setattr(self, key, val)
315 | self.is_best[key] = True
316 | else:
317 | if goal == 'minimize':
318 | if val < getattr(self, key):
319 | self.is_best[key] = True
320 | setattr(self, key, val)
321 | else:
322 | self.is_best[key] = False
323 | elif goal == 'maximize':
324 | if val > getattr(self, key):
325 | self.is_best[key] = True
326 | setattr(self, key, val)
327 | else:
328 | self.is_best[key] = False
329 |
330 | def set_train(self):
331 | self.model.train()
332 |
333 | def set_eval(self):
334 | self.model.eval()
335 |
336 | def train(self):
337 | self.start_time = time.time()
338 | if self.opt.is_master:
339 | print("Total epoch: %d " % self.opt.num_epochs)
340 | print("train loader size: %d " % len(self.train_loader))
341 | print("val loader size: %d " % len(self.val_loader))
342 | print("log_frequency: %d " % self.opt.log_frequency)
343 | for self.epoch in range(self.opt.num_epochs):
344 | if self.opt.distributed:
345 | self.train_sampler.set_epoch(self.epoch)
346 | self.train_epoch()
347 | # if self.opt.is_master:
348 | self.val_epoch()
349 | torch.distributed.barrier()
350 | torch.cuda.empty_cache()
351 | if self.opt.is_master:
352 | self.save_model(monitor_key=self.opt.monitor_key)
353 |
354 | def val(self):
355 | self.val_epoch()
356 |
357 | def process_batch(self, inputs, mode):
358 | raise Exception("Need to implement process_batch")
359 |
360 | def compute_losses(self, inputs, outputs):
361 | raise Exception("Need to implement compute_losses")
362 |
363 | def log_string(self, content):
364 | with open(self.log_file, 'a') as f:
365 | f.write(content + '\n')
366 | print(content, flush=True)
367 |
368 | def log(self, mode, inputs, losses, batch_idx, outputs):
369 | """Write an event to the tensorboard events file
370 | """
371 | writer = self.writers[mode]
372 | for l, v in losses.items():
373 | if type(losses[l]) == dict:
374 | writer.add_scalars("{}".format(l), v, self.step)
375 | else:
376 | writer.add_scalar("{}".format(l), v, self.step)
377 |
378 | if batch_idx % 150 == 0:
379 | writer.add_image('image0', inputs[("img_ori", 0, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
380 | writer.add_image('image1', inputs[("img_ori", 1, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
381 | writer.add_image('image2', inputs[("img_ori", 2, 0)][0], global_step=self.step, walltime=None, dataformats='CHW')
382 | depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0])
383 | writer.add_image('depth_gt', depth_gt, global_step=self.step, walltime=None, dataformats='HWC')
384 | depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0])
385 | writer.add_image('depth_pred', depth_pred, global_step=self.step, walltime=None, dataformats='HWC')
386 | depth_pred_2 = gray_2_colormap_np(outputs[('depth_pred_2', 0)][0][0])
387 | writer.add_image('depth_pred_2', depth_pred_2, global_step=self.step, walltime=None, dataformats='HWC')
388 |
389 |
390 | def save_opts(self):
391 | """Save options to disk so we know what we ran this experiment with
392 | """
393 | models_dir = os.path.join(self.log_path, "models")
394 | if not os.path.exists(models_dir):
395 | os.makedirs(models_dir)
396 | to_save = self.opt.__dict__.copy()
397 |
398 | with open(os.path.join(models_dir, 'opt.json'), 'w') as f:
399 | json.dump(to_save, f, indent=2, sort_keys=True)
400 |
401 | def clean_models(self, keep_ids):
402 | models = glob.glob(os.path.join(self.log_path, "models", "weights_*"))
403 | models = sorted(models,
404 | key=lambda x: int(x.split('/')[-1].split('_')[-1]))
405 | for i in range(len(models) - 1):
406 | epoch = int(models[i].split('/')[-1].split('_')[-1])
407 | if epoch not in keep_ids:
408 | shutil.rmtree(models[i])
409 |
410 | def save_model(self, monitor_key=""):
411 | """Save model weights to disk
412 | """
413 | save_folder = os.path.join(self.log_path, "models", "weights_latest")
414 | if not os.path.exists(save_folder):
415 | os.makedirs(save_folder)
416 | print("save model to folder %s" % save_folder)
417 |
418 | save_path = os.path.join(save_folder, "model_{}.pth".format(self.epoch))
419 | if self.opt.distributed:
420 | to_save = self.model.module.state_dict()
421 | else:
422 | to_save = self.model.state_dict()
423 |
424 | torch.save(to_save, save_path)
425 | save_path_opt = os.path.join(save_folder, "{}.pth".format("adam"))
426 | torch.save(self.model_optimizer.state_dict(), save_path_opt)
427 | # if len(monitor_key):
428 | # if type(monitor_key) != list:
429 | # monitor_key = [monitor_key]
430 | # for key in monitor_key:
431 | # if not self.is_best[key]:
432 | # continue
433 | # save_folder = os.path.join(self.log_path, "models",
434 | # f"weights_best_{key}")
435 | # os.makedirs(save_folder, exist_ok=True)
436 | # cmd = f"cp {save_path} {save_folder}/model.pth"
437 | # os.system(cmd)
438 | # cmd = f"cp {save_path_opt} {save_folder}/adam.pth"
439 | # os.system(cmd)
440 | # with open(f"{save_folder}/key.txt", "w") as f:
441 | # val = getattr(self, key)
442 | # f.write(f"{key} {val}\n")
443 |
444 | def load_model(self):
445 | self.opt.load_weights_folder = os.path.expanduser(
446 | self.opt.load_weights_folder)
447 | assert os.path.isdir(self.opt.load_weights_folder), \
448 | "Cannot find folder {}".format(self.opt.load_weights_folder)
449 | print("loading model from folder {}".format(
450 | self.opt.load_weights_folder))
451 |
452 | try:
453 | self.epoch = int(
454 | self.opt.load_weights_folder.split('/')[-2].split('_')[1])
455 | except:
456 | self.epoch = 0
457 |
458 | try:
459 | path = os.path.join(self.opt.load_weights_folder,
460 | "{}.pth".format("model"))
461 | model_dict = self.model.state_dict()
462 | pretrained_dict = torch.load(path, 'cpu')
463 | for k, v in pretrained_dict.items():
464 | if k not in model_dict:
465 | print('model dict missing ', k, v.shape)
466 | for k, v in model_dict.items():
467 | if k not in pretrained_dict:
468 | print('pretrained_dict missing ', k, v.shape)
469 | pretrained_dict = {
470 | k: v
471 | for k, v in pretrained_dict.items() if k in model_dict
472 | }
473 | model_dict.update(pretrained_dict)
474 | self.model.load_state_dict(model_dict)
475 | except Exception as e:
476 | print(e)
477 | print("Fail loading {}".format("model"))
478 |
479 | # loading optimizer state
480 | optimizer_load_path = os.path.join(self.opt.load_weights_folder,
481 | "adam.pth")
482 | if os.path.isfile(optimizer_load_path):
483 | print("Loading optimizer weights")
484 | optimizer_dict = torch.load(optimizer_load_path, 'cpu')
485 | try:
486 | self.model_optimizer.load_state_dict(optimizer_dict)
487 | except Exception as e:
488 | print(e)
489 | print("Fail loading optimizer weights")
490 | else:
491 | print("Cannot find optimizer weights so optimizer is randomly initialized")
492 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import open3d as o3d
3 | from collections import defaultdict
4 | import os
5 | import random
6 | import cv2
7 | import numpy as np
8 | import torch
9 | from torch.autograd import Variable
10 | import time
11 | import torch.nn.functional as F
12 | import re
13 | import collections.abc as container_abcs
14 | import torch.distributed as dist
15 | import torch.multiprocessing as mp
16 | import subprocess
17 | import matplotlib
18 | import matplotlib.pyplot as plt
19 |
20 | def gray_2_colormap_np_2(img, cmap = 'rainbow', max = None):
21 | img = img.squeeze()
22 | assert img.ndim == 2
23 | img[img<0] = 0
24 | mask_invalid = img < 1e-10
25 | if max == None:
26 | img = img / (img.max() + 1e-8)
27 | else:
28 | img = img/(max + 1e-8)
29 |
30 | norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1)
31 | cmap_m = matplotlib.cm.get_cmap(cmap)
32 | map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m)
33 | colormap = (map.to_rgba(img)[:,:,:3]*255).astype(np.uint8)
34 | colormap[mask_invalid] = 0
35 |
36 | return colormap
37 |
38 | def gray_2_colormap_np(img, cmap = 'rainbow', max = None):
39 | img = img.cpu().detach().numpy().squeeze()
40 | assert img.ndim == 2
41 | img[img<0] = 0
42 | mask_invalid = img < 1e-10
43 | if max == None:
44 | img = img / (img.max() + 1e-8)
45 | else:
46 | img = img/(max + 1e-8)
47 |
48 | norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1)
49 | cmap_m = matplotlib.cm.get_cmap(cmap)
50 | map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m)
51 | colormap = (map.to_rgba(img)[:,:,:3]*255).astype(np.uint8)
52 | colormap[mask_invalid] = 0
53 |
54 | return colormap
55 |
56 |
57 | def plot(xs,
58 | ys,
59 | stds=None,
60 | xlabel='',
61 | ylabel='',
62 | title='',
63 | legends=None,
64 | save_fn='test.png',
65 | marker=None,
66 | marker_size=12):
67 | MARKERS = ["o", "X", "D", "^", "<", "v", ">"]
68 | if marker is None:
69 | marker = MARKERS[3]
70 |
71 | nline = len(ys)
72 | nrows, ncols = 1, 1
73 | fig, ax = plt.subplots(figsize=(7, 7))
74 | grid = plt.GridSpec(nrows, ncols, figure=fig)
75 |
76 | ax1 = plt.subplot(grid[0, 0])
77 | lh = []
78 | for i in range(nline):
79 | if stds is not None:
80 | #l, _, _= ax1.errorbar(xs, ys[i], yerr=stds[i], linewidth=4, marker=MARKERS[0], markersize=1, )
81 | l, = ax1.plot(
82 | xs,
83 | ys[i],
84 | linewidth=4,
85 | marker=marker,
86 | markersize=marker_size,
87 | )
88 | color = l.get_color()
89 | low = [x[0] for x in stds[i]]
90 | high = [x[1] for x in stds[i]]
91 | ax1.fill_between(xs, low, high, color=color, alpha=.1)
92 |
93 | else:
94 | l, = ax1.plot(
95 | xs,
96 | ys[i],
97 | linewidth=4,
98 | marker=marker,
99 | markersize=marker_size,
100 | )
101 | lh.append(l)
102 |
103 | ax1.set_xlabel(xlabel, fontsize=25)
104 | ax1.set_ylabel(ylabel, fontsize=25)
105 | ax1.set_title(title, fontsize=25)
106 | if legends is not None:
107 | lgnd = ax1.legend(lh, legends, fontsize=15)
108 | plt.savefig(save_fn)
109 |
110 |
111 | def init_dist_slurm(tcp_port, local_rank, backend='nccl'):
112 | """
113 | modified from https://github.com/open-mmlab/mmdetection
114 | Args:
115 | tcp_port:
116 | backend:
117 | Returns:
118 | """
119 | proc_id = int(os.environ['SLURM_PROCID'])
120 | ntasks = int(os.environ['SLURM_NTASKS'])
121 | node_list = os.environ['SLURM_NODELIST']
122 | num_gpus = torch.cuda.device_count()
123 | torch.cuda.set_device(proc_id % num_gpus)
124 | addr = subprocess.getoutput(
125 | 'scontrol show hostname {} | head -n1'.format(node_list))
126 | os.environ['MASTER_PORT'] = str(tcp_port)
127 | os.environ['MASTER_ADDR'] = addr
128 | os.environ['WORLD_SIZE'] = str(ntasks)
129 | os.environ['RANK'] = str(proc_id)
130 | dist.init_process_group(backend=backend)
131 |
132 | total_gpus = dist.get_world_size()
133 | rank = dist.get_rank()
134 | return total_gpus, rank
135 |
136 |
137 | def init_dist_pytorch(tcp_port, local_rank, backend='nccl'):
138 | if mp.get_start_method(allow_none=True) is None:
139 | mp.set_start_method('spawn')
140 |
141 | num_gpus = torch.cuda.device_count()
142 | torch.cuda.set_device(local_rank % num_gpus)
143 | dist.init_process_group(backend=backend,
144 | init_method='tcp://127.0.0.1:%d' % tcp_port,
145 | rank=local_rank,
146 | world_size=num_gpus)
147 | rank = dist.get_rank()
148 | return num_gpus, rank
149 |
150 |
151 | def set_random_seed(seed):
152 | random.seed(seed)
153 | np.random.seed(seed)
154 | torch.manual_seed(seed)
155 | torch.backends.cudnn.deterministic = True
156 | torch.backends.cudnn.benchmark = False
157 |
158 |
159 | def randomRotation(epsilon):
160 | axis = (np.random.rand(3) - 0.5)
161 | axis /= np.linalg.norm(axis)
162 | dtheta = np.random.randn(1) * np.pi * epsilon
163 | K = np.array(
164 | [0, -axis[2], axis[1], axis[2], 0, -axis[0], -axis[1], axis[0],
165 | 0]).reshape(3, 3)
166 | dR = np.eye(3) + np.sin(dtheta) * K + (1 - np.cos(dtheta)) * np.matmul(
167 | K, K)
168 | return dR
169 |
170 |
171 | def angle_axis_to_rotation_matrix(angle_axis):
172 | """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix
173 |
174 | Args:
175 | angle_axis (Tensor): tensor of 3d vector of axis-angle rotations.
176 |
177 | Returns:
178 | Tensor: tensor of 4x4 rotation matrices.
179 |
180 | Shape:
181 | - Input: :math:`(N, 3)`
182 | - Output: :math:`(N, 4, 4)`
183 |
184 | Example:
185 | >>> input = torch.rand(1, 3) # Nx3
186 | >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx4x4
187 | """
188 | def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
189 | # We want to be careful to only evaluate the square root if the
190 | # norm of the angle_axis vector is greater than zero. Otherwise
191 | # we get a division by zero.
192 | k_one = 1.0
193 | theta = torch.sqrt(theta2)
194 | wxyz = angle_axis / (theta + eps)
195 | wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
196 | cos_theta = torch.cos(theta)
197 | sin_theta = torch.sin(theta)
198 |
199 | r00 = cos_theta + wx * wx * (k_one - cos_theta)
200 | r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
201 | r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
202 | r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
203 | r11 = cos_theta + wy * wy * (k_one - cos_theta)
204 | r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
205 | r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
206 | r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
207 | r22 = cos_theta + wz * wz * (k_one - cos_theta)
208 | rotation_matrix = torch.cat(
209 | [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1)
210 | return rotation_matrix.view(-1, 3, 3)
211 |
212 | def _compute_rotation_matrix_taylor(angle_axis):
213 | rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
214 | k_one = torch.ones_like(rx)
215 | rotation_matrix = torch.cat(
216 | [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1)
217 | return rotation_matrix.view(-1, 3, 3)
218 |
219 | # stolen from ceres/rotation.h
220 |
221 | _angle_axis = torch.unsqueeze(angle_axis, dim=1)
222 | theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
223 | theta2 = torch.squeeze(theta2, dim=1)
224 |
225 | # compute rotation matrices
226 | rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
227 | rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)
228 |
229 | # create mask to handle both cases
230 | eps = 1e-6
231 | mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
232 | mask_pos = (mask).type_as(theta2)
233 | mask_neg = (mask == False).type_as(theta2) # noqa
234 |
235 | # create output pose matrix
236 | batch_size = angle_axis.shape[0]
237 | rotation_matrix = torch.eye(4).to(angle_axis.device).type_as(angle_axis)
238 | rotation_matrix = rotation_matrix.view(1, 4, 4).repeat(batch_size, 1, 1)
239 | # fill output matrix with masked values
240 | rotation_matrix[..., :3, :3] = \
241 | mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
242 | return rotation_matrix # Nx4x4
243 |
244 |
245 | default_collate_err_msg_format = (
246 | "default_collate: batch must contain tensors, numpy arrays, numbers, "
247 | "dicts or lists; found {}")
248 |
249 | np_str_obj_array_pattern = re.compile(r'[SaUO]')
250 |
251 |
252 | def default_convert(data):
253 | r"""Converts each NumPy array data field into a tensor"""
254 | elem_type = type(data)
255 | if isinstance(data, torch.Tensor):
256 | return data
257 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
258 | and elem_type.__name__ != 'string_':
259 | # array of string classes and object
260 | if elem_type.__name__ == 'ndarray' \
261 | and np_str_obj_array_pattern.search(data.dtype.str) is not None:
262 | return data
263 | return torch.as_tensor(data)
264 | elif isinstance(data, container_abcs.Mapping):
265 | return {key: default_convert(data[key]) for key in data}
266 | elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
267 | return elem_type(*(default_convert(d) for d in data))
268 | elif isinstance(data, container_abcs.Sequence) and not isinstance(
269 | data, string_classes):
270 | return [default_convert(d) for d in data]
271 | else:
272 | return data
273 |
274 |
275 | def custom_collate(batch):
276 | r"""Puts each data field into a tensor with outer dimension batch size"""
277 |
278 | elem = batch[0]
279 | elem_type = type(elem)
280 | if isinstance(batch, list) and isinstance(elem, tuple):
281 | #data = torch.cat((x[0] for x in batch))
282 | return [x[0] for x in batch]
283 | if type(elem) == tuple and elem[1] == 'varlen':
284 | return [x[0] for x in batch]
285 |
286 | if isinstance(elem, torch.Tensor):
287 | out = None
288 | if torch.utils.data.get_worker_info() is not None:
289 | # If we're in a background process, concatenate directly into a
290 | # shared memory tensor to avoid an extra copy
291 | numel = sum([x.numel() for x in batch])
292 | storage = elem.storage()._new_shared(numel)
293 | out = elem.new(storage)
294 | try:
295 | return torch.stack(batch, 0, out=out)
296 | except:
297 | import ipdb
298 | ipdb.set_trace()
299 |
300 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
301 | and elem_type.__name__ != 'string_':
302 | elem = batch[0]
303 | if elem_type.__name__ == 'ndarray':
304 | # array of string classes and object
305 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
306 | raise TypeError(
307 | default_collate_err_msg_format.format(elem.dtype))
308 |
309 | return custom_collate([torch.as_tensor(b) for b in batch])
310 | elif elem.shape == (): # scalars
311 | return torch.as_tensor(batch)
312 | elif isinstance(elem, float):
313 | return torch.tensor(batch, dtype=torch.float64)
314 | #elif isinstance(elem, int_classes):
315 | elif isinstance(elem, int):
316 | return torch.tensor(batch)
317 | #elif isinstance(elem, string_classes):
318 | elif isinstance(elem, str):
319 | return batch
320 | elif isinstance(elem, container_abcs.Mapping):
321 | return {key: custom_collate([d[key] for d in batch]) for key in elem}
322 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
323 | return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
324 | elif isinstance(elem, container_abcs.Sequence):
325 | transposed = zip(*batch)
326 | return [custom_collate(samples) for samples in transposed]
327 |
328 | raise TypeError(default_collate_err_msg_format.format(elem_type))
329 |
330 |
331 | _use_shared_memory = False
332 |
333 | np_str_obj_array_pattern = re.compile(r'[SaUO]')
334 |
335 | error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"
336 |
337 | numpy_type_map = {
338 | 'float64': torch.DoubleTensor,
339 | 'float32': torch.FloatTensor,
340 | 'float16': torch.HalfTensor,
341 | 'int64': torch.LongTensor,
342 | 'int32': torch.IntTensor,
343 | 'int16': torch.ShortTensor,
344 | 'int8': torch.CharTensor,
345 | 'uint8': torch.ByteTensor,
346 | }
347 |
348 |
349 | def default_collatev1_1(batch):
350 | r"""Puts each data field into a tensor with outer dimension batch size"""
351 |
352 | elem = batch[0]
353 | elem_type = type(batch[0])
354 | if isinstance(batch, list) and isinstance(elem, tuple):
355 | #data = torch.cat((x[0] for x in batch))
356 | return [x[0] for x in batch]
357 | if isinstance(batch[0], torch.Tensor):
358 | out = None
359 | if _use_shared_memory:
360 | # If we're in a background process, concatenate directly into a
361 | # shared memory tensor to avoid an extra copy
362 | numel = sum([x.numel() for x in batch])
363 | storage = batch[0].storage()._new_shared(numel)
364 | out = batch[0].new(storage)
365 | try:
366 | return torch.stack(batch, 0, out=out)
367 | except:
368 | import ipdb
369 | ipdb.set_trace()
370 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
371 | and elem_type.__name__ != 'string_':
372 | elem = batch[0]
373 | try:
374 | if elem_type.__name__ == 'ndarray':
375 | # array of string classes and object
376 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
377 | raise TypeError(error_msg_fmt.format(elem.dtype))
378 |
379 | return default_collatev1_1(
380 | [torch.from_numpy(b) for b in batch])
381 | except:
382 | import ipdb
383 | ipdb.set_trace()
384 | if elem.shape == (): # scalars
385 | py_type = float if elem.dtype.name.startswith('float') else int
386 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
387 | elif isinstance(batch[0], float):
388 | return torch.tensor(batch, dtype=torch.float64)
389 | #elif isinstance(batch[0], int_classes):
390 | elif isinstance(batch[0], int):
391 | return torch.tensor(batch)
392 | #elif isinstance(batch[0], string_classes):
393 | elif isinstance(batch[0], str):
394 | return batch
395 | elif isinstance(batch[0], container_abcs.Mapping):
396 | return {
397 | key: default_collatev1_1([d[key] for d in batch])
398 | for key in batch[0]
399 | }
400 | elif isinstance(batch[0], tuple) and hasattr(batch[0],
401 | '_fields'): # namedtuple
402 | return type(batch[0])(*(default_collatev1_1(samples)
403 | for samples in zip(*batch)))
404 | elif isinstance(batch[0], container_abcs.Sequence):
405 | transposed = zip(*batch)
406 | return [default_collatev1_1(samples) for samples in transposed]
407 |
408 | raise TypeError((error_msg_fmt.format(type(batch[0]))))
409 |
410 |
411 | def backproject_depth_th(depth, inv_K, mask=False, device='cuda'):
412 | h, w = depth.shape
413 | idu, idv = np.meshgrid(range(w), range(h))
414 | grid = np.stack((idu.flatten(), idv.flatten(), np.ones([w * h])))
415 | grid = torch.from_numpy(grid).float().to(device)
416 | x = torch.matmul(inv_K[:3, :3], grid)
417 | x = x * depth.flatten()[None, :]
418 | x = x.t()
419 | if mask:
420 | x = x[depth.flatten() > 0]
421 | return x
422 |
423 |
424 | def backproject_depth(depth, inv_K, mask=False):
425 | h, w = depth.shape
426 | idu, idv = np.meshgrid(range(w), range(h))
427 | grid = np.stack((idu.flatten(), idv.flatten(), np.ones([w * h])))
428 | x = np.matmul(inv_K[:3, :3], grid)
429 | x = x * depth.flatten()[None, :]
430 | x = x.T
431 | if mask:
432 | x = x[depth.flatten() > 0]
433 | return x
434 |
435 |
436 | def parameters_count(net, name, do_print=True):
437 | model_parameters = filter(lambda p: p.requires_grad, net.parameters())
438 | params = sum([np.prod(p.size()) for p in model_parameters])
439 | if do_print:
440 | print('#params %s: %.3f M' % (name, params / 1e6))
441 | return params
442 |
443 |
444 | def cuda_time():
445 | torch.cuda.synchronize()
446 | return time.time()
447 |
448 |
449 | def transform3x3(pc, T):
450 | # T: [4,4]
451 | # pc: [n, 3]
452 | # return: [n, 3]
453 | return (np.matmul(T[:3, :3], pc.T)).T
454 |
455 |
456 | def transform4x4(pc, T):
457 | # T: [4,4]
458 | # pc: [n, 3]
459 | # return: [n, 3]
460 | return (np.matmul(T[:3, :3], pc.T) + T[:3, 3:4]).T
461 |
462 |
463 | def transform4x4_th(pc, T):
464 | # T: [4,4]
465 | # pc: [n, 3]
466 | # return: [n, 3]
467 | return (torch.matmul(T[:3, :3], pc.t()) + T[:3, 3:4]).t()
468 |
469 |
470 | def v(var, cuda=True, volatile=False):
471 | if type(var) == torch.Tensor or type(var) == torch.DoubleTensor:
472 | res = Variable(var.float(), volatile=volatile)
473 | elif type(var) == np.ndarray:
474 | res = Variable(torch.from_numpy(var).float(), volatile=volatile)
475 | if cuda:
476 | res = res.cuda()
477 | return res
478 |
479 |
480 | def npy(var):
481 | return var.data.cpu().numpy()
482 |
483 |
484 | def worker_init_fn(worker_id):
485 | np.random.seed(np.random.get_state()[1][0] + worker_id)
486 |
487 |
488 | def sec_to_hm(t):
489 | """Convert time in seconds to time in hours, minutes and seconds
490 | e.g. 10239 -> (2, 50, 39)
491 | """
492 | t = int(t)
493 | s = t % 60
494 | t //= 60
495 | m = t % 60
496 | t //= 60
497 | return t, m, s
498 |
499 |
500 | def sec_to_hm_str(t):
501 | """Convert time in seconds to a nice string
502 | e.g. 10239 -> '02h50m39s'
503 | """
504 | h, m, s = sec_to_hm(t)
505 | return "{:02d}h{:02d}m{:02d}s".format(h, m, s)
506 |
507 |
508 | def write_ply(fn, point, normal=None, color=None):
509 |
510 | ply = o3d.geometry.PointCloud()
511 | ply.points = o3d.utility.Vector3dVector(point)
512 | if color is not None:
513 | ply.colors = o3d.utility.Vector3dVector(color)
514 | if normal is not None:
515 | ply.normals = o3d.utility.Vector3dVector(normal)
516 | o3d.io.write_point_cloud(fn, ply)
517 |
518 |
519 | def skew(x):
520 | return np.array([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]])
521 |
522 |
523 | def Thres_metrics(pred, gt, mask, interval, thre):
524 | abs_diff = (pred - gt).abs() / interval
525 | metric = (mask * (abs_diff < thre).float()).sum() / mask.sum()
526 | return metric
527 |
528 |
529 | def Thres_metrics_np(pred, gt, mask, interval, thre):
530 | abs_diff = np.abs(pred - gt) / interval
531 | metric = (mask * (abs_diff < thre)).sum() / mask.sum()
532 | return metric
533 |
--------------------------------------------------------------------------------
/visual_ddad.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import sys
5 | import cv2
6 | import torch
7 | import torch.nn as nn
8 | from options import MVS2DOptions, EvalCfg
9 | import networks
10 | from torch.utils.data import DataLoader
11 | from datasets.DDAD import DDAD
12 | import torch.nn.functional as F
13 | from utils import *
14 |
15 | def resize_depth_preserve(depth, shape):
16 | """
17 | Resizes depth map preserving all valid depth pixels
18 | Multiple downsampled points can be assigned to the same pixel.
19 |
20 | Parameters
21 | ----------
22 | depth : np.array [h,w]
23 | Depth map
24 | shape : tuple (H,W)
25 | Output shape
26 |
27 | Returns
28 | -------
29 | depth : np.array [H,W,1]
30 | Resized depth map
31 | """
32 |
33 | # Store dimensions and reshapes to single column
34 | depth = np.squeeze(depth)
35 | h, w = depth.shape
36 | x = depth.reshape(-1)
37 | # Create coordinate grid
38 | uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2)
39 | # Filters valid points
40 | idx = x > 0
41 | crd, val = uv[idx], x[idx]
42 | # Downsamples coordinates
43 | crd[:, 0] = (crd[:, 0] * (shape[0] / h)).astype(np.int32)
44 | crd[:, 1] = (crd[:, 1] * (shape[1] / w)).astype(np.int32)
45 | # Filters points inside image
46 | idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1])
47 | crd, val = crd[idx], val[idx]
48 | # Creates downsampled depth image and assigns points
49 | depth = np.zeros(shape)
50 | depth[crd[:, 0], crd[:, 1]] = val
51 | # Return resized depth map
52 | return np.expand_dims(depth, axis=0)
53 |
54 |
55 | def homo_warping_depth(src_fea, src_proj, ref_proj, depth_values):
56 | # src_fea: [B, C, H, W]
57 | # src_proj: [B, 4, 4]
58 | # ref_proj: [B, 4, 4]
59 | # depth_values: [B, Ndepth, H, W]
60 | # out: [B, C, Ndepth, H, W]
61 | batch, channels = src_fea.shape[0], src_fea.shape[1]
62 | num_depth = depth_values.shape[1]
63 | #height, width = src_fea.shape[2], src_fea.shape[3]
64 | h_src, w_src = src_fea.shape[2], src_fea.shape[3]
65 | h_ref, w_ref = depth_values.shape[2], depth_values.shape[3]
66 |
67 | with torch.no_grad():
68 |
69 | proj = torch.matmul(src_proj, torch.inverse(ref_proj))
70 | rot = proj[:, :3, :3] # [B,3,3]
71 | trans = proj[:, :3, 3:4] # [B,3,1]
72 |
73 |
74 | y, x = torch.meshgrid([torch.arange(0, h_ref, dtype=torch.float32, device=src_fea.device),
75 | torch.arange(0, w_ref, dtype=torch.float32, device=src_fea.device)])
76 | y, x = y.contiguous(), x.contiguous()
77 | y, x = y.view(h_ref * w_ref), x.view(h_ref * w_ref)
78 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
79 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
80 |
81 | rot_xyz = torch.matmul(rot, xyz)
82 | rot_depth_xyz = rot_xyz * depth_values.view(batch, 1, -1)
83 |
84 | proj_xyz = rot_depth_xyz + trans.view(batch,3,1)
85 |
86 | proj_xy = proj_xyz[:, :2, :] / proj_xyz[:, 2:3, :] # [B, 2, Ndepth, H*W]
87 | z = proj_xyz[:, 2:3, :].view(batch, h_ref, w_ref)
88 | proj_x_normalized = proj_xy[:, 0, :] / ((w_src - 1) / 2.0) - 1
89 | proj_y_normalized = proj_xy[:, 1, :] / ((h_src - 1) / 2.0) - 1
90 | X_mask = ((proj_x_normalized > 1)+(proj_x_normalized < -1)).detach()
91 | proj_x_normalized[X_mask] = 2 # make sure that no point in warped image is a combinaison of im and gray
92 | Y_mask = ((proj_y_normalized > 1)+(proj_y_normalized < -1)).detach()
93 | proj_y_normalized[Y_mask] = 2
94 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=2) # [B, Ndepth, H*W, 2]
95 | grid = proj_xy
96 | proj_mask = ((X_mask + Y_mask) > 0).view(batch, num_depth, h_ref, w_ref)
97 | proj_mask = (proj_mask + (z <= 0)) > 0
98 |
99 | warped_src_fea = F.grid_sample(src_fea, grid.view(batch, h_ref, w_ref, 2), mode='bilinear',
100 | padding_mode='zeros', align_corners=True)
101 |
102 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, h_ref, w_ref)
103 |
104 | #return warped_src_fea , proj_mask
105 | return warped_src_fea
106 |
107 |
108 | def to_gpu(inputs, keys=None):
109 | if keys == None:
110 | keys = inputs.keys()
111 | for key in keys:
112 | if key not in inputs:
113 | continue
114 | ipt = inputs[key]
115 | if type(ipt) == torch.Tensor:
116 | inputs[key] = ipt.cuda()
117 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor:
118 | inputs[key] = [
119 | x.cuda() for x in ipt
120 | ]
121 | elif type(ipt) == dict:
122 | for k in ipt.keys():
123 | if type(ipt[k]) == torch.Tensor:
124 | ipt[k] = ipt[k].cuda()
125 |
126 |
127 | options = MVS2DOptions()
128 | opts = options.parse()
129 |
130 | # opts.width = int(640)
131 | # opts.height = int(480)
132 | dataset = DDAD(opts, False)
133 | data_loader = DataLoader(dataset,
134 | 1,
135 | shuffle=False,
136 | num_workers=1,
137 | pin_memory=True,
138 | drop_last=False,
139 | sampler=None)
140 | model = networks.MVS2D(opt=opts).cuda()
141 | pretrained_dict = torch.load("/home/cjd/MVS2D/log/AFNet/models/weights_latest/model.pth")
142 | model_dict = model.state_dict()
143 | pretrained_dict = {
144 | k: v
145 | for k, v in pretrained_dict.items() if k in model_dict
146 | }
147 | model_dict.update(pretrained_dict)
148 | model.load_state_dict(model_dict)
149 |
150 | model.eval()
151 | root_path = '/data/cjd/AFnet/visual/ddad/'
152 | with torch.no_grad():
153 | for batch_idx, inputs in enumerate(data_loader):
154 | print(batch_idx)
155 | to_gpu(inputs)
156 |
157 | imgs, proj_mats, pose_mats = [], [], []
158 | for i in range(inputs['num_frame'][0].item()):
159 | imgs.append(inputs[('color', i, 0)])
160 | proj_mats.append(inputs[('proj', i)])
161 | pose_mats.append(inputs[('pose', i)])
162 |
163 | pose_mats[0] = pose_mats[0]*0.75
164 | pose_mats[1] = pose_mats[1]*0.75
165 | pose_mats[2] = pose_mats[2]*0.75
166 |
167 | outputs = model(imgs[0], imgs[1:], pose_mats[0], pose_mats[1:], inputs[('inv_K_pool', 0)])
168 |
169 | depth_gt = inputs[("depth_gt", 0, 0)][0].cpu().detach().numpy().squeeze()
170 | depth_gt = resize_depth_preserve(depth_gt, (608,960))
171 | depth_gt_path = os.path.join(root_path,'depth_gt','{}.png'.format(batch_idx))
172 | depth_gt_np = gray_2_colormap_np_2(depth_gt ,max = 120)[:,:,::-1]
173 |
174 | img0 = imgs[0]
175 |
176 |
177 | depth_pred = outputs[('depth_pred', 0)][0]
178 | depth_pred_2 = outputs[('depth_pred_2', 0)][0]
179 |
180 | depth_pred_np = gray_2_colormap_np(depth_pred ,max = 120)[:,:,::-1]
181 | depth_pred_2_np = gray_2_colormap_np(depth_pred_2 ,max = 120)[:,:,::-1]
182 | img0_path = os.path.join(root_path,'img0', '{}.png'.format(batch_idx))
183 | depth_1_path = os.path.join(root_path,'depth_1','{}.png'.format(batch_idx))
184 | depth_2_path = os.path.join(root_path,'depth_2', '{}.png'.format(batch_idx))
185 |
186 |
187 | img0_np = img0[0].cpu().detach().numpy().squeeze().transpose(1,2,0)
188 | img0_np = (img0_np / img0_np.max() * 255).astype(np.uint8)
189 | cv2.imwrite(img0_path, img0_np)
190 | cv2.imwrite(depth_1_path, depth_pred_np)
191 | cv2.imwrite(depth_2_path, depth_pred_2_np)
192 | cv2.imwrite(depth_gt_path, depth_gt_np)
193 |
194 |
195 |
196 |
197 |
198 |
199 | # a = input('input some')
200 | # print(a)
201 |
202 | # break
203 |
--------------------------------------------------------------------------------