├── .gitignore
├── LICENSE
├── README.md
├── dataset
├── __init__.py
└── dataset_360D.py
├── exporters
├── __init__.py
└── image.py
├── filesystem
└── file_utils.py
├── infer.py
├── models
├── __init__.py
├── modules.py
└── resnet360.py
├── spherical
├── __init__.py
├── cartesian.py
├── derivatives.py
├── grid.py
└── weights.py
├── supervision
├── __init__.py
├── direct.py
├── photometric.py
├── smoothness.py
├── splatting.py
└── ssim.py
├── test.py
├── train_lr.py
├── train_sv.py
├── train_tc.py
├── train_ud.py
└── utils
├── __init__.py
├── checkpoint.py
├── framework.py
├── init.py
├── meters.py
├── opt.py
└── visualization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2019, Visual Computing Lab, Information Technologies Institute, Centre for Reseach and Technology Hellas
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spherical View Synthesis for Self-Supervised 360o Depth Estimation
2 |
3 | [](https://arxiv.org/pdf/1909.08112.pdf)
4 | [](http://3dv19.gel.ulaval.ca/)
5 | [](https://vcl3d.github.io/SphericalViewSynthesis/)
6 | ___
7 |
8 | # Data
9 |
10 | >  An updated dataset is now available which fixes a critical issue with 3D60, the lighting bias introduced by the light source placed at the origin. More information can be found at the [Pano3D project page](https://vcl3d.github.io/Pano3D/).
11 | >
12 | The 360o stereo data used to train the self-supervised models are available [here](https://vcl3d.github.io/3D60/) and are part of a larger dataset __\[[1](#OmniDepth), [2](#HyperSphere)\]__ that contains rendered color images, depth and normal maps for each viewpoint in a trinocular setup.
13 |
14 | ___
15 |
16 | ## Train
17 | Training code to reproduce our experiments is available in this repository:
18 |
19 | A set of training scripts are available for each different variant:
20 |
21 | * [`train_ud.py`](./train_ud.py) for vertical stereo (__UD__) training
22 | * [`train_lr.py`](./train_lr.py) for horizontal stereo (__LR__) training
23 | * [`train_tc.py`](./train_tc.py) for trinocular stereo (__TC__) training, using the `photo_ratio` argument to train the different __TC__ variants.
24 | * [`train_sv.py`](./train_sv.py) for supervised (__SV__) training
25 |
26 | The PyTorch implementation of the differentiable depth-image-based forward rendering ([_`splatting`_](./supervision/splatting.py#L9)), presented in __\[[3](#LSI)\]__ and originally implemented in [TensorFlow](https://github.com/google/layered-scene-inference), is also [available](./supervision/splatting.py#L73).
27 |
28 | ## Test
29 |
30 | Our evaluation script [`test.py`](./test.py) also includes the adaptation of the metrics calculation to spherical data that includes [spherical weighting](./spherical/weights.py#L8) and [spiral sampling](./test.py#L92).
31 |
32 | ## Pre-trained Models
33 | Our PyTorch pre-trained models (corresponding to those reported in the paper) are available at our [releases](https://github.com/VCL3D/SphericalViewSynthesis/releases) and contain these model variants:
34 |
35 | * [__UD__ @ epoch 16](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/UD/ud.pt)
36 | * [__TC8__ @ epoch 16](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/TC8/tc8.pt)
37 | * [__TC6__ @ epoch 28](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/TC6/tc6.pt)
38 | * [__TC4__ @ epoch 17](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/TC4/tc4.pt)
39 | * [__TC2__ @ epoch 20](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/TC2/tc2.pt)
40 | * [__LR__ @ epoch 18](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/LR/lr.pt)
41 | * [__SV__ @ epoch 24](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/SV/sv.pt)
42 |
43 | ___
44 |
45 | ## Citation
46 | If you use this code and/or data, please cite the following:
47 | ```
48 | @inproceedings{zioulis2019spherical,
49 | author = "Zioulis, Nikolaos and Karakottas, Antonis and Zarpalas, Dimitris and Alvarez, Federic and Daras, Petros",
50 | title = "Spherical View Synthesis for Self-Supervised $360^o$ Depth Estimation",
51 | booktitle = "International Conference on 3D Vision (3DV)",
52 | month = "September",
53 | year = "2019"
54 | }
55 | ```
56 |
57 |
58 | # References
59 | __\[[1](https://vcl.iti.gr/360-dataset)\]__ Zioulis, N.__\*__, Karakottas, A.__\*__, Zarpalas, D., and Daras, P. (2018). [Omnidepth: Dense depth estimation for indoors spherical panoramas](https://arxiv.org/pdf/1807.09620.pdf). In Proceedings of the European Conference on Computer Vision (ECCV).
60 |
61 | __\[[2](https://vcl3d.github.io/HyperSphereSurfaceRegression/)\]__ Karakottas, A., Zioulis, N., Samaras, S., Ataloglou, D., Gkitsas, V., Zarpalas, D., and Daras, P. (2019). [360o Surface Regression with a Hyper-sphere Loss](https://arxiv.org/pdf/1909.07043.pdf). In Proceedings of the International Conference on 3D Vision (3DV).
62 |
63 | __[3]__ Tulsiani, S., Tucker, R., and Snavely, N. (2018). [Layer-structured 3d scene inference via view synthesis](https://arxiv.org/pdf/1807.10264.pdf). In Proceedings of the European Conference on Computer Vision (ECCV).
64 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset_360D import *
2 |
--------------------------------------------------------------------------------
/dataset/dataset_360D.py:
--------------------------------------------------------------------------------
1 | ###################################
2 | # 360 dataset pytorch dataloader
3 | ###################################
4 | import os
5 |
6 | import numpy as np
7 | import cv2
8 | import PIL.Image as Image
9 | import datetime
10 |
11 | import torch
12 | from torch.utils.data import Dataset
13 | from torchvision import transforms
14 | # Ignore warnings
15 | import warnings
16 | warnings.filterwarnings("ignore")
17 |
18 | ############################################################################################################
19 | # We use a text file to hold our dataset's filenames
20 | # The "filenames list" has the following format
21 | #
22 | # path/to/Left/rgb.png path/to/Right/rgb.png path/to/Up/rgb.png path/to/Left/depth.exr path/to/Right/depth.exr path/to/Up/depth.exr
23 | #
24 | # We also have a Trinocular version, but you get the feeling.
25 | #############################################################################################################
26 |
27 | class Dataset360D(Dataset):
28 | #360D Dataset#
29 | def __init__(self, filenamesFile, delimiter, mode, inputShape, transform=None, rescaled=False):
30 | #########################################################################################################
31 | # Arguments:
32 | # -filenamesFile: Absolute path to the aforementioned filenames .txt file
33 | # -transform : (Optional) transform to be applied on a sample
34 | # -mode : Dataset mode. Available options: mono, lr (Left-Right), ud (Up-Down), tc (Trinocular)
35 | #########################################################################################################
36 | self.height = inputShape[0]
37 | self.width = inputShape[1]
38 | self.sample = {} # one dataset sample (dictionary)
39 | self.resize2 = transforms.Resize([128, 256]) # function to resize input image by a factor of 2
40 | self.resize4 = transforms.Resize([64, 128]) # function to resize input image by a factor of 4
41 | self.pilToTensor = transforms.ToTensor() if transform is None else transforms.Compose((
42 | [
43 | transforms.ToTensor(), # function to convert pillow image to tensor
44 | transform
45 | ])
46 | )
47 | self.filenamesFilePath = filenamesFile # file containing image paths to load
48 | self.delimiter = delimiter # delimiter in filenames file
49 | self.mode = mode # dataset mode
50 | self.initDict(self.mode) # initializes dictionary with filepaths
51 | self.loadFilenamesFile() # loads filepaths to dictionary
52 | self.rescaled = rescaled
53 |
54 | # Check if given dataset mode is correct
55 | # Available modes: mono, lr, ud, tc
56 | def checkMode(self, mode):
57 | accepted = False
58 | if (mode != "mono" and mode != "lr" and mode != "ud" and mode != "tc"):
59 | print("{} | Given dataset mode [{}] is not known. Available modes: mono, lr, ud, tc".format(datetime.datetime.now(), mode))
60 | exit()
61 | else:
62 | accepted = True
63 | return accepted
64 |
65 | # initializes dictionary's lists w.r.t. the dataset's mode
66 | def initDict(self, mode):
67 | if (mode == "mono"):
68 | self.sample["leftRGB"] = []
69 | self.sample["leftRGB2"] = []
70 | self.sample["leftRGB4"] = []
71 | self.sample["leftDepth"] = []
72 | self.sample["leftDepth2"] = []
73 | self.sample["leftDepth4"] = []
74 | elif (mode == "lr"):
75 | self.sample["leftRGB"] = []
76 | self.sample["leftRGB2"] = []
77 | self.sample["leftRGB4"] = []
78 | self.sample["rightRGB"] = []
79 | self.sample["rightRGB2"] = []
80 | self.sample["rightRGB4"] = []
81 | self.sample["leftDepth"] = []
82 | self.sample["leftDepth2"] = []
83 | self.sample["leftDepth4"] = []
84 | self.sample["rightDepth"] = []
85 | self.sample["rightDepth2"] = []
86 | self.sample["rightDepth4"] = []
87 | elif (mode == "ud"):
88 | self.sample["leftRGB"] = []
89 | self.sample["leftRGB2"] = []
90 | self.sample["leftRGB4"] = []
91 | self.sample["upRGB"] = []
92 | self.sample["upRGB2"] = []
93 | self.sample["upRGB4"] = []
94 | self.sample["leftDepth"] = []
95 | self.sample["leftDepth2"] = []
96 | self.sample["leftDepth4"] = []
97 | self.sample["upDepth"] = []
98 | self.sample["upDepth2"] = []
99 | self.sample["upDepth4"] = []
100 | elif (mode == "tc"):
101 | self.sample["leftRGB"] = []
102 | self.sample["leftRGB2"] = []
103 | self.sample["leftRGB4"] = []
104 | self.sample["rightRGB"] = []
105 | self.sample["rightRGB2"] = []
106 | self.sample["rightRGB4"] = []
107 | self.sample["upRGB"] = []
108 | self.sample["upRGB2"] = []
109 | self.sample["upRGB4"] = []
110 | self.sample["leftDepth"] = []
111 | self.sample["leftDepth2"] = []
112 | self.sample["leftDepth4"] = []
113 | self.sample["rightDepth"] = []
114 | self.sample["rightDepth2"] = []
115 | self.sample["rightDepth4"] = []
116 | self.sample["upDepth"] = []
117 | self.sample["upDepth2"] = []
118 | self.sample["upDepth4"] = []
119 |
120 | # configures samples when in mono mode
121 | # loads filepaths to dictionary's list
122 | def initModeMono(self, lines):
123 | for line in lines:
124 | leftRGBPath = line.split(self.delimiter)[0]
125 | leftDepthPath = line.split(self.delimiter)[3]
126 | self.sample["leftRGB"].append(leftRGBPath)
127 | self.sample["leftDepth"].append(leftDepthPath)
128 |
129 |
130 | # configures dataset samples when in Left-Right mode
131 | def initModeLR(self, lines):
132 | for line in lines:
133 | leftRGBPath = line.split(self.delimiter)[0]
134 | rightRGBPath = line.split(self.delimiter)[1]
135 | leftDepthPath = line.split(self.delimiter)[3]
136 | rightDepthPath = line.split(self.delimiter)[4]
137 | self.sample["leftRGB"].append(leftRGBPath)
138 | self.sample["rightRGB"].append(rightRGBPath)
139 | self.sample["leftDepth"].append(leftDepthPath)
140 | self.sample["rightDepth"].append(rightDepthPath)
141 |
142 | # configures dataset samples when in Up-Down mode
143 | def initModeUD(self, lines):
144 | for line in lines:
145 | leftRGBPath = line.split(self.delimiter)[0]
146 | upRGBPath = line.split(self.delimiter)[2]
147 | leftDepthPath = line.split(self.delimiter)[3]
148 | upDepthPath = line.split(self.delimiter)[5]
149 | self.sample["leftRGB"].append(leftRGBPath)
150 | self.sample["upRGB"].append(upRGBPath)
151 | self.sample["leftDepth"].append(leftDepthPath)
152 | self.sample["upDepth"].append(upDepthPath)
153 |
154 | # configures dataset samples when in Trinocular mode
155 | def initModeTC(self, lines):
156 | for line in lines:
157 | leftRGBPath = line.split(self.delimiter)[0]
158 | rightRGBPath = line.split(self.delimiter)[1]
159 | upRGBPath = line.split(self.delimiter)[2]
160 | leftDepthPath = line.split(self.delimiter)[3]
161 | rightDepthPath = line.split(self.delimiter)[4]
162 | upDepthPath = line.split(self.delimiter)[5]
163 | self.sample["leftRGB"].append(leftRGBPath)
164 | self.sample["rightRGB"].append(rightRGBPath)
165 | self.sample["upRGB"].append(upRGBPath)
166 | self.sample["leftDepth"].append(leftDepthPath)
167 | self.sample["rightDepth"].append(rightDepthPath)
168 | self.sample["upDepth"].append(upDepthPath)
169 |
170 | # Loads filenames from .txt file and saves the samples' paths w.r.t. the dataset mode
171 | def loadFilenamesFile(self):
172 | if (not os.path.exists(self.filenamesFilePath)):
173 | print("{} | Filepath [{}] does not exist.".format(datetime.datetime.now(), self.filenamesFilePath))
174 | exit()
175 | fileID = open(self.filenamesFilePath, "r")
176 | lines = fileID.readlines()
177 | if (lines == 0):
178 | print("{} | Cannot open file: {}".format(datetime.datetime.now(), self.filenamesFilePath))
179 | exit()
180 | self.length = len(lines)
181 | if (self.mode == "mono"):
182 | self.initModeMono(lines)
183 | elif (self.mode == "lr"):
184 | self.initModeLR(lines)
185 | elif (self.mode == "ud"):
186 | self.initModeUD(lines)
187 | elif (self.mode == "tc"):
188 | self.initModeTC(lines)
189 |
190 | # loads sample from dataset mono mode
191 | def loadItemMono(self, idx):
192 | item = {}
193 | if (idx >= self.length):
194 | print("Index [{}] out of range. Dataset length: {}".format(idx, self.length))
195 | else:
196 | dtmp = np.array(cv2.imread(self.sample["leftDepth"][idx], cv2.IMREAD_ANYDEPTH))
197 | left_depth = torch.from_numpy(dtmp)
198 | left_depth.unsqueeze_(0)
199 | if self.rescaled:
200 | dtmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2))
201 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4))
202 | left_depth2 = torch.from_numpy(dtmp2)
203 | left_depth2.unsqueeze_(0)
204 | left_depth4 = torch.from_numpy(dtmp4)
205 | left_depth4.unsqueeze_(0)
206 |
207 | pilRGB = Image.open(self.sample["leftRGB"][idx])
208 | rgb = self.pilToTensor(pilRGB)
209 | if self.rescaled:
210 | rgb2 = self.pilToTensor(self.resize2(pilRGB))
211 | rgb4 = self.pilToTensor(self.resize4(pilRGB))
212 | item = {
213 | "leftRGB": rgb,
214 | "leftRGB2": rgb2,
215 | "leftRGB4": rgb4,
216 | "leftDepth": left_depth,
217 | "leftDepth2": left_depth2,
218 | "leftDepth4": left_depth4,
219 | "leftDepth_filename": os.path.basename(self.sample["leftDepth"][idx][:-4])
220 | } if self.rescaled else {
221 | "leftRGB": rgb,
222 | "leftDepth": left_depth,
223 | "leftDepth_filename": os.path.basename(self.sample["leftDepth"][idx][:-4])
224 | }
225 | return item
226 |
227 | # loads sample from dataset lr mode
228 | def loadItemLR(self, idx):
229 | item = {}
230 | if (idx >= self.length):
231 | print("Index [{}] out of range. Dataset length: {}".format(idx, self.length))
232 | else:
233 | leftRGB = Image.open(self.sample["leftRGB"][idx])
234 | rightRGB = Image.open(self.sample["rightRGB"][idx])
235 | if self.rescaled:
236 | leftRGB2 = self.pilToTensor(self.resize2(leftRGB))
237 | leftRGB4 = self.pilToTensor(self.resize4(leftRGB))
238 | rightRGB2 = self.pilToTensor(self.resize2(rightRGB))
239 | rightRGB4 = self.pilToTensor(self.resize4(rightRGB))
240 |
241 | dtmp = np.array(cv2.imread(self.sample["leftDepth"][idx], cv2.IMREAD_ANYDEPTH))
242 | left_depth = torch.from_numpy(dtmp)
243 | left_depth.unsqueeze_(0)
244 | if self.rescaled:
245 | dtmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2))
246 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4))
247 | left_depth2 = torch.from_numpy(dtmp2)
248 | left_depth2.unsqueeze_(0)
249 | left_depth4 = torch.from_numpy(dtmp4)
250 | left_depth4.unsqueeze_(0)
251 |
252 | dtmp = np.array(cv2.imread(self.sample["rightDepth"][idx], cv2.IMREAD_ANYDEPTH))
253 | right_depth = torch.from_numpy(dtmp)
254 | right_depth.unsqueeze_(0)
255 | if self.rescaled:
256 | tmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2))
257 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4))
258 | right_depth2 = torch.from_numpy(dtmp2)
259 | right_depth2.unsqueeze_(0)
260 | right_depth4 = torch.from_numpy(dtmp4)
261 | right_depth4.unsqueeze_(0)
262 | item = {
263 | "leftRGB": self.pilToTensor(leftRGB),
264 | "rightRGB": self.pilToTensor(rightRGB),
265 | "leftRGB2": leftRGB2,
266 | "rightRGB2": rightRGB2,
267 | "leftRGB4": leftRGB4,
268 | "rightRGB4": rightRGB4 ,
269 | "leftDepth": left_depth,
270 | 'leftDepth2': left_depth2,
271 | 'leftDepth4': left_depth4,
272 | "rightDepth": right_depth,
273 | "rightDepth2": right_depth2,
274 | "rightDepth4": right_depth4,
275 | 'leftDepth_filename': os.path.basename(self.sample['leftDepth'][idx][:-4])
276 | } if self.rescaled else {
277 | "leftRGB": self.pilToTensor(leftRGB),
278 | "rightRGB": self.pilToTensor(rightRGB),
279 | "leftDepth": left_depth,
280 | "rightDepth": right_depth,
281 | 'leftDepth_filename': os.path.basename(self.sample['leftDepth'][idx][:-4])
282 | }
283 | return item
284 |
285 | # loads sample from dataset ud mode
286 | def loadItemUD(self, idx):
287 | item = {}
288 | if (idx >= self.length):
289 | print("Index [{}] out of range. Dataset length: {}".format(idx, self.length))
290 | else:
291 | leftRGB = Image.open(self.sample["leftRGB"][idx])
292 | upRGB = Image.open(self.sample["upRGB"][idx])
293 | if self.rescaled:
294 | leftRGB2 = self.pilToTensor(self.resize2(leftRGB))
295 | leftRGB4 = self.pilToTensor(self.resize4(leftRGB))
296 | upRGB2 = self.pilToTensor(self.resize2(upRGB))
297 | upRGB4 = self.pilToTensor(self.resize4(upRGB))
298 | dtmp = np.array(cv2.imread(self.sample["leftDepth"][idx], cv2.IMREAD_ANYDEPTH))
299 | depth = torch.from_numpy(dtmp)
300 | depth.unsqueeze_(0)
301 | if self.rescaled:
302 | dtmp2 = cv2.resize(dtmp, (self.width // 2, self.height // 2))
303 | dtmp4 = cv2.resize(dtmp, (self.width // 4, self.height // 4))
304 | depth2 = torch.from_numpy(dtmp2)
305 | depth2.unsqueeze_(0)
306 | depth4 = torch.from_numpy(dtmp4)
307 | depth4.unsqueeze_(0)
308 |
309 |
310 | dtmp = np.array(cv2.imread(self.sample["upDepth"][idx], cv2.IMREAD_ANYDEPTH))
311 | up_depth = torch.from_numpy(dtmp)
312 | up_depth.unsqueeze_(0)
313 | if self.rescaled:
314 | tmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2))
315 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4))
316 | up_depth2 = torch.from_numpy(dtmp2)
317 | up_depth2.unsqueeze_(0)
318 | up_depth4 = torch.from_numpy(dtmp4)
319 | up_depth4.unsqueeze_(0)
320 |
321 | item = {
322 | "leftRGB": self.pilToTensor(leftRGB),
323 | "upRGB": self.pilToTensor(upRGB),
324 | "leftRGB2": leftRGB2,
325 | "upRGB2": upRGB2,
326 | "leftRGB4": leftRGB4,
327 | "upRGB4": upRGB4,
328 | "leftDepth": depth,
329 | "leftDepth2": depth2,
330 | "leftDepth4": depth4,
331 | "upDepth": up_depth,
332 | "upDepth2": up_depth2,
333 | "upDepth4": up_depth4,
334 | 'leftDepth_filename': os.path.basename(self.sample['leftDepth'][idx][:-4])
335 | } if self.rescaled else {
336 | "leftRGB": self.pilToTensor(leftRGB),
337 | "upRGB": self.pilToTensor(upRGB),
338 | "leftDepth": depth,
339 | "upDepth": up_depth,
340 | 'leftDepth_filename': os.path.basename(self.sample['leftDepth'][idx][:-4])
341 | }
342 | return item
343 |
344 | # loads sample from dataset tc mode
345 | def loadItemTC(self, idx):
346 | item = {}
347 | if (idx >= self.length):
348 | print("Index [{}] out of range. Dataset length: {}".format(idx, self.length))
349 | else:
350 | leftRGB = Image.open(self.sample["leftRGB"][idx])
351 | rightRGB = Image.open(self.sample["rightRGB"][idx])
352 | upRGB = Image.open(self.sample["upRGB"][idx])
353 | if self.rescaled:
354 | leftRGB2 = self.pilToTensor(self.resize2(leftRGB))
355 | leftRGB4 = self.pilToTensor(self.resize4(leftRGB))
356 | rightRGB2 = self.pilToTensor(self.resize2(rightRGB))
357 | rightRGB4 = self.pilToTensor(self.resize4(rightRGB))
358 | upRGB2 = self.pilToTensor(self.resize2(upRGB))
359 | upRGB4 = self.pilToTensor(self.resize4(upRGB))
360 |
361 | dtmp = np.array(cv2.imread(self.sample["leftDepth"][idx], cv2.IMREAD_ANYDEPTH))
362 | depth = torch.from_numpy(dtmp)
363 | depth.unsqueeze_(0)
364 | if self.rescaled:
365 | dtmp2 = cv2.resize(dtmp, (self.width // 2, self.height // 2))
366 | dtmp4 = cv2.resize(dtmp, (self.width // 4, self.height // 4))
367 | depth2 = torch.from_numpy(dtmp2)
368 | depth2.unsqueeze_(0)
369 | depth4 = torch.from_numpy(dtmp4)
370 | depth4.unsqueeze_(0)
371 |
372 | dtmp = np.array(cv2.imread(self.sample["rightDepth"][idx], cv2.IMREAD_ANYDEPTH))
373 | right_depth = torch.from_numpy(dtmp)
374 | right_depth.unsqueeze_(0)
375 | if self.rescaled:
376 | tmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2))
377 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4))
378 | right_depth2 = torch.from_numpy(dtmp2)
379 | right_depth2.unsqueeze_(0)
380 | right_depth4 = torch.from_numpy(dtmp4)
381 | right_depth4.unsqueeze_(0)
382 |
383 | dtmp = np.array(cv2.imread(self.sample["upDepth"][idx], cv2.IMREAD_ANYDEPTH))
384 | up_depth = torch.from_numpy(dtmp)
385 | up_depth.unsqueeze_(0)
386 | if self.rescaled:
387 | tmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2))
388 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4))
389 | up_depth2 = torch.from_numpy(dtmp2)
390 | up_depth2.unsqueeze_(0)
391 | up_depth4 = torch.from_numpy(dtmp4)
392 | up_depth4.unsqueeze_(0)
393 |
394 | item = {
395 | "leftRGB": self.pilToTensor(leftRGB),
396 | "rightRGB": self.pilToTensor(rightRGB),
397 | "upRGB": self.pilToTensor(upRGB),
398 | "leftRGB2": leftRGB2,
399 | "rightRGB2": rightRGB2,
400 | "upRGB2": upRGB2,
401 | "leftRGB4": leftRGB4,
402 | "rightRGB4": rightRGB4,
403 | "upRGB4": upRGB4,
404 | "leftDepth": depth,
405 | "leftDepth2": depth2,
406 | "leftDepth4": depth4,
407 | "upDepth": up_depth,
408 | "upDepth2": up_depth2,
409 | "upDepth4": up_depth4,
410 | "rightDepth": right_depth,
411 | "rightDepth2": right_depth2,
412 | "rightDepth4": right_depth4,
413 | "depthFilename": os.path.basename(self.sample["leftDepth"][idx][:-4])
414 | } if self.rescaled else {
415 | "leftRGB": self.pilToTensor(leftRGB),
416 | "rightRGB": self.pilToTensor(rightRGB),
417 | "upRGB": self.pilToTensor(upRGB),
418 | "leftDepth": depth,
419 | "rightDepth": right_depth,
420 | "upDepth": up_depth,
421 | "depthFilename": os.path.basename(self.sample["leftDepth"][idx][:-4])
422 | }
423 | return item
424 |
425 | # torch override
426 | # returns samples length
427 | def __len__(self):
428 | return self.length
429 |
430 | # torch override
431 | def __getitem__(self, idx):
432 | if (self.mode == "mono"):
433 | return self.loadItemMono(idx)
434 | elif(self.mode == "lr"):
435 | return self.loadItemLR(idx)
436 | elif(self.mode == "ud"):
437 | return self.loadItemUD(idx)
438 | elif(self.mode == "tc"):
439 | return self.loadItemTC(idx)
440 |
441 |
442 |
443 |
--------------------------------------------------------------------------------
/exporters/__init__.py:
--------------------------------------------------------------------------------
1 | from .image import *
--------------------------------------------------------------------------------
/exporters/image.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import numpy
4 |
5 | def save_image(filename, tensor, scale=255.0):
6 | b, _, __, ___ = tensor.size()
7 | for n in range(b):
8 | array = tensor[n, :, :, :].detach().cpu().numpy()
9 | array = array.transpose(1, 2, 0) * scale
10 | cv2.imwrite(filename.replace("#", str(n)), array)
11 |
12 | def save_depth(filename, tensor, scale=1000.0):
13 | b, _, __, ___ = tensor.size()
14 | for n in range(b):
15 | array = tensor[n, :, :, :].detach().cpu().numpy()
16 | array = array.transpose(1, 2, 0) * scale
17 | array = numpy.uint16(array)
18 | cv2.imwrite(filename.replace("#", str(n)), array)
19 |
20 | def save_data(filename, tensor, scale=1000.0):
21 | b, _, __, ___ = tensor.size()
22 | for n in range(b):
23 | array = tensor[n, :, :, :].detach().cpu().numpy()
24 | array = array.transpose(1, 2, 0) * scale
25 | array = numpy.float32(array)
26 | cv2.imwrite(filename.replace("#", str(n)), array)
27 |
--------------------------------------------------------------------------------
/filesystem/file_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | '''
4 | Filesystem class
5 | provides file control utilities like tensor saving etc.
6 | '''
7 | class Filesystem:
8 | def __init__(self):
9 | self.cwd = os.getcwd()
10 | if os.path.isfile(self.cwd):
11 | self.cwd = os.path.basename(self.cwd)
12 | '''
13 | Creates directory
14 | either by giving the absolute path to create
15 | or the relative path w.r.t. the current working directory
16 |
17 | \param path the path to create
18 | '''
19 | def mkdir(self, path):
20 | if os.path.isabs(path):
21 | if not os.path.exists(path):
22 | os.mkdir(path)
23 | else:
24 | pathToCreate = os.path.join(self.cwd, path)
25 | if not os.path.exists(pathToCreate):
26 | os.mkdir(pathToCreate)
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import numpy
5 | import cv2
6 |
7 | import torch
8 |
9 | import models
10 | import utils
11 | import exporters
12 |
13 | def parse_arguments(args):
14 | usage_text = (
15 | "Semi-supervised Spherical Depth Estimation Testing."
16 | )
17 | parser = argparse.ArgumentParser(description=usage_text)
18 | parser.add_argument("--input_path", type=str, help="Path to the input spherical panorama image.")
19 | parser.add_argument('--weights', type=str, help='Path to the trained weights file.')
20 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.')
21 | return parser.parse_known_args(args)
22 |
23 | if __name__ == "__main__":
24 | args, unknown = parse_arguments(sys.argv)
25 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0]
26 | # device & visualizers
27 | device = torch.device("cuda:{}" .format(gpus[0])\
28 | if torch.cuda.is_available() and len(gpus) > 0 and gpus[0] >= 0\
29 | else "cpu")
30 | # model
31 | model = models.get_model("resnet_coord", {})
32 | utils.init.initialize_weights(model, args.weights, pred_bias=None)
33 | model = model.to(device)
34 | # test data
35 | width, height = 512, 256
36 | if not os.path.exists(args.input_path):
37 | print("Input image path does not exist (%s)." % args.input_path)
38 | exit(-1)
39 | img = cv2.imread(args.input_path)
40 | h, w, _ = img.shape
41 | if h != height and w != width:
42 | img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
43 | img = img.transpose(2, 0, 1) / 255.0
44 | img = torch.from_numpy(img).float().expand(1, -1, -1, -1)
45 | model.eval()
46 | with torch.no_grad():
47 | left_rgb = img.to(device)
48 | ''' Prediction '''
49 | left_depth_pred = torch.abs(model(left_rgb))
50 | exporters.image.save_data(os.path.join(
51 | os.path.dirname(args.input_path),
52 | os.path.splitext(os.path.basename(
53 | args.input_path))[0] + "_depth.exr"),
54 | left_depth_pred, scale=1.0)
55 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet360 import *
2 |
3 | import sys
4 |
5 | def get_model(name, model_params):
6 | if name == 'resnet_coord':
7 | return ResNet360(
8 | # conv_type='standard', activation='elu', norm_type='none', \
9 | conv_type='coord', activation='elu', norm_type='none', \
10 | width=512,
11 | )
12 | else:
13 | print("Could not find the requested model ({})".format(name), file=sys.stderr)
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | '''
6 | Code adapted from https://github.com/uber-research/coordconv
7 | accompanying the paper "An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution" (NeurIPS 2018)
8 | '''
9 |
10 | class AddCoords360(nn.Module):
11 | def __init__(self, x_dim=64, y_dim=64, with_r=False):
12 | super(AddCoords360, self).__init__()
13 | self.x_dim = int(x_dim)
14 | self.y_dim = int(y_dim)
15 | self.with_r = with_r
16 |
17 | def forward(self, input_tensor):
18 | """
19 | input_tensor: (batch, c, x_dim, y_dim)
20 | """
21 | batch_size_tensor = input_tensor.shape[0]
22 |
23 | xx_ones = torch.ones([1, self.y_dim], dtype=torch.float32, device=input_tensor.device)
24 | xx_ones = xx_ones.unsqueeze(-1)
25 |
26 | xx_range = torch.arange(self.x_dim, dtype=torch.float32, device=input_tensor.device).unsqueeze(0)
27 | xx_range = xx_range.unsqueeze(1)
28 |
29 | xx_channel = torch.matmul(xx_ones, xx_range)
30 | xx_channel = xx_channel.unsqueeze(-1)
31 |
32 | yy_ones = torch.ones([1, self.x_dim], dtype=torch.float32, device=input_tensor.device)
33 | yy_ones = yy_ones.unsqueeze(1)
34 |
35 | yy_range = torch.arange(self.y_dim, dtype=torch.float32, device=input_tensor.device).unsqueeze(0)
36 | yy_range = yy_range.unsqueeze(-1)
37 |
38 | yy_channel = torch.matmul(yy_range, yy_ones)
39 | yy_channel = yy_channel.unsqueeze(-1)
40 |
41 | xx_channel = xx_channel.permute(0, 3, 2, 1)
42 | yy_channel = yy_channel.permute(0, 3, 2, 1)
43 |
44 | xx_channel = xx_channel.float() / (self.x_dim - 1)
45 | yy_channel = yy_channel.float() / (self.y_dim - 1)
46 |
47 | xx_channel = xx_channel * 2 - 1
48 | yy_channel = yy_channel * 2 - 1
49 |
50 | xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
51 | yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
52 |
53 | ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
54 |
55 | if self.with_r:
56 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
57 | ret = torch.cat([ret, rr], dim=1)
58 |
59 | return ret
60 |
61 | class CoordConv360(nn.Module):
62 | """CoordConv layer as in the paper."""
63 | def __init__(self, x_dim, y_dim, with_r, in_channels, out_channels, kernel_size, *args, **kwargs):
64 | super(CoordConv360, self).__init__()
65 | self.addcoords = AddCoords360(x_dim=x_dim, y_dim=y_dim, with_r=with_r)
66 | in_size = in_channels+2
67 | if with_r:
68 | in_size += 1
69 | self.conv = nn.Conv2d(in_size, out_channels, kernel_size, **kwargs)
70 |
71 | def forward(self, input_tensor):
72 | ret = self.addcoords(input_tensor)
73 | ret = self.conv(ret)
74 | return ret
75 |
76 |
77 | def create_conv(in_size, out_size, conv_type, padding=1, stride=1, kernel_size=3, width=512):
78 | if conv_type == 'standard':
79 | return nn.Conv2d(in_channels=in_size, out_channels=out_size, \
80 | kernel_size=kernel_size, padding=padding, stride=stride)
81 | elif conv_type == 'coord':
82 | return CoordConv360(x_dim=width / 2.0, y_dim=width,\
83 | with_r=False, kernel_size=kernel_size, stride=stride,\
84 | in_channels=in_size, out_channels=out_size, padding=padding)
85 |
86 | def create_activation(activation):
87 | if activation == 'relu':
88 | return nn.ReLU(inplace=True)
89 | elif activation == 'elu':
90 | return nn.ELU(inplace=True)
91 |
92 | class Identity(nn.Module):
93 | def forward(self, x):
94 | return x
95 |
96 | def create_normalization(out_size, norm_type):
97 | if norm_type == 'batchnorm':
98 | return nn.BatchNorm2d(out_size)
99 | elif norm_type == 'groupnorm':
100 | return nn.GroupNorm(out_size // 4, out_size)
101 | elif norm_type == 'none':
102 | return Identity()
103 |
104 | def create_downscale(out_size, down_mode):
105 | if down_mode == 'pool':
106 | return torch.nn.modules.MaxPool2d(2)
107 | elif down_mode == 'downconv':
108 | return nn.Conv2d(in_channels=out_size, out_channels=out_size, kernel_size=3,\
109 | stride=2, padding=1, bias=False)
110 | elif down_mode == 'gaussian':
111 | print("Not implemented")
--------------------------------------------------------------------------------
/models/resnet360.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import functools
5 |
6 | from .modules import *
7 |
8 | # adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
9 |
10 | class ResNet360(nn.Module):
11 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
12 |
13 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
14 | """
15 | def __init__(
16 | self,
17 | in_channels=3,
18 | out_channels=1,
19 | depth=5,
20 | wf=32,
21 | conv_type='coord',
22 | padding='kernel',
23 | norm_type='none',
24 | activation='elu',
25 | up_mode='upconv',
26 | down_mode='downconv',
27 | width=512,
28 | use_dropout=False,
29 | padding_type='reflect',
30 | ):
31 | """Construct a Resnet-based generator
32 |
33 | Parameters:
34 | input_nc (int) -- the number of channels in input images
35 | output_nc (int) -- the number of channels in output images
36 | ngf (int) -- the number of filters in the last conv layer
37 | norm_layer -- normalization layer
38 | use_dropout (bool) -- if use dropout layers
39 | n_blocks (int) -- the number of ResNet blocks
40 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
41 | """
42 | assert(depth >= 0)
43 | super(ResNet360, self).__init__()
44 | model = (
45 | [
46 | create_conv(in_channels, wf, conv_type, \
47 | kernel_size=7, padding=3, stride=1, width=width),
48 | create_normalization(wf, norm_type),
49 | create_activation(activation)
50 | ]
51 | )
52 |
53 | n_downsampling = 2
54 | for i in range(n_downsampling):
55 | mult = 2 ** i
56 | model += (
57 | [
58 | create_conv(wf * mult, wf * mult * 2, conv_type, \
59 | kernel_size=3, stride=2, padding=1, width=width // (i+1)),
60 | create_normalization(wf * mult * 2, norm_type),
61 | create_activation(activation)
62 | ]
63 | )
64 |
65 | mult = 2 ** n_downsampling
66 | for i in range(depth):
67 | model += [ResnetBlock(wf * mult, activation=activation, \
68 | norm_type=norm_type, conv_type=conv_type, \
69 | width=width // (2 ** n_downsampling))]
70 |
71 | for i in range(n_downsampling):
72 | mult = 2 ** (n_downsampling - i)
73 | model += (
74 | [
75 | nn.ConvTranspose2d(wf * mult, int(wf * mult / 2),
76 | kernel_size=3, stride=2,
77 | padding=1, output_padding=1),
78 | create_normalization(int(wf * mult / 2), norm_type),
79 | create_activation(activation)
80 | ]
81 | )
82 |
83 | model += [create_conv(wf, out_channels, conv_type, \
84 | kernel_size=7, padding=3, width=width)]
85 |
86 | self.model = nn.Sequential(*model)
87 |
88 | def forward(self, input):
89 | """Standard forward"""
90 | return self.model(input)
91 |
92 |
93 | class ResnetBlock(nn.Module):
94 | """Define a Resnet block"""
95 |
96 | def __init__(self, dim, norm_type, conv_type, activation, width):
97 | """Initialize the Resnet block
98 |
99 | A resnet block is a conv block with skip connections
100 | We construct a conv block with build_conv_block function,
101 | and implement skip connections in function.
102 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
103 | """
104 | super(ResnetBlock, self).__init__()
105 | conv_block = []
106 | conv_block +=(
107 | [
108 | create_conv(dim, dim, conv_type, width=width),
109 | create_normalization(dim, norm_type),
110 | create_activation(activation),
111 | ]
112 | )
113 | conv_block +=(
114 | [
115 | create_conv(dim, dim, conv_type, width=width),
116 | create_normalization(dim, norm_type),
117 | ]
118 | )
119 |
120 | self.block = nn.Sequential(*conv_block)
121 |
122 | def forward(self, x):
123 | """Forward function (with skip connections)"""
124 | out = x + self.block(x) # add skip connections
125 | return out
--------------------------------------------------------------------------------
/spherical/__init__.py:
--------------------------------------------------------------------------------
1 | from .grid import *
2 | from .cartesian import *
3 | from .derivatives import *
4 | from .weights import *
--------------------------------------------------------------------------------
/spherical/cartesian.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .grid import *
4 |
5 | '''
6 | Cartesian coordinates extraction from Spherical coordinates
7 | z is forward axis
8 | y is the up axis
9 | x is the right axis
10 | r is the radius (i.e. spherical depth)
11 | phi is the longitude/azimuthial rotation angle (defined on the x-z plane)
12 | theta is the latitude/elevation rotation angle (defined on the y-z plane)
13 | '''
14 | def coord_x(sgrid, depth):
15 | return ( # r * sin(phi) * sin(theta) -> r * cos(phi) * -cos(theta) in our offsets
16 | depth # this is due to the offsets as explained below
17 | * torch.cos(phi(sgrid)) # long = x - 3 * pi / 2
18 | * -1 * torch.cos(theta(sgrid)) # lat = y - pi / 2
19 | )
20 |
21 | def coord_y(sgrid, depth):
22 | return ( # r * cos(theta) -> r * sin(theta) in our offsets
23 | depth # this is due to the offsets as explained below
24 | * torch.sin(theta(sgrid)) # lat = y - pi / 2
25 | )
26 |
27 | def coord_z(sgrid, depth):
28 | return ( # r * cos(phi) * sin(theta) -> r * -sin(phi) * -cos(theta) in our offsets
29 | depth # this is due to the offsets as explained above
30 | * torch.sin(phi(sgrid)) # * -1
31 | * torch.cos(theta(sgrid)) # * -1
32 | ) # the -1s cancel out
33 |
34 | def coords_3d(sgrid, depth):
35 | return torch.cat(
36 | (
37 | coord_x(sgrid, depth),
38 | coord_y(sgrid, depth),
39 | coord_z(sgrid, depth)
40 | ), dim=1
41 | )
42 |
43 | def xi(pcloud):
44 | return pcloud[:, 0, :, :].unsqueeze(1)
45 |
46 | def yi(pcloud):
47 | return pcloud[:, 1, :, :].unsqueeze(1)
48 |
49 | def zeta(pcloud):
50 | return pcloud[:, 2, :, :].unsqueeze(1)
--------------------------------------------------------------------------------
/spherical/derivatives.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .grid import *
4 | from .cartesian import *
5 |
6 | ''' Image (I) spatial derivatives '''
7 | def dI_du(img):
8 | right_pad = (0, 1, 0, 0)
9 | tensor = torch.nn.functional.pad(img, right_pad, mode="replicate")
10 | gu = tensor[:, :, :, :-1] - tensor[:, :, :, 1:] # NCHW
11 | return gu
12 |
13 | def dI_dv(img):
14 | bottom_pad = (0, 0, 0, 1)
15 | tensor = torch.nn.functional.pad(img, bottom_pad, mode="replicate")
16 | dv = tensor[:, :, :-1, :] - tensor[:, :, 1:, :] # NCHW
17 | return dv
18 |
19 | def dI_duv(img):
20 | du = dI_du(img)
21 | dv = dI_dv(img)
22 | duv = torch.cat((du, dv), dim=1)
23 | duv_mag = torch.norm(duv, p=2, dim=1, keepdim=True)
24 | return duv_mag
25 |
26 | '''
27 | Spherical coordinates (r, phi, theta) derivatives
28 | w.r.t. their Cartesian counterparts (x, y, z)
29 | '''
30 | def dr_dx(sgrid):
31 | return ( # sin(lat) * sin(long) -> cos(long) * -cos(lat)
32 | -1 # this is due to the offsets as explaned below
33 | * torch.cos(phi(sgrid)) # long = x - 3 * pi / 2
34 | * torch.cos(theta(sgrid)) # lat = y - pi / 2
35 | ) # the depth (radius) distortion for each spherical coord with a horizontal baseline
36 |
37 | def dphi_dx(sgrid):
38 | return ( # cos(long) / sin(lat) -> -sin(long) / -cos(lat)
39 | torch.sin(phi(sgrid)) # * -1
40 | / torch.cos(theta(sgrid)) # * -1
41 | ) # the -1s cancel out and are ommitted
42 |
43 | def dtheta_dx(sgrid):
44 | return ( # sin(long) * cos(lat) -> cos(long) * sin(lat)
45 | torch.cos(phi(sgrid)) * torch.sin(theta(sgrid))
46 | )
47 |
48 | def dtheta_dy(sgrid):
49 | return ( # -sin(lat) -> -1 * -cos(lat) == cos(lat)
50 | torch.cos(theta(sgrid))
51 | )
52 |
53 | def dphi_horizontal(sgrid, depth, baseline):
54 | _, __, h, ___ = depth.size()
55 | return torch.clamp(
56 | (
57 | torch.sin(phi(sgrid))
58 | / (
59 | depth
60 | * torch.cos(theta(sgrid))
61 | )
62 | * baseline
63 | * (h / numpy.pi)
64 | ),
65 | -h, h # h = w/2 the max disparity due to our spherical nature (i.e. front/back symmetry)
66 | )
67 |
68 | def dtheta_horizontal(sgrid, depth, baseline):
69 | _, __, h, ___ = depth.size()
70 | return torch.clamp(
71 | (
72 | torch.cos(phi(sgrid))
73 | * torch.sin(theta(sgrid))
74 | * baseline
75 | / depth
76 | * (h / numpy.pi)
77 | ),
78 | 0, h
79 | )
80 |
81 | def dr_horizontal(sgrid, baseline):
82 | return ( # sin(lat) * sin(long) -> cos(long) * -cos(lat)
83 | -1 # this is due to the offsets as explained below
84 | * torch.cos(phi(sgrid)) # long = x - 3 * pi / 2
85 | * torch.cos(theta(sgrid)) # lat = y - pi / 2
86 | * baseline
87 | ) # the depth (radius) distortion for each spherical coord with a horizontal baseline
88 |
89 | def dtheta_vertical(sgrid, depth, baseline):
90 | _, __, h, ___ = depth.size()
91 | return (
92 | torch.cos(theta(sgrid))
93 | * baseline
94 | / depth
95 | * (h / numpy.pi)
96 | )
97 |
98 | '''
99 | Structured Point Cloud Vertices (V) spatial derivatives
100 | '''
101 | def dV_dx(pcloud):
102 | return dI_duv(xi(pcloud))
103 |
104 | def dV_dy(pcloud):
105 | return dI_duv(yi(pcloud))
106 |
107 | def dV_dz(pcloud):
108 | return dI_duv(zeta(pcloud))
109 |
110 | def dV_dxyz(pcloud):
111 | du_x = dI_du(xi(pcloud))
112 | dv_x = dI_dv(xi(pcloud))
113 |
114 | du_y = dI_du(yi(pcloud))
115 | dv_y = dI_dv(yi(pcloud))
116 |
117 | du_z = dI_du(zeta(pcloud))
118 | dv_z = dI_dv(zeta(pcloud))
119 |
120 | du_xyz = torch.abs(du_x) + torch.abs(du_y) + torch.abs(du_z)
121 | dv_xyz = torch.abs(dv_x) + torch.abs(dv_y) + torch.abs(dv_z)
122 |
123 | duv_xyz = torch.cat((du_xyz, dv_xyz), dim=1)
124 | duv__xyz_mag = torch.norm(duv_xyz, p=2, dim=1, keepdim=True)
125 | return duv__xyz_mag
--------------------------------------------------------------------------------
/spherical/grid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy
3 |
4 | def create_image_grid(width, height, data_type=torch.float32):
5 | v_range = (
6 | torch.arange(0, height) # [0 - h]
7 | .view(1, height, 1) # [1, [0 - h], 1]
8 | .expand(1, height, width) # [1, [0 - h], W]
9 | .type(data_type) # [1, H, W]
10 | )
11 | u_range = (
12 | torch.arange(0, width) # [0 - w]
13 | .view(1, 1, width) # [1, 1, [0 - w]]
14 | .expand(1, height, width) # [1, H, [0 - w]]
15 | .type(data_type) # [1, H, W]
16 | )
17 | return torch.stack((u_range, v_range), dim=1) # [1, 2, H, W]
18 |
19 | def coord_u(uvgrid):
20 | return uvgrid[:, 0, :, :].unsqueeze(1)
21 |
22 | def coord_v(uvgrid):
23 | return uvgrid[:, 1, :, :].unsqueeze(1)
24 |
25 | def create_spherical_grid(width, horizontal_shift=(-numpy.pi - numpy.pi / 2.0),
26 | vertical_shift=(-numpy.pi / 2.0), data_type=torch.float32):
27 | height = int(width // 2.0)
28 | v_range = (
29 | torch.arange(0, height) # [0 - h]
30 | .view(1, height, 1) # [1, [0 - h], 1]
31 | .expand(1, height, width) # [1, [0 - h], W]
32 | .type(data_type) # [1, H, W]
33 | )
34 | u_range = (
35 | torch.arange(0, width) # [0 - w]
36 | .view(1, 1, width) # [1, 1, [0 - w]]
37 | .expand(1, height, width) # [1, H, [0 - w]]
38 | .type(data_type) # [1, H, W]
39 | )
40 | u_range *= (2 * numpy.pi / width) # [0, 2 * pi]
41 | v_range *= (numpy.pi / height) # [0, pi]
42 | u_range += horizontal_shift # [-hs, 2 * pi - hs] -> standard values are [-3 * pi / 2, pi / 2]
43 | v_range += vertical_shift # [-vs, pi - vs] -> standard values are [-pi / 2, pi / 2]
44 | return torch.stack((u_range, v_range), dim=1) # [1, 2, H, W]
45 |
46 | def phi(sgrid): # longitude or azimuth
47 | return sgrid[:, 0, :, :].unsqueeze(1)
48 |
49 | def azimuth(sgrid): # longitude or phi
50 | return sgrid[:, 0, :, :].unsqueeze(1)
51 |
52 | def longitude(sgrid): # phi or azimuth
53 | return sgrid[:, 0, :, :].unsqueeze(1)
54 |
55 | def theta(sgrid): # latitude or elevation
56 | return sgrid[:, 1, :, :].unsqueeze(1)
57 |
58 | def elevation(sgrid): # theta or elevation
59 | return sgrid[:, 1, :, :].unsqueeze(1)
60 |
61 | def latitude(sgrid): # latitude or theta
62 | return sgrid[:, 1, :, :].unsqueeze(1)
--------------------------------------------------------------------------------
/spherical/weights.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .grid import *
4 |
5 | def phi_confidence(sgrid): # fading towards horizontal singularities
6 | return torch.abs(torch.sin(phi(sgrid)))
7 |
8 | def theta_confidence(sgrid): # fading towards vertical singularities
9 | return torch.abs(torch.cos(theta(sgrid)))
10 |
11 | def spherical_confidence(sgrid, zero_low=0.0, one_high=1.0):
12 | weights = phi_confidence(sgrid) * theta_confidence(sgrid)
13 | weights[weights < zero_low] = 0.0
14 | weights[weights > one_high] = 1.0
15 | return weights
--------------------------------------------------------------------------------
/supervision/__init__.py:
--------------------------------------------------------------------------------
1 | from .splatting import *
2 | from .photometric import *
3 | from .smoothness import *
4 | from .direct import *
--------------------------------------------------------------------------------
/supervision/direct.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def calculate_berhu_loss(pred, gt, mask, weights):
4 | diff = gt - pred
5 | abs_diff = torch.abs(diff)
6 | c = torch.max(abs_diff).item() / 5
7 | leq = (abs_diff <= c).float()
8 | l2_losses = (diff**2 + c**2) / (2 * c)
9 | loss = leq * abs_diff + (1 - leq) * l2_losses
10 | _, c, __, ___ = loss.size()
11 | count = torch.sum(mask, dim=[1, 2, 3], keepdim=True).float()
12 | masked_loss = loss * mask.float()
13 | weighted_loss = masked_loss * weights
14 | return torch.mean(torch.sum(weighted_loss, dim=[1, 2, 3], keepdim=True) / count)
--------------------------------------------------------------------------------
/supervision/photometric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .ssim import *
4 |
5 | class PhotometricLossParameters(object):
6 | def __init__(self, alpha=0.85, l1_estimator='none',\
7 | ssim_estimator='none', window=7, std=1.5, ssim_mode='gaussian'):
8 | super(PhotometricLossParameters, self).__init__()
9 | self.alpha = alpha
10 | self.l1_estimator = l1_estimator
11 | self.ssim_estimator = ssim_estimator
12 | self.window = window
13 | self.std = std
14 | self.ssim_mode = ssim_mode
15 |
16 | def get_alpha(self):
17 | return self.alpha
18 |
19 | def get_l1_estimator(self):
20 | return self.l1_estimator
21 |
22 | def get_ssim_estimator(self):
23 | return self.ssim_estimator
24 |
25 | def get_window(self):
26 | return self.window
27 |
28 | def get_std(self):
29 | return self.std
30 |
31 | def get_ssim_mode(self):
32 | return self.ssim_mode
33 |
34 | def calculate_loss(pred, gt, params, mask, weights):
35 | valid_mask = mask.type(gt.dtype)
36 | masked_gt = gt * valid_mask
37 | masked_pred = pred * valid_mask
38 | l1 = torch.abs(masked_gt - masked_pred)
39 | d_ssim = torch.clamp(
40 | (
41 | 1 - ssim_loss(masked_pred, masked_gt, kernel_size=params.get_window(),
42 | std=params.get_std(), mode=params.get_ssim_mode())
43 | ) / 2, 0, 1)
44 | loss = (
45 | d_ssim * params.get_alpha()
46 | + l1 * (1 - params.get_alpha())
47 | )
48 | loss *= valid_mask
49 | loss *= weights
50 | count = torch.sum(mask, dim=[1, 2, 3], keepdim=True).float()
51 | return torch.mean(torch.sum(loss, dim=[1, 2, 3], keepdim=True) / count)
52 |
--------------------------------------------------------------------------------
/supervision/smoothness.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def guided_smoothness_loss(input_duv, guide_duv, mask, weights):
4 | guidance_weights = torch.exp(-guide_duv)
5 | smoothness = input_duv * guidance_weights
6 | smoothness[~mask] = 0.0
7 | smoothness *= weights
8 | return torch.sum(smoothness) / torch.sum(mask)
--------------------------------------------------------------------------------
/supervision/splatting.py:
--------------------------------------------------------------------------------
1 | '''
2 | PyTorch implementation of https://github.com/google/layered-scene-inference
3 | accompanying the paper "Layer-structured 3D Scene Inference via View Synthesis",
4 | ECCV 2018 https://shubhtuls.github.io/lsi/
5 | '''
6 |
7 | import torch
8 |
9 | def __splat__(values, coords, splatted):
10 | b, c, h, w = splatted.size()
11 | uvs = coords
12 | u = uvs[:, 0, :, :].unsqueeze(1)
13 | v = uvs[:, 1, :, :].unsqueeze(1)
14 |
15 | u0 = torch.floor(u)
16 | u1 = u0 + 1
17 | v0 = torch.floor(v)
18 | v1 = v0 + 1
19 |
20 | u0_safe = torch.clamp(u0, 0.0, w-1)
21 | v0_safe = torch.clamp(v0, 0.0, h-1)
22 | u1_safe = torch.clamp(u1, 0.0, w-1)
23 | v1_safe = torch.clamp(v1, 0.0, h-1)
24 |
25 | u0_w = (u1 - u) * (u0 == u0_safe).detach().type(values.dtype)
26 | u1_w = (u - u0) * (u1 == u1_safe).detach().type(values.dtype)
27 | v0_w = (v1 - v) * (v0 == v0_safe).detach().type(values.dtype)
28 | v1_w = (v - v0) * (v1 == v1_safe).detach().type(values.dtype)
29 |
30 | top_left_w = u0_w * v0_w
31 | top_right_w = u1_w * v0_w
32 | bottom_left_w = u0_w * v1_w
33 | bottom_right_w = u1_w * v1_w
34 |
35 | weight_threshold = 1e-3
36 | top_left_w *= (top_left_w >= weight_threshold).detach().type(values.dtype)
37 | top_right_w *= (top_right_w >= weight_threshold).detach().type(values.dtype)
38 | bottom_left_w *= (bottom_left_w >= weight_threshold).detach().type(values.dtype)
39 | bottom_right_w *= (bottom_right_w >= weight_threshold).detach().type(values.dtype)
40 |
41 | for channel in range(c):
42 | top_left_values = values[:, channel, :, :].unsqueeze(1) * top_left_w
43 | top_right_values = values[:, channel, :, :].unsqueeze(1) * top_right_w
44 | bottom_left_values = values[:, channel, :, :].unsqueeze(1) * bottom_left_w
45 | bottom_right_values = values[:, channel, :, :].unsqueeze(1) * bottom_right_w
46 |
47 | top_left_values = top_left_values.reshape(b, -1)
48 | top_right_values = top_right_values.reshape(b, -1)
49 | bottom_left_values = bottom_left_values.reshape(b, -1)
50 | bottom_right_values = bottom_right_values.reshape(b, -1)
51 |
52 | top_left_indices = (u0_safe + v0_safe * w).reshape(b, -1).type(torch.int64)
53 | top_right_indices = (u1_safe + v0_safe * w).reshape(b, -1).type(torch.int64)
54 | bottom_left_indices = (u0_safe + v1_safe * w).reshape(b, -1).type(torch.int64)
55 | bottom_right_indices = (u1_safe + v1_safe * w).reshape(b, -1).type(torch.int64)
56 |
57 | splatted_channel = splatted[:, channel, :, :].unsqueeze(1)
58 | splatted_channel = splatted_channel.reshape(b, -1)
59 | splatted_channel.scatter_add_(1, top_left_indices, top_left_values)
60 | splatted_channel.scatter_add_(1, top_right_indices, top_right_values)
61 | splatted_channel.scatter_add_(1, bottom_left_indices, bottom_left_values)
62 | splatted_channel.scatter_add_(1, bottom_right_indices, bottom_right_values)
63 | splatted = splatted.reshape(b, c, h, w)
64 |
65 | def __weighted_average_splat__(depth, weights, epsilon=1e-8):
66 | zero_weights = (weights <= epsilon).detach().type(depth.dtype)
67 | return depth / (weights + epsilon * zero_weights)
68 |
69 | def __depth_distance_weights__(depth, max_depth=20.0):
70 | weights = 1.0 / torch.exp(2 * depth / max_depth)
71 | return weights
72 |
73 | def render(img, depth, coords, max_depth=20.0):
74 | splatted_img = torch.zeros_like(img)
75 | splatted_wgts = torch.zeros_like(depth)
76 | weights = __depth_distance_weights__(depth, max_depth=max_depth)
77 | __splat__(img * weights, coords, splatted_img)
78 | __splat__(weights, coords, splatted_wgts)
79 | recon = __weighted_average_splat__(splatted_img, splatted_wgts)
80 | mask = (splatted_wgts > 1e-3).detach()
81 | return recon, mask
82 |
83 | def render_to(src, tgt, wgts, depth, coords, max_depth=20.0):
84 | weights = __depth_distance_weights__(depth, max_depth=max_depth)
85 | __splat__(src * weights, coords, tgt)
86 | __splat__(weights, coords, wgts)
87 | tgt = __weighted_average_splat__(tgt, wgts)
88 | mask = (wgts > 1e-3).detach()
89 | return mask
--------------------------------------------------------------------------------
/supervision/ssim.py:
--------------------------------------------------------------------------------
1 | '''
2 | Code modified from https://github.com/Po-Hsun-Su/pytorch-ssim
3 | '''
4 |
5 | import torch
6 | import numpy
7 | import math
8 |
9 | def __gaussian__(kernel_size, std, data_type=torch.float32):
10 | gaussian = numpy.array([math.exp(-(x - kernel_size//2)**2/float(2*std**2)) for x in range(kernel_size)])
11 | gaussian /= numpy.sum(gaussian)
12 | return torch.tensor(gaussian, dtype=data_type)
13 |
14 | def __create_kernel__(kernel_size, data_type=torch.float32, channels=3, std=1.5):
15 | gaussian1d = __gaussian__(kernel_size, std).unsqueeze(1)
16 | gaussian2d = torch.mm(gaussian1d, gaussian1d.t())\
17 | .type(data_type)\
18 | .unsqueeze(0)\
19 | .unsqueeze(0)
20 | window = gaussian2d.expand(channels, 1, kernel_size, kernel_size).contiguous()
21 | return window
22 |
23 | def __ssim_gaussian__(prediction, groundtruth, kernel, kernel_size, channels=3):
24 | padding = kernel_size // 2
25 | prediction_mean = torch.nn.functional.conv2d(prediction, kernel, padding=padding, groups=channels)
26 | groundtruth_mean = torch.nn.functional.conv2d(groundtruth, kernel, padding=padding, groups=channels)
27 |
28 | prediction_mean_squared = prediction_mean.pow(2)
29 | groundtruth_mean_squared = groundtruth_mean.pow(2)
30 | prediction_mean_times_groundtruth_mean = prediction_mean * groundtruth_mean
31 |
32 | prediction_sigma_squared = torch.nn.functional.conv2d(prediction * prediction, kernel, padding=padding, groups=channels)\
33 | - prediction_mean_squared
34 | groundtruth_sigma_squared = torch.nn.functional.conv2d(groundtruth * groundtruth, kernel, padding=padding, groups=channels)\
35 | - groundtruth_mean_squared
36 | prediction_groundtruth_covariance = torch.nn.functional.conv2d(prediction * groundtruth, kernel, padding=padding, groups=channels)\
37 | - prediction_mean_times_groundtruth_mean
38 |
39 | C1 = 0.01**2 # assume that images are in the [0, 1] range
40 | C2 = 0.03**2 # assume that images are in the [0, 1] range
41 |
42 | return (
43 | ( # numerator
44 | (2 * prediction_mean_times_groundtruth_mean + C1) # luminance term
45 | * (2 * prediction_groundtruth_covariance + C2) # structural term
46 | )
47 | / # division
48 | ( # denominator
49 | (prediction_mean_squared + groundtruth_mean_squared + C1) # luminance term
50 | * (prediction_sigma_squared + groundtruth_sigma_squared + C2) # structural term
51 | )
52 | )
53 |
54 | def ssim_gaussian(prediction, groundtruth, kernel_size=11, std=1.5):
55 | (_, channels, _, _) = prediction.size()
56 | kernel = __create_kernel__(kernel_size, data_type=prediction.type(),\
57 | channels=channels, std=std)
58 |
59 | if prediction.is_cuda:
60 | kernel = kernel.to(prediction.get_device())
61 | kernel = kernel.type_as(prediction)
62 |
63 | return __ssim_gaussian__(prediction, groundtruth, kernel, kernel_size, channels)
64 |
65 | def ssim_box(prediction, groundtruth, kernel_size=3):
66 | C1 = 0.01 ** 2
67 | C2 = 0.03 ** 2
68 |
69 | prediction_mean = torch.nn.AvgPool2d(kernel_size, stride=1)(prediction)
70 | groundtruth_mean = torch.nn.AvgPool2d(kernel_size, stride=1)(groundtruth)
71 | prediction_groundtruth_mean = prediction_mean * groundtruth_mean
72 | prediction_mean_squared = prediction_mean.pow(2)
73 | groundtruth_mean_squared = groundtruth_mean.pow(2)
74 |
75 | prediction_sigma = torch.nn.AvgPool2d(kernel_size, stride=1)(prediction * prediction) - prediction_mean_squared
76 | groundtruth_sigma = torch.nn.AvgPool2d(kernel_size, stride=1)(groundtruth * groundtruth) - groundtruth_mean_squared
77 | correlation = torch.nn.AvgPool2d(kernel_size, stride=1)(prediction * groundtruth) - prediction_groundtruth_mean
78 |
79 | numerator = (2 * prediction_groundtruth_mean + C1) * (2 * correlation + C2)
80 | denominator = (prediction_mean_squared + groundtruth_mean_squared + C1)\
81 | * (prediction_sigma + groundtruth_sigma + C2)
82 | ssim = numerator / denominator
83 | pad = kernel_size // 2
84 | return torch.nn.functional.pad(ssim, (pad, pad, pad, pad))
85 |
86 | def ssim_loss(prediction, groundtruth, kernel_size=5, std=1.5, mode='gaussian'):
87 | if mode == 'gaussian':
88 | return ssim_gaussian(prediction, groundtruth, kernel_size=kernel_size, std=std)
89 | elif mode == 'box':
90 | return ssim_box(prediction, groundtruth, kernel_size=kernel_size)
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import numpy
5 |
6 | import torch
7 | import torchvision
8 |
9 | import models
10 | import dataset
11 | import utils
12 | from filesystem import file_utils
13 |
14 | import supervision as L
15 | import exporters as IO
16 | import spherical as S360
17 |
18 | def parse_arguments(args):
19 | usage_text = (
20 | "Semi-supervised Spherical Depth Estimation Testing."
21 | )
22 | parser = argparse.ArgumentParser(description=usage_text)
23 | # enumerables
24 | parser.add_argument('-b',"--batch_size", type=int, help="Test a number of samples each iteration.")
25 | parser.add_argument('--save_iters', type=int, default=100, help='Maximum test iterations whose results will be saved.')
26 | # paths
27 | parser.add_argument("--test_path", type=str, help="Path to the testing file containing the test set file paths")
28 | parser.add_argument("--save_path", type=str, help="Path to the folder where the models and results will be saved at.")
29 | # model
30 | parser.add_argument("--configuration", required=False, type=str, default='mono', help="Data loader configuration , , , ", choices=['mono', 'lr', 'ud', 'tc'])
31 | parser.add_argument('--weights', type=str, help='Path to the trained weights file.')
32 | parser.add_argument('--model', default="default", type=str, help='Model selection argument.')
33 | # hardware
34 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.')
35 | # other
36 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.')
37 | parser.add_argument("--visdom", type=str, nargs='?', default=None, const="127.0.0.1", help="Visdom server IP (port defaults to 8097)")
38 | parser.add_argument("--visdom_iters", type=int, default=400, help = "Iteration interval that results will be reported at the visdom server for visualization.")
39 | # metrics
40 | parser.add_argument("--depth_thres", type=float, default=20.0, help = "Depth threshold - depth clipping.")
41 | parser.add_argument("--width", type=float, default=512, help = "Spherical image width.")
42 | parser.add_argument("--baseline", type=float, default=0.26, help = "Stereo baseline distance (in either axis).")
43 | parser.add_argument("--median_scale", required=False, default=False, action="store_true", help = "Perform median scaling before calculating metrics.")
44 | parser.add_argument("--spherical_weights", required=False, default=False, action="store_true", help = "Use spherical weighting when calculating the metrics.")
45 | parser.add_argument("--spherical_sampling", required=False, default=False, action="store_true", help = "Use spherical sampling when calculating the metrics.")
46 | # save options
47 | parser.add_argument("--save_recon", required=False, default=False, action="store_true", help = "Flag to toggle reconstructed result saving.")
48 | parser.add_argument("--save_original", required=False, default=False, action="store_true", help = "Flag to toggle input (image) saving.")
49 | parser.add_argument("--save_depth", required=False, default=False, action="store_true", help = "Flag to toggle output (depth) saving.")
50 | return parser.parse_known_args(args)
51 |
52 | def compute_errors(gt, pred, invalid_mask, weights, sampling, mode='cpu', median_scale=False):
53 | b, _, __, ___ = gt.size()
54 | scale = torch.median(gt.reshape(b, -1), dim=1)[0] / torch.median(pred.reshape(b, -1), dim=1)[0]\
55 | if median_scale else torch.tensor(1.0).expand(b, 1, 1, 1).to(gt.device)
56 | pred = pred * scale.reshape(b, 1, 1, 1)
57 | valid_sum = torch.sum(~invalid_mask, dim=[1, 2, 3], keepdim=True)
58 | gt[invalid_mask] = 0.0
59 | pred[invalid_mask] = 0.0
60 | thresh = torch.max((gt / pred), (pred / gt))
61 | thresh[invalid_mask | (sampling < 0.5)] = 2.0
62 |
63 | sum_dims = [1, 2, 3]
64 | delta_valid_sum = torch.sum(~invalid_mask & (sampling > 0), dim=[1, 2, 3], keepdim=True)
65 | delta1 = (thresh < 1.25 ).float().sum(dim=sum_dims, keepdim=True).float() / delta_valid_sum.float()
66 | delta2 = (thresh < (1.25 ** 2)).float().sum(dim=sum_dims, keepdim=True).float() / delta_valid_sum.float()
67 | delta3 = (thresh < (1.25 ** 3)).float().sum(dim=sum_dims, keepdim=True).float() / delta_valid_sum.float()
68 |
69 | rmse = (gt - pred) ** 2
70 | rmse[invalid_mask] = 0.0
71 | rmse_w = rmse * weights
72 | rmse_mean = torch.sqrt(rmse_w.sum(dim=sum_dims, keepdim=True) / valid_sum.float())
73 |
74 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
75 | rmse_log[invalid_mask] = 0.0
76 | rmse_log_w = rmse_log * weights
77 | rmse_log_mean = torch.sqrt(rmse_log_w.sum(dim=sum_dims, keepdim=True) / valid_sum.float())
78 |
79 | abs_rel = (torch.abs(gt - pred) / gt)
80 | abs_rel[invalid_mask] = 0.0
81 | abs_rel_w = abs_rel * weights
82 | abs_rel_mean = abs_rel_w.sum(dim=sum_dims, keepdim=True) / valid_sum.float()
83 |
84 | sq_rel = (((gt - pred)**2) / gt)
85 | sq_rel[invalid_mask] = 0.0
86 | sq_rel_w = sq_rel * weights
87 | sq_rel_mean = sq_rel_w.sum(dim=sum_dims, keepdim=True) / valid_sum.float()
88 |
89 | return (abs_rel_mean, abs_rel), (sq_rel_mean, sq_rel), (rmse_mean, rmse), \
90 | (rmse_log_mean, rmse_log), delta1, delta2, delta3
91 |
92 | def spiral_sampling(grid, percentage):
93 | b, c, h, w = grid.size()
94 | N = torch.tensor(h*w*percentage).int().float()
95 | sampling = torch.zeros_like(grid)[:, 0, :, :].unsqueeze(1)
96 | phi_k = torch.tensor(0.0).float()
97 | for k in torch.arange(N - 1):
98 | k = k.float() + 1.0
99 | h_k = -1 + 2 * (k - 1) / (N - 1)
100 | theta_k = torch.acos(h_k)
101 | phi_k = phi_k + torch.tensor(3.6).float() / torch.sqrt(N) / torch.sqrt(1 - h_k * h_k) \
102 | if k > 1.0 else torch.tensor(0.0).float()
103 | phi_k = torch.fmod(phi_k, 2 * numpy.pi)
104 | sampling[:, :, int(theta_k / numpy.pi * h) - 1, int(phi_k / numpy.pi / 2 * w) - 1] += 1.0
105 | return (sampling > 0).float()
106 |
107 | if __name__ == "__main__":
108 | args, unknown = parse_arguments(sys.argv)
109 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0]
110 | # device & visualizers
111 | device = torch.device("cuda:{}" .format(gpus[0])\
112 | if torch.cuda.is_available() and len(gpus) > 0 and gpus[0] >= 0\
113 | else "cpu")
114 | plot_visualizer, image_visualizer = (utils.NullVisualizer(), utils.NullVisualizer())\
115 | if args.visdom is None\
116 | else (
117 | utils.VisdomPlotVisualizer(args.name + "_test_plots_", args.visdom),
118 | utils.VisdomImageVisualizer(args.name + "_test_images_", args.visdom,\
119 | count=2 if 2 <= args.batch_size else args.batch_size)
120 | )
121 | image_visualizer.update_epoch(0)
122 | # model
123 | model_params = { 'width': 512, 'configuration': args.configuration }
124 | model = models.get_model(args.model, model_params)
125 | utils.init.initialize_weights(model, args.weights, pred_bias=None)
126 | if (len(gpus) > 1):
127 | model = torch.nn.parallel.DataParallel(model, gpus)
128 | model = model.to(device)
129 | # test data
130 | width, height = args.width, args.width // 2
131 | test_data = dataset.dataset_360D.Dataset360D(args.test_path, " ", args.configuration, [height, width])
132 | test_data_iterator = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size,\
133 | num_workers=args.batch_size // 4 // (2 if len(gpus) > 0 else 1), pin_memory=False, shuffle=False)
134 | fs = file_utils.Filesystem()
135 | fs.mkdir(args.save_path)
136 | print("Test size : {}".format(args.batch_size * test_data_iterator.__len__()))
137 | # params & error vars
138 | max_save_iters = args.save_iters if args.save_iters > 0\
139 | else args.batch_size * test_data_iterator.__len__()
140 | errors = numpy.zeros((7, args.batch_size * test_data_iterator.__len__()), numpy.float32)
141 | weights = S360.weights.theta_confidence(
142 | S360.grid.create_spherical_grid(width)
143 | ).to(device) if args.spherical_weights else torch.ones(1, 1, height, width).to(device)
144 | sampling = spiral_sampling(S360.grid.create_image_grid(width, height), 0.25).to(device) \
145 | if args.spherical_sampling else torch.ones(1, 1, height, width).to(device)
146 | # loop over test set
147 | model.eval()
148 | with torch.no_grad():
149 | counter = 0
150 | uvgrid = S360.grid.create_image_grid(width, height).to(device)
151 | sgrid = S360.grid.create_spherical_grid(width).to(device)
152 | for test_batch_id , test_batch in enumerate(test_data_iterator):
153 | ''' Data '''
154 | left_rgb = test_batch['leftRGB'].to(device)
155 | left_depth = test_batch['leftDepth'].to(device)
156 | if 'rightRGB' in test_batch:
157 | right_rgb = test_batch['rightRGB'].to(device)
158 | mask = (left_depth > args.depth_thres)
159 | b, c, h, w = left_rgb.size()
160 | ''' Prediction '''
161 | left_depth_pred = torch.abs(model(left_rgb))
162 | ''' Errors '''
163 | abs_rel_t, sq_rel_t, rmse_t, rmse_log_t, delta1, delta2, delta3\
164 | = compute_errors(left_depth, left_depth_pred, mask, weights=weights, sampling=sampling, \
165 | mode='gpu' if torch.cuda.is_available() and len(gpus) > 0 and gpus[0] >= 0 else "cpu", \
166 | median_scale=args.median_scale)
167 | ''' Visualize & Append Errors '''
168 | for i in range(b):
169 | idx = counter + i
170 | errors[:, idx] = abs_rel_t[0][i], sq_rel_t[0][i], rmse_t[0][i], \
171 | rmse_log_t[0][i], delta1[i], delta2[i], delta3[i]
172 | for j in range(7):
173 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[0, idx]), "abs_rel")
174 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[1, idx]), "sq_rel")
175 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[2, idx]), "rmse")
176 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[3, idx]), "rmse_log")
177 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[4, idx]), "delta1")
178 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[5, idx]), "delta2")
179 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[6, idx]), "delta3")
180 | ''' Store '''
181 | if counter < args.save_iters:
182 | if args.save_original:
183 | IO.image.save_image(os.path.join(args.save_path,\
184 | str(counter) + "_" + args.name + "_#_left.png"), left_rgb)
185 | if args.save_depth:
186 | IO.image.save_data(os.path.join(args.save_path,\
187 | str(counter) + "_" + args.name + "_#_depth.exr"), left_depth_pred, scale=1.0)
188 | if args.save_recon:
189 | rads = sgrid.expand(b, -1, -1, -1)
190 | uv = uvgrid.expand(b, -1, -1, -1)
191 | disp = torch.cat(
192 | (
193 | S360.derivatives.dphi_horizontal(rads, left_depth_pred, args.baseline),
194 | S360.derivatives.dtheta_horizontal(rads, left_depth_pred, args.baseline)
195 | ), dim=1
196 | )
197 | right_render_coords = uv + disp
198 | right_render_coords[:, 0, :, :] = torch.fmod(right_render_coords[:, 0, :, :] + width, width)
199 | right_render_coords[torch.isnan(right_render_coords)] = 0.0
200 | right_render_coords[torch.isinf(right_render_coords)] = 0.0
201 | right_rgb_t, right_mask_t = L.splatting.render(left_rgb, left_depth_pred, right_render_coords, max_depth=args.depth_thres)
202 | IO.image.save_image(os.path.join(args.save_path,\
203 | str(counter) + "_" + args.name + "_#_right_t.png"), right_rgb_t)
204 | counter += b
205 | ''' Visualize Predictions '''
206 | if args.visdom_iters > 0 and (counter + 1) % args.visdom_iters <= args.batch_size:
207 | image_visualizer.show_separate_images(left_rgb, 'input')
208 | if 'rightRGB' in test_batch:
209 | image_visualizer.show_separate_images(right_rgb, 'target')
210 | image_visualizer.show_map(left_depth_pred, 'depth')
211 | if args.save_recon:
212 | image_visualizer.show_separate_images(right_rgb_t, 'recon')
213 | mean_errors = errors.mean(1)
214 | error_names = ['abs_rel','sq_rel','rmse','log_rmse','delta1','delta2','delta3']
215 | print("Results ({}): ".format(args.name))
216 | print("\t{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names))
217 | print("\t{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors))
218 |
219 |
220 |
--------------------------------------------------------------------------------
/train_lr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import numpy
5 |
6 | import torch
7 |
8 | import models
9 | import dataset
10 | import utils
11 |
12 | import supervision as L
13 | import exporters as IO
14 | import spherical as S360
15 |
16 | def parse_arguments(args):
17 | usage_text = (
18 | "Omnidirectional Horizontal Stereo Placement (Left-Right , LR) Training."
19 | )
20 | parser = argparse.ArgumentParser(description=usage_text)
21 | # durations
22 | parser.add_argument('-e',"--epochs", type=int, help="Train for a total number of epochs.")
23 | parser.add_argument('-b',"--batch_size", type=int, help="Train with a number of samples each train iteration.")
24 | parser.add_argument("--test_batch_size", default=1, type=int, help="Test with a number of samples each test iteration.")
25 | parser.add_argument('-d','--disp_iters', type=int, default=50, help='Log training progress (i.e. loss etc.) on console every iterations.')
26 | parser.add_argument('--save_iters', type=int, default=100, help='Maximum test iterations to perform each test run.')
27 | # paths
28 | parser.add_argument("--train_path", type=str, help="Path to the training file containing the train set files paths")
29 | # model
30 | parser.add_argument("--configuration", required=False, type=str, default='mono', help="Data loader configuration , , , ", choices=['mono', 'lr', 'ud', 'tc'])
31 | parser.add_argument('--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)')
32 | parser.add_argument('--model', default="default", type=str, help='Model selection argument.')
33 | # optimization
34 | parser.add_argument('-o','--optimizer', type=str, default="adam", help='The optimizer that will be used during training.')
35 | parser.add_argument("--opt_state", type=str, help="Path to stored optimizer state file to continue training)")
36 | parser.add_argument('-l','--lr', type=float, default=0.0002, help='Optimization Learning Rate.')
37 | parser.add_argument('-m','--momentum', type=float, default=0.9, help='Optimization Momentum.')
38 | parser.add_argument('--momentum2', type=float, default=0.999, help='Optimization Second Momentum (optional, only used by some optimizers).')
39 | parser.add_argument('--epsilon', type=float, default=1e-8, help='Optimization Epsilon (optional, only used by some optimizers).')
40 | parser.add_argument('--weight_decay', type=float, default=0, help='Optimization Weight Decay.')
41 | # hardware
42 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.')
43 | # other
44 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.')
45 | parser.add_argument("--visdom", type=str, nargs='?', default=None, const="127.0.0.1", help="Visdom server IP (port defaults to 8097)")
46 | parser.add_argument("--visdom_iters", type=int, default=400, help = "Iteration interval that results will be reported at the visdom server for visualization.")
47 | parser.add_argument("--seed", type=int, default=1337, help="Fixed manual seed, zero means no seeding.")
48 | # network specific params
49 | parser.add_argument("--photo_w", type=float, default=1.0, help = "Photometric loss weight.")
50 | parser.add_argument("--smooth_reg_w", type=float, default=0.1, help = "Smoothness regularization weight.")
51 | parser.add_argument("--ssim_window", type=int, default=7, help = "Kernel size to use in SSIM calculation.")
52 | parser.add_argument("--ssim_mode", type=str, default='gaussian', help = "Type of SSIM averaging (either gaussian or box).")
53 | parser.add_argument("--ssim_std", type=float, default=1.5, help = "SSIM standard deviation value used when creating the gaussian averaging kernels.")
54 | parser.add_argument("--ssim_alpha", type=float, default=0.85, help = "Alpha factor to weight the SSIM and L1 losses, where a x SSIM and (1 - a) x L1.")
55 | parser.add_argument("--pred_bias", type=float, default=5.0, help = "Initialize prediction layers' bias to the given value (helps convergence).")
56 | # details
57 | parser.add_argument("--depth_thres", type=float, default=20.0, help = "Depth threshold - depth clipping.")
58 | parser.add_argument("--baseline", type=float, default=0.26, help = "Stereo baseline distance (in either axis).")
59 | parser.add_argument("--width", type=float, default=512, help = "Spherical image width.")
60 | return parser.parse_known_args(args)
61 |
62 | if __name__ == "__main__":
63 | args, unknown = parse_arguments(sys.argv)
64 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0]
65 | # device & visualizers
66 | device, visualizers, model_params = utils.initialize(args)
67 | plot_viz = visualizers[0]
68 | img_viz = visualizers[1]
69 | # model
70 | model = models.get_model(args.model, model_params)
71 | utils.init.initialize_weights(model, args.weight_init, pred_bias=args.pred_bias)
72 | if (len(gpus) > 1):
73 | model = torch.nn.parallel.DataParallel(model, gpus)
74 | model = model.to(device)
75 | # optimizer
76 | optimizer = utils.init_optimizer(model, args)
77 | # train data
78 | train_data = dataset.dataset_360D.Dataset360D(args.train_path, " ", args.configuration, [256, 512])
79 | train_data_iterator = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,\
80 | num_workers=args.batch_size // len(gpus) // len(gpus), pin_memory=False, shuffle=True)
81 | # test data
82 | test_data = dataset.dataset_360D.Dataset360D(args.test_path, " ", args.configuration, [256, 512])
83 | test_data_iterator = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size,\
84 | num_workers=args.batch_size // len(gpus) // len(gpus), pin_memory=False, shuffle=True)
85 | print("Data size : {0} | Test size : {1}".format(\
86 | args.batch_size * train_data_iterator.__len__(), \
87 | args.test_batch_size * test_data_iterator.__len__()))
88 | # params
89 | width = args.width
90 | height = args.width // 2
91 | photo_params = L.photometric.PhotometricLossParameters(
92 | alpha=args.ssim_alpha, l1_estimator='none', ssim_estimator='none',
93 | ssim_mode=args.ssim_mode, std=args.ssim_std, window=args.ssim_window
94 | )
95 | iteration_counter = 0
96 | # meters
97 | total_loss = utils.AverageMeter()
98 | running_photo_loss = utils.AverageMeter()
99 | running_depth_smooth_loss = utils.AverageMeter()
100 | # train / test loop
101 | model.train()
102 | plot_viz.config(**vars(args))
103 | for epoch in range(args.epochs):
104 | print("Training | Epoch: {}".format(epoch))
105 | img_viz.update_epoch(epoch)
106 | for batch_id, batch in enumerate(train_data_iterator):
107 | optimizer.zero_grad()
108 | active_loss = torch.tensor(0.0).to(device)
109 | ''' Data '''
110 | left_rgb = batch['leftRGB'].to(device)
111 | b, _, __, ___ = left_rgb.size()
112 | expand_size = (b, -1, -1, -1)
113 | sgrid = S360.grid.create_spherical_grid(width).to(device)
114 | uvgrid = S360.grid.create_image_grid(width, height).to(device)
115 | right_rgb = batch['rightRGB'].to(device)
116 | left_depth = batch['leftDepth'].to(device)
117 | right_depth = batch['rightDepth'].to(device)
118 | ''' Prediction '''
119 | left_depth_pred = torch.abs(model(left_rgb))
120 | ''' Forward Rendering LR '''
121 | disp = torch.cat(
122 | (
123 | S360.derivatives.dphi_horizontal(sgrid, left_depth_pred, args.baseline),
124 | S360.derivatives.dtheta_horizontal(sgrid, left_depth_pred, args.baseline)
125 | ),
126 | dim=1
127 | )
128 | right_render_coords = uvgrid + disp
129 | right_render_coords[:, 0, :, :] = torch.fmod(right_render_coords[:, 0, :, :] + width, width)
130 | right_render_coords[torch.isnan(right_render_coords)] = 0.0
131 | right_render_coords[torch.isinf(right_render_coords)] = 0.0
132 | right_rgb_t, right_mask_t = L.splatting.render(left_rgb, left_depth_pred,\
133 | right_render_coords, max_depth=args.depth_thres)
134 | ''' Loss LR '''
135 | right_cutoff_mask = (right_depth < args.depth_thres)
136 | right_mask_t &= ~(right_depth > args.depth_thres)
137 | attention_weights = S360.weights.phi_confidence(
138 | S360.grid.create_spherical_grid(width)).to(device)
139 | # attention_weights = S360.weights.spherical_confidence(
140 | # S360.grid.create_spherical_grid(width), zero_low=0.001
141 | # ).to(device)
142 | # attention_weights = torch.ones_like(left_depth)
143 | photo_loss = L.photometric.calculate_loss(right_rgb_t, right_rgb, photo_params,
144 | mask=right_cutoff_mask, weights=attention_weights)
145 | active_loss += photo_loss * args.photo_w
146 | ''' Loss Prior (3D Smoothness) '''
147 | left_xyz = S360.cartesian.coords_3d(sgrid, left_depth_pred)
148 | dI_dxyz = S360.derivatives.dV_dxyz(left_xyz)
149 | guidance_duv = S360.derivatives.dI_duv(left_rgb)
150 | # attention_weights = torch.zeros_like(left_depth)
151 | depth_smooth_loss = L.smoothness.guided_smoothness_loss(
152 | dI_dxyz, guidance_duv, right_cutoff_mask, (1.0 - attention_weights)
153 | * right_cutoff_mask.type(attention_weights.dtype)
154 | )
155 | active_loss += depth_smooth_loss * args.smooth_reg_w
156 | ''' Update Params '''
157 | active_loss.backward()
158 | optimizer.step()
159 | ''' Visualize'''
160 | total_loss.update(active_loss)
161 | running_depth_smooth_loss.update(depth_smooth_loss)
162 | running_photo_loss.update(photo_loss)
163 | iteration_counter += b
164 | if (iteration_counter + 1) % args.disp_iters <= args.batch_size:
165 | print("Epoch: {}, iteration: {}\nPhotometric: {}\nSmoothness: {}\nTotal average loss: {}\n"\
166 | .format(epoch, iteration_counter, running_photo_loss.avg, \
167 | running_depth_smooth_loss.avg, total_loss.avg))
168 | plot_viz.append_loss(epoch + 1, iteration_counter, total_loss.avg, "avg")
169 | plot_viz.append_loss(epoch + 1, iteration_counter, running_photo_loss.avg, "photo")
170 | plot_viz.append_loss(epoch + 1, iteration_counter, running_depth_smooth_loss.avg, "smooth")
171 | total_loss.reset()
172 | running_photo_loss.reset()
173 | running_depth_smooth_loss.reset()
174 | if args.visdom_iters > 0 and (iteration_counter + 1) % args.visdom_iters <= args.batch_size:
175 | img_viz.show_separate_images(left_rgb, 'input')
176 | img_viz.show_separate_images(right_rgb, 'target')
177 | img_viz.show_map(left_depth_pred, 'depth')
178 | img_viz.show_separate_images(torch.clamp(right_rgb_t, min=0.0, max=1.0), 'recon')
179 | ''' Save '''
180 | print("Saving model @ epoch #" + str(epoch))
181 | utils.checkpoint.save_network_state(model, optimizer, epoch,\
182 | args.name + "_model_state", args.save_path)
183 | ''' Test '''
184 | print("Testing model @ epoch #" + str(epoch))
185 | model.eval()
186 | with torch.no_grad():
187 | rmse_avg = torch.tensor(0.0).float()
188 | counter = torch.tensor(0.0).float()
189 | for test_batch_id , test_batch in enumerate(test_data_iterator):
190 | left_rgb = test_batch['leftRGB'].to(device)
191 | b, c, h, w = left_rgb.size()
192 | rads = sgrid.expand(b, -1, -1, -1)
193 | uv = uvgrid.expand(b, -1, -1, -1)
194 | left_depth_pred = torch.abs(model(left_rgb))
195 | left_depth = test_batch['leftDepth'].to(device)
196 | left_depth[torch.isnan(left_depth)] = 50.0
197 | left_depth[torch.isinf(left_depth)] = 50.0
198 | mse = (left_depth_pred ** 2) - (left_depth ** 2)
199 | mse[torch.isnan(mse)] = 0.0
200 | mse[torch.isinf(mse)] = 0.0
201 | mask = (left_depth < args.depth_thres).float()
202 | if torch.sum(mask) == 0:
203 | continue
204 | rmse = torch.sqrt(torch.sum(mse * mask) / torch.sum(mask).float())
205 | if not torch.isnan(rmse):
206 | rmse_avg += rmse.cpu().float()
207 | counter += torch.tensor(b).float()
208 | if counter < args.save_iters:
209 | disp = torch.cat(
210 | (
211 | S360.derivatives.dphi_horizontal(rads, left_depth_pred, args.baseline),
212 | S360.derivatives.dtheta_horizontal(rads, left_depth_pred, args.baseline)
213 | ), dim=1
214 | )
215 | right_render_coords = uv + disp
216 | right_render_coords[:, 0, :, :] = torch.fmod(right_render_coords[:, 0, :, :] + width, width)
217 | right_render_coords[torch.isnan(right_render_coords)] = 0.0
218 | right_render_coords[torch.isinf(right_render_coords)] = 0.0
219 | right_rgb_t, right_mask_t = L.splatting.render(left_rgb, left_depth_pred, right_render_coords, max_depth=args.depth_thres)
220 | IO.image.save_image(os.path.join(args.save_path,\
221 | str(epoch) + "_" + str(counter) + "_#_left.png"), left_rgb)
222 | IO.image.save_image(os.path.join(args.save_path,\
223 | str(epoch) + "_" + str(counter) + "_#_right_t.png"), right_rgb_t)
224 | IO.image.save_data(os.path.join(args.save_path,\
225 | str(epoch) + "_" + str(counter) + "_#_depth.exr"), left_depth_pred, scale=1.0)
226 | rmse_avg /= counter
227 | print("Testing epoch {}: RMSE = {}".format(epoch+1, rmse_avg))
228 | plot_viz.append_loss(epoch + 1, epoch + 1, rmse_avg, "rmse", mode='test')
229 | torch.enable_grad()
230 | model.train()
231 |
--------------------------------------------------------------------------------
/train_sv.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import numpy
5 |
6 | import torch
7 | import torchvision
8 |
9 | import models
10 | import dataset
11 | import utils
12 |
13 | import supervision as L
14 | import exporters as IO
15 | import spherical as S360
16 |
17 | def parse_arguments(args):
18 | usage_text = (
19 | "Omnidirectional Supervised (SV) Training."
20 | )
21 | parser = argparse.ArgumentParser(description=usage_text)
22 | # durations
23 | parser.add_argument('-e',"--epochs", type=int, help="Train for a total number of epochs.")
24 | parser.add_argument('-b',"--batch_size", type=int, help="Train with a number of samples each train iteration.")
25 | parser.add_argument("--test_batch_size", default=1, type=int, help="Test with a number of samples each test iteration.")
26 | parser.add_argument('-d','--disp_iters', type=int, default=50, help='Log training progress (i.e. loss etc.) on console every iterations.')
27 | parser.add_argument('--save_iters', type=int, default=100, help='Maximum test iterations to perform each test run.')
28 | # paths
29 | parser.add_argument("--train_path", type=str, help="Path to the training file containing the train set files paths")
30 | parser.add_argument("--test_path", type=str, help="Path to the testing file containing the test set file paths")
31 | parser.add_argument("--save_path", type=str, help="Path to the folder where the models and results will be saved at.")
32 | # model
33 | parser.add_argument("--configuration", required=False, type=str, default='mono', help="Data loader configuration , , , ", choices=['mono', 'lr', 'ud', 'tc'])
34 | parser.add_argument('--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)')
35 | parser.add_argument('--model', default="default", type=str, help='Model selection argument.')
36 | # optimization
37 | parser.add_argument('-o','--optimizer', type=str, default="adam", help='The optimizer that will be used during training.')
38 | parser.add_argument("--opt_state", type=str, help="Path to stored optimizer state file to continue training)")
39 | parser.add_argument('-l','--lr', type=float, default=0.0002, help='Optimization Learning Rate.')
40 | parser.add_argument('-m','--momentum', type=float, default=0.9, help='Optimization Momentum.')
41 | parser.add_argument('--momentum2', type=float, default=0.999, help='Optimization Second Momentum (optional, only used by some optimizers).')
42 | parser.add_argument('--epsilon', type=float, default=1e-8, help='Optimization Epsilon (optional, only used by some optimizers).')
43 | parser.add_argument('--weight_decay', type=float, default=0, help='Optimization Weight Decay.')
44 | # hardware
45 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.')
46 | # other
47 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.')
48 | parser.add_argument("--visdom", type=str, nargs='?', default=None, const="127.0.0.1", help="Visdom server IP (port defaults to 8097)")
49 | parser.add_argument("--visdom_iters", type=int, default=400, help = "Iteration interval that results will be reported at the visdom server for visualization.")
50 | parser.add_argument("--seed", type=int, default=1337, help="Fixed manual seed, zero means no seeding.")
51 | # network specific params
52 | parser.add_argument("--depth_w", type=float, default=1.0, help = "Photometric loss weight.")
53 | parser.add_argument("--smooth_reg_w", type=float, default=0.1, help = "Smoothness regularization weight.")
54 | parser.add_argument("--ssim_window", type=int, default=7, help = "Kernel size to use in SSIM calculation.")
55 | parser.add_argument("--ssim_mode", type=str, default='gaussian', help = "Type of SSIM averaging (either gaussian or box).")
56 | parser.add_argument("--ssim_std", type=float, default=1.5, help = "SSIM standard deviation value used when creating the gaussian averaging kernels.")
57 | parser.add_argument("--ssim_alpha", type=float, default=0.85, help = "Alpha factor to weight the SSIM and L1 losses, where a x SSIM and (1 - a) x L1.")
58 | parser.add_argument("--pred_bias", type=float, default=5.0, help = "Initialize prediction layers' bias to the given value (helps convergence).")
59 | # details
60 | parser.add_argument("--depth_thres", type=float, default=20.0, help = "Depth threshold - depth clipping.")
61 | parser.add_argument("--baseline", type=float, default=0.26, help = "Stereo baseline distance (in either axis).")
62 | parser.add_argument("--width", type=float, default=512, help = "Spherical image width.")
63 | return parser.parse_known_args(args)
64 |
65 | if __name__ == "__main__":
66 | args, unknown = parse_arguments(sys.argv)
67 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0]
68 | # device & visualizers
69 | device, visualizers, model_params = utils.initialize(args)
70 | plot_viz = visualizers[0]
71 | img_viz = visualizers[1]
72 | # model
73 | model = models.get_model(args.model, model_params)
74 | utils.init.initialize_weights(model, args.weight_init, pred_bias=args.pred_bias)
75 | if (len(gpus) > 1):
76 | model = torch.nn.parallel.DataParallel(model, gpus)
77 | model = model.to(device)
78 | # optimizer
79 | optimizer = utils.init_optimizer(model, args)
80 | # train data
81 | train_data = dataset.dataset_360D.Dataset360D(args.train_path, " ", args.configuration, [256, 512])
82 | train_data_iterator = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,\
83 | num_workers=args.batch_size // len(gpus), pin_memory=True, shuffle=True)
84 | # test data
85 | test_data = dataset.dataset_360D.Dataset360D(args.test_path, " ", args.configuration, [256, 512])
86 | test_data_iterator = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size,\
87 | num_workers=args.batch_size // len(gpus), pin_memory=True, shuffle=True)
88 | print("Data size : {0} | Test size : {1}".format(\
89 | args.batch_size * train_data_iterator.__len__(), \
90 | args.test_batch_size * test_data_iterator.__len__()))
91 | # params
92 | width = args.width
93 | height = args.width // 2
94 | iteration_counter = 0
95 | # meters
96 | total_loss = utils.AverageMeter()
97 | running_depth_loss = utils.AverageMeter()
98 | running_depth_smooth_loss = utils.AverageMeter()
99 | # train / test loop
100 | model.train()
101 | plot_viz.config(**vars(args))
102 | for epoch in range(args.epochs):
103 | print("Training | Epoch: {}".format(epoch))
104 | img_viz.update_epoch(epoch)
105 | for batch_id, batch in enumerate(train_data_iterator):
106 | optimizer.zero_grad()
107 | active_loss = torch.tensor(0.0).to(device)
108 | ''' Data '''
109 | left_rgb = batch['leftRGB'].to(device)
110 | b, _, __, ___ = left_rgb.size()
111 | left_depth = batch['leftDepth'].to(device)
112 | ''' Prediction '''
113 | left_depth_pred = torch.abs(model(left_rgb))
114 | ''' Berhu Loss '''
115 | left_cutoff_mask = (left_depth < args.depth_thres)
116 | attention_weights = S360.weights.theta_confidence(
117 | S360.grid.create_spherical_grid(width)).to(device)
118 | # attention_weights = torch.ones_like(left_depth)
119 | depth_loss = L.direct.calculate_berhu_loss(left_depth_pred, left_depth,
120 | mask=left_cutoff_mask, weights=attention_weights)
121 | active_loss += depth_loss * args.depth_w
122 | ''' Loss Prior (3D Smoothness) '''
123 | left_xyz = S360.cartesian.coords_3d(
124 | S360.grid.create_spherical_grid(width).to(device), left_depth_pred)
125 | dI_dxyz = S360.derivatives.dV_dxyz(left_xyz)
126 | guidance_duv = S360.derivatives.dI_duv(left_rgb)
127 | # attention_weights = torch.zeros_like(left_depth)
128 | depth_smooth_loss = L.smoothness.guided_smoothness_loss(
129 | dI_dxyz, guidance_duv, left_cutoff_mask, (1.0 - attention_weights)
130 | * left_cutoff_mask.type(attention_weights.dtype)
131 | )
132 | active_loss += depth_smooth_loss * args.smooth_reg_w
133 | ''' Update Params '''
134 | active_loss.backward()
135 | optimizer.step()
136 | ''' Visualize'''
137 | total_loss.update(active_loss)
138 | running_depth_smooth_loss.update(depth_smooth_loss)
139 | running_depth_loss.update(depth_loss)
140 | iteration_counter += b
141 | if (iteration_counter + 1) % args.disp_iters <= args.batch_size:
142 | print("Epoch: {}, iteration: {}\nBerhu: {}\nSmoothness: {}\nTotal average loss: {}\n"\
143 | .format(epoch, iteration_counter, running_depth_loss.avg, \
144 | running_depth_smooth_loss.avg, total_loss.avg))
145 | plot_viz.append_loss(epoch + 1, iteration_counter, total_loss.avg, "avg")
146 | plot_viz.append_loss(epoch + 1, iteration_counter, running_depth_loss.avg, "berhu")
147 | plot_viz.append_loss(epoch + 1, iteration_counter, running_depth_smooth_loss.avg, "smooth")
148 | total_loss.reset()
149 | running_depth_loss.reset()
150 | running_depth_smooth_loss.reset()
151 | if args.visdom_iters > 0 and (iteration_counter + 1) % args.visdom_iters <= args.batch_size:
152 | img_viz.show_separate_images(left_rgb, 'input')
153 | img_viz.show_map(left_depth * left_cutoff_mask.float(), 'target')
154 | img_viz.show_map(left_depth_pred, 'depth')
155 | ''' Save '''
156 | print("Saving model @ epoch #" + str(epoch))
157 | utils.checkpoint.save_network_state(model, optimizer, epoch,\
158 | args.name + "_model_state", args.save_path)
159 | ''' Test '''
160 | print("Testing model @ epoch #" + str(epoch))
161 | model.eval()
162 | with torch.no_grad():
163 | rmse_avg = torch.tensor(0.0).float()
164 | counter = torch.tensor(0.0).float()
165 | for test_batch_id , test_batch in enumerate(test_data_iterator):
166 | left_rgb = test_batch['leftRGB'].to(device)
167 | b, c, h, w = left_rgb.size()
168 | left_depth_pred = torch.abs(model(left_rgb))
169 | left_depth = test_batch['leftDepth'].to(device)
170 | left_depth[torch.isnan(left_depth)] = 50.0
171 | left_depth[torch.isinf(left_depth)] = 50.0
172 | mse = (left_depth_pred ** 2) - (left_depth ** 2)
173 | mse[torch.isnan(mse)] = 0.0
174 | mse[torch.isinf(mse)] = 0.0
175 | mask = (left_depth < args.depth_thres).float()
176 | if torch.sum(mask) == 0:
177 | continue
178 | rmse = torch.sqrt(torch.sum(mse * mask) / torch.sum(mask))
179 | if not torch.isnan(rmse):
180 | rmse_avg += rmse.cpu().float()
181 | counter += torch.tensor(b).float()
182 | if counter < args.save_iters:
183 | IO.image.save_image(os.path.join(args.save_path,\
184 | str(epoch) + "_" + str(counter) + "_#_left.png"), left_rgb)
185 | IO.image.save_data(os.path.join(args.save_path,\
186 | str(epoch) + "_" + str(counter) + "_#_depth.exr"), left_depth_pred, scale=1.0)
187 | rmse_avg /= counter
188 | print("Testing epoch {}: RMSE = {}".format(epoch+1, rmse_avg))
189 | plot_viz.append_loss(epoch + 1, iteration_counter, rmse_avg, "rmse")
190 | torch.enable_grad()
191 | model.train()
192 |
--------------------------------------------------------------------------------
/train_tc.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import numpy
5 |
6 | import torch
7 |
8 | import models
9 | import dataset
10 | import utils
11 |
12 | import supervision as L
13 | import exporters as IO
14 | import spherical as S360
15 |
16 | def parse_arguments(args):
17 | usage_text = (
18 | "Omnidirectional Trinocular Stereo Placement (Up-Down & Left-Right , UD+LR) Training"
19 | )
20 | parser = argparse.ArgumentParser(description=usage_text)
21 | # durations
22 | parser.add_argument('-e',"--epochs", type=int, help="Train for a total number of epochs.")
23 | parser.add_argument('-b',"--batch_size", type=int, help="Train with a number of samples each train iteration.")
24 | parser.add_argument("--test_batch_size", default=1, type=int, help="Test with a number of samples each test iteration.")
25 | parser.add_argument('-d','--disp_iters', type=int, default=50, help='Log training progress (i.e. loss etc.) on console every iterations.')
26 | parser.add_argument('--save_iters', type=int, default=100, help='Maximum test iterations to perform each test run.')
27 | # paths
28 | parser.add_argument("--configuration", required=False, type=str, default='mono', help="Data loader configuration , , , ", choices=['mono', 'lr', 'ud', 'tc'])
29 | parser.add_argument("--test_path", type=str, help="Path to the testing file containing the test set file paths")
30 | parser.add_argument("--save_path", type=str, help="Path to the folder where the models and results will be saved at.")
31 | # model
32 | parser.add_argument("--configuration", required = False, type = str, default='tc', help = "Training configuration , , , ", choices=['mono', 'lr', 'ud', 'tc'])
33 | parser.add_argument('--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)')
34 | parser.add_argument('--model', default="default", type=str, help='Model selection argument.')
35 | # optimization
36 | parser.add_argument('-o','--optimizer', type=str, default="adam", help='The optimizer that will be used during training.')
37 | parser.add_argument("--opt_state", type=str, help="Path to stored optimizer state file to continue training)")
38 | parser.add_argument('-l','--lr', type=float, default=0.0002, help='Optimization Learning Rate.')
39 | parser.add_argument('-m','--momentum', type=float, default=0.9, help='Optimization Momentum.')
40 | parser.add_argument('--momentum2', type=float, default=0.999, help='Optimization Second Momentum (optional, only used by some optimizers).')
41 | parser.add_argument('--epsilon', type=float, default=1e-8, help='Optimization Epsilon (optional, only used by some optimizers).')
42 | parser.add_argument('--weight_decay', type=float, default=0, help='Optimization Weight Decay.')
43 | # hardware
44 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.')
45 | # other
46 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.')
47 | parser.add_argument("--visdom", type=str, nargs='?', default=None, const="127.0.0.1", help="Visdom server IP (port defaults to 8097)")
48 | parser.add_argument("--visdom_iters", type=int, default=400, help = "Iteration interval that results will be reported at the visdom server for visualization.")
49 | parser.add_argument("--seed", type=int, default=1337, help="Fixed manual seed, zero means no seeding.")
50 | # network specific params
51 | parser.add_argument("--photo_w", type=float, default=1.0, help = "Photometric loss weight.")
52 | parser.add_argument("--photo_ratio", type=float, default=0.5, help = "Ratio between right (1-ratio) and up (ratio) photometric loss.")
53 | parser.add_argument("--smooth_reg_w", type=float, default=0.1, help = "Smoothness regularization weight.")
54 | parser.add_argument("--ssim_window", type=int, default=7, help = "Kernel size to use in SSIM calculation.")
55 | parser.add_argument("--ssim_mode", type=str, default='gaussian', help = "Type of SSIM averaging (either gaussian or box).")
56 | parser.add_argument("--ssim_std", type=float, default=1.5, help = "SSIM standard deviation value used when creating the gaussian averaging kernels.")
57 | parser.add_argument("--ssim_alpha", type=float, default=0.85, help = "Alpha factor to weight the SSIM and L1 losses, where a x SSIM and (1 - a) x L1.")
58 | parser.add_argument("--pred_bias", type=float, default=5.0, help = "Initialize prediction layers' bias to the given value (helps convergence).")
59 | # details
60 | parser.add_argument("--depth_thres", type=float, default=20.0, help = "Depth threshold - depth clipping.")
61 | parser.add_argument("--baseline", type=float, default=0.26, help = "Stereo baseline distance (in either axis).")
62 | parser.add_argument("--width", type=float, default=512, help = "Spherical image width.")
63 | return parser.parse_known_args(args)
64 |
65 | if __name__ == "__main__":
66 | args, unknown = parse_arguments(sys.argv)
67 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0]
68 | # device & visualizers
69 | device, visualizers, model_params = utils.initialize(args)
70 | plot_viz = visualizers[0]
71 | img_viz = visualizers[1]
72 | # model
73 | model = models.get_model(args.model, model_params)
74 | utils.init.initialize_weights(model, args.weight_init, pred_bias=args.pred_bias)
75 | if (len(gpus) > 1):
76 | model = torch.nn.parallel.DataParallel(model, gpus)
77 | model = model.to(device)
78 | # optimizer
79 | optimizer = utils.init_optimizer(model, args)
80 | # train data
81 | train_data = dataset.dataset_360D.Dataset360D(args.train_path, " ", args.configuration, [256, 512])
82 | train_data_iterator = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,\
83 | num_workers=args.batch_size // len(gpus) // 4, pin_memory=False, shuffle=True)
84 | # test data
85 | test_data = dataset.dataset_360D.Dataset360D(args.test_path, " ", args.configuration, [256, 512])
86 | test_data_iterator = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size,\
87 | num_workers=args.batch_size // len(gpus) // 4, pin_memory=False, shuffle=True)
88 | print("Data size : {0} | Test size : {1}".format(\
89 | args.batch_size * train_data_iterator.__len__(), \
90 | args.test_batch_size * test_data_iterator.__len__()))
91 | # params
92 | width = args.width
93 | height = args.width // 2
94 | photo_params = L.photometric.PhotometricLossParameters(
95 | alpha=args.ssim_alpha, l1_estimator='none', ssim_estimator='none',
96 | ssim_mode=args.ssim_mode, std=args.ssim_std, window=args.ssim_window
97 | )
98 | iteration_counter = 0
99 | # meters
100 | total_loss = utils.AverageMeter()
101 | running_photo_loss_lr = utils.AverageMeter()
102 | running_photo_loss_ud = utils.AverageMeter()
103 | running_depth_smooth_loss = utils.AverageMeter()
104 | # train / test loop
105 | model.train()
106 | plot_viz.config(**vars(args))
107 | for epoch in range(args.epochs):
108 | print("Training | Epoch: {}".format(epoch))
109 | img_viz.update_epoch(epoch)
110 | for batch_id, batch in enumerate(train_data_iterator):
111 | optimizer.zero_grad()
112 | active_loss = torch.tensor(0.0).to(device)
113 | ''' Data '''
114 | left_rgb = batch['leftRGB'].to(device)
115 | b, _, __, ___ = left_rgb.size()
116 | expand_size = (b, -1, -1, -1)
117 | sgrid = S360.grid.create_spherical_grid(width).to(device)
118 | uvgrid = S360.grid.create_image_grid(width, height).to(device)
119 | right_rgb = batch['rightRGB'].to(device)
120 | up_rgb = batch['upRGB'].to(device)
121 | left_depth = batch['leftDepth'].to(device)
122 | up_depth = batch['upDepth'].to(device)
123 | right_depth = batch['rightDepth'].to(device)
124 | ''' Prediction '''
125 | left_depth_pred = torch.abs(model(left_rgb))
126 | ''' Forward Rendering LR '''
127 | disp_lr = torch.cat(
128 | (
129 | S360.derivatives.dphi_horizontal(sgrid, left_depth_pred, args.baseline),
130 | S360.derivatives.dtheta_horizontal(sgrid, left_depth_pred, args.baseline)
131 | ),
132 | dim=1
133 | )
134 | right_render_coords = uvgrid + disp_lr
135 | right_render_coords[:, 0, :, :] = torch.fmod(right_render_coords[:, 0, :, :] + width, width)
136 | right_render_coords[torch.isnan(right_render_coords)] = 0.0
137 | right_render_coords[torch.isinf(right_render_coords)] = 0.0
138 | right_rgb_t, right_mask_t = L.splatting.render(left_rgb, left_depth_pred,\
139 | right_render_coords, max_depth=args.depth_thres)
140 | ''' Forward Rendering UD '''
141 | disp_ud = torch.cat(
142 | (
143 | torch.zeros_like(left_depth_pred),
144 | S360.derivatives.dtheta_vertical(sgrid, left_depth_pred, args.baseline)
145 | ),
146 | dim=1
147 | )
148 | up_render_coords = uvgrid + disp_ud
149 | up_render_coords[torch.isnan(up_render_coords)] = 0.0
150 | up_render_coords[torch.isinf(up_render_coords)] = 0.0
151 | up_rgb_t, up_mask_t = L.splatting.render(left_rgb, left_depth_pred,\
152 | up_render_coords, max_depth=args.depth_thres)
153 | ''' Loss LR '''
154 | right_cutoff_mask = (right_depth < args.depth_thres)
155 | attention_weights_lr = S360.weights.phi_confidence(
156 | S360.grid.create_spherical_grid(width)).to(device)
157 | # attention_weights_lr = S360.weights.spherical_confidence(
158 | # S360.grid.create_spherical_grid(width), zero_low=0.001
159 | # ).to(device)
160 | photo_loss_lr = L.photometric.calculate_loss(right_rgb_t, right_rgb, photo_params,
161 | mask=right_cutoff_mask, weights=attention_weights_lr)
162 | active_loss += photo_loss_lr * args.photo_w * (1 - args.photo_ratio)
163 | ''' Loss UD '''
164 | up_cutoff_mask = (up_depth < args.depth_thres)
165 | attention_weights_ud = S360.weights.theta_confidence(
166 | S360.grid.create_spherical_grid(width)).to(device)
167 | photo_loss_ud = L.photometric.calculate_loss(up_rgb_t, up_rgb, photo_params,
168 | mask=up_cutoff_mask, weights=attention_weights_ud)
169 | active_loss += photo_loss_ud * args.photo_w * args.photo_ratio
170 | ''' Loss Prior (3D Smoothness) '''
171 | left_xyz = S360.cartesian.coords_3d(sgrid, left_depth_pred)
172 | dI_dxyz = S360.derivatives.dV_dxyz(left_xyz)
173 | tc_cuttof_mask = right_cutoff_mask & up_cutoff_mask
174 | guidance_duv = S360.derivatives.dI_duv(left_rgb)
175 | depth_smooth_loss = L.smoothness.guided_smoothness_loss(
176 | dI_dxyz, guidance_duv, tc_cuttof_mask, (1.0 - attention_weights_ud)
177 | * tc_cuttof_mask.type(attention_weights_ud.dtype)
178 | )
179 | active_loss += depth_smooth_loss * args.smooth_reg_w
180 | ''' Update Params '''
181 | active_loss.backward()
182 | optimizer.step()
183 | ''' Visualize'''
184 | total_loss.update(active_loss)
185 | running_depth_smooth_loss.update(depth_smooth_loss)
186 | running_photo_loss_lr.update(photo_loss_lr)
187 | running_photo_loss_ud.update(photo_loss_ud)
188 | iteration_counter += b
189 | if (iteration_counter + 1) % args.disp_iters <= args.batch_size:
190 | print("Epoch: {}, iteration: {}\nPhotometric (LR-UD): {} - {}\nSmoothness: {}\nTotal average loss: {}\n"\
191 | .format(epoch, iteration_counter, running_photo_loss_lr.avg, \
192 | running_photo_loss_ud.avg, running_depth_smooth_loss.avg, total_loss.avg))
193 | plot_viz.append_loss(epoch + 1, iteration_counter, total_loss.avg, "avg")
194 | plot_viz.append_loss(epoch + 1, iteration_counter, running_photo_loss_lr.avg, "photo_lr")
195 | plot_viz.append_loss(epoch + 1, iteration_counter, running_photo_loss_ud.avg, "photo_ud")
196 | plot_viz.append_loss(epoch + 1, iteration_counter, running_depth_smooth_loss.avg, "smooth")
197 | total_loss.reset()
198 | running_photo_loss_lr.reset()
199 | running_photo_loss_ud.reset()
200 | running_depth_smooth_loss.reset()
201 | if args.visdom_iters > 0 and (iteration_counter + 1) % args.visdom_iters <= args.batch_size:
202 | img_viz.show_separate_images(left_rgb, 'input')
203 | img_viz.show_separate_images(right_rgb, 'right')
204 | img_viz.show_separate_images(up_rgb, 'up')
205 | img_viz.show_map(left_depth_pred, 'depth')
206 | img_viz.show_separate_images(torch.clamp(right_rgb_t, min=0.0, max=1.0), 'recon_lr')
207 | img_viz.show_separate_images(torch.clamp(up_rgb_t, min=0.0, max=1.0), 'recon_ud')
208 | ''' Save '''
209 | print("Saving model @ epoch #" + str(epoch))
210 | utils.checkpoint.save_network_state(model, optimizer, epoch,\
211 | args.name + "_model_state", args.save_path)
212 | ''' Test '''
213 | print("Testing model @ epoch #" + str(epoch))
214 | model.eval()
215 | with torch.no_grad():
216 | rmse_avg = torch.tensor(0.0).float()
217 | counter = torch.tensor(0.0).float()
218 | for test_batch_id , test_batch in enumerate(test_data_iterator):
219 | left_rgb = test_batch['leftRGB'].to(device)
220 | b, c, h, w = left_rgb.size()
221 | rads = sgrid.expand(b, -1, -1, -1)
222 | uv = uvgrid.expand(b, -1, -1, -1)
223 | left_depth_pred = torch.abs(model(left_rgb))
224 | left_depth = test_batch['leftDepth'].to(device)
225 | left_depth[torch.isnan(left_depth)] = 50.0
226 | left_depth[torch.isinf(left_depth)] = 50.0
227 | mse = (left_depth_pred ** 2) - (left_depth ** 2)
228 | mse[torch.isnan(mse)] = 0.0
229 | mse[torch.isinf(mse)] = 0.0
230 | mask = (left_depth < args.depth_thres).float()
231 | if torch.sum(mask) == 0:
232 | continue
233 | rmse = torch.sqrt(torch.sum(mse * mask) / torch.sum(mask).float())
234 | if not torch.isnan(rmse):
235 | rmse_avg += rmse.cpu().float()
236 | counter += torch.tensor(b).float()
237 | if counter < args.save_iters:
238 | disp = torch.cat(
239 | (
240 | S360.derivatives.dphi_horizontal(rads, left_depth_pred, args.baseline),
241 | S360.derivatives.dtheta_horizontal(rads, left_depth_pred, args.baseline)
242 | ), dim=1
243 | )
244 | right_render_coords = uv + disp
245 | right_render_coords[:, 0, :, :] = torch.fmod(right_render_coords[:, 0, :, :] + width, width)
246 | right_render_coords[torch.isnan(right_render_coords)] = 0.0
247 | right_render_coords[torch.isinf(right_render_coords)] = 0.0
248 | right_rgb_t, right_mask_t = L.splatting.render(left_rgb, left_depth_pred, right_render_coords, max_depth=args.depth_thres)
249 | # save
250 | IO.image.save_image(os.path.join(args.save_path,\
251 | str(epoch) + "_" + str(counter) + "_#_left.png"), left_rgb)
252 | IO.image.save_image(os.path.join(args.save_path,\
253 | str(epoch) + "_" + str(counter) + "_#_right_t.png"), right_rgb_t)
254 | IO.image.save_data(os.path.join(args.save_path,\
255 | str(epoch) + "_" + str(counter) + "_#_depth.exr"), left_depth_pred, scale=1.0)
256 | if (counter == 0) or (torch.isnan(rmse_avg) > 0):
257 | print("Error calculating RMSE (val:%f , sum:%d)" % (rmse_avg, counter))
258 | plot_viz.append_loss(epoch + 1, epoch + 1, torch.tensor(0.0), "rmse", mode='test')
259 | else:
260 | rmse_avg /= counter
261 | print("Testing epoch {}: RMSE = {}".format(epoch+1, rmse_avg))
262 | plot_viz.append_loss(epoch + 1, epoch + 1, rmse_avg, "rmse", mode='test')
263 | torch.enable_grad()
264 | model.train()
265 |
--------------------------------------------------------------------------------
/train_ud.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import numpy
5 |
6 | import torch
7 |
8 | import models
9 | import dataset
10 | import utils
11 |
12 | import supervision as L
13 | import exporters as IO
14 | import spherical as S360
15 |
16 | def parse_arguments(args):
17 | usage_text = (
18 | "Omnidirectional Vertical Stereo Placement (Up-Down , UD) Training."
19 | )
20 | parser = argparse.ArgumentParser(description=usage_text)
21 | # durations
22 | parser.add_argument('-e',"--epochs", type=int, help="Train for a total number of epochs.")
23 | parser.add_argument('-b',"--batch_size", type=int, help="Train with a number of samples each train iteration.")
24 | parser.add_argument("--test_batch_size", default=1, type=int, help="Test with a number of samples each test iteration.")
25 | parser.add_argument('-d','--disp_iters', type=int, default=50, help='Log training progress (i.e. loss etc.) on console every iterations.')
26 | parser.add_argument('--save_iters', type=int, default=100, help='Maximum test iterations to perform each test run.')
27 | # paths
28 | parser.add_argument("--train_path", type=str, help="Path to the training file containing the train set files paths")
29 | parser.add_argument("--test_path", type=str, help="Path to the testing file containing the test set file paths")
30 | parser.add_argument("--save_path", type=str, help="Path to the folder where the models and results will be saved at.")
31 | # model
32 | parser.add_argument("--configuration", required=False, type=str, default='mono', help="Data loader configuration , , , ", choices=['mono', 'lr', 'ud', 'tc'])
33 | parser.add_argument('--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)')
34 | parser.add_argument('--model', default="default", type=str, help='Model selection argument.')
35 | # optimization
36 | parser.add_argument('-o','--optimizer', type=str, default="adam", help='The optimizer that will be used during training.')
37 | parser.add_argument("--opt_state", type=str, help="Path to stored optimizer state file to continue training)")
38 | parser.add_argument('-l','--lr', type=float, default=0.0002, help='Optimization Learning Rate.')
39 | parser.add_argument('-m','--momentum', type=float, default=0.9, help='Optimization Momentum.')
40 | parser.add_argument('--momentum2', type=float, default=0.999, help='Optimization Second Momentum (optional, only used by some optimizers).')
41 | parser.add_argument('--epsilon', type=float, default=1e-8, help='Optimization Epsilon (optional, only used by some optimizers).')
42 | parser.add_argument('--weight_decay', type=float, default=0, help='Optimization Weight Decay.')
43 | # hardware
44 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.')
45 | # other
46 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.')
47 | parser.add_argument("--visdom", type=str, nargs='?', default=None, const="127.0.0.1", help="Visdom server IP (port defaults to 8097)")
48 | parser.add_argument("--visdom_iters", type=int, default=400, help = "Iteration interval that results will be reported at the visdom server for visualization.")
49 | parser.add_argument("--seed", type=int, default=1337, help="Fixed manual seed, zero means no seeding.")
50 | # network specific params
51 | parser.add_argument("--photo_w", type=float, default=1.0, help = "Photometric loss weight.")
52 | parser.add_argument("--smooth_reg_w", type=float, default=0.1, help = "Smoothness regularization weight.")
53 | parser.add_argument("--ssim_window", type=int, default=7, help = "Kernel size to use in SSIM calculation.")
54 | parser.add_argument("--ssim_mode", type=str, default='gaussian', help = "Type of SSIM averaging (either gaussian or box).")
55 | parser.add_argument("--ssim_std", type=float, default=1.5, help = "SSIM standard deviation value used when creating the gaussian averaging kernels.")
56 | parser.add_argument("--ssim_alpha", type=float, default=0.85, help = "Alpha factor to weight the SSIM and L1 losses, where a x SSIM and (1 - a) x L1.")
57 | parser.add_argument("--pred_bias", type=float, default=5.0, help = "Initialize prediction layers' bias to the given value (helps convergence).")
58 | # details
59 | parser.add_argument("--depth_thres", type=float, default=20.0, help = "Depth threshold - depth clipping.")
60 | parser.add_argument("--baseline", type=float, default=0.26, help = "Stereo baseline distance (in either axis).")
61 | parser.add_argument("--width", type=float, default=512, help = "Spherical image width.")
62 | return parser.parse_known_args(args)
63 |
64 | if __name__ == "__main__":
65 | args, unknown = parse_arguments(sys.argv)
66 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0]
67 | # device & visualizers
68 | device, visualizers, model_params = utils.initialize(args)
69 | plot_viz = visualizers[0]
70 | img_viz = visualizers[1]
71 | # model
72 | model = models.get_model(args.model, model_params)
73 | utils.init.initialize_weights(model, args.weight_init, pred_bias=args.pred_bias)
74 | if (len(gpus) > 1):
75 | model = torch.nn.parallel.DataParallel(model, gpus)
76 | model = model.to(device)
77 | # optimizer
78 | optimizer = utils.init_optimizer(model, args)
79 | # train data
80 | train_data = dataset.dataset_360D.Dataset360D(args.train_path, " ", args.configuration, [256, 512])
81 | train_data_iterator = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,\
82 | num_workers=args.batch_size // 4 // len(gpus), pin_memory=False, shuffle=True)
83 | # test data
84 | test_data = dataset.dataset_360D.Dataset360D(args.test_path, " ", args.configuration, [256, 512])
85 | test_data_iterator = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size,\
86 | num_workers=args.batch_size // 4 // len(gpus), pin_memory=False, shuffle=True)
87 | print("Data size : {0} | Test size : {1}".format(\
88 | args.batch_size * train_data_iterator.__len__(), \
89 | args.test_batch_size * test_data_iterator.__len__()))
90 | # params
91 | width = args.width
92 | height = args.width // 2
93 | photo_params = L.photometric.PhotometricLossParameters(
94 | alpha=args.ssim_alpha, l1_estimator='none', ssim_estimator='none',
95 | ssim_mode=args.ssim_mode, std=args.ssim_std, window=args.ssim_window
96 | )
97 | iteration_counter = 0
98 | # meters
99 | total_loss = utils.AverageMeter()
100 | running_photo_loss = utils.AverageMeter()
101 | running_depth_smooth_loss = utils.AverageMeter()
102 | # train / test loop
103 | model.train()
104 | plot_viz.config(**vars(args))
105 | for epoch in range(args.epochs):
106 | print("Training | Epoch: {}".format(epoch))
107 | img_viz.update_epoch(epoch)
108 | for batch_id, batch in enumerate(train_data_iterator):
109 | optimizer.zero_grad()
110 | active_loss = torch.tensor(0.0).to(device)
111 | ''' Data '''
112 | left_rgb = batch['leftRGB'].to(device)
113 | b, _, __, ___ = left_rgb.size()
114 | expand_size = (b, -1, -1, -1)
115 | sgrid = S360.grid.create_spherical_grid(width).to(device)
116 | uvgrid = S360.grid.create_image_grid(width, height).to(device)
117 | up_rgb = batch['upRGB'].to(device)
118 | left_depth = batch['leftDepth'].to(device)
119 | up_depth = batch['upDepth'].to(device)
120 | ''' Prediction '''
121 | left_depth_pred = torch.abs(model(left_rgb))
122 | ''' Forward Rendering UD '''
123 | disp = torch.cat(
124 | (
125 | torch.zeros_like(left_depth_pred),
126 | S360.derivatives.dtheta_vertical(sgrid, left_depth_pred, args.baseline)
127 | ),
128 | dim=1
129 | )
130 | up_render_coords = uvgrid + disp
131 | up_render_coords[torch.isnan(up_render_coords)] = 0.0
132 | up_render_coords[torch.isinf(up_render_coords)] = 0.0
133 | up_rgb_t, up_mask_t = L.splatting.render(left_rgb, left_depth_pred,\
134 | up_render_coords, max_depth=args.depth_thres)
135 | ''' Loss UD '''
136 | up_cutoff_mask = (up_depth < args.depth_thres)
137 | up_mask_t &= ~(up_depth > args.depth_thres)
138 | attention_weights = S360.weights.theta_confidence(
139 | S360.grid.create_spherical_grid(width)).to(device)
140 | # attention_weights = torch.ones_like(left_depth)
141 | photo_loss = L.photometric.calculate_loss(up_rgb_t, up_rgb, photo_params,
142 | mask=up_cutoff_mask, weights=attention_weights)
143 | active_loss += photo_loss * args.photo_w
144 | ''' Loss Prior (3D Smoothness) '''
145 | left_xyz = S360.cartesian.coords_3d(sgrid, left_depth_pred)
146 | dI_dxyz = S360.derivatives.dV_dxyz(left_xyz)
147 | guidance_duv = S360.derivatives.dI_duv(left_rgb)
148 | # attention_weights = torch.zeros_like(left_depth)
149 | depth_smooth_loss = L.smoothness.guided_smoothness_loss(
150 | dI_dxyz, guidance_duv, up_cutoff_mask, (1.0 - attention_weights)
151 | * up_cutoff_mask.type(attention_weights.dtype)
152 | )
153 | active_loss += depth_smooth_loss * args.smooth_reg_w
154 | ''' Update Params '''
155 | active_loss.backward()
156 | optimizer.step()
157 | ''' Visualize'''
158 | total_loss.update(active_loss)
159 | running_depth_smooth_loss.update(depth_smooth_loss)
160 | running_photo_loss.update(photo_loss)
161 | iteration_counter += b
162 | if (iteration_counter + 1) % args.disp_iters <= args.batch_size:
163 | print("Epoch: {}, iteration: {}\nPhotometric: {}\nSmoothness: {}\nTotal average loss: {}\n"\
164 | .format(epoch, iteration_counter, running_photo_loss.avg, \
165 | running_depth_smooth_loss.avg, total_loss.avg))
166 | plot_viz.append_loss(epoch + 1, iteration_counter, total_loss.avg, "avg")
167 | plot_viz.append_loss(epoch + 1, iteration_counter, running_photo_loss.avg, "photo")
168 | plot_viz.append_loss(epoch + 1, iteration_counter, running_depth_smooth_loss.avg, "smooth")
169 | total_loss.reset()
170 | running_photo_loss.reset()
171 | running_depth_smooth_loss.reset()
172 | if args.visdom_iters > 0 and (iteration_counter + 1) % args.visdom_iters <= args.batch_size:
173 | img_viz.show_separate_images(left_rgb, 'input')
174 | img_viz.show_separate_images(up_rgb, 'target')
175 | img_viz.show_map(left_depth_pred, 'depth')
176 | img_viz.show_separate_images(torch.clamp(up_rgb_t, min=0.0, max=1.0), 'recon')
177 | ''' Save '''
178 | print("Saving model @ epoch #" + str(epoch))
179 | utils.checkpoint.save_network_state(model, optimizer, epoch,\
180 | args.name + "_model_state", args.save_path)
181 | ''' Test '''
182 | print("Testing model @ epoch #" + str(epoch))
183 | model.eval()
184 | with torch.no_grad():
185 | rmse_avg = torch.tensor(0.0).float()
186 | counter = torch.tensor(0.0).float()
187 | for test_batch_id , test_batch in enumerate(test_data_iterator):
188 | left_rgb = test_batch['leftRGB'].to(device)
189 | b, c, h, w = left_rgb.size()
190 | rads = sgrid.expand(b, -1, -1, -1)
191 | uv = uvgrid.expand(b, -1, -1, -1)
192 | left_depth_pred = torch.abs(model(left_rgb))
193 | left_depth = test_batch['leftDepth'].to(device)
194 | left_depth[torch.isnan(left_depth)] = 50.0
195 | left_depth[torch.isinf(left_depth)] = 50.0
196 | mse = (left_depth_pred ** 2) - (left_depth ** 2)
197 | mse[torch.isnan(mse)] = 0.0
198 | mse[torch.isinf(mse)] = 0.0
199 | mask = (left_depth < args.depth_thres).float()
200 | if torch.sum(mask) == 0:
201 | continue
202 | rmse = torch.sqrt(torch.sum(mse * mask) / torch.sum(mask).float())
203 | if not torch.isnan(rmse):
204 | rmse_avg += rmse.cpu().float()
205 | counter += torch.tensor(b).float()
206 | if counter < args.save_iters:
207 | disp = torch.cat(
208 | (
209 | torch.zeros_like(left_depth_pred),
210 | S360.derivatives.dtheta_vertical(rads, left_depth_pred, args.baseline)
211 | ), dim=1
212 | )
213 | up_render_coords = uv + disp
214 | up_render_coords[torch.isnan(up_render_coords)] = 0.0
215 | up_render_coords[torch.isinf(up_render_coords)] = 0.0
216 | up_rgb_t, up_mask_t = L.splatting.render(left_rgb, left_depth_pred, \
217 | up_render_coords, max_depth=args.depth_thres)
218 | # save
219 | IO.image.save_image(os.path.join(args.save_path,\
220 | str(epoch) + "_" + str(counter) + "_#_left.png"), left_rgb)
221 | IO.image.save_image(os.path.join(args.save_path,\
222 | str(epoch) + "_" + str(counter) + "_#_up_t.png"), up_rgb_t)
223 | IO.image.save_data(os.path.join(args.save_path,\
224 | str(epoch) + "_" + str(counter) + "_#_depth.exr"), left_depth_pred, scale=1.0)
225 | rmse_avg /= counter
226 | print("Testing epoch {}: RMSE = {}".format(epoch+1, rmse_avg))
227 | plot_viz.append_loss(epoch + 1, epoch + 1, rmse_avg, "rmse", mode='test')
228 | torch.enable_grad()
229 | model.train()
230 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .init import *
2 | from .opt import *
3 | from .visualization import *
4 | from .framework import *
5 | from .meters import *
6 | from .checkpoint import *
--------------------------------------------------------------------------------
/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from os import path
4 |
5 | def save_network_state(model, optimizer, epoch , name , save_path):
6 | if not path.exists(save_path):
7 | raise ValueError("{} not a valid path to save model state".format(save_path))
8 | torch.save(
9 | {
10 | 'epoch' : epoch,
11 | 'model_state_dict' : model.state_dict(),
12 | 'optimizer_state_dict' : optimizer.state_dict()
13 | }, path.join(save_path, "{}_e{}.pt".format(name, epoch)))
14 |
15 |
--------------------------------------------------------------------------------
/utils/framework.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import datetime
3 | import numpy
4 | import random
5 |
6 | from .opt import *
7 | from .visualization import *
8 |
9 | def initialize(args): #TODO: add visdom count as argument
10 | # create and init device
11 | print("{} | Torch Version: {}".format(datetime.datetime.now(), torch.__version__))
12 | if args.seed > 0:
13 | print("Set to reproducibility mode with seed: {}".format(args.seed))
14 | torch.manual_seed(args.seed)
15 | torch.cuda.manual_seed_all(args.seed)
16 | numpy.random.seed(args.seed)
17 | torch.backends.cudnn.deterministic = True
18 | torch.backends.cudnn.benchmark = False
19 | random.seed(args.seed)
20 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0]
21 | device = torch.device("cuda:{}" .format(gpus[0]) if torch.cuda.is_available() and len(gpus) > 0 and gpus[0] >= 0 else "cpu")
22 | print("Training {0} for {1} epochs using a batch size of {2} on {3}".format(args.name, args.epochs, args.batch_size, device))
23 | # create visualizer
24 | visualizer = (NullVisualizer(), NullVisualizer())\
25 | if args.visdom is None\
26 | else (
27 | VisdomPlotVisualizer(args.name + "_plots_", args.visdom),
28 | VisdomImageVisualizer(args.name + "_images_", args.visdom,\
29 | count=2 if 2 <= args.batch_size else args.batch_size)
30 | )
31 | if args.visdom is None:
32 | args.visdom_iters = 0
33 | # create & init model
34 | model_params = {
35 | 'width': 512,
36 | 'height': 256,
37 | 'configuration': args.configuration,
38 | }
39 | return device, visualizer, model_params
40 |
41 | def init_optimizer(model, args):
42 | opt_params = OptimizerParameters(learning_rate=args.lr, momentum=args.momentum,\
43 | momentum2=args.momentum2, epsilon=args.epsilon)
44 | optimizer = get_optimizer(args.optimizer, model.parameters(), opt_params)
45 | if args.opt_state is not None:
46 | opt_state = torch.load(args.opt_state)
47 | print("Loading previously saved optimizer state from {}".format(args.opt_state))
48 | optimizer.load_state_dict(opt_state["optimizer_state_dict"])
49 | return optimizer
50 |
--------------------------------------------------------------------------------
/utils/init.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import os
4 | import sys
5 |
6 | def initialize_weights(model, init = "xavier", pred_bias=None):
7 | init_func = None
8 | if init == "xavier":
9 | init_func = torch.nn.init.xavier_normal_
10 | elif init == "kaiming":
11 | init_func = torch.nn.init.kaiming_normal_
12 | elif init == "gaussian" or init == "normal":
13 | init_func = torch.nn.init.normal_
14 |
15 | if init_func is not None:
16 | #TODO: logging /w print or lib
17 | for module in model.modules():
18 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \
19 | or isinstance(module, torch.nn.ConvTranspose2d):
20 | init_func(module.weight)
21 | if module.bias is not None:
22 | module.bias.data.zero_()
23 | elif isinstance(module, torch.nn.BatchNorm2d):
24 | module.weight.data.fill_(1)
25 | module.bias.data.zero_()
26 | if pred_bias is not None:
27 | list(model.modules())[-1].bias.data.fill_(pred_bias)
28 | elif os.path.exists(init):
29 | #TODO: logging /w print or lib
30 | weights = torch.load(init, map_location={'cuda:1':'cuda:0'})
31 | model.load_state_dict(weights["model_state_dict"])
32 | else:
33 | print("Error when initializing model's weights, {} either doesn't exist or is not a valid initialization function.".format(init), \
34 | file=sys.stderr)
35 |
36 | def initialize_prediction_bias(model, pred_bias=None):
37 | if pred_bias is not None:
38 | list(model.modules())[-1].bias.data.fill_(pred_bias)
39 |
40 |
--------------------------------------------------------------------------------
/utils/meters.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # Computes and stores the average and current value
4 | class AverageMeter(object):
5 | def __init__(self):
6 | self.reset()
7 |
8 | def reset(self):
9 | self.val = torch.tensor(0.0)
10 | self.avg = torch.tensor(0.0)
11 | self.sum = torch.tensor(0.0)
12 | self.count = torch.tensor(0.0)
13 |
14 | def update(self, val, n=1):
15 | self.val = val
16 | self.sum += val * n
17 | self.count += n
18 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/utils/opt.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.optim as optim
4 | from torch.optim import Optimizer
5 |
6 | import sys
7 |
8 | class OptimizerParameters(object):
9 | def __init__(self, learning_rate=0.001, momentum=0.9, momentum2=0.999,\
10 | epsilon=1e-8, weight_decay=0.0005, damp=0):
11 | super(OptimizerParameters, self).__init__()
12 | self.learning_rate = learning_rate
13 | self.momentum = momentum
14 | self.momentum2 = momentum2
15 | self.epsilon = epsilon
16 | self.damp = damp
17 | self.weight_decay = weight_decay
18 |
19 | def get_learning_rate(self):
20 | return self.learning_rate
21 |
22 | def get_momentum(self):
23 | return self.momentum
24 |
25 | def get_momentum2(self):
26 | return self.momentum2
27 |
28 | def get_epsilon(self):
29 | return self.epsilon
30 |
31 | def get_weight_decay(self):
32 | return self.weight_decay
33 |
34 | def get_damp(self):
35 | return self.damp
36 |
37 | def get_optimizer(opt_type, model_params, opt_params):
38 | if opt_type == "adam":
39 | return optim.Adam(model_params, \
40 | lr=opt_params.get_learning_rate(), \
41 | betas=(opt_params.get_momentum(), opt_params.get_momentum2()), \
42 | eps=opt_params.get_epsilon(),
43 | )
44 | elif opt_type == "adabound" or opt_type == "amsbound":
45 | return AdaBound(model_params, \
46 | lr=opt_params.get_learning_rate(), \
47 | betas=(opt_params.get_momentum(), opt_params.get_momentum2()), \
48 | eps=opt_params.get_epsilon(),
49 | weight_decay=opt_params.get_weight_decay(),\
50 | final_lr=0.001, gamma=0.002,\
51 | amsbound=True if opt_type == "amsbound" else False
52 | )
53 | elif opt_type == "sgd":
54 | return optim.SGD(model_params, \
55 | lr=opt_params.get_learning_rate(), \
56 | momentum=opt_params.get_momentum(), \
57 | weight_decay=opt_params.get_weight_decay(), \
58 | dampening=opt_params.get_damp() \
59 | )
60 | else:
61 | print("Error when initializing optimizer, {} is not a valid optimizer type.".format(opt_type), \
62 | file=sys.stderr)
63 | return None
64 |
65 | def adjust_learning_rate(optimizer, epoch, scale=2):
66 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
67 | for param_group in optimizer.param_groups:
68 | lr = param_group['lr']
69 | lr = lr * (0.1 ** (epoch // scale))
70 | param_group['lr'] = lr
71 |
72 |
73 | '''
74 | Code from https://github.com/Luolc/AdaBound
75 | '''
76 |
77 | class AdaBound(Optimizer):
78 | """Implements AdaBound algorithm.
79 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_.
80 | Arguments:
81 | params (iterable): iterable of parameters to optimize or dicts defining
82 | parameter groups
83 | lr (float, optional): Adam learning rate (default: 1e-3)
84 | betas (Tuple[float, float], optional): coefficients used for computing
85 | running averages of gradient and its square (default: (0.9, 0.999))
86 | final_lr (float, optional): final (SGD) learning rate (default: 0.1)
87 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
88 | eps (float, optional): term added to the denominator to improve
89 | numerical stability (default: 1e-8)
90 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
91 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
92 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate:
93 | https://openreview.net/forum?id=Bkg3g2R9FX
94 | """
95 |
96 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3,
97 | eps=1e-8, weight_decay=0, amsbound=False):
98 | if not 0.0 <= lr:
99 | raise ValueError("Invalid learning rate: {}".format(lr))
100 | if not 0.0 <= eps:
101 | raise ValueError("Invalid epsilon value: {}".format(eps))
102 | if not 0.0 <= betas[0] < 1.0:
103 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
104 | if not 0.0 <= betas[1] < 1.0:
105 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
106 | if not 0.0 <= final_lr:
107 | raise ValueError("Invalid final learning rate: {}".format(final_lr))
108 | if not 0.0 <= gamma < 1.0:
109 | raise ValueError("Invalid gamma parameter: {}".format(gamma))
110 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps,
111 | weight_decay=weight_decay, amsbound=amsbound)
112 | super(AdaBound, self).__init__(params, defaults)
113 |
114 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups))
115 |
116 | def __setstate__(self, state):
117 | super(AdaBound, self).__setstate__(state)
118 | for group in self.param_groups:
119 | group.setdefault('amsbound', False)
120 |
121 | def step(self, closure=None):
122 | """Performs a single optimization step.
123 | Arguments:
124 | closure (callable, optional): A closure that reevaluates the model
125 | and returns the loss.
126 | """
127 | loss = None
128 | if closure is not None:
129 | loss = closure()
130 |
131 | for group, base_lr in zip(self.param_groups, self.base_lrs):
132 | for p in group['params']:
133 | if p.grad is None:
134 | continue
135 | grad = p.grad.data
136 | if grad.is_sparse:
137 | raise RuntimeError(
138 | 'Adam does not support sparse gradients, please consider SparseAdam instead')
139 | amsbound = group['amsbound']
140 |
141 | state = self.state[p]
142 |
143 | # State initialization
144 | if len(state) == 0:
145 | state['step'] = 0
146 | # Exponential moving average of gradient values
147 | state['exp_avg'] = torch.zeros_like(p.data)
148 | # Exponential moving average of squared gradient values
149 | state['exp_avg_sq'] = torch.zeros_like(p.data)
150 | if amsbound:
151 | # Maintains max of all exp. moving avg. of sq. grad. values
152 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
153 |
154 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
155 | if amsbound:
156 | max_exp_avg_sq = state['max_exp_avg_sq']
157 | beta1, beta2 = group['betas']
158 |
159 | state['step'] += 1
160 |
161 | if group['weight_decay'] != 0:
162 | grad = grad.add(group['weight_decay'], p.data)
163 |
164 | # Decay the first and second moment running average coefficient
165 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
166 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
167 | if amsbound:
168 | # Maintains the maximum of all 2nd moment running avg. till now
169 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
170 | # Use the max. for normalizing running avg. of gradient
171 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
172 | else:
173 | denom = exp_avg_sq.sqrt().add_(group['eps'])
174 |
175 | bias_correction1 = 1 - beta1 ** state['step']
176 | bias_correction2 = 1 - beta2 ** state['step']
177 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
178 |
179 | # Applies bounds on actual learning rate
180 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
181 | final_lr = group['final_lr'] * group['lr'] / base_lr
182 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))
183 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step']))
184 | step_size = torch.full_like(denom, step_size)
185 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
186 |
187 | p.data.add_(-step_size)
188 |
189 | return loss
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import visdom
2 | import numpy
3 | import torch
4 | import datetime
5 | from json2html import *
6 |
7 | class NullVisualizer(object):
8 | def __init__(self):
9 | self.name = __name__
10 |
11 | def append_loss(self, epoch, global_iteration, loss, mode='train'):
12 | pass
13 |
14 | def show_images(self, images, title):
15 | pass
16 |
17 | def update_epoch(self, epoch):
18 | pass
19 |
20 | class VisdomPlotVisualizer(object):
21 | def __init__(self, name, server="http://localhost"):
22 | self.visualizer = visdom.Visdom(server=server, port=8097, env=name,\
23 | use_incoming_socket=False)
24 | self.name = name
25 | self.server = server
26 | self.first_train_value = True
27 | self.first_test_value = True
28 | self.plots = {}
29 |
30 | def append_loss(self, epoch, global_iteration, loss, loss_name="total", mode='train'):
31 | plot_name = loss_name + ('_loss' if mode == 'train' else '_error')
32 | opts = (
33 | {
34 | 'title': plot_name,
35 | 'xlabel': 'iterations',
36 | 'ylabel': loss_name
37 | })
38 | loss_value = float(loss.detach().cpu().numpy())
39 | if loss_name not in self.plots:
40 | self.plots[loss_name] = self.visualizer.line(X=numpy.array([global_iteration]),\
41 | Y=numpy.array([loss_value]), opts=opts)
42 | else:
43 | self.visualizer.line(X=numpy.array([global_iteration]),\
44 | Y=numpy.array([loss_value]), win=self.plots[loss_name], name=mode, update='append')
45 |
46 | def config(self, **kwargs):
47 | self.visualizer.text(json2html.convert(json=dict(kwargs)))
48 |
49 | def update_epoch(self, epoch):
50 | pass
51 |
52 | class VisdomImageVisualizer(object):
53 | def __init__(self, name, server="http://localhost", count=2):
54 | self.name = name
55 | self.server = server
56 | self.count = count
57 |
58 | def update_epoch(self, epoch):
59 | self.visualizer = visdom.Visdom(server=self.server, port=8097,\
60 | env=self.name + str(epoch), use_incoming_socket=False)
61 |
62 | def show_separate_images(self, images, title):
63 | b, c, h, w = images.size()
64 | take = self.count if self.count < b else b
65 | recon_images = images.detach().cpu()[:take, [2, 1, 0], :, :]\
66 | if c == 3 else images.detach().cpu()[:take, :, :, :]
67 | for i in range(take):
68 | img = recon_images[i, :, :, :]
69 | opts = (
70 | {
71 | 'title': title + "_" + str(i),
72 | 'width': w, 'height': h
73 | })
74 | self.visualizer.image(img, opts=opts,\
75 | win=self.name + title + "_window_" + str(i))
76 |
77 | def show_map(self, maps, title):
78 | b, c, h, w = maps.size()
79 | maps_cpu = torch.flip(maps, dims=[2]).detach().cpu()[:self.count, :, :, :]
80 | for i in range(min(b, self.count)):
81 | opts = (
82 | {
83 | 'title': title + str(i), 'colormap': 'Viridis'
84 | })
85 | heatmap = maps_cpu[i, :, :, :].squeeze(0)
86 | #TODO: flip images before heatmap call
87 | self.visualizer.heatmap(heatmap,\
88 | opts=opts, win=self.name + title + "_window_" + str(i))
89 |
--------------------------------------------------------------------------------