├── .gitignore
├── LICENSE.md
├── README.md
├── data
├── kitti15_img_left.jpg
├── kitti15_img_right.jpg
├── sf_img_left.jpg
├── sf_img_right.jpg
└── teaser.png
├── envs
├── bi3d_conda_env.yml
└── bi3d_pytorch_19_01.DockerFile
└── src
├── models
├── Bi3DNet.py
├── DispRefine2D.py
├── FeatExtractNet.py
├── GCNet.py
├── PSMNet.py
├── RefineNet2D.py
├── RefineNet3D.py
├── SegNet2D.py
└── __init__.py
├── project.toml
├── run_binary_depth_estimation.py
├── run_continuous_depth_estimation.py
├── run_demo_kitti15.sh
├── run_demo_sf.sh
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Add any directories, files, or patterns you don't want to be tracked by version control
2 |
3 | *.png
4 | *.pfm
5 | *.pth.tar
6 | *.npy
7 | *.ppm
8 | *.pyc
9 | *.tar
10 | *.zip
11 | *.gif
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | # NVIDIA Source Code License for Bi3D
2 |
3 | ## 1. Definitions
4 |
5 | “Licensor” means any person or entity that distributes its Work.
6 |
7 | “Software” means the original work of authorship made available under this License.
8 |
9 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License.
10 |
11 | “NVIDIA Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by NVIDIA or its affiliates.
12 |
13 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
14 |
15 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
16 |
17 | ## 2. License Grant
18 |
19 | ### 2.1 Copyright Grant.
20 | Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
21 |
22 | ## 3. Limitations
23 |
24 | ### 3.1 Redistribution.
25 | You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
26 |
27 | ### 3.2 Derivative Works.
28 | You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
29 |
30 | ### 3.3 Use Limitation.
31 | The Work and any derivative works thereof only may be used or intended for use non-commercially and with NVIDIA Processors. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
32 |
33 | ### 3.4 Patent Claims.
34 | If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately.
35 |
36 | ### 3.5 Trademarks.
37 | This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License.
38 |
39 | ### 3.6 Termination.
40 | If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately.
41 |
42 | ## 4. Disclaimer of Warranty.
43 |
44 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
45 |
46 | ## 5. Limitation of Liability.
47 |
48 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Bi3D — Official PyTorch Implementation
2 |
3 | 
4 |
5 | **Bi3D: Stereo Depth Estimation via Binary Classifications**
6 | Abhishek Badki, Alejandro Troccoli, Kihwan Kim, Jan Kautz, Pradeep Sen, and Orazio Gallo
7 | IEEE CVPR 2020
8 |
9 | ## Abstract:
10 | *Stereo-based depth estimation is a cornerstone of computer vision, with state-of-the-art methods delivering accurate results in real time. For several applications such as autonomous navigation, however, it may be useful to trade accuracy for lower latency. We present Bi3D, a method that estimates depth via a series of binary classifications. Rather than testing if objects are* at *a particular depth D, as existing stereo methods do, it classifies them as being* closer *or* farther *than D. This property offers a powerful mechanism to balance accuracy and latency. Given a strict time budget, Bi3D can detect objects closer than a given distance in as little as a few milliseconds, or estimate depth with arbitrarily coarse quantization, with complexity linear with the number of quantization levels. Bi3D can also use the allotted quantization levels to get continuous depth, but in a specific depth range. For standard stereo (i.e., continuous depth on the whole range), our method is close to or on par with state-of-the-art, finely tuned stereo methods.*
11 |
12 |
13 | ## Paper:
14 | https://arxiv.org/pdf/2005.07274.pdf
15 |
16 | ## Videos:
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | ## Citing Bi3D:
28 | @InProceedings{badki2020Bi3D,
29 | author = {Badki, Abhishek and Troccoli, Alejandro and Kim, Kihwan and Kautz, Jan and Sen, Pradeep and Gallo, Orazio},
30 | title = {{Bi3D}: {S}tereo Depth Estimation via Binary Classifications},
31 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
32 | year = {2020}
33 | }
34 |
35 | or the arXiv paper
36 |
37 | @InProceedings{badki2020Bi3D,
38 | author = {Badki, Abhishek and Troccoli, Alejandro and Kim, Kihwan and Kautz, Jan and Sen, Pradeep and Gallo, Orazio},
39 | title = {{Bi3D}: {S}tereo Depth Estimation via Binary Classifications},
40 | booktitle = {arXiv preprint arXiv:2005.07274},
41 | year = {2020}
42 | }
43 |
44 |
45 | ## Code:
46 |
47 | ### License
48 |
49 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
50 |
51 | Licensed under the [NVIDIA Source Code License](LICENSE.md)
52 |
53 | ### Description
54 |
55 |
56 | ### Setup
57 |
58 | We offer two ways of setting up your environemnt, through Docker or Conda.
59 |
60 | #### Docker
61 | For convenience, we provide a Dockerfile to build a container image to run the code. The image will contain the Python dependencies.
62 |
63 | System requirements:
64 |
65 | 1. Docker (Tested on version 19.03.11)
66 |
67 | 2. [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker/wiki)
68 |
69 | 3. NVIDIA GPU driver.
70 |
71 | Build the container image:
72 | ```
73 | docker build -t bi3d . -f envs/bi3d_pytorch_19_01.DockerFile
74 | ```
75 | To launch the container, run the following:
76 | ```
77 | docker run --rm -it --gpus=all -v $(pwd):/bi3d -w /bi3d --net=host --ipc=host bi3d:latest /bin/bash
78 | ```
79 |
80 | #### Conda
81 | All dependencies will be installed automatically using the following:
82 | ```
83 | conda env create -f envs/bi3d_conda_env.yml
84 | ```
85 | You can activate the environment by running:
86 | ```
87 | conda activate bi3d
88 | ```
89 |
90 | ### Pre-trained models
91 | Download the pre-trained models [here](https://drive.google.com/file/d/1X4Ing9WumtIxonNXXCzKJulJtPgzk61n).
92 |
93 | ### Run the demo
94 |
95 | ```
96 | cd src
97 | # RUN DEMO FOR SCENEFLOW DATASET
98 | sh run_demo_sf.sh
99 | # RUN DEMO FOR KITTI15 DATASET
100 | sh run_demo_kitti15.sh
101 | ```
102 |
--------------------------------------------------------------------------------
/data/kitti15_img_left.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Bi3D/4b5fdb48d820b8a5cfd95a7d2d82ea56f18cd597/data/kitti15_img_left.jpg
--------------------------------------------------------------------------------
/data/kitti15_img_right.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Bi3D/4b5fdb48d820b8a5cfd95a7d2d82ea56f18cd597/data/kitti15_img_right.jpg
--------------------------------------------------------------------------------
/data/sf_img_left.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Bi3D/4b5fdb48d820b8a5cfd95a7d2d82ea56f18cd597/data/sf_img_left.jpg
--------------------------------------------------------------------------------
/data/sf_img_right.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Bi3D/4b5fdb48d820b8a5cfd95a7d2d82ea56f18cd597/data/sf_img_right.jpg
--------------------------------------------------------------------------------
/data/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/Bi3D/4b5fdb48d820b8a5cfd95a7d2d82ea56f18cd597/data/teaser.png
--------------------------------------------------------------------------------
/envs/bi3d_conda_env.yml:
--------------------------------------------------------------------------------
1 | name: bi3d
2 | channels:
3 | - pytorch
4 | - soumith
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - blas=1.0=mkl
9 | - ca-certificates=2020.6.24=0
10 | - certifi=2020.6.20=py37_0
11 | - cudatoolkit=10.0.130=0
12 | - freetype=2.10.2=h5ab3b9f_0
13 | - intel-openmp=2020.1=217
14 | - jpeg=9b=h024ee3a_2
15 | - lcms2=2.11=h396b838_0
16 | - ld_impl_linux-64=2.33.1=h53a641e_7
17 | - libedit=3.1.20191231=h14c3975_1
18 | - libffi=3.3=he6710b0_2
19 | - libgcc-ng=9.1.0=hdf63c60_0
20 | - libgfortran-ng=7.3.0=hdf63c60_0
21 | - libpng=1.6.37=hbc83047_0
22 | - libstdcxx-ng=9.1.0=hdf63c60_0
23 | - libtiff=4.1.0=h2733197_1
24 | - lz4-c=1.9.2=he6710b0_0
25 | - mkl=2020.1=217
26 | - mkl-service=2.3.0=py37he904b0f_0
27 | - mkl_fft=1.1.0=py37h23d657b_0
28 | - mkl_random=1.1.1=py37h0573a6f_0
29 | - ncurses=6.2=he6710b0_1
30 | - ninja=1.9.0=py37hfd86e86_0
31 | - numpy=1.18.5=py37ha1c710e_0
32 | - numpy-base=1.18.5=py37hde5b4d6_0
33 | - olefile=0.46=py_0
34 | - openssl=1.1.1g=h7b6447c_0
35 | - pillow=7.2.0=py37hb39fc2d_0
36 | - pip=20.1.1=py37_1
37 | - python=3.7.7=hcff3b4d_5
38 | - pytorch=1.4.0=py3.7_cuda10.0.130_cudnn7.6.3_0
39 | - readline=8.0=h7b6447c_0
40 | - setuptools=49.2.0=py37_0
41 | - six=1.15.0=py_0
42 | - sqlite=3.32.3=h62c20be_0
43 | - tk=8.6.10=hbc83047_0
44 | - torchvision=0.5.0=py37_cu100
45 | - wheel=0.34.2=py37_0
46 | - xz=5.2.5=h7b6447c_0
47 | - zlib=1.2.11=h7b6447c_3
48 | - zstd=1.4.5=h0b5b093_0
49 | - pip:
50 | - imageio==2.9.0
51 | - opencv-python==4.3.0.36
52 | - protobuf==3.12.2
53 | - tensorboardx==2.1
54 |
55 |
--------------------------------------------------------------------------------
/envs/bi3d_pytorch_19_01.DockerFile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:19.01-py3
2 |
3 | RUN pip install Pillow
4 | RUN pip install imageio
5 | RUN pip install tensorboardX
6 | RUN pip install opencv-python
7 |
--------------------------------------------------------------------------------
/src/models/Bi3DNet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | from torch.autograd import Variable
13 | import torch.nn.functional as F
14 |
15 | import models.FeatExtractNet as FeatNet
16 | import models.SegNet2D as SegNet
17 | import models.RefineNet2D as RefineNet
18 | import models.RefineNet3D as RefineNet3D
19 |
20 |
21 | __all__ = ["bi3dnet_binary_depth", "bi3dnet_continuous_depth_2D", "bi3dnet_continuous_depth_3D"]
22 |
23 |
24 | def compute_cost_volume(features_left, features_right, disp_ids, max_disp, is_disps_per_example):
25 |
26 | batch_size = features_left.shape[0]
27 | feature_size = features_left.shape[1]
28 | H = features_left.shape[2]
29 | W = features_left.shape[3]
30 |
31 | psv_size = disp_ids.shape[1]
32 |
33 | psv = Variable(features_left.new_zeros(batch_size, psv_size, feature_size * 2, H, W + max_disp)).cuda()
34 |
35 | if is_disps_per_example:
36 | for i in range(batch_size):
37 | psv[i, 0, :feature_size, :, 0:W] = features_left[i]
38 | psv[i, 0, feature_size:, :, disp_ids[i, 0] : W + disp_ids[i, 0]] = features_right[i]
39 | psv = psv.contiguous()
40 | else:
41 | for i in range(psv_size):
42 | psv[:, i, :feature_size, :, 0:W] = features_left
43 | psv[:, i, feature_size:, :, disp_ids[0, i] : W + disp_ids[0, i]] = features_right
44 | psv = psv.contiguous()
45 |
46 | return psv
47 |
48 |
49 | """
50 | Bi3DNet for continuous depthmap generation. Doesn't use 3D regularization.
51 | """
52 |
53 |
54 | class Bi3DNetContinuousDepth2D(nn.Module):
55 | def __init__(self, options, featnet_arch, segnet_arch, refinenet_arch=None, max_disparity=192):
56 |
57 | super(Bi3DNetContinuousDepth2D, self).__init__()
58 |
59 | self.max_disparity = max_disparity
60 | self.max_disparity_seg = int(self.max_disparity / 3)
61 | self.is_disps_per_example = False
62 | self.is_save_memory = False
63 |
64 | self.is_refine = True
65 | if refinenet_arch == None:
66 | self.is_refine = False
67 |
68 | self.featnet = FeatNet.__dict__[featnet_arch](options, data=None)
69 | self.segnet = SegNet.__dict__[segnet_arch](options, data=None)
70 | if self.is_refine:
71 | self.refinenet = RefineNet.__dict__[refinenet_arch](options, data=None)
72 |
73 | return
74 |
75 | def forward(self, img_left, img_right, disp_ids):
76 |
77 | batch_size = img_left.shape[0]
78 | psv_size = disp_ids.shape[1]
79 |
80 | if psv_size == 1:
81 | self.is_disps_per_example = True
82 | else:
83 | self.is_disps_per_example = False
84 |
85 | # Feature Extraction
86 | features_left = self.featnet(img_left)
87 | features_right = self.featnet(img_right)
88 | feature_size = features_left.shape[1]
89 | H = features_left.shape[2]
90 | W = features_left.shape[3]
91 |
92 | # Cost Volume Generation
93 | psv = compute_cost_volume(
94 | features_left, features_right, disp_ids, self.max_disparity_seg, self.is_disps_per_example
95 | )
96 |
97 | psv = psv.view(batch_size * psv_size, feature_size * 2, H, W + self.max_disparity_seg)
98 |
99 | # Segmentation Network
100 | seg_raw_low_res = self.segnet(psv)[:, :, :, :W]
101 | seg_raw_low_res = seg_raw_low_res.view(batch_size, 1, psv_size, H, W)
102 |
103 | # Upsampling
104 | seg_prob_low_res_up = torch.sigmoid(
105 | F.interpolate(
106 | seg_raw_low_res,
107 | size=[psv_size * 3, img_left.size()[-2], img_left.size()[-1]],
108 | mode="trilinear",
109 | align_corners=False,
110 | )
111 | )
112 | seg_prob_low_res_up = seg_prob_low_res_up[:, 0, 1:-1, :, :]
113 |
114 | # Projection
115 | disparity_normalized = torch.mean((seg_prob_low_res_up), dim=1, keepdim=True)
116 |
117 | # Refinement
118 | if self.is_refine:
119 | refine_net_input = torch.cat((disparity_normalized, img_left), dim=1)
120 | disparity_normalized = self.refinenet(refine_net_input)
121 |
122 | return seg_prob_low_res_up, disparity_normalized
123 |
124 |
125 | def bi3dnet_continuous_depth_2D(options, data=None):
126 |
127 | print("==> USING Bi3DNetContinuousDepth2D")
128 | for key in options:
129 | if "bi3dnet" in key:
130 | print("{} : {}".format(key, options[key]))
131 |
132 | model = Bi3DNetContinuousDepth2D(
133 | options,
134 | featnet_arch=options["bi3dnet_featnet_arch"],
135 | segnet_arch=options["bi3dnet_segnet_arch"],
136 | refinenet_arch=options["bi3dnet_refinenet_arch"],
137 | max_disparity=options["bi3dnet_max_disparity"],
138 | )
139 |
140 | if data is not None:
141 | model.load_state_dict(data["state_dict"])
142 |
143 | return model
144 |
145 |
146 | """
147 | Bi3DNet for continuous depthmap generation. Uses 3D regularization.
148 | """
149 |
150 |
151 | class Bi3DNetContinuousDepth3D(nn.Module):
152 | def __init__(
153 | self,
154 | options,
155 | featnet_arch,
156 | segnet_arch,
157 | refinenet_arch=None,
158 | refinenet3d_arch=None,
159 | max_disparity=192,
160 | ):
161 |
162 | super(Bi3DNetContinuousDepth3D, self).__init__()
163 |
164 | self.max_disparity = max_disparity
165 | self.max_disparity_seg = int(self.max_disparity / 3)
166 | self.is_disps_per_example = False
167 | self.is_save_memory = False
168 |
169 | self.is_refine = True
170 | if refinenet_arch == None:
171 | self.is_refine = False
172 |
173 | self.featnet = FeatNet.__dict__[featnet_arch](options, data=None)
174 | self.segnet = SegNet.__dict__[segnet_arch](options, data=None)
175 | if self.is_refine:
176 | self.refinenet = RefineNet.__dict__[refinenet_arch](options, data=None)
177 | self.refinenet3d = RefineNet3D.__dict__[refinenet3d_arch](options, data=None)
178 |
179 | return
180 |
181 | def forward(self, img_left, img_right, disp_ids):
182 |
183 | batch_size = img_left.shape[0]
184 | psv_size = disp_ids.shape[1]
185 |
186 | if psv_size == 1:
187 | self.is_disps_per_example = True
188 | else:
189 | self.is_disps_per_example = False
190 |
191 | # Feature Extraction
192 | features_left = self.featnet(img_left)
193 | features_right = self.featnet(img_right)
194 | feature_size = features_left.shape[1]
195 | H = features_left.shape[2]
196 | W = features_left.shape[3]
197 |
198 | # Cost Volume Generation
199 | psv = compute_cost_volume(
200 | features_left, features_right, disp_ids, self.max_disparity_seg, self.is_disps_per_example
201 | )
202 |
203 | psv = psv.view(batch_size * psv_size, feature_size * 2, H, W + self.max_disparity_seg)
204 |
205 | # Segmentation Network
206 | seg_raw_low_res = self.segnet(psv)[:, :, :, :W] # cropped to remove excess boundary
207 | seg_raw_low_res = seg_raw_low_res.view(batch_size, 1, psv_size, H, W)
208 |
209 | # Upsampling
210 | seg_prob_low_res_up = torch.sigmoid(
211 | F.interpolate(
212 | seg_raw_low_res,
213 | size=[psv_size * 3, img_left.size()[-2], img_left.size()[-1]],
214 | mode="trilinear",
215 | align_corners=False,
216 | )
217 | )
218 |
219 | seg_prob_low_res_up = seg_prob_low_res_up[:, 0, 1:-1, :, :]
220 |
221 | # Upsampling after 3D Regularization
222 | seg_raw_low_res_refined = seg_raw_low_res
223 | seg_raw_low_res_refined[:, :, 1:, :, :] = self.refinenet3d(
224 | features_left, seg_raw_low_res_refined[:, :, 1:, :, :]
225 | )
226 |
227 | seg_prob_low_res_refined_up = torch.sigmoid(
228 | F.interpolate(
229 | seg_raw_low_res_refined,
230 | size=[psv_size * 3, img_left.size()[-2], img_left.size()[-1]],
231 | mode="trilinear",
232 | align_corners=False,
233 | )
234 | )
235 |
236 | seg_prob_low_res_refined_up = seg_prob_low_res_refined_up[:, 0, 1:-1, :, :]
237 |
238 | # Projection
239 | disparity_normalized_noisy = torch.mean((seg_prob_low_res_refined_up), dim=1, keepdim=True)
240 |
241 | # Refinement
242 | if self.is_refine:
243 | refine_net_input = torch.cat((disparity_normalized_noisy, img_left), dim=1)
244 | disparity_normalized = self.refinenet(refine_net_input)
245 |
246 | return (
247 | seg_prob_low_res_up,
248 | seg_prob_low_res_refined_up,
249 | disparity_normalized_noisy,
250 | disparity_normalized,
251 | )
252 |
253 |
254 | def bi3dnet_continuous_depth_3D(options, data=None):
255 |
256 | print("==> USING Bi3DNetContinuousDepth3D")
257 | for key in options:
258 | if "bi3dnet" in key:
259 | print("{} : {}".format(key, options[key]))
260 |
261 | model = Bi3DNetContinuousDepth3D(
262 | options,
263 | featnet_arch=options["bi3dnet_featnet_arch"],
264 | segnet_arch=options["bi3dnet_segnet_arch"],
265 | refinenet_arch=options["bi3dnet_refinenet_arch"],
266 | refinenet3d_arch=options["bi3dnet_regnet_arch"],
267 | max_disparity=options["bi3dnet_max_disparity"],
268 | )
269 |
270 | if data is not None:
271 | model.load_state_dict(data["state_dict"])
272 |
273 | return model
274 |
275 |
276 | """
277 | Bi3DNet for binary depthmap generation.
278 | """
279 |
280 |
281 | class Bi3DNetBinaryDepth(nn.Module):
282 | def __init__(
283 | self,
284 | options,
285 | featnet_arch,
286 | segnet_arch,
287 | refinenet_arch=None,
288 | featnethr_arch=None,
289 | max_disparity=192,
290 | is_disps_per_example=False,
291 | ):
292 |
293 | super(Bi3DNetBinaryDepth, self).__init__()
294 |
295 | self.max_disparity = max_disparity
296 | self.max_disparity_seg = int(max_disparity / 3)
297 | self.is_disps_per_example = is_disps_per_example
298 |
299 | self.is_refine = True
300 | if refinenet_arch == None:
301 | self.is_refine = False
302 |
303 | self.featnet = FeatNet.__dict__[featnet_arch](options, data=None)
304 | self.featnethr = FeatNet.__dict__[featnethr_arch](options, data=None)
305 | self.segnet = SegNet.__dict__[segnet_arch](options, data=None)
306 | if self.is_refine:
307 | self.refinenet = RefineNet.__dict__[refinenet_arch](options, data=None)
308 |
309 | return
310 |
311 | def forward(self, img_left, img_right, disp_ids):
312 |
313 | batch_size = img_left.shape[0]
314 | psv_size = disp_ids.shape[1]
315 |
316 | if psv_size == 1:
317 | self.is_disps_per_example = True
318 | else:
319 | self.is_disps_per_example = False
320 |
321 | # Feature Extraction
322 | features = self.featnet(torch.cat((img_left, img_right), dim=0))
323 |
324 | features_left = features[:batch_size, :, :, :]
325 | features_right = features[batch_size:, :, :, :]
326 |
327 | if self.is_refine:
328 | features_lefthr = self.featnethr(img_left)
329 | feature_size = features_left.shape[1]
330 | H = features_left.shape[2]
331 | W = features_left.shape[3]
332 |
333 | # Cost Volume Generation
334 | psv = compute_cost_volume(
335 | features_left, features_right, disp_ids, self.max_disparity_seg, self.is_disps_per_example
336 | )
337 |
338 | psv = psv.view(batch_size * psv_size, feature_size * 2, H, W + self.max_disparity_seg)
339 |
340 | # Segmentation Network
341 | seg_raw_low_res = self.segnet(psv)[:, :, :, :W] # cropped to remove excess boundary
342 | seg_prob_low_res = torch.sigmoid(seg_raw_low_res)
343 | seg_prob_low_res = seg_prob_low_res.view(batch_size, psv_size, H, W)
344 |
345 | seg_prob_low_res_up = F.interpolate(
346 | seg_prob_low_res, size=img_left.size()[-2:], mode="bilinear", align_corners=False
347 | )
348 | out = []
349 | out.append(seg_prob_low_res_up)
350 |
351 | # Refinement
352 | if self.is_refine:
353 | seg_raw_high_res = F.interpolate(
354 | seg_raw_low_res, size=img_left.size()[-2:], mode="bilinear", align_corners=False
355 | )
356 | # Refine Net
357 | features_left_expand = (
358 | features_lefthr[:, None, :, :, :].expand(-1, psv_size, -1, -1, -1).contiguous()
359 | )
360 | features_left_expand = features_left_expand.view(
361 | -1, features_lefthr.size()[1], features_lefthr.size()[2], features_lefthr.size()[3]
362 | )
363 | refine_net_input = torch.cat((seg_raw_high_res, features_left_expand), dim=1)
364 |
365 | seg_raw_high_res = self.refinenet(refine_net_input)
366 |
367 | seg_prob_high_res = torch.sigmoid(seg_raw_high_res)
368 | seg_prob_high_res = seg_prob_high_res.view(
369 | batch_size, psv_size, img_left.size()[-2], img_left.size()[-1]
370 | )
371 | out.append(seg_prob_high_res)
372 | else:
373 | out.append(seg_prob_low_res_up)
374 |
375 | return out
376 |
377 |
378 | def bi3dnet_binary_depth(options, data=None):
379 |
380 | print("==> USING Bi3DNetBinaryDepth")
381 | for key in options:
382 | if "bi3dnet" in key:
383 | print("{} : {}".format(key, options[key]))
384 |
385 | model = Bi3DNetBinaryDepth(
386 | options,
387 | featnet_arch=options["bi3dnet_featnet_arch"],
388 | segnet_arch=options["bi3dnet_segnet_arch"],
389 | refinenet_arch=options["bi3dnet_refinenet_arch"],
390 | featnethr_arch=options["bi3dnet_featnethr_arch"],
391 | max_disparity=options["bi3dnet_max_disparity"],
392 | is_disps_per_example=options["bi3dnet_disps_per_example_true"],
393 | )
394 |
395 | if data is not None:
396 | model.load_state_dict(data["state_dict"])
397 |
398 | return model
399 |
--------------------------------------------------------------------------------
/src/models/DispRefine2D.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2019 Xuanyi Li (xuanyili.edu@gmail.com)
4 | # Copyright (c) 2020 NVIDIA
5 | #
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 | #
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 | #
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE.
23 |
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 | import math
28 |
29 | from models.PSMNet import conv2d
30 | from models.PSMNet import conv2d_lrelu
31 |
32 | """
33 | The code in this file is adapted
34 | from https://github.com/meteorshowers/StereoNet-ActiveStereoNet
35 | """
36 |
37 |
38 | class BasicBlock(nn.Module):
39 |
40 | expansion = 1
41 |
42 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation):
43 |
44 | super(BasicBlock, self).__init__()
45 |
46 | self.conv1 = conv2d_lrelu(inplanes, planes, 3, stride, pad, dilation)
47 | self.conv2 = conv2d(planes, planes, 3, 1, pad, dilation)
48 |
49 | self.downsample = downsample
50 | self.stride = stride
51 |
52 | def forward(self, x):
53 |
54 | out = self.conv1(x)
55 | out = self.conv2(out)
56 |
57 | if self.downsample is not None:
58 | x = self.downsample(x)
59 |
60 | out += x
61 |
62 | return out
63 |
64 |
65 | class DispRefineNet(nn.Module):
66 | def __init__(self, out_planes=32):
67 |
68 | super(DispRefineNet, self).__init__()
69 |
70 | self.out_planes = out_planes
71 |
72 | self.conv2d_feature = conv2d_lrelu(
73 | in_planes=4, out_planes=self.out_planes, kernel_size=3, stride=1, pad=1, dilation=1
74 | )
75 |
76 | self.residual_astrous_blocks = nn.ModuleList()
77 | astrous_list = [1, 2, 4, 8, 1, 1]
78 | for di in astrous_list:
79 | self.residual_astrous_blocks.append(
80 | BasicBlock(self.out_planes, self.out_planes, stride=1, downsample=None, pad=1, dilation=di)
81 | )
82 |
83 | self.conv2d_out = nn.Conv2d(self.out_planes, 1, kernel_size=3, stride=1, padding=1)
84 |
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
88 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
89 | elif isinstance(m, nn.Conv3d):
90 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
91 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
92 | elif isinstance(m, nn.BatchNorm2d):
93 | m.weight.data.fill_(1)
94 | m.bias.data.zero_()
95 | elif isinstance(m, nn.BatchNorm3d):
96 | m.weight.data.fill_(1)
97 | m.bias.data.zero_()
98 | elif isinstance(m, nn.Linear):
99 | m.bias.data.zero_()
100 |
101 | return
102 |
103 | def forward(self, x):
104 |
105 | disp = x[:, 0, :, :][:, None, :, :]
106 | output = self.conv2d_feature(x)
107 |
108 | for astrous_block in self.residual_astrous_blocks:
109 | output = astrous_block(output)
110 |
111 | output = self.conv2d_out(output) # residual disparity
112 | output = output + disp # final disparity
113 |
114 | return output
115 |
--------------------------------------------------------------------------------
/src/models/FeatExtractNet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | from __future__ import print_function
10 | import torch
11 | import torch.nn as nn
12 | import math
13 |
14 | from models.PSMNet import conv2d
15 | from models.PSMNet import conv2d_relu
16 | from models.PSMNet import FeatExtractNetSPP
17 |
18 | __all__ = ["featextractnetspp", "featextractnethr"]
19 |
20 |
21 | """
22 | Feature extraction network.
23 | Generates 16D features at the image resolution.
24 | Used for final refinement.
25 | """
26 |
27 |
28 | class FeatExtractNetHR(nn.Module):
29 | def __init__(self, out_planes=16):
30 |
31 | super(FeatExtractNetHR, self).__init__()
32 |
33 | self.conv1 = nn.Sequential(
34 | conv2d_relu(3, out_planes, kernel_size=3, stride=1, pad=1, dilation=1),
35 | conv2d_relu(out_planes, out_planes, kernel_size=3, stride=1, pad=1, dilation=1),
36 | nn.Conv2d(out_planes, out_planes, kernel_size=1, padding=0, stride=1, bias=False),
37 | )
38 |
39 | for m in self.modules():
40 | if isinstance(m, nn.Conv2d):
41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
42 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
43 | elif isinstance(m, nn.Conv3d):
44 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
45 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
46 | elif isinstance(m, nn.BatchNorm2d):
47 | m.weight.data.fill_(1)
48 | m.bias.data.zero_()
49 | elif isinstance(m, nn.BatchNorm3d):
50 | m.weight.data.fill_(1)
51 | m.bias.data.zero_()
52 | elif isinstance(m, nn.Linear):
53 | m.bias.data.zero_()
54 |
55 | return
56 |
57 | def forward(self, input):
58 |
59 | output = self.conv1(input)
60 | return output
61 |
62 |
63 | def featextractnethr(options, data=None):
64 |
65 | print("==> USING FeatExtractNetHR")
66 | for key in options:
67 | if "featextractnethr" in key:
68 | print("{} : {}".format(key, options[key]))
69 |
70 | model = FeatExtractNetHR(out_planes=options["featextractnethr_out_planes"])
71 |
72 | if data is not None:
73 | model.load_state_dict(data["state_dict"])
74 |
75 | return model
76 |
77 |
78 | """
79 | Feature extraction network.
80 | Generates 32D features at 3x less resolution.
81 | Uses Spatial Pyramid Pooling inspired by PSMNet.
82 | """
83 |
84 |
85 | def featextractnetspp(options, data=None):
86 |
87 | print("==> USING FeatExtractNetSPP")
88 | for key in options:
89 | if "feat" in key:
90 | print("{} : {}".format(key, options[key]))
91 |
92 | model = FeatExtractNetSPP()
93 |
94 | if data is not None:
95 | model.load_state_dict(data["state_dict"])
96 |
97 | return model
98 |
--------------------------------------------------------------------------------
/src/models/GCNet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018 Wang Yufeng
2 | # Copyright (c) 2020 NVIDIA
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | import torch.nn as nn
18 |
19 | """
20 | The code in this file is adapted from https://github.com/wyf2017/DSMnet
21 | """
22 |
23 |
24 | def conv3d_relu(in_planes, out_planes, kernel_size=3, stride=1, activefun=nn.ReLU(inplace=True)):
25 |
26 | return nn.Sequential(
27 | nn.Conv3d(in_planes, out_planes, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=True),
28 | activefun,
29 | )
30 |
31 |
32 | def deconv3d_relu(in_planes, out_planes, kernel_size=4, stride=2, activefun=nn.ReLU(inplace=True)):
33 |
34 | assert stride > 1
35 | p = (kernel_size - 1) // 2
36 | op = stride - (kernel_size - 2 * p)
37 | return nn.Sequential(
38 | nn.ConvTranspose3d(
39 | in_planes, out_planes, kernel_size, stride, padding=p, output_padding=op, bias=True
40 | ),
41 | activefun,
42 | )
43 |
44 |
45 | """
46 | GCNet style 3D regularization network
47 | """
48 |
49 |
50 | class feature3d(nn.Module):
51 | def __init__(self, num_F):
52 |
53 | super(feature3d, self).__init__()
54 | self.F = num_F
55 |
56 | self.l19 = conv3d_relu(self.F + 32, self.F, kernel_size=3, stride=1)
57 | self.l20 = conv3d_relu(self.F, self.F, kernel_size=3, stride=1)
58 |
59 | self.l21 = conv3d_relu(self.F + 32, self.F * 2, kernel_size=3, stride=2)
60 | self.l22 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1)
61 | self.l23 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1)
62 |
63 | self.l24 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=2)
64 | self.l25 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1)
65 | self.l26 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1)
66 |
67 | self.l27 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=2)
68 | self.l28 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1)
69 | self.l29 = conv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=1)
70 |
71 | self.l30 = conv3d_relu(self.F * 2, self.F * 4, kernel_size=3, stride=2)
72 | self.l31 = conv3d_relu(self.F * 4, self.F * 4, kernel_size=3, stride=1)
73 | self.l32 = conv3d_relu(self.F * 4, self.F * 4, kernel_size=3, stride=1)
74 |
75 | self.l33 = deconv3d_relu(self.F * 4, self.F * 2, kernel_size=3, stride=2)
76 | self.l34 = deconv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=2)
77 | self.l35 = deconv3d_relu(self.F * 2, self.F * 2, kernel_size=3, stride=2)
78 | self.l36 = deconv3d_relu(self.F * 2, self.F, kernel_size=3, stride=2)
79 |
80 | self.l37 = nn.Conv3d(self.F, 1, kernel_size=3, stride=1, padding=1, bias=True)
81 |
82 | def forward(self, x):
83 |
84 | x18 = x
85 | x21 = self.l21(x18)
86 | x24 = self.l24(x21)
87 | x27 = self.l27(x24)
88 | x30 = self.l30(x27)
89 | x31 = self.l31(x30)
90 | x32 = self.l32(x31)
91 |
92 | x29 = self.l29(self.l28(x27))
93 | x33 = self.l33(x32) + x29
94 |
95 | x26 = self.l26(self.l25(x24))
96 | x34 = self.l34(x33) + x26
97 |
98 | x23 = self.l23(self.l22(x21))
99 | x35 = self.l35(x34) + x23
100 |
101 | x20 = self.l20(self.l19(x18))
102 | x36 = self.l36(x35) + x20
103 |
104 | x37 = self.l37(x36)
105 |
106 | conf_volume_wo_sig = x37
107 |
108 | return conf_volume_wo_sig
109 |
--------------------------------------------------------------------------------
/src/models/PSMNet.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Jia-Ren Chang
4 | # Copyright (c) 2020 NVIDIA
5 | #
6 | # Permission is hereby granted, free of charge, to any person obtaining a copy
7 | # of this software and associated documentation files (the "Software"), to deal
8 | # in the Software without restriction, including without limitation the rights
9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | # copies of the Software, and to permit persons to whom the Software is
11 | # furnished to do so, subject to the following conditions:
12 | #
13 | # The above copyright notice and this permission notice shall be included in all
14 | # copies or substantial portions of the Software.
15 | #
16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | # SOFTWARE.
23 |
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 | import math
28 |
29 | """
30 | The code in this file is adapted from https://github.com/JiaRenChang/PSMNet
31 | """
32 |
33 |
34 | def conv2d(in_planes, out_planes, kernel_size, stride, pad, dilation):
35 |
36 | return nn.Sequential(
37 | nn.Conv2d(
38 | in_planes,
39 | out_planes,
40 | kernel_size=kernel_size,
41 | stride=stride,
42 | padding=dilation if dilation > 1 else pad,
43 | dilation=dilation,
44 | bias=True,
45 | )
46 | )
47 |
48 |
49 | def conv2d_relu(in_planes, out_planes, kernel_size, stride, pad, dilation):
50 |
51 | return nn.Sequential(
52 | nn.Conv2d(
53 | in_planes,
54 | out_planes,
55 | kernel_size=kernel_size,
56 | stride=stride,
57 | padding=dilation if dilation > 1 else pad,
58 | dilation=dilation,
59 | bias=True,
60 | ),
61 | nn.ReLU(inplace=True),
62 | )
63 |
64 |
65 | def conv2d_lrelu(in_planes, out_planes, kernel_size, stride, pad, dilation=1):
66 |
67 | return nn.Sequential(
68 | nn.Conv2d(
69 | in_planes,
70 | out_planes,
71 | kernel_size=kernel_size,
72 | stride=stride,
73 | padding=dilation if dilation > 1 else pad,
74 | dilation=dilation,
75 | bias=True,
76 | ),
77 | nn.LeakyReLU(0.1, inplace=True),
78 | )
79 |
80 |
81 | class BasicBlock(nn.Module):
82 |
83 | expansion = 1
84 |
85 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation):
86 |
87 | super(BasicBlock, self).__init__()
88 |
89 | self.conv1 = conv2d_relu(inplanes, planes, 3, stride, pad, dilation)
90 | self.conv2 = conv2d(planes, planes, 3, 1, pad, dilation)
91 |
92 | self.downsample = downsample
93 | self.stride = stride
94 |
95 | def forward(self, x):
96 |
97 | out = self.conv1(x)
98 | out = self.conv2(out)
99 |
100 | if self.downsample is not None:
101 | x = self.downsample(x)
102 |
103 | out += x
104 |
105 | return out
106 |
107 |
108 | class FeatExtractNetSPP(nn.Module):
109 | def __init__(self):
110 |
111 | super(FeatExtractNetSPP, self).__init__()
112 |
113 | self.align_corners = False
114 | self.inplanes = 32
115 |
116 | self.firstconv = nn.Sequential(
117 | conv2d_relu(3, 32, 3, 3, 1, 1), conv2d_relu(32, 32, 3, 1, 1, 1), conv2d_relu(32, 32, 3, 1, 1, 1)
118 | )
119 |
120 | self.layer1 = self._make_layer(BasicBlock, 32, 2, 1, 1, 2)
121 |
122 | self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64, 64)), conv2d_relu(32, 32, 1, 1, 0, 1))
123 |
124 | self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32, 32)), conv2d_relu(32, 32, 1, 1, 0, 1))
125 |
126 | self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16, 16)), conv2d_relu(32, 32, 1, 1, 0, 1))
127 |
128 | self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8, 8)), conv2d_relu(32, 32, 1, 1, 0, 1))
129 |
130 | self.lastconv = nn.Sequential(
131 | conv2d_relu(160, 64, 3, 1, 1, 1),
132 | nn.Conv2d(64, 32, kernel_size=1, padding=0, stride=1, bias=False),
133 | )
134 |
135 | for m in self.modules():
136 | if isinstance(m, nn.Conv2d):
137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
138 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
139 | elif isinstance(m, nn.Conv3d):
140 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
141 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
142 | elif isinstance(m, nn.BatchNorm2d):
143 | m.weight.data.fill_(1)
144 | m.bias.data.zero_()
145 | elif isinstance(m, nn.BatchNorm3d):
146 | m.weight.data.fill_(1)
147 | m.bias.data.zero_()
148 | elif isinstance(m, nn.Linear):
149 | m.bias.data.zero_()
150 |
151 | def _make_layer(self, block, planes, blocks, stride, pad, dilation):
152 | downsample = None
153 | if stride != 1 or self.inplanes != planes * block.expansion:
154 | downsample = nn.Sequential(
155 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
156 | nn.BatchNorm2d(planes * block.expansion),
157 | )
158 |
159 | layers = []
160 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))
161 | self.inplanes = planes * block.expansion
162 | for i in range(1, blocks):
163 | layers.append(block(self.inplanes, planes, 1, None, pad, dilation))
164 |
165 | return nn.Sequential(*layers)
166 |
167 | def forward(self, input):
168 |
169 | output0 = self.firstconv(input)
170 | output1 = self.layer1(output0)
171 |
172 | output_branch1 = self.branch1(output1)
173 | output_branch1 = F.interpolate(
174 | output_branch1,
175 | (output1.size()[2], output1.size()[3]),
176 | mode="bilinear",
177 | align_corners=self.align_corners,
178 | )
179 |
180 | output_branch2 = self.branch2(output1)
181 | output_branch2 = F.interpolate(
182 | output_branch2,
183 | (output1.size()[2], output1.size()[3]),
184 | mode="bilinear",
185 | align_corners=self.align_corners,
186 | )
187 |
188 | output_branch3 = self.branch3(output1)
189 | output_branch3 = F.interpolate(
190 | output_branch3,
191 | (output1.size()[2], output1.size()[3]),
192 | mode="bilinear",
193 | align_corners=self.align_corners,
194 | )
195 |
196 | output_branch4 = self.branch4(output1)
197 | output_branch4 = F.interpolate(
198 | output_branch4,
199 | (output1.size()[2], output1.size()[3]),
200 | mode="bilinear",
201 | align_corners=self.align_corners,
202 | )
203 |
204 | output_feature = torch.cat(
205 | (output1, output_branch4, output_branch3, output_branch2, output_branch1), 1
206 | )
207 |
208 | output_feature = self.lastconv(output_feature)
209 |
210 | return output_feature
211 |
--------------------------------------------------------------------------------
/src/models/RefineNet2D.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | from __future__ import print_function
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import math
14 | import argparse
15 | import time
16 | import torch.backends.cudnn as cudnn
17 |
18 | from models.PSMNet import conv2d
19 | from models.PSMNet import conv2d_lrelu
20 |
21 | from models.DispRefine2D import DispRefineNet
22 |
23 | __all__ = ["disprefinenet", "segrefinenet"]
24 |
25 |
26 | """
27 | Disparity refinement network.
28 | Takes concatenated input image and the disparity map to generate refined disparity map.
29 | Generates refined output using input image as guide.
30 | """
31 |
32 |
33 | def disprefinenet(options, data=None):
34 |
35 | print("==> USING DispRefineNet")
36 | for key in options:
37 | if "disprefinenet" in key:
38 | print("{} : {}".format(key, options[key]))
39 |
40 | model = DispRefineNet(out_planes=options["disprefinenet_out_planes"])
41 |
42 | if data is not None:
43 | model.load_state_dict(data["state_dict"])
44 |
45 | return model
46 |
47 |
48 | """
49 | Binary segmentation refinement network.
50 | Takes as input high resolution features of input image and the disparity map.
51 | Generates refined output using input image as guide.
52 | """
53 |
54 |
55 | class SegRefineNet(nn.Module):
56 | def __init__(self, in_planes=17, out_planes=8):
57 |
58 | super(SegRefineNet, self).__init__()
59 |
60 | self.conv1 = nn.Sequential(conv2d_lrelu(in_planes, out_planes, kernel_size=3, stride=1, pad=1))
61 |
62 | self.classif1 = nn.Conv2d(out_planes, 1, kernel_size=3, padding=1, stride=1, bias=False)
63 |
64 | for m in self.modules():
65 | if isinstance(m, nn.Conv2d):
66 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
67 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
68 | elif isinstance(m, nn.Conv3d):
69 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
70 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
71 | elif isinstance(m, nn.BatchNorm2d):
72 | m.weight.data.fill_(1)
73 | m.bias.data.zero_()
74 | elif isinstance(m, nn.BatchNorm3d):
75 | m.weight.data.fill_(1)
76 | m.bias.data.zero_()
77 | elif isinstance(m, nn.Linear):
78 | m.bias.data.zero_()
79 |
80 | def forward(self, input):
81 |
82 | output0 = self.conv1(input)
83 | output = self.classif1(output0)
84 |
85 | return output
86 |
87 |
88 | def segrefinenet(options, data=None):
89 |
90 | print("==> USING SegRefineNet")
91 | for key in options:
92 | if "segrefinenet" in key:
93 | print("{} : {}".format(key, options[key]))
94 |
95 | model = SegRefineNet(
96 | in_planes=options["segrefinenet_in_planes"], out_planes=options["segrefinenet_out_planes"]
97 | )
98 |
99 | if data is not None:
100 | model.load_state_dict(data["state_dict"])
101 |
102 | return model
103 |
--------------------------------------------------------------------------------
/src/models/RefineNet3D.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import torch
10 | import torch.nn as nn
11 | import numpy as np
12 |
13 | __all__ = ["segregnet3d"]
14 |
15 | from models.GCNet import conv3d_relu
16 | from models.GCNet import deconv3d_relu
17 | from models.GCNet import feature3d
18 |
19 |
20 | def net_init(net):
21 |
22 | for m in net.modules():
23 | if isinstance(m, nn.Linear):
24 | m.weight.data = fanin_init(m.weight.data.size())
25 | elif isinstance(m, nn.Conv3d):
26 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
27 | m.weight.data.normal_(0, np.sqrt(2.0 / n))
28 | elif isinstance(m, nn.Conv2d):
29 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
30 | m.weight.data.normal_(0, np.sqrt(2.0 / n))
31 | elif isinstance(m, nn.Conv1d):
32 | n = m.kernel_size[0] * m.out_channels
33 | m.weight.data.normal_(0, np.sqrt(2.0 / n))
34 | elif isinstance(m, nn.BatchNorm3d):
35 | m.weight.data.fill_(1)
36 | m.bias.data.zero_()
37 | elif isinstance(m, nn.BatchNorm2d):
38 | m.weight.data.fill_(1)
39 | m.bias.data.zero_()
40 | elif isinstance(m, nn.BatchNorm1d):
41 | m.weight.data.fill_(1)
42 | m.bias.data.zero_()
43 |
44 |
45 | class SegRegNet3D(nn.Module):
46 | def __init__(self, F=16):
47 |
48 | super(SegRegNet3D, self).__init__()
49 |
50 | self.conf_preprocess = conv3d_relu(1, F, kernel_size=3, stride=1)
51 | self.layer3d = feature3d(F)
52 |
53 | net_init(self)
54 |
55 | def forward(self, fL, conf_volume):
56 |
57 | fL_stack = fL[:, :, None, :, :].repeat(1, 1, int(conf_volume.shape[2]), 1, 1)
58 | conf_vol_preprocess = self.conf_preprocess(conf_volume)
59 | input_volume = torch.cat((fL_stack, conf_vol_preprocess), dim=1)
60 | oL = self.layer3d(input_volume)
61 |
62 | return oL
63 |
64 |
65 | def segregnet3d(options, data=None):
66 |
67 | print("==> USING SegRegNet3D")
68 | for key in options:
69 | if "regnet" in key:
70 | print("{} : {}".format(key, options[key]))
71 |
72 | model = SegRegNet3D(F=options["regnet_out_planes"])
73 | if data is not None:
74 | model.load_state_dict(data["state_dict"])
75 |
76 | return model
77 |
--------------------------------------------------------------------------------
/src/models/SegNet2D.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import torch
10 | import torch.nn as nn
11 | import argparse
12 | import math
13 | import torch.nn.functional as F
14 | import torch.backends.cudnn as cudnn
15 | import time
16 |
17 | __all__ = ["segnet2d"]
18 |
19 | # Util Functions
20 | def conv(in_planes, out_planes, kernel_size=3, stride=1, activefun=nn.LeakyReLU(0.1, inplace=True)):
21 |
22 | return nn.Sequential(
23 | nn.Conv2d(
24 | in_planes,
25 | out_planes,
26 | kernel_size=kernel_size,
27 | stride=stride,
28 | padding=(kernel_size - 1) // 2,
29 | bias=True,
30 | ),
31 | activefun,
32 | )
33 |
34 |
35 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, activefun=nn.LeakyReLU(0.1, inplace=True)):
36 |
37 | return nn.Sequential(
38 | nn.ConvTranspose2d(
39 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, bias=True
40 | ),
41 | activefun,
42 | )
43 |
44 |
45 | class SegNet2D(nn.Module):
46 | def __init__(self):
47 |
48 | super(SegNet2D, self).__init__()
49 |
50 | self.activefun = nn.LeakyReLU(0.1, inplace=True)
51 |
52 | cps = [64, 128, 256, 512, 512, 512]
53 | dps = [512, 512, 256, 128, 64]
54 |
55 | # Encoder
56 | self.conv1 = conv(cps[0], cps[1], kernel_size=3, stride=2, activefun=self.activefun)
57 | self.conv1_1 = conv(cps[1], cps[1], kernel_size=3, stride=1, activefun=self.activefun)
58 |
59 | self.conv2 = conv(cps[1], cps[2], kernel_size=3, stride=2, activefun=self.activefun)
60 | self.conv2_1 = conv(cps[2], cps[2], kernel_size=3, stride=1, activefun=self.activefun)
61 |
62 | self.conv3 = conv(cps[2], cps[3], kernel_size=3, stride=2, activefun=self.activefun)
63 | self.conv3_1 = conv(cps[3], cps[3], kernel_size=3, stride=1, activefun=self.activefun)
64 |
65 | self.conv4 = conv(cps[3], cps[4], kernel_size=3, stride=2, activefun=self.activefun)
66 | self.conv4_1 = conv(cps[4], cps[4], kernel_size=3, stride=1, activefun=self.activefun)
67 |
68 | self.conv5 = conv(cps[4], cps[5], kernel_size=3, stride=2, activefun=self.activefun)
69 | self.conv5_1 = conv(cps[5], cps[5], kernel_size=3, stride=1, activefun=self.activefun)
70 |
71 | # Decoder
72 | self.deconv5 = deconv(cps[5], dps[0], kernel_size=4, stride=2, activefun=self.activefun)
73 | self.deconv5_1 = conv(dps[0] + cps[4], dps[0], kernel_size=3, stride=1, activefun=self.activefun)
74 |
75 | self.deconv4 = deconv(cps[4], dps[1], kernel_size=4, stride=2, activefun=self.activefun)
76 | self.deconv4_1 = conv(dps[1] + cps[3], dps[1], kernel_size=3, stride=1, activefun=self.activefun)
77 |
78 | self.deconv3 = deconv(dps[1], dps[2], kernel_size=4, stride=2, activefun=self.activefun)
79 | self.deconv3_1 = conv(dps[2] + cps[2], dps[2], kernel_size=3, stride=1, activefun=self.activefun)
80 |
81 | self.deconv2 = deconv(dps[2], dps[3], kernel_size=4, stride=2, activefun=self.activefun)
82 | self.deconv2_1 = conv(dps[3] + cps[1], dps[3], kernel_size=3, stride=1, activefun=self.activefun)
83 |
84 | self.deconv1 = deconv(dps[3], dps[4], kernel_size=4, stride=2, activefun=self.activefun)
85 | self.deconv1_1 = conv(dps[4] + cps[0], dps[4], kernel_size=3, stride=1, activefun=self.activefun)
86 |
87 | self.last_conv = nn.Conv2d(dps[4], 1, kernel_size=3, stride=1, padding=1, bias=True)
88 |
89 | # Init
90 | for m in self.modules():
91 | if isinstance(m, nn.Conv2d):
92 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
93 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
94 | elif isinstance(m, nn.Conv3d):
95 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
96 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
97 | elif isinstance(m, nn.BatchNorm2d):
98 | m.weight.data.fill_(1)
99 | m.bias.data.zero_()
100 | elif isinstance(m, nn.BatchNorm3d):
101 | m.weight.data.fill_(1)
102 | m.bias.data.zero_()
103 | elif isinstance(m, nn.Linear):
104 | m.bias.data.zero_()
105 |
106 | return
107 |
108 | def forward(self, x):
109 |
110 | out_conv0 = x
111 | out_conv1 = self.conv1_1(self.conv1(out_conv0))
112 | out_conv2 = self.conv2_1(self.conv2(out_conv1))
113 | out_conv3 = self.conv3_1(self.conv3(out_conv2))
114 | out_conv4 = self.conv4_1(self.conv4(out_conv3))
115 | out_conv5 = self.conv5_1(self.conv5(out_conv4))
116 |
117 | out_deconv5 = self.deconv5(out_conv5)
118 | out_deconv5_1 = self.deconv5_1(torch.cat((out_conv4, out_deconv5), 1))
119 |
120 | out_deconv4 = self.deconv4(out_deconv5_1)
121 | out_deconv4_1 = self.deconv4_1(torch.cat((out_conv3, out_deconv4), 1))
122 |
123 | out_deconv3 = self.deconv3(out_deconv4_1)
124 | out_deconv3_1 = self.deconv3_1(torch.cat((out_conv2, out_deconv3), 1))
125 |
126 | out_deconv2 = self.deconv2(out_deconv3_1)
127 | out_deconv2_1 = self.deconv2_1(torch.cat((out_conv1, out_deconv2), 1))
128 |
129 | out_deconv1 = self.deconv1(out_deconv2_1)
130 | out_deconv1_1 = self.deconv1_1(torch.cat((out_conv0, out_deconv1), 1))
131 |
132 | raw_seg = self.last_conv(out_deconv1_1)
133 |
134 | return raw_seg
135 |
136 |
137 | def segnet2d(options, data=None):
138 |
139 | print("==> USING SegNet2D")
140 | for key in options:
141 | if "segnet2d" in key:
142 | print("{} : {}".format(key, options[key]))
143 |
144 | model = SegNet2D()
145 |
146 | if data is not None:
147 | model.load_state_dict(data["state_dict"])
148 |
149 | return model
150 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .Bi3DNet import *
2 | from .FeatExtractNet import *
3 | from .SegNet2D import *
4 | from .RefineNet2D import *
5 | from .RefineNet3D import *
6 | from .PSMNet import *
7 | from .GCNet import *
8 | from .DispRefine2D import *
9 |
10 |
--------------------------------------------------------------------------------
/src/project.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 110
3 | target-version = ['py37']
--------------------------------------------------------------------------------
/src/run_binary_depth_estimation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import argparse
10 | import os
11 | import torch
12 | import torchvision.transforms as transforms
13 | from PIL import Image
14 |
15 | import models
16 | import cv2
17 | import numpy as np
18 |
19 | from util import disp2rgb, str2bool
20 | import random
21 |
22 | model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__"))
23 |
24 |
25 | # Parse arguments
26 | parser = argparse.ArgumentParser(allow_abbrev=False)
27 |
28 | # Model
29 | parser.add_argument("--arch", type=str, default="bi3dnet_binary_depth")
30 |
31 | parser.add_argument("--bi3dnet_featnet_arch", type=str, default="featextractnetspp")
32 | parser.add_argument("--bi3dnet_featnethr_arch", type=str, default="featextractnethr")
33 | parser.add_argument("--bi3dnet_segnet_arch", type=str, default="segnet2d")
34 | parser.add_argument("--bi3dnet_refinenet_arch", type=str, default="segrefinenet")
35 | parser.add_argument("--bi3dnet_max_disparity", type=int, default=192)
36 | parser.add_argument("--bi3dnet_disps_per_example_true", type=str2bool, default=True)
37 |
38 | parser.add_argument("--featextractnethr_out_planes", type=int, default=16)
39 | parser.add_argument("--segrefinenet_in_planes", type=int, default=17)
40 | parser.add_argument("--segrefinenet_out_planes", type=int, default=8)
41 |
42 | # Input
43 | parser.add_argument("--pretrained", type=str)
44 | parser.add_argument("--img_left", type=str)
45 | parser.add_argument("--img_right", type=str)
46 | parser.add_argument("--disp_vals", type=float, nargs="*")
47 | parser.add_argument("--crop_height", type=int)
48 | parser.add_argument("--crop_width", type=int)
49 |
50 | args, unknown = parser.parse_known_args()
51 |
52 | ####################################################################################################
53 | def main():
54 |
55 | options = vars(args)
56 | print("==> ALL PARAMETERS")
57 | for key in options:
58 | print("{} : {}".format(key, options[key]))
59 |
60 | out_dir = "out"
61 | if not os.path.isdir(out_dir):
62 | os.mkdir(out_dir)
63 |
64 | base_name = os.path.splitext(os.path.basename(args.img_left))[0]
65 |
66 | # Model
67 | network_data = torch.load(args.pretrained)
68 | print("=> using pre-trained model '{}'".format(args.arch))
69 | model = models.__dict__[args.arch](options, network_data).cuda()
70 |
71 | # Inputs
72 | img_left = Image.open(args.img_left).convert("RGB")
73 | img_left = transforms.functional.to_tensor(img_left)
74 | img_left = transforms.functional.normalize(img_left, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
75 | img_left = img_left.type(torch.cuda.FloatTensor)[None, :, :, :]
76 | img_right = Image.open(args.img_right).convert("RGB")
77 | img_right = transforms.functional.to_tensor(img_right)
78 | img_right = transforms.functional.normalize(img_right, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
79 | img_right = img_right.type(torch.cuda.FloatTensor)[None, :, :, :]
80 |
81 | segs = []
82 | for disp_val in args.disp_vals:
83 |
84 | assert disp_val % 3 == 0, "disparity value should be a multiple of 3 as we downsample the image by 3"
85 | disp_long = torch.Tensor([[disp_val / 3]]).type(torch.LongTensor).cuda()
86 |
87 | # Pad inputs
88 | tw = args.crop_width
89 | th = args.crop_height
90 | assert tw % 96 == 0, "image dimensions should be a multiple of 96"
91 | assert th % 96 == 0, "image dimensions should be a multiple of 96"
92 | h = img_left.shape[2]
93 | w = img_left.shape[3]
94 | x1 = random.randint(0, max(0, w - tw))
95 | y1 = random.randint(0, max(0, h - th))
96 | pad_w = tw - w if tw - w > 0 else 0
97 | pad_h = th - h if th - h > 0 else 0
98 | pad_opr = torch.nn.ZeroPad2d((pad_w, 0, pad_h, 0))
99 | img_left = img_left[:, :, y1 : y1 + min(th, h), x1 : x1 + min(tw, w)]
100 | img_right = img_right[:, :, y1 : y1 + min(th, h), x1 : x1 + min(tw, w)]
101 | img_left_pad = pad_opr(img_left)
102 | img_right_pad = pad_opr(img_right)
103 |
104 | # Inference
105 | model.eval()
106 | with torch.no_grad():
107 | output = model(img_left_pad, img_right_pad, disp_long)[1][:, :, pad_h:, pad_w:]
108 |
109 | # Write binary depth results
110 | seg_img = output[0, 0][None, :, :].clone().cpu().detach().numpy()
111 | seg_img = np.transpose(seg_img * 255.0, (1, 2, 0))
112 | cv2.imwrite(
113 | os.path.join(out_dir, "%s_%s_seg_confidence_%d.png" % (base_name, args.arch, disp_val)), seg_img
114 | )
115 |
116 | segs.append(output[0, 0][None, :, :].clone().cpu().detach().numpy())
117 |
118 | # Generate quantized depth results
119 | segs = np.concatenate(segs, axis=0)
120 | segs = np.insert(segs, 0, np.ones((1, h, w), dtype=np.float32), axis=0)
121 | segs = np.append(segs, np.zeros((1, h, w), dtype=np.float32), axis=0)
122 |
123 | segs = 1.0 - segs
124 |
125 | # Get the pdf values for each segmented region
126 | pdf_method = segs[1:, :, :] - segs[:-1, :, :]
127 |
128 | # Get the labels
129 | labels_method = np.argmax(pdf_method, axis=0).astype(np.int)
130 | disp_map = labels_method.astype(np.float32)
131 |
132 | disp_vals = args.disp_vals
133 | disp_vals.insert(0, 0)
134 | disp_vals.append(args.bi3dnet_max_disparity)
135 |
136 | for i in range(len(disp_vals) - 1):
137 | min_disp = disp_vals[i]
138 | max_disp = disp_vals[i + 1]
139 | mid_disp = 0.5 * (min_disp + max_disp)
140 | disp_map[labels_method == i] = mid_disp
141 |
142 | disp_vals_str_list = ["%d" % disp_val for disp_val in disp_vals]
143 | disp_vals_str = "-".join(disp_vals_str_list)
144 |
145 | img_disp = np.clip(disp_map, 0, args.bi3dnet_max_disparity)
146 | img_disp = img_disp / args.bi3dnet_max_disparity
147 | img_disp = (disp2rgb(img_disp) * 255.0).astype(np.uint8)
148 |
149 | cv2.imwrite(
150 | os.path.join(out_dir, "%s_%s_quant_depth_%s.png" % (base_name, args.arch, disp_vals_str)), img_disp
151 | )
152 |
153 | return
154 |
155 |
156 | if __name__ == "__main__":
157 | main()
158 |
--------------------------------------------------------------------------------
/src/run_continuous_depth_estimation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import argparse
10 | import os
11 | import time
12 | import torch
13 | import torchvision.transforms as transforms
14 | from PIL import Image
15 |
16 | import models
17 | import cv2
18 | import numpy as np
19 | from util import disp2rgb, str2bool
20 |
21 | import random
22 |
23 | model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__"))
24 |
25 |
26 | # Parse Arguments
27 | parser = argparse.ArgumentParser(allow_abbrev=False)
28 |
29 | # Experiment Type
30 | parser.add_argument("--arch", type=str, default="bi3dnet_continuous_depth_2D")
31 |
32 | parser.add_argument("--bi3dnet_featnet_arch", type=str, default="featextractnetspp")
33 | parser.add_argument("--bi3dnet_segnet_arch", type=str, default="segnet2d")
34 | parser.add_argument("--bi3dnet_refinenet_arch", type=str, default="disprefinenet")
35 | parser.add_argument("--bi3dnet_regnet_arch", type=str, default="segregnet3d")
36 | parser.add_argument("--bi3dnet_max_disparity", type=int, default=192)
37 | parser.add_argument("--regnet_out_planes", type=int, default=16)
38 | parser.add_argument("--disprefinenet_out_planes", type=int, default=32)
39 | parser.add_argument("--bi3dnet_disps_per_example_true", type=str2bool, default=True)
40 |
41 | # Input
42 | parser.add_argument("--pretrained", type=str)
43 | parser.add_argument("--img_left", type=str)
44 | parser.add_argument("--img_right", type=str)
45 | parser.add_argument("--disp_range_min", type=int)
46 | parser.add_argument("--disp_range_max", type=int)
47 | parser.add_argument("--crop_height", type=int)
48 | parser.add_argument("--crop_width", type=int)
49 |
50 | args, unknown = parser.parse_known_args()
51 |
52 | ##############################################################################################################
53 | def main():
54 |
55 | options = vars(args)
56 | print("==> ALL PARAMETERS")
57 | for key in options:
58 | print("{} : {}".format(key, options[key]))
59 |
60 | out_dir = "out"
61 | if not os.path.isdir(out_dir):
62 | os.mkdir(out_dir)
63 |
64 | base_name = os.path.splitext(os.path.basename(args.img_left))[0]
65 |
66 | # Model
67 | if args.pretrained:
68 | network_data = torch.load(args.pretrained)
69 | else:
70 | print("Need an input model")
71 | exit()
72 |
73 | print("=> using pre-trained model '{}'".format(args.arch))
74 | model = models.__dict__[args.arch](options, network_data).cuda()
75 |
76 | # Inputs
77 | img_left = Image.open(args.img_left).convert("RGB")
78 | img_right = Image.open(args.img_right).convert("RGB")
79 | img_left = transforms.functional.to_tensor(img_left)
80 | img_right = transforms.functional.to_tensor(img_right)
81 | img_left = transforms.functional.normalize(img_left, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
82 | img_right = transforms.functional.normalize(img_right, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
83 | img_left = img_left.type(torch.cuda.FloatTensor)[None, :, :, :]
84 | img_right = img_right.type(torch.cuda.FloatTensor)[None, :, :, :]
85 |
86 | # Prepare Disparities
87 | max_disparity = args.disp_range_max
88 | min_disparity = args.disp_range_min
89 |
90 | assert max_disparity % 3 == 0 and min_disparity % 3 == 0, "disparities should be divisible by 3"
91 |
92 | if args.arch == "bi3dnet_continuous_depth_3D":
93 | assert (
94 | max_disparity - min_disparity
95 | ) % 48 == 0, "for 3D regularization the difference in disparities should be divisible by 48"
96 |
97 | max_disp_levels = (max_disparity - min_disparity) + 1
98 |
99 | max_disparity_3x = int(max_disparity / 3)
100 | min_disparity_3x = int(min_disparity / 3)
101 | max_disp_levels_3x = (max_disparity_3x - min_disparity_3x) + 1
102 | disp_3x = np.linspace(min_disparity_3x, max_disparity_3x, max_disp_levels_3x, dtype=np.int32)
103 | disp_long_3x_main = torch.from_numpy(disp_3x).type(torch.LongTensor).cuda()
104 | disp_float_main = np.linspace(min_disparity, max_disparity, max_disp_levels, dtype=np.float32)
105 | disp_float_main = torch.from_numpy(disp_float_main).type(torch.float32).cuda()
106 | delta = 1
107 | d_min_GT = min_disparity - 0.5 * delta
108 | d_max_GT = max_disparity + 0.5 * delta
109 | disp_long_3x = disp_long_3x_main[None, :].expand(img_left.shape[0], -1)
110 | disp_float = disp_float_main[None, :].expand(img_left.shape[0], -1)
111 |
112 | # Pad Inputs
113 | tw = args.crop_width
114 | th = args.crop_height
115 | assert tw % 96 == 0, "image dimensions should be multiple of 96"
116 | assert th % 96 == 0, "image dimensions should be multiple of 96"
117 | h = img_left.shape[2]
118 | w = img_left.shape[3]
119 | x1 = random.randint(0, max(0, w - tw))
120 | y1 = random.randint(0, max(0, h - th))
121 | pad_w = tw - w if tw - w > 0 else 0
122 | pad_h = th - h if th - h > 0 else 0
123 | pad_opr = torch.nn.ZeroPad2d((pad_w, 0, pad_h, 0))
124 | img_left = img_left[:, :, y1 : y1 + min(th, h), x1 : x1 + min(tw, w)]
125 | img_right = img_right[:, :, y1 : y1 + min(th, h), x1 : x1 + min(tw, w)]
126 | img_left_pad = pad_opr(img_left)
127 | img_right_pad = pad_opr(img_right)
128 |
129 | # Inference
130 | model.eval()
131 | with torch.no_grad():
132 | if args.arch == "bi3dnet_continuous_depth_2D":
133 | output_seg_low_res_upsample, output_disp_normalized = model(
134 | img_left_pad, img_right_pad, disp_long_3x
135 | )
136 | output_seg = output_seg_low_res_upsample
137 | else:
138 | (
139 | output_seg_low_res_upsample,
140 | output_seg_low_res_upsample_refined,
141 | output_disp_normalized_no_reg,
142 | output_disp_normalized,
143 | ) = model(img_left_pad, img_right_pad, disp_long_3x)
144 | output_seg = output_seg_low_res_upsample_refined
145 |
146 | output_seg = output_seg[:, :, pad_h:, pad_w:]
147 | output_disp_normalized = output_disp_normalized[:, :, pad_h:, pad_w:]
148 | output_disp = torch.clamp(
149 | output_disp_normalized * delta * max_disp_levels + d_min_GT, min=d_min_GT, max=d_max_GT
150 | )
151 |
152 | # Write Results
153 | max_disparity_color = 192
154 | output_disp_clamp = output_disp[0, 0, :, :].cpu().clone().numpy()
155 | output_disp_clamp[output_disp_clamp < min_disparity] = 0
156 | output_disp_clamp[output_disp_clamp > max_disparity] = max_disparity_color
157 | disp_np_ours_color = disp2rgb(output_disp_clamp / max_disparity_color) * 255.0
158 | cv2.imwrite(
159 | os.path.join(out_dir, "%s_%s_%d_%d.png" % (base_name, args.arch, min_disparity, max_disparity)),
160 | disp_np_ours_color,
161 | )
162 |
163 | return
164 |
165 |
166 | if __name__ == "__main__":
167 | main()
168 |
--------------------------------------------------------------------------------
/src/run_demo_kitti15.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # GENERATE BINARY DEPTH SEGMENTATIONS AND COMBINE THEM TO GENERATE QUANTIZED DEPTH
4 | CUDA_VISIBLE_DEVICES=0 python run_binary_depth_estimation.py \
5 | --arch bi3dnet_binary_depth \
6 | --bi3dnet_featnet_arch featextractnetspp \
7 | --bi3dnet_featnethr_arch featextractnethr \
8 | --bi3dnet_segnet_arch segnet2d \
9 | --bi3dnet_refinenet_arch segrefinenet \
10 | --featextractnethr_out_planes 16 \
11 | --segrefinenet_in_planes 17 \
12 | --segrefinenet_out_planes 8 \
13 | --crop_height 384 --crop_width 1248 \
14 | --disp_vals 12 21 30 39 48 \
15 | --img_left '../data/kitti15_img_left.jpg' \
16 | --img_right '../data/kitti15_img_right.jpg' \
17 | --pretrained '../model_weights/kitti15_binary_depth.pth.tar'
18 |
19 |
20 | # FULL RANGE CONTINOUS DEPTH ESTIMATION WITHOUT 3D REGULARIZATION
21 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \
22 | --arch bi3dnet_continuous_depth_2D \
23 | --bi3dnet_featnet_arch featextractnetspp \
24 | --bi3dnet_segnet_arch segnet2d \
25 | --bi3dnet_refinenet_arch disprefinenet \
26 | --disprefinenet_out_planes 32 \
27 | --crop_height 384 --crop_width 1248 \
28 | --disp_range_min 0 \
29 | --disp_range_max 192 \
30 | --bi3dnet_max_disparity 192 \
31 | --img_left '../data/kitti15_img_left.jpg' \
32 | --img_right '../data/kitti15_img_right.jpg' \
33 | --pretrained '../model_weights/kitti15_continuous_depth_no_conf_reg.pth.tar'
34 |
35 |
36 | # SELECTIVE RANGE CONTINOUS DEPTH ESTIMATION WITHOUT 3D REGULARIZATION
37 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \
38 | --arch bi3dnet_continuous_depth_2D \
39 | --bi3dnet_featnet_arch featextractnetspp \
40 | --bi3dnet_segnet_arch segnet2d \
41 | --bi3dnet_refinenet_arch disprefinenet \
42 | --disprefinenet_out_planes 32 \
43 | --crop_height 384 --crop_width 1248 \
44 | --disp_range_min 12 \
45 | --disp_range_max 48 \
46 | --bi3dnet_max_disparity 192 \
47 | --img_left '../data/kitti15_img_left.jpg' \
48 | --img_right '../data/kitti15_img_right.jpg' \
49 | --pretrained '../model_weights/kitti15_continuous_depth_no_conf_reg.pth.tar'
50 |
51 |
52 | # FULL RANGE CONTINOUS DEPTH ESTIMATION WITH 3D REGULARIZATION
53 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \
54 | --arch bi3dnet_continuous_depth_3D \
55 | --bi3dnet_featnet_arch featextractnetspp \
56 | --bi3dnet_segnet_arch segnet2d \
57 | --bi3dnet_refinenet_arch disprefinenet \
58 | --bi3dnet_regnet_arch segregnet3d \
59 | --disprefinenet_out_planes 32 \
60 | --regnet_out_planes 16 \
61 | --crop_height 384 --crop_width 1248 \
62 | --disp_range_min 0 \
63 | --disp_range_max 192 \
64 | --bi3dnet_max_disparity 192 \
65 | --img_left '../data/kitti15_img_left.jpg' \
66 | --img_right '../data/kitti15_img_right.jpg' \
67 | --pretrained '../model_weights/kitti15_continuous_depth_conf_reg.pth.tar'
68 |
--------------------------------------------------------------------------------
/src/run_demo_sf.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # GENERATE BINARY DEPTH SEGMENTATIONS AND COMBINE THEM TO GENERATE QUANTIZED DEPTH
4 | CUDA_VISIBLE_DEVICES=0 python run_binary_depth_estimation.py \
5 | --arch bi3dnet_binary_depth \
6 | --bi3dnet_featnet_arch featextractnetspp \
7 | --bi3dnet_featnethr_arch featextractnethr \
8 | --bi3dnet_segnet_arch segnet2d \
9 | --bi3dnet_refinenet_arch segrefinenet \
10 | --featextractnethr_out_planes 16 \
11 | --segrefinenet_in_planes 17 \
12 | --segrefinenet_out_planes 8 \
13 | --crop_height 576 --crop_width 960 \
14 | --disp_vals 24 36 54 96 144 \
15 | --img_left '../data/sf_img_left.jpg' \
16 | --img_right '../data/sf_img_right.jpg' \
17 | --pretrained '../model_weights/sf_binary_depth.pth.tar'
18 |
19 |
20 | # FULL RANGE CONTINOUS DEPTH ESTIMATION WITHOUT 3D REGULARIZATION
21 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \
22 | --arch bi3dnet_continuous_depth_2D \
23 | --bi3dnet_featnet_arch featextractnetspp \
24 | --bi3dnet_segnet_arch segnet2d \
25 | --bi3dnet_refinenet_arch disprefinenet \
26 | --disprefinenet_out_planes 32 \
27 | --crop_height 576 --crop_width 960 \
28 | --disp_range_min 0 \
29 | --disp_range_max 192 \
30 | --bi3dnet_max_disparity 192 \
31 | --img_left '../data/sf_img_left.jpg' \
32 | --img_right '../data/sf_img_right.jpg' \
33 | --pretrained '../model_weights/sf_continuous_depth_no_conf_reg.pth.tar'
34 |
35 |
36 | # SELECTIVE RANGE CONTINOUS DEPTH ESTIMATION WITHOUT 3D REGULARIZATION
37 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \
38 | --arch bi3dnet_continuous_depth_2D \
39 | --bi3dnet_featnet_arch featextractnetspp \
40 | --bi3dnet_segnet_arch segnet2d \
41 | --bi3dnet_refinenet_arch disprefinenet \
42 | --disprefinenet_out_planes 32 \
43 | --crop_height 576 --crop_width 960 \
44 | --disp_range_min 18 \
45 | --disp_range_max 60 \
46 | --bi3dnet_max_disparity 192 \
47 | --img_left '../data/sf_img_left.jpg' \
48 | --img_right '../data/sf_img_right.jpg' \
49 | --pretrained '../model_weights/sf_continuous_depth_no_conf_reg.pth.tar'
50 |
51 |
52 | # FULL RANGE CONTINOUS DEPTH ESTIMATION WITH 3D REGULARIZATION
53 | CUDA_VISIBLE_DEVICES=0 python run_continuous_depth_estimation.py \
54 | --arch bi3dnet_continuous_depth_3D \
55 | --bi3dnet_featnet_arch featextractnetspp \
56 | --bi3dnet_segnet_arch segnet2d \
57 | --bi3dnet_refinenet_arch disprefinenet \
58 | --bi3dnet_regnet_arch segregnet3d \
59 | --disprefinenet_out_planes 32 \
60 | --regnet_out_planes 16 \
61 | --crop_height 576 --crop_width 960 \
62 | --disp_range_min 0 \
63 | --disp_range_max 192 \
64 | --bi3dnet_max_disparity 192 \
65 | --img_left '../data/sf_img_left.jpg' \
66 | --img_right '../data/sf_img_right.jpg' \
67 | --pretrained '../model_weights/sf_continuous_depth_conf_reg.pth.tar'
68 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import os
10 | import numpy as np
11 |
12 |
13 | def disp2rgb(disp):
14 | H = disp.shape[0]
15 | W = disp.shape[1]
16 |
17 | I = disp.flatten()
18 |
19 | map = np.array(
20 | [
21 | [0, 0, 0, 114],
22 | [0, 0, 1, 185],
23 | [1, 0, 0, 114],
24 | [1, 0, 1, 174],
25 | [0, 1, 0, 114],
26 | [0, 1, 1, 185],
27 | [1, 1, 0, 114],
28 | [1, 1, 1, 0],
29 | ]
30 | )
31 | bins = map[:-1, 3]
32 | cbins = np.cumsum(bins)
33 | bins = bins / cbins[-1]
34 | cbins = cbins[:-1] / cbins[-1]
35 |
36 | ind = np.minimum(
37 | np.sum(np.repeat(I[None, :], 6, axis=0) > np.repeat(cbins[:, None], I.shape[0], axis=1), axis=0), 6
38 | )
39 | bins = np.reciprocal(bins)
40 | cbins = np.append(np.array([[0]]), cbins[:, None])
41 |
42 | I = np.multiply(I - cbins[ind], bins[ind])
43 | I = np.minimum(
44 | np.maximum(
45 | np.multiply(map[ind, 0:3], np.repeat(1 - I[:, None], 3, axis=1))
46 | + np.multiply(map[ind + 1, 0:3], np.repeat(I[:, None], 3, axis=1)),
47 | 0,
48 | ),
49 | 1,
50 | )
51 |
52 | I = np.reshape(I, [H, W, 3]).astype(np.float32)
53 |
54 | return I
55 |
56 |
57 | def str2bool(bool_input_string):
58 | if isinstance(bool_input_string, bool):
59 | return bool_input_string
60 | if bool_input_string.lower() in ("true"):
61 | return True
62 | elif bool_input_string.lower() in ("false"):
63 | return False
64 | else:
65 | raise NameError("Please provide boolean type.")
66 |
--------------------------------------------------------------------------------