├── .gitignore ├── EVALUATION.md ├── LICENSE ├── README.md ├── dpt ├── __init__.py ├── base_model.py ├── blocks.py ├── midas_net.py ├── models.py ├── transforms.py └── vit.py ├── input └── .placeholder ├── output_monodepth └── .placeholder ├── output_semseg └── .placeholder ├── requirements.txt ├── run_monodepth.py ├── run_segmentation.py ├── setup.py ├── util ├── __init__.py ├── io.py ├── misc.py └── pallete.py └── weights └── .placeholder /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | *.png 107 | *.pfm 108 | *.jpg 109 | *.jpeg 110 | *.pt 111 | .DS_Store 112 | -------------------------------------------------------------------------------- /EVALUATION.md: -------------------------------------------------------------------------------- 1 | ### Genral-purpose models 2 | The general-purpose models are affine-invariant and as such need a pre-alignment step before an error can be computed. 3 | 4 | Sample code for NYUv2 can be found here: 5 | https://gist.github.com/ranftlr/a1c7a24ebb24ce0e2f2ace5bce917022 6 | 7 | Sample code for KITTI can be found here: 8 | https://gist.github.com/ranftlr/45f4c7ddeb1bbb88d606bc600cab6c8d 9 | 10 | 11 | ### KITTI 12 | * Remove images from `/input/` and `/output_monodepth/` folders 13 | * Download `kitti_eval_dataset.zip` https://drive.google.com/file/d/1GbfMGuwg2VS06Vl75-_tB5FDj9EOrjl0/view?usp=sharing and unzip it in the `/input/` folder (or follow this repository https://github.com/cogaplex-bts/bts to get RGB and Depth images from list [eigen_test_files_with_gt.txt](https://github.com/cogaplex-bts/bts/blob/master/train_test_inputs/eigen_test_files_with_gt.txt) ) 14 | * Download [dpt_hybrid_kitti-cb926ef4.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid_kitti-cb926ef4.pt) model and place it in the `/weights/` folder 15 | * Download [eval_with_pngs.py](https://raw.githubusercontent.com/cogaplex-bts/bts/5a55542ebbe849eb85b5ce9592365225b93d8b28/utils/eval_with_pngs.py) in the root folder 16 | * `python run_monodepth.py --model_type dpt_hybrid_kitti --kitti_crop --absolute_depth` 17 | * `python ./eval_with_pngs.py --pred_path ./output_monodepth/ --gt_path ./input/gt/ --dataset kitti --min_depth_eval 1e-3 --max_depth_eval 80 --garg_crop --do_kb_crop` 18 | 19 | Result: 20 | ``` 21 | Evaluating 697 files 22 | GT files reading done 23 | 45 GT files missing 24 | Computing errors 25 | d1, d2, d3, AbsRel, SqRel, RMSE, RMSElog, SILog, log10 26 | 0.959, 0.995, 0.999, 0.062, 0.222, 2.575, 0.092, 8.282, 0.027 27 | Done. 28 | ``` 29 | 30 | ---- 31 | 32 | ### NYUv2 33 | * Remove images from `/input/` and `/output_monodepth/` folders 34 | * Download `nyu_eval_dataset.zip` https://drive.google.com/file/d/1b37uu-bqTZcSwokGkHIOEXuuBdfo80HI/view?usp=sharing and unzip it in the `/input/` folder (or follow this repository https://github.com/cogaplex-bts/bts to get RGB and Depth images from list [nyudepthv2_test_files_with_gt.txt](https://github.com/cogaplex-bts/bts/blob/master/train_test_inputs/nyudepthv2_test_files_with_gt.txt) ) 35 | * Download [dpt_hybrid_nyu-2ce69ec7.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid_nyu-2ce69ec7.pt) model (**or a new model** that is fine-tuned with slightly different hyperparameters [dpt_hybrid_nyu_new-217f207d.pt](https://drive.google.com/file/d/1Nxv2OiqhAMosBL2a3pflamTW39dMjaSp/view?usp=sharing) ) and place it in the `/weights/` folder 36 | * Download [eval_with_pngs.py](https://raw.githubusercontent.com/cogaplex-bts/bts/5a55542ebbe849eb85b5ce9592365225b93d8b28/utils/eval_with_pngs.py) in the root folder 37 | * `python run_monodepth.py --model_type dpt_hybrid_nyu --absolute_depth` 38 | (or **for new model** `python run_monodepth.py --model_type dpt_hybrid_nyu --absolute_depth --model_weights weights/dpt_hybrid_nyu_new-217f207d.pt` ) 39 | * `python ./eval_with_pngs.py --pred_path ./output_monodepth/ --gt_path ./input/gt/ --dataset nyu --max_depth_eval 10 --eigen_crop` 40 | 41 | Result (old model) - **from paper**: 42 | ``` 43 | Evaluating 654 files 44 | GT files reading done 45 | 0 GT files missing 46 | Computing errors 47 | d1, d2, d3, AbsRel, SqRel, RMSE, RMSElog, SILog, log10 48 | 0.904, 0.988, 0.998, 0.109, 0.054, 0.357, 0.129, 9.521, 0.045 49 | Done. 50 | ``` 51 | 52 | Result (new model): 53 | ``` 54 | GT files reading done 55 | 697 GT files missing 56 | Computing errors 57 | d1, d2, d3, AbsRel, SqRel, RMSE, RMSElog, SILog, log10 58 | 0.905, 0.988, 0.998, 0.109, 0.055, 0.357, 0.129, 9.427, 0.045 59 | Done. 60 | ``` 61 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PROJECT NOT UNDER ACTIVE MANAGEMENT 2 | This project will no longer be maintained by Intel. 3 | Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project. 4 | Intel no longer accepts patches to this project. 5 | If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project. 6 | 7 | ## Vision Transformers for Dense Prediction 8 | 9 | This repository contains code and models for our [paper](https://arxiv.org/abs/2103.13413): 10 | 11 | > Vision Transformers for Dense Prediction 12 | > René Ranftl, Alexey Bochkovskiy, Vladlen Koltun 13 | 14 | 15 | ### Changelog 16 | * [March 2021] Initial release of inference code and models 17 | 18 | ### Setup 19 | 20 | 1) Download the model weights and place them in the `weights` folder: 21 | 22 | 23 | Monodepth: 24 | - [dpt_hybrid-midas-501f0c75.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt), [Mirror](https://drive.google.com/file/d/1dgcJEYYw1F8qirXhZxgNK8dWWz_8gZBD/view?usp=sharing) 25 | - [dpt_large-midas-2f21e586.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt), [Mirror](https://drive.google.com/file/d/1vnuhoMc6caF-buQQ4hK0CeiMk9SjwB-G/view?usp=sharing) 26 | 27 | Segmentation: 28 | - [dpt_hybrid-ade20k-53898607.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-ade20k-53898607.pt), [Mirror](https://drive.google.com/file/d/1zKIAMbltJ3kpGLMh6wjsq65_k5XQ7_9m/view?usp=sharing) 29 | - [dpt_large-ade20k-b12dca68.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-ade20k-b12dca68.pt), [Mirror](https://drive.google.com/file/d/1foDpUM7CdS8Zl6GPdkrJaAOjskb7hHe-/view?usp=sharing) 30 | 31 | 2) Set up dependencies: 32 | 33 | ```shell 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | The code was tested with Python 3.7, PyTorch 1.8.0, OpenCV 4.5.1, and timm 0.4.5 38 | 39 | ### Usage 40 | 41 | 1) Place one or more input images in the folder `input`. 42 | 43 | 2) Run a monocular depth estimation model: 44 | 45 | ```shell 46 | python run_monodepth.py 47 | ``` 48 | 49 | Or run a semantic segmentation model: 50 | 51 | ```shell 52 | python run_segmentation.py 53 | ``` 54 | 55 | 3) The results are written to the folder `output_monodepth` and `output_semseg`, respectively. 56 | 57 | Use the flag `-t` to switch between different models. Possible options are `dpt_hybrid` (default) and `dpt_large`. 58 | 59 | 60 | **Additional models:** 61 | 62 | - Monodepth finetuned on KITTI: [dpt_hybrid_kitti-cb926ef4.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid_kitti-cb926ef4.pt) [Mirror](https://drive.google.com/file/d/1-oJpORoJEdxj4LTV-Pc17iB-smp-khcX/view?usp=sharing) 63 | - Monodepth finetuned on NYUv2: [dpt_hybrid_nyu-2ce69ec7.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid_nyu-2ce69ec7.pt) [Mirror](https\://drive.google.com/file/d/1NjiFw1Z9lUAfTPZu4uQ9gourVwvmd58O/view?usp=sharing) 64 | 65 | Run with 66 | 67 | ```shell 68 | python run_monodepth -t [dpt_hybrid_kitti|dpt_hybrid_nyu] 69 | ``` 70 | 71 | ### Evaluation 72 | 73 | Hints on how to evaluate monodepth models can be found here: https://github.com/intel-isl/DPT/blob/main/EVALUATION.md 74 | 75 | 76 | ### Citation 77 | 78 | Please cite our papers if you use this code or any of the models. 79 | ``` 80 | @article{Ranftl2021, 81 | author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun}, 82 | title = {Vision Transformers for Dense Prediction}, 83 | journal = {ArXiv preprint}, 84 | year = {2021}, 85 | } 86 | ``` 87 | 88 | ``` 89 | @article{Ranftl2020, 90 | author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, 91 | title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, 92 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, 93 | year = {2020}, 94 | } 95 | ``` 96 | 97 | ### Acknowledgements 98 | 99 | Our work builds on and uses code from [timm](https://github.com/rwightman/pytorch-image-models) and [PyTorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding). We'd like to thank the authors for making these libraries available. 100 | 101 | ### License 102 | 103 | MIT License 104 | -------------------------------------------------------------------------------- /dpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/DPT/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/__init__.py -------------------------------------------------------------------------------- /dpt/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device("cpu")) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /dpt/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | 12 | def _make_encoder( 13 | backbone, 14 | features, 15 | use_pretrained, 16 | groups=1, 17 | expand=False, 18 | exportable=True, 19 | hooks=None, 20 | use_vit_only=False, 21 | use_readout="ignore", 22 | enable_attention_hooks=False, 23 | ): 24 | if backbone == "vitl16_384": 25 | pretrained = _make_pretrained_vitl16_384( 26 | use_pretrained, 27 | hooks=hooks, 28 | use_readout=use_readout, 29 | enable_attention_hooks=enable_attention_hooks, 30 | ) 31 | scratch = _make_scratch( 32 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 33 | ) # ViT-L/16 - 85.0% Top1 (backbone) 34 | elif backbone == "vitb_rn50_384": 35 | pretrained = _make_pretrained_vitb_rn50_384( 36 | use_pretrained, 37 | hooks=hooks, 38 | use_vit_only=use_vit_only, 39 | use_readout=use_readout, 40 | enable_attention_hooks=enable_attention_hooks, 41 | ) 42 | scratch = _make_scratch( 43 | [256, 512, 768, 768], features, groups=groups, expand=expand 44 | ) # ViT-H/16 - 85.0% Top1 (backbone) 45 | elif backbone == "vitb16_384": 46 | pretrained = _make_pretrained_vitb16_384( 47 | use_pretrained, 48 | hooks=hooks, 49 | use_readout=use_readout, 50 | enable_attention_hooks=enable_attention_hooks, 51 | ) 52 | scratch = _make_scratch( 53 | [96, 192, 384, 768], features, groups=groups, expand=expand 54 | ) # ViT-B/16 - 84.6% Top1 (backbone) 55 | elif backbone == "resnext101_wsl": 56 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 57 | scratch = _make_scratch( 58 | [256, 512, 1024, 2048], features, groups=groups, expand=expand 59 | ) # efficientnet_lite3 60 | else: 61 | print(f"Backbone '{backbone}' not implemented") 62 | assert False 63 | 64 | return pretrained, scratch 65 | 66 | 67 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 68 | scratch = nn.Module() 69 | 70 | out_shape1 = out_shape 71 | out_shape2 = out_shape 72 | out_shape3 = out_shape 73 | out_shape4 = out_shape 74 | if expand == True: 75 | out_shape1 = out_shape 76 | out_shape2 = out_shape * 2 77 | out_shape3 = out_shape * 4 78 | out_shape4 = out_shape * 8 79 | 80 | scratch.layer1_rn = nn.Conv2d( 81 | in_shape[0], 82 | out_shape1, 83 | kernel_size=3, 84 | stride=1, 85 | padding=1, 86 | bias=False, 87 | groups=groups, 88 | ) 89 | scratch.layer2_rn = nn.Conv2d( 90 | in_shape[1], 91 | out_shape2, 92 | kernel_size=3, 93 | stride=1, 94 | padding=1, 95 | bias=False, 96 | groups=groups, 97 | ) 98 | scratch.layer3_rn = nn.Conv2d( 99 | in_shape[2], 100 | out_shape3, 101 | kernel_size=3, 102 | stride=1, 103 | padding=1, 104 | bias=False, 105 | groups=groups, 106 | ) 107 | scratch.layer4_rn = nn.Conv2d( 108 | in_shape[3], 109 | out_shape4, 110 | kernel_size=3, 111 | stride=1, 112 | padding=1, 113 | bias=False, 114 | groups=groups, 115 | ) 116 | 117 | return scratch 118 | 119 | 120 | def _make_resnet_backbone(resnet): 121 | pretrained = nn.Module() 122 | pretrained.layer1 = nn.Sequential( 123 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 124 | ) 125 | 126 | pretrained.layer2 = resnet.layer2 127 | pretrained.layer3 = resnet.layer3 128 | pretrained.layer4 = resnet.layer4 129 | 130 | return pretrained 131 | 132 | 133 | def _make_pretrained_resnext101_wsl(use_pretrained): 134 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 135 | return _make_resnet_backbone(resnet) 136 | 137 | 138 | class Interpolate(nn.Module): 139 | """Interpolation module.""" 140 | 141 | def __init__(self, scale_factor, mode, align_corners=False): 142 | """Init. 143 | 144 | Args: 145 | scale_factor (float): scaling 146 | mode (str): interpolation mode 147 | """ 148 | super(Interpolate, self).__init__() 149 | 150 | self.interp = nn.functional.interpolate 151 | self.scale_factor = scale_factor 152 | self.mode = mode 153 | self.align_corners = align_corners 154 | 155 | def forward(self, x): 156 | """Forward pass. 157 | 158 | Args: 159 | x (tensor): input 160 | 161 | Returns: 162 | tensor: interpolated data 163 | """ 164 | 165 | x = self.interp( 166 | x, 167 | scale_factor=self.scale_factor, 168 | mode=self.mode, 169 | align_corners=self.align_corners, 170 | ) 171 | 172 | return x 173 | 174 | 175 | class ResidualConvUnit(nn.Module): 176 | """Residual convolution module.""" 177 | 178 | def __init__(self, features): 179 | """Init. 180 | 181 | Args: 182 | features (int): number of features 183 | """ 184 | super().__init__() 185 | 186 | self.conv1 = nn.Conv2d( 187 | features, features, kernel_size=3, stride=1, padding=1, bias=True 188 | ) 189 | 190 | self.conv2 = nn.Conv2d( 191 | features, features, kernel_size=3, stride=1, padding=1, bias=True 192 | ) 193 | 194 | self.relu = nn.ReLU(inplace=True) 195 | 196 | def forward(self, x): 197 | """Forward pass. 198 | 199 | Args: 200 | x (tensor): input 201 | 202 | Returns: 203 | tensor: output 204 | """ 205 | out = self.relu(x) 206 | out = self.conv1(out) 207 | out = self.relu(out) 208 | out = self.conv2(out) 209 | 210 | return out + x 211 | 212 | 213 | class FeatureFusionBlock(nn.Module): 214 | """Feature fusion block.""" 215 | 216 | def __init__(self, features): 217 | """Init. 218 | 219 | Args: 220 | features (int): number of features 221 | """ 222 | super(FeatureFusionBlock, self).__init__() 223 | 224 | self.resConfUnit1 = ResidualConvUnit(features) 225 | self.resConfUnit2 = ResidualConvUnit(features) 226 | 227 | def forward(self, *xs): 228 | """Forward pass. 229 | 230 | Returns: 231 | tensor: output 232 | """ 233 | output = xs[0] 234 | 235 | if len(xs) == 2: 236 | output += self.resConfUnit1(xs[1]) 237 | 238 | output = self.resConfUnit2(output) 239 | 240 | output = nn.functional.interpolate( 241 | output, scale_factor=2, mode="bilinear", align_corners=True 242 | ) 243 | 244 | return output 245 | 246 | 247 | class ResidualConvUnit_custom(nn.Module): 248 | """Residual convolution module.""" 249 | 250 | def __init__(self, features, activation, bn): 251 | """Init. 252 | 253 | Args: 254 | features (int): number of features 255 | """ 256 | super().__init__() 257 | 258 | self.bn = bn 259 | 260 | self.groups = 1 261 | 262 | self.conv1 = nn.Conv2d( 263 | features, 264 | features, 265 | kernel_size=3, 266 | stride=1, 267 | padding=1, 268 | bias=not self.bn, 269 | groups=self.groups, 270 | ) 271 | 272 | self.conv2 = nn.Conv2d( 273 | features, 274 | features, 275 | kernel_size=3, 276 | stride=1, 277 | padding=1, 278 | bias=not self.bn, 279 | groups=self.groups, 280 | ) 281 | 282 | if self.bn == True: 283 | self.bn1 = nn.BatchNorm2d(features) 284 | self.bn2 = nn.BatchNorm2d(features) 285 | 286 | self.activation = activation 287 | 288 | self.skip_add = nn.quantized.FloatFunctional() 289 | 290 | def forward(self, x): 291 | """Forward pass. 292 | 293 | Args: 294 | x (tensor): input 295 | 296 | Returns: 297 | tensor: output 298 | """ 299 | 300 | out = self.activation(x) 301 | out = self.conv1(out) 302 | if self.bn == True: 303 | out = self.bn1(out) 304 | 305 | out = self.activation(out) 306 | out = self.conv2(out) 307 | if self.bn == True: 308 | out = self.bn2(out) 309 | 310 | if self.groups > 1: 311 | out = self.conv_merge(out) 312 | 313 | return self.skip_add.add(out, x) 314 | 315 | # return out + x 316 | 317 | 318 | class FeatureFusionBlock_custom(nn.Module): 319 | """Feature fusion block.""" 320 | 321 | def __init__( 322 | self, 323 | features, 324 | activation, 325 | deconv=False, 326 | bn=False, 327 | expand=False, 328 | align_corners=True, 329 | ): 330 | """Init. 331 | 332 | Args: 333 | features (int): number of features 334 | """ 335 | super(FeatureFusionBlock_custom, self).__init__() 336 | 337 | self.deconv = deconv 338 | self.align_corners = align_corners 339 | 340 | self.groups = 1 341 | 342 | self.expand = expand 343 | out_features = features 344 | if self.expand == True: 345 | out_features = features // 2 346 | 347 | self.out_conv = nn.Conv2d( 348 | features, 349 | out_features, 350 | kernel_size=1, 351 | stride=1, 352 | padding=0, 353 | bias=True, 354 | groups=1, 355 | ) 356 | 357 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 358 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 359 | 360 | self.skip_add = nn.quantized.FloatFunctional() 361 | 362 | def forward(self, *xs): 363 | """Forward pass. 364 | 365 | Returns: 366 | tensor: output 367 | """ 368 | output = xs[0] 369 | 370 | if len(xs) == 2: 371 | res = self.resConfUnit1(xs[1]) 372 | output = self.skip_add.add(output, res) 373 | # output += res 374 | 375 | output = self.resConfUnit2(output) 376 | 377 | output = nn.functional.interpolate( 378 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 379 | ) 380 | 381 | output = self.out_conv(output) 382 | 383 | return output 384 | -------------------------------------------------------------------------------- /dpt/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_large(BaseModel): 13 | """Network for monocular depth estimation.""" 14 | 15 | def __init__(self, path=None, features=256, non_negative=True): 16 | """Init. 17 | 18 | Args: 19 | path (str, optional): Path to saved model. Defaults to None. 20 | features (int, optional): Number of features. Defaults to 256. 21 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 22 | """ 23 | print("Loading weights: ", path) 24 | 25 | super(MidasNet_large, self).__init__() 26 | 27 | use_pretrained = False if path is None else True 28 | 29 | self.pretrained, self.scratch = _make_encoder( 30 | backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained 31 | ) 32 | 33 | self.scratch.refinenet4 = FeatureFusionBlock(features) 34 | self.scratch.refinenet3 = FeatureFusionBlock(features) 35 | self.scratch.refinenet2 = FeatureFusionBlock(features) 36 | self.scratch.refinenet1 = FeatureFusionBlock(features) 37 | 38 | self.scratch.output_conv = nn.Sequential( 39 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 40 | Interpolate(scale_factor=2, mode="bilinear"), 41 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 42 | nn.ReLU(True), 43 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 44 | nn.ReLU(True) if non_negative else nn.Identity(), 45 | ) 46 | 47 | if path: 48 | self.load(path) 49 | 50 | def forward(self, x): 51 | """Forward pass. 52 | 53 | Args: 54 | x (tensor): input data (image) 55 | 56 | Returns: 57 | tensor: depth 58 | """ 59 | 60 | layer_1 = self.pretrained.layer1(x) 61 | layer_2 = self.pretrained.layer2(layer_1) 62 | layer_3 = self.pretrained.layer3(layer_2) 63 | layer_4 = self.pretrained.layer4(layer_3) 64 | 65 | layer_1_rn = self.scratch.layer1_rn(layer_1) 66 | layer_2_rn = self.scratch.layer2_rn(layer_2) 67 | layer_3_rn = self.scratch.layer3_rn(layer_3) 68 | layer_4_rn = self.scratch.layer4_rn(layer_4) 69 | 70 | path_4 = self.scratch.refinenet4(layer_4_rn) 71 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 72 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 73 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 74 | 75 | out = self.scratch.output_conv(path_1) 76 | 77 | return torch.squeeze(out, dim=1) 78 | -------------------------------------------------------------------------------- /dpt/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | enable_attention_hooks=False, 36 | ): 37 | 38 | super(DPT, self).__init__() 39 | 40 | self.channels_last = channels_last 41 | 42 | hooks = { 43 | "vitb_rn50_384": [0, 1, 8, 11], 44 | "vitb16_384": [2, 5, 8, 11], 45 | "vitl16_384": [5, 11, 17, 23], 46 | } 47 | 48 | # Instantiate backbone and reassemble blocks 49 | self.pretrained, self.scratch = _make_encoder( 50 | backbone, 51 | features, 52 | False, # Set to true of you want to train from scratch, uses ImageNet weights 53 | groups=1, 54 | expand=False, 55 | exportable=False, 56 | hooks=hooks[backbone], 57 | use_readout=readout, 58 | enable_attention_hooks=enable_attention_hooks, 59 | ) 60 | 61 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 63 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 64 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 65 | 66 | self.scratch.output_conv = head 67 | 68 | def forward(self, x): 69 | if self.channels_last == True: 70 | x.contiguous(memory_format=torch.channels_last) 71 | 72 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 73 | 74 | layer_1_rn = self.scratch.layer1_rn(layer_1) 75 | layer_2_rn = self.scratch.layer2_rn(layer_2) 76 | layer_3_rn = self.scratch.layer3_rn(layer_3) 77 | layer_4_rn = self.scratch.layer4_rn(layer_4) 78 | 79 | path_4 = self.scratch.refinenet4(layer_4_rn) 80 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 81 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 82 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 83 | 84 | out = self.scratch.output_conv(path_1) 85 | 86 | return out 87 | 88 | 89 | class DPTDepthModel(DPT): 90 | def __init__( 91 | self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs 92 | ): 93 | features = kwargs["features"] if "features" in kwargs else 256 94 | 95 | self.scale = scale 96 | self.shift = shift 97 | self.invert = invert 98 | 99 | head = nn.Sequential( 100 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 101 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 102 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 103 | nn.ReLU(True), 104 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 105 | nn.ReLU(True) if non_negative else nn.Identity(), 106 | nn.Identity(), 107 | ) 108 | 109 | super().__init__(head, **kwargs) 110 | 111 | if path is not None: 112 | self.load(path) 113 | 114 | def forward(self, x): 115 | inv_depth = super().forward(x).squeeze(dim=1) 116 | 117 | if self.invert: 118 | depth = self.scale * inv_depth + self.shift 119 | depth[depth < 1e-8] = 1e-8 120 | depth = 1.0 / depth 121 | return depth 122 | else: 123 | return inv_depth 124 | 125 | 126 | class DPTSegmentationModel(DPT): 127 | def __init__(self, num_classes, path=None, **kwargs): 128 | 129 | features = kwargs["features"] if "features" in kwargs else 256 130 | 131 | kwargs["use_bn"] = True 132 | 133 | head = nn.Sequential( 134 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), 135 | nn.BatchNorm2d(features), 136 | nn.ReLU(True), 137 | nn.Dropout(0.1, False), 138 | nn.Conv2d(features, num_classes, kernel_size=1), 139 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 140 | ) 141 | 142 | super().__init__(head, **kwargs) 143 | 144 | self.auxlayer = nn.Sequential( 145 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), 146 | nn.BatchNorm2d(features), 147 | nn.ReLU(True), 148 | nn.Dropout(0.1, False), 149 | nn.Conv2d(features, num_classes, kernel_size=1), 150 | ) 151 | 152 | if path is not None: 153 | self.load(path) 154 | -------------------------------------------------------------------------------- /dpt/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height).""" 50 | 51 | def __init__( 52 | self, 53 | width, 54 | height, 55 | resize_target=True, 56 | keep_aspect_ratio=False, 57 | ensure_multiple_of=1, 58 | resize_method="lower_bound", 59 | image_interpolation_method=cv2.INTER_AREA, 60 | ): 61 | """Init. 62 | 63 | Args: 64 | width (int): desired output width 65 | height (int): desired output height 66 | resize_target (bool, optional): 67 | True: Resize the full sample (image, mask, target). 68 | False: Resize image only. 69 | Defaults to True. 70 | keep_aspect_ratio (bool, optional): 71 | True: Keep the aspect ratio of the input sample. 72 | Output sample might not have the given width and height, and 73 | resize behaviour depends on the parameter 'resize_method'. 74 | Defaults to False. 75 | ensure_multiple_of (int, optional): 76 | Output width and height is constrained to be multiple of this parameter. 77 | Defaults to 1. 78 | resize_method (str, optional): 79 | "lower_bound": Output will be at least as large as the given size. 80 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 81 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 82 | Defaults to "lower_bound". 83 | """ 84 | self.__width = width 85 | self.__height = height 86 | 87 | self.__resize_target = resize_target 88 | self.__keep_aspect_ratio = keep_aspect_ratio 89 | self.__multiple_of = ensure_multiple_of 90 | self.__resize_method = resize_method 91 | self.__image_interpolation_method = image_interpolation_method 92 | 93 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 94 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 95 | 96 | if max_val is not None and y > max_val: 97 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 98 | 99 | if y < min_val: 100 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 101 | 102 | return y 103 | 104 | def get_size(self, width, height): 105 | # determine new height and width 106 | scale_height = self.__height / height 107 | scale_width = self.__width / width 108 | 109 | if self.__keep_aspect_ratio: 110 | if self.__resize_method == "lower_bound": 111 | # scale such that output size is lower bound 112 | if scale_width > scale_height: 113 | # fit width 114 | scale_height = scale_width 115 | else: 116 | # fit height 117 | scale_width = scale_height 118 | elif self.__resize_method == "upper_bound": 119 | # scale such that output size is upper bound 120 | if scale_width < scale_height: 121 | # fit width 122 | scale_height = scale_width 123 | else: 124 | # fit height 125 | scale_width = scale_height 126 | elif self.__resize_method == "minimal": 127 | # scale as least as possbile 128 | if abs(1 - scale_width) < abs(1 - scale_height): 129 | # fit width 130 | scale_height = scale_width 131 | else: 132 | # fit height 133 | scale_width = scale_height 134 | else: 135 | raise ValueError( 136 | f"resize_method {self.__resize_method} not implemented" 137 | ) 138 | 139 | if self.__resize_method == "lower_bound": 140 | new_height = self.constrain_to_multiple_of( 141 | scale_height * height, min_val=self.__height 142 | ) 143 | new_width = self.constrain_to_multiple_of( 144 | scale_width * width, min_val=self.__width 145 | ) 146 | elif self.__resize_method == "upper_bound": 147 | new_height = self.constrain_to_multiple_of( 148 | scale_height * height, max_val=self.__height 149 | ) 150 | new_width = self.constrain_to_multiple_of( 151 | scale_width * width, max_val=self.__width 152 | ) 153 | elif self.__resize_method == "minimal": 154 | new_height = self.constrain_to_multiple_of(scale_height * height) 155 | new_width = self.constrain_to_multiple_of(scale_width * width) 156 | else: 157 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 158 | 159 | return (new_width, new_height) 160 | 161 | def __call__(self, sample): 162 | width, height = self.get_size( 163 | sample["image"].shape[1], sample["image"].shape[0] 164 | ) 165 | 166 | # resize sample 167 | sample["image"] = cv2.resize( 168 | sample["image"], 169 | (width, height), 170 | interpolation=self.__image_interpolation_method, 171 | ) 172 | 173 | if self.__resize_target: 174 | if "disparity" in sample: 175 | sample["disparity"] = cv2.resize( 176 | sample["disparity"], 177 | (width, height), 178 | interpolation=cv2.INTER_NEAREST, 179 | ) 180 | 181 | if "depth" in sample: 182 | sample["depth"] = cv2.resize( 183 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 184 | ) 185 | 186 | sample["mask"] = cv2.resize( 187 | sample["mask"].astype(np.float32), 188 | (width, height), 189 | interpolation=cv2.INTER_NEAREST, 190 | ) 191 | sample["mask"] = sample["mask"].astype(bool) 192 | 193 | return sample 194 | 195 | 196 | class NormalizeImage(object): 197 | """Normlize image by given mean and std.""" 198 | 199 | def __init__(self, mean, std): 200 | self.__mean = mean 201 | self.__std = std 202 | 203 | def __call__(self, sample): 204 | sample["image"] = (sample["image"] - self.__mean) / self.__std 205 | 206 | return sample 207 | 208 | 209 | class PrepareForNet(object): 210 | """Prepare sample for usage as network input.""" 211 | 212 | def __init__(self): 213 | pass 214 | 215 | def __call__(self, sample): 216 | image = np.transpose(sample["image"], (2, 0, 1)) 217 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 218 | 219 | if "mask" in sample: 220 | sample["mask"] = sample["mask"].astype(np.float32) 221 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 222 | 223 | if "disparity" in sample: 224 | disparity = sample["disparity"].astype(np.float32) 225 | sample["disparity"] = np.ascontiguousarray(disparity) 226 | 227 | if "depth" in sample: 228 | depth = sample["depth"].astype(np.float32) 229 | sample["depth"] = np.ascontiguousarray(depth) 230 | 231 | return sample 232 | -------------------------------------------------------------------------------- /dpt/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | activations = {} 10 | 11 | 12 | def get_activation(name): 13 | def hook(model, input, output): 14 | activations[name] = output 15 | 16 | return hook 17 | 18 | 19 | attention = {} 20 | 21 | 22 | def get_attention(name): 23 | def hook(module, input, output): 24 | x = input[0] 25 | B, N, C = x.shape 26 | qkv = ( 27 | module.qkv(x) 28 | .reshape(B, N, 3, module.num_heads, C // module.num_heads) 29 | .permute(2, 0, 3, 1, 4) 30 | ) 31 | q, k, v = ( 32 | qkv[0], 33 | qkv[1], 34 | qkv[2], 35 | ) # make torchscript happy (cannot use tensor as tuple) 36 | 37 | attn = (q @ k.transpose(-2, -1)) * module.scale 38 | 39 | attn = attn.softmax(dim=-1) # [:,:,1,1:] 40 | attention[name] = attn 41 | 42 | return hook 43 | 44 | 45 | def get_mean_attention_map(attn, token, shape): 46 | attn = attn[:, :, token, 1:] 47 | attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float() 48 | attn = torch.nn.functional.interpolate( 49 | attn, size=shape[2:], mode="bicubic", align_corners=False 50 | ).squeeze(0) 51 | 52 | all_attn = torch.mean(attn, 0) 53 | 54 | return all_attn 55 | 56 | 57 | class Slice(nn.Module): 58 | def __init__(self, start_index=1): 59 | super(Slice, self).__init__() 60 | self.start_index = start_index 61 | 62 | def forward(self, x): 63 | return x[:, self.start_index :] 64 | 65 | 66 | class AddReadout(nn.Module): 67 | def __init__(self, start_index=1): 68 | super(AddReadout, self).__init__() 69 | self.start_index = start_index 70 | 71 | def forward(self, x): 72 | if self.start_index == 2: 73 | readout = (x[:, 0] + x[:, 1]) / 2 74 | else: 75 | readout = x[:, 0] 76 | return x[:, self.start_index :] + readout.unsqueeze(1) 77 | 78 | 79 | class ProjectReadout(nn.Module): 80 | def __init__(self, in_features, start_index=1): 81 | super(ProjectReadout, self).__init__() 82 | self.start_index = start_index 83 | 84 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 85 | 86 | def forward(self, x): 87 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 88 | features = torch.cat((x[:, self.start_index :], readout), -1) 89 | 90 | return self.project(features) 91 | 92 | 93 | class Transpose(nn.Module): 94 | def __init__(self, dim0, dim1): 95 | super(Transpose, self).__init__() 96 | self.dim0 = dim0 97 | self.dim1 = dim1 98 | 99 | def forward(self, x): 100 | x = x.transpose(self.dim0, self.dim1) 101 | return x 102 | 103 | 104 | def forward_vit(pretrained, x): 105 | b, c, h, w = x.shape 106 | 107 | glob = pretrained.model.forward_flex(x) 108 | 109 | layer_1 = pretrained.activations["1"] 110 | layer_2 = pretrained.activations["2"] 111 | layer_3 = pretrained.activations["3"] 112 | layer_4 = pretrained.activations["4"] 113 | 114 | layer_1 = pretrained.act_postprocess1[0:2](layer_1) 115 | layer_2 = pretrained.act_postprocess2[0:2](layer_2) 116 | layer_3 = pretrained.act_postprocess3[0:2](layer_3) 117 | layer_4 = pretrained.act_postprocess4[0:2](layer_4) 118 | 119 | unflatten = nn.Sequential( 120 | nn.Unflatten( 121 | 2, 122 | torch.Size( 123 | [ 124 | h // pretrained.model.patch_size[1], 125 | w // pretrained.model.patch_size[0], 126 | ] 127 | ), 128 | ) 129 | ) 130 | 131 | if layer_1.ndim == 3: 132 | layer_1 = unflatten(layer_1) 133 | if layer_2.ndim == 3: 134 | layer_2 = unflatten(layer_2) 135 | if layer_3.ndim == 3: 136 | layer_3 = unflatten(layer_3) 137 | if layer_4.ndim == 3: 138 | layer_4 = unflatten(layer_4) 139 | 140 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) 141 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) 142 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) 143 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) 144 | 145 | return layer_1, layer_2, layer_3, layer_4 146 | 147 | 148 | def _resize_pos_embed(self, posemb, gs_h, gs_w): 149 | posemb_tok, posemb_grid = ( 150 | posemb[:, : self.start_index], 151 | posemb[0, self.start_index :], 152 | ) 153 | 154 | gs_old = int(math.sqrt(len(posemb_grid))) 155 | 156 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 157 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 158 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 159 | 160 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 161 | 162 | return posemb 163 | 164 | 165 | def forward_flex(self, x): 166 | b, c, h, w = x.shape 167 | 168 | pos_embed = self._resize_pos_embed( 169 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] 170 | ) 171 | 172 | B = x.shape[0] 173 | 174 | if hasattr(self.patch_embed, "backbone"): 175 | x = self.patch_embed.backbone(x) 176 | if isinstance(x, (list, tuple)): 177 | x = x[-1] # last feature if backbone outputs list/tuple of features 178 | 179 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) 180 | 181 | if getattr(self, "dist_token", None) is not None: 182 | cls_tokens = self.cls_token.expand( 183 | B, -1, -1 184 | ) # stole cls_tokens impl from Phil Wang, thanks 185 | dist_token = self.dist_token.expand(B, -1, -1) 186 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 187 | else: 188 | cls_tokens = self.cls_token.expand( 189 | B, -1, -1 190 | ) # stole cls_tokens impl from Phil Wang, thanks 191 | x = torch.cat((cls_tokens, x), dim=1) 192 | 193 | x = x + pos_embed 194 | x = self.pos_drop(x) 195 | 196 | for blk in self.blocks: 197 | x = blk(x) 198 | 199 | x = self.norm(x) 200 | 201 | return x 202 | 203 | 204 | def get_readout_oper(vit_features, features, use_readout, start_index=1): 205 | if use_readout == "ignore": 206 | readout_oper = [Slice(start_index)] * len(features) 207 | elif use_readout == "add": 208 | readout_oper = [AddReadout(start_index)] * len(features) 209 | elif use_readout == "project": 210 | readout_oper = [ 211 | ProjectReadout(vit_features, start_index) for out_feat in features 212 | ] 213 | else: 214 | assert ( 215 | False 216 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" 217 | 218 | return readout_oper 219 | 220 | 221 | def _make_vit_b16_backbone( 222 | model, 223 | features=[96, 192, 384, 768], 224 | size=[384, 384], 225 | hooks=[2, 5, 8, 11], 226 | vit_features=768, 227 | use_readout="ignore", 228 | start_index=1, 229 | enable_attention_hooks=False, 230 | ): 231 | pretrained = nn.Module() 232 | 233 | pretrained.model = model 234 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 235 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 236 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 237 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 238 | 239 | pretrained.activations = activations 240 | 241 | if enable_attention_hooks: 242 | pretrained.model.blocks[hooks[0]].attn.register_forward_hook( 243 | get_attention("attn_1") 244 | ) 245 | pretrained.model.blocks[hooks[1]].attn.register_forward_hook( 246 | get_attention("attn_2") 247 | ) 248 | pretrained.model.blocks[hooks[2]].attn.register_forward_hook( 249 | get_attention("attn_3") 250 | ) 251 | pretrained.model.blocks[hooks[3]].attn.register_forward_hook( 252 | get_attention("attn_4") 253 | ) 254 | pretrained.attention = attention 255 | 256 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 257 | 258 | # 32, 48, 136, 384 259 | pretrained.act_postprocess1 = nn.Sequential( 260 | readout_oper[0], 261 | Transpose(1, 2), 262 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 263 | nn.Conv2d( 264 | in_channels=vit_features, 265 | out_channels=features[0], 266 | kernel_size=1, 267 | stride=1, 268 | padding=0, 269 | ), 270 | nn.ConvTranspose2d( 271 | in_channels=features[0], 272 | out_channels=features[0], 273 | kernel_size=4, 274 | stride=4, 275 | padding=0, 276 | bias=True, 277 | dilation=1, 278 | groups=1, 279 | ), 280 | ) 281 | 282 | pretrained.act_postprocess2 = nn.Sequential( 283 | readout_oper[1], 284 | Transpose(1, 2), 285 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 286 | nn.Conv2d( 287 | in_channels=vit_features, 288 | out_channels=features[1], 289 | kernel_size=1, 290 | stride=1, 291 | padding=0, 292 | ), 293 | nn.ConvTranspose2d( 294 | in_channels=features[1], 295 | out_channels=features[1], 296 | kernel_size=2, 297 | stride=2, 298 | padding=0, 299 | bias=True, 300 | dilation=1, 301 | groups=1, 302 | ), 303 | ) 304 | 305 | pretrained.act_postprocess3 = nn.Sequential( 306 | readout_oper[2], 307 | Transpose(1, 2), 308 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 309 | nn.Conv2d( 310 | in_channels=vit_features, 311 | out_channels=features[2], 312 | kernel_size=1, 313 | stride=1, 314 | padding=0, 315 | ), 316 | ) 317 | 318 | pretrained.act_postprocess4 = nn.Sequential( 319 | readout_oper[3], 320 | Transpose(1, 2), 321 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 322 | nn.Conv2d( 323 | in_channels=vit_features, 324 | out_channels=features[3], 325 | kernel_size=1, 326 | stride=1, 327 | padding=0, 328 | ), 329 | nn.Conv2d( 330 | in_channels=features[3], 331 | out_channels=features[3], 332 | kernel_size=3, 333 | stride=2, 334 | padding=1, 335 | ), 336 | ) 337 | 338 | pretrained.model.start_index = start_index 339 | pretrained.model.patch_size = [16, 16] 340 | 341 | # We inject this function into the VisionTransformer instances so that 342 | # we can use it with interpolated position embeddings without modifying the library source. 343 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 344 | pretrained.model._resize_pos_embed = types.MethodType( 345 | _resize_pos_embed, pretrained.model 346 | ) 347 | 348 | return pretrained 349 | 350 | 351 | def _make_vit_b_rn50_backbone( 352 | model, 353 | features=[256, 512, 768, 768], 354 | size=[384, 384], 355 | hooks=[0, 1, 8, 11], 356 | vit_features=768, 357 | use_vit_only=False, 358 | use_readout="ignore", 359 | start_index=1, 360 | enable_attention_hooks=False, 361 | ): 362 | pretrained = nn.Module() 363 | 364 | pretrained.model = model 365 | 366 | if use_vit_only == True: 367 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 368 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 369 | else: 370 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( 371 | get_activation("1") 372 | ) 373 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( 374 | get_activation("2") 375 | ) 376 | 377 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 378 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 379 | 380 | if enable_attention_hooks: 381 | pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1")) 382 | pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2")) 383 | pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3")) 384 | pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4")) 385 | pretrained.attention = attention 386 | 387 | pretrained.activations = activations 388 | 389 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 390 | 391 | if use_vit_only == True: 392 | pretrained.act_postprocess1 = nn.Sequential( 393 | readout_oper[0], 394 | Transpose(1, 2), 395 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 396 | nn.Conv2d( 397 | in_channels=vit_features, 398 | out_channels=features[0], 399 | kernel_size=1, 400 | stride=1, 401 | padding=0, 402 | ), 403 | nn.ConvTranspose2d( 404 | in_channels=features[0], 405 | out_channels=features[0], 406 | kernel_size=4, 407 | stride=4, 408 | padding=0, 409 | bias=True, 410 | dilation=1, 411 | groups=1, 412 | ), 413 | ) 414 | 415 | pretrained.act_postprocess2 = nn.Sequential( 416 | readout_oper[1], 417 | Transpose(1, 2), 418 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 419 | nn.Conv2d( 420 | in_channels=vit_features, 421 | out_channels=features[1], 422 | kernel_size=1, 423 | stride=1, 424 | padding=0, 425 | ), 426 | nn.ConvTranspose2d( 427 | in_channels=features[1], 428 | out_channels=features[1], 429 | kernel_size=2, 430 | stride=2, 431 | padding=0, 432 | bias=True, 433 | dilation=1, 434 | groups=1, 435 | ), 436 | ) 437 | else: 438 | pretrained.act_postprocess1 = nn.Sequential( 439 | nn.Identity(), nn.Identity(), nn.Identity() 440 | ) 441 | pretrained.act_postprocess2 = nn.Sequential( 442 | nn.Identity(), nn.Identity(), nn.Identity() 443 | ) 444 | 445 | pretrained.act_postprocess3 = nn.Sequential( 446 | readout_oper[2], 447 | Transpose(1, 2), 448 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 449 | nn.Conv2d( 450 | in_channels=vit_features, 451 | out_channels=features[2], 452 | kernel_size=1, 453 | stride=1, 454 | padding=0, 455 | ), 456 | ) 457 | 458 | pretrained.act_postprocess4 = nn.Sequential( 459 | readout_oper[3], 460 | Transpose(1, 2), 461 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 462 | nn.Conv2d( 463 | in_channels=vit_features, 464 | out_channels=features[3], 465 | kernel_size=1, 466 | stride=1, 467 | padding=0, 468 | ), 469 | nn.Conv2d( 470 | in_channels=features[3], 471 | out_channels=features[3], 472 | kernel_size=3, 473 | stride=2, 474 | padding=1, 475 | ), 476 | ) 477 | 478 | pretrained.model.start_index = start_index 479 | pretrained.model.patch_size = [16, 16] 480 | 481 | # We inject this function into the VisionTransformer instances so that 482 | # we can use it with interpolated position embeddings without modifying the library source. 483 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 484 | 485 | # We inject this function into the VisionTransformer instances so that 486 | # we can use it with interpolated position embeddings without modifying the library source. 487 | pretrained.model._resize_pos_embed = types.MethodType( 488 | _resize_pos_embed, pretrained.model 489 | ) 490 | 491 | return pretrained 492 | 493 | 494 | def _make_pretrained_vitb_rn50_384( 495 | pretrained, 496 | use_readout="ignore", 497 | hooks=None, 498 | use_vit_only=False, 499 | enable_attention_hooks=False, 500 | ): 501 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) 502 | 503 | hooks = [0, 1, 8, 11] if hooks == None else hooks 504 | return _make_vit_b_rn50_backbone( 505 | model, 506 | features=[256, 512, 768, 768], 507 | size=[384, 384], 508 | hooks=hooks, 509 | use_vit_only=use_vit_only, 510 | use_readout=use_readout, 511 | enable_attention_hooks=enable_attention_hooks, 512 | ) 513 | 514 | 515 | def _make_pretrained_vitl16_384( 516 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 517 | ): 518 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) 519 | 520 | hooks = [5, 11, 17, 23] if hooks == None else hooks 521 | return _make_vit_b16_backbone( 522 | model, 523 | features=[256, 512, 1024, 1024], 524 | hooks=hooks, 525 | vit_features=1024, 526 | use_readout=use_readout, 527 | enable_attention_hooks=enable_attention_hooks, 528 | ) 529 | 530 | 531 | def _make_pretrained_vitb16_384( 532 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 533 | ): 534 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) 535 | 536 | hooks = [2, 5, 8, 11] if hooks == None else hooks 537 | return _make_vit_b16_backbone( 538 | model, 539 | features=[96, 192, 384, 768], 540 | hooks=hooks, 541 | use_readout=use_readout, 542 | enable_attention_hooks=enable_attention_hooks, 543 | ) 544 | 545 | 546 | def _make_pretrained_deitb16_384( 547 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 548 | ): 549 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) 550 | 551 | hooks = [2, 5, 8, 11] if hooks == None else hooks 552 | return _make_vit_b16_backbone( 553 | model, 554 | features=[96, 192, 384, 768], 555 | hooks=hooks, 556 | use_readout=use_readout, 557 | enable_attention_hooks=enable_attention_hooks, 558 | ) 559 | 560 | 561 | def _make_pretrained_deitb16_distil_384( 562 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 563 | ): 564 | model = timm.create_model( 565 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained 566 | ) 567 | 568 | hooks = [2, 5, 8, 11] if hooks == None else hooks 569 | return _make_vit_b16_backbone( 570 | model, 571 | features=[96, 192, 384, 768], 572 | hooks=hooks, 573 | use_readout=use_readout, 574 | start_index=2, 575 | enable_attention_hooks=enable_attention_hooks, 576 | ) 577 | -------------------------------------------------------------------------------- /input/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/DPT/cd3fe90bb4c48577535cc4d51b602acca688a2ee/input/.placeholder -------------------------------------------------------------------------------- /output_monodepth/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/DPT/cd3fe90bb4c48577535cc4d51b602acca688a2ee/output_monodepth/.placeholder -------------------------------------------------------------------------------- /output_semseg/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/DPT/cd3fe90bb4c48577535cc4d51b602acca688a2ee/output_semseg/.placeholder -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | torchvision==0.9.1 3 | opencv-python==4.5.2.54 4 | timm==0.4.5 5 | -------------------------------------------------------------------------------- /run_monodepth.py: -------------------------------------------------------------------------------- 1 | """Compute depth maps for images in the input folder. 2 | """ 3 | import os 4 | import glob 5 | import torch 6 | import cv2 7 | import argparse 8 | 9 | import util.io 10 | 11 | from torchvision.transforms import Compose 12 | 13 | from dpt.models import DPTDepthModel 14 | from dpt.midas_net import MidasNet_large 15 | from dpt.transforms import Resize, NormalizeImage, PrepareForNet 16 | 17 | #from util.misc import visualize_attention 18 | 19 | 20 | def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=True): 21 | """Run MonoDepthNN to compute depth maps. 22 | 23 | Args: 24 | input_path (str): path to input folder 25 | output_path (str): path to output folder 26 | model_path (str): path to saved model 27 | """ 28 | print("initialize") 29 | 30 | # select device 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | print("device: %s" % device) 33 | 34 | # load network 35 | if model_type == "dpt_large": # DPT-Large 36 | net_w = net_h = 384 37 | model = DPTDepthModel( 38 | path=model_path, 39 | backbone="vitl16_384", 40 | non_negative=True, 41 | enable_attention_hooks=False, 42 | ) 43 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 44 | elif model_type == "dpt_hybrid": # DPT-Hybrid 45 | net_w = net_h = 384 46 | model = DPTDepthModel( 47 | path=model_path, 48 | backbone="vitb_rn50_384", 49 | non_negative=True, 50 | enable_attention_hooks=False, 51 | ) 52 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 53 | elif model_type == "dpt_hybrid_kitti": 54 | net_w = 1216 55 | net_h = 352 56 | 57 | model = DPTDepthModel( 58 | path=model_path, 59 | scale=0.00006016, 60 | shift=0.00579, 61 | invert=True, 62 | backbone="vitb_rn50_384", 63 | non_negative=True, 64 | enable_attention_hooks=False, 65 | ) 66 | 67 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 68 | elif model_type == "dpt_hybrid_nyu": 69 | net_w = 640 70 | net_h = 480 71 | 72 | model = DPTDepthModel( 73 | path=model_path, 74 | scale=0.000305, 75 | shift=0.1378, 76 | invert=True, 77 | backbone="vitb_rn50_384", 78 | non_negative=True, 79 | enable_attention_hooks=False, 80 | ) 81 | 82 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 83 | elif model_type == "midas_v21": # Convolutional model 84 | net_w = net_h = 384 85 | 86 | model = MidasNet_large(model_path, non_negative=True) 87 | normalization = NormalizeImage( 88 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 89 | ) 90 | else: 91 | assert ( 92 | False 93 | ), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid|dpt_hybrid_kitti|dpt_hybrid_nyu|midas_v21]" 94 | 95 | transform = Compose( 96 | [ 97 | Resize( 98 | net_w, 99 | net_h, 100 | resize_target=None, 101 | keep_aspect_ratio=True, 102 | ensure_multiple_of=32, 103 | resize_method="minimal", 104 | image_interpolation_method=cv2.INTER_CUBIC, 105 | ), 106 | normalization, 107 | PrepareForNet(), 108 | ] 109 | ) 110 | 111 | model.eval() 112 | 113 | if optimize == True and device == torch.device("cuda"): 114 | model = model.to(memory_format=torch.channels_last) 115 | model = model.half() 116 | 117 | model.to(device) 118 | 119 | # get input 120 | img_names = glob.glob(os.path.join(input_path, "*")) 121 | num_images = len(img_names) 122 | 123 | # create output folder 124 | os.makedirs(output_path, exist_ok=True) 125 | 126 | print("start processing") 127 | for ind, img_name in enumerate(img_names): 128 | if os.path.isdir(img_name): 129 | continue 130 | 131 | print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) 132 | # input 133 | 134 | img = util.io.read_image(img_name) 135 | 136 | if args.kitti_crop is True: 137 | height, width, _ = img.shape 138 | top = height - 352 139 | left = (width - 1216) // 2 140 | img = img[top : top + 352, left : left + 1216, :] 141 | 142 | img_input = transform({"image": img})["image"] 143 | 144 | # compute 145 | with torch.no_grad(): 146 | sample = torch.from_numpy(img_input).to(device).unsqueeze(0) 147 | 148 | if optimize == True and device == torch.device("cuda"): 149 | sample = sample.to(memory_format=torch.channels_last) 150 | sample = sample.half() 151 | 152 | prediction = model.forward(sample) 153 | prediction = ( 154 | torch.nn.functional.interpolate( 155 | prediction.unsqueeze(1), 156 | size=img.shape[:2], 157 | mode="bicubic", 158 | align_corners=False, 159 | ) 160 | .squeeze() 161 | .cpu() 162 | .numpy() 163 | ) 164 | 165 | if model_type == "dpt_hybrid_kitti": 166 | prediction *= 256 167 | 168 | if model_type == "dpt_hybrid_nyu": 169 | prediction *= 1000.0 170 | 171 | filename = os.path.join( 172 | output_path, os.path.splitext(os.path.basename(img_name))[0] 173 | ) 174 | util.io.write_depth(filename, prediction, bits=2, absolute_depth=args.absolute_depth) 175 | 176 | print("finished") 177 | 178 | 179 | if __name__ == "__main__": 180 | parser = argparse.ArgumentParser() 181 | 182 | parser.add_argument( 183 | "-i", "--input_path", default="input", help="folder with input images" 184 | ) 185 | 186 | parser.add_argument( 187 | "-o", 188 | "--output_path", 189 | default="output_monodepth", 190 | help="folder for output images", 191 | ) 192 | 193 | parser.add_argument( 194 | "-m", "--model_weights", default=None, help="path to model weights" 195 | ) 196 | 197 | parser.add_argument( 198 | "-t", 199 | "--model_type", 200 | default="dpt_hybrid", 201 | help="model type [dpt_large|dpt_hybrid|midas_v21]", 202 | ) 203 | 204 | parser.add_argument("--kitti_crop", dest="kitti_crop", action="store_true") 205 | parser.add_argument("--absolute_depth", dest="absolute_depth", action="store_true") 206 | 207 | parser.add_argument("--optimize", dest="optimize", action="store_true") 208 | parser.add_argument("--no-optimize", dest="optimize", action="store_false") 209 | 210 | parser.set_defaults(optimize=True) 211 | parser.set_defaults(kitti_crop=False) 212 | parser.set_defaults(absolute_depth=False) 213 | 214 | args = parser.parse_args() 215 | 216 | default_models = { 217 | "midas_v21": "weights/midas_v21-f6b98070.pt", 218 | "dpt_large": "weights/dpt_large-midas-2f21e586.pt", 219 | "dpt_hybrid": "weights/dpt_hybrid-midas-501f0c75.pt", 220 | "dpt_hybrid_kitti": "weights/dpt_hybrid_kitti-cb926ef4.pt", 221 | "dpt_hybrid_nyu": "weights/dpt_hybrid_nyu-2ce69ec7.pt", 222 | } 223 | 224 | if args.model_weights is None: 225 | args.model_weights = default_models[args.model_type] 226 | 227 | # set torch options 228 | torch.backends.cudnn.enabled = True 229 | torch.backends.cudnn.benchmark = True 230 | 231 | # compute depth maps 232 | run( 233 | args.input_path, 234 | args.output_path, 235 | args.model_weights, 236 | args.model_type, 237 | args.optimize, 238 | ) 239 | -------------------------------------------------------------------------------- /run_segmentation.py: -------------------------------------------------------------------------------- 1 | """Compute segmentation maps for images in the input folder. 2 | """ 3 | import os 4 | import glob 5 | import cv2 6 | import argparse 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | import util.io 12 | 13 | from torchvision.transforms import Compose 14 | from dpt.models import DPTSegmentationModel 15 | from dpt.transforms import Resize, NormalizeImage, PrepareForNet 16 | 17 | 18 | def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=True): 19 | """Run segmentation network 20 | 21 | Args: 22 | input_path (str): path to input folder 23 | output_path (str): path to output folder 24 | model_path (str): path to saved model 25 | """ 26 | print("initialize") 27 | 28 | # select device 29 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | print("device: %s" % device) 31 | 32 | net_w = net_h = 480 33 | 34 | # load network 35 | if model_type == "dpt_large": 36 | model = DPTSegmentationModel( 37 | 150, 38 | path=model_path, 39 | backbone="vitl16_384", 40 | ) 41 | elif model_type == "dpt_hybrid": 42 | model = DPTSegmentationModel( 43 | 150, 44 | path=model_path, 45 | backbone="vitb_rn50_384", 46 | ) 47 | else: 48 | assert ( 49 | False 50 | ), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]" 51 | 52 | transform = Compose( 53 | [ 54 | Resize( 55 | net_w, 56 | net_h, 57 | resize_target=None, 58 | keep_aspect_ratio=True, 59 | ensure_multiple_of=32, 60 | resize_method="minimal", 61 | image_interpolation_method=cv2.INTER_CUBIC, 62 | ), 63 | NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 64 | PrepareForNet(), 65 | ] 66 | ) 67 | 68 | model.eval() 69 | 70 | if optimize == True and device == torch.device("cuda"): 71 | model = model.to(memory_format=torch.channels_last) 72 | model = model.half() 73 | 74 | model.to(device) 75 | 76 | # get input 77 | img_names = glob.glob(os.path.join(input_path, "*")) 78 | num_images = len(img_names) 79 | 80 | # create output folder 81 | os.makedirs(output_path, exist_ok=True) 82 | 83 | print("start processing") 84 | 85 | for ind, img_name in enumerate(img_names): 86 | 87 | print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) 88 | 89 | # input 90 | img = util.io.read_image(img_name) 91 | img_input = transform({"image": img})["image"] 92 | 93 | # compute 94 | with torch.no_grad(): 95 | sample = torch.from_numpy(img_input).to(device).unsqueeze(0) 96 | if optimize == True and device == torch.device("cuda"): 97 | sample = sample.to(memory_format=torch.channels_last) 98 | sample = sample.half() 99 | 100 | out = model.forward(sample) 101 | 102 | prediction = torch.nn.functional.interpolate( 103 | out, size=img.shape[:2], mode="bicubic", align_corners=False 104 | ) 105 | prediction = torch.argmax(prediction, dim=1) + 1 106 | prediction = prediction.squeeze().cpu().numpy() 107 | 108 | # output 109 | filename = os.path.join( 110 | output_path, os.path.splitext(os.path.basename(img_name))[0] 111 | ) 112 | util.io.write_segm_img(filename, img, prediction, alpha=0.5) 113 | 114 | print("finished") 115 | 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser() 119 | 120 | parser.add_argument( 121 | "-i", "--input_path", default="input", help="folder with input images" 122 | ) 123 | 124 | parser.add_argument( 125 | "-o", "--output_path", default="output_semseg", help="folder for output images" 126 | ) 127 | 128 | parser.add_argument( 129 | "-m", 130 | "--model_weights", 131 | default=None, 132 | help="path to the trained weights of model", 133 | ) 134 | 135 | # 'vit_large', 'vit_hybrid' 136 | parser.add_argument("-t", "--model_type", default="dpt_hybrid", help="model type") 137 | 138 | parser.add_argument("--optimize", dest="optimize", action="store_true") 139 | parser.add_argument("--no-optimize", dest="optimize", action="store_false") 140 | parser.set_defaults(optimize=True) 141 | 142 | args = parser.parse_args() 143 | 144 | default_models = { 145 | "dpt_large": "weights/dpt_large-ade20k-b12dca68.pt", 146 | "dpt_hybrid": "weights/dpt_hybrid-ade20k-53898607.pt", 147 | } 148 | 149 | if args.model_weights is None: 150 | args.model_weights = default_models[args.model_type] 151 | 152 | # set torch options 153 | torch.backends.cudnn.enabled = True 154 | torch.backends.cudnn.benchmark = True 155 | 156 | # compute segmentation maps 157 | run( 158 | args.input_path, 159 | args.output_path, 160 | args.model_weights, 161 | args.model_type, 162 | args.optimize, 163 | ) 164 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | __version__ = '0.0.1dev1' 4 | 5 | setuptools.setup( 6 | name='dpt', 7 | version=__version__, 8 | packages=setuptools.find_packages(), 9 | # Only put dependencies that's not depends on cuda directly. 10 | install_requires=['timm'] 11 | ) 12 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/DPT/cd3fe90bb4c48577535cc4d51b602acca688a2ee/util/__init__.py -------------------------------------------------------------------------------- /util/io.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth. 2 | """ 3 | import sys 4 | import re 5 | import numpy as np 6 | import cv2 7 | import torch 8 | 9 | from PIL import Image 10 | 11 | 12 | from .pallete import get_mask_pallete 13 | 14 | def read_pfm(path): 15 | """Read pfm file. 16 | 17 | Args: 18 | path (str): path to file 19 | 20 | Returns: 21 | tuple: (data, scale) 22 | """ 23 | with open(path, "rb") as file: 24 | 25 | color = None 26 | width = None 27 | height = None 28 | scale = None 29 | endian = None 30 | 31 | header = file.readline().rstrip() 32 | if header.decode("ascii") == "PF": 33 | color = True 34 | elif header.decode("ascii") == "Pf": 35 | color = False 36 | else: 37 | raise Exception("Not a PFM file: " + path) 38 | 39 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 40 | if dim_match: 41 | width, height = list(map(int, dim_match.groups())) 42 | else: 43 | raise Exception("Malformed PFM header.") 44 | 45 | scale = float(file.readline().decode("ascii").rstrip()) 46 | if scale < 0: 47 | # little-endian 48 | endian = "<" 49 | scale = -scale 50 | else: 51 | # big-endian 52 | endian = ">" 53 | 54 | data = np.fromfile(file, endian + "f") 55 | shape = (height, width, 3) if color else (height, width) 56 | 57 | data = np.reshape(data, shape) 58 | data = np.flipud(data) 59 | 60 | return data, scale 61 | 62 | 63 | def write_pfm(path, image, scale=1): 64 | """Write pfm file. 65 | 66 | Args: 67 | path (str): pathto file 68 | image (array): data 69 | scale (int, optional): Scale. Defaults to 1. 70 | """ 71 | 72 | with open(path, "wb") as file: 73 | color = None 74 | 75 | if image.dtype.name != "float32": 76 | raise Exception("Image dtype must be float32.") 77 | 78 | image = np.flipud(image) 79 | 80 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 81 | color = True 82 | elif ( 83 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 84 | ): # greyscale 85 | color = False 86 | else: 87 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 88 | 89 | file.write("PF\n" if color else "Pf\n".encode()) 90 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 91 | 92 | endian = image.dtype.byteorder 93 | 94 | if endian == "<" or endian == "=" and sys.byteorder == "little": 95 | scale = -scale 96 | 97 | file.write("%f\n".encode() % scale) 98 | 99 | image.tofile(file) 100 | 101 | 102 | def read_image(path): 103 | """Read image and output RGB image (0-1). 104 | 105 | Args: 106 | path (str): path to file 107 | 108 | Returns: 109 | array: RGB image (0-1) 110 | """ 111 | img = cv2.imread(path) 112 | 113 | if img.ndim == 2: 114 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 115 | 116 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 117 | 118 | return img 119 | 120 | 121 | def resize_image(img): 122 | """Resize image and make it fit for network. 123 | 124 | Args: 125 | img (array): image 126 | 127 | Returns: 128 | tensor: data ready for network 129 | """ 130 | height_orig = img.shape[0] 131 | width_orig = img.shape[1] 132 | 133 | if width_orig > height_orig: 134 | scale = width_orig / 384 135 | else: 136 | scale = height_orig / 384 137 | 138 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 139 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 140 | 141 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 142 | 143 | img_resized = ( 144 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 145 | ) 146 | img_resized = img_resized.unsqueeze(0) 147 | 148 | return img_resized 149 | 150 | 151 | def resize_depth(depth, width, height): 152 | """Resize depth map and bring to CPU (numpy). 153 | 154 | Args: 155 | depth (tensor): depth 156 | width (int): image width 157 | height (int): image height 158 | 159 | Returns: 160 | array: processed depth 161 | """ 162 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 163 | 164 | depth_resized = cv2.resize( 165 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 166 | ) 167 | 168 | return depth_resized 169 | 170 | 171 | def write_depth(path, depth, bits=1, absolute_depth=False): 172 | """Write depth map to pfm and png file. 173 | 174 | Args: 175 | path (str): filepath without extension 176 | depth (array): depth 177 | """ 178 | write_pfm(path + ".pfm", depth.astype(np.float32)) 179 | 180 | if absolute_depth: 181 | out = depth 182 | else: 183 | depth_min = depth.min() 184 | depth_max = depth.max() 185 | 186 | max_val = (2 ** (8 * bits)) - 1 187 | 188 | if depth_max - depth_min > np.finfo("float").eps: 189 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 190 | else: 191 | out = np.zeros(depth.shape, dtype=depth.dtype) 192 | 193 | if bits == 1: 194 | cv2.imwrite(path + ".png", out.astype("uint8"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) 195 | elif bits == 2: 196 | cv2.imwrite(path + ".png", out.astype("uint16"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) 197 | 198 | return 199 | 200 | 201 | def write_segm_img(path, image, labels, palette="detail", alpha=0.5): 202 | """Write depth map to pfm and png file. 203 | 204 | Args: 205 | path (str): filepath without extension 206 | image (array): input image 207 | labels (array): labeling of the image 208 | """ 209 | 210 | mask = get_mask_pallete(labels, "ade20k") 211 | 212 | img = Image.fromarray(np.uint8(255*image)).convert("RGBA") 213 | seg = mask.convert("RGBA") 214 | 215 | out = Image.blend(img, seg, alpha) 216 | 217 | out.save(path + ".png") 218 | 219 | return 220 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from dpt.vit import get_mean_attention_map 4 | 5 | def visualize_attention(input, model, prediction, model_type): 6 | input = (input + 1.0)/2.0 7 | 8 | attn1 = model.pretrained.attention["attn_1"] 9 | attn2 = model.pretrained.attention["attn_2"] 10 | attn3 = model.pretrained.attention["attn_3"] 11 | attn4 = model.pretrained.attention["attn_4"] 12 | 13 | plt.subplot(3,4,1), plt.imshow(input.squeeze().permute(1,2,0)), plt.title("Input", fontsize=8), plt.axis("off") 14 | plt.subplot(3,4,2), plt.imshow(prediction), plt.set_cmap("inferno"), plt.title("Prediction", fontsize=8), plt.axis("off") 15 | 16 | if model_type == "dpt_hybrid": 17 | h = [3,6,9,12] 18 | else: 19 | h = [6,12,18,24] 20 | 21 | # upper left 22 | plt.subplot(345), 23 | ax1 = plt.imshow(get_mean_attention_map(attn1, 1, input.shape)) 24 | plt.ylabel("Upper left corner", fontsize=8) 25 | plt.title(f"Layer {h[0]}", fontsize=8) 26 | gc = plt.gca() 27 | gc.axes.xaxis.set_ticklabels([]) 28 | gc.axes.yaxis.set_ticklabels([]) 29 | gc.axes.xaxis.set_ticks([]) 30 | gc.axes.yaxis.set_ticks([]) 31 | 32 | 33 | plt.subplot(346), 34 | plt.imshow(get_mean_attention_map(attn2, 1, input.shape)) 35 | plt.title(f"Layer {h[1]}", fontsize=8) 36 | plt.axis("off"), 37 | 38 | plt.subplot(347), 39 | plt.imshow(get_mean_attention_map(attn3, 1, input.shape)) 40 | plt.title(f"Layer {h[2]}", fontsize=8) 41 | plt.axis("off"), 42 | 43 | 44 | plt.subplot(348), 45 | plt.imshow(get_mean_attention_map(attn4, 1, input.shape)) 46 | plt.title(f"Layer {h[3]}", fontsize=8) 47 | plt.axis("off"), 48 | 49 | 50 | # lower right 51 | plt.subplot(3,4,9), plt.imshow(get_mean_attention_map(attn1, -1, input.shape)) 52 | plt.ylabel("Lower right corner", fontsize=8) 53 | gc = plt.gca() 54 | gc.axes.xaxis.set_ticklabels([]) 55 | gc.axes.yaxis.set_ticklabels([]) 56 | gc.axes.xaxis.set_ticks([]) 57 | gc.axes.yaxis.set_ticks([]) 58 | 59 | plt.subplot(3,4,10), plt.imshow(get_mean_attention_map(attn2, -1, input.shape)), plt.axis("off") 60 | plt.subplot(3,4,11), plt.imshow(get_mean_attention_map(attn3, -1, input.shape)), plt.axis("off") 61 | plt.subplot(3,4,12), plt.imshow(get_mean_attention_map(attn4, -1, input.shape)), plt.axis("off") 62 | plt.tight_layout() 63 | plt.show() 64 | -------------------------------------------------------------------------------- /util/pallete.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | from PIL import Image 12 | 13 | def get_mask_pallete(npimg, dataset='detail'): 14 | """Get image color pallete for visualizing masks""" 15 | # recovery boundary 16 | if dataset == 'pascal_voc': 17 | npimg[npimg==21] = 255 18 | # put colormap 19 | out_img = Image.fromarray(npimg.squeeze().astype('uint8')) 20 | if dataset == 'ade20k': 21 | out_img.putpalette(adepallete) 22 | elif dataset == 'citys': 23 | out_img.putpalette(citypallete) 24 | elif dataset in ('detail', 'pascal_voc', 'pascal_aug'): 25 | out_img.putpalette(vocpallete) 26 | return out_img 27 | 28 | def _get_voc_pallete(num_cls): 29 | n = num_cls 30 | pallete = [0]*(n*3) 31 | for j in range(0,n): 32 | lab = j 33 | pallete[j*3+0] = 0 34 | pallete[j*3+1] = 0 35 | pallete[j*3+2] = 0 36 | i = 0 37 | while (lab > 0): 38 | pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i)) 39 | pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i)) 40 | pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i)) 41 | i = i + 1 42 | lab >>= 3 43 | return pallete 44 | 45 | vocpallete = _get_voc_pallete(256) 46 | 47 | adepallete = [0,0,0,120,120,120,180,120,120,6,230,230,80,50,50,4,200,3,120,120,80,140,140,140,204,5,255,230,230,230,4,250,7,224,5,255,235,255,7,150,5,61,120,120,70,8,255,51,255,6,82,143,255,140,204,255,4,255,51,7,204,70,3,0,102,200,61,230,250,255,6,51,11,102,255,255,7,71,255,9,224,9,7,230,220,220,220,255,9,92,112,9,255,8,255,214,7,255,224,255,184,6,10,255,71,255,41,10,7,255,255,224,255,8,102,8,255,255,61,6,255,194,7,255,122,8,0,255,20,255,8,41,255,5,153,6,51,255,235,12,255,160,150,20,0,163,255,140,140,140,250,10,15,20,255,0,31,255,0,255,31,0,255,224,0,153,255,0,0,0,255,255,71,0,0,235,255,0,173,255,31,0,255,11,200,200,255,82,0,0,255,245,0,61,255,0,255,112,0,255,133,255,0,0,255,163,0,255,102,0,194,255,0,0,143,255,51,255,0,0,82,255,0,255,41,0,255,173,10,0,255,173,255,0,0,255,153,255,92,0,255,0,255,255,0,245,255,0,102,255,173,0,255,0,20,255,184,184,0,31,255,0,255,61,0,71,255,255,0,204,0,255,194,0,255,82,0,10,255,0,112,255,51,0,255,0,194,255,0,122,255,0,255,163,255,153,0,0,255,10,255,112,0,143,255,0,82,0,255,163,255,0,255,235,0,8,184,170,133,0,255,0,255,92,184,0,255,255,0,31,0,184,255,0,214,255,255,0,112,92,255,0,0,224,255,112,224,255,70,184,160,163,0,255,153,0,255,71,255,0,255,0,163,255,204,0,255,0,143,0,255,235,133,255,0,255,0,235,245,0,255,255,0,122,255,245,0,10,190,212,214,255,0,0,204,255,20,0,255,255,255,0,0,153,255,0,41,255,0,255,204,41,0,255,41,255,0,173,0,255,0,245,255,71,0,255,122,0,255,0,255,184,0,92,255,184,255,0,0,133,255,255,214,0,25,194,194,102,255,0,92,0,255] 48 | 49 | citypallete = [ 50 | 128,64,128,244,35,232,70,70,70,102,102,156,190,153,153,153,153,153,250,170,30,220,220,0,107,142,35,152,251,152,70,130,180,220,20,60,255,0,0,0,0,142,0,0,70,0,60,100,0,80,100,0,0,230,119,11,32,128,192,0,0,64,128,128,64,128,0,192,128,128,192,128,64,64,0,192,64,0,64,192,0,192,192,0,64,64,128,192,64,128,64,192,128,192,192,128,0,0,64,128,0,64,0,128,64,128,128,64,0,0,192,128,0,192,0,128,192,128,128,192,64,0,64,192,0,64,64,128,64,192,128,64,64,0,192,192,0,192,64,128,192,192,128,192,0,64,64,128,64,64,0,192,64,128,192,64,0,64,192,128,64,192,0,192,192,128,192,192,64,64,64,192,64,64,64,192,64,192,192,64,64,64,192,192,64,192,64,192,192,192,192,192,32,0,0,160,0,0,32,128,0,160,128,0,32,0,128,160,0,128,32,128,128,160,128,128,96,0,0,224,0,0,96,128,0,224,128,0,96,0,128,224,0,128,96,128,128,224,128,128,32,64,0,160,64,0,32,192,0,160,192,0,32,64,128,160,64,128,32,192,128,160,192,128,96,64,0,224,64,0,96,192,0,224,192,0,96,64,128,224,64,128,96,192,128,224,192,128,32,0,64,160,0,64,32,128,64,160,128,64,32,0,192,160,0,192,32,128,192,160,128,192,96,0,64,224,0,64,96,128,64,224,128,64,96,0,192,224,0,192,96,128,192,224,128,192,32,64,64,160,64,64,32,192,64,160,192,64,32,64,192,160,64,192,32,192,192,160,192,192,96,64,64,224,64,64,96,192,64,224,192,64,96,64,192,224,64,192,96,192,192,224,192,192,0,32,0,128,32,0,0,160,0,128,160,0,0,32,128,128,32,128,0,160,128,128,160,128,64,32,0,192,32,0,64,160,0,192,160,0,64,32,128,192,32,128,64,160,128,192,160,128,0,96,0,128,96,0,0,224,0,128,224,0,0,96,128,128,96,128,0,224,128,128,224,128,64,96,0,192,96,0,64,224,0,192,224,0,64,96,128,192,96,128,64,224,128,192,224,128,0,32,64,128,32,64,0,160,64,128,160,64,0,32,192,128,32,192,0,160,192,128,160,192,64,32,64,192,32,64,64,160,64,192,160,64,64,32,192,192,32,192,64,160,192,192,160,192,0,96,64,128,96,64,0,224,64,128,224,64,0,96,192,128,96,192,0,224,192,128,224,192,64,96,64,192,96,64,64,224,64,192,224,64,64,96,192,192,96,192,64,224,192,192,224,192,32,32,0,160,32,0,32,160,0,160,160,0,32,32,128,160,32,128,32,160,128,160,160,128,96,32,0,224,32,0,96,160,0,224,160,0,96,32,128,224,32,128,96,160,128,224,160,128,32,96,0,160,96,0,32,224,0,160,224,0,32,96,128,160,96,128,32,224,128,160,224,128,96,96,0,224,96,0,96,224,0,224,224,0,96,96,128,224,96,128,96,224,128,224,224,128,32,32,64,160,32,64,32,160,64,160,160,64,32,32,192,160,32,192,32,160,192,160,160,192,96,32,64,224,32,64,96,160,64,224,160,64,96,32,192,224,32,192,96,160,192,224,160,192,32,96,64,160,96,64,32,224,64,160,224,64,32,96,192,160,96,192,32,224,192,160,224,192,96,96,64,224,96,64,96,224,64,224,224,64,96,96,192,224,96,192,96,224,192,0,0,0] 51 | -------------------------------------------------------------------------------- /weights/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/DPT/cd3fe90bb4c48577535cc4d51b602acca688a2ee/weights/.placeholder --------------------------------------------------------------------------------