├── LICENSE ├── README.md ├── assets ├── pointcloud2.png └── visual_compare.png ├── configs ├── DDAD.conf ├── base.conf └── kitti.conf ├── data_split ├── kitti_eigen_test.txt ├── kitti_eigen_train.txt └── kitti_eigen_val.txt ├── datasets ├── DDAD.py ├── DDAD_crop.py ├── DDAD_forward.py ├── __pycache__ │ ├── DDAD.cpython-37.pyc │ ├── DDAD_crop.cpython-37.pyc │ ├── DDAD_forward.cpython-37.pyc │ ├── kitti.cpython-37.pyc │ ├── kitti.cpython-39.pyc │ ├── kitti_odometry.cpython-37.pyc │ └── kitti_odometry.cpython-39.pyc ├── kitti.py └── kitti_odometry.py ├── eval_ddad.py ├── eval_kitti.py ├── generate_dynamic_mask.py ├── hybrid_evaluate_depth.py ├── networks ├── AFNet.py ├── __init__.py ├── __pycache__ │ ├── AFNet.cpython-37.pyc │ ├── AFNet.cpython-39.pyc │ ├── AFNet_efficient.cpython-37.pyc │ ├── AFNet_main.cpython-37.pyc │ ├── AFNet_mobile.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-39.pyc │ ├── module.cpython-37.pyc │ ├── module.cpython-39.pyc │ └── mvs2d.cpython-37.pyc ├── mobilenet.py └── module.py ├── options.py ├── requirements.txt ├── scripts ├── test.sh └── train.sh ├── train_af.py ├── train_kitti.py ├── trainer_base_af.py ├── trainer_base_kitti.py ├── utils.py └── visual_ddad.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 cjd24-coder 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

AFNet: Adaptive Fusion of Single-View and Multi-View Depth for Autonomous Driving

3 |

**CVPR 2024**

