├── 1.png ├── test.py ├── criteria.py ├── utils.py ├── metrics.py ├── models.py ├── README.md ├── nyu_dataloader.py ├── transforms.py └── main.py /1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lishunkai/DenseDepthMapCreationFromSparsePoints/HEAD/1.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | rgb = cv2.imread('1.png') 5 | gray = cv2.cvtColor(rgb,cv2.COLOR_BGR2GRAY) 6 | #gray = rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114 7 | orb = cv2.ORB_create(500) 8 | kp = orb.detect(gray, None) 9 | print(len(kp)) 10 | 11 | sparse_depth = np.zeros(rgb.shape) 12 | 13 | for i in range(len(kp)): 14 | tu = kp[i].pt 15 | x = int(tu[0]) 16 | y = int(tu[1]) 17 | print(i,' ',x,' ',y) 18 | sparse_depth[y,x] = rgb[y,x] # x,y和行,列的顺序不一样 19 | 20 | 21 | # print(len(kp)) # keypoint的个数 22 | # print(kp[0]) # 多少个角点,就有多少个下标 23 | # tu = kp[0].pt #(提取坐标) pt指的是元组 tuple(x,y) 24 | # print(tu[0],tu[1]) # 输出第一个keypoint的x,y坐标 25 | 26 | 27 | # def create_sparse_depth_ORB(self, rgb, depth, prob): 28 | # num_samples = int(prob * depth.size) 29 | # gray = rgb2grayscale(rgb) 30 | # orb = cv2.ORB_create(num_samples) 31 | # kp = orb.detect(gray, None) 32 | # print len(kp) 33 | 34 | # mask_keep = np.random.uniform(0, 1, depth.shape) < prob # 生成一个0-1的mask,0-1是随机产生的,0-1产生的概率小于预设的概率 35 | # sparse_depth = np.zeros(depth.shape) 36 | # sparse_depth[mask_keep] = depth[mask_keep] # 把深度图中和mask_keep对应的元素赋值给sparse_depth中的对应元素 37 | # return sparse_depth -------------------------------------------------------------------------------- /criteria.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | # L2范数 6 | class MaskedMSELoss(nn.Module): 7 | # Module是pytorch提供的一个基类,每次我们搭建神经网络时都要继承这个类,继承这个类会使搭建网络的过程变得异常简单 8 | # 详见https://blog.csdn.net/u012436149/article/details/78281553 9 | def __init__(self): 10 | # super是继承父类(超类)的一种方法 11 | # 详见https://www.cnblogs.com/HoMe-Lin/p/5745297.html 12 | super(MaskedMSELoss, self).__init__() 13 | 14 | def forward(self, pred, target): 15 | # 在没完善一个程序之前,我们不知道程序在哪里会出错。与其让它在运行最后崩溃,不如在出现错误条件时就崩溃,这时就需要assert断言 16 | # 详见https://www.cnblogs.com/liuchunxiao83/p/5298016.html 17 | assert pred.dim() == target.dim(), "inconsistent dimensions" 18 | valid_mask = (target>0).detach() 19 | diff = target - pred 20 | diff = diff[valid_mask] 21 | self.loss = (diff ** 2).mean() 22 | return self.loss 23 | 24 | # L1范数 25 | class MaskedL1Loss(nn.Module): 26 | def __init__(self): 27 | super(MaskedL1Loss, self).__init__() 28 | 29 | def forward(self, pred, target): 30 | assert pred.dim() == target.dim(), "inconsistent dimensions" 31 | valid_mask = (target>0).detach() 32 | diff = target - pred 33 | diff = diff[valid_mask] 34 | self.loss = diff.abs().mean() 35 | return self.loss 36 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # 画图和保存图片的函数 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt # 导入画图工具 5 | from PIL import Image # Python Imaging Library 6 | 7 | cmap = plt.cm.jet # cmap: 颜色图谱(colormap) 详见https://blog.csdn.net/haoji007/article/details/52063168 8 | 9 | def merge_into_row(input, target, depth_pred): 10 | # np.squeeze(): 从数组的形状中删除单维条目,即把shape中为1的维度去掉 11 | # 对于高维数组,transpose需要用到一个由轴编号组成的元组,才能进行转置。详见https://www.cnblogs.com/sunshinewang/p/6893503.html 12 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C 13 | depth = np.squeeze(target.cpu().numpy()) 14 | depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth)) 15 | depth = 255 * cmap(depth)[:,:,:3] # H, W, C [:,:,:3]是什么意思? 16 | pred = np.squeeze(depth_pred.data.cpu().numpy()) 17 | pred = (pred - np.min(pred)) / (np.max(pred) - np.min(pred)) 18 | pred = 255 * cmap(pred)[:,:,:3] # H, W, C 19 | img_merge = np.hstack([rgb, depth, pred]) # 将一系列数组按输入顺序水平地排成一排 20 | 21 | # img_merge.save(output_directory + '/comparison_' + str(epoch) + '.png') 22 | return img_merge 23 | 24 | def add_row(img_merge, row): 25 | return np.vstack([img_merge, row]) # 将一系列数组按输入顺序竖直地排成一列 26 | 27 | def save_image(img_merge, filename): 28 | img_merge = Image.fromarray(img_merge.astype('uint8')) # 设置保存的数据格式 29 | img_merge.save(filename) -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # 这个脚本文件中定义了许多评价模型的标准 2 | 3 | import torch 4 | import math 5 | import numpy as np 6 | 7 | def log10(x): 8 | """Convert a new tensor with the base-10 logarithm of the elements of x. """ 9 | return torch.log(x) / math.log(10) 10 | 11 | # 新式类,详见https://www.cnblogs.com/liulipeng/p/7069004.html 12 | class Result(object): 13 | # 带有两个下划线开头的函数是声明该属性为私有,不能在类的外部被使用或直接访问 14 | # init函数(方法)支持带参数的类的初始化,也可为声明该类的属性。第一个参数必须是self,后续参数则可以自由指定。 15 | # 在类的内部,使用def关键字可以为类定义一个函数(方法)。类方法必须包含参数self,且为第一个参数。 16 | # python函数只能先定义再调用 17 | # Python中的self等价于C++中的self指针和Java、C#中的this参数。 18 | # self指的是传入的实例(instance),不同实例类的属性值不同以及方法执行结果不同 19 | # 详见https://blog.csdn.net/ly_ysys629/article/details/54893185 20 | def __init__(self): 21 | # 一行读取多个值 22 | self.irmse, self.imae = 0, 0 23 | self.mse, self.rmse, self.mae = 0, 0, 0 24 | self.absrel, self.lg10 = 0, 0 25 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 26 | self.data_time, self.gpu_time = 0, 0 27 | 28 | def set_to_worst(self): # 将最坏的情况初始化为无穷大 29 | self.irmse, self.imae = np.inf, np.inf 30 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf 31 | self.absrel, self.lg10 = np.inf, np.inf 32 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 33 | self.data_time, self.gpu_time = 0, 0 34 | 35 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time): 36 | self.irmse, self.imae = irmse, imae 37 | self.mse, self.rmse, self.mae = mse, rmse, mae 38 | self.absrel, self.lg10 = absrel, lg10 39 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3 40 | self.data_time, self.gpu_time = data_time, gpu_time 41 | 42 | def evaluate(self, output, target): 43 | valid_mask = target>0 # valid_mask为ture或false 44 | output = output[valid_mask] # [ ]:代表list列表数据类型,列表是一种可变的序列 45 | target = target[valid_mask] 46 | 47 | abs_diff = (output - target).abs() 48 | 49 | self.mse = (torch.pow(abs_diff, 2)).mean() # 均方误差 torch.pow()和python本身的**有什么区别? 50 | self.rmse = math.sqrt(self.mse) # 均方根误差 51 | self.mae = abs_diff.mean() # Mean Absolute Error 52 | self.lg10 = (log10(output) - log10(target)).abs().mean() 53 | self.absrel = (abs_diff / target).mean() 54 | 55 | maxRatio = torch.max(output / target, target / output) 56 | self.delta1 = (maxRatio < 1.25).float().mean() 57 | self.delta2 = (maxRatio < 1.25 ** 2).float().mean() 58 | self.delta3 = (maxRatio < 1.25 ** 3).float().mean() 59 | self.data_time = 0 60 | self.gpu_time = 0 61 | 62 | inv_output = 1 / output 63 | inv_target = 1 / target 64 | abs_inv_diff = (inv_output - inv_target).abs() 65 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 66 | self.imae = abs_inv_diff.mean() 67 | 68 | 69 | class AverageMeter(object): 70 | def __init__(self): 71 | self.reset() 72 | 73 | def reset(self): 74 | self.count = 0.0 75 | 76 | self.sum_irmse, self.sum_imae = 0, 0 77 | self.sum_mse, self.sum_rmse, self.sum_mae = 0, 0, 0 78 | self.sum_absrel, self.sum_lg10 = 0, 0 79 | self.sum_delta1, self.sum_delta2, self.sum_delta3 = 0, 0, 0 80 | self.sum_data_time, self.sum_gpu_time = 0, 0 81 | 82 | def update(self, result, gpu_time, data_time, n=1): 83 | self.count += n 84 | 85 | self.sum_irmse += n*result.irmse 86 | self.sum_imae += n*result.imae 87 | self.sum_mse += n*result.mse 88 | self.sum_rmse += n*result.rmse 89 | self.sum_mae += n*result.mae 90 | self.sum_absrel += n*result.absrel 91 | self.sum_lg10 += n*result.lg10 92 | self.sum_delta1 += n*result.delta1 93 | self.sum_delta2 += n*result.delta2 94 | self.sum_delta3 += n*result.delta3 95 | self.sum_data_time += n*data_time 96 | self.sum_gpu_time += n*gpu_time 97 | 98 | def average(self): 99 | avg = Result() # Result()是个类,这句话是定义一个Result()类的对象avg 100 | # 调用avg对象中的update函数 101 | avg.update( 102 | self.sum_irmse / self.count, self.sum_imae / self.count, 103 | self.sum_mse / self.count, self.sum_rmse / self.count, self.sum_mae / self.count, 104 | self.sum_absrel / self.count, self.sum_lg10 / self.count, 105 | self.sum_delta1 / self.count, self.sum_delta2 / self.count, self.sum_delta3 / self.count, 106 | self.sum_gpu_time / self.count, self.sum_data_time / self.count) 107 | return avg -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models 5 | import collections 6 | import math 7 | # 关于pytorch实现CNN的详细剖析,详见http://developer.51cto.com/art/201708/548220.htm 8 | # 关于深度学习中编码层的作用,详见https://zhuanlan.zhihu.com/p/27549418 9 | 10 | oheight, owidth = 228, 304 11 | 12 | def weights_init(m): 13 | # Initialize filters with Gaussian random weights 14 | if isinstance(m, nn.Conv2d): 15 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 16 | m.weight.data.normal_(0, math.sqrt(2. / n)) 17 | if m.bias is not None: 18 | m.bias.data.zero_() 19 | elif isinstance(m, nn.ConvTranspose2d): # ConvTranspose2d是卷积的反操作,某种意义上可当做反卷积 20 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 21 | m.weight.data.normal_(0, math.sqrt(2. / n)) 22 | if m.bias is not None: 23 | m.bias.data.zero_() 24 | elif isinstance(m, nn.BatchNorm2d): 25 | m.weight.data.fill_(1) 26 | m.bias.data.zero_() 27 | 28 | class Decoder(nn.Module): 29 | # Decoder is the base class for all decoders 30 | 31 | # Module是pytorch提供的一个基类,每次我们搭建神经网络时都要继承这个类,继承这个类会使搭建网络的过程变得异常简单 32 | # 详见https://blog.csdn.net/u012436149/article/details/78281553 33 | 34 | names = ['deconv{}'.format(i) for i in range(2,10)] 35 | 36 | def __init__(self): 37 | super(Decoder, self).__init__() 38 | # super是继承父类(超类)的一种方法 39 | # 详见https://www.cnblogs.com/HoMe-Lin/p/5745297.html 40 | 41 | self.layer1 = None 42 | self.layer2 = None 43 | self.layer3 = None 44 | self.layer4 = None 45 | 46 | def forward(self, x): 47 | x = self.layer1(x) 48 | x = self.layer2(x) 49 | x = self.layer3(x) 50 | x = self.layer4(x) 51 | return x 52 | 53 | class DeConv(Decoder): 54 | def __init__(self, in_channels, kernel_size): 55 | assert kernel_size>=2, "kernel_size out of range: {}".format(kernel_size) 56 | super(DeConv, self).__init__() 57 | 58 | def convt(in_channels): 59 | stride = 2 60 | padding = (kernel_size - 1) // 2 61 | output_padding = kernel_size % 2 62 | assert -2 - 2*padding + kernel_size + output_padding == 0, "deconv parameters incorrect" 63 | 64 | module_name = "deconv{}".format(kernel_size) 65 | return nn.Sequential(collections.OrderedDict([ 66 | (module_name, nn.ConvTranspose2d(in_channels,in_channels//2,kernel_size, 67 | stride,padding,output_padding,bias=False)), 68 | ('batchnorm', nn.BatchNorm2d(in_channels//2)), 69 | ('relu', nn.ReLU(inplace=True)), 70 | ])) 71 | 72 | self.layer1 = convt(in_channels) 73 | self.layer2 = convt(in_channels // 2) 74 | self.layer3 = convt(in_channels // (2 ** 2)) 75 | self.layer4 = convt(in_channels // (2 ** 3)) 76 | 77 | 78 | def choose_decoder(decoder): 79 | assert decoder[:6] == 'deconv' # [:6]: 列表中的第1~6位元素 [1:]: 列表中第1位以后的所有元素(不含第1位) 80 | assert len(decoder)==7 81 | 82 | num_channels = 512 83 | iheight, iwidth = 10, 8 84 | kernel_size = int(decoder[6]) # decoder的第7个元素。和c++一样,python数组的编号也是从0开始的 85 | return DeConv(num_channels, kernel_size) 86 | 87 | 88 | class ResNet(nn.Module): 89 | def __init__(self, layers, decoder, in_channels=3, out_channels=1, pretrained=True): 90 | 91 | if layers not in [18, 34, 50, 101, 152]: # 这是ResNet常用的层数 92 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer models are defined for ResNet. Got {}'.format(layers)) 93 | 94 | super(ResNet, self).__init__() 95 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) 96 | 97 | if in_channels == 3: 98 | self.conv1 = pretrained_model._modules['conv1'] 99 | self.bn1 = pretrained_model._modules['bn1'] 100 | else: 101 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | weights_init(self.conv1) 104 | weights_init(self.bn1) 105 | 106 | self.relu = pretrained_model._modules['relu'] 107 | self.maxpool = pretrained_model._modules['maxpool'] 108 | self.layer1 = pretrained_model._modules['layer1'] 109 | self.layer2 = pretrained_model._modules['layer2'] 110 | self.layer3 = pretrained_model._modules['layer3'] 111 | self.layer4 = pretrained_model._modules['layer4'] 112 | 113 | # clear memory 114 | del pretrained_model 115 | 116 | # define number of intermediate channels 117 | if layers <= 34: 118 | num_channels = 512 119 | elif layers >= 50: 120 | num_channels = 2048 121 | 122 | self.conv2 = nn.Conv2d(num_channels,512,kernel_size=1,bias=False) 123 | self.bn2 = nn.BatchNorm2d(512) 124 | self.decoder = choose_decoder(decoder) 125 | 126 | # setting bias=true doesn't improve accuracy 127 | self.conv3 = nn.Conv2d(32,out_channels,kernel_size=3,stride=1,padding=1,bias=False) 128 | self.bilinear = nn.Upsample(size=(oheight, owidth), mode='bilinear') 129 | 130 | # weight init 131 | self.conv2.apply(weights_init) 132 | self.bn2.apply(weights_init) 133 | self.decoder.apply(weights_init) 134 | self.conv3.apply(weights_init) 135 | 136 | def forward(self, x): 137 | # resnet 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | x = self.conv2(x) 148 | x = self.bn2(x) 149 | 150 | # decoder 151 | x = self.decoder(x) 152 | x = self.conv3(x) 153 | x = self.bilinear(x) 154 | 155 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DenseDepthMapCreationFromSparsePoints 2 | 3 | sparse-to-dense.pytorch 4 | ============================ 5 | 6 | This repo implements the training and testing of deep regression neural networks for ["Sparse-to-Dense: Depth Prediction from Sparse Depth Samples and a Single Image"](https://arxiv.org/pdf/1709.07492.pdf) by [Fangchang Ma](http://www.mit.edu/~fcma) and [Sertac Karaman](http://karaman.mit.edu/) at MIT. A video demonstration is available on [YouTube](https://youtu.be/vNIIT_M7x7Y). 7 |

