├── README.md ├── networks ├── depth_decoderv2.py └── uncert_decoder.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # The official repo for SUB-Depth 2 | 3 | [SUB-Depth: Self-distillation and Uncertainty Boosting Self-supervised Monocular Depth Estimation](https://arxiv.org/abs/2111.09692v2) 4 | 5 | # TO DO 6 | - release training code and upload trained models 7 | -------------------------------------------------------------------------------- /networks/depth_decoderv2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from collections import OrderedDict 6 | from layers import * 7 | 8 | class DepthDecoderV2(nn.Module): 9 | def __init__(self, num_ch_enc, scales=range(4), num_output_channels=2, use_skips=True, lite_model = None): 10 | super(DepthDecoderV2, self).__init__() 11 | 12 | self.num_output_channels = num_output_channels 13 | self.use_skips = use_skips 14 | self.upsample_mode = 'nearest' 15 | self.scales = scales 16 | self.num_ch_enc = num_ch_enc 17 | self.num_ch_dec = np.array([16, 32, 64, 128, 256]) 18 | 19 | # decoder 20 | self.convs = OrderedDict() 21 | for i in range(4, -1, -1): #i=[4,3,2,1,0] 22 | # upconv_0 23 | num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] 24 | num_ch_out = self.num_ch_dec[i] 25 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)#CONV2D 26 | 27 | # upconv_1 28 | num_ch_in = self.num_ch_dec[i] 29 | if self.use_skips and i > 0: 30 | num_ch_in += self.num_ch_enc[i - 1] 31 | num_ch_out = self.num_ch_dec[i] 32 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) 33 | 34 | for s in self.scales: 35 | self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) 36 | self.decoder = nn.ModuleList(list(self.convs.values())) 37 | self.sigmoid = nn.Sigmoid() 38 | 39 | def forward(self, input_features): 40 | self.outputs = {} 41 | x = input_features[-1] 42 | for i in range(4, -1, -1): 43 | x = self.convs[("upconv", i, 0)](x) 44 | x = [upsample(x)] 45 | if self.use_skips and i > 0: 46 | x += [input_features[i - 1]] 47 | x = torch.cat(x, 1) 48 | x = self.convs[("upconv", i, 1)](x) 49 | if i in self.scales: 50 | final = self.sigmoid(self.convs[("dispconv", i)](x)) 51 | self.outputs[("disp", i)] = final[:,0,:,:].unsqueeze(1) 52 | self.outputs[("uncert", i)] = final[:,1,:,:].unsqueeze(1) 53 | return self.outputs 54 | -------------------------------------------------------------------------------- /networks/uncert_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | from collections import OrderedDict 14 | from layers import * 15 | #from hr_layers import * 16 | #from visual_block import visual_block 17 | 18 | class UncertDecoder(nn.Module): 19 | def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True): 20 | # num_ch_enc = np.array([64, 64, 128, 256, 512]) 21 | super(UncertDecoder, self).__init__() 22 | 23 | self.num_output_channels = num_output_channels 24 | self.use_skips = use_skips 25 | self.upsample_mode = 'nearest' 26 | self.scales = scales 27 | self.num_ch_enc = num_ch_enc 28 | self.num_ch_dec = np.array([16, 32, 64, 128, 256]) 29 | self.convs = OrderedDict() 30 | for i in range(4, -1, -1): #i=[4,3,2,1,0] 31 | # upconv_0 32 | num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] 33 | num_ch_out = self.num_ch_dec[i] 34 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)#CONV2D 35 | 36 | # upconv_1 37 | num_ch_in = self.num_ch_dec[i] 38 | if self.use_skips and i > 0: 39 | num_ch_in += self.num_ch_enc[i - 1] 40 | num_ch_out = self.num_ch_dec[i] 41 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) 42 | 43 | #for s in self.scales: 44 | self.convs[("uncert_conv", 0)] = Conv3x3(self.num_ch_dec[0], self.num_output_channels) 45 | 46 | self.decoder = nn.ModuleList(list(self.convs.values())) 47 | self.sigmoid = nn.Sigmoid()#why not relu? 48 | 49 | def forward(self, input_features): 50 | #block_list = [self.se_block0, self.se_block1, self.se_block2, self.se_block3] 51 | # decoder 52 | x = input_features[-1] 53 | for i in range(4, -1, -1):#[4,3,2,1,0] 54 | x = self.convs[("upconv", i, 0)](x) 55 | x = [upsample(x)]#this function in layers.py 56 | #x = upsample(x)#this function in layers.py 57 | if self.use_skips and i > 0: 58 | x += [input_features[i - 1]] 59 | x = torch.cat(x, 1) 60 | x = self.convs[("upconv", i, 1)](x) 61 | if i == 0: 62 | uncertainty_map = self.sigmoid(self.convs[("uncert_conv", i)](x)) 63 | return uncertainty_map 64 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from datetime import datetime 3 | import numpy as np 4 | import math 5 | import time 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | import json 11 | import torchvision 12 | from utils import * 13 | from kitti_utils import * 14 | from layers import * 15 | 16 | import datasets 17 | import networks 18 | import hr_networks 19 | from IPython import embed 20 | 21 | class Trainer: 22 | def __init__(self, options): 23 | now = datetime.now() 24 | current_time_date = now.strftime("%d%m%Y-%H:%M:%S") 25 | self.opt = options 26 | self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name, current_time_date) 27 | 28 | # checking height and width are multiples of 32 29 | assert self.opt.height % 32 == 0, "'height' must be a multiple of 32" 30 | assert self.opt.width % 32 == 0, "'width' must be a multiple of 32" 31 | 32 | ################################## 33 | self.models = {} 34 | self.parameters_to_train = [] 35 | 36 | self.device = torch.device("cpu" if self.opt.no_cuda else "cuda:0")#not using cuda? 37 | self.num_scales = len(self.opt.scales)#scales = [0,1,2,3]'scales used in the loss' 38 | self.num_input_frames = len(self.opt.frame_ids)#frames = [0,-1,1]'frame to load' 39 | self.num_pose_frames = 2 if self.opt.pose_model_input == "pairs" else self.num_input_frames 40 | assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0" 41 | 42 | self.use_pose_net = not (self.opt.use_stereo and self.opt.frame_ids == [0]) 43 | 44 | if self.opt.use_stereo: 45 | self.opt.frame_ids.append("s") 46 | 47 | ################################## teacher network ################################## 48 | 49 | self.models["encoder_teacher"] = networks.ResnetEncoder( 50 | self.opt.num_layers, self.opt.weights_init == "pretrained") 51 | 52 | self.models["depth_teacher"] = networks.DepthDecoder( 53 | self.models["encoder_teacher"].num_ch_enc, self.opt.scales) 54 | 55 | encoder_path = os.path.join(self.opt.teacher_weights_folder, "encoder.pth") 56 | decoder_path = os.path.join(self.opt.teacher_weights_folder, "depth.pth") 57 | 58 | encoder_dict = torch.load(encoder_path) if torch.cuda.is_available() else torch.load(encoder_path,map_location = 'cpu') 59 | decoder_dict = torch.load(decoder_path) if torch.cuda.is_available() else torch.load(encoder_path,map_location = 'cpu') 60 | model_dict = self.models["encoder_teacher"].state_dict() 61 | dec_model_dict = self.models["depth_teacher"].state_dict() 62 | 63 | self.models["encoder_teacher"].load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict}) 64 | self.models["depth_teacher"].load_state_dict({k: v for k, v in decoder_dict.items() if k in dec_model_dict}) 65 | 66 | ################################### student network ################################# 67 | 68 | self.models["encoder"] = networks.ResnetEncoder( 69 | self.opt.num_layers, self.opt.weights_init == "pretrained") 70 | 71 | self.models["depth"] = networks.DepthDecoderV2( 72 | self.models["encoder"].num_ch_enc, self.opt.scales) 73 | 74 | self.models["encoder"].to(self.device) 75 | self.models["encoder_teacher"].to(self.device) 76 | self.parameters_to_train += list(self.models["encoder"].parameters()) 77 | self.models["depth"].to(self.device) 78 | self.models["depth_teacher"].to(self.device) 79 | self.parameters_to_train += list(self.models["depth"].parameters()) 80 | 81 | self.models["pose_encoder"] = networks.ResnetEncoder( 82 | self.opt.num_layers, 83 | self.opt.weights_init == "pretrained", 84 | num_input_images=self.num_pose_frames)#num_input_images=2 85 | 86 | self.models["pose"] = networks.PoseDecoder( 87 | self.models["pose_encoder"].num_ch_enc, 88 | num_input_features=1, 89 | num_frames_to_predict_for=2) 90 | 91 | ### independent photometric uncertainty modelling: 92 | 93 | self.models["ph_uncert_enc"] = networks.ResnetEncoder( 94 | self.opt.num_layers, 95 | self.opt.weights_init == "pretrained", 96 | num_input_images=2)#num_input_images=2 97 | 98 | self.models["ph_uncert_dec"] = networks.UncertDecoder( 99 | self.models["pose_encoder"].num_ch_enc) 100 | 101 | self.models["ph_uncert_enc"].cuda() 102 | self.parameters_to_train += list(self.models["ph_uncert_enc"].parameters()) 103 | self.models["ph_uncert_dec"].cuda() 104 | self.parameters_to_train += list(self.models["ph_uncert_dec"].parameters()) 105 | 106 | self.models["pose_encoder"].cuda() 107 | self.models["pose"].cuda() 108 | self.parameters_to_train += list(self.models["pose_encoder"].parameters()) 109 | self.parameters_to_train += list(self.models["pose"].parameters()) 110 | 111 | self.model_optimizer = optim.Adam(self.parameters_to_train, self.opt.learning_rate)#learning_rate=1e-4 112 | self.model_lr_scheduler = optim.lr_scheduler.StepLR( 113 | self.model_optimizer, self.opt.scheduler_step_size, 0.1)#defualt = 15'step size of the scheduler' 114 | 115 | if self.opt.load_weights_folder is not None: 116 | self.load_model() 117 | 118 | print("Training model named:\n ", self.opt.model_name) 119 | print("Models and tensorboard events files are saved to:\n ", self.log_path) 120 | print("Training is using:\n ", self.device) 121 | 122 | # data 123 | datasets_dict = {"kitti": datasets.KITTIRAWDataset, 124 | "kitti_odom": datasets.KITTIOdomDataset, 125 | } 126 | self.dataset = datasets_dict["kitti"] 127 | fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt") 128 | 129 | #change trainset 130 | train_filenames = readlines(fpath.format("train")) 131 | #change valset 132 | val_filenames = readlines(fpath.format("val")) 133 | img_ext = '.png' if self.opt.png else '.jpg' 134 | num_train_samples = len(train_filenames ) 135 | self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs 136 | 137 | train_dataset = self.dataset( 138 | self.opt.data_path, train_filenames, self.opt.height, self.opt.width, 139 | self.opt.frame_ids, 4, is_train=True, img_ext='.jpg') 140 | self.train_loader = DataLoader( 141 | train_dataset, self.opt.batch_size, True, 142 | num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) 143 | 144 | val_dataset = datasets.KITTIRAWDataset( 145 | self.opt.data_path, val_filenames, self.opt.height, self.opt.width, 146 | self.opt.frame_ids, 4, is_train=False, img_ext=img_ext) 147 | self.val_loader = DataLoader( 148 | val_dataset, self.opt.batch_size, True, 149 | num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) 150 | self.val_iter = iter(self.val_loader) 151 | 152 | self.writers = {} 153 | 154 | if not self.opt.no_ssim: 155 | #doing this 156 | self.ssim = SSIM() 157 | self.ssim.to(self.device) 158 | self.num_total_batch = (train_dataset.__len__()) // self.opt.batch_size 159 | 160 | self.backproject_depth = {} 161 | self.project_3d = {} 162 | for scale in self.opt.scales: 163 | h = self.opt.height // (2 ** scale)#defualt=[0,1,2,3]'scales used in the loss' 164 | w = self.opt.width // (2 ** scale) 165 | 166 | self.backproject_depth[scale] = BackprojectDepth(self.opt.batch_size, h, w)#in layers.py 167 | self.backproject_depth[scale].to(self.device) 168 | 169 | self.project_3d[scale] = Project3D(self.opt.batch_size, h, w) 170 | self.project_3d[scale].to(self.device) 171 | 172 | self.depth_metric_names = [ 173 | "de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1", "da/a2", "da/a3"] 174 | 175 | print("Using split:\n ", self.opt.split) 176 | print("There are {:d} training items and {:d} validation items\n".format( 177 | len(train_dataset), len(val_dataset))) 178 | 179 | self.save_opts() 180 | 181 | def set_train(self): 182 | """Convert all models to training mode 183 | """ 184 | for m in self.models.values(): 185 | m.train() 186 | 187 | def set_eval(self): 188 | """Convert all models to testing/evaluation mode 189 | """ 190 | for m in self.models.values(): 191 | m.eval() 192 | 193 | def train(self): 194 | """Run the entire training pipeline 195 | """ 196 | self.init_time = time.time() 197 | if isinstance(self.opt.load_weights_folder,str): 198 | self.epoch_start = int(self.opt.load_weights_folder[-1]) + 1 199 | else: 200 | self.epoch_start = 0 201 | self.step = 0 202 | self.start_time = time.time() 203 | for self.epoch in range(self.opt.num_epochs - self.epoch_start): 204 | self.epoch = self.epoch_start + self.epoch 205 | self.run_epoch() 206 | if (self.epoch + 1) % self.opt.save_frequency == 0:#number of epochs between each save defualt =1 207 | self.save_model() 208 | self.total_training_time = time.time() - self.init_time 209 | print('====>total training time:{}'.format(sec_to_hm_str(self.total_training_time))) 210 | 211 | def run_epoch(self): 212 | """Run a single epoch of training and validation 213 | """ 214 | self.set_train() 215 | self.every_epoch_start_time = time.time() 216 | 217 | for batch_idx, inputs in enumerate(self.train_loader): 218 | before_op_time = time.time() 219 | outputs, losses = self.process_batch(inputs) 220 | self.model_optimizer.zero_grad() 221 | losses["loss"].backward() 222 | self.model_optimizer.step() 223 | duration = time.time() - before_op_time 224 | early_phase = batch_idx % self.opt.log_frequency == 0 and self.step < 2000#log_fre 's defualt = 250 225 | late_phase = self.step % 2000 == 0 226 | 227 | if early_phase or late_phase: 228 | self.log_time(batch_idx, duration, losses["loss"].cpu().data) 229 | 230 | if "depth_gt" in inputs: 231 | self.compute_depth_losses(inputs, outputs, losses) 232 | 233 | #self.log("train", inputs, outputs, losses) 234 | self.val() 235 | self.step += 1 236 | 237 | self.model_lr_scheduler.step() 238 | self.every_epoch_end_time = time.time() 239 | print("====>training time of this epoch:{}".format(sec_to_hm_str(self.every_epoch_end_time-self.every_epoch_start_time))) 240 | 241 | def process_batch_ts(self, inputs, outputs, losses): 242 | ### compute the pseudo ground truth 243 | for key, ipt in inputs.items():#inputs.values() has :12x3x196x640. 244 | inputs[key] = ipt.to(self.device)#put tensor in gpu memory 245 | outputs_t = self.models["depth_teacher"](self.models["encoder_teacher"](inputs["color_aug", 0 , 0])) 246 | 247 | for key, item in outputs_t.items(): 248 | outputs_t[key].detach() 249 | regression_loss = self.regress_loss(outputs_t,outputs) 250 | losses["loss"] += regression_loss 251 | return losses 252 | 253 | 254 | def process_batch(self, inputs): 255 | """Pass a minibatch through the network and generate images and losses 256 | """ 257 | for key, ipt in inputs.items():#inputs.values() has :12x3x196x640. 258 | inputs[key] = ipt.to(self.device)#put tensor in gpu memory 259 | 260 | # Otherwise, we only feed the image with frame_id 0 through the depth encoder 261 | features = self.models["encoder"](inputs["color_aug", 0, 0]) 262 | outputs = self.models["depth"](features) 263 | 264 | outputs.update(self.predict_poses(inputs, features)) 265 | 266 | self.generate_images_pred(inputs, outputs) 267 | losses = self.compute_losses(inputs, outputs) 268 | losses = self.process_batch_ts(inputs, outputs, losses) 269 | return outputs, losses 270 | 271 | def predict_poses(self, inputs, features): 272 | """Predict poses between input frames for monocular sequences. 273 | """ 274 | outputs = {} 275 | if self.num_pose_frames == 2: 276 | # In this setting, we compute the pose to each source frame via a 277 | # separate forward pass through the pose network. 278 | 279 | # select what features the pose network takes as input 280 | if self.opt.pose_model_type == "shared": 281 | pose_feats = {f_i: features[f_i] for f_i in self.opt.frame_ids} 282 | else: 283 | pose_feats = {f_i: inputs["color_aug", f_i, 0] for f_i in self.opt.frame_ids} 284 | #pose_feats is a dict: 285 | #key: 286 | """"keys 287 | 0 288 | -1 289 | 1 290 | """ 291 | uncerts = [] 292 | for f_i in self.opt.frame_ids[1:]: 293 | if f_i != "s": 294 | # To maintain ordering we always pass frames in temporal order 295 | if f_i < 0: 296 | pose_inputs = [pose_feats[f_i], pose_feats[0]]#nerboring frames 297 | uncerts_inputs = [pose_feats[f_i], pose_feats[0]] 298 | else: 299 | pose_inputs = [pose_feats[0], pose_feats[f_i]] 300 | uncerts_inputs = [pose_feats[0], pose_feats[f_i]] 301 | if self.opt.pose_model_type == "separate_resnet": 302 | pose_inputs = [self.models["pose_encoder"](torch.cat(pose_inputs, 1))] 303 | elif self.opt.pose_model_type == "posecnn": 304 | pose_inputs = torch.cat(pose_inputs, 1) 305 | 306 | axisangle, translation = self.models["pose"](pose_inputs) 307 | outputs[("axisangle", 0, f_i)] = axisangle 308 | outputs[("translation", 0, f_i)] = translation 309 | outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters( 310 | axisangle[:, 0], translation[:, 0], invert=(f_i < 0)) 311 | 312 | uncert_inputs = self.models["ph_uncert_enc"](torch.cat(uncerts_inputs,1)) 313 | uncert = self.models[("ph_uncert_dec")](uncert_inputs) 314 | uncerts.append(uncert) 315 | photometric_uncerts = torch.cat(uncerts, 1) 316 | outputs["photometric_uncert"] = photometric_uncerts.mean(1,True) 317 | 318 | ##################################################################### 319 | 320 | else: 321 | # Here we input all frames to the pose net (and predict all poses) together 322 | if self.opt.pose_model_type in ["separate_resnet", "posecnn"]: 323 | pose_inputs = torch.cat( 324 | [inputs[("color_aug", i, 0)] for i in self.opt.frame_ids if i != "s"], 1) 325 | 326 | if self.opt.pose_model_type == "separate_resnet": 327 | pose_inputs = [self.models["pose_encoder"](pose_inputs)] 328 | 329 | elif self.opt.pose_model_type == "shared": 330 | pose_inputs = [features[i] for i in self.opt.frame_ids if i != "s"] 331 | 332 | axisangle, translation = self.models["pose"](pose_inputs) 333 | 334 | for i, f_i in enumerate(self.opt.frame_ids[1:]): 335 | if f_i != "s": 336 | outputs[("axisangle", 0, f_i)] = axisangle 337 | outputs[("translation", 0, f_i)] = translation 338 | outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters( 339 | axisangle[:, i], translation[:, i]) 340 | 341 | return outputs 342 | 343 | def val(self): 344 | """Validate the model on a single minibatch 345 | """ 346 | self.set_eval() 347 | i = 0 348 | try: 349 | inputs = self.val_iter.next() 350 | except StopIteration: 351 | self.val_iter = iter(self.val_loader) 352 | inputs = self.val_iter.next() 353 | 354 | with torch.no_grad(): 355 | if i == 0: 356 | outputs, losses = self.process_batch(inputs) 357 | 358 | if "depth_gt" in inputs: 359 | self.compute_depth_losses(inputs, outputs, losses) 360 | 361 | del inputs, outputs, losses 362 | 363 | self.set_train() 364 | 365 | def generate_images_pred(self, inputs, outputs): 366 | """Generate the warped (reprojected) color images for a minibatch. 367 | Generated images are saved into the `outputs` dictionary. 368 | """ 369 | for scale in self.opt.scales: 370 | disp = outputs[("disp", scale)] 371 | if self.opt.v1_multiscale: 372 | source_scale = scale 373 | else: 374 | disp = F.interpolate( 375 | disp, [self.opt.height, self.opt.width], mode="bilinear", align_corners=False) 376 | source_scale = 0 377 | 378 | _, depth = disp_to_depth(disp, self.opt.min_depth, self.opt.max_depth)#disp_to_depth function is in layers.py 379 | 380 | outputs[("depth", 0, scale)] = depth 381 | 382 | for i, frame_id in enumerate(self.opt.frame_ids[1:]): 383 | 384 | if frame_id == "s": 385 | T = inputs["stereo_T"] 386 | else: 387 | T = outputs[("cam_T_cam", 0, frame_id)] 388 | 389 | if self.opt.pose_model_type == "posecnn": 390 | 391 | axisangle = outputs[("axisangle", 0, frame_id)] 392 | translation = outputs[("translation", 0, frame_id)] 393 | 394 | inv_depth = 1 / depth 395 | mean_inv_depth = inv_depth.mean(3, True).mean(2, True) 396 | 397 | T = transformation_from_parameters( 398 | axisangle[:, 0], translation[:, 0] * mean_inv_depth[:, 0], frame_id < 0) 399 | 400 | cam_points = self.backproject_depth[source_scale]( 401 | depth, inputs[("inv_K", source_scale)]) 402 | pix_coords = self.project_3d[source_scale]( 403 | cam_points, inputs[("K", source_scale)], T) 404 | outputs[("sample", frame_id, scale)] = pix_coords 405 | 406 | outputs[("color", frame_id, scale)] = F.grid_sample( 407 | inputs[("color", frame_id, source_scale)], 408 | outputs[("sample", frame_id, scale)], 409 | padding_mode="border") 410 | 411 | if not self.opt.disable_automasking: 412 | #doing this 413 | outputs[("color_identity", frame_id, scale)] = \ 414 | inputs[("color", frame_id, source_scale)] 415 | 416 | def regress_loss(self, outputs_t, outputs): 417 | losses ={} 418 | abs_diff = torch.abs(outputs[("disp",0)] - outputs_t[("disp",0)]) 419 | uncerted_l1_loss = ( abs_diff / outputs[("uncert",0)] + torch.log(outputs[("uncert",0)])).mean() 420 | return uncerted_l1_loss 421 | 422 | def compute_reprojection_loss(self, pred, target): 423 | """Computes reprojection loss between a batch of predicted and target images 424 | """ 425 | abs_diff = torch.abs(target - pred) 426 | l1_loss = abs_diff.mean(1, True) 427 | 428 | if self.opt.no_ssim: 429 | reprojection_loss = l1_loss 430 | else: 431 | ssim_loss = self.ssim(pred, target).mean(1, True) 432 | reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss 433 | 434 | return reprojection_loss 435 | 436 | def compute_losses(self, inputs, outputs): 437 | """Compute the reprojection and smoothness losses for a minibatch 438 | """ 439 | losses = {} 440 | ph_total_loss = 0 441 | total_loss = 0 442 | 443 | for scale in self.opt.scales: 444 | ph_loss = 0 445 | loss = 0 446 | reprojection_losses = [] 447 | 448 | if self.opt.v1_multiscale: 449 | source_scale = scale 450 | else: 451 | source_scale = 0 452 | 453 | disp = outputs[("disp", scale)] 454 | color = inputs[("color", 0, scale)] 455 | target = inputs[("color", 0, source_scale)] 456 | 457 | for frame_id in self.opt.frame_ids[1:]: 458 | pred = outputs[("color", frame_id, scale)] 459 | reprojection_losses.append(self.compute_reprojection_loss(pred, target)) 460 | reprojection_losses = torch.cat(reprojection_losses, 1) 461 | if not self.opt.disable_automasking: 462 | #doing this 463 | identity_reprojection_losses = [] 464 | for frame_id in self.opt.frame_ids[1:]: 465 | pred = inputs[("color", frame_id, source_scale)] 466 | identity_reprojection_losses.append( 467 | self.compute_reprojection_loss(pred, target)) 468 | 469 | identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1) 470 | if self.opt.avg_reprojection: 471 | identity_reprojection_loss = identity_reprojection_losses.mean(1, keepdim=True) 472 | else: 473 | # save both images, and do min all at once below 474 | identity_reprojection_loss = identity_reprojection_losses 475 | 476 | elif self.opt.predictive_mask: 477 | mask = outputs["predictive_mask"]["disp", scale] 478 | if not self.opt.v1_multiscale: 479 | mask = F.interpolate( 480 | mask, [self.opt.height, self.opt.width], 481 | mode="bilinear", align_corners=False) 482 | 483 | reprojection_losses *= mask 484 | 485 | # add a loss pushing mask to 1 (using nn.BCELoss for stability) 486 | weighting_loss = 0.2 * nn.BCELoss()(mask, torch.ones(mask.shape).cuda()) if torch.cuda.is_available() else 0.2 * nn.BCELoss()(mask, torch.ones(mask.shape).cpu()) 487 | loss += weighting_loss.mean() 488 | 489 | if self.opt.avg_reprojection: 490 | reprojection_loss = reprojection_losses.mean(1, keepdim=True) 491 | else: 492 | #doing_this 493 | reprojection_loss = reprojection_losses 494 | 495 | if not self.opt.disable_automasking: 496 | if torch.cuda.is_available(): 497 | identity_reprojection_loss += torch.randn(identity_reprojection_loss.shape).cuda() * 0.00001 498 | else: 499 | identity_reprojection_loss += torch.randn(identity_reprojection_loss.shape).cpu() * 0.00001 500 | combined = torch.cat((identity_reprojection_loss, reprojection_loss), dim=1) 501 | else: 502 | combined = reprojection_loss 503 | 504 | if combined.shape[1] == 1: 505 | to_optimise = combined 506 | else: 507 | #doing this 508 | to_optimise, idxs = torch.min(combined, dim=1) 509 | if not self.opt.disable_automasking: 510 | outputs["identity_selection/{}".format(scale)] = ( 511 | idxs > identity_reprojection_loss.shape[1] - 1).float() 512 | to_optimise_1 = to_optimise / outputs["photometric_uncert"] + torch.log(outputs["photometric_uncert"]) 513 | loss += to_optimise_1.mean() 514 | mean_disp = disp.mean(2, True).mean(3, True) 515 | norm_disp = disp / (mean_disp + 1e-7) 516 | smooth_loss = get_smooth_loss(norm_disp, color) 517 | 518 | loss += self.opt.disparity_smoothness * smooth_loss / (2 ** scale)#defualt=1e-3 something with get_smooth_loss function 519 | total_loss += loss 520 | losses["loss/{}".format(scale)] = loss 521 | total_loss /= self.num_scales 522 | losses["loss"] = total_loss 523 | return losses 524 | 525 | def compute_depth_losses(self, inputs, outputs, losses): 526 | """Compute depth metrics, to allow monitoring during training 527 | 528 | This isn't particularly accurate as it averages over the entire batch, 529 | so i#s only used to give an indication of validation performance 530 | 531 | 532 | """ 533 | depth_pred = outputs[("depth", 0, 0)] 534 | depth_pred = torch.clamp(F.interpolate( 535 | depth_pred, [375, 1242], mode="bilinear", align_corners=False), 1e-3, 80) 536 | depth_pred = depth_pred.detach() 537 | 538 | depth_gt = inputs["depth_gt"] 539 | mask = depth_gt > 0 540 | 541 | # garg/eigen crop 542 | crop_mask = torch.zeros_like(mask) 543 | crop_mask[:, :, 153:371, 44:1197] = 1 544 | mask = mask * crop_mask 545 | 546 | depth_gt = depth_gt[mask] 547 | depth_pred = depth_pred[mask] 548 | depth_pred *= torch.median(depth_gt) / torch.median(depth_pred) 549 | 550 | depth_pred = torch.clamp(depth_pred, min=1e-3, max=80) 551 | 552 | depth_errors = compute_depth_errors(depth_gt, depth_pred) 553 | 554 | for i, metric in enumerate(self.depth_metric_names): 555 | losses[metric] = np.array(depth_errors[i].cpu()) 556 | 557 | def log_time(self, batch_idx, duration, loss): 558 | """Print a logging statement to the terminal 559 | """ 560 | samples_per_sec = self.opt.batch_size / duration 561 | time_sofar = time.time() - self.start_time 562 | training_time_left = ( 563 | self.num_total_steps / self.step - 1.0) * time_sofar if self.step > 0 else 0 564 | print_string = "epoch {:>3} | batch_idx {:>6} | examples/s: {:5.1f}" + \ 565 | " | loss: {:.5f} | time elapsed: {} | time left: {}" 566 | print(print_string.format(self.epoch, batch_idx, samples_per_sec, loss, 567 | sec_to_hm_str(time_sofar), sec_to_hm_str(training_time_left))) 568 | 569 | def log(self, mode, inputs, outputs, losses): 570 | """Write an event to the tensorboard events file 571 | """ 572 | #writer = self.writers[mode] 573 | for l, v in losses.items(): 574 | writer.add_scalar("{}".format(l), v, self.step) 575 | 576 | for j in range(min(4, self.opt.batch_size)): # write a maxmimum of four images 577 | for s in self.opt.scales: 578 | for frame_id in self.opt.frame_ids: 579 | writer.add_image( 580 | "color_{}_{}/{}".format(frame_id, s, j), 581 | inputs[("color", frame_id, s)][j].data, self.step) 582 | if s == 0 and frame_id != 0: 583 | writer.add_image( 584 | "color_pred_{}_{}/{}".format(frame_id, s, j), 585 | outputs[("color", frame_id, s)][j].data, self.step) 586 | 587 | writer.add_image( 588 | "disp_{}/{}".format(s, j), 589 | normalize_image(outputs[("disp", s)][j]), self.step) 590 | 591 | if self.opt.predictive_mask: 592 | for f_idx, frame_id in enumerate(self.opt.frame_ids[1:]): 593 | writer.add_image( 594 | "predictive_mask_{}_{}/{}".format(frame_id, s, j), 595 | outputs["predictive_mask"][("disp", s)][j, f_idx][None, ...], 596 | self.step) 597 | 598 | elif not self.opt.disable_automasking: 599 | writer.add_image( 600 | "automask_{}/{}".format(s, j), 601 | outputs["identity_selection/{}".format(s)][j][None, ...], self.step) 602 | 603 | def save_opts(self): 604 | """Save options to disk so we know what we ran this experiment with 605 | """ 606 | models_dir = os.path.join(self.log_path, "models") 607 | if not os.path.exists(models_dir): 608 | os.makedirs(models_dir) 609 | to_save = self.opt.__dict__.copy() 610 | 611 | with open(os.path.join(models_dir, 'opt.json'), 'w') as f: 612 | json.dump(to_save, f, indent=2) 613 | 614 | def save_model(self): 615 | """Save model weights to disk 616 | """ 617 | save_folder = os.path.join(self.log_path, "models", "weights_{}".format(self.epoch)) 618 | if not os.path.exists(save_folder): 619 | os.makedirs(save_folder) 620 | 621 | for model_name, model in self.models.items(): 622 | if model_name in ["encoder", "depth", "uncert","pose_encoder","pose","ph_uncert_enc","ph_uncert_dec"]: 623 | save_path = os.path.join(save_folder, "{}.pth".format(model_name)) 624 | to_save = model.state_dict() 625 | if model_name == 'encoder': 626 | # save the sizes - these are needed at prediction time 627 | to_save['height'] = self.opt.height 628 | to_save['width'] = self.opt.width 629 | to_save['use_stereo'] = self.opt.use_stereo 630 | torch.save(to_save, save_path) 631 | 632 | 633 | def load_model(self): 634 | """Load model(s) from disk 635 | """ 636 | self.opt.load_weights_folder = os.path.expanduser(self.opt.load_weights_folder) 637 | 638 | assert os.path.isdir(self.opt.load_weights_folder), \ 639 | "Cannot find folder {}".format(self.opt.load_weights_folder) 640 | print("loading model from folder {}".format(self.opt.load_weights_folder)) 641 | 642 | for n in self.opt.models_to_load: 643 | print("Loading {} weights...".format(n)) 644 | path = os.path.join(self.opt.load_weights_folder, "{}.pth".format(n)) 645 | model_dict = self.models[n].state_dict() 646 | pretrained_dict = torch.load(path) 647 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 648 | model_dict.update(pretrained_dict) 649 | self.models[n].load_state_dict(model_dict) 650 | 651 | # loading adam state 652 | optimizer_load_path = os.path.join(self.opt.load_weights_folder, "adam.pth") 653 | if os.path.isfile(optimizer_load_path): 654 | print("Loading Adam weights") 655 | optimizer_dict = torch.load(optimizer_load_path) 656 | self.model_optimizer.load_state_dict(optimizer_dict) 657 | else: 658 | print("Cannot find Adam weights so Adam is randomly initialized") 659 | --------------------------------------------------------------------------------