4 | 5 | [Paper](https://arxiv.org/pdf/2403.07535.pdf) 6 |

7 | 8 | This work presents AFNet, a new multi-view and singleview depth fusion network AFNet for alleviating the defects of the existing multi-view methods, which will fail under noisy poses in real-world autonomous driving scenarios. 9 | 10 | ![teaser](assets/pointcloud2.png) 11 | 12 | 13 | ## ✏️ Changelog 14 | ### Mar. 20 2024 15 | * Initial release. Due to the confidentiality agreement, the accuracy of the current reproduced model on KITTI is very slightly different from that in the paper. We release an initial version first, and the final version will be released soon. 16 | 17 | * In addition, the models trained under noise pose will soon be released. 18 | 19 | 20 | ## ⚙️ Installation 21 | 22 | The code is tested with CUDA11.7. Please use the following commands to install dependencies: 23 | 24 | ``` 25 | conda create --name AFNet python=3.7 26 | conda activate AFNet 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## 🎬 Demo 31 | ![teaser](assets/visual_compare.png) 32 | 33 | 34 | ## ⏳ Training & Testing 35 | 36 | We use 4 Nvidia 3090 GPU for training. You may need to modify 'CUDA_VISIBLE_DEVICES' and batch size to accommodate your GPU resources. 37 | 38 | #### Training 39 | First download and extract DDAD and KITTI data and split. You should download and process DDAD dataset follow [DDAD🔗](https://github.com/TRI-ML/DDAD). 40 | #### Download 41 | [__split__ 🔗](https://1drv.ms/u/s!AtFfCZ2Ckf3DghYrBvQ-DWCQR1Nd?e=Q6qz8d) (You need to move this json file in split to the data_split path) 42 | [ models 🔗](https://1drv.ms/u/s!AtFfCZ2Ckf3DghVXXZY611mqxa8B?e=nZ7taR) (models for testing) 43 | 44 | Then run the following command to train our model. 45 | ``` 46 | bash scripts/train.sh 47 | ``` 48 | 49 | #### Testing 50 | First download and extract data, split and pretrained models. 51 | 52 | ### DDAD: 53 | run: 54 | ``` 55 | python eval_ddad.py --cfg "./configs/DDAD.conf" 56 | ``` 57 | 58 | You should get something like these: 59 | 60 | | abs_rel | sq_rel | log10 | rmse | rmse_log | a1 | a2 | a3 | abs_diff | 61 | |---------|--------|-------|-------|----------|-------|-------|-------|----------| 62 | | 0.088 | 0.979 | 0.035 | 4.60 | 0.154 | 0.917 | 0.972 | 0.987 | 2.042 | 63 | 64 | ### KITTI: 65 | run: 66 | ``` 67 | python eval_kitti.py --cfg "./configs/kitti.conf" 68 | ``` 69 | You should get something like these: 70 | 71 | | abs_rel | sq_rel | log10 | rmse | rmse_log | a1 | a2 | a3 | abs_diff | 72 | |---------|--------|-------|-------|----------|-------|-------|-------|----------| 73 | | 0.044 | 0.132 | 0.019 | 1.712 | 0.069 | 0.980 | 0.997 | 0.999 | 0.804 | 74 | 75 | 76 | #### Acknowledgement 77 | Thanks to Zhenpei Yang for opening source of his excellent works [MVS2D](https://github.com/zhenpeiyang/MVS2D?tab=readme-ov-file#nov-27-2021) 78 | 79 | ## Citation 80 | 81 | If you find this project useful, please consider citing: 82 | 83 | ```bibtex 84 | @misc{cheng2024adaptive, 85 | title={Adaptive Fusion of Single-View and Multi-View Depth for Autonomous Driving}, 86 | author={JunDa Cheng and Wei Yin and Kaixuan Wang and Xiaozhi Chen and Shijie Wang and Xin Yang}, 87 | year={2024}, 88 | eprint={2403.07535}, 89 | archivePrefix={arXiv}, 90 | primaryClass={cs.CV} 91 | } 92 | ``` 93 | 94 | 95 | -------------------------------------------------------------------------------- /assets/pointcloud2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/assets/pointcloud2.png -------------------------------------------------------------------------------- /assets/visual_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/assets/visual_compare.png -------------------------------------------------------------------------------- /configs/DDAD.conf: -------------------------------------------------------------------------------- 1 | include "./base.conf" 2 | dataset = DDAD 3 | data_path = /data/cjd/ddad/my_ddad/ 4 | height = 448 5 | width = 768 6 | eval_height = 1216 7 | eval_width = 1920 8 | min_depth = 1.0 9 | max_depth = 80.0 10 | EVAL_MIN_DEPTH = 1.0 11 | EVAL_MAX_DEPTH = 80.0 12 | use_test = 0 13 | num_frame = 3 14 | num_frame_test = 3 15 | perturb_pose = 0 16 | LR = 1e-4 17 | fullsize_eval = 0 18 | DECAY_STEP_LIST = [10, 20] 19 | pred_conf = 1 20 | use_skip = 1 21 | num_epochs = 30 22 | loss_d = L1 23 | batch_size = 4 24 | nlabel = 96 25 | max_conf = 0.06 26 | min_conf = 2e-4 27 | multiprocessing_distributed = 1 28 | monitor_key = [abs_rel, thre1] 29 | monitor_goal = [minimize, maximize] 30 | num_workers = 16 31 | log_dir = /data/cjd/AFnet/ 32 | model_name = AFNet_mobile -------------------------------------------------------------------------------- /configs/base.conf: -------------------------------------------------------------------------------- 1 | data_path = ./data/ScanNet 2 | log_dir = experiments/test/ 3 | model_name = test 4 | overwrite = True 5 | note = "" 6 | mode = train 7 | 8 | ## TRAINING OPTIONS 9 | num_workers = 4 10 | DECAY_STEP_LIST = [10, 20] 11 | LR = 2e-4 12 | LR_DECAY = 0.1 13 | LR_CLIP = 1e-6 14 | WEIGHT_DECAY = 0 15 | MOMENTUM = 0.9 16 | GRAD_NORM_CLIP = 10.0 17 | loss_d = L1 18 | launcher = pytorch 19 | tcp_port = 18891 20 | multiprocessing_distributed = 0 21 | dist_backend = nccl 22 | dist_url = tcp://127.0.0.1:1234 23 | num_epochs = 15 24 | monitor_key = "" 25 | optimizer = adam 26 | epoch_size = -1 27 | log_frequency = 10 28 | val_frequency = 100 29 | save_frequency = 1 30 | save_prediction = 0 31 | eval_vis = 0 32 | num_epoch = 30 33 | batch_size = 2 34 | 35 | ## MODEL OPTIONS 36 | pred_conf = 0 37 | use_skip = 1 38 | nlabel = 64 39 | mono = 0 40 | multi_view_agg = 1 41 | depth_embedding = learned 42 | robust = False 43 | unet_channel_mode = v0 44 | use_unet = True 45 | num_depth_regressor_anchor = 512 46 | att_rate = 4 47 | inv_depth = 1 48 | resnet_layers = [2,2,2,2] 49 | max_conf = -1 50 | clip_conf = True 51 | model_size = default 52 | fast = 0 53 | depth_head = default 54 | feat_dim = 32 55 | nhead = 1 56 | 57 | ## DATASET OPTIONS 58 | height = 640 59 | width = 480 60 | dataset = ScanNet 61 | input_scale = 0 62 | output_scale = 2 63 | min_depth = 0.1 64 | max_depth = 10.0 65 | seed = 0 66 | use_test = 0 67 | num_frame = 3 68 | perturb_pose = 0 69 | sceneID = None 70 | filter = None 71 | fullsize_eval = 0 72 | random_lighting = 0 73 | disable_color_aug = 0 74 | 75 | 76 | ## EVAL OPTIONS 77 | conf_thres = 3.5 78 | num_consistent = 1 79 | geo_pixel_thres = 0.75 80 | geo_depth_thres = 0.0001 81 | att_thres = 0.25 82 | EVAL_MIN_DEPTH = 0.5 83 | EVAL_MAX_DEPTH = 10.0 84 | disable_median_scaling = 1 85 | num_frame_test = 5 86 | val_epoch_size = -1 87 | 88 | benchmark_fps = False 89 | -------------------------------------------------------------------------------- /configs/kitti.conf: -------------------------------------------------------------------------------- 1 | include "./base.conf" 2 | dataset = kitti 3 | data_path = /data/cjd/kitti_raw/ 4 | height = 352 5 | width = 1216 6 | eval_height = 352 7 | eval_width = 1216 8 | min_depth = 1.0 9 | max_depth = 60.0 10 | EVAL_MIN_DEPTH = 1.0 11 | EVAL_MAX_DEPTH = 60.0 12 | use_test = 0 13 | num_frame = 3 14 | num_frame_test = 3 15 | perturb_pose = 0 16 | LR = 1e-4 17 | fullsize_eval = 0 18 | DECAY_STEP_LIST = [10, 20] 19 | pred_conf = 1 20 | use_skip = 1 21 | num_epochs = 50 22 | loss_d = L1 23 | batch_size = 3 24 | nlabel = 96 25 | max_conf = 0.06 26 | min_conf = 2e-4 27 | multiprocessing_distributed = 1 28 | monitor_key = [abs_rel, thre1] 29 | monitor_goal = [minimize, maximize] 30 | num_workers = 16 31 | log_dir = /data/cjd/AFnet/ 32 | model_name = AF_kitti3 -------------------------------------------------------------------------------- /data_split/kitti_eigen_test.txt: -------------------------------------------------------------------------------- 1 | 2011_09_26 0002 val 69 2 | 2011_09_26 0002 val 54 3 | 2011_09_26 0002 val 42 4 | 2011_09_26 0002 val 57 5 | 2011_09_26 0002 val 30 6 | 2011_09_26 0002 val 27 7 | 2011_09_26 0002 val 12 8 | 2011_09_26 0002 val 36 9 | 2011_09_26 0002 val 33 10 | 2011_09_26 0002 val 15 11 | 2011_09_26 0002 val 39 12 | 2011_09_26 0002 val 9 13 | 2011_09_26 0002 val 51 14 | 2011_09_26 0002 val 60 15 | 2011_09_26 0002 val 21 16 | 2011_09_26 0002 val 24 17 | 2011_09_26 0002 val 45 18 | 2011_09_26 0002 val 18 19 | 2011_09_26 0002 val 48 20 | 2011_09_26 0002 val 6 21 | 2011_09_26 0002 val 63 22 | 2011_09_26 0009 train 16 23 | 2011_09_26 0009 train 32 24 | 2011_09_26 0009 train 48 25 | 2011_09_26 0009 train 64 26 | 2011_09_26 0009 train 80 27 | 2011_09_26 0009 train 96 28 | 2011_09_26 0009 train 112 29 | 2011_09_26 0009 train 128 30 | 2011_09_26 0009 train 144 31 | 2011_09_26 0009 train 160 32 | 2011_09_26 0009 train 176 33 | 2011_09_26 0009 train 196 34 | 2011_09_26 0009 train 212 35 | 2011_09_26 0009 train 228 36 | 2011_09_26 0009 train 244 37 | 2011_09_26 0009 train 260 38 | 2011_09_26 0009 train 276 39 | 2011_09_26 0009 train 292 40 | 2011_09_26 0009 train 308 41 | 2011_09_26 0009 train 324 42 | 2011_09_26 0009 train 340 43 | 2011_09_26 0009 train 356 44 | 2011_09_26 0009 train 372 45 | 2011_09_26 0009 train 388 46 | 2011_09_26 0013 val 90 47 | 2011_09_26 0013 val 50 48 | 2011_09_26 0013 val 110 49 | 2011_09_26 0013 val 115 50 | 2011_09_26 0013 val 60 51 | 2011_09_26 0013 val 105 52 | 2011_09_26 0013 val 125 53 | 2011_09_26 0013 val 20 54 | 2011_09_26 0013 val 85 55 | 2011_09_26 0013 val 70 56 | 2011_09_26 0013 val 80 57 | 2011_09_26 0013 val 65 58 | 2011_09_26 0013 val 95 59 | 2011_09_26 0013 val 130 60 | 2011_09_26 0013 val 100 61 | 2011_09_26 0013 val 10 62 | 2011_09_26 0013 val 30 63 | 2011_09_26 0013 val 135 64 | 2011_09_26 0013 val 40 65 | 2011_09_26 0013 val 5 66 | 2011_09_26 0013 val 120 67 | 2011_09_26 0013 val 45 68 | 2011_09_26 0013 val 35 69 | 2011_09_26 0020 val 69 70 | 2011_09_26 0020 val 57 71 | 2011_09_26 0020 val 12 72 | 2011_09_26 0020 val 72 73 | 2011_09_26 0020 val 18 74 | 2011_09_26 0020 val 63 75 | 2011_09_26 0020 val 15 76 | 2011_09_26 0020 val 66 77 | 2011_09_26 0020 val 6 78 | 2011_09_26 0020 val 48 79 | 2011_09_26 0020 val 60 80 | 2011_09_26 0020 val 9 81 | 2011_09_26 0020 val 33 82 | 2011_09_26 0020 val 21 83 | 2011_09_26 0020 val 75 84 | 2011_09_26 0020 val 27 85 | 2011_09_26 0020 val 45 86 | 2011_09_26 0020 val 78 87 | 2011_09_26 0020 val 36 88 | 2011_09_26 0020 val 51 89 | 2011_09_26 0020 val 54 90 | 2011_09_26 0020 val 42 91 | 2011_09_26 0023 val 18 92 | 2011_09_26 0023 val 90 93 | 2011_09_26 0023 val 126 94 | 2011_09_26 0023 val 378 95 | 2011_09_26 0023 val 36 96 | 2011_09_26 0023 val 288 97 | 2011_09_26 0023 val 198 98 | 2011_09_26 0023 val 450 99 | 2011_09_26 0023 val 144 100 | 2011_09_26 0023 val 72 101 | 2011_09_26 0023 val 252 102 | 2011_09_26 0023 val 180 103 | 2011_09_26 0023 val 432 104 | 2011_09_26 0023 val 396 105 | 2011_09_26 0023 val 54 106 | 2011_09_26 0023 val 468 107 | 2011_09_26 0023 val 306 108 | 2011_09_26 0023 val 108 109 | 2011_09_26 0023 val 162 110 | 2011_09_26 0023 val 342 111 | 2011_09_26 0023 val 270 112 | 2011_09_26 0023 val 414 113 | 2011_09_26 0023 val 216 114 | 2011_09_26 0023 val 360 115 | 2011_09_26 0023 val 324 116 | 2011_09_26 0027 train 77 117 | 2011_09_26 0027 train 35 118 | 2011_09_26 0027 train 91 119 | 2011_09_26 0027 train 112 120 | 2011_09_26 0027 train 7 121 | 2011_09_26 0027 train 175 122 | 2011_09_26 0027 train 42 123 | 2011_09_26 0027 train 98 124 | 2011_09_26 0027 train 133 125 | 2011_09_26 0027 train 161 126 | 2011_09_26 0027 train 14 127 | 2011_09_26 0027 train 126 128 | 2011_09_26 0027 train 168 129 | 2011_09_26 0027 train 70 130 | 2011_09_26 0027 train 84 131 | 2011_09_26 0027 train 140 132 | 2011_09_26 0027 train 49 133 | 2011_09_26 0027 train 182 134 | 2011_09_26 0027 train 147 135 | 2011_09_26 0027 train 56 136 | 2011_09_26 0027 train 63 137 | 2011_09_26 0027 train 21 138 | 2011_09_26 0027 train 119 139 | 2011_09_26 0027 train 28 140 | 2011_09_26 0029 train 380 141 | 2011_09_26 0029 train 394 142 | 2011_09_26 0029 train 324 143 | 2011_09_26 0029 train 268 144 | 2011_09_26 0029 train 366 145 | 2011_09_26 0029 train 296 146 | 2011_09_26 0029 train 14 147 | 2011_09_26 0029 train 28 148 | 2011_09_26 0029 train 182 149 | 2011_09_26 0029 train 168 150 | 2011_09_26 0029 train 196 151 | 2011_09_26 0029 train 140 152 | 2011_09_26 0029 train 84 153 | 2011_09_26 0029 train 56 154 | 2011_09_26 0029 train 112 155 | 2011_09_26 0029 train 352 156 | 2011_09_26 0029 train 126 157 | 2011_09_26 0029 train 70 158 | 2011_09_26 0029 train 310 159 | 2011_09_26 0029 train 154 160 | 2011_09_26 0029 train 98 161 | 2011_09_26 0029 train 408 162 | 2011_09_26 0029 train 42 163 | 2011_09_26 0029 train 338 164 | 2011_09_26 0036 val 128 165 | 2011_09_26 0036 val 192 166 | 2011_09_26 0036 val 32 167 | 2011_09_26 0036 val 352 168 | 2011_09_26 0036 val 608 169 | 2011_09_26 0036 val 224 170 | 2011_09_26 0036 val 576 171 | 2011_09_26 0036 val 672 172 | 2011_09_26 0036 val 64 173 | 2011_09_26 0036 val 448 174 | 2011_09_26 0036 val 704 175 | 2011_09_26 0036 val 640 176 | 2011_09_26 0036 val 512 177 | 2011_09_26 0036 val 768 178 | 2011_09_26 0036 val 160 179 | 2011_09_26 0036 val 416 180 | 2011_09_26 0036 val 480 181 | 2011_09_26 0036 val 288 182 | 2011_09_26 0036 val 544 183 | 2011_09_26 0036 val 96 184 | 2011_09_26 0036 val 384 185 | 2011_09_26 0036 val 256 186 | 2011_09_26 0036 val 320 187 | 2011_09_26 0046 train 5 188 | 2011_09_26 0046 train 10 189 | 2011_09_26 0046 train 15 190 | 2011_09_26 0046 train 20 191 | 2011_09_26 0046 train 25 192 | 2011_09_26 0046 train 30 193 | 2011_09_26 0046 train 35 194 | 2011_09_26 0046 train 40 195 | 2011_09_26 0046 train 45 196 | 2011_09_26 0046 train 50 197 | 2011_09_26 0046 train 55 198 | 2011_09_26 0046 train 60 199 | 2011_09_26 0046 train 65 200 | 2011_09_26 0046 train 70 201 | 2011_09_26 0046 train 75 202 | 2011_09_26 0046 train 80 203 | 2011_09_26 0046 train 85 204 | 2011_09_26 0046 train 90 205 | 2011_09_26 0046 train 95 206 | 2011_09_26 0046 train 100 207 | 2011_09_26 0046 train 105 208 | 2011_09_26 0046 train 110 209 | 2011_09_26 0046 train 115 210 | 2011_09_26 0048 train 5 211 | 2011_09_26 0048 train 6 212 | 2011_09_26 0048 train 7 213 | 2011_09_26 0048 train 8 214 | 2011_09_26 0048 train 9 215 | 2011_09_26 0048 train 10 216 | 2011_09_26 0048 train 11 217 | 2011_09_26 0048 train 12 218 | 2011_09_26 0048 train 13 219 | 2011_09_26 0048 train 14 220 | 2011_09_26 0048 train 15 221 | 2011_09_26 0048 train 16 222 | 2011_09_26 0052 train 46 223 | 2011_09_26 0052 train 14 224 | 2011_09_26 0052 train 36 225 | 2011_09_26 0052 train 28 226 | 2011_09_26 0052 train 26 227 | 2011_09_26 0052 train 50 228 | 2011_09_26 0052 train 40 229 | 2011_09_26 0052 train 8 230 | 2011_09_26 0052 train 16 231 | 2011_09_26 0052 train 44 232 | 2011_09_26 0052 train 18 233 | 2011_09_26 0052 train 32 234 | 2011_09_26 0052 train 42 235 | 2011_09_26 0052 train 10 236 | 2011_09_26 0052 train 20 237 | 2011_09_26 0052 train 48 238 | 2011_09_26 0052 train 52 239 | 2011_09_26 0052 train 6 240 | 2011_09_26 0052 train 30 241 | 2011_09_26 0052 train 12 242 | 2011_09_26 0052 train 38 243 | 2011_09_26 0052 train 22 244 | 2011_09_26 0056 train 11 245 | 2011_09_26 0056 train 33 246 | 2011_09_26 0056 train 242 247 | 2011_09_26 0056 train 253 248 | 2011_09_26 0056 train 286 249 | 2011_09_26 0056 train 154 250 | 2011_09_26 0056 train 99 251 | 2011_09_26 0056 train 220 252 | 2011_09_26 0056 train 22 253 | 2011_09_26 0056 train 77 254 | 2011_09_26 0056 train 187 255 | 2011_09_26 0056 train 143 256 | 2011_09_26 0056 train 66 257 | 2011_09_26 0056 train 176 258 | 2011_09_26 0056 train 110 259 | 2011_09_26 0056 train 275 260 | 2011_09_26 0056 train 264 261 | 2011_09_26 0056 train 198 262 | 2011_09_26 0056 train 55 263 | 2011_09_26 0056 train 88 264 | 2011_09_26 0056 train 121 265 | 2011_09_26 0056 train 209 266 | 2011_09_26 0056 train 165 267 | 2011_09_26 0056 train 231 268 | 2011_09_26 0056 train 44 269 | 2011_09_26 0059 train 56 270 | 2011_09_26 0059 train 344 271 | 2011_09_26 0059 train 358 272 | 2011_09_26 0059 train 316 273 | 2011_09_26 0059 train 238 274 | 2011_09_26 0059 train 98 275 | 2011_09_26 0059 train 112 276 | 2011_09_26 0059 train 28 277 | 2011_09_26 0059 train 14 278 | 2011_09_26 0059 train 330 279 | 2011_09_26 0059 train 154 280 | 2011_09_26 0059 train 42 281 | 2011_09_26 0059 train 302 282 | 2011_09_26 0059 train 182 283 | 2011_09_26 0059 train 288 284 | 2011_09_26 0059 train 140 285 | 2011_09_26 0059 train 274 286 | 2011_09_26 0059 train 224 287 | 2011_09_26 0059 train 196 288 | 2011_09_26 0059 train 126 289 | 2011_09_26 0059 train 84 290 | 2011_09_26 0059 train 210 291 | 2011_09_26 0059 train 70 292 | 2011_09_26 0064 train 528 293 | 2011_09_26 0064 train 308 294 | 2011_09_26 0064 train 44 295 | 2011_09_26 0064 train 352 296 | 2011_09_26 0064 train 66 297 | 2011_09_26 0064 train 506 298 | 2011_09_26 0064 train 176 299 | 2011_09_26 0064 train 22 300 | 2011_09_26 0064 train 242 301 | 2011_09_26 0064 train 462 302 | 2011_09_26 0064 train 418 303 | 2011_09_26 0064 train 110 304 | 2011_09_26 0064 train 440 305 | 2011_09_26 0064 train 396 306 | 2011_09_26 0064 train 154 307 | 2011_09_26 0064 train 374 308 | 2011_09_26 0064 train 88 309 | 2011_09_26 0064 train 286 310 | 2011_09_26 0064 train 550 311 | 2011_09_26 0064 train 264 312 | 2011_09_26 0064 train 220 313 | 2011_09_26 0064 train 330 314 | 2011_09_26 0064 train 484 315 | 2011_09_26 0064 train 198 316 | 2011_09_26 0084 train 283 317 | 2011_09_26 0084 train 361 318 | 2011_09_26 0084 train 270 319 | 2011_09_26 0084 train 127 320 | 2011_09_26 0084 train 205 321 | 2011_09_26 0084 train 218 322 | 2011_09_26 0084 train 153 323 | 2011_09_26 0084 train 335 324 | 2011_09_26 0084 train 192 325 | 2011_09_26 0084 train 348 326 | 2011_09_26 0084 train 101 327 | 2011_09_26 0084 train 49 328 | 2011_09_26 0084 train 179 329 | 2011_09_26 0084 train 140 330 | 2011_09_26 0084 train 374 331 | 2011_09_26 0084 train 322 332 | 2011_09_26 0084 train 309 333 | 2011_09_26 0084 train 244 334 | 2011_09_26 0084 train 62 335 | 2011_09_26 0084 train 257 336 | 2011_09_26 0084 train 88 337 | 2011_09_26 0084 train 114 338 | 2011_09_26 0084 train 75 339 | 2011_09_26 0084 train 296 340 | 2011_09_26 0084 train 231 341 | 2011_09_26 0086 train 7 342 | 2011_09_26 0086 train 196 343 | 2011_09_26 0086 train 439 344 | 2011_09_26 0086 train 169 345 | 2011_09_26 0086 train 115 346 | 2011_09_26 0086 train 34 347 | 2011_09_26 0086 train 304 348 | 2011_09_26 0086 train 331 349 | 2011_09_26 0086 train 277 350 | 2011_09_26 0086 train 520 351 | 2011_09_26 0086 train 682 352 | 2011_09_26 0086 train 628 353 | 2011_09_26 0086 train 88 354 | 2011_09_26 0086 train 601 355 | 2011_09_26 0086 train 574 356 | 2011_09_26 0086 train 223 357 | 2011_09_26 0086 train 655 358 | 2011_09_26 0086 train 358 359 | 2011_09_26 0086 train 412 360 | 2011_09_26 0086 train 142 361 | 2011_09_26 0086 train 385 362 | 2011_09_26 0086 train 61 363 | 2011_09_26 0086 train 493 364 | 2011_09_26 0086 train 466 365 | 2011_09_26 0086 train 250 366 | 2011_09_26 0093 train 16 367 | 2011_09_26 0093 train 32 368 | 2011_09_26 0093 train 48 369 | 2011_09_26 0093 train 64 370 | 2011_09_26 0093 train 80 371 | 2011_09_26 0093 train 96 372 | 2011_09_26 0093 train 112 373 | 2011_09_26 0093 train 128 374 | 2011_09_26 0093 train 144 375 | 2011_09_26 0093 train 160 376 | 2011_09_26 0093 train 176 377 | 2011_09_26 0093 train 192 378 | 2011_09_26 0093 train 208 379 | 2011_09_26 0093 train 224 380 | 2011_09_26 0093 train 240 381 | 2011_09_26 0093 train 256 382 | 2011_09_26 0093 train 305 383 | 2011_09_26 0093 train 321 384 | 2011_09_26 0093 train 337 385 | 2011_09_26 0093 train 353 386 | 2011_09_26 0093 train 369 387 | 2011_09_26 0093 train 385 388 | 2011_09_26 0093 train 401 389 | 2011_09_26 0093 train 417 390 | 2011_09_26 0096 train 19 391 | 2011_09_26 0096 train 38 392 | 2011_09_26 0096 train 57 393 | 2011_09_26 0096 train 76 394 | 2011_09_26 0096 train 95 395 | 2011_09_26 0096 train 114 396 | 2011_09_26 0096 train 133 397 | 2011_09_26 0096 train 152 398 | 2011_09_26 0096 train 171 399 | 2011_09_26 0096 train 190 400 | 2011_09_26 0096 train 209 401 | 2011_09_26 0096 train 228 402 | 2011_09_26 0096 train 247 403 | 2011_09_26 0096 train 266 404 | 2011_09_26 0096 train 285 405 | 2011_09_26 0096 train 304 406 | 2011_09_26 0096 train 323 407 | 2011_09_26 0096 train 342 408 | 2011_09_26 0096 train 361 409 | 2011_09_26 0096 train 380 410 | 2011_09_26 0096 train 399 411 | 2011_09_26 0096 train 418 412 | 2011_09_26 0096 train 437 413 | 2011_09_26 0096 train 456 414 | 2011_09_26 0101 train 692 415 | 2011_09_26 0101 train 930 416 | 2011_09_26 0101 train 760 417 | 2011_09_26 0101 train 896 418 | 2011_09_26 0101 train 284 419 | 2011_09_26 0101 train 148 420 | 2011_09_26 0101 train 522 421 | 2011_09_26 0101 train 794 422 | 2011_09_26 0101 train 624 423 | 2011_09_26 0101 train 726 424 | 2011_09_26 0101 train 216 425 | 2011_09_26 0101 train 318 426 | 2011_09_26 0101 train 488 427 | 2011_09_26 0101 train 590 428 | 2011_09_26 0101 train 454 429 | 2011_09_26 0101 train 862 430 | 2011_09_26 0101 train 386 431 | 2011_09_26 0101 train 352 432 | 2011_09_26 0101 train 420 433 | 2011_09_26 0101 train 658 434 | 2011_09_26 0101 train 828 435 | 2011_09_26 0101 train 556 436 | 2011_09_26 0101 train 114 437 | 2011_09_26 0101 train 182 438 | 2011_09_26 0101 train 80 439 | 2011_09_26 0106 train 15 440 | 2011_09_26 0106 train 35 441 | 2011_09_26 0106 train 43 442 | 2011_09_26 0106 train 51 443 | 2011_09_26 0106 train 59 444 | 2011_09_26 0106 train 67 445 | 2011_09_26 0106 train 75 446 | 2011_09_26 0106 train 83 447 | 2011_09_26 0106 train 91 448 | 2011_09_26 0106 train 99 449 | 2011_09_26 0106 train 107 450 | 2011_09_26 0106 train 115 451 | 2011_09_26 0106 train 123 452 | 2011_09_26 0106 train 131 453 | 2011_09_26 0106 train 139 454 | 2011_09_26 0106 train 147 455 | 2011_09_26 0106 train 155 456 | 2011_09_26 0106 train 163 457 | 2011_09_26 0106 train 171 458 | 2011_09_26 0106 train 179 459 | 2011_09_26 0106 train 187 460 | 2011_09_26 0106 train 195 461 | 2011_09_26 0106 train 203 462 | 2011_09_26 0106 train 211 463 | 2011_09_26 0106 train 219 464 | 2011_09_26 0117 train 312 465 | 2011_09_26 0117 train 494 466 | 2011_09_26 0117 train 104 467 | 2011_09_26 0117 train 130 468 | 2011_09_26 0117 train 156 469 | 2011_09_26 0117 train 182 470 | 2011_09_26 0117 train 598 471 | 2011_09_26 0117 train 416 472 | 2011_09_26 0117 train 364 473 | 2011_09_26 0117 train 26 474 | 2011_09_26 0117 train 78 475 | 2011_09_26 0117 train 572 476 | 2011_09_26 0117 train 468 477 | 2011_09_26 0117 train 260 478 | 2011_09_26 0117 train 624 479 | 2011_09_26 0117 train 234 480 | 2011_09_26 0117 train 442 481 | 2011_09_26 0117 train 390 482 | 2011_09_26 0117 train 546 483 | 2011_09_26 0117 train 286 484 | 2011_09_26 0117 train 338 485 | 2011_09_26 0117 train 208 486 | 2011_09_26 0117 train 650 487 | 2011_09_26 0117 train 52 488 | 2011_09_28 0002 train 24 489 | 2011_09_28 0002 train 21 490 | 2011_09_28 0002 train 36 491 | 2011_09_28 0002 train 51 492 | 2011_09_28 0002 train 18 493 | 2011_09_28 0002 train 33 494 | 2011_09_28 0002 train 90 495 | 2011_09_28 0002 train 45 496 | 2011_09_28 0002 train 54 497 | 2011_09_28 0002 train 12 498 | 2011_09_28 0002 train 39 499 | 2011_09_28 0002 train 9 500 | 2011_09_28 0002 train 30 501 | 2011_09_28 0002 train 78 502 | 2011_09_28 0002 train 60 503 | 2011_09_28 0002 train 48 504 | 2011_09_28 0002 train 84 505 | 2011_09_28 0002 train 81 506 | 2011_09_28 0002 train 6 507 | 2011_09_28 0002 train 57 508 | 2011_09_28 0002 train 72 509 | 2011_09_28 0002 train 87 510 | 2011_09_28 0002 train 63 511 | 2011_09_29 0071 train 252 512 | 2011_09_29 0071 train 540 513 | 2011_09_29 0071 train 36 514 | 2011_09_29 0071 train 360 515 | 2011_09_29 0071 train 807 516 | 2011_09_29 0071 train 879 517 | 2011_09_29 0071 train 288 518 | 2011_09_29 0071 train 771 519 | 2011_09_29 0071 train 216 520 | 2011_09_29 0071 train 951 521 | 2011_09_29 0071 train 324 522 | 2011_09_29 0071 train 432 523 | 2011_09_29 0071 train 504 524 | 2011_09_29 0071 train 576 525 | 2011_09_29 0071 train 108 526 | 2011_09_29 0071 train 180 527 | 2011_09_29 0071 train 72 528 | 2011_09_29 0071 train 612 529 | 2011_09_29 0071 train 915 530 | 2011_09_29 0071 train 735 531 | 2011_09_29 0071 train 144 532 | 2011_09_29 0071 train 396 533 | 2011_09_29 0071 train 468 534 | 2011_09_30 0016 val 132 535 | 2011_09_30 0016 val 11 536 | 2011_09_30 0016 val 154 537 | 2011_09_30 0016 val 22 538 | 2011_09_30 0016 val 242 539 | 2011_09_30 0016 val 198 540 | 2011_09_30 0016 val 176 541 | 2011_09_30 0016 val 231 542 | 2011_09_30 0016 val 220 543 | 2011_09_30 0016 val 88 544 | 2011_09_30 0016 val 143 545 | 2011_09_30 0016 val 55 546 | 2011_09_30 0016 val 33 547 | 2011_09_30 0016 val 187 548 | 2011_09_30 0016 val 110 549 | 2011_09_30 0016 val 44 550 | 2011_09_30 0016 val 77 551 | 2011_09_30 0016 val 66 552 | 2011_09_30 0016 val 165 553 | 2011_09_30 0016 val 264 554 | 2011_09_30 0016 val 253 555 | 2011_09_30 0016 val 209 556 | 2011_09_30 0016 val 121 557 | 2011_09_30 0018 train 107 558 | 2011_09_30 0018 train 2247 559 | 2011_09_30 0018 train 1391 560 | 2011_09_30 0018 train 535 561 | 2011_09_30 0018 train 1819 562 | 2011_09_30 0018 train 1177 563 | 2011_09_30 0018 train 428 564 | 2011_09_30 0018 train 1926 565 | 2011_09_30 0018 train 749 566 | 2011_09_30 0018 train 1284 567 | 2011_09_30 0018 train 2140 568 | 2011_09_30 0018 train 1605 569 | 2011_09_30 0018 train 1498 570 | 2011_09_30 0018 train 642 571 | 2011_09_30 0018 train 2740 572 | 2011_09_30 0018 train 2419 573 | 2011_09_30 0018 train 856 574 | 2011_09_30 0018 train 2526 575 | 2011_09_30 0018 train 1712 576 | 2011_09_30 0018 train 1070 577 | 2011_09_30 0018 train 2033 578 | 2011_09_30 0018 train 214 579 | 2011_09_30 0018 train 963 580 | 2011_09_30 0018 train 2633 581 | 2011_09_30 0027 train 533 582 | 2011_09_30 0027 train 1040 583 | 2011_09_30 0027 train 82 584 | 2011_09_30 0027 train 205 585 | 2011_09_30 0027 train 835 586 | 2011_09_30 0027 train 451 587 | 2011_09_30 0027 train 164 588 | 2011_09_30 0027 train 794 589 | 2011_09_30 0027 train 328 590 | 2011_09_30 0027 train 615 591 | 2011_09_30 0027 train 917 592 | 2011_09_30 0027 train 369 593 | 2011_09_30 0027 train 287 594 | 2011_09_30 0027 train 123 595 | 2011_09_30 0027 train 876 596 | 2011_09_30 0027 train 410 597 | 2011_09_30 0027 train 492 598 | 2011_09_30 0027 train 958 599 | 2011_09_30 0027 train 656 600 | 2011_09_30 0027 train 753 601 | 2011_09_30 0027 train 574 602 | 2011_09_30 0027 train 1081 603 | 2011_09_30 0027 train 41 604 | 2011_09_30 0027 train 246 605 | 2011_10_03 0027 train 2906 606 | 2011_10_03 0027 train 2544 607 | 2011_10_03 0027 train 362 608 | 2011_10_03 0027 train 4535 609 | 2011_10_03 0027 train 734 610 | 2011_10_03 0027 train 1096 611 | 2011_10_03 0027 train 4173 612 | 2011_10_03 0027 train 543 613 | 2011_10_03 0027 train 1277 614 | 2011_10_03 0027 train 4354 615 | 2011_10_03 0027 train 1458 616 | 2011_10_03 0027 train 1820 617 | 2011_10_03 0027 train 3449 618 | 2011_10_03 0027 train 3268 619 | 2011_10_03 0027 train 915 620 | 2011_10_03 0027 train 2363 621 | 2011_10_03 0027 train 2725 622 | 2011_10_03 0027 train 181 623 | 2011_10_03 0027 train 1639 624 | 2011_10_03 0027 train 3992 625 | 2011_10_03 0027 train 3087 626 | 2011_10_03 0027 train 2001 627 | 2011_10_03 0027 train 3811 628 | 2011_10_03 0027 train 3630 629 | 2011_10_03 0047 val 96 630 | 2011_10_03 0047 val 800 631 | 2011_10_03 0047 val 320 632 | 2011_10_03 0047 val 576 633 | 2011_10_03 0047 val 480 634 | 2011_10_03 0047 val 640 635 | 2011_10_03 0047 val 32 636 | 2011_10_03 0047 val 384 637 | 2011_10_03 0047 val 160 638 | 2011_10_03 0047 val 704 639 | 2011_10_03 0047 val 736 640 | 2011_10_03 0047 val 672 641 | 2011_10_03 0047 val 64 642 | 2011_10_03 0047 val 288 643 | 2011_10_03 0047 val 352 644 | 2011_10_03 0047 val 512 645 | 2011_10_03 0047 val 544 646 | 2011_10_03 0047 val 608 647 | 2011_10_03 0047 val 128 648 | 2011_10_03 0047 val 224 649 | 2011_10_03 0047 val 416 650 | 2011_10_03 0047 val 192 651 | 2011_10_03 0047 val 448 652 | 2011_10_03 0047 val 768 -------------------------------------------------------------------------------- /datasets/DDAD.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import os 3 | import random 4 | import numpy as np 5 | import copy 6 | from PIL import Image # using pillow-simd for increased speed 7 | from time import time 8 | import torch 9 | import torch.utils.data as data 10 | from torchvision import transforms 11 | import cv2 12 | 13 | cv2.setNumThreads(0) 14 | import glob 15 | import utils 16 | import torch.nn.functional as F 17 | from utils import npy 18 | import json 19 | 20 | 21 | 22 | class DDAD(data.Dataset): 23 | def __init__(self, opt, is_train): 24 | super(DDAD, self).__init__() 25 | self.opt = opt 26 | self.json_path = "data_split/DDAD_video.json" 27 | self.data_path_root = self.opt.data_path 28 | self.is_train = is_train 29 | if self.is_train: 30 | self.data_path = os.path.join(self.data_path_root, 'train/') 31 | else: 32 | self.data_path = os.path.join(self.data_path_root, 'val/') 33 | 34 | f = open(self.json_path, 'r') 35 | content_all = f.read() 36 | json_list_all = json.loads(content_all) 37 | f.close() 38 | 39 | if self.is_train: 40 | self.file_names = json_list_all["train"] 41 | print('train', len(self.file_names)) 42 | else: 43 | self.file_names = json_list_all["val"] 44 | print('val', len(self.file_names)) 45 | 46 | print('filter_pre', len(self.file_names)) 47 | self.file_names = [x for x in self.file_names if 'timestamp' in x.keys() and 'timestamp_back' in x.keys() and 'timestamp_forward' in x.keys()] 48 | print('filter_after', len(self.file_names)) 49 | 50 | 51 | 52 | def get_k_ori_randomcrop(self, k_raw, inputs, x1, y1): 53 | 54 | fx_ori = k_raw[0,0] 55 | fy_ori = k_raw[1,1] 56 | fx_virtual = 1060.0 57 | fx_scale = fx_ori / fx_virtual 58 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale 59 | 60 | pose_cur = inputs[("pose", 0)] 61 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale 62 | inputs[("pose", 0)] = pose_cur 63 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)]) 64 | 65 | pose_pre = inputs[("pose", 1)] 66 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale 67 | inputs[("pose", 1)] = pose_pre 68 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)]) 69 | 70 | pose_next = inputs[("pose", 2)] 71 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale 72 | inputs[("pose", 2)] = pose_next 73 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)]) 74 | 75 | K = np.zeros((3,3), dtype = float) 76 | K[0,0] = fx_ori 77 | K[1,1] = fy_ori 78 | K[2,2] = 1.0 79 | K[0,2] = k_raw[0,2] 80 | K[1,2] = k_raw[1,2] 81 | 82 | h_crop = y1 - self.opt.height 83 | w_crop = x1 84 | 85 | K[0,2] = K[0,2] - w_crop 86 | K[1,2] = K[1,2] - h_crop 87 | 88 | return K, inputs 89 | 90 | def get_k_ori_centercrop(self, k_raw, inputs): 91 | 92 | fx_ori = k_raw[0,0] 93 | fy_ori = k_raw[1,1] 94 | fx_virtual = 1060.0 95 | fx_scale = fx_ori / fx_virtual 96 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale 97 | 98 | pose_cur = inputs[("pose", 0)] 99 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale 100 | inputs[("pose", 0)] = pose_cur 101 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)]) 102 | 103 | pose_pre = inputs[("pose", 1)] 104 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale 105 | inputs[("pose", 1)] = pose_pre 106 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)]) 107 | 108 | pose_next = inputs[("pose", 2)] 109 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale 110 | inputs[("pose", 2)] = pose_next 111 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)]) 112 | 113 | K = np.zeros((3,3), dtype = float) 114 | K[0,0] = fx_ori 115 | K[1,1] = fy_ori 116 | K[2,2] = 1.0 117 | K[0,2] = k_raw[0,2] 118 | K[1,2] = k_raw[1,2] 119 | 120 | h_crop = 0.0 121 | w_crop = 8.0 122 | 123 | K[0,2] = K[0,2] - w_crop 124 | K[1,2] = K[1,2] - h_crop 125 | 126 | return K, inputs 127 | 128 | 129 | def __len__(self): 130 | return len(self.file_names) 131 | 132 | 133 | def __getitem__(self, index): 134 | inputs = {} 135 | cur_npz_path = self.data_path + str(self.file_names[index]['timestamp']) + '_' + self.file_names[index]['Camera'] + '.npz' 136 | pre_npz_path = self.data_path + str(self.file_names[index]['timestamp_back']) + '_' + self.file_names[index]['Camera'] + '.npz' 137 | next_npz_path = self.data_path + str(self.file_names[index]['timestamp_forward']) + '_' + self.file_names[index]['Camera'] + '.npz' 138 | 139 | file_cur = np.load(cur_npz_path) 140 | file_pre = np.load(pre_npz_path) 141 | file_next = np.load(next_npz_path) 142 | 143 | depth_cur_gt = file_cur['depth'] 144 | depth_cur_gt = np.array(depth_cur_gt).astype(np.float32) 145 | 146 | 147 | inputs[("depth_gt", 0, 0)] = depth_cur_gt 148 | 149 | 150 | if self.is_train: 151 | if random.randint(0, 10) < 8: 152 | y_center = int((1216 + self.opt.height)/2) 153 | y1 = random.randint(int(y_center - 70), int(y_center + 50)) 154 | else: 155 | y1 = random.randint(self.opt.height, 1216) 156 | x1 = random.randint(0, 1930 - self.opt.width) 157 | else: 158 | y1 = int((1216 + self.opt.height)/2) 159 | x1 = int((1936 - self.opt.width)/2) 160 | 161 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)][y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)][None, :, :] 162 | 163 | rgb_cur = file_cur['rgb'] 164 | rgb_cur = rgb_cur[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)] 165 | rgb_cur = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB) 166 | rgb_cur = torch.from_numpy(rgb_cur).permute(2, 0, 1) / 255. 167 | inputs[("color", 0, 0)] = rgb_cur 168 | 169 | 170 | 171 | rgb_pre = file_pre['rgb'] 172 | rgb_pre = rgb_pre[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)] 173 | rgb_pre = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB) 174 | rgb_pre = torch.from_numpy(rgb_pre).permute(2, 0, 1) / 255. 175 | inputs[("color", 1, 0)] = rgb_pre 176 | 177 | 178 | rgb_next = file_next['rgb'] 179 | rgb_next = rgb_next[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)] 180 | rgb_next = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB) 181 | rgb_next = torch.from_numpy(rgb_next).permute(2, 0, 1) / 255. 182 | inputs[("color", 2, 0)] = rgb_next 183 | 184 | 185 | pose_cur = file_cur['pose'] 186 | pose_cur = np.linalg.inv(pose_cur).astype('float32') 187 | inputs[("pose", 0)] = pose_cur 188 | 189 | pose_pre = file_pre['pose'] 190 | pose_pre = np.linalg.inv(pose_pre).astype('float32') 191 | inputs[("pose", 1)] = pose_pre 192 | 193 | pose_next = file_next['pose'] 194 | pose_next = np.linalg.inv(pose_next).astype('float32') 195 | inputs[("pose", 2)] = pose_next 196 | 197 | k_raw = file_cur['intrinsics'] 198 | k_crop, inputs = self.get_k_ori_randomcrop(k_raw, inputs, x1, y1) 199 | inputs = self.get_K(k_crop, inputs) 200 | 201 | inputs = self.compute_projection_matrix(inputs) 202 | 203 | 204 | inputs['num_frame'] = 3 205 | 206 | return inputs 207 | 208 | 209 | 210 | def get_K(self, K, inputs): 211 | inv_K = np.linalg.inv(K) 212 | K_pool = {} 213 | ho, wo = self.opt.height, self.opt.width 214 | for i in range(6): 215 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32') 216 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i 217 | 218 | inputs['K_pool'] = K_pool 219 | 220 | inputs[("inv_K_pool", 0)] = {} 221 | for k, v in K_pool.items(): 222 | K44 = np.eye(4) 223 | K44[:3, :3] = v 224 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32') 225 | 226 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32')) 227 | 228 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32')) 229 | 230 | return inputs 231 | 232 | def get_K_test(self, K, inputs): 233 | inv_K = np.linalg.inv(K) 234 | K_pool = {} 235 | ho, wo = self.opt.eval_height, self.opt.eval_width 236 | for i in range(6): 237 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32') 238 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i 239 | 240 | inputs['K_pool'] = K_pool 241 | 242 | inputs[("inv_K_pool", 0)] = {} 243 | for k, v in K_pool.items(): 244 | K44 = np.eye(4) 245 | K44[:3, :3] = v 246 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32') 247 | 248 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32')) 249 | 250 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32')) 251 | 252 | return inputs 253 | 254 | def compute_projection_matrix(self, inputs): 255 | for i in range(self.opt.num_frame): 256 | inputs[("proj", i)] = {} 257 | for k, v in inputs['K_pool'].items(): 258 | K44 = np.eye(4) 259 | K44[:3, :3] = v 260 | inputs[("proj", 261 | i)][k] = np.matmul(K44, inputs[("pose", 262 | i)]).astype('float32') 263 | return inputs 264 | -------------------------------------------------------------------------------- /datasets/DDAD_crop.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import os 3 | import random 4 | import numpy as np 5 | import copy 6 | from PIL import Image # using pillow-simd for increased speed 7 | from time import time 8 | import torch 9 | import torch.utils.data as data 10 | from torchvision import transforms 11 | import cv2 12 | 13 | cv2.setNumThreads(0) 14 | import glob 15 | import utils 16 | import torch.nn.functional as F 17 | from utils import npy 18 | import json 19 | 20 | 21 | 22 | class DDAD(data.Dataset): 23 | def __init__(self, opt, is_train): 24 | super(DDAD, self).__init__() 25 | self.opt = opt 26 | self.json_path = "/home/cjd/tmp/DDAD_video.json" 27 | self.data_path_root = '/data/cjd/ddad/my_ddad/' 28 | self.is_train = is_train 29 | if self.is_train: 30 | self.data_path = os.path.join(self.data_path_root, 'train/') 31 | else: 32 | self.data_path = os.path.join(self.data_path_root, 'val/') 33 | 34 | f = open(self.json_path, 'r') 35 | content_all = f.read() 36 | json_list_all = json.loads(content_all) 37 | f.close() 38 | 39 | if self.is_train: 40 | self.file_names = json_list_all["train"] 41 | # self.file_names = self.file_names[:300] 42 | print('train', len(self.file_names)) 43 | else: 44 | self.file_names = json_list_all["val"] 45 | # self.file_names = self.file_names[:300] 46 | 47 | print('val', len(self.file_names)) 48 | 49 | print('filter_pre', len(self.file_names)) 50 | self.file_names = [x for x in self.file_names if 'timestamp' in x.keys() and 'timestamp_back' in x.keys() and 'timestamp_forward' in x.keys() and x['Camera'] == 'CAMERA_01'] 51 | print('filter_after', len(self.file_names)) 52 | self.file_names = self.file_names[:50] 53 | 54 | 55 | 56 | def get_k_ori_randomcrop(self, k_raw, inputs, x1, y1): 57 | 58 | fx_ori = k_raw[0,0] 59 | fy_ori = k_raw[1,1] 60 | fx_virtual = 1060.0 61 | fx_scale = fx_ori / fx_virtual 62 | # print('fx_scale', fx_scale) 63 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale 64 | 65 | pose_cur = inputs[("pose", 0)] 66 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale 67 | inputs[("pose", 0)] = pose_cur 68 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)]) 69 | 70 | pose_pre = inputs[("pose", 1)] 71 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale 72 | inputs[("pose", 1)] = pose_pre 73 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)]) 74 | 75 | pose_next = inputs[("pose", 2)] 76 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale 77 | inputs[("pose", 2)] = pose_next 78 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)]) 79 | 80 | # inputs['focal_scale'] = float(fx_scale) 81 | 82 | K = np.zeros((3,3), dtype = float) 83 | K[0,0] = fx_ori 84 | K[1,1] = fy_ori 85 | K[2,2] = 1.0 86 | K[0,2] = k_raw[0,2] 87 | K[1,2] = k_raw[1,2] 88 | 89 | h_crop = y1 - self.opt.height 90 | w_crop = x1 91 | 92 | K[0,2] = K[0,2] - w_crop 93 | K[1,2] = K[1,2] - h_crop 94 | 95 | return K, inputs 96 | 97 | def get_k_ori_centercrop(self, k_raw, inputs, x_center, y_center): 98 | 99 | fx_ori = k_raw[0,0] 100 | fy_ori = k_raw[1,1] 101 | fx_virtual = 1060.0 102 | fx_scale = fx_ori / fx_virtual 103 | # print('fx_scale', fx_scale) 104 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale 105 | 106 | pose_cur = inputs[("pose", 0)] 107 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale 108 | inputs[("pose", 0)] = pose_cur 109 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)]) 110 | 111 | pose_pre = inputs[("pose", 1)] 112 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale 113 | inputs[("pose", 1)] = pose_pre 114 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)]) 115 | 116 | pose_next = inputs[("pose", 2)] 117 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale 118 | inputs[("pose", 2)] = pose_next 119 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)]) 120 | 121 | # inputs['focal_scale'] = float(fx_scale) 122 | 123 | K = np.zeros((3,3), dtype = float) 124 | K[0,0] = fx_ori 125 | K[1,1] = fy_ori 126 | K[2,2] = 1.0 127 | K[0,2] = k_raw[0,2] 128 | K[1,2] = k_raw[1,2] 129 | 130 | h_crop = int(y_center - 0.5*self.opt.height) 131 | w_crop = int(x_center - 0.5*self.opt.width) 132 | 133 | K[0,2] = K[0,2] - w_crop 134 | K[1,2] = K[1,2] - h_crop 135 | 136 | return K, inputs 137 | 138 | 139 | def __len__(self): 140 | return len(self.file_names) 141 | 142 | 143 | def __getitem__(self, index): 144 | inputs = {} 145 | cur_npz_path = self.data_path + str(self.file_names[index]['timestamp']) + '_' + self.file_names[index]['Camera'] + '.npz' 146 | pre_npz_path = self.data_path + str(self.file_names[index]['timestamp_back']) + '_' + self.file_names[index]['Camera'] + '.npz' 147 | next_npz_path = self.data_path + str(self.file_names[index]['timestamp_forward']) + '_' + self.file_names[index]['Camera'] + '.npz' 148 | 149 | cur_mask_path = cur_npz_path.replace('.npz', '_dynamic.npz') 150 | inputs['dynamic_mask'] = cur_mask_path 151 | file_cur = np.load(cur_npz_path) 152 | file_pre = np.load(pre_npz_path) 153 | file_next = np.load(next_npz_path) 154 | 155 | depth_cur_gt = file_cur['depth'] 156 | depth_cur_gt = np.array(depth_cur_gt).astype(np.float32) 157 | 158 | 159 | inputs[("depth_gt", 0, 0)] = depth_cur_gt 160 | 161 | 162 | if self.is_train: 163 | #h_ori=1216, w_ori=1936 164 | if random.randint(0, 10) < 8: 165 | y_center = int((1216 + self.opt.height)/2) 166 | y1 = random.randint(int(y_center - 70), int(y_center + 50)) 167 | else: 168 | y1 = random.randint(self.opt.height, 1216) 169 | x1 = random.randint(0, 1930 - self.opt.width) 170 | # else: 171 | # y1 = int((1216 + self.opt.height)/2) 172 | # x1 = int((1936 - self.opt.width)/2) 173 | 174 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)][y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)][None, :, :] 175 | 176 | rgb_cur = file_cur['rgb'] 177 | rgb_cur = rgb_cur[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)] 178 | rgb_cur = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB) 179 | rgb_cur = torch.from_numpy(rgb_cur).permute(2, 0, 1) / 255. 180 | inputs[("color", 0, 0)] = rgb_cur 181 | 182 | 183 | 184 | rgb_pre = file_pre['rgb'] 185 | rgb_pre = rgb_pre[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)] 186 | rgb_pre = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB) 187 | rgb_pre = torch.from_numpy(rgb_pre).permute(2, 0, 1) / 255. 188 | inputs[("color", 1, 0)] = rgb_pre 189 | 190 | 191 | rgb_next = file_next['rgb'] 192 | rgb_next = rgb_next[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)] 193 | rgb_next = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB) 194 | rgb_next = torch.from_numpy(rgb_next).permute(2, 0, 1) / 255. 195 | inputs[("color", 2, 0)] = rgb_next 196 | 197 | 198 | pose_cur = file_cur['pose'] 199 | pose_cur = np.linalg.inv(pose_cur).astype('float32') 200 | inputs[("pose", 0)] = pose_cur 201 | 202 | pose_pre = file_pre['pose'] 203 | pose_pre = np.linalg.inv(pose_pre).astype('float32') 204 | inputs[("pose", 1)] = pose_pre 205 | 206 | pose_next = file_next['pose'] 207 | pose_next = np.linalg.inv(pose_next).astype('float32') 208 | inputs[("pose", 2)] = pose_next 209 | 210 | k_raw = file_cur['intrinsics'] 211 | # print('k_raw',k_raw) 212 | k_crop, inputs = self.get_k_ori_randomcrop(k_raw, inputs, x1, y1) 213 | inputs = self.get_K(k_crop, inputs) 214 | 215 | 216 | else: 217 | x_center = int((1936 + self.opt.width)/2) 218 | y_center = int((1216 + self.opt.height)/2) 219 | 220 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)][int(y_center - 0.5*self.opt.height):int(y_center + 0.5*self.opt.height), int(x_center - 0.5*self.opt.width):int(x_center + 0.5*self.opt.width)][None, :, :] 221 | print(inputs[("depth_gt", 0, 0)].shape) 222 | rgb_cur = file_cur['rgb'] 223 | rgb_cur = rgb_cur[int(y_center - 0.5*self.opt.height):int(y_center + 0.5*self.opt.height), int(x_center - 0.5*self.opt.width):int(x_center + 0.5*self.opt.width)] 224 | rgb_cur = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB) 225 | rgb_cur = torch.from_numpy(rgb_cur).permute(2, 0, 1) / 255. 226 | inputs[("color", 0, 0)] = rgb_cur 227 | 228 | 229 | 230 | rgb_pre = file_pre['rgb'] 231 | rgb_pre = rgb_pre[int(y_center - 0.5*self.opt.height):int(y_center + 0.5*self.opt.height), int(x_center - 0.5*self.opt.width):int(x_center + 0.5*self.opt.width)] 232 | rgb_pre = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB) 233 | rgb_pre = torch.from_numpy(rgb_pre).permute(2, 0, 1) / 255. 234 | inputs[("color", 1, 0)] = rgb_pre 235 | 236 | 237 | rgb_next = file_next['rgb'] 238 | rgb_next = rgb_next[int(y_center - 0.5*self.opt.height):int(y_center + 0.5*self.opt.height), int(x_center - 0.5*self.opt.width):int(x_center + 0.5*self.opt.width)] 239 | rgb_next = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB) 240 | rgb_next = torch.from_numpy(rgb_next).permute(2, 0, 1) / 255. 241 | inputs[("color", 2, 0)] = rgb_next 242 | 243 | 244 | pose_cur = file_cur['pose'] 245 | pose_cur = np.linalg.inv(pose_cur).astype('float32') 246 | inputs[("pose", 0)] = pose_cur 247 | 248 | pose_pre = file_pre['pose'] 249 | pose_pre = np.linalg.inv(pose_pre).astype('float32') 250 | inputs[("pose", 1)] = pose_pre 251 | 252 | pose_next = file_next['pose'] 253 | pose_next = np.linalg.inv(pose_next).astype('float32') 254 | inputs[("pose", 2)] = pose_next 255 | 256 | k_raw = file_cur['intrinsics'] 257 | 258 | k_crop, inputs = self.get_k_ori_centercrop(k_raw, inputs, x_center, y_center) 259 | 260 | inputs = self.get_K_test(k_crop, inputs) 261 | 262 | inputs = self.compute_projection_matrix(inputs) 263 | 264 | 265 | inputs['num_frame'] = 3 266 | 267 | # for key, value in inputs.items(): 268 | # print(key, value.dtype) 269 | 270 | return inputs 271 | 272 | 273 | 274 | def get_K(self, K, inputs): 275 | inv_K = np.linalg.inv(K) 276 | K_pool = {} 277 | ho, wo = self.opt.height, self.opt.width 278 | for i in range(6): 279 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32') 280 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i 281 | 282 | inputs['K_pool'] = K_pool 283 | 284 | inputs[("inv_K_pool", 0)] = {} 285 | for k, v in K_pool.items(): 286 | K44 = np.eye(4) 287 | K44[:3, :3] = v 288 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32') 289 | 290 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32')) 291 | 292 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32')) 293 | 294 | return inputs 295 | 296 | def get_K_test(self, K, inputs): 297 | inv_K = np.linalg.inv(K) 298 | K_pool = {} 299 | ho, wo = self.opt.eval_height, self.opt.eval_width 300 | for i in range(6): 301 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32') 302 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i 303 | 304 | inputs['K_pool'] = K_pool 305 | 306 | inputs[("inv_K_pool", 0)] = {} 307 | for k, v in K_pool.items(): 308 | K44 = np.eye(4) 309 | K44[:3, :3] = v 310 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32') 311 | 312 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32')) 313 | 314 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32')) 315 | 316 | return inputs 317 | 318 | def compute_projection_matrix(self, inputs): 319 | for i in range(self.opt.num_frame): 320 | inputs[("proj", i)] = {} 321 | for k, v in inputs['K_pool'].items(): 322 | K44 = np.eye(4) 323 | K44[:3, :3] = v 324 | inputs[("proj", 325 | i)][k] = np.matmul(K44, inputs[("pose", 326 | i)]).astype('float32') 327 | return inputs 328 | -------------------------------------------------------------------------------- /datasets/DDAD_forward.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import os 3 | import random 4 | import numpy as np 5 | import copy 6 | from PIL import Image # using pillow-simd for increased speed 7 | from time import time 8 | import torch 9 | import torch.utils.data as data 10 | from torchvision import transforms 11 | import cv2 12 | 13 | cv2.setNumThreads(0) 14 | import glob 15 | import utils 16 | import torch.nn.functional as F 17 | from utils import npy 18 | import json 19 | 20 | 21 | 22 | class DDAD(data.Dataset): 23 | def __init__(self, opt, is_train): 24 | super(DDAD, self).__init__() 25 | self.opt = opt 26 | self.json_path = "/home/cjd/tmp/DDAD_video.json" 27 | self.data_path_root = '/data/cjd/ddad/my_ddad/' 28 | self.is_train = is_train 29 | if self.is_train: 30 | self.data_path = os.path.join(self.data_path_root, 'train/') 31 | else: 32 | self.data_path = os.path.join(self.data_path_root, 'val/') 33 | 34 | f = open(self.json_path, 'r') 35 | content_all = f.read() 36 | json_list_all = json.loads(content_all) 37 | f.close() 38 | 39 | if self.is_train: 40 | self.file_names = json_list_all["train"] 41 | # self.file_names = self.file_names[:300] 42 | print('train', len(self.file_names)) 43 | else: 44 | self.file_names = json_list_all["val"] 45 | # self.file_names = self.file_names[:300] 46 | 47 | print('val', len(self.file_names)) 48 | 49 | print('filter_pre', len(self.file_names)) 50 | self.file_names = [x for x in self.file_names if 'timestamp' in x.keys() and 'timestamp_back' in x.keys() and 'timestamp_forward' in x.keys() and x['Camera'] == 'CAMERA_01'] 51 | print('filter_after', len(self.file_names)) 52 | self.file_names = self.file_names[:50] 53 | 54 | 55 | 56 | def get_k_ori_randomcrop(self, k_raw, inputs, x1, y1): 57 | 58 | fx_ori = k_raw[0,0] 59 | fy_ori = k_raw[1,1] 60 | fx_virtual = 1060.0 61 | fx_scale = fx_ori / fx_virtual 62 | # print('fx_scale', fx_scale) 63 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale 64 | 65 | pose_cur = inputs[("pose", 0)] 66 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale 67 | inputs[("pose", 0)] = pose_cur 68 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)]) 69 | 70 | pose_pre = inputs[("pose", 1)] 71 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale 72 | inputs[("pose", 1)] = pose_pre 73 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)]) 74 | 75 | pose_next = inputs[("pose", 2)] 76 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale 77 | inputs[("pose", 2)] = pose_next 78 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)]) 79 | 80 | # inputs['focal_scale'] = float(fx_scale) 81 | 82 | K = np.zeros((3,3), dtype = float) 83 | K[0,0] = fx_ori 84 | K[1,1] = fy_ori 85 | K[2,2] = 1.0 86 | K[0,2] = k_raw[0,2] 87 | K[1,2] = k_raw[1,2] 88 | 89 | h_crop = y1 - self.opt.height 90 | w_crop = x1 91 | 92 | K[0,2] = K[0,2] - w_crop 93 | K[1,2] = K[1,2] - h_crop 94 | 95 | return K, inputs 96 | 97 | def get_k_ori_centercrop(self, k_raw, inputs): 98 | 99 | fx_ori = k_raw[0,0] 100 | fy_ori = k_raw[1,1] 101 | fx_virtual = 1060.0 102 | fx_scale = fx_ori / fx_virtual 103 | # print('fx_scale', fx_scale) 104 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)] / fx_scale 105 | 106 | pose_cur = inputs[("pose", 0)] 107 | pose_cur[:3, 3] = pose_cur[:3, 3] / fx_scale 108 | inputs[("pose", 0)] = pose_cur 109 | inputs[("pose_inv", 0)] = np.linalg.inv(inputs[("pose", 0)]) 110 | 111 | pose_pre = inputs[("pose", 1)] 112 | pose_pre[:3, 3] = pose_pre[:3, 3] / fx_scale 113 | inputs[("pose", 1)] = pose_pre 114 | inputs[("pose_inv", 1)] = np.linalg.inv(inputs[("pose", 1)]) 115 | 116 | pose_next = inputs[("pose", 2)] 117 | pose_next[:3, 3] = pose_next[:3, 3] / fx_scale 118 | inputs[("pose", 2)] = pose_next 119 | inputs[("pose_inv", 2)] = np.linalg.inv(inputs[("pose", 2)]) 120 | 121 | # inputs['focal_scale'] = float(fx_scale) 122 | 123 | K = np.zeros((3,3), dtype = float) 124 | K[0,0] = fx_ori 125 | K[1,1] = fy_ori 126 | K[2,2] = 1.0 127 | K[0,2] = k_raw[0,2] 128 | K[1,2] = k_raw[1,2] 129 | 130 | h_crop = 0.0 131 | w_crop = 8.0 132 | 133 | K[0,2] = K[0,2] - w_crop 134 | K[1,2] = K[1,2] - h_crop 135 | 136 | return K, inputs 137 | 138 | 139 | def __len__(self): 140 | return len(self.file_names) 141 | 142 | 143 | def __getitem__(self, index): 144 | inputs = {} 145 | cur_npz_path = self.data_path + str(self.file_names[index]['timestamp']) + '_' + self.file_names[index]['Camera'] + '.npz' 146 | pre_npz_path = self.data_path + str(self.file_names[index]['timestamp_back']) + '_' + self.file_names[index]['Camera'] + '.npz' 147 | next_npz_path = self.data_path + str(self.file_names[index]['timestamp_forward']) + '_' + self.file_names[index]['Camera'] + '.npz' 148 | 149 | cur_mask_path = cur_npz_path.replace('.npz', '_dynamic.npz') 150 | inputs['dynamic_mask'] = cur_mask_path 151 | file_cur = np.load(cur_npz_path) 152 | file_pre = np.load(pre_npz_path) 153 | file_next = np.load(next_npz_path) 154 | 155 | depth_cur_gt = file_cur['depth'] 156 | depth_cur_gt = np.array(depth_cur_gt).astype(np.float32) 157 | 158 | 159 | inputs[("depth_gt", 0, 0)] = depth_cur_gt 160 | 161 | 162 | if self.is_train: 163 | #h_ori=1216, w_ori=1936 164 | if random.randint(0, 10) < 8: 165 | y_center = int((1216 + self.opt.height)/2) 166 | y1 = random.randint(int(y_center - 70), int(y_center + 50)) 167 | else: 168 | y1 = random.randint(self.opt.height, 1216) 169 | x1 = random.randint(0, 1930 - self.opt.width) 170 | # else: 171 | # y1 = int((1216 + self.opt.height)/2) 172 | # x1 = int((1936 - self.opt.width)/2) 173 | 174 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)][y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)][None, :, :] 175 | 176 | rgb_cur = file_cur['rgb'] 177 | rgb_cur = rgb_cur[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)] 178 | rgb_cur = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB) 179 | rgb_cur = torch.from_numpy(rgb_cur).permute(2, 0, 1) / 255. 180 | inputs[("color", 0, 0)] = rgb_cur 181 | 182 | 183 | 184 | rgb_pre = file_pre['rgb'] 185 | rgb_pre = rgb_pre[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)] 186 | rgb_pre = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB) 187 | rgb_pre = torch.from_numpy(rgb_pre).permute(2, 0, 1) / 255. 188 | inputs[("color", 1, 0)] = rgb_pre 189 | 190 | 191 | rgb_next = file_next['rgb'] 192 | rgb_next = rgb_next[y1 - int(self.opt.height):y1, x1:x1+int(self.opt.width)] 193 | rgb_next = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB) 194 | rgb_next = torch.from_numpy(rgb_next).permute(2, 0, 1) / 255. 195 | inputs[("color", 2, 0)] = rgb_next 196 | 197 | 198 | pose_cur = file_cur['pose'] 199 | pose_cur = np.linalg.inv(pose_cur).astype('float32') 200 | inputs[("pose", 0)] = pose_cur 201 | 202 | pose_pre = file_pre['pose'] 203 | pose_pre = np.linalg.inv(pose_pre).astype('float32') 204 | inputs[("pose", 1)] = pose_pre 205 | 206 | pose_next = file_next['pose'] 207 | pose_next = np.linalg.inv(pose_next).astype('float32') 208 | inputs[("pose", 2)] = pose_next 209 | 210 | k_raw = file_cur['intrinsics'] 211 | # print('k_raw',k_raw) 212 | k_crop, inputs = self.get_k_ori_randomcrop(k_raw, inputs, x1, y1) 213 | inputs = self.get_K(k_crop, inputs) 214 | 215 | 216 | else: 217 | inputs[("depth_gt", 0, 0)] = inputs[("depth_gt", 0, 0)][:, 8:1928][None, :, :] 218 | 219 | rgb_cur = file_cur['rgb'] 220 | rgb_cur = rgb_cur[:, 8:1928] 221 | rgb_cur = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB) 222 | rgb_cur = torch.from_numpy(rgb_cur).permute(2, 0, 1) / 255. 223 | inputs[("color", 0, 0)] = rgb_cur 224 | 225 | 226 | 227 | rgb_pre = file_pre['rgb'] 228 | rgb_pre = rgb_pre[:, 8:1928] 229 | rgb_pre = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB) 230 | rgb_pre = torch.from_numpy(rgb_pre).permute(2, 0, 1) / 255. 231 | inputs[("color", 1, 0)] = rgb_pre 232 | 233 | 234 | rgb_next = file_next['rgb'] 235 | rgb_next = rgb_next[:, 8:1928] 236 | rgb_next = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB) 237 | rgb_next = torch.from_numpy(rgb_next).permute(2, 0, 1) / 255. 238 | inputs[("color", 2, 0)] = rgb_next 239 | 240 | 241 | pose_cur = file_cur['pose'] 242 | pose_cur = np.linalg.inv(pose_cur).astype('float32') 243 | inputs[("pose", 0)] = pose_cur 244 | 245 | pose_pre = file_pre['pose'] 246 | pose_pre = np.linalg.inv(pose_pre).astype('float32') 247 | inputs[("pose", 1)] = pose_pre 248 | 249 | pose_next = file_next['pose'] 250 | pose_next = np.linalg.inv(pose_next).astype('float32') 251 | inputs[("pose", 2)] = pose_next 252 | 253 | k_raw = file_cur['intrinsics'] 254 | 255 | k_crop, inputs = self.get_k_ori_centercrop(k_raw, inputs) 256 | 257 | inputs = self.get_K_test(k_crop, inputs) 258 | 259 | inputs = self.compute_projection_matrix(inputs) 260 | 261 | 262 | inputs['num_frame'] = 3 263 | 264 | # for key, value in inputs.items(): 265 | # print(key, value.dtype) 266 | 267 | return inputs 268 | 269 | 270 | 271 | def get_K(self, K, inputs): 272 | inv_K = np.linalg.inv(K) 273 | K_pool = {} 274 | ho, wo = self.opt.height, self.opt.width 275 | for i in range(6): 276 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32') 277 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i 278 | 279 | inputs['K_pool'] = K_pool 280 | 281 | inputs[("inv_K_pool", 0)] = {} 282 | for k, v in K_pool.items(): 283 | K44 = np.eye(4) 284 | K44[:3, :3] = v 285 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32') 286 | 287 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32')) 288 | 289 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32')) 290 | 291 | return inputs 292 | 293 | def get_K_test(self, K, inputs): 294 | inv_K = np.linalg.inv(K) 295 | K_pool = {} 296 | ho, wo = self.opt.eval_height, self.opt.eval_width 297 | for i in range(6): 298 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32') 299 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i 300 | 301 | inputs['K_pool'] = K_pool 302 | 303 | inputs[("inv_K_pool", 0)] = {} 304 | for k, v in K_pool.items(): 305 | K44 = np.eye(4) 306 | K44[:3, :3] = v 307 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32') 308 | 309 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32')) 310 | 311 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32')) 312 | 313 | return inputs 314 | 315 | def compute_projection_matrix(self, inputs): 316 | for i in range(self.opt.num_frame): 317 | inputs[("proj", i)] = {} 318 | for k, v in inputs['K_pool'].items(): 319 | K44 = np.eye(4) 320 | K44[:3, :3] = v 321 | inputs[("proj", 322 | i)][k] = np.matmul(K44, inputs[("pose", 323 | i)]).astype('float32') 324 | return inputs 325 | -------------------------------------------------------------------------------- /datasets/__pycache__/DDAD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/DDAD.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/DDAD_crop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/DDAD_crop.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/DDAD_forward.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/DDAD_forward.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/kitti.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/kitti.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/kitti.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/kitti.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/kitti_odometry.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/kitti_odometry.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/kitti_odometry.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/datasets/__pycache__/kitti_odometry.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/kitti.py: -------------------------------------------------------------------------------- 1 | # dataloader for KITTI / when training & testing F-Net and MaGNet 2 | import os 3 | import random 4 | import glob 5 | 6 | import numpy as np 7 | import torch 8 | import torch.utils.data as data 9 | import torch.utils.data.distributed 10 | from PIL import Image 11 | 12 | from torch.utils.data import Dataset, DataLoader 13 | from torchvision import transforms 14 | 15 | import pykitti 16 | 17 | import cv2 18 | 19 | import utils 20 | import torch.nn.functional as F 21 | from utils import npy 22 | import json 23 | 24 | 25 | 26 | class DDAD_kitti(data.Dataset): 27 | def __init__(self, opt, is_train): 28 | super(DDAD_kitti, self).__init__() 29 | self.opt = opt 30 | self.is_train = is_train 31 | if self.is_train: 32 | with open("./data_split/kitti_eigen_train.txt", 'r') as f: 33 | self.filenames = f.readlines() 34 | else: 35 | with open("./data_split/kitti_eigen_test.txt", 'r') as f: 36 | self.filenames = f.readlines() 37 | 38 | # self.mode = mode 39 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 40 | self.dataset_path = self.opt.data_path 41 | 42 | # local window 43 | self.window_radius = int(2) 44 | self.n_views = int(2) 45 | self.frame_interval = self.window_radius // (self.n_views // 2) 46 | self.img_idx_center = self.n_views // 2 47 | 48 | # window_idx_list 49 | self.window_idx_list = list(range(-self.n_views // 2, (self.n_views // 2) + 1)) 50 | self.window_idx_list = [i * self.frame_interval for i in self.window_idx_list] 51 | 52 | # image resolution 53 | self.img_H = self.opt.height # 352 54 | self.img_W = self.opt.width # 1216 55 | 56 | def __len__(self): 57 | return len(self.filenames) 58 | 59 | # get camera intrinscs 60 | def get_cam_intrinsics(self, p_data): 61 | raw_img_size = p_data.get_cam2(0).size 62 | raw_W = int(raw_img_size[0]) 63 | raw_H = int(raw_img_size[1]) 64 | 65 | top_margin = int(raw_H - 352) 66 | left_margin = int((raw_W - 1216) / 2) 67 | 68 | # original intrinsic matrix (4X4) 69 | IntM_ = p_data.calib.K_cam2 70 | 71 | # updated intrinsic matrix 72 | IntM = np.zeros((3, 3)) 73 | IntM[2, 2] = 1. 74 | IntM[0, 0] = IntM_[0, 0] 75 | IntM[1, 1] = IntM_[1, 1] 76 | IntM[0, 2] = (IntM_[0, 2] - left_margin) 77 | IntM[1, 2] = (IntM_[1, 2] - top_margin) 78 | 79 | IntM = IntM.astype(np.float32) 80 | return IntM 81 | 82 | def __getitem__(self, idx): 83 | inputs = {} 84 | date, drive, mode, img_idx = self.filenames[idx].split(' ') 85 | img_idx = int(img_idx) 86 | scene_name = '%s_drive_%s_sync' % (date, drive) 87 | 88 | # identify the neighbor views 89 | img_idx_list = [img_idx + i for i in self.window_idx_list] 90 | p_data = pykitti.raw(self.dataset_path + '/rawdata', date, drive, frames=img_idx_list) 91 | 92 | # cam intrinsics 93 | cam_intrins = self.get_cam_intrinsics(p_data) 94 | 95 | # color augmentation 96 | color_aug = False 97 | if self.is_train: 98 | if random.random() > 0.5: 99 | color_aug = True 100 | aug_gamma = random.uniform(0.9, 1.1) 101 | aug_brightness = random.uniform(0.9, 1.1) 102 | aug_colors = np.random.uniform(0.9, 1.1, size=3) 103 | 104 | # data array 105 | data_array = [] 106 | for i in range(self.n_views + 1): 107 | cur_idx = img_idx_list[i] 108 | 109 | # read img 110 | img_name = '%010d.png' % cur_idx 111 | img_path = self.dataset_path + '/rawdata/{}/{}/image_02/data/{}'.format(date, scene_name, img_name) 112 | img = Image.open(img_path).convert("RGB") 113 | 114 | # kitti benchmark crop 115 | height = img.height 116 | width = img.width 117 | top_margin = int(height - 352) 118 | left_margin = int((width - 1216) / 2) 119 | img = img.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352)) 120 | 121 | # to tensor 122 | img = np.array(img).astype(np.float32) / 255.0 # (H, W, 3) 123 | img_ori = torch.from_numpy(img).permute(2, 0, 1) 124 | if color_aug: 125 | img = self.augment_image(img, aug_gamma, aug_brightness, aug_colors) 126 | img = torch.from_numpy(img).permute(2, 0, 1) # (3, H, W) 127 | img = self.normalize(img) 128 | 129 | # read dmap (only for the ref img) 130 | if i == self.img_idx_center: 131 | dmap_path = self.dataset_path + '/{}/{}/proj_depth/groundtruth/image_02/{}'.format(mode, scene_name, 132 | img_name) 133 | gt_dmap = Image.open(dmap_path).crop((left_margin, top_margin, left_margin + 1216, top_margin + 352)) 134 | gt_dmap = np.array(gt_dmap)[:, :, np.newaxis].astype(np.float32) # (H, W, 1) 135 | gt_dmap = gt_dmap / 256.0 136 | gt_dmap = torch.from_numpy(gt_dmap).permute(2, 0, 1) # (1, H, W) 137 | else: 138 | gt_dmap = 0.0 139 | 140 | # read extM 141 | pose = p_data.oxts[i].T_w_imu 142 | M_imu2cam = p_data.calib.T_cam2_imu 143 | extM = np.matmul(M_imu2cam, np.linalg.inv(pose)) 144 | extM = extM.astype('float32') 145 | 146 | data_dict = { 147 | 'img_ori': img_ori, 148 | 'img': img, 149 | 'gt_dmap': gt_dmap, 150 | 'extM': extM, 151 | 'scene_name': scene_name, 152 | 'img_idx': str(img_idx), 153 | } 154 | data_array.append(data_dict) 155 | 156 | inputs[("color", 0, 0)] = data_array[1]['img'] 157 | inputs[("img_ori", 0, 0)] = data_array[1]['img_ori'] 158 | 159 | inputs[("depth_gt", 0, 0)] = data_array[1]['gt_dmap'] 160 | inputs[("pose", 0)] = data_array[1]['extM'] 161 | 162 | inputs[("color", 1, 0)] = data_array[0]['img'] 163 | inputs[("pose", 1)] = data_array[0]['extM'] 164 | inputs[("img_ori", 1, 0)] = data_array[0]['img_ori'] 165 | 166 | inputs[("color", 2, 0)] = data_array[2]['img'] 167 | inputs[("pose", 2)] = data_array[2]['extM'] 168 | inputs[("img_ori", 2, 0)] = data_array[2]['img_ori'] 169 | 170 | 171 | inputs = self.get_K(cam_intrins, inputs) 172 | 173 | inputs = self.compute_projection_matrix(inputs) 174 | 175 | inputs['num_frame'] = 3 176 | 177 | return inputs 178 | 179 | def augment_image(self, image, gamma, brightness, colors): 180 | # gamma augmentation 181 | image_aug = image ** gamma 182 | 183 | # brightness augmentation 184 | image_aug = image_aug * brightness 185 | 186 | # color augmentation 187 | white = np.ones((image.shape[0], image.shape[1])) 188 | color_image = np.stack([white * colors[i] for i in range(3)], axis=2) 189 | image_aug *= color_image 190 | image_aug = np.clip(image_aug, 0, 1) 191 | 192 | return image_aug 193 | 194 | def get_K(self, K, inputs): 195 | inv_K = np.linalg.inv(K) 196 | K_pool = {} 197 | ho, wo = self.opt.height, self.opt.width 198 | for i in range(6): 199 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32') 200 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i 201 | 202 | inputs['K_pool'] = K_pool 203 | 204 | inputs[("inv_K_pool", 0)] = {} 205 | for k, v in K_pool.items(): 206 | K44 = np.eye(4) 207 | K44[:3, :3] = v 208 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32') 209 | 210 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32')) 211 | 212 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32')) 213 | 214 | return inputs 215 | 216 | 217 | def compute_projection_matrix(self, inputs): 218 | for i in range(3): 219 | inputs[("proj", i)] = {} 220 | for k, v in inputs['K_pool'].items(): 221 | K44 = np.eye(4) 222 | K44[:3, :3] = v 223 | inputs[("proj", 224 | i)][k] = np.matmul(K44, inputs[("pose", 225 | i)]).astype('float32') 226 | return inputs -------------------------------------------------------------------------------- /eval_ddad.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import sys 5 | import cv2 6 | import torch 7 | import torch.nn as nn 8 | from options import MVS2DOptions, EvalCfg 9 | import networks 10 | from torch.utils.data import DataLoader 11 | from datasets.DDAD import DDAD 12 | import torch.nn.functional as F 13 | from utils import * 14 | from hybrid_evaluate_depth import evaluate_depth_maps, compute_errors,compute_errors1,compute_errors_perimage 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 17 | 18 | 19 | def to_gpu(inputs, keys=None): 20 | if keys == None: 21 | keys = inputs.keys() 22 | for key in keys: 23 | if key not in inputs: 24 | continue 25 | ipt = inputs[key] 26 | if type(ipt) == torch.Tensor: 27 | inputs[key] = ipt.cuda() 28 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor: 29 | inputs[key] = [ 30 | x.cuda() for x in ipt 31 | ] 32 | elif type(ipt) == dict: 33 | for k in ipt.keys(): 34 | if type(ipt[k]) == torch.Tensor: 35 | ipt[k] = ipt[k].cuda() 36 | 37 | 38 | options = MVS2DOptions() 39 | opts = options.parse() 40 | 41 | opts.cfg = "/configs/DDAD.conf" 42 | dataset = DDAD(opts, False) 43 | data_loader = DataLoader(dataset, 44 | 1, 45 | shuffle=False, 46 | num_workers=4, 47 | pin_memory=True, 48 | drop_last=False, 49 | sampler=None) 50 | model = networks.MVS2D(opt=opts).cuda() 51 | pretrained_dict = torch.load("pretrained_model/DDAD/model_DDAD.pth") 52 | 53 | model.load_state_dict(pretrained_dict) 54 | model.eval() 55 | 56 | index = 0 57 | 58 | total_result_sum = {} 59 | total_result_count = {} 60 | min_depth = opts.EVAL_MIN_DEPTH 61 | max_depth = opts.EVAL_MAX_DEPTH 62 | 63 | with torch.no_grad(): 64 | for batch_idx, inputs in enumerate(data_loader): 65 | to_gpu(inputs) 66 | 67 | imgs, proj_mats, pose_mats = [], [], [] 68 | for i in range(inputs['num_frame'][0].item()): 69 | imgs.append(inputs[('color', i, 0)]) 70 | proj_mats.append(inputs[('proj', i)]) 71 | pose_mats.append(inputs[('pose', i)]) 72 | 73 | depth_gt = inputs[("depth_gt", 0, 0)] 74 | depth_gt_np = depth_gt.cpu().detach().numpy().squeeze() 75 | 76 | mask = (depth_gt_np>min_depth) & (depth_gt_np < max_depth) 77 | outputs = model(imgs[0], imgs[1:], pose_mats[0], pose_mats[1:], 78 | inputs[('inv_K_pool', 0)]) 79 | depth_pred_1 = outputs[('depth_pred', 0)] 80 | depth_pred_2 = outputs[('depth_pred_2', 0)] 81 | 82 | depth_pred_2_np = depth_pred_2.cpu().detach().numpy().squeeze() 83 | depth_pred_1_np = depth_pred_1.cpu().detach().numpy().squeeze() 84 | 85 | error_temp = compute_errors_perimage(depth_gt_np[mask], depth_pred_1_np[mask], min_depth, max_depth) 86 | error_temp_2_ = compute_errors_perimage(depth_gt_np[mask], depth_pred_2_np[mask], min_depth, max_depth) 87 | print('cur',index, error_temp) 88 | index = index + 1 89 | error_temp_2 = {} 90 | for k,v in error_temp_2_.items(): 91 | new_k = k + '_2' 92 | error_temp_2[new_k] = error_temp_2_[k] 93 | 94 | error_temp_all = {} 95 | error_temp_all.update(error_temp) 96 | error_temp_all.update(error_temp_2) 97 | 98 | for k,v in error_temp_all.items(): 99 | if not isinstance(v,float): 100 | v=v.items() 101 | if k in total_result_sum: 102 | total_result_sum[k] = total_result_sum[k] + v 103 | else: 104 | total_result_sum[k] = v 105 | 106 | for k in total_result_sum.keys(): 107 | total_result_count[k] = total_result_sum['valid_number'] 108 | 109 | print('final####################################') 110 | 111 | for k in total_result_sum.keys(): 112 | this_tensor = torch.tensor([total_result_sum[k], total_result_count[k]]) 113 | this_list = [this_tensor] 114 | this_tensor = this_list[0].detach().cpu().numpy() 115 | reduce_sum = this_tensor[0].item() 116 | reduce_count = this_tensor[1].item() 117 | reduce_mean = reduce_sum / reduce_count 118 | print(k, reduce_mean) 119 | 120 | -------------------------------------------------------------------------------- /eval_kitti.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import sys 5 | import cv2 6 | import torch 7 | import torch.nn as nn 8 | from options import MVS2DOptions, EvalCfg 9 | import networks 10 | from torch.utils.data import DataLoader 11 | from datasets.kitti import DDAD_kitti 12 | from hybrid_evaluate_depth import evaluate_depth_maps, compute_errors,compute_errors1,compute_errors_perimage 13 | import torch.nn.functional as F 14 | import os 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 16 | 17 | def to_gpu(inputs, keys=None): 18 | if keys == None: 19 | keys = inputs.keys() 20 | for key in keys: 21 | if key not in inputs: 22 | continue 23 | ipt = inputs[key] 24 | if type(ipt) == torch.Tensor: 25 | inputs[key] = ipt.cuda() 26 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor: 27 | inputs[key] = [ 28 | x.cuda() for x in ipt 29 | ] 30 | elif type(ipt) == dict: 31 | for k in ipt.keys(): 32 | if type(ipt[k]) == torch.Tensor: 33 | ipt[k] = ipt[k].cuda() 34 | 35 | 36 | options = MVS2DOptions() 37 | opts = options.parse() 38 | 39 | # opts.cfg = "./configs/kitti.conf" 40 | dataset = DDAD_kitti(opts, False) 41 | data_loader = DataLoader(dataset, 42 | 1, 43 | shuffle=False, 44 | num_workers=4, 45 | pin_memory=True, 46 | drop_last=False, 47 | sampler=None) 48 | model = networks.MVS2D(opt=opts).cuda() 49 | pretrained_dict = torch.load("pretrained_model/kitti/model_kitti.pth") 50 | 51 | model.load_state_dict(pretrained_dict) 52 | model.eval() 53 | 54 | min_depth = opts.EVAL_MIN_DEPTH 55 | max_depth = opts.EVAL_MAX_DEPTH 56 | 57 | index = 0 58 | total_result_sum = {} 59 | total_result_count = {} 60 | with torch.no_grad(): 61 | for batch_idx, inputs in enumerate(data_loader): 62 | to_gpu(inputs) 63 | 64 | imgs, proj_mats, pose_mats = [], [], [] 65 | for i in range(inputs['num_frame'][0].item()): 66 | imgs.append(inputs[('color', i, 0)]) 67 | proj_mats.append(inputs[('proj', i)]) 68 | pose_mats.append(inputs[('pose', i)]) 69 | 70 | depth_gt = inputs[("depth_gt", 0, 0)] 71 | depth_gt_np = depth_gt.cpu().detach().numpy().squeeze() 72 | mask = (depth_gt_np>min_depth) & (depth_gt_np < max_depth) 73 | 74 | if np.sum(mask.astype(np.float32)) > 5: 75 | 76 | outputs = model(imgs[0], imgs[1:], pose_mats[0], pose_mats[1:], 77 | inputs[('inv_K_pool', 0)]) 78 | depth_pred_1_tensor = outputs[('depth_pred', 0)] 79 | depth_pred_2_tensor = outputs[('depth_pred_2', 0)] 80 | 81 | depth_pred_2 = depth_pred_2_tensor.cpu().detach().numpy().squeeze() 82 | depth_pred_1 = depth_pred_1_tensor.cpu().detach().numpy().squeeze() 83 | 84 | error_temp = compute_errors_perimage(depth_gt_np[mask], depth_pred_1[mask], min_depth, max_depth) 85 | error_temp_2_ = compute_errors_perimage(depth_gt_np[mask], depth_pred_2[mask], min_depth, max_depth) 86 | print('cur',index, error_temp) 87 | index = index + 1 88 | error_temp_2 = {} 89 | for k,v in error_temp_2_.items(): 90 | new_k = k + '_2' 91 | error_temp_2[new_k] = error_temp_2_[k] 92 | 93 | error_temp_all = {} 94 | error_temp_all.update(error_temp) 95 | error_temp_all.update(error_temp_2) 96 | 97 | for k,v in error_temp_all.items(): 98 | if not isinstance(v,float): 99 | v=v.items() 100 | if k in total_result_sum: 101 | total_result_sum[k] = total_result_sum[k] + v 102 | else: 103 | total_result_sum[k] = v 104 | 105 | for k in total_result_sum.keys(): 106 | total_result_count[k] = total_result_sum['valid_number'] 107 | 108 | print('final####################################') 109 | 110 | for k in total_result_sum.keys(): 111 | this_tensor = torch.tensor([total_result_sum[k], total_result_count[k]]) 112 | this_list = [this_tensor] 113 | this_tensor = this_list[0].detach().cpu().numpy() 114 | reduce_sum = this_tensor[0].item() 115 | reduce_count = this_tensor[1].item() 116 | reduce_mean = reduce_sum / reduce_count 117 | print(k, reduce_mean) 118 | 119 | -------------------------------------------------------------------------------- /generate_dynamic_mask.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import os 3 | import random 4 | import numpy as np 5 | import copy 6 | from PIL import Image # using pillow-simd for increased speed 7 | from time import time 8 | import torch 9 | import torch.utils.data as data 10 | from torchvision import transforms 11 | import cv2 12 | 13 | cv2.setNumThreads(0) 14 | import glob 15 | import utils 16 | import torch.nn.functional as F 17 | from utils import npy 18 | import json 19 | from mmdet.apis import init_detector, inference_detector 20 | import mmcv 21 | from skimage.metrics import structural_similarity 22 | 23 | import os 24 | # os.environ['CUDA_VISIBLE_DEVICES'] = '3' 25 | 26 | def to_gpu(inputs, keys=None): 27 | if keys == None: 28 | keys = inputs.keys() 29 | for key in keys: 30 | if key not in inputs: 31 | continue 32 | ipt = inputs[key] 33 | if type(ipt) == torch.Tensor: 34 | inputs[key] = ipt.cuda() 35 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor: 36 | inputs[key] = [ 37 | x.cuda() for x in ipt 38 | ] 39 | elif type(ipt) == dict: 40 | for k in ipt.keys(): 41 | if type(ipt[k]) == torch.Tensor: 42 | ipt[k] = ipt[k].cuda() 43 | 44 | def homo_warping_depth(src_fea, src_proj, ref_proj, depth_values): 45 | # src_fea: [B, C, H, W] 46 | # src_proj: [B, 4, 4] 47 | # ref_proj: [B, 4, 4] 48 | # depth_values: [B, Ndepth, H, W] 49 | # out: [B, C, Ndepth, H, W] 50 | batch, channels = src_fea.shape[0], src_fea.shape[1] 51 | num_depth = depth_values.shape[1] 52 | #height, width = src_fea.shape[2], src_fea.shape[3] 53 | h_src, w_src = src_fea.shape[2], src_fea.shape[3] 54 | h_ref, w_ref = depth_values.shape[2], depth_values.shape[3] 55 | 56 | with torch.no_grad(): 57 | 58 | proj = torch.matmul(src_proj, torch.inverse(ref_proj)) 59 | rot = proj[:, :3, :3] # [B,3,3] 60 | trans = proj[:, :3, 3:4] # [B,3,1] 61 | 62 | 63 | y, x = torch.meshgrid([torch.arange(0, h_ref, dtype=torch.float32, device=src_fea.device), 64 | torch.arange(0, w_ref, dtype=torch.float32, device=src_fea.device)]) 65 | y, x = y.contiguous(), x.contiguous() 66 | y, x = y.view(h_ref * w_ref), x.view(h_ref * w_ref) 67 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 68 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] 69 | 70 | rot_xyz = torch.matmul(rot, xyz) 71 | print(rot_xyz.shape) 72 | print(depth_values.shape) 73 | rot_depth_xyz = rot_xyz * depth_values.view(batch, 1, -1) 74 | 75 | proj_xyz = rot_depth_xyz + trans.view(batch,3,1) 76 | 77 | proj_xy = proj_xyz[:, :2, :] / proj_xyz[:, 2:3, :] # [B, 2, Ndepth, H*W] 78 | z = proj_xyz[:, 2:3, :].view(batch, h_ref, w_ref) 79 | proj_x_normalized = proj_xy[:, 0, :] / ((w_src - 1) / 2.0) - 1 80 | proj_y_normalized = proj_xy[:, 1, :] / ((h_src - 1) / 2.0) - 1 81 | X_mask = ((proj_x_normalized > 1)+(proj_x_normalized < -1)).detach() 82 | proj_x_normalized[X_mask] = 2 # make sure that no point in warped image is a combinaison of im and gray 83 | Y_mask = ((proj_y_normalized > 1)+(proj_y_normalized < -1)).detach() 84 | proj_y_normalized[Y_mask] = 2 85 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=2) # [B, Ndepth, H*W, 2] 86 | grid = proj_xy 87 | proj_mask = ((X_mask + Y_mask) > 0).view(batch, num_depth, h_ref, w_ref) 88 | proj_mask = (proj_mask + (z <= 0)) > 0 89 | 90 | warped_src_fea = F.grid_sample(src_fea, grid.view(batch, h_ref, w_ref, 2), mode='bilinear', 91 | padding_mode='zeros', align_corners=True) 92 | 93 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, h_ref, w_ref) 94 | 95 | #return warped_src_fea , proj_mask 96 | return warped_src_fea 97 | 98 | def main(): 99 | json_path = "/home/cjd/tmp/DDAD_video.json" 100 | data_path_ori = '/data/cjd/ddad/ddad_train_val/' 101 | data_path_root = '/data/cjd/ddad/my_ddad/' 102 | data_path = os.path.join(data_path_root, 'val/') 103 | f = open(json_path, 'r') 104 | content_all = f.read() 105 | json_list_all = json.loads(content_all) 106 | f.close() 107 | file_names = json_list_all["val"] 108 | file_names = [x for x in file_names if 'timestamp' in x.keys() and 'timestamp_back' in x.keys() and 'timestamp_forward' in x.keys() and x['Camera'] == 'CAMERA_01'] 109 | 110 | model = init_detector("/home/cjd/mmdetection3d-master/configs/nuimages/htc_x101_64x4d_fpn_dconv_c3-c5_coco-20e_16x1_20e_nuim.py", "/home/cjd/MVS2D/htc_x101_64x4d_fpn_dconv_c3-c5_coco-20e_16x1_20e_nuim_20201008_211222-0b16ac4b.pth", device = 'cuda:3') 111 | # print(model.CLASSES) 112 | class_all = model.CLASSES 113 | print(class_all) 114 | lenth = len(file_names) 115 | print(lenth) 116 | for index in range(lenth): 117 | inputs = {} 118 | # cur_img_path = data_path_ori + str(file_names[index]['video_num']) + '/rgb/' + file_names[index]['Camera'] +'/'+ str(file_names[index]['timestamp']) + '.png' 119 | cur_npz_path = data_path + str(file_names[index]['timestamp']) + '_' + file_names[index]['Camera'] + '.npz' 120 | pre_npz_path = data_path + str(file_names[index]['timestamp_back']) + '_' + file_names[index]['Camera'] + '.npz' 121 | next_npz_path = data_path + str(file_names[index]['timestamp_forward']) + '_' + file_names[index]['Camera'] + '.npz' 122 | 123 | file_cur = np.load(cur_npz_path) 124 | file_pre = np.load(pre_npz_path) 125 | file_next = np.load(next_npz_path) 126 | 127 | 128 | depth_cur_gt = file_cur['depth'] 129 | depth_cur_gt = np.array(depth_cur_gt).astype(np.float32) 130 | 131 | inputs[("depth_gt", 0, 0)] = torch.from_numpy(depth_cur_gt) 132 | 133 | rgb_cur = file_cur['rgb'] 134 | # print(rgb_cur.shape) 135 | rgb_cur_input = cv2.cvtColor(rgb_cur, cv2.COLOR_BGR2RGB) 136 | rgb_cur_input = torch.from_numpy(rgb_cur_input).permute(2, 0, 1) / 255. 137 | inputs[("color", 0, 0)] = rgb_cur_input 138 | # cv2.imwrite('img.png', rgb_cur) 139 | pose_cur = file_cur['pose'] 140 | pose_cur = np.linalg.inv(pose_cur).astype('float32') 141 | inputs[("pose", 0)] = pose_cur 142 | rgb_pre = file_pre['rgb'] 143 | rgb_pre_input = cv2.cvtColor(rgb_pre, cv2.COLOR_BGR2RGB) 144 | rgb_pre_input = torch.from_numpy(rgb_pre_input).permute(2, 0, 1) / 255. 145 | inputs[("color", 1, 0)] = rgb_pre_input 146 | pose_pre = file_pre['pose'] 147 | pose_pre = np.linalg.inv(pose_pre).astype('float32') 148 | inputs[("pose", 1)] = pose_pre 149 | rgb_next = file_next['rgb'] 150 | rgb_next_input = cv2.cvtColor(rgb_next, cv2.COLOR_BGR2RGB) 151 | rgb_next_input = torch.from_numpy(rgb_next_input).permute(2, 0, 1) / 255. 152 | inputs[("color", 2, 0)] = rgb_next_input 153 | pose_next = file_next['pose'] 154 | pose_next = np.linalg.inv(pose_next).astype('float32') 155 | inputs[("pose", 2)] = pose_next 156 | K = file_cur['intrinsics'] 157 | 158 | inv_K = np.linalg.inv(K) 159 | 160 | K_pool = {} 161 | ho, wo, _ = rgb_cur.shape 162 | for i in range(6): 163 | K_pool[(ho // 2**i, wo // 2**i)] = K.copy().astype('float32') 164 | K_pool[(ho // 2**i, wo // 2**i)][:2, :] /= 2**i 165 | 166 | inputs['K_pool'] = K_pool 167 | 168 | inputs[("inv_K_pool", 0)] = {} 169 | for k, v in K_pool.items(): 170 | K44 = np.eye(4) 171 | K44[:3, :3] = v 172 | inputs[("inv_K_pool", 0)][k] = np.linalg.inv(K44).astype('float32') 173 | 174 | inputs[("inv_K", 0)] = torch.from_numpy(inv_K.astype('float32')) 175 | 176 | inputs[("K", 0)] = torch.from_numpy(K.astype('float32')) 177 | 178 | for i in range(3): 179 | inputs[("proj", i)] = {} 180 | for k, v in inputs['K_pool'].items(): 181 | K44 = np.eye(4) 182 | K44[:3, :3] = v 183 | inputs[("proj", 184 | i)][k] = torch.from_numpy(np.matmul(K44, inputs[("pose", 185 | i)]).astype('float32')) 186 | to_gpu(inputs) 187 | h, w, _ = rgb_cur.shape 188 | imgs, proj_mats, pose_mats = [], [], [] 189 | for i in range(3): 190 | imgs.append(inputs[('color', i, 0)]) 191 | proj_mats.append(inputs[('proj', i)]) 192 | pose_mats.append(inputs[('pose', i)]) 193 | 194 | depth_gt = inputs[("depth_gt", 0, 0)][None,None,:,:] 195 | img0 = imgs[0][None,:,:,:] 196 | img1 = imgs[1][None,:,:,:] 197 | 198 | proj_mats_0 = proj_mats[0][(h, w)][None,:,:] 199 | 200 | proj_mats_1 = proj_mats[1][(h, w)][None,:,:] 201 | # print(img1.shape) 202 | # print(proj_mats_0.shape) 203 | # print(depth_gt.shape) 204 | warped_img0 = homo_warping_depth(img1, proj_mats_1, proj_mats_0, depth_gt) 205 | img0_np = img0[0].cpu().detach().numpy().squeeze().transpose(1,2,0) 206 | warped_img0_np = warped_img0[0].cpu().detach().numpy().squeeze().transpose(1,2,0) 207 | depth_gt_np = depth_gt.cpu().detach().numpy().squeeze() 208 | # img0_np = (img0_np / img0_np.max() * 255).astype(np.uint8) 209 | # cv2.imwrite('img0.png', img0_np) 210 | 211 | # warped_img0_np = (warped_img0_np / warped_img0_np.max() * 255).astype(np.uint8) 212 | # cv2.imwrite('warped_img.png', warped_img0_np) 213 | # print(rgb_cur.shape) 214 | # img = mmcv.imread(cur_img_path) 215 | result = inference_detector(model, rgb_cur) 216 | 217 | index_list = [0,1,2,3,4,5,6,7] 218 | 219 | mask_all = np.zeros_like(rgb_cur[:,:,0], dtype = bool) 220 | for index_ in index_list: 221 | object_number = len(result[1][index_]) 222 | for index_object in range(object_number): 223 | mask_now = result[1][index_][index_object] & (depth_gt_np > 0) 224 | # diff_now = warped_img0_np[mask_now] - img0_np[mask_now] 225 | if np.sum(mask_now.astype(float)) > 50: 226 | ssim = structural_similarity(img0_np[mask_now], warped_img0_np[mask_now], multichannel = True) 227 | else: 228 | ssim = 0.3 229 | print(ssim) 230 | if result[0][index_][index_object][4] > 0.5 and ssim < 0.75: 231 | # print(result[0][index_][index_object]) 232 | mask_all = mask_all + result[1][index_][index_object] 233 | 234 | # # car_number = len(result[1][0]) 235 | 236 | # # for car_index in range(car_number): 237 | # # car_mask = car_mask + result[1][0][car_index] 238 | mask_all_vis = mask_all.astype(np.float) 239 | mask_all_vis = (mask_all_vis*255).astype(np.uint8) 240 | 241 | save_path = cur_npz_path.replace('.npz', '_dynamic.npz') 242 | print(save_path) 243 | 244 | np.savez(save_path, mask_all) 245 | 246 | # cv2.imwrite('seg_mask.png',mask_all_vis) 247 | # a = input('print something') 248 | # print(a) 249 | 250 | 251 | if __name__ == "__main__": 252 | main() -------------------------------------------------------------------------------- /hybrid_evaluate_depth.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from open3d import * 3 | import os 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from utils import write_ply, backproject_depth, v, npy, Thres_metrics_np 8 | 9 | cv2.setNumThreads( 10 | 0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1) 11 | 12 | def compute_errors_perimage(gt, pred, min_depth, max_depth): 13 | valid_mask = (gt > min_depth) & (gt < max_depth) 14 | epe = np.mean(np.abs(gt[valid_mask] - pred[valid_mask])) 15 | abs_rel = np.mean(np.abs(gt[valid_mask] - pred[valid_mask]) / gt[valid_mask]) 16 | sq_rel = np.mean(((gt[valid_mask] - pred[valid_mask])**2) / gt[valid_mask]) 17 | 18 | rmse = (gt - pred)**2 19 | rmse = np.sqrt(rmse.mean()) 20 | 21 | rmse_log = (np.log(gt) - np.log(pred))**2 22 | rmse_log = np.sqrt(rmse_log.mean()) 23 | 24 | thresh = np.maximum((gt / pred), (pred / gt)) 25 | a1 = (thresh < 1.25).mean() 26 | a2 = (thresh < 1.25**2).mean() 27 | a3 = (thresh < 1.25**3).mean() 28 | log10 = np.mean(np.abs(np.log10(pred) - np.log10(gt))) 29 | 30 | return { 31 | 'abs_rel':abs_rel.item(), 32 | 'sq_rel':sq_rel.item(), 33 | 'rmse':rmse.item(), 34 | 'rmse_log':rmse_log.item(), 35 | 'a1':a1.item(), 36 | 'a2':a2.item(), 37 | 'a3':a3.item(), 38 | 'log10':log10.item(), 39 | 'valid_number':1.0, 40 | 'abs_diff':epe.item() 41 | } 42 | 43 | def compute_errors(gt, pred, disable_median_scaling, min_depth, max_depth, 44 | interval): 45 | """Computation of error metrics between predicted and ground truth depths 46 | """ 47 | # if not disable_median_scaling: 48 | # ratio = np.median(gt) / np.median(pred) 49 | # pred *= ratio 50 | 51 | # pred[pred < min_depth] = min_depth 52 | # pred[pred > max_depth] = max_depth 53 | mask = np.logical_and(gt > min_depth, gt < max_depth) 54 | 55 | thresh = np.maximum((gt / pred), (pred / gt)) 56 | a1 = (thresh < 1.25).mean() 57 | a2 = (thresh < 1.25**2).mean() 58 | a3 = (thresh < 1.25**3).mean() 59 | 60 | rmse = (gt - pred)**2 61 | rmse = np.sqrt(rmse.mean()) 62 | 63 | rmse_log = (np.log(gt) - np.log(pred))**2 64 | rmse_log = np.sqrt(rmse_log.mean()) 65 | 66 | abs_rel = np.mean(np.abs(gt[mask] - pred[mask]) / gt[mask]) 67 | print('1',abs_rel) 68 | abs_rel_2 = np.sum(np.abs(gt[mask] - pred[mask]) / gt[mask])/np.sum(mask.astype(np.float32)) 69 | 70 | print('2', abs_rel_2) 71 | abs_diff = np.mean(np.abs(gt - pred)) 72 | # abs_diff_median = np.median(np.abs(gt - pred)) 73 | 74 | sq_rel = np.mean(((gt - pred)**2) / gt) 75 | log10 = np.mean(np.abs(np.log10(pred) - np.log10(gt))) 76 | # mask = np.ones_like(pred) 77 | # thre1 = Thres_metrics_np(pred, gt, mask, 1.0, 0.2) 78 | # thre3 = Thres_metrics_np(pred, gt, mask, 1.0, 0.5) 79 | # thre5 = Thres_metrics_np(pred, gt, mask, 1.0, 1.0) 80 | 81 | result = {} 82 | result['abs_rel'] = abs_rel 83 | result['sq_rel'] = sq_rel 84 | result['rmse'] = rmse 85 | result['rmse_log'] = rmse_log 86 | result['log10'] = log10 87 | result['a1'] = a1 88 | result['a2'] = a2 89 | result['a3'] = a3 90 | result['abs_diff'] = abs_diff 91 | result['total_count'] = 1.0 92 | 93 | return result 94 | 95 | def compute_errors1(gt, pred, disable_median_scaling, min_depth, max_depth, 96 | interval): 97 | """Computation of error metrics between predicted and ground truth depths 98 | """ 99 | # if not disable_median_scaling: 100 | # ratio = np.median(gt) / np.median(pred) 101 | # pred *= ratio 102 | 103 | # pred[pred < min_depth] = min_depth 104 | # pred[pred > max_depth] = max_depth 105 | 106 | thresh = np.maximum((gt / pred), (pred / gt)) 107 | a1 = (thresh < 1.25).mean() 108 | a2 = (thresh < 1.25**2).mean() 109 | a3 = (thresh < 1.25**3).mean() 110 | 111 | rmse = (gt - pred)**2 112 | rmse = np.sqrt(rmse.mean()) 113 | 114 | rmse_log = (np.log(gt) - np.log(pred))**2 115 | rmse_log = np.sqrt(rmse_log.mean()) 116 | 117 | abs_rel = np.mean(np.abs(gt - pred) / gt) 118 | abs_diff = np.mean(np.abs(gt - pred)) 119 | # abs_diff_median = np.median(np.abs(gt - pred)) 120 | 121 | sq_rel = np.mean(((gt - pred)**2) / gt) 122 | log10 = np.mean(np.abs(np.log10(pred) - np.log10(gt))) 123 | # mask = np.ones_like(pred) 124 | # thre1 = Thres_metrics_np(pred, gt, mask, 1.0, 0.2) 125 | # thre3 = Thres_metrics_np(pred, gt, mask, 1.0, 0.5) 126 | # thre5 = Thres_metrics_np(pred, gt, mask, 1.0, 1.0) 127 | 128 | result = {} 129 | result['abs_rel'] = abs_rel 130 | result['sq_rel'] = sq_rel 131 | result['rmse'] = rmse 132 | result['rmse_log'] = rmse_log 133 | result['log10'] = log10 134 | result['a1'] = a1 135 | result['a2'] = a2 136 | result['a3'] = a3 137 | result['abs_diff'] = abs_diff 138 | result['total_count'] = 1.0 139 | 140 | return result 141 | 142 | 143 | # return abs_rel, sq_rel, log10, rmse, rmse_log, a1, a2, a3, abs_diff, abs_diff_median, thre1, thre3, thre5 144 | 145 | 146 | def evaluate_depth_maps(results, config, do_print=False): 147 | errors = [] 148 | 149 | if not os.path.exists(config.save_dir): 150 | os.makedirs(config.save_dir) 151 | 152 | print('eval against gt depth map of size: %sx%d' % 153 | (results[0][1].shape[0], results[0][1].shape[1])) 154 | for i in range(len(results)): 155 | if i % 100 == 0: 156 | print('evaluation : %d/%d' % (i, len(results))) 157 | 158 | gt_depth = results[i][1] 159 | gt_height, gt_width = gt_depth.shape[:2] 160 | pred_depth = results[i][0] 161 | filename = results[i][2] 162 | inv_K = results[i][3] 163 | if gt_width != pred_depth.shape[1] or gt_height != pred_depth.shape[0]: 164 | pred_depth = cv2.resize(pred_depth, (gt_width, gt_height), 165 | interpolation=cv2.INTER_NEAREST) 166 | mask = np.logical_and(gt_depth > config.MIN_DEPTH, 167 | gt_depth < config.MAX_DEPTH) 168 | if not mask.sum(): 169 | continue 170 | 171 | ind = np.where(mask.flatten())[0] 172 | if config.vis: 173 | cam_points = backproject_depth(pred_depth, inv_K, mask=False) 174 | cam_points_gt = backproject_depth(gt_depth, inv_K, mask=False) 175 | write_ply('%s/%s_pred.ply' % (config.save_dir, filename), 176 | cam_points[ind]) 177 | write_ply('%s/%s_gt.ply' % (config.save_dir, filename), 178 | cam_points_gt[ind]) 179 | 180 | dataset = filename.split('_')[0] 181 | interval = (935 - 425) / (128 - 1) # Interval value used by MVSNet 182 | errors.append( 183 | (compute_errors(gt_depth[mask], pred_depth[mask], 184 | config.disable_median_scaling, config.MIN_DEPTH, 185 | config.MAX_DEPTH, interval), dataset, filename)) 186 | 187 | with open('%s/errors.txt' % (config.save_dir), 'w') as f: 188 | for x, _, fID in errors: 189 | tex = fID + ' ' + ' '.join(['%.3f' % y for y in x]) 190 | f.write(tex + '\n') 191 | 192 | np.save('%s/error.npy' % config.save_dir, errors) 193 | results = {} 194 | all_errors = [x[0] for x in errors] 195 | 196 | print(f"total example evaluated: {len(all_errors)}") 197 | all_mean_errors = np.array(all_errors).mean(0) 198 | if do_print: 199 | print("\n all") 200 | print("\n " + 201 | ("{:>8} | " * 202 | 13).format("abs_rel", "sq_rel", "log10", "rmse", "rmse_log", 203 | "a1", "a2", "a3", "abs_diff", "abs_diff_median")) 204 | print(("&{: 8.3f} " * 13).format(*all_mean_errors.tolist()) + "\\\\") 205 | 206 | error_names = [ 207 | "abs_rel", "sq_rel", "log10", "rmse", "rmse_log", "a1", "a2", "a3", 208 | "abs_diff", "abs_diff_median", "thre1", "thre3", "thre5" 209 | ] 210 | results['depth'] = {'error_names': error_names, 'errors': all_mean_errors} 211 | 212 | errors_per_dataset = {} 213 | for x in errors: 214 | key = x[1] 215 | if key not in errors_per_dataset: 216 | errors_per_dataset[key] = [x[0]] 217 | else: 218 | errors_per_dataset[key].append(x[0]) 219 | if config.print_per_dataset_stats: 220 | for key in errors_per_dataset.keys(): 221 | errors_ = errors_per_dataset[key] 222 | mean_errors = np.array(errors_).mean(0) 223 | 224 | print("\n dataset %s: %d" % (key, len(errors_))) 225 | print("\n " + 226 | ("{:>8} | " * 227 | 13).format("abs_rel", "sq_rel", "log10", "rmse", "rmse_log", 228 | "a1", "a2", "a3", "abs_diff", "abs_diff_median", 229 | "thre1", "thre3", "thre5")) 230 | print(("&{: 8.3f} " * 13).format(*mean_errors.tolist()) + "\\\\") 231 | 232 | print("\n-> Done!") 233 | return results 234 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | # from .mvs2d import MVS2D 2 | from .AFNet import MVS2D 3 | # from .AFNet_main import MVS2D 4 | # from .AFNet_mobile import MVS2D 5 | # from .AFNet_efficient import MVS2D -------------------------------------------------------------------------------- /networks/__pycache__/AFNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/AFNet.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/AFNet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/AFNet.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/AFNet_efficient.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/AFNet_efficient.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/AFNet_main.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/AFNet_main.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/AFNet_mobile.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/AFNet_mobile.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/module.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/module.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/module.cpython-39.pyc -------------------------------------------------------------------------------- /networks/__pycache__/mvs2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Junda24/AFNet/1d0416a2ec740519a8c464461962bebc238338bc/networks/__pycache__/mvs2d.cpython-37.pyc -------------------------------------------------------------------------------- /networks/mobilenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.hub import load_state_dict_from_url 3 | 4 | model_urls = { 5 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 6 | } 7 | 8 | 9 | def _make_divisible(v, divisor, min_value=None): 10 | """ 11 | This function is taken from the original tf repo. 12 | It ensures that all layers have a channel number that is divisible by 8 13 | It can be seen here: 14 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 15 | :param v: 16 | :param divisor: 17 | :param min_value: 18 | :return: 19 | """ 20 | if min_value is None: 21 | min_value = divisor 22 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 23 | # Make sure that round down does not go down by more than 10%. 24 | if new_v < 0.9 * v: 25 | new_v += divisor 26 | return new_v 27 | 28 | class ConvBNReLU(nn.Sequential): 29 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None): 30 | padding = (kernel_size - 1) // 2 31 | if norm_layer is None: 32 | norm_layer = nn.BatchNorm2d 33 | super(ConvBNReLU, self).__init__( 34 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 35 | norm_layer(out_planes), 36 | nn.ReLU6(inplace=True) 37 | ) 38 | 39 | 40 | class InvertedResidual(nn.Module): 41 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None): 42 | super(InvertedResidual, self).__init__() 43 | self.stride = stride 44 | assert stride in [1, 2] 45 | 46 | if norm_layer is None: 47 | norm_layer = nn.BatchNorm2d 48 | 49 | hidden_dim = int(round(inp * expand_ratio)) 50 | self.use_res_connect = self.stride == 1 and inp == oup 51 | 52 | layers = [] 53 | if expand_ratio != 1: 54 | # pw 55 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 56 | layers.extend([ 57 | # dw 58 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 59 | # pw-linear 60 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 61 | norm_layer(oup), 62 | ]) 63 | self.conv = nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | if self.use_res_connect: 67 | return x + self.conv(x) 68 | else: 69 | return self.conv(x) 70 | 71 | 72 | class MobileNetV2(nn.Module): 73 | def __init__(self, 74 | num_classes=1000, 75 | width_mult=1.0, 76 | inverted_residual_setting=None, 77 | round_nearest=8, 78 | block=None, 79 | norm_layer=None): 80 | """ 81 | MobileNet V2 main class 82 | Args: 83 | num_classes (int): Number of classes 84 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 85 | inverted_residual_setting: Network structure 86 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 87 | Set to 1 to turn off rounding 88 | block: Module specifying inverted residual building block for mobilenet 89 | norm_layer: Module specifying the normalization layer to use 90 | """ 91 | super(MobileNetV2, self).__init__() 92 | 93 | if block is None: 94 | block = InvertedResidual 95 | 96 | if norm_layer is None: 97 | norm_layer = nn.BatchNorm2d 98 | 99 | input_channel = 32 100 | last_channel = 1280 101 | 102 | if inverted_residual_setting is None: 103 | inverted_residual_setting = [ 104 | # t, c, n, s 105 | [1, 16, 1, 1], 106 | [6, 24, 2, 2], 107 | [6, 32, 3, 2], 108 | [6, 64, 4, 2], 109 | [6, 96, 3, 1], 110 | [6, 160, 3, 2], 111 | [6, 320, 1, 1], 112 | ] 113 | 114 | # only check the first element, assuming user knows t,c,n,s are required 115 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 116 | raise ValueError("inverted_residual_setting should be non-empty " 117 | "or a 4-element list, got {}".format(inverted_residual_setting)) 118 | 119 | # building first layer 120 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 121 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 122 | features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 123 | # building inverted residual blocks 124 | for t, c, n, s in inverted_residual_setting: 125 | output_channel = _make_divisible(c * width_mult, round_nearest) 126 | for i in range(n): 127 | stride = s if i == 0 else 1 128 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 129 | input_channel = output_channel 130 | # building last several layers 131 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 132 | # make it nn.Sequential 133 | self.features = nn.Sequential(*features) 134 | 135 | # building classifier 136 | self.classifier = nn.Sequential( 137 | nn.Dropout(0.2), 138 | nn.Linear(self.last_channel, num_classes), 139 | ) 140 | 141 | # weight initialization 142 | for m in self.modules(): 143 | if isinstance(m, nn.Conv2d): 144 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 145 | if m.bias is not None: 146 | nn.init.zeros_(m.bias) 147 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 148 | nn.init.ones_(m.weight) 149 | nn.init.zeros_(m.bias) 150 | elif isinstance(m, nn.Linear): 151 | nn.init.normal_(m.weight, 0, 0.01) 152 | nn.init.zeros_(m.bias) 153 | 154 | def _forward_impl(self, x): 155 | # This exists since TorchScript doesn't support inheritance, so the superclass method 156 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 157 | x = self.features(x) 158 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 159 | # x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) 160 | # x = self.classifier(x) 161 | return x 162 | 163 | def forward(self, x): 164 | return self._forward_impl(x) 165 | 166 | 167 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 168 | """ 169 | Constructs a MobileNetV2 architecture from 170 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 171 | Args: 172 | pretrained (bool): If True, returns a model pre-trained on ImageNet 173 | progress (bool): If True, displays a progress bar of the download to stderr 174 | """ 175 | model = MobileNetV2(**kwargs) 176 | if pretrained: 177 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 178 | progress=progress) 179 | model.load_state_dict(state_dict) 180 | return model -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import os 3 | import argparse 4 | from pyhocon import ConfigFactory 5 | 6 | 7 | class MVS2DOptions: 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser(description="MVS2D options") 10 | 11 | # PATH OPTIONS 12 | self.parser.add_argument("--log_dir", type=str, help="", default=None) 13 | self.parser.add_argument("--model_name", 14 | type=str, 15 | help="", 16 | default=None) 17 | self.parser.add_argument("--num_epochs", 18 | type=int, 19 | default=None, 20 | help="") 21 | self.parser.add_argument("--overwrite", 22 | type=int, 23 | default=None, 24 | help="") 25 | self.parser.add_argument('--note', type=str, help="", default=None) 26 | self.parser.add_argument("--DECAY_STEP_LIST", 27 | nargs="+", 28 | type=int, 29 | help="", 30 | default=None) 31 | # DATA OPTIONS 32 | self.parser.add_argument( 33 | "--mode", 34 | type=str, 35 | default=None, 36 | choices=["train", "test", "train+test", "full_test", "recon"]) 37 | self.parser.add_argument("--num_workers", 38 | type=int, 39 | help="", 40 | default=None) 41 | self.parser.add_argument("--use_test", type=int, default=None, help="") 42 | self.parser.add_argument("--robust", type=int, default=None, help="") 43 | self.parser.add_argument("--perturb_pose", 44 | type=int, 45 | default=None, 46 | help="") 47 | self.parser.add_argument('--num_frame', 48 | type=int, 49 | help="", 50 | default=None) 51 | self.parser.add_argument('--fullsize_eval', 52 | type=int, 53 | help="", 54 | default=None) 55 | self.parser.add_argument('--filter', nargs="+", type=str, default=None) 56 | # MODEL OPTIONS 57 | self.parser.add_argument("--load_weights_folder", 58 | type=str, 59 | help="", 60 | default=None) 61 | 62 | # TRAINING OPTIONS 63 | 64 | # OPTIMIZATION OPTIONS 65 | 66 | # MULTI-GPU OPTIONS 67 | self.parser.add_argument("--world_size", type=int, default=1, help="") 68 | self.parser.add_argument("--multiprocessing_distributed", 69 | type=int, 70 | default=None, 71 | help="") 72 | self.parser.add_argument('--rank', type=int, help="", default=0) 73 | self.parser.add_argument('--gpu', type=int, help="", default=None) 74 | self.parser.add_argument('--local_rank', type=int, help="", default=0) 75 | self.parser.add_argument('--tcp_port', type=int, default=None, help="") 76 | 77 | # OTHERS 78 | self.parser.add_argument('--save_prediction', 79 | type=int, 80 | help="", 81 | default=None) 82 | self.parser.add_argument("--debug", help="", action="store_true") 83 | # self.parser.add_argument('--cfg', type=str, default="./configs/DDAD_kitti.conf") 84 | self.parser.add_argument('--cfg', type=str) 85 | 86 | self.parser.add_argument('--epoch_size', type=int, default=None) 87 | self.parser.add_argument('--val_epoch_size', type=int, default=None) 88 | 89 | def parse(self): 90 | self.options = self.parser.parse_args() 91 | cfg = ConfigFactory.parse_file(self.options.cfg) 92 | for k in cfg.keys(): 93 | if k not in self.options: 94 | setattr(self.options, k, cfg[k]) 95 | else: 96 | if getattr(self.options, k) is None: 97 | setattr(self.options, k, cfg[k]) 98 | 99 | return self.options 100 | 101 | 102 | class EvalCfg(object): 103 | def __init__(self, 104 | save_dir, 105 | min_depth=1e-3, 106 | max_depth=10.0, 107 | vis=False, 108 | disable_median_scaling=True, 109 | eigen_crop=True, 110 | print_per_dataset_stats=False, 111 | garg_crop=False): 112 | self.save_dir = save_dir 113 | self.vis = vis 114 | self.MIN_DEPTH = min_depth 115 | self.MAX_DEPTH = max_depth 116 | self.eigen_crop = eigen_crop 117 | self.garg_crop = garg_crop 118 | self.disable_median_scaling = disable_median_scaling 119 | self.print_per_dataset_stats = print_per_dataset_stats 120 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.9.0 2 | ipdb==0.12.3 3 | joblib==1.0.1 4 | lz4==3.1.3 5 | matplotlib==3.1.3 6 | numpy==1.20.3 7 | open3d==0.10.0.0 8 | opencv_contrib_python==3.4.2.17 9 | path.py==12.5.0 10 | Pillow==8.4.0 11 | pyhocon==0.3.58 12 | pykdtree==1.3.4 13 | requests==2.26.0 14 | scipy==1.7.1 15 | tensorboardX==2.4.1 16 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | BASE_DIR='./' 3 | EXP_DIR="${BASE_DIR}/experiments/release/ScanNet/exp0" 4 | cfg=./configs/scannet/release.conf 5 | 6 | ## test 7 | CUDA_VISIBLE_DEVICES=0 python train.py --model_name=config0_test --mode=test --cfg $cfg --load_weights_folder=./pretrained_model/scannet/MVS2D --use_test=1 --fullsize_eval=1 8 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # cfg=./configs/DDAD.conf 3 | cfg=./configs/kitti.conf 4 | 5 | 6 | CUDA_VISIBLE_DEVICES=0,1,2 OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=3 train_kitti.py --num_epochs=60 --DECAY_STEP_LIST 30 40 --cfg $cfg --load_weights_folder=/pretrained_model/kitti/ --fullsize_eval=1 --use_test=0 7 | -------------------------------------------------------------------------------- /train_af.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from open3d import * 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | # from tensorboardX import SummaryWriter 7 | from torch.utils.tensorboard import SummaryWriter 8 | import json 9 | from utils import * 10 | import networks 11 | import os 12 | import glob 13 | import random 14 | import torch.optim as optim 15 | from options import MVS2DOptions, EvalCfg 16 | from trainer_base_af import BaseTrainer 17 | from hybrid_evaluate_depth import evaluate_depth_maps, compute_errors,compute_errors1,compute_errors_perimage 18 | from dtu_pyeval import dtu_pyeval 19 | import pprint 20 | import torch.distributed as dist 21 | 22 | 23 | class Trainer(BaseTrainer): 24 | def __init__(self, options): 25 | super(Trainer, self).__init__(options) 26 | 27 | def build_model(self): 28 | self.parameters_to_train = [] 29 | self.model = networks.MVS2D(opt=self.opt).cuda() 30 | self.parameters_to_train += list(self.model.parameters()) 31 | parameters_count(self.model, 'MVS2D') 32 | 33 | # def build_optimizer(self): 34 | # if self.opt.optimizer.lower() == 'adam': 35 | # self.model_optimizer = optim.Adam( 36 | # self.model.parameters(), 37 | # lr=self.opt.LR, 38 | # weight_decay=self.opt.WEIGHT_DECAY) 39 | # elif self.opt.optimizer.lower() == 'sgd': 40 | # self.model_optimizer = optim.SGD( 41 | # self.model.parameters(), 42 | # lr=self.opt.LR, 43 | # weight_decay=self.opt.WEIGHT_DECAY) 44 | 45 | # def val_epoch(self): 46 | # print("Validation") 47 | # writer = self.writers['val'] 48 | # self.set_eval() 49 | # results_depth = [] 50 | # val_loss = [] 51 | # config = EvalCfg( 52 | # eigen_crop=False, 53 | # garg_crop=False, 54 | # min_depth=self.opt.EVAL_MIN_DEPTH, 55 | # max_depth=self.opt.EVAL_MAX_DEPTH, 56 | # vis=self.epoch % 10 == 0 and self.opt.eval_vis, 57 | # disable_median_scaling=self.opt.disable_median_scaling, 58 | # print_per_dataset_stats=self.opt.dataset == 'DeMoN', 59 | # save_dir=os.path.join(self.log_path, 'eval_%03d' % self.epoch)) 60 | # if not os.path.exists(config.save_dir): 61 | # os.makedirs(config.save_dir) 62 | # print('evaluation results save to folder %s' % config.save_dir) 63 | # times = [] 64 | # val_stats = defaultdict(list) 65 | # dict_pred = {} 66 | # dict_pred_2 = {} 67 | # total_result_count = {} 68 | # total_result_count_2 = {} 69 | 70 | # with torch.no_grad(): 71 | # for batch_idx, inputs in enumerate(self.val_loader): 72 | # if self.opt.val_epoch_size != -1 and batch_idx >= self.opt.val_epoch_size: 73 | # break 74 | # if batch_idx % 100 == 0: 75 | # print(batch_idx, len(self.val_loader)) 76 | # # filenames = inputs["filenames"] 77 | # losses, outputs = self.process_batch(inputs, 'val') 78 | # # b = len(inputs["filenames"]) 79 | 80 | # s = 0 81 | # pred_depth = npy(outputs[('depth_pred', s)]) 82 | # pred_depth_2 = npy(outputs[('depth_pred_2', s)]) 83 | # depth_gt = npy(inputs[('depth_gt', 0, s)]) 84 | # mask = np.logical_and(depth_gt > config.MIN_DEPTH, 85 | # depth_gt < config.MAX_DEPTH) 86 | # interval = (935 - 425) / (128 - 1) # Interval value used by MVSNet 87 | # # dict_pred_temp = compute_errors(depth_gt[mask], pred_depth[mask], config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval) 88 | # dict_pred_temp = compute_errors(depth_gt, pred_depth, config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval) 89 | 90 | # dict_pred_temp_2 = compute_errors1(depth_gt[mask], pred_depth_2[mask], config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval) 91 | 92 | # for k, v in dict_pred_temp.items(): 93 | # # print(k,v) 94 | # if k in dict_pred: 95 | # dict_pred[k] = dict_pred[k] + v 96 | # # dict_pred['total_count'] = dict_pred['total_count'] + 1.0 97 | # else: 98 | # dict_pred[k] = v 99 | # # dict_pred['total_count'] = 1.0 100 | 101 | # for k, v in dict_pred_temp_2.items(): 102 | # k = k + '_2' 103 | # if k in dict_pred_2: 104 | # dict_pred_2[k] = dict_pred_2[k] + v 105 | # # dict_pred_2['total_count_2'] = dict_pred_2['total_count_2'] + 1.0 106 | # else: 107 | # dict_pred_2[k] = v 108 | # # dict_pred_2['total_count_2'] = 1.0 109 | # if batch_idx % 80 == 0: 110 | # writer.add_image('image0', inputs[("color", 0, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 111 | # writer.add_image('image1', inputs[("color", 1, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 112 | # writer.add_image('image2', inputs[("color", 2, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 113 | # depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0]) 114 | # writer.add_image('depth_gt', depth_gt, global_step=self.step, walltime=None, dataformats='HWC') 115 | # depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0]) 116 | # writer.add_image('depth_pred', depth_pred, global_step=self.step, walltime=None, dataformats='HWC') 117 | 118 | # # print('total_count', dict_pred['total_count']) 119 | # # print('total_count_2', dict_pred_2['total_count_2']) 120 | # # print('abs_rel', dict_pred['abs_rel']) 121 | 122 | # for k in dict_pred.keys(): 123 | # total_result_count[k] = dict_pred['total_count'] 124 | 125 | # for k in dict_pred_2.keys(): 126 | # total_result_count_2[k] = dict_pred_2['total_count_2'] 127 | 128 | 129 | # for k in dict_pred.keys(): 130 | # this_tensor = torch.tensor([dict_pred[k], total_result_count[k]]).to(self.device) 131 | # this_list = [this_tensor] 132 | # torch.distributed.all_reduce_multigpu(this_list) 133 | # this_tensor = this_list[0].detach().cpu().numpy() 134 | # reduce_sum = this_tensor[0].item() 135 | # reduce_count = this_tensor[1].item() 136 | # reduce_mean = reduce_sum / reduce_count 137 | # if self.is_master: 138 | # writer.add_scalar(k, reduce_mean, self.step) 139 | 140 | # for k in dict_pred_2.keys(): 141 | # this_tensor = torch.tensor([dict_pred_2[k], total_result_count_2[k]]).to(self.device) 142 | # this_list = [this_tensor] 143 | # torch.distributed.all_reduce_multigpu(this_list) 144 | # this_tensor = this_list[0].detach().cpu().numpy() 145 | # reduce_sum = this_tensor[0].item() 146 | # reduce_count = this_tensor[1].item() 147 | # reduce_mean = reduce_sum / reduce_count 148 | # if self.is_master: 149 | # writer.add_scalar(k, reduce_mean, self.step) 150 | 151 | # self.set_train() 152 | 153 | def val_epoch(self): 154 | print("Validation") 155 | writer = self.writers['val'] 156 | self.set_eval() 157 | results_depth = [] 158 | val_loss = [] 159 | config = EvalCfg( 160 | eigen_crop=False, 161 | garg_crop=False, 162 | min_depth=self.opt.EVAL_MIN_DEPTH, 163 | max_depth=self.opt.EVAL_MAX_DEPTH, 164 | vis=self.epoch % 10 == 0 and self.opt.eval_vis, 165 | disable_median_scaling=self.opt.disable_median_scaling, 166 | print_per_dataset_stats=self.opt.dataset == 'DeMoN', 167 | save_dir=os.path.join(self.log_path, 'eval_%03d' % self.epoch)) 168 | if not os.path.exists(config.save_dir) and self.is_master: 169 | os.makedirs(config.save_dir) 170 | print('evaluation results save to folder %s' % config.save_dir) 171 | times = [] 172 | val_stats = defaultdict(list) 173 | total_result_sum = {} 174 | total_result_count = {} 175 | 176 | with torch.no_grad(): 177 | for batch_idx, inputs in enumerate(self.val_loader): 178 | if self.opt.val_epoch_size != -1 and batch_idx >= self.opt.val_epoch_size: 179 | break 180 | if batch_idx % 100 == 0: 181 | print(batch_idx, len(self.val_loader)) 182 | # filenames = inputs["filenames"] 183 | losses, outputs = self.process_batch(inputs, 'val') 184 | # b = len(inputs["filenames"]) 185 | 186 | s = 0 187 | pred_depth = npy(outputs[('depth_pred', s)]) 188 | pred_depth_2 = npy(outputs[('depth_pred_2', s)]) 189 | depth_gt = npy(inputs[('depth_gt', 0, s)]) 190 | mask = np.logical_and(depth_gt > config.MIN_DEPTH, 191 | depth_gt < config.MAX_DEPTH) 192 | error_temp = compute_errors_perimage(depth_gt[mask], pred_depth[mask], config.MIN_DEPTH, config.MAX_DEPTH) 193 | error_temp_2_ = compute_errors_perimage(depth_gt[mask], pred_depth_2[mask], config.MIN_DEPTH, config.MAX_DEPTH) 194 | 195 | error_temp_2 = {} 196 | for k,v in error_temp_2_.items(): 197 | new_k = k + '_2' 198 | error_temp_2[new_k] = error_temp_2_[k] 199 | 200 | error_temp_all = {} 201 | error_temp_all.update(error_temp) 202 | error_temp_all.update(error_temp_2) 203 | 204 | for k,v in error_temp_all.items(): 205 | if not isinstance(v,float): 206 | v=v.items() 207 | if k in total_result_sum: 208 | total_result_sum[k] = total_result_sum[k] + v 209 | else: 210 | total_result_sum[k] = v 211 | self.eval_step += 1 212 | 213 | if self.eval_step % 80 == 0 and self.is_master: 214 | writer.add_image('image0', inputs[("color", 0, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW') 215 | writer.add_image('image1', inputs[("color", 1, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW') 216 | writer.add_image('image2', inputs[("color", 2, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW') 217 | depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0]) 218 | writer.add_image('depth_gt', depth_gt, global_step=self.eval_step, walltime=None, dataformats='HWC') 219 | depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0]) 220 | writer.add_image('depth_pred', depth_pred, global_step=self.eval_step, walltime=None, dataformats='HWC') 221 | depth_pred_2 = gray_2_colormap_np(outputs[('depth_pred_2', 0)][0][0]) 222 | writer.add_image('depth_pred_2', depth_pred_2, global_step=self.eval_step, walltime=None, dataformats='HWC') 223 | 224 | # print('total_count', dict_pred['total_count']) 225 | # print('total_count_2', dict_pred_2['total_count_2']) 226 | # print('abs_rel', dict_pred['abs_rel']) 227 | 228 | for k in total_result_sum.keys(): 229 | total_result_count[k] = total_result_sum['valid_number'] 230 | 231 | 232 | for k in total_result_sum.keys(): 233 | this_tensor = torch.tensor([total_result_sum[k], total_result_count[k]]).to(self.device) 234 | this_list = [this_tensor] 235 | torch.distributed.all_reduce_multigpu(this_list) 236 | torch.distributed.barrier() 237 | this_tensor = this_list[0].detach().cpu().numpy() 238 | reduce_sum = this_tensor[0].item() 239 | reduce_count = this_tensor[1].item() 240 | reduce_mean = reduce_sum / reduce_count 241 | if self.is_master: 242 | writer.add_scalar(k, reduce_mean, self.eval_step) 243 | 244 | self.set_train() 245 | 246 | def process_batch(self, inputs, mode): 247 | self.to_gpu(inputs) 248 | 249 | imgs, proj_mats, pose_mats = [], [], [] 250 | for i in range(inputs['num_frame'][0].item()): 251 | imgs.append(inputs[('color', i, self.opt.input_scale)]) 252 | proj_mats.append(inputs[('proj', i)]) 253 | pose_mats.append(inputs[('pose', i)]) 254 | 255 | # outputs = self.model(imgs[0], imgs[1:], proj_mats[0], proj_mats[1:], 256 | # inputs[('inv_K_pool', 0)]) 257 | outputs = self.model(imgs[0], imgs[1:], pose_mats[0], pose_mats[1:], 258 | inputs[('inv_K_pool', 0)]) 259 | losses = self.compute_losses(inputs, outputs) 260 | return losses, outputs 261 | 262 | def compute_losses(self, inputs, outputs): 263 | losses, loss, s = {}, 0, 0 264 | depth_pred = outputs[('depth_pred', s)] 265 | depth_pred_2 = outputs[('depth_pred_2', s)] 266 | depth_gt = inputs[('depth_gt', 0, s)] 267 | 268 | 269 | valid_depth = (depth_gt > 0) 270 | 271 | # if self.opt.pred_conf: 272 | # log_conf_pred = outputs[('log_conf_pred', s)] 273 | # conf_pred = torch.exp(log_conf_pred) 274 | # min_conf = self.opt.min_conf 275 | # max_conf = self.opt.max_conf if self.opt.max_conf != -1 else None 276 | # conf_pred = conf_pred.clamp(min_conf, max_conf) 277 | # loss_depth = ((depth_pred - depth_gt).abs() / conf_pred + 278 | # log_conf_pred)[valid_depth].mean() 279 | # else: 280 | loss_depth = (depth_pred[valid_depth] - depth_gt[valid_depth]).abs().mean() 281 | 282 | loss_depth_2 = (depth_pred_2[valid_depth] - depth_gt[valid_depth]).abs().mean() 283 | 284 | losses["depth"] = loss_depth 285 | losses["depth_2"] = loss_depth_2 286 | 287 | loss += loss_depth + loss_depth_2 288 | losses["loss"] = loss 289 | 290 | return losses 291 | 292 | 293 | def run_fusion(dense_folder, out_folder, opts): 294 | cmd = f"python patchmatch_fusion.py \ 295 | --dense_folder {dense_folder} \ 296 | --outdir {out_folder} \ 297 | --n_proc 4 \ 298 | --conf_thres {opts.conf_thres} \ 299 | --att_thres {opts.att_thres} \ 300 | --use_conf_thres {opts.pred_conf} \ 301 | --geo_depth_thres {opts.geo_depth_thres} \ 302 | --geo_pixel_thres {opts.geo_pixel_thres} \ 303 | --num_consistent {opts.num_consistent} \ 304 | " 305 | 306 | os.system(cmd) 307 | 308 | 309 | if __name__ == "__main__": 310 | options = MVS2DOptions() 311 | opts = options.parse() 312 | 313 | set_random_seed(666) 314 | 315 | if torch.cuda.device_count() > 1 and not opts.multiprocessing_distributed: 316 | raise Exception( 317 | "Detected more than 1 GPU. Please set multiprocessing_distributed=1 or set CUDA_VISIBLE_DEVICES" 318 | ) 319 | 320 | opts.distributed = opts.world_size > 1 or opts.multiprocessing_distributed 321 | if opts.multiprocessing_distributed: 322 | # total_gpus, opts.rank = init_dist_pytorch(opts.tcp_port, 323 | # opts.local_rank, 324 | # backend='nccl') 325 | print('opts.local_rank', opts.local_rank) 326 | torch.cuda.set_device(opts.local_rank) 327 | dist.init_process_group("nccl", rank=opts.local_rank, world_size=3) 328 | opts.ngpus_per_node = 3 329 | opts.gpu = opts.local_rank 330 | print("Use GPU: {}/{} for training".format(opts.gpu, 331 | opts.ngpus_per_node)) 332 | else: 333 | opts.gpu = 0 334 | 335 | if opts.mode == 'train': 336 | trainer = Trainer(opts) 337 | trainer.train() 338 | 339 | elif opts.mode == 'test': 340 | trainer = Trainer(opts) 341 | trainer.val() 342 | 343 | elif opts.mode == 'full_test': 344 | ## save depth prediction 345 | opts.mode = 'test' 346 | trainer = Trainer(opts) 347 | trainer.val() 348 | 349 | ## fuse dense prediction into final point cloud 350 | dense_folder = f"{opts.log_dir}/{opts.model_name}/eval_000/prediction" 351 | out_folder = f"{opts.log_dir}/{opts.model_name}/recon" 352 | run_fusion(dense_folder, out_folder, opts) 353 | 354 | ## eval point cloud 355 | MeanData, MeanStl, MeanAvg = dtu_pyeval( 356 | f"{out_folder}", 357 | gt_dir='./data/SampleSet/MVS Data/', 358 | voxel_down_sample=False, 359 | fn=f"{out_folder}/result.txt") 360 | -------------------------------------------------------------------------------- /train_kitti.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from open3d import * 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | # from tensorboardX import SummaryWriter 7 | from torch.utils.tensorboard import SummaryWriter 8 | import json 9 | from utils import * 10 | import networks 11 | import os 12 | import glob 13 | import random 14 | import torch.optim as optim 15 | from options import MVS2DOptions, EvalCfg 16 | from trainer_base_kitti import BaseTrainer 17 | from hybrid_evaluate_depth import evaluate_depth_maps, compute_errors,compute_errors1,compute_errors_perimage 18 | from dtu_pyeval import dtu_pyeval 19 | import pprint 20 | import torch.distributed as dist 21 | 22 | 23 | class Trainer(BaseTrainer): 24 | def __init__(self, options): 25 | super(Trainer, self).__init__(options) 26 | 27 | def build_model(self): 28 | self.parameters_to_train = [] 29 | self.model = networks.MVS2D(opt=self.opt).cuda() 30 | self.parameters_to_train += list(self.model.parameters()) 31 | parameters_count(self.model, 'MVS2D') 32 | 33 | # def build_optimizer(self): 34 | # if self.opt.optimizer.lower() == 'adam': 35 | # self.model_optimizer = optim.Adam( 36 | # self.model.parameters(), 37 | # lr=self.opt.LR, 38 | # weight_decay=self.opt.WEIGHT_DECAY) 39 | # elif self.opt.optimizer.lower() == 'sgd': 40 | # self.model_optimizer = optim.SGD( 41 | # self.model.parameters(), 42 | # lr=self.opt.LR, 43 | # weight_decay=self.opt.WEIGHT_DECAY) 44 | 45 | # def val_epoch(self): 46 | # print("Validation") 47 | # writer = self.writers['val'] 48 | # self.set_eval() 49 | # results_depth = [] 50 | # val_loss = [] 51 | # config = EvalCfg( 52 | # eigen_crop=False, 53 | # garg_crop=False, 54 | # min_depth=self.opt.EVAL_MIN_DEPTH, 55 | # max_depth=self.opt.EVAL_MAX_DEPTH, 56 | # vis=self.epoch % 10 == 0 and self.opt.eval_vis, 57 | # disable_median_scaling=self.opt.disable_median_scaling, 58 | # print_per_dataset_stats=self.opt.dataset == 'DeMoN', 59 | # save_dir=os.path.join(self.log_path, 'eval_%03d' % self.epoch)) 60 | # if not os.path.exists(config.save_dir): 61 | # os.makedirs(config.save_dir) 62 | # print('evaluation results save to folder %s' % config.save_dir) 63 | # times = [] 64 | # val_stats = defaultdict(list) 65 | # dict_pred = {} 66 | # dict_pred_2 = {} 67 | # total_result_count = {} 68 | # total_result_count_2 = {} 69 | 70 | # with torch.no_grad(): 71 | # for batch_idx, inputs in enumerate(self.val_loader): 72 | # if self.opt.val_epoch_size != -1 and batch_idx >= self.opt.val_epoch_size: 73 | # break 74 | # if batch_idx % 100 == 0: 75 | # print(batch_idx, len(self.val_loader)) 76 | # # filenames = inputs["filenames"] 77 | # losses, outputs = self.process_batch(inputs, 'val') 78 | # # b = len(inputs["filenames"]) 79 | 80 | # s = 0 81 | # pred_depth = npy(outputs[('depth_pred', s)]) 82 | # pred_depth_2 = npy(outputs[('depth_pred_2', s)]) 83 | # depth_gt = npy(inputs[('depth_gt', 0, s)]) 84 | # mask = np.logical_and(depth_gt > config.MIN_DEPTH, 85 | # depth_gt < config.MAX_DEPTH) 86 | # interval = (935 - 425) / (128 - 1) # Interval value used by MVSNet 87 | # # dict_pred_temp = compute_errors(depth_gt[mask], pred_depth[mask], config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval) 88 | # dict_pred_temp = compute_errors(depth_gt, pred_depth, config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval) 89 | 90 | # dict_pred_temp_2 = compute_errors1(depth_gt[mask], pred_depth_2[mask], config.disable_median_scaling, config.MIN_DEPTH, config.MAX_DEPTH, interval) 91 | 92 | # for k, v in dict_pred_temp.items(): 93 | # # print(k,v) 94 | # if k in dict_pred: 95 | # dict_pred[k] = dict_pred[k] + v 96 | # # dict_pred['total_count'] = dict_pred['total_count'] + 1.0 97 | # else: 98 | # dict_pred[k] = v 99 | # # dict_pred['total_count'] = 1.0 100 | 101 | # for k, v in dict_pred_temp_2.items(): 102 | # k = k + '_2' 103 | # if k in dict_pred_2: 104 | # dict_pred_2[k] = dict_pred_2[k] + v 105 | # # dict_pred_2['total_count_2'] = dict_pred_2['total_count_2'] + 1.0 106 | # else: 107 | # dict_pred_2[k] = v 108 | # # dict_pred_2['total_count_2'] = 1.0 109 | # if batch_idx % 80 == 0: 110 | # writer.add_image('image0', inputs[("color", 0, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 111 | # writer.add_image('image1', inputs[("color", 1, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 112 | # writer.add_image('image2', inputs[("color", 2, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 113 | # depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0]) 114 | # writer.add_image('depth_gt', depth_gt, global_step=self.step, walltime=None, dataformats='HWC') 115 | # depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0]) 116 | # writer.add_image('depth_pred', depth_pred, global_step=self.step, walltime=None, dataformats='HWC') 117 | 118 | # # print('total_count', dict_pred['total_count']) 119 | # # print('total_count_2', dict_pred_2['total_count_2']) 120 | # # print('abs_rel', dict_pred['abs_rel']) 121 | 122 | # for k in dict_pred.keys(): 123 | # total_result_count[k] = dict_pred['total_count'] 124 | 125 | # for k in dict_pred_2.keys(): 126 | # total_result_count_2[k] = dict_pred_2['total_count_2'] 127 | 128 | 129 | # for k in dict_pred.keys(): 130 | # this_tensor = torch.tensor([dict_pred[k], total_result_count[k]]).to(self.device) 131 | # this_list = [this_tensor] 132 | # torch.distributed.all_reduce_multigpu(this_list) 133 | # this_tensor = this_list[0].detach().cpu().numpy() 134 | # reduce_sum = this_tensor[0].item() 135 | # reduce_count = this_tensor[1].item() 136 | # reduce_mean = reduce_sum / reduce_count 137 | # if self.is_master: 138 | # writer.add_scalar(k, reduce_mean, self.step) 139 | 140 | # for k in dict_pred_2.keys(): 141 | # this_tensor = torch.tensor([dict_pred_2[k], total_result_count_2[k]]).to(self.device) 142 | # this_list = [this_tensor] 143 | # torch.distributed.all_reduce_multigpu(this_list) 144 | # this_tensor = this_list[0].detach().cpu().numpy() 145 | # reduce_sum = this_tensor[0].item() 146 | # reduce_count = this_tensor[1].item() 147 | # reduce_mean = reduce_sum / reduce_count 148 | # if self.is_master: 149 | # writer.add_scalar(k, reduce_mean, self.step) 150 | 151 | # self.set_train() 152 | 153 | def val_epoch(self): 154 | print("Validation") 155 | writer = self.writers['val'] 156 | self.set_eval() 157 | results_depth = [] 158 | val_loss = [] 159 | config = EvalCfg( 160 | eigen_crop=False, 161 | garg_crop=False, 162 | min_depth=self.opt.EVAL_MIN_DEPTH, 163 | max_depth=self.opt.EVAL_MAX_DEPTH, 164 | vis=self.epoch % 10 == 0 and self.opt.eval_vis, 165 | disable_median_scaling=self.opt.disable_median_scaling, 166 | print_per_dataset_stats=self.opt.dataset == 'DeMoN', 167 | save_dir=os.path.join(self.log_path, 'eval_%03d' % self.epoch)) 168 | if not os.path.exists(config.save_dir) and self.is_master: 169 | os.makedirs(config.save_dir) 170 | print('evaluation results save to folder %s' % config.save_dir) 171 | times = [] 172 | val_stats = defaultdict(list) 173 | total_result_sum = {} 174 | total_result_count = {} 175 | 176 | with torch.no_grad(): 177 | for batch_idx, inputs in enumerate(self.val_loader): 178 | if self.opt.val_epoch_size != -1 and batch_idx >= self.opt.val_epoch_size: 179 | break 180 | if batch_idx % 100 == 0: 181 | print(batch_idx, len(self.val_loader)) 182 | # filenames = inputs["filenames"] 183 | losses, outputs = self.process_batch(inputs, 'val') 184 | # b = len(inputs["filenames"]) 185 | 186 | s = 0 187 | pred_depth = npy(outputs[('depth_pred', s)]) 188 | pred_depth_2 = npy(outputs[('depth_pred_2', s)]) 189 | depth_gt = npy(inputs[('depth_gt', 0, s)]) 190 | mask = np.logical_and(depth_gt > config.MIN_DEPTH, 191 | depth_gt < config.MAX_DEPTH) 192 | error_temp = compute_errors_perimage(depth_gt[mask], pred_depth[mask], config.MIN_DEPTH, config.MAX_DEPTH) 193 | error_temp_2_ = compute_errors_perimage(depth_gt[mask], pred_depth_2[mask], config.MIN_DEPTH, config.MAX_DEPTH) 194 | 195 | error_temp_2 = {} 196 | for k,v in error_temp_2_.items(): 197 | new_k = k + '_2' 198 | error_temp_2[new_k] = error_temp_2_[k] 199 | 200 | error_temp_all = {} 201 | error_temp_all.update(error_temp) 202 | error_temp_all.update(error_temp_2) 203 | 204 | for k,v in error_temp_all.items(): 205 | if not isinstance(v,float): 206 | v=v.items() 207 | if k in total_result_sum: 208 | total_result_sum[k] = total_result_sum[k] + v 209 | else: 210 | total_result_sum[k] = v 211 | self.eval_step += 1 212 | 213 | if self.eval_step % 80 == 0 and self.is_master: 214 | writer.add_image('image0', inputs[("img_ori", 0, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW') 215 | writer.add_image('image1', inputs[("img_ori", 1, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW') 216 | writer.add_image('image2', inputs[("img_ori", 2, 0)][0], global_step=self.eval_step, walltime=None, dataformats='CHW') 217 | depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0]) 218 | writer.add_image('depth_gt', depth_gt, global_step=self.eval_step, walltime=None, dataformats='HWC') 219 | depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0]) 220 | writer.add_image('depth_pred', depth_pred, global_step=self.eval_step, walltime=None, dataformats='HWC') 221 | depth_pred_2 = gray_2_colormap_np(outputs[('depth_pred_2', 0)][0][0]) 222 | writer.add_image('depth_pred_2', depth_pred_2, global_step=self.eval_step, walltime=None, dataformats='HWC') 223 | 224 | # print('total_count', dict_pred['total_count']) 225 | # print('total_count_2', dict_pred_2['total_count_2']) 226 | # print('abs_rel', dict_pred['abs_rel']) 227 | 228 | for k in total_result_sum.keys(): 229 | total_result_count[k] = total_result_sum['valid_number'] 230 | 231 | 232 | for k in total_result_sum.keys(): 233 | this_tensor = torch.tensor([total_result_sum[k], total_result_count[k]]).to(self.device) 234 | this_list = [this_tensor] 235 | torch.distributed.all_reduce_multigpu(this_list) 236 | torch.distributed.barrier() 237 | this_tensor = this_list[0].detach().cpu().numpy() 238 | reduce_sum = this_tensor[0].item() 239 | reduce_count = this_tensor[1].item() 240 | reduce_mean = reduce_sum / reduce_count 241 | if self.is_master: 242 | writer.add_scalar(k, reduce_mean, self.eval_step) 243 | 244 | self.set_train() 245 | 246 | def process_batch(self, inputs, mode): 247 | self.to_gpu(inputs) 248 | 249 | imgs, proj_mats, pose_mats = [], [], [] 250 | for i in range(inputs['num_frame'][0].item()): 251 | imgs.append(inputs[('color', i, self.opt.input_scale)]) 252 | proj_mats.append(inputs[('proj', i)]) 253 | pose_mats.append(inputs[('pose', i)]) 254 | 255 | # outputs = self.model(imgs[0], imgs[1:], proj_mats[0], proj_mats[1:], 256 | # inputs[('inv_K_pool', 0)]) 257 | outputs = self.model(imgs[0], imgs[1:], pose_mats[0], pose_mats[1:], 258 | inputs[('inv_K_pool', 0)]) 259 | losses = self.compute_losses(inputs, outputs) 260 | return losses, outputs 261 | 262 | def compute_losses(self, inputs, outputs): 263 | losses, loss, s = {}, 0, 0 264 | depth_pred = outputs[('depth_pred', s)] 265 | depth_pred_2 = outputs[('depth_pred_2', s)] 266 | depth_gt = inputs[('depth_gt', 0, s)] 267 | 268 | 269 | valid_depth = (depth_gt > 0) 270 | 271 | # if self.opt.pred_conf: 272 | # log_conf_pred = outputs[('log_conf_pred', s)] 273 | # conf_pred = torch.exp(log_conf_pred) 274 | # min_conf = self.opt.min_conf 275 | # max_conf = self.opt.max_conf if self.opt.max_conf != -1 else None 276 | # conf_pred = conf_pred.clamp(min_conf, max_conf) 277 | # loss_depth = ((depth_pred - depth_gt).abs() / conf_pred + 278 | # log_conf_pred)[valid_depth].mean() 279 | # else: 280 | loss_depth = (depth_pred[valid_depth] - depth_gt[valid_depth]).abs().mean() 281 | 282 | loss_depth_2 = (depth_pred_2[valid_depth] - depth_gt[valid_depth]).abs().mean() 283 | 284 | losses["depth"] = loss_depth 285 | losses["depth_2"] = loss_depth_2 286 | 287 | loss += loss_depth + loss_depth_2 288 | losses["loss"] = loss 289 | 290 | return losses 291 | 292 | 293 | def run_fusion(dense_folder, out_folder, opts): 294 | cmd = f"python patchmatch_fusion.py \ 295 | --dense_folder {dense_folder} \ 296 | --outdir {out_folder} \ 297 | --n_proc 4 \ 298 | --conf_thres {opts.conf_thres} \ 299 | --att_thres {opts.att_thres} \ 300 | --use_conf_thres {opts.pred_conf} \ 301 | --geo_depth_thres {opts.geo_depth_thres} \ 302 | --geo_pixel_thres {opts.geo_pixel_thres} \ 303 | --num_consistent {opts.num_consistent} \ 304 | " 305 | 306 | os.system(cmd) 307 | 308 | 309 | if __name__ == "__main__": 310 | options = MVS2DOptions() 311 | opts = options.parse() 312 | 313 | set_random_seed(666) 314 | 315 | if torch.cuda.device_count() > 1 and not opts.multiprocessing_distributed: 316 | raise Exception( 317 | "Detected more than 1 GPU. Please set multiprocessing_distributed=1 or set CUDA_VISIBLE_DEVICES" 318 | ) 319 | 320 | opts.distributed = opts.world_size > 1 or opts.multiprocessing_distributed 321 | if opts.multiprocessing_distributed: 322 | # total_gpus, opts.rank = init_dist_pytorch(opts.tcp_port, 323 | # opts.local_rank, 324 | # backend='nccl') 325 | print('opts.local_rank', opts.local_rank) 326 | torch.cuda.set_device(opts.local_rank) 327 | dist.init_process_group("nccl", rank=opts.local_rank, world_size=3) 328 | opts.ngpus_per_node = 3 329 | opts.gpu = opts.local_rank 330 | print("Use GPU: {}/{} for training".format(opts.gpu, 331 | opts.ngpus_per_node)) 332 | else: 333 | opts.gpu = 0 334 | 335 | if opts.mode == 'train': 336 | trainer = Trainer(opts) 337 | trainer.train() 338 | 339 | elif opts.mode == 'test': 340 | trainer = Trainer(opts) 341 | trainer.val() 342 | 343 | elif opts.mode == 'full_test': 344 | ## save depth prediction 345 | opts.mode = 'test' 346 | trainer = Trainer(opts) 347 | trainer.val() 348 | 349 | ## fuse dense prediction into final point cloud 350 | dense_folder = f"{opts.log_dir}/{opts.model_name}/eval_000/prediction" 351 | out_folder = f"{opts.log_dir}/{opts.model_name}/recon" 352 | run_fusion(dense_folder, out_folder, opts) 353 | 354 | ## eval point cloud 355 | MeanData, MeanStl, MeanAvg = dtu_pyeval( 356 | f"{out_folder}", 357 | gt_dir='./data/SampleSet/MVS Data/', 358 | voxel_down_sample=False, 359 | fn=f"{out_folder}/result.txt") 360 | -------------------------------------------------------------------------------- /trainer_base_af.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from open3d import * 3 | import numpy as np 4 | import time 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | # from tensorboardX import SummaryWriter 11 | from torch.utils.tensorboard import SummaryWriter 12 | import json 13 | from utils import * 14 | import os 15 | import stat 16 | import glob 17 | import shutil 18 | from torch.autograd import Variable 19 | import torch.optim.lr_scheduler as lr_sched 20 | from options import MVS2DOptions 21 | import torch.backends.cudnn as cudnn 22 | 23 | cudnn.benchmark = True 24 | 25 | g = torch.Generator() 26 | g.manual_seed(0) 27 | 28 | 29 | def worker_init_fn(worker_id): 30 | seed = np.random.get_state()[1][0] + worker_id 31 | np.random.seed(seed) 32 | import random 33 | random.seed(seed) 34 | 35 | 36 | def file_remove_readonly(func, path, execinfo): 37 | os.chmod(path, stat.S_IWUSR)#修改文件权限 38 | func(path) 39 | 40 | 41 | class BaseTrainer(object): 42 | def __init__(self, options): 43 | 44 | self.is_best = {} 45 | self.epoch = 0 46 | self.step = 0 47 | self.eval_step = 0 48 | self.opt = options 49 | self.is_master = self.opt.gpu == 0 50 | self.opt.is_master = self.opt.gpu == 0 51 | self.device = self.opt.gpu 52 | 53 | 54 | # base_dir = '.' 55 | # self.opt.log_dir = os.path.join(base_dir, self.opt.log_dir) 56 | self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name) 57 | if self.opt.is_master: 58 | if os.path.exists(self.log_path) and self.opt.overwrite: 59 | try: 60 | shutil.rmtree(self.log_path) 61 | # shutil.rmtree(self.log_path, onerror=file_remove_readonly) 62 | except: 63 | print('overwrite folder failed') 64 | self.log_file = os.path.join(self.log_path, "log.txt") 65 | 66 | self.writers = {} 67 | for mode in ["train", "val"]: 68 | self.writers[mode] = SummaryWriter( 69 | os.path.join(self.log_path, mode)) 70 | if self.opt.is_master: 71 | with open(self.log_file, 'w') as f: 72 | f.write(self.opt.note + '\n') 73 | 74 | self.save_opts() 75 | 76 | self.build_dataset() 77 | 78 | self.build_model() 79 | 80 | # self.build_optimizer() 81 | self.fetch_optimizer() 82 | 83 | if self.opt.load_weights_folder is not None: 84 | self.load_model() 85 | 86 | if self.opt.distributed: 87 | if self.opt.gpu is not None: 88 | print( 89 | f"batch size on GPU: {self.opt.gpu}: {self.opt.batch_size}" 90 | ) 91 | 92 | self.model = torch.nn.parallel.DistributedDataParallel( 93 | self.model, 94 | device_ids=[self.opt.gpu], 95 | find_unused_parameters=True) 96 | else: 97 | model = torch.nn.parallel.DistributedDataParallel( 98 | self.model, find_unused_parameters=True) 99 | 100 | # self.build_scheduler() 101 | 102 | self.total_data_time = 0 103 | self.total_op_time = 0 104 | if self.opt.epoch_size == -1: 105 | self.opt.epoch_size = len(self.train_loader) 106 | 107 | if self.opt.is_master: 108 | print("Training model named:\n ", self.opt.model_name) 109 | print("Models and tensorboard events files are saved to:\n ", 110 | self.opt.log_dir) 111 | 112 | self.num_total_steps = len(self.train_loader) * self.opt.num_epochs 113 | print("There are {:d} training items and {:d} validation items\n". 114 | format( 115 | len(self.train_loader) * self.opt.batch_size, 116 | len(self.val_loader) * 1)) 117 | 118 | # def build_optimizer(self): 119 | # optimizer = optim.Adam(self.model.parameters(), 120 | # lr=self.opt.LR, 121 | # weight_decay=self.opt.WEIGHT_DECAY) 122 | 123 | # self.model_optimizer = optimizer 124 | 125 | 126 | # def build_scheduler(self): 127 | # total_iters_each_epoch = len(self.train_loader) 128 | # decay_steps = [ 129 | # x * total_iters_each_epoch for x in self.opt.DECAY_STEP_LIST 130 | # ] 131 | # total_steps = total_iters_each_epoch * self.opt.num_epochs 132 | 133 | # def lr_lbmd(cur_epoch): 134 | # cur_decay = 1 135 | # for decay_step in decay_steps: 136 | # if cur_epoch >= decay_step: 137 | # cur_decay = cur_decay * self.opt.LR_DECAY 138 | # return max(cur_decay, self.opt.LR_CLIP / self.opt.LR) 139 | 140 | # self.model_lr_scheduler = lr_sched.LambdaLR(self.model_optimizer, 141 | # lr_lbmd, 142 | # last_epoch=-1) 143 | 144 | def fetch_optimizer(self): 145 | """ Create the optimizer and learning rate scheduler """ 146 | total_iters_each_epoch = len(self.train_loader) 147 | total_steps = total_iters_each_epoch * self.opt.num_epochs 148 | optimizer = optim.AdamW(self.model.parameters(), lr=self.opt.LR, weight_decay=.00001, eps=1e-8) 149 | 150 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, self.opt.LR, total_steps+100, 151 | pct_start=0.01, cycle_momentum=False, anneal_strategy='linear') 152 | self.model_lr_scheduler = scheduler 153 | self.model_optimizer = optimizer 154 | 155 | 156 | def to_gpu(self, inputs, keys=None): 157 | if keys == None: 158 | keys = inputs.keys() 159 | for key in keys: 160 | if key not in inputs: 161 | continue 162 | ipt = inputs[key] 163 | if type(ipt) == torch.Tensor: 164 | inputs[key] = ipt.cuda(self.opt.gpu, non_blocking=True) 165 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor: 166 | inputs[key] = [ 167 | x.cuda(self.opt.gpu, non_blocking=True) for x in ipt 168 | ] 169 | elif type(ipt) == dict: 170 | for k in ipt.keys(): 171 | if type(ipt[k]) == torch.Tensor: 172 | ipt[k] = ipt[k].cuda(self.opt.gpu, non_blocking=True) 173 | 174 | def build_dataset(self): 175 | # if self.opt.dataset == 'ScanNet': 176 | # from datasets.ScanNet import ScanNet as Dataset 177 | # elif self.opt.dataset == 'DeMoN': 178 | # from datasets.DeMoN import DeMoN as Dataset 179 | # elif self.opt.dataset == 'DTU': 180 | # from datasets.DTU import DTU as Dataset 181 | # else: 182 | # raise Exception("Unknown Dataset") 183 | from datasets.DDAD import DDAD 184 | 185 | train_dataset = DDAD(self.opt, True) 186 | if self.opt.distributed: 187 | self.train_sampler = torch.utils.data.distributed.DistributedSampler( 188 | train_dataset) 189 | else: 190 | self.train_sampler = None 191 | self.train_loader = DataLoader(train_dataset, 192 | self.opt.batch_size, 193 | shuffle=(self.train_sampler is None), 194 | num_workers=self.opt.num_workers, 195 | pin_memory=True, 196 | worker_init_fn=worker_init_fn, 197 | drop_last=True, 198 | sampler=self.train_sampler) 199 | 200 | val_dataset = DDAD(self.opt, False) 201 | if self.opt.distributed: 202 | self.val_sampler = torch.utils.data.distributed.DistributedSampler( 203 | val_dataset) 204 | else: 205 | self.val_sampler = None 206 | self.val_loader = DataLoader(val_dataset, 207 | self.opt.batch_size, 208 | shuffle=False, 209 | num_workers=self.opt.num_workers, 210 | pin_memory=True, 211 | worker_init_fn=worker_init_fn, 212 | drop_last=False, 213 | sampler=self.val_sampler) 214 | 215 | # val_dataset = DDAD(self.opt, False) 216 | # self.val_sampler = None 217 | # self.val_loader = DataLoader(val_dataset, 218 | # 1, 219 | # shuffle=False, 220 | # num_workers=self.opt.num_workers, 221 | # pin_memory=True, 222 | # drop_last=False, 223 | # sampler=self.val_sampler) 224 | 225 | def log_time(self, batch_idx, op_time, step_time, loss): 226 | """Print a logging statement to the terminal 227 | """ 228 | if self.opt.distributed: 229 | ops_per_sec = self.opt.ngpus_per_node * self.opt.batch_size / op_time 230 | steps_per_sec = self.opt.ngpus_per_node * self.opt.batch_size / step_time 231 | else: 232 | ops_per_sec = self.opt.batch_size / op_time 233 | steps_per_sec = self.opt.batch_size / step_time 234 | time_sofar = time.time() - self.start_time 235 | 236 | training_time_left = (self.num_total_steps / self.step - 237 | 1.0) * time_sofar if self.step > 0 else 0 238 | print_string = "epoch {:>3} | batch {:>6}/{:>6} | ops/s: {:5.1f} | steps/s: {:5.1f} | t_data/t_op: {:5.1f} " + \ 239 | " | loss: {:.5f} | time elapsed: {} | time left: {} | lr: {:.7f}" 240 | self.log_string( 241 | print_string.format(self.epoch, batch_idx, len(self.train_loader), 242 | ops_per_sec, steps_per_sec, 243 | self.total_data_time / self.total_op_time, 244 | loss, sec_to_hm_str(time_sofar), 245 | sec_to_hm_str(training_time_left), 246 | self.model_optimizer.param_groups[0]['lr'])) 247 | 248 | def train_epoch(self): 249 | if self.opt.is_master: 250 | print("Training") 251 | self.writers['train'].add_scalar( 252 | "lr", self.model_optimizer.param_groups[0]['lr'], self.step) 253 | self.set_train() 254 | before_data_loader_time = time.time() 255 | time_last_step = time.time() 256 | 257 | if self.opt.epoch_size == 0: 258 | return 259 | 260 | for batch_idx, inputs in enumerate(self.train_loader): 261 | if batch_idx >= self.opt.epoch_size: 262 | break 263 | after_data_loader_time = time.time() 264 | duration_data = after_data_loader_time - before_data_loader_time 265 | self.total_data_time += duration_data 266 | before_op_time = time.time() 267 | 268 | self.model_lr_scheduler.step(self.step) 269 | 270 | if self.opt.is_master: 271 | try: 272 | cur_lr = float(self.model_optimizer.lr) 273 | except: 274 | cur_lr = self.model_optimizer.param_groups[0]['lr'] 275 | 276 | self.writers['train'].add_scalar('meta_data/learning_rate', 277 | cur_lr, self.step) 278 | 279 | self.model_optimizer.zero_grad() 280 | losses, outputs = self.process_batch(inputs, 'train') 281 | losses['loss'].backward() 282 | 283 | torch.nn.utils.clip_grad_norm_(self.parameters_to_train, 284 | self.opt.GRAD_NORM_CLIP) 285 | 286 | contain_nan = False 287 | for weight in self.parameters_to_train: 288 | if weight.grad is not None: 289 | if torch.any(torch.isnan(weight.grad)): 290 | print('skip parameters update because of nan in grad') 291 | contain_nan = True 292 | if not contain_nan: 293 | self.model_optimizer.step() 294 | 295 | duration = time.time() - before_op_time 296 | self.total_op_time += duration 297 | 298 | if self.opt.is_master and batch_idx % self.opt.log_frequency == 0: 299 | duration_step = time.time() - time_last_step 300 | self.log_time(batch_idx, duration, duration_step, 301 | losses["loss"].cpu().data) 302 | self.log("train", inputs, losses, batch_idx, outputs) 303 | self.step += 1 304 | before_data_loader_time = time.time() 305 | time_last_step = time.time() 306 | 307 | def update_monitor_key(self, metrics, keys, goals): 308 | if len(keys): 309 | if type(keys) != list: 310 | keys = [keys] 311 | for key, goal in zip(keys, goals): 312 | val = metrics[key] 313 | if not hasattr(self, key): 314 | setattr(self, key, val) 315 | self.is_best[key] = True 316 | else: 317 | if goal == 'minimize': 318 | if val < getattr(self, key): 319 | self.is_best[key] = True 320 | setattr(self, key, val) 321 | else: 322 | self.is_best[key] = False 323 | elif goal == 'maximize': 324 | if val > getattr(self, key): 325 | self.is_best[key] = True 326 | setattr(self, key, val) 327 | else: 328 | self.is_best[key] = False 329 | 330 | def set_train(self): 331 | self.model.train() 332 | 333 | def set_eval(self): 334 | self.model.eval() 335 | 336 | def train(self): 337 | self.start_time = time.time() 338 | if self.opt.is_master: 339 | print("Total epoch: %d " % self.opt.num_epochs) 340 | print("train loader size: %d " % len(self.train_loader)) 341 | print("val loader size: %d " % len(self.val_loader)) 342 | print("log_frequency: %d " % self.opt.log_frequency) 343 | for self.epoch in range(self.opt.num_epochs): 344 | if self.opt.distributed: 345 | self.train_sampler.set_epoch(self.epoch) 346 | self.train_epoch() 347 | # if self.opt.is_master: 348 | self.val_epoch() 349 | torch.distributed.barrier() 350 | torch.cuda.empty_cache() 351 | if self.opt.is_master: 352 | self.save_model(monitor_key=self.opt.monitor_key) 353 | 354 | def val(self): 355 | self.val_epoch() 356 | 357 | def process_batch(self, inputs, mode): 358 | raise Exception("Need to implement process_batch") 359 | 360 | def compute_losses(self, inputs, outputs): 361 | raise Exception("Need to implement compute_losses") 362 | 363 | def log_string(self, content): 364 | with open(self.log_file, 'a') as f: 365 | f.write(content + '\n') 366 | print(content, flush=True) 367 | 368 | def log(self, mode, inputs, losses, batch_idx, outputs): 369 | """Write an event to the tensorboard events file 370 | """ 371 | writer = self.writers[mode] 372 | for l, v in losses.items(): 373 | if type(losses[l]) == dict: 374 | writer.add_scalars("{}".format(l), v, self.step) 375 | else: 376 | writer.add_scalar("{}".format(l), v, self.step) 377 | 378 | if batch_idx % 150 == 0: 379 | writer.add_image('image0', inputs[("color", 0, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 380 | writer.add_image('image1', inputs[("color", 1, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 381 | writer.add_image('image2', inputs[("color", 2, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 382 | depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0]) 383 | writer.add_image('depth_gt', depth_gt, global_step=self.step, walltime=None, dataformats='HWC') 384 | depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0]) 385 | writer.add_image('depth_pred', depth_pred, global_step=self.step, walltime=None, dataformats='HWC') 386 | depth_pred_2 = gray_2_colormap_np(outputs[('depth_pred_2', 0)][0][0]) 387 | writer.add_image('depth_pred_2', depth_pred_2, global_step=self.step, walltime=None, dataformats='HWC') 388 | 389 | 390 | def save_opts(self): 391 | """Save options to disk so we know what we ran this experiment with 392 | """ 393 | models_dir = os.path.join(self.log_path, "models") 394 | if not os.path.exists(models_dir): 395 | os.makedirs(models_dir) 396 | to_save = self.opt.__dict__.copy() 397 | 398 | with open(os.path.join(models_dir, 'opt.json'), 'w') as f: 399 | json.dump(to_save, f, indent=2, sort_keys=True) 400 | 401 | def clean_models(self, keep_ids): 402 | models = glob.glob(os.path.join(self.log_path, "models", "weights_*")) 403 | models = sorted(models, 404 | key=lambda x: int(x.split('/')[-1].split('_')[-1])) 405 | for i in range(len(models) - 1): 406 | epoch = int(models[i].split('/')[-1].split('_')[-1]) 407 | if epoch not in keep_ids: 408 | shutil.rmtree(models[i]) 409 | 410 | def save_model(self, monitor_key=""): 411 | """Save model weights to disk 412 | """ 413 | save_folder = os.path.join(self.log_path, "models", "weights_latest") 414 | if not os.path.exists(save_folder): 415 | os.makedirs(save_folder) 416 | print("save model to folder %s" % save_folder) 417 | 418 | save_path = os.path.join(save_folder, "model_{}.pth".format(self.epoch)) 419 | if self.opt.distributed: 420 | to_save = self.model.module.state_dict() 421 | else: 422 | to_save = self.model.state_dict() 423 | 424 | torch.save(to_save, save_path) 425 | save_path_opt = os.path.join(save_folder, "{}.pth".format("adam")) 426 | torch.save(self.model_optimizer.state_dict(), save_path_opt) 427 | # if len(monitor_key): 428 | # if type(monitor_key) != list: 429 | # monitor_key = [monitor_key] 430 | # for key in monitor_key: 431 | # if not self.is_best[key]: 432 | # continue 433 | # save_folder = os.path.join(self.log_path, "models", 434 | # f"weights_best_{key}") 435 | # os.makedirs(save_folder, exist_ok=True) 436 | # cmd = f"cp {save_path} {save_folder}/model.pth" 437 | # os.system(cmd) 438 | # cmd = f"cp {save_path_opt} {save_folder}/adam.pth" 439 | # os.system(cmd) 440 | # with open(f"{save_folder}/key.txt", "w") as f: 441 | # val = getattr(self, key) 442 | # f.write(f"{key} {val}\n") 443 | 444 | def load_model(self): 445 | self.opt.load_weights_folder = os.path.expanduser( 446 | self.opt.load_weights_folder) 447 | assert os.path.isdir(self.opt.load_weights_folder), \ 448 | "Cannot find folder {}".format(self.opt.load_weights_folder) 449 | print("loading model from folder {}".format( 450 | self.opt.load_weights_folder)) 451 | 452 | try: 453 | self.epoch = int( 454 | self.opt.load_weights_folder.split('/')[-2].split('_')[1]) 455 | except: 456 | self.epoch = 0 457 | 458 | try: 459 | path = os.path.join(self.opt.load_weights_folder, 460 | "{}.pth".format("model")) 461 | model_dict = self.model.state_dict() 462 | pretrained_dict = torch.load(path, 'cpu') 463 | for k, v in pretrained_dict.items(): 464 | if k not in model_dict: 465 | print('model dict missing ', k, v.shape) 466 | for k, v in model_dict.items(): 467 | if k not in pretrained_dict: 468 | print('pretrained_dict missing ', k, v.shape) 469 | pretrained_dict = { 470 | k: v 471 | for k, v in pretrained_dict.items() if k in model_dict 472 | } 473 | model_dict.update(pretrained_dict) 474 | self.model.load_state_dict(model_dict) 475 | except Exception as e: 476 | print(e) 477 | print("Fail loading {}".format("model")) 478 | 479 | # loading optimizer state 480 | optimizer_load_path = os.path.join(self.opt.load_weights_folder, 481 | "adam.pth") 482 | if os.path.isfile(optimizer_load_path): 483 | print("Loading optimizer weights") 484 | optimizer_dict = torch.load(optimizer_load_path, 'cpu') 485 | try: 486 | self.model_optimizer.load_state_dict(optimizer_dict) 487 | except Exception as e: 488 | print(e) 489 | print("Fail loading optimizer weights") 490 | else: 491 | print("Cannot find optimizer weights so optimizer is randomly initialized") 492 | -------------------------------------------------------------------------------- /trainer_base_kitti.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from open3d import * 3 | import numpy as np 4 | import time 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | # from tensorboardX import SummaryWriter 11 | from torch.utils.tensorboard import SummaryWriter 12 | import json 13 | from utils import * 14 | import os 15 | import stat 16 | import glob 17 | import shutil 18 | from torch.autograd import Variable 19 | import torch.optim.lr_scheduler as lr_sched 20 | from options import MVS2DOptions 21 | import torch.backends.cudnn as cudnn 22 | 23 | cudnn.benchmark = True 24 | 25 | g = torch.Generator() 26 | g.manual_seed(0) 27 | 28 | 29 | def worker_init_fn(worker_id): 30 | seed = np.random.get_state()[1][0] + worker_id 31 | np.random.seed(seed) 32 | import random 33 | random.seed(seed) 34 | 35 | 36 | def file_remove_readonly(func, path, execinfo): 37 | os.chmod(path, stat.S_IWUSR)#修改文件权限 38 | func(path) 39 | 40 | 41 | class BaseTrainer(object): 42 | def __init__(self, options): 43 | 44 | self.is_best = {} 45 | self.epoch = 0 46 | self.step = 0 47 | self.eval_step = 0 48 | self.opt = options 49 | self.is_master = self.opt.gpu == 0 50 | self.opt.is_master = self.opt.gpu == 0 51 | self.device = self.opt.gpu 52 | 53 | 54 | # base_dir = '.' 55 | # self.opt.log_dir = os.path.join(base_dir, self.opt.log_dir) 56 | self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name) 57 | if self.opt.is_master: 58 | if os.path.exists(self.log_path) and self.opt.overwrite: 59 | try: 60 | shutil.rmtree(self.log_path) 61 | # shutil.rmtree(self.log_path, onerror=file_remove_readonly) 62 | except: 63 | print('overwrite folder failed') 64 | self.log_file = os.path.join(self.log_path, "log.txt") 65 | 66 | self.writers = {} 67 | for mode in ["train", "val"]: 68 | self.writers[mode] = SummaryWriter( 69 | os.path.join(self.log_path, mode)) 70 | if self.opt.is_master: 71 | with open(self.log_file, 'w') as f: 72 | f.write(self.opt.note + '\n') 73 | 74 | self.save_opts() 75 | 76 | self.build_dataset() 77 | 78 | self.build_model() 79 | 80 | # self.build_optimizer() 81 | self.fetch_optimizer() 82 | 83 | if self.opt.load_weights_folder is not None: 84 | self.load_model() 85 | 86 | if self.opt.distributed: 87 | if self.opt.gpu is not None: 88 | print( 89 | f"batch size on GPU: {self.opt.gpu}: {self.opt.batch_size}" 90 | ) 91 | 92 | self.model = torch.nn.parallel.DistributedDataParallel( 93 | self.model, 94 | device_ids=[self.opt.gpu], 95 | find_unused_parameters=True) 96 | else: 97 | model = torch.nn.parallel.DistributedDataParallel( 98 | self.model, find_unused_parameters=True) 99 | 100 | # self.build_scheduler() 101 | 102 | self.total_data_time = 0 103 | self.total_op_time = 0 104 | if self.opt.epoch_size == -1: 105 | self.opt.epoch_size = len(self.train_loader) 106 | 107 | if self.opt.is_master: 108 | print("Training model named:\n ", self.opt.model_name) 109 | print("Models and tensorboard events files are saved to:\n ", 110 | self.opt.log_dir) 111 | 112 | self.num_total_steps = len(self.train_loader) * self.opt.num_epochs 113 | print("There are {:d} training items and {:d} validation items\n". 114 | format( 115 | len(self.train_loader) * self.opt.batch_size, 116 | len(self.val_loader) * 1)) 117 | 118 | # def build_optimizer(self): 119 | # optimizer = optim.Adam(self.model.parameters(), 120 | # lr=self.opt.LR, 121 | # weight_decay=self.opt.WEIGHT_DECAY) 122 | 123 | # self.model_optimizer = optimizer 124 | 125 | 126 | # def build_scheduler(self): 127 | # total_iters_each_epoch = len(self.train_loader) 128 | # decay_steps = [ 129 | # x * total_iters_each_epoch for x in self.opt.DECAY_STEP_LIST 130 | # ] 131 | # total_steps = total_iters_each_epoch * self.opt.num_epochs 132 | 133 | # def lr_lbmd(cur_epoch): 134 | # cur_decay = 1 135 | # for decay_step in decay_steps: 136 | # if cur_epoch >= decay_step: 137 | # cur_decay = cur_decay * self.opt.LR_DECAY 138 | # return max(cur_decay, self.opt.LR_CLIP / self.opt.LR) 139 | 140 | # self.model_lr_scheduler = lr_sched.LambdaLR(self.model_optimizer, 141 | # lr_lbmd, 142 | # last_epoch=-1) 143 | 144 | def fetch_optimizer(self): 145 | """ Create the optimizer and learning rate scheduler """ 146 | total_iters_each_epoch = len(self.train_loader) 147 | total_steps = total_iters_each_epoch * self.opt.num_epochs 148 | optimizer = optim.AdamW(self.model.parameters(), lr=self.opt.LR, weight_decay=.00001, eps=1e-8) 149 | 150 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, self.opt.LR, total_steps+100, 151 | pct_start=0.01, cycle_momentum=False, anneal_strategy='linear') 152 | self.model_lr_scheduler = scheduler 153 | self.model_optimizer = optimizer 154 | 155 | 156 | def to_gpu(self, inputs, keys=None): 157 | if keys == None: 158 | keys = inputs.keys() 159 | for key in keys: 160 | if key not in inputs: 161 | continue 162 | ipt = inputs[key] 163 | if type(ipt) == torch.Tensor: 164 | inputs[key] = ipt.cuda(self.opt.gpu, non_blocking=True) 165 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor: 166 | inputs[key] = [ 167 | x.cuda(self.opt.gpu, non_blocking=True) for x in ipt 168 | ] 169 | elif type(ipt) == dict: 170 | for k in ipt.keys(): 171 | if type(ipt[k]) == torch.Tensor: 172 | ipt[k] = ipt[k].cuda(self.opt.gpu, non_blocking=True) 173 | 174 | def build_dataset(self): 175 | # if self.opt.dataset == 'ScanNet': 176 | # from datasets.ScanNet import ScanNet as Dataset 177 | # elif self.opt.dataset == 'DeMoN': 178 | # from datasets.DeMoN import DeMoN as Dataset 179 | # elif self.opt.dataset == 'DTU': 180 | # from datasets.DTU import DTU as Dataset 181 | # else: 182 | # raise Exception("Unknown Dataset") 183 | from datasets.kitti import DDAD_kitti 184 | 185 | train_dataset = DDAD_kitti(self.opt, True) 186 | if self.opt.distributed: 187 | self.train_sampler = torch.utils.data.distributed.DistributedSampler( 188 | train_dataset) 189 | else: 190 | self.train_sampler = None 191 | self.train_loader = DataLoader(train_dataset, 192 | self.opt.batch_size, 193 | shuffle=(self.train_sampler is None), 194 | num_workers=self.opt.num_workers, 195 | pin_memory=True, 196 | worker_init_fn=worker_init_fn, 197 | drop_last=True, 198 | sampler=self.train_sampler) 199 | 200 | val_dataset = DDAD_kitti(self.opt, False) 201 | if self.opt.distributed: 202 | self.val_sampler = torch.utils.data.distributed.DistributedSampler( 203 | val_dataset) 204 | else: 205 | self.val_sampler = None 206 | self.val_loader = DataLoader(val_dataset, 207 | self.opt.batch_size, 208 | shuffle=False, 209 | num_workers=self.opt.num_workers, 210 | pin_memory=True, 211 | worker_init_fn=worker_init_fn, 212 | drop_last=False, 213 | sampler=self.val_sampler) 214 | 215 | # val_dataset = DDAD(self.opt, False) 216 | # self.val_sampler = None 217 | # self.val_loader = DataLoader(val_dataset, 218 | # 1, 219 | # shuffle=False, 220 | # num_workers=self.opt.num_workers, 221 | # pin_memory=True, 222 | # drop_last=False, 223 | # sampler=self.val_sampler) 224 | 225 | def log_time(self, batch_idx, op_time, step_time, loss): 226 | """Print a logging statement to the terminal 227 | """ 228 | if self.opt.distributed: 229 | ops_per_sec = self.opt.ngpus_per_node * self.opt.batch_size / op_time 230 | steps_per_sec = self.opt.ngpus_per_node * self.opt.batch_size / step_time 231 | else: 232 | ops_per_sec = self.opt.batch_size / op_time 233 | steps_per_sec = self.opt.batch_size / step_time 234 | time_sofar = time.time() - self.start_time 235 | 236 | training_time_left = (self.num_total_steps / self.step - 237 | 1.0) * time_sofar if self.step > 0 else 0 238 | print_string = "epoch {:>3} | batch {:>6}/{:>6} | ops/s: {:5.1f} | steps/s: {:5.1f} | t_data/t_op: {:5.1f} " + \ 239 | " | loss: {:.5f} | time elapsed: {} | time left: {} | lr: {:.7f}" 240 | self.log_string( 241 | print_string.format(self.epoch, batch_idx, len(self.train_loader), 242 | ops_per_sec, steps_per_sec, 243 | self.total_data_time / self.total_op_time, 244 | loss, sec_to_hm_str(time_sofar), 245 | sec_to_hm_str(training_time_left), 246 | self.model_optimizer.param_groups[0]['lr'])) 247 | 248 | def train_epoch(self): 249 | if self.opt.is_master: 250 | print("Training") 251 | self.writers['train'].add_scalar( 252 | "lr", self.model_optimizer.param_groups[0]['lr'], self.step) 253 | self.set_train() 254 | before_data_loader_time = time.time() 255 | time_last_step = time.time() 256 | 257 | if self.opt.epoch_size == 0: 258 | return 259 | 260 | for batch_idx, inputs in enumerate(self.train_loader): 261 | if batch_idx >= self.opt.epoch_size: 262 | break 263 | after_data_loader_time = time.time() 264 | duration_data = after_data_loader_time - before_data_loader_time 265 | self.total_data_time += duration_data 266 | before_op_time = time.time() 267 | 268 | self.model_lr_scheduler.step(self.step) 269 | 270 | if self.opt.is_master: 271 | try: 272 | cur_lr = float(self.model_optimizer.lr) 273 | except: 274 | cur_lr = self.model_optimizer.param_groups[0]['lr'] 275 | 276 | self.writers['train'].add_scalar('meta_data/learning_rate', 277 | cur_lr, self.step) 278 | 279 | self.model_optimizer.zero_grad() 280 | losses, outputs = self.process_batch(inputs, 'train') 281 | losses['loss'].backward() 282 | 283 | torch.nn.utils.clip_grad_norm_(self.parameters_to_train, 284 | self.opt.GRAD_NORM_CLIP) 285 | 286 | contain_nan = False 287 | for weight in self.parameters_to_train: 288 | if weight.grad is not None: 289 | if torch.any(torch.isnan(weight.grad)): 290 | print('skip parameters update because of nan in grad') 291 | contain_nan = True 292 | if not contain_nan: 293 | self.model_optimizer.step() 294 | 295 | duration = time.time() - before_op_time 296 | self.total_op_time += duration 297 | 298 | if self.opt.is_master and batch_idx % self.opt.log_frequency == 0: 299 | duration_step = time.time() - time_last_step 300 | self.log_time(batch_idx, duration, duration_step, 301 | losses["loss"].cpu().data) 302 | self.log("train", inputs, losses, batch_idx, outputs) 303 | self.step += 1 304 | before_data_loader_time = time.time() 305 | time_last_step = time.time() 306 | 307 | def update_monitor_key(self, metrics, keys, goals): 308 | if len(keys): 309 | if type(keys) != list: 310 | keys = [keys] 311 | for key, goal in zip(keys, goals): 312 | val = metrics[key] 313 | if not hasattr(self, key): 314 | setattr(self, key, val) 315 | self.is_best[key] = True 316 | else: 317 | if goal == 'minimize': 318 | if val < getattr(self, key): 319 | self.is_best[key] = True 320 | setattr(self, key, val) 321 | else: 322 | self.is_best[key] = False 323 | elif goal == 'maximize': 324 | if val > getattr(self, key): 325 | self.is_best[key] = True 326 | setattr(self, key, val) 327 | else: 328 | self.is_best[key] = False 329 | 330 | def set_train(self): 331 | self.model.train() 332 | 333 | def set_eval(self): 334 | self.model.eval() 335 | 336 | def train(self): 337 | self.start_time = time.time() 338 | if self.opt.is_master: 339 | print("Total epoch: %d " % self.opt.num_epochs) 340 | print("train loader size: %d " % len(self.train_loader)) 341 | print("val loader size: %d " % len(self.val_loader)) 342 | print("log_frequency: %d " % self.opt.log_frequency) 343 | for self.epoch in range(self.opt.num_epochs): 344 | if self.opt.distributed: 345 | self.train_sampler.set_epoch(self.epoch) 346 | self.train_epoch() 347 | # if self.opt.is_master: 348 | self.val_epoch() 349 | torch.distributed.barrier() 350 | torch.cuda.empty_cache() 351 | if self.opt.is_master: 352 | self.save_model(monitor_key=self.opt.monitor_key) 353 | 354 | def val(self): 355 | self.val_epoch() 356 | 357 | def process_batch(self, inputs, mode): 358 | raise Exception("Need to implement process_batch") 359 | 360 | def compute_losses(self, inputs, outputs): 361 | raise Exception("Need to implement compute_losses") 362 | 363 | def log_string(self, content): 364 | with open(self.log_file, 'a') as f: 365 | f.write(content + '\n') 366 | print(content, flush=True) 367 | 368 | def log(self, mode, inputs, losses, batch_idx, outputs): 369 | """Write an event to the tensorboard events file 370 | """ 371 | writer = self.writers[mode] 372 | for l, v in losses.items(): 373 | if type(losses[l]) == dict: 374 | writer.add_scalars("{}".format(l), v, self.step) 375 | else: 376 | writer.add_scalar("{}".format(l), v, self.step) 377 | 378 | if batch_idx % 150 == 0: 379 | writer.add_image('image0', inputs[("img_ori", 0, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 380 | writer.add_image('image1', inputs[("img_ori", 1, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 381 | writer.add_image('image2', inputs[("img_ori", 2, 0)][0], global_step=self.step, walltime=None, dataformats='CHW') 382 | depth_gt = gray_2_colormap_np(inputs[("depth_gt", 0, 0)][0][0]) 383 | writer.add_image('depth_gt', depth_gt, global_step=self.step, walltime=None, dataformats='HWC') 384 | depth_pred = gray_2_colormap_np(outputs[('depth_pred', 0)][0][0]) 385 | writer.add_image('depth_pred', depth_pred, global_step=self.step, walltime=None, dataformats='HWC') 386 | depth_pred_2 = gray_2_colormap_np(outputs[('depth_pred_2', 0)][0][0]) 387 | writer.add_image('depth_pred_2', depth_pred_2, global_step=self.step, walltime=None, dataformats='HWC') 388 | 389 | 390 | def save_opts(self): 391 | """Save options to disk so we know what we ran this experiment with 392 | """ 393 | models_dir = os.path.join(self.log_path, "models") 394 | if not os.path.exists(models_dir): 395 | os.makedirs(models_dir) 396 | to_save = self.opt.__dict__.copy() 397 | 398 | with open(os.path.join(models_dir, 'opt.json'), 'w') as f: 399 | json.dump(to_save, f, indent=2, sort_keys=True) 400 | 401 | def clean_models(self, keep_ids): 402 | models = glob.glob(os.path.join(self.log_path, "models", "weights_*")) 403 | models = sorted(models, 404 | key=lambda x: int(x.split('/')[-1].split('_')[-1])) 405 | for i in range(len(models) - 1): 406 | epoch = int(models[i].split('/')[-1].split('_')[-1]) 407 | if epoch not in keep_ids: 408 | shutil.rmtree(models[i]) 409 | 410 | def save_model(self, monitor_key=""): 411 | """Save model weights to disk 412 | """ 413 | save_folder = os.path.join(self.log_path, "models", "weights_latest") 414 | if not os.path.exists(save_folder): 415 | os.makedirs(save_folder) 416 | print("save model to folder %s" % save_folder) 417 | 418 | save_path = os.path.join(save_folder, "model_{}.pth".format(self.epoch)) 419 | if self.opt.distributed: 420 | to_save = self.model.module.state_dict() 421 | else: 422 | to_save = self.model.state_dict() 423 | 424 | torch.save(to_save, save_path) 425 | save_path_opt = os.path.join(save_folder, "{}.pth".format("adam")) 426 | torch.save(self.model_optimizer.state_dict(), save_path_opt) 427 | # if len(monitor_key): 428 | # if type(monitor_key) != list: 429 | # monitor_key = [monitor_key] 430 | # for key in monitor_key: 431 | # if not self.is_best[key]: 432 | # continue 433 | # save_folder = os.path.join(self.log_path, "models", 434 | # f"weights_best_{key}") 435 | # os.makedirs(save_folder, exist_ok=True) 436 | # cmd = f"cp {save_path} {save_folder}/model.pth" 437 | # os.system(cmd) 438 | # cmd = f"cp {save_path_opt} {save_folder}/adam.pth" 439 | # os.system(cmd) 440 | # with open(f"{save_folder}/key.txt", "w") as f: 441 | # val = getattr(self, key) 442 | # f.write(f"{key} {val}\n") 443 | 444 | def load_model(self): 445 | self.opt.load_weights_folder = os.path.expanduser( 446 | self.opt.load_weights_folder) 447 | assert os.path.isdir(self.opt.load_weights_folder), \ 448 | "Cannot find folder {}".format(self.opt.load_weights_folder) 449 | print("loading model from folder {}".format( 450 | self.opt.load_weights_folder)) 451 | 452 | try: 453 | self.epoch = int( 454 | self.opt.load_weights_folder.split('/')[-2].split('_')[1]) 455 | except: 456 | self.epoch = 0 457 | 458 | try: 459 | path = os.path.join(self.opt.load_weights_folder, 460 | "{}.pth".format("model")) 461 | model_dict = self.model.state_dict() 462 | pretrained_dict = torch.load(path, 'cpu') 463 | for k, v in pretrained_dict.items(): 464 | if k not in model_dict: 465 | print('model dict missing ', k, v.shape) 466 | for k, v in model_dict.items(): 467 | if k not in pretrained_dict: 468 | print('pretrained_dict missing ', k, v.shape) 469 | pretrained_dict = { 470 | k: v 471 | for k, v in pretrained_dict.items() if k in model_dict 472 | } 473 | model_dict.update(pretrained_dict) 474 | self.model.load_state_dict(model_dict) 475 | except Exception as e: 476 | print(e) 477 | print("Fail loading {}".format("model")) 478 | 479 | # loading optimizer state 480 | optimizer_load_path = os.path.join(self.opt.load_weights_folder, 481 | "adam.pth") 482 | if os.path.isfile(optimizer_load_path): 483 | print("Loading optimizer weights") 484 | optimizer_dict = torch.load(optimizer_load_path, 'cpu') 485 | try: 486 | self.model_optimizer.load_state_dict(optimizer_dict) 487 | except Exception as e: 488 | print(e) 489 | print("Fail loading optimizer weights") 490 | else: 491 | print("Cannot find optimizer weights so optimizer is randomly initialized") 492 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import open3d as o3d 3 | from collections import defaultdict 4 | import os 5 | import random 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from torch.autograd import Variable 10 | import time 11 | import torch.nn.functional as F 12 | import re 13 | import collections.abc as container_abcs 14 | import torch.distributed as dist 15 | import torch.multiprocessing as mp 16 | import subprocess 17 | import matplotlib 18 | import matplotlib.pyplot as plt 19 | 20 | def gray_2_colormap_np_2(img, cmap = 'rainbow', max = None): 21 | img = img.squeeze() 22 | assert img.ndim == 2 23 | img[img<0] = 0 24 | mask_invalid = img < 1e-10 25 | if max == None: 26 | img = img / (img.max() + 1e-8) 27 | else: 28 | img = img/(max + 1e-8) 29 | 30 | norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1) 31 | cmap_m = matplotlib.cm.get_cmap(cmap) 32 | map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m) 33 | colormap = (map.to_rgba(img)[:,:,:3]*255).astype(np.uint8) 34 | colormap[mask_invalid] = 0 35 | 36 | return colormap 37 | 38 | def gray_2_colormap_np(img, cmap = 'rainbow', max = None): 39 | img = img.cpu().detach().numpy().squeeze() 40 | assert img.ndim == 2 41 | img[img<0] = 0 42 | mask_invalid = img < 1e-10 43 | if max == None: 44 | img = img / (img.max() + 1e-8) 45 | else: 46 | img = img/(max + 1e-8) 47 | 48 | norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1) 49 | cmap_m = matplotlib.cm.get_cmap(cmap) 50 | map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m) 51 | colormap = (map.to_rgba(img)[:,:,:3]*255).astype(np.uint8) 52 | colormap[mask_invalid] = 0 53 | 54 | return colormap 55 | 56 | 57 | def plot(xs, 58 | ys, 59 | stds=None, 60 | xlabel='', 61 | ylabel='', 62 | title='', 63 | legends=None, 64 | save_fn='test.png', 65 | marker=None, 66 | marker_size=12): 67 | MARKERS = ["o", "X", "D", "^", "<", "v", ">"] 68 | if marker is None: 69 | marker = MARKERS[3] 70 | 71 | nline = len(ys) 72 | nrows, ncols = 1, 1 73 | fig, ax = plt.subplots(figsize=(7, 7)) 74 | grid = plt.GridSpec(nrows, ncols, figure=fig) 75 | 76 | ax1 = plt.subplot(grid[0, 0]) 77 | lh = [] 78 | for i in range(nline): 79 | if stds is not None: 80 | #l, _, _= ax1.errorbar(xs, ys[i], yerr=stds[i], linewidth=4, marker=MARKERS[0], markersize=1, ) 81 | l, = ax1.plot( 82 | xs, 83 | ys[i], 84 | linewidth=4, 85 | marker=marker, 86 | markersize=marker_size, 87 | ) 88 | color = l.get_color() 89 | low = [x[0] for x in stds[i]] 90 | high = [x[1] for x in stds[i]] 91 | ax1.fill_between(xs, low, high, color=color, alpha=.1) 92 | 93 | else: 94 | l, = ax1.plot( 95 | xs, 96 | ys[i], 97 | linewidth=4, 98 | marker=marker, 99 | markersize=marker_size, 100 | ) 101 | lh.append(l) 102 | 103 | ax1.set_xlabel(xlabel, fontsize=25) 104 | ax1.set_ylabel(ylabel, fontsize=25) 105 | ax1.set_title(title, fontsize=25) 106 | if legends is not None: 107 | lgnd = ax1.legend(lh, legends, fontsize=15) 108 | plt.savefig(save_fn) 109 | 110 | 111 | def init_dist_slurm(tcp_port, local_rank, backend='nccl'): 112 | """ 113 | modified from https://github.com/open-mmlab/mmdetection 114 | Args: 115 | tcp_port: 116 | backend: 117 | Returns: 118 | """ 119 | proc_id = int(os.environ['SLURM_PROCID']) 120 | ntasks = int(os.environ['SLURM_NTASKS']) 121 | node_list = os.environ['SLURM_NODELIST'] 122 | num_gpus = torch.cuda.device_count() 123 | torch.cuda.set_device(proc_id % num_gpus) 124 | addr = subprocess.getoutput( 125 | 'scontrol show hostname {} | head -n1'.format(node_list)) 126 | os.environ['MASTER_PORT'] = str(tcp_port) 127 | os.environ['MASTER_ADDR'] = addr 128 | os.environ['WORLD_SIZE'] = str(ntasks) 129 | os.environ['RANK'] = str(proc_id) 130 | dist.init_process_group(backend=backend) 131 | 132 | total_gpus = dist.get_world_size() 133 | rank = dist.get_rank() 134 | return total_gpus, rank 135 | 136 | 137 | def init_dist_pytorch(tcp_port, local_rank, backend='nccl'): 138 | if mp.get_start_method(allow_none=True) is None: 139 | mp.set_start_method('spawn') 140 | 141 | num_gpus = torch.cuda.device_count() 142 | torch.cuda.set_device(local_rank % num_gpus) 143 | dist.init_process_group(backend=backend, 144 | init_method='tcp://127.0.0.1:%d' % tcp_port, 145 | rank=local_rank, 146 | world_size=num_gpus) 147 | rank = dist.get_rank() 148 | return num_gpus, rank 149 | 150 | 151 | def set_random_seed(seed): 152 | random.seed(seed) 153 | np.random.seed(seed) 154 | torch.manual_seed(seed) 155 | torch.backends.cudnn.deterministic = True 156 | torch.backends.cudnn.benchmark = False 157 | 158 | 159 | def randomRotation(epsilon): 160 | axis = (np.random.rand(3) - 0.5) 161 | axis /= np.linalg.norm(axis) 162 | dtheta = np.random.randn(1) * np.pi * epsilon 163 | K = np.array( 164 | [0, -axis[2], axis[1], axis[2], 0, -axis[0], -axis[1], axis[0], 165 | 0]).reshape(3, 3) 166 | dR = np.eye(3) + np.sin(dtheta) * K + (1 - np.cos(dtheta)) * np.matmul( 167 | K, K) 168 | return dR 169 | 170 | 171 | def angle_axis_to_rotation_matrix(angle_axis): 172 | """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix 173 | 174 | Args: 175 | angle_axis (Tensor): tensor of 3d vector of axis-angle rotations. 176 | 177 | Returns: 178 | Tensor: tensor of 4x4 rotation matrices. 179 | 180 | Shape: 181 | - Input: :math:`(N, 3)` 182 | - Output: :math:`(N, 4, 4)` 183 | 184 | Example: 185 | >>> input = torch.rand(1, 3) # Nx3 186 | >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx4x4 187 | """ 188 | def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6): 189 | # We want to be careful to only evaluate the square root if the 190 | # norm of the angle_axis vector is greater than zero. Otherwise 191 | # we get a division by zero. 192 | k_one = 1.0 193 | theta = torch.sqrt(theta2) 194 | wxyz = angle_axis / (theta + eps) 195 | wx, wy, wz = torch.chunk(wxyz, 3, dim=1) 196 | cos_theta = torch.cos(theta) 197 | sin_theta = torch.sin(theta) 198 | 199 | r00 = cos_theta + wx * wx * (k_one - cos_theta) 200 | r10 = wz * sin_theta + wx * wy * (k_one - cos_theta) 201 | r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta) 202 | r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta 203 | r11 = cos_theta + wy * wy * (k_one - cos_theta) 204 | r21 = wx * sin_theta + wy * wz * (k_one - cos_theta) 205 | r02 = wy * sin_theta + wx * wz * (k_one - cos_theta) 206 | r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta) 207 | r22 = cos_theta + wz * wz * (k_one - cos_theta) 208 | rotation_matrix = torch.cat( 209 | [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1) 210 | return rotation_matrix.view(-1, 3, 3) 211 | 212 | def _compute_rotation_matrix_taylor(angle_axis): 213 | rx, ry, rz = torch.chunk(angle_axis, 3, dim=1) 214 | k_one = torch.ones_like(rx) 215 | rotation_matrix = torch.cat( 216 | [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1) 217 | return rotation_matrix.view(-1, 3, 3) 218 | 219 | # stolen from ceres/rotation.h 220 | 221 | _angle_axis = torch.unsqueeze(angle_axis, dim=1) 222 | theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2)) 223 | theta2 = torch.squeeze(theta2, dim=1) 224 | 225 | # compute rotation matrices 226 | rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2) 227 | rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis) 228 | 229 | # create mask to handle both cases 230 | eps = 1e-6 231 | mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device) 232 | mask_pos = (mask).type_as(theta2) 233 | mask_neg = (mask == False).type_as(theta2) # noqa 234 | 235 | # create output pose matrix 236 | batch_size = angle_axis.shape[0] 237 | rotation_matrix = torch.eye(4).to(angle_axis.device).type_as(angle_axis) 238 | rotation_matrix = rotation_matrix.view(1, 4, 4).repeat(batch_size, 1, 1) 239 | # fill output matrix with masked values 240 | rotation_matrix[..., :3, :3] = \ 241 | mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor 242 | return rotation_matrix # Nx4x4 243 | 244 | 245 | default_collate_err_msg_format = ( 246 | "default_collate: batch must contain tensors, numpy arrays, numbers, " 247 | "dicts or lists; found {}") 248 | 249 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 250 | 251 | 252 | def default_convert(data): 253 | r"""Converts each NumPy array data field into a tensor""" 254 | elem_type = type(data) 255 | if isinstance(data, torch.Tensor): 256 | return data 257 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 258 | and elem_type.__name__ != 'string_': 259 | # array of string classes and object 260 | if elem_type.__name__ == 'ndarray' \ 261 | and np_str_obj_array_pattern.search(data.dtype.str) is not None: 262 | return data 263 | return torch.as_tensor(data) 264 | elif isinstance(data, container_abcs.Mapping): 265 | return {key: default_convert(data[key]) for key in data} 266 | elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple 267 | return elem_type(*(default_convert(d) for d in data)) 268 | elif isinstance(data, container_abcs.Sequence) and not isinstance( 269 | data, string_classes): 270 | return [default_convert(d) for d in data] 271 | else: 272 | return data 273 | 274 | 275 | def custom_collate(batch): 276 | r"""Puts each data field into a tensor with outer dimension batch size""" 277 | 278 | elem = batch[0] 279 | elem_type = type(elem) 280 | if isinstance(batch, list) and isinstance(elem, tuple): 281 | #data = torch.cat((x[0] for x in batch)) 282 | return [x[0] for x in batch] 283 | if type(elem) == tuple and elem[1] == 'varlen': 284 | return [x[0] for x in batch] 285 | 286 | if isinstance(elem, torch.Tensor): 287 | out = None 288 | if torch.utils.data.get_worker_info() is not None: 289 | # If we're in a background process, concatenate directly into a 290 | # shared memory tensor to avoid an extra copy 291 | numel = sum([x.numel() for x in batch]) 292 | storage = elem.storage()._new_shared(numel) 293 | out = elem.new(storage) 294 | try: 295 | return torch.stack(batch, 0, out=out) 296 | except: 297 | import ipdb 298 | ipdb.set_trace() 299 | 300 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 301 | and elem_type.__name__ != 'string_': 302 | elem = batch[0] 303 | if elem_type.__name__ == 'ndarray': 304 | # array of string classes and object 305 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 306 | raise TypeError( 307 | default_collate_err_msg_format.format(elem.dtype)) 308 | 309 | return custom_collate([torch.as_tensor(b) for b in batch]) 310 | elif elem.shape == (): # scalars 311 | return torch.as_tensor(batch) 312 | elif isinstance(elem, float): 313 | return torch.tensor(batch, dtype=torch.float64) 314 | #elif isinstance(elem, int_classes): 315 | elif isinstance(elem, int): 316 | return torch.tensor(batch) 317 | #elif isinstance(elem, string_classes): 318 | elif isinstance(elem, str): 319 | return batch 320 | elif isinstance(elem, container_abcs.Mapping): 321 | return {key: custom_collate([d[key] for d in batch]) for key in elem} 322 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 323 | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) 324 | elif isinstance(elem, container_abcs.Sequence): 325 | transposed = zip(*batch) 326 | return [custom_collate(samples) for samples in transposed] 327 | 328 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 329 | 330 | 331 | _use_shared_memory = False 332 | 333 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 334 | 335 | error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}" 336 | 337 | numpy_type_map = { 338 | 'float64': torch.DoubleTensor, 339 | 'float32': torch.FloatTensor, 340 | 'float16': torch.HalfTensor, 341 | 'int64': torch.LongTensor, 342 | 'int32': torch.IntTensor, 343 | 'int16': torch.ShortTensor, 344 | 'int8': torch.CharTensor, 345 | 'uint8': torch.ByteTensor, 346 | } 347 | 348 | 349 | def default_collatev1_1(batch): 350 | r"""Puts each data field into a tensor with outer dimension batch size""" 351 | 352 | elem = batch[0] 353 | elem_type = type(batch[0]) 354 | if isinstance(batch, list) and isinstance(elem, tuple): 355 | #data = torch.cat((x[0] for x in batch)) 356 | return [x[0] for x in batch] 357 | if isinstance(batch[0], torch.Tensor): 358 | out = None 359 | if _use_shared_memory: 360 | # If we're in a background process, concatenate directly into a 361 | # shared memory tensor to avoid an extra copy 362 | numel = sum([x.numel() for x in batch]) 363 | storage = batch[0].storage()._new_shared(numel) 364 | out = batch[0].new(storage) 365 | try: 366 | return torch.stack(batch, 0, out=out) 367 | except: 368 | import ipdb 369 | ipdb.set_trace() 370 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 371 | and elem_type.__name__ != 'string_': 372 | elem = batch[0] 373 | try: 374 | if elem_type.__name__ == 'ndarray': 375 | # array of string classes and object 376 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 377 | raise TypeError(error_msg_fmt.format(elem.dtype)) 378 | 379 | return default_collatev1_1( 380 | [torch.from_numpy(b) for b in batch]) 381 | except: 382 | import ipdb 383 | ipdb.set_trace() 384 | if elem.shape == (): # scalars 385 | py_type = float if elem.dtype.name.startswith('float') else int 386 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 387 | elif isinstance(batch[0], float): 388 | return torch.tensor(batch, dtype=torch.float64) 389 | #elif isinstance(batch[0], int_classes): 390 | elif isinstance(batch[0], int): 391 | return torch.tensor(batch) 392 | #elif isinstance(batch[0], string_classes): 393 | elif isinstance(batch[0], str): 394 | return batch 395 | elif isinstance(batch[0], container_abcs.Mapping): 396 | return { 397 | key: default_collatev1_1([d[key] for d in batch]) 398 | for key in batch[0] 399 | } 400 | elif isinstance(batch[0], tuple) and hasattr(batch[0], 401 | '_fields'): # namedtuple 402 | return type(batch[0])(*(default_collatev1_1(samples) 403 | for samples in zip(*batch))) 404 | elif isinstance(batch[0], container_abcs.Sequence): 405 | transposed = zip(*batch) 406 | return [default_collatev1_1(samples) for samples in transposed] 407 | 408 | raise TypeError((error_msg_fmt.format(type(batch[0])))) 409 | 410 | 411 | def backproject_depth_th(depth, inv_K, mask=False, device='cuda'): 412 | h, w = depth.shape 413 | idu, idv = np.meshgrid(range(w), range(h)) 414 | grid = np.stack((idu.flatten(), idv.flatten(), np.ones([w * h]))) 415 | grid = torch.from_numpy(grid).float().to(device) 416 | x = torch.matmul(inv_K[:3, :3], grid) 417 | x = x * depth.flatten()[None, :] 418 | x = x.t() 419 | if mask: 420 | x = x[depth.flatten() > 0] 421 | return x 422 | 423 | 424 | def backproject_depth(depth, inv_K, mask=False): 425 | h, w = depth.shape 426 | idu, idv = np.meshgrid(range(w), range(h)) 427 | grid = np.stack((idu.flatten(), idv.flatten(), np.ones([w * h]))) 428 | x = np.matmul(inv_K[:3, :3], grid) 429 | x = x * depth.flatten()[None, :] 430 | x = x.T 431 | if mask: 432 | x = x[depth.flatten() > 0] 433 | return x 434 | 435 | 436 | def parameters_count(net, name, do_print=True): 437 | model_parameters = filter(lambda p: p.requires_grad, net.parameters()) 438 | params = sum([np.prod(p.size()) for p in model_parameters]) 439 | if do_print: 440 | print('#params %s: %.3f M' % (name, params / 1e6)) 441 | return params 442 | 443 | 444 | def cuda_time(): 445 | torch.cuda.synchronize() 446 | return time.time() 447 | 448 | 449 | def transform3x3(pc, T): 450 | # T: [4,4] 451 | # pc: [n, 3] 452 | # return: [n, 3] 453 | return (np.matmul(T[:3, :3], pc.T)).T 454 | 455 | 456 | def transform4x4(pc, T): 457 | # T: [4,4] 458 | # pc: [n, 3] 459 | # return: [n, 3] 460 | return (np.matmul(T[:3, :3], pc.T) + T[:3, 3:4]).T 461 | 462 | 463 | def transform4x4_th(pc, T): 464 | # T: [4,4] 465 | # pc: [n, 3] 466 | # return: [n, 3] 467 | return (torch.matmul(T[:3, :3], pc.t()) + T[:3, 3:4]).t() 468 | 469 | 470 | def v(var, cuda=True, volatile=False): 471 | if type(var) == torch.Tensor or type(var) == torch.DoubleTensor: 472 | res = Variable(var.float(), volatile=volatile) 473 | elif type(var) == np.ndarray: 474 | res = Variable(torch.from_numpy(var).float(), volatile=volatile) 475 | if cuda: 476 | res = res.cuda() 477 | return res 478 | 479 | 480 | def npy(var): 481 | return var.data.cpu().numpy() 482 | 483 | 484 | def worker_init_fn(worker_id): 485 | np.random.seed(np.random.get_state()[1][0] + worker_id) 486 | 487 | 488 | def sec_to_hm(t): 489 | """Convert time in seconds to time in hours, minutes and seconds 490 | e.g. 10239 -> (2, 50, 39) 491 | """ 492 | t = int(t) 493 | s = t % 60 494 | t //= 60 495 | m = t % 60 496 | t //= 60 497 | return t, m, s 498 | 499 | 500 | def sec_to_hm_str(t): 501 | """Convert time in seconds to a nice string 502 | e.g. 10239 -> '02h50m39s' 503 | """ 504 | h, m, s = sec_to_hm(t) 505 | return "{:02d}h{:02d}m{:02d}s".format(h, m, s) 506 | 507 | 508 | def write_ply(fn, point, normal=None, color=None): 509 | 510 | ply = o3d.geometry.PointCloud() 511 | ply.points = o3d.utility.Vector3dVector(point) 512 | if color is not None: 513 | ply.colors = o3d.utility.Vector3dVector(color) 514 | if normal is not None: 515 | ply.normals = o3d.utility.Vector3dVector(normal) 516 | o3d.io.write_point_cloud(fn, ply) 517 | 518 | 519 | def skew(x): 520 | return np.array([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]]) 521 | 522 | 523 | def Thres_metrics(pred, gt, mask, interval, thre): 524 | abs_diff = (pred - gt).abs() / interval 525 | metric = (mask * (abs_diff < thre).float()).sum() / mask.sum() 526 | return metric 527 | 528 | 529 | def Thres_metrics_np(pred, gt, mask, interval, thre): 530 | abs_diff = np.abs(pred - gt) / interval 531 | metric = (mask * (abs_diff < thre)).sum() / mask.sum() 532 | return metric 533 | -------------------------------------------------------------------------------- /visual_ddad.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import sys 5 | import cv2 6 | import torch 7 | import torch.nn as nn 8 | from options import MVS2DOptions, EvalCfg 9 | import networks 10 | from torch.utils.data import DataLoader 11 | from datasets.DDAD import DDAD 12 | import torch.nn.functional as F 13 | from utils import * 14 | 15 | def resize_depth_preserve(depth, shape): 16 | """ 17 | Resizes depth map preserving all valid depth pixels 18 | Multiple downsampled points can be assigned to the same pixel. 19 | 20 | Parameters 21 | ---------- 22 | depth : np.array [h,w] 23 | Depth map 24 | shape : tuple (H,W) 25 | Output shape 26 | 27 | Returns 28 | ------- 29 | depth : np.array [H,W,1] 30 | Resized depth map 31 | """ 32 | 33 | # Store dimensions and reshapes to single column 34 | depth = np.squeeze(depth) 35 | h, w = depth.shape 36 | x = depth.reshape(-1) 37 | # Create coordinate grid 38 | uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2) 39 | # Filters valid points 40 | idx = x > 0 41 | crd, val = uv[idx], x[idx] 42 | # Downsamples coordinates 43 | crd[:, 0] = (crd[:, 0] * (shape[0] / h)).astype(np.int32) 44 | crd[:, 1] = (crd[:, 1] * (shape[1] / w)).astype(np.int32) 45 | # Filters points inside image 46 | idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1]) 47 | crd, val = crd[idx], val[idx] 48 | # Creates downsampled depth image and assigns points 49 | depth = np.zeros(shape) 50 | depth[crd[:, 0], crd[:, 1]] = val 51 | # Return resized depth map 52 | return np.expand_dims(depth, axis=0) 53 | 54 | 55 | def homo_warping_depth(src_fea, src_proj, ref_proj, depth_values): 56 | # src_fea: [B, C, H, W] 57 | # src_proj: [B, 4, 4] 58 | # ref_proj: [B, 4, 4] 59 | # depth_values: [B, Ndepth, H, W] 60 | # out: [B, C, Ndepth, H, W] 61 | batch, channels = src_fea.shape[0], src_fea.shape[1] 62 | num_depth = depth_values.shape[1] 63 | #height, width = src_fea.shape[2], src_fea.shape[3] 64 | h_src, w_src = src_fea.shape[2], src_fea.shape[3] 65 | h_ref, w_ref = depth_values.shape[2], depth_values.shape[3] 66 | 67 | with torch.no_grad(): 68 | 69 | proj = torch.matmul(src_proj, torch.inverse(ref_proj)) 70 | rot = proj[:, :3, :3] # [B,3,3] 71 | trans = proj[:, :3, 3:4] # [B,3,1] 72 | 73 | 74 | y, x = torch.meshgrid([torch.arange(0, h_ref, dtype=torch.float32, device=src_fea.device), 75 | torch.arange(0, w_ref, dtype=torch.float32, device=src_fea.device)]) 76 | y, x = y.contiguous(), x.contiguous() 77 | y, x = y.view(h_ref * w_ref), x.view(h_ref * w_ref) 78 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 79 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] 80 | 81 | rot_xyz = torch.matmul(rot, xyz) 82 | rot_depth_xyz = rot_xyz * depth_values.view(batch, 1, -1) 83 | 84 | proj_xyz = rot_depth_xyz + trans.view(batch,3,1) 85 | 86 | proj_xy = proj_xyz[:, :2, :] / proj_xyz[:, 2:3, :] # [B, 2, Ndepth, H*W] 87 | z = proj_xyz[:, 2:3, :].view(batch, h_ref, w_ref) 88 | proj_x_normalized = proj_xy[:, 0, :] / ((w_src - 1) / 2.0) - 1 89 | proj_y_normalized = proj_xy[:, 1, :] / ((h_src - 1) / 2.0) - 1 90 | X_mask = ((proj_x_normalized > 1)+(proj_x_normalized < -1)).detach() 91 | proj_x_normalized[X_mask] = 2 # make sure that no point in warped image is a combinaison of im and gray 92 | Y_mask = ((proj_y_normalized > 1)+(proj_y_normalized < -1)).detach() 93 | proj_y_normalized[Y_mask] = 2 94 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=2) # [B, Ndepth, H*W, 2] 95 | grid = proj_xy 96 | proj_mask = ((X_mask + Y_mask) > 0).view(batch, num_depth, h_ref, w_ref) 97 | proj_mask = (proj_mask + (z <= 0)) > 0 98 | 99 | warped_src_fea = F.grid_sample(src_fea, grid.view(batch, h_ref, w_ref, 2), mode='bilinear', 100 | padding_mode='zeros', align_corners=True) 101 | 102 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, h_ref, w_ref) 103 | 104 | #return warped_src_fea , proj_mask 105 | return warped_src_fea 106 | 107 | 108 | def to_gpu(inputs, keys=None): 109 | if keys == None: 110 | keys = inputs.keys() 111 | for key in keys: 112 | if key not in inputs: 113 | continue 114 | ipt = inputs[key] 115 | if type(ipt) == torch.Tensor: 116 | inputs[key] = ipt.cuda() 117 | elif type(ipt) == list and type(ipt[0]) == torch.Tensor: 118 | inputs[key] = [ 119 | x.cuda() for x in ipt 120 | ] 121 | elif type(ipt) == dict: 122 | for k in ipt.keys(): 123 | if type(ipt[k]) == torch.Tensor: 124 | ipt[k] = ipt[k].cuda() 125 | 126 | 127 | options = MVS2DOptions() 128 | opts = options.parse() 129 | 130 | # opts.width = int(640) 131 | # opts.height = int(480) 132 | dataset = DDAD(opts, False) 133 | data_loader = DataLoader(dataset, 134 | 1, 135 | shuffle=False, 136 | num_workers=1, 137 | pin_memory=True, 138 | drop_last=False, 139 | sampler=None) 140 | model = networks.MVS2D(opt=opts).cuda() 141 | pretrained_dict = torch.load("/home/cjd/MVS2D/log/AFNet/models/weights_latest/model.pth") 142 | model_dict = model.state_dict() 143 | pretrained_dict = { 144 | k: v 145 | for k, v in pretrained_dict.items() if k in model_dict 146 | } 147 | model_dict.update(pretrained_dict) 148 | model.load_state_dict(model_dict) 149 | 150 | model.eval() 151 | root_path = '/data/cjd/AFnet/visual/ddad/' 152 | with torch.no_grad(): 153 | for batch_idx, inputs in enumerate(data_loader): 154 | print(batch_idx) 155 | to_gpu(inputs) 156 | 157 | imgs, proj_mats, pose_mats = [], [], [] 158 | for i in range(inputs['num_frame'][0].item()): 159 | imgs.append(inputs[('color', i, 0)]) 160 | proj_mats.append(inputs[('proj', i)]) 161 | pose_mats.append(inputs[('pose', i)]) 162 | 163 | pose_mats[0] = pose_mats[0]*0.75 164 | pose_mats[1] = pose_mats[1]*0.75 165 | pose_mats[2] = pose_mats[2]*0.75 166 | 167 | outputs = model(imgs[0], imgs[1:], pose_mats[0], pose_mats[1:], inputs[('inv_K_pool', 0)]) 168 | 169 | depth_gt = inputs[("depth_gt", 0, 0)][0].cpu().detach().numpy().squeeze() 170 | depth_gt = resize_depth_preserve(depth_gt, (608,960)) 171 | depth_gt_path = os.path.join(root_path,'depth_gt','{}.png'.format(batch_idx)) 172 | depth_gt_np = gray_2_colormap_np_2(depth_gt ,max = 120)[:,:,::-1] 173 | 174 | img0 = imgs[0] 175 | 176 | 177 | depth_pred = outputs[('depth_pred', 0)][0] 178 | depth_pred_2 = outputs[('depth_pred_2', 0)][0] 179 | 180 | depth_pred_np = gray_2_colormap_np(depth_pred ,max = 120)[:,:,::-1] 181 | depth_pred_2_np = gray_2_colormap_np(depth_pred_2 ,max = 120)[:,:,::-1] 182 | img0_path = os.path.join(root_path,'img0', '{}.png'.format(batch_idx)) 183 | depth_1_path = os.path.join(root_path,'depth_1','{}.png'.format(batch_idx)) 184 | depth_2_path = os.path.join(root_path,'depth_2', '{}.png'.format(batch_idx)) 185 | 186 | 187 | img0_np = img0[0].cpu().detach().numpy().squeeze().transpose(1,2,0) 188 | img0_np = (img0_np / img0_np.max() * 255).astype(np.uint8) 189 | cv2.imwrite(img0_path, img0_np) 190 | cv2.imwrite(depth_1_path, depth_pred_np) 191 | cv2.imwrite(depth_2_path, depth_pred_2_np) 192 | cv2.imwrite(depth_gt_path, depth_gt_np) 193 | 194 | 195 | 196 | 197 | 198 | 199 | # a = input('input some') 200 | # print(a) 201 | 202 | # break 203 | --------------------------------------------------------------------------------