8 | photo not available 9 | photo not available 10 |

11 | 12 | This repo can be used for training and testing of 13 | - RGB (or grayscale image) based depth prediction 14 | - sparse depth based depth prediction 15 | - RGBd (i.e., both RGB and sparse depth) based depth prediction 16 | 17 | The original Torch implementation of the paper can be found [here](https://github.com/fangchangma/sparse-to-dense). This PyTorch version is under development and is subject to major modifications in the future. 18 | 19 | ## Contents 20 | 0. [Requirements](#requirements) 21 | 0. [Training](#training) 22 | 0. [Testing](#testing) 23 | 0. [Trained Models](#trained-models) 24 | 0. [Benchmark](#benchmark) 25 | 0. [Citation](#citation) 26 | 27 | ## Requirements 28 | - Install [PyTorch](http://pytorch.org/) on a machine with CUDA GPU. 29 | - Install the [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) and other dependencies (files in our pre-processed datasets are in HDF5 formats). 30 | ```bash 31 | sudo apt-get update 32 | sudo apt-get install -y libhdf5-serial-dev hdf5-tools 33 | pip install h5py matplotlib imageio scikit-image 34 | ``` 35 | - Download the preprocessed [NYU Depth V2](http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) dataset in HDF5 formats, and place them under the `data` folder. The downloading process might take an hour or so. The NYU dataset requires 32G of storage space. 36 | ```bash 37 | mkdir data 38 | cd data 39 | wget http://datasets.lids.mit.edu/sparse-to-dense/data/nyudepthv2.tar.gz 40 | tar -xvf nyudepthv2.tar.gz && rm -f nyudepthv2.tar.gz 41 | cd .. 42 | ``` 43 | ## Training 44 | The training scripts come with several options, which can be listed with the `--help` flag. Currently this repo only supports training on the NYU dataset, and deconvolution with different kernel sizes (no `upconv` or `upproj` since we found them to be inefficient compared with using simple `deconv` with larger kernel sizes). 45 | ```bash 46 | python3 main.py --help 47 | ``` 48 | 49 | For instance, run the following command to train a network with ResNet50 as the encoder, deconvolutions of kernel size 3 as the decoder, and both RGB and 100 random sparse depth samples as the input to the network. 50 | ```bash 51 | python3 main.py -a resnet50 -d deconv3 -m rgbd -s 100 52 | ``` 53 | 54 | Training results will be saved under the `results` folder. 55 | 56 | 57 | ## Testing 58 | To test the performance of a trained model, simply run main.py with the `-e` option, along with other model options. For instance, 59 | ```bash 60 | python3 main.py -e 61 | ``` 62 | 63 | ## Trained Models 64 | Trained models will be released later. 65 | 66 | ## Benchmark 67 | The following numbers are from the original Torch repo. 68 | - Error metrics on NYU Depth v2: 69 | 70 | | RGB | rms | rel | delta1 | delta2 | delta3 | 71 | |-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:| 72 | | [Roy & Todorovic](http://web.engr.oregonstate.edu/~sinisa/research/publications/cvpr16_NRF.pdf) (_CVPR 2016_) | 0.744 | 0.187 | - | - | - | 73 | | [Eigen & Fergus](http://cs.nyu.edu/~deigen/dnl/) (_ICCV 2015_) | 0.641 | 0.158 | 76.9 | 95.0 | 98.8 | 74 | | [Laina et al](https://arxiv.org/pdf/1606.00373.pdf) (_3DV 2016_) | 0.573 | **0.127** | **81.1** | 95.3 | 98.8 | 75 | | Ours-RGB | **0.514** | 0.143 | 81.0 | **95.9** | **98.9** | 76 | 77 | | RGBd-#samples | rms | rel | delta1 | delta2 | delta3 | 78 | |-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:| 79 | | [Liao et al](https://arxiv.org/abs/1611.02174) (_ICRA 2017_)-225 | 0.442 | 0.104 | 87.8 | 96.4 | 98.9 | 80 | | Ours-20 | 0.351 | 0.078 | 92.8 | 98.4 | 99.6 | 81 | | Ours-50 | 0.281 | 0.059 | 95.5 | 99.0 | 99.7 | 82 | | Ours-200| **0.230** | **0.044** | **97.1** | **99.4** | **99.8** | 83 | 84 | photo not available 85 | 86 | - Error metrics on KITTI dataset: 87 | 88 | | RGB | rms | rel | delta1 | delta2 | delta3 | 89 | |-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:| 90 | | [Make3D](http://papers.nips.cc/paper/5539-depth-map-prediction-from-a-single-image-using-a-multi-scale-deep-network.pdf) | 8.734 | 0.280 | 60.1 | 82.0 | 92.6 | 91 | | [Mancini et al](https://arxiv.org/pdf/1607.06349.pdf) (_IROS 2016_) | 7.508 | - | 31.8 | 61.7 | 81.3 | 92 | | [Eigen et al](http://papers.nips.cc/paper/5539-depth-map-prediction-from-a-single-image-using-a-multi-scale-deep-network.pdf) (_NIPS 2014_) | 7.156 | **0.190** | **69.2** | 89.9 | **96.7** | 93 | | Ours-RGB | **6.266** | 0.208 | 59.1 | **90.0** | 96.2 | 94 | 95 | | RGBd-#samples | rms | rel | delta1 | delta2 | delta3 | 96 | |-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:| 97 | | [Cadena et al](https://pdfs.semanticscholar.org/18d5/f0747a23706a344f1d15b032ea22795324fa.pdf) (_RSS 2016_)-650 | 7.14 | 0.179 | 70.9 | 88.8 | 95.6 | 98 | | Ours-50 | 4.884 | 0.109 | 87.1 | 95.2 | 97.9 | 99 | | [Liao et al](https://arxiv.org/abs/1611.02174) (_ICRA 2017_)-225 | 4.50 | 0.113 | 87.4 | 96.0 | 98.4 | 100 | | Ours-100 | 4.303 | 0.095 | 90.0 | 96.3 | 98.3 | 101 | | Ours-200 | 3.851 | 0.083 | 91.9 | 97.0 | 98.6 | 102 | | Ours-500| **3.378** | **0.073** | **93.5** | **97.6** | **98.9** | 103 | 104 | photo not available 105 | 106 | Note: our networks are trained on the KITTI odometry dataset, using only sparse labels from laser measurements. 107 | 108 | ## Citation 109 | If you use our code or method in your work, please cite: 110 | 111 | @article{Ma2017SparseToDense, 112 | title={Sparse-to-Dense: Depth Prediction from Sparse Depth Samples and a Single Image}, 113 | author={Ma, Fangchang and Karaman, Sertac}, 114 | journal={arXiv preprint arXiv:1709.07492}, 115 | year={2017} 116 | } 117 | 118 | Please direct any questions to [Fangchang Ma](http://www.mit.edu/~fcma) at fcma@mit.edu. 119 | -------------------------------------------------------------------------------- /nyu_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import cv2 # opencv的python版本 4 | import numpy as np 5 | import torch.utils.data as data 6 | import h5py # HDF5的python版本 7 | import transforms 8 | from PIL import Image 9 | 10 | IMG_EXTENSIONS = [ 11 | '.h5', 12 | ] 13 | 14 | def is_image_file(filename): 15 | # return any(): 只要迭代器中有一个元素为真就返回为真。 16 | # 详见https://blog.csdn.net/heatdeath/article/details/70178511 17 | # 详见https://blog.csdn.net/u013630349/article/details/47374333 18 | # 其中,...for...in...是迭代器(iterable) 19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 20 | 21 | def find_classes(dir): 22 | # ...for...in...if... 挑选出in后面的内容中符合if条件的元素,组成一个新的list 23 | # 详见http://www.jb51.net/article/86987.htm 24 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 25 | classes.sort() # 升序排列 26 | # Python中[],(),{}的区别: 27 | # {}表示字典,[]表示数组,()表示元组 28 | # 数组的值可以改变,可以使用切片获取部分数据 29 | # 元组的值一旦设置,不可更改,不可使用切片 30 | # 详见https://zhidao.baidu.com/question/484920124.html 31 | class_to_idx = {classes[i]: i for i in range(len(classes))} 32 | return classes, class_to_idx 33 | 34 | # 这个函数可以用于生成自己的数据集 35 | def make_dataset(dir, class_to_idx): 36 | images = [] # 存放图片序号的数组 37 | # dir是路径 38 | dir = os.path.expanduser(dir) 39 | for target in sorted(os.listdir(dir)): # sorted(): 输出排序后的列表 升序排列 40 | # print(target) 41 | # target只是文件名,不含路径。为了获得文件的完整路径,用os.path.join(dirpath, name) 42 | d = os.path.join(dir, target) 43 | if not os.path.isdir(d): 44 | continue 45 | 46 | for root, _, fnames in sorted(os.walk(d)): # tuple: 元组,数组 47 | for fname in sorted(fnames): 48 | if is_image_file(fname): 49 | path = os.path.join(root, fname) 50 | item = (path, class_to_idx[target]) 51 | images.append(item) # append()用于在列表末尾添加新的对象,和C++中的push_back()一样 52 | 53 | return images 54 | 55 | 56 | def h5_loader(path): 57 | # 关于HDF5的文件操作,详见https://blog.csdn.net/yudf2010/article/details/50353292 58 | h5f = h5py.File(path, "r") # r是读的意思 59 | rgb = np.array(h5f['rgb']) # 使用array()函数可以将python的array_like数据转变成数组形式,使用matrix()函数转变成矩阵形式。 60 | # 基于习惯,在实际使用中较常用array而少用matrix来表示矩阵。 61 | rgb = np.transpose(rgb, (1, 2, 0)) # 关于np.transpose()对高维数组的转置,详见https://www.cnblogs.com/sunshinewang/p/6893503.html 62 | depth = np.array(h5f['depth']) 63 | 64 | return rgb, depth 65 | 66 | iheight, iwidth = 480, 640 # raw image size 67 | oheight, owidth = 228, 304 # image size after pre-processing 68 | color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4) 69 | 70 | # 数据增强(论文中的方法) 71 | def train_transform(rgb, depth): 72 | s = np.random.uniform(1.0, 1.5) # random scaling 73 | # print("scale factor s={}".format(s)) 74 | depth_np = depth / s 75 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees 76 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip 77 | 78 | # perform 1st part of data augmentation 79 | transform = transforms.Compose([ 80 | transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation is very slow 81 | transforms.Rotate(angle), 82 | transforms.Resize(s), 83 | transforms.CenterCrop((oheight, owidth)), 84 | transforms.HorizontalFlip(do_flip) 85 | ]) 86 | rgb_np = transform(rgb) 87 | 88 | # random color jittering 89 | rgb_np = color_jitter(rgb_np) 90 | 91 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 92 | depth_np = transform(depth_np) 93 | 94 | return rgb_np, depth_np 95 | 96 | # 数据增强(论文中的方法) 97 | def val_transform(rgb, depth): 98 | depth_np = depth 99 | 100 | # perform 1st part of data augmentation 101 | transform = transforms.Compose([ 102 | transforms.Resize(240.0 / iheight), 103 | transforms.CenterCrop((oheight, owidth)), 104 | ]) 105 | rgb_np = transform(rgb) 106 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 107 | depth_np = transform(depth_np) 108 | 109 | return rgb_np, depth_np 110 | 111 | def rgb2grayscale(rgb): 112 | return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114 113 | 114 | # 仅限RGB图像 115 | def MatrixTocvMat(data): 116 | data = data*255 117 | im = Image.fromarray(data.astype(np.uint8)) 118 | new_im = cv2.cvtColor(np.asarray(im),cv2.COLOR_RGB2BGR) 119 | return new_im 120 | 121 | to_tensor = transforms.ToTensor() 122 | 123 | class NYUDataset(data.Dataset): 124 | modality_names = ['rgb', 'rgbd', 'd'] 125 | 126 | def __init__(self, root, type, modality='rgb', num_samples=0, loader=h5_loader): 127 | classes, class_to_idx = find_classes(root) 128 | imgs = make_dataset(root, class_to_idx) 129 | if len(imgs) == 0: 130 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 131 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 132 | 133 | self.root = root 134 | self.imgs = imgs 135 | self.classes = classes 136 | self.class_to_idx = class_to_idx 137 | if type == 'train': 138 | self.transform = train_transform 139 | elif type == 'val': 140 | self.transform = val_transform 141 | else: 142 | raise (RuntimeError("Invalid dataset type: " + type + "\n" 143 | "Supported dataset types are: train, val")) 144 | self.loader = loader 145 | 146 | if modality in self.modality_names: # 如果在...中有... 147 | self.modality = modality 148 | if modality in ['rgbd', 'd', 'gd']: 149 | if num_samples <= 0: 150 | raise (RuntimeError("Invalid number of samples: {}\n".format(num_samples))) 151 | self.num_samples = num_samples 152 | else: 153 | self.num_samples = 0 154 | else: 155 | raise (RuntimeError("Invalid modality type: " + modality + "\n" 156 | "Supported dataset types are: " + ''.join(self.modality_names))) 157 | 158 | # 生成稀疏深度图 原版 159 | def create_sparse_depth(self, depth, num_samples): 160 | prob = float(num_samples) / depth.size # 概率 161 | mask_keep = np.random.uniform(0, 1, depth.shape) < prob # 生成一个0-1的mask,0-1是随机产生的,0-1产生的概率小于预设的概率 162 | sparse_depth = np.zeros(depth.shape) 163 | sparse_depth[mask_keep] = depth[mask_keep] # 把深度图中和mask_keep对应的元素赋值给sparse_depth中的对应元素 164 | return sparse_depth 165 | 166 | # 生成稀疏深度图 ORB特征提取 167 | def create_sparse_depth_ORB(self, rgb_np, depth, num_samples): 168 | # num_samples = int(prob * depth.size) 169 | rgb = MatrixTocvMat(rgb_np) 170 | gray = cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY) 171 | orb = cv2.ORB_create(num_samples, 1.2, 4, 10, 0, 2, cv2.ORB_HARRIS_SCORE, 10) # 需要进行网格化 172 | kp = orb.detect(gray, None) 173 | # print("number of ORB KeyPoints:", len(kp)) 174 | # cv2.namedWindow("ORB") 175 | # rgb_orb = cv2.drawKeypoints(rgb,kp,(255,0,0),-1) 176 | # cv2.imshow("ORB",rgb_orb) 177 | # cv2.waitKey(0) 178 | 179 | # print(len(kp)) # keypoint的个数 180 | # print(kp[0]) # 多少个角点,就有多少个下标 181 | # tu = kp[0].pt #(提取坐标) pt指的是元组 tuple(x,y) 182 | # print(tu[0],tu[1]) # 输出第一个keypoint的x,y坐标 183 | 184 | sparse_depth = np.zeros(depth.shape) 185 | for i in range(len(kp)): 186 | tu = kp[i].pt 187 | x = int(tu[0]) 188 | y = int(tu[1]) 189 | sparse_depth[y,x] = depth[y,x] # x,y和行,列的顺序不一样 190 | 191 | return sparse_depth 192 | 193 | def create_rgbd(self, rgb, depth, num_samples): 194 | sparse_depth = self.create_sparse_depth_ORB(rgb, depth, num_samples) 195 | # sparse_depth = self.create_sparse_depth(depth, num_samples) 196 | 197 | # rgbd = np.dstack((rgb[:,:,0], rgb[:,:,1], rgb[:,:,2], sparse_depth)) 198 | rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2) # append()相当于push_back() 199 | return rgbd 200 | 201 | # 获取原始图像 202 | def __getraw__(self, index): 203 | """ 204 | Args: 205 | index (int): Index 206 | 207 | Returns: 208 | tuple: (rgb, depth) the raw data. 209 | """ 210 | path, target = self.imgs[index] 211 | rgb, depth = self.loader(path) # 这个loader就是h5_loader 212 | return rgb, depth 213 | 214 | def __get_all_item__(self, index): 215 | """ 216 | Args: 217 | index (int): Index 218 | 219 | Returns: 220 | tuple: (input_tensor, depth_tensor, input_np, depth_np) 221 | """ 222 | rgb, depth = self.__getraw__(index) 223 | if self.transform is not None: 224 | rgb_np, depth_np = self.transform(rgb, depth) # 经过数据增强步骤后的结果 225 | else: 226 | raise(RuntimeError("transform not defined")) 227 | 228 | # color normalization 229 | # rgb_tensor = normalize_rgb(rgb_tensor) 230 | # rgb_np = normalize_np(rgb_np) 231 | 232 | if self.modality == 'rgb': 233 | input_np = rgb_np 234 | elif self.modality == 'rgbd': 235 | input_np = self.create_rgbd(rgb_np, depth_np, self.num_samples) 236 | elif self.modality == 'd': 237 | input_np = self.create_sparse_depth_ORB(rgb_np, depth_np, self.num_samples) 238 | # input_np = self.create_sparse_depth(depth_np, self.num_samples) 239 | 240 | input_tensor = to_tensor(input_np) 241 | while input_tensor.dim() < 3: 242 | input_tensor = input_tensor.unsqueeze(0) 243 | depth_tensor = to_tensor(depth_np) 244 | depth_tensor = depth_tensor.unsqueeze(0) 245 | 246 | return input_tensor, depth_tensor, input_np, depth_np 247 | 248 | def __getitem__(self, index): 249 | """ 250 | Args: 251 | index (int): Index 252 | 253 | Returns: 254 | tuple: (input_tensor, depth_tensor) 255 | """ 256 | input_tensor, depth_tensor, input_np, depth_np = self.__get_all_item__(index) 257 | 258 | return input_tensor, depth_tensor 259 | 260 | def __len__(self): 261 | return len(self.imgs) -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | # 对图片做各种变换的函数 2 | 3 | from __future__ import division # 导入未来支持的语言 division是精确除法 相见https://blog.csdn.net/feixingfei/article/details/7081446 4 | import torch 5 | import math 6 | import random 7 | 8 | from PIL import Image, ImageOps, ImageEnhance 9 | try: # 尝试做...如果不行则... python中所有含条件的语句都要加冒号 10 | import accimage 11 | except ImportError: 12 | accimage = None 13 | 14 | import numpy as np 15 | import numbers 16 | import types 17 | import collections 18 | import warnings 19 | 20 | import scipy.ndimage.interpolation as itpl 21 | import scipy.misc as misc 22 | 23 | 24 | def _is_numpy_image(img): # numpy格式的图像 25 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 26 | # return ... and ... 和 return ... or ... 27 | # 详见https://zhidao.baidu.com/question/1642792031752755300.html 28 | # 详见http://www.mamicode.com/info-detail-1765166.html 29 | 30 | def _is_pil_image(img): # PIL格式的图像 31 | if accimage is not None: 32 | return isinstance(img, (Image.Image, accimage.Image)) 33 | else: 34 | return isinstance(img, Image.Image) 35 | 36 | def _is_tensor_image(img): # tensor格式的图像 37 | return torch.is_tensor(img) and img.ndimension() == 3 38 | 39 | def adjust_brightness(img, brightness_factor): 40 | """Adjust brightness of an Image. 41 | 42 | Args: 43 | img (PIL Image): PIL Image to be adjusted. 44 | brightness_factor (float): How much to adjust the brightness. Can be 45 | any non negative number. 0 gives a black image, 1 gives the 46 | original image while 2 increases the brightness by a factor of 2. 47 | 48 | Returns: 49 | PIL Image: Brightness adjusted image. 50 | """ 51 | if not _is_pil_image(img): 52 | # 当程序出现错误时,python会自动引发异常,也可以通过raise显示地引发异常。 53 | # 一旦执行了raise语句,raise后面的语句将不能执行。 54 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 55 | 56 | enhancer = ImageEnhance.Brightness(img) 57 | img = enhancer.enhance(brightness_factor) 58 | return img 59 | 60 | 61 | def adjust_contrast(img, contrast_factor): 62 | """Adjust contrast of an Image. 63 | 64 | Args: 65 | img (PIL Image): PIL Image to be adjusted. 66 | contrast_factor (float): How much to adjust the contrast. Can be any 67 | non negative number. 0 gives a solid gray image, 1 gives the 68 | original image while 2 increases the contrast by a factor of 2. 69 | 70 | Returns: 71 | PIL Image: Contrast adjusted image. 72 | """ 73 | if not _is_pil_image(img): 74 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 75 | 76 | enhancer = ImageEnhance.Contrast(img) 77 | img = enhancer.enhance(contrast_factor) 78 | return img 79 | 80 | 81 | def adjust_saturation(img, saturation_factor): 82 | """Adjust color saturation of an image. 83 | 84 | Args: 85 | img (PIL Image): PIL Image to be adjusted. 86 | saturation_factor (float): How much to adjust the saturation. 0 will 87 | give a black and white image, 1 will give the original image while 88 | 2 will enhance the saturation by a factor of 2. 89 | 90 | Returns: 91 | PIL Image: Saturation adjusted image. 92 | """ 93 | if not _is_pil_image(img): 94 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 95 | 96 | enhancer = ImageEnhance.Color(img) 97 | img = enhancer.enhance(saturation_factor) 98 | return img 99 | 100 | 101 | def adjust_hue(img, hue_factor): # hue: 色调 102 | """Adjust hue of an image. 103 | 104 | The image hue is adjusted by converting the image to HSV and 105 | cyclically shifting the intensities in the hue channel (H). 106 | The image is then converted back to original image mode. 107 | 108 | `hue_factor` is the amount of shift in H channel and must be in the 109 | interval `[-0.5, 0.5]`. 110 | 111 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 112 | 113 | Args: 114 | img (PIL Image): PIL Image to be adjusted. 115 | hue_factor (float): How much to shift the hue channel. Should be in 116 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 117 | HSV space in positive and negative direction respectively. 118 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 119 | with complementary colors while 0 gives the original image. 120 | 121 | Returns: 122 | PIL Image: Hue adjusted image. 123 | """ 124 | if not(-0.5 <= hue_factor <= 0.5): 125 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 126 | 127 | if not _is_pil_image(img): 128 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 129 | 130 | input_mode = img.mode 131 | if input_mode in {'L', '1', 'I', 'F'}: 132 | return img 133 | 134 | h, s, v = img.convert('HSV').split() 135 | 136 | np_h = np.array(h, dtype=np.uint8) 137 | # uint8 addition take cares of rotation across boundaries 138 | with np.errstate(over='ignore'): 139 | np_h += np.uint8(hue_factor * 255) 140 | h = Image.fromarray(np_h, 'L') 141 | 142 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 143 | return img 144 | 145 | 146 | def adjust_gamma(img, gamma, gain=1): 147 | """Perform gamma correction on an image. 148 | 149 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 150 | based on the following equation: 151 | 152 | I_out = 255 * gain * ((I_in / 255) ** gamma) 153 | 154 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 155 | 156 | Args: 157 | img (PIL Image): PIL Image to be adjusted. 158 | gamma (float): Non negative real number. gamma larger than 1 make the 159 | shadows darker, while gamma smaller than 1 make dark regions 160 | lighter. 161 | gain (float): The constant multiplier. 162 | """ 163 | if not _is_pil_image(img): 164 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 165 | 166 | if gamma < 0: 167 | raise ValueError('Gamma should be a non-negative real number') 168 | 169 | input_mode = img.mode 170 | img = img.convert('RGB') 171 | 172 | np_img = np.array(img, dtype=np.float32) 173 | np_img = 255 * gain * ((np_img / 255) ** gamma) 174 | np_img = np.uint8(np.clip(np_img, 0, 255)) 175 | 176 | img = Image.fromarray(np_img, 'RGB').convert(input_mode) 177 | return img 178 | 179 | 180 | class Compose(object): 181 | """Composes several transforms together. 182 | 183 | Args: 184 | transforms (list of ``Transform`` objects): list of transforms to compose. 185 | 186 | Example: 187 | >>> transforms.Compose([ 188 | >>> transforms.CenterCrop(10), 189 | >>> transforms.ToTensor(), 190 | >>> ]) 191 | """ 192 | 193 | def __init__(self, transforms): 194 | self.transforms = transforms 195 | 196 | # __call__:把类的实例当成函数来用,详见https://blog.csdn.net/yaokai_assultmaster/article/details/70256621 197 | def __call__(self, img): 198 | for t in self.transforms: 199 | img = t(img) 200 | return img 201 | 202 | 203 | class ToTensor(object): 204 | """Convert a ``numpy.ndarray`` to tensor. 205 | 206 | Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). 207 | """ 208 | 209 | def __call__(self, img): 210 | """Convert a ``numpy.ndarray`` to tensor. 211 | 212 | Args: 213 | img (numpy.ndarray): Image to be converted to tensor. 214 | 215 | Returns: 216 | Tensor: Converted image. 217 | """ 218 | if not(_is_numpy_image(img)): 219 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 220 | 221 | if isinstance(img, np.ndarray): 222 | # handle numpy array 223 | if img.ndim == 3: 224 | img = torch.from_numpy(img.transpose((2, 0, 1)).copy()) 225 | elif img.ndim == 2: 226 | img = torch.from_numpy(img.copy()) 227 | else: 228 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 229 | 230 | # backward compatibility 231 | # return img.float().div(255) 232 | return img.float() 233 | 234 | 235 | class NormalizeNumpyArray(object): 236 | """Normalize a ``numpy.ndarray`` with mean and standard deviation. 237 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 238 | will normalize each channel of the input ``numpy.ndarray`` i.e. 239 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 240 | 241 | Args: 242 | mean (sequence): Sequence of means for each channel. 243 | std (sequence): Sequence of standard deviations for each channel. 244 | """ 245 | 246 | def __init__(self, mean, std): 247 | self.mean = mean 248 | self.std = std 249 | 250 | def __call__(self, img): 251 | """ 252 | Args: 253 | img (numpy.ndarray): Image of size (H, W, C) to be normalized. 254 | 255 | Returns: 256 | Tensor: Normalized image. 257 | """ 258 | if not(_is_numpy_image(img)): 259 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 260 | # TODO: make efficient 261 | # TODO在python中是一种助记符(Mnemonics),用来解释将要做什么。 262 | print(img.shape) 263 | for i in range(3): # 对RGB图像的每个通道进行同样的操作 264 | img[:,:,i] = (img[:,:,i] - self.mean[i]) / self.std[i] 265 | return img 266 | 267 | class NormalizeTensor(object): 268 | """Normalize an tensor image with mean and standard deviation. 269 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 270 | will normalize each channel of the input ``torch.*Tensor`` i.e. 271 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 272 | 273 | Args: 274 | mean (sequence): Sequence of means for each channel. 275 | std (sequence): Sequence of standard deviations for each channel. 276 | """ 277 | 278 | def __init__(self, mean, std): 279 | self.mean = mean 280 | self.std = std 281 | 282 | def __call__(self, tensor): 283 | """ 284 | Args: 285 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 286 | 287 | Returns: 288 | Tensor: Normalized Tensor image. 289 | """ 290 | if not _is_tensor_image(tensor): 291 | raise TypeError('tensor is not a torch image.') 292 | # TODO: make efficient 293 | # zip()函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。 294 | # 详见http://www.runoob.com/python/python-func-zip.html 295 | for t, m, s in zip(tensor, self.mean, self.std): 296 | t.sub_(m).div_(s) 297 | return tensor 298 | 299 | class Rotate(object): 300 | """Rotates the given ``numpy.ndarray``. 301 | 302 | Args: 303 | angle (float): The rotation angle in degrees. 304 | """ 305 | 306 | def __init__(self, angle): 307 | self.angle = angle 308 | 309 | def __call__(self, img): 310 | """ 311 | Args: 312 | img (numpy.ndarray (C x H x W)): Image to be rotated. 313 | 314 | Returns: 315 | img (numpy.ndarray (C x H x W)): Rotated image. 316 | """ 317 | 318 | return itpl.rotate(img, self.angle, reshape=False, prefilter=False) 319 | 320 | class Resize(object): 321 | """Resize the the given ``numpy.ndarray`` to the given size. 322 | Args: 323 | size (sequence or int): Desired output size. If size is a sequence like 324 | (h, w), output size will be matched to this. If size is an int, 325 | smaller edge of the image will be matched to this number. 326 | i.e, if height > width, then image will be rescaled to 327 | (size * height / width, size) 328 | interpolation (int, optional): Desired interpolation. Default is 329 | ``PIL.Image.BILINEAR`` 330 | """ 331 | 332 | def __init__(self, size, interpolation='nearest'): 333 | assert isinstance(size, int) or isinstance(size, float) or \ 334 | (isinstance(size, collections.Iterable) and len(size) == 2) 335 | self.size = size 336 | self.interpolation = interpolation 337 | 338 | def __call__(self, img): 339 | """ 340 | Args: 341 | img (PIL Image): Image to be scaled. 342 | Returns: 343 | PIL Image: Rescaled image. 344 | """ 345 | if img.ndim == 3: 346 | return misc.imresize(img, self.size, self.interpolation) 347 | elif img.ndim == 2: 348 | return misc.imresize(img, self.size, self.interpolation, 'F') 349 | else: 350 | RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 351 | 352 | 353 | class CenterCrop(object): 354 | """Crops the given ``numpy.ndarray`` at the center. 355 | 356 | Args: 357 | size (sequence or int): Desired output size of the crop. If size is an 358 | int instead of sequence like (h, w), a square crop (size, size) is 359 | made. 360 | """ 361 | 362 | def __init__(self, size): 363 | if isinstance(size, numbers.Number): 364 | self.size = (int(size), int(size)) 365 | else: 366 | self.size = size 367 | 368 | # @staticmethod详见https://www.zhihu.com/question/20021164 369 | @staticmethod 370 | def get_params(img, output_size): 371 | """Get parameters for ``crop`` for center crop. 372 | 373 | Args: 374 | img (numpy.ndarray (C x H x W)): Image to be cropped. 375 | output_size (tuple): Expected output size of the crop. 376 | 377 | Returns: 378 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. 379 | """ 380 | h = img.shape[0] 381 | w = img.shape[1] 382 | th, tw = output_size 383 | i = int(round((h - th) / 2.)) 384 | j = int(round((w - tw) / 2.)) 385 | 386 | # # randomized cropping 387 | # i = np.random.randint(i-3, i+4) 388 | # j = np.random.randint(j-3, j+4) 389 | 390 | return i, j, th, tw 391 | 392 | def __call__(self, img): 393 | """ 394 | Args: 395 | img (numpy.ndarray (C x H x W)): Image to be cropped. 396 | 397 | Returns: 398 | img (numpy.ndarray (C x H x W)): Cropped image. 399 | """ 400 | i, j, h, w = self.get_params(img, self.size) 401 | 402 | """ 403 | i: Upper pixel coordinate. 404 | j: Left pixel coordinate. 405 | h: Height of the cropped image. 406 | w: Width of the cropped image. 407 | """ 408 | if not(_is_numpy_image(img)): 409 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 410 | if img.ndim == 3: 411 | return img[i:i+h, j:j+w, :] 412 | elif img.ndim == 2: 413 | return img[i:i + h, j:j + w] 414 | else: 415 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 416 | 417 | 418 | class Lambda(object): 419 | """Apply a user-defined lambda as a transform. 420 | 421 | Args: 422 | lambd (function): Lambda/function to be used for transform. 423 | """ 424 | 425 | def __init__(self, lambd): 426 | assert isinstance(lambd, types.LambdaType) 427 | self.lambd = lambd 428 | 429 | def __call__(self, img): 430 | return self.lambd(img) 431 | 432 | 433 | class HorizontalFlip(object): 434 | """Horizontally flip the given ``numpy.ndarray``. 435 | 436 | Args: 437 | do_flip (boolean): whether or not do horizontal flip. 438 | 439 | """ 440 | 441 | def __init__(self, do_flip): 442 | self.do_flip = do_flip 443 | 444 | def __call__(self, img): 445 | """ 446 | Args: 447 | img (numpy.ndarray (C x H x W)): Image to be flipped. 448 | 449 | Returns: 450 | img (numpy.ndarray (C x H x W)): flipped image. 451 | """ 452 | if not(_is_numpy_image(img)): 453 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 454 | 455 | if self.do_flip: 456 | return np.fliplr(img) 457 | else: 458 | return img 459 | 460 | 461 | class ColorJitter(object): # jitter:抖动 462 | """Randomly change the brightness, contrast and saturation of an image. 463 | 464 | Args: 465 | brightness (float): How much to jitter brightness. brightness_factor 466 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 467 | contrast (float): How much to jitter contrast. contrast_factor 468 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 469 | saturation (float): How much to jitter saturation. saturation_factor 470 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 471 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 472 | [-hue, hue]. Should be >=0 and <= 0.5. 473 | """ 474 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 475 | self.brightness = brightness 476 | self.contrast = contrast 477 | self.saturation = saturation 478 | self.hue = hue 479 | 480 | @staticmethod 481 | def get_params(brightness, contrast, saturation, hue): 482 | """Get a randomized transform to be applied on image. 483 | 484 | Arguments are same as that of __init__. 485 | 486 | Returns: 487 | Transform which randomly adjusts brightness, contrast and 488 | saturation in a random order. 489 | """ 490 | transforms = [] 491 | if brightness > 0: 492 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 493 | transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) 494 | 495 | if contrast > 0: 496 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 497 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) 498 | 499 | if saturation > 0: 500 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 501 | transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) 502 | 503 | if hue > 0: 504 | hue_factor = np.random.uniform(-hue, hue) 505 | transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) 506 | 507 | np.random.shuffle(transforms) 508 | transform = Compose(transforms) 509 | 510 | return transform 511 | 512 | def __call__(self, img): 513 | """ 514 | Args: 515 | img (numpy.ndarray (C x H x W)): Input image. 516 | 517 | Returns: 518 | img (numpy.ndarray (C x H x W)): Color jittered image. 519 | """ 520 | if not(_is_numpy_image(img)): 521 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 522 | 523 | pil = Image.fromarray(img) 524 | transform = self.get_params(self.brightness, self.contrast, 525 | self.saturation, self.hue) 526 | return np.array(transform(pil)) 527 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # 导入一般的文件操作模块 2 | import argparse # 导入自动命令解析器 详见https://blog.csdn.net/ali197294332/article/details/51180628 3 | import os # 导入Python的系统基础操作模块 4 | import shutil # 导入高级的文件操作模块 详见https://www.cnblogs.com/MnCu8261/p/5494807.html 5 | import time # 导入对时间操作的函数 详见http://www.jb51.net/article/87721.htm 6 | import sys # 导入系统相关的信息模块 7 | import csv # 导入处理csv格式文件的相关模块 详见https://www.cnblogs.com/yanglang/p/7126660.html 8 | 9 | # 导入pytorch相关的模块 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel # 可以实现模块级别的并行计算。可以将一个模块forward部分分到各个gpu中,然后backward时合并gradient到original module。 13 | import torch.backends.cudnn as cudnn # cudnn: CUDA Deep Neural Network 相比标准的cuda,它在一些常用的神经网络操作上进行了性能的优化,比如卷积,pooling,归一化,以及激活层等等。 14 | import torch.optim 15 | import torch.utils.data 16 | 17 | # 导入自定义的模块 18 | from nyu_dataloader import NYUDataset 19 | from models import Decoder, ResNet 20 | from metrics import AverageMeter, Result 21 | import criteria 22 | import utils 23 | 24 | model_names = ['resnet18', 'resnet50'] 25 | loss_names = ['l1', 'l2'] 26 | data_names = ['NYUDataset'] 27 | decoder_names = Decoder.names 28 | modality_names = NYUDataset.modality_names 29 | 30 | cudnn.benchmark = True 31 | 32 | # 创建一个解析处理器 33 | parser = argparse.ArgumentParser(description='Sparse-to-Dense Training') 34 | # parser.add_argument('--data', metavar='DIR', help='path to dataset', 35 | # default="data/NYUDataset") 36 | 37 | # 设置多个参数 38 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 39 | choices=model_names, 40 | help='model architecture: ' + 41 | ' | '.join(model_names) + 42 | ' (default: resnet18)') 43 | # metavar:占位字符串,用于在输出帮助信息时,代替当前命令行选项的附加参数的值进行输出 44 | # join:连接字符串数组。将字符串、元组、列表中的元素以指定的字符(分隔符)连接生成一个新的字符串。详见https://blog.csdn.net/zmdzbzbhss123/article/details/52279008 45 | parser.add_argument('--data', metavar='DATA', default='nyudepthv2', 46 | choices=data_names, 47 | help='dataset: ' + 48 | ' | '.join(data_names) + 49 | ' (default: nyudepthv2)') 50 | parser.add_argument('--modality', '-m', metavar='MODALITY', default='rgb', 51 | choices=modality_names, 52 | help='modality: ' + 53 | ' | '.join(modality_names) + 54 | ' (default: rgb)') 55 | parser.add_argument('-s', '--num-samples', default=0, type=int, metavar='N', 56 | help='number of sparse depth samples (default: 0)') 57 | parser.add_argument('--decoder', '-d', metavar='DECODER', default='deconv2', 58 | choices=decoder_names, 59 | help='decoder: ' + 60 | ' | '.join(decoder_names) + 61 | ' (default: deconv2)') 62 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', 63 | help='number of data loading workers (default: 10)') 64 | parser.add_argument('--epochs', default=30, type=int, metavar='N', 65 | help='number of total epochs to run (default: 30)') 66 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 67 | help='manual epoch number (useful on restarts)') 68 | parser.add_argument('-c', '--criterion', metavar='LOSS', default='l1', 69 | choices=loss_names, 70 | help='loss function: ' + 71 | ' | '.join(loss_names) + 72 | ' (default: l1)') 73 | parser.add_argument('-b', '--batch-size', default=8, type=int, 74 | help='mini-batch size (default: 8)') 75 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 76 | metavar='LR', help='initial learning rate (default 0.01)') 77 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 78 | help='momentum') 79 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 80 | metavar='W', help='weight decay (default: 1e-4)') 81 | parser.add_argument('--print-freq', '-p', default=10, type=int, 82 | metavar='N', help='print frequency (default: 10)') 83 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 84 | help='path to latest checkpoint (default: none)') 85 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 86 | help='evaluate model on validation set') 87 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 88 | default=True, help='use ImageNet pre-trained weights (default: True)') 89 | 90 | fieldnames = ['mse', 'rmse', 'absrel', 'lg10', 'mae', 91 | 'delta1', 'delta2', 'delta3', 92 | 'data_time', 'gpu_time'] 93 | 94 | # Result()是criteria.py中的类 95 | best_result = Result() 96 | best_result.set_to_worst() 97 | 98 | # 定义函数 99 | # 基本格式: 100 | # def function_name(parameters): 101 | # expressions 102 | # Python使用缩进来规定代码的作用域 103 | def main(): 104 | global args, best_result, output_directory, train_csv, test_csv # 全局变量 105 | args = parser.parse_args() # 获取参数值 106 | args.data = os.path.join('data', args.data) 107 | # os.path.join()函数:将多个路径组合后返回 108 | # 语法:os.path.join(path1[,path2[,......]]) 109 | # 注:第一个绝对路径之前的参数将被忽略 110 | # 注意if的语句后面有冒号 111 | # args中modality的参数值。modality之前定义过 112 | if args.modality == 'rgb' and args.num_samples != 0: 113 | print("number of samples is forced to be 0 when input modality is rgb") 114 | args.num_samples = 0 115 | # 若是RGB的sparse-to-dense,则在生成训练数据时将稀疏深度点设为0 116 | 117 | # create results folder, if not already exists 118 | output_directory = os.path.join('results', 119 | 'NYUDataset.modality={}.nsample={}.arch={}.decoder={}.criterion={}.lr={}.bs={}'. 120 | format(args.modality, args.num_samples, args.arch, args.decoder, args.criterion, args.lr, args.batch_size)) # 输出文件名的格式 121 | 122 | # 如果路径不存在 123 | if not os.path.exists(output_directory): 124 | os.makedirs(output_directory) 125 | 126 | train_csv = os.path.join(output_directory, 'train.csv') 127 | test_csv = os.path.join(output_directory, 'test.csv') 128 | best_txt = os.path.join(output_directory, 'best.txt') 129 | 130 | # define loss function (criterion) and optimizer 131 | if args.criterion == 'l2': 132 | criterion = criteria.MaskedMSELoss().cuda() # 调用别的py文件中的内容时,若被调用的是函数,则直接写函数名即可;若被调用的是类,则要按这句话的格式写 133 | out_channels = 1 134 | # elif: else if 135 | elif args.criterion == 'l1': 136 | criterion = criteria.MaskedL1Loss().cuda() 137 | out_channels = 1 138 | 139 | # Data loading code 140 | print("=> creating data loaders ...") 141 | traindir = os.path.join(args.data, 'train') 142 | valdir = os.path.join(args.data, 'val') 143 | 144 | train_dataset = NYUDataset(traindir, type='train', 145 | modality=args.modality, num_samples=args.num_samples) 146 | # DataLoader是导入数据的函数 147 | train_loader = torch.utils.data.DataLoader( 148 | train_dataset, batch_size=args.batch_size, shuffle=True, 149 | num_workers=args.workers, pin_memory=True, sampler=None) 150 | 151 | # set batch size to be 1 for validation 152 | val_dataset = NYUDataset(valdir, type='val', 153 | modality=args.modality, num_samples=args.num_samples) 154 | val_loader = torch.utils.data.DataLoader(val_dataset, 155 | batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) 156 | 157 | print("=> data loaders created.") 158 | 159 | # evaluation mode 160 | if args.evaluate: 161 | best_model_filename = os.path.join(output_directory, 'model_best.pth.tar') 162 | if os.path.isfile(best_model_filename): 163 | print("=> loading best model '{}'".format(best_model_filename)) 164 | checkpoint = torch.load(best_model_filename) 165 | args.start_epoch = checkpoint['epoch'] 166 | best_result = checkpoint['best_result'] 167 | model = checkpoint['model'] 168 | print("=> loaded best model (epoch {})".format(checkpoint['epoch'])) 169 | else: # else也要加: 170 | print("=> no best model found at '{}'".format(best_model_filename)) 171 | validate(val_loader, model, checkpoint['epoch'], write_to_file=False) 172 | return 173 | 174 | # optionally resume from a checkpoint 175 | elif args.resume: 176 | if os.path.isfile(args.resume): 177 | print("=> loading checkpoint '{}'".format(args.resume)) 178 | checkpoint = torch.load(args.resume) 179 | args.start_epoch = checkpoint['epoch']+1 180 | best_result = checkpoint['best_result'] 181 | model = checkpoint['model'] 182 | optimizer = checkpoint['optimizer'] 183 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 184 | else: 185 | print("=> no checkpoint found at '{}'".format(args.resume)) 186 | 187 | # create new model 188 | else: 189 | # define model 190 | print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder)) 191 | in_channels = len(args.modality) # len()返回对象的长度或项目个数 192 | if args.arch == 'resnet50': 193 | model = ResNet(layers=50, decoder=args.decoder, in_channels=in_channels, 194 | out_channels=out_channels, pretrained=args.pretrained) 195 | elif args.arch == 'resnet18': 196 | model = ResNet(layers=18, decoder=args.decoder, in_channels=in_channels, 197 | out_channels=out_channels, pretrained=args.pretrained) 198 | print("=> model created.") 199 | 200 | optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 201 | 202 | # create new csv files with only header 203 | # with open() as xxx: 的用法详见https://www.cnblogs.com/ymjyqsx/p/6554817.html 204 | with open(train_csv, 'w') as csvfile: 205 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 206 | writer.writeheader() 207 | with open(test_csv, 'w') as csvfile: 208 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 209 | writer.writeheader() 210 | 211 | # model = torch.nn.DataParallel(model).cuda() 212 | model = model.cuda() 213 | print(model) 214 | print("=> model transferred to GPU.") 215 | 216 | # for循环也要有: 217 | # 一般情况下,循环次数未知采用while循环,循环次数已知采用for 218 | for epoch in range(args.start_epoch, args.epochs): 219 | adjust_learning_rate(optimizer, epoch) 220 | 221 | # train for one epoch 222 | train(train_loader, model, criterion, optimizer, epoch) 223 | 224 | # evaluate on validation set 225 | result, img_merge = validate(val_loader, model, epoch) 226 | # Python的return可以返回多个值 227 | 228 | # remember best rmse and save checkpoint 229 | is_best = result.rmse < best_result.rmse 230 | if is_best: 231 | best_result = result 232 | with open(best_txt, 'w') as txtfile: 233 | # 字符串格式化输出 234 | # :3f中,3表示输出宽度,f表示浮点型。若输出位数小于此宽度,则默认右对齐,左边补空格。 235 | # 若输出位数大于宽度,则按实际位数输出。 236 | # :.3f中,.3表示指定除小数点外的输出位数,f表示浮点型。 237 | txtfile.write("epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n". 238 | format(epoch, result.mse, result.rmse, result.absrel, result.lg10, result.mae, result.delta1, result.gpu_time)) 239 | # None表示该值是一个空对象,空值是Python里一个特殊的值,用None表示。None不能理解为0,因为0是有意义的,而None是一个特殊的空值。 240 | # 你可以将None赋值给任何变量,也可以将任何变量赋值给一个None值的对象 241 | # None在判断的时候是False 242 | # NULL是空字符,和None不一样 243 | if img_merge is not None: 244 | img_filename = output_directory + '/comparison_best.png' 245 | utils.save_image(img_merge, img_filename) 246 | 247 | # Python中,万物皆对象,所有的操作都是针对对象的。一个对象包括两方面的特征: 248 | # 属性:去描述它的特征 249 | # 方法:它所具有的行为 250 | # 所以,对象=属性+方法 (其实方法也是一种属性,一种区别于数据属性的可调用属性) 251 | 252 | save_checkpoint({ 253 | 'epoch': epoch, 254 | 'arch': args.arch, 255 | 'model': model, 256 | 'best_result': best_result, 257 | 'optimizer' : optimizer, 258 | }, is_best, epoch) 259 | 260 | 261 | def train(train_loader, model, criterion, optimizer, epoch): 262 | average_meter = AverageMeter() 263 | 264 | # switch to train mode 265 | model.train() 266 | 267 | end = time.time() # 计时开始 268 | # enumerate()用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在for循环当中。 269 | for i, (input, target) in enumerate(train_loader): 270 | 271 | input, target = input.cuda(), target.cuda() 272 | # torch.autograd提供实现任意标量值功能的自动区分的类和功能。它将所有张量包装在Variable对象中。 273 | # Variable可以看作是对Tensor对象周围的一个薄包装,也包含了和张量相关的梯度,以及对创建它的函数的引用。 274 | # 此引用允许对创建数据的整个操作链进行回溯。需要BP的网络都是通过Variable来计算的。 275 | # pytorch中的所有运算都是基于Tensor的,Variable只是一个Wrapper,Variable的计算的实质就是里面的Tensor在计算。 276 | # 详见https://blog.csdn.net/KGzhang/article/details/77483383 277 | input_var = torch.autograd.Variable(input) 278 | target_var = torch.autograd.Variable(target) 279 | torch.cuda.synchronize() # 等待当前设备上所有流中的所有内核完成 280 | data_time = time.time() - end # 计算用时 281 | 282 | # compute depth_pred 283 | end = time.time() 284 | depth_pred = model(input_var) 285 | loss = criterion(depth_pred, target_var) 286 | # optimizer包提供训练时更新参数的功能 287 | optimizer.zero_grad() # zero the gradient buffers,必须要置零 288 | # 在BP的时候,pytorch将Variable的梯度放在Variable对象中,我们随时可以用Variable.grad得到grad。 289 | # 刚创建Variable的时候,它的grad属性初始化为0.0 290 | loss.backward() # compute gradient and do SGD step 291 | optimizer.step() # 更新 292 | torch.cuda.synchronize() 293 | gpu_time = time.time() - end 294 | 295 | # measure accuracy and record loss 296 | result = Result() 297 | output1 = torch.index_select(depth_pred.data, 1, torch.cuda.LongTensor([0])) 298 | result.evaluate(output1, target) 299 | average_meter.update(result, gpu_time, data_time, input.size(0)) 300 | end = time.time() 301 | 302 | if (i + 1) % args.print_freq == 0: 303 | print('=> output: {}'.format(output_directory)) 304 | print('Train Epoch: {0} [{1}/{2}]\t' 305 | 't_Data={data_time:.3f}({average.data_time:.3f}) ' 306 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f}) ' 307 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 308 | 'MAE={result.mae:.2f}({average.mae:.2f}) ' 309 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 310 | 'REL={result.absrel:.3f}({average.absrel:.3f}) ' 311 | 'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format( 312 | epoch, i+1, len(train_loader), data_time=data_time, 313 | gpu_time=gpu_time, result=result, average=average_meter.average())) 314 | 315 | avg = average_meter.average() 316 | with open(train_csv, 'a') as csvfile: 317 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 318 | writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10, 319 | 'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3, 320 | 'gpu_time': avg.gpu_time, 'data_time': avg.data_time}) 321 | 322 | 323 | def validate(val_loader, model, epoch, write_to_file=True): 324 | average_meter = AverageMeter() 325 | 326 | # switch to evaluate mode 327 | model.eval() 328 | 329 | end = time.time() 330 | for i, (input, target) in enumerate(val_loader): 331 | input, target = input.cuda(), target.cuda() # 从后面看,这里的target应该是深度图的ground truth 332 | input_var = torch.autograd.Variable(input) 333 | target_var = torch.autograd.Variable(target) 334 | torch.cuda.synchronize() 335 | data_time = time.time() - end 336 | 337 | # compute output 338 | end = time.time() 339 | depth_pred = model(input_var) 340 | torch.cuda.synchronize() 341 | gpu_time = time.time() - end 342 | 343 | # measure accuracy and record loss 344 | result = Result() 345 | output1 = torch.index_select(depth_pred.data, 1, torch.cuda.LongTensor([0])) 346 | result.evaluate(output1, target) 347 | average_meter.update(result, gpu_time, data_time, input.size(0)) 348 | end = time.time() 349 | 350 | # save 8 images for visualization 351 | skip = 50 352 | if args.modality == 'd': 353 | img_merge = None 354 | else: 355 | if args.modality == 'rgb': 356 | rgb = input 357 | elif args.modality == 'rgbd': 358 | rgb = input[:,:3,:,:] 359 | 360 | if i == 0: 361 | img_merge = utils.merge_into_row(rgb, target, depth_pred) 362 | # 隔50个图片抽一张作为可视化结果 363 | elif (i < 8*skip) and (i % skip == 0): # and等同于C++中的&& 364 | row = utils.merge_into_row(rgb, target, depth_pred) 365 | img_merge = utils.add_row(img_merge, row) # 添加一行 366 | elif i == 8*skip: # 只保存8张图片,保存够8张后输出 367 | filename = output_directory + '/comparison_' + str(epoch) + '.png' # str():将()中的对象转换为字符串 368 | utils.save_image(img_merge, filename) # 建议:把这种常用的功能写到特定的脚本文件中,再像这样调用 369 | 370 | if (i+1) % args.print_freq == 0: 371 | print('Test: [{0}/{1}]\t' 372 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\t' 373 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 374 | 'MAE={result.mae:.2f}({average.mae:.2f}) ' 375 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 376 | 'REL={result.absrel:.3f}({average.absrel:.3f}) ' 377 | 'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format( 378 | i+1, len(val_loader), gpu_time=gpu_time, result=result, average=average_meter.average())) 379 | 380 | avg = average_meter.average() 381 | 382 | print('\n*\n' 383 | 'RMSE={average.rmse:.3f}\n' 384 | 'MAE={average.mae:.3f}\n' 385 | 'Delta1={average.delta1:.3f}\n' 386 | 'REL={average.absrel:.3f}\n' 387 | 'Lg10={average.lg10:.3f}\n' 388 | 't_GPU={time:.3f}\n'.format( 389 | average=avg, time=avg.gpu_time)) 390 | 391 | if write_to_file: 392 | with open(test_csv, 'a') as csvfile: 393 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 394 | writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10, 395 | 'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3, 396 | 'data_time': avg.data_time, 'gpu_time': avg.gpu_time}) 397 | 398 | return avg, img_merge 399 | 400 | def save_checkpoint(state, is_best, epoch): 401 | checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch) + '.pth.tar') 402 | torch.save(state, checkpoint_filename) 403 | if is_best: 404 | best_filename = os.path.join(output_directory, 'model_best.pth.tar') 405 | shutil.copyfile(checkpoint_filename, best_filename) 406 | if epoch > 0: 407 | prev_checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch-1) + '.pth.tar') 408 | if os.path.exists(prev_checkpoint_filename): 409 | os.remove(prev_checkpoint_filename) 410 | 411 | def adjust_learning_rate(optimizer, epoch): 412 | # """ """中的内容为函数的说明,在鼠标放在此函数上时会自动显示该说明 413 | """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" 414 | # //:整数除法,返回不大于结果的一个最大的整数 415 | # /:浮点数除法。在符号前后的数字都为整型时,输出的也是整型,和c++一样 416 | # 在Python数学运算中*代表乘法,**为指数运算 417 | lr = args.lr * (0.1 ** (epoch // 10)) 418 | for param_group in optimizer.param_groups: 419 | param_group['lr'] = lr 420 | 421 | # 这句话是程序的入口 422 | # 意思是:当.py文件被直接运行时,if __name__ == '__main__'之下的代码块将被运行;当.py文件以模块形式被导入时,if __name__ == '__main__'之下的代码块不被运行。 423 | # 详见https://blog.csdn.net/yjk13703623757/article/details/77918633 424 | if __name__ == '__main__': 425 | main() --------------------------------------------------------------------------------