├── License ├── README.md ├── evaluate_depth.py ├── evaluate_hr_depth.py ├── fig └── kittiandds.png ├── networks ├── __init__.py ├── hr_decoder.py ├── hr_layers.py ├── mpvit.py └── nets.py └── trainer.py /License: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Chaoqiang Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MonoViT 2 | 3 | This is the reference PyTorch implementation for training and testing depth estimation models using the method described in 4 | 5 | > **MonoViT: Self-Supervised Monocular Depth Estimation with a Vision Transformer** [arxiv](https://arxiv.org/abs/2208.03543) 6 | > 7 | >Chaoqiang Zhao*, Youmin Zhang*, Matteo Poggi, Fabio Tosi, Xianda Guo,Zheng Zhu, Guan Huang, Yang Tang, Stefano Mattoccia 8 | 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/monovit-self-supervised-monocular-depth/monocular-depth-estimation-on-kitti-eigen-1)](https://paperswithcode.com/sota/monocular-depth-estimation-on-kitti-eigen-1?p=monovit-self-supervised-monocular-depth) 10 | 11 |
sym
12 |
13 | 14 | 15 | If you find our work useful in your research please consider citing our paper: 16 | 17 | ``` 18 | @inproceedings{monovit, 19 | title={MonoViT: Self-Supervised Monocular Depth Estimation with a Vision Transformer}, 20 | author={Zhao, Chaoqiang and Zhang, Youmin and Poggi, Matteo and Tosi, Fabio and Guo, Xianda and Zhu, Zheng and Huang, Guan and Tang, Yang and Mattoccia, Stefano}, 21 | booktitle={International Conference on 3D Vision}, 22 | year={2022} 23 | } 24 | ``` 25 | 26 | 27 | 28 | ## ⚙️ Setup 29 | 30 | Assuming a fresh [Anaconda](https://www.anaconda.com/download/) distribution, you can install the dependencies with: 31 | ```shell 32 | pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 33 | pip install dominate==2.4.0 Pillow==6.1.0 visdom==0.1.8 34 | pip install tensorboardX==1.4 opencv-python matplotlib scikit-image 35 | pip3 install mmcv-full==1.3.0 mmsegmentation==0.11.0 36 | pip install timm einops IPython 37 | ``` 38 | We ran our experiments with PyTorch 1.9.0, CUDA 11.1, Python 3.7 and Ubuntu 18.04. 39 | 40 | Note that our code is built based on [Monodepth2](https://github.com/nianticlabs/monodepth2) 41 | 42 | ## Results on KITTI 43 | 44 | We provide the following options for `--model_name`: 45 | 46 | | `--model_name` | Training modality | Pretrained? | Model resolution |Abs Rel| Sq Rel| RMSE| RMSE log| delta < 1.25 | delta < 1.25^2 | delta < 1.25^3 | 47 | |-----------------------|-------------|------|-----------------|----|----|----|------|--------|--------|--------| 48 | | [`mono_640x192`](https://drive.google.com/drive/folders/1VWDPuqiMPDD2P--Oka-yJgh8z7ouCX4D?usp=sharing) | Mono | Yes | 640 x 192 | 0.099 |0.708 |4.372| 0.175 |0.900 |0.967| 0.984| 49 | | [`mono+stereo_640x192`](https://drive.google.com/drive/folders/1_HPsL1Vg3s0LdOykfTT0aMlE6-u3IxQn?usp=sharing) | Mono + Stereo | Yes | 640 x 192 | 0.098| 0.683| 4.333| 0.174| 0.904| 0.967| 0.984| 50 | | [`mono_1024x320`](https://drive.google.com/drive/folders/1EDTSZ59CGW9rUoDL3EwEKn3PpZpUUGsS?usp=sharing) | Mono | Yes | 1024 x 320 | 0.096| 0.714| 4.292| 0.172| 0.908| 0.968| 0.984| 51 | | [`mono+stereo_1024x320`](https://drive.google.com/drive/folders/1tez1RQFO33MMyVAq_gkOVHoL2TO98-TH?usp=sharing) | Mono + Stereo | Yes | 1024 x 320 | 0.093 |0.671 |4.202 |0.169 |0.912 |0.969 |0.985| 52 | | [`mono_1280x384`](https://drive.google.com/drive/folders/1l3egRvLaoBqgYrgfktgpJt613QwZ4twT?usp=sharing) | Mono | Yes | 1280 x 384 | 0.094 |0.682| 4.200| 0.170| 0.912| 0.969| 0.984| 53 | 54 | ## Robustness 55 | 56 | | Model | Modality | mCE (%) | mRR (%) | Clean | Bright | Dark | Fog | Frost | Snow | Contrast | Defocus | Glass | Motion | Zoom | Elastic| Quant| Gaussian | Impulse | Shot | ISO | Pixelate | JPEG | 57 | | :-- | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | 58 | | [MonoDepth2R18]()| Mono | 100.00 | 84.46 | 0.119 | 0.130 | 0.280 | 0.155 | 0.277 | 0.511 | 0.187 | 0.244 | 0.242 | 0.216 | 0.201 | 0.129 | 0.193 | 0.384 | 0.389 | 0.340 | 0.388 | 0.145 | 0.196 | 59 | | [MonoDepth2R18+nopt]() | Mono | 119.75 | 82.50 | 0.144 | 0.183 | 0.343 | 0.311 | 0.312 | 0.399 | 0.416 | 0.254 | 0.232 | 0.199 | 0.207 | 0.148 | 0.212 | 0.441 | 0.452 | 0.402 | 0.453 | 0.153 | 0.171 | 60 | | [MonoDepth2R18+HR]() | Mono | 106.06 | 82.44 | 0.114 | 0.129 | 0.376 | 0.155 | 0.271 | 0.582 | 0.214 | 0.393 | 0.257 | 0.230 | 0.232 | 0.123 | 0.215 | 0.326 | 0.352 | 0.317 | 0.344 | 0.138 | 0.198 | 61 | | [MonoDepth2R50]() | Mono | 113.43 | 80.59 | 0.117 | 0.127 | 0.294 | 0.155 | 0.287 | 0.492 | 0.233 | 0.427 | 0.392 | 0.277 | 0.208 | 0.130 | 0.198 | 0.409 | 0.403 | 0.368 | 0.425 | 0.155 | 0.211 | 62 | | [MaskOcc]() | Mono | 104.05 | 82.97 | 0.117 | 0.130 | 0.285 | 0.154 | 0.283 | 0.492 | 0.200 | 0.318 | 0.295 | 0.228 | 0.201 | 0.129 | 0.184 | 0.403 | 0.410 | 0.364 | 0.417 | 0.143 | 0.177 | 63 | | [DNetR18]() | Mono | 104.71 | 83.34 | 0.118 | 0.128 | 0.264 | 0.156 | 0.317 | 0.504 | 0.209 | 0.348 | 0.320 | 0.242 | 0.215 | 0.131 | 0.189 | 0.362 | 0.366 | 0.326 | 0.357 | 0.145 | 0.190 | 64 | | [CADepth]() | Mono | 110.11 | 80.07 | 0.108 | 0.121 | 0.300 | 0.142 | 0.324 | 0.529 | 0.193 | 0.356 | 0.347 | 0.285 | 0.208 | 0.121 | 0.192 | 0.423 | 0.433 | 0.383 | 0.448 | 0.144 | 0.195 | 65 | | [HR-Depth]() | Mono | 103.73 | 82.93 | 0.112 | 0.121 | 0.289 | 0.151 | 0.279 | 0.481 | 0.213 | 0.356 | 0.300 | 0.263 | 0.224 | 0.124 | 0.187 | 0.363 | 0.373 | 0.336 | 0.374 | 0.135 | 0.176 | 66 | | [DIFFNetHRNet]() | Mono | 94.96 | 85.41 | 0.102 | 0.111 | 0.222 | 0.131 | 0.199 | 0.352 | 0.161 | 0.513 | 0.330 | 0.280 | 0.197 | 0.114 | 0.165 | 0.292 | 0.266 | 0.255 | 0.270 | 0.135 | 0.202 | 67 | | [ManyDepthsingle]() | Mono | 105.41 | 83.11 | 0.123 | 0.135 | 0.274 | 0.169 | 0.288 | 0.479 | 0.227 | 0.254 | 0.279 | 0.211 | 0.194 | 0.134 | 0.189 | 0.430 | 0.450 | 0.387 | 0.452 | 0.147 | 0.182 | 68 | | [FSRE-Depth]() | Mono | 99.05 | 83.86 | 0.109 | 0.128 | 0.261 | 0.139 | 0.237 | 0.393 | 0.170 | 0.291 | 0.273 | 0.214 | 0.185 | 0.119 | 0.179 | 0.400 | 0.414 | 0.370 | 0.407 | 0.147 | 0.224 | 69 | | [MonoViTMPViT]() | Mono | 79.33 | 89.15 | 0.099 | 0.106 | 0.243 | 0.116 | 0.213 | 0.275 | 0.119 | 0.180 | 0.204 | 0.163 | 0.179 | 0.118 | 0.146 | 0.310 | 0.293 | 0.271 | 0.290 | 0.162 | 0.154 | 70 | | [MonoViTMPViT+HR]() | Mono | 70.79 | 90.67 | 0.090 | 0.097 | 0.221 | 0.113 | 0.217 | 0.253 | 0.113 | 0.146 | 0.159 | 0.144 | 0.175 | 0.098 | 0.138 | 0.267 | 0.246 | 0.236 | 0.246 | 0.135 | 0.145 | 71 | 72 | The [RoboDepth Challenge Team](https://github.com/ldkong1205/RoboDepth) is evaluating the robustness of different depth estimation algorithms. MonoViT has achieved the outstanding robustness. 73 | 74 | ## 💾 KITTI training data 75 | 76 | You can download the entire [raw KITTI dataset](http://www.cvlibs.net/datasets/kitti/raw_data.php) by running: 77 | ```shell 78 | wget -i splits/kitti_archives_to_download.txt -P kitti_data/ 79 | ``` 80 | Then unzip with 81 | ```shell 82 | cd kitti_data 83 | unzip "*.zip" 84 | cd .. 85 | ``` 86 | **Warning:** it weighs about **175GB**, so make sure you have enough space to unzip too! 87 | 88 | Our default settings expect that you have converted the png images to jpeg with this command, **which also deletes the raw KITTI `.png` files**: 89 | ```shell 90 | find kitti_data/ -name '*.png' | parallel 'convert -quality 92 -sampling-factor 2x2,1x1,1x1 {.}.png {.}.jpg && rm {}' 91 | ``` 92 | **or** you can skip this conversion step and train from raw png files by adding the flag `--png` when training, at the expense of slower load times. 93 | 94 | The above conversion command creates images which match our experiments, where KITTI `.png` images were converted to `.jpg` on Ubuntu 16.04 with default chroma subsampling `2x2,1x1,1x1`. 95 | We found that Ubuntu 18.04 defaults to `2x2,2x2,2x2`, which gives different results, hence the explicit parameter in the conversion command. 96 | 97 | You can also place the KITTI dataset wherever you like and point towards it with the `--data_path` flag during training and evaluation. 98 | 99 | **Splits** 100 | 101 | The train/test/validation splits are defined in the `splits/` folder. 102 | By default, the code will train a depth model using [Zhou's subset](https://github.com/tinghuiz/SfMLearner) of the standard Eigen split of KITTI, which is designed for monocular training. 103 | You can also train a model using the new [benchmark split](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_prediction) or the [odometry split](http://www.cvlibs.net/datasets/kitti/eval_odometry.php) by setting the `--split` flag. 104 | 105 | 106 | **Custom dataset** 107 | 108 | You can train on a custom monocular or stereo dataset by writing a new dataloader class which inherits from `MonoDataset` – see the `KITTIDataset` class in `datasets/kitti_dataset.py` for an example. 109 | 110 | 111 | ## ⏳ Training 112 | 113 | PLease download the ImageNet-1K pretrained MPViT [model](https://dl.dropbox.com/s/y3dnmmy8h4npz7a/mpvit_small.pth) to `./ckpt/`. 114 | 115 | For training, please download monodepth2, replace the depth network, and revise the setting of the depth network, the optimizer and learning rate according to `trainer.py`. 116 | 117 | Because of the different torch version between MonoViT and Monodepth2, the func `transforms.ColorJitter.get_params` in dataloader should also be revised to `transforms.ColorJitter`. 118 | 119 | By default models and tensorboard event files are saved to `./tmp/`. 120 | This can be changed with the `--log_dir` flag. 121 | 122 | 123 | **Monocular training:** 124 | ```shell 125 | python train.py --model_name mono_model --learning_rate 5e-5 126 | ``` 127 | 128 | **Monocular + stereo training:** 129 | ```shell 130 | python train.py --model_name mono+stereo_model --use_stereo --learning_rate 5e-5 131 | ``` 132 | 133 | 134 | ### GPUs 135 | 136 | The code of the Single GPU version can only be run on a single GPU. 137 | You can specify which GPU to use with the `CUDA_VISIBLE_DEVICES` environment variable: 138 | ```shell 139 | CUDA_VISIBLE_DEVICES=1 python train.py --model_name mono_model 140 | ``` 141 | 142 | ## 📊 KITTI evaluation 143 | 144 | To prepare the ground truth depth maps, please follow the monodepth2. 145 | 146 | ...assuming that you have placed the KITTI dataset in the default location of `./kitti_data/`. 147 | 148 | The following example command evaluates the epoch 19 weights of a model named `mono_model` (Note that please use `evaluate_depth.py` for 640x192 models and `evaluate_hr_depth.py --height 320/384 --width 1024/1280` for the others): 149 | ```shell 150 | python evaluate_depth.py --load_weights_folder ./tmp/mono_model/models/weights_19/ --eval_mono 151 | ``` 152 | 153 | 154 | An additional parameter `--eval_split` can be set. 155 | The three different values possible for `eval_split` are explained here: 156 | 157 | | `--eval_split` | Test set size | For models trained with... | Description | 158 | |-----------------------|---------------|----------------------------|--------------| 159 | | **`eigen`** | 697 | `--split eigen_zhou` (default) or `--split eigen_full` | The standard Eigen test files | 160 | | **`eigen_benchmark`** | 652 | `--split eigen_zhou` (default) or `--split eigen_full` | Evaluate with the improved ground truth from the [new KITTI depth benchmark](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_prediction) | 161 | | **`benchmark`** | 500 | `--split benchmark` | The [new KITTI depth benchmark](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_prediction) test files. | 162 | 163 | ## Contact us 164 | 165 | Contact us: zhaocqilc@gmail.com 166 | 167 | ## Acknowledgement 168 | Thanks the authors for their works: 169 | 170 | [Monodepth2](https://github.com/nianticlabs/monodepth2) 171 | 172 | [MPVIT](https://github.com/youngwanLEE/MPViT) 173 | 174 | [HR-Depth](https://github.com/shawLyu/HR-Depth) 175 | 176 | [DIFFNet](https://github.com/brandleyzhou/DIFFNet) 177 | -------------------------------------------------------------------------------- /evaluate_depth.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from layers import disp_to_depth 11 | from utils import readlines 12 | from options import MonodepthOptions 13 | import datasets 14 | import networks 15 | 16 | cv2.setNumThreads(0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1) 17 | 18 | 19 | splits_dir = os.path.join(os.path.dirname(__file__), "splits") 20 | 21 | # Models which were trained with stereo supervision were trained with a nominal 22 | # baseline of 0.1 units. The KITTI rig has a baseline of 54cm. Therefore, 23 | # to convert our stereo predictions to real-world scale we multiply our depths by 5.4. 24 | STEREO_SCALE_FACTOR = 5.4 25 | 26 | 27 | def compute_errors(gt, pred): 28 | """Computation of error metrics between predicted and ground truth depths 29 | """ 30 | thresh = np.maximum((gt / pred), (pred / gt)) 31 | a1 = (thresh < 1.25 ).mean() 32 | a2 = (thresh < 1.25 ** 2).mean() 33 | a3 = (thresh < 1.25 ** 3).mean() 34 | 35 | rmse = (gt - pred) ** 2 36 | rmse = np.sqrt(rmse.mean()) 37 | 38 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 39 | rmse_log = np.sqrt(rmse_log.mean()) 40 | 41 | abs_rel = np.mean(np.abs(gt - pred) / gt) 42 | 43 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 44 | 45 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 46 | 47 | 48 | def batch_post_process_disparity(l_disp, r_disp): 49 | """Apply the disparity post-processing method as introduced in Monodepthv1 50 | """ 51 | _, h, w = l_disp.shape 52 | m_disp = 0.5 * (l_disp + r_disp) 53 | l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h)) 54 | l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...] 55 | r_mask = l_mask[:, :, ::-1] 56 | return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp 57 | 58 | 59 | def evaluate(opt): 60 | """Evaluates a pretrained model using a specified test set 61 | """ 62 | MIN_DEPTH = 1e-3 63 | MAX_DEPTH = 80 64 | 65 | assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \ 66 | "Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo" 67 | 68 | if opt.ext_disp_to_eval is None: 69 | 70 | opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder) 71 | 72 | assert os.path.isdir(opt.load_weights_folder), \ 73 | "Cannot find a folder at {}".format(opt.load_weights_folder) 74 | 75 | print("-> Loading weights from {}".format(opt.load_weights_folder)) 76 | 77 | filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt")) 78 | encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth") 79 | decoder_path = os.path.join(opt.load_weights_folder, "depth.pth") 80 | 81 | encoder_dict = torch.load(encoder_path) 82 | 83 | dataset = datasets.KITTIRAWDataset(opt.data_path, filenames, 84 | encoder_dict['height'], encoder_dict['width'], 85 | [0], 4, is_train=False) 86 | dataloader = DataLoader(dataset, 8, shuffle=False, num_workers=opt.num_workers, 87 | pin_memory=True, drop_last=False) 88 | 89 | encoder = networks.mpvit_small() #networks.ResnetEncoder(opt.num_layers, False) 90 | encoder.num_ch_enc = [64,128,216,288,288] # = networks.ResnetEncoder(opt.num_layers, False) 91 | depth_decoder = networks.DepthDecoder() 92 | 93 | model_dict = encoder.state_dict() 94 | encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict}) 95 | depth_decoder.load_state_dict(torch.load(decoder_path)) 96 | 97 | encoder.cuda() 98 | encoder.eval() 99 | depth_decoder.cuda() 100 | depth_decoder.eval() 101 | 102 | pred_disps = [] 103 | 104 | print("-> Computing predictions with size {}x{}".format( 105 | encoder_dict['width'], encoder_dict['height'])) 106 | 107 | with torch.no_grad(): 108 | for data in dataloader: 109 | input_color = data[("color", 0, 0)].cuda() 110 | 111 | if opt.post_process: 112 | # Post-processed results require each image to have two forward passes 113 | input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0) 114 | 115 | output = depth_decoder(encoder(input_color)) 116 | 117 | pred_disp, _ = disp_to_depth(output[("disp", 0)], opt.min_depth, opt.max_depth) 118 | pred_disp = pred_disp.cpu()[:, 0].numpy() 119 | 120 | if opt.post_process: 121 | N = pred_disp.shape[0] // 2 122 | pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1]) 123 | 124 | pred_disps.append(pred_disp) 125 | 126 | pred_disps = np.concatenate(pred_disps) 127 | 128 | else: 129 | # Load predictions from file 130 | print("-> Loading predictions from {}".format(opt.ext_disp_to_eval)) 131 | pred_disps = np.load(opt.ext_disp_to_eval) 132 | 133 | if opt.eval_eigen_to_benchmark: 134 | eigen_to_benchmark_ids = np.load( 135 | os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy")) 136 | 137 | pred_disps = pred_disps[eigen_to_benchmark_ids] 138 | 139 | if opt.save_pred_disps: 140 | output_path = os.path.join( 141 | opt.load_weights_folder, "disps_{}_split.npy".format(opt.eval_split)) 142 | print("-> Saving predicted disparities to ", output_path) 143 | np.save(output_path, pred_disps) 144 | 145 | if opt.no_eval: 146 | print("-> Evaluation disabled. Done.") 147 | quit() 148 | 149 | elif opt.eval_split == 'benchmark': 150 | save_dir = os.path.join(opt.load_weights_folder, "benchmark_predictions") 151 | print("-> Saving out benchmark predictions to {}".format(save_dir)) 152 | if not os.path.exists(save_dir): 153 | os.makedirs(save_dir) 154 | 155 | for idx in range(len(pred_disps)): 156 | disp_resized = cv2.resize(pred_disps[idx], (1216, 352)) 157 | depth = STEREO_SCALE_FACTOR / disp_resized 158 | depth = np.clip(depth, 0, 80) 159 | depth = np.uint16(depth * 256) 160 | save_path = os.path.join(save_dir, "{:010d}.png".format(idx)) 161 | cv2.imwrite(save_path, depth) 162 | 163 | print("-> No ground truth is available for the KITTI benchmark, so not evaluating. Done.") 164 | quit() 165 | 166 | gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz") 167 | gt_depths = np.load(gt_path, fix_imports=True,allow_pickle=True, encoding='latin1')["data"] 168 | 169 | print("-> Evaluating") 170 | 171 | if opt.eval_stereo: 172 | print(" Stereo evaluation - " 173 | "disabling median scaling, scaling by {}".format(STEREO_SCALE_FACTOR)) 174 | opt.disable_median_scaling = True 175 | opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR 176 | else: 177 | print(" Mono evaluation - using median scaling") 178 | 179 | errors = [] 180 | ratios = [] 181 | 182 | for i in range(pred_disps.shape[0]): 183 | 184 | gt_depth = gt_depths[i] 185 | gt_height, gt_width = gt_depth.shape[:2] 186 | 187 | pred_disp = pred_disps[i] 188 | pred_disp = cv2.resize(pred_disp, (gt_width, gt_height)) 189 | pred_depth = 1 / pred_disp 190 | 191 | if opt.eval_split == "eigen": 192 | mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH) 193 | 194 | crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height, 195 | 0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32) 196 | crop_mask = np.zeros(mask.shape) 197 | crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1 198 | mask = np.logical_and(mask, crop_mask) 199 | 200 | else: 201 | mask = gt_depth > 0 202 | 203 | pred_depth = pred_depth[mask] 204 | gt_depth = gt_depth[mask] 205 | 206 | pred_depth *= opt.pred_depth_scale_factor 207 | if not opt.disable_median_scaling: 208 | ratio = np.median(gt_depth) / np.median(pred_depth) 209 | ratios.append(ratio) 210 | pred_depth *= ratio 211 | 212 | pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH 213 | pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH 214 | 215 | errors.append(compute_errors(gt_depth, pred_depth)) 216 | 217 | if not opt.disable_median_scaling: 218 | ratios = np.array(ratios) 219 | med = np.median(ratios) 220 | print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med))) 221 | 222 | mean_errors = np.array(errors).mean(0) 223 | 224 | 225 | 226 | results_edit=open('results.txt',mode='a') 227 | results_edit.write("\n " + 'model_name: %s '%(opt.load_weights_folder)) 228 | results_edit.write("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 229 | results_edit.write("\n " + ("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 230 | results_edit.close() 231 | print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 232 | print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 233 | print("\n-> Done!") 234 | 235 | 236 | if __name__ == "__main__": 237 | options = MonodepthOptions() 238 | evaluate(options.parse()) 239 | -------------------------------------------------------------------------------- /evaluate_hr_depth.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from layers import disp_to_depth 11 | from utils import readlines 12 | from options import MonodepthOptions 13 | import datasets 14 | import networks 15 | 16 | cv2.setNumThreads(0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1) 17 | 18 | 19 | splits_dir = os.path.join(os.path.dirname(__file__), "splits") 20 | 21 | # Models which were trained with stereo supervision were trained with a nominal 22 | # baseline of 0.1 units. The KITTI rig has a baseline of 54cm. Therefore, 23 | # to convert our stereo predictions to real-world scale we multiply our depths by 5.4. 24 | STEREO_SCALE_FACTOR = 5.4 25 | 26 | 27 | def compute_errors(gt, pred): 28 | """Computation of error metrics between predicted and ground truth depths 29 | """ 30 | thresh = np.maximum((gt / pred), (pred / gt)) 31 | a1 = (thresh < 1.25 ).mean() 32 | a2 = (thresh < 1.25 ** 2).mean() 33 | a3 = (thresh < 1.25 ** 3).mean() 34 | 35 | rmse = (gt - pred) ** 2 36 | rmse = np.sqrt(rmse.mean()) 37 | 38 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 39 | rmse_log = np.sqrt(rmse_log.mean()) 40 | 41 | abs_rel = np.mean(np.abs(gt - pred) / gt) 42 | 43 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 44 | 45 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 46 | 47 | 48 | def batch_post_process_disparity(l_disp, r_disp): 49 | """Apply the disparity post-processing method as introduced in Monodepthv1 50 | """ 51 | _, h, w = l_disp.shape 52 | m_disp = 0.5 * (l_disp + r_disp) 53 | l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h)) 54 | l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...] 55 | r_mask = l_mask[:, :, ::-1] 56 | return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp 57 | 58 | 59 | def evaluate(opt): 60 | """Evaluates a pretrained model using a specified test set 61 | """ 62 | MIN_DEPTH = 1e-3 63 | MAX_DEPTH = 80 64 | 65 | assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \ 66 | "Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo" 67 | 68 | if opt.ext_disp_to_eval is None: 69 | 70 | opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder) 71 | 72 | assert os.path.isdir(opt.load_weights_folder), \ 73 | "Cannot find a folder at {}".format(opt.load_weights_folder) 74 | 75 | print("-> Loading weights from {}".format(opt.load_weights_folder)) 76 | 77 | filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt")) 78 | depth_path = os.path.join(opt.load_weights_folder, "depth.pth") 79 | 80 | depth_dict = torch.load(depth_path) 81 | #new_dict = depth_dict 82 | new_dict = {} 83 | for k,v in depth_dict.items(): 84 | name = k[7:] 85 | new_dict[name]=v 86 | dataset = datasets.KITTIRAWDataset(opt.data_path, filenames, 87 | opt.height,opt.width, 88 | [0], 4, is_train=False) 89 | dataloader = DataLoader(dataset, 8, shuffle=False, num_workers=opt.num_workers, 90 | pin_memory=True, drop_last=False) 91 | 92 | depth = networks.DeepNet('mpvitnet') 93 | depth.load_state_dict({k: v for k, v in new_dict.items() if k in depth.state_dict()}) 94 | #depth.load_state_dict({k: v for k, v in new_dict.items() if k in depth.state_dict()}) 95 | 96 | depth.cuda() 97 | depth.eval() 98 | 99 | pred_disps = [] 100 | 101 | print("-> Computing predictions with size {}x{}".format( 102 | opt.height,opt.width)) 103 | 104 | with torch.no_grad(): 105 | for data in dataloader: 106 | input_color = data[("color", 0, 0)].cuda() 107 | 108 | if opt.post_process: 109 | # Post-processed results require each image to have two forward passes 110 | input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0) 111 | 112 | output = depth(input_color) 113 | 114 | pred_disp, _ = disp_to_depth(output[("disp", 0)], opt.min_depth, opt.max_depth) 115 | pred_disp = pred_disp.cpu()[:, 0].numpy() 116 | 117 | if opt.post_process: 118 | N = pred_disp.shape[0] // 2 119 | pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1]) 120 | 121 | pred_disps.append(pred_disp) 122 | 123 | pred_disps = np.concatenate(pred_disps) 124 | 125 | else: 126 | # Load predictions from file 127 | print("-> Loading predictions from {}".format(opt.ext_disp_to_eval)) 128 | pred_disps = np.load(opt.ext_disp_to_eval) 129 | if opt.eval_eigen_to_benchmark: 130 | eigen_to_benchmark_ids = np.load( 131 | os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy")) 132 | 133 | pred_disps = pred_disps[eigen_to_benchmark_ids] 134 | 135 | if opt.save_pred_disps: 136 | output_path = os.path.join( 137 | opt.load_weights_folder, "disps_{}_split.npy".format(opt.eval_split)) 138 | print("-> Saving predicted disparities to ", output_path) 139 | np.save(output_path, pred_disps) 140 | 141 | if opt.no_eval: 142 | print("-> Evaluation disabled. Done.") 143 | quit() 144 | 145 | elif opt.eval_split == 'benchmark': 146 | save_dir = os.path.join(opt.load_weights_folder, "benchmark_predictions") 147 | print("-> Saving out benchmark predictions to {}".format(save_dir)) 148 | if not os.path.exists(save_dir): 149 | os.makedirs(save_dir) 150 | 151 | for idx in range(len(pred_disps)): 152 | disp_resized = cv2.resize(pred_disps[idx], (1216, 352)) 153 | depth = STEREO_SCALE_FACTOR / disp_resized 154 | depth = np.clip(depth, 0, 80) 155 | depth = np.uint16(depth * 256) 156 | save_path = os.path.join(save_dir, "{:010d}.png".format(idx)) 157 | cv2.imwrite(save_path, depth) 158 | 159 | print("-> No ground truth is available for the KITTI benchmark, so not evaluating. Done.") 160 | quit() 161 | 162 | gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz") 163 | gt_depths = np.load(gt_path, fix_imports=True,allow_pickle=True, encoding='latin1')["data"] 164 | 165 | print("-> Evaluating") 166 | 167 | if opt.eval_stereo: 168 | print(" Stereo evaluation - " 169 | "disabling median scaling, scaling by {}".format(STEREO_SCALE_FACTOR)) 170 | opt.disable_median_scaling = True 171 | opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR 172 | else: 173 | print(" Mono evaluation - using median scaling") 174 | 175 | errors = [] 176 | ratios = [] 177 | 178 | for i in range(pred_disps.shape[0]): 179 | 180 | gt_depth = gt_depths[i] 181 | gt_height, gt_width = gt_depth.shape[:2] 182 | 183 | pred_disp = pred_disps[i] 184 | pred_disp = cv2.resize(pred_disp, (gt_width, gt_height)) 185 | pred_depth = 1 / pred_disp 186 | 187 | if opt.eval_split == "eigen": 188 | mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH) 189 | 190 | crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height, 191 | 0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32) 192 | crop_mask = np.zeros(mask.shape) 193 | crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1 194 | mask = np.logical_and(mask, crop_mask) 195 | 196 | else: 197 | mask = gt_depth > 0 198 | 199 | pred_depth = pred_depth[mask] 200 | gt_depth = gt_depth[mask] 201 | 202 | pred_depth *= opt.pred_depth_scale_factor 203 | if not opt.disable_median_scaling: 204 | ratio = np.median(gt_depth) / np.median(pred_depth) 205 | ratios.append(ratio) 206 | pred_depth *= ratio 207 | 208 | pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH 209 | pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH 210 | 211 | errors.append(compute_errors(gt_depth, pred_depth)) 212 | 213 | if not opt.disable_median_scaling: 214 | ratios = np.array(ratios) 215 | med = np.median(ratios) 216 | print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med))) 217 | 218 | mean_errors = np.array(errors).mean(0) 219 | 220 | save_dir = opt.load_weights_folder[:-7] 221 | results_edit=open('results.txt',mode='a') 222 | results_edit.write("\n " + 'model_name: %s '%(opt.load_weights_folder)) 223 | results_edit.write("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 224 | results_edit.write("\n " + ("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 225 | results_edit.close() 226 | print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 227 | print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 228 | print("\n-> Done!") 229 | 230 | if __name__ == "__main__": 231 | options = MonodepthOptions() 232 | evaluate(options.parse()) -------------------------------------------------------------------------------- /fig/kittiandds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxcqlf/MonoViT/3960e94ce4980ffb7dabc879bd5566323167126f/fig/kittiandds.png -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | #from .resnet_encoder import ResnetEncoder 2 | #from .pose_decoder import PoseDecoder 3 | #from .pose_cnn import PoseCNN 4 | from .hr_decoder import DepthDecoder 5 | from .mpvit import * 6 | from .nets import DeepNet 7 | -------------------------------------------------------------------------------- /networks/hr_decoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from collections import OrderedDict 7 | from .hr_layers import * 8 | 9 | 10 | class DepthDecoder(nn.Module): 11 | def __init__(self, ch_enc = [64,128,216,288,288], scales=range(4),num_ch_enc = [ 64, 64, 128, 256, 512 ], num_output_channels=1): 12 | super(DepthDecoder, self).__init__() 13 | self.num_output_channels = num_output_channels 14 | self.num_ch_enc = num_ch_enc 15 | self.ch_enc = ch_enc 16 | self.scales = scales 17 | self.num_ch_dec = np.array([16, 32, 64, 128, 256]) 18 | self.convs = nn.ModuleDict() 19 | 20 | # decoder 21 | self.convs = nn.ModuleDict() 22 | 23 | # feature fusion 24 | self.convs["f4"] = Attention_Module(self.ch_enc[4] , num_ch_enc[4]) 25 | self.convs["f3"] = Attention_Module(self.ch_enc[3] , num_ch_enc[3]) 26 | self.convs["f2"] = Attention_Module(self.ch_enc[2] , num_ch_enc[2]) 27 | self.convs["f1"] = Attention_Module(self.ch_enc[1] , num_ch_enc[1]) 28 | 29 | 30 | 31 | self.all_position = ["01", "11", "21", "31", "02", "12", "22", "03", "13", "04"] 32 | self.attention_position = ["31", "22", "13", "04"] 33 | self.non_attention_position = ["01", "11", "21", "02", "12", "03"] 34 | 35 | for j in range(5): 36 | for i in range(5 - j): 37 | # upconv 0 38 | num_ch_in = num_ch_enc[i] 39 | if i == 0 and j != 0: 40 | num_ch_in /= 2 41 | num_ch_out = num_ch_in / 2 42 | self.convs["X_{}{}_Conv_0".format(i, j)] = ConvBlock(num_ch_in, num_ch_out) 43 | 44 | # X_04 upconv 1, only add X_04 convolution 45 | if i == 0 and j == 4: 46 | num_ch_in = num_ch_out 47 | num_ch_out = self.num_ch_dec[i] 48 | self.convs["X_{}{}_Conv_1".format(i, j)] = ConvBlock(num_ch_in, num_ch_out) 49 | 50 | # declare fSEModule and original module 51 | for index in self.attention_position: 52 | row = int(index[0]) 53 | col = int(index[1]) 54 | self.convs["X_" + index + "_attention"] = fSEModule(num_ch_enc[row + 1] // 2, self.num_ch_enc[row] 55 | + self.num_ch_dec[row + 1] * (col - 1)) 56 | for index in self.non_attention_position: 57 | row = int(index[0]) 58 | col = int(index[1]) 59 | if col == 1: 60 | self.convs["X_{}{}_Conv_1".format(row + 1, col - 1)] = ConvBlock(num_ch_enc[row + 1] // 2 + 61 | self.num_ch_enc[row], self.num_ch_dec[row + 1]) 62 | else: 63 | self.convs["X_"+index+"_downsample"] = Conv1x1(num_ch_enc[row+1] // 2 + self.num_ch_enc[row] 64 | + self.num_ch_dec[row+1]*(col-1), self.num_ch_dec[row + 1] * 2) 65 | self.convs["X_{}{}_Conv_1".format(row + 1, col - 1)] = ConvBlock(self.num_ch_dec[row + 1] * 2, self.num_ch_dec[row + 1]) 66 | 67 | for i in range(4): 68 | self.convs["dispconv{}".format(i)] = Conv3x3(self.num_ch_dec[i], self.num_output_channels) 69 | 70 | 71 | self.decoder = nn.ModuleList(list(self.convs.values())) 72 | self.sigmoid = nn.Sigmoid() 73 | 74 | def nestConv(self, conv, high_feature, low_features): 75 | conv_0 = conv[0] 76 | conv_1 = conv[1] 77 | assert isinstance(low_features, list) 78 | high_features = [upsample(conv_0(high_feature))] 79 | for feature in low_features: 80 | high_features.append(feature) 81 | high_features = torch.cat(high_features, 1) 82 | if len(conv) == 3: 83 | high_features = conv[2](high_features) 84 | return conv_1(high_features) 85 | 86 | def forward(self, input_features): 87 | outputs = {} 88 | feat={} 89 | feat[4] = self.convs["f4"](input_features[4]) 90 | feat[3] = self.convs["f3"](input_features[3]) 91 | feat[2] = self.convs["f2"](input_features[2]) 92 | feat[1] = self.convs["f1"](input_features[1]) 93 | feat[0] = input_features[0] 94 | 95 | features = {} 96 | for i in range(5): 97 | features["X_{}0".format(i)] = feat[i] 98 | # Network architecture 99 | for index in self.all_position: 100 | row = int(index[0]) 101 | col = int(index[1]) 102 | 103 | low_features = [] 104 | for i in range(col): 105 | low_features.append(features["X_{}{}".format(row, i)]) 106 | 107 | # add fSE block to decoder 108 | if index in self.attention_position: 109 | features["X_"+index] = self.convs["X_" + index + "_attention"]( 110 | self.convs["X_{}{}_Conv_0".format(row+1, col-1)](features["X_{}{}".format(row+1, col-1)]), low_features) 111 | elif index in self.non_attention_position: 112 | conv = [self.convs["X_{}{}_Conv_0".format(row + 1, col - 1)], 113 | self.convs["X_{}{}_Conv_1".format(row + 1, col - 1)]] 114 | if col != 1: 115 | conv.append(self.convs["X_" + index + "_downsample"]) 116 | features["X_" + index] = self.nestConv(conv, features["X_{}{}".format(row+1, col-1)], low_features) 117 | 118 | x = features["X_04"] 119 | x = self.convs["X_04_Conv_0"](x) 120 | x = self.convs["X_04_Conv_1"](upsample(x)) 121 | outputs[("disp", 0)] = self.sigmoid(self.convs["dispconv0"](x)) 122 | outputs[("disp", 1)] = self.sigmoid(self.convs["dispconv1"](features["X_04"])) 123 | outputs[("disp", 2)] = self.sigmoid(self.convs["dispconv2"](features["X_13"])) 124 | outputs[("disp", 3)] = self.sigmoid(self.convs["dispconv3"](features["X_22"])) 125 | return outputs 126 | 127 | -------------------------------------------------------------------------------- /networks/hr_layers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import numpy as np 4 | import math 5 | 6 | from matplotlib import pyplot as plt 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | 13 | 14 | def upsample(x): 15 | """Upsample input tensor by a factor of 2 16 | """ 17 | return F.interpolate(x, scale_factor=2, mode="nearest") 18 | 19 | 20 | def visual_feature(features,stage): 21 | feature_map = features.squeeze(0).cpu() 22 | n,h,w = feature_map.size() 23 | print(h,w) 24 | list_mean = [] 25 | #sum_feature_map = torch.sum(feature_map,0) 26 | sum_feature_map,_ = torch.max(feature_map,0) 27 | for i in range(n): 28 | list_mean.append(torch.mean(feature_map[i])) 29 | 30 | sum_mean = sum(list_mean) 31 | feature_map_weighted = torch.ones([n,h,w]) 32 | for i in range(n): 33 | feature_map_weighted[i,:,:] = (torch.mean(feature_map[i]) / sum_mean) * feature_map[i,:,:] 34 | sum_feature_map_weighted = torch.sum(feature_map_weighted,0) 35 | plt.imshow(sum_feature_map) 36 | #plt.savefig('feature_viz/{}_stage.png'.format(a)) 37 | plt.savefig('feature_viz/decoder_{}.png'.format(stage)) 38 | plt.imshow(sum_feature_map_weighted) 39 | #plt.savefig('feature_viz/{}_stage_weighted.png'.format(a)) 40 | plt.savefig('feature_viz/decoder_{}_weighted.png'.format(stage)) 41 | 42 | def depth_to_disp(depth, min_depth, max_depth): 43 | min_disp = 1 / max_depth 44 | max_disp = 1 / min_depth 45 | disp = 1 / depth - min_disp 46 | return disp / (max_disp - min_disp) 47 | 48 | def disp_to_depth(disp, min_depth, max_depth): 49 | """Convert network's sigmoid output into depth prediction 50 | The formula for this conversion is given in the 'additional considerations' 51 | section of the paper. 52 | """ 53 | min_disp = 1 / max_depth 54 | max_disp = 1 / min_depth 55 | scaled_disp = min_disp + (max_disp - min_disp) * disp 56 | depth = 1 / scaled_disp 57 | return scaled_disp, depth 58 | 59 | 60 | def transformation_from_parameters(axisangle, translation, invert=False): 61 | """Convert the network's (axisangle, translation) output into a 4x4 matrix 62 | """ 63 | R = rot_from_axisangle(axisangle) 64 | t = translation.clone() 65 | 66 | if invert: 67 | R = R.transpose(1, 2) 68 | t *= -1 69 | 70 | T = get_translation_matrix(t) 71 | 72 | if invert: 73 | M = torch.matmul(R, T) 74 | else: 75 | M = torch.matmul(T, R) 76 | 77 | return M 78 | 79 | 80 | def get_translation_matrix(translation_vector): 81 | """Convert a translation vector into a 4x4 transformation matrix 82 | """ 83 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) 84 | 85 | t = translation_vector.contiguous().view(-1, 3, 1) 86 | 87 | T[:, 0, 0] = 1 88 | T[:, 1, 1] = 1 89 | T[:, 2, 2] = 1 90 | T[:, 3, 3] = 1 91 | T[:, :3, 3, None] = t 92 | 93 | return T 94 | 95 | 96 | def rot_from_axisangle(vec): 97 | """Convert an axisangle rotation into a 4x4 transformation matrix 98 | (adapted from https://github.com/Wallacoloo/printipi) 99 | Input 'vec' has to be Bx1x3 100 | """ 101 | angle = torch.norm(vec, 2, 2, True) 102 | axis = vec / (angle + 1e-7) 103 | 104 | ca = torch.cos(angle) 105 | sa = torch.sin(angle) 106 | C = 1 - ca 107 | 108 | x = axis[..., 0].unsqueeze(1) 109 | y = axis[..., 1].unsqueeze(1) 110 | z = axis[..., 2].unsqueeze(1) 111 | 112 | xs = x * sa 113 | ys = y * sa 114 | zs = z * sa 115 | xC = x * C 116 | yC = y * C 117 | zC = z * C 118 | xyC = x * yC 119 | yzC = y * zC 120 | zxC = z * xC 121 | 122 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) 123 | 124 | rot[:, 0, 0] = torch.squeeze(x * xC + ca) 125 | rot[:, 0, 1] = torch.squeeze(xyC - zs) 126 | rot[:, 0, 2] = torch.squeeze(zxC + ys) 127 | rot[:, 1, 0] = torch.squeeze(xyC + zs) 128 | rot[:, 1, 1] = torch.squeeze(y * yC + ca) 129 | rot[:, 1, 2] = torch.squeeze(yzC - xs) 130 | rot[:, 2, 0] = torch.squeeze(zxC - ys) 131 | rot[:, 2, 1] = torch.squeeze(yzC + xs) 132 | rot[:, 2, 2] = torch.squeeze(z * zC + ca) 133 | rot[:, 3, 3] = 1 134 | 135 | return rot 136 | 137 | class ConvBlock(nn.Module): 138 | """Layer to perform a convolution followed by ELU 139 | """ 140 | def __init__(self, in_channels, out_channels): 141 | super(ConvBlock, self).__init__() 142 | 143 | self.conv = Conv3x3(in_channels, out_channels) 144 | self.nonlin = nn.ELU(inplace=True) 145 | 146 | def forward(self, x): 147 | out = self.conv(x) 148 | out = self.nonlin(out) 149 | return out 150 | 151 | 152 | class Conv3x3(nn.Module): 153 | """Layer to pad and convolve input 154 | """ 155 | def __init__(self, in_channels, out_channels, use_refl=True): 156 | super(Conv3x3, self).__init__() 157 | 158 | if use_refl: 159 | self.pad = nn.ReflectionPad2d(1) 160 | else: 161 | self.pad = nn.ZeroPad2d(1) 162 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 163 | 164 | def forward(self, x): 165 | out = self.pad(x) 166 | out = self.conv(out) 167 | return out 168 | 169 | class Conv1x1(nn.Module): 170 | def __init__(self, in_channels, out_channels): 171 | super(Conv1x1, self).__init__() 172 | 173 | self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=False) 174 | 175 | def forward(self, x): 176 | return self.conv(x) 177 | 178 | class ASPP(nn.Module): 179 | def __init__(self, in_channels, out_channels): 180 | super(ASPP, self).__init__() 181 | 182 | self.atrous_block1 = nn.Conv2d(in_channels, out_channels, 1, 1) 183 | self.atrous_block6 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=6, dilation=6) 184 | self.atrous_block12 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=12, dilation=12) 185 | self.atrous_block18 = nn.Conv2d(in_channels, out_channels, 3, 1, padding=18, dilation=18) 186 | 187 | self.conv1x1 = nn.Conv2d(out_channels*4, out_channels, 1, 1) 188 | 189 | def forward(self, features): 190 | features_1 = self.atrous_block18(features[0]) 191 | features_2 = self.atrous_block12(features[1]) 192 | features_3 = self.atrous_block6(features[2]) 193 | features_4 = self.atrous_block1(features[3]) 194 | 195 | output_feature = [features_1, features_2, features_3, features_4] 196 | output_feature = torch.cat(output_feature, 1) 197 | 198 | return self.conv1x1(output_feature) 199 | 200 | class BackprojectDepth(nn.Module): 201 | """Layer to transform a depth image into a point cloud 202 | """ 203 | def __init__(self, batch_size, height, width): 204 | super(BackprojectDepth, self).__init__() 205 | 206 | self.batch_size = batch_size 207 | self.height = height 208 | self.width = width 209 | 210 | # Prepare Coordinates shape [b,3,h*w] 211 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') 212 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 213 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), 214 | requires_grad=False) 215 | 216 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), 217 | requires_grad=False) 218 | 219 | self.pix_coords = torch.unsqueeze(torch.stack( 220 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) 221 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) 222 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), 223 | requires_grad=False) 224 | 225 | def forward(self, depth, inv_K): 226 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) 227 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 228 | cam_points = torch.cat([cam_points, self.ones], 1) 229 | 230 | return cam_points 231 | 232 | 233 | class Project3D(nn.Module): 234 | """Layer which projects 3D points into a camera with intrinsics K and at position T 235 | """ 236 | def __init__(self, batch_size, height, width, eps=1e-7): 237 | super(Project3D, self).__init__() 238 | 239 | self.batch_size = batch_size 240 | self.height = height 241 | self.width = width 242 | self.eps = eps 243 | 244 | def forward(self, points, K, T): 245 | P = torch.matmul(K, T)[:, :3, :] 246 | 247 | cam_points = torch.matmul(P, points) 248 | 249 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) 250 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) 251 | pix_coords = pix_coords.permute(0, 2, 3, 1) 252 | # normalize 253 | pix_coords[..., 0] /= self.width - 1 254 | pix_coords[..., 1] /= self.height - 1 255 | pix_coords = (pix_coords - 0.5) * 2 256 | return pix_coords 257 | 258 | 259 | def upsample(x): 260 | """Upsample input tensor by a factor of 2 261 | """ 262 | return F.interpolate(x, scale_factor=2, mode="nearest") 263 | 264 | def get_smooth_loss(disp, img): 265 | """Computes the smoothness loss for a disparity image 266 | The color image is used for edge-aware smoothness 267 | """ 268 | grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) 269 | grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) 270 | 271 | grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) 272 | grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) 273 | 274 | grad_disp_x *= torch.exp(-grad_img_x) 275 | grad_disp_y *= torch.exp(-grad_img_y) 276 | 277 | return grad_disp_x.mean() + grad_disp_y.mean() 278 | 279 | 280 | class SSIM(nn.Module): 281 | """Layer to compute the SSIM loss between a pair of images 282 | """ 283 | def __init__(self): 284 | super(SSIM, self).__init__() 285 | self.mu_x_pool = nn.AvgPool2d(3, 1) 286 | self.mu_y_pool = nn.AvgPool2d(3, 1) 287 | self.sig_x_pool = nn.AvgPool2d(3, 1) 288 | self.sig_y_pool = nn.AvgPool2d(3, 1) 289 | self.sig_xy_pool = nn.AvgPool2d(3, 1) 290 | 291 | self.refl = nn.ReflectionPad2d(1) 292 | 293 | self.C1 = 0.01 ** 2 294 | self.C2 = 0.03 ** 2 295 | 296 | def forward(self, x, y): 297 | x = self.refl(x) 298 | y = self.refl(y) 299 | 300 | mu_x = self.mu_x_pool(x) 301 | mu_y = self.mu_y_pool(y) 302 | 303 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 304 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 305 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y 306 | 307 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) 308 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) 309 | 310 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 311 | 312 | 313 | def compute_depth_errors(gt, pred): 314 | """Computation of error metrics between predicted and ground truth depths 315 | """ 316 | thresh = torch.max((gt / pred), (pred / gt)) 317 | a1 = (thresh < 1.25 ).float().mean() 318 | a2 = (thresh < 1.25 ** 2).float().mean() 319 | a3 = (thresh < 1.25 ** 3).float().mean() 320 | 321 | rmse = (gt - pred) ** 2 322 | rmse = torch.sqrt(rmse.mean()) 323 | 324 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 325 | rmse_log = torch.sqrt(rmse_log.mean()) 326 | 327 | abs_rel = torch.mean(torch.abs(gt - pred) / gt) 328 | 329 | sq_rel = torch.mean((gt - pred) ** 2 / gt) 330 | 331 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 332 | 333 | class SE_block(nn.Module): 334 | def __init__(self, in_channel, visual_weights = False, reduction = 16 ): 335 | super(SE_block, self).__init__() 336 | reduction = reduction 337 | in_channel = in_channel 338 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 339 | self.max_pool = nn.AdaptiveMaxPool2d(1) 340 | self.fc = nn.Sequential( 341 | nn.Linear(in_channel, in_channel // reduction, bias = False), 342 | nn.ReLU(inplace = True), 343 | nn.Linear(in_channel // reduction, in_channel, bias = False) 344 | ) 345 | self.sigmoid = nn.Sigmoid() 346 | self.relu = nn.ReLU(inplace = True) 347 | self.vis = False 348 | 349 | def forward(self, in_feature): 350 | 351 | b,c,_,_ = in_feature.size() 352 | output_weights_avg = self.avg_pool(in_feature).view(b,c) 353 | output_weights_max = self.max_pool(in_feature).view(b,c) 354 | output_weights_avg = self.fc(output_weights_avg).view(b,c,1,1) 355 | output_weights_max = self.fc(output_weights_max).view(b,c,1,1) 356 | output_weights = output_weights_avg + output_weights_max 357 | output_weights = self.sigmoid(output_weights) 358 | return output_weights.expand_as(in_feature) * in_feature 359 | 360 | ## ChannelAttetion 361 | class ChannelAttention(nn.Module): 362 | def __init__(self, in_planes, ratio=16): 363 | super(ChannelAttention, self).__init__() 364 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 365 | 366 | self.fc = nn.Sequential( 367 | nn.Linear(in_planes,in_planes // ratio, bias = False), 368 | nn.ReLU(inplace = True), 369 | nn.Linear(in_planes // ratio, in_planes, bias = False) 370 | ) 371 | self.sigmoid = nn.Sigmoid() 372 | for m in self.modules(): 373 | if isinstance(m, nn.Conv2d): 374 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 375 | 376 | def forward(self, in_feature): 377 | x = in_feature 378 | b, c, _, _ = in_feature.size() 379 | avg_out = self.fc(self.avg_pool(x).view(b,c)).view(b, c, 1, 1) 380 | out = avg_out 381 | return self.sigmoid(out).expand_as(in_feature) * in_feature 382 | 383 | ## SpatialAttetion 384 | 385 | class SpatialAttention(nn.Module): 386 | def __init__(self, kernel_size=7): 387 | super(SpatialAttention, self).__init__() 388 | 389 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) 390 | self.sigmoid = nn.Sigmoid() 391 | 392 | for m in self.modules(): 393 | if isinstance(m, nn.Conv2d): 394 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 395 | def forward(self, in_feature): 396 | x = in_feature 397 | avg_out = torch.mean(x, dim=1, keepdim=True) 398 | max_out, _ = torch.max(x, dim=1, keepdim=True) 399 | x = torch.cat([avg_out, max_out], dim=1) 400 | #x = avg_out 401 | #x = max_out 402 | x = self.conv1(x) 403 | return self.sigmoid(x).expand_as(in_feature) * in_feature 404 | 405 | 406 | #CS means channel-spatial 407 | class CS_Block(nn.Module): 408 | def __init__(self, in_channel, reduction = 16 ): 409 | super(CS_Block, self).__init__() 410 | 411 | reduction = reduction 412 | in_channel = in_channel 413 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 414 | self.max_pool = nn.AdaptiveMaxPool2d(1) 415 | self.fc = nn.Sequential( 416 | nn.Linear(in_channel, in_channel // reduction, bias = False), 417 | nn.ReLU(inplace = True), 418 | nn.Linear(in_channel // reduction, in_channel, bias = False) 419 | ) 420 | self.sigmoid = nn.Sigmoid() 421 | ## Spatial_Block 422 | self.conv = nn.Conv2d(2,1,kernel_size = 1, bias = False) 423 | #self.conv = nn.Conv2d(1,1,kernel_size = 1, bias = False) 424 | self.relu = nn.ReLU(inplace = True) 425 | 426 | def forward(self, in_feature): 427 | 428 | b,c,_,_ = in_feature.size() 429 | 430 | 431 | output_weights_avg = self.avg_pool(in_feature).view(b,c) 432 | output_weights_max = self.max_pool(in_feature).view(b,c) 433 | 434 | output_weights_avg = self.fc(output_weights_avg).view(b,c,1,1) 435 | output_weights_max = self.fc(output_weights_max).view(b,c,1,1) 436 | 437 | output_weights = output_weights_avg + output_weights_max 438 | 439 | output_weights = self.sigmoid(output_weights) 440 | out_feature_1 = output_weights.expand_as(in_feature) * in_feature 441 | 442 | ## Spatial_Block 443 | in_feature_avg = torch.mean(out_feature_1,1,True) 444 | in_feature_max,_ = torch.max(out_feature_1,1,True) 445 | mixed_feature = torch.cat([in_feature_avg,in_feature_max],1) 446 | spatial_attention = self.sigmoid(self.conv(mixed_feature)) 447 | out_feature = spatial_attention.expand_as(out_feature_1) * out_feature_1 448 | ######################### 449 | 450 | return out_feature 451 | 452 | class Attention_Module(nn.Module): 453 | def __init__(self, high_feature_channel, output_channel = None): 454 | super(Attention_Module, self).__init__() 455 | in_channel = high_feature_channel 456 | out_channel = high_feature_channel 457 | if output_channel is not None: 458 | out_channel = output_channel 459 | channel = in_channel 460 | self.ca = ChannelAttention(channel) 461 | #self.sa = SpatialAttention() 462 | #self.cs = CS_Block(channel) 463 | self.conv_se = nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = 3, stride = 1, padding = 1 ) 464 | self.relu = nn.ReLU(inplace = True) 465 | 466 | def forward(self, high_features): 467 | 468 | features = high_features 469 | 470 | features = self.ca(features) 471 | #features = self.sa(features) 472 | #features = self.cs(features) 473 | 474 | return self.relu(self.conv_se(features)) 475 | 476 | class fSEModule(nn.Module): 477 | def __init__(self, high_feature_channel, low_feature_channels, output_channel=None): 478 | super(fSEModule, self).__init__() 479 | in_channel = high_feature_channel + low_feature_channels 480 | out_channel = high_feature_channel 481 | if output_channel is not None: 482 | out_channel = output_channel 483 | reduction = 16 484 | channel = in_channel 485 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 486 | 487 | self.fc = nn.Sequential( 488 | nn.Linear(channel, channel // reduction, bias=False), 489 | nn.ReLU(inplace=True), 490 | nn.Linear(channel // reduction, channel, bias=False) 491 | ) 492 | 493 | self.sigmoid = nn.Sigmoid() 494 | 495 | self.conv_se = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1) 496 | self.relu = nn.ReLU(inplace=True) 497 | 498 | def forward(self, high_features, low_features): 499 | features = [upsample(high_features)] 500 | features += low_features 501 | features = torch.cat(features, 1) 502 | 503 | b, c, _, _ = features.size() 504 | y = self.avg_pool(features).view(b, c) 505 | y = self.fc(y).view(b, c, 1, 1) 506 | 507 | y = self.sigmoid(y) 508 | features = features * y.expand_as(features) 509 | 510 | return self.relu(self.conv_se(features)) -------------------------------------------------------------------------------- /networks/mpvit.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- 2 | # MPViT: Multi-Path Vision Transformer for Dense Prediction 3 | # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI). 4 | # All Rights Reserved. 5 | # Written by Youngwan Lee 6 | # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # -------------------------------------------------------------------------------- 9 | # References: 10 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 11 | # CoaT: https://github.com/mlpc-ucsd/CoaT 12 | # -------------------------------------------------------------------------------- 13 | 14 | 15 | import numpy as np 16 | import math 17 | 18 | import torch 19 | 20 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 21 | from timm.models.layers import DropPath, trunc_normal_ 22 | 23 | from einops import rearrange 24 | from functools import partial 25 | from torch import nn, einsum 26 | from torch.nn.modules.batchnorm import _BatchNorm 27 | 28 | from mmcv.runner import load_checkpoint,load_state_dict 29 | from mmcv.cnn import build_norm_layer 30 | 31 | from mmseg.utils import get_root_logger 32 | from mmseg.models.builder import BACKBONES 33 | 34 | __all__ = [ 35 | "mpvit_tiny", 36 | "mpvit_xsmall", 37 | "mpvit_small", 38 | "mpvit_base", 39 | ] 40 | 41 | def _cfg_mpvit(url="", **kwargs): 42 | return { 43 | "url": url, 44 | "num_classes": 1000, 45 | "input_size": (3, 224, 224), 46 | "pool_size": None, 47 | "crop_pct": 0.9, 48 | "interpolation": "bicubic", 49 | "mean": IMAGENET_DEFAULT_MEAN, 50 | "std": IMAGENET_DEFAULT_STD, 51 | "first_conv": "patch_embed.proj", 52 | "classifier": "head", 53 | **kwargs, 54 | } 55 | 56 | 57 | class Mlp(nn.Module): 58 | """Feed-forward network (FFN, a.k.a. MLP) class.""" 59 | 60 | def __init__( 61 | self, 62 | in_features, 63 | hidden_features=None, 64 | out_features=None, 65 | act_layer=nn.GELU, 66 | drop=0.0, 67 | ): 68 | super().__init__() 69 | out_features = out_features or in_features 70 | hidden_features = hidden_features or in_features 71 | self.fc1 = nn.Linear(in_features, hidden_features) 72 | self.act = act_layer() 73 | self.fc2 = nn.Linear(hidden_features, out_features) 74 | self.drop = nn.Dropout(drop) 75 | 76 | def forward(self, x): 77 | x = self.fc1(x) 78 | x = self.act(x) 79 | x = self.drop(x) 80 | x = self.fc2(x) 81 | x = self.drop(x) 82 | return x 83 | 84 | 85 | class Conv2d_BN(nn.Module): 86 | def __init__( 87 | self, 88 | in_ch, 89 | out_ch, 90 | kernel_size=1, 91 | stride=1, 92 | pad=0, 93 | dilation=1, 94 | groups=1, 95 | bn_weight_init=1, 96 | act_layer=None, 97 | norm_cfg=dict(type="BN"), 98 | ): 99 | super().__init__() 100 | # self.add_module('c', torch.nn.Conv2d( 101 | # a, b, ks, stride, pad, dilation, groups, bias=False)) 102 | self.conv = torch.nn.Conv2d( 103 | in_ch, out_ch, kernel_size, stride, pad, dilation, groups, bias=False 104 | ) 105 | self.bn = build_norm_layer(norm_cfg, out_ch)[1] 106 | 107 | torch.nn.init.constant_(self.bn.weight, bn_weight_init) 108 | torch.nn.init.constant_(self.bn.bias, 0) 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | # Note that there is no bias due to BN 112 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out)) 114 | 115 | self.act_layer = act_layer() if act_layer is not None else nn.Identity() 116 | 117 | def forward(self, x): 118 | x = self.conv(x) 119 | x = self.bn(x) 120 | x = self.act_layer(x) 121 | 122 | return x 123 | 124 | 125 | class DWConv2d_BN(nn.Module): 126 | """ 127 | Depthwise Separable Conv 128 | """ 129 | 130 | def __init__( 131 | self, 132 | in_ch, 133 | out_ch, 134 | kernel_size=1, 135 | stride=1, 136 | norm_layer=nn.BatchNorm2d, 137 | act_layer=nn.Hardswish, 138 | bn_weight_init=1, 139 | norm_cfg=dict(type="BN"), 140 | ): 141 | super().__init__() 142 | 143 | # dw 144 | self.dwconv = nn.Conv2d( 145 | in_ch, 146 | out_ch, 147 | kernel_size, 148 | stride, 149 | (kernel_size - 1) // 2, 150 | groups=out_ch, 151 | bias=False, 152 | ) 153 | # pw-linear 154 | self.pwconv = nn.Conv2d(out_ch, out_ch, 1, 1, 0, bias=False) 155 | self.bn = build_norm_layer(norm_cfg, out_ch)[1] 156 | self.act = act_layer() if act_layer is not None else nn.Identity() 157 | 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 161 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 162 | if m.bias is not None: 163 | m.bias.data.zero_() 164 | elif isinstance(m, nn.BatchNorm2d): 165 | m.weight.data.fill_(bn_weight_init) 166 | m.bias.data.zero_() 167 | 168 | def forward(self, x): 169 | 170 | x = self.dwconv(x) 171 | x = self.pwconv(x) 172 | x = self.bn(x) 173 | x = self.act(x) 174 | 175 | return x 176 | 177 | 178 | class DWCPatchEmbed(nn.Module): 179 | """ 180 | Depthwise Convolutional Patch Embedding layer 181 | Image to Patch Embedding 182 | """ 183 | 184 | def __init__( 185 | self, 186 | in_chans=3, 187 | embed_dim=768, 188 | patch_size=16, 189 | stride=1, 190 | pad=0, 191 | act_layer=nn.Hardswish, 192 | norm_cfg=dict(type="BN"), 193 | ): 194 | super().__init__() 195 | 196 | # TODO : confirm whether act_layer is effective or not 197 | self.patch_conv = DWConv2d_BN( 198 | in_chans, 199 | embed_dim, 200 | kernel_size=patch_size, 201 | stride=stride, 202 | act_layer=nn.Hardswish, 203 | norm_cfg=norm_cfg, 204 | ) 205 | 206 | def forward(self, x): 207 | x = self.patch_conv(x) 208 | 209 | return x 210 | 211 | 212 | class Patch_Embed_stage(nn.Module): 213 | def __init__(self, embed_dim, num_path=4, isPool=False, norm_cfg=dict(type="BN")): 214 | super(Patch_Embed_stage, self).__init__() 215 | 216 | self.patch_embeds = nn.ModuleList( 217 | [ 218 | DWCPatchEmbed( 219 | in_chans=embed_dim, 220 | embed_dim=embed_dim, 221 | patch_size=3, 222 | stride=2 if isPool and idx == 0 else 1, 223 | pad=1, 224 | norm_cfg=norm_cfg, 225 | ) 226 | for idx in range(num_path) 227 | ] 228 | ) 229 | 230 | # scale 231 | 232 | def forward(self, x): 233 | att_inputs = [] 234 | for pe in self.patch_embeds: 235 | x = pe(x) 236 | att_inputs.append(x) 237 | 238 | return att_inputs 239 | 240 | 241 | class ConvPosEnc(nn.Module): 242 | """Convolutional Position Encoding. 243 | Note: This module is similar to the conditional position encoding in CPVT. 244 | """ 245 | 246 | def __init__(self, dim, k=3): 247 | super(ConvPosEnc, self).__init__() 248 | 249 | self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim) 250 | 251 | def forward(self, x, size): 252 | B, N, C = x.shape 253 | H, W = size 254 | 255 | feat = x.transpose(1, 2).contiguous().view(B, C, H, W) 256 | x = self.proj(feat) + feat 257 | x = x.flatten(2).transpose(1, 2).contiguous() 258 | 259 | return x 260 | 261 | 262 | class ConvRelPosEnc(nn.Module): 263 | """Convolutional relative position encoding.""" 264 | def __init__(self, Ch, h, window): 265 | """Initialization. 266 | 267 | Ch: Channels per head. 268 | h: Number of heads. 269 | window: Window size(s) in convolutional relative positional encoding. 270 | It can have two forms: 271 | 1. An integer of window size, which assigns all attention heads 272 | with the same window size in ConvRelPosEnc. 273 | 2. A dict mapping window size to #attention head splits 274 | (e.g. {window size 1: #attention head split 1, window size 275 | 2: #attention head split 2}) 276 | It will apply different window size to 277 | the attention head splits. 278 | """ 279 | super().__init__() 280 | 281 | if isinstance(window, int): 282 | # Set the same window size for all attention heads. 283 | window = {window: h} 284 | self.window = window 285 | elif isinstance(window, dict): 286 | self.window = window 287 | else: 288 | raise ValueError() 289 | 290 | self.conv_list = nn.ModuleList() 291 | self.head_splits = [] 292 | for cur_window, cur_head_split in window.items(): 293 | dilation = 1 # Use dilation=1 at default. 294 | padding_size = (cur_window + (cur_window - 1) * 295 | (dilation - 1)) // 2 296 | cur_conv = nn.Conv2d( 297 | cur_head_split * Ch, 298 | cur_head_split * Ch, 299 | kernel_size=(cur_window, cur_window), 300 | padding=(padding_size, padding_size), 301 | dilation=(dilation, dilation), 302 | groups=cur_head_split * Ch, 303 | ) 304 | self.conv_list.append(cur_conv) 305 | self.head_splits.append(cur_head_split) 306 | self.channel_splits = [x * Ch for x in self.head_splits] 307 | 308 | def forward(self, q, v, size): 309 | """foward function""" 310 | B, h, N, Ch = q.shape 311 | H, W = size 312 | 313 | # We don't use CLS_TOKEN 314 | q_img = q 315 | v_img = v 316 | 317 | # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W]. 318 | v_img = rearrange(v_img, "B h (H W) Ch -> B (h Ch) H W", H=H, W=W) 319 | # Split according to channels. 320 | v_img_list = torch.split(v_img, self.channel_splits, dim=1) 321 | conv_v_img_list = [ 322 | conv(x) for conv, x in zip(self.conv_list, v_img_list) 323 | ] 324 | conv_v_img = torch.cat(conv_v_img_list, dim=1) 325 | # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch]. 326 | conv_v_img = rearrange(conv_v_img, "B (h Ch) H W -> B h (H W) Ch", h=h) 327 | 328 | EV_hat_img = q_img * conv_v_img 329 | EV_hat = EV_hat_img 330 | return EV_hat 331 | 332 | 333 | class FactorAtt_ConvRelPosEnc(nn.Module): 334 | """Factorized attention with convolutional relative position encoding class.""" 335 | 336 | def __init__( 337 | self, 338 | dim, 339 | num_heads=8, 340 | qkv_bias=False, 341 | qk_scale=None, 342 | attn_drop=0.0, 343 | proj_drop=0.0, 344 | shared_crpe=None, 345 | ): 346 | super().__init__() 347 | self.num_heads = num_heads 348 | head_dim = dim // num_heads 349 | self.scale = qk_scale or head_dim ** -0.5 350 | 351 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 352 | self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. 353 | self.proj = nn.Linear(dim, dim) 354 | self.proj_drop = nn.Dropout(proj_drop) 355 | 356 | # Shared convolutional relative position encoding. 357 | self.crpe = shared_crpe 358 | 359 | def forward(self, x, size): 360 | B, N, C = x.shape 361 | 362 | # Generate Q, K, V. 363 | qkv = ( 364 | self.qkv(x) 365 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 366 | .permute(2, 0, 3, 1, 4) 367 | .contiguous() 368 | ) # Shape: [3, B, h, N, Ch]. 369 | q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, h, N, Ch]. 370 | 371 | # Factorized attention. 372 | k_softmax = k.softmax(dim=2) # Softmax on dim N. 373 | k_softmax_T_dot_v = einsum( 374 | "b h n k, b h n v -> b h k v", k_softmax, v 375 | ) # Shape: [B, h, Ch, Ch]. 376 | factor_att = einsum( 377 | "b h n k, b h k v -> b h n v", q, k_softmax_T_dot_v 378 | ) # Shape: [B, h, N, Ch]. 379 | 380 | # Convolutional relative position encoding. 381 | crpe = self.crpe(q, v, size=size) # Shape: [B, h, N, Ch]. 382 | 383 | # Merge and reshape. 384 | x = self.scale * factor_att + crpe 385 | x = ( 386 | x.transpose(1, 2).reshape(B, N, C).contiguous() 387 | ) # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]. 388 | 389 | # Output projection. 390 | x = self.proj(x) 391 | x = self.proj_drop(x) 392 | 393 | return x 394 | 395 | 396 | class MHCABlock(nn.Module): 397 | def __init__( 398 | self, 399 | dim, 400 | num_heads, 401 | mlp_ratio=3, 402 | drop_path=0.0, 403 | qkv_bias=True, 404 | qk_scale=None, 405 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 406 | shared_cpe=None, 407 | shared_crpe=None, 408 | ): 409 | super().__init__() 410 | 411 | self.cpe = shared_cpe 412 | self.crpe = shared_crpe 413 | self.factoratt_crpe = FactorAtt_ConvRelPosEnc( 414 | dim, 415 | num_heads=num_heads, 416 | qkv_bias=qkv_bias, 417 | qk_scale=qk_scale, 418 | shared_crpe=shared_crpe, 419 | ) 420 | self.mlp = Mlp(in_features=dim, hidden_features=dim * mlp_ratio) 421 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 422 | 423 | self.norm1 = norm_layer(dim) 424 | self.norm2 = norm_layer(dim) 425 | 426 | def forward(self, x, size): 427 | # x.shape = [B, N, C] 428 | 429 | if self.cpe is not None: 430 | x = self.cpe(x, size) 431 | cur = self.norm1(x) 432 | x = x + self.drop_path(self.factoratt_crpe(cur, size)) 433 | 434 | cur = self.norm2(x) 435 | x = x + self.drop_path(self.mlp(cur)) 436 | return x 437 | 438 | 439 | class MHCAEncoder(nn.Module): 440 | def __init__( 441 | self, 442 | dim, 443 | num_layers=1, 444 | num_heads=8, 445 | mlp_ratio=3, 446 | drop_path_list=[], 447 | qk_scale=None, 448 | crpe_window={3: 2, 5: 3, 7: 3}, 449 | ): 450 | super().__init__() 451 | 452 | self.num_layers = num_layers 453 | self.cpe = ConvPosEnc(dim, k=3) 454 | self.crpe = ConvRelPosEnc(Ch=dim // num_heads, h=num_heads, window=crpe_window) 455 | self.MHCA_layers = nn.ModuleList( 456 | [ 457 | MHCABlock( 458 | dim, 459 | num_heads=num_heads, 460 | mlp_ratio=mlp_ratio, 461 | drop_path=drop_path_list[idx], 462 | qk_scale=qk_scale, 463 | shared_cpe=self.cpe, 464 | shared_crpe=self.crpe, 465 | ) 466 | for idx in range(self.num_layers) 467 | ] 468 | ) 469 | 470 | def forward(self, x, size): 471 | H, W = size 472 | B = x.shape[0] 473 | # x' shape : [B, N, C] 474 | for layer in self.MHCA_layers: 475 | x = layer(x, (H, W)) 476 | 477 | # return x's shape : [B, N, C] -> [B, C, H, W] 478 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 479 | return x 480 | 481 | 482 | class ResBlock(nn.Module): 483 | def __init__( 484 | self, 485 | in_features, 486 | hidden_features=None, 487 | out_features=None, 488 | act_layer=nn.Hardswish, 489 | norm_cfg=dict(type="BN"), 490 | ): 491 | super().__init__() 492 | 493 | out_features = out_features or in_features 494 | hidden_features = hidden_features or in_features 495 | self.conv1 = Conv2d_BN( 496 | in_features, hidden_features, act_layer=act_layer, norm_cfg=norm_cfg 497 | ) 498 | self.dwconv = nn.Conv2d( 499 | hidden_features, 500 | hidden_features, 501 | 3, 502 | 1, 503 | 1, 504 | bias=False, 505 | groups=hidden_features, 506 | ) 507 | # self.norm = norm_layer(hidden_features) 508 | self.norm = build_norm_layer(norm_cfg, hidden_features)[1] 509 | self.act = act_layer() 510 | self.conv2 = Conv2d_BN(hidden_features, out_features, norm_cfg=norm_cfg) 511 | self.apply(self._init_weights) 512 | 513 | def _init_weights(self, m): 514 | if isinstance(m, nn.Conv2d): 515 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 516 | fan_out //= m.groups 517 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 518 | if m.bias is not None: 519 | m.bias.data.zero_() 520 | elif isinstance(m, nn.BatchNorm2d): 521 | m.weight.data.fill_(1) 522 | m.bias.data.zero_() 523 | 524 | def forward(self, x): 525 | identity = x 526 | feat = self.conv1(x) 527 | feat = self.dwconv(feat) 528 | feat = self.norm(feat) 529 | feat = self.act(feat) 530 | feat = self.conv2(feat) 531 | 532 | return identity + feat 533 | 534 | 535 | class MHCA_stage(nn.Module): 536 | def __init__( 537 | self, 538 | embed_dim, 539 | out_embed_dim, 540 | num_layers=1, 541 | num_heads=8, 542 | mlp_ratio=3, 543 | num_path=4, 544 | norm_cfg=dict(type="BN"), 545 | drop_path_list=[], 546 | ): 547 | super().__init__() 548 | 549 | self.mhca_blks = nn.ModuleList( 550 | [ 551 | MHCAEncoder( 552 | embed_dim, 553 | num_layers, 554 | num_heads, 555 | mlp_ratio, 556 | drop_path_list=drop_path_list, 557 | ) 558 | for _ in range(num_path) 559 | ] 560 | ) 561 | 562 | self.InvRes = ResBlock( 563 | in_features=embed_dim, out_features=embed_dim, norm_cfg=norm_cfg 564 | ) 565 | self.aggregate = Conv2d_BN( 566 | embed_dim * (num_path + 1), 567 | out_embed_dim, 568 | act_layer=nn.Hardswish, 569 | norm_cfg=norm_cfg, 570 | ) 571 | 572 | def forward(self, inputs): 573 | att_outputs = [self.InvRes(inputs[0])] 574 | for x, encoder in zip(inputs, self.mhca_blks): 575 | # [B, C, H, W] -> [B, N, C] 576 | _, _, H, W = x.shape 577 | x = x.flatten(2).transpose(1, 2).contiguous() 578 | att_outputs.append(encoder(x, size=(H, W))) 579 | 580 | out_concat = torch.cat(att_outputs, dim=1) 581 | out = self.aggregate(out_concat) 582 | 583 | return out,att_outputs 584 | 585 | 586 | def dpr_generator(drop_path_rate, num_layers, num_stages): 587 | """ 588 | Generate drop path rate list following linear decay rule 589 | """ 590 | dpr_list = [x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers))] 591 | dpr = [] 592 | cur = 0 593 | for i in range(num_stages): 594 | dpr_per_stage = dpr_list[cur : cur + num_layers[i]] 595 | dpr.append(dpr_per_stage) 596 | cur += num_layers[i] 597 | 598 | return dpr 599 | 600 | 601 | @BACKBONES.register_module() 602 | class MPViT(nn.Module): 603 | """Multi-Path ViT class.""" 604 | 605 | def __init__( 606 | self, 607 | num_classes=80, 608 | in_chans=3, 609 | num_stages=4, 610 | num_layers=[1, 1, 1, 1], 611 | mlp_ratios=[8, 8, 4, 4], 612 | num_path=[4, 4, 4, 4], 613 | embed_dims=[64, 128, 256, 512], 614 | num_heads=[8, 8, 8, 8], 615 | drop_path_rate=0.2, 616 | norm_cfg=dict(type="BN"), 617 | norm_eval=False, 618 | pretrained=None, 619 | ): 620 | super().__init__() 621 | 622 | self.num_classes = num_classes 623 | self.num_stages = num_stages 624 | self.conv_norm_cfg = norm_cfg 625 | self.norm_eval = norm_eval 626 | 627 | dpr = dpr_generator(drop_path_rate, num_layers, num_stages) 628 | 629 | self.stem = nn.Sequential( 630 | Conv2d_BN( 631 | in_chans, 632 | embed_dims[0] // 2, 633 | kernel_size=3, 634 | stride=2, 635 | pad=1, 636 | act_layer=nn.Hardswish, 637 | norm_cfg=self.conv_norm_cfg, 638 | ), 639 | Conv2d_BN( 640 | embed_dims[0] // 2, 641 | embed_dims[0], 642 | kernel_size=3, 643 | stride=1, 644 | pad=1, 645 | act_layer=nn.Hardswish, 646 | norm_cfg=self.conv_norm_cfg, 647 | ), 648 | ) 649 | 650 | # Patch embeddings. 651 | self.patch_embed_stages = nn.ModuleList( 652 | [ 653 | Patch_Embed_stage( 654 | embed_dims[idx], 655 | num_path=num_path[idx], 656 | isPool= True, 657 | norm_cfg=self.conv_norm_cfg, 658 | ) 659 | for idx in range(self.num_stages) 660 | ] 661 | ) 662 | 663 | # Multi-Head Convolutional Self-Attention (MHCA) 664 | self.mhca_stages = nn.ModuleList( 665 | [ 666 | MHCA_stage( 667 | embed_dims[idx], 668 | embed_dims[idx + 1] 669 | if not (idx + 1) == self.num_stages 670 | else embed_dims[idx], 671 | num_layers[idx], 672 | num_heads[idx], 673 | mlp_ratios[idx], 674 | num_path[idx], 675 | norm_cfg=self.conv_norm_cfg, 676 | drop_path_list=dpr[idx], 677 | ) 678 | for idx in range(self.num_stages) 679 | ] 680 | ) 681 | 682 | def init_weights(self, pretrained=None): 683 | """Initialize the weights in backbone. 684 | 685 | Args: 686 | pretrained (str, optional): Path to pre-trained weights. 687 | Defaults to None. 688 | """ 689 | 690 | def _init_weights(m): 691 | if isinstance(m, nn.Linear): 692 | trunc_normal_(m.weight, std=0.02) 693 | if isinstance(m, nn.Linear) and m.bias is not None: 694 | nn.init.constant_(m.bias, 0) 695 | elif isinstance(m, nn.LayerNorm): 696 | nn.init.constant_(m.bias, 0) 697 | nn.init.constant_(m.weight, 1.0) 698 | 699 | if isinstance(pretrained, str): 700 | self.apply(_init_weights) 701 | logger = get_root_logger() 702 | load_checkpoint(self, pretrained, strict=False, logger=logger) 703 | elif pretrained is None: 704 | self.apply(_init_weights) 705 | else: 706 | raise TypeError("pretrained must be a str or None") 707 | 708 | def forward_features(self, x): 709 | 710 | # x's shape : [B, C, H, W] 711 | outs = [] 712 | x = self.stem(x) # Shape : [B, C, H/4, W/4] 713 | outs.append(x) 714 | for idx in range(self.num_stages): 715 | att_inputs = self.patch_embed_stages[idx](x) 716 | #outs.append(att_inputs) 717 | x,ff = self.mhca_stages[idx](att_inputs) 718 | outs.append(x) 719 | 720 | 721 | return outs 722 | 723 | def forward(self, x): 724 | x = self.forward_features(x) 725 | 726 | return x 727 | 728 | def train(self, mode=True): 729 | """Convert the model into training mode while keep normalization layer 730 | freezed.""" 731 | super(MPViT, self).train(mode) 732 | if mode and self.norm_eval: 733 | for m in self.modules(): 734 | # trick: eval have effect on BatchNorm only 735 | if isinstance(m, _BatchNorm): 736 | m.eval() 737 | 738 | 739 | def mpvit_tiny(**kwargs): 740 | """mpvit_tiny : 741 | 742 | - #paths : [2, 3, 3, 3] 743 | - #layers : [1, 2, 4, 1] 744 | - #channels : [64, 96, 176, 216] 745 | - MLP_ratio : 2 746 | Number of params: 5843736 747 | FLOPs : 1654163812 748 | Activations : 16641952 749 | """ 750 | 751 | model = MPViT( 752 | num_stages=4, 753 | num_path=[2, 3, 3, 3], 754 | num_layers=[1, 2, 4, 1], 755 | embed_dims=[64, 96, 176, 216], 756 | mlp_ratios=[2, 2, 2, 2], 757 | num_heads=[8, 8, 8, 8], 758 | **kwargs, 759 | ) 760 | model.default_cfg = _cfg_mpvit() 761 | return model 762 | 763 | 764 | def mpvit_xsmall(**kwargs): 765 | """mpvit_xsmall : 766 | 767 | - #paths : [2, 3, 3, 3] 768 | - #layers : [1, 2, 4, 1] 769 | - #channels : [64, 128, 192, 256] 770 | - MLP_ratio : 4 771 | Number of params : 10573448 772 | FLOPs : 2971396560 773 | Activations : 21983464 774 | """ 775 | 776 | model = MPViT( 777 | num_stages=4, 778 | num_path=[2, 3, 3, 3], 779 | num_layers=[1, 2, 4, 1], 780 | embed_dims=[64, 128, 192, 256], 781 | mlp_ratios=[4, 4, 4, 4], 782 | num_heads=[8, 8, 8, 8], 783 | **kwargs, 784 | ) 785 | checkpoint = torch.load('./ckpt/mpvit_xsmall.pth', map_location=lambda storage, loc: storage)['model'] 786 | logger = get_root_logger() 787 | load_state_dict(model, checkpoint, strict=False, logger=logger) 788 | del checkpoint 789 | del logger 790 | model.default_cfg = _cfg_mpvit() 791 | return model 792 | 793 | 794 | def mpvit_small(**kwargs): 795 | """mpvit_small : 796 | 797 | - #paths : [2, 3, 3, 3] 798 | - #layers : [1, 3, 6, 3] 799 | - #channels : [64, 128, 216, 288] 800 | - MLP_ratio : 4 801 | Number of params : 22892400 802 | FLOPs : 4799650824 803 | Activations : 30601880 804 | """ 805 | 806 | model = MPViT( 807 | num_stages=4, 808 | num_path=[2, 3, 3, 3], 809 | num_layers=[1, 3, 6, 3], 810 | embed_dims=[64, 128, 216, 288], 811 | mlp_ratios=[4, 4, 4, 4], 812 | num_heads=[8, 8, 8, 8], 813 | **kwargs, 814 | ) 815 | checkpoint = torch.load('./ckpt/mpvit_small.pth', map_location=lambda storage, loc: storage)['model'] 816 | logger = get_root_logger() 817 | load_state_dict(model, checkpoint, strict=False, logger=logger) 818 | del checkpoint 819 | del logger 820 | model.default_cfg = _cfg_mpvit() 821 | return model 822 | 823 | 824 | def mpvit_base(**kwargs): 825 | """mpvit_base : 826 | 827 | - #paths : [2, 3, 3, 3] 828 | - #layers : [1, 3, 8, 3] 829 | - #channels : [128, 224, 368, 480] 830 | - MLP_ratio : 4 831 | Number of params: 74845976 832 | FLOPs : 16445326240 833 | Activations : 60204392 834 | """ 835 | 836 | model = MPViT( 837 | num_stages=4, 838 | num_path=[2, 3, 3, 3], 839 | num_layers=[1, 3, 8, 3], 840 | embed_dims=[128, 224, 368, 480], 841 | mlp_ratios=[4, 4, 4, 4], 842 | num_heads=[8, 8, 8, 8], 843 | **kwargs, 844 | ) 845 | model.default_cfg = _cfg_mpvit() 846 | return model 847 | -------------------------------------------------------------------------------- /networks/nets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | #from msilib.schema import Class 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from collections import OrderedDict 9 | #from layers import * 10 | 11 | #from .resnet_encoder import ResnetEncoder 12 | from .hr_decoder import DepthDecoder 13 | #from .pose_decoder import PoseDecoder 14 | from .mpvit import * 15 | 16 | 17 | class DeepNet(nn.Module): 18 | def __init__(self,type,weights_init= "pretrained",num_layers=18,num_pose_frames=2,scales=range(4)): 19 | super(DeepNet, self).__init__() 20 | self.type = type 21 | self.num_layers=num_layers 22 | self.weights_init=weights_init 23 | self.num_pose_frames=num_pose_frames 24 | self.scales = scales 25 | 26 | 27 | if self.type =='mpvitnet': 28 | self.encoder = mpvit_small() 29 | self.decoder = DepthDecoder() 30 | 31 | else: 32 | print("wrong type of the networks, only depthnet and posenet") 33 | 34 | 35 | def forward(self, inputs): 36 | if self.type =='mpvitnet': 37 | self.outputs = self.decoder(self.encoder(inputs)) 38 | else: 39 | self.outputs = self.decoder(self.encoder(inputs)) 40 | return self.outputs 41 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import numpy as np 5 | import time 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | from tensorboardX import SummaryWriter 12 | 13 | import json 14 | 15 | from utils import * 16 | from kitti_utils import * 17 | from layers import * 18 | 19 | import datasets 20 | import networks 21 | from IPython import embed 22 | 23 | 24 | class Trainer: 25 | 26 | ####################### 27 | #### MonoViT ## 28 | ###################### 29 | #self.model_optimizer = optim.AdamW(self.parameters_to_train, self.opt.learning_rate) 30 | self.params = [ { 31 | "params":self.parameters_to_train, 32 | "lr": 1e-4 33 | #"weight_decay": 0.01 34 | }, 35 | { 36 | "params": list(self.models["encoder"].parameters()), 37 | "lr": self.opt.learning_rate 38 | #"weight_decay": 0.01 39 | } ] 40 | 41 | self.model_optimizer = optim.AdamW(self.params) 42 | self.model_lr_scheduler = optim.lr_scheduler.ExponentialLR( 43 | self.model_optimizer,0.9) 44 | --------------------------------------------------------------------------------