├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── __init__.py └── dataset_360D.py ├── exporters ├── __init__.py └── image.py ├── filesystem └── file_utils.py ├── infer.py ├── models ├── __init__.py ├── modules.py └── resnet360.py ├── spherical ├── __init__.py ├── cartesian.py ├── derivatives.py ├── grid.py └── weights.py ├── supervision ├── __init__.py ├── direct.py ├── photometric.py ├── smoothness.py ├── splatting.py └── ssim.py ├── test.py ├── train_lr.py ├── train_sv.py ├── train_tc.py ├── train_ud.py └── utils ├── __init__.py ├── checkpoint.py ├── framework.py ├── init.py ├── meters.py ├── opt.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2019, Visual Computing Lab, Information Technologies Institute, Centre for Reseach and Technology Hellas 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spherical View Synthesis for Self-Supervised 360o Depth Estimation 2 | 3 | [![Paper](http://img.shields.io/badge/paper-arxiv.1909.08112-critical.svg?style=plastic)](https://arxiv.org/pdf/1909.08112.pdf) 4 | [![Conference](http://img.shields.io/badge/3DV-2019-blue.svg?style=plastic)](http://3dv19.gel.ulaval.ca/) 5 | [![Project Page](http://img.shields.io/badge/Project-Page-blueviolet.svg?style=plastic)](https://vcl3d.github.io/SphericalViewSynthesis/) 6 | ___ 7 | 8 | # Data 9 | 10 | > ![IMPORTANT](https://img.shields.io/badge/IMPORTANT-DATA_UPDATE-C70039?style=plastic&logo=dataversioncontrol&logoWidth=40&logoColor=C70039) An updated dataset is now available which fixes a critical issue with 3D60, the lighting bias introduced by the light source placed at the origin. More information can be found at the [Pano3D project page](https://vcl3d.github.io/Pano3D/). 11 | > 12 | The 360o stereo data used to train the self-supervised models are available [here](https://vcl3d.github.io/3D60/) and are part of a larger dataset __\[[1](#OmniDepth), [2](#HyperSphere)\]__ that contains rendered color images, depth and normal maps for each viewpoint in a trinocular setup. 13 | 14 | ___ 15 | 16 | ## Train 17 | Training code to reproduce our experiments is available in this repository: 18 | 19 | A set of training scripts are available for each different variant: 20 | 21 | * [`train_ud.py`](./train_ud.py) for vertical stereo (__UD__) training 22 | * [`train_lr.py`](./train_lr.py) for horizontal stereo (__LR__) training 23 | * [`train_tc.py`](./train_tc.py) for trinocular stereo (__TC__) training, using the `photo_ratio` argument to train the different __TC__ variants. 24 | * [`train_sv.py`](./train_sv.py) for supervised (__SV__) training 25 | 26 | The PyTorch implementation of the differentiable depth-image-based forward rendering ([_`splatting`_](./supervision/splatting.py#L9)), presented in __\[[3](#LSI)\]__ and originally implemented in [TensorFlow](https://github.com/google/layered-scene-inference), is also [available](./supervision/splatting.py#L73). 27 | 28 | ## Test 29 | 30 | Our evaluation script [`test.py`](./test.py) also includes the adaptation of the metrics calculation to spherical data that includes [spherical weighting](./spherical/weights.py#L8) and [spiral sampling](./test.py#L92). 31 | 32 | ## Pre-trained Models 33 | Our PyTorch pre-trained models (corresponding to those reported in the paper) are available at our [releases](https://github.com/VCL3D/SphericalViewSynthesis/releases) and contain these model variants: 34 | 35 | * [__UD__ @ epoch 16](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/UD/ud.pt) 36 | * [__TC8__ @ epoch 16](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/TC8/tc8.pt) 37 | * [__TC6__ @ epoch 28](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/TC6/tc6.pt) 38 | * [__TC4__ @ epoch 17](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/TC4/tc4.pt) 39 | * [__TC2__ @ epoch 20](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/TC2/tc2.pt) 40 | * [__LR__ @ epoch 18](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/LR/lr.pt) 41 | * [__SV__ @ epoch 24](https://github.com/VCL3D/SphericalViewSynthesis/releases/download/SV/sv.pt) 42 | 43 | ___ 44 | 45 | ## Citation 46 | If you use this code and/or data, please cite the following: 47 | ``` 48 | @inproceedings{zioulis2019spherical, 49 | author = "Zioulis, Nikolaos and Karakottas, Antonis and Zarpalas, Dimitris and Alvarez, Federic and Daras, Petros", 50 | title = "Spherical View Synthesis for Self-Supervised $360^o$ Depth Estimation", 51 | booktitle = "International Conference on 3D Vision (3DV)", 52 | month = "September", 53 | year = "2019" 54 | } 55 | ``` 56 | 57 | 58 | # References 59 | __\[[1](https://vcl.iti.gr/360-dataset)\]__ Zioulis, N.__\*__, Karakottas, A.__\*__, Zarpalas, D., and Daras, P. (2018). [Omnidepth: Dense depth estimation for indoors spherical panoramas](https://arxiv.org/pdf/1807.09620.pdf). In Proceedings of the European Conference on Computer Vision (ECCV). 60 | 61 | __\[[2](https://vcl3d.github.io/HyperSphereSurfaceRegression/)\]__ Karakottas, A., Zioulis, N., Samaras, S., Ataloglou, D., Gkitsas, V., Zarpalas, D., and Daras, P. (2019). [360o Surface Regression with a Hyper-sphere Loss](https://arxiv.org/pdf/1909.07043.pdf). In Proceedings of the International Conference on 3D Vision (3DV). 62 | 63 | __[3]__ Tulsiani, S., Tucker, R., and Snavely, N. (2018). [Layer-structured 3d scene inference via view synthesis](https://arxiv.org/pdf/1807.10264.pdf). In Proceedings of the European Conference on Computer Vision (ECCV). 64 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_360D import * 2 | -------------------------------------------------------------------------------- /dataset/dataset_360D.py: -------------------------------------------------------------------------------- 1 | ################################### 2 | # 360 dataset pytorch dataloader 3 | ################################### 4 | import os 5 | 6 | import numpy as np 7 | import cv2 8 | import PIL.Image as Image 9 | import datetime 10 | 11 | import torch 12 | from torch.utils.data import Dataset 13 | from torchvision import transforms 14 | # Ignore warnings 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | 18 | ############################################################################################################ 19 | # We use a text file to hold our dataset's filenames 20 | # The "filenames list" has the following format 21 | # 22 | # path/to/Left/rgb.png path/to/Right/rgb.png path/to/Up/rgb.png path/to/Left/depth.exr path/to/Right/depth.exr path/to/Up/depth.exr 23 | # 24 | # We also have a Trinocular version, but you get the feeling. 25 | ############################################################################################################# 26 | 27 | class Dataset360D(Dataset): 28 | #360D Dataset# 29 | def __init__(self, filenamesFile, delimiter, mode, inputShape, transform=None, rescaled=False): 30 | ######################################################################################################### 31 | # Arguments: 32 | # -filenamesFile: Absolute path to the aforementioned filenames .txt file 33 | # -transform : (Optional) transform to be applied on a sample 34 | # -mode : Dataset mode. Available options: mono, lr (Left-Right), ud (Up-Down), tc (Trinocular) 35 | ######################################################################################################### 36 | self.height = inputShape[0] 37 | self.width = inputShape[1] 38 | self.sample = {} # one dataset sample (dictionary) 39 | self.resize2 = transforms.Resize([128, 256]) # function to resize input image by a factor of 2 40 | self.resize4 = transforms.Resize([64, 128]) # function to resize input image by a factor of 4 41 | self.pilToTensor = transforms.ToTensor() if transform is None else transforms.Compose(( 42 | [ 43 | transforms.ToTensor(), # function to convert pillow image to tensor 44 | transform 45 | ]) 46 | ) 47 | self.filenamesFilePath = filenamesFile # file containing image paths to load 48 | self.delimiter = delimiter # delimiter in filenames file 49 | self.mode = mode # dataset mode 50 | self.initDict(self.mode) # initializes dictionary with filepaths 51 | self.loadFilenamesFile() # loads filepaths to dictionary 52 | self.rescaled = rescaled 53 | 54 | # Check if given dataset mode is correct 55 | # Available modes: mono, lr, ud, tc 56 | def checkMode(self, mode): 57 | accepted = False 58 | if (mode != "mono" and mode != "lr" and mode != "ud" and mode != "tc"): 59 | print("{} | Given dataset mode [{}] is not known. Available modes: mono, lr, ud, tc".format(datetime.datetime.now(), mode)) 60 | exit() 61 | else: 62 | accepted = True 63 | return accepted 64 | 65 | # initializes dictionary's lists w.r.t. the dataset's mode 66 | def initDict(self, mode): 67 | if (mode == "mono"): 68 | self.sample["leftRGB"] = [] 69 | self.sample["leftRGB2"] = [] 70 | self.sample["leftRGB4"] = [] 71 | self.sample["leftDepth"] = [] 72 | self.sample["leftDepth2"] = [] 73 | self.sample["leftDepth4"] = [] 74 | elif (mode == "lr"): 75 | self.sample["leftRGB"] = [] 76 | self.sample["leftRGB2"] = [] 77 | self.sample["leftRGB4"] = [] 78 | self.sample["rightRGB"] = [] 79 | self.sample["rightRGB2"] = [] 80 | self.sample["rightRGB4"] = [] 81 | self.sample["leftDepth"] = [] 82 | self.sample["leftDepth2"] = [] 83 | self.sample["leftDepth4"] = [] 84 | self.sample["rightDepth"] = [] 85 | self.sample["rightDepth2"] = [] 86 | self.sample["rightDepth4"] = [] 87 | elif (mode == "ud"): 88 | self.sample["leftRGB"] = [] 89 | self.sample["leftRGB2"] = [] 90 | self.sample["leftRGB4"] = [] 91 | self.sample["upRGB"] = [] 92 | self.sample["upRGB2"] = [] 93 | self.sample["upRGB4"] = [] 94 | self.sample["leftDepth"] = [] 95 | self.sample["leftDepth2"] = [] 96 | self.sample["leftDepth4"] = [] 97 | self.sample["upDepth"] = [] 98 | self.sample["upDepth2"] = [] 99 | self.sample["upDepth4"] = [] 100 | elif (mode == "tc"): 101 | self.sample["leftRGB"] = [] 102 | self.sample["leftRGB2"] = [] 103 | self.sample["leftRGB4"] = [] 104 | self.sample["rightRGB"] = [] 105 | self.sample["rightRGB2"] = [] 106 | self.sample["rightRGB4"] = [] 107 | self.sample["upRGB"] = [] 108 | self.sample["upRGB2"] = [] 109 | self.sample["upRGB4"] = [] 110 | self.sample["leftDepth"] = [] 111 | self.sample["leftDepth2"] = [] 112 | self.sample["leftDepth4"] = [] 113 | self.sample["rightDepth"] = [] 114 | self.sample["rightDepth2"] = [] 115 | self.sample["rightDepth4"] = [] 116 | self.sample["upDepth"] = [] 117 | self.sample["upDepth2"] = [] 118 | self.sample["upDepth4"] = [] 119 | 120 | # configures samples when in mono mode 121 | # loads filepaths to dictionary's list 122 | def initModeMono(self, lines): 123 | for line in lines: 124 | leftRGBPath = line.split(self.delimiter)[0] 125 | leftDepthPath = line.split(self.delimiter)[3] 126 | self.sample["leftRGB"].append(leftRGBPath) 127 | self.sample["leftDepth"].append(leftDepthPath) 128 | 129 | 130 | # configures dataset samples when in Left-Right mode 131 | def initModeLR(self, lines): 132 | for line in lines: 133 | leftRGBPath = line.split(self.delimiter)[0] 134 | rightRGBPath = line.split(self.delimiter)[1] 135 | leftDepthPath = line.split(self.delimiter)[3] 136 | rightDepthPath = line.split(self.delimiter)[4] 137 | self.sample["leftRGB"].append(leftRGBPath) 138 | self.sample["rightRGB"].append(rightRGBPath) 139 | self.sample["leftDepth"].append(leftDepthPath) 140 | self.sample["rightDepth"].append(rightDepthPath) 141 | 142 | # configures dataset samples when in Up-Down mode 143 | def initModeUD(self, lines): 144 | for line in lines: 145 | leftRGBPath = line.split(self.delimiter)[0] 146 | upRGBPath = line.split(self.delimiter)[2] 147 | leftDepthPath = line.split(self.delimiter)[3] 148 | upDepthPath = line.split(self.delimiter)[5] 149 | self.sample["leftRGB"].append(leftRGBPath) 150 | self.sample["upRGB"].append(upRGBPath) 151 | self.sample["leftDepth"].append(leftDepthPath) 152 | self.sample["upDepth"].append(upDepthPath) 153 | 154 | # configures dataset samples when in Trinocular mode 155 | def initModeTC(self, lines): 156 | for line in lines: 157 | leftRGBPath = line.split(self.delimiter)[0] 158 | rightRGBPath = line.split(self.delimiter)[1] 159 | upRGBPath = line.split(self.delimiter)[2] 160 | leftDepthPath = line.split(self.delimiter)[3] 161 | rightDepthPath = line.split(self.delimiter)[4] 162 | upDepthPath = line.split(self.delimiter)[5] 163 | self.sample["leftRGB"].append(leftRGBPath) 164 | self.sample["rightRGB"].append(rightRGBPath) 165 | self.sample["upRGB"].append(upRGBPath) 166 | self.sample["leftDepth"].append(leftDepthPath) 167 | self.sample["rightDepth"].append(rightDepthPath) 168 | self.sample["upDepth"].append(upDepthPath) 169 | 170 | # Loads filenames from .txt file and saves the samples' paths w.r.t. the dataset mode 171 | def loadFilenamesFile(self): 172 | if (not os.path.exists(self.filenamesFilePath)): 173 | print("{} | Filepath [{}] does not exist.".format(datetime.datetime.now(), self.filenamesFilePath)) 174 | exit() 175 | fileID = open(self.filenamesFilePath, "r") 176 | lines = fileID.readlines() 177 | if (lines == 0): 178 | print("{} | Cannot open file: {}".format(datetime.datetime.now(), self.filenamesFilePath)) 179 | exit() 180 | self.length = len(lines) 181 | if (self.mode == "mono"): 182 | self.initModeMono(lines) 183 | elif (self.mode == "lr"): 184 | self.initModeLR(lines) 185 | elif (self.mode == "ud"): 186 | self.initModeUD(lines) 187 | elif (self.mode == "tc"): 188 | self.initModeTC(lines) 189 | 190 | # loads sample from dataset mono mode 191 | def loadItemMono(self, idx): 192 | item = {} 193 | if (idx >= self.length): 194 | print("Index [{}] out of range. Dataset length: {}".format(idx, self.length)) 195 | else: 196 | dtmp = np.array(cv2.imread(self.sample["leftDepth"][idx], cv2.IMREAD_ANYDEPTH)) 197 | left_depth = torch.from_numpy(dtmp) 198 | left_depth.unsqueeze_(0) 199 | if self.rescaled: 200 | dtmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2)) 201 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4)) 202 | left_depth2 = torch.from_numpy(dtmp2) 203 | left_depth2.unsqueeze_(0) 204 | left_depth4 = torch.from_numpy(dtmp4) 205 | left_depth4.unsqueeze_(0) 206 | 207 | pilRGB = Image.open(self.sample["leftRGB"][idx]) 208 | rgb = self.pilToTensor(pilRGB) 209 | if self.rescaled: 210 | rgb2 = self.pilToTensor(self.resize2(pilRGB)) 211 | rgb4 = self.pilToTensor(self.resize4(pilRGB)) 212 | item = { 213 | "leftRGB": rgb, 214 | "leftRGB2": rgb2, 215 | "leftRGB4": rgb4, 216 | "leftDepth": left_depth, 217 | "leftDepth2": left_depth2, 218 | "leftDepth4": left_depth4, 219 | "leftDepth_filename": os.path.basename(self.sample["leftDepth"][idx][:-4]) 220 | } if self.rescaled else { 221 | "leftRGB": rgb, 222 | "leftDepth": left_depth, 223 | "leftDepth_filename": os.path.basename(self.sample["leftDepth"][idx][:-4]) 224 | } 225 | return item 226 | 227 | # loads sample from dataset lr mode 228 | def loadItemLR(self, idx): 229 | item = {} 230 | if (idx >= self.length): 231 | print("Index [{}] out of range. Dataset length: {}".format(idx, self.length)) 232 | else: 233 | leftRGB = Image.open(self.sample["leftRGB"][idx]) 234 | rightRGB = Image.open(self.sample["rightRGB"][idx]) 235 | if self.rescaled: 236 | leftRGB2 = self.pilToTensor(self.resize2(leftRGB)) 237 | leftRGB4 = self.pilToTensor(self.resize4(leftRGB)) 238 | rightRGB2 = self.pilToTensor(self.resize2(rightRGB)) 239 | rightRGB4 = self.pilToTensor(self.resize4(rightRGB)) 240 | 241 | dtmp = np.array(cv2.imread(self.sample["leftDepth"][idx], cv2.IMREAD_ANYDEPTH)) 242 | left_depth = torch.from_numpy(dtmp) 243 | left_depth.unsqueeze_(0) 244 | if self.rescaled: 245 | dtmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2)) 246 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4)) 247 | left_depth2 = torch.from_numpy(dtmp2) 248 | left_depth2.unsqueeze_(0) 249 | left_depth4 = torch.from_numpy(dtmp4) 250 | left_depth4.unsqueeze_(0) 251 | 252 | dtmp = np.array(cv2.imread(self.sample["rightDepth"][idx], cv2.IMREAD_ANYDEPTH)) 253 | right_depth = torch.from_numpy(dtmp) 254 | right_depth.unsqueeze_(0) 255 | if self.rescaled: 256 | tmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2)) 257 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4)) 258 | right_depth2 = torch.from_numpy(dtmp2) 259 | right_depth2.unsqueeze_(0) 260 | right_depth4 = torch.from_numpy(dtmp4) 261 | right_depth4.unsqueeze_(0) 262 | item = { 263 | "leftRGB": self.pilToTensor(leftRGB), 264 | "rightRGB": self.pilToTensor(rightRGB), 265 | "leftRGB2": leftRGB2, 266 | "rightRGB2": rightRGB2, 267 | "leftRGB4": leftRGB4, 268 | "rightRGB4": rightRGB4 , 269 | "leftDepth": left_depth, 270 | 'leftDepth2': left_depth2, 271 | 'leftDepth4': left_depth4, 272 | "rightDepth": right_depth, 273 | "rightDepth2": right_depth2, 274 | "rightDepth4": right_depth4, 275 | 'leftDepth_filename': os.path.basename(self.sample['leftDepth'][idx][:-4]) 276 | } if self.rescaled else { 277 | "leftRGB": self.pilToTensor(leftRGB), 278 | "rightRGB": self.pilToTensor(rightRGB), 279 | "leftDepth": left_depth, 280 | "rightDepth": right_depth, 281 | 'leftDepth_filename': os.path.basename(self.sample['leftDepth'][idx][:-4]) 282 | } 283 | return item 284 | 285 | # loads sample from dataset ud mode 286 | def loadItemUD(self, idx): 287 | item = {} 288 | if (idx >= self.length): 289 | print("Index [{}] out of range. Dataset length: {}".format(idx, self.length)) 290 | else: 291 | leftRGB = Image.open(self.sample["leftRGB"][idx]) 292 | upRGB = Image.open(self.sample["upRGB"][idx]) 293 | if self.rescaled: 294 | leftRGB2 = self.pilToTensor(self.resize2(leftRGB)) 295 | leftRGB4 = self.pilToTensor(self.resize4(leftRGB)) 296 | upRGB2 = self.pilToTensor(self.resize2(upRGB)) 297 | upRGB4 = self.pilToTensor(self.resize4(upRGB)) 298 | dtmp = np.array(cv2.imread(self.sample["leftDepth"][idx], cv2.IMREAD_ANYDEPTH)) 299 | depth = torch.from_numpy(dtmp) 300 | depth.unsqueeze_(0) 301 | if self.rescaled: 302 | dtmp2 = cv2.resize(dtmp, (self.width // 2, self.height // 2)) 303 | dtmp4 = cv2.resize(dtmp, (self.width // 4, self.height // 4)) 304 | depth2 = torch.from_numpy(dtmp2) 305 | depth2.unsqueeze_(0) 306 | depth4 = torch.from_numpy(dtmp4) 307 | depth4.unsqueeze_(0) 308 | 309 | 310 | dtmp = np.array(cv2.imread(self.sample["upDepth"][idx], cv2.IMREAD_ANYDEPTH)) 311 | up_depth = torch.from_numpy(dtmp) 312 | up_depth.unsqueeze_(0) 313 | if self.rescaled: 314 | tmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2)) 315 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4)) 316 | up_depth2 = torch.from_numpy(dtmp2) 317 | up_depth2.unsqueeze_(0) 318 | up_depth4 = torch.from_numpy(dtmp4) 319 | up_depth4.unsqueeze_(0) 320 | 321 | item = { 322 | "leftRGB": self.pilToTensor(leftRGB), 323 | "upRGB": self.pilToTensor(upRGB), 324 | "leftRGB2": leftRGB2, 325 | "upRGB2": upRGB2, 326 | "leftRGB4": leftRGB4, 327 | "upRGB4": upRGB4, 328 | "leftDepth": depth, 329 | "leftDepth2": depth2, 330 | "leftDepth4": depth4, 331 | "upDepth": up_depth, 332 | "upDepth2": up_depth2, 333 | "upDepth4": up_depth4, 334 | 'leftDepth_filename': os.path.basename(self.sample['leftDepth'][idx][:-4]) 335 | } if self.rescaled else { 336 | "leftRGB": self.pilToTensor(leftRGB), 337 | "upRGB": self.pilToTensor(upRGB), 338 | "leftDepth": depth, 339 | "upDepth": up_depth, 340 | 'leftDepth_filename': os.path.basename(self.sample['leftDepth'][idx][:-4]) 341 | } 342 | return item 343 | 344 | # loads sample from dataset tc mode 345 | def loadItemTC(self, idx): 346 | item = {} 347 | if (idx >= self.length): 348 | print("Index [{}] out of range. Dataset length: {}".format(idx, self.length)) 349 | else: 350 | leftRGB = Image.open(self.sample["leftRGB"][idx]) 351 | rightRGB = Image.open(self.sample["rightRGB"][idx]) 352 | upRGB = Image.open(self.sample["upRGB"][idx]) 353 | if self.rescaled: 354 | leftRGB2 = self.pilToTensor(self.resize2(leftRGB)) 355 | leftRGB4 = self.pilToTensor(self.resize4(leftRGB)) 356 | rightRGB2 = self.pilToTensor(self.resize2(rightRGB)) 357 | rightRGB4 = self.pilToTensor(self.resize4(rightRGB)) 358 | upRGB2 = self.pilToTensor(self.resize2(upRGB)) 359 | upRGB4 = self.pilToTensor(self.resize4(upRGB)) 360 | 361 | dtmp = np.array(cv2.imread(self.sample["leftDepth"][idx], cv2.IMREAD_ANYDEPTH)) 362 | depth = torch.from_numpy(dtmp) 363 | depth.unsqueeze_(0) 364 | if self.rescaled: 365 | dtmp2 = cv2.resize(dtmp, (self.width // 2, self.height // 2)) 366 | dtmp4 = cv2.resize(dtmp, (self.width // 4, self.height // 4)) 367 | depth2 = torch.from_numpy(dtmp2) 368 | depth2.unsqueeze_(0) 369 | depth4 = torch.from_numpy(dtmp4) 370 | depth4.unsqueeze_(0) 371 | 372 | dtmp = np.array(cv2.imread(self.sample["rightDepth"][idx], cv2.IMREAD_ANYDEPTH)) 373 | right_depth = torch.from_numpy(dtmp) 374 | right_depth.unsqueeze_(0) 375 | if self.rescaled: 376 | tmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2)) 377 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4)) 378 | right_depth2 = torch.from_numpy(dtmp2) 379 | right_depth2.unsqueeze_(0) 380 | right_depth4 = torch.from_numpy(dtmp4) 381 | right_depth4.unsqueeze_(0) 382 | 383 | dtmp = np.array(cv2.imread(self.sample["upDepth"][idx], cv2.IMREAD_ANYDEPTH)) 384 | up_depth = torch.from_numpy(dtmp) 385 | up_depth.unsqueeze_(0) 386 | if self.rescaled: 387 | tmp2 = cv2.resize(dtmp, (dtmp.shape[1] // 2, dtmp.shape[0] // 2)) 388 | dtmp4 = cv2.resize(dtmp, (dtmp.shape[1] // 4, dtmp.shape[0] // 4)) 389 | up_depth2 = torch.from_numpy(dtmp2) 390 | up_depth2.unsqueeze_(0) 391 | up_depth4 = torch.from_numpy(dtmp4) 392 | up_depth4.unsqueeze_(0) 393 | 394 | item = { 395 | "leftRGB": self.pilToTensor(leftRGB), 396 | "rightRGB": self.pilToTensor(rightRGB), 397 | "upRGB": self.pilToTensor(upRGB), 398 | "leftRGB2": leftRGB2, 399 | "rightRGB2": rightRGB2, 400 | "upRGB2": upRGB2, 401 | "leftRGB4": leftRGB4, 402 | "rightRGB4": rightRGB4, 403 | "upRGB4": upRGB4, 404 | "leftDepth": depth, 405 | "leftDepth2": depth2, 406 | "leftDepth4": depth4, 407 | "upDepth": up_depth, 408 | "upDepth2": up_depth2, 409 | "upDepth4": up_depth4, 410 | "rightDepth": right_depth, 411 | "rightDepth2": right_depth2, 412 | "rightDepth4": right_depth4, 413 | "depthFilename": os.path.basename(self.sample["leftDepth"][idx][:-4]) 414 | } if self.rescaled else { 415 | "leftRGB": self.pilToTensor(leftRGB), 416 | "rightRGB": self.pilToTensor(rightRGB), 417 | "upRGB": self.pilToTensor(upRGB), 418 | "leftDepth": depth, 419 | "rightDepth": right_depth, 420 | "upDepth": up_depth, 421 | "depthFilename": os.path.basename(self.sample["leftDepth"][idx][:-4]) 422 | } 423 | return item 424 | 425 | # torch override 426 | # returns samples length 427 | def __len__(self): 428 | return self.length 429 | 430 | # torch override 431 | def __getitem__(self, idx): 432 | if (self.mode == "mono"): 433 | return self.loadItemMono(idx) 434 | elif(self.mode == "lr"): 435 | return self.loadItemLR(idx) 436 | elif(self.mode == "ud"): 437 | return self.loadItemUD(idx) 438 | elif(self.mode == "tc"): 439 | return self.loadItemTC(idx) 440 | 441 | 442 | 443 | -------------------------------------------------------------------------------- /exporters/__init__.py: -------------------------------------------------------------------------------- 1 | from .image import * -------------------------------------------------------------------------------- /exporters/image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy 4 | 5 | def save_image(filename, tensor, scale=255.0): 6 | b, _, __, ___ = tensor.size() 7 | for n in range(b): 8 | array = tensor[n, :, :, :].detach().cpu().numpy() 9 | array = array.transpose(1, 2, 0) * scale 10 | cv2.imwrite(filename.replace("#", str(n)), array) 11 | 12 | def save_depth(filename, tensor, scale=1000.0): 13 | b, _, __, ___ = tensor.size() 14 | for n in range(b): 15 | array = tensor[n, :, :, :].detach().cpu().numpy() 16 | array = array.transpose(1, 2, 0) * scale 17 | array = numpy.uint16(array) 18 | cv2.imwrite(filename.replace("#", str(n)), array) 19 | 20 | def save_data(filename, tensor, scale=1000.0): 21 | b, _, __, ___ = tensor.size() 22 | for n in range(b): 23 | array = tensor[n, :, :, :].detach().cpu().numpy() 24 | array = array.transpose(1, 2, 0) * scale 25 | array = numpy.float32(array) 26 | cv2.imwrite(filename.replace("#", str(n)), array) 27 | -------------------------------------------------------------------------------- /filesystem/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ''' 4 | Filesystem class 5 | provides file control utilities like tensor saving etc. 6 | ''' 7 | class Filesystem: 8 | def __init__(self): 9 | self.cwd = os.getcwd() 10 | if os.path.isfile(self.cwd): 11 | self.cwd = os.path.basename(self.cwd) 12 | ''' 13 | Creates directory 14 | either by giving the absolute path to create 15 | or the relative path w.r.t. the current working directory 16 | 17 | \param path the path to create 18 | ''' 19 | def mkdir(self, path): 20 | if os.path.isabs(path): 21 | if not os.path.exists(path): 22 | os.mkdir(path) 23 | else: 24 | pathToCreate = os.path.join(self.cwd, path) 25 | if not os.path.exists(pathToCreate): 26 | os.mkdir(pathToCreate) -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import numpy 5 | import cv2 6 | 7 | import torch 8 | 9 | import models 10 | import utils 11 | import exporters 12 | 13 | def parse_arguments(args): 14 | usage_text = ( 15 | "Semi-supervised Spherical Depth Estimation Testing." 16 | ) 17 | parser = argparse.ArgumentParser(description=usage_text) 18 | parser.add_argument("--input_path", type=str, help="Path to the input spherical panorama image.") 19 | parser.add_argument('--weights', type=str, help='Path to the trained weights file.') 20 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.') 21 | return parser.parse_known_args(args) 22 | 23 | if __name__ == "__main__": 24 | args, unknown = parse_arguments(sys.argv) 25 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0] 26 | # device & visualizers 27 | device = torch.device("cuda:{}" .format(gpus[0])\ 28 | if torch.cuda.is_available() and len(gpus) > 0 and gpus[0] >= 0\ 29 | else "cpu") 30 | # model 31 | model = models.get_model("resnet_coord", {}) 32 | utils.init.initialize_weights(model, args.weights, pred_bias=None) 33 | model = model.to(device) 34 | # test data 35 | width, height = 512, 256 36 | if not os.path.exists(args.input_path): 37 | print("Input image path does not exist (%s)." % args.input_path) 38 | exit(-1) 39 | img = cv2.imread(args.input_path) 40 | h, w, _ = img.shape 41 | if h != height and w != width: 42 | img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 43 | img = img.transpose(2, 0, 1) / 255.0 44 | img = torch.from_numpy(img).float().expand(1, -1, -1, -1) 45 | model.eval() 46 | with torch.no_grad(): 47 | left_rgb = img.to(device) 48 | ''' Prediction ''' 49 | left_depth_pred = torch.abs(model(left_rgb)) 50 | exporters.image.save_data(os.path.join( 51 | os.path.dirname(args.input_path), 52 | os.path.splitext(os.path.basename( 53 | args.input_path))[0] + "_depth.exr"), 54 | left_depth_pred, scale=1.0) 55 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet360 import * 2 | 3 | import sys 4 | 5 | def get_model(name, model_params): 6 | if name == 'resnet_coord': 7 | return ResNet360( 8 | # conv_type='standard', activation='elu', norm_type='none', \ 9 | conv_type='coord', activation='elu', norm_type='none', \ 10 | width=512, 11 | ) 12 | else: 13 | print("Could not find the requested model ({})".format(name), file=sys.stderr) -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | ''' 6 | Code adapted from https://github.com/uber-research/coordconv 7 | accompanying the paper "An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution" (NeurIPS 2018) 8 | ''' 9 | 10 | class AddCoords360(nn.Module): 11 | def __init__(self, x_dim=64, y_dim=64, with_r=False): 12 | super(AddCoords360, self).__init__() 13 | self.x_dim = int(x_dim) 14 | self.y_dim = int(y_dim) 15 | self.with_r = with_r 16 | 17 | def forward(self, input_tensor): 18 | """ 19 | input_tensor: (batch, c, x_dim, y_dim) 20 | """ 21 | batch_size_tensor = input_tensor.shape[0] 22 | 23 | xx_ones = torch.ones([1, self.y_dim], dtype=torch.float32, device=input_tensor.device) 24 | xx_ones = xx_ones.unsqueeze(-1) 25 | 26 | xx_range = torch.arange(self.x_dim, dtype=torch.float32, device=input_tensor.device).unsqueeze(0) 27 | xx_range = xx_range.unsqueeze(1) 28 | 29 | xx_channel = torch.matmul(xx_ones, xx_range) 30 | xx_channel = xx_channel.unsqueeze(-1) 31 | 32 | yy_ones = torch.ones([1, self.x_dim], dtype=torch.float32, device=input_tensor.device) 33 | yy_ones = yy_ones.unsqueeze(1) 34 | 35 | yy_range = torch.arange(self.y_dim, dtype=torch.float32, device=input_tensor.device).unsqueeze(0) 36 | yy_range = yy_range.unsqueeze(-1) 37 | 38 | yy_channel = torch.matmul(yy_range, yy_ones) 39 | yy_channel = yy_channel.unsqueeze(-1) 40 | 41 | xx_channel = xx_channel.permute(0, 3, 2, 1) 42 | yy_channel = yy_channel.permute(0, 3, 2, 1) 43 | 44 | xx_channel = xx_channel.float() / (self.x_dim - 1) 45 | yy_channel = yy_channel.float() / (self.y_dim - 1) 46 | 47 | xx_channel = xx_channel * 2 - 1 48 | yy_channel = yy_channel * 2 - 1 49 | 50 | xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1) 51 | yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1) 52 | 53 | ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 54 | 55 | if self.with_r: 56 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) 57 | ret = torch.cat([ret, rr], dim=1) 58 | 59 | return ret 60 | 61 | class CoordConv360(nn.Module): 62 | """CoordConv layer as in the paper.""" 63 | def __init__(self, x_dim, y_dim, with_r, in_channels, out_channels, kernel_size, *args, **kwargs): 64 | super(CoordConv360, self).__init__() 65 | self.addcoords = AddCoords360(x_dim=x_dim, y_dim=y_dim, with_r=with_r) 66 | in_size = in_channels+2 67 | if with_r: 68 | in_size += 1 69 | self.conv = nn.Conv2d(in_size, out_channels, kernel_size, **kwargs) 70 | 71 | def forward(self, input_tensor): 72 | ret = self.addcoords(input_tensor) 73 | ret = self.conv(ret) 74 | return ret 75 | 76 | 77 | def create_conv(in_size, out_size, conv_type, padding=1, stride=1, kernel_size=3, width=512): 78 | if conv_type == 'standard': 79 | return nn.Conv2d(in_channels=in_size, out_channels=out_size, \ 80 | kernel_size=kernel_size, padding=padding, stride=stride) 81 | elif conv_type == 'coord': 82 | return CoordConv360(x_dim=width / 2.0, y_dim=width,\ 83 | with_r=False, kernel_size=kernel_size, stride=stride,\ 84 | in_channels=in_size, out_channels=out_size, padding=padding) 85 | 86 | def create_activation(activation): 87 | if activation == 'relu': 88 | return nn.ReLU(inplace=True) 89 | elif activation == 'elu': 90 | return nn.ELU(inplace=True) 91 | 92 | class Identity(nn.Module): 93 | def forward(self, x): 94 | return x 95 | 96 | def create_normalization(out_size, norm_type): 97 | if norm_type == 'batchnorm': 98 | return nn.BatchNorm2d(out_size) 99 | elif norm_type == 'groupnorm': 100 | return nn.GroupNorm(out_size // 4, out_size) 101 | elif norm_type == 'none': 102 | return Identity() 103 | 104 | def create_downscale(out_size, down_mode): 105 | if down_mode == 'pool': 106 | return torch.nn.modules.MaxPool2d(2) 107 | elif down_mode == 'downconv': 108 | return nn.Conv2d(in_channels=out_size, out_channels=out_size, kernel_size=3,\ 109 | stride=2, padding=1, bias=False) 110 | elif down_mode == 'gaussian': 111 | print("Not implemented") -------------------------------------------------------------------------------- /models/resnet360.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import functools 5 | 6 | from .modules import * 7 | 8 | # adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 9 | 10 | class ResNet360(nn.Module): 11 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. 12 | 13 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) 14 | """ 15 | def __init__( 16 | self, 17 | in_channels=3, 18 | out_channels=1, 19 | depth=5, 20 | wf=32, 21 | conv_type='coord', 22 | padding='kernel', 23 | norm_type='none', 24 | activation='elu', 25 | up_mode='upconv', 26 | down_mode='downconv', 27 | width=512, 28 | use_dropout=False, 29 | padding_type='reflect', 30 | ): 31 | """Construct a Resnet-based generator 32 | 33 | Parameters: 34 | input_nc (int) -- the number of channels in input images 35 | output_nc (int) -- the number of channels in output images 36 | ngf (int) -- the number of filters in the last conv layer 37 | norm_layer -- normalization layer 38 | use_dropout (bool) -- if use dropout layers 39 | n_blocks (int) -- the number of ResNet blocks 40 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 41 | """ 42 | assert(depth >= 0) 43 | super(ResNet360, self).__init__() 44 | model = ( 45 | [ 46 | create_conv(in_channels, wf, conv_type, \ 47 | kernel_size=7, padding=3, stride=1, width=width), 48 | create_normalization(wf, norm_type), 49 | create_activation(activation) 50 | ] 51 | ) 52 | 53 | n_downsampling = 2 54 | for i in range(n_downsampling): 55 | mult = 2 ** i 56 | model += ( 57 | [ 58 | create_conv(wf * mult, wf * mult * 2, conv_type, \ 59 | kernel_size=3, stride=2, padding=1, width=width // (i+1)), 60 | create_normalization(wf * mult * 2, norm_type), 61 | create_activation(activation) 62 | ] 63 | ) 64 | 65 | mult = 2 ** n_downsampling 66 | for i in range(depth): 67 | model += [ResnetBlock(wf * mult, activation=activation, \ 68 | norm_type=norm_type, conv_type=conv_type, \ 69 | width=width // (2 ** n_downsampling))] 70 | 71 | for i in range(n_downsampling): 72 | mult = 2 ** (n_downsampling - i) 73 | model += ( 74 | [ 75 | nn.ConvTranspose2d(wf * mult, int(wf * mult / 2), 76 | kernel_size=3, stride=2, 77 | padding=1, output_padding=1), 78 | create_normalization(int(wf * mult / 2), norm_type), 79 | create_activation(activation) 80 | ] 81 | ) 82 | 83 | model += [create_conv(wf, out_channels, conv_type, \ 84 | kernel_size=7, padding=3, width=width)] 85 | 86 | self.model = nn.Sequential(*model) 87 | 88 | def forward(self, input): 89 | """Standard forward""" 90 | return self.model(input) 91 | 92 | 93 | class ResnetBlock(nn.Module): 94 | """Define a Resnet block""" 95 | 96 | def __init__(self, dim, norm_type, conv_type, activation, width): 97 | """Initialize the Resnet block 98 | 99 | A resnet block is a conv block with skip connections 100 | We construct a conv block with build_conv_block function, 101 | and implement skip connections in function. 102 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 103 | """ 104 | super(ResnetBlock, self).__init__() 105 | conv_block = [] 106 | conv_block +=( 107 | [ 108 | create_conv(dim, dim, conv_type, width=width), 109 | create_normalization(dim, norm_type), 110 | create_activation(activation), 111 | ] 112 | ) 113 | conv_block +=( 114 | [ 115 | create_conv(dim, dim, conv_type, width=width), 116 | create_normalization(dim, norm_type), 117 | ] 118 | ) 119 | 120 | self.block = nn.Sequential(*conv_block) 121 | 122 | def forward(self, x): 123 | """Forward function (with skip connections)""" 124 | out = x + self.block(x) # add skip connections 125 | return out -------------------------------------------------------------------------------- /spherical/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import * 2 | from .cartesian import * 3 | from .derivatives import * 4 | from .weights import * -------------------------------------------------------------------------------- /spherical/cartesian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .grid import * 4 | 5 | ''' 6 | Cartesian coordinates extraction from Spherical coordinates 7 | z is forward axis 8 | y is the up axis 9 | x is the right axis 10 | r is the radius (i.e. spherical depth) 11 | phi is the longitude/azimuthial rotation angle (defined on the x-z plane) 12 | theta is the latitude/elevation rotation angle (defined on the y-z plane) 13 | ''' 14 | def coord_x(sgrid, depth): 15 | return ( # r * sin(phi) * sin(theta) -> r * cos(phi) * -cos(theta) in our offsets 16 | depth # this is due to the offsets as explained below 17 | * torch.cos(phi(sgrid)) # long = x - 3 * pi / 2 18 | * -1 * torch.cos(theta(sgrid)) # lat = y - pi / 2 19 | ) 20 | 21 | def coord_y(sgrid, depth): 22 | return ( # r * cos(theta) -> r * sin(theta) in our offsets 23 | depth # this is due to the offsets as explained below 24 | * torch.sin(theta(sgrid)) # lat = y - pi / 2 25 | ) 26 | 27 | def coord_z(sgrid, depth): 28 | return ( # r * cos(phi) * sin(theta) -> r * -sin(phi) * -cos(theta) in our offsets 29 | depth # this is due to the offsets as explained above 30 | * torch.sin(phi(sgrid)) # * -1 31 | * torch.cos(theta(sgrid)) # * -1 32 | ) # the -1s cancel out 33 | 34 | def coords_3d(sgrid, depth): 35 | return torch.cat( 36 | ( 37 | coord_x(sgrid, depth), 38 | coord_y(sgrid, depth), 39 | coord_z(sgrid, depth) 40 | ), dim=1 41 | ) 42 | 43 | def xi(pcloud): 44 | return pcloud[:, 0, :, :].unsqueeze(1) 45 | 46 | def yi(pcloud): 47 | return pcloud[:, 1, :, :].unsqueeze(1) 48 | 49 | def zeta(pcloud): 50 | return pcloud[:, 2, :, :].unsqueeze(1) -------------------------------------------------------------------------------- /spherical/derivatives.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .grid import * 4 | from .cartesian import * 5 | 6 | ''' Image (I) spatial derivatives ''' 7 | def dI_du(img): 8 | right_pad = (0, 1, 0, 0) 9 | tensor = torch.nn.functional.pad(img, right_pad, mode="replicate") 10 | gu = tensor[:, :, :, :-1] - tensor[:, :, :, 1:] # NCHW 11 | return gu 12 | 13 | def dI_dv(img): 14 | bottom_pad = (0, 0, 0, 1) 15 | tensor = torch.nn.functional.pad(img, bottom_pad, mode="replicate") 16 | dv = tensor[:, :, :-1, :] - tensor[:, :, 1:, :] # NCHW 17 | return dv 18 | 19 | def dI_duv(img): 20 | du = dI_du(img) 21 | dv = dI_dv(img) 22 | duv = torch.cat((du, dv), dim=1) 23 | duv_mag = torch.norm(duv, p=2, dim=1, keepdim=True) 24 | return duv_mag 25 | 26 | ''' 27 | Spherical coordinates (r, phi, theta) derivatives 28 | w.r.t. their Cartesian counterparts (x, y, z) 29 | ''' 30 | def dr_dx(sgrid): 31 | return ( # sin(lat) * sin(long) -> cos(long) * -cos(lat) 32 | -1 # this is due to the offsets as explaned below 33 | * torch.cos(phi(sgrid)) # long = x - 3 * pi / 2 34 | * torch.cos(theta(sgrid)) # lat = y - pi / 2 35 | ) # the depth (radius) distortion for each spherical coord with a horizontal baseline 36 | 37 | def dphi_dx(sgrid): 38 | return ( # cos(long) / sin(lat) -> -sin(long) / -cos(lat) 39 | torch.sin(phi(sgrid)) # * -1 40 | / torch.cos(theta(sgrid)) # * -1 41 | ) # the -1s cancel out and are ommitted 42 | 43 | def dtheta_dx(sgrid): 44 | return ( # sin(long) * cos(lat) -> cos(long) * sin(lat) 45 | torch.cos(phi(sgrid)) * torch.sin(theta(sgrid)) 46 | ) 47 | 48 | def dtheta_dy(sgrid): 49 | return ( # -sin(lat) -> -1 * -cos(lat) == cos(lat) 50 | torch.cos(theta(sgrid)) 51 | ) 52 | 53 | def dphi_horizontal(sgrid, depth, baseline): 54 | _, __, h, ___ = depth.size() 55 | return torch.clamp( 56 | ( 57 | torch.sin(phi(sgrid)) 58 | / ( 59 | depth 60 | * torch.cos(theta(sgrid)) 61 | ) 62 | * baseline 63 | * (h / numpy.pi) 64 | ), 65 | -h, h # h = w/2 the max disparity due to our spherical nature (i.e. front/back symmetry) 66 | ) 67 | 68 | def dtheta_horizontal(sgrid, depth, baseline): 69 | _, __, h, ___ = depth.size() 70 | return torch.clamp( 71 | ( 72 | torch.cos(phi(sgrid)) 73 | * torch.sin(theta(sgrid)) 74 | * baseline 75 | / depth 76 | * (h / numpy.pi) 77 | ), 78 | 0, h 79 | ) 80 | 81 | def dr_horizontal(sgrid, baseline): 82 | return ( # sin(lat) * sin(long) -> cos(long) * -cos(lat) 83 | -1 # this is due to the offsets as explained below 84 | * torch.cos(phi(sgrid)) # long = x - 3 * pi / 2 85 | * torch.cos(theta(sgrid)) # lat = y - pi / 2 86 | * baseline 87 | ) # the depth (radius) distortion for each spherical coord with a horizontal baseline 88 | 89 | def dtheta_vertical(sgrid, depth, baseline): 90 | _, __, h, ___ = depth.size() 91 | return ( 92 | torch.cos(theta(sgrid)) 93 | * baseline 94 | / depth 95 | * (h / numpy.pi) 96 | ) 97 | 98 | ''' 99 | Structured Point Cloud Vertices (V) spatial derivatives 100 | ''' 101 | def dV_dx(pcloud): 102 | return dI_duv(xi(pcloud)) 103 | 104 | def dV_dy(pcloud): 105 | return dI_duv(yi(pcloud)) 106 | 107 | def dV_dz(pcloud): 108 | return dI_duv(zeta(pcloud)) 109 | 110 | def dV_dxyz(pcloud): 111 | du_x = dI_du(xi(pcloud)) 112 | dv_x = dI_dv(xi(pcloud)) 113 | 114 | du_y = dI_du(yi(pcloud)) 115 | dv_y = dI_dv(yi(pcloud)) 116 | 117 | du_z = dI_du(zeta(pcloud)) 118 | dv_z = dI_dv(zeta(pcloud)) 119 | 120 | du_xyz = torch.abs(du_x) + torch.abs(du_y) + torch.abs(du_z) 121 | dv_xyz = torch.abs(dv_x) + torch.abs(dv_y) + torch.abs(dv_z) 122 | 123 | duv_xyz = torch.cat((du_xyz, dv_xyz), dim=1) 124 | duv__xyz_mag = torch.norm(duv_xyz, p=2, dim=1, keepdim=True) 125 | return duv__xyz_mag -------------------------------------------------------------------------------- /spherical/grid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy 3 | 4 | def create_image_grid(width, height, data_type=torch.float32): 5 | v_range = ( 6 | torch.arange(0, height) # [0 - h] 7 | .view(1, height, 1) # [1, [0 - h], 1] 8 | .expand(1, height, width) # [1, [0 - h], W] 9 | .type(data_type) # [1, H, W] 10 | ) 11 | u_range = ( 12 | torch.arange(0, width) # [0 - w] 13 | .view(1, 1, width) # [1, 1, [0 - w]] 14 | .expand(1, height, width) # [1, H, [0 - w]] 15 | .type(data_type) # [1, H, W] 16 | ) 17 | return torch.stack((u_range, v_range), dim=1) # [1, 2, H, W] 18 | 19 | def coord_u(uvgrid): 20 | return uvgrid[:, 0, :, :].unsqueeze(1) 21 | 22 | def coord_v(uvgrid): 23 | return uvgrid[:, 1, :, :].unsqueeze(1) 24 | 25 | def create_spherical_grid(width, horizontal_shift=(-numpy.pi - numpy.pi / 2.0), 26 | vertical_shift=(-numpy.pi / 2.0), data_type=torch.float32): 27 | height = int(width // 2.0) 28 | v_range = ( 29 | torch.arange(0, height) # [0 - h] 30 | .view(1, height, 1) # [1, [0 - h], 1] 31 | .expand(1, height, width) # [1, [0 - h], W] 32 | .type(data_type) # [1, H, W] 33 | ) 34 | u_range = ( 35 | torch.arange(0, width) # [0 - w] 36 | .view(1, 1, width) # [1, 1, [0 - w]] 37 | .expand(1, height, width) # [1, H, [0 - w]] 38 | .type(data_type) # [1, H, W] 39 | ) 40 | u_range *= (2 * numpy.pi / width) # [0, 2 * pi] 41 | v_range *= (numpy.pi / height) # [0, pi] 42 | u_range += horizontal_shift # [-hs, 2 * pi - hs] -> standard values are [-3 * pi / 2, pi / 2] 43 | v_range += vertical_shift # [-vs, pi - vs] -> standard values are [-pi / 2, pi / 2] 44 | return torch.stack((u_range, v_range), dim=1) # [1, 2, H, W] 45 | 46 | def phi(sgrid): # longitude or azimuth 47 | return sgrid[:, 0, :, :].unsqueeze(1) 48 | 49 | def azimuth(sgrid): # longitude or phi 50 | return sgrid[:, 0, :, :].unsqueeze(1) 51 | 52 | def longitude(sgrid): # phi or azimuth 53 | return sgrid[:, 0, :, :].unsqueeze(1) 54 | 55 | def theta(sgrid): # latitude or elevation 56 | return sgrid[:, 1, :, :].unsqueeze(1) 57 | 58 | def elevation(sgrid): # theta or elevation 59 | return sgrid[:, 1, :, :].unsqueeze(1) 60 | 61 | def latitude(sgrid): # latitude or theta 62 | return sgrid[:, 1, :, :].unsqueeze(1) -------------------------------------------------------------------------------- /spherical/weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .grid import * 4 | 5 | def phi_confidence(sgrid): # fading towards horizontal singularities 6 | return torch.abs(torch.sin(phi(sgrid))) 7 | 8 | def theta_confidence(sgrid): # fading towards vertical singularities 9 | return torch.abs(torch.cos(theta(sgrid))) 10 | 11 | def spherical_confidence(sgrid, zero_low=0.0, one_high=1.0): 12 | weights = phi_confidence(sgrid) * theta_confidence(sgrid) 13 | weights[weights < zero_low] = 0.0 14 | weights[weights > one_high] = 1.0 15 | return weights -------------------------------------------------------------------------------- /supervision/__init__.py: -------------------------------------------------------------------------------- 1 | from .splatting import * 2 | from .photometric import * 3 | from .smoothness import * 4 | from .direct import * -------------------------------------------------------------------------------- /supervision/direct.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def calculate_berhu_loss(pred, gt, mask, weights): 4 | diff = gt - pred 5 | abs_diff = torch.abs(diff) 6 | c = torch.max(abs_diff).item() / 5 7 | leq = (abs_diff <= c).float() 8 | l2_losses = (diff**2 + c**2) / (2 * c) 9 | loss = leq * abs_diff + (1 - leq) * l2_losses 10 | _, c, __, ___ = loss.size() 11 | count = torch.sum(mask, dim=[1, 2, 3], keepdim=True).float() 12 | masked_loss = loss * mask.float() 13 | weighted_loss = masked_loss * weights 14 | return torch.mean(torch.sum(weighted_loss, dim=[1, 2, 3], keepdim=True) / count) -------------------------------------------------------------------------------- /supervision/photometric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .ssim import * 4 | 5 | class PhotometricLossParameters(object): 6 | def __init__(self, alpha=0.85, l1_estimator='none',\ 7 | ssim_estimator='none', window=7, std=1.5, ssim_mode='gaussian'): 8 | super(PhotometricLossParameters, self).__init__() 9 | self.alpha = alpha 10 | self.l1_estimator = l1_estimator 11 | self.ssim_estimator = ssim_estimator 12 | self.window = window 13 | self.std = std 14 | self.ssim_mode = ssim_mode 15 | 16 | def get_alpha(self): 17 | return self.alpha 18 | 19 | def get_l1_estimator(self): 20 | return self.l1_estimator 21 | 22 | def get_ssim_estimator(self): 23 | return self.ssim_estimator 24 | 25 | def get_window(self): 26 | return self.window 27 | 28 | def get_std(self): 29 | return self.std 30 | 31 | def get_ssim_mode(self): 32 | return self.ssim_mode 33 | 34 | def calculate_loss(pred, gt, params, mask, weights): 35 | valid_mask = mask.type(gt.dtype) 36 | masked_gt = gt * valid_mask 37 | masked_pred = pred * valid_mask 38 | l1 = torch.abs(masked_gt - masked_pred) 39 | d_ssim = torch.clamp( 40 | ( 41 | 1 - ssim_loss(masked_pred, masked_gt, kernel_size=params.get_window(), 42 | std=params.get_std(), mode=params.get_ssim_mode()) 43 | ) / 2, 0, 1) 44 | loss = ( 45 | d_ssim * params.get_alpha() 46 | + l1 * (1 - params.get_alpha()) 47 | ) 48 | loss *= valid_mask 49 | loss *= weights 50 | count = torch.sum(mask, dim=[1, 2, 3], keepdim=True).float() 51 | return torch.mean(torch.sum(loss, dim=[1, 2, 3], keepdim=True) / count) 52 | -------------------------------------------------------------------------------- /supervision/smoothness.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def guided_smoothness_loss(input_duv, guide_duv, mask, weights): 4 | guidance_weights = torch.exp(-guide_duv) 5 | smoothness = input_duv * guidance_weights 6 | smoothness[~mask] = 0.0 7 | smoothness *= weights 8 | return torch.sum(smoothness) / torch.sum(mask) -------------------------------------------------------------------------------- /supervision/splatting.py: -------------------------------------------------------------------------------- 1 | ''' 2 | PyTorch implementation of https://github.com/google/layered-scene-inference 3 | accompanying the paper "Layer-structured 3D Scene Inference via View Synthesis", 4 | ECCV 2018 https://shubhtuls.github.io/lsi/ 5 | ''' 6 | 7 | import torch 8 | 9 | def __splat__(values, coords, splatted): 10 | b, c, h, w = splatted.size() 11 | uvs = coords 12 | u = uvs[:, 0, :, :].unsqueeze(1) 13 | v = uvs[:, 1, :, :].unsqueeze(1) 14 | 15 | u0 = torch.floor(u) 16 | u1 = u0 + 1 17 | v0 = torch.floor(v) 18 | v1 = v0 + 1 19 | 20 | u0_safe = torch.clamp(u0, 0.0, w-1) 21 | v0_safe = torch.clamp(v0, 0.0, h-1) 22 | u1_safe = torch.clamp(u1, 0.0, w-1) 23 | v1_safe = torch.clamp(v1, 0.0, h-1) 24 | 25 | u0_w = (u1 - u) * (u0 == u0_safe).detach().type(values.dtype) 26 | u1_w = (u - u0) * (u1 == u1_safe).detach().type(values.dtype) 27 | v0_w = (v1 - v) * (v0 == v0_safe).detach().type(values.dtype) 28 | v1_w = (v - v0) * (v1 == v1_safe).detach().type(values.dtype) 29 | 30 | top_left_w = u0_w * v0_w 31 | top_right_w = u1_w * v0_w 32 | bottom_left_w = u0_w * v1_w 33 | bottom_right_w = u1_w * v1_w 34 | 35 | weight_threshold = 1e-3 36 | top_left_w *= (top_left_w >= weight_threshold).detach().type(values.dtype) 37 | top_right_w *= (top_right_w >= weight_threshold).detach().type(values.dtype) 38 | bottom_left_w *= (bottom_left_w >= weight_threshold).detach().type(values.dtype) 39 | bottom_right_w *= (bottom_right_w >= weight_threshold).detach().type(values.dtype) 40 | 41 | for channel in range(c): 42 | top_left_values = values[:, channel, :, :].unsqueeze(1) * top_left_w 43 | top_right_values = values[:, channel, :, :].unsqueeze(1) * top_right_w 44 | bottom_left_values = values[:, channel, :, :].unsqueeze(1) * bottom_left_w 45 | bottom_right_values = values[:, channel, :, :].unsqueeze(1) * bottom_right_w 46 | 47 | top_left_values = top_left_values.reshape(b, -1) 48 | top_right_values = top_right_values.reshape(b, -1) 49 | bottom_left_values = bottom_left_values.reshape(b, -1) 50 | bottom_right_values = bottom_right_values.reshape(b, -1) 51 | 52 | top_left_indices = (u0_safe + v0_safe * w).reshape(b, -1).type(torch.int64) 53 | top_right_indices = (u1_safe + v0_safe * w).reshape(b, -1).type(torch.int64) 54 | bottom_left_indices = (u0_safe + v1_safe * w).reshape(b, -1).type(torch.int64) 55 | bottom_right_indices = (u1_safe + v1_safe * w).reshape(b, -1).type(torch.int64) 56 | 57 | splatted_channel = splatted[:, channel, :, :].unsqueeze(1) 58 | splatted_channel = splatted_channel.reshape(b, -1) 59 | splatted_channel.scatter_add_(1, top_left_indices, top_left_values) 60 | splatted_channel.scatter_add_(1, top_right_indices, top_right_values) 61 | splatted_channel.scatter_add_(1, bottom_left_indices, bottom_left_values) 62 | splatted_channel.scatter_add_(1, bottom_right_indices, bottom_right_values) 63 | splatted = splatted.reshape(b, c, h, w) 64 | 65 | def __weighted_average_splat__(depth, weights, epsilon=1e-8): 66 | zero_weights = (weights <= epsilon).detach().type(depth.dtype) 67 | return depth / (weights + epsilon * zero_weights) 68 | 69 | def __depth_distance_weights__(depth, max_depth=20.0): 70 | weights = 1.0 / torch.exp(2 * depth / max_depth) 71 | return weights 72 | 73 | def render(img, depth, coords, max_depth=20.0): 74 | splatted_img = torch.zeros_like(img) 75 | splatted_wgts = torch.zeros_like(depth) 76 | weights = __depth_distance_weights__(depth, max_depth=max_depth) 77 | __splat__(img * weights, coords, splatted_img) 78 | __splat__(weights, coords, splatted_wgts) 79 | recon = __weighted_average_splat__(splatted_img, splatted_wgts) 80 | mask = (splatted_wgts > 1e-3).detach() 81 | return recon, mask 82 | 83 | def render_to(src, tgt, wgts, depth, coords, max_depth=20.0): 84 | weights = __depth_distance_weights__(depth, max_depth=max_depth) 85 | __splat__(src * weights, coords, tgt) 86 | __splat__(weights, coords, wgts) 87 | tgt = __weighted_average_splat__(tgt, wgts) 88 | mask = (wgts > 1e-3).detach() 89 | return mask -------------------------------------------------------------------------------- /supervision/ssim.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code modified from https://github.com/Po-Hsun-Su/pytorch-ssim 3 | ''' 4 | 5 | import torch 6 | import numpy 7 | import math 8 | 9 | def __gaussian__(kernel_size, std, data_type=torch.float32): 10 | gaussian = numpy.array([math.exp(-(x - kernel_size//2)**2/float(2*std**2)) for x in range(kernel_size)]) 11 | gaussian /= numpy.sum(gaussian) 12 | return torch.tensor(gaussian, dtype=data_type) 13 | 14 | def __create_kernel__(kernel_size, data_type=torch.float32, channels=3, std=1.5): 15 | gaussian1d = __gaussian__(kernel_size, std).unsqueeze(1) 16 | gaussian2d = torch.mm(gaussian1d, gaussian1d.t())\ 17 | .type(data_type)\ 18 | .unsqueeze(0)\ 19 | .unsqueeze(0) 20 | window = gaussian2d.expand(channels, 1, kernel_size, kernel_size).contiguous() 21 | return window 22 | 23 | def __ssim_gaussian__(prediction, groundtruth, kernel, kernel_size, channels=3): 24 | padding = kernel_size // 2 25 | prediction_mean = torch.nn.functional.conv2d(prediction, kernel, padding=padding, groups=channels) 26 | groundtruth_mean = torch.nn.functional.conv2d(groundtruth, kernel, padding=padding, groups=channels) 27 | 28 | prediction_mean_squared = prediction_mean.pow(2) 29 | groundtruth_mean_squared = groundtruth_mean.pow(2) 30 | prediction_mean_times_groundtruth_mean = prediction_mean * groundtruth_mean 31 | 32 | prediction_sigma_squared = torch.nn.functional.conv2d(prediction * prediction, kernel, padding=padding, groups=channels)\ 33 | - prediction_mean_squared 34 | groundtruth_sigma_squared = torch.nn.functional.conv2d(groundtruth * groundtruth, kernel, padding=padding, groups=channels)\ 35 | - groundtruth_mean_squared 36 | prediction_groundtruth_covariance = torch.nn.functional.conv2d(prediction * groundtruth, kernel, padding=padding, groups=channels)\ 37 | - prediction_mean_times_groundtruth_mean 38 | 39 | C1 = 0.01**2 # assume that images are in the [0, 1] range 40 | C2 = 0.03**2 # assume that images are in the [0, 1] range 41 | 42 | return ( 43 | ( # numerator 44 | (2 * prediction_mean_times_groundtruth_mean + C1) # luminance term 45 | * (2 * prediction_groundtruth_covariance + C2) # structural term 46 | ) 47 | / # division 48 | ( # denominator 49 | (prediction_mean_squared + groundtruth_mean_squared + C1) # luminance term 50 | * (prediction_sigma_squared + groundtruth_sigma_squared + C2) # structural term 51 | ) 52 | ) 53 | 54 | def ssim_gaussian(prediction, groundtruth, kernel_size=11, std=1.5): 55 | (_, channels, _, _) = prediction.size() 56 | kernel = __create_kernel__(kernel_size, data_type=prediction.type(),\ 57 | channels=channels, std=std) 58 | 59 | if prediction.is_cuda: 60 | kernel = kernel.to(prediction.get_device()) 61 | kernel = kernel.type_as(prediction) 62 | 63 | return __ssim_gaussian__(prediction, groundtruth, kernel, kernel_size, channels) 64 | 65 | def ssim_box(prediction, groundtruth, kernel_size=3): 66 | C1 = 0.01 ** 2 67 | C2 = 0.03 ** 2 68 | 69 | prediction_mean = torch.nn.AvgPool2d(kernel_size, stride=1)(prediction) 70 | groundtruth_mean = torch.nn.AvgPool2d(kernel_size, stride=1)(groundtruth) 71 | prediction_groundtruth_mean = prediction_mean * groundtruth_mean 72 | prediction_mean_squared = prediction_mean.pow(2) 73 | groundtruth_mean_squared = groundtruth_mean.pow(2) 74 | 75 | prediction_sigma = torch.nn.AvgPool2d(kernel_size, stride=1)(prediction * prediction) - prediction_mean_squared 76 | groundtruth_sigma = torch.nn.AvgPool2d(kernel_size, stride=1)(groundtruth * groundtruth) - groundtruth_mean_squared 77 | correlation = torch.nn.AvgPool2d(kernel_size, stride=1)(prediction * groundtruth) - prediction_groundtruth_mean 78 | 79 | numerator = (2 * prediction_groundtruth_mean + C1) * (2 * correlation + C2) 80 | denominator = (prediction_mean_squared + groundtruth_mean_squared + C1)\ 81 | * (prediction_sigma + groundtruth_sigma + C2) 82 | ssim = numerator / denominator 83 | pad = kernel_size // 2 84 | return torch.nn.functional.pad(ssim, (pad, pad, pad, pad)) 85 | 86 | def ssim_loss(prediction, groundtruth, kernel_size=5, std=1.5, mode='gaussian'): 87 | if mode == 'gaussian': 88 | return ssim_gaussian(prediction, groundtruth, kernel_size=kernel_size, std=std) 89 | elif mode == 'box': 90 | return ssim_box(prediction, groundtruth, kernel_size=kernel_size) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import numpy 5 | 6 | import torch 7 | import torchvision 8 | 9 | import models 10 | import dataset 11 | import utils 12 | from filesystem import file_utils 13 | 14 | import supervision as L 15 | import exporters as IO 16 | import spherical as S360 17 | 18 | def parse_arguments(args): 19 | usage_text = ( 20 | "Semi-supervised Spherical Depth Estimation Testing." 21 | ) 22 | parser = argparse.ArgumentParser(description=usage_text) 23 | # enumerables 24 | parser.add_argument('-b',"--batch_size", type=int, help="Test a number of samples each iteration.") 25 | parser.add_argument('--save_iters', type=int, default=100, help='Maximum test iterations whose results will be saved.') 26 | # paths 27 | parser.add_argument("--test_path", type=str, help="Path to the testing file containing the test set file paths") 28 | parser.add_argument("--save_path", type=str, help="Path to the folder where the models and results will be saved at.") 29 | # model 30 | parser.add_argument("--configuration", required=False, type=str, default='mono', help="Data loader configuration , , , ", choices=['mono', 'lr', 'ud', 'tc']) 31 | parser.add_argument('--weights', type=str, help='Path to the trained weights file.') 32 | parser.add_argument('--model', default="default", type=str, help='Model selection argument.') 33 | # hardware 34 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.') 35 | # other 36 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.') 37 | parser.add_argument("--visdom", type=str, nargs='?', default=None, const="127.0.0.1", help="Visdom server IP (port defaults to 8097)") 38 | parser.add_argument("--visdom_iters", type=int, default=400, help = "Iteration interval that results will be reported at the visdom server for visualization.") 39 | # metrics 40 | parser.add_argument("--depth_thres", type=float, default=20.0, help = "Depth threshold - depth clipping.") 41 | parser.add_argument("--width", type=float, default=512, help = "Spherical image width.") 42 | parser.add_argument("--baseline", type=float, default=0.26, help = "Stereo baseline distance (in either axis).") 43 | parser.add_argument("--median_scale", required=False, default=False, action="store_true", help = "Perform median scaling before calculating metrics.") 44 | parser.add_argument("--spherical_weights", required=False, default=False, action="store_true", help = "Use spherical weighting when calculating the metrics.") 45 | parser.add_argument("--spherical_sampling", required=False, default=False, action="store_true", help = "Use spherical sampling when calculating the metrics.") 46 | # save options 47 | parser.add_argument("--save_recon", required=False, default=False, action="store_true", help = "Flag to toggle reconstructed result saving.") 48 | parser.add_argument("--save_original", required=False, default=False, action="store_true", help = "Flag to toggle input (image) saving.") 49 | parser.add_argument("--save_depth", required=False, default=False, action="store_true", help = "Flag to toggle output (depth) saving.") 50 | return parser.parse_known_args(args) 51 | 52 | def compute_errors(gt, pred, invalid_mask, weights, sampling, mode='cpu', median_scale=False): 53 | b, _, __, ___ = gt.size() 54 | scale = torch.median(gt.reshape(b, -1), dim=1)[0] / torch.median(pred.reshape(b, -1), dim=1)[0]\ 55 | if median_scale else torch.tensor(1.0).expand(b, 1, 1, 1).to(gt.device) 56 | pred = pred * scale.reshape(b, 1, 1, 1) 57 | valid_sum = torch.sum(~invalid_mask, dim=[1, 2, 3], keepdim=True) 58 | gt[invalid_mask] = 0.0 59 | pred[invalid_mask] = 0.0 60 | thresh = torch.max((gt / pred), (pred / gt)) 61 | thresh[invalid_mask | (sampling < 0.5)] = 2.0 62 | 63 | sum_dims = [1, 2, 3] 64 | delta_valid_sum = torch.sum(~invalid_mask & (sampling > 0), dim=[1, 2, 3], keepdim=True) 65 | delta1 = (thresh < 1.25 ).float().sum(dim=sum_dims, keepdim=True).float() / delta_valid_sum.float() 66 | delta2 = (thresh < (1.25 ** 2)).float().sum(dim=sum_dims, keepdim=True).float() / delta_valid_sum.float() 67 | delta3 = (thresh < (1.25 ** 3)).float().sum(dim=sum_dims, keepdim=True).float() / delta_valid_sum.float() 68 | 69 | rmse = (gt - pred) ** 2 70 | rmse[invalid_mask] = 0.0 71 | rmse_w = rmse * weights 72 | rmse_mean = torch.sqrt(rmse_w.sum(dim=sum_dims, keepdim=True) / valid_sum.float()) 73 | 74 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 75 | rmse_log[invalid_mask] = 0.0 76 | rmse_log_w = rmse_log * weights 77 | rmse_log_mean = torch.sqrt(rmse_log_w.sum(dim=sum_dims, keepdim=True) / valid_sum.float()) 78 | 79 | abs_rel = (torch.abs(gt - pred) / gt) 80 | abs_rel[invalid_mask] = 0.0 81 | abs_rel_w = abs_rel * weights 82 | abs_rel_mean = abs_rel_w.sum(dim=sum_dims, keepdim=True) / valid_sum.float() 83 | 84 | sq_rel = (((gt - pred)**2) / gt) 85 | sq_rel[invalid_mask] = 0.0 86 | sq_rel_w = sq_rel * weights 87 | sq_rel_mean = sq_rel_w.sum(dim=sum_dims, keepdim=True) / valid_sum.float() 88 | 89 | return (abs_rel_mean, abs_rel), (sq_rel_mean, sq_rel), (rmse_mean, rmse), \ 90 | (rmse_log_mean, rmse_log), delta1, delta2, delta3 91 | 92 | def spiral_sampling(grid, percentage): 93 | b, c, h, w = grid.size() 94 | N = torch.tensor(h*w*percentage).int().float() 95 | sampling = torch.zeros_like(grid)[:, 0, :, :].unsqueeze(1) 96 | phi_k = torch.tensor(0.0).float() 97 | for k in torch.arange(N - 1): 98 | k = k.float() + 1.0 99 | h_k = -1 + 2 * (k - 1) / (N - 1) 100 | theta_k = torch.acos(h_k) 101 | phi_k = phi_k + torch.tensor(3.6).float() / torch.sqrt(N) / torch.sqrt(1 - h_k * h_k) \ 102 | if k > 1.0 else torch.tensor(0.0).float() 103 | phi_k = torch.fmod(phi_k, 2 * numpy.pi) 104 | sampling[:, :, int(theta_k / numpy.pi * h) - 1, int(phi_k / numpy.pi / 2 * w) - 1] += 1.0 105 | return (sampling > 0).float() 106 | 107 | if __name__ == "__main__": 108 | args, unknown = parse_arguments(sys.argv) 109 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0] 110 | # device & visualizers 111 | device = torch.device("cuda:{}" .format(gpus[0])\ 112 | if torch.cuda.is_available() and len(gpus) > 0 and gpus[0] >= 0\ 113 | else "cpu") 114 | plot_visualizer, image_visualizer = (utils.NullVisualizer(), utils.NullVisualizer())\ 115 | if args.visdom is None\ 116 | else ( 117 | utils.VisdomPlotVisualizer(args.name + "_test_plots_", args.visdom), 118 | utils.VisdomImageVisualizer(args.name + "_test_images_", args.visdom,\ 119 | count=2 if 2 <= args.batch_size else args.batch_size) 120 | ) 121 | image_visualizer.update_epoch(0) 122 | # model 123 | model_params = { 'width': 512, 'configuration': args.configuration } 124 | model = models.get_model(args.model, model_params) 125 | utils.init.initialize_weights(model, args.weights, pred_bias=None) 126 | if (len(gpus) > 1): 127 | model = torch.nn.parallel.DataParallel(model, gpus) 128 | model = model.to(device) 129 | # test data 130 | width, height = args.width, args.width // 2 131 | test_data = dataset.dataset_360D.Dataset360D(args.test_path, " ", args.configuration, [height, width]) 132 | test_data_iterator = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size,\ 133 | num_workers=args.batch_size // 4 // (2 if len(gpus) > 0 else 1), pin_memory=False, shuffle=False) 134 | fs = file_utils.Filesystem() 135 | fs.mkdir(args.save_path) 136 | print("Test size : {}".format(args.batch_size * test_data_iterator.__len__())) 137 | # params & error vars 138 | max_save_iters = args.save_iters if args.save_iters > 0\ 139 | else args.batch_size * test_data_iterator.__len__() 140 | errors = numpy.zeros((7, args.batch_size * test_data_iterator.__len__()), numpy.float32) 141 | weights = S360.weights.theta_confidence( 142 | S360.grid.create_spherical_grid(width) 143 | ).to(device) if args.spherical_weights else torch.ones(1, 1, height, width).to(device) 144 | sampling = spiral_sampling(S360.grid.create_image_grid(width, height), 0.25).to(device) \ 145 | if args.spherical_sampling else torch.ones(1, 1, height, width).to(device) 146 | # loop over test set 147 | model.eval() 148 | with torch.no_grad(): 149 | counter = 0 150 | uvgrid = S360.grid.create_image_grid(width, height).to(device) 151 | sgrid = S360.grid.create_spherical_grid(width).to(device) 152 | for test_batch_id , test_batch in enumerate(test_data_iterator): 153 | ''' Data ''' 154 | left_rgb = test_batch['leftRGB'].to(device) 155 | left_depth = test_batch['leftDepth'].to(device) 156 | if 'rightRGB' in test_batch: 157 | right_rgb = test_batch['rightRGB'].to(device) 158 | mask = (left_depth > args.depth_thres) 159 | b, c, h, w = left_rgb.size() 160 | ''' Prediction ''' 161 | left_depth_pred = torch.abs(model(left_rgb)) 162 | ''' Errors ''' 163 | abs_rel_t, sq_rel_t, rmse_t, rmse_log_t, delta1, delta2, delta3\ 164 | = compute_errors(left_depth, left_depth_pred, mask, weights=weights, sampling=sampling, \ 165 | mode='gpu' if torch.cuda.is_available() and len(gpus) > 0 and gpus[0] >= 0 else "cpu", \ 166 | median_scale=args.median_scale) 167 | ''' Visualize & Append Errors ''' 168 | for i in range(b): 169 | idx = counter + i 170 | errors[:, idx] = abs_rel_t[0][i], sq_rel_t[0][i], rmse_t[0][i], \ 171 | rmse_log_t[0][i], delta1[i], delta2[i], delta3[i] 172 | for j in range(7): 173 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[0, idx]), "abs_rel") 174 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[1, idx]), "sq_rel") 175 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[2, idx]), "rmse") 176 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[3, idx]), "rmse_log") 177 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[4, idx]), "delta1") 178 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[5, idx]), "delta2") 179 | plot_visualizer.append_loss(1, idx, torch.tensor(errors[6, idx]), "delta3") 180 | ''' Store ''' 181 | if counter < args.save_iters: 182 | if args.save_original: 183 | IO.image.save_image(os.path.join(args.save_path,\ 184 | str(counter) + "_" + args.name + "_#_left.png"), left_rgb) 185 | if args.save_depth: 186 | IO.image.save_data(os.path.join(args.save_path,\ 187 | str(counter) + "_" + args.name + "_#_depth.exr"), left_depth_pred, scale=1.0) 188 | if args.save_recon: 189 | rads = sgrid.expand(b, -1, -1, -1) 190 | uv = uvgrid.expand(b, -1, -1, -1) 191 | disp = torch.cat( 192 | ( 193 | S360.derivatives.dphi_horizontal(rads, left_depth_pred, args.baseline), 194 | S360.derivatives.dtheta_horizontal(rads, left_depth_pred, args.baseline) 195 | ), dim=1 196 | ) 197 | right_render_coords = uv + disp 198 | right_render_coords[:, 0, :, :] = torch.fmod(right_render_coords[:, 0, :, :] + width, width) 199 | right_render_coords[torch.isnan(right_render_coords)] = 0.0 200 | right_render_coords[torch.isinf(right_render_coords)] = 0.0 201 | right_rgb_t, right_mask_t = L.splatting.render(left_rgb, left_depth_pred, right_render_coords, max_depth=args.depth_thres) 202 | IO.image.save_image(os.path.join(args.save_path,\ 203 | str(counter) + "_" + args.name + "_#_right_t.png"), right_rgb_t) 204 | counter += b 205 | ''' Visualize Predictions ''' 206 | if args.visdom_iters > 0 and (counter + 1) % args.visdom_iters <= args.batch_size: 207 | image_visualizer.show_separate_images(left_rgb, 'input') 208 | if 'rightRGB' in test_batch: 209 | image_visualizer.show_separate_images(right_rgb, 'target') 210 | image_visualizer.show_map(left_depth_pred, 'depth') 211 | if args.save_recon: 212 | image_visualizer.show_separate_images(right_rgb_t, 'recon') 213 | mean_errors = errors.mean(1) 214 | error_names = ['abs_rel','sq_rel','rmse','log_rmse','delta1','delta2','delta3'] 215 | print("Results ({}): ".format(args.name)) 216 | print("\t{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names)) 217 | print("\t{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors)) 218 | 219 | 220 | -------------------------------------------------------------------------------- /train_lr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import numpy 5 | 6 | import torch 7 | 8 | import models 9 | import dataset 10 | import utils 11 | 12 | import supervision as L 13 | import exporters as IO 14 | import spherical as S360 15 | 16 | def parse_arguments(args): 17 | usage_text = ( 18 | "Omnidirectional Horizontal Stereo Placement (Left-Right , LR) Training." 19 | ) 20 | parser = argparse.ArgumentParser(description=usage_text) 21 | # durations 22 | parser.add_argument('-e',"--epochs", type=int, help="Train for a total number of epochs.") 23 | parser.add_argument('-b',"--batch_size", type=int, help="Train with a number of samples each train iteration.") 24 | parser.add_argument("--test_batch_size", default=1, type=int, help="Test with a number of samples each test iteration.") 25 | parser.add_argument('-d','--disp_iters', type=int, default=50, help='Log training progress (i.e. loss etc.) on console every iterations.') 26 | parser.add_argument('--save_iters', type=int, default=100, help='Maximum test iterations to perform each test run.') 27 | # paths 28 | parser.add_argument("--train_path", type=str, help="Path to the training file containing the train set files paths") 29 | # model 30 | parser.add_argument("--configuration", required=False, type=str, default='mono', help="Data loader configuration , , , ", choices=['mono', 'lr', 'ud', 'tc']) 31 | parser.add_argument('--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)') 32 | parser.add_argument('--model', default="default", type=str, help='Model selection argument.') 33 | # optimization 34 | parser.add_argument('-o','--optimizer', type=str, default="adam", help='The optimizer that will be used during training.') 35 | parser.add_argument("--opt_state", type=str, help="Path to stored optimizer state file to continue training)") 36 | parser.add_argument('-l','--lr', type=float, default=0.0002, help='Optimization Learning Rate.') 37 | parser.add_argument('-m','--momentum', type=float, default=0.9, help='Optimization Momentum.') 38 | parser.add_argument('--momentum2', type=float, default=0.999, help='Optimization Second Momentum (optional, only used by some optimizers).') 39 | parser.add_argument('--epsilon', type=float, default=1e-8, help='Optimization Epsilon (optional, only used by some optimizers).') 40 | parser.add_argument('--weight_decay', type=float, default=0, help='Optimization Weight Decay.') 41 | # hardware 42 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.') 43 | # other 44 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.') 45 | parser.add_argument("--visdom", type=str, nargs='?', default=None, const="127.0.0.1", help="Visdom server IP (port defaults to 8097)") 46 | parser.add_argument("--visdom_iters", type=int, default=400, help = "Iteration interval that results will be reported at the visdom server for visualization.") 47 | parser.add_argument("--seed", type=int, default=1337, help="Fixed manual seed, zero means no seeding.") 48 | # network specific params 49 | parser.add_argument("--photo_w", type=float, default=1.0, help = "Photometric loss weight.") 50 | parser.add_argument("--smooth_reg_w", type=float, default=0.1, help = "Smoothness regularization weight.") 51 | parser.add_argument("--ssim_window", type=int, default=7, help = "Kernel size to use in SSIM calculation.") 52 | parser.add_argument("--ssim_mode", type=str, default='gaussian', help = "Type of SSIM averaging (either gaussian or box).") 53 | parser.add_argument("--ssim_std", type=float, default=1.5, help = "SSIM standard deviation value used when creating the gaussian averaging kernels.") 54 | parser.add_argument("--ssim_alpha", type=float, default=0.85, help = "Alpha factor to weight the SSIM and L1 losses, where a x SSIM and (1 - a) x L1.") 55 | parser.add_argument("--pred_bias", type=float, default=5.0, help = "Initialize prediction layers' bias to the given value (helps convergence).") 56 | # details 57 | parser.add_argument("--depth_thres", type=float, default=20.0, help = "Depth threshold - depth clipping.") 58 | parser.add_argument("--baseline", type=float, default=0.26, help = "Stereo baseline distance (in either axis).") 59 | parser.add_argument("--width", type=float, default=512, help = "Spherical image width.") 60 | return parser.parse_known_args(args) 61 | 62 | if __name__ == "__main__": 63 | args, unknown = parse_arguments(sys.argv) 64 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0] 65 | # device & visualizers 66 | device, visualizers, model_params = utils.initialize(args) 67 | plot_viz = visualizers[0] 68 | img_viz = visualizers[1] 69 | # model 70 | model = models.get_model(args.model, model_params) 71 | utils.init.initialize_weights(model, args.weight_init, pred_bias=args.pred_bias) 72 | if (len(gpus) > 1): 73 | model = torch.nn.parallel.DataParallel(model, gpus) 74 | model = model.to(device) 75 | # optimizer 76 | optimizer = utils.init_optimizer(model, args) 77 | # train data 78 | train_data = dataset.dataset_360D.Dataset360D(args.train_path, " ", args.configuration, [256, 512]) 79 | train_data_iterator = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,\ 80 | num_workers=args.batch_size // len(gpus) // len(gpus), pin_memory=False, shuffle=True) 81 | # test data 82 | test_data = dataset.dataset_360D.Dataset360D(args.test_path, " ", args.configuration, [256, 512]) 83 | test_data_iterator = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size,\ 84 | num_workers=args.batch_size // len(gpus) // len(gpus), pin_memory=False, shuffle=True) 85 | print("Data size : {0} | Test size : {1}".format(\ 86 | args.batch_size * train_data_iterator.__len__(), \ 87 | args.test_batch_size * test_data_iterator.__len__())) 88 | # params 89 | width = args.width 90 | height = args.width // 2 91 | photo_params = L.photometric.PhotometricLossParameters( 92 | alpha=args.ssim_alpha, l1_estimator='none', ssim_estimator='none', 93 | ssim_mode=args.ssim_mode, std=args.ssim_std, window=args.ssim_window 94 | ) 95 | iteration_counter = 0 96 | # meters 97 | total_loss = utils.AverageMeter() 98 | running_photo_loss = utils.AverageMeter() 99 | running_depth_smooth_loss = utils.AverageMeter() 100 | # train / test loop 101 | model.train() 102 | plot_viz.config(**vars(args)) 103 | for epoch in range(args.epochs): 104 | print("Training | Epoch: {}".format(epoch)) 105 | img_viz.update_epoch(epoch) 106 | for batch_id, batch in enumerate(train_data_iterator): 107 | optimizer.zero_grad() 108 | active_loss = torch.tensor(0.0).to(device) 109 | ''' Data ''' 110 | left_rgb = batch['leftRGB'].to(device) 111 | b, _, __, ___ = left_rgb.size() 112 | expand_size = (b, -1, -1, -1) 113 | sgrid = S360.grid.create_spherical_grid(width).to(device) 114 | uvgrid = S360.grid.create_image_grid(width, height).to(device) 115 | right_rgb = batch['rightRGB'].to(device) 116 | left_depth = batch['leftDepth'].to(device) 117 | right_depth = batch['rightDepth'].to(device) 118 | ''' Prediction ''' 119 | left_depth_pred = torch.abs(model(left_rgb)) 120 | ''' Forward Rendering LR ''' 121 | disp = torch.cat( 122 | ( 123 | S360.derivatives.dphi_horizontal(sgrid, left_depth_pred, args.baseline), 124 | S360.derivatives.dtheta_horizontal(sgrid, left_depth_pred, args.baseline) 125 | ), 126 | dim=1 127 | ) 128 | right_render_coords = uvgrid + disp 129 | right_render_coords[:, 0, :, :] = torch.fmod(right_render_coords[:, 0, :, :] + width, width) 130 | right_render_coords[torch.isnan(right_render_coords)] = 0.0 131 | right_render_coords[torch.isinf(right_render_coords)] = 0.0 132 | right_rgb_t, right_mask_t = L.splatting.render(left_rgb, left_depth_pred,\ 133 | right_render_coords, max_depth=args.depth_thres) 134 | ''' Loss LR ''' 135 | right_cutoff_mask = (right_depth < args.depth_thres) 136 | right_mask_t &= ~(right_depth > args.depth_thres) 137 | attention_weights = S360.weights.phi_confidence( 138 | S360.grid.create_spherical_grid(width)).to(device) 139 | # attention_weights = S360.weights.spherical_confidence( 140 | # S360.grid.create_spherical_grid(width), zero_low=0.001 141 | # ).to(device) 142 | # attention_weights = torch.ones_like(left_depth) 143 | photo_loss = L.photometric.calculate_loss(right_rgb_t, right_rgb, photo_params, 144 | mask=right_cutoff_mask, weights=attention_weights) 145 | active_loss += photo_loss * args.photo_w 146 | ''' Loss Prior (3D Smoothness) ''' 147 | left_xyz = S360.cartesian.coords_3d(sgrid, left_depth_pred) 148 | dI_dxyz = S360.derivatives.dV_dxyz(left_xyz) 149 | guidance_duv = S360.derivatives.dI_duv(left_rgb) 150 | # attention_weights = torch.zeros_like(left_depth) 151 | depth_smooth_loss = L.smoothness.guided_smoothness_loss( 152 | dI_dxyz, guidance_duv, right_cutoff_mask, (1.0 - attention_weights) 153 | * right_cutoff_mask.type(attention_weights.dtype) 154 | ) 155 | active_loss += depth_smooth_loss * args.smooth_reg_w 156 | ''' Update Params ''' 157 | active_loss.backward() 158 | optimizer.step() 159 | ''' Visualize''' 160 | total_loss.update(active_loss) 161 | running_depth_smooth_loss.update(depth_smooth_loss) 162 | running_photo_loss.update(photo_loss) 163 | iteration_counter += b 164 | if (iteration_counter + 1) % args.disp_iters <= args.batch_size: 165 | print("Epoch: {}, iteration: {}\nPhotometric: {}\nSmoothness: {}\nTotal average loss: {}\n"\ 166 | .format(epoch, iteration_counter, running_photo_loss.avg, \ 167 | running_depth_smooth_loss.avg, total_loss.avg)) 168 | plot_viz.append_loss(epoch + 1, iteration_counter, total_loss.avg, "avg") 169 | plot_viz.append_loss(epoch + 1, iteration_counter, running_photo_loss.avg, "photo") 170 | plot_viz.append_loss(epoch + 1, iteration_counter, running_depth_smooth_loss.avg, "smooth") 171 | total_loss.reset() 172 | running_photo_loss.reset() 173 | running_depth_smooth_loss.reset() 174 | if args.visdom_iters > 0 and (iteration_counter + 1) % args.visdom_iters <= args.batch_size: 175 | img_viz.show_separate_images(left_rgb, 'input') 176 | img_viz.show_separate_images(right_rgb, 'target') 177 | img_viz.show_map(left_depth_pred, 'depth') 178 | img_viz.show_separate_images(torch.clamp(right_rgb_t, min=0.0, max=1.0), 'recon') 179 | ''' Save ''' 180 | print("Saving model @ epoch #" + str(epoch)) 181 | utils.checkpoint.save_network_state(model, optimizer, epoch,\ 182 | args.name + "_model_state", args.save_path) 183 | ''' Test ''' 184 | print("Testing model @ epoch #" + str(epoch)) 185 | model.eval() 186 | with torch.no_grad(): 187 | rmse_avg = torch.tensor(0.0).float() 188 | counter = torch.tensor(0.0).float() 189 | for test_batch_id , test_batch in enumerate(test_data_iterator): 190 | left_rgb = test_batch['leftRGB'].to(device) 191 | b, c, h, w = left_rgb.size() 192 | rads = sgrid.expand(b, -1, -1, -1) 193 | uv = uvgrid.expand(b, -1, -1, -1) 194 | left_depth_pred = torch.abs(model(left_rgb)) 195 | left_depth = test_batch['leftDepth'].to(device) 196 | left_depth[torch.isnan(left_depth)] = 50.0 197 | left_depth[torch.isinf(left_depth)] = 50.0 198 | mse = (left_depth_pred ** 2) - (left_depth ** 2) 199 | mse[torch.isnan(mse)] = 0.0 200 | mse[torch.isinf(mse)] = 0.0 201 | mask = (left_depth < args.depth_thres).float() 202 | if torch.sum(mask) == 0: 203 | continue 204 | rmse = torch.sqrt(torch.sum(mse * mask) / torch.sum(mask).float()) 205 | if not torch.isnan(rmse): 206 | rmse_avg += rmse.cpu().float() 207 | counter += torch.tensor(b).float() 208 | if counter < args.save_iters: 209 | disp = torch.cat( 210 | ( 211 | S360.derivatives.dphi_horizontal(rads, left_depth_pred, args.baseline), 212 | S360.derivatives.dtheta_horizontal(rads, left_depth_pred, args.baseline) 213 | ), dim=1 214 | ) 215 | right_render_coords = uv + disp 216 | right_render_coords[:, 0, :, :] = torch.fmod(right_render_coords[:, 0, :, :] + width, width) 217 | right_render_coords[torch.isnan(right_render_coords)] = 0.0 218 | right_render_coords[torch.isinf(right_render_coords)] = 0.0 219 | right_rgb_t, right_mask_t = L.splatting.render(left_rgb, left_depth_pred, right_render_coords, max_depth=args.depth_thres) 220 | IO.image.save_image(os.path.join(args.save_path,\ 221 | str(epoch) + "_" + str(counter) + "_#_left.png"), left_rgb) 222 | IO.image.save_image(os.path.join(args.save_path,\ 223 | str(epoch) + "_" + str(counter) + "_#_right_t.png"), right_rgb_t) 224 | IO.image.save_data(os.path.join(args.save_path,\ 225 | str(epoch) + "_" + str(counter) + "_#_depth.exr"), left_depth_pred, scale=1.0) 226 | rmse_avg /= counter 227 | print("Testing epoch {}: RMSE = {}".format(epoch+1, rmse_avg)) 228 | plot_viz.append_loss(epoch + 1, epoch + 1, rmse_avg, "rmse", mode='test') 229 | torch.enable_grad() 230 | model.train() 231 | -------------------------------------------------------------------------------- /train_sv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import numpy 5 | 6 | import torch 7 | import torchvision 8 | 9 | import models 10 | import dataset 11 | import utils 12 | 13 | import supervision as L 14 | import exporters as IO 15 | import spherical as S360 16 | 17 | def parse_arguments(args): 18 | usage_text = ( 19 | "Omnidirectional Supervised (SV) Training." 20 | ) 21 | parser = argparse.ArgumentParser(description=usage_text) 22 | # durations 23 | parser.add_argument('-e',"--epochs", type=int, help="Train for a total number of epochs.") 24 | parser.add_argument('-b',"--batch_size", type=int, help="Train with a number of samples each train iteration.") 25 | parser.add_argument("--test_batch_size", default=1, type=int, help="Test with a number of samples each test iteration.") 26 | parser.add_argument('-d','--disp_iters', type=int, default=50, help='Log training progress (i.e. loss etc.) on console every iterations.') 27 | parser.add_argument('--save_iters', type=int, default=100, help='Maximum test iterations to perform each test run.') 28 | # paths 29 | parser.add_argument("--train_path", type=str, help="Path to the training file containing the train set files paths") 30 | parser.add_argument("--test_path", type=str, help="Path to the testing file containing the test set file paths") 31 | parser.add_argument("--save_path", type=str, help="Path to the folder where the models and results will be saved at.") 32 | # model 33 | parser.add_argument("--configuration", required=False, type=str, default='mono', help="Data loader configuration , , , ", choices=['mono', 'lr', 'ud', 'tc']) 34 | parser.add_argument('--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)') 35 | parser.add_argument('--model', default="default", type=str, help='Model selection argument.') 36 | # optimization 37 | parser.add_argument('-o','--optimizer', type=str, default="adam", help='The optimizer that will be used during training.') 38 | parser.add_argument("--opt_state", type=str, help="Path to stored optimizer state file to continue training)") 39 | parser.add_argument('-l','--lr', type=float, default=0.0002, help='Optimization Learning Rate.') 40 | parser.add_argument('-m','--momentum', type=float, default=0.9, help='Optimization Momentum.') 41 | parser.add_argument('--momentum2', type=float, default=0.999, help='Optimization Second Momentum (optional, only used by some optimizers).') 42 | parser.add_argument('--epsilon', type=float, default=1e-8, help='Optimization Epsilon (optional, only used by some optimizers).') 43 | parser.add_argument('--weight_decay', type=float, default=0, help='Optimization Weight Decay.') 44 | # hardware 45 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.') 46 | # other 47 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.') 48 | parser.add_argument("--visdom", type=str, nargs='?', default=None, const="127.0.0.1", help="Visdom server IP (port defaults to 8097)") 49 | parser.add_argument("--visdom_iters", type=int, default=400, help = "Iteration interval that results will be reported at the visdom server for visualization.") 50 | parser.add_argument("--seed", type=int, default=1337, help="Fixed manual seed, zero means no seeding.") 51 | # network specific params 52 | parser.add_argument("--depth_w", type=float, default=1.0, help = "Photometric loss weight.") 53 | parser.add_argument("--smooth_reg_w", type=float, default=0.1, help = "Smoothness regularization weight.") 54 | parser.add_argument("--ssim_window", type=int, default=7, help = "Kernel size to use in SSIM calculation.") 55 | parser.add_argument("--ssim_mode", type=str, default='gaussian', help = "Type of SSIM averaging (either gaussian or box).") 56 | parser.add_argument("--ssim_std", type=float, default=1.5, help = "SSIM standard deviation value used when creating the gaussian averaging kernels.") 57 | parser.add_argument("--ssim_alpha", type=float, default=0.85, help = "Alpha factor to weight the SSIM and L1 losses, where a x SSIM and (1 - a) x L1.") 58 | parser.add_argument("--pred_bias", type=float, default=5.0, help = "Initialize prediction layers' bias to the given value (helps convergence).") 59 | # details 60 | parser.add_argument("--depth_thres", type=float, default=20.0, help = "Depth threshold - depth clipping.") 61 | parser.add_argument("--baseline", type=float, default=0.26, help = "Stereo baseline distance (in either axis).") 62 | parser.add_argument("--width", type=float, default=512, help = "Spherical image width.") 63 | return parser.parse_known_args(args) 64 | 65 | if __name__ == "__main__": 66 | args, unknown = parse_arguments(sys.argv) 67 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0] 68 | # device & visualizers 69 | device, visualizers, model_params = utils.initialize(args) 70 | plot_viz = visualizers[0] 71 | img_viz = visualizers[1] 72 | # model 73 | model = models.get_model(args.model, model_params) 74 | utils.init.initialize_weights(model, args.weight_init, pred_bias=args.pred_bias) 75 | if (len(gpus) > 1): 76 | model = torch.nn.parallel.DataParallel(model, gpus) 77 | model = model.to(device) 78 | # optimizer 79 | optimizer = utils.init_optimizer(model, args) 80 | # train data 81 | train_data = dataset.dataset_360D.Dataset360D(args.train_path, " ", args.configuration, [256, 512]) 82 | train_data_iterator = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,\ 83 | num_workers=args.batch_size // len(gpus), pin_memory=True, shuffle=True) 84 | # test data 85 | test_data = dataset.dataset_360D.Dataset360D(args.test_path, " ", args.configuration, [256, 512]) 86 | test_data_iterator = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size,\ 87 | num_workers=args.batch_size // len(gpus), pin_memory=True, shuffle=True) 88 | print("Data size : {0} | Test size : {1}".format(\ 89 | args.batch_size * train_data_iterator.__len__(), \ 90 | args.test_batch_size * test_data_iterator.__len__())) 91 | # params 92 | width = args.width 93 | height = args.width // 2 94 | iteration_counter = 0 95 | # meters 96 | total_loss = utils.AverageMeter() 97 | running_depth_loss = utils.AverageMeter() 98 | running_depth_smooth_loss = utils.AverageMeter() 99 | # train / test loop 100 | model.train() 101 | plot_viz.config(**vars(args)) 102 | for epoch in range(args.epochs): 103 | print("Training | Epoch: {}".format(epoch)) 104 | img_viz.update_epoch(epoch) 105 | for batch_id, batch in enumerate(train_data_iterator): 106 | optimizer.zero_grad() 107 | active_loss = torch.tensor(0.0).to(device) 108 | ''' Data ''' 109 | left_rgb = batch['leftRGB'].to(device) 110 | b, _, __, ___ = left_rgb.size() 111 | left_depth = batch['leftDepth'].to(device) 112 | ''' Prediction ''' 113 | left_depth_pred = torch.abs(model(left_rgb)) 114 | ''' Berhu Loss ''' 115 | left_cutoff_mask = (left_depth < args.depth_thres) 116 | attention_weights = S360.weights.theta_confidence( 117 | S360.grid.create_spherical_grid(width)).to(device) 118 | # attention_weights = torch.ones_like(left_depth) 119 | depth_loss = L.direct.calculate_berhu_loss(left_depth_pred, left_depth, 120 | mask=left_cutoff_mask, weights=attention_weights) 121 | active_loss += depth_loss * args.depth_w 122 | ''' Loss Prior (3D Smoothness) ''' 123 | left_xyz = S360.cartesian.coords_3d( 124 | S360.grid.create_spherical_grid(width).to(device), left_depth_pred) 125 | dI_dxyz = S360.derivatives.dV_dxyz(left_xyz) 126 | guidance_duv = S360.derivatives.dI_duv(left_rgb) 127 | # attention_weights = torch.zeros_like(left_depth) 128 | depth_smooth_loss = L.smoothness.guided_smoothness_loss( 129 | dI_dxyz, guidance_duv, left_cutoff_mask, (1.0 - attention_weights) 130 | * left_cutoff_mask.type(attention_weights.dtype) 131 | ) 132 | active_loss += depth_smooth_loss * args.smooth_reg_w 133 | ''' Update Params ''' 134 | active_loss.backward() 135 | optimizer.step() 136 | ''' Visualize''' 137 | total_loss.update(active_loss) 138 | running_depth_smooth_loss.update(depth_smooth_loss) 139 | running_depth_loss.update(depth_loss) 140 | iteration_counter += b 141 | if (iteration_counter + 1) % args.disp_iters <= args.batch_size: 142 | print("Epoch: {}, iteration: {}\nBerhu: {}\nSmoothness: {}\nTotal average loss: {}\n"\ 143 | .format(epoch, iteration_counter, running_depth_loss.avg, \ 144 | running_depth_smooth_loss.avg, total_loss.avg)) 145 | plot_viz.append_loss(epoch + 1, iteration_counter, total_loss.avg, "avg") 146 | plot_viz.append_loss(epoch + 1, iteration_counter, running_depth_loss.avg, "berhu") 147 | plot_viz.append_loss(epoch + 1, iteration_counter, running_depth_smooth_loss.avg, "smooth") 148 | total_loss.reset() 149 | running_depth_loss.reset() 150 | running_depth_smooth_loss.reset() 151 | if args.visdom_iters > 0 and (iteration_counter + 1) % args.visdom_iters <= args.batch_size: 152 | img_viz.show_separate_images(left_rgb, 'input') 153 | img_viz.show_map(left_depth * left_cutoff_mask.float(), 'target') 154 | img_viz.show_map(left_depth_pred, 'depth') 155 | ''' Save ''' 156 | print("Saving model @ epoch #" + str(epoch)) 157 | utils.checkpoint.save_network_state(model, optimizer, epoch,\ 158 | args.name + "_model_state", args.save_path) 159 | ''' Test ''' 160 | print("Testing model @ epoch #" + str(epoch)) 161 | model.eval() 162 | with torch.no_grad(): 163 | rmse_avg = torch.tensor(0.0).float() 164 | counter = torch.tensor(0.0).float() 165 | for test_batch_id , test_batch in enumerate(test_data_iterator): 166 | left_rgb = test_batch['leftRGB'].to(device) 167 | b, c, h, w = left_rgb.size() 168 | left_depth_pred = torch.abs(model(left_rgb)) 169 | left_depth = test_batch['leftDepth'].to(device) 170 | left_depth[torch.isnan(left_depth)] = 50.0 171 | left_depth[torch.isinf(left_depth)] = 50.0 172 | mse = (left_depth_pred ** 2) - (left_depth ** 2) 173 | mse[torch.isnan(mse)] = 0.0 174 | mse[torch.isinf(mse)] = 0.0 175 | mask = (left_depth < args.depth_thres).float() 176 | if torch.sum(mask) == 0: 177 | continue 178 | rmse = torch.sqrt(torch.sum(mse * mask) / torch.sum(mask)) 179 | if not torch.isnan(rmse): 180 | rmse_avg += rmse.cpu().float() 181 | counter += torch.tensor(b).float() 182 | if counter < args.save_iters: 183 | IO.image.save_image(os.path.join(args.save_path,\ 184 | str(epoch) + "_" + str(counter) + "_#_left.png"), left_rgb) 185 | IO.image.save_data(os.path.join(args.save_path,\ 186 | str(epoch) + "_" + str(counter) + "_#_depth.exr"), left_depth_pred, scale=1.0) 187 | rmse_avg /= counter 188 | print("Testing epoch {}: RMSE = {}".format(epoch+1, rmse_avg)) 189 | plot_viz.append_loss(epoch + 1, iteration_counter, rmse_avg, "rmse") 190 | torch.enable_grad() 191 | model.train() 192 | -------------------------------------------------------------------------------- /train_tc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import numpy 5 | 6 | import torch 7 | 8 | import models 9 | import dataset 10 | import utils 11 | 12 | import supervision as L 13 | import exporters as IO 14 | import spherical as S360 15 | 16 | def parse_arguments(args): 17 | usage_text = ( 18 | "Omnidirectional Trinocular Stereo Placement (Up-Down & Left-Right , UD+LR) Training" 19 | ) 20 | parser = argparse.ArgumentParser(description=usage_text) 21 | # durations 22 | parser.add_argument('-e',"--epochs", type=int, help="Train for a total number of epochs.") 23 | parser.add_argument('-b',"--batch_size", type=int, help="Train with a number of samples each train iteration.") 24 | parser.add_argument("--test_batch_size", default=1, type=int, help="Test with a number of samples each test iteration.") 25 | parser.add_argument('-d','--disp_iters', type=int, default=50, help='Log training progress (i.e. loss etc.) on console every iterations.') 26 | parser.add_argument('--save_iters', type=int, default=100, help='Maximum test iterations to perform each test run.') 27 | # paths 28 | parser.add_argument("--configuration", required=False, type=str, default='mono', help="Data loader configuration , , , ", choices=['mono', 'lr', 'ud', 'tc']) 29 | parser.add_argument("--test_path", type=str, help="Path to the testing file containing the test set file paths") 30 | parser.add_argument("--save_path", type=str, help="Path to the folder where the models and results will be saved at.") 31 | # model 32 | parser.add_argument("--configuration", required = False, type = str, default='tc', help = "Training configuration , , , ", choices=['mono', 'lr', 'ud', 'tc']) 33 | parser.add_argument('--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)') 34 | parser.add_argument('--model', default="default", type=str, help='Model selection argument.') 35 | # optimization 36 | parser.add_argument('-o','--optimizer', type=str, default="adam", help='The optimizer that will be used during training.') 37 | parser.add_argument("--opt_state", type=str, help="Path to stored optimizer state file to continue training)") 38 | parser.add_argument('-l','--lr', type=float, default=0.0002, help='Optimization Learning Rate.') 39 | parser.add_argument('-m','--momentum', type=float, default=0.9, help='Optimization Momentum.') 40 | parser.add_argument('--momentum2', type=float, default=0.999, help='Optimization Second Momentum (optional, only used by some optimizers).') 41 | parser.add_argument('--epsilon', type=float, default=1e-8, help='Optimization Epsilon (optional, only used by some optimizers).') 42 | parser.add_argument('--weight_decay', type=float, default=0, help='Optimization Weight Decay.') 43 | # hardware 44 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.') 45 | # other 46 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.') 47 | parser.add_argument("--visdom", type=str, nargs='?', default=None, const="127.0.0.1", help="Visdom server IP (port defaults to 8097)") 48 | parser.add_argument("--visdom_iters", type=int, default=400, help = "Iteration interval that results will be reported at the visdom server for visualization.") 49 | parser.add_argument("--seed", type=int, default=1337, help="Fixed manual seed, zero means no seeding.") 50 | # network specific params 51 | parser.add_argument("--photo_w", type=float, default=1.0, help = "Photometric loss weight.") 52 | parser.add_argument("--photo_ratio", type=float, default=0.5, help = "Ratio between right (1-ratio) and up (ratio) photometric loss.") 53 | parser.add_argument("--smooth_reg_w", type=float, default=0.1, help = "Smoothness regularization weight.") 54 | parser.add_argument("--ssim_window", type=int, default=7, help = "Kernel size to use in SSIM calculation.") 55 | parser.add_argument("--ssim_mode", type=str, default='gaussian', help = "Type of SSIM averaging (either gaussian or box).") 56 | parser.add_argument("--ssim_std", type=float, default=1.5, help = "SSIM standard deviation value used when creating the gaussian averaging kernels.") 57 | parser.add_argument("--ssim_alpha", type=float, default=0.85, help = "Alpha factor to weight the SSIM and L1 losses, where a x SSIM and (1 - a) x L1.") 58 | parser.add_argument("--pred_bias", type=float, default=5.0, help = "Initialize prediction layers' bias to the given value (helps convergence).") 59 | # details 60 | parser.add_argument("--depth_thres", type=float, default=20.0, help = "Depth threshold - depth clipping.") 61 | parser.add_argument("--baseline", type=float, default=0.26, help = "Stereo baseline distance (in either axis).") 62 | parser.add_argument("--width", type=float, default=512, help = "Spherical image width.") 63 | return parser.parse_known_args(args) 64 | 65 | if __name__ == "__main__": 66 | args, unknown = parse_arguments(sys.argv) 67 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0] 68 | # device & visualizers 69 | device, visualizers, model_params = utils.initialize(args) 70 | plot_viz = visualizers[0] 71 | img_viz = visualizers[1] 72 | # model 73 | model = models.get_model(args.model, model_params) 74 | utils.init.initialize_weights(model, args.weight_init, pred_bias=args.pred_bias) 75 | if (len(gpus) > 1): 76 | model = torch.nn.parallel.DataParallel(model, gpus) 77 | model = model.to(device) 78 | # optimizer 79 | optimizer = utils.init_optimizer(model, args) 80 | # train data 81 | train_data = dataset.dataset_360D.Dataset360D(args.train_path, " ", args.configuration, [256, 512]) 82 | train_data_iterator = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,\ 83 | num_workers=args.batch_size // len(gpus) // 4, pin_memory=False, shuffle=True) 84 | # test data 85 | test_data = dataset.dataset_360D.Dataset360D(args.test_path, " ", args.configuration, [256, 512]) 86 | test_data_iterator = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size,\ 87 | num_workers=args.batch_size // len(gpus) // 4, pin_memory=False, shuffle=True) 88 | print("Data size : {0} | Test size : {1}".format(\ 89 | args.batch_size * train_data_iterator.__len__(), \ 90 | args.test_batch_size * test_data_iterator.__len__())) 91 | # params 92 | width = args.width 93 | height = args.width // 2 94 | photo_params = L.photometric.PhotometricLossParameters( 95 | alpha=args.ssim_alpha, l1_estimator='none', ssim_estimator='none', 96 | ssim_mode=args.ssim_mode, std=args.ssim_std, window=args.ssim_window 97 | ) 98 | iteration_counter = 0 99 | # meters 100 | total_loss = utils.AverageMeter() 101 | running_photo_loss_lr = utils.AverageMeter() 102 | running_photo_loss_ud = utils.AverageMeter() 103 | running_depth_smooth_loss = utils.AverageMeter() 104 | # train / test loop 105 | model.train() 106 | plot_viz.config(**vars(args)) 107 | for epoch in range(args.epochs): 108 | print("Training | Epoch: {}".format(epoch)) 109 | img_viz.update_epoch(epoch) 110 | for batch_id, batch in enumerate(train_data_iterator): 111 | optimizer.zero_grad() 112 | active_loss = torch.tensor(0.0).to(device) 113 | ''' Data ''' 114 | left_rgb = batch['leftRGB'].to(device) 115 | b, _, __, ___ = left_rgb.size() 116 | expand_size = (b, -1, -1, -1) 117 | sgrid = S360.grid.create_spherical_grid(width).to(device) 118 | uvgrid = S360.grid.create_image_grid(width, height).to(device) 119 | right_rgb = batch['rightRGB'].to(device) 120 | up_rgb = batch['upRGB'].to(device) 121 | left_depth = batch['leftDepth'].to(device) 122 | up_depth = batch['upDepth'].to(device) 123 | right_depth = batch['rightDepth'].to(device) 124 | ''' Prediction ''' 125 | left_depth_pred = torch.abs(model(left_rgb)) 126 | ''' Forward Rendering LR ''' 127 | disp_lr = torch.cat( 128 | ( 129 | S360.derivatives.dphi_horizontal(sgrid, left_depth_pred, args.baseline), 130 | S360.derivatives.dtheta_horizontal(sgrid, left_depth_pred, args.baseline) 131 | ), 132 | dim=1 133 | ) 134 | right_render_coords = uvgrid + disp_lr 135 | right_render_coords[:, 0, :, :] = torch.fmod(right_render_coords[:, 0, :, :] + width, width) 136 | right_render_coords[torch.isnan(right_render_coords)] = 0.0 137 | right_render_coords[torch.isinf(right_render_coords)] = 0.0 138 | right_rgb_t, right_mask_t = L.splatting.render(left_rgb, left_depth_pred,\ 139 | right_render_coords, max_depth=args.depth_thres) 140 | ''' Forward Rendering UD ''' 141 | disp_ud = torch.cat( 142 | ( 143 | torch.zeros_like(left_depth_pred), 144 | S360.derivatives.dtheta_vertical(sgrid, left_depth_pred, args.baseline) 145 | ), 146 | dim=1 147 | ) 148 | up_render_coords = uvgrid + disp_ud 149 | up_render_coords[torch.isnan(up_render_coords)] = 0.0 150 | up_render_coords[torch.isinf(up_render_coords)] = 0.0 151 | up_rgb_t, up_mask_t = L.splatting.render(left_rgb, left_depth_pred,\ 152 | up_render_coords, max_depth=args.depth_thres) 153 | ''' Loss LR ''' 154 | right_cutoff_mask = (right_depth < args.depth_thres) 155 | attention_weights_lr = S360.weights.phi_confidence( 156 | S360.grid.create_spherical_grid(width)).to(device) 157 | # attention_weights_lr = S360.weights.spherical_confidence( 158 | # S360.grid.create_spherical_grid(width), zero_low=0.001 159 | # ).to(device) 160 | photo_loss_lr = L.photometric.calculate_loss(right_rgb_t, right_rgb, photo_params, 161 | mask=right_cutoff_mask, weights=attention_weights_lr) 162 | active_loss += photo_loss_lr * args.photo_w * (1 - args.photo_ratio) 163 | ''' Loss UD ''' 164 | up_cutoff_mask = (up_depth < args.depth_thres) 165 | attention_weights_ud = S360.weights.theta_confidence( 166 | S360.grid.create_spherical_grid(width)).to(device) 167 | photo_loss_ud = L.photometric.calculate_loss(up_rgb_t, up_rgb, photo_params, 168 | mask=up_cutoff_mask, weights=attention_weights_ud) 169 | active_loss += photo_loss_ud * args.photo_w * args.photo_ratio 170 | ''' Loss Prior (3D Smoothness) ''' 171 | left_xyz = S360.cartesian.coords_3d(sgrid, left_depth_pred) 172 | dI_dxyz = S360.derivatives.dV_dxyz(left_xyz) 173 | tc_cuttof_mask = right_cutoff_mask & up_cutoff_mask 174 | guidance_duv = S360.derivatives.dI_duv(left_rgb) 175 | depth_smooth_loss = L.smoothness.guided_smoothness_loss( 176 | dI_dxyz, guidance_duv, tc_cuttof_mask, (1.0 - attention_weights_ud) 177 | * tc_cuttof_mask.type(attention_weights_ud.dtype) 178 | ) 179 | active_loss += depth_smooth_loss * args.smooth_reg_w 180 | ''' Update Params ''' 181 | active_loss.backward() 182 | optimizer.step() 183 | ''' Visualize''' 184 | total_loss.update(active_loss) 185 | running_depth_smooth_loss.update(depth_smooth_loss) 186 | running_photo_loss_lr.update(photo_loss_lr) 187 | running_photo_loss_ud.update(photo_loss_ud) 188 | iteration_counter += b 189 | if (iteration_counter + 1) % args.disp_iters <= args.batch_size: 190 | print("Epoch: {}, iteration: {}\nPhotometric (LR-UD): {} - {}\nSmoothness: {}\nTotal average loss: {}\n"\ 191 | .format(epoch, iteration_counter, running_photo_loss_lr.avg, \ 192 | running_photo_loss_ud.avg, running_depth_smooth_loss.avg, total_loss.avg)) 193 | plot_viz.append_loss(epoch + 1, iteration_counter, total_loss.avg, "avg") 194 | plot_viz.append_loss(epoch + 1, iteration_counter, running_photo_loss_lr.avg, "photo_lr") 195 | plot_viz.append_loss(epoch + 1, iteration_counter, running_photo_loss_ud.avg, "photo_ud") 196 | plot_viz.append_loss(epoch + 1, iteration_counter, running_depth_smooth_loss.avg, "smooth") 197 | total_loss.reset() 198 | running_photo_loss_lr.reset() 199 | running_photo_loss_ud.reset() 200 | running_depth_smooth_loss.reset() 201 | if args.visdom_iters > 0 and (iteration_counter + 1) % args.visdom_iters <= args.batch_size: 202 | img_viz.show_separate_images(left_rgb, 'input') 203 | img_viz.show_separate_images(right_rgb, 'right') 204 | img_viz.show_separate_images(up_rgb, 'up') 205 | img_viz.show_map(left_depth_pred, 'depth') 206 | img_viz.show_separate_images(torch.clamp(right_rgb_t, min=0.0, max=1.0), 'recon_lr') 207 | img_viz.show_separate_images(torch.clamp(up_rgb_t, min=0.0, max=1.0), 'recon_ud') 208 | ''' Save ''' 209 | print("Saving model @ epoch #" + str(epoch)) 210 | utils.checkpoint.save_network_state(model, optimizer, epoch,\ 211 | args.name + "_model_state", args.save_path) 212 | ''' Test ''' 213 | print("Testing model @ epoch #" + str(epoch)) 214 | model.eval() 215 | with torch.no_grad(): 216 | rmse_avg = torch.tensor(0.0).float() 217 | counter = torch.tensor(0.0).float() 218 | for test_batch_id , test_batch in enumerate(test_data_iterator): 219 | left_rgb = test_batch['leftRGB'].to(device) 220 | b, c, h, w = left_rgb.size() 221 | rads = sgrid.expand(b, -1, -1, -1) 222 | uv = uvgrid.expand(b, -1, -1, -1) 223 | left_depth_pred = torch.abs(model(left_rgb)) 224 | left_depth = test_batch['leftDepth'].to(device) 225 | left_depth[torch.isnan(left_depth)] = 50.0 226 | left_depth[torch.isinf(left_depth)] = 50.0 227 | mse = (left_depth_pred ** 2) - (left_depth ** 2) 228 | mse[torch.isnan(mse)] = 0.0 229 | mse[torch.isinf(mse)] = 0.0 230 | mask = (left_depth < args.depth_thres).float() 231 | if torch.sum(mask) == 0: 232 | continue 233 | rmse = torch.sqrt(torch.sum(mse * mask) / torch.sum(mask).float()) 234 | if not torch.isnan(rmse): 235 | rmse_avg += rmse.cpu().float() 236 | counter += torch.tensor(b).float() 237 | if counter < args.save_iters: 238 | disp = torch.cat( 239 | ( 240 | S360.derivatives.dphi_horizontal(rads, left_depth_pred, args.baseline), 241 | S360.derivatives.dtheta_horizontal(rads, left_depth_pred, args.baseline) 242 | ), dim=1 243 | ) 244 | right_render_coords = uv + disp 245 | right_render_coords[:, 0, :, :] = torch.fmod(right_render_coords[:, 0, :, :] + width, width) 246 | right_render_coords[torch.isnan(right_render_coords)] = 0.0 247 | right_render_coords[torch.isinf(right_render_coords)] = 0.0 248 | right_rgb_t, right_mask_t = L.splatting.render(left_rgb, left_depth_pred, right_render_coords, max_depth=args.depth_thres) 249 | # save 250 | IO.image.save_image(os.path.join(args.save_path,\ 251 | str(epoch) + "_" + str(counter) + "_#_left.png"), left_rgb) 252 | IO.image.save_image(os.path.join(args.save_path,\ 253 | str(epoch) + "_" + str(counter) + "_#_right_t.png"), right_rgb_t) 254 | IO.image.save_data(os.path.join(args.save_path,\ 255 | str(epoch) + "_" + str(counter) + "_#_depth.exr"), left_depth_pred, scale=1.0) 256 | if (counter == 0) or (torch.isnan(rmse_avg) > 0): 257 | print("Error calculating RMSE (val:%f , sum:%d)" % (rmse_avg, counter)) 258 | plot_viz.append_loss(epoch + 1, epoch + 1, torch.tensor(0.0), "rmse", mode='test') 259 | else: 260 | rmse_avg /= counter 261 | print("Testing epoch {}: RMSE = {}".format(epoch+1, rmse_avg)) 262 | plot_viz.append_loss(epoch + 1, epoch + 1, rmse_avg, "rmse", mode='test') 263 | torch.enable_grad() 264 | model.train() 265 | -------------------------------------------------------------------------------- /train_ud.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import numpy 5 | 6 | import torch 7 | 8 | import models 9 | import dataset 10 | import utils 11 | 12 | import supervision as L 13 | import exporters as IO 14 | import spherical as S360 15 | 16 | def parse_arguments(args): 17 | usage_text = ( 18 | "Omnidirectional Vertical Stereo Placement (Up-Down , UD) Training." 19 | ) 20 | parser = argparse.ArgumentParser(description=usage_text) 21 | # durations 22 | parser.add_argument('-e',"--epochs", type=int, help="Train for a total number of epochs.") 23 | parser.add_argument('-b',"--batch_size", type=int, help="Train with a number of samples each train iteration.") 24 | parser.add_argument("--test_batch_size", default=1, type=int, help="Test with a number of samples each test iteration.") 25 | parser.add_argument('-d','--disp_iters', type=int, default=50, help='Log training progress (i.e. loss etc.) on console every iterations.') 26 | parser.add_argument('--save_iters', type=int, default=100, help='Maximum test iterations to perform each test run.') 27 | # paths 28 | parser.add_argument("--train_path", type=str, help="Path to the training file containing the train set files paths") 29 | parser.add_argument("--test_path", type=str, help="Path to the testing file containing the test set file paths") 30 | parser.add_argument("--save_path", type=str, help="Path to the folder where the models and results will be saved at.") 31 | # model 32 | parser.add_argument("--configuration", required=False, type=str, default='mono', help="Data loader configuration , , , ", choices=['mono', 'lr', 'ud', 'tc']) 33 | parser.add_argument('--weight_init', type=str, default="xavier", help='Weight initialization method, or path to weights file (for fine-tuning or continuing training)') 34 | parser.add_argument('--model', default="default", type=str, help='Model selection argument.') 35 | # optimization 36 | parser.add_argument('-o','--optimizer', type=str, default="adam", help='The optimizer that will be used during training.') 37 | parser.add_argument("--opt_state", type=str, help="Path to stored optimizer state file to continue training)") 38 | parser.add_argument('-l','--lr', type=float, default=0.0002, help='Optimization Learning Rate.') 39 | parser.add_argument('-m','--momentum', type=float, default=0.9, help='Optimization Momentum.') 40 | parser.add_argument('--momentum2', type=float, default=0.999, help='Optimization Second Momentum (optional, only used by some optimizers).') 41 | parser.add_argument('--epsilon', type=float, default=1e-8, help='Optimization Epsilon (optional, only used by some optimizers).') 42 | parser.add_argument('--weight_decay', type=float, default=0, help='Optimization Weight Decay.') 43 | # hardware 44 | parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.') 45 | # other 46 | parser.add_argument('-n','--name', type=str, default='default_name', help='The name of this train/test. Used when storing information.') 47 | parser.add_argument("--visdom", type=str, nargs='?', default=None, const="127.0.0.1", help="Visdom server IP (port defaults to 8097)") 48 | parser.add_argument("--visdom_iters", type=int, default=400, help = "Iteration interval that results will be reported at the visdom server for visualization.") 49 | parser.add_argument("--seed", type=int, default=1337, help="Fixed manual seed, zero means no seeding.") 50 | # network specific params 51 | parser.add_argument("--photo_w", type=float, default=1.0, help = "Photometric loss weight.") 52 | parser.add_argument("--smooth_reg_w", type=float, default=0.1, help = "Smoothness regularization weight.") 53 | parser.add_argument("--ssim_window", type=int, default=7, help = "Kernel size to use in SSIM calculation.") 54 | parser.add_argument("--ssim_mode", type=str, default='gaussian', help = "Type of SSIM averaging (either gaussian or box).") 55 | parser.add_argument("--ssim_std", type=float, default=1.5, help = "SSIM standard deviation value used when creating the gaussian averaging kernels.") 56 | parser.add_argument("--ssim_alpha", type=float, default=0.85, help = "Alpha factor to weight the SSIM and L1 losses, where a x SSIM and (1 - a) x L1.") 57 | parser.add_argument("--pred_bias", type=float, default=5.0, help = "Initialize prediction layers' bias to the given value (helps convergence).") 58 | # details 59 | parser.add_argument("--depth_thres", type=float, default=20.0, help = "Depth threshold - depth clipping.") 60 | parser.add_argument("--baseline", type=float, default=0.26, help = "Stereo baseline distance (in either axis).") 61 | parser.add_argument("--width", type=float, default=512, help = "Spherical image width.") 62 | return parser.parse_known_args(args) 63 | 64 | if __name__ == "__main__": 65 | args, unknown = parse_arguments(sys.argv) 66 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0] 67 | # device & visualizers 68 | device, visualizers, model_params = utils.initialize(args) 69 | plot_viz = visualizers[0] 70 | img_viz = visualizers[1] 71 | # model 72 | model = models.get_model(args.model, model_params) 73 | utils.init.initialize_weights(model, args.weight_init, pred_bias=args.pred_bias) 74 | if (len(gpus) > 1): 75 | model = torch.nn.parallel.DataParallel(model, gpus) 76 | model = model.to(device) 77 | # optimizer 78 | optimizer = utils.init_optimizer(model, args) 79 | # train data 80 | train_data = dataset.dataset_360D.Dataset360D(args.train_path, " ", args.configuration, [256, 512]) 81 | train_data_iterator = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,\ 82 | num_workers=args.batch_size // 4 // len(gpus), pin_memory=False, shuffle=True) 83 | # test data 84 | test_data = dataset.dataset_360D.Dataset360D(args.test_path, " ", args.configuration, [256, 512]) 85 | test_data_iterator = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size,\ 86 | num_workers=args.batch_size // 4 // len(gpus), pin_memory=False, shuffle=True) 87 | print("Data size : {0} | Test size : {1}".format(\ 88 | args.batch_size * train_data_iterator.__len__(), \ 89 | args.test_batch_size * test_data_iterator.__len__())) 90 | # params 91 | width = args.width 92 | height = args.width // 2 93 | photo_params = L.photometric.PhotometricLossParameters( 94 | alpha=args.ssim_alpha, l1_estimator='none', ssim_estimator='none', 95 | ssim_mode=args.ssim_mode, std=args.ssim_std, window=args.ssim_window 96 | ) 97 | iteration_counter = 0 98 | # meters 99 | total_loss = utils.AverageMeter() 100 | running_photo_loss = utils.AverageMeter() 101 | running_depth_smooth_loss = utils.AverageMeter() 102 | # train / test loop 103 | model.train() 104 | plot_viz.config(**vars(args)) 105 | for epoch in range(args.epochs): 106 | print("Training | Epoch: {}".format(epoch)) 107 | img_viz.update_epoch(epoch) 108 | for batch_id, batch in enumerate(train_data_iterator): 109 | optimizer.zero_grad() 110 | active_loss = torch.tensor(0.0).to(device) 111 | ''' Data ''' 112 | left_rgb = batch['leftRGB'].to(device) 113 | b, _, __, ___ = left_rgb.size() 114 | expand_size = (b, -1, -1, -1) 115 | sgrid = S360.grid.create_spherical_grid(width).to(device) 116 | uvgrid = S360.grid.create_image_grid(width, height).to(device) 117 | up_rgb = batch['upRGB'].to(device) 118 | left_depth = batch['leftDepth'].to(device) 119 | up_depth = batch['upDepth'].to(device) 120 | ''' Prediction ''' 121 | left_depth_pred = torch.abs(model(left_rgb)) 122 | ''' Forward Rendering UD ''' 123 | disp = torch.cat( 124 | ( 125 | torch.zeros_like(left_depth_pred), 126 | S360.derivatives.dtheta_vertical(sgrid, left_depth_pred, args.baseline) 127 | ), 128 | dim=1 129 | ) 130 | up_render_coords = uvgrid + disp 131 | up_render_coords[torch.isnan(up_render_coords)] = 0.0 132 | up_render_coords[torch.isinf(up_render_coords)] = 0.0 133 | up_rgb_t, up_mask_t = L.splatting.render(left_rgb, left_depth_pred,\ 134 | up_render_coords, max_depth=args.depth_thres) 135 | ''' Loss UD ''' 136 | up_cutoff_mask = (up_depth < args.depth_thres) 137 | up_mask_t &= ~(up_depth > args.depth_thres) 138 | attention_weights = S360.weights.theta_confidence( 139 | S360.grid.create_spherical_grid(width)).to(device) 140 | # attention_weights = torch.ones_like(left_depth) 141 | photo_loss = L.photometric.calculate_loss(up_rgb_t, up_rgb, photo_params, 142 | mask=up_cutoff_mask, weights=attention_weights) 143 | active_loss += photo_loss * args.photo_w 144 | ''' Loss Prior (3D Smoothness) ''' 145 | left_xyz = S360.cartesian.coords_3d(sgrid, left_depth_pred) 146 | dI_dxyz = S360.derivatives.dV_dxyz(left_xyz) 147 | guidance_duv = S360.derivatives.dI_duv(left_rgb) 148 | # attention_weights = torch.zeros_like(left_depth) 149 | depth_smooth_loss = L.smoothness.guided_smoothness_loss( 150 | dI_dxyz, guidance_duv, up_cutoff_mask, (1.0 - attention_weights) 151 | * up_cutoff_mask.type(attention_weights.dtype) 152 | ) 153 | active_loss += depth_smooth_loss * args.smooth_reg_w 154 | ''' Update Params ''' 155 | active_loss.backward() 156 | optimizer.step() 157 | ''' Visualize''' 158 | total_loss.update(active_loss) 159 | running_depth_smooth_loss.update(depth_smooth_loss) 160 | running_photo_loss.update(photo_loss) 161 | iteration_counter += b 162 | if (iteration_counter + 1) % args.disp_iters <= args.batch_size: 163 | print("Epoch: {}, iteration: {}\nPhotometric: {}\nSmoothness: {}\nTotal average loss: {}\n"\ 164 | .format(epoch, iteration_counter, running_photo_loss.avg, \ 165 | running_depth_smooth_loss.avg, total_loss.avg)) 166 | plot_viz.append_loss(epoch + 1, iteration_counter, total_loss.avg, "avg") 167 | plot_viz.append_loss(epoch + 1, iteration_counter, running_photo_loss.avg, "photo") 168 | plot_viz.append_loss(epoch + 1, iteration_counter, running_depth_smooth_loss.avg, "smooth") 169 | total_loss.reset() 170 | running_photo_loss.reset() 171 | running_depth_smooth_loss.reset() 172 | if args.visdom_iters > 0 and (iteration_counter + 1) % args.visdom_iters <= args.batch_size: 173 | img_viz.show_separate_images(left_rgb, 'input') 174 | img_viz.show_separate_images(up_rgb, 'target') 175 | img_viz.show_map(left_depth_pred, 'depth') 176 | img_viz.show_separate_images(torch.clamp(up_rgb_t, min=0.0, max=1.0), 'recon') 177 | ''' Save ''' 178 | print("Saving model @ epoch #" + str(epoch)) 179 | utils.checkpoint.save_network_state(model, optimizer, epoch,\ 180 | args.name + "_model_state", args.save_path) 181 | ''' Test ''' 182 | print("Testing model @ epoch #" + str(epoch)) 183 | model.eval() 184 | with torch.no_grad(): 185 | rmse_avg = torch.tensor(0.0).float() 186 | counter = torch.tensor(0.0).float() 187 | for test_batch_id , test_batch in enumerate(test_data_iterator): 188 | left_rgb = test_batch['leftRGB'].to(device) 189 | b, c, h, w = left_rgb.size() 190 | rads = sgrid.expand(b, -1, -1, -1) 191 | uv = uvgrid.expand(b, -1, -1, -1) 192 | left_depth_pred = torch.abs(model(left_rgb)) 193 | left_depth = test_batch['leftDepth'].to(device) 194 | left_depth[torch.isnan(left_depth)] = 50.0 195 | left_depth[torch.isinf(left_depth)] = 50.0 196 | mse = (left_depth_pred ** 2) - (left_depth ** 2) 197 | mse[torch.isnan(mse)] = 0.0 198 | mse[torch.isinf(mse)] = 0.0 199 | mask = (left_depth < args.depth_thres).float() 200 | if torch.sum(mask) == 0: 201 | continue 202 | rmse = torch.sqrt(torch.sum(mse * mask) / torch.sum(mask).float()) 203 | if not torch.isnan(rmse): 204 | rmse_avg += rmse.cpu().float() 205 | counter += torch.tensor(b).float() 206 | if counter < args.save_iters: 207 | disp = torch.cat( 208 | ( 209 | torch.zeros_like(left_depth_pred), 210 | S360.derivatives.dtheta_vertical(rads, left_depth_pred, args.baseline) 211 | ), dim=1 212 | ) 213 | up_render_coords = uv + disp 214 | up_render_coords[torch.isnan(up_render_coords)] = 0.0 215 | up_render_coords[torch.isinf(up_render_coords)] = 0.0 216 | up_rgb_t, up_mask_t = L.splatting.render(left_rgb, left_depth_pred, \ 217 | up_render_coords, max_depth=args.depth_thres) 218 | # save 219 | IO.image.save_image(os.path.join(args.save_path,\ 220 | str(epoch) + "_" + str(counter) + "_#_left.png"), left_rgb) 221 | IO.image.save_image(os.path.join(args.save_path,\ 222 | str(epoch) + "_" + str(counter) + "_#_up_t.png"), up_rgb_t) 223 | IO.image.save_data(os.path.join(args.save_path,\ 224 | str(epoch) + "_" + str(counter) + "_#_depth.exr"), left_depth_pred, scale=1.0) 225 | rmse_avg /= counter 226 | print("Testing epoch {}: RMSE = {}".format(epoch+1, rmse_avg)) 227 | plot_viz.append_loss(epoch + 1, epoch + 1, rmse_avg, "rmse", mode='test') 228 | torch.enable_grad() 229 | model.train() 230 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .init import * 2 | from .opt import * 3 | from .visualization import * 4 | from .framework import * 5 | from .meters import * 6 | from .checkpoint import * -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from os import path 4 | 5 | def save_network_state(model, optimizer, epoch , name , save_path): 6 | if not path.exists(save_path): 7 | raise ValueError("{} not a valid path to save model state".format(save_path)) 8 | torch.save( 9 | { 10 | 'epoch' : epoch, 11 | 'model_state_dict' : model.state_dict(), 12 | 'optimizer_state_dict' : optimizer.state_dict() 13 | }, path.join(save_path, "{}_e{}.pt".format(name, epoch))) 14 | 15 | -------------------------------------------------------------------------------- /utils/framework.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | import numpy 4 | import random 5 | 6 | from .opt import * 7 | from .visualization import * 8 | 9 | def initialize(args): #TODO: add visdom count as argument 10 | # create and init device 11 | print("{} | Torch Version: {}".format(datetime.datetime.now(), torch.__version__)) 12 | if args.seed > 0: 13 | print("Set to reproducibility mode with seed: {}".format(args.seed)) 14 | torch.manual_seed(args.seed) 15 | torch.cuda.manual_seed_all(args.seed) 16 | numpy.random.seed(args.seed) 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | random.seed(args.seed) 20 | gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0] 21 | device = torch.device("cuda:{}" .format(gpus[0]) if torch.cuda.is_available() and len(gpus) > 0 and gpus[0] >= 0 else "cpu") 22 | print("Training {0} for {1} epochs using a batch size of {2} on {3}".format(args.name, args.epochs, args.batch_size, device)) 23 | # create visualizer 24 | visualizer = (NullVisualizer(), NullVisualizer())\ 25 | if args.visdom is None\ 26 | else ( 27 | VisdomPlotVisualizer(args.name + "_plots_", args.visdom), 28 | VisdomImageVisualizer(args.name + "_images_", args.visdom,\ 29 | count=2 if 2 <= args.batch_size else args.batch_size) 30 | ) 31 | if args.visdom is None: 32 | args.visdom_iters = 0 33 | # create & init model 34 | model_params = { 35 | 'width': 512, 36 | 'height': 256, 37 | 'configuration': args.configuration, 38 | } 39 | return device, visualizer, model_params 40 | 41 | def init_optimizer(model, args): 42 | opt_params = OptimizerParameters(learning_rate=args.lr, momentum=args.momentum,\ 43 | momentum2=args.momentum2, epsilon=args.epsilon) 44 | optimizer = get_optimizer(args.optimizer, model.parameters(), opt_params) 45 | if args.opt_state is not None: 46 | opt_state = torch.load(args.opt_state) 47 | print("Loading previously saved optimizer state from {}".format(args.opt_state)) 48 | optimizer.load_state_dict(opt_state["optimizer_state_dict"]) 49 | return optimizer 50 | -------------------------------------------------------------------------------- /utils/init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import os 4 | import sys 5 | 6 | def initialize_weights(model, init = "xavier", pred_bias=None): 7 | init_func = None 8 | if init == "xavier": 9 | init_func = torch.nn.init.xavier_normal_ 10 | elif init == "kaiming": 11 | init_func = torch.nn.init.kaiming_normal_ 12 | elif init == "gaussian" or init == "normal": 13 | init_func = torch.nn.init.normal_ 14 | 15 | if init_func is not None: 16 | #TODO: logging /w print or lib 17 | for module in model.modules(): 18 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ 19 | or isinstance(module, torch.nn.ConvTranspose2d): 20 | init_func(module.weight) 21 | if module.bias is not None: 22 | module.bias.data.zero_() 23 | elif isinstance(module, torch.nn.BatchNorm2d): 24 | module.weight.data.fill_(1) 25 | module.bias.data.zero_() 26 | if pred_bias is not None: 27 | list(model.modules())[-1].bias.data.fill_(pred_bias) 28 | elif os.path.exists(init): 29 | #TODO: logging /w print or lib 30 | weights = torch.load(init, map_location={'cuda:1':'cuda:0'}) 31 | model.load_state_dict(weights["model_state_dict"]) 32 | else: 33 | print("Error when initializing model's weights, {} either doesn't exist or is not a valid initialization function.".format(init), \ 34 | file=sys.stderr) 35 | 36 | def initialize_prediction_bias(model, pred_bias=None): 37 | if pred_bias is not None: 38 | list(model.modules())[-1].bias.data.fill_(pred_bias) 39 | 40 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Computes and stores the average and current value 4 | class AverageMeter(object): 5 | def __init__(self): 6 | self.reset() 7 | 8 | def reset(self): 9 | self.val = torch.tensor(0.0) 10 | self.avg = torch.tensor(0.0) 11 | self.sum = torch.tensor(0.0) 12 | self.count = torch.tensor(0.0) 13 | 14 | def update(self, val, n=1): 15 | self.val = val 16 | self.sum += val * n 17 | self.count += n 18 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/opt.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.optim as optim 4 | from torch.optim import Optimizer 5 | 6 | import sys 7 | 8 | class OptimizerParameters(object): 9 | def __init__(self, learning_rate=0.001, momentum=0.9, momentum2=0.999,\ 10 | epsilon=1e-8, weight_decay=0.0005, damp=0): 11 | super(OptimizerParameters, self).__init__() 12 | self.learning_rate = learning_rate 13 | self.momentum = momentum 14 | self.momentum2 = momentum2 15 | self.epsilon = epsilon 16 | self.damp = damp 17 | self.weight_decay = weight_decay 18 | 19 | def get_learning_rate(self): 20 | return self.learning_rate 21 | 22 | def get_momentum(self): 23 | return self.momentum 24 | 25 | def get_momentum2(self): 26 | return self.momentum2 27 | 28 | def get_epsilon(self): 29 | return self.epsilon 30 | 31 | def get_weight_decay(self): 32 | return self.weight_decay 33 | 34 | def get_damp(self): 35 | return self.damp 36 | 37 | def get_optimizer(opt_type, model_params, opt_params): 38 | if opt_type == "adam": 39 | return optim.Adam(model_params, \ 40 | lr=opt_params.get_learning_rate(), \ 41 | betas=(opt_params.get_momentum(), opt_params.get_momentum2()), \ 42 | eps=opt_params.get_epsilon(), 43 | ) 44 | elif opt_type == "adabound" or opt_type == "amsbound": 45 | return AdaBound(model_params, \ 46 | lr=opt_params.get_learning_rate(), \ 47 | betas=(opt_params.get_momentum(), opt_params.get_momentum2()), \ 48 | eps=opt_params.get_epsilon(), 49 | weight_decay=opt_params.get_weight_decay(),\ 50 | final_lr=0.001, gamma=0.002,\ 51 | amsbound=True if opt_type == "amsbound" else False 52 | ) 53 | elif opt_type == "sgd": 54 | return optim.SGD(model_params, \ 55 | lr=opt_params.get_learning_rate(), \ 56 | momentum=opt_params.get_momentum(), \ 57 | weight_decay=opt_params.get_weight_decay(), \ 58 | dampening=opt_params.get_damp() \ 59 | ) 60 | else: 61 | print("Error when initializing optimizer, {} is not a valid optimizer type.".format(opt_type), \ 62 | file=sys.stderr) 63 | return None 64 | 65 | def adjust_learning_rate(optimizer, epoch, scale=2): 66 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 67 | for param_group in optimizer.param_groups: 68 | lr = param_group['lr'] 69 | lr = lr * (0.1 ** (epoch // scale)) 70 | param_group['lr'] = lr 71 | 72 | 73 | ''' 74 | Code from https://github.com/Luolc/AdaBound 75 | ''' 76 | 77 | class AdaBound(Optimizer): 78 | """Implements AdaBound algorithm. 79 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. 80 | Arguments: 81 | params (iterable): iterable of parameters to optimize or dicts defining 82 | parameter groups 83 | lr (float, optional): Adam learning rate (default: 1e-3) 84 | betas (Tuple[float, float], optional): coefficients used for computing 85 | running averages of gradient and its square (default: (0.9, 0.999)) 86 | final_lr (float, optional): final (SGD) learning rate (default: 0.1) 87 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3) 88 | eps (float, optional): term added to the denominator to improve 89 | numerical stability (default: 1e-8) 90 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 91 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm 92 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: 93 | https://openreview.net/forum?id=Bkg3g2R9FX 94 | """ 95 | 96 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3, 97 | eps=1e-8, weight_decay=0, amsbound=False): 98 | if not 0.0 <= lr: 99 | raise ValueError("Invalid learning rate: {}".format(lr)) 100 | if not 0.0 <= eps: 101 | raise ValueError("Invalid epsilon value: {}".format(eps)) 102 | if not 0.0 <= betas[0] < 1.0: 103 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 104 | if not 0.0 <= betas[1] < 1.0: 105 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 106 | if not 0.0 <= final_lr: 107 | raise ValueError("Invalid final learning rate: {}".format(final_lr)) 108 | if not 0.0 <= gamma < 1.0: 109 | raise ValueError("Invalid gamma parameter: {}".format(gamma)) 110 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, 111 | weight_decay=weight_decay, amsbound=amsbound) 112 | super(AdaBound, self).__init__(params, defaults) 113 | 114 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) 115 | 116 | def __setstate__(self, state): 117 | super(AdaBound, self).__setstate__(state) 118 | for group in self.param_groups: 119 | group.setdefault('amsbound', False) 120 | 121 | def step(self, closure=None): 122 | """Performs a single optimization step. 123 | Arguments: 124 | closure (callable, optional): A closure that reevaluates the model 125 | and returns the loss. 126 | """ 127 | loss = None 128 | if closure is not None: 129 | loss = closure() 130 | 131 | for group, base_lr in zip(self.param_groups, self.base_lrs): 132 | for p in group['params']: 133 | if p.grad is None: 134 | continue 135 | grad = p.grad.data 136 | if grad.is_sparse: 137 | raise RuntimeError( 138 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 139 | amsbound = group['amsbound'] 140 | 141 | state = self.state[p] 142 | 143 | # State initialization 144 | if len(state) == 0: 145 | state['step'] = 0 146 | # Exponential moving average of gradient values 147 | state['exp_avg'] = torch.zeros_like(p.data) 148 | # Exponential moving average of squared gradient values 149 | state['exp_avg_sq'] = torch.zeros_like(p.data) 150 | if amsbound: 151 | # Maintains max of all exp. moving avg. of sq. grad. values 152 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 153 | 154 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 155 | if amsbound: 156 | max_exp_avg_sq = state['max_exp_avg_sq'] 157 | beta1, beta2 = group['betas'] 158 | 159 | state['step'] += 1 160 | 161 | if group['weight_decay'] != 0: 162 | grad = grad.add(group['weight_decay'], p.data) 163 | 164 | # Decay the first and second moment running average coefficient 165 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 166 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 167 | if amsbound: 168 | # Maintains the maximum of all 2nd moment running avg. till now 169 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 170 | # Use the max. for normalizing running avg. of gradient 171 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 172 | else: 173 | denom = exp_avg_sq.sqrt().add_(group['eps']) 174 | 175 | bias_correction1 = 1 - beta1 ** state['step'] 176 | bias_correction2 = 1 - beta2 ** state['step'] 177 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 178 | 179 | # Applies bounds on actual learning rate 180 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 181 | final_lr = group['final_lr'] * group['lr'] / base_lr 182 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) 183 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) 184 | step_size = torch.full_like(denom, step_size) 185 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) 186 | 187 | p.data.add_(-step_size) 188 | 189 | return loss -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | import numpy 3 | import torch 4 | import datetime 5 | from json2html import * 6 | 7 | class NullVisualizer(object): 8 | def __init__(self): 9 | self.name = __name__ 10 | 11 | def append_loss(self, epoch, global_iteration, loss, mode='train'): 12 | pass 13 | 14 | def show_images(self, images, title): 15 | pass 16 | 17 | def update_epoch(self, epoch): 18 | pass 19 | 20 | class VisdomPlotVisualizer(object): 21 | def __init__(self, name, server="http://localhost"): 22 | self.visualizer = visdom.Visdom(server=server, port=8097, env=name,\ 23 | use_incoming_socket=False) 24 | self.name = name 25 | self.server = server 26 | self.first_train_value = True 27 | self.first_test_value = True 28 | self.plots = {} 29 | 30 | def append_loss(self, epoch, global_iteration, loss, loss_name="total", mode='train'): 31 | plot_name = loss_name + ('_loss' if mode == 'train' else '_error') 32 | opts = ( 33 | { 34 | 'title': plot_name, 35 | 'xlabel': 'iterations', 36 | 'ylabel': loss_name 37 | }) 38 | loss_value = float(loss.detach().cpu().numpy()) 39 | if loss_name not in self.plots: 40 | self.plots[loss_name] = self.visualizer.line(X=numpy.array([global_iteration]),\ 41 | Y=numpy.array([loss_value]), opts=opts) 42 | else: 43 | self.visualizer.line(X=numpy.array([global_iteration]),\ 44 | Y=numpy.array([loss_value]), win=self.plots[loss_name], name=mode, update='append') 45 | 46 | def config(self, **kwargs): 47 | self.visualizer.text(json2html.convert(json=dict(kwargs))) 48 | 49 | def update_epoch(self, epoch): 50 | pass 51 | 52 | class VisdomImageVisualizer(object): 53 | def __init__(self, name, server="http://localhost", count=2): 54 | self.name = name 55 | self.server = server 56 | self.count = count 57 | 58 | def update_epoch(self, epoch): 59 | self.visualizer = visdom.Visdom(server=self.server, port=8097,\ 60 | env=self.name + str(epoch), use_incoming_socket=False) 61 | 62 | def show_separate_images(self, images, title): 63 | b, c, h, w = images.size() 64 | take = self.count if self.count < b else b 65 | recon_images = images.detach().cpu()[:take, [2, 1, 0], :, :]\ 66 | if c == 3 else images.detach().cpu()[:take, :, :, :] 67 | for i in range(take): 68 | img = recon_images[i, :, :, :] 69 | opts = ( 70 | { 71 | 'title': title + "_" + str(i), 72 | 'width': w, 'height': h 73 | }) 74 | self.visualizer.image(img, opts=opts,\ 75 | win=self.name + title + "_window_" + str(i)) 76 | 77 | def show_map(self, maps, title): 78 | b, c, h, w = maps.size() 79 | maps_cpu = torch.flip(maps, dims=[2]).detach().cpu()[:self.count, :, :, :] 80 | for i in range(min(b, self.count)): 81 | opts = ( 82 | { 83 | 'title': title + str(i), 'colormap': 'Viridis' 84 | }) 85 | heatmap = maps_cpu[i, :, :, :].squeeze(0) 86 | #TODO: flip images before heatmap call 87 | self.visualizer.heatmap(heatmap,\ 88 | opts=opts, win=self.name + title + "_window_" + str(i)) 89 | --------------------------------------------------------------------------------