├── .gitignore ├── README.md ├── criteria.py ├── dataset.py ├── metrics.py ├── network.py ├── omnidepth.yml ├── omnidepth_trainer.py ├── preprocess.py ├── show ├── 1.png ├── 2.png ├── 3.png └── 4.png ├── test_omnidepth.py ├── train_omnidepth.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | /data/ 128 | /experiments/ 129 | 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### OmniDepth-Pytorch 2 | A PyTorch reimplementation of the Omnidepth paper from Zioulis et al., ECCV 2018: 3 | 4 | Notable difference with the paper: PyTorch's weight decay for the Adam solver does not seem to function the same way as Caffe's. Hence, I do not use weight decay in training. Instead, I use a learning rate schedule 5 | 6 | this project is riched by some trainning visualizations based on others 7 | 8 | ### Dependencies 9 | - Python 3.7 10 | - Pytorch 1.1 11 | - Visdom 0.1.8 12 | 13 | ### Show 14 | ![training process](show/1.png) 15 | ![loss.etc metrics show](show/2.png) 16 | ![indoor result](show/3.png) 17 | ![outdoor result](show/4.png) 18 | 19 | ### Credit 20 | - fork from https://github.com/meder411/OmniDepth-PyTorch 21 | 22 | - If you do use this repository, please make sure to cite the authors' original paper: 23 | Zioulis, Nikolaos, et al. "OmniDepth: Dense Depth Estimation for Indoors Spherical Panoramas." 24 | Proceedings of the European Conference on Computer Vision (ECCV). 2018. 25 | -------------------------------------------------------------------------------- /criteria.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SquaredGradientLoss(nn.Module): 7 | '''Compute the gradient magnitude of an image using the simple filters as in: 8 | Garg, Ravi, et al. "Unsupervised cnn for single view depth estimation: Geometry to the rescue." European Conference on Computer Vision. Springer, Cham, 2016. 9 | ''' 10 | 11 | def __init__(self): 12 | super(SquaredGradientLoss, self).__init__() 13 | 14 | self.register_buffer('dx_filter', torch.FloatTensor([ 15 | [0, 0, 0], 16 | [-0.5, 0, 0.5], 17 | [0, 0, 0]]).view(1, 1, 3, 3)) 18 | self.register_buffer('dy_filter', torch.FloatTensor([ 19 | [0, -0.5, 0], 20 | [0, 0, 0], 21 | [0, 0.5, 0]]).view(1, 1, 3, 3)) 22 | 23 | def forward(self, pred, mask): 24 | dx = F.conv2d( 25 | pred, 26 | self.dx_filter.to(pred.get_device()), 27 | padding=1, 28 | groups=pred.shape[1]) 29 | dy = F.conv2d( 30 | pred, 31 | self.dy_filter.to(pred.get_device()), 32 | padding=1, 33 | groups=pred.shape[1]) 34 | 35 | error = mask * \ 36 | (dx.abs().sum(1, keepdim=True) + dy.abs().sum(1, keepdim=True)) 37 | 38 | return error.sum() / (mask > 0).sum().float() 39 | 40 | 41 | class L2Loss(nn.Module): 42 | 43 | def __init__(self): 44 | super(L2Loss, self).__init__() 45 | 46 | self.metric = nn.MSELoss() 47 | 48 | def forward(self, pred, gt, mask): 49 | error = mask * self.metric(pred, gt) 50 | return error.sum() / (mask > 0).sum().float() 51 | 52 | 53 | class MultiScaleL2Loss(nn.Module): 54 | 55 | def __init__(self, alpha_list, beta_list): 56 | super(MultiScaleL2Loss, self).__init__() 57 | 58 | self.depth_metric = L2Loss() 59 | self.grad_metric = SquaredGradientLoss() 60 | self.alpha_list = alpha_list 61 | self.beta_list = beta_list 62 | 63 | def forward(self, pred_list, gt_list, mask_list): 64 | # Go through each scale and accumulate errors 65 | depth_error = 0 66 | for i in range(len(pred_list)): 67 | depth_pred = pred_list[i] 68 | depth_gt = gt_list[i] 69 | mask = mask_list[i] 70 | alpha = self.alpha_list[i] 71 | beta = self.beta_list[i] 72 | 73 | # Compute depth error at this scale 74 | depth_error += alpha * self.depth_metric( 75 | depth_pred, 76 | depth_gt, 77 | mask) 78 | 79 | # Compute gradient error at this scale 80 | depth_error += beta * self.grad_metric( 81 | depth_pred, 82 | mask) 83 | 84 | return depth_error 85 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | __author__ = "Marc Eder" 2 | 3 | # * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 4 | # * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 5 | 6 | import torch 7 | import torch.utils.data 8 | 9 | import numpy as np 10 | from skimage import io 11 | 12 | from PIL import Image 13 | import os 14 | import re 15 | 16 | 17 | class OmniDepthDataset(torch.utils.data.Dataset): 18 | '''PyTorch dataset module for effiicient loading''' 19 | 20 | def __init__(self, imgs_path): 21 | self.image_list = [] 22 | for img_name in os.listdir(imgs_path): 23 | if re.match(r'.+\d\.jpeg', img_name): 24 | self.image_list.append(imgs_path + os.path.splitext(img_name)[0]) 25 | 26 | self.max_depth = 255.0 27 | 28 | def __getitem__(self, idx): 29 | '''Load the data''' 30 | # Select the panos to load 31 | common_paths = self.image_list[idx] 32 | 33 | # Load the panos 34 | rgb = self.readRGBPano(common_paths + ".jpeg") 35 | depth = self.readDepthPanoFromJPEG(common_paths + "_d.jpeg") 36 | depth_mask = ((depth <= self.max_depth) & (depth > 0.)).astype(np.uint8) 37 | # Threshold depths 38 | depth *= depth_mask 39 | 40 | # Make a list of loaded data 41 | pano_data = [rgb, depth, depth_mask, common_paths] 42 | # Convert to torch format 43 | pano_data[0] = torch.from_numpy(pano_data[0].transpose(2, 0, 1)).float() 44 | pano_data[1] = torch.from_numpy(pano_data[1][None, ...]).float() 45 | pano_data[2] = torch.from_numpy(pano_data[2][None, ...]).float() 46 | # Return the set of pano data 47 | return pano_data 48 | 49 | def __len__(self): 50 | '''Return the size of this dataset''' 51 | return len(self.image_list) 52 | 53 | def readRGBPano(self, path): 54 | '''Read RGB and normalize to [0,1].''' 55 | rgb = io.imread(path).astype(np.float32) / 255. 56 | return rgb 57 | 58 | def readDepthPanoFromJPEG(self, path): 59 | img = Image.open(path) 60 | img.load() 61 | data = np.asarray(img, dtype="float32") 62 | res = np.reshape(data, (img.height, img.width, 3))[..., 0] 63 | return res 64 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # ========================== 5 | # Depth Prediction Metrics 6 | # ========================== 7 | 8 | def abs_rel_error(pred, gt, mask): 9 | '''Compute absolute relative difference error''' 10 | return ((pred[mask > 0] - gt[mask > 0]).abs() / gt[mask > 0]).mean() 11 | 12 | 13 | def sq_rel_error(pred, gt, mask): 14 | '''Compute squared relative difference error''' 15 | return (((pred[mask > 0] - gt[mask > 0]) ** 2) / gt[mask > 0]).mean() 16 | 17 | 18 | def lin_rms_sq_error(pred, gt, mask): 19 | '''Compute the linear RMS error except the final square-root step''' 20 | return ((pred[mask > 0] - gt[mask > 0]) ** 2).mean() 21 | 22 | 23 | def log_rms_sq_error(pred, gt, mask): 24 | '''Compute the log RMS error except the final square-root step''' 25 | mask = (mask > 0) & (pred > 1e-7) & (gt > 1e-7) # Compute a mask of valid values 26 | return ((pred[mask].log() - gt[mask].log()) ** 2).mean() 27 | 28 | 29 | def delta_inlier_ratio(pred, gt, mask, degree=1): 30 | '''Compute the delta inlier rate to a specified degree (def: 1)''' 31 | return (torch.max(pred[mask > 0] / gt[mask > 0], gt[mask > 0] / pred[mask > 0]) < (1.25 ** degree)).float().mean() 32 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | # * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 2 | __author__ = "Marc Eder" 3 | # * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from util import xavier_init 10 | 11 | 12 | class RectNet(nn.Module): 13 | 14 | def __init__(self): 15 | super(RectNet, self).__init__() 16 | 17 | # Network definition 18 | self.input0_0 = ConvELUBlock(3, 8, (3, 9), padding=(1, 4)) 19 | self.input0_1 = ConvELUBlock(3, 8, (5, 11), padding=(2, 5)) 20 | self.input0_2 = ConvELUBlock(3, 8, (5, 7), padding=(2, 3)) 21 | self.input0_3 = ConvELUBlock(3, 8, 7, padding=3) 22 | 23 | self.input1_0 = ConvELUBlock(32, 16, (3, 9), padding=(1, 4)) 24 | self.input1_1 = ConvELUBlock(32, 16, (3, 7), padding=(1, 3)) 25 | self.input1_2 = ConvELUBlock(32, 16, (3, 5), padding=(1, 2)) 26 | self.input1_3 = ConvELUBlock(32, 16, 5, padding=2) 27 | 28 | self.encoder0_0 = ConvELUBlock(64, 128, 3, stride=2, padding=1) 29 | self.encoder0_1 = ConvELUBlock(128, 128, 3, padding=1) 30 | self.encoder0_2 = ConvELUBlock(128, 128, 3, padding=1) 31 | 32 | self.encoder1_0 = ConvELUBlock(128, 256, 3, stride=2, padding=1) 33 | self.encoder1_1 = ConvELUBlock(256, 256, 3, padding=2, dilation=2) 34 | self.encoder1_2 = ConvELUBlock(256, 256, 3, padding=4, dilation=4) 35 | self.encoder1_3 = ConvELUBlock(512, 256, 1) 36 | 37 | self.encoder2_0 = ConvELUBlock(256, 512, 3, padding=8, dilation=8) 38 | self.encoder2_1 = ConvELUBlock(512, 512, 3, padding=16, dilation=16) 39 | self.encoder2_2 = ConvELUBlock(1024, 512, 1) 40 | 41 | self.decoder0_0 = ConvTransposeELUBlock(512, 256, 4, stride=2, padding=1) 42 | self.decoder0_1 = ConvELUBlock(256, 256, 5, padding=2) 43 | 44 | self.prediction0 = nn.Conv2d(256, 1, 3, padding=1) 45 | 46 | self.decoder1_0 = ConvTransposeELUBlock(256, 128, 4, stride=2, padding=1) 47 | self.decoder1_1 = ConvELUBlock(128, 128, 5, padding=2) 48 | self.decoder1_2 = ConvELUBlock(129, 64, 1) 49 | 50 | self.prediction1 = nn.Conv2d(64, 1, 3, padding=1) 51 | 52 | # Initialize the network weights 53 | self.apply(xavier_init) 54 | 55 | def forward(self, x): 56 | # First filter bank 57 | input0_0_out = self.input0_0(x) 58 | input0_1_out = self.input0_1(x) 59 | input0_2_out = self.input0_2(x) 60 | input0_3_out = self.input0_3(x) 61 | input0_out_cat = torch.cat( 62 | (input0_0_out, 63 | input0_1_out, 64 | input0_2_out, 65 | input0_3_out), 1) 66 | 67 | # Second filter bank 68 | input1_0_out = self.input1_0(input0_out_cat) 69 | input1_1_out = self.input1_1(input0_out_cat) 70 | input1_2_out = self.input1_2(input0_out_cat) 71 | input1_3_out = self.input1_3(input0_out_cat) 72 | 73 | # First encoding block 74 | encoder0_0_out = self.encoder0_0( 75 | torch.cat((input1_0_out, 76 | input1_1_out, 77 | input1_2_out, 78 | input1_3_out), 1)) 79 | encoder0_1_out = self.encoder0_1(encoder0_0_out) 80 | encoder0_2_out = self.encoder0_2(encoder0_1_out) 81 | 82 | # Second encoding block 83 | encoder1_0_out = self.encoder1_0(encoder0_2_out) 84 | encoder1_1_out = self.encoder1_1(encoder1_0_out) 85 | encoder1_2_out = self.encoder1_2(encoder1_1_out) 86 | encoder1_3_out = self.encoder1_3( 87 | torch.cat((encoder1_1_out, encoder1_2_out), 1)) 88 | 89 | # Third encoding block 90 | encoder2_0_out = self.encoder2_0(encoder1_3_out) 91 | encoder2_1_out = self.encoder2_1(encoder2_0_out) 92 | encoder2_2_out = self.encoder2_2( 93 | torch.cat((encoder2_0_out, encoder2_1_out), 1)) 94 | 95 | # First decoding block 96 | decoder0_0_out = self.decoder0_0(encoder2_2_out) 97 | decoder0_1_out = self.decoder0_1(decoder0_0_out) 98 | 99 | # 2x downsampled prediction 100 | pred_2x = self.prediction0(decoder0_1_out) 101 | upsampled_pred_2x = F.interpolate(pred_2x.detach(), scale_factor=2) 102 | 103 | # Second decoding block 104 | decoder1_0_out = self.decoder1_0(decoder0_1_out) 105 | decoder1_1_out = self.decoder1_1(decoder1_0_out) 106 | decoder1_2_out = self.decoder1_2( 107 | torch.cat((upsampled_pred_2x, decoder1_1_out), 1)) 108 | 109 | # Second prediction output (original scale) 110 | pred_1x = self.prediction1(decoder1_2_out) 111 | 112 | return [pred_1x, pred_2x] 113 | 114 | 115 | # ----------------------------------------------------------------------------- 116 | class UResNet(nn.Module): 117 | 118 | def __init__(self): 119 | super(UResNet, self).__init__() 120 | 121 | self.input0 = ConvELUBlock( 122 | in_channels=3, 123 | out_channels=32, 124 | kernel_size=7, 125 | stride=1, 126 | padding=3) 127 | self.input1 = ConvELUBlock( 128 | in_channels=32, 129 | out_channels=64, 130 | kernel_size=5, 131 | stride=1, 132 | padding=2) 133 | 134 | self.encoder0 = SkipBlock(64, 128) 135 | self.encoder1 = SkipBlock(128, 256) 136 | self.encoder2 = SkipBlock(256, 512) 137 | self.encoder3 = SkipBlock(512, 1024) 138 | 139 | self.decoder0_0 = ConvTransposeELUBlock( 140 | in_channels=1024, 141 | out_channels=512, 142 | kernel_size=4, 143 | stride=2, 144 | padding=1) 145 | self.decoder0_1 = ConvELUBlock( 146 | in_channels=512, 147 | out_channels=512, 148 | kernel_size=5, 149 | stride=1, 150 | padding=2) 151 | self.decoder1_0 = ConvTransposeELUBlock( 152 | in_channels=512, 153 | out_channels=256, 154 | kernel_size=4, 155 | stride=2, 156 | padding=1) 157 | self.decoder1_1 = ConvELUBlock( 158 | in_channels=256, 159 | out_channels=256, 160 | kernel_size=5, 161 | stride=1, 162 | padding=2) 163 | self.decoder2_0 = ConvTransposeELUBlock( 164 | in_channels=256, 165 | out_channels=128, 166 | kernel_size=4, 167 | stride=2, 168 | padding=1) 169 | self.decoder2_1 = ConvELUBlock( 170 | in_channels=128 + 1, 171 | out_channels=128, 172 | kernel_size=5, 173 | stride=1, 174 | padding=2) 175 | self.decoder3_0 = ConvTransposeELUBlock( 176 | in_channels=128, 177 | out_channels=64, 178 | kernel_size=4, 179 | stride=2, 180 | padding=1) 181 | self.decoder3_1 = ConvELUBlock( 182 | in_channels=64 + 1, 183 | out_channels=64, 184 | kernel_size=5, 185 | stride=1, 186 | padding=2) 187 | 188 | self.prediction0 = nn.Conv2d(256, 1, 3, padding=1) 189 | self.prediction1 = nn.Conv2d(128, 1, 3, padding=1) 190 | self.prediction2 = nn.Conv2d(64, 1, 3, padding=1) 191 | 192 | self.apply(xavier_init) 193 | 194 | def forward(self, x): 195 | # Encode down to 4x 196 | x = self.input0(x) 197 | x = self.input1(x) 198 | x = self.encoder0(x) 199 | x = self.encoder1(x) 200 | x = self.encoder2(x) 201 | x = self.encoder3(x) 202 | x = self.decoder0_0(x) 203 | x = self.decoder0_1(x) 204 | x = self.decoder1_0(x) 205 | x = self.decoder1_1(x) 206 | 207 | # Predict at 4x downsampled 208 | pred_4x = self.prediction0(x) 209 | 210 | # Upsample through convolution to 2x 211 | x = self.decoder2_0(x) 212 | upsampled_pred_4x = F.interpolate(pred_4x.detach(), scale_factor=2) 213 | 214 | # Predict at 2x downsampled 215 | x = self.decoder2_1(torch.cat((x, upsampled_pred_4x), 1)) 216 | pred_2x = self.prediction1(x) 217 | 218 | # Upsample through convolution to 1x 219 | x = self.decoder3_0(x) 220 | upsampled_pred_2x = F.interpolate(pred_2x.detach(), scale_factor=2) 221 | 222 | # Predict at 1x 223 | x = self.decoder3_1(torch.cat((x, upsampled_pred_2x), 1)) 224 | pred_1x = self.prediction2(x) 225 | 226 | return [pred_1x, pred_2x, pred_4x] 227 | 228 | 229 | # ----------------------------------------------------------------------------- 230 | class ConvELUBlock(nn.Module): 231 | 232 | def __init__(self, 233 | in_channels, 234 | out_channels, 235 | kernel_size, 236 | stride=1, 237 | padding=0, 238 | dilation=1): 239 | super(ConvELUBlock, self).__init__() 240 | 241 | self.conv = nn.Conv2d( 242 | in_channels=in_channels, 243 | out_channels=out_channels, 244 | kernel_size=kernel_size, 245 | stride=stride, 246 | padding=padding, 247 | dilation=dilation) 248 | 249 | def forward(self, x): 250 | return F.elu(self.conv(x), inplace=True) 251 | 252 | 253 | # ----------------------------------------------------------------------------- 254 | class ConvTransposeELUBlock(nn.Module): 255 | 256 | def __init__(self, 257 | in_channels, 258 | out_channels, 259 | kernel_size, 260 | stride=1, 261 | padding=0, 262 | dilation=1): 263 | super(ConvTransposeELUBlock, self).__init__() 264 | 265 | self.conv = nn.ConvTranspose2d( 266 | in_channels=in_channels, 267 | out_channels=out_channels, 268 | kernel_size=kernel_size, 269 | stride=stride, 270 | padding=padding, 271 | dilation=dilation) 272 | 273 | def forward(self, x): 274 | return F.elu(self.conv(x), inplace=True) 275 | 276 | 277 | # ----------------------------------------------------------------------------- 278 | class SkipBlock(nn.Module): 279 | 280 | def __init__(self, in_channels, out_channels): 281 | super(SkipBlock, self).__init__() 282 | 283 | self.conv1 = ConvELUBlock( 284 | in_channels=in_channels, 285 | out_channels=out_channels, 286 | kernel_size=3, 287 | stride=2, 288 | padding=1) 289 | self.conv2 = ConvELUBlock( 290 | in_channels=out_channels, 291 | out_channels=out_channels, 292 | kernel_size=3, 293 | stride=1, 294 | padding=1) 295 | self.conv3 = ConvELUBlock( 296 | in_channels=out_channels, 297 | out_channels=out_channels, 298 | kernel_size=3, 299 | stride=1, 300 | padding=1) 301 | 302 | def forward(self, x): 303 | # First convolutional block 304 | out1 = self.conv1(x) 305 | 306 | # Second and third convolutional blocks 307 | out3 = self.conv3(self.conv2(out1)) 308 | 309 | # Return the sum of the outputs of the first block and the third block 310 | return out1 + out3 311 | -------------------------------------------------------------------------------- /omnidepth.yml: -------------------------------------------------------------------------------- 1 | name: omnidepth 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - blas=1.0=mkl 8 | - ca-certificates=2018.11.29=ha4d7672_0 9 | - certifi=2018.11.29=py37_1000 10 | - cffi=1.11.5=py37he75722e_1 11 | - cloudpickle=0.7.0=py_0 12 | - cudatoolkit=9.0=h13b8566_0 13 | - cycler=0.10.0=py_1 14 | - cytoolz=0.9.0.1=py37h14c3975_1001 15 | - dask-core=1.1.1=py_0 16 | - decorator=4.3.2=py_0 17 | - freetype=2.9.1=h8a8886c_1 18 | - icu=58.2=hf484d3e_1000 19 | - imageio=2.5.0=py37_0 20 | - intel-openmp=2019.1=144 21 | - jpeg=9b=h024ee3a_2 22 | - kiwisolver=1.0.1=py37h6bb024c_1002 23 | - libedit=3.1.20181209=hc058e9b_0 24 | - libffi=3.2.1=hd88cf55_4 25 | - libgcc-ng=8.2.0=hdf63c60_1 26 | - libgfortran-ng=7.3.0=hdf63c60_0 27 | - libpng=1.6.36=hbc83047_0 28 | - libstdcxx-ng=8.2.0=hdf63c60_1 29 | - libtiff=4.0.10=h2733197_2 30 | - matplotlib-base=3.0.2=py37h167e16e_1001 31 | - mkl=2019.1=144 32 | - mkl_fft=1.0.10=py37ha843d7b_0 33 | - mkl_random=1.0.2=py37hd81dba3_0 34 | - ncurses=6.1=he6710b0_1 35 | - networkx=2.2=py_1 36 | - ninja=1.8.2=py37h6bb024c_1 37 | - numpy=1.15.4=py37h7e9f1db_0 38 | - numpy-base=1.15.4=py37hde5b4d6_0 39 | - olefile=0.46=py37_0 40 | - openssl=1.1.1a=h14c3975_1000 41 | - pillow=5.4.1=py37h34e0f95_0 42 | - pip=19.0.1=py37_0 43 | - pycparser=2.19=py37_0 44 | - pyparsing=2.3.1=py_0 45 | - python=3.7.2=h0371630_0 46 | - python-dateutil=2.8.0=py_0 47 | - pytorch=1.0.1=py3.7_cuda9.0.176_cudnn7.4.2_2 48 | - pywavelets=1.0.1=py37h3010b51_1000 49 | - readline=7.0=h7b6447c_5 50 | - scikit-image=0.14.2=py37hf484d3e_1 51 | - setuptools=40.8.0=py37_0 52 | - six=1.12.0=py37_0 53 | - sqlite=3.26.0=h7b6447c_0 54 | - tk=8.6.8=hbc83047_0 55 | - toolz=0.9.0=py_1 56 | - torchvision=0.2.1=py_2 57 | - wheel=0.32.3=py37_0 58 | - xz=5.2.4=h14c3975_4 59 | - zlib=1.2.11=h7b6447c_3 60 | - zstd=1.3.7=h0b5b093_0 61 | - pip: 62 | - chardet==3.0.4 63 | - idna==2.8 64 | - openexr==1.3.2 65 | - pyzmq==17.1.2 66 | - requests==2.21.0 67 | - scipy==1.2.1 68 | - torchfile==0.1.0 69 | - tornado==5.1.1 70 | - urllib3==1.24.1 71 | - visdom==0.1.8.8 72 | - websocket-client==0.54.0 73 | prefix: /data/anaconda3/envs/omnidepth 74 | 75 | -------------------------------------------------------------------------------- /omnidepth_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import time 4 | import datetime 5 | 6 | import math 7 | import shutil 8 | import os.path as osp 9 | 10 | import util 11 | from metrics import * 12 | 13 | 14 | # From https://github.com/fyu/drn 15 | class AverageMeter(object): 16 | """Computes and stores the average and current value""" 17 | 18 | def __init__(self): 19 | self.reset() 20 | 21 | def reset(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def update(self, val, n=1): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | 33 | def to_dict(self): 34 | return {'val': self.val, 35 | 'sum': self.sum, 36 | 'count': self.count, 37 | 'avg': self.avg} 38 | 39 | def from_dict(self, meter_dict): 40 | self.val = meter_dict['val'] 41 | self.sum = meter_dict['sum'] 42 | self.count = meter_dict['count'] 43 | self.avg = meter_dict['avg'] 44 | 45 | 46 | def visualize_rgb(rgb): 47 | # Scale back to [0,255] 48 | return (255 * rgb).byte() 49 | 50 | 51 | def visualize_mask(mask): 52 | '''Visualize the data mask''' 53 | mask /= mask.max() 54 | return (255 * mask).byte() 55 | 56 | 57 | class OmniDepthTrainer(object): 58 | 59 | def __init__( 60 | self, 61 | name, 62 | network, 63 | train_dataloader, 64 | val_dataloader, 65 | criterion, 66 | optimizer, 67 | checkpoint_dir, 68 | device, 69 | visdom=None, 70 | scheduler=None, 71 | num_epochs=20, 72 | validation_freq=1, 73 | visualization_freq=5, 74 | validation_sample_freq=-1, 75 | num_samples=0): 76 | 77 | # Name of this experiment 78 | self.name = name 79 | 80 | # Class instances 81 | self.network = network 82 | self.train_dataloader = train_dataloader 83 | self.val_dataloader = val_dataloader 84 | self.criterion = criterion 85 | self.optimizer = optimizer 86 | self.scheduler = scheduler 87 | 88 | # Training options 89 | self.num_epochs = num_epochs 90 | self.validation_freq = validation_freq 91 | self.visualization_freq = visualization_freq 92 | self.validation_sample_freq = validation_sample_freq 93 | 94 | # CUDA info 95 | self.device = device 96 | 97 | # Some timers 98 | self.batch_time_meter = AverageMeter() 99 | self.forward_time_meter = AverageMeter() 100 | self.backward_time_meter = AverageMeter() 101 | 102 | # Some trackers 103 | self.epoch = 0 104 | 105 | # Directory to store checkpoints 106 | self.checkpoint_dir = checkpoint_dir 107 | 108 | # Accuracy metric trackers 109 | self.abs_rel_error_meter = AverageMeter() 110 | self.sq_rel_error_meter = AverageMeter() 111 | self.lin_rms_sq_error_meter = AverageMeter() 112 | self.log_rms_sq_error_meter = AverageMeter() 113 | self.d1_inlier_meter = AverageMeter() 114 | self.d2_inlier_meter = AverageMeter() 115 | self.d3_inlier_meter = AverageMeter() 116 | 117 | # Track the best inlier ratio recorded so far 118 | self.best_d1_inlier = 0.0 119 | self.is_best = False 120 | 121 | # List of length 2 [Visdom instance, env] 122 | self.vis = visdom 123 | 124 | # Loss trackers 125 | self.loss = AverageMeter() 126 | self.i = 0 127 | self.num_samples = num_samples 128 | self.lin_rms_sq_error = 99999999999 129 | self.loss_error = 99999999999 130 | 131 | def forward_pass(self, inputs): 132 | ''' 133 | Accepts the inputs to the network as a Python list 134 | Returns the network output 135 | ''' 136 | return self.network(*inputs) 137 | 138 | def compute_loss(self, output, gt): 139 | ''' 140 | Returns the total loss 141 | ''' 142 | return self.criterion(output, gt[::2], gt[1::2]) 143 | 144 | def backward_pass(self, loss): 145 | # Computes the backward pass and updates the optimizer 146 | self.optimizer.zero_grad() 147 | loss.backward() 148 | self.optimizer.step() 149 | 150 | def train_one_epoch(self): 151 | # Put the model in train mode 152 | self.network = self.network.train() 153 | self.i = 0 154 | max_batch_num = len(self.train_dataloader) - 1 155 | end = time.time() 156 | self.loss.reset() 157 | # Load data 158 | for batch_num, data in enumerate(self.train_dataloader): 159 | # Parse the data into inputs, ground truth, and other 160 | inputs, gt, other = self.parse_data(data) 161 | 162 | # Run a forward pass 163 | forward_time = time.time() 164 | output = self.forward_pass(inputs) 165 | self.forward_time_meter.update(time.time() - forward_time) 166 | 167 | # Compute the loss(es) 168 | loss = self.compute_loss(output, gt) 169 | self.loss.update(loss) 170 | # Backpropagation of the total loss 171 | backward_time = time.time() 172 | self.backward_pass(loss) 173 | self.backward_time_meter.update(time.time() - backward_time) 174 | 175 | # Update batch times 176 | self.batch_time_meter.update(time.time() - end) 177 | end = time.time() 178 | # Every few batches 179 | if batch_num % self.visualization_freq == 0 or batch_num == max_batch_num: 180 | # Visualize the loss 181 | self.visualize_loss(batch_num, loss) 182 | self.visualize_samples(inputs, gt, other, output) 183 | self.print_progress(batch_num, max_batch_num, len(inputs[0]), loss) 184 | 185 | # Print the most recent batch report 186 | # self.print_batch_report(batch_num, loss) 187 | 188 | def validate(self): 189 | # print('Validating model....') 190 | 191 | # Put the model in eval mode 192 | self.network = self.network.eval() 193 | 194 | # Reset meter 195 | self.reset_eval_metrics() 196 | 197 | # Load data 198 | s = time.time() 199 | with torch.no_grad(): 200 | for batch_num, data in enumerate(self.val_dataloader): 201 | 202 | # Parse the data 203 | inputs, gt, other = self.parse_data(data) 204 | 205 | # Run a forward pass 206 | output = self.forward_pass(inputs) 207 | 208 | # Compute the evaluation metrics 209 | self.compute_eval_metrics(output, gt) 210 | 211 | # If trying to save intermediate outputs 212 | if self.validation_sample_freq >= 0: 213 | # Save the intermediate outputs 214 | if batch_num % self.validation_sample_freq == 0: 215 | self.save_samples(inputs, gt, other, output) 216 | 217 | # Print a report on the validation results 218 | # print('Validation finished in {} seconds'.format(time.time() - s)) 219 | # self.print_validation_report() 220 | 221 | def train(self, checkpoint_path=None, weights_only=False): 222 | print('Starting training') 223 | print(datetime.datetime.now()) 224 | start_time = datetime.datetime.now() 225 | # Load pretrained parameters if desired 226 | if checkpoint_path is not None: 227 | self.load_checkpoint(checkpoint_path, weights_only) 228 | if weights_only: 229 | self.initialize_visualizations() 230 | else: 231 | # Initialize any training visualizations 232 | self.initialize_visualizations() 233 | 234 | # Train for specified number of epochs 235 | for self.epoch in range(self.epoch, self.num_epochs): 236 | epoch_start_time = datetime.datetime.now() 237 | # Increment the LR scheduler 238 | if self.scheduler is not None: 239 | self.scheduler.step() 240 | # Run an epoch of training 241 | self.train_one_epoch() 242 | epoch_end_time = datetime.datetime.now() 243 | total_seconds = (epoch_end_time - epoch_start_time).seconds 244 | util.print_time('Epoch', total_seconds) 245 | if self.epoch % self.validation_freq == 0: 246 | self.validate() 247 | if self.lin_rms_sq_error_meter.avg <= self.lin_rms_sq_error and self.loss.avg <= self.loss_error: 248 | self.save_checkpoint() 249 | self.lin_rms_sq_error = self.lin_rms_sq_error_meter.avg 250 | self.loss_error = self.loss.avg 251 | self.visualize_metrics() 252 | end_time = datetime.datetime.now() 253 | seconds = (end_time - start_time).seconds 254 | util.print_time('Training', seconds) 255 | 256 | def evaluate(self, checkpoint_path): 257 | print('Evaluating model....') 258 | 259 | # Load the checkpoint to evaluate 260 | self.load_checkpoint(checkpoint_path, True, True) 261 | 262 | # Put the model in eval mode 263 | self.network = self.network.eval() 264 | 265 | # Reset meter 266 | self.reset_eval_metrics() 267 | 268 | # Load data 269 | s = time.time() 270 | with torch.no_grad(): 271 | for batch_num, data in enumerate(self.val_dataloader): 272 | 273 | # Parse the data 274 | inputs, gt, other = self.parse_data(data) 275 | 276 | # Run a forward pass 277 | output = self.forward_pass(inputs) 278 | 279 | # Compute the evaluation metrics 280 | self.compute_eval_metrics(output, gt) 281 | 282 | # If trying to save intermediate outputs 283 | if self.validation_sample_freq >= 0: 284 | # Save the intermediate outputs 285 | if batch_num % self.validation_sample_freq == 0: 286 | self.save_samples(inputs, gt, other, output) 287 | 288 | # Print a report on the validation results 289 | print('Evaluation finished in {} seconds'.format(time.time() - s)) 290 | self.print_validation_report() 291 | 292 | def predict(self, checkpoint_path): 293 | print('Prediction ....') 294 | 295 | # Load the checkpoint to evaluate 296 | self.load_checkpoint(checkpoint_path, True, True) 297 | 298 | # Put the model in eval mode 299 | self.network = self.network.eval() 300 | 301 | # Reset meter 302 | self.reset_eval_metrics() 303 | 304 | # Load data 305 | s = time.time() 306 | with torch.no_grad(): 307 | for batch_num, data in enumerate(self.val_dataloader): 308 | # Parse the data 309 | inputs, gt, other = self.parse_data(data) 310 | 311 | # Run a forward pass 312 | output = self.forward_pass(inputs) 313 | 314 | # Compute the evaluation metrics 315 | self.compute_eval_metrics(output, gt) 316 | 317 | # visu 318 | self.visualize_download_output(inputs, gt, output) 319 | 320 | # Print a report on the validation results 321 | print('Prediction finished in {} seconds'.format(time.time() - s)) 322 | self.print_validation_report() 323 | 324 | def filter_loss_gen(self, checkpoint_path, threshold, to): 325 | print('Filter ....') 326 | # Load the checkpoint to evaluate 327 | self.load_checkpoint(checkpoint_path, True, True) 328 | 329 | # Load data 330 | s = time.time() 331 | cnt = 0 332 | for batch_num, data in enumerate(self.val_dataloader): 333 | # Parse the data 334 | inputs, gt, common_path = self.parse_data(data) 335 | rgb_img_path = common_path[0] + '.jpeg' 336 | depth_img_path = common_path[0] + '_d.jpeg' 337 | # Run a forward pass 338 | output = self.forward_pass(inputs) 339 | 340 | loss = self.compute_loss(output, gt) 341 | if loss > threshold: 342 | shutil.move(rgb_img_path, to + osp.basename(rgb_img_path)) 343 | shutil.move(depth_img_path, to + osp.basename(depth_img_path)) 344 | cnt += 1 345 | print('img loss > {} has num: {}'.format(threshold, cnt)) 346 | 347 | # Print a report on the validation results 348 | print('Prediction finished in {} seconds'.format(time.time() - s)) 349 | 350 | def parse_data(self, data): 351 | ''' 352 | Returns a list of the inputs as first output, a list of the GT as a second output, and a list of the remaining info as a third output. Must be implemented. 353 | ''' 354 | rgb = data[0].to(self.device) 355 | gt_depth_1x = data[1].to(self.device) 356 | gt_depth_2x = F.interpolate(gt_depth_1x, scale_factor=0.5) 357 | gt_depth_4x = F.interpolate(gt_depth_1x, scale_factor=0.25) 358 | mask_1x = data[2].to(self.device) 359 | mask_2x = F.interpolate(mask_1x, scale_factor=0.5) 360 | mask_4x = F.interpolate(mask_1x, scale_factor=0.25) 361 | 362 | inputs = [rgb] 363 | gt = [gt_depth_1x, mask_1x, gt_depth_2x, mask_2x, gt_depth_4x, mask_4x] 364 | other = data[3] 365 | 366 | return inputs, gt, other 367 | 368 | def reset_eval_metrics(self): 369 | ''' 370 | Resets metrics used to evaluate the model 371 | ''' 372 | self.abs_rel_error_meter.reset() 373 | self.sq_rel_error_meter.reset() 374 | self.lin_rms_sq_error_meter.reset() 375 | self.log_rms_sq_error_meter.reset() 376 | self.d1_inlier_meter.reset() 377 | self.d2_inlier_meter.reset() 378 | self.d3_inlier_meter.reset() 379 | self.is_best = False 380 | 381 | def compute_eval_metrics(self, output, gt): 382 | ''' 383 | Computes metrics used to evaluate the model 384 | ''' 385 | depth_pred = output[0] 386 | gt_depth = gt[0] 387 | depth_mask = gt[1] 388 | 389 | N = depth_mask.sum() 390 | 391 | # Align the prediction scales via median 392 | median_scaling_factor = gt_depth[depth_mask > 0].median() / depth_pred[depth_mask > 0].median() 393 | depth_pred *= median_scaling_factor 394 | 395 | abs_rel = abs_rel_error(depth_pred, gt_depth, depth_mask) 396 | sq_rel = sq_rel_error(depth_pred, gt_depth, depth_mask) 397 | rms_sq_lin = lin_rms_sq_error(depth_pred, gt_depth, depth_mask) 398 | rms_sq_log = log_rms_sq_error(depth_pred, gt_depth, depth_mask) 399 | d1 = delta_inlier_ratio(depth_pred, gt_depth, depth_mask, degree=1) 400 | d2 = delta_inlier_ratio(depth_pred, gt_depth, depth_mask, degree=2) 401 | d3 = delta_inlier_ratio(depth_pred, gt_depth, depth_mask, degree=3) 402 | 403 | self.abs_rel_error_meter.update(abs_rel, N) 404 | self.sq_rel_error_meter.update(sq_rel, N) 405 | self.lin_rms_sq_error_meter.update(rms_sq_lin, N) 406 | self.log_rms_sq_error_meter.update(rms_sq_log, N) 407 | self.d1_inlier_meter.update(d1, N) 408 | self.d2_inlier_meter.update(d2, N) 409 | self.d3_inlier_meter.update(d3, N) 410 | 411 | def load_checkpoint(self, checkpoint_path=None, weights_only=False, eval_mode=False): 412 | ''' 413 | Initializes network with pretrained parameters 414 | ''' 415 | if checkpoint_path is not None: 416 | print('Loading checkpoint \'{}\''.format(checkpoint_path)) 417 | checkpoint = torch.load(checkpoint_path) 418 | 419 | # If we want to continue training where we left off, load entire training state 420 | if not weights_only: 421 | self.epoch = checkpoint['epoch'] 422 | experiment_name = checkpoint['experiment'] 423 | self.vis[1] = experiment_name 424 | self.best_d1_inlier = checkpoint['best_d1_inlier'] 425 | self.loss.from_dict(checkpoint['loss_meter']) 426 | else: 427 | print('NOTE: Loading weights only') 428 | 429 | # Load the optimizer and model state 430 | if not eval_mode: 431 | util.load_optimizer( 432 | self.optimizer, 433 | checkpoint['optimizer'], 434 | self.device) 435 | util.load_partial_model( 436 | self.network, 437 | checkpoint['state_dict']) 438 | 439 | print('Loaded checkpoint \'{}\' (epoch {})'.format( 440 | checkpoint_path, checkpoint['epoch'])) 441 | else: 442 | print('WARNING: No checkpoint found') 443 | 444 | def initialize_visualizations(self): 445 | ''' 446 | Initializes visualizations 447 | ''' 448 | 449 | self.vis[0].line( 450 | env=self.vis[1], 451 | X=torch.zeros(1, 1).long(), 452 | Y=torch.zeros(1, 1).float(), 453 | win='losses', 454 | opts=dict( 455 | title='Loss Plot', 456 | markers=False, 457 | xlabel='Iteration', 458 | ylabel='Loss', 459 | legend=['Total Loss'])) 460 | 461 | self.vis[0].line( 462 | env=self.vis[1], 463 | X=torch.zeros(1, 4).long(), 464 | Y=torch.zeros(1, 4).float(), 465 | win='error_metrics', 466 | opts=dict( 467 | title='Depth Error Metrics', 468 | markers=True, 469 | xlabel='Epoch', 470 | ylabel='Error', 471 | legend=['Abs. Rel. Error', 'Sq. Rel. Error', 'Linear RMS Error', 'Log RMS Error'])) 472 | 473 | self.vis[0].line( 474 | env=self.vis[1], 475 | X=torch.zeros(1, 3).long(), 476 | Y=torch.zeros(1, 3).float(), 477 | win='inlier_metrics', 478 | opts=dict( 479 | title='Depth Inlier Metrics', 480 | markers=True, 481 | xlabel='Epoch', 482 | ylabel='Percent', 483 | legend=['d1', 'd2', 'd3'])) 484 | 485 | def visualize_loss(self, batch_num, loss): 486 | ''' 487 | Updates the loss visualization 488 | ''' 489 | total_num_batches = self.epoch * len(self.train_dataloader) + batch_num 490 | self.vis[0].line( 491 | env=self.vis[1], 492 | X=torch.tensor([total_num_batches]), 493 | Y=torch.tensor([loss]), 494 | win='losses', 495 | update='append', 496 | opts=dict( 497 | legend=['Total Loss'])) 498 | 499 | def visualize_samples(self, inputs, gt, other, output): 500 | ''' 501 | Updates the output samples visualization 502 | ''' 503 | rgb = inputs[0][0].cpu() 504 | depth_pred = output[0][0].cpu() 505 | gt_depth = gt[0][0].cpu() 506 | depth_mask = gt[1][0].cpu() 507 | 508 | self.vis[0].image( 509 | visualize_rgb(rgb), 510 | env=self.vis[1], 511 | win='rgb', 512 | opts=dict( 513 | title='Input RGB Image', 514 | caption='Input RGB Image')) 515 | 516 | self.vis[0].image( 517 | visualize_mask(depth_mask), 518 | env=self.vis[1], 519 | win='mask', 520 | opts=dict( 521 | title='Mask', 522 | caption='Mask')) 523 | 524 | max_depth = max( 525 | ((depth_mask > 0).float() * gt_depth).max().item(), 526 | ((depth_mask > 0).float() * depth_pred).max().item()) 527 | self.vis[0].heatmap( 528 | depth_pred.squeeze().flip(-2), 529 | env=self.vis[1], 530 | win='depth_pred', 531 | opts=dict( 532 | title='Depth Prediction', 533 | caption='Depth Prediction', 534 | xmax=max_depth, 535 | xmin=gt_depth.min().item())) 536 | 537 | self.vis[0].heatmap( 538 | gt_depth.squeeze().flip(-2), 539 | env=self.vis[1], 540 | win='gt_depth', 541 | opts=dict( 542 | title='Depth GT', 543 | caption='Depth GT', 544 | xmax=max_depth)) 545 | 546 | def visualize_download_output(self, inputs, gt, output): 547 | ''' 548 | Updates the output samples visualization 549 | ''' 550 | rgb = inputs[0][0].cpu() 551 | depth_pred = output[0][0].cpu() 552 | gt_depth = gt[0][0].cpu() 553 | depth_mask = gt[1][0].cpu() 554 | 555 | self.vis[0].image( 556 | visualize_rgb(rgb), 557 | env=self.vis[1], 558 | win='rgb', 559 | opts=dict( 560 | title='Input RGB Image', 561 | caption='Input RGB Image')) 562 | 563 | max_depth = max( 564 | ((depth_mask > 0).float() * gt_depth).max().item(), 565 | ((depth_mask > 0).float() * depth_pred).max().item()) 566 | self.vis[0].image( 567 | depth_pred.squeeze(), 568 | env=self.vis[1], 569 | win='depth_pred', 570 | opts=dict( 571 | title='Depth Prediction', 572 | caption='Depth Prediction', 573 | xmax=max_depth, 574 | xmin=gt_depth.min().item())) 575 | 576 | self.vis[0].image( 577 | gt_depth.squeeze(), 578 | env=self.vis[1], 579 | win='gt_depth', 580 | opts=dict( 581 | title='Depth GT', 582 | caption='Depth GT', 583 | xmax=max_depth)) 584 | 585 | def visualize_metrics(self): 586 | ''' 587 | Updates the metrics visualization 588 | ''' 589 | abs_rel = self.abs_rel_error_meter.avg 590 | sq_rel = self.sq_rel_error_meter.avg 591 | lin_rms = math.sqrt(self.lin_rms_sq_error_meter.avg) 592 | log_rms = math.sqrt(self.log_rms_sq_error_meter.avg) 593 | d1 = self.d1_inlier_meter.avg 594 | d2 = self.d2_inlier_meter.avg 595 | d3 = self.d3_inlier_meter.avg 596 | 597 | errors = torch.FloatTensor([abs_rel, sq_rel, lin_rms, log_rms]) 598 | errors = errors.view(1, -1) 599 | epoch_expanded = torch.ones(errors.shape) * (self.epoch + 1) 600 | self.vis[0].line( 601 | env=self.vis[1], 602 | X=epoch_expanded, 603 | Y=errors, 604 | win='error_metrics', 605 | update='append', 606 | opts=dict( 607 | legend=['Abs. Rel. Error', 'Sq. Rel. Error', 'Linear RMS Error', 'Log RMS Error'])) 608 | 609 | inliers = torch.FloatTensor([d1, d2, d3]) 610 | inliers = inliers.view(1, -1) 611 | epoch_expanded = torch.ones(inliers.shape) * (self.epoch + 1) 612 | self.vis[0].line( 613 | env=self.vis[1], 614 | X=epoch_expanded, 615 | Y=inliers, 616 | win='inlier_metrics', 617 | update='append', 618 | opts=dict( 619 | legend=['d1', 'd2', 'd3'])) 620 | 621 | def print_progress(self, batch_num, max_batch_num, batch_size, loss, max_arrow=50): 622 | self.i = (batch_num + 1) * batch_size if batch_num < max_batch_num else self.num_samples 623 | num_arrow = int(self.i * max_arrow / self.num_samples) # 计算显示多少个'>' 624 | num_line = max_arrow - num_arrow # 计算显示多少个'-' 625 | percent = self.i * 100.0 / self.num_samples # 计算完成进度,格式为xx.xx% 626 | process_bar = '[' + '>' * num_arrow + '-' * num_line + ']' \ 627 | + '%.2f' % percent + '%' + '\r' # 带输出的字符串,'\r'表示不换行回到最左边 628 | if batch_num == max_batch_num: 629 | loss = self.loss.avg 630 | print('\r', 631 | 'Epoch: [{0}][{1}/{2}]'.format(self.epoch + 1, batch_num + 1, 632 | len(self.train_dataloader)) + ' - Loss: %.3f ' % loss + process_bar, 633 | end='') # 这两句打印字符到终端 634 | if self.i >= self.num_samples: 635 | print('') 636 | 637 | def print_batch_report(self, batch_num, loss): 638 | ''' 639 | Prints a report of the current batch 640 | ''' 641 | print('Epoch: [{0}][{1}/{2}]\t' 642 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\n' 643 | 'Forward Time {forward_time.val:.3f} ({forward_time.avg:.3f})\t' 644 | 'Backward Time {backward_time.val:.3f} ({backward_time.avg:.3f})\n' 645 | 'Loss {loss:.4f} ({loss:.4f})\n\n'.format( 646 | self.epoch + 1, 647 | batch_num + 1, 648 | len(self.train_dataloader), 649 | batch_time=self.batch_time_meter, 650 | forward_time=self.forward_time_meter, 651 | backward_time=self.backward_time_meter, 652 | loss=loss)) 653 | 654 | def print_validation_report(self): 655 | ''' 656 | Prints a report of the validation results 657 | ''' 658 | print('Epoch: {}\n' 659 | ' Avg. Abs. Rel. Error: {:.4f}\n' 660 | ' Avg. Sq. Rel. Error: {:.4f}\n' 661 | ' Avg. Lin. RMS Error: {:.4f}\n' 662 | ' Avg. Log RMS Error: {:.4f}\n' 663 | ' Inlier D1: {:.4f}\n' 664 | ' Inlier D2: {:.4f}\n' 665 | ' Inlier D3: {:.4f}\n\n'.format( 666 | self.epoch + 1, 667 | self.abs_rel_error_meter.avg, 668 | self.sq_rel_error_meter.avg, 669 | math.sqrt(self.lin_rms_sq_error_meter.avg), 670 | math.sqrt(self.log_rms_sq_error_meter.avg), 671 | self.d1_inlier_meter.avg, 672 | self.d2_inlier_meter.avg, 673 | self.d3_inlier_meter.avg)) 674 | 675 | # Also update the best state tracker 676 | if self.best_d1_inlier < self.d1_inlier_meter.avg: 677 | self.best_d1_inlier = self.d1_inlier_meter.avg 678 | self.is_best = True 679 | 680 | def save_checkpoint(self): 681 | ''' 682 | Saves the model state 683 | ''' 684 | # Save latest checkpoint (constantly overwriting itself) 685 | checkpoint_path = osp.join( 686 | self.checkpoint_dir, 687 | 'epoch_latest.pth') 688 | 689 | # Actually saves the latest checkpoint and also updating the file holding the best one 690 | util.save_checkpoint( 691 | { 692 | 'epoch': self.epoch + 1, 693 | 'experiment': self.name, 694 | 'state_dict': self.network.state_dict(), 695 | 'optimizer': self.optimizer.state_dict(), 696 | 'loss_meter': self.loss.to_dict(), 697 | 'best_d1_inlier': self.best_d1_inlier 698 | }, 699 | self.is_best, 700 | filename=checkpoint_path) 701 | 702 | # Copies the latest checkpoint to another file stored for each epoch 703 | history_path = osp.join( 704 | self.checkpoint_dir, 705 | 'epoch{}_{:.0f}_{:.0f}.pth'.format((self.epoch + 1), math.sqrt(self.lin_rms_sq_error_meter.avg), 706 | math.sqrt(self.loss.avg))) 707 | shutil.copyfile(checkpoint_path, history_path) 708 | print('Checkpoint saved') 709 | 710 | def save_samples(self, inputs, gt, other, outputs): 711 | ''' 712 | Saves samples of the network inputs and outputs 713 | ''' 714 | pass 715 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | import os 6 | import re 7 | import shutil 8 | 9 | 10 | # imgs_path = './data/test/' 11 | # output_path = './data/panoimgs/result.jpg' 12 | # img_list = [imgs_path + str(num) + '.jpeg' for num in range(1, 3)] 13 | # depth_json_path = './data/test/U4qAucYgiXcf4P3kDmXxiQ.json' 14 | 15 | 16 | def stitch_pano_overlap(img_list, output_path): 17 | imgs = [] 18 | for img_name in img_list: 19 | img = cv2.imread(img_name) 20 | if img is None: 21 | print("can't read image " + img_name) 22 | sys.exit(-1) 23 | imgs.append(img) 24 | stitcher = cv2.Stitcher_create(cv2.Stitcher_PANORAMA) 25 | status, pano = stitcher.stitch(imgs) 26 | 27 | if status != cv2.Stitcher_OK: 28 | print("Can't stitch images, error code = %d" % status) 29 | sys.exit(-1) 30 | cv2.imwrite(output_path, pano) 31 | print("stitching completed successfully. %s saved!" % output_path) 32 | 33 | 34 | def stitch_pano(img_list, output_path): 35 | imgs = [] 36 | imgs_size = [] 37 | for item in img_list: 38 | img = Image.open(item) 39 | imgs.append(img) 40 | imgs_size.append(img.size) 41 | new_size = np.sum(imgs_size, axis=0) 42 | joint = Image.new('RGB', (new_size[0], imgs_size[0][1])) 43 | loc = [] 44 | x = 0 45 | loc.append((x, 0)) 46 | for item in imgs_size: 47 | x += list(item)[0] 48 | loc.append((x, 0)) 49 | for i, img in enumerate(imgs): 50 | joint.paste(img, loc[i]) 51 | joint.save(output_path) 52 | 53 | 54 | def pano_resize(imgs_path): 55 | img_cnt = 0 56 | for img_name in os.listdir(imgs_path): 57 | if not re.match(r'.+d\.jpeg', img_name): 58 | img_cnt += 1 59 | img_path = imgs_path + img_name 60 | img = Image.open(img_path) 61 | (x, y) = img.size 62 | x_new = 512 63 | y_new = int(y * x_new / x) 64 | out = img.resize((x_new, y_new), Image.ANTIALIAS) 65 | out.save(img_path) 66 | 67 | print("resized imgs count:", img_cnt) 68 | 69 | 70 | def delete_nopair_imgs(imgs_path): 71 | img_list = [] 72 | del_cnt = 0 73 | for img_name in os.listdir(imgs_path): 74 | if re.match(r'.+d\.jpeg', img_name): 75 | prefix = img_name[:-7] 76 | else: 77 | prefix = img_name[:-5] 78 | 79 | if prefix in img_list: 80 | img_list.remove(prefix) 81 | else: 82 | img_list.append(prefix) 83 | num_wrong_imgs = len(img_list) 84 | print('find no pair or repeat imgs num: ', num_wrong_imgs) 85 | for prefix in img_list: 86 | rgb_img = imgs_path + prefix + ".jpeg" 87 | depth_img = imgs_path + prefix + "_d.jpeg" 88 | if os.path.exists(rgb_img): 89 | os.remove(rgb_img) 90 | del_cnt += 1 91 | if os.path.exists(depth_img): 92 | os.remove(depth_img) 93 | del_cnt += 1 94 | print('delete no pair or repeat imgs num: ', del_cnt) 95 | 96 | 97 | def copy_n_imgs(num, origin, to): 98 | cnt = 0 99 | for img_name in os.listdir(origin): 100 | if not re.match(r'.+d\.jpeg', img_name): 101 | img = Image.open(origin + img_name) 102 | img.save(to + img_name) 103 | cnt += 1 104 | if cnt >= num: 105 | break 106 | print('copy {} imgs from {} to {}'.format(cnt, origin, to)) 107 | 108 | 109 | def copy_n_imgpairs(num, origin, to): 110 | cnt = 0 111 | img_list = [] 112 | for img_name in os.listdir(origin): 113 | if re.match(r'.+d\.jpeg', img_name): 114 | prefix = img_name[:-7] 115 | else: 116 | prefix = img_name[:-5] 117 | 118 | if prefix not in img_list: 119 | img_list.append(prefix) 120 | cnt += 1 121 | if cnt >= num: 122 | break 123 | for prefix in img_list: 124 | shutil.move(origin + prefix + '.jpeg', to + prefix + '.jpeg') 125 | shutil.move(origin + prefix + '_d.jpeg', to + prefix + '_d.jpeg') 126 | print('copy {} img pairs'.format(cnt)) 127 | 128 | 129 | def filter__nopair_imgs(origin, to): 130 | img_list = [] 131 | for img_name in os.listdir(origin): 132 | if re.match(r'.+d\.jpeg', img_name): 133 | prefix = img_name[:-7] 134 | else: 135 | prefix = img_name[:-5] 136 | 137 | if prefix in img_list: 138 | img_list.remove(prefix) 139 | else: 140 | img_list.append(prefix) 141 | 142 | for prefix in img_list: 143 | shutil.move(origin + prefix + '_d.jpeg', to + prefix + '_d.jpeg') 144 | print('copy {} img pairs'.format(len(img_list))) 145 | 146 | 147 | def move_files(origin, to): 148 | cnt = 0 149 | for file in os.listdir(origin): 150 | if re.match(r'.+d\.jpeg', file): 151 | shutil.move(origin + file, to + file) 152 | cnt += 1 153 | print('total move files num: %d' % cnt) 154 | 155 | 156 | if __name__ == '__main__': 157 | imgs_path = '/home/nowburn/python_projects/cv/OmniDepth/show/' 158 | origin = '/home/nowburn/python_projects/cv/OmniDepth/data/training/' 159 | to = '/home/nowburn/python_projects/cv/OmniDepth/data/tmp/' 160 | # copy_n_imgs(100, imgs_path, to) 161 | # copy_n_imgs(2000, origin, to) 162 | # copy_n_imgpairs(674, origin, to) 163 | #filter__nopair_imgs(origin, to) 164 | pano_resize(imgs_path) 165 | 166 | 167 | -------------------------------------------------------------------------------- /show/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowburn/OmniDepth-Pytorch/c6acaa2dc095882a2de40fcfda77f80891406f0e/show/1.png -------------------------------------------------------------------------------- /show/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowburn/OmniDepth-Pytorch/c6acaa2dc095882a2de40fcfda77f80891406f0e/show/2.png -------------------------------------------------------------------------------- /show/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowburn/OmniDepth-Pytorch/c6acaa2dc095882a2de40fcfda77f80891406f0e/show/3.png -------------------------------------------------------------------------------- /show/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nowburn/OmniDepth-Pytorch/c6acaa2dc095882a2de40fcfda77f80891406f0e/show/4.png -------------------------------------------------------------------------------- /test_omnidepth.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | 3 | from omnidepth_trainer import OmniDepthTrainer 4 | from network import * 5 | from dataset import * 6 | 7 | import os.path as osp 8 | from criteria import * 9 | 10 | # -------------- 11 | # PARAMETERS 12 | # -------------- 13 | network_type = 'RectNet' # 'RectNet' or 'UResNet' 14 | experiment_name = 'omnidepth' 15 | val_file_list = './data/show/test/' # List of evaluation files 16 | checkpoint_dir = osp.join('experiments', experiment_name) 17 | checkpoint_path = None 18 | checkpoint_path = osp.join(checkpoint_dir, 'epoch_latest.pth') 19 | batch_size = 1 20 | num_workers = 8 21 | validation_sample_freq = -1 22 | device_ids = [0] 23 | 24 | # ------------------------------------------------------- 25 | # Fill in the rest 26 | vis = visdom.Visdom() 27 | env = 'predict' 28 | device = torch.device('cuda', device_ids[0]) 29 | 30 | # UResNet 31 | if network_type == 'UResNet': 32 | model = UResNet() 33 | alpha_list = [0.445, 0.275, 0.13] 34 | beta_list = [0.15, 0., 0.] 35 | # RectNet 36 | elif network_type == 'RectNet': 37 | model = RectNet() 38 | alpha_list = [0.535, 0.272] 39 | beta_list = [0.134, 0.068, ] 40 | else: 41 | assert True, 'Unsupported network type' 42 | 43 | criterion = MultiScaleL2Loss(alpha_list, beta_list) 44 | # ------------------------------------------------------- 45 | # Set up the training routine 46 | network = nn.DataParallel( 47 | model.float(), 48 | device_ids=device_ids).to(device) 49 | 50 | val_dataloader = torch.utils.data.DataLoader( 51 | dataset=OmniDepthDataset(val_file_list), 52 | batch_size=batch_size, 53 | shuffle=False, 54 | num_workers=num_workers, 55 | drop_last=False) 56 | 57 | trainer = OmniDepthTrainer( 58 | experiment_name, 59 | network, 60 | None, 61 | val_dataloader, 62 | criterion, 63 | None, 64 | checkpoint_dir, 65 | device, 66 | visdom=[vis, env], 67 | validation_sample_freq=validation_sample_freq) 68 | 69 | trainer.predict(checkpoint_path) 70 | 71 | -------------------------------------------------------------------------------- /train_omnidepth.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | 3 | from omnidepth_trainer import OmniDepthTrainer 4 | from network import * 5 | from criteria import * 6 | from dataset import * 7 | from util import mkdirs, set_caffe_param_mult 8 | 9 | import os.path as osp 10 | 11 | # -------------- 12 | # PARAMETERS 13 | # -------------- 14 | network_type = 'RectNet' # 'RectNet' or 'UResNet' 15 | experiment_name = 'omnidepth' 16 | train_file_list = './data/training/' # File with list of training files 17 | val_file_list = './data/validation/' # File with list of validation files 18 | max_steps_file = './data/num_samples.txt' # File with number of max_steips for progress show 19 | checkpoint_dir = osp.join('experiments', experiment_name) 20 | checkpoint_path = None 21 | checkpoint_path = osp.join(checkpoint_dir, 'epoch_latest.pth') 22 | load_weights_only = True 23 | batch_size = 10 24 | num_workers = 10 25 | lr = 2e-4 26 | step_size = 3 27 | lr_decay = 0.5 28 | num_epochs = 9999 29 | validation_freq = 1 30 | visualization_freq = 5 31 | validation_sample_freq = -1 32 | device_ids = [0] 33 | num_samples = 0 34 | 35 | # ------------------------------------------------------- 36 | # Fill in the rest 37 | vis = visdom.Visdom() 38 | env = experiment_name 39 | device = torch.device('cuda', device_ids[0]) 40 | 41 | # UResNet 42 | if network_type == 'UResNet': 43 | model = UResNet() 44 | alpha_list = [0.445, 0.275, 0.13] 45 | beta_list = [0.15, 0., 0.] 46 | # RectNet 47 | elif network_type == 'RectNet': 48 | model = RectNet() 49 | alpha_list = [0.535, 0.272] 50 | beta_list = [0.134, 0.068, ] 51 | else: 52 | assert False, 'Unsupported network type' 53 | 54 | # Make the checkpoint directory 55 | mkdirs(checkpoint_dir) 56 | 57 | # Read_Write max_steps 58 | if load_weights_only: 59 | n = 0 60 | for img_name in os.listdir(train_file_list): 61 | if re.match(r'.+d\.jpeg', img_name): 62 | n += 1 63 | num_samples = n 64 | with open(max_steps_file, 'w') as f: 65 | f.write(str(num_samples)) 66 | else: 67 | with open(max_steps_file, 'r') as f: 68 | num_samples = int(f.read()) 69 | 70 | # ------------------------------------------------------- 71 | # Set up the training routine 72 | network = nn.DataParallel( 73 | model.float(), 74 | device_ids=device_ids).to(device) 75 | 76 | train_dataloader = torch.utils.data.DataLoader( 77 | dataset=OmniDepthDataset(train_file_list), 78 | batch_size=batch_size, 79 | shuffle=True, 80 | num_workers=num_workers, 81 | drop_last=False) 82 | 83 | val_dataloader = torch.utils.data.DataLoader( 84 | dataset=OmniDepthDataset(val_file_list), 85 | batch_size=batch_size, 86 | shuffle=False, 87 | num_workers=num_workers, 88 | drop_last=False) 89 | 90 | criterion = MultiScaleL2Loss(alpha_list, beta_list) 91 | 92 | # Set up network parameters with Caffe-like LR multipliers 93 | param_list = set_caffe_param_mult(network, lr, 0) 94 | optimizer = torch.optim.Adam( 95 | params=param_list, 96 | lr=lr) 97 | 98 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 99 | step_size=step_size, 100 | gamma=lr_decay) 101 | 102 | trainer = OmniDepthTrainer( 103 | experiment_name, 104 | network, 105 | train_dataloader, 106 | val_dataloader, 107 | criterion, 108 | optimizer, 109 | checkpoint_dir, 110 | device, 111 | visdom=[vis, env], 112 | scheduler=scheduler, 113 | num_epochs=num_epochs, 114 | validation_freq=validation_freq, 115 | visualization_freq=visualization_freq, 116 | validation_sample_freq=validation_sample_freq, 117 | num_samples=num_samples) 118 | 119 | trainer.train(checkpoint_path, load_weights_only) 120 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | import Imath, array 6 | 7 | import os 8 | import os.path as osp 9 | import shutil 10 | 11 | 12 | def mkdirs(path): 13 | '''Convenience function to make all intermediate folders in creating a directory''' 14 | try: 15 | os.makedirs(path) 16 | except: 17 | pass 18 | 19 | 20 | def xavier_init(m): 21 | '''Provides Xavier initialization for the network weights and 22 | normally distributes batch norm params''' 23 | classname = m.__class__.__name__ 24 | if (classname.find('Conv2d') != -1) or (classname.find('ConvTranspose2d') != -1): 25 | nn.init.xavier_normal_(m.weight.data) 26 | m.bias.data.fill_(0) 27 | 28 | 29 | def save_checkpoint(state, is_best, filename): 30 | '''Saves a training checkpoints''' 31 | torch.save(state, filename) 32 | if is_best: 33 | basename = osp.basename(filename) # File basename 34 | idx = filename.find(basename) # Index where path ends and basename begins 35 | # Copy the file to a different filename in the same directory 36 | shutil.copyfile(filename, osp.join(filename[:idx], 'model_best.pth')) 37 | 38 | 39 | def load_partial_model(model, loaded_state_dict): 40 | '''Loaded a save model, even if the model is not a perfect match. This will run even if there is are layers from the current network missing in the saved model. 41 | However, layers without a perfect match will be ignored.''' 42 | model_dict = model.state_dict() 43 | pretrained_dict = {k: v for k, v in loaded_state_dict.items() if k in model_dict} 44 | model_dict.update(pretrained_dict) 45 | model.load_state_dict(model_dict) 46 | 47 | 48 | def load_optimizer(optimizer, loaded_optimizer_dict, device): 49 | '''Loads the saved state of the optimizer and puts it back on the GPU if necessary. Similar to loading the partial model, this will load only the optimization parameters that match the current parameterization.''' 50 | optimizer_dict = optimizer.state_dict() 51 | pretrained_dict = {k: v for k, v in loaded_optimizer_dict.items() 52 | if k in optimizer_dict and k != 'param_groups'} 53 | optimizer_dict.update(pretrained_dict) 54 | optimizer.load_state_dict(optimizer_dict) 55 | for state in optimizer.state.values(): 56 | for k, v in state.items(): 57 | if torch.is_tensor(v): 58 | state[k] = v.to(device) 59 | 60 | 61 | def set_caffe_param_mult(m, base_lr, base_weight_decay): 62 | '''Function that allows us to assign a LR multiplier of 2 and a decay multiplier of 0 to the bias weights (which is common in Caffe)''' 63 | param_list = [] 64 | for name, params in m.named_parameters(): 65 | if name.find('bias') != -1: 66 | param_list.append({'params': params, 'lr': 2 * base_lr, 'weight_decay': 0.0}) 67 | else: 68 | param_list.append({'params': params, 'lr': base_lr, 'weight_decay': base_weight_decay}) 69 | return param_list 70 | 71 | 72 | def print_time(name, seconds): 73 | print(name + ' cost time: {}:{}:{}\n'.format(seconds // 3600, (seconds % 3600) // 60, 74 | seconds % 60)) 75 | --------------------------------------------------------------------------------