├── .gitignore ├── LICENSE ├── README.md ├── configs ├── arguments_eval_kittieigen.txt ├── arguments_eval_nyu.txt ├── arguments_train_kittieigen.txt └── arguments_train_nyu.txt ├── data_splits ├── eigen_test_files_with_gt.txt ├── eigen_train_files_with_gt.txt ├── kitti_depth_prediction_train.txt ├── kitti_official_test.txt ├── kitti_official_valid.txt ├── nyudepthv2_test_files_with_gt.txt └── nyudepthv2_train_files_with_gt_dense.txt ├── files ├── intro.png ├── office_00633.jpg ├── office_00633_depth.jpg ├── office_00633_pcd.jpg ├── output_nyu1_compressed.gif └── output_nyu2_compressed.gif └── newcrfs ├── dataloaders ├── __init__.py ├── dataloader.py └── dataloader_kittipred.py ├── demo.py ├── eval.py ├── networks ├── NewCRFDepth.py ├── __init__.py ├── newcrf_layers.py ├── newcrf_utils.py ├── swin_transformer.py └── uper_crf_head.py ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | __pycache__ 3 | model_zoo 4 | models 5 | datasets -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Alibaba Cloud 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## NeW CRFs: Neural Window Fully-connected CRFs for Monocular Depth Estimation 2 | 3 | This is the official PyTorch implementation code for NeWCRFs. For technical details, please refer to: 4 | 5 | **NeW CRFs: Neural Window Fully-connected CRFs for Monocular Depth Estimation**
6 | Weihao Yuan, Xiaodong Gu, Zuozhuo Dai, Siyu Zhu, Ping Tan
7 | **CVPR 2022**
8 | **[[Project Page](https://weihaosky.github.io/newcrfs/)]** | 9 | **[[Paper](https://arxiv.org/abs/2203.01502)]**
10 | 11 | 12 |

13 |    14 |

15 | 20 | 21 | ![Output1](files/output_nyu2_compressed.gif) 22 | 23 | ## Bibtex 24 | If you find this code useful in your research, please cite: 25 | 26 | ``` 27 | @inproceedings{yuan2022newcrfs, 28 | title={NeWCRFs: Neural Window Fully-connected CRFs for Monocular Depth Estimation}, 29 | author={Yuan, Weihao and Gu, Xiaodong and Dai, Zuozhuo and Zhu, Siyu and Tan, Ping}, 30 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 31 | pages={}, 32 | year={2022} 33 | } 34 | ``` 35 | 36 | ## Contents 37 | 1. [Installation](#installation) 38 | 2. [Datasets](#datasets) 39 | 3. [Training](#training) 40 | 4. [Evaluation](#evaluation) 41 | 5. [Models](#models) 42 | 6. [Demo](#demo) 43 | 44 | ## Installation 45 | ``` 46 | conda create -n newcrfs python=3.8 47 | conda activate newcrfs 48 | conda install pytorch=1.10.0 torchvision cudatoolkit=11.1 49 | pip install matplotlib, tqdm, tensorboardX, timm, mmcv 50 | ``` 51 | 52 | 53 | ## Datasets 54 | You can prepare the datasets KITTI and NYUv2 according to [here](https://github.com/cleinc/bts), and then modify the data path in the config files to your dataset locations. 55 | 56 | Or you can download the NYUv2 data from [here](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/newcrfs/datasets/nyu/sync.zip) and download the KITTI data from [here](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_prediction). 57 | 58 | 59 | ## Training 60 | First download the pretrained encoder backbone from [here](https://github.com/microsoft/Swin-Transformer), and then modify the pretrain path in the config files. 61 | 62 | Training the NYUv2 model: 63 | ``` 64 | python newcrfs/train.py configs/arguments_train_nyu.txt 65 | ``` 66 | 67 | Training the KITTI model: 68 | ``` 69 | python newcrfs/train.py configs/arguments_train_kittieigen.txt 70 | ``` 71 | 72 | 73 | ## Evaluation 74 | Evaluate the NYUv2 model: 75 | ``` 76 | python newcrfs/eval.py configs/arguments_eval_nyu.txt 77 | ``` 78 | 79 | Evaluate the KITTI model: 80 | ``` 81 | python newcrfs/eval.py configs/arguments_eval_kittieigen.txt 82 | ``` 83 | 84 | ## Models 85 | | Model | Abs.Rel. | Sqr.Rel | RMSE | RMSElog | a1 | a2 | a3| SILog| 86 | | :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 87 | |[NYUv2](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/newcrfs/models/model_nyu.ckpt) | 0.0952 | 0.0443 | 0.3310 | 0.1185 | 0.923 | 0.992 | 0.998 | 9.1023 | 88 | |[KITTI_Eigen](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/newcrfs/models/model_kittieigen.ckpt) | 0.0520 | 0.1482 | 2.0716 | 0.0780 | 0.975 | 0.997 | 0.999 | 6.9859 | 89 | 90 | 91 | ## Demo 92 | Test images with the indoor model: 93 | ``` 94 | python newcrfs/test.py --data_path datasets/test_data --dataset nyu --filenames_file data_splits/test_list.txt --checkpoint_path model_nyu.ckpt --max_depth 10 --save_viz 95 | ``` 96 | 97 | Play with the live demo from a video or your webcam: 98 | ``` 99 | python newcrfs/demo.py --dataset nyu --checkpoint_path model_zoo/model_nyu.ckpt --max_depth 10 --video video.mp4 100 | ``` 101 | 102 | ![Output1](files/output_nyu1_compressed.gif) 103 | 104 | [Demo video1](https://www.youtube.com/watch?v=RrWQIpXoP2Y) 105 | 106 | [Demo video2](https://www.youtube.com/watch?v=fD3sWH_54cg) 107 | 108 | [Demo video3](https://www.youtube.com/watch?v=IztmOYZNirM) 109 | 110 | ## Acknowledgements 111 | Thanks to Jin Han Lee for opening source of the excellent work [BTS](https://github.com/cleinc/bts). 112 | Thanks to Microsoft Research Asia for opening source of the excellent work [Swin Transformer](https://github.com/microsoft/Swin-Transformer). -------------------------------------------------------------------------------- /configs/arguments_eval_kittieigen.txt: -------------------------------------------------------------------------------- 1 | --model_name newcrfs_kittieigen 2 | --encoder large07 3 | --dataset kitti 4 | --input_height 352 5 | --input_width 1216 6 | --max_depth 80 7 | --do_kb_crop 8 | 9 | --data_path_eval datasets/kitti/ 10 | --gt_path_eval datasets/kitti/ 11 | --filenames_file_eval data_splits/eigen_test_files_with_gt.txt 12 | --min_depth_eval 1e-3 13 | --max_depth_eval 80 14 | --garg_crop 15 | 16 | --checkpoint_path model_zoo/model_kittieigen.ckpt -------------------------------------------------------------------------------- /configs/arguments_eval_nyu.txt: -------------------------------------------------------------------------------- 1 | --model_name newscrf_nyu 2 | --encoder large07 3 | --dataset nyu 4 | --input_height 480 5 | --input_width 640 6 | --max_depth 10 7 | 8 | --data_path_eval datasets/nyu/official_splits/test/ 9 | --gt_path_eval datasets/nyu/official_splits/test/ 10 | --filenames_file_eval data_splits/nyudepthv2_test_files_with_gt.txt 11 | --min_depth_eval 1e-3 12 | --max_depth_eval 10 13 | --eigen_crop 14 | 15 | --checkpoint_path model_zoo/model_nyu.ckpt -------------------------------------------------------------------------------- /configs/arguments_train_kittieigen.txt: -------------------------------------------------------------------------------- 1 | --mode train 2 | --model_name newcrfs_kittieigen 3 | --encoder large07 4 | --pretrain model_zoo/swin_transformer/swin_large_patch4_window7_224_22k.pth 5 | --dataset kitti 6 | --data_path datasets/kitti/ 7 | --gt_path datasets/kitti/ 8 | --filenames_file data_splits/eigen_train_files_with_gt.txt 9 | --batch_size 8 10 | --num_epochs 50 11 | --learning_rate 2e-5 12 | --weight_decay 1e-2 13 | --adam_eps 1e-3 14 | --num_threads 1 15 | --input_height 352 16 | --input_width 1120 17 | --max_depth 80 18 | --do_kb_crop 19 | --do_random_rotate 20 | --degree 1.0 21 | --log_directory ./models/ 22 | --multiprocessing_distributed 23 | --dist_url tcp://127.0.0.1:2345 24 | 25 | --log_freq 100 26 | --do_online_eval 27 | --eval_freq 1000 28 | --data_path_eval datasets/kitti/ 29 | --gt_path_eval datasets/kitti/ 30 | --filenames_file_eval data_splits/eigen_test_files_with_gt.txt 31 | --min_depth_eval 1e-3 32 | --max_depth_eval 80 33 | --garg_crop 34 | -------------------------------------------------------------------------------- /configs/arguments_train_nyu.txt: -------------------------------------------------------------------------------- 1 | --mode train 2 | --model_name newcrfs_nyu 3 | --encoder large07 4 | --pretrain model_zoo/swin_transformer/swin_large_patch4_window7_224_22k.pth 5 | --dataset nyu 6 | --data_path datasets/nyu/sync/ 7 | --gt_path datasets/nyu/sync/ 8 | --filenames_file data_splits/nyudepthv2_train_files_with_gt_dense.txt 9 | --batch_size 8 10 | --num_epochs 50 11 | --learning_rate 2e-5 12 | --weight_decay 1e-2 13 | --adam_eps 1e-3 14 | --num_threads 1 15 | --input_height 480 16 | --input_width 640 17 | --max_depth 10 18 | --do_random_rotate 19 | --degree 2.5 20 | --log_directory ./models/ 21 | --multiprocessing_distributed 22 | --dist_url tcp://127.0.0.1:2345 23 | 24 | --log_freq 100 25 | --do_online_eval 26 | --eval_freq 1000 27 | --data_path_eval datasets/nyu/official_splits/test/ 28 | --gt_path_eval datasets/nyu/official_splits/test/ 29 | --filenames_file_eval data_splits/nyudepthv2_test_files_with_gt.txt 30 | --min_depth_eval 1e-3 31 | --max_depth_eval 10 32 | --eigen_crop 33 | -------------------------------------------------------------------------------- /data_splits/kitti_official_test.txt: -------------------------------------------------------------------------------- 1 | depth_selection/test_depth_prediction_anonymous/image/0000000000.png 2 | depth_selection/test_depth_prediction_anonymous/image/0000000001.png 3 | depth_selection/test_depth_prediction_anonymous/image/0000000002.png 4 | depth_selection/test_depth_prediction_anonymous/image/0000000003.png 5 | depth_selection/test_depth_prediction_anonymous/image/0000000004.png 6 | depth_selection/test_depth_prediction_anonymous/image/0000000005.png 7 | depth_selection/test_depth_prediction_anonymous/image/0000000006.png 8 | depth_selection/test_depth_prediction_anonymous/image/0000000007.png 9 | depth_selection/test_depth_prediction_anonymous/image/0000000008.png 10 | depth_selection/test_depth_prediction_anonymous/image/0000000009.png 11 | depth_selection/test_depth_prediction_anonymous/image/0000000010.png 12 | depth_selection/test_depth_prediction_anonymous/image/0000000011.png 13 | depth_selection/test_depth_prediction_anonymous/image/0000000012.png 14 | depth_selection/test_depth_prediction_anonymous/image/0000000013.png 15 | depth_selection/test_depth_prediction_anonymous/image/0000000014.png 16 | depth_selection/test_depth_prediction_anonymous/image/0000000015.png 17 | depth_selection/test_depth_prediction_anonymous/image/0000000016.png 18 | depth_selection/test_depth_prediction_anonymous/image/0000000017.png 19 | depth_selection/test_depth_prediction_anonymous/image/0000000018.png 20 | depth_selection/test_depth_prediction_anonymous/image/0000000019.png 21 | depth_selection/test_depth_prediction_anonymous/image/0000000020.png 22 | depth_selection/test_depth_prediction_anonymous/image/0000000021.png 23 | depth_selection/test_depth_prediction_anonymous/image/0000000022.png 24 | depth_selection/test_depth_prediction_anonymous/image/0000000023.png 25 | depth_selection/test_depth_prediction_anonymous/image/0000000024.png 26 | depth_selection/test_depth_prediction_anonymous/image/0000000025.png 27 | depth_selection/test_depth_prediction_anonymous/image/0000000026.png 28 | depth_selection/test_depth_prediction_anonymous/image/0000000027.png 29 | depth_selection/test_depth_prediction_anonymous/image/0000000028.png 30 | depth_selection/test_depth_prediction_anonymous/image/0000000029.png 31 | depth_selection/test_depth_prediction_anonymous/image/0000000030.png 32 | depth_selection/test_depth_prediction_anonymous/image/0000000031.png 33 | depth_selection/test_depth_prediction_anonymous/image/0000000032.png 34 | depth_selection/test_depth_prediction_anonymous/image/0000000033.png 35 | depth_selection/test_depth_prediction_anonymous/image/0000000034.png 36 | depth_selection/test_depth_prediction_anonymous/image/0000000035.png 37 | depth_selection/test_depth_prediction_anonymous/image/0000000036.png 38 | depth_selection/test_depth_prediction_anonymous/image/0000000037.png 39 | depth_selection/test_depth_prediction_anonymous/image/0000000038.png 40 | depth_selection/test_depth_prediction_anonymous/image/0000000039.png 41 | depth_selection/test_depth_prediction_anonymous/image/0000000040.png 42 | depth_selection/test_depth_prediction_anonymous/image/0000000041.png 43 | depth_selection/test_depth_prediction_anonymous/image/0000000042.png 44 | depth_selection/test_depth_prediction_anonymous/image/0000000043.png 45 | depth_selection/test_depth_prediction_anonymous/image/0000000044.png 46 | depth_selection/test_depth_prediction_anonymous/image/0000000045.png 47 | depth_selection/test_depth_prediction_anonymous/image/0000000046.png 48 | depth_selection/test_depth_prediction_anonymous/image/0000000047.png 49 | depth_selection/test_depth_prediction_anonymous/image/0000000048.png 50 | depth_selection/test_depth_prediction_anonymous/image/0000000049.png 51 | depth_selection/test_depth_prediction_anonymous/image/0000000050.png 52 | depth_selection/test_depth_prediction_anonymous/image/0000000051.png 53 | depth_selection/test_depth_prediction_anonymous/image/0000000052.png 54 | depth_selection/test_depth_prediction_anonymous/image/0000000053.png 55 | depth_selection/test_depth_prediction_anonymous/image/0000000054.png 56 | depth_selection/test_depth_prediction_anonymous/image/0000000055.png 57 | depth_selection/test_depth_prediction_anonymous/image/0000000056.png 58 | depth_selection/test_depth_prediction_anonymous/image/0000000057.png 59 | depth_selection/test_depth_prediction_anonymous/image/0000000058.png 60 | depth_selection/test_depth_prediction_anonymous/image/0000000059.png 61 | depth_selection/test_depth_prediction_anonymous/image/0000000060.png 62 | depth_selection/test_depth_prediction_anonymous/image/0000000061.png 63 | depth_selection/test_depth_prediction_anonymous/image/0000000062.png 64 | depth_selection/test_depth_prediction_anonymous/image/0000000063.png 65 | depth_selection/test_depth_prediction_anonymous/image/0000000064.png 66 | depth_selection/test_depth_prediction_anonymous/image/0000000065.png 67 | depth_selection/test_depth_prediction_anonymous/image/0000000066.png 68 | depth_selection/test_depth_prediction_anonymous/image/0000000067.png 69 | depth_selection/test_depth_prediction_anonymous/image/0000000068.png 70 | depth_selection/test_depth_prediction_anonymous/image/0000000069.png 71 | depth_selection/test_depth_prediction_anonymous/image/0000000070.png 72 | depth_selection/test_depth_prediction_anonymous/image/0000000071.png 73 | depth_selection/test_depth_prediction_anonymous/image/0000000072.png 74 | depth_selection/test_depth_prediction_anonymous/image/0000000073.png 75 | depth_selection/test_depth_prediction_anonymous/image/0000000074.png 76 | depth_selection/test_depth_prediction_anonymous/image/0000000075.png 77 | depth_selection/test_depth_prediction_anonymous/image/0000000076.png 78 | depth_selection/test_depth_prediction_anonymous/image/0000000077.png 79 | depth_selection/test_depth_prediction_anonymous/image/0000000078.png 80 | depth_selection/test_depth_prediction_anonymous/image/0000000079.png 81 | depth_selection/test_depth_prediction_anonymous/image/0000000080.png 82 | depth_selection/test_depth_prediction_anonymous/image/0000000081.png 83 | depth_selection/test_depth_prediction_anonymous/image/0000000082.png 84 | depth_selection/test_depth_prediction_anonymous/image/0000000083.png 85 | depth_selection/test_depth_prediction_anonymous/image/0000000084.png 86 | depth_selection/test_depth_prediction_anonymous/image/0000000085.png 87 | depth_selection/test_depth_prediction_anonymous/image/0000000086.png 88 | depth_selection/test_depth_prediction_anonymous/image/0000000087.png 89 | depth_selection/test_depth_prediction_anonymous/image/0000000088.png 90 | depth_selection/test_depth_prediction_anonymous/image/0000000089.png 91 | depth_selection/test_depth_prediction_anonymous/image/0000000090.png 92 | depth_selection/test_depth_prediction_anonymous/image/0000000091.png 93 | depth_selection/test_depth_prediction_anonymous/image/0000000092.png 94 | depth_selection/test_depth_prediction_anonymous/image/0000000093.png 95 | depth_selection/test_depth_prediction_anonymous/image/0000000094.png 96 | depth_selection/test_depth_prediction_anonymous/image/0000000095.png 97 | depth_selection/test_depth_prediction_anonymous/image/0000000096.png 98 | depth_selection/test_depth_prediction_anonymous/image/0000000097.png 99 | depth_selection/test_depth_prediction_anonymous/image/0000000098.png 100 | depth_selection/test_depth_prediction_anonymous/image/0000000099.png 101 | depth_selection/test_depth_prediction_anonymous/image/0000000100.png 102 | depth_selection/test_depth_prediction_anonymous/image/0000000101.png 103 | depth_selection/test_depth_prediction_anonymous/image/0000000102.png 104 | depth_selection/test_depth_prediction_anonymous/image/0000000103.png 105 | depth_selection/test_depth_prediction_anonymous/image/0000000104.png 106 | depth_selection/test_depth_prediction_anonymous/image/0000000105.png 107 | depth_selection/test_depth_prediction_anonymous/image/0000000106.png 108 | depth_selection/test_depth_prediction_anonymous/image/0000000107.png 109 | depth_selection/test_depth_prediction_anonymous/image/0000000108.png 110 | depth_selection/test_depth_prediction_anonymous/image/0000000109.png 111 | depth_selection/test_depth_prediction_anonymous/image/0000000110.png 112 | depth_selection/test_depth_prediction_anonymous/image/0000000111.png 113 | depth_selection/test_depth_prediction_anonymous/image/0000000112.png 114 | depth_selection/test_depth_prediction_anonymous/image/0000000113.png 115 | depth_selection/test_depth_prediction_anonymous/image/0000000114.png 116 | depth_selection/test_depth_prediction_anonymous/image/0000000115.png 117 | depth_selection/test_depth_prediction_anonymous/image/0000000116.png 118 | depth_selection/test_depth_prediction_anonymous/image/0000000117.png 119 | depth_selection/test_depth_prediction_anonymous/image/0000000118.png 120 | depth_selection/test_depth_prediction_anonymous/image/0000000119.png 121 | depth_selection/test_depth_prediction_anonymous/image/0000000120.png 122 | depth_selection/test_depth_prediction_anonymous/image/0000000121.png 123 | depth_selection/test_depth_prediction_anonymous/image/0000000122.png 124 | depth_selection/test_depth_prediction_anonymous/image/0000000123.png 125 | depth_selection/test_depth_prediction_anonymous/image/0000000124.png 126 | depth_selection/test_depth_prediction_anonymous/image/0000000125.png 127 | depth_selection/test_depth_prediction_anonymous/image/0000000126.png 128 | depth_selection/test_depth_prediction_anonymous/image/0000000127.png 129 | depth_selection/test_depth_prediction_anonymous/image/0000000128.png 130 | depth_selection/test_depth_prediction_anonymous/image/0000000129.png 131 | depth_selection/test_depth_prediction_anonymous/image/0000000130.png 132 | depth_selection/test_depth_prediction_anonymous/image/0000000131.png 133 | depth_selection/test_depth_prediction_anonymous/image/0000000132.png 134 | depth_selection/test_depth_prediction_anonymous/image/0000000133.png 135 | depth_selection/test_depth_prediction_anonymous/image/0000000134.png 136 | depth_selection/test_depth_prediction_anonymous/image/0000000135.png 137 | depth_selection/test_depth_prediction_anonymous/image/0000000136.png 138 | depth_selection/test_depth_prediction_anonymous/image/0000000137.png 139 | depth_selection/test_depth_prediction_anonymous/image/0000000138.png 140 | depth_selection/test_depth_prediction_anonymous/image/0000000139.png 141 | depth_selection/test_depth_prediction_anonymous/image/0000000140.png 142 | depth_selection/test_depth_prediction_anonymous/image/0000000141.png 143 | depth_selection/test_depth_prediction_anonymous/image/0000000142.png 144 | depth_selection/test_depth_prediction_anonymous/image/0000000143.png 145 | depth_selection/test_depth_prediction_anonymous/image/0000000144.png 146 | depth_selection/test_depth_prediction_anonymous/image/0000000145.png 147 | depth_selection/test_depth_prediction_anonymous/image/0000000146.png 148 | depth_selection/test_depth_prediction_anonymous/image/0000000147.png 149 | depth_selection/test_depth_prediction_anonymous/image/0000000148.png 150 | depth_selection/test_depth_prediction_anonymous/image/0000000149.png 151 | depth_selection/test_depth_prediction_anonymous/image/0000000150.png 152 | depth_selection/test_depth_prediction_anonymous/image/0000000151.png 153 | depth_selection/test_depth_prediction_anonymous/image/0000000152.png 154 | depth_selection/test_depth_prediction_anonymous/image/0000000153.png 155 | depth_selection/test_depth_prediction_anonymous/image/0000000154.png 156 | depth_selection/test_depth_prediction_anonymous/image/0000000155.png 157 | depth_selection/test_depth_prediction_anonymous/image/0000000156.png 158 | depth_selection/test_depth_prediction_anonymous/image/0000000157.png 159 | depth_selection/test_depth_prediction_anonymous/image/0000000158.png 160 | depth_selection/test_depth_prediction_anonymous/image/0000000159.png 161 | depth_selection/test_depth_prediction_anonymous/image/0000000160.png 162 | depth_selection/test_depth_prediction_anonymous/image/0000000161.png 163 | depth_selection/test_depth_prediction_anonymous/image/0000000162.png 164 | depth_selection/test_depth_prediction_anonymous/image/0000000163.png 165 | depth_selection/test_depth_prediction_anonymous/image/0000000164.png 166 | depth_selection/test_depth_prediction_anonymous/image/0000000165.png 167 | depth_selection/test_depth_prediction_anonymous/image/0000000166.png 168 | depth_selection/test_depth_prediction_anonymous/image/0000000167.png 169 | depth_selection/test_depth_prediction_anonymous/image/0000000168.png 170 | depth_selection/test_depth_prediction_anonymous/image/0000000169.png 171 | depth_selection/test_depth_prediction_anonymous/image/0000000170.png 172 | depth_selection/test_depth_prediction_anonymous/image/0000000171.png 173 | depth_selection/test_depth_prediction_anonymous/image/0000000172.png 174 | depth_selection/test_depth_prediction_anonymous/image/0000000173.png 175 | depth_selection/test_depth_prediction_anonymous/image/0000000174.png 176 | depth_selection/test_depth_prediction_anonymous/image/0000000175.png 177 | depth_selection/test_depth_prediction_anonymous/image/0000000176.png 178 | depth_selection/test_depth_prediction_anonymous/image/0000000177.png 179 | depth_selection/test_depth_prediction_anonymous/image/0000000178.png 180 | depth_selection/test_depth_prediction_anonymous/image/0000000179.png 181 | depth_selection/test_depth_prediction_anonymous/image/0000000180.png 182 | depth_selection/test_depth_prediction_anonymous/image/0000000181.png 183 | depth_selection/test_depth_prediction_anonymous/image/0000000182.png 184 | depth_selection/test_depth_prediction_anonymous/image/0000000183.png 185 | depth_selection/test_depth_prediction_anonymous/image/0000000184.png 186 | depth_selection/test_depth_prediction_anonymous/image/0000000185.png 187 | depth_selection/test_depth_prediction_anonymous/image/0000000186.png 188 | depth_selection/test_depth_prediction_anonymous/image/0000000187.png 189 | depth_selection/test_depth_prediction_anonymous/image/0000000188.png 190 | depth_selection/test_depth_prediction_anonymous/image/0000000189.png 191 | depth_selection/test_depth_prediction_anonymous/image/0000000190.png 192 | depth_selection/test_depth_prediction_anonymous/image/0000000191.png 193 | depth_selection/test_depth_prediction_anonymous/image/0000000192.png 194 | depth_selection/test_depth_prediction_anonymous/image/0000000193.png 195 | depth_selection/test_depth_prediction_anonymous/image/0000000194.png 196 | depth_selection/test_depth_prediction_anonymous/image/0000000195.png 197 | depth_selection/test_depth_prediction_anonymous/image/0000000196.png 198 | depth_selection/test_depth_prediction_anonymous/image/0000000197.png 199 | depth_selection/test_depth_prediction_anonymous/image/0000000198.png 200 | depth_selection/test_depth_prediction_anonymous/image/0000000199.png 201 | depth_selection/test_depth_prediction_anonymous/image/0000000200.png 202 | depth_selection/test_depth_prediction_anonymous/image/0000000201.png 203 | depth_selection/test_depth_prediction_anonymous/image/0000000202.png 204 | depth_selection/test_depth_prediction_anonymous/image/0000000203.png 205 | depth_selection/test_depth_prediction_anonymous/image/0000000204.png 206 | depth_selection/test_depth_prediction_anonymous/image/0000000205.png 207 | depth_selection/test_depth_prediction_anonymous/image/0000000206.png 208 | depth_selection/test_depth_prediction_anonymous/image/0000000207.png 209 | depth_selection/test_depth_prediction_anonymous/image/0000000208.png 210 | depth_selection/test_depth_prediction_anonymous/image/0000000209.png 211 | depth_selection/test_depth_prediction_anonymous/image/0000000210.png 212 | depth_selection/test_depth_prediction_anonymous/image/0000000211.png 213 | depth_selection/test_depth_prediction_anonymous/image/0000000212.png 214 | depth_selection/test_depth_prediction_anonymous/image/0000000213.png 215 | depth_selection/test_depth_prediction_anonymous/image/0000000214.png 216 | depth_selection/test_depth_prediction_anonymous/image/0000000215.png 217 | depth_selection/test_depth_prediction_anonymous/image/0000000216.png 218 | depth_selection/test_depth_prediction_anonymous/image/0000000217.png 219 | depth_selection/test_depth_prediction_anonymous/image/0000000218.png 220 | depth_selection/test_depth_prediction_anonymous/image/0000000219.png 221 | depth_selection/test_depth_prediction_anonymous/image/0000000220.png 222 | depth_selection/test_depth_prediction_anonymous/image/0000000221.png 223 | depth_selection/test_depth_prediction_anonymous/image/0000000222.png 224 | depth_selection/test_depth_prediction_anonymous/image/0000000223.png 225 | depth_selection/test_depth_prediction_anonymous/image/0000000224.png 226 | depth_selection/test_depth_prediction_anonymous/image/0000000225.png 227 | depth_selection/test_depth_prediction_anonymous/image/0000000226.png 228 | depth_selection/test_depth_prediction_anonymous/image/0000000227.png 229 | depth_selection/test_depth_prediction_anonymous/image/0000000228.png 230 | depth_selection/test_depth_prediction_anonymous/image/0000000229.png 231 | depth_selection/test_depth_prediction_anonymous/image/0000000230.png 232 | depth_selection/test_depth_prediction_anonymous/image/0000000231.png 233 | depth_selection/test_depth_prediction_anonymous/image/0000000232.png 234 | depth_selection/test_depth_prediction_anonymous/image/0000000233.png 235 | depth_selection/test_depth_prediction_anonymous/image/0000000234.png 236 | depth_selection/test_depth_prediction_anonymous/image/0000000235.png 237 | depth_selection/test_depth_prediction_anonymous/image/0000000236.png 238 | depth_selection/test_depth_prediction_anonymous/image/0000000237.png 239 | depth_selection/test_depth_prediction_anonymous/image/0000000238.png 240 | depth_selection/test_depth_prediction_anonymous/image/0000000239.png 241 | depth_selection/test_depth_prediction_anonymous/image/0000000240.png 242 | depth_selection/test_depth_prediction_anonymous/image/0000000241.png 243 | depth_selection/test_depth_prediction_anonymous/image/0000000242.png 244 | depth_selection/test_depth_prediction_anonymous/image/0000000243.png 245 | depth_selection/test_depth_prediction_anonymous/image/0000000244.png 246 | depth_selection/test_depth_prediction_anonymous/image/0000000245.png 247 | depth_selection/test_depth_prediction_anonymous/image/0000000246.png 248 | depth_selection/test_depth_prediction_anonymous/image/0000000247.png 249 | depth_selection/test_depth_prediction_anonymous/image/0000000248.png 250 | depth_selection/test_depth_prediction_anonymous/image/0000000249.png 251 | depth_selection/test_depth_prediction_anonymous/image/0000000250.png 252 | depth_selection/test_depth_prediction_anonymous/image/0000000251.png 253 | depth_selection/test_depth_prediction_anonymous/image/0000000252.png 254 | depth_selection/test_depth_prediction_anonymous/image/0000000253.png 255 | depth_selection/test_depth_prediction_anonymous/image/0000000254.png 256 | depth_selection/test_depth_prediction_anonymous/image/0000000255.png 257 | depth_selection/test_depth_prediction_anonymous/image/0000000256.png 258 | depth_selection/test_depth_prediction_anonymous/image/0000000257.png 259 | depth_selection/test_depth_prediction_anonymous/image/0000000258.png 260 | depth_selection/test_depth_prediction_anonymous/image/0000000259.png 261 | depth_selection/test_depth_prediction_anonymous/image/0000000260.png 262 | depth_selection/test_depth_prediction_anonymous/image/0000000261.png 263 | depth_selection/test_depth_prediction_anonymous/image/0000000262.png 264 | depth_selection/test_depth_prediction_anonymous/image/0000000263.png 265 | depth_selection/test_depth_prediction_anonymous/image/0000000264.png 266 | depth_selection/test_depth_prediction_anonymous/image/0000000265.png 267 | depth_selection/test_depth_prediction_anonymous/image/0000000266.png 268 | depth_selection/test_depth_prediction_anonymous/image/0000000267.png 269 | depth_selection/test_depth_prediction_anonymous/image/0000000268.png 270 | depth_selection/test_depth_prediction_anonymous/image/0000000269.png 271 | depth_selection/test_depth_prediction_anonymous/image/0000000270.png 272 | depth_selection/test_depth_prediction_anonymous/image/0000000271.png 273 | depth_selection/test_depth_prediction_anonymous/image/0000000272.png 274 | depth_selection/test_depth_prediction_anonymous/image/0000000273.png 275 | depth_selection/test_depth_prediction_anonymous/image/0000000274.png 276 | depth_selection/test_depth_prediction_anonymous/image/0000000275.png 277 | depth_selection/test_depth_prediction_anonymous/image/0000000276.png 278 | depth_selection/test_depth_prediction_anonymous/image/0000000277.png 279 | depth_selection/test_depth_prediction_anonymous/image/0000000278.png 280 | depth_selection/test_depth_prediction_anonymous/image/0000000279.png 281 | depth_selection/test_depth_prediction_anonymous/image/0000000280.png 282 | depth_selection/test_depth_prediction_anonymous/image/0000000281.png 283 | depth_selection/test_depth_prediction_anonymous/image/0000000282.png 284 | depth_selection/test_depth_prediction_anonymous/image/0000000283.png 285 | depth_selection/test_depth_prediction_anonymous/image/0000000284.png 286 | depth_selection/test_depth_prediction_anonymous/image/0000000285.png 287 | depth_selection/test_depth_prediction_anonymous/image/0000000286.png 288 | depth_selection/test_depth_prediction_anonymous/image/0000000287.png 289 | depth_selection/test_depth_prediction_anonymous/image/0000000288.png 290 | depth_selection/test_depth_prediction_anonymous/image/0000000289.png 291 | depth_selection/test_depth_prediction_anonymous/image/0000000290.png 292 | depth_selection/test_depth_prediction_anonymous/image/0000000291.png 293 | depth_selection/test_depth_prediction_anonymous/image/0000000292.png 294 | depth_selection/test_depth_prediction_anonymous/image/0000000293.png 295 | depth_selection/test_depth_prediction_anonymous/image/0000000294.png 296 | depth_selection/test_depth_prediction_anonymous/image/0000000295.png 297 | depth_selection/test_depth_prediction_anonymous/image/0000000296.png 298 | depth_selection/test_depth_prediction_anonymous/image/0000000297.png 299 | depth_selection/test_depth_prediction_anonymous/image/0000000298.png 300 | depth_selection/test_depth_prediction_anonymous/image/0000000299.png 301 | depth_selection/test_depth_prediction_anonymous/image/0000000300.png 302 | depth_selection/test_depth_prediction_anonymous/image/0000000301.png 303 | depth_selection/test_depth_prediction_anonymous/image/0000000302.png 304 | depth_selection/test_depth_prediction_anonymous/image/0000000303.png 305 | depth_selection/test_depth_prediction_anonymous/image/0000000304.png 306 | depth_selection/test_depth_prediction_anonymous/image/0000000305.png 307 | depth_selection/test_depth_prediction_anonymous/image/0000000306.png 308 | depth_selection/test_depth_prediction_anonymous/image/0000000307.png 309 | depth_selection/test_depth_prediction_anonymous/image/0000000308.png 310 | depth_selection/test_depth_prediction_anonymous/image/0000000309.png 311 | depth_selection/test_depth_prediction_anonymous/image/0000000310.png 312 | depth_selection/test_depth_prediction_anonymous/image/0000000311.png 313 | depth_selection/test_depth_prediction_anonymous/image/0000000312.png 314 | depth_selection/test_depth_prediction_anonymous/image/0000000313.png 315 | depth_selection/test_depth_prediction_anonymous/image/0000000314.png 316 | depth_selection/test_depth_prediction_anonymous/image/0000000315.png 317 | depth_selection/test_depth_prediction_anonymous/image/0000000316.png 318 | depth_selection/test_depth_prediction_anonymous/image/0000000317.png 319 | depth_selection/test_depth_prediction_anonymous/image/0000000318.png 320 | depth_selection/test_depth_prediction_anonymous/image/0000000319.png 321 | depth_selection/test_depth_prediction_anonymous/image/0000000320.png 322 | depth_selection/test_depth_prediction_anonymous/image/0000000321.png 323 | depth_selection/test_depth_prediction_anonymous/image/0000000322.png 324 | depth_selection/test_depth_prediction_anonymous/image/0000000323.png 325 | depth_selection/test_depth_prediction_anonymous/image/0000000324.png 326 | depth_selection/test_depth_prediction_anonymous/image/0000000325.png 327 | depth_selection/test_depth_prediction_anonymous/image/0000000326.png 328 | depth_selection/test_depth_prediction_anonymous/image/0000000327.png 329 | depth_selection/test_depth_prediction_anonymous/image/0000000328.png 330 | depth_selection/test_depth_prediction_anonymous/image/0000000329.png 331 | depth_selection/test_depth_prediction_anonymous/image/0000000330.png 332 | depth_selection/test_depth_prediction_anonymous/image/0000000331.png 333 | depth_selection/test_depth_prediction_anonymous/image/0000000332.png 334 | depth_selection/test_depth_prediction_anonymous/image/0000000333.png 335 | depth_selection/test_depth_prediction_anonymous/image/0000000334.png 336 | depth_selection/test_depth_prediction_anonymous/image/0000000335.png 337 | depth_selection/test_depth_prediction_anonymous/image/0000000336.png 338 | depth_selection/test_depth_prediction_anonymous/image/0000000337.png 339 | depth_selection/test_depth_prediction_anonymous/image/0000000338.png 340 | depth_selection/test_depth_prediction_anonymous/image/0000000339.png 341 | depth_selection/test_depth_prediction_anonymous/image/0000000340.png 342 | depth_selection/test_depth_prediction_anonymous/image/0000000341.png 343 | depth_selection/test_depth_prediction_anonymous/image/0000000342.png 344 | depth_selection/test_depth_prediction_anonymous/image/0000000343.png 345 | depth_selection/test_depth_prediction_anonymous/image/0000000344.png 346 | depth_selection/test_depth_prediction_anonymous/image/0000000345.png 347 | depth_selection/test_depth_prediction_anonymous/image/0000000346.png 348 | depth_selection/test_depth_prediction_anonymous/image/0000000347.png 349 | depth_selection/test_depth_prediction_anonymous/image/0000000348.png 350 | depth_selection/test_depth_prediction_anonymous/image/0000000349.png 351 | depth_selection/test_depth_prediction_anonymous/image/0000000350.png 352 | depth_selection/test_depth_prediction_anonymous/image/0000000351.png 353 | depth_selection/test_depth_prediction_anonymous/image/0000000352.png 354 | depth_selection/test_depth_prediction_anonymous/image/0000000353.png 355 | depth_selection/test_depth_prediction_anonymous/image/0000000354.png 356 | depth_selection/test_depth_prediction_anonymous/image/0000000355.png 357 | depth_selection/test_depth_prediction_anonymous/image/0000000356.png 358 | depth_selection/test_depth_prediction_anonymous/image/0000000357.png 359 | depth_selection/test_depth_prediction_anonymous/image/0000000358.png 360 | depth_selection/test_depth_prediction_anonymous/image/0000000359.png 361 | depth_selection/test_depth_prediction_anonymous/image/0000000360.png 362 | depth_selection/test_depth_prediction_anonymous/image/0000000361.png 363 | depth_selection/test_depth_prediction_anonymous/image/0000000362.png 364 | depth_selection/test_depth_prediction_anonymous/image/0000000363.png 365 | depth_selection/test_depth_prediction_anonymous/image/0000000364.png 366 | depth_selection/test_depth_prediction_anonymous/image/0000000365.png 367 | depth_selection/test_depth_prediction_anonymous/image/0000000366.png 368 | depth_selection/test_depth_prediction_anonymous/image/0000000367.png 369 | depth_selection/test_depth_prediction_anonymous/image/0000000368.png 370 | depth_selection/test_depth_prediction_anonymous/image/0000000369.png 371 | depth_selection/test_depth_prediction_anonymous/image/0000000370.png 372 | depth_selection/test_depth_prediction_anonymous/image/0000000371.png 373 | depth_selection/test_depth_prediction_anonymous/image/0000000372.png 374 | depth_selection/test_depth_prediction_anonymous/image/0000000373.png 375 | depth_selection/test_depth_prediction_anonymous/image/0000000374.png 376 | depth_selection/test_depth_prediction_anonymous/image/0000000375.png 377 | depth_selection/test_depth_prediction_anonymous/image/0000000376.png 378 | depth_selection/test_depth_prediction_anonymous/image/0000000377.png 379 | depth_selection/test_depth_prediction_anonymous/image/0000000378.png 380 | depth_selection/test_depth_prediction_anonymous/image/0000000379.png 381 | depth_selection/test_depth_prediction_anonymous/image/0000000380.png 382 | depth_selection/test_depth_prediction_anonymous/image/0000000381.png 383 | depth_selection/test_depth_prediction_anonymous/image/0000000382.png 384 | depth_selection/test_depth_prediction_anonymous/image/0000000383.png 385 | depth_selection/test_depth_prediction_anonymous/image/0000000384.png 386 | depth_selection/test_depth_prediction_anonymous/image/0000000385.png 387 | depth_selection/test_depth_prediction_anonymous/image/0000000386.png 388 | depth_selection/test_depth_prediction_anonymous/image/0000000387.png 389 | depth_selection/test_depth_prediction_anonymous/image/0000000388.png 390 | depth_selection/test_depth_prediction_anonymous/image/0000000389.png 391 | depth_selection/test_depth_prediction_anonymous/image/0000000390.png 392 | depth_selection/test_depth_prediction_anonymous/image/0000000391.png 393 | depth_selection/test_depth_prediction_anonymous/image/0000000392.png 394 | depth_selection/test_depth_prediction_anonymous/image/0000000393.png 395 | depth_selection/test_depth_prediction_anonymous/image/0000000394.png 396 | depth_selection/test_depth_prediction_anonymous/image/0000000395.png 397 | depth_selection/test_depth_prediction_anonymous/image/0000000396.png 398 | depth_selection/test_depth_prediction_anonymous/image/0000000397.png 399 | depth_selection/test_depth_prediction_anonymous/image/0000000398.png 400 | depth_selection/test_depth_prediction_anonymous/image/0000000399.png 401 | depth_selection/test_depth_prediction_anonymous/image/0000000400.png 402 | depth_selection/test_depth_prediction_anonymous/image/0000000401.png 403 | depth_selection/test_depth_prediction_anonymous/image/0000000402.png 404 | depth_selection/test_depth_prediction_anonymous/image/0000000403.png 405 | depth_selection/test_depth_prediction_anonymous/image/0000000404.png 406 | depth_selection/test_depth_prediction_anonymous/image/0000000405.png 407 | depth_selection/test_depth_prediction_anonymous/image/0000000406.png 408 | depth_selection/test_depth_prediction_anonymous/image/0000000407.png 409 | depth_selection/test_depth_prediction_anonymous/image/0000000408.png 410 | depth_selection/test_depth_prediction_anonymous/image/0000000409.png 411 | depth_selection/test_depth_prediction_anonymous/image/0000000410.png 412 | depth_selection/test_depth_prediction_anonymous/image/0000000411.png 413 | depth_selection/test_depth_prediction_anonymous/image/0000000412.png 414 | depth_selection/test_depth_prediction_anonymous/image/0000000413.png 415 | depth_selection/test_depth_prediction_anonymous/image/0000000414.png 416 | depth_selection/test_depth_prediction_anonymous/image/0000000415.png 417 | depth_selection/test_depth_prediction_anonymous/image/0000000416.png 418 | depth_selection/test_depth_prediction_anonymous/image/0000000417.png 419 | depth_selection/test_depth_prediction_anonymous/image/0000000418.png 420 | depth_selection/test_depth_prediction_anonymous/image/0000000419.png 421 | depth_selection/test_depth_prediction_anonymous/image/0000000420.png 422 | depth_selection/test_depth_prediction_anonymous/image/0000000421.png 423 | depth_selection/test_depth_prediction_anonymous/image/0000000422.png 424 | depth_selection/test_depth_prediction_anonymous/image/0000000423.png 425 | depth_selection/test_depth_prediction_anonymous/image/0000000424.png 426 | depth_selection/test_depth_prediction_anonymous/image/0000000425.png 427 | depth_selection/test_depth_prediction_anonymous/image/0000000426.png 428 | depth_selection/test_depth_prediction_anonymous/image/0000000427.png 429 | depth_selection/test_depth_prediction_anonymous/image/0000000428.png 430 | depth_selection/test_depth_prediction_anonymous/image/0000000429.png 431 | depth_selection/test_depth_prediction_anonymous/image/0000000430.png 432 | depth_selection/test_depth_prediction_anonymous/image/0000000431.png 433 | depth_selection/test_depth_prediction_anonymous/image/0000000432.png 434 | depth_selection/test_depth_prediction_anonymous/image/0000000433.png 435 | depth_selection/test_depth_prediction_anonymous/image/0000000434.png 436 | depth_selection/test_depth_prediction_anonymous/image/0000000435.png 437 | depth_selection/test_depth_prediction_anonymous/image/0000000436.png 438 | depth_selection/test_depth_prediction_anonymous/image/0000000437.png 439 | depth_selection/test_depth_prediction_anonymous/image/0000000438.png 440 | depth_selection/test_depth_prediction_anonymous/image/0000000439.png 441 | depth_selection/test_depth_prediction_anonymous/image/0000000440.png 442 | depth_selection/test_depth_prediction_anonymous/image/0000000441.png 443 | depth_selection/test_depth_prediction_anonymous/image/0000000442.png 444 | depth_selection/test_depth_prediction_anonymous/image/0000000443.png 445 | depth_selection/test_depth_prediction_anonymous/image/0000000444.png 446 | depth_selection/test_depth_prediction_anonymous/image/0000000445.png 447 | depth_selection/test_depth_prediction_anonymous/image/0000000446.png 448 | depth_selection/test_depth_prediction_anonymous/image/0000000447.png 449 | depth_selection/test_depth_prediction_anonymous/image/0000000448.png 450 | depth_selection/test_depth_prediction_anonymous/image/0000000449.png 451 | depth_selection/test_depth_prediction_anonymous/image/0000000450.png 452 | depth_selection/test_depth_prediction_anonymous/image/0000000451.png 453 | depth_selection/test_depth_prediction_anonymous/image/0000000452.png 454 | depth_selection/test_depth_prediction_anonymous/image/0000000453.png 455 | depth_selection/test_depth_prediction_anonymous/image/0000000454.png 456 | depth_selection/test_depth_prediction_anonymous/image/0000000455.png 457 | depth_selection/test_depth_prediction_anonymous/image/0000000456.png 458 | depth_selection/test_depth_prediction_anonymous/image/0000000457.png 459 | depth_selection/test_depth_prediction_anonymous/image/0000000458.png 460 | depth_selection/test_depth_prediction_anonymous/image/0000000459.png 461 | depth_selection/test_depth_prediction_anonymous/image/0000000460.png 462 | depth_selection/test_depth_prediction_anonymous/image/0000000461.png 463 | depth_selection/test_depth_prediction_anonymous/image/0000000462.png 464 | depth_selection/test_depth_prediction_anonymous/image/0000000463.png 465 | depth_selection/test_depth_prediction_anonymous/image/0000000464.png 466 | depth_selection/test_depth_prediction_anonymous/image/0000000465.png 467 | depth_selection/test_depth_prediction_anonymous/image/0000000466.png 468 | depth_selection/test_depth_prediction_anonymous/image/0000000467.png 469 | depth_selection/test_depth_prediction_anonymous/image/0000000468.png 470 | depth_selection/test_depth_prediction_anonymous/image/0000000469.png 471 | depth_selection/test_depth_prediction_anonymous/image/0000000470.png 472 | depth_selection/test_depth_prediction_anonymous/image/0000000471.png 473 | depth_selection/test_depth_prediction_anonymous/image/0000000472.png 474 | depth_selection/test_depth_prediction_anonymous/image/0000000473.png 475 | depth_selection/test_depth_prediction_anonymous/image/0000000474.png 476 | depth_selection/test_depth_prediction_anonymous/image/0000000475.png 477 | depth_selection/test_depth_prediction_anonymous/image/0000000476.png 478 | depth_selection/test_depth_prediction_anonymous/image/0000000477.png 479 | depth_selection/test_depth_prediction_anonymous/image/0000000478.png 480 | depth_selection/test_depth_prediction_anonymous/image/0000000479.png 481 | depth_selection/test_depth_prediction_anonymous/image/0000000480.png 482 | depth_selection/test_depth_prediction_anonymous/image/0000000481.png 483 | depth_selection/test_depth_prediction_anonymous/image/0000000482.png 484 | depth_selection/test_depth_prediction_anonymous/image/0000000483.png 485 | depth_selection/test_depth_prediction_anonymous/image/0000000484.png 486 | depth_selection/test_depth_prediction_anonymous/image/0000000485.png 487 | depth_selection/test_depth_prediction_anonymous/image/0000000486.png 488 | depth_selection/test_depth_prediction_anonymous/image/0000000487.png 489 | depth_selection/test_depth_prediction_anonymous/image/0000000488.png 490 | depth_selection/test_depth_prediction_anonymous/image/0000000489.png 491 | depth_selection/test_depth_prediction_anonymous/image/0000000490.png 492 | depth_selection/test_depth_prediction_anonymous/image/0000000491.png 493 | depth_selection/test_depth_prediction_anonymous/image/0000000492.png 494 | depth_selection/test_depth_prediction_anonymous/image/0000000493.png 495 | depth_selection/test_depth_prediction_anonymous/image/0000000494.png 496 | depth_selection/test_depth_prediction_anonymous/image/0000000495.png 497 | depth_selection/test_depth_prediction_anonymous/image/0000000496.png 498 | depth_selection/test_depth_prediction_anonymous/image/0000000497.png 499 | depth_selection/test_depth_prediction_anonymous/image/0000000498.png 500 | depth_selection/test_depth_prediction_anonymous/image/0000000499.png 501 | -------------------------------------------------------------------------------- /files/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/intro.png -------------------------------------------------------------------------------- /files/office_00633.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/office_00633.jpg -------------------------------------------------------------------------------- /files/office_00633_depth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/office_00633_depth.jpg -------------------------------------------------------------------------------- /files/office_00633_pcd.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/office_00633_pcd.jpg -------------------------------------------------------------------------------- /files/output_nyu1_compressed.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/output_nyu1_compressed.gif -------------------------------------------------------------------------------- /files/output_nyu2_compressed.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/files/output_nyu2_compressed.gif -------------------------------------------------------------------------------- /newcrfs/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/newcrfs/dataloaders/__init__.py -------------------------------------------------------------------------------- /newcrfs/dataloaders/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torch.utils.data.distributed 4 | from torchvision import transforms 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import os 9 | import random 10 | 11 | from utils import DistributedSamplerNoEvenlyDivisible 12 | 13 | 14 | def _is_pil_image(img): 15 | return isinstance(img, Image.Image) 16 | 17 | 18 | def _is_numpy_image(img): 19 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 20 | 21 | 22 | def preprocessing_transforms(mode): 23 | return transforms.Compose([ 24 | ToTensor(mode=mode) 25 | ]) 26 | 27 | 28 | class NewDataLoader(object): 29 | def __init__(self, args, mode): 30 | if mode == 'train': 31 | self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode)) 32 | if args.distributed: 33 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples) 34 | else: 35 | self.train_sampler = None 36 | 37 | self.data = DataLoader(self.training_samples, args.batch_size, 38 | shuffle=(self.train_sampler is None), 39 | num_workers=args.num_threads, 40 | pin_memory=True, 41 | sampler=self.train_sampler) 42 | 43 | elif mode == 'online_eval': 44 | self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode)) 45 | if args.distributed: 46 | # self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False) 47 | self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False) 48 | else: 49 | self.eval_sampler = None 50 | self.data = DataLoader(self.testing_samples, 1, 51 | shuffle=False, 52 | num_workers=1, 53 | pin_memory=True, 54 | sampler=self.eval_sampler) 55 | 56 | elif mode == 'test': 57 | self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode)) 58 | self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1) 59 | 60 | else: 61 | print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode)) 62 | 63 | 64 | class DataLoadPreprocess(Dataset): 65 | def __init__(self, args, mode, transform=None, is_for_online_eval=False): 66 | self.args = args 67 | if mode == 'online_eval': 68 | with open(args.filenames_file_eval, 'r') as f: 69 | self.filenames = f.readlines() 70 | else: 71 | with open(args.filenames_file, 'r') as f: 72 | self.filenames = f.readlines() 73 | 74 | self.mode = mode 75 | self.transform = transform 76 | self.to_tensor = ToTensor 77 | self.is_for_online_eval = is_for_online_eval 78 | 79 | def __getitem__(self, idx): 80 | sample_path = self.filenames[idx] 81 | # focal = float(sample_path.split()[2]) 82 | focal = 518.8579 83 | 84 | if self.mode == 'train': 85 | if self.args.dataset == 'kitti': 86 | rgb_file = sample_path.split()[0] 87 | depth_file = os.path.join(sample_path.split()[0].split('/')[0], sample_path.split()[1]) 88 | if self.args.use_right is True and random.random() > 0.5: 89 | rgb_file.replace('image_02', 'image_03') 90 | depth_file.replace('image_02', 'image_03') 91 | else: 92 | rgb_file = sample_path.split()[0] 93 | depth_file = sample_path.split()[1] 94 | 95 | image_path = os.path.join(self.args.data_path, rgb_file) 96 | depth_path = os.path.join(self.args.gt_path, depth_file) 97 | 98 | image = Image.open(image_path) 99 | depth_gt = Image.open(depth_path) 100 | 101 | if self.args.do_kb_crop is True: 102 | height = image.height 103 | width = image.width 104 | top_margin = int(height - 352) 105 | left_margin = int((width - 1216) / 2) 106 | depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352)) 107 | image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352)) 108 | 109 | # To avoid blank boundaries due to pixel registration 110 | if self.args.dataset == 'nyu': 111 | if self.args.input_height == 480: 112 | depth_gt = np.array(depth_gt) 113 | valid_mask = np.zeros_like(depth_gt) 114 | valid_mask[45:472, 43:608] = 1 115 | depth_gt[valid_mask==0] = 0 116 | depth_gt = Image.fromarray(depth_gt) 117 | else: 118 | depth_gt = depth_gt.crop((43, 45, 608, 472)) 119 | image = image.crop((43, 45, 608, 472)) 120 | 121 | if self.args.do_random_rotate is True: 122 | random_angle = (random.random() - 0.5) * 2 * self.args.degree 123 | image = self.rotate_image(image, random_angle) 124 | depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST) 125 | 126 | image = np.asarray(image, dtype=np.float32) / 255.0 127 | depth_gt = np.asarray(depth_gt, dtype=np.float32) 128 | depth_gt = np.expand_dims(depth_gt, axis=2) 129 | 130 | if self.args.dataset == 'nyu': 131 | depth_gt = depth_gt / 1000.0 132 | else: 133 | depth_gt = depth_gt / 256.0 134 | 135 | if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width: 136 | image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width) 137 | image, depth_gt = self.train_preprocess(image, depth_gt) 138 | sample = {'image': image, 'depth': depth_gt, 'focal': focal} 139 | 140 | else: 141 | if self.mode == 'online_eval': 142 | data_path = self.args.data_path_eval 143 | else: 144 | data_path = self.args.data_path 145 | 146 | image_path = os.path.join(data_path, "./" + sample_path.split()[0]) 147 | image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 148 | 149 | if self.mode == 'online_eval': 150 | gt_path = self.args.gt_path_eval 151 | depth_path = os.path.join(gt_path, "./" + sample_path.split()[1]) 152 | if self.args.dataset == 'kitti': 153 | depth_path = os.path.join(gt_path, sample_path.split()[0].split('/')[0], sample_path.split()[1]) 154 | has_valid_depth = False 155 | try: 156 | depth_gt = Image.open(depth_path) 157 | has_valid_depth = True 158 | except IOError: 159 | depth_gt = False 160 | # print('Missing gt for {}'.format(image_path)) 161 | 162 | if has_valid_depth: 163 | depth_gt = np.asarray(depth_gt, dtype=np.float32) 164 | depth_gt = np.expand_dims(depth_gt, axis=2) 165 | if self.args.dataset == 'nyu': 166 | depth_gt = depth_gt / 1000.0 167 | else: 168 | depth_gt = depth_gt / 256.0 169 | 170 | if self.args.do_kb_crop is True: 171 | height = image.shape[0] 172 | width = image.shape[1] 173 | top_margin = int(height - 352) 174 | left_margin = int((width - 1216) / 2) 175 | image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :] 176 | if self.mode == 'online_eval' and has_valid_depth: 177 | depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :] 178 | 179 | if self.mode == 'online_eval': 180 | sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth} 181 | else: 182 | sample = {'image': image, 'focal': focal} 183 | 184 | if self.transform: 185 | sample = self.transform(sample) 186 | 187 | return sample 188 | 189 | def rotate_image(self, image, angle, flag=Image.BILINEAR): 190 | result = image.rotate(angle, resample=flag) 191 | return result 192 | 193 | def random_crop(self, img, depth, height, width): 194 | assert img.shape[0] >= height 195 | assert img.shape[1] >= width 196 | assert img.shape[0] == depth.shape[0] 197 | assert img.shape[1] == depth.shape[1] 198 | x = random.randint(0, img.shape[1] - width) 199 | y = random.randint(0, img.shape[0] - height) 200 | img = img[y:y + height, x:x + width, :] 201 | depth = depth[y:y + height, x:x + width, :] 202 | return img, depth 203 | 204 | def train_preprocess(self, image, depth_gt): 205 | # Random flipping 206 | do_flip = random.random() 207 | if do_flip > 0.5: 208 | image = (image[:, ::-1, :]).copy() 209 | depth_gt = (depth_gt[:, ::-1, :]).copy() 210 | 211 | # Random gamma, brightness, color augmentation 212 | do_augment = random.random() 213 | if do_augment > 0.5: 214 | image = self.augment_image(image) 215 | 216 | return image, depth_gt 217 | 218 | def augment_image(self, image): 219 | # gamma augmentation 220 | gamma = random.uniform(0.9, 1.1) 221 | image_aug = image ** gamma 222 | 223 | # brightness augmentation 224 | if self.args.dataset == 'nyu': 225 | brightness = random.uniform(0.75, 1.25) 226 | else: 227 | brightness = random.uniform(0.9, 1.1) 228 | image_aug = image_aug * brightness 229 | 230 | # color augmentation 231 | colors = np.random.uniform(0.9, 1.1, size=3) 232 | white = np.ones((image.shape[0], image.shape[1])) 233 | color_image = np.stack([white * colors[i] for i in range(3)], axis=2) 234 | image_aug *= color_image 235 | image_aug = np.clip(image_aug, 0, 1) 236 | 237 | return image_aug 238 | 239 | def __len__(self): 240 | return len(self.filenames) 241 | 242 | 243 | class ToTensor(object): 244 | def __init__(self, mode): 245 | self.mode = mode 246 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 247 | 248 | def __call__(self, sample): 249 | image, focal = sample['image'], sample['focal'] 250 | image = self.to_tensor(image) 251 | image = self.normalize(image) 252 | 253 | if self.mode == 'test': 254 | return {'image': image, 'focal': focal} 255 | 256 | depth = sample['depth'] 257 | if self.mode == 'train': 258 | depth = self.to_tensor(depth) 259 | return {'image': image, 'depth': depth, 'focal': focal} 260 | else: 261 | has_valid_depth = sample['has_valid_depth'] 262 | return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth} 263 | 264 | def to_tensor(self, pic): 265 | if not (_is_pil_image(pic) or _is_numpy_image(pic)): 266 | raise TypeError( 267 | 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) 268 | 269 | if isinstance(pic, np.ndarray): 270 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 271 | return img 272 | 273 | # handle PIL Image 274 | if pic.mode == 'I': 275 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 276 | elif pic.mode == 'I;16': 277 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 278 | else: 279 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 280 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 281 | if pic.mode == 'YCbCr': 282 | nchannel = 3 283 | elif pic.mode == 'I;16': 284 | nchannel = 1 285 | else: 286 | nchannel = len(pic.mode) 287 | img = img.view(pic.size[1], pic.size[0], nchannel) 288 | 289 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 290 | if isinstance(img, torch.ByteTensor): 291 | return img.float() 292 | else: 293 | return img 294 | -------------------------------------------------------------------------------- /newcrfs/dataloaders/dataloader_kittipred.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torch.utils.data.distributed 4 | from torchvision import transforms 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import os 9 | import random 10 | 11 | from utils import DistributedSamplerNoEvenlyDivisible 12 | 13 | 14 | def _is_pil_image(img): 15 | return isinstance(img, Image.Image) 16 | 17 | 18 | def _is_numpy_image(img): 19 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 20 | 21 | 22 | def preprocessing_transforms(mode): 23 | return transforms.Compose([ 24 | ToTensor(mode=mode) 25 | ]) 26 | 27 | 28 | class NewDataLoader(object): 29 | def __init__(self, args, mode): 30 | if mode == 'train': 31 | self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode)) 32 | if args.distributed: 33 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples) 34 | else: 35 | self.train_sampler = None 36 | 37 | self.data = DataLoader(self.training_samples, args.batch_size, 38 | shuffle=(self.train_sampler is None), 39 | num_workers=args.num_threads, 40 | pin_memory=True, 41 | sampler=self.train_sampler) 42 | 43 | elif mode == 'online_eval': 44 | self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode)) 45 | if args.distributed: 46 | # self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False) 47 | self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False) 48 | else: 49 | self.eval_sampler = None 50 | self.data = DataLoader(self.testing_samples, 1, 51 | shuffle=False, 52 | num_workers=1, 53 | pin_memory=True, 54 | sampler=self.eval_sampler) 55 | 56 | elif mode == 'test': 57 | self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode)) 58 | self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1) 59 | 60 | else: 61 | print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode)) 62 | 63 | 64 | class DataLoadPreprocess(Dataset): 65 | def __init__(self, args, mode, transform=None, is_for_online_eval=False): 66 | self.args = args 67 | if mode == 'online_eval': 68 | with open(args.filenames_file_eval, 'r') as f: 69 | self.filenames = f.readlines() 70 | else: 71 | with open(args.filenames_file, 'r') as f: 72 | self.filenames = f.readlines() 73 | 74 | self.mode = mode 75 | self.transform = transform 76 | self.to_tensor = ToTensor 77 | self.is_for_online_eval = is_for_online_eval 78 | 79 | def __getitem__(self, idx): 80 | sample_path = self.filenames[idx] 81 | # focal = float(sample_path.split()[2]) 82 | focal = 518.8579 83 | 84 | if self.mode == 'train': 85 | rgb_file = sample_path.split()[0] 86 | depth_file = rgb_file.replace('/image_02/data/', '/proj_depth/groundtruth/image_02/') 87 | if self.args.use_right is True and random.random() > 0.5: 88 | rgb_file.replace('image_02', 'image_03') 89 | depth_file.replace('image_02', 'image_03') 90 | 91 | image_path = os.path.join(self.args.data_path, rgb_file) 92 | depth_path = os.path.join(self.args.gt_path, depth_file) 93 | 94 | image = Image.open(image_path) 95 | depth_gt = Image.open(depth_path) 96 | 97 | if self.args.do_kb_crop is True: 98 | height = image.height 99 | width = image.width 100 | top_margin = int(height - 352) 101 | left_margin = int((width - 1216) / 2) 102 | depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352)) 103 | image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352)) 104 | 105 | if self.args.do_random_rotate is True: 106 | random_angle = (random.random() - 0.5) * 2 * self.args.degree 107 | image = self.rotate_image(image, random_angle) 108 | depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST) 109 | 110 | image = np.asarray(image, dtype=np.float32) / 255.0 111 | depth_gt = np.asarray(depth_gt, dtype=np.float32) 112 | depth_gt = np.expand_dims(depth_gt, axis=2) 113 | 114 | depth_gt = depth_gt / 256.0 115 | 116 | if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width: 117 | image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width) 118 | image, depth_gt = self.train_preprocess(image, depth_gt) 119 | sample = {'image': image, 'depth': depth_gt, 'focal': focal} 120 | 121 | else: 122 | if self.mode == 'online_eval': 123 | data_path = self.args.data_path_eval 124 | else: 125 | data_path = self.args.data_path 126 | 127 | image_path = os.path.join(data_path, "./" + sample_path.split()[0]) 128 | image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 129 | 130 | if self.mode == 'online_eval': 131 | gt_path = self.args.gt_path_eval 132 | depth_path = image_path.replace('/image/', '/groundtruth_depth/').replace('sync_image', 'sync_groundtruth_depth') 133 | has_valid_depth = False 134 | try: 135 | depth_gt = Image.open(depth_path) 136 | has_valid_depth = True 137 | except IOError: 138 | depth_gt = False 139 | # print('Missing gt for {}'.format(image_path)) 140 | 141 | if has_valid_depth: 142 | depth_gt = np.asarray(depth_gt, dtype=np.float32) 143 | depth_gt = np.expand_dims(depth_gt, axis=2) 144 | if self.args.dataset == 'nyu': 145 | depth_gt = depth_gt / 1000.0 146 | else: 147 | depth_gt = depth_gt / 256.0 148 | 149 | if self.args.do_kb_crop is True: 150 | height = image.shape[0] 151 | width = image.shape[1] 152 | top_margin = int(height - 352) 153 | left_margin = int((width - 1216) / 2) 154 | image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :] 155 | if self.mode == 'online_eval' and has_valid_depth: 156 | depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :] 157 | 158 | if self.mode == 'online_eval': 159 | sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth} 160 | else: 161 | sample = {'image': image, 'focal': focal} 162 | 163 | if self.transform: 164 | sample = self.transform(sample) 165 | 166 | return sample 167 | 168 | def rotate_image(self, image, angle, flag=Image.BILINEAR): 169 | result = image.rotate(angle, resample=flag) 170 | return result 171 | 172 | def random_crop(self, img, depth, height, width): 173 | assert img.shape[0] >= height 174 | assert img.shape[1] >= width 175 | assert img.shape[0] == depth.shape[0] 176 | assert img.shape[1] == depth.shape[1] 177 | x = random.randint(0, img.shape[1] - width) 178 | y = random.randint(0, img.shape[0] - height) 179 | img = img[y:y + height, x:x + width, :] 180 | depth = depth[y:y + height, x:x + width, :] 181 | return img, depth 182 | 183 | def train_preprocess(self, image, depth_gt): 184 | # Random flipping 185 | do_flip = random.random() 186 | if do_flip > 0.5: 187 | image = (image[:, ::-1, :]).copy() 188 | depth_gt = (depth_gt[:, ::-1, :]).copy() 189 | 190 | # Random gamma, brightness, color augmentation 191 | do_augment = random.random() 192 | if do_augment > 0.5: 193 | image = self.augment_image(image) 194 | 195 | return image, depth_gt 196 | 197 | def augment_image(self, image): 198 | # gamma augmentation 199 | gamma = random.uniform(0.9, 1.1) 200 | image_aug = image ** gamma 201 | 202 | # brightness augmentation 203 | if self.args.dataset == 'nyu': 204 | brightness = random.uniform(0.75, 1.25) 205 | else: 206 | brightness = random.uniform(0.9, 1.1) 207 | image_aug = image_aug * brightness 208 | 209 | # color augmentation 210 | colors = np.random.uniform(0.9, 1.1, size=3) 211 | white = np.ones((image.shape[0], image.shape[1])) 212 | color_image = np.stack([white * colors[i] for i in range(3)], axis=2) 213 | image_aug *= color_image 214 | image_aug = np.clip(image_aug, 0, 1) 215 | 216 | return image_aug 217 | 218 | def __len__(self): 219 | return len(self.filenames) 220 | 221 | 222 | class ToTensor(object): 223 | def __init__(self, mode): 224 | self.mode = mode 225 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 226 | 227 | def __call__(self, sample): 228 | image, focal = sample['image'], sample['focal'] 229 | image = self.to_tensor(image) 230 | image = self.normalize(image) 231 | 232 | if self.mode == 'test': 233 | return {'image': image, 'focal': focal} 234 | 235 | depth = sample['depth'] 236 | if self.mode == 'train': 237 | depth = self.to_tensor(depth) 238 | return {'image': image, 'depth': depth, 'focal': focal} 239 | else: 240 | has_valid_depth = sample['has_valid_depth'] 241 | return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth} 242 | 243 | def to_tensor(self, pic): 244 | if not (_is_pil_image(pic) or _is_numpy_image(pic)): 245 | raise TypeError( 246 | 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) 247 | 248 | if isinstance(pic, np.ndarray): 249 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 250 | return img 251 | 252 | # handle PIL Image 253 | if pic.mode == 'I': 254 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 255 | elif pic.mode == 'I;16': 256 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 257 | else: 258 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 259 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 260 | if pic.mode == 'YCbCr': 261 | nchannel = 3 262 | elif pic.mode == 'I;16': 263 | nchannel = 1 264 | else: 265 | nchannel = len(pic.mode) 266 | img = img.view(pic.size[1], pic.size[0], nchannel) 267 | 268 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 269 | if isinstance(img, torch.ByteTensor): 270 | return img.float() 271 | else: 272 | return img 273 | -------------------------------------------------------------------------------- /newcrfs/demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.backends.cudnn as cudnn 5 | from torch.autograd import Variable 6 | 7 | import os 8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 9 | import sys 10 | import time 11 | import argparse 12 | import numpy as np 13 | 14 | import cv2 15 | from scipy import ndimage 16 | from skimage.transform import resize 17 | import matplotlib.pyplot as plt 18 | 19 | plasma = plt.get_cmap('plasma') 20 | greys = plt.get_cmap('Greys') 21 | 22 | # UI and OpenGL 23 | from PySide2 import QtCore, QtGui, QtWidgets, QtOpenGL 24 | from OpenGL import GL, GLU 25 | from OpenGL.arrays import vbo 26 | from OpenGL.GL import shaders 27 | import glm 28 | 29 | from utils import post_process_depth, flip_lr 30 | from networks.NewCRFDepth import NewCRFDepth 31 | 32 | 33 | # Argument Parser 34 | parser = argparse.ArgumentParser(description='NeWCRFs Live 3D') 35 | parser.add_argument('--model_name', type=str, help='model name', default='newcrfs') 36 | parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07') 37 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10) 38 | parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', required=True) 39 | parser.add_argument('--input_height', type=int, help='input height', default=480) 40 | parser.add_argument('--input_width', type=int, help='input width', default=640) 41 | parser.add_argument('--dataset', type=str, help='dataset this model trained on', default='nyu') 42 | parser.add_argument('--crop', type=str, help='crop: kbcrop, edge, non', default='non') 43 | parser.add_argument('--video', type=str, help='video path', default='') 44 | 45 | args = parser.parse_args() 46 | 47 | # Image shapes 48 | height_rgb, width_rgb = args.input_height, args.input_width 49 | height_depth, width_depth = height_rgb, width_rgb 50 | 51 | 52 | # =============== Intrinsics rectify ================== 53 | # Open this if you have the real intrinsics 54 | Use_intrs_remap = False 55 | # Intrinsic parameters for your own webcam camera 56 | camera_matrix = np.zeros(shape=(3, 3)) 57 | camera_matrix[0, 0] = 5.4765313594010649e+02 58 | camera_matrix[0, 2] = 3.2516069906172453e+02 59 | camera_matrix[1, 1] = 5.4801781476172562e+02 60 | camera_matrix[1, 2] = 2.4794113960783835e+02 61 | camera_matrix[2, 2] = 1 62 | dist_coeffs = np.array([ 3.7230261423972011e-02, -1.6171708069773008e-01, -3.5260752900266357e-04, 1.7161234226767313e-04, 1.0192711400840315e-01 ]) 63 | # Parameters for a model trained on NYU Depth V2 64 | new_camera_matrix = np.zeros(shape=(3, 3)) 65 | new_camera_matrix[0, 0] = 518.8579 66 | new_camera_matrix[0, 2] = 320 67 | new_camera_matrix[1, 1] = 518.8579 68 | new_camera_matrix[1, 2] = 240 69 | new_camera_matrix[2, 2] = 1 70 | 71 | R = np.identity(3, dtype=np.float) 72 | map1, map2 = cv2.initUndistortRectifyMap(camera_matrix, dist_coeffs, R, new_camera_matrix, (width_rgb, height_rgb), cv2.CV_32FC1) 73 | 74 | 75 | def load_model(): 76 | args.mode = 'test' 77 | model = NewCRFDepth(version='large07', inv_depth=False, max_depth=args.max_depth) 78 | model = torch.nn.DataParallel(model) 79 | 80 | checkpoint = torch.load(args.checkpoint_path) 81 | model.load_state_dict(checkpoint['model']) 82 | model.eval() 83 | model.cuda() 84 | 85 | return model 86 | 87 | # Function timing 88 | ticTime = time.time() 89 | 90 | 91 | def tic(): 92 | global ticTime; 93 | ticTime = time.time() 94 | 95 | 96 | def toc(): 97 | print('{0} seconds.'.format(time.time() - ticTime)) 98 | 99 | 100 | # Conversion from Numpy to QImage and back 101 | def np_to_qimage(a): 102 | im = a.copy() 103 | return QtGui.QImage(im.data, im.shape[1], im.shape[0], im.strides[0], QtGui.QImage.Format_RGB888).copy() 104 | 105 | 106 | def qimage_to_np(img): 107 | img = img.convertToFormat(QtGui.QImage.Format.Format_ARGB32) 108 | return np.array(img.constBits()).reshape(img.height(), img.width(), 4) 109 | 110 | 111 | # Compute edge magnitudes 112 | def edges(d): 113 | dx = ndimage.sobel(d, 0) # horizontal derivative 114 | dy = ndimage.sobel(d, 1) # vertical derivative 115 | return np.abs(dx) + np.abs(dy) 116 | 117 | 118 | # Main window 119 | class Window(QtWidgets.QWidget): 120 | updateInput = QtCore.Signal() 121 | 122 | def __init__(self, parent=None): 123 | QtWidgets.QWidget.__init__(self, parent) 124 | self.model = None 125 | self.capture = None 126 | self.glWidget = GLWidget() 127 | 128 | mainLayout = QtWidgets.QVBoxLayout() 129 | 130 | # Input / output views 131 | viewsLayout = QtWidgets.QGridLayout() 132 | self.inputViewer = QtWidgets.QLabel("[Click to start]") 133 | self.inputViewer.setPixmap(QtGui.QPixmap(width_rgb, height_rgb)) 134 | self.outputViewer = QtWidgets.QLabel("[Click to start]") 135 | self.outputViewer.setPixmap(QtGui.QPixmap(width_rgb, height_rgb)) 136 | 137 | imgsFrame = QtWidgets.QFrame() 138 | inputsLayout = QtWidgets.QVBoxLayout() 139 | imgsFrame.setLayout(inputsLayout) 140 | inputsLayout.addWidget(self.inputViewer) 141 | inputsLayout.addWidget(self.outputViewer) 142 | 143 | viewsLayout.addWidget(imgsFrame, 0, 0) 144 | viewsLayout.addWidget(self.glWidget, 0, 1) 145 | viewsLayout.setColumnStretch(1, 10) 146 | mainLayout.addLayout(viewsLayout) 147 | 148 | # Load depth estimation model 149 | toolsLayout = QtWidgets.QHBoxLayout() 150 | 151 | self.button2 = QtWidgets.QPushButton("Webcam") 152 | self.button2.clicked.connect(self.loadCamera) 153 | toolsLayout.addWidget(self.button2) 154 | 155 | self.button3 = QtWidgets.QPushButton("Video") 156 | self.button3.clicked.connect(self.loadVideoFile) 157 | toolsLayout.addWidget(self.button3) 158 | 159 | self.button4 = QtWidgets.QPushButton("Pause") 160 | self.button4.clicked.connect(self.loadImage) 161 | toolsLayout.addWidget(self.button4) 162 | 163 | self.button6 = QtWidgets.QPushButton("Refresh") 164 | self.button6.clicked.connect(self.updateCloud) 165 | toolsLayout.addWidget(self.button6) 166 | 167 | mainLayout.addLayout(toolsLayout) 168 | 169 | self.setLayout(mainLayout) 170 | self.setWindowTitle(self.tr("NeWCRFs Live")) 171 | 172 | # Signals 173 | self.updateInput.connect(self.update_input) 174 | 175 | # Default example 176 | if self.glWidget.rgb.any() and self.glWidget.depth.any(): 177 | img = (self.glWidget.rgb * 255).astype('uint8') 178 | self.inputViewer.setPixmap(QtGui.QPixmap.fromImage(np_to_qimage(img))) 179 | coloredDepth = (plasma(self.glWidget.depth[:, :, 0])[:, :, :3] * 255).astype('uint8') 180 | self.outputViewer.setPixmap(QtGui.QPixmap.fromImage(np_to_qimage(coloredDepth))) 181 | 182 | def loadModel(self): 183 | print('== loadModel') 184 | QtGui.QGuiApplication.setOverrideCursor(QtCore.Qt.WaitCursor) 185 | tic() 186 | self.model = load_model() 187 | print('Model loaded.') 188 | toc() 189 | self.updateCloud() 190 | QtGui.QGuiApplication.restoreOverrideCursor() 191 | 192 | def loadCamera(self): 193 | print('== loadCamera') 194 | tic() 195 | self.model = load_model() 196 | print('Model loaded.') 197 | toc() 198 | self.capture = cv2.VideoCapture(0) 199 | self.updateInput.emit() 200 | 201 | def loadVideoFile(self): 202 | print('== loadVideoFile') 203 | self.model = load_model() 204 | self.capture = cv2.VideoCapture(args.video) 205 | self.updateInput.emit() 206 | 207 | def loadImage(self): 208 | print('== loadImage') 209 | self.capture = None 210 | img = (self.glWidget.rgb * 255).astype('uint8') 211 | self.inputViewer.setPixmap(QtGui.QPixmap.fromImage(np_to_qimage(img))) 212 | self.updateCloud() 213 | 214 | def loadImageFile(self): 215 | print('== loadImageFile') 216 | self.capture = None 217 | filename = \ 218 | QtWidgets.QFileDialog.getOpenFileName(None, 'Select image', '', self.tr('Image files (*.jpg *.png)'))[0] 219 | img = QtGui.QImage(filename).scaledToHeight(height_rgb) 220 | xstart = 0 221 | if img.width() > width_rgb: xstart = (img.width() - width_rgb) // 2 222 | img = img.copy(xstart, 0, xstart + width_rgb, height_rgb) 223 | self.inputViewer.setPixmap(QtGui.QPixmap.fromImage(img)) 224 | print('== loadImageFile') 225 | self.updateCloud() 226 | 227 | def update_input(self): 228 | print('== update_input') 229 | # Don't update anymore if no capture device is set 230 | if self.capture == None: 231 | return 232 | 233 | # Capture a frame 234 | ret, frame = self.capture.read() 235 | 236 | # Loop video playback if current stream is video file 237 | if not ret: 238 | self.capture.set(cv2.CAP_PROP_POS_FRAMES, 0) 239 | ret, frame = self.capture.read() 240 | 241 | # Prepare image and show in UI 242 | if Use_intrs_remap: 243 | frame_ud = cv2.remap(frame, map1, map2, interpolation=cv2.INTER_LINEAR) 244 | else: 245 | frame_ud = cv2.resize(frame, (width_rgb, height_rgb), interpolation=cv2.INTER_LINEAR) 246 | frame = cv2.cvtColor(frame_ud, cv2.COLOR_BGR2RGB) 247 | image = np_to_qimage(frame) 248 | self.inputViewer.setPixmap(QtGui.QPixmap.fromImage(image)) 249 | 250 | # Update the point cloud 251 | self.updateCloud() 252 | 253 | def updateCloud(self): 254 | print('== updateCloud') 255 | rgb8 = qimage_to_np(self.inputViewer.pixmap().toImage()) 256 | self.glWidget.rgb = (rgb8[:, :, :3] / 255)[:, :, ::-1] 257 | 258 | if self.model: 259 | input_image = rgb8[:, :, :3].astype(np.float32) 260 | 261 | # Normalize image 262 | input_image[:, :, 0] = (input_image[:, :, 0] - 123.68) * 0.017 263 | input_image[:, :, 1] = (input_image[:, :, 1] - 116.78) * 0.017 264 | input_image[:, :, 2] = (input_image[:, :, 2] - 103.94) * 0.017 265 | 266 | H, W, _ = input_image.shape 267 | if args.crop == 'kbcrop': 268 | top_margin = int(H - 352) 269 | left_margin = int((W - 1216) / 2) 270 | input_image_cropped = input_image[top_margin:top_margin + 352, 271 | left_margin:left_margin + 1216] 272 | elif args.crop == 'edge': 273 | input_image_cropped = input_image[32:-32, 32:-32, :] 274 | else: 275 | input_image_cropped = input_image 276 | 277 | input_images = np.expand_dims(input_image_cropped, axis=0) 278 | input_images = np.transpose(input_images, (0, 3, 1, 2)) 279 | 280 | with torch.no_grad(): 281 | image = Variable(torch.from_numpy(input_images)).cuda() 282 | # Predict 283 | depth_est = self.model(image) 284 | post_process = True 285 | if post_process: 286 | image_flipped = flip_lr(image) 287 | depth_est_flipped = self.model(image_flipped) 288 | depth_cropped = post_process_depth(depth_est, depth_est_flipped) 289 | 290 | depth = np.zeros((height_depth, width_depth), dtype=np.float32) 291 | if args.crop == 'kbcrop': 292 | depth[top_margin:top_margin + 352, left_margin:left_margin + 1216] = \ 293 | depth_cropped[0].cpu().squeeze() / args.max_depth 294 | elif args.crop == 'edge': 295 | depth[32:-32, 32:-32] = depth_cropped[0].cpu().squeeze() / args.max_depth 296 | else: 297 | depth[:, :] = depth_cropped[0].cpu().squeeze() / args.max_depth 298 | 299 | coloredDepth = (greys(np.log10(depth * args.max_depth))[:, :, :3] * 255).astype('uint8') 300 | self.outputViewer.setPixmap(QtGui.QPixmap.fromImage(np_to_qimage(coloredDepth))) 301 | self.glWidget.depth = depth 302 | 303 | else: 304 | self.glWidget.depth = 0.5 + np.zeros((height_rgb // 2, width_rgb // 2, 1)) 305 | 306 | self.glWidget.updateRGBD() 307 | self.glWidget.updateGL() 308 | 309 | # Update to next frame if we are live 310 | QtCore.QTimer.singleShot(10, self.updateInput) 311 | 312 | 313 | class GLWidget(QtOpenGL.QGLWidget): 314 | def __init__(self, parent=None): 315 | QtOpenGL.QGLWidget.__init__(self, parent) 316 | 317 | self.object = 0 318 | self.xRot = 5040 319 | self.yRot = 40 320 | self.zRot = 0 321 | self.zoomLevel = 9 322 | 323 | self.lastPos = QtCore.QPoint() 324 | 325 | self.green = QtGui.QColor.fromCmykF(0.0, 0.0, 0.0, 1.0) 326 | self.black = QtGui.QColor.fromCmykF(0.0, 0.0, 0.0, 1.0) 327 | 328 | # Precompute for world coordinates 329 | self.xx, self.yy = self.worldCoords(width=width_rgb, height=height_rgb) 330 | 331 | self.rgb = np.zeros((height_rgb, width_rgb, 3), dtype=np.uint8) 332 | self.depth = np.zeros((height_depth, height_depth), dtype=np.float32) 333 | 334 | self.col_vbo = None 335 | self.pos_vbo = None 336 | if self.rgb.any() and self.depth.any(): 337 | self.updateRGBD() 338 | 339 | def xRotation(self): 340 | return self.xRot 341 | 342 | def yRotation(self): 343 | return self.yRot 344 | 345 | def zRotation(self): 346 | return self.zRot 347 | 348 | def minimumSizeHint(self): 349 | return QtCore.QSize(height_rgb, width_rgb) 350 | 351 | def sizeHint(self): 352 | return QtCore.QSize(height_rgb, width_rgb) 353 | 354 | def setXRotation(self, angle): 355 | if angle != self.xRot: 356 | self.xRot = angle 357 | self.emit(QtCore.SIGNAL("xRotationChanged(int)"), angle) 358 | self.updateGL() 359 | 360 | def setYRotation(self, angle): 361 | if angle != self.yRot: 362 | self.yRot = angle 363 | self.emit(QtCore.SIGNAL("yRotationChanged(int)"), angle) 364 | self.updateGL() 365 | 366 | def setZRotation(self, angle): 367 | if angle != self.zRot: 368 | self.zRot = angle 369 | self.emit(QtCore.SIGNAL("zRotationChanged(int)"), angle) 370 | self.updateGL() 371 | 372 | def resizeGL(self, width, height): 373 | GL.glViewport(0, 0, width, height) 374 | 375 | def mousePressEvent(self, event): 376 | self.lastPos = QtCore.QPoint(event.pos()) 377 | 378 | def mouseMoveEvent(self, event): 379 | dx = -(event.x() - self.lastPos.x()) 380 | dy = (event.y() - self.lastPos.y()) 381 | 382 | if event.buttons() & QtCore.Qt.LeftButton: 383 | self.setXRotation(self.xRot + dy) 384 | self.setYRotation(self.yRot + dx) 385 | elif event.buttons() & QtCore.Qt.RightButton: 386 | self.setXRotation(self.xRot + dy) 387 | self.setZRotation(self.zRot + dx) 388 | 389 | self.lastPos = QtCore.QPoint(event.pos()) 390 | 391 | def wheelEvent(self, event): 392 | numDegrees = event.delta() / 8 393 | numSteps = numDegrees / 15 394 | self.zoomLevel = self.zoomLevel + numSteps 395 | event.accept() 396 | self.updateGL() 397 | 398 | def initializeGL(self): 399 | self.qglClearColor(self.black.darker()) 400 | GL.glShadeModel(GL.GL_FLAT) 401 | GL.glEnable(GL.GL_DEPTH_TEST) 402 | GL.glEnable(GL.GL_CULL_FACE) 403 | 404 | VERTEX_SHADER = shaders.compileShader("""#version 330 405 | layout(location = 0) in vec3 position; 406 | layout(location = 1) in vec3 color; 407 | uniform mat4 mvp; out vec4 frag_color; 408 | void main() {gl_Position = mvp * vec4(position, 1.0);frag_color = vec4(color, 1.0);}""", GL.GL_VERTEX_SHADER) 409 | 410 | FRAGMENT_SHADER = shaders.compileShader("""#version 330 411 | in vec4 frag_color; out vec4 out_color; 412 | void main() {out_color = frag_color;}""", GL.GL_FRAGMENT_SHADER) 413 | 414 | self.shaderProgram = shaders.compileProgram(VERTEX_SHADER, FRAGMENT_SHADER) 415 | 416 | self.UNIFORM_LOCATIONS = { 417 | 'position': GL.glGetAttribLocation(self.shaderProgram, 'position'), 418 | 'color': GL.glGetAttribLocation(self.shaderProgram, 'color'), 419 | 'mvp': GL.glGetUniformLocation(self.shaderProgram, 'mvp'), 420 | } 421 | 422 | shaders.glUseProgram(self.shaderProgram) 423 | 424 | def paintGL(self): 425 | if self.rgb.any() and self.depth.any(): 426 | GL.glClear(GL.GL_COLOR_BUFFER_BIT | GL.GL_DEPTH_BUFFER_BIT) 427 | self.drawObject() 428 | 429 | def worldCoords(self, width, height): 430 | cx, cy = width / 2, height / 2 431 | fx = 518.8579 432 | fy = 518.8579 433 | xx, yy = np.tile(range(width), height), np.repeat(range(height), width) 434 | xx = (xx - cx) / fx 435 | yy = (yy - cy) / fy 436 | return xx, yy 437 | 438 | def posFromDepth(self, depth): 439 | length = depth.shape[0] * depth.shape[1] 440 | 441 | depth[edges(depth) > 0.3] = 1e6 # Hide depth edges 442 | z = depth.reshape(length) 443 | 444 | return np.dstack((self.xx * z, self.yy * z, z)).reshape((length, 3)) 445 | 446 | def createPointCloudVBOfromRGBD(self): 447 | # Create position and color VBOs 448 | self.pos_vbo = vbo.VBO(data=self.pos, usage=GL.GL_DYNAMIC_DRAW, target=GL.GL_ARRAY_BUFFER) 449 | self.col_vbo = vbo.VBO(data=self.col, usage=GL.GL_DYNAMIC_DRAW, target=GL.GL_ARRAY_BUFFER) 450 | 451 | def updateRGBD(self): 452 | # RGBD dimensions 453 | width, height = self.depth.shape[1], self.depth.shape[0] 454 | 455 | # Reshape 456 | points = self.posFromDepth(self.depth.copy()) 457 | colors = resize(self.rgb, (height, width)).reshape((height * width, 3)) 458 | 459 | # Flatten and convert to float32 460 | self.pos = points.astype('float32') 461 | self.col = colors.reshape(height * width, 3).astype('float32') 462 | 463 | # Move center of scene 464 | self.pos = self.pos + glm.vec3(0, -0.06, -0.3) 465 | 466 | # Create VBOs 467 | if not self.col_vbo: 468 | self.createPointCloudVBOfromRGBD() 469 | 470 | def drawObject(self): 471 | # Update camera 472 | model, view, proj = glm.mat4(1), glm.mat4(1), glm.perspective(45, self.width() / self.height(), 0.01, 100) 473 | center, up, eye = glm.vec3(0, -0.075, 0), glm.vec3(0, -1, 0), glm.vec3(0, 0, -0.4 * (self.zoomLevel / 10)) 474 | view = glm.lookAt(eye, center, up) 475 | model = glm.rotate(model, self.xRot / 160.0, glm.vec3(1, 0, 0)) 476 | model = glm.rotate(model, self.yRot / 160.0, glm.vec3(0, 1, 0)) 477 | model = glm.rotate(model, self.zRot / 160.0, glm.vec3(0, 0, 1)) 478 | mvp = proj * view * model 479 | GL.glUniformMatrix4fv(self.UNIFORM_LOCATIONS['mvp'], 1, False, glm.value_ptr(mvp)) 480 | 481 | # Update data 482 | self.pos_vbo.set_array(self.pos) 483 | self.col_vbo.set_array(self.col) 484 | 485 | # Point size 486 | GL.glPointSize(2) 487 | 488 | # Position 489 | self.pos_vbo.bind() 490 | GL.glEnableVertexAttribArray(0) 491 | GL.glVertexAttribPointer(0, 3, GL.GL_FLOAT, GL.GL_FALSE, 0, None) 492 | 493 | # Color 494 | self.col_vbo.bind() 495 | GL.glEnableVertexAttribArray(1) 496 | GL.glVertexAttribPointer(1, 3, GL.GL_FLOAT, GL.GL_FALSE, 0, None) 497 | 498 | # Draw 499 | GL.glDrawArrays(GL.GL_POINTS, 0, self.pos.shape[0]) 500 | 501 | if __name__ == '__main__': 502 | app = QtWidgets.QApplication(sys.argv) 503 | window = Window() 504 | window.show() 505 | res = app.exec_() -------------------------------------------------------------------------------- /newcrfs/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | 4 | import os, sys 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from utils import post_process_depth, flip_lr, compute_errors 10 | from networks.NewCRFDepth import NewCRFDepth 11 | 12 | 13 | def convert_arg_line_to_args(arg_line): 14 | for arg in arg_line.split(): 15 | if not arg.strip(): 16 | continue 17 | yield arg 18 | 19 | 20 | parser = argparse.ArgumentParser(description='NeWCRFs PyTorch implementation.', fromfile_prefix_chars='@') 21 | parser.convert_arg_line_to_args = convert_arg_line_to_args 22 | 23 | parser.add_argument('--model_name', type=str, help='model name', default='newcrfs') 24 | parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07') 25 | parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='') 26 | 27 | # Dataset 28 | parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu') 29 | parser.add_argument('--input_height', type=int, help='input height', default=480) 30 | parser.add_argument('--input_width', type=int, help='input width', default=640) 31 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10) 32 | 33 | # Preprocessing 34 | parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true') 35 | parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5) 36 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true') 37 | parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true') 38 | 39 | # Eval 40 | parser.add_argument('--data_path_eval', type=str, help='path to the data for evaluation', required=False) 41 | parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for evaluation', required=False) 42 | parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for evaluation', required=False) 43 | parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3) 44 | parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80) 45 | parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true') 46 | parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true') 47 | 48 | 49 | if sys.argv.__len__() == 2: 50 | arg_filename_with_prefix = '@' + sys.argv[1] 51 | args = parser.parse_args([arg_filename_with_prefix]) 52 | else: 53 | args = parser.parse_args() 54 | 55 | if args.dataset == 'kitti' or args.dataset == 'nyu': 56 | from dataloaders.dataloader import NewDataLoader 57 | elif args.dataset == 'kittipred': 58 | from dataloaders.dataloader_kittipred import NewDataLoader 59 | 60 | 61 | def eval(model, dataloader_eval, post_process=False): 62 | eval_measures = torch.zeros(10).cuda() 63 | for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)): 64 | with torch.no_grad(): 65 | image = torch.autograd.Variable(eval_sample_batched['image'].cuda()) 66 | gt_depth = eval_sample_batched['depth'] 67 | has_valid_depth = eval_sample_batched['has_valid_depth'] 68 | if not has_valid_depth: 69 | # print('Invalid depth. continue.') 70 | continue 71 | 72 | pred_depth = model(image) 73 | if post_process: 74 | image_flipped = flip_lr(image) 75 | pred_depth_flipped = model(image_flipped) 76 | pred_depth = post_process_depth(pred_depth, pred_depth_flipped) 77 | 78 | pred_depth = pred_depth.cpu().numpy().squeeze() 79 | gt_depth = gt_depth.cpu().numpy().squeeze() 80 | 81 | if args.do_kb_crop: 82 | height, width = gt_depth.shape 83 | top_margin = int(height - 352) 84 | left_margin = int((width - 1216) / 2) 85 | pred_depth_uncropped = np.zeros((height, width), dtype=np.float32) 86 | pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth 87 | pred_depth = pred_depth_uncropped 88 | 89 | pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval 90 | pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval 91 | pred_depth[np.isinf(pred_depth)] = args.max_depth_eval 92 | pred_depth[np.isnan(pred_depth)] = args.min_depth_eval 93 | 94 | valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval) 95 | 96 | if args.garg_crop or args.eigen_crop: 97 | gt_height, gt_width = gt_depth.shape 98 | eval_mask = np.zeros(valid_mask.shape) 99 | 100 | if args.garg_crop: 101 | eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1 102 | 103 | elif args.eigen_crop: 104 | if args.dataset == 'kitti': 105 | eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1 106 | elif args.dataset == 'nyu': 107 | eval_mask[45:471, 41:601] = 1 108 | 109 | valid_mask = np.logical_and(valid_mask, eval_mask) 110 | 111 | measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask]) 112 | 113 | eval_measures[:9] += torch.tensor(measures).cuda() 114 | eval_measures[9] += 1 115 | 116 | eval_measures_cpu = eval_measures.cpu() 117 | cnt = eval_measures_cpu[9].item() 118 | eval_measures_cpu /= cnt 119 | print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process) 120 | print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms', 121 | 'sq_rel', 'log_rms', 'd1', 'd2', 122 | 'd3')) 123 | for i in range(8): 124 | print('{:7.4f}, '.format(eval_measures_cpu[i]), end='') 125 | print('{:7.4f}'.format(eval_measures_cpu[8])) 126 | return eval_measures_cpu 127 | 128 | 129 | def main_worker(args): 130 | 131 | # CRF model 132 | model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None) 133 | model.train() 134 | 135 | num_params = sum([np.prod(p.size()) for p in model.parameters()]) 136 | print("== Total number of parameters: {}".format(num_params)) 137 | 138 | num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad]) 139 | print("== Total number of learning parameters: {}".format(num_params_update)) 140 | 141 | model = torch.nn.DataParallel(model) 142 | model.cuda() 143 | 144 | print("== Model Initialized") 145 | 146 | if args.checkpoint_path != '': 147 | if os.path.isfile(args.checkpoint_path): 148 | print("== Loading checkpoint '{}'".format(args.checkpoint_path)) 149 | checkpoint = torch.load(args.checkpoint_path, map_location='cpu') 150 | model.load_state_dict(checkpoint['model']) 151 | print("== Loaded checkpoint '{}'".format(args.checkpoint_path)) 152 | del checkpoint 153 | else: 154 | print("== No checkpoint found at '{}'".format(args.checkpoint_path)) 155 | 156 | cudnn.benchmark = True 157 | 158 | dataloader_eval = NewDataLoader(args, 'online_eval') 159 | 160 | # ===== Evaluation ====== 161 | model.eval() 162 | with torch.no_grad(): 163 | eval_measures = eval(model, dataloader_eval, post_process=True) 164 | 165 | 166 | def main(): 167 | torch.cuda.empty_cache() 168 | args.distributed = False 169 | ngpus_per_node = torch.cuda.device_count() 170 | if ngpus_per_node > 1: 171 | print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'") 172 | return -1 173 | 174 | main_worker(args) 175 | 176 | 177 | if __name__ == '__main__': 178 | main() 179 | -------------------------------------------------------------------------------- /newcrfs/networks/NewCRFDepth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .swin_transformer import SwinTransformer 6 | from .newcrf_layers import NewCRF 7 | from .uper_crf_head import PSP 8 | ######################################################################################################################## 9 | 10 | 11 | class NewCRFDepth(nn.Module): 12 | """ 13 | Depth network based on neural window FC-CRFs architecture. 14 | """ 15 | def __init__(self, version=None, inv_depth=False, pretrained=None, 16 | frozen_stages=-1, min_depth=0.1, max_depth=100.0, **kwargs): 17 | super().__init__() 18 | 19 | self.inv_depth = inv_depth 20 | self.with_auxiliary_head = False 21 | self.with_neck = False 22 | 23 | norm_cfg = dict(type='BN', requires_grad=True) 24 | # norm_cfg = dict(type='GN', requires_grad=True, num_groups=8) 25 | 26 | window_size = int(version[-2:]) 27 | 28 | if version[:-2] == 'base': 29 | embed_dim = 128 30 | depths = [2, 2, 18, 2] 31 | num_heads = [4, 8, 16, 32] 32 | in_channels = [128, 256, 512, 1024] 33 | elif version[:-2] == 'large': 34 | embed_dim = 192 35 | depths = [2, 2, 18, 2] 36 | num_heads = [6, 12, 24, 48] 37 | in_channels = [192, 384, 768, 1536] 38 | elif version[:-2] == 'tiny': 39 | embed_dim = 96 40 | depths = [2, 2, 6, 2] 41 | num_heads = [3, 6, 12, 24] 42 | in_channels = [96, 192, 384, 768] 43 | 44 | backbone_cfg = dict( 45 | embed_dim=embed_dim, 46 | depths=depths, 47 | num_heads=num_heads, 48 | window_size=window_size, 49 | ape=False, 50 | drop_path_rate=0.3, 51 | patch_norm=True, 52 | use_checkpoint=False, 53 | frozen_stages=frozen_stages 54 | ) 55 | 56 | embed_dim = 512 57 | decoder_cfg = dict( 58 | in_channels=in_channels, 59 | in_index=[0, 1, 2, 3], 60 | pool_scales=(1, 2, 3, 6), 61 | channels=embed_dim, 62 | dropout_ratio=0.0, 63 | num_classes=32, 64 | norm_cfg=norm_cfg, 65 | align_corners=False 66 | ) 67 | 68 | self.backbone = SwinTransformer(**backbone_cfg) 69 | v_dim = decoder_cfg['num_classes']*4 70 | win = 7 71 | crf_dims = [128, 256, 512, 1024] 72 | v_dims = [64, 128, 256, embed_dim] 73 | self.crf3 = NewCRF(input_dim=in_channels[3], embed_dim=crf_dims[3], window_size=win, v_dim=v_dims[3], num_heads=32) 74 | self.crf2 = NewCRF(input_dim=in_channels[2], embed_dim=crf_dims[2], window_size=win, v_dim=v_dims[2], num_heads=16) 75 | self.crf1 = NewCRF(input_dim=in_channels[1], embed_dim=crf_dims[1], window_size=win, v_dim=v_dims[1], num_heads=8) 76 | self.crf0 = NewCRF(input_dim=in_channels[0], embed_dim=crf_dims[0], window_size=win, v_dim=v_dims[0], num_heads=4) 77 | 78 | self.decoder = PSP(**decoder_cfg) 79 | self.disp_head1 = DispHead(input_dim=crf_dims[0]) 80 | 81 | self.up_mode = 'bilinear' 82 | if self.up_mode == 'mask': 83 | self.mask_head = nn.Sequential( 84 | nn.Conv2d(crf_dims[0], 64, 3, padding=1), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(64, 16*9, 1, padding=0)) 87 | 88 | self.min_depth = min_depth 89 | self.max_depth = max_depth 90 | 91 | self.init_weights(pretrained=pretrained) 92 | 93 | def init_weights(self, pretrained=None): 94 | """Initialize the weights in backbone and heads. 95 | 96 | Args: 97 | pretrained (str, optional): Path to pre-trained weights. 98 | Defaults to None. 99 | """ 100 | print(f'== Load encoder backbone from: {pretrained}') 101 | self.backbone.init_weights(pretrained=pretrained) 102 | self.decoder.init_weights() 103 | if self.with_auxiliary_head: 104 | if isinstance(self.auxiliary_head, nn.ModuleList): 105 | for aux_head in self.auxiliary_head: 106 | aux_head.init_weights() 107 | else: 108 | self.auxiliary_head.init_weights() 109 | 110 | def upsample_mask(self, disp, mask): 111 | """ Upsample disp [H/4, W/4, 1] -> [H, W, 1] using convex combination """ 112 | N, _, H, W = disp.shape 113 | mask = mask.view(N, 1, 9, 4, 4, H, W) 114 | mask = torch.softmax(mask, dim=2) 115 | 116 | up_disp = F.unfold(disp, kernel_size=3, padding=1) 117 | up_disp = up_disp.view(N, 1, 9, 1, 1, H, W) 118 | 119 | up_disp = torch.sum(mask * up_disp, dim=2) 120 | up_disp = up_disp.permute(0, 1, 4, 2, 5, 3) 121 | return up_disp.reshape(N, 1, 4*H, 4*W) 122 | 123 | def forward(self, imgs): 124 | 125 | feats = self.backbone(imgs) 126 | if self.with_neck: 127 | feats = self.neck(feats) 128 | 129 | ppm_out = self.decoder(feats) 130 | 131 | e3 = self.crf3(feats[3], ppm_out) 132 | e3 = nn.PixelShuffle(2)(e3) 133 | e2 = self.crf2(feats[2], e3) 134 | e2 = nn.PixelShuffle(2)(e2) 135 | e1 = self.crf1(feats[1], e2) 136 | e1 = nn.PixelShuffle(2)(e1) 137 | e0 = self.crf0(feats[0], e1) 138 | 139 | if self.up_mode == 'mask': 140 | mask = self.mask_head(e0) 141 | d1 = self.disp_head1(e0, 1) 142 | d1 = self.upsample_mask(d1, mask) 143 | else: 144 | d1 = self.disp_head1(e0, 4) 145 | 146 | depth = d1 * self.max_depth 147 | 148 | return depth 149 | 150 | 151 | class DispHead(nn.Module): 152 | def __init__(self, input_dim=100): 153 | super(DispHead, self).__init__() 154 | # self.norm1 = nn.BatchNorm2d(input_dim) 155 | self.conv1 = nn.Conv2d(input_dim, 1, 3, padding=1) 156 | # self.relu = nn.ReLU(inplace=True) 157 | self.sigmoid = nn.Sigmoid() 158 | 159 | def forward(self, x, scale): 160 | # x = self.relu(self.norm1(x)) 161 | x = self.sigmoid(self.conv1(x)) 162 | if scale > 1: 163 | x = upsample(x, scale_factor=scale) 164 | return x 165 | 166 | 167 | class DispUnpack(nn.Module): 168 | def __init__(self, input_dim=100, hidden_dim=128): 169 | super(DispUnpack, self).__init__() 170 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 171 | self.conv2 = nn.Conv2d(hidden_dim, 16, 3, padding=1) 172 | self.relu = nn.ReLU(inplace=True) 173 | self.sigmoid = nn.Sigmoid() 174 | self.pixel_shuffle = nn.PixelShuffle(4) 175 | 176 | def forward(self, x, output_size): 177 | x = self.relu(self.conv1(x)) 178 | x = self.sigmoid(self.conv2(x)) # [b, 16, h/4, w/4] 179 | # x = torch.reshape(x, [x.shape[0], 1, x.shape[2]*4, x.shape[3]*4]) 180 | x = self.pixel_shuffle(x) 181 | 182 | return x 183 | 184 | 185 | def upsample(x, scale_factor=2, mode="bilinear", align_corners=False): 186 | """Upsample input tensor by a factor of 2 187 | """ 188 | return F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=align_corners) -------------------------------------------------------------------------------- /newcrfs/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/NeWCRFs/d327bf7ca8fb43959734bb02ddc7b56cf283c8d9/newcrfs/networks/__init__.py -------------------------------------------------------------------------------- /newcrfs/networks/newcrf_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.checkpoint as checkpoint 5 | import numpy as np 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | 8 | 9 | class Mlp(nn.Module): 10 | """ Multilayer perceptron.""" 11 | 12 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 13 | super().__init__() 14 | out_features = out_features or in_features 15 | hidden_features = hidden_features or in_features 16 | self.fc1 = nn.Linear(in_features, hidden_features) 17 | self.act = act_layer() 18 | self.fc2 = nn.Linear(hidden_features, out_features) 19 | self.drop = nn.Dropout(drop) 20 | 21 | def forward(self, x): 22 | x = self.fc1(x) 23 | x = self.act(x) 24 | x = self.drop(x) 25 | x = self.fc2(x) 26 | x = self.drop(x) 27 | return x 28 | 29 | 30 | def window_partition(x, window_size): 31 | """ 32 | Args: 33 | x: (B, H, W, C) 34 | window_size (int): window size 35 | 36 | Returns: 37 | windows: (num_windows*B, window_size, window_size, C) 38 | """ 39 | B, H, W, C = x.shape 40 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 41 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 42 | return windows 43 | 44 | 45 | def window_reverse(windows, window_size, H, W): 46 | """ 47 | Args: 48 | windows: (num_windows*B, window_size, window_size, C) 49 | window_size (int): Window size 50 | H (int): Height of image 51 | W (int): Width of image 52 | 53 | Returns: 54 | x: (B, H, W, C) 55 | """ 56 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 57 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 59 | return x 60 | 61 | 62 | class WindowAttention(nn.Module): 63 | """ Window based multi-head self attention (W-MSA) module with relative position bias. 64 | It supports both of shifted and non-shifted window. 65 | 66 | Args: 67 | dim (int): Number of input channels. 68 | window_size (tuple[int]): The height and width of the window. 69 | num_heads (int): Number of attention heads. 70 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 71 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 72 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 73 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 74 | """ 75 | 76 | def __init__(self, dim, window_size, num_heads, v_dim, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 77 | 78 | super().__init__() 79 | self.dim = dim 80 | self.window_size = window_size # Wh, Ww 81 | self.num_heads = num_heads 82 | head_dim = dim // num_heads 83 | self.scale = qk_scale or head_dim ** -0.5 84 | 85 | # define a parameter table of relative position bias 86 | self.relative_position_bias_table = nn.Parameter( 87 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 88 | 89 | # get pair-wise relative position index for each token inside the window 90 | coords_h = torch.arange(self.window_size[0]) 91 | coords_w = torch.arange(self.window_size[1]) 92 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 93 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 94 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 95 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 96 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 97 | relative_coords[:, :, 1] += self.window_size[1] - 1 98 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 99 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 100 | self.register_buffer("relative_position_index", relative_position_index) 101 | 102 | self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) 103 | self.attn_drop = nn.Dropout(attn_drop) 104 | self.proj = nn.Linear(v_dim, v_dim) 105 | self.proj_drop = nn.Dropout(proj_drop) 106 | 107 | trunc_normal_(self.relative_position_bias_table, std=.02) 108 | self.softmax = nn.Softmax(dim=-1) 109 | 110 | def forward(self, x, v, mask=None): 111 | """ Forward function. 112 | 113 | Args: 114 | x: input features with shape of (num_windows*B, N, C) 115 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 116 | """ 117 | B_, N, C = x.shape 118 | qk = self.qk(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 119 | q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple) 120 | 121 | q = q * self.scale 122 | attn = (q @ k.transpose(-2, -1)) 123 | 124 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 125 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 126 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 127 | attn = attn + relative_position_bias.unsqueeze(0) 128 | 129 | if mask is not None: 130 | nW = mask.shape[0] 131 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 132 | attn = attn.view(-1, self.num_heads, N, N) 133 | attn = self.softmax(attn) 134 | else: 135 | attn = self.softmax(attn) 136 | 137 | attn = self.attn_drop(attn) 138 | 139 | # assert self.dim % v.shape[-1] == 0, "self.dim % v.shape[-1] != 0" 140 | # repeat_num = self.dim // v.shape[-1] 141 | # v = v.view(B_, N, self.num_heads // repeat_num, -1).transpose(1, 2).repeat(1, repeat_num, 1, 1) 142 | 143 | assert self.dim == v.shape[-1], "self.dim != v.shape[-1]" 144 | v = v.view(B_, N, self.num_heads, -1).transpose(1, 2) 145 | 146 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 147 | x = self.proj(x) 148 | x = self.proj_drop(x) 149 | return x 150 | 151 | 152 | class CRFBlock(nn.Module): 153 | """ CRF Block. 154 | 155 | Args: 156 | dim (int): Number of input channels. 157 | num_heads (int): Number of attention heads. 158 | window_size (int): Window size. 159 | shift_size (int): Shift size for SW-MSA. 160 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 161 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 162 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 163 | drop (float, optional): Dropout rate. Default: 0.0 164 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 165 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 166 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 167 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 168 | """ 169 | 170 | def __init__(self, dim, num_heads, v_dim, window_size=7, shift_size=0, 171 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 172 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 173 | super().__init__() 174 | self.dim = dim 175 | self.num_heads = num_heads 176 | self.v_dim = v_dim 177 | self.window_size = window_size 178 | self.shift_size = shift_size 179 | self.mlp_ratio = mlp_ratio 180 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 181 | 182 | self.norm1 = norm_layer(dim) 183 | self.attn = WindowAttention( 184 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, v_dim=v_dim, 185 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 186 | 187 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 188 | self.norm2 = norm_layer(v_dim) 189 | mlp_hidden_dim = int(v_dim * mlp_ratio) 190 | self.mlp = Mlp(in_features=v_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 191 | 192 | self.H = None 193 | self.W = None 194 | 195 | def forward(self, x, v, mask_matrix): 196 | """ Forward function. 197 | 198 | Args: 199 | x: Input feature, tensor size (B, H*W, C). 200 | H, W: Spatial resolution of the input feature. 201 | mask_matrix: Attention mask for cyclic shift. 202 | """ 203 | B, L, C = x.shape 204 | H, W = self.H, self.W 205 | assert L == H * W, "input feature has wrong size" 206 | 207 | shortcut = x 208 | x = self.norm1(x) 209 | x = x.view(B, H, W, C) 210 | 211 | # pad feature maps to multiples of window size 212 | pad_l = pad_t = 0 213 | pad_r = (self.window_size - W % self.window_size) % self.window_size 214 | pad_b = (self.window_size - H % self.window_size) % self.window_size 215 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 216 | v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b)) 217 | _, Hp, Wp, _ = x.shape 218 | 219 | # cyclic shift 220 | if self.shift_size > 0: 221 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 222 | shifted_v = torch.roll(v, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 223 | attn_mask = mask_matrix 224 | else: 225 | shifted_x = x 226 | shifted_v = v 227 | attn_mask = None 228 | 229 | # partition windows 230 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 231 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 232 | v_windows = window_partition(shifted_v, self.window_size) # nW*B, window_size, window_size, C 233 | v_windows = v_windows.view(-1, self.window_size * self.window_size, v_windows.shape[-1]) # nW*B, window_size*window_size, C 234 | 235 | # W-MSA/SW-MSA 236 | attn_windows = self.attn(x_windows, v_windows, mask=attn_mask) # nW*B, window_size*window_size, C 237 | 238 | # merge windows 239 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.v_dim) 240 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C 241 | 242 | # reverse cyclic shift 243 | if self.shift_size > 0: 244 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 245 | else: 246 | x = shifted_x 247 | 248 | if pad_r > 0 or pad_b > 0: 249 | x = x[:, :H, :W, :].contiguous() 250 | 251 | x = x.view(B, H * W, self.v_dim) 252 | 253 | # FFN 254 | x = shortcut + self.drop_path(x) 255 | x = x + self.drop_path(self.mlp(self.norm2(x))) 256 | 257 | return x 258 | 259 | 260 | class BasicCRFLayer(nn.Module): 261 | """ A basic NeWCRFs layer for one stage. 262 | 263 | Args: 264 | dim (int): Number of feature channels 265 | depth (int): Depths of this stage. 266 | num_heads (int): Number of attention head. 267 | window_size (int): Local window size. Default: 7. 268 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 269 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 270 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 271 | drop (float, optional): Dropout rate. Default: 0.0 272 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 273 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 274 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 275 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 276 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 277 | """ 278 | 279 | def __init__(self, 280 | dim, 281 | depth, 282 | num_heads, 283 | v_dim, 284 | window_size=7, 285 | mlp_ratio=4., 286 | qkv_bias=True, 287 | qk_scale=None, 288 | drop=0., 289 | attn_drop=0., 290 | drop_path=0., 291 | norm_layer=nn.LayerNorm, 292 | downsample=None, 293 | use_checkpoint=False): 294 | super().__init__() 295 | self.window_size = window_size 296 | self.shift_size = window_size // 2 297 | self.depth = depth 298 | self.use_checkpoint = use_checkpoint 299 | 300 | # build blocks 301 | self.blocks = nn.ModuleList([ 302 | CRFBlock( 303 | dim=dim, 304 | num_heads=num_heads, 305 | v_dim=v_dim, 306 | window_size=window_size, 307 | shift_size=0 if (i % 2 == 0) else window_size // 2, 308 | mlp_ratio=mlp_ratio, 309 | qkv_bias=qkv_bias, 310 | qk_scale=qk_scale, 311 | drop=drop, 312 | attn_drop=attn_drop, 313 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 314 | norm_layer=norm_layer) 315 | for i in range(depth)]) 316 | 317 | # patch merging layer 318 | if downsample is not None: 319 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 320 | else: 321 | self.downsample = None 322 | 323 | def forward(self, x, v, H, W): 324 | """ Forward function. 325 | 326 | Args: 327 | x: Input feature, tensor size (B, H*W, C). 328 | H, W: Spatial resolution of the input feature. 329 | """ 330 | 331 | # calculate attention mask for SW-MSA 332 | Hp = int(np.ceil(H / self.window_size)) * self.window_size 333 | Wp = int(np.ceil(W / self.window_size)) * self.window_size 334 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 335 | h_slices = (slice(0, -self.window_size), 336 | slice(-self.window_size, -self.shift_size), 337 | slice(-self.shift_size, None)) 338 | w_slices = (slice(0, -self.window_size), 339 | slice(-self.window_size, -self.shift_size), 340 | slice(-self.shift_size, None)) 341 | cnt = 0 342 | for h in h_slices: 343 | for w in w_slices: 344 | img_mask[:, h, w, :] = cnt 345 | cnt += 1 346 | 347 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 348 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 349 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 350 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 351 | 352 | for blk in self.blocks: 353 | blk.H, blk.W = H, W 354 | if self.use_checkpoint: 355 | x = checkpoint.checkpoint(blk, x, attn_mask) 356 | else: 357 | x = blk(x, v, attn_mask) 358 | if self.downsample is not None: 359 | x_down = self.downsample(x, H, W) 360 | Wh, Ww = (H + 1) // 2, (W + 1) // 2 361 | return x, H, W, x_down, Wh, Ww 362 | else: 363 | return x, H, W, x, H, W 364 | 365 | 366 | class NewCRF(nn.Module): 367 | def __init__(self, 368 | input_dim=96, 369 | embed_dim=96, 370 | v_dim=64, 371 | window_size=7, 372 | num_heads=4, 373 | depth=2, 374 | patch_size=4, 375 | in_chans=3, 376 | norm_layer=nn.LayerNorm, 377 | patch_norm=True): 378 | super().__init__() 379 | 380 | self.embed_dim = embed_dim 381 | self.patch_norm = patch_norm 382 | 383 | if input_dim != embed_dim: 384 | self.proj_x = nn.Conv2d(input_dim, embed_dim, 3, padding=1) 385 | else: 386 | self.proj_x = None 387 | 388 | if v_dim != embed_dim: 389 | self.proj_v = nn.Conv2d(v_dim, embed_dim, 3, padding=1) 390 | elif embed_dim % v_dim == 0: 391 | self.proj_v = None 392 | 393 | # For now, v_dim need to be equal to embed_dim, because the output of window-attn is the input of shift-window-attn 394 | v_dim = embed_dim 395 | assert v_dim == embed_dim 396 | 397 | self.crf_layer = BasicCRFLayer( 398 | dim=embed_dim, 399 | depth=depth, 400 | num_heads=num_heads, 401 | v_dim=v_dim, 402 | window_size=window_size, 403 | mlp_ratio=4., 404 | qkv_bias=True, 405 | qk_scale=None, 406 | drop=0., 407 | attn_drop=0., 408 | drop_path=0., 409 | norm_layer=norm_layer, 410 | downsample=None, 411 | use_checkpoint=False) 412 | 413 | layer = norm_layer(embed_dim) 414 | layer_name = 'norm_crf' 415 | self.add_module(layer_name, layer) 416 | 417 | 418 | def forward(self, x, v): 419 | if self.proj_x is not None: 420 | x = self.proj_x(x) 421 | if self.proj_v is not None: 422 | v = self.proj_v(v) 423 | 424 | Wh, Ww = x.size(2), x.size(3) 425 | x = x.flatten(2).transpose(1, 2) 426 | v = v.transpose(1, 2).transpose(2, 3) 427 | 428 | x_out, H, W, x, Wh, Ww = self.crf_layer(x, v, Wh, Ww) 429 | norm_layer = getattr(self, f'norm_crf') 430 | x_out = norm_layer(x_out) 431 | out = x_out.view(-1, H, W, self.embed_dim).permute(0, 3, 1, 2).contiguous() 432 | 433 | return out -------------------------------------------------------------------------------- /newcrfs/networks/newcrf_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import os 3 | import os.path as osp 4 | import pkgutil 5 | import warnings 6 | from collections import OrderedDict 7 | from importlib import import_module 8 | 9 | import torch 10 | import torchvision 11 | import torch.nn as nn 12 | from torch.utils import model_zoo 13 | from torch.nn import functional as F 14 | from torch.nn.parallel import DataParallel, DistributedDataParallel 15 | from torch import distributed as dist 16 | 17 | TORCH_VERSION = torch.__version__ 18 | 19 | 20 | def resize(input, 21 | size=None, 22 | scale_factor=None, 23 | mode='nearest', 24 | align_corners=None, 25 | warning=True): 26 | if warning: 27 | if size is not None and align_corners: 28 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 29 | output_h, output_w = tuple(int(x) for x in size) 30 | if output_h > input_h or output_w > output_h: 31 | if ((output_h > 1 and output_w > 1 and input_h > 1 32 | and input_w > 1) and (output_h - 1) % (input_h - 1) 33 | and (output_w - 1) % (input_w - 1)): 34 | warnings.warn( 35 | f'When align_corners={align_corners}, ' 36 | 'the output would more aligned if ' 37 | f'input size {(input_h, input_w)} is `x+1` and ' 38 | f'out size {(output_h, output_w)} is `nx+1`') 39 | if isinstance(size, torch.Size): 40 | size = tuple(int(x) for x in size) 41 | return F.interpolate(input, size, scale_factor, mode, align_corners) 42 | 43 | 44 | def normal_init(module, mean=0, std=1, bias=0): 45 | if hasattr(module, 'weight') and module.weight is not None: 46 | nn.init.normal_(module.weight, mean, std) 47 | if hasattr(module, 'bias') and module.bias is not None: 48 | nn.init.constant_(module.bias, bias) 49 | 50 | 51 | def is_module_wrapper(module): 52 | module_wrappers = (DataParallel, DistributedDataParallel) 53 | return isinstance(module, module_wrappers) 54 | 55 | 56 | def get_dist_info(): 57 | if TORCH_VERSION < '1.0': 58 | initialized = dist._initialized 59 | else: 60 | if dist.is_available(): 61 | initialized = dist.is_initialized() 62 | else: 63 | initialized = False 64 | if initialized: 65 | rank = dist.get_rank() 66 | world_size = dist.get_world_size() 67 | else: 68 | rank = 0 69 | world_size = 1 70 | return rank, world_size 71 | 72 | 73 | def load_state_dict(module, state_dict, strict=False, logger=None): 74 | """Load state_dict to a module. 75 | 76 | This method is modified from :meth:`torch.nn.Module.load_state_dict`. 77 | Default value for ``strict`` is set to ``False`` and the message for 78 | param mismatch will be shown even if strict is False. 79 | 80 | Args: 81 | module (Module): Module that receives the state_dict. 82 | state_dict (OrderedDict): Weights. 83 | strict (bool): whether to strictly enforce that the keys 84 | in :attr:`state_dict` match the keys returned by this module's 85 | :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. 86 | logger (:obj:`logging.Logger`, optional): Logger to log the error 87 | message. If not specified, print function will be used. 88 | """ 89 | unexpected_keys = [] 90 | all_missing_keys = [] 91 | err_msg = [] 92 | 93 | metadata = getattr(state_dict, '_metadata', None) 94 | state_dict = state_dict.copy() 95 | if metadata is not None: 96 | state_dict._metadata = metadata 97 | 98 | # use _load_from_state_dict to enable checkpoint version control 99 | def load(module, prefix=''): 100 | # recursively check parallel module in case that the model has a 101 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) 102 | if is_module_wrapper(module): 103 | module = module.module 104 | local_metadata = {} if metadata is None else metadata.get( 105 | prefix[:-1], {}) 106 | module._load_from_state_dict(state_dict, prefix, local_metadata, True, 107 | all_missing_keys, unexpected_keys, 108 | err_msg) 109 | for name, child in module._modules.items(): 110 | if child is not None: 111 | load(child, prefix + name + '.') 112 | 113 | load(module) 114 | load = None # break load->load reference cycle 115 | 116 | # ignore "num_batches_tracked" of BN layers 117 | missing_keys = [ 118 | key for key in all_missing_keys if 'num_batches_tracked' not in key 119 | ] 120 | 121 | if unexpected_keys: 122 | err_msg.append('unexpected key in source ' 123 | f'state_dict: {", ".join(unexpected_keys)}\n') 124 | if missing_keys: 125 | err_msg.append( 126 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n') 127 | 128 | rank, _ = get_dist_info() 129 | if len(err_msg) > 0 and rank == 0: 130 | err_msg.insert( 131 | 0, 'The model and loaded state dict do not match exactly\n') 132 | err_msg = '\n'.join(err_msg) 133 | if strict: 134 | raise RuntimeError(err_msg) 135 | elif logger is not None: 136 | logger.warning(err_msg) 137 | else: 138 | print(err_msg) 139 | 140 | 141 | def load_url_dist(url, model_dir=None): 142 | """In distributed setting, this function only download checkpoint at local 143 | rank 0.""" 144 | rank, world_size = get_dist_info() 145 | rank = int(os.environ.get('LOCAL_RANK', rank)) 146 | if rank == 0: 147 | checkpoint = model_zoo.load_url(url, model_dir=model_dir) 148 | if world_size > 1: 149 | torch.distributed.barrier() 150 | if rank > 0: 151 | checkpoint = model_zoo.load_url(url, model_dir=model_dir) 152 | return checkpoint 153 | 154 | 155 | def get_torchvision_models(): 156 | model_urls = dict() 157 | for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): 158 | if ispkg: 159 | continue 160 | _zoo = import_module(f'torchvision.models.{name}') 161 | if hasattr(_zoo, 'model_urls'): 162 | _urls = getattr(_zoo, 'model_urls') 163 | model_urls.update(_urls) 164 | return model_urls 165 | 166 | 167 | def _load_checkpoint(filename, map_location=None): 168 | """Load checkpoint from somewhere (modelzoo, file, url). 169 | 170 | Args: 171 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 172 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 173 | details. 174 | map_location (str | None): Same as :func:`torch.load`. Default: None. 175 | 176 | Returns: 177 | dict | OrderedDict: The loaded checkpoint. It can be either an 178 | OrderedDict storing model weights or a dict containing other 179 | information, which depends on the checkpoint. 180 | """ 181 | if filename.startswith('modelzoo://'): 182 | warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' 183 | 'use "torchvision://" instead') 184 | model_urls = get_torchvision_models() 185 | model_name = filename[11:] 186 | checkpoint = load_url_dist(model_urls[model_name]) 187 | else: 188 | if not osp.isfile(filename): 189 | raise IOError(f'{filename} is not a checkpoint file') 190 | checkpoint = torch.load(filename, map_location=map_location) 191 | return checkpoint 192 | 193 | 194 | def load_checkpoint(model, 195 | filename, 196 | map_location='cpu', 197 | strict=False, 198 | logger=None): 199 | """Load checkpoint from a file or URI. 200 | 201 | Args: 202 | model (Module): Module to load checkpoint. 203 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 204 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 205 | details. 206 | map_location (str): Same as :func:`torch.load`. 207 | strict (bool): Whether to allow different params for the model and 208 | checkpoint. 209 | logger (:mod:`logging.Logger` or None): The logger for error message. 210 | 211 | Returns: 212 | dict or OrderedDict: The loaded checkpoint. 213 | """ 214 | checkpoint = _load_checkpoint(filename, map_location) 215 | # OrderedDict is a subclass of dict 216 | if not isinstance(checkpoint, dict): 217 | raise RuntimeError( 218 | f'No state_dict found in checkpoint file {filename}') 219 | # get state_dict from checkpoint 220 | if 'state_dict' in checkpoint: 221 | state_dict = checkpoint['state_dict'] 222 | elif 'model' in checkpoint: 223 | state_dict = checkpoint['model'] 224 | else: 225 | state_dict = checkpoint 226 | # strip prefix of state_dict 227 | if list(state_dict.keys())[0].startswith('module.'): 228 | state_dict = {k[7:]: v for k, v in state_dict.items()} 229 | 230 | # for MoBY, load model of online branch 231 | if sorted(list(state_dict.keys()))[0].startswith('encoder'): 232 | state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} 233 | 234 | # reshape absolute position embedding 235 | if state_dict.get('absolute_pos_embed') is not None: 236 | absolute_pos_embed = state_dict['absolute_pos_embed'] 237 | N1, L, C1 = absolute_pos_embed.size() 238 | N2, C2, H, W = model.absolute_pos_embed.size() 239 | if N1 != N2 or C1 != C2 or L != H*W: 240 | logger.warning("Error in loading absolute_pos_embed, pass") 241 | else: 242 | state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) 243 | 244 | # interpolate position bias table if needed 245 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 246 | for table_key in relative_position_bias_table_keys: 247 | table_pretrained = state_dict[table_key] 248 | table_current = model.state_dict()[table_key] 249 | L1, nH1 = table_pretrained.size() 250 | L2, nH2 = table_current.size() 251 | if nH1 != nH2: 252 | logger.warning(f"Error in loading {table_key}, pass") 253 | else: 254 | if L1 != L2: 255 | S1 = int(L1 ** 0.5) 256 | S2 = int(L2 ** 0.5) 257 | table_pretrained_resized = F.interpolate( 258 | table_pretrained.permute(1, 0).view(1, nH1, S1, S1), 259 | size=(S2, S2), mode='bicubic') 260 | state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) 261 | 262 | # load state_dict 263 | load_state_dict(model, state_dict, strict, logger) 264 | return checkpoint -------------------------------------------------------------------------------- /newcrfs/networks/swin_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.checkpoint as checkpoint 5 | import numpy as np 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | 8 | from .newcrf_utils import load_checkpoint 9 | 10 | 11 | class Mlp(nn.Module): 12 | """ Multilayer perceptron.""" 13 | 14 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 15 | super().__init__() 16 | out_features = out_features or in_features 17 | hidden_features = hidden_features or in_features 18 | self.fc1 = nn.Linear(in_features, hidden_features) 19 | self.act = act_layer() 20 | self.fc2 = nn.Linear(hidden_features, out_features) 21 | self.drop = nn.Dropout(drop) 22 | 23 | def forward(self, x): 24 | x = self.fc1(x) 25 | x = self.act(x) 26 | x = self.drop(x) 27 | x = self.fc2(x) 28 | x = self.drop(x) 29 | return x 30 | 31 | 32 | def window_partition(x, window_size): 33 | """ 34 | Args: 35 | x: (B, H, W, C) 36 | window_size (int): window size 37 | 38 | Returns: 39 | windows: (num_windows*B, window_size, window_size, C) 40 | """ 41 | B, H, W, C = x.shape 42 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 43 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 44 | return windows 45 | 46 | 47 | def window_reverse(windows, window_size, H, W): 48 | """ 49 | Args: 50 | windows: (num_windows*B, window_size, window_size, C) 51 | window_size (int): Window size 52 | H (int): Height of image 53 | W (int): Width of image 54 | 55 | Returns: 56 | x: (B, H, W, C) 57 | """ 58 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 59 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 60 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 61 | return x 62 | 63 | 64 | class WindowAttention(nn.Module): 65 | """ Window based multi-head self attention (W-MSA) module with relative position bias. 66 | It supports both of shifted and non-shifted window. 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | window_size (tuple[int]): The height and width of the window. 71 | num_heads (int): Number of attention heads. 72 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 73 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 74 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 75 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 76 | """ 77 | 78 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 79 | 80 | super().__init__() 81 | self.dim = dim 82 | self.window_size = window_size # Wh, Ww 83 | self.num_heads = num_heads 84 | head_dim = dim // num_heads 85 | self.scale = qk_scale or head_dim ** -0.5 86 | 87 | # define a parameter table of relative position bias 88 | self.relative_position_bias_table = nn.Parameter( 89 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 90 | 91 | # get pair-wise relative position index for each token inside the window 92 | coords_h = torch.arange(self.window_size[0]) 93 | coords_w = torch.arange(self.window_size[1]) 94 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 95 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 96 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 97 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 98 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 99 | relative_coords[:, :, 1] += self.window_size[1] - 1 100 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 101 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 102 | self.register_buffer("relative_position_index", relative_position_index) 103 | 104 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 105 | self.attn_drop = nn.Dropout(attn_drop) 106 | self.proj = nn.Linear(dim, dim) 107 | self.proj_drop = nn.Dropout(proj_drop) 108 | 109 | trunc_normal_(self.relative_position_bias_table, std=.02) 110 | self.softmax = nn.Softmax(dim=-1) 111 | 112 | def forward(self, x, mask=None): 113 | """ Forward function. 114 | 115 | Args: 116 | x: input features with shape of (num_windows*B, N, C) 117 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 118 | """ 119 | B_, N, C = x.shape 120 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 121 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 122 | 123 | q = q * self.scale 124 | attn = (q @ k.transpose(-2, -1)) 125 | 126 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 127 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 128 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 129 | attn = attn + relative_position_bias.unsqueeze(0) 130 | 131 | if mask is not None: 132 | nW = mask.shape[0] 133 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 134 | attn = attn.view(-1, self.num_heads, N, N) 135 | attn = self.softmax(attn) 136 | else: 137 | attn = self.softmax(attn) 138 | 139 | attn = self.attn_drop(attn) 140 | 141 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 142 | x = self.proj(x) 143 | x = self.proj_drop(x) 144 | return x 145 | 146 | 147 | class SwinTransformerBlock(nn.Module): 148 | """ Swin Transformer Block. 149 | 150 | Args: 151 | dim (int): Number of input channels. 152 | num_heads (int): Number of attention heads. 153 | window_size (int): Window size. 154 | shift_size (int): Shift size for SW-MSA. 155 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 156 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 157 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 158 | drop (float, optional): Dropout rate. Default: 0.0 159 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 160 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 161 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 162 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 163 | """ 164 | 165 | def __init__(self, dim, num_heads, window_size=7, shift_size=0, 166 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 167 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 168 | super().__init__() 169 | self.dim = dim 170 | self.num_heads = num_heads 171 | self.window_size = window_size 172 | self.shift_size = shift_size 173 | self.mlp_ratio = mlp_ratio 174 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 175 | 176 | self.norm1 = norm_layer(dim) 177 | self.attn = WindowAttention( 178 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 179 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 180 | 181 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 182 | self.norm2 = norm_layer(dim) 183 | mlp_hidden_dim = int(dim * mlp_ratio) 184 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 185 | 186 | self.H = None 187 | self.W = None 188 | 189 | def forward(self, x, mask_matrix): 190 | """ Forward function. 191 | 192 | Args: 193 | x: Input feature, tensor size (B, H*W, C). 194 | H, W: Spatial resolution of the input feature. 195 | mask_matrix: Attention mask for cyclic shift. 196 | """ 197 | B, L, C = x.shape 198 | H, W = self.H, self.W 199 | assert L == H * W, "input feature has wrong size" 200 | 201 | shortcut = x 202 | x = self.norm1(x) 203 | x = x.view(B, H, W, C) 204 | 205 | # pad feature maps to multiples of window size 206 | pad_l = pad_t = 0 207 | pad_r = (self.window_size - W % self.window_size) % self.window_size 208 | pad_b = (self.window_size - H % self.window_size) % self.window_size 209 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 210 | _, Hp, Wp, _ = x.shape 211 | 212 | # cyclic shift 213 | if self.shift_size > 0: 214 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 215 | attn_mask = mask_matrix 216 | else: 217 | shifted_x = x 218 | attn_mask = None 219 | 220 | # partition windows 221 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 222 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 223 | 224 | # W-MSA/SW-MSA 225 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C 226 | 227 | # merge windows 228 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 229 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C 230 | 231 | # reverse cyclic shift 232 | if self.shift_size > 0: 233 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 234 | else: 235 | x = shifted_x 236 | 237 | if pad_r > 0 or pad_b > 0: 238 | x = x[:, :H, :W, :].contiguous() 239 | 240 | x = x.view(B, H * W, C) 241 | 242 | # FFN 243 | x = shortcut + self.drop_path(x) 244 | x = x + self.drop_path(self.mlp(self.norm2(x))) 245 | 246 | return x 247 | 248 | 249 | class PatchMerging(nn.Module): 250 | """ Patch Merging Layer 251 | 252 | Args: 253 | dim (int): Number of input channels. 254 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 255 | """ 256 | def __init__(self, dim, norm_layer=nn.LayerNorm): 257 | super().__init__() 258 | self.dim = dim 259 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 260 | self.norm = norm_layer(4 * dim) 261 | 262 | def forward(self, x, H, W): 263 | """ Forward function. 264 | 265 | Args: 266 | x: Input feature, tensor size (B, H*W, C). 267 | H, W: Spatial resolution of the input feature. 268 | """ 269 | B, L, C = x.shape 270 | assert L == H * W, "input feature has wrong size" 271 | 272 | x = x.view(B, H, W, C) 273 | 274 | # padding 275 | pad_input = (H % 2 == 1) or (W % 2 == 1) 276 | if pad_input: 277 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 278 | 279 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 280 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 281 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 282 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 283 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 284 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 285 | 286 | x = self.norm(x) 287 | x = self.reduction(x) 288 | 289 | return x 290 | 291 | 292 | class BasicLayer(nn.Module): 293 | """ A basic Swin Transformer layer for one stage. 294 | 295 | Args: 296 | dim (int): Number of feature channels 297 | depth (int): Depths of this stage. 298 | num_heads (int): Number of attention head. 299 | window_size (int): Local window size. Default: 7. 300 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 301 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 302 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 303 | drop (float, optional): Dropout rate. Default: 0.0 304 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 305 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 306 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 307 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 308 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 309 | """ 310 | 311 | def __init__(self, 312 | dim, 313 | depth, 314 | num_heads, 315 | window_size=7, 316 | mlp_ratio=4., 317 | qkv_bias=True, 318 | qk_scale=None, 319 | drop=0., 320 | attn_drop=0., 321 | drop_path=0., 322 | norm_layer=nn.LayerNorm, 323 | downsample=None, 324 | use_checkpoint=False): 325 | super().__init__() 326 | self.window_size = window_size 327 | self.shift_size = window_size // 2 328 | self.depth = depth 329 | self.use_checkpoint = use_checkpoint 330 | 331 | # build blocks 332 | self.blocks = nn.ModuleList([ 333 | SwinTransformerBlock( 334 | dim=dim, 335 | num_heads=num_heads, 336 | window_size=window_size, 337 | shift_size=0 if (i % 2 == 0) else window_size // 2, 338 | mlp_ratio=mlp_ratio, 339 | qkv_bias=qkv_bias, 340 | qk_scale=qk_scale, 341 | drop=drop, 342 | attn_drop=attn_drop, 343 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 344 | norm_layer=norm_layer) 345 | for i in range(depth)]) 346 | 347 | # patch merging layer 348 | if downsample is not None: 349 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 350 | else: 351 | self.downsample = None 352 | 353 | def forward(self, x, H, W): 354 | """ Forward function. 355 | 356 | Args: 357 | x: Input feature, tensor size (B, H*W, C). 358 | H, W: Spatial resolution of the input feature. 359 | """ 360 | 361 | # calculate attention mask for SW-MSA 362 | Hp = int(np.ceil(H / self.window_size)) * self.window_size 363 | Wp = int(np.ceil(W / self.window_size)) * self.window_size 364 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 365 | h_slices = (slice(0, -self.window_size), 366 | slice(-self.window_size, -self.shift_size), 367 | slice(-self.shift_size, None)) 368 | w_slices = (slice(0, -self.window_size), 369 | slice(-self.window_size, -self.shift_size), 370 | slice(-self.shift_size, None)) 371 | cnt = 0 372 | for h in h_slices: 373 | for w in w_slices: 374 | img_mask[:, h, w, :] = cnt 375 | cnt += 1 376 | 377 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 378 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 379 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 380 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 381 | 382 | for blk in self.blocks: 383 | blk.H, blk.W = H, W 384 | if self.use_checkpoint: 385 | x = checkpoint.checkpoint(blk, x, attn_mask) 386 | else: 387 | x = blk(x, attn_mask) 388 | if self.downsample is not None: 389 | x_down = self.downsample(x, H, W) 390 | Wh, Ww = (H + 1) // 2, (W + 1) // 2 391 | return x, H, W, x_down, Wh, Ww 392 | else: 393 | return x, H, W, x, H, W 394 | 395 | 396 | class PatchEmbed(nn.Module): 397 | """ Image to Patch Embedding 398 | 399 | Args: 400 | patch_size (int): Patch token size. Default: 4. 401 | in_chans (int): Number of input image channels. Default: 3. 402 | embed_dim (int): Number of linear projection output channels. Default: 96. 403 | norm_layer (nn.Module, optional): Normalization layer. Default: None 404 | """ 405 | 406 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 407 | super().__init__() 408 | patch_size = to_2tuple(patch_size) 409 | self.patch_size = patch_size 410 | 411 | self.in_chans = in_chans 412 | self.embed_dim = embed_dim 413 | 414 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 415 | if norm_layer is not None: 416 | self.norm = norm_layer(embed_dim) 417 | else: 418 | self.norm = None 419 | 420 | def forward(self, x): 421 | """Forward function.""" 422 | # padding 423 | _, _, H, W = x.size() 424 | if W % self.patch_size[1] != 0: 425 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) 426 | if H % self.patch_size[0] != 0: 427 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) 428 | 429 | x = self.proj(x) # B C Wh Ww 430 | if self.norm is not None: 431 | Wh, Ww = x.size(2), x.size(3) 432 | x = x.flatten(2).transpose(1, 2) 433 | x = self.norm(x) 434 | x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) 435 | 436 | return x 437 | 438 | 439 | class SwinTransformer(nn.Module): 440 | """ Swin Transformer backbone. 441 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 442 | https://arxiv.org/pdf/2103.14030 443 | 444 | Args: 445 | pretrain_img_size (int): Input image size for training the pretrained model, 446 | used in absolute postion embedding. Default 224. 447 | patch_size (int | tuple(int)): Patch size. Default: 4. 448 | in_chans (int): Number of input image channels. Default: 3. 449 | embed_dim (int): Number of linear projection output channels. Default: 96. 450 | depths (tuple[int]): Depths of each Swin Transformer stage. 451 | num_heads (tuple[int]): Number of attention head of each stage. 452 | window_size (int): Window size. Default: 7. 453 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 454 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 455 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. 456 | drop_rate (float): Dropout rate. 457 | attn_drop_rate (float): Attention dropout rate. Default: 0. 458 | drop_path_rate (float): Stochastic depth rate. Default: 0.2. 459 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 460 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. 461 | patch_norm (bool): If True, add normalization after patch embedding. Default: True. 462 | out_indices (Sequence[int]): Output from which stages. 463 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 464 | -1 means not freezing any parameters. 465 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 466 | """ 467 | 468 | def __init__(self, 469 | pretrain_img_size=224, 470 | patch_size=4, 471 | in_chans=3, 472 | embed_dim=96, 473 | depths=[2, 2, 6, 2], 474 | num_heads=[3, 6, 12, 24], 475 | window_size=7, 476 | mlp_ratio=4., 477 | qkv_bias=True, 478 | qk_scale=None, 479 | drop_rate=0., 480 | attn_drop_rate=0., 481 | drop_path_rate=0.2, 482 | norm_layer=nn.LayerNorm, 483 | ape=False, 484 | patch_norm=True, 485 | out_indices=(0, 1, 2, 3), 486 | frozen_stages=-1, 487 | use_checkpoint=False): 488 | super().__init__() 489 | 490 | self.pretrain_img_size = pretrain_img_size 491 | self.num_layers = len(depths) 492 | self.embed_dim = embed_dim 493 | self.ape = ape 494 | self.patch_norm = patch_norm 495 | self.out_indices = out_indices 496 | self.frozen_stages = frozen_stages 497 | 498 | # split image into non-overlapping patches 499 | self.patch_embed = PatchEmbed( 500 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 501 | norm_layer=norm_layer if self.patch_norm else None) 502 | 503 | # absolute position embedding 504 | if self.ape: 505 | pretrain_img_size = to_2tuple(pretrain_img_size) 506 | patch_size = to_2tuple(patch_size) 507 | patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] 508 | 509 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) 510 | trunc_normal_(self.absolute_pos_embed, std=.02) 511 | 512 | self.pos_drop = nn.Dropout(p=drop_rate) 513 | 514 | # stochastic depth 515 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 516 | 517 | # build layers 518 | self.layers = nn.ModuleList() 519 | for i_layer in range(self.num_layers): 520 | layer = BasicLayer( 521 | dim=int(embed_dim * 2 ** i_layer), 522 | depth=depths[i_layer], 523 | num_heads=num_heads[i_layer], 524 | window_size=window_size, 525 | mlp_ratio=mlp_ratio, 526 | qkv_bias=qkv_bias, 527 | qk_scale=qk_scale, 528 | drop=drop_rate, 529 | attn_drop=attn_drop_rate, 530 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 531 | norm_layer=norm_layer, 532 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 533 | use_checkpoint=use_checkpoint) 534 | self.layers.append(layer) 535 | 536 | num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] 537 | self.num_features = num_features 538 | 539 | # add a norm layer for each output 540 | for i_layer in out_indices: 541 | layer = norm_layer(num_features[i_layer]) 542 | layer_name = f'norm{i_layer}' 543 | self.add_module(layer_name, layer) 544 | 545 | self._freeze_stages() 546 | 547 | def _freeze_stages(self): 548 | if self.frozen_stages >= 0: 549 | self.patch_embed.eval() 550 | for param in self.patch_embed.parameters(): 551 | param.requires_grad = False 552 | 553 | if self.frozen_stages >= 1 and self.ape: 554 | self.absolute_pos_embed.requires_grad = False 555 | 556 | if self.frozen_stages >= 2: 557 | self.pos_drop.eval() 558 | for i in range(0, self.frozen_stages - 1): 559 | m = self.layers[i] 560 | m.eval() 561 | for param in m.parameters(): 562 | param.requires_grad = False 563 | 564 | def init_weights(self, pretrained=None): 565 | """Initialize the weights in backbone. 566 | 567 | Args: 568 | pretrained (str, optional): Path to pre-trained weights. 569 | Defaults to None. 570 | """ 571 | 572 | def _init_weights(m): 573 | if isinstance(m, nn.Linear): 574 | trunc_normal_(m.weight, std=.02) 575 | if isinstance(m, nn.Linear) and m.bias is not None: 576 | nn.init.constant_(m.bias, 0) 577 | elif isinstance(m, nn.LayerNorm): 578 | nn.init.constant_(m.bias, 0) 579 | nn.init.constant_(m.weight, 1.0) 580 | 581 | if isinstance(pretrained, str): 582 | self.apply(_init_weights) 583 | # logger = get_root_logger() 584 | load_checkpoint(self, pretrained, strict=False) 585 | elif pretrained is None: 586 | self.apply(_init_weights) 587 | else: 588 | raise TypeError('pretrained must be a str or None') 589 | 590 | def forward(self, x): 591 | """Forward function.""" 592 | x = self.patch_embed(x) 593 | 594 | Wh, Ww = x.size(2), x.size(3) 595 | if self.ape: 596 | # interpolate the position embedding to the corresponding size 597 | absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') 598 | x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C 599 | else: 600 | x = x.flatten(2).transpose(1, 2) 601 | x = self.pos_drop(x) 602 | 603 | outs = [] 604 | for i in range(self.num_layers): 605 | layer = self.layers[i] 606 | x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) 607 | 608 | if i in self.out_indices: 609 | norm_layer = getattr(self, f'norm{i}') 610 | x_out = norm_layer(x_out) 611 | 612 | out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() 613 | outs.append(out) 614 | 615 | return tuple(outs) 616 | 617 | def train(self, mode=True): 618 | """Convert the model into training mode while keep layers freezed.""" 619 | super(SwinTransformer, self).train(mode) 620 | self._freeze_stages() 621 | -------------------------------------------------------------------------------- /newcrfs/networks/uper_crf_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from mmcv.cnn import ConvModule 6 | from .newcrf_utils import resize, normal_init 7 | 8 | 9 | class PPM(nn.ModuleList): 10 | """Pooling Pyramid Module used in PSPNet. 11 | 12 | Args: 13 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 14 | Module. 15 | in_channels (int): Input channels. 16 | channels (int): Channels after modules, before conv_seg. 17 | conv_cfg (dict|None): Config of conv layers. 18 | norm_cfg (dict|None): Config of norm layers. 19 | act_cfg (dict): Config of activation layers. 20 | align_corners (bool): align_corners argument of F.interpolate. 21 | """ 22 | 23 | def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, 24 | act_cfg, align_corners): 25 | super(PPM, self).__init__() 26 | self.pool_scales = pool_scales 27 | self.align_corners = align_corners 28 | self.in_channels = in_channels 29 | self.channels = channels 30 | self.conv_cfg = conv_cfg 31 | self.norm_cfg = norm_cfg 32 | self.act_cfg = act_cfg 33 | for pool_scale in pool_scales: 34 | # == if batch size = 1, BN is not supported, change to GN 35 | if pool_scale == 1: norm_cfg = dict(type='GN', requires_grad=True, num_groups=256) 36 | self.append( 37 | nn.Sequential( 38 | nn.AdaptiveAvgPool2d(pool_scale), 39 | ConvModule( 40 | self.in_channels, 41 | self.channels, 42 | 1, 43 | conv_cfg=self.conv_cfg, 44 | norm_cfg=norm_cfg, 45 | act_cfg=self.act_cfg))) 46 | 47 | def forward(self, x): 48 | """Forward function.""" 49 | ppm_outs = [] 50 | for ppm in self: 51 | ppm_out = ppm(x) 52 | upsampled_ppm_out = resize( 53 | ppm_out, 54 | size=x.size()[2:], 55 | mode='bilinear', 56 | align_corners=self.align_corners) 57 | ppm_outs.append(upsampled_ppm_out) 58 | return ppm_outs 59 | 60 | 61 | class BaseDecodeHead(nn.Module): 62 | """Base class for BaseDecodeHead. 63 | 64 | Args: 65 | in_channels (int|Sequence[int]): Input channels. 66 | channels (int): Channels after modules, before conv_seg. 67 | num_classes (int): Number of classes. 68 | dropout_ratio (float): Ratio of dropout layer. Default: 0.1. 69 | conv_cfg (dict|None): Config of conv layers. Default: None. 70 | norm_cfg (dict|None): Config of norm layers. Default: None. 71 | act_cfg (dict): Config of activation layers. 72 | Default: dict(type='ReLU') 73 | in_index (int|Sequence[int]): Input feature index. Default: -1 74 | input_transform (str|None): Transformation type of input features. 75 | Options: 'resize_concat', 'multiple_select', None. 76 | 'resize_concat': Multiple feature maps will be resize to the 77 | same size as first one and than concat together. 78 | Usually used in FCN head of HRNet. 79 | 'multiple_select': Multiple feature maps will be bundle into 80 | a list and passed into decode head. 81 | None: Only one select feature map is allowed. 82 | Default: None. 83 | loss_decode (dict): Config of decode loss. 84 | Default: dict(type='CrossEntropyLoss'). 85 | ignore_index (int | None): The label index to be ignored. When using 86 | masked BCE loss, ignore_index should be set to None. Default: 255 87 | sampler (dict|None): The config of segmentation map sampler. 88 | Default: None. 89 | align_corners (bool): align_corners argument of F.interpolate. 90 | Default: False. 91 | """ 92 | 93 | def __init__(self, 94 | in_channels, 95 | channels, 96 | *, 97 | num_classes, 98 | dropout_ratio=0.1, 99 | conv_cfg=None, 100 | norm_cfg=None, 101 | act_cfg=dict(type='ReLU'), 102 | in_index=-1, 103 | input_transform=None, 104 | loss_decode=dict( 105 | type='CrossEntropyLoss', 106 | use_sigmoid=False, 107 | loss_weight=1.0), 108 | ignore_index=255, 109 | sampler=None, 110 | align_corners=False): 111 | super(BaseDecodeHead, self).__init__() 112 | self._init_inputs(in_channels, in_index, input_transform) 113 | self.channels = channels 114 | self.num_classes = num_classes 115 | self.dropout_ratio = dropout_ratio 116 | self.conv_cfg = conv_cfg 117 | self.norm_cfg = norm_cfg 118 | self.act_cfg = act_cfg 119 | self.in_index = in_index 120 | # self.loss_decode = build_loss(loss_decode) 121 | self.ignore_index = ignore_index 122 | self.align_corners = align_corners 123 | # if sampler is not None: 124 | # self.sampler = build_pixel_sampler(sampler, context=self) 125 | # else: 126 | # self.sampler = None 127 | 128 | # self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) 129 | # self.conv1 = nn.Conv2d(channels, num_classes, 3, padding=1) 130 | if dropout_ratio > 0: 131 | self.dropout = nn.Dropout2d(dropout_ratio) 132 | else: 133 | self.dropout = None 134 | self.fp16_enabled = False 135 | 136 | def extra_repr(self): 137 | """Extra repr.""" 138 | s = f'input_transform={self.input_transform}, ' \ 139 | f'ignore_index={self.ignore_index}, ' \ 140 | f'align_corners={self.align_corners}' 141 | return s 142 | 143 | def _init_inputs(self, in_channels, in_index, input_transform): 144 | """Check and initialize input transforms. 145 | 146 | The in_channels, in_index and input_transform must match. 147 | Specifically, when input_transform is None, only single feature map 148 | will be selected. So in_channels and in_index must be of type int. 149 | When input_transform 150 | 151 | Args: 152 | in_channels (int|Sequence[int]): Input channels. 153 | in_index (int|Sequence[int]): Input feature index. 154 | input_transform (str|None): Transformation type of input features. 155 | Options: 'resize_concat', 'multiple_select', None. 156 | 'resize_concat': Multiple feature maps will be resize to the 157 | same size as first one and than concat together. 158 | Usually used in FCN head of HRNet. 159 | 'multiple_select': Multiple feature maps will be bundle into 160 | a list and passed into decode head. 161 | None: Only one select feature map is allowed. 162 | """ 163 | 164 | if input_transform is not None: 165 | assert input_transform in ['resize_concat', 'multiple_select'] 166 | self.input_transform = input_transform 167 | self.in_index = in_index 168 | if input_transform is not None: 169 | assert isinstance(in_channels, (list, tuple)) 170 | assert isinstance(in_index, (list, tuple)) 171 | assert len(in_channels) == len(in_index) 172 | if input_transform == 'resize_concat': 173 | self.in_channels = sum(in_channels) 174 | else: 175 | self.in_channels = in_channels 176 | else: 177 | assert isinstance(in_channels, int) 178 | assert isinstance(in_index, int) 179 | self.in_channels = in_channels 180 | 181 | def init_weights(self): 182 | """Initialize weights of classification layer.""" 183 | # normal_init(self.conv_seg, mean=0, std=0.01) 184 | # normal_init(self.conv1, mean=0, std=0.01) 185 | 186 | def _transform_inputs(self, inputs): 187 | """Transform inputs for decoder. 188 | 189 | Args: 190 | inputs (list[Tensor]): List of multi-level img features. 191 | 192 | Returns: 193 | Tensor: The transformed inputs 194 | """ 195 | 196 | if self.input_transform == 'resize_concat': 197 | inputs = [inputs[i] for i in self.in_index] 198 | upsampled_inputs = [ 199 | resize( 200 | input=x, 201 | size=inputs[0].shape[2:], 202 | mode='bilinear', 203 | align_corners=self.align_corners) for x in inputs 204 | ] 205 | inputs = torch.cat(upsampled_inputs, dim=1) 206 | elif self.input_transform == 'multiple_select': 207 | inputs = [inputs[i] for i in self.in_index] 208 | else: 209 | inputs = inputs[self.in_index] 210 | 211 | return inputs 212 | 213 | def forward(self, inputs): 214 | """Placeholder of forward function.""" 215 | pass 216 | 217 | def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): 218 | """Forward function for training. 219 | Args: 220 | inputs (list[Tensor]): List of multi-level img features. 221 | img_metas (list[dict]): List of image info dict where each dict 222 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 223 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 224 | For details on the values of these keys see 225 | `mmseg/datasets/pipelines/formatting.py:Collect`. 226 | gt_semantic_seg (Tensor): Semantic segmentation masks 227 | used if the architecture supports semantic segmentation task. 228 | train_cfg (dict): The training config. 229 | 230 | Returns: 231 | dict[str, Tensor]: a dictionary of loss components 232 | """ 233 | seg_logits = self.forward(inputs) 234 | losses = self.losses(seg_logits, gt_semantic_seg) 235 | return losses 236 | 237 | def forward_test(self, inputs, img_metas, test_cfg): 238 | """Forward function for testing. 239 | 240 | Args: 241 | inputs (list[Tensor]): List of multi-level img features. 242 | img_metas (list[dict]): List of image info dict where each dict 243 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 244 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 245 | For details on the values of these keys see 246 | `mmseg/datasets/pipelines/formatting.py:Collect`. 247 | test_cfg (dict): The testing config. 248 | 249 | Returns: 250 | Tensor: Output segmentation map. 251 | """ 252 | return self.forward(inputs) 253 | 254 | 255 | class UPerHead(BaseDecodeHead): 256 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 257 | super(UPerHead, self).__init__( 258 | input_transform='multiple_select', **kwargs) 259 | # FPN Module 260 | self.lateral_convs = nn.ModuleList() 261 | self.fpn_convs = nn.ModuleList() 262 | for in_channels in self.in_channels: # skip the top layer 263 | l_conv = ConvModule( 264 | in_channels, 265 | self.channels, 266 | 1, 267 | conv_cfg=self.conv_cfg, 268 | norm_cfg=self.norm_cfg, 269 | act_cfg=self.act_cfg, 270 | inplace=True) 271 | fpn_conv = ConvModule( 272 | self.channels, 273 | self.channels, 274 | 3, 275 | padding=1, 276 | conv_cfg=self.conv_cfg, 277 | norm_cfg=self.norm_cfg, 278 | act_cfg=self.act_cfg, 279 | inplace=True) 280 | self.lateral_convs.append(l_conv) 281 | self.fpn_convs.append(fpn_conv) 282 | 283 | def forward(self, inputs): 284 | """Forward function.""" 285 | 286 | inputs = self._transform_inputs(inputs) 287 | 288 | # build laterals 289 | laterals = [ 290 | lateral_conv(inputs[i]) 291 | for i, lateral_conv in enumerate(self.lateral_convs) 292 | ] 293 | 294 | # laterals.append(self.psp_forward(inputs)) 295 | 296 | # build top-down path 297 | used_backbone_levels = len(laterals) 298 | for i in range(used_backbone_levels - 1, 0, -1): 299 | prev_shape = laterals[i - 1].shape[2:] 300 | laterals[i - 1] += resize( 301 | laterals[i], 302 | size=prev_shape, 303 | mode='bilinear', 304 | align_corners=self.align_corners) 305 | 306 | # build outputs 307 | fpn_outs = [ 308 | self.fpn_convs[i](laterals[i]) 309 | for i in range(used_backbone_levels - 1) 310 | ] 311 | # append psp feature 312 | fpn_outs.append(laterals[-1]) 313 | 314 | return fpn_outs[0] 315 | 316 | 317 | 318 | class PSP(BaseDecodeHead): 319 | """Unified Perceptual Parsing for Scene Understanding. 320 | 321 | This head is the implementation of `UPerNet 322 | `_. 323 | 324 | Args: 325 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 326 | Module applied on the last feature. Default: (1, 2, 3, 6). 327 | """ 328 | 329 | def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): 330 | super(PSP, self).__init__( 331 | input_transform='multiple_select', **kwargs) 332 | # PSP Module 333 | self.psp_modules = PPM( 334 | pool_scales, 335 | self.in_channels[-1], 336 | self.channels, 337 | conv_cfg=self.conv_cfg, 338 | norm_cfg=self.norm_cfg, 339 | act_cfg=self.act_cfg, 340 | align_corners=self.align_corners) 341 | self.bottleneck = ConvModule( 342 | self.in_channels[-1] + len(pool_scales) * self.channels, 343 | self.channels, 344 | 3, 345 | padding=1, 346 | conv_cfg=self.conv_cfg, 347 | norm_cfg=self.norm_cfg, 348 | act_cfg=self.act_cfg) 349 | 350 | def psp_forward(self, inputs): 351 | """Forward function of PSP module.""" 352 | x = inputs[-1] 353 | psp_outs = [x] 354 | psp_outs.extend(self.psp_modules(x)) 355 | psp_outs = torch.cat(psp_outs, dim=1) 356 | output = self.bottleneck(psp_outs) 357 | 358 | return output 359 | 360 | def forward(self, inputs): 361 | """Forward function.""" 362 | inputs = self._transform_inputs(inputs) 363 | 364 | return self.psp_forward(inputs) 365 | -------------------------------------------------------------------------------- /newcrfs/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | import os, sys, errno 8 | import argparse 9 | import time 10 | import numpy as np 11 | import cv2 12 | import matplotlib.pyplot as plt 13 | from tqdm import tqdm 14 | 15 | from utils import post_process_depth, flip_lr 16 | from networks.NewCRFDepth import NewCRFDepth 17 | 18 | 19 | def convert_arg_line_to_args(arg_line): 20 | for arg in arg_line.split(): 21 | if not arg.strip(): 22 | continue 23 | yield arg 24 | 25 | 26 | parser = argparse.ArgumentParser(description='NeWCRFs PyTorch implementation.', fromfile_prefix_chars='@') 27 | parser.convert_arg_line_to_args = convert_arg_line_to_args 28 | 29 | parser.add_argument('--model_name', type=str, help='model name', default='newcrfs') 30 | parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07') 31 | parser.add_argument('--data_path', type=str, help='path to the data', required=True) 32 | parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True) 33 | parser.add_argument('--input_height', type=int, help='input height', default=480) 34 | parser.add_argument('--input_width', type=int, help='input width', default=640) 35 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10) 36 | parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='') 37 | parser.add_argument('--dataset', type=str, help='dataset to train on', default='nyu') 38 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true') 39 | parser.add_argument('--save_viz', help='if set, save visulization of the outputs', action='store_true') 40 | 41 | if sys.argv.__len__() == 2: 42 | arg_filename_with_prefix = '@' + sys.argv[1] 43 | args = parser.parse_args([arg_filename_with_prefix]) 44 | else: 45 | args = parser.parse_args() 46 | 47 | if args.dataset == 'kitti' or args.dataset == 'nyu': 48 | from dataloaders.dataloader import NewDataLoader 49 | elif args.dataset == 'kittipred': 50 | from dataloaders.dataloader_kittipred import NewDataLoader 51 | 52 | model_dir = os.path.dirname(args.checkpoint_path) 53 | sys.path.append(model_dir) 54 | 55 | 56 | def get_num_lines(file_path): 57 | f = open(file_path, 'r') 58 | lines = f.readlines() 59 | f.close() 60 | return len(lines) 61 | 62 | 63 | def test(params): 64 | """Test function.""" 65 | args.mode = 'test' 66 | dataloader = NewDataLoader(args, 'test') 67 | 68 | model = NewCRFDepth(version='large07', inv_depth=False, max_depth=args.max_depth) 69 | model = torch.nn.DataParallel(model) 70 | 71 | checkpoint = torch.load(args.checkpoint_path) 72 | model.load_state_dict(checkpoint['model']) 73 | model.eval() 74 | model.cuda() 75 | 76 | num_params = sum([np.prod(p.size()) for p in model.parameters()]) 77 | print("Total number of parameters: {}".format(num_params)) 78 | 79 | num_test_samples = get_num_lines(args.filenames_file) 80 | 81 | with open(args.filenames_file) as f: 82 | lines = f.readlines() 83 | 84 | print('now testing {} files with {}'.format(num_test_samples, args.checkpoint_path)) 85 | 86 | pred_depths = [] 87 | start_time = time.time() 88 | with torch.no_grad(): 89 | for _, sample in enumerate(tqdm(dataloader.data)): 90 | image = Variable(sample['image'].cuda()) 91 | # Predict 92 | depth_est = model(image) 93 | post_process = True 94 | if post_process: 95 | image_flipped = flip_lr(image) 96 | depth_est_flipped = model(image_flipped) 97 | depth_est = post_process_depth(depth_est, depth_est_flipped) 98 | 99 | pred_depth = depth_est.cpu().numpy().squeeze() 100 | 101 | if args.do_kb_crop: 102 | height, width = 352, 1216 103 | top_margin = int(height - 352) 104 | left_margin = int((width - 1216) / 2) 105 | pred_depth_uncropped = np.zeros((height, width), dtype=np.float32) 106 | pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth 107 | pred_depth = pred_depth_uncropped 108 | 109 | pred_depths.append(pred_depth) 110 | 111 | elapsed_time = time.time() - start_time 112 | print('Elapesed time: %s' % str(elapsed_time)) 113 | print('Done.') 114 | 115 | save_name = 'models/result_' + args.model_name 116 | 117 | print('Saving result pngs..') 118 | if not os.path.exists(save_name): 119 | try: 120 | os.mkdir(save_name) 121 | os.mkdir(save_name + '/raw') 122 | os.mkdir(save_name + '/cmap') 123 | os.mkdir(save_name + '/rgb') 124 | os.mkdir(save_name + '/gt') 125 | except OSError as e: 126 | if e.errno != errno.EEXIST: 127 | raise 128 | 129 | for s in tqdm(range(num_test_samples)): 130 | if args.dataset == 'kitti': 131 | date_drive = lines[s].split('/')[1] 132 | filename_pred_png = save_name + '/raw/' + date_drive + '_' + lines[s].split()[0].split('/')[-1].replace( 133 | '.jpg', '.png') 134 | filename_cmap_png = save_name + '/cmap/' + date_drive + '_' + lines[s].split()[0].split('/')[ 135 | -1].replace('.jpg', '.png') 136 | filename_image_png = save_name + '/rgb/' + date_drive + '_' + lines[s].split()[0].split('/')[-1] 137 | elif args.dataset == 'kittipred': 138 | filename_pred_png = save_name + '/raw/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png') 139 | filename_cmap_png = save_name + '/cmap/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png') 140 | filename_image_png = save_name + '/rgb/' + lines[s].split()[0].split('/')[-1] 141 | else: 142 | scene_name = lines[s].split()[0].split('/')[0] 143 | filename_pred_png = save_name + '/raw/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace( 144 | '.jpg', '.png') 145 | filename_cmap_png = save_name + '/cmap/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1].replace( 146 | '.jpg', '.png') 147 | filename_gt_png = save_name + '/gt/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1].replace( 148 | '.jpg', '_gt.png') 149 | filename_image_png = save_name + '/rgb/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1] 150 | 151 | rgb_path = os.path.join(args.data_path, './' + lines[s].split()[0]) 152 | image = cv2.imread(rgb_path) 153 | if args.dataset == 'nyu': 154 | gt_path = os.path.join(args.data_path, './' + lines[s].split()[1]) 155 | gt = cv2.imread(gt_path, -1).astype(np.float32) / 1000.0 # Visualization purpose only 156 | gt[gt == 0] = np.amax(gt) 157 | 158 | pred_depth = pred_depths[s] 159 | 160 | if args.dataset == 'kitti' or args.dataset == 'kittipred': 161 | pred_depth_scaled = pred_depth * 256.0 162 | else: 163 | pred_depth_scaled = pred_depth * 1000.0 164 | 165 | pred_depth_scaled = pred_depth_scaled.astype(np.uint16) 166 | cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0]) 167 | 168 | if args.save_viz: 169 | cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :]) 170 | if args.dataset == 'nyu': 171 | plt.imsave(filename_gt_png, (10 - gt) / 10, cmap='plasma') 172 | pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9] 173 | plt.imsave(filename_cmap_png, (10 - pred_depth) / 10, cmap='plasma') 174 | else: 175 | plt.imsave(filename_cmap_png, np.log10(pred_depth), cmap='Greys') 176 | 177 | return 178 | 179 | 180 | if __name__ == '__main__': 181 | test(args) 182 | -------------------------------------------------------------------------------- /newcrfs/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.utils as utils 4 | import torch.backends.cudnn as cudnn 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | 8 | import os, sys, time 9 | from telnetlib import IP 10 | import argparse 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from tensorboardX import SummaryWriter 15 | 16 | from utils import post_process_depth, flip_lr, silog_loss, compute_errors, eval_metrics, \ 17 | block_print, enable_print, normalize_result, inv_normalize, convert_arg_line_to_args 18 | from networks.NewCRFDepth import NewCRFDepth 19 | 20 | 21 | parser = argparse.ArgumentParser(description='NeWCRFs PyTorch implementation.', fromfile_prefix_chars='@') 22 | parser.convert_arg_line_to_args = convert_arg_line_to_args 23 | 24 | parser.add_argument('--mode', type=str, help='train or test', default='train') 25 | parser.add_argument('--model_name', type=str, help='model name', default='newcrfs') 26 | parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07') 27 | parser.add_argument('--pretrain', type=str, help='path of pretrained encoder', default=None) 28 | 29 | # Dataset 30 | parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu') 31 | parser.add_argument('--data_path', type=str, help='path to the data', required=True) 32 | parser.add_argument('--gt_path', type=str, help='path to the groundtruth data', required=True) 33 | parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True) 34 | parser.add_argument('--input_height', type=int, help='input height', default=480) 35 | parser.add_argument('--input_width', type=int, help='input width', default=640) 36 | parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10) 37 | 38 | # Log and save 39 | parser.add_argument('--log_directory', type=str, help='directory to save checkpoints and summaries', default='') 40 | parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='') 41 | parser.add_argument('--log_freq', type=int, help='Logging frequency in global steps', default=100) 42 | parser.add_argument('--save_freq', type=int, help='Checkpoint saving frequency in global steps', default=5000) 43 | 44 | # Training 45 | parser.add_argument('--weight_decay', type=float, help='weight decay factor for optimization', default=1e-2) 46 | parser.add_argument('--retrain', help='if used with checkpoint_path, will restart training from step zero', action='store_true') 47 | parser.add_argument('--adam_eps', type=float, help='epsilon in Adam optimizer', default=1e-6) 48 | parser.add_argument('--batch_size', type=int, help='batch size', default=4) 49 | parser.add_argument('--num_epochs', type=int, help='number of epochs', default=50) 50 | parser.add_argument('--learning_rate', type=float, help='initial learning rate', default=1e-4) 51 | parser.add_argument('--end_learning_rate', type=float, help='end learning rate', default=-1) 52 | parser.add_argument('--variance_focus', type=float, help='lambda in paper: [0, 1], higher value more focus on minimizing variance of error', default=0.85) 53 | 54 | # Preprocessing 55 | parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true') 56 | parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5) 57 | parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true') 58 | parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true') 59 | 60 | # Multi-gpu training 61 | parser.add_argument('--num_threads', type=int, help='number of threads to use for data loading', default=1) 62 | parser.add_argument('--world_size', type=int, help='number of nodes for distributed training', default=1) 63 | parser.add_argument('--rank', type=int, help='node rank for distributed training', default=0) 64 | parser.add_argument('--dist_url', type=str, help='url used to set up distributed training', default='tcp://127.0.0.1:1234') 65 | parser.add_argument('--dist_backend', type=str, help='distributed backend', default='nccl') 66 | parser.add_argument('--gpu', type=int, help='GPU id to use.', default=None) 67 | parser.add_argument('--multiprocessing_distributed', help='Use multi-processing distributed training to launch ' 68 | 'N processes per node, which has N GPUs. This is the ' 69 | 'fastest way to use PyTorch for either single node or ' 70 | 'multi node data parallel training', action='store_true',) 71 | # Online eval 72 | parser.add_argument('--do_online_eval', help='if set, perform online eval in every eval_freq steps', action='store_true') 73 | parser.add_argument('--data_path_eval', type=str, help='path to the data for online evaluation', required=False) 74 | parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for online evaluation', required=False) 75 | parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for online evaluation', required=False) 76 | parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3) 77 | parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80) 78 | parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true') 79 | parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true') 80 | parser.add_argument('--eval_freq', type=int, help='Online evaluation frequency in global steps', default=500) 81 | parser.add_argument('--eval_summary_directory', type=str, help='output directory for eval summary,' 82 | 'if empty outputs to checkpoint folder', default='') 83 | 84 | if sys.argv.__len__() == 2: 85 | arg_filename_with_prefix = '@' + sys.argv[1] 86 | args = parser.parse_args([arg_filename_with_prefix]) 87 | else: 88 | args = parser.parse_args() 89 | 90 | if args.dataset == 'kitti' or args.dataset == 'nyu': 91 | from dataloaders.dataloader import NewDataLoader 92 | elif args.dataset == 'kittipred': 93 | from dataloaders.dataloader_kittipred import NewDataLoader 94 | 95 | 96 | def online_eval(model, dataloader_eval, gpu, ngpus, post_process=False): 97 | eval_measures = torch.zeros(10).cuda(device=gpu) 98 | for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)): 99 | with torch.no_grad(): 100 | image = torch.autograd.Variable(eval_sample_batched['image'].cuda(gpu, non_blocking=True)) 101 | gt_depth = eval_sample_batched['depth'] 102 | has_valid_depth = eval_sample_batched['has_valid_depth'] 103 | if not has_valid_depth: 104 | # print('Invalid depth. continue.') 105 | continue 106 | 107 | pred_depth = model(image) 108 | if post_process: 109 | image_flipped = flip_lr(image) 110 | pred_depth_flipped = model(image_flipped) 111 | pred_depth = post_process_depth(pred_depth, pred_depth_flipped) 112 | 113 | pred_depth = pred_depth.cpu().numpy().squeeze() 114 | gt_depth = gt_depth.cpu().numpy().squeeze() 115 | 116 | if args.do_kb_crop: 117 | height, width = gt_depth.shape 118 | top_margin = int(height - 352) 119 | left_margin = int((width - 1216) / 2) 120 | pred_depth_uncropped = np.zeros((height, width), dtype=np.float32) 121 | pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth 122 | pred_depth = pred_depth_uncropped 123 | 124 | pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval 125 | pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval 126 | pred_depth[np.isinf(pred_depth)] = args.max_depth_eval 127 | pred_depth[np.isnan(pred_depth)] = args.min_depth_eval 128 | 129 | valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval) 130 | 131 | if args.garg_crop or args.eigen_crop: 132 | gt_height, gt_width = gt_depth.shape 133 | eval_mask = np.zeros(valid_mask.shape) 134 | 135 | if args.garg_crop: 136 | eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1 137 | 138 | elif args.eigen_crop: 139 | if args.dataset == 'kitti': 140 | eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1 141 | elif args.dataset == 'nyu': 142 | eval_mask[45:471, 41:601] = 1 143 | 144 | valid_mask = np.logical_and(valid_mask, eval_mask) 145 | 146 | measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask]) 147 | 148 | eval_measures[:9] += torch.tensor(measures).cuda(device=gpu) 149 | eval_measures[9] += 1 150 | 151 | if args.multiprocessing_distributed: 152 | group = dist.new_group([i for i in range(ngpus)]) 153 | dist.all_reduce(tensor=eval_measures, op=dist.ReduceOp.SUM, group=group) 154 | 155 | if not args.multiprocessing_distributed or gpu == 0: 156 | eval_measures_cpu = eval_measures.cpu() 157 | cnt = eval_measures_cpu[9].item() 158 | eval_measures_cpu /= cnt 159 | print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process) 160 | print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms', 161 | 'sq_rel', 'log_rms', 'd1', 'd2', 162 | 'd3')) 163 | for i in range(8): 164 | print('{:7.4f}, '.format(eval_measures_cpu[i]), end='') 165 | print('{:7.4f}'.format(eval_measures_cpu[8])) 166 | return eval_measures_cpu 167 | 168 | return None 169 | 170 | 171 | def main_worker(gpu, ngpus_per_node, args): 172 | args.gpu = gpu 173 | 174 | if args.gpu is not None: 175 | print("== Use GPU: {} for training".format(args.gpu)) 176 | 177 | if args.distributed: 178 | if args.dist_url == "env://" and args.rank == -1: 179 | args.rank = int(os.environ["RANK"]) 180 | if args.multiprocessing_distributed: 181 | args.rank = args.rank * ngpus_per_node + gpu 182 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) 183 | 184 | # NeWCRFs model 185 | model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=args.pretrain) 186 | model.train() 187 | 188 | num_params = sum([np.prod(p.size()) for p in model.parameters()]) 189 | print("== Total number of parameters: {}".format(num_params)) 190 | 191 | num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad]) 192 | print("== Total number of learning parameters: {}".format(num_params_update)) 193 | 194 | if args.distributed: 195 | if args.gpu is not None: 196 | torch.cuda.set_device(args.gpu) 197 | model.cuda(args.gpu) 198 | args.batch_size = int(args.batch_size / ngpus_per_node) 199 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 200 | else: 201 | model.cuda() 202 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) 203 | else: 204 | model = torch.nn.DataParallel(model) 205 | model.cuda() 206 | 207 | if args.distributed: 208 | print("== Model Initialized on GPU: {}".format(args.gpu)) 209 | else: 210 | print("== Model Initialized") 211 | 212 | global_step = 0 213 | best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3 214 | best_eval_measures_higher_better = torch.zeros(3).cpu() 215 | best_eval_steps = np.zeros(9, dtype=np.int32) 216 | 217 | # Training parameters 218 | optimizer = torch.optim.Adam([{'params': model.module.parameters()}], 219 | lr=args.learning_rate) 220 | 221 | model_just_loaded = False 222 | if args.checkpoint_path != '': 223 | if os.path.isfile(args.checkpoint_path): 224 | print("== Loading checkpoint '{}'".format(args.checkpoint_path)) 225 | if args.gpu is None: 226 | checkpoint = torch.load(args.checkpoint_path) 227 | else: 228 | loc = 'cuda:{}'.format(args.gpu) 229 | checkpoint = torch.load(args.checkpoint_path, map_location=loc) 230 | model.load_state_dict(checkpoint['model']) 231 | optimizer.load_state_dict(checkpoint['optimizer']) 232 | if not args.retrain: 233 | try: 234 | global_step = checkpoint['global_step'] 235 | best_eval_measures_higher_better = checkpoint['best_eval_measures_higher_better'].cpu() 236 | best_eval_measures_lower_better = checkpoint['best_eval_measures_lower_better'].cpu() 237 | best_eval_steps = checkpoint['best_eval_steps'] 238 | except KeyError: 239 | print("Could not load values for online evaluation") 240 | 241 | print("== Loaded checkpoint '{}' (global_step {})".format(args.checkpoint_path, checkpoint['global_step'])) 242 | else: 243 | print("== No checkpoint found at '{}'".format(args.checkpoint_path)) 244 | model_just_loaded = True 245 | del checkpoint 246 | 247 | cudnn.benchmark = True 248 | 249 | dataloader = NewDataLoader(args, 'train') 250 | dataloader_eval = NewDataLoader(args, 'online_eval') 251 | 252 | # ===== Evaluation before training ====== 253 | # model.eval() 254 | # with torch.no_grad(): 255 | # eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node, post_process=True) 256 | 257 | # Logging 258 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): 259 | writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30) 260 | if args.do_online_eval: 261 | if args.eval_summary_directory != '': 262 | eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name) 263 | else: 264 | eval_summary_path = os.path.join(args.log_directory, args.model_name, 'eval') 265 | eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30) 266 | 267 | silog_criterion = silog_loss(variance_focus=args.variance_focus) 268 | 269 | start_time = time.time() 270 | duration = 0 271 | 272 | num_log_images = args.batch_size 273 | end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate 274 | 275 | var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad] 276 | var_cnt = len(var_sum) 277 | var_sum = np.sum(var_sum) 278 | 279 | print("== Initial variables' sum: {:.3f}, avg: {:.3f}".format(var_sum, var_sum/var_cnt)) 280 | 281 | steps_per_epoch = len(dataloader.data) 282 | num_total_steps = args.num_epochs * steps_per_epoch 283 | epoch = global_step // steps_per_epoch 284 | 285 | while epoch < args.num_epochs: 286 | if args.distributed: 287 | dataloader.train_sampler.set_epoch(epoch) 288 | 289 | for step, sample_batched in enumerate(dataloader.data): 290 | optimizer.zero_grad() 291 | before_op_time = time.time() 292 | 293 | image = torch.autograd.Variable(sample_batched['image'].cuda(args.gpu, non_blocking=True)) 294 | depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(args.gpu, non_blocking=True)) 295 | 296 | depth_est = model(image) 297 | 298 | if args.dataset == 'nyu': 299 | mask = depth_gt > 0.1 300 | else: 301 | mask = depth_gt > 1.0 302 | 303 | loss = silog_criterion.forward(depth_est, depth_gt, mask.to(torch.bool)) 304 | loss.backward() 305 | for param_group in optimizer.param_groups: 306 | current_lr = (args.learning_rate - end_learning_rate) * (1 - global_step / num_total_steps) ** 0.9 + end_learning_rate 307 | param_group['lr'] = current_lr 308 | 309 | optimizer.step() 310 | 311 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): 312 | print('[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'.format(epoch, step, steps_per_epoch, global_step, current_lr, loss)) 313 | if np.isnan(loss.cpu().item()): 314 | print('NaN in loss occurred. Aborting training.') 315 | return -1 316 | 317 | duration += time.time() - before_op_time 318 | if global_step and global_step % args.log_freq == 0 and not model_just_loaded: 319 | var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad] 320 | var_cnt = len(var_sum) 321 | var_sum = np.sum(var_sum) 322 | examples_per_sec = args.batch_size / duration * args.log_freq 323 | duration = 0 324 | time_sofar = (time.time() - start_time) / 3600 325 | training_time_left = (num_total_steps / global_step - 1.0) * time_sofar 326 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): 327 | print("{}".format(args.model_name)) 328 | print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h' 329 | print(print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item()/var_cnt, time_sofar, training_time_left)) 330 | 331 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 332 | and args.rank % ngpus_per_node == 0): 333 | writer.add_scalar('silog_loss', loss, global_step) 334 | writer.add_scalar('learning_rate', current_lr, global_step) 335 | writer.add_scalar('var average', var_sum.item()/var_cnt, global_step) 336 | depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e3, depth_gt) 337 | for i in range(num_log_images): 338 | writer.add_image('depth_gt/image/{}'.format(i), normalize_result(1/depth_gt[i, :, :, :].data), global_step) 339 | writer.add_image('depth_est/image/{}'.format(i), normalize_result(1/depth_est[i, :, :, :].data), global_step) 340 | writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step) 341 | writer.flush() 342 | 343 | if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded: 344 | time.sleep(0.1) 345 | model.eval() 346 | with torch.no_grad(): 347 | eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node, post_process=True) 348 | if eval_measures is not None: 349 | for i in range(9): 350 | eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step)) 351 | measure = eval_measures[i] 352 | is_best = False 353 | if i < 6 and measure < best_eval_measures_lower_better[i]: 354 | old_best = best_eval_measures_lower_better[i].item() 355 | best_eval_measures_lower_better[i] = measure.item() 356 | is_best = True 357 | elif i >= 6 and measure > best_eval_measures_higher_better[i-6]: 358 | old_best = best_eval_measures_higher_better[i-6].item() 359 | best_eval_measures_higher_better[i-6] = measure.item() 360 | is_best = True 361 | if is_best: 362 | old_best_step = best_eval_steps[i] 363 | old_best_name = '/model-{}-best_{}_{:.5f}'.format(old_best_step, eval_metrics[i], old_best) 364 | model_path = args.log_directory + '/' + args.model_name + old_best_name 365 | if os.path.exists(model_path): 366 | command = 'rm {}'.format(model_path) 367 | os.system(command) 368 | best_eval_steps[i] = global_step 369 | model_save_name = '/model-{}-best_{}_{:.5f}'.format(global_step, eval_metrics[i], measure) 370 | print('New best for {}. Saving model: {}'.format(eval_metrics[i], model_save_name)) 371 | checkpoint = {'global_step': global_step, 372 | 'model': model.state_dict(), 373 | 'optimizer': optimizer.state_dict(), 374 | 'best_eval_measures_higher_better': best_eval_measures_higher_better, 375 | 'best_eval_measures_lower_better': best_eval_measures_lower_better, 376 | 'best_eval_steps': best_eval_steps 377 | } 378 | torch.save(checkpoint, args.log_directory + '/' + args.model_name + model_save_name) 379 | eval_summary_writer.flush() 380 | model.train() 381 | block_print() 382 | enable_print() 383 | 384 | model_just_loaded = False 385 | global_step += 1 386 | 387 | epoch += 1 388 | 389 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): 390 | writer.close() 391 | if args.do_online_eval: 392 | eval_summary_writer.close() 393 | 394 | 395 | def main(): 396 | if args.mode != 'train': 397 | print('train.py is only for training.') 398 | return -1 399 | 400 | command = 'mkdir ' + os.path.join(args.log_directory, args.model_name) 401 | os.system(command) 402 | 403 | args_out_path = os.path.join(args.log_directory, args.model_name) 404 | command = 'cp ' + sys.argv[1] + ' ' + args_out_path 405 | os.system(command) 406 | 407 | save_files = True 408 | if save_files: 409 | aux_out_path = os.path.join(args.log_directory, args.model_name) 410 | networks_savepath = os.path.join(aux_out_path, 'networks') 411 | dataloaders_savepath = os.path.join(aux_out_path, 'dataloaders') 412 | command = 'cp newcrfs/train.py ' + aux_out_path 413 | os.system(command) 414 | command = 'mkdir -p ' + networks_savepath + ' && cp newcrfs/networks/*.py ' + networks_savepath 415 | os.system(command) 416 | command = 'mkdir -p ' + dataloaders_savepath + ' && cp newcrfs/dataloaders/*.py ' + dataloaders_savepath 417 | os.system(command) 418 | 419 | torch.cuda.empty_cache() 420 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 421 | 422 | ngpus_per_node = torch.cuda.device_count() 423 | if ngpus_per_node > 1 and not args.multiprocessing_distributed: 424 | print("This machine has more than 1 gpu. Please specify --multiprocessing_distributed, or set \'CUDA_VISIBLE_DEVICES=0\'") 425 | return -1 426 | 427 | if args.do_online_eval: 428 | print("You have specified --do_online_eval.") 429 | print("This will evaluate the model every eval_freq {} steps and save best models for individual eval metrics." 430 | .format(args.eval_freq)) 431 | 432 | if args.multiprocessing_distributed: 433 | args.world_size = ngpus_per_node * args.world_size 434 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 435 | else: 436 | main_worker(args.gpu, ngpus_per_node, args) 437 | 438 | 439 | if __name__ == '__main__': 440 | main() 441 | -------------------------------------------------------------------------------- /newcrfs/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from torch.utils.data import Sampler 5 | from torchvision import transforms 6 | 7 | import os, sys 8 | import numpy as np 9 | import math 10 | import torch 11 | 12 | 13 | def convert_arg_line_to_args(arg_line): 14 | for arg in arg_line.split(): 15 | if not arg.strip(): 16 | continue 17 | yield arg 18 | 19 | 20 | def block_print(): 21 | sys.stdout = open(os.devnull, 'w') 22 | 23 | 24 | def enable_print(): 25 | sys.stdout = sys.__stdout__ 26 | 27 | 28 | def get_num_lines(file_path): 29 | f = open(file_path, 'r') 30 | lines = f.readlines() 31 | f.close() 32 | return len(lines) 33 | 34 | 35 | def colorize(value, vmin=None, vmax=None, cmap='Greys'): 36 | value = value.cpu().numpy()[:, :, :] 37 | value = np.log10(value) 38 | 39 | vmin = value.min() if vmin is None else vmin 40 | vmax = value.max() if vmax is None else vmax 41 | 42 | if vmin != vmax: 43 | value = (value - vmin) / (vmax - vmin) 44 | else: 45 | value = value*0. 46 | 47 | cmapper = matplotlib.cm.get_cmap(cmap) 48 | value = cmapper(value, bytes=True) 49 | 50 | img = value[:, :, :3] 51 | 52 | return img.transpose((2, 0, 1)) 53 | 54 | 55 | def normalize_result(value, vmin=None, vmax=None): 56 | value = value.cpu().numpy()[0, :, :] 57 | 58 | vmin = value.min() if vmin is None else vmin 59 | vmax = value.max() if vmax is None else vmax 60 | 61 | if vmin != vmax: 62 | value = (value - vmin) / (vmax - vmin) 63 | else: 64 | value = value * 0. 65 | 66 | return np.expand_dims(value, 0) 67 | 68 | 69 | inv_normalize = transforms.Normalize( 70 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 71 | std=[1/0.229, 1/0.224, 1/0.225] 72 | ) 73 | 74 | 75 | eval_metrics = ['silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3'] 76 | 77 | 78 | def compute_errors(gt, pred): 79 | thresh = np.maximum((gt / pred), (pred / gt)) 80 | d1 = (thresh < 1.25).mean() 81 | d2 = (thresh < 1.25 ** 2).mean() 82 | d3 = (thresh < 1.25 ** 3).mean() 83 | 84 | rms = (gt - pred) ** 2 85 | rms = np.sqrt(rms.mean()) 86 | 87 | log_rms = (np.log(gt) - np.log(pred)) ** 2 88 | log_rms = np.sqrt(log_rms.mean()) 89 | 90 | abs_rel = np.mean(np.abs(gt - pred) / gt) 91 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 92 | 93 | err = np.log(pred) - np.log(gt) 94 | silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 95 | 96 | err = np.abs(np.log10(pred) - np.log10(gt)) 97 | log10 = np.mean(err) 98 | 99 | return [silog, abs_rel, log10, rms, sq_rel, log_rms, d1, d2, d3] 100 | 101 | 102 | class silog_loss(nn.Module): 103 | def __init__(self, variance_focus): 104 | super(silog_loss, self).__init__() 105 | self.variance_focus = variance_focus 106 | 107 | def forward(self, depth_est, depth_gt, mask): 108 | d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask]) 109 | return torch.sqrt((d ** 2).mean() - self.variance_focus * (d.mean() ** 2)) * 10.0 110 | 111 | 112 | def flip_lr(image): 113 | """ 114 | Flip image horizontally 115 | 116 | Parameters 117 | ---------- 118 | image : torch.Tensor [B,3,H,W] 119 | Image to be flipped 120 | 121 | Returns 122 | ------- 123 | image_flipped : torch.Tensor [B,3,H,W] 124 | Flipped image 125 | """ 126 | assert image.dim() == 4, 'You need to provide a [B,C,H,W] image to flip' 127 | return torch.flip(image, [3]) 128 | 129 | 130 | def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'): 131 | """ 132 | Fuse inverse depth and flipped inverse depth maps 133 | 134 | Parameters 135 | ---------- 136 | inv_depth : torch.Tensor [B,1,H,W] 137 | Inverse depth map 138 | inv_depth_hat : torch.Tensor [B,1,H,W] 139 | Flipped inverse depth map produced from a flipped image 140 | method : str 141 | Method that will be used to fuse the inverse depth maps 142 | 143 | Returns 144 | ------- 145 | fused_inv_depth : torch.Tensor [B,1,H,W] 146 | Fused inverse depth map 147 | """ 148 | if method == 'mean': 149 | return 0.5 * (inv_depth + inv_depth_hat) 150 | elif method == 'max': 151 | return torch.max(inv_depth, inv_depth_hat) 152 | elif method == 'min': 153 | return torch.min(inv_depth, inv_depth_hat) 154 | else: 155 | raise ValueError('Unknown post-process method {}'.format(method)) 156 | 157 | 158 | def post_process_depth(depth, depth_flipped, method='mean'): 159 | """ 160 | Post-process an inverse and flipped inverse depth map 161 | 162 | Parameters 163 | ---------- 164 | inv_depth : torch.Tensor [B,1,H,W] 165 | Inverse depth map 166 | inv_depth_flipped : torch.Tensor [B,1,H,W] 167 | Inverse depth map produced from a flipped image 168 | method : str 169 | Method that will be used to fuse the inverse depth maps 170 | 171 | Returns 172 | ------- 173 | inv_depth_pp : torch.Tensor [B,1,H,W] 174 | Post-processed inverse depth map 175 | """ 176 | B, C, H, W = depth.shape 177 | inv_depth_hat = flip_lr(depth_flipped) 178 | inv_depth_fused = fuse_inv_depth(depth, inv_depth_hat, method=method) 179 | xs = torch.linspace(0., 1., W, device=depth.device, 180 | dtype=depth.dtype).repeat(B, C, H, 1) 181 | mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.) 182 | mask_hat = flip_lr(mask) 183 | return mask_hat * depth + mask * inv_depth_hat + \ 184 | (1.0 - mask - mask_hat) * inv_depth_fused 185 | 186 | 187 | class DistributedSamplerNoEvenlyDivisible(Sampler): 188 | """Sampler that restricts data loading to a subset of the dataset. 189 | 190 | It is especially useful in conjunction with 191 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 192 | process can pass a DistributedSampler instance as a DataLoader sampler, 193 | and load a subset of the original dataset that is exclusive to it. 194 | 195 | .. note:: 196 | Dataset is assumed to be of constant size. 197 | 198 | Arguments: 199 | dataset: Dataset used for sampling. 200 | num_replicas (optional): Number of processes participating in 201 | distributed training. 202 | rank (optional): Rank of the current process within num_replicas. 203 | shuffle (optional): If true (default), sampler will shuffle the indices 204 | """ 205 | 206 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 207 | if num_replicas is None: 208 | if not dist.is_available(): 209 | raise RuntimeError("Requires distributed package to be available") 210 | num_replicas = dist.get_world_size() 211 | if rank is None: 212 | if not dist.is_available(): 213 | raise RuntimeError("Requires distributed package to be available") 214 | rank = dist.get_rank() 215 | self.dataset = dataset 216 | self.num_replicas = num_replicas 217 | self.rank = rank 218 | self.epoch = 0 219 | num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) 220 | rest = len(self.dataset) - num_samples * self.num_replicas 221 | if self.rank < rest: 222 | num_samples += 1 223 | self.num_samples = num_samples 224 | self.total_size = len(dataset) 225 | # self.total_size = self.num_samples * self.num_replicas 226 | self.shuffle = shuffle 227 | 228 | def __iter__(self): 229 | # deterministically shuffle based on epoch 230 | g = torch.Generator() 231 | g.manual_seed(self.epoch) 232 | if self.shuffle: 233 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 234 | else: 235 | indices = list(range(len(self.dataset))) 236 | 237 | # add extra samples to make it evenly divisible 238 | # indices += indices[:(self.total_size - len(indices))] 239 | # assert len(indices) == self.total_size 240 | 241 | # subsample 242 | indices = indices[self.rank:self.total_size:self.num_replicas] 243 | self.num_samples = len(indices) 244 | # assert len(indices) == self.num_samples 245 | 246 | return iter(indices) 247 | 248 | def __len__(self): 249 | return self.num_samples 250 | 251 | def set_epoch(self, epoch): 252 | self.epoch = epoch --------------------------------------------------------------------------------