├── .gitignore ├── LICENSE.md ├── README.md ├── data ├── kitti15_img_left.jpg ├── kitti15_img_right.jpg ├── sf_img_left.jpg ├── sf_img_right.jpg └── teaser.png ├── envs ├── bi3d_conda_env.yml └── bi3d_pytorch_19_01.DockerFile └── src ├── models ├── Bi3DNet.py ├── DispRefine2D.py ├── FeatExtractNet.py ├── GCNet.py ├── PSMNet.py ├── RefineNet2D.py ├── RefineNet3D.py ├── SegNet2D.py └── __init__.py ├── project.toml ├── run_binary_depth_estimation.py ├── run_continuous_depth_estimation.py ├── run_demo_kitti15.sh ├── run_demo_sf.sh └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Add any directories, files, or patterns you don't want to be tracked by version control 2 | 3 | *.png 4 | *.pfm 5 | *.pth.tar 6 | *.npy 7 | *.ppm 8 | *.pyc 9 | *.tar 10 | *.zip 11 | *.gif -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # NVIDIA Source Code License for Bi3D 2 | 3 | ## 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Software” means the original work of authorship made available under this License. 8 | 9 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License. 10 | 11 | “NVIDIA Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by NVIDIA or its affiliates. 12 | 13 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 14 | 15 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 16 | 17 | ## 2. License Grant 18 | 19 | ### 2.1 Copyright Grant. 20 | Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 21 | 22 | ## 3. Limitations 23 | 24 | ### 3.1 Redistribution. 25 | You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 26 | 27 | ### 3.2 Derivative Works. 28 | You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 29 | 30 | ### 3.3 Use Limitation. 31 | The Work and any derivative works thereof only may be used or intended for use non-commercially and with NVIDIA Processors. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 32 | 33 | ### 3.4 Patent Claims. 34 | If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 35 | 36 | ### 3.5 Trademarks. 37 | This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. 38 | 39 | ### 3.6 Termination. 40 | If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately. 41 | 42 | ## 4. Disclaimer of Warranty. 43 | 44 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 45 | 46 | ## 5. Limitation of Liability. 47 | 48 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Bi3D — Official PyTorch Implementation 2 | 3 | ![Teaser image](data/teaser.png) 4 | 5 | **Bi3D: Stereo Depth Estimation via Binary Classifications**
6 | Abhishek Badki, Alejandro Troccoli, Kihwan Kim, Jan Kautz, Pradeep Sen, and Orazio Gallo
7 | IEEE CVPR 2020
8 | 9 | ## Abstract: 10 | *Stereo-based depth estimation is a cornerstone of computer vision, with state-of-the-art methods delivering accurate results in real time. For several applications such as autonomous navigation, however, it may be useful to trade accuracy for lower latency. We present Bi3D, a method that estimates depth via a series of binary classifications. Rather than testing if objects are* at *a particular depth D, as existing stereo methods do, it classifies them as being* closer *or* farther *than D. This property offers a powerful mechanism to balance accuracy and latency. Given a strict time budget, Bi3D can detect objects closer than a given distance in as little as a few milliseconds, or estimate depth with arbitrarily coarse quantization, with complexity linear with the number of quantization levels. Bi3D can also use the allotted quantization levels to get continuous depth, but in a specific depth range. For standard stereo (i.e., continuous depth on the whole range), our method is close to or on par with state-of-the-art, finely tuned stereo methods.* 11 | 12 | 13 | ## Paper: 14 | https://arxiv.org/pdf/2005.07274.pdf
15 | 16 | ## Videos:
17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | ## Citing Bi3D: 28 | @InProceedings{badki2020Bi3D, 29 | author = {Badki, Abhishek and Troccoli, Alejandro and Kim, Kihwan and Kautz, Jan and Sen, Pradeep and Gallo, Orazio}, 30 | title = {{Bi3D}: {S}tereo Depth Estimation via Binary Classifications}, 31 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 32 | year = {2020} 33 | } 34 | 35 | or the arXiv paper 36 | 37 | @InProceedings{badki2020Bi3D, 38 | author = {Badki, Abhishek and Troccoli, Alejandro and Kim, Kihwan and Kautz, Jan and Sen, Pradeep and Gallo, Orazio}, 39 | title = {{Bi3D}: {S}tereo Depth Estimation via Binary Classifications}, 40 | booktitle = {arXiv preprint arXiv:2005.07274}, 41 | year = {2020} 42 | } 43 | 44 | 45 | ## Code:
46 | 47 | ### License 48 | 49 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved. 50 | 51 | Licensed under the [NVIDIA Source Code License](LICENSE.md) 52 | 53 | ### Description 54 | 55 | 56 | ### Setup 57 | 58 | We offer two ways of setting up your environemnt, through Docker or Conda. 59 | 60 | #### Docker 61 | For convenience, we provide a Dockerfile to build a container image to run the code. The image will contain the Python dependencies. 62 | 63 | System requirements: 64 | 65 | 1. Docker (Tested on version 19.03.11) 66 | 67 | 2. [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker/wiki) 68 | 69 | 3. NVIDIA GPU driver. 70 | 71 | Build the container image: 72 | ``` 73 | docker build -t bi3d . -f envs/bi3d_pytorch_19_01.DockerFile 74 | ``` 75 | To launch the container, run the following: 76 | ``` 77 | docker run --rm -it --gpus=all -v $(pwd):/bi3d -w /bi3d --net=host --ipc=host bi3d:latest /bin/bash 78 | ``` 79 | 80 | #### Conda 81 | All dependencies will be installed automatically using the following: 82 | ``` 83 | conda env create -f envs/bi3d_conda_env.yml 84 | ``` 85 | You can activate the environment by running: 86 | ``` 87 | conda activate bi3d 88 | ``` 89 | 90 | ### Pre-trained models 91 | Download the pre-trained models [here](https://drive.google.com/file/d/1X4Ing9WumtIxonNXXCzKJulJtPgzk61n). 92 | 93 | ### Run the demo 94 | 95 | ``` 96 | cd src 97 | # RUN DEMO FOR SCENEFLOW DATASET 98 | sh run_demo_sf.sh 99 | # RUN DEMO FOR KITTI15 DATASET 100 | sh run_demo_kitti15.sh 101 | ``` 102 | -------------------------------------------------------------------------------- /data/kitti15_img_left.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Bi3D/4b5fdb48d820b8a5cfd95a7d2d82ea56f18cd597/data/kitti15_img_left.jpg -------------------------------------------------------------------------------- /data/kitti15_img_right.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Bi3D/4b5fdb48d820b8a5cfd95a7d2d82ea56f18cd597/data/kitti15_img_right.jpg -------------------------------------------------------------------------------- /data/sf_img_left.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Bi3D/4b5fdb48d820b8a5cfd95a7d2d82ea56f18cd597/data/sf_img_left.jpg -------------------------------------------------------------------------------- /data/sf_img_right.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Bi3D/4b5fdb48d820b8a5cfd95a7d2d82ea56f18cd597/data/sf_img_right.jpg -------------------------------------------------------------------------------- /data/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/Bi3D/4b5fdb48d820b8a5cfd95a7d2d82ea56f18cd597/data/teaser.png -------------------------------------------------------------------------------- /envs/bi3d_conda_env.yml: -------------------------------------------------------------------------------- 1 | name: bi3d 2 | channels: 3 | - pytorch 4 | - soumith 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - ca-certificates=2020.6.24=0 10 | - certifi=2020.6.20=py37_0 11 | - cudatoolkit=10.0.130=0 12 | - freetype=2.10.2=h5ab3b9f_0 13 | - intel-openmp=2020.1=217 14 | - jpeg=9b=h024ee3a_2 15 | - lcms2=2.11=h396b838_0 16 | - ld_impl_linux-64=2.33.1=h53a641e_7 17 | - libedit=3.1.20191231=h14c3975_1 18 | - libffi=3.3=he6710b0_2 19 | - libgcc-ng=9.1.0=hdf63c60_0 20 | - libgfortran-ng=7.3.0=hdf63c60_0 21 | - libpng=1.6.37=hbc83047_0 22 | - libstdcxx-ng=9.1.0=hdf63c60_0 23 | - libtiff=4.1.0=h2733197_1 24 | - lz4-c=1.9.2=he6710b0_0 25 | - mkl=2020.1=217 26 | - mkl-service=2.3.0=py37he904b0f_0 27 | - mkl_fft=1.1.0=py37h23d657b_0 28 | - mkl_random=1.1.1=py37h0573a6f_0 29 | - ncurses=6.2=he6710b0_1 30 | - ninja=1.9.0=py37hfd86e86_0 31 | - numpy=1.18.5=py37ha1c710e_0 32 | - numpy-base=1.18.5=py37hde5b4d6_0 33 | - olefile=0.46=py_0 34 | - openssl=1.1.1g=h7b6447c_0 35 | - pillow=7.2.0=py37hb39fc2d_0 36 | - pip=20.1.1=py37_1 37 | - python=3.7.7=hcff3b4d_5 38 | - pytorch=1.4.0=py3.7_cuda10.0.130_cudnn7.6.3_0 39 | - readline=8.0=h7b6447c_0 40 | - setuptools=49.2.0=py37_0 41 | - six=1.15.0=py_0 42 | - sqlite=3.32.3=h62c20be_0 43 | - tk=8.6.10=hbc83047_0 44 | - torchvision=0.5.0=py37_cu100 45 | - wheel=0.34.2=py37_0 46 | - xz=5.2.5=h7b6447c_0 47 | - zlib=1.2.11=h7b6447c_3 48 | - zstd=1.4.5=h0b5b093_0 49 | - pip: 50 | - imageio==2.9.0 51 | - opencv-python==4.3.0.36 52 | - protobuf==3.12.2 53 | - tensorboardx==2.1 54 | 55 | -------------------------------------------------------------------------------- /envs/bi3d_pytorch_19_01.DockerFile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:19.01-py3 2 | 3 | RUN pip install Pillow 4 | RUN pip install imageio 5 | RUN pip install tensorboardX 6 | RUN pip install opencv-python 7 | -------------------------------------------------------------------------------- /src/models/Bi3DNet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | import torch.nn.functional as F 14 | 15 | import models.FeatExtractNet as FeatNet 16 | import models.SegNet2D as SegNet 17 | import models.RefineNet2D as RefineNet 18 | import models.RefineNet3D as RefineNet3D 19 | 20 | 21 | __all__ = ["bi3dnet_binary_depth", "bi3dnet_continuous_depth_2D", "bi3dnet_continuous_depth_3D"] 22 | 23 | 24 | def compute_cost_volume(features_left, features_right, disp_ids, max_disp, is_disps_per_example): 25 | 26 | batch_size = features_left.shape[0] 27 | feature_size = features_left.shape[1] 28 | H = features_left.shape[2] 29 | W = features_left.shape[3] 30 | 31 | psv_size = disp_ids.shape[1] 32 | 33 | psv = Variable(features_left.new_zeros(batch_size, psv_size, feature_size * 2, H, W + max_disp)).cuda() 34 | 35 | if is_disps_per_example: 36 | for i in range(batch_size): 37 | psv[i, 0, :feature_size, :, 0:W] = features_left[i] 38 | psv[i, 0, feature_size:, :, disp_ids[i, 0] : W + disp_ids[i, 0]] = features_right[i] 39 | psv = psv.contiguous() 40 | else: 41 | for i in range(psv_size): 42 | psv[:, i, :feature_size, :, 0:W] = features_left 43 | psv[:, i, feature_size:, :, disp_ids[0, i] : W + disp_ids[0, i]] = features_right 44 | psv = psv.contiguous() 45 | 46 | return psv 47 | 48 | 49 | """ 50 | Bi3DNet for continuous depthmap generation. Doesn't use 3D regularization. 51 | """ 52 | 53 | 54 | class Bi3DNetContinuousDepth2D(nn.Module): 55 | def __init__(self, options, featnet_arch, segnet_arch, refinenet_arch=None, max_disparity=192): 56 | 57 | super(Bi3DNetContinuousDepth2D, self).__init__() 58 | 59 | self.max_disparity = max_disparity 60 | self.max_disparity_seg = int(self.max_disparity / 3) 61 | self.is_disps_per_example = False 62 | self.is_save_memory = False 63 | 64 | self.is_refine = True 65 | if refinenet_arch == None: 66 | self.is_refine = False 67 | 68 | self.featnet = FeatNet.__dict__[featnet_arch](options, data=None) 69 | self.segnet = SegNet.__dict__[segnet_arch](options, data=None) 70 | if self.is_refine: 71 | self.refinenet = RefineNet.__dict__[refinenet_arch](options, data=None) 72 | 73 | return 74 | 75 | def forward(self, img_left, img_right, disp_ids): 76 | 77 | batch_size = img_left.shape[0] 78 | psv_size = disp_ids.shape[1] 79 | 80 | if psv_size == 1: 81 | self.is_disps_per_example = True 82 | else: 83 | self.is_disps_per_example = False 84 | 85 | # Feature Extraction 86 | features_left = self.featnet(img_left) 87 | features_right = self.featnet(img_right) 88 | feature_size = features_left.shape[1] 89 | H = features_left.shape[2] 90 | W = features_left.shape[3] 91 | 92 | # Cost Volume Generation 93 | psv = compute_cost_volume( 94 | features_left, features_right, disp_ids, self.max_disparity_seg, self.is_disps_per_example 95 | ) 96 | 97 | psv = psv.view(batch_size * psv_size, feature_size * 2, H, W + self.max_disparity_seg) 98 | 99 | # Segmentation Network 100 | seg_raw_low_res = self.segnet(psv)[:, :, :, :W] 101 | seg_raw_low_res = seg_raw_low_res.view(batch_size, 1, psv_size, H, W) 102 | 103 | # Upsampling 104 | seg_prob_low_res_up = torch.sigmoid( 105 | F.interpolate( 106 | seg_raw_low_res, 107 | size=[psv_size * 3, img_left.size()[-2], img_left.size()[-1]], 108 | mode="trilinear", 109 | align_corners=False, 110 | ) 111 | ) 112 | seg_prob_low_res_up = seg_prob_low_res_up[:, 0, 1:-1, :, :] 113 | 114 | # Projection 115 | disparity_normalized = torch.mean((seg_prob_low_res_up), dim=1, keepdim=True) 116 | 117 | # Refinement 118 | if self.is_refine: 119 | refine_net_input = torch.cat((disparity_normalized, img_left), dim=1) 120 | disparity_normalized = self.refinenet(refine_net_input) 121 | 122 | return seg_prob_low_res_up, disparity_normalized 123 | 124 | 125 | def bi3dnet_continuous_depth_2D(options, data=None): 126 | 127 | print("==> USING Bi3DNetContinuousDepth2D") 128 | for key in options: 129 | if "bi3dnet" in key: 130 | print("{} : {}".format(key, options[key])) 131 | 132 | model = Bi3DNetContinuousDepth2D( 133 | options, 134 | featnet_arch=options["bi3dnet_featnet_arch"], 135 | segnet_arch=options["bi3dnet_segnet_arch"], 136 | refinenet_arch=options["bi3dnet_refinenet_arch"], 137 | max_disparity=options["bi3dnet_max_disparity"], 138 | ) 139 | 140 | if data is not None: 141 | model.load_state_dict(data["state_dict"]) 142 | 143 | return model 144 | 145 | 146 | """ 147 | Bi3DNet for continuous depthmap generation. Uses 3D regularization. 148 | """ 149 | 150 | 151 | class Bi3DNetContinuousDepth3D(nn.Module): 152 | def __init__( 153 | self, 154 | options, 155 | featnet_arch, 156 | segnet_arch, 157 | refinenet_arch=None, 158 | refinenet3d_arch=None, 159 | max_disparity=192, 160 | ): 161 | 162 | super(Bi3DNetContinuousDepth3D, self).__init__() 163 | 164 | self.max_disparity = max_disparity 165 | self.max_disparity_seg = int(self.max_disparity / 3) 166 | self.is_disps_per_example = False 167 | self.is_save_memory = False 168 | 169 | self.is_refine = True 170 | if refinenet_arch == None: 171 | self.is_refine = False 172 | 173 | self.featnet = FeatNet.__dict__[featnet_arch](options, data=None) 174 | self.segnet = SegNet.__dict__[segnet_arch](options, data=None) 175 | if self.is_refine: 176 | self.refinenet = RefineNet.__dict__[refinenet_arch](options, data=None) 177 | self.refinenet3d = RefineNet3D.__dict__[refinenet3d_arch](options, data=None) 178 | 179 | return 180 | 181 | def forward(self, img_left, img_right, disp_ids): 182 | 183 | batch_size = img_left.shape[0] 184 | psv_size = disp_ids.shape[1] 185 | 186 | if psv_size == 1: 187 | self.is_disps_per_example = True 188 | else: 189 | self.is_disps_per_example = False 190 | 191 | # Feature Extraction 192 | features_left = self.featnet(img_left) 193 | features_right = self.featnet(img_right) 194 | feature_size = features_left.shape[1] 195 | H = features_left.shape[2] 196 | W = features_left.shape[3] 197 | 198 | # Cost Volume Generation 199 | psv = compute_cost_volume( 200 | features_left, features_right, disp_ids, self.max_disparity_seg, self.is_disps_per_example 201 | ) 202 | 203 | psv = psv.view(batch_size * psv_size, feature_size * 2, H, W + self.max_disparity_seg) 204 | 205 | # Segmentation Network 206 | seg_raw_low_res = self.segnet(psv)[:, :, :, :W] # cropped to remove excess boundary 207 | seg_raw_low_res = seg_raw_low_res.view(batch_size, 1, psv_size, H, W) 208 | 209 | # Upsampling 210 | seg_prob_low_res_up = torch.sigmoid( 211 | F.interpolate( 212 | seg_raw_low_res, 213 | size=[psv_size * 3, img_left.size()[-2], img_left.size()[-1]], 214 | mode="trilinear", 215 | align_corners=False, 216 | ) 217 | ) 218 | 219 | seg_prob_low_res_up = seg_prob_low_res_up[:, 0, 1:-1, :, :] 220 | 221 | # Upsampling after 3D Regularization 222 | seg_raw_low_res_refined = seg_raw_low_res 223 | seg_raw_low_res_refined[:, :, 1:, :, :] = self.refinenet3d( 224 | features_left, seg_raw_low_res_refined[:, :, 1:, :, :] 225 | ) 226 | 227 | seg_prob_low_res_refined_up = torch.sigmoid( 228 | F.interpolate( 229 | seg_raw_low_res_refined, 230 | size=[psv_size * 3, img_left.size()[-2], img_left.size()[-1]], 231 | mode="trilinear", 232 | align_corners=False, 233 | ) 234 | ) 235 | 236 | seg_prob_low_res_refined_up = seg_prob_low_res_refined_up[:, 0, 1:-1, :, :] 237 | 238 | # Projection 239 | disparity_normalized_noisy = torch.mean((seg_prob_low_res_refined_up), dim=1, keepdim=True) 240 | 241 | # Refinement 242 | if self.is_refine: 243 | refine_net_input = torch.cat((disparity_normalized_noisy, img_left), dim=1) 244 | disparity_normalized = self.refinenet(refine_net_input) 245 | 246 | return ( 247 | seg_prob_low_res_up, 248 | seg_prob_low_res_refined_up, 249 | disparity_normalized_noisy, 250 | disparity_normalized, 251 | ) 252 | 253 | 254 | def bi3dnet_continuous_depth_3D(options, data=None): 255 | 256 | print("==> USING Bi3DNetContinuousDepth3D") 257 | for key in options: 258 | if "bi3dnet" in key: 259 | print("{} : {}".format(key, options[key])) 260 | 261 | model = Bi3DNetContinuousDepth3D( 262 | options, 263 | featnet_arch=options["bi3dnet_featnet_arch"], 264 | segnet_arch=options["bi3dnet_segnet_arch"], 265 | refinenet_arch=options["bi3dnet_refinenet_arch"], 266 | refinenet3d_arch=options["bi3dnet_regnet_arch"], 267 | max_disparity=options["bi3dnet_max_disparity"], 268 | ) 269 | 270 | if data is not None: 271 | model.load_state_dict(data["state_dict"]) 272 | 273 | return model 274 | 275 | 276 | """ 277 | Bi3DNet for binary depthmap generation. 278 | """ 279 | 280 | 281 | class Bi3DNetBinaryDepth(nn.Module): 282 | def __init__( 283 | self, 284 | options, 285 | featnet_arch, 286 | segnet_arch, 287 | refinenet_arch=None, 288 | featnethr_arch=None, 289 | max_disparity=192, 290 | is_disps_per_example=False, 291 | ): 292 | 293 | super(Bi3DNetBinaryDepth, self).__init__() 294 | 295 | self.max_disparity = max_disparity 296 | self.max_disparity_seg = int(max_disparity / 3) 297 | self.is_disps_per_example = is_disps_per_example 298 | 299 | self.is_refine = True 300 | if refinenet_arch == None: 301 | self.is_refine = False 302 | 303 | self.featnet = FeatNet.__dict__[featnet_arch](options, data=None) 304 | self.featnethr = FeatNet.__dict__[featnethr_arch](options, data=None) 305 | self.segnet = SegNet.__dict__[segnet_arch](options, data=None) 306 | if self.is_refine: 307 | self.refinenet = RefineNet.__dict__[refinenet_arch](options, data=None) 308 | 309 | return 310 | 311 | def forward(self, img_left, img_right, disp_ids): 312 | 313 | batch_size = img_left.shape[0] 314 | psv_size = disp_ids.shape[1] 315 | 316 | if psv_size == 1: 317 | self.is_disps_per_example = True 318 | else: 319 | self.is_disps_per_example = False 320 | 321 | # Feature Extraction 322 | features = self.featnet(torch.cat((img_left, img_right), dim=0)) 323 | 324 | features_left = features[:batch_size, :, :, :] 325 | features_right = features[batch_size:, :, :, :] 326 | 327 | if self.is_refine: 328 | features_lefthr = self.featnethr(img_left) 329 | feature_size = features_left.shape[1] 330 | H = features_left.shape[2] 331 | W = features_left.shape[3] 332 | 333 | # Cost Volume Generation 334 | psv = compute_cost_volume( 335 | features_left, features_right, disp_ids, self.max_disparity_seg, self.is_disps_per_example 336 | ) 337 | 338 | psv = psv.view(batch_size * psv_size, feature_size * 2, H, W + self.max_disparity_seg) 339 | 340 | # Segmentation Network 341 | seg_raw_low_res = self.segnet(psv)[:, :, :, :W] # cropped to remove excess boundary 342 | seg_prob_low_res = torch.sigmoid(seg_raw_low_res) 343 | seg_prob_low_res = seg_prob_low_res.view(batch_size, psv_size, H, W) 344 | 345 | seg_prob_low_res_up = F.interpolate( 346 | seg_prob_low_res, size=img_left.size()[-2:], mode="bilinear", align_corners=False 347 | ) 348 | out = [] 349 | out.append(seg_prob_low_res_up) 350 | 351 | # Refinement 352 | if self.is_refine: 353 | seg_raw_high_res = F.interpolate( 354 | seg_raw_low_res, size=img_left.size()[-2:], mode="bilinear", align_corners=False 355 | ) 356 | # Refine Net 357 | features_left_expand = ( 358 | features_lefthr[:, None, :, :, :].expand(-1, psv_size, -1, -1, -1).contiguous() 359 | ) 360 | features_left_expand = features_left_expand.view( 361 | -1, features_lefthr.size()[1], features_lefthr.size()[2], features_lefthr.size()[3] 362 | ) 363 | refine_net_input = torch.cat((seg_raw_high_res, features_left_expand), dim=1) 364 | 365 | seg_raw_high_res = self.refinenet(refine_net_input) 366 | 367 | seg_prob_high_res = torch.sigmoid(seg_raw_high_res) 368 | seg_prob_high_res = seg_prob_high_res.view( 369 | batch_size, psv_size, img_left.size()[-2], img_left.size()[-1] 370 | ) 371 | out.append(seg_prob_high_res) 372 | else: 373 | out.append(seg_prob_low_res_up) 374 | 375 | return out 376 | 377 | 378 | def bi3dnet_binary_depth(options, data=None): 379 | 380 | print("==> USING Bi3DNetBinaryDepth") 381 | for key in options: 382 | if "bi3dnet" in key: 383 | print("{} : {}".format(key, options[key])) 384 | 385 | model = Bi3DNetBinaryDepth( 386 | options, 387 | featnet_arch=options["bi3dnet_featnet_arch"], 388 | segnet_arch=options["bi3dnet_segnet_arch"], 389 | refinenet_arch=options["bi3dnet_refinenet_arch"], 390 | featnethr_arch=options["bi3dnet_featnethr_arch"], 391 | max_disparity=options["bi3dnet_max_disparity"], 392 | is_disps_per_example=options["bi3dnet_disps_per_example_true"], 393 | ) 394 | 395 | if data is not None: 396 | model.load_state_dict(data["state_dict"]) 397 | 398 | return model 399 | -------------------------------------------------------------------------------- /src/models/DispRefine2D.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Xuanyi Li (xuanyili.edu@gmail.com) 4 | # Copyright (c) 2020 NVIDIA 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import math 28 | 29 | from models.PSMNet import conv2d 30 | from models.PSMNet import conv2d_lrelu 31 | 32 | """ 33 | The code in this file is adapted 34 | from https://github.com/meteorshowers/StereoNet-ActiveStereoNet 35 | """ 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | 40 | expansion = 1 41 | 42 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 43 | 44 | super(BasicBlock, self).__init__() 45 | 46 | self.conv1 = conv2d_lrelu(inplanes, planes, 3, stride, pad, dilation) 47 | self.conv2 = conv2d(planes, planes, 3, 1, pad, dilation) 48 | 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x): 53 | 54 | out = self.conv1(x) 55 | out = self.conv2(out) 56 | 57 | if self.downsample is not None: 58 | x = self.downsample(x) 59 | 60 | out += x 61 | 62 | return out 63 | 64 | 65 | class DispRefineNet(nn.Module): 66 | def __init__(self, out_planes=32): 67 | 68 | super(DispRefineNet, self).__init__() 69 | 70 | self.out_planes = out_planes 71 | 72 | self.conv2d_feature = conv2d_lrelu( 73 | in_planes=4, out_planes=self.out_planes, kernel_size=3, stride=1, pad=1, dilation=1 74 | ) 75 | 76 | self.residual_astrous_blocks = nn.ModuleList() 77 | astrous_list = [1, 2, 4, 8, 1, 1] 78 | for di in astrous_list: 79 | self.residual_astrous_blocks.append( 80 | BasicBlock(self.out_planes, self.out_planes, stride=1, downsample=None, pad=1, dilation=di) 81 | ) 82 | 83 | self.conv2d_out = nn.Conv2d(self.out_planes, 1, kernel_size=3, stride=1, padding=1) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 89 | elif isinstance(m, nn.Conv3d): 90 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 91 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 92 | elif isinstance(m, nn.BatchNorm2d): 93 | m.weight.data.fill_(1) 94 | m.bias.data.zero_() 95 | elif isinstance(m, nn.BatchNorm3d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | elif isinstance(m, nn.Linear): 99 | m.bias.data.zero_() 100 | 101 | return 102 | 103 | def forward(self, x): 104 | 105 | disp = x[:, 0, :, :][:, None, :, :] 106 | output = self.conv2d_feature(x) 107 | 108 | for astrous_block in self.residual_astrous_blocks: 109 | output = astrous_block(output) 110 | 111 | output = self.conv2d_out(output) # residual disparity 112 | output = output + disp # final disparity 113 | 114 | return output 115 | -------------------------------------------------------------------------------- /src/models/FeatExtractNet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from __future__ import print_function 10 | import torch 11 | import torch.nn as nn 12 | import math 13 | 14 | from models.PSMNet import conv2d 15 | from models.PSMNet import conv2d_relu 16 | from models.PSMNet import FeatExtractNetSPP 17 | 18 | __all__ = ["featextractnetspp", "featextractnethr"] 19 | 20 | 21 | """ 22 | Feature extraction network. 23 | Generates 16D features at the image resolution. 24 | Used for final refinement. 25 | """ 26 | 27 | 28 | class FeatExtractNetHR(nn.Module): 29 | def __init__(self, out_planes=16): 30 | 31 | super(FeatExtractNetHR, self).__init__() 32 | 33 | self.conv1 = nn.Sequential( 34 | conv2d_relu(3, out_planes, kernel_size=3, stride=1, pad=1, dilation=1), 35 | conv2d_relu(out_planes, out_planes, kernel_size=3, stride=1, pad=1, dilation=1), 36 | nn.Conv2d(out_planes, out_planes, kernel_size=1, padding=0, stride=1, bias=False), 37 | ) 38 | 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 42 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 43 | elif isinstance(m, nn.Conv3d): 44 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 45 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 46 | elif isinstance(m, nn.BatchNorm2d): 47 | m.weight.data.fill_(1) 48 | m.bias.data.zero_() 49 | elif isinstance(m, nn.BatchNorm3d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.Linear): 53 | m.bias.data.zero_() 54 | 55 | return 56 | 57 | def forward(self, input): 58 | 59 | output = self.conv1(input) 60 | return output 61 | 62 | 63 | def featextractnethr(options, data=None): 64 | 65 | print("==> USING FeatExtractNetHR") 66 | for key in options: 67 | if "featextractnethr" in key: 68 | print("{} : {}".format(key, options[key])) 69 | 70 | model = FeatExtractNetHR(out_planes=options["featextractnethr_out_planes"]) 71 | 72 | if data is not None: 73 | model.load_state_dict(data["state_dict"]) 74 | 75 | return model 76 | 77 | 78 | """ 79 | Feature extraction network. 80 | Generates 32D features at 3x less resolution. 81 | Uses Spatial Pyramid Pooling inspired by PSMNet. 82 | """ 83 | 84 | 85 | def featextractnetspp(options, data=None): 86 | 87 | print("==> USING FeatExtractNetSPP") 88 | for key in options: 89 | if "feat" in key: 90 | print("{} : {}".format(key, options[key])) 91 | 92 | model = FeatExtractNetSPP() 93 | 94 | if data is not None: 95 | model.load_state_dict(data["state_dict"]) 96 | 97 | return model 98 | -------------------------------------------------------------------------------- /src/models/GCNet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Wang Yufeng 2 | # Copyright (c) 2020 NVIDIA 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | """ 20 | The code in this file is adapted from https://github.com/wyf2017/DSMnet 21 | """ 22 | 23 | 24 | def conv3d_relu(in_planes, out_planes, kernel_size=3, stride=1, activefun=nn.ReLU(inplace=True)): 25 | 26 | return nn.Sequential( 27 | nn.Conv3d(in_planes, out_planes, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=True), 28 | activefun, 29 | ) 30 | 31 | 32 | def deconv3d_relu(in_planes, out_planes, kernel_size=4, stride=2, activefun=nn.ReLU(inplace=True)): 33 | 34 | assert stride > 1 35 | p = (kernel_size - 1) // 2 36 | op = stride - (kernel_size - 2 * p) 37 | return nn.Sequential( 38 | nn.ConvTranspose3d( 39 | in_planes, out_planes, kernel_size, stride, padding=p, output_padding=op, bias=True 40 | ), 41 | activefun, 42 | ) 43 | 44 | 45 | """ 46 | GCNet style 3D regularization network 47 | """ 48 | 49 | 50 | class feature3d(nn.Module): 51 | def __init__(self, num_F): 52 | 53 | super(feature3d, self).__init__() 54 | self.F = num_F 55 | 56 | self.l19 = conv3d_relu(self.F + 32, self.F, kernel_size=3, stride=1) 57 | self.l20 = conv3d_relu(self.F, self.F, kernel_size=3, stride=1) 58 | 59 | self.l21 = conv3d_relu(self.F + 32, self.F * 2, kernel_size=3, stride=2) 60 | self.l22 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1) 61 | self.l23 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1) 62 | 63 | self.l24 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=2) 64 | self.l25 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1) 65 | self.l26 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1) 66 | 67 | self.l27 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=2) 68 | self.l28 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1) 69 | self.l29 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1) 70 | 71 | self.l30 = conv3d_relu(self.F * 2, self.F * 4, kernel_size=3, stride=2) 72 | self.l31 = conv3d_relu(self.F * 4, self.F * 4, kernel_size=3, stride=1) 73 | self.l32 = conv3d_relu(self.F * 4, self.F * 4, kernel_size=3, stride=1) 74 | 75 | self.l33 = deconv3d_relu(self.F * 4, self.F * 2, kernel_size=3, stride=2) 76 | self.l34 = deconv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=2) 77 | self.l35 = deconv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=2) 78 | self.l36 = deconv3d_relu(self.F * 2, self.F, kernel_size=3, stride=2) 79 | 80 | self.l37 = nn.Conv3d(self.F, 1, kernel_size=3, stride=1, padding=1, bias=True) 81 | 82 | def forward(self, x): 83 | 84 | x18 = x 85 | x21 = self.l21(x18) 86 | x24 = self.l24(x21) 87 | x27 = self.l27(x24) 88 | x30 = self.l30(x27) 89 | x31 = self.l31(x30) 90 | x32 = self.l32(x31) 91 | 92 | x29 = self.l29(self.l28(x27)) 93 | x33 = self.l33(x32) + x29 94 | 95 | x26 = self.l26(self.l25(x24)) 96 | x34 = self.l34(x33) + x26 97 | 98 | x23 = self.l23(self.l22(x21)) 99 | x35 = self.l35(x34) + x23 100 | 101 | x20 = self.l20(self.l19(x18)) 102 | x36 = self.l36(x35) + x20 103 | 104 | x37 = self.l37(x36) 105 | 106 | conf_volume_wo_sig = x37 107 | 108 | return conf_volume_wo_sig 109 | -------------------------------------------------------------------------------- /src/models/PSMNet.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Jia-Ren Chang 4 | # Copyright (c) 2020 NVIDIA 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import math 28 | 29 | """ 30 | The code in this file is adapted from https://github.com/JiaRenChang/PSMNet 31 | """ 32 | 33 | 34 | def conv2d(in_planes, out_planes, kernel_size, stride, pad, dilation): 35 | 36 | return nn.Sequential( 37 | nn.Conv2d( 38 | in_planes, 39 | out_planes, 40 | kernel_size=kernel_size, 41 | stride=stride, 42 | padding=dilation if dilation > 1 else pad, 43 | dilation=dilation, 44 | bias=True, 45 | ) 46 | ) 47 | 48 | 49 | def conv2d_relu(in_planes, out_planes, kernel_size, stride, pad, dilation): 50 | 51 | return nn.Sequential( 52 | nn.Conv2d( 53 | in_planes, 54 | out_planes, 55 | kernel_size=kernel_size, 56 | stride=stride, 57 | padding=dilation if dilation > 1 else pad, 58 | dilation=dilation, 59 | bias=True, 60 | ), 61 | nn.ReLU(inplace=True), 62 | ) 63 | 64 | 65 | def conv2d_lrelu(in_planes, out_planes, kernel_size, stride, pad, dilation=1): 66 | 67 | return nn.Sequential( 68 | nn.Conv2d( 69 | in_planes, 70 | out_planes, 71 | kernel_size=kernel_size, 72 | stride=stride, 73 | padding=dilation if dilation > 1 else pad, 74 | dilation=dilation, 75 | bias=True, 76 | ), 77 | nn.LeakyReLU(0.1, inplace=True), 78 | ) 79 | 80 | 81 | class BasicBlock(nn.Module): 82 | 83 | expansion = 1 84 | 85 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 86 | 87 | super(BasicBlock, self).__init__() 88 | 89 | self.conv1 = conv2d_relu(inplanes, planes, 3, stride, pad, dilation) 90 | self.conv2 = conv2d(planes, planes, 3, 1, pad, dilation) 91 | 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | 97 | out = self.conv1(x) 98 | out = self.conv2(out) 99 | 100 | if self.downsample is not None: 101 | x = self.downsample(x) 102 | 103 | out += x 104 | 105 | return out 106 | 107 | 108 | class FeatExtractNetSPP(nn.Module): 109 | def __init__(self): 110 | 111 | super(FeatExtractNetSPP, self).__init__() 112 | 113 | self.align_corners = False 114 | self.inplanes = 32 115 | 116 | self.firstconv = nn.Sequential( 117 | conv2d_relu(3, 32, 3, 3, 1, 1), conv2d_relu(32, 32, 3, 1, 1, 1), conv2d_relu(32, 32, 3, 1, 1, 1) 118 | ) 119 | 120 | self.layer1 = self._make_layer(BasicBlock, 32, 2, 1, 1, 2) 121 | 122 | self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64, 64)), conv2d_relu(32, 32, 1, 1, 0, 1)) 123 | 124 | self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32, 32)), conv2d_relu(32, 32, 1, 1, 0, 1)) 125 | 126 | self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16, 16)), conv2d_relu(32, 32, 1, 1, 0, 1)) 127 | 128 | self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8, 8)), conv2d_relu(32, 32, 1, 1, 0, 1)) 129 | 130 | self.lastconv = nn.Sequential( 131 | conv2d_relu(160, 64, 3, 1, 1, 1), 132 | nn.Conv2d(64, 32, kernel_size=1, padding=0, stride=1, bias=False), 133 | ) 134 | 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 139 | elif isinstance(m, nn.Conv3d): 140 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 141 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | elif isinstance(m, nn.BatchNorm3d): 146 | m.weight.data.fill_(1) 147 | m.bias.data.zero_() 148 | elif isinstance(m, nn.Linear): 149 | m.bias.data.zero_() 150 | 151 | def _make_layer(self, block, planes, blocks, stride, pad, dilation): 152 | downsample = None 153 | if stride != 1 or self.inplanes != planes * block.expansion: 154 | downsample = nn.Sequential( 155 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 156 | nn.BatchNorm2d(planes * block.expansion), 157 | ) 158 | 159 | layers = [] 160 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation)) 161 | self.inplanes = planes * block.expansion 162 | for i in range(1, blocks): 163 | layers.append(block(self.inplanes, planes, 1, None, pad, dilation)) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | def forward(self, input): 168 | 169 | output0 = self.firstconv(input) 170 | output1 = self.layer1(output0) 171 | 172 | output_branch1 = self.branch1(output1) 173 | output_branch1 = F.interpolate( 174 | output_branch1, 175 | (output1.size()[2], output1.size()[3]), 176 | mode="bilinear", 177 | align_corners=self.align_corners, 178 | ) 179 | 180 | output_branch2 = self.branch2(output1) 181 | output_branch2 = F.interpolate( 182 | output_branch2, 183 | (output1.size()[2], output1.size()[3]), 184 | mode="bilinear", 185 | align_corners=self.align_corners, 186 | ) 187 | 188 | output_branch3 = self.branch3(output1) 189 | output_branch3 = F.interpolate( 190 | output_branch3, 191 | (output1.size()[2], output1.size()[3]), 192 | mode="bilinear", 193 | align_corners=self.align_corners, 194 | ) 195 | 196 | output_branch4 = self.branch4(output1) 197 | output_branch4 = F.interpolate( 198 | output_branch4, 199 | (output1.size()[2], output1.size()[3]), 200 | mode="bilinear", 201 | align_corners=self.align_corners, 202 | ) 203 | 204 | output_feature = torch.cat( 205 | (output1, output_branch4, output_branch3, output_branch2, output_branch1), 1 206 | ) 207 | 208 | output_feature = self.lastconv(output_feature) 209 | 210 | return output_feature 211 | -------------------------------------------------------------------------------- /src/models/RefineNet2D.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from __future__ import print_function 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import math 14 | import argparse 15 | import time 16 | import torch.backends.cudnn as cudnn 17 | 18 | from models.PSMNet import conv2d 19 | from models.PSMNet import conv2d_lrelu 20 | 21 | from models.DispRefine2D import DispRefineNet 22 | 23 | __all__ = ["disprefinenet", "segrefinenet"] 24 | 25 | 26 | """ 27 | Disparity refinement network. 28 | Takes concatenated input image and the disparity map to generate refined disparity map. 29 | Generates refined output using input image as guide. 30 | """ 31 | 32 | 33 | def disprefinenet(options, data=None): 34 | 35 | print("==> USING DispRefineNet") 36 | for key in options: 37 | if "disprefinenet" in key: 38 | print("{} : {}".format(key, options[key])) 39 | 40 | model = DispRefineNet(out_planes=options["disprefinenet_out_planes"]) 41 | 42 | if data is not None: 43 | model.load_state_dict(data["state_dict"]) 44 | 45 | return model 46 | 47 | 48 | """ 49 | Binary segmentation refinement network. 50 | Takes as input high resolution features of input image and the disparity map. 51 | Generates refined output using input image as guide. 52 | """ 53 | 54 | 55 | class SegRefineNet(nn.Module): 56 | def __init__(self, in_planes=17, out_planes=8): 57 | 58 | super(SegRefineNet, self).__init__() 59 | 60 | self.conv1 = nn.Sequential(conv2d_lrelu(in_planes, out_planes, kernel_size=3, stride=1, pad=1)) 61 | 62 | self.classif1 = nn.Conv2d(out_planes, 1, kernel_size=3, padding=1, stride=1, bias=False) 63 | 64 | for m in self.modules(): 65 | if isinstance(m, nn.Conv2d): 66 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 67 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 68 | elif isinstance(m, nn.Conv3d): 69 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 70 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 71 | elif isinstance(m, nn.BatchNorm2d): 72 | m.weight.data.fill_(1) 73 | m.bias.data.zero_() 74 | elif isinstance(m, nn.BatchNorm3d): 75 | m.weight.data.fill_(1) 76 | m.bias.data.zero_() 77 | elif isinstance(m, nn.Linear): 78 | m.bias.data.zero_() 79 | 80 | def forward(self, input): 81 | 82 | output0 = self.conv1(input) 83 | output = self.classif1(output0) 84 | 85 | return output 86 | 87 | 88 | def segrefinenet(options, data=None): 89 | 90 | print("==> USING SegRefineNet") 91 | for key in options: 92 | if "segrefinenet" in key: 93 | print("{} : {}".format(key, options[key])) 94 | 95 | model = SegRefineNet( 96 | in_planes=options["segrefinenet_in_planes"], out_planes=options["segrefinenet_out_planes"] 97 | ) 98 | 99 | if data is not None: 100 | model.load_state_dict(data["state_dict"]) 101 | 102 | return model 103 | -------------------------------------------------------------------------------- /src/models/RefineNet3D.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | 13 | __all__ = ["segregnet3d"] 14 | 15 | from models.GCNet import conv3d_relu 16 | from models.GCNet import deconv3d_relu 17 | from models.GCNet import feature3d 18 | 19 | 20 | def net_init(net): 21 | 22 | for m in net.modules(): 23 | if isinstance(m, nn.Linear): 24 | m.weight.data = fanin_init(m.weight.data.size()) 25 | elif isinstance(m, nn.Conv3d): 26 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 27 | m.weight.data.normal_(0, np.sqrt(2.0 / n)) 28 | elif isinstance(m, nn.Conv2d): 29 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 30 | m.weight.data.normal_(0, np.sqrt(2.0 / n)) 31 | elif isinstance(m, nn.Conv1d): 32 | n = m.kernel_size[0] * m.out_channels 33 | m.weight.data.normal_(0, np.sqrt(2.0 / n)) 34 | elif isinstance(m, nn.BatchNorm3d): 35 | m.weight.data.fill_(1) 36 | m.bias.data.zero_() 37 | elif isinstance(m, nn.BatchNorm2d): 38 | m.weight.data.fill_(1) 39 | m.bias.data.zero_() 40 | elif isinstance(m, nn.BatchNorm1d): 41 | m.weight.data.fill_(1) 42 | m.bias.data.zero_() 43 | 44 | 45 | class SegRegNet3D(nn.Module): 46 | def __init__(self, F=16): 47 | 48 | super(SegRegNet3D, self).__init__() 49 | 50 | self.conf_preprocess = conv3d_relu(1, F, kernel_size=3, stride=1) 51 | self.layer3d = feature3d(F) 52 | 53 | net_init(self) 54 | 55 | def forward(self, fL, conf_volume): 56 | 57 | fL_stack = fL[:, :, None, :, :].repeat(1, 1, int(conf_volume.shape[2]), 1, 1) 58 | conf_vol_preprocess = self.conf_preprocess(conf_volume) 59 | input_volume = torch.cat((fL_stack, conf_vol_preprocess), dim=1) 60 | oL = self.layer3d(input_volume) 61 | 62 | return oL 63 | 64 | 65 | def segregnet3d(options, data=None): 66 | 67 | print("==> USING SegRegNet3D") 68 | for key in options: 69 | if "regnet" in key: 70 | print("{} : {}".format(key, options[key])) 71 | 72 | model = SegRegNet3D(F=options["regnet_out_planes"]) 73 | if data is not None: 74 | model.load_state_dict(data["state_dict"]) 75 | 76 | return model 77 | -------------------------------------------------------------------------------- /src/models/SegNet2D.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import torch 10 | import torch.nn as nn 11 | import argparse 12 | import math 13 | import torch.nn.functional as F 14 | import torch.backends.cudnn as cudnn 15 | import time 16 | 17 | __all__ = ["segnet2d"] 18 | 19 | # Util Functions 20 | def conv(in_planes, out_planes, kernel_size=3, stride=1, activefun=nn.LeakyReLU(0.1, inplace=True)): 21 | 22 | return nn.Sequential( 23 | nn.Conv2d( 24 | in_planes, 25 | out_planes, 26 | kernel_size=kernel_size, 27 | stride=stride, 28 | padding=(kernel_size - 1) // 2, 29 | bias=True, 30 | ), 31 | activefun, 32 | ) 33 | 34 | 35 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, activefun=nn.LeakyReLU(0.1, inplace=True)): 36 | 37 | return nn.Sequential( 38 | nn.ConvTranspose2d( 39 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, bias=True 40 | ), 41 | activefun, 42 | ) 43 | 44 | 45 | class SegNet2D(nn.Module): 46 | def __init__(self): 47 | 48 | super(SegNet2D, self).__init__() 49 | 50 | self.activefun = nn.LeakyReLU(0.1, inplace=True) 51 | 52 | cps = [64, 128, 256, 512, 512, 512] 53 | dps = [512, 512, 256, 128, 64] 54 | 55 | # Encoder 56 | self.conv1 = conv(cps[0], cps[1], kernel_size=3, stride=2, activefun=self.activefun) 57 | self.conv1_1 = conv(cps[1], cps[1], kernel_size=3, stride=1, activefun=self.activefun) 58 | 59 | self.conv2 = conv(cps[1], cps[2], kernel_size=3, stride=2, activefun=self.activefun) 60 | self.conv2_1 = conv(cps[2], cps[2], kernel_size=3, stride=1, activefun=self.activefun) 61 | 62 | self.conv3 = conv(cps[2], cps[3], kernel_size=3, stride=2, activefun=self.activefun) 63 | self.conv3_1 = conv(cps[3], cps[3], kernel_size=3, stride=1, activefun=self.activefun) 64 | 65 | self.conv4 = conv(cps[3], cps[4], kernel_size=3, stride=2, activefun=self.activefun) 66 | self.conv4_1 = conv(cps[4], cps[4], kernel_size=3, stride=1, activefun=self.activefun) 67 | 68 | self.conv5 = conv(cps[4], cps[5], kernel_size=3, stride=2, activefun=self.activefun) 69 | self.conv5_1 = conv(cps[5], cps[5], kernel_size=3, stride=1, activefun=self.activefun) 70 | 71 | # Decoder 72 | self.deconv5 = deconv(cps[5], dps[0], kernel_size=4, stride=2, activefun=self.activefun) 73 | self.deconv5_1 = conv(dps[0] + cps[4], dps[0], kernel_size=3, stride=1, activefun=self.activefun) 74 | 75 | self.deconv4 = deconv(cps[4], dps[1], kernel_size=4, stride=2, activefun=self.activefun) 76 | self.deconv4_1 = conv(dps[1] + cps[3], dps[1], kernel_size=3, stride=1, activefun=self.activefun) 77 | 78 | self.deconv3 = deconv(dps[1], dps[2], kernel_size=4, stride=2, activefun=self.activefun) 79 | self.deconv3_1 = conv(dps[2] + cps[2], dps[2], kernel_size=3, stride=1, activefun=self.activefun) 80 | 81 | self.deconv2 = deconv(dps[2], dps[3], kernel_size=4, stride=2, activefun=self.activefun) 82 | self.deconv2_1 = conv(dps[3] + cps[1], dps[3], kernel_size=3, stride=1, activefun=self.activefun) 83 | 84 | self.deconv1 = deconv(dps[3], dps[4], kernel_size=4, stride=2, activefun=self.activefun) 85 | self.deconv1_1 = conv(dps[4] + cps[0], dps[4], kernel_size=3, stride=1, activefun=self.activefun) 86 | 87 | self.last_conv = nn.Conv2d(dps[4], 1, kernel_size=3, stride=1, padding=1, bias=True) 88 | 89 | # Init 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 93 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 94 | elif isinstance(m, nn.Conv3d): 95 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 96 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 97 | elif isinstance(m, nn.BatchNorm2d): 98 | m.weight.data.fill_(1) 99 | m.bias.data.zero_() 100 | elif isinstance(m, nn.BatchNorm3d): 101 | m.weight.data.fill_(1) 102 | m.bias.data.zero_() 103 | elif isinstance(m, nn.Linear): 104 | m.bias.data.zero_() 105 | 106 | return 107 | 108 | def forward(self, x): 109 | 110 | out_conv0 = x 111 | out_conv1 = self.conv1_1(self.conv1(out_conv0)) 112 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 113 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 114 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 115 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 116 | 117 | out_deconv5 = self.deconv5(out_conv5) 118 | out_deconv5_1 = self.deconv5_1(torch.cat((out_conv4, out_deconv5), 1)) 119 | 120 | out_deconv4 = self.deconv4(out_deconv5_1) 121 | out_deconv4_1 = self.deconv4_1(torch.cat((out_conv3, out_deconv4), 1)) 122 | 123 | out_deconv3 = self.deconv3(out_deconv4_1) 124 | out_deconv3_1 = self.deconv3_1(torch.cat((out_conv2, out_deconv3), 1)) 125 | 126 | out_deconv2 = self.deconv2(out_deconv3_1) 127 | out_deconv2_1 = self.deconv2_1(torch.cat((out_conv1, out_deconv2), 1)) 128 | 129 | out_deconv1 = self.deconv1(out_deconv2_1) 130 | out_deconv1_1 = self.deconv1_1(torch.cat((out_conv0, out_deconv1), 1)) 131 | 132 | raw_seg = self.last_conv(out_deconv1_1) 133 | 134 | return raw_seg 135 | 136 | 137 | def segnet2d(options, data=None): 138 | 139 | print("==> USING SegNet2D") 140 | for key in options: 141 | if "segnet2d" in key: 142 | print("{} : {}".format(key, options[key])) 143 | 144 | model = SegNet2D() 145 | 146 | if data is not None: 147 | model.load_state_dict(data["state_dict"]) 148 | 149 | return model 150 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .Bi3DNet import * 2 | from .FeatExtractNet import * 3 | from .SegNet2D import * 4 | from .RefineNet2D import * 5 | from .RefineNet3D import * 6 | from .PSMNet import * 7 | from .GCNet import * 8 | from .DispRefine2D import * 9 | 10 | -------------------------------------------------------------------------------- /src/project.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 110 3 | target-version = ['py37'] -------------------------------------------------------------------------------- /src/run_binary_depth_estimation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import argparse 10 | import os 11 | import torch 12 | import torchvision.transforms as transforms 13 | from PIL import Image 14 | 15 | import models 16 | import cv2 17 | import numpy as np 18 | 19 | from util import disp2rgb, str2bool 20 | import random 21 | 22 | model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__")) 23 | 24 | 25 | # Parse arguments 26 | parser = argparse.ArgumentParser(allow_abbrev=False) 27 | 28 | # Model 29 | parser.add_argument("--arch", type=str, default="bi3dnet_binary_depth") 30 | 31 | parser.add_argument("--bi3dnet_featnet_arch", type=str, default="featextractnetspp") 32 | parser.add_argument("--bi3dnet_featnethr_arch", type=str, default="featextractnethr") 33 | parser.add_argument("--bi3dnet_segnet_arch", type=str, default="segnet2d") 34 | parser.add_argument("--bi3dnet_refinenet_arch", type=str, default="segrefinenet") 35 | parser.add_argument("--bi3dnet_max_disparity", type=int, default=192) 36 | parser.add_argument("--bi3dnet_disps_per_example_true", type=str2bool, default=True) 37 | 38 | parser.add_argument("--featextractnethr_out_planes", type=int, default=16) 39 | parser.add_argument("--segrefinenet_in_planes", type=int, default=17) 40 | parser.add_argument("--segrefinenet_out_planes", type=int, default=8) 41 | 42 | # Input 43 | parser.add_argument("--pretrained", type=str) 44 | parser.add_argument("--img_left", type=str) 45 | parser.add_argument("--img_right", type=str) 46 | parser.add_argument("--disp_vals", type=float, nargs="*") 47 | parser.add_argument("--crop_height", type=int) 48 | parser.add_argument("--crop_width", type=int) 49 | 50 | args, unknown = parser.parse_known_args() 51 | 52 | #################################################################################################### 53 | def main(): 54 | 55 | options = vars(args) 56 | print("==> ALL PARAMETERS") 57 | for key in options: 58 | print("{} : {}".format(key, options[key])) 59 | 60 | out_dir = "out" 61 | if not os.path.isdir(out_dir): 62 | os.mkdir(out_dir) 63 | 64 | base_name = os.path.splitext(os.path.basename(args.img_left))[0] 65 | 66 | # Model 67 | network_data = torch.load(args.pretrained) 68 | print("=> using pre-trained model '{}'".format(args.arch)) 69 | model = models.__dict__[args.arch](options, network_data).cuda() 70 | 71 | # Inputs 72 | img_left = Image.open(args.img_left).convert("RGB") 73 | img_left = transforms.functional.to_tensor(img_left) 74 | img_left = transforms.functional.normalize(img_left, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 75 | img_left = img_left.type(torch.cuda.FloatTensor)[None, :, :, :] 76 | img_right = Image.open(args.img_right).convert("RGB") 77 | img_right = transforms.functional.to_tensor(img_right) 78 | img_right = transforms.functional.normalize(img_right, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 79 | img_right = img_right.type(torch.cuda.FloatTensor)[None, :, :, :] 80 | 81 | segs = [] 82 | for disp_val in args.disp_vals: 83 | 84 | assert disp_val % 3 == 0, "disparity value should be a multiple of 3 as we downsample the image by 3" 85 | disp_long = torch.Tensor([[disp_val / 3]]).type(torch.LongTensor).cuda() 86 | 87 | # Pad inputs 88 | tw = args.crop_width 89 | th = args.crop_height 90 | assert tw % 96 == 0, "image dimensions should be a multiple of 96" 91 | assert th % 96 == 0, "image dimensions should be a multiple of 96" 92 | h = img_left.shape[2] 93 | w = img_left.shape[3] 94 | x1 = random.randint(0, max(0, w - tw)) 95 | y1 = random.randint(0, max(0, h - th)) 96 | pad_w = tw - w if tw - w > 0 else 0 97 | pad_h = th - h if th - h > 0 else 0 98 | pad_opr = torch.nn.ZeroPad2d((pad_w, 0, pad_h, 0)) 99 | img_left = img_left[:, :, y1 : y1 + min(th, h), x1 : x1 + min(tw, w)] 100 | img_right = img_right[:, :, y1 : y1 + min(th, h), x1 : x1 + min(tw, w)] 101 | img_left_pad = pad_opr(img_left) 102 | img_right_pad = pad_opr(img_right) 103 | 104 | # Inference 105 | model.eval() 106 | with torch.no_grad(): 107 | output = model(img_left_pad, img_right_pad, disp_long)[1][:, :, pad_h:, pad_w:] 108 | 109 | # Write binary depth results 110 | seg_img = output[0, 0][None, :, :].clone().cpu().detach().numpy() 111 | seg_img = np.transpose(seg_img * 255.0, (1, 2, 0)) 112 | cv2.imwrite( 113 | os.path.join(out_dir, "%s_%s_seg_confidence_%d.png" % (base_name, args.arch, disp_val)), seg_img 114 | ) 115 | 116 | segs.append(output[0, 0][None, :, :].clone().cpu().detach().numpy()) 117 | 118 | # Generate quantized depth results 119 | segs = np.concatenate(segs, axis=0) 120 | segs = np.insert(segs, 0, np.ones((1, h, w), dtype=np.float32), axis=0) 121 | segs = np.append(segs, np.zeros((1, h, w), dtype=np.float32), axis=0) 122 | 123 | segs = 1.0 - segs 124 | 125 | # Get the pdf values for each segmented region 126 | pdf_method = segs[1:, :, :] - segs[:-1, :, :] 127 | 128 | # Get the labels 129 | labels_method = np.argmax(pdf_method, axis=0).astype(np.int) 130 | disp_map = labels_method.astype(np.float32) 131 | 132 | disp_vals = args.disp_vals 133 | disp_vals.insert(0, 0) 134 | disp_vals.append(args.bi3dnet_max_disparity) 135 | 136 | for i in range(len(disp_vals) - 1): 137 | min_disp = disp_vals[i] 138 | max_disp = disp_vals[i + 1] 139 | mid_disp = 0.5 * (min_disp + max_disp) 140 | disp_map[labels_method == i] = mid_disp 141 | 142 | disp_vals_str_list = ["%d" % disp_val for disp_val in disp_vals] 143 | disp_vals_str = "-".join(disp_vals_str_list) 144 | 145 | img_disp = np.clip(disp_map, 0, args.bi3dnet_max_disparity) 146 | img_disp = img_disp / args.bi3dnet_max_disparity 147 | img_disp = (disp2rgb(img_disp) * 255.0).astype(np.uint8) 148 | 149 | cv2.imwrite( 150 | os.path.join(out_dir, "%s_%s_quant_depth_%s.png" % (base_name, args.arch, disp_vals_str)), img_disp 151 | ) 152 | 153 | return 154 | 155 | 156 | if __name__ == "__main__": 157 | main() 158 | -------------------------------------------------------------------------------- /src/run_continuous_depth_estimation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import argparse 10 | import os 11 | import time 12 | import torch 13 | import torchvision.transforms as transforms 14 | from PIL import Image 15 | 16 | import models 17 | import cv2 18 | import numpy as np 19 | from util import disp2rgb, str2bool 20 | 21 | import random 22 | 23 | model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__")) 24 | 25 | 26 | # Parse Arguments 27 | parser = argparse.ArgumentParser(allow_abbrev=False) 28 | 29 | # Experiment Type 30 | parser.add_argument("--arch", type=str, default="bi3dnet_continuous_depth_2D") 31 | 32 | parser.add_argument("--bi3dnet_featnet_arch", type=str, default="featextractnetspp") 33 | parser.add_argument("--bi3dnet_segnet_arch", type=str, default="segnet2d") 34 | parser.add_argument("--bi3dnet_refinenet_arch", type=str, default="disprefinenet") 35 | parser.add_argument("--bi3dnet_regnet_arch", type=str, default="segregnet3d") 36 | parser.add_argument("--bi3dnet_max_disparity", type=int, default=192) 37 | parser.add_argument("--regnet_out_planes", type=int, default=16) 38 | parser.add_argument("--disprefinenet_out_planes", type=int, default=32) 39 | parser.add_argument("--bi3dnet_disps_per_example_true", type=str2bool, default=True) 40 | 41 | # Input 42 | parser.add_argument("--pretrained", type=str) 43 | parser.add_argument("--img_left", type=str) 44 | parser.add_argument("--img_right", type=str) 45 | parser.add_argument("--disp_range_min", type=int) 46 | parser.add_argument("--disp_range_max", type=int) 47 | parser.add_argument("--crop_height", type=int) 48 | parser.add_argument("--crop_width", type=int) 49 | 50 | args, unknown = parser.parse_known_args() 51 | 52 | ############################################################################################################## 53 | def main(): 54 | 55 | options = vars(args) 56 | print("==> ALL PARAMETERS") 57 | for key in options: 58 | print("{} : {}".format(key, options[key])) 59 | 60 | out_dir = "out" 61 | if not os.path.isdir(out_dir): 62 | os.mkdir(out_dir) 63 | 64 | base_name = os.path.splitext(os.path.basename(args.img_left))[0] 65 | 66 | # Model 67 | if args.pretrained: 68 | network_data = torch.load(args.pretrained) 69 | else: 70 | print("Need an input model") 71 | exit() 72 | 73 | print("=> using pre-trained model '{}'".format(args.arch)) 74 | model = models.__dict__[args.arch](options, network_data).cuda() 75 | 76 | # Inputs 77 | img_left = Image.open(args.img_left).convert("RGB") 78 | img_right = Image.open(args.img_right).convert("RGB") 79 | img_left = transforms.functional.to_tensor(img_left) 80 | img_right = transforms.functional.to_tensor(img_right) 81 | img_left = transforms.functional.normalize(img_left, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 82 | img_right = transforms.functional.normalize(img_right, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 83 | img_left = img_left.type(torch.cuda.FloatTensor)[None, :, :, :] 84 | img_right = img_right.type(torch.cuda.FloatTensor)[None, :, :, :] 85 | 86 | # Prepare Disparities 87 | max_disparity = args.disp_range_max 88 | min_disparity = args.disp_range_min 89 | 90 | assert max_disparity % 3 == 0 and min_disparity % 3 == 0, "disparities should be divisible by 3" 91 | 92 | if args.arch == "bi3dnet_continuous_depth_3D": 93 | assert ( 94 | max_disparity - min_disparity 95 | ) % 48 == 0, "for 3D regularization the difference in disparities should be divisible by 48" 96 | 97 | max_disp_levels = (max_disparity - min_disparity) + 1 98 | 99 | max_disparity_3x = int(max_disparity / 3) 100 | min_disparity_3x = int(min_disparity / 3) 101 | max_disp_levels_3x = (max_disparity_3x - min_disparity_3x) + 1 102 | disp_3x = np.linspace(min_disparity_3x, max_disparity_3x, max_disp_levels_3x, dtype=np.int32) 103 | disp_long_3x_main = torch.from_numpy(disp_3x).type(torch.LongTensor).cuda() 104 | disp_float_main = np.linspace(min_disparity, max_disparity, max_disp_levels, dtype=np.float32) 105 | disp_float_main = torch.from_numpy(disp_float_main).type(torch.float32).cuda() 106 | delta = 1 107 | d_min_GT = min_disparity - 0.5 * delta 108 | d_max_GT = max_disparity + 0.5 * delta 109 | disp_long_3x = disp_long_3x_main[None, :].expand(img_left.shape[0], -1) 110 | disp_float = disp_float_main[None, :].expand(img_left.shape[0], -1) 111 | 112 | # Pad Inputs 113 | tw = args.crop_width 114 | th = args.crop_height 115 | assert tw % 96 == 0, "image dimensions should be multiple of 96" 116 | assert th % 96 == 0, "image dimensions should be multiple of 96" 117 | h = img_left.shape[2] 118 | w = img_left.shape[3] 119 | x1 = random.randint(0, max(0, w - tw)) 120 | y1 = random.randint(0, max(0, h - th)) 121 | pad_w = tw - w if tw - w > 0 else 0 122 | pad_h = th - h if th - h > 0 else 0 123 | pad_opr = torch.nn.ZeroPad2d((pad_w, 0, pad_h, 0)) 124 | img_left = img_left[:, :, y1 : y1 + min(th, h), x1 : x1 + min(tw, w)] 125 | img_right = img_right[:, :, y1 : y1 + min(th, h), x1 : x1 + min(tw, w)] 126 | img_left_pad = pad_opr(img_left) 127 | img_right_pad = pad_opr(img_right) 128 | 129 | # Inference 130 | model.eval() 131 | with torch.no_grad(): 132 | if args.arch == "bi3dnet_continuous_depth_2D": 133 | output_seg_low_res_upsample, output_disp_normalized = model( 134 | img_left_pad, img_right_pad, disp_long_3x 135 | ) 136 | output_seg = output_seg_low_res_upsample 137 | else: 138 | ( 139 | output_seg_low_res_upsample, 140 | output_seg_low_res_upsample_refined, 141 | output_disp_normalized_no_reg, 142 | output_disp_normalized, 143 | ) = model(img_left_pad, img_right_pad, disp_long_3x) 144 | output_seg = output_seg_low_res_upsample_refined 145 | 146 | output_seg = output_seg[:, :, pad_h:, pad_w:] 147 | output_disp_normalized = output_disp_normalized[:, :, pad_h:, pad_w:] 148 | output_disp = torch.clamp( 149 | output_disp_normalized * delta * max_disp_levels + d_min_GT, min=d_min_GT, max=d_max_GT 150 | ) 151 | 152 | # Write Results 153 | max_disparity_color = 192 154 | output_disp_clamp = output_disp[0, 0, :, :].cpu().clone().numpy() 155 | output_disp_clamp[output_disp_clamp < min_disparity] = 0 156 | output_disp_clamp[output_disp_clamp > max_disparity] = max_disparity_color 157 | disp_np_ours_color = disp2rgb(output_disp_clamp / max_disparity_color) * 255.0 158 | cv2.imwrite( 159 | os.path.join(out_dir, "%s_%s_%d_%d.png" % (base_name, args.arch, min_disparity, max_disparity)), 160 | disp_np_ours_color, 161 | ) 162 | 163 | return 164 | 165 | 166 | if __name__ == "__main__": 167 | main() 168 | -------------------------------------------------------------------------------- /src/run_demo_kitti15.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GENERATE BINARY DEPTH SEGMENTATIONS AND COMBINE THEM TO GENERATE QUANTIZED DEPTH 4 | CUDA_VISIBLE_DEVICES=0 python run_binary_depth_estimation.py \ 5 | --arch bi3dnet_binary_depth \ 6 | --bi3dnet_featnet_arch featextractnetspp \ 7 | --bi3dnet_featnethr_arch featextractnethr \ 8 | --bi3dnet_segnet_arch segnet2d \ 9 | --bi3dnet_refinenet_arch segrefinenet \ 10 | --featextractnethr_out_planes 16 \ 11 | --segrefinenet_in_planes 17 \ 12 | --segrefinenet_out_planes 8 \ 13 | --crop_height 384 --crop_width 1248 \ 14 | --disp_vals 12 21 30 39 48 \ 15 | --img_left '../data/kitti15_img_left.jpg' \ 16 | --img_right '../data/kitti15_img_right.jpg' \ 17 | --pretrained '../model_weights/kitti15_binary_depth.pth.tar' 18 | 19 | 20 | # FULL RANGE CONTINOUS DEPTH ESTIMATION WITHOUT 3D REGULARIZATION 21 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \ 22 | --arch bi3dnet_continuous_depth_2D \ 23 | --bi3dnet_featnet_arch featextractnetspp \ 24 | --bi3dnet_segnet_arch segnet2d \ 25 | --bi3dnet_refinenet_arch disprefinenet \ 26 | --disprefinenet_out_planes 32 \ 27 | --crop_height 384 --crop_width 1248 \ 28 | --disp_range_min 0 \ 29 | --disp_range_max 192 \ 30 | --bi3dnet_max_disparity 192 \ 31 | --img_left '../data/kitti15_img_left.jpg' \ 32 | --img_right '../data/kitti15_img_right.jpg' \ 33 | --pretrained '../model_weights/kitti15_continuous_depth_no_conf_reg.pth.tar' 34 | 35 | 36 | # SELECTIVE RANGE CONTINOUS DEPTH ESTIMATION WITHOUT 3D REGULARIZATION 37 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \ 38 | --arch bi3dnet_continuous_depth_2D \ 39 | --bi3dnet_featnet_arch featextractnetspp \ 40 | --bi3dnet_segnet_arch segnet2d \ 41 | --bi3dnet_refinenet_arch disprefinenet \ 42 | --disprefinenet_out_planes 32 \ 43 | --crop_height 384 --crop_width 1248 \ 44 | --disp_range_min 12 \ 45 | --disp_range_max 48 \ 46 | --bi3dnet_max_disparity 192 \ 47 | --img_left '../data/kitti15_img_left.jpg' \ 48 | --img_right '../data/kitti15_img_right.jpg' \ 49 | --pretrained '../model_weights/kitti15_continuous_depth_no_conf_reg.pth.tar' 50 | 51 | 52 | # FULL RANGE CONTINOUS DEPTH ESTIMATION WITH 3D REGULARIZATION 53 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \ 54 | --arch bi3dnet_continuous_depth_3D \ 55 | --bi3dnet_featnet_arch featextractnetspp \ 56 | --bi3dnet_segnet_arch segnet2d \ 57 | --bi3dnet_refinenet_arch disprefinenet \ 58 | --bi3dnet_regnet_arch segregnet3d \ 59 | --disprefinenet_out_planes 32 \ 60 | --regnet_out_planes 16 \ 61 | --crop_height 384 --crop_width 1248 \ 62 | --disp_range_min 0 \ 63 | --disp_range_max 192 \ 64 | --bi3dnet_max_disparity 192 \ 65 | --img_left '../data/kitti15_img_left.jpg' \ 66 | --img_right '../data/kitti15_img_right.jpg' \ 67 | --pretrained '../model_weights/kitti15_continuous_depth_conf_reg.pth.tar' 68 | -------------------------------------------------------------------------------- /src/run_demo_sf.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GENERATE BINARY DEPTH SEGMENTATIONS AND COMBINE THEM TO GENERATE QUANTIZED DEPTH 4 | CUDA_VISIBLE_DEVICES=0 python run_binary_depth_estimation.py \ 5 | --arch bi3dnet_binary_depth \ 6 | --bi3dnet_featnet_arch featextractnetspp \ 7 | --bi3dnet_featnethr_arch featextractnethr \ 8 | --bi3dnet_segnet_arch segnet2d \ 9 | --bi3dnet_refinenet_arch segrefinenet \ 10 | --featextractnethr_out_planes 16 \ 11 | --segrefinenet_in_planes 17 \ 12 | --segrefinenet_out_planes 8 \ 13 | --crop_height 576 --crop_width 960 \ 14 | --disp_vals 24 36 54 96 144 \ 15 | --img_left '../data/sf_img_left.jpg' \ 16 | --img_right '../data/sf_img_right.jpg' \ 17 | --pretrained '../model_weights/sf_binary_depth.pth.tar' 18 | 19 | 20 | # FULL RANGE CONTINOUS DEPTH ESTIMATION WITHOUT 3D REGULARIZATION 21 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \ 22 | --arch bi3dnet_continuous_depth_2D \ 23 | --bi3dnet_featnet_arch featextractnetspp \ 24 | --bi3dnet_segnet_arch segnet2d \ 25 | --bi3dnet_refinenet_arch disprefinenet \ 26 | --disprefinenet_out_planes 32 \ 27 | --crop_height 576 --crop_width 960 \ 28 | --disp_range_min 0 \ 29 | --disp_range_max 192 \ 30 | --bi3dnet_max_disparity 192 \ 31 | --img_left '../data/sf_img_left.jpg' \ 32 | --img_right '../data/sf_img_right.jpg' \ 33 | --pretrained '../model_weights/sf_continuous_depth_no_conf_reg.pth.tar' 34 | 35 | 36 | # SELECTIVE RANGE CONTINOUS DEPTH ESTIMATION WITHOUT 3D REGULARIZATION 37 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \ 38 | --arch bi3dnet_continuous_depth_2D \ 39 | --bi3dnet_featnet_arch featextractnetspp \ 40 | --bi3dnet_segnet_arch segnet2d \ 41 | --bi3dnet_refinenet_arch disprefinenet \ 42 | --disprefinenet_out_planes 32 \ 43 | --crop_height 576 --crop_width 960 \ 44 | --disp_range_min 18 \ 45 | --disp_range_max 60 \ 46 | --bi3dnet_max_disparity 192 \ 47 | --img_left '../data/sf_img_left.jpg' \ 48 | --img_right '../data/sf_img_right.jpg' \ 49 | --pretrained '../model_weights/sf_continuous_depth_no_conf_reg.pth.tar' 50 | 51 | 52 | # FULL RANGE CONTINOUS DEPTH ESTIMATION WITH 3D REGULARIZATION 53 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \ 54 | --arch bi3dnet_continuous_depth_3D \ 55 | --bi3dnet_featnet_arch featextractnetspp \ 56 | --bi3dnet_segnet_arch segnet2d \ 57 | --bi3dnet_refinenet_arch disprefinenet \ 58 | --bi3dnet_regnet_arch segregnet3d \ 59 | --disprefinenet_out_planes 32 \ 60 | --regnet_out_planes 16 \ 61 | --crop_height 576 --crop_width 960 \ 62 | --disp_range_min 0 \ 63 | --disp_range_max 192 \ 64 | --bi3dnet_max_disparity 192 \ 65 | --img_left '../data/sf_img_left.jpg' \ 66 | --img_right '../data/sf_img_right.jpg' \ 67 | --pretrained '../model_weights/sf_continuous_depth_conf_reg.pth.tar' 68 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import numpy as np 11 | 12 | 13 | def disp2rgb(disp): 14 | H = disp.shape[0] 15 | W = disp.shape[1] 16 | 17 | I = disp.flatten() 18 | 19 | map = np.array( 20 | [ 21 | [0, 0, 0, 114], 22 | [0, 0, 1, 185], 23 | [1, 0, 0, 114], 24 | [1, 0, 1, 174], 25 | [0, 1, 0, 114], 26 | [0, 1, 1, 185], 27 | [1, 1, 0, 114], 28 | [1, 1, 1, 0], 29 | ] 30 | ) 31 | bins = map[:-1, 3] 32 | cbins = np.cumsum(bins) 33 | bins = bins / cbins[-1] 34 | cbins = cbins[:-1] / cbins[-1] 35 | 36 | ind = np.minimum( 37 | np.sum(np.repeat(I[None, :], 6, axis=0) > np.repeat(cbins[:, None], I.shape[0], axis=1), axis=0), 6 38 | ) 39 | bins = np.reciprocal(bins) 40 | cbins = np.append(np.array([[0]]), cbins[:, None]) 41 | 42 | I = np.multiply(I - cbins[ind], bins[ind]) 43 | I = np.minimum( 44 | np.maximum( 45 | np.multiply(map[ind, 0:3], np.repeat(1 - I[:, None], 3, axis=1)) 46 | + np.multiply(map[ind + 1, 0:3], np.repeat(I[:, None], 3, axis=1)), 47 | 0, 48 | ), 49 | 1, 50 | ) 51 | 52 | I = np.reshape(I, [H, W, 3]).astype(np.float32) 53 | 54 | return I 55 | 56 | 57 | def str2bool(bool_input_string): 58 | if isinstance(bool_input_string, bool): 59 | return bool_input_string 60 | if bool_input_string.lower() in ("true"): 61 | return True 62 | elif bool_input_string.lower() in ("false"): 63 | return False 64 | else: 65 | raise NameError("Please provide boolean type.") 66 | --------------------------------------------------------------------------------