├── .gitignore ├── README.md ├── create_nyu_h5.py ├── data.py ├── discritization.py ├── loss.py ├── lr_decay.py ├── model.py ├── progress_tracking.py ├── resnet_dilated.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | pretrained/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DORN_depth_estimation_Pytorch 2 | 3 | This is an unoficial Pytorch implementation of [Deep Ordinal Regression Network for Monocular Depth Estimation](http://arxiv.org/abs/1806.02446) paper by Fu et. al. 4 | 5 | Table. Performance on NYU V2. 6 | | Source | δ1 | δ2 | δ3 | rel | log10 | rms | 7 | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | 8 | |Original paper*| 0.828 | 0.965 | 0.992 | 0.115 | 0.051 | 0.509 | 9 | | This repo* | 0.806 | 0.957 | 0.989 | 0.151 | 0.062 | 0.586 | 10 | 11 | *Note, that the data splits are different (see Known Differences below for details). The worse performance might be due to the smaller training set (795 vs about 120K images). 12 | 13 | ## How to use 14 | 15 | These steps show how to run the code on the official split of the NYU V2 depth dataset. 16 | 17 | To prepare data: 18 | - Download [nyu_depth_v2_labeled.mat](http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat) (data) and [splits.mat](http://horatio.cs.nyu.edu/mit/silberman/indoor_seg_sup/splits.mat). 19 | - Edit create_nyu_h5.py to add data_path (folder with the .mat files from previous step) and output_path. 20 | - Run: 21 | ```bash 22 | python create_nyu_h5.py 23 | ``` 24 | 25 | For start training on NYU V2 run: 26 | ```bash 27 | train.py [-h] [--dataset DATASET] [--data-path DATA_PATH] 28 | [--pretrained] [--epochs EPOCHS] [--bs BS] [--bs-test BS_TEST] 29 | [--lr LR] [--gpu GPU] 30 | ``` 31 | Or simply: 32 | ```bash 33 | python train.py --data-path DATA_PATH --pretrained 34 | ``` 35 | (where DATA_PATH is same as output_path used during preparing data). 36 | 37 | For more info on arguments run: 38 | ```bash 39 | python train.py --help 40 | ``` 41 | 42 | To train on a different dataset, implementation of the DataLoader is required. 43 | 44 | To monitor training, use Tensorboard: 45 | ```bash 46 | tensorboard --logdir ./logs/ 47 | ``` 48 | 49 | ## Known Differences 50 | The implementation closely follows the paper and the [official repo](https://github.com/hufu6371/DORN) with some exceptions. The list of known differences: 51 | - Only training on the labeled part of NYU V2 is currently implemented (not on all the raw data). 52 | - ColorJitter is used instead of the color transformation from the Eigen's paper. 53 | - Feature extractor is pretrained on a different dataset. 54 | 55 | ## Pretrained feature extractor 56 | 57 | DORN uses a modified version of ResNet-101 as a feature extractor (with dilations and three 3x3 convolutional layers in the begining instead of one 7x7 layer). If you select pretrained=True, weights pretrained on MIT ADE20K dataset will be loaded from [this project](https://github.com/CSAILVision/semantic-segmentation-pytorch). This is different from the paper (the authors suggest pretraining on ImageNet). That is the only suitable pretrained model on the Web that I am aware of. 58 | 59 | ## Requirements 60 | 61 | - Python 3 62 | - Pytorch (version 1.3 tested) 63 | - Torchvision 64 | - Tensorboard 65 | 66 | ## Acknowledgements 67 | 68 | The code is based on [this implementation](https://github.com/dontLoveBugs/DORN_pytorch). 69 | -------------------------------------------------------------------------------- /create_nyu_h5.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import h5py 4 | import os 5 | 6 | # Modify next 2 lines 7 | data_path = '' # where to find nyu_depth_v2_labeled.mat and splits.mat 8 | output_path = '' # where to put the resulting dataset 9 | 10 | train_path = os.path.join(output_path, 'train/official/') 11 | val_path = os.path.join(output_path, 'val/official/') 12 | 13 | splilts = sio.loadmat(data_path + '/splits.mat') 14 | 15 | train_idx = splilts['trainNdxs'] 16 | val_idx = splilts['testNdxs'] 17 | 18 | train_idx = np.array(train_idx) 19 | val_idx = np.array(val_idx) 20 | 21 | f = h5py.File(data_path + '/nyu_depth_v2_labeled.mat') 22 | images = f["images"] 23 | depths = f["depths"] 24 | labels = f["labels"] 25 | 26 | images = np.array(images) 27 | depths = np.array(depths) 28 | labels = np.array(labels) 29 | 30 | 31 | if not os.path.isdir(train_path): 32 | os.makedirs(train_path) 33 | 34 | if not os.path.isdir(val_path): 35 | os.makedirs(val_path) 36 | 37 | for idx in range(len(train_idx)): 38 | f_idx = '{0:0>5}'.format(int(train_idx[idx])) 39 | print('train:', f_idx) 40 | h5f = h5py.File(train_path + f_idx + '.h5', 'w') 41 | 42 | h5f['rgb'] = np.transpose(images[train_idx[idx] - 1][0], (0, 2, 1)) 43 | h5f['depth'] = np.transpose(depths[train_idx[idx] - 1][0], (1, 0)) 44 | 45 | h5f.close() 46 | 47 | for idx in range(len(val_idx)): 48 | f_idx = '{0:0>5}'.format(int(val_idx[idx])) 49 | print('val:', f_idx) 50 | h5f = h5py.File(val_path + f_idx + '.h5', 'w') 51 | 52 | h5f['rgb'] = np.transpose(images[val_idx[idx] - 1][0], (0, 2, 1)) 53 | h5f['depth'] = np.transpose(depths[val_idx[idx] - 1][0], (1, 0)) 54 | 55 | h5f.close() 56 | 57 | print(train_idx[0]) 58 | print(images[train_idx[0] - 1][0].shape) 59 | print(depths[train_idx[0] - 1][0].shape) -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torchvision.transforms as T 4 | import torchvision.transforms.functional as TF 5 | from PIL import Image 6 | import numpy as np 7 | import h5py 8 | import os 9 | import random 10 | 11 | 12 | class NYUDataset(Dataset): 13 | def __init__(self, root, type): 14 | self.classes, self.class_to_idx = self.find_classes(root) 15 | self.imgs = self.make_dataset(root, self.class_to_idx) 16 | assert len(self.imgs) > 0, "Found 0 images in subfolders of: " + root + "\n" 17 | print("Found {} images in {} folder.".format(len(self.imgs), type)) 18 | if type == 'train': 19 | self.transform = self.train_transform 20 | else: 21 | self.transform = self.test_transform 22 | 23 | def find_classes(self, dir): 24 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 25 | classes.sort() 26 | class_to_idx = {classes[i]: i for i in range(len(classes))} 27 | return classes, class_to_idx 28 | 29 | 30 | def make_dataset(self, dir, class_to_idx): 31 | images = [] 32 | dir = os.path.expanduser(dir) 33 | for target in sorted(os.listdir(dir)): 34 | d = os.path.join(dir, target) 35 | if not os.path.isdir(d): 36 | continue 37 | for root, _, fnames in sorted(os.walk(d)): 38 | for fname in sorted(fnames): 39 | if fname.endswith('.h5'): 40 | path = os.path.join(root, fname) 41 | item = (path, class_to_idx[target]) 42 | images.append(item) 43 | return images 44 | 45 | def train_transform(self, rgb, depth): 46 | """ 47 | Train data augmentation. For details see the DORN paper. 48 | """ 49 | # Resize for computational efficiency 50 | rgb = TF.resize(rgb, size=(288, 384)) 51 | depth = TF.resize(depth, size=(288, 384)) 52 | 53 | # Random rotation 54 | angle = T.RandomRotation.get_params(degrees=(-5,5)) 55 | rgb = TF.rotate(rgb, angle) 56 | depth = TF.rotate(depth, angle) 57 | 58 | # Random scaling 59 | s = np.random.uniform(1.0, 1.5) 60 | rgb = TF.resize(rgb, size=round(288 * s)) 61 | depth = TF.resize(depth, size=round(288 * s)) 62 | 63 | # Random crop 64 | i, j, h, w = T.RandomCrop.get_params(rgb, output_size=(257, 353)) 65 | rgb = TF.crop(rgb, i, j, h, w) 66 | depth = TF.crop(depth, i, j, h, w) 67 | 68 | # Random horizontal flipping 69 | if random.random() > 0.5: 70 | rgb = TF.hflip(rgb) 71 | depth = TF.hflip(depth) 72 | 73 | color_jitter = T.ColorJitter(0.4, 0.4, 0.4) 74 | rgb = color_jitter(rgb) 75 | 76 | rgb = TF.to_tensor(rgb) 77 | depth = TF.to_tensor(depth) 78 | 79 | depth /= s # preserves world-space geometry of the scene 80 | 81 | return rgb, depth 82 | 83 | def test_transform(self, rgb, depth): 84 | """ 85 | Test data augmentation. For details see the DORN paper. 86 | """ 87 | # data augmentations 88 | transform = T.Compose([ 89 | T.Resize((288, 384)), 90 | T.CenterCrop((257, 353)), 91 | T.ToTensor() 92 | ]) 93 | 94 | rgb = transform(rgb) 95 | depth = transform(depth) 96 | 97 | return rgb, depth 98 | 99 | def __getitem__(self, index): 100 | path, target = self.imgs[index] 101 | 102 | h5f = h5py.File(path, "r") 103 | rgb = Image.fromarray(np.array(h5f['rgb']).transpose((1, 2, 0)), 'RGB') 104 | depth = Image.fromarray(np.array(h5f['depth']), 'F') 105 | 106 | rgb, depth = self.transform(rgb, depth) 107 | return rgb, depth 108 | 109 | def __len__(self): 110 | return len(self.imgs) 111 | 112 | 113 | def get_dataloaders(dataset, data_path, bs, bs_test): 114 | if dataset == 'nyu': 115 | train_set = NYUDataset(os.path.join(data_path, 'train'), type='train') 116 | test_set = NYUDataset(os.path.join(data_path, 'val'), type='val') 117 | else: 118 | print('Not implemented for dataset', dataset) 119 | raise NotImplementedError 120 | 121 | train_loader = DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=10, pin_memory=True) 122 | test_loader = DataLoader(test_set, batch_size=bs_test, shuffle=False, num_workers=10, pin_memory=True) 123 | return train_loader, test_loader -------------------------------------------------------------------------------- /discritization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SID: 4 | def __init__(self, dataset): 5 | super(SID, self).__init__() 6 | if dataset == 'kitti': 7 | alpha = 0.001 8 | beta = 80.0 9 | elif dataset == 'nyu': 10 | alpha = 0.7113 11 | beta = 9.9955 12 | 13 | K = 80.0 14 | 15 | self.alpha = torch.tensor(alpha).cuda() 16 | self.beta = torch.tensor(beta).cuda() 17 | self.K = torch.tensor(K).cuda() 18 | 19 | def labels2depth(self, labels): 20 | depth = self.alpha * (self.beta / self.alpha) ** (labels.float() / self.K) 21 | return depth.float() 22 | 23 | 24 | def depth2labels(self, depth): 25 | labels = self.K * torch.log(depth / self.alpha) / torch.log(self.beta / self.alpha) 26 | return labels.cuda().round().int() 27 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class OrdinalLoss(nn.Module): 6 | """ 7 | Ordinal loss as defined in the paper "DORN for Monocular Depth Estimation". 8 | """ 9 | 10 | def __init__(self): 11 | super(OrdinalLoss, self).__init__() 12 | 13 | def forward(self, pred_softmax, target_labels): 14 | """ 15 | :param pred_softmax: predicted softmax probabilities P 16 | :param target_labels: ground truth ordinal labels 17 | :return: ordinal loss 18 | """ 19 | N, C, H, W = pred_softmax.size() # C - number of discrete sub-intervals (= number of channels) 20 | 21 | K = torch.zeros((N, C, H, W), dtype=torch.int).cuda() 22 | for i in range(C): 23 | K[:, i, :, :] = K[:, i, :, :] + i * torch.ones((N, H, W), dtype=torch.int).cuda() 24 | 25 | mask = (K <= target_labels).detach() 26 | 27 | loss = pred_softmax[mask].clamp(1e-8, 1e8).log().sum() + (1 - pred_softmax[~mask]).clamp(1e-8, 1e8).log().sum() 28 | loss /= -N * H * W 29 | return loss -------------------------------------------------------------------------------- /lr_decay.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | # Code from https://github.com/cmpark0126/pytorch-polynomial-lr-decay/blob/master/torch_poly_lr_decay/torch_poly_lr_decay.py 4 | class PolynomialLRDecay(_LRScheduler): 5 | """Polynomial learning rate decay until step reach to max_decay_step 6 | 7 | Args: 8 | optimizer (Optimizer): Wrapped optimizer. 9 | max_decay_steps: after this step, we stop decreasing learning rate 10 | end_learning_rate: scheduler stoping learning rate decay, value of learning rate must be this value 11 | power: The power of the polynomial. 12 | """ 13 | 14 | def __init__(self, optimizer, max_decay_steps, end_learning_rate=0.0001, power=1.0): 15 | if max_decay_steps <= 1.: 16 | raise ValueError('max_decay_steps should be greater than 1.') 17 | self.max_decay_steps = max_decay_steps 18 | self.end_learning_rate = end_learning_rate 19 | self.power = power 20 | self.last_step = 0 21 | super().__init__(optimizer) 22 | 23 | def get_lr(self): 24 | if self.t_step > self.max_decay_steps: 25 | return [self.end_learning_rate for _ in self.base_lrs] 26 | 27 | return [(base_lr - self.end_learning_rate) * 28 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 29 | self.end_learning_rate for base_lr in self.base_lrs] 30 | 31 | def step(self, step=None): 32 | if step is None: 33 | step = self.last_step + 1 34 | self.last_step = step if step != 0 else 1 35 | if self.last_step <= self.max_decay_steps: 36 | decay_lrs = [(base_lr - self.end_learning_rate) * 37 | ((1 - self.last_step / self.max_decay_steps) ** (self.power)) + 38 | self.end_learning_rate for base_lr in self.base_lrs] 39 | for param_group, lr in zip(self.optimizer.param_groups, decay_lrs): 40 | param_group['lr'] = lr -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from resnet_dilated import resnet101dilated 5 | 6 | 7 | def weights_init(model): 8 | for m in model.modules(): 9 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 10 | nn.init.xavier_normal_(m.weight.data) 11 | if m.bias is not None: 12 | m.bias.data.zero_() 13 | elif isinstance(m, nn.BatchNorm2d): 14 | m.weight.data.fill_(1.0) 15 | m.bias.data.zero_() 16 | 17 | 18 | class SceneUnderstandingModule(nn.Module): 19 | def __init__(self, dataset): 20 | super(SceneUnderstandingModule, self).__init__() 21 | if dataset == 'kitti': 22 | dilations = [6, 12, 18] 23 | self.out_size = (385, 513) 24 | elif dataset == 'nyu': 25 | dilations = [4, 8, 12] 26 | self.out_size = (257, 353) 27 | 28 | self.encoder = FullImageEncoder(dataset) 29 | self.aspp1 = nn.Sequential( 30 | nn.Conv2d(2048, 512, 1), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(512, 512, 1), 33 | nn.ReLU(inplace=True) 34 | ) 35 | self.aspp2 = nn.Sequential( 36 | nn.Conv2d(2048, 512, 3, padding=dilations[0], dilation=dilations[0]), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(512, 512, 1), 39 | nn.ReLU(inplace=True) 40 | ) 41 | self.aspp3 = nn.Sequential( 42 | nn.Conv2d(2048, 512, 3, padding=dilations[1], dilation=dilations[1]), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(512, 512, 1), 45 | nn.ReLU(inplace=True) 46 | ) 47 | self.aspp4 = nn.Sequential( 48 | nn.Conv2d(2048, 512, 3, padding=dilations[2], dilation=dilations[2]), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(512, 512, 1), 51 | nn.ReLU(inplace=True) 52 | ) 53 | self.concat_process = nn.Sequential( 54 | nn.Dropout2d(p=0.5), 55 | nn.Conv2d(512 * 5, 2048, 1), 56 | nn.ReLU(inplace=True), 57 | nn.Dropout2d(p=0.5), 58 | nn.Conv2d(2048, 160, 1) # out_channels=160=2*k for k=80; in official published models 59 | ) 60 | 61 | 62 | def forward(self, x): 63 | x1 = self.encoder(x) 64 | x2 = self.aspp1(x) 65 | x3 = self.aspp2(x) 66 | x4 = self.aspp3(x) 67 | x5 = self.aspp4(x) 68 | 69 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 70 | # print('Scene Understanding Module concat:', x.size()) 71 | x = self.concat_process(x) 72 | # print('Scene Understanding Module processed:', x.size()) 73 | x = F.interpolate(x, size=self.out_size, mode='bilinear', align_corners=True) 74 | return x 75 | 76 | 77 | class FullImageEncoder(nn.Module): 78 | def __init__(self, dataset): 79 | super(FullImageEncoder, self).__init__() 80 | if dataset == 'kitti': 81 | k = 16 82 | self.h, self.w = 49, 65 83 | self.h_, self.w_ = 4, 5 84 | elif dataset == 'nyu': 85 | k = 8 86 | self.h, self.w = 33, 45 87 | self.h_, self.w_ = 5, 6 88 | 89 | self.global_pooling = nn.AvgPool2d(k, stride=k, ceil_mode=True) # It seems, Caffe uses ceil_mode by default. 90 | self.dropout = nn.Dropout2d(p=0.5) 91 | self.global_fc = nn.Linear(2048 * self.h_ * self.w_, 512) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.conv1 = nn.Conv2d(512, 512, 1) # 1x1 conv 94 | 95 | def forward(self, x): 96 | # print('Full Image Encoder Input:', x.size()) 97 | x = self.global_pooling(x) 98 | x = self.dropout(x) 99 | # print('Full Image Encoder Pool:', x.size()) 100 | x = x.view(-1, 2048 * self.h_ * self.w_) 101 | # print('Full Image Encoder View1:', x.size()) 102 | x = self.global_fc(x) 103 | x = self.relu(x) 104 | # print('Full Image Encoder FC:', x.size()) 105 | x = x.view(-1, 512, 1, 1) 106 | # print('Full Image Encoder View2:', x.size()) 107 | x = self.conv1(x) 108 | x = self.relu(x) 109 | x = F.interpolate(x, size=(self.h, self.w), mode='bilinear', align_corners=True) # the "COPY" upsampling 110 | # print('Full Image Encoder Upsample:', x.size()) 111 | return x 112 | 113 | 114 | class OrdinalRegressionLayer(nn.Module): 115 | def __init__(self): 116 | super(OrdinalRegressionLayer, self).__init__() 117 | 118 | def forward(self, x): 119 | """ 120 | :param x: N x 2K x H x W; N - batch_size, 2K - channels, K - number of discrete sub-intervals 121 | :return: labels - ordinal labels (corresponding to discrete depth values) of size N x 1 x H x W 122 | softmax - predicted softmax probabilities P (as in the paper) of size N x K x H x W 123 | """ 124 | N, K, H, W = x.size() 125 | K = K // 2 # number of discrete sub-intervals 126 | 127 | odd = x[:, ::2, :, :].clone() 128 | even = x[:, 1::2, :, :].clone() 129 | 130 | odd = odd.view(N, 1, K * H * W) 131 | even = even.view(N, 1, K * H * W) 132 | 133 | paired_channels = torch.cat((odd, even), dim=1) 134 | paired_channels = paired_channels.clamp(min=1e-8, max=1e8) # prevent nans 135 | 136 | softmax = nn.functional.softmax(paired_channels, dim=1) 137 | 138 | softmax = softmax[:, 1, :] 139 | softmax = softmax.view(-1, K, H, W) 140 | labels = torch.sum((softmax > 0.5), dim=1).view(-1, 1, H, W) 141 | return labels, softmax 142 | 143 | 144 | class DORN(nn.Module): 145 | def __init__(self, dataset, pretrained=False): 146 | if not (dataset == 'kitti' or dataset == 'nyu'): 147 | raise NotImplementedError('Supported datasets: kitti | nuy (got %s)' % dataset) 148 | 149 | super(DORN, self).__init__() 150 | self.pretrained = pretrained 151 | 152 | self.dense_feature_extractor = resnet101dilated(pretrained=pretrained) 153 | self.scene_understanding_modulule = SceneUnderstandingModule(dataset=dataset) 154 | self.ordinal_regression = OrdinalRegressionLayer() 155 | 156 | weights_init(self.scene_understanding_modulule) 157 | weights_init(self.ordinal_regression) 158 | 159 | def forward(self, x): 160 | # Input image size KITTI: (385, 513), NYU: (257, 353) 161 | x = self.dense_feature_extractor(x) # Output KITTI: [batch, 2048, 49, 65], NYU: [batch, 2048, 33, 45]. 162 | x = self.scene_understanding_modulule(x) # Output shape same as input shape except 2K channels. 163 | labels, softmax = self.ordinal_regression(x) 164 | return labels, softmax 165 | 166 | def train(self, mode=True): 167 | """ 168 | Override train() to keep BN and first two conv layers frozend. 169 | """ 170 | super().train(mode) 171 | 172 | if self.pretrained: 173 | # Freeze BatchNorm layers 174 | for module in self.modules(): 175 | if isinstance(module, nn.modules.BatchNorm2d): 176 | module.eval() 177 | 178 | # Freeze first two conv layers 179 | self.dense_feature_extractor.conv1.eval() 180 | self.dense_feature_extractor.conv2.eval() 181 | 182 | return self 183 | 184 | def get_1x_lr_params(self): 185 | for k in self.dense_feature_extractor.parameters(): 186 | if k.requires_grad: 187 | yield k 188 | 189 | def get_10x_lr_params(self): 190 | for module in [self.scene_understanding_modulule, self.ordinal_regression]: 191 | for k in module.parameters(): 192 | if k.requires_grad: 193 | yield k -------------------------------------------------------------------------------- /progress_tracking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def log10(x): 8 | """Convert a new tensor with the base-10 logarithm of the elements of x. """ 9 | return torch.log(x) / math.log(10) 10 | 11 | class Result(object): 12 | def __init__(self): 13 | self.irmse, self.imae = 0, 0 14 | self.mse, self.rmse, self.mae = 0, 0, 0 15 | self.absrel, self.lg10 = 0, 0 16 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 17 | 18 | def set_to_worst(self): 19 | self.irmse, self.imae = np.inf, np.inf 20 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf 21 | self.absrel, self.lg10 = np.inf, np.inf 22 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 23 | 24 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3): 25 | self.irmse, self.imae = irmse, imae 26 | self.mse, self.rmse, self.mae = mse, rmse, mae 27 | self.absrel, self.lg10 = absrel, lg10 28 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3 29 | 30 | def evaluate(self, output, target): 31 | valid_mask = target>0 32 | output = output[valid_mask] 33 | target = target[valid_mask] 34 | 35 | abs_diff = (output - target).abs() 36 | 37 | self.mse = float((torch.pow(abs_diff, 2)).mean()) 38 | self.rmse = math.sqrt(self.mse) 39 | self.mae = float(abs_diff.mean()) 40 | self.lg10 = float((log10(output) - log10(target)).abs().mean()) 41 | self.absrel = float((abs_diff / target).mean()) 42 | 43 | maxRatio = torch.max(output / target, target / output) 44 | self.delta1 = float((maxRatio < 1.25).float().mean()) 45 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean()) 46 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean()) 47 | 48 | inv_output = 1 / output 49 | inv_target = 1 / target 50 | abs_inv_diff = (inv_output - inv_target).abs() 51 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 52 | self.imae = float(abs_inv_diff.mean()) 53 | 54 | 55 | class AverageMeter(object): 56 | def __init__(self): 57 | self.reset() 58 | 59 | def reset(self): 60 | self.count = 0.0 61 | 62 | self.sum_irmse, self.sum_imae = 0, 0 63 | self.sum_mse, self.sum_rmse, self.sum_mae = 0, 0, 0 64 | self.sum_absrel, self.sum_lg10 = 0, 0 65 | self.sum_delta1, self.sum_delta2, self.sum_delta3 = 0, 0, 0 66 | 67 | def update(self, result, n=1): 68 | self.count += n 69 | 70 | self.sum_irmse += n*result.irmse 71 | self.sum_imae += n*result.imae 72 | self.sum_mse += n*result.mse 73 | self.sum_rmse += n*result.rmse 74 | self.sum_mae += n*result.mae 75 | self.sum_absrel += n*result.absrel 76 | self.sum_lg10 += n*result.lg10 77 | self.sum_delta1 += n*result.delta1 78 | self.sum_delta2 += n*result.delta2 79 | self.sum_delta3 += n*result.delta3 80 | 81 | def average(self): 82 | avg = Result() 83 | avg.update( 84 | self.sum_irmse / self.count, self.sum_imae / self.count, 85 | self.sum_mse / self.count, self.sum_rmse / self.count, self.sum_mae / self.count, 86 | self.sum_absrel / self.count, self.sum_lg10 / self.count, 87 | self.sum_delta1 / self.count, self.sum_delta2 / self.count, self.sum_delta3 / self.count) 88 | return avg 89 | 90 | def log(self, logger, epoch, stage="Train"): 91 | avg = self.average() 92 | logger.add_scalar(stage + '/RMSE', avg.rmse, epoch) 93 | logger.add_scalar(stage + '/rml', avg.absrel, epoch) 94 | logger.add_scalar(stage + '/Log10', avg.lg10, epoch) 95 | logger.add_scalar(stage + '/Delta1', avg.delta1, epoch) 96 | logger.add_scalar(stage + '/Delta2', avg.delta2, epoch) 97 | logger.add_scalar(stage + '/Delta3', avg.delta3, epoch) 98 | 99 | 100 | class ImageBuilder(object): 101 | """ 102 | Builds an image iteratively row by row where the columns are (input image, target depth map, output depth map). 103 | """ 104 | def __init__(self): 105 | self.count = 0 106 | self.img_merge = None 107 | 108 | def has_image(self): 109 | return self.img_merge is not None 110 | 111 | def get_image(self): 112 | return torch.from_numpy(np.transpose(self.img_merge, (2, 0, 1)) / 255.0) 113 | 114 | def add_row(self, input, target, depth): 115 | if self.count == 0: 116 | self.img_merge = self.merge_into_row(input, target, depth) 117 | else: 118 | row = self.merge_into_row(input, target, depth) 119 | self.img_merge = np.vstack([self.img_merge, row]) 120 | 121 | self.count += 1 122 | 123 | @staticmethod 124 | def colored_depthmap(depth, d_min=None, d_max=None): 125 | if d_min is None: 126 | d_min = np.min(depth) 127 | if d_max is None: 128 | d_max = np.max(depth) 129 | depth_relative = (depth - d_min) / (d_max - d_min) 130 | return 255 * plt.cm.jet(depth_relative)[:, :, :3] # H, W, C 131 | 132 | @staticmethod 133 | def merge_into_row(input, depth_target, depth_pred): 134 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1, 2, 0)) # H, W, C 135 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 136 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) 137 | 138 | d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu)) 139 | d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu)) 140 | depth_target_col = ImageBuilder.colored_depthmap(depth_target_cpu, d_min, d_max) 141 | depth_pred_col = ImageBuilder.colored_depthmap(depth_pred_cpu, d_min, d_max) 142 | img_merge = np.hstack([rgb, depth_target_col, depth_pred_col]) 143 | 144 | return img_merge -------------------------------------------------------------------------------- /resnet_dilated.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import urllib 5 | 6 | # ------------------------------------------------------------------------- 7 | # This version of Resnet101 was used in PSPnet and DORN (according to https://github.com/hufu6371/DORN) 8 | # Code based on https://github.com/speedinghzl/pytorch-segmentation-toolbox/blob/master/networks/pspnet.py 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | "3x3 convolution with padding" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | 14 | 15 | class Bottleneck(nn.Module): 16 | expansion = 4 17 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1): 18 | super(Bottleneck, self).__init__() 19 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 22 | padding=dilation, dilation=dilation, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 25 | self.bn3 = nn.BatchNorm2d(planes * 4) 26 | self.relu = nn.ReLU(inplace=False) 27 | self.relu_inplace = nn.ReLU(inplace=True) 28 | self.downsample = downsample 29 | self.dilation = dilation 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv3(out) 44 | out = self.bn3(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out = out + residual 50 | out = self.relu_inplace(out) 51 | 52 | return out 53 | 54 | 55 | class ResNet(nn.Module): 56 | def __init__(self, block, layers): 57 | super(ResNet, self).__init__() 58 | self.inplanes = 128 59 | 60 | self.conv1 = conv3x3(3, 64, stride=2) 61 | self.bn1 = nn.BatchNorm2d(64) 62 | self.relu1 = nn.ReLU(inplace=False) 63 | 64 | self.conv2 = conv3x3(64, 64) 65 | self.bn2 = nn.BatchNorm2d(64) 66 | self.relu2 = nn.ReLU(inplace=False) 67 | 68 | self.conv3 = conv3x3(64, 128) 69 | self.bn3 = nn.BatchNorm2d(128) 70 | self.relu3 = nn.ReLU(inplace=False) 71 | 72 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 73 | 74 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 75 | self.layer1 = self._make_layer(block, 64, layers[0]) 76 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 77 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 78 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 79 | 80 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 81 | downsample = None 82 | if stride != 1 or self.inplanes != planes * block.expansion: 83 | downsample = nn.Sequential( 84 | nn.Conv2d(self.inplanes, planes * block.expansion, 85 | kernel_size=1, stride=stride, bias=False), 86 | nn.BatchNorm2d(planes * block.expansion,affine = True)) 87 | 88 | layers = [] 89 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample)) 90 | self.inplanes = planes * block.expansion 91 | for i in range(1, blocks): 92 | layers.append(block(self.inplanes, planes, dilation=dilation)) 93 | 94 | return nn.Sequential(*layers) 95 | 96 | def forward(self, x): 97 | x = self.relu1(self.bn1(self.conv1(x))) 98 | x = self.relu2(self.bn2(self.conv2(x))) 99 | x = self.relu3(self.bn3(self.conv3(x))) 100 | x = self.maxpool(x) 101 | x = self.layer1(x) 102 | x = self.layer2(x) 103 | x = self.layer3(x) 104 | x = self.layer4(x) 105 | return x 106 | 107 | 108 | def resnet101dilated(pretrained=False): 109 | model = ResNet(Bottleneck,[3, 4, 23, 3]) 110 | 111 | if pretrained: 112 | # Download pretrained model if it does not exist 113 | # This is ADE20K-pretrained encoder from https://github.com/CSAILVision/semantic-segmentation-pytorch 114 | # It is the only pretrained resnet101dilated I could find online 115 | # (with the same architecture as in the paper - see https://github.com/hufu6371/DORN/tree/master/models) 116 | filename = './pretrained/ade20k_resnet101dilated_encoder_epoch_25.pth' 117 | if not os.path.isfile(filename): 118 | os.makedirs(os.path.dirname(filename)) 119 | print('Pretrained feature extractor not found. Downloading...') 120 | url = 'http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet101dilated-ppm_deepsup/encoder_epoch_25.pth' 121 | urllib.request.urlretrieve(url, filename) 122 | print('Download completed:', filename) 123 | 124 | # Load pretrained parameters 125 | saved_state_dict = torch.load(filename, map_location='cpu') 126 | 127 | new_params = model.state_dict().copy() 128 | for i in saved_state_dict: 129 | i_parts = i.split('.') 130 | if not i_parts[0] == 'fc' and i.find('._') == -1: 131 | new_params['.'.join(i_parts[0:])] = saved_state_dict[i] 132 | 133 | model.load_state_dict(new_params) 134 | 135 | return model -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from model import DORN 4 | from data import get_dataloaders 5 | from loss import OrdinalLoss 6 | from lr_decay import PolynomialLRDecay 7 | from discritization import SID 8 | from progress_tracking import AverageMeter, Result, ImageBuilder 9 | from tensorboardX import SummaryWriter 10 | from datetime import datetime 11 | import os, socket 12 | 13 | LOG_IMAGES = 3 # number of images per epoch to log with tensorboard 14 | 15 | # Parse arguments 16 | parser = argparse.ArgumentParser(description='DORN depth estimation in PyTorch') 17 | parser.add_argument('--dataset', default='nyu', type=str, help='dataset: kitti or nyu (default: nyu)') 18 | parser.add_argument('--data-path', default='./nyu_official', type=str, help='path to the dataset') 19 | parser.add_argument("--pretrained", action='store_true', help="use a pretrained feature extractor") 20 | parser.add_argument('--epochs', default=200, type=int, help='n of epochs (default: 200)') 21 | parser.add_argument('--bs', default=3, type=int, help='[train] batch size(default: 3)') 22 | parser.add_argument('--bs-test', default=3, type=int, help='[test] batch size (default: 3)') 23 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate (default: 1e-4)') 24 | parser.add_argument('--gpu', default='0', type=str, help='GPU id to use (default: 0)') 25 | args = parser.parse_args() 26 | print(args) 27 | 28 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 29 | train_loader, val_loader = get_dataloaders(args.dataset, args.data_path, args.bs, args.bs_test) 30 | model = DORN(dataset=args.dataset, pretrained=args.pretrained).cuda() 31 | train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr}, {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}] 32 | optimizer = torch.optim.SGD(train_params, lr=args.lr, momentum=0.9, weight_decay=0.0005) 33 | lr_decay = PolynomialLRDecay(optimizer, args.epochs, args.lr * 1e-2) 34 | criterion = OrdinalLoss() 35 | sid = SID(args.dataset) 36 | 37 | # Create logger 38 | log_dir = os.path.join(os.path.abspath(os.getcwd()), 'logs', datetime.now().strftime('%b%d_%H-%M-%S_') + socket.gethostname()) 39 | os.makedirs(log_dir) 40 | logger = SummaryWriter(log_dir) 41 | # Log arguments 42 | with open(os.path.join(log_dir, "args.txt"), "a") as f: 43 | print(args, file=f) 44 | 45 | for epoch in range(args.epochs): 46 | # log learning rate 47 | for i, param_group in enumerate(optimizer.param_groups): 48 | logger.add_scalar('Lr/lr_' + str(i), float(param_group['lr']), epoch) 49 | 50 | print('Epoch', epoch, 'train in progress...') 51 | model.train() 52 | average_meter = AverageMeter() 53 | image_builder = ImageBuilder() 54 | for i, (input, target) in enumerate(train_loader): 55 | input, target = input.cuda(), target.cuda() 56 | 57 | pred_labels, pred_softmax = model(input) 58 | target_labels = sid.depth2labels(target) # get ground truth ordinal labels using SID 59 | loss = criterion(pred_softmax, target_labels) 60 | optimizer.zero_grad() 61 | loss.backward() 62 | optimizer.step() 63 | 64 | # track performance scores 65 | depth = sid.labels2depth(pred_labels) 66 | result = Result() 67 | result.evaluate(depth.data, target.data) 68 | average_meter.update(result, input.size(0)) 69 | if i <= LOG_IMAGES: 70 | image_builder.add_row(input[0,:,:,:], target[0,:,:], depth[0,:,:]) 71 | 72 | # log performance scores with tensorboard 73 | average_meter.log(logger, epoch, 'Train') 74 | if LOG_IMAGES: 75 | logger.add_image('Train/Image', image_builder.get_image(), epoch) 76 | 77 | lr_decay.step() 78 | 79 | print('Epoch', epoch, 'test in progress...') 80 | model.eval() 81 | average_meter = AverageMeter() 82 | image_builder = ImageBuilder() 83 | for i, (input, target) in enumerate(val_loader): 84 | input, target = input.cuda(), target.cuda() 85 | 86 | with torch.no_grad(): 87 | pred_labels, _ = model(input) 88 | 89 | # track performance scores 90 | pred = sid.labels2depth(pred_labels) 91 | result = Result() 92 | result.evaluate(pred.data, target.data) 93 | average_meter.update(result, input.size(0)) 94 | if i <= LOG_IMAGES: 95 | image_builder.add_row(input[0,:,:,:], target[0,:,:], pred[0,:,:]) 96 | 97 | # log performance scores with tensorboard 98 | average_meter.log(logger, epoch, 'Test') 99 | if LOG_IMAGES: 100 | logger.add_image('Test/Image', image_builder.get_image(), epoch) 101 | 102 | print() 103 | 104 | logger.close() 105 | --------------------------------------------------------------------------------