├── .gitignore ├── README.md ├── criteria.py ├── dataloaders ├── __init__.py ├── dataloader.py ├── kitti_dataloader.py ├── nyu_dataloader.py ├── path.py └── transforms.py ├── main.py ├── metrics.py ├── network ├── FCRN.py └── __init__.py ├── result ├── kitti.png ├── nyu.png └── result.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | .idea/ 29 | /.idea 30 | *.iml 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | .dmypy.json 115 | dmypy.json 116 | 117 | # Pyre type checker 118 | .pyre/ 119 | 120 | # JPG PNG 121 | *.jpg 122 | *.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FCRN implemented in Pytorch 0.4.1 2 | 3 | 4 | ### Introduction 5 | This is a PyTorch(0.4.1) implementation of [Deeper Depth Prediction with Fully Convolutional Residual Networks](http://ieeexplore.ieee.org/document/7785097/). It 6 | can use Fully Convolutional Residual Networks to realize monocular depth prediction. Currently, we can train FCRN 7 | using NYUDepthv2 and Kitti Odometry Dataset. 8 | 9 | 10 | ### Result 11 | 12 | #### NYU Depthv2 13 | 14 | The code was tested with Python 3.5 with Pytorch 0.4.1 in 12GB TITAN X. We train 60 epochs with batch size = 16. The trained model can be download from [BaiduYun](https://pan.baidu.com/s/1A3lq0ntPKBOH-En818bo8A). 15 | 16 | Method | rml | rmse | log10 | Delta1 | Delta2 | Delta3 17 | :-------| :------: | :------: | :------: | :------: | :------: | :------: 18 | FCRN | 0.127 | 0.573 | 0.055 | 0.811 | 0.953 | 0.988 19 | FCRN_ours | 0.149 | 0.527 | 0.062 | 0.805 | 0.954 | 0.987 20 | 21 | ![Image text](https://github.com/dontLoveBugs/FCRN_pytorch/blob/master/result/nyu.png) 22 | 23 | #### Kitti Odometry 24 | Method | rml | rmse | log10 | Delta1 | Delta2 | Delta3 25 | :-------| :------: | :------: | :------: | :------: | :------: | :------: 26 | FCRN_ours | 0.113 | 4.801 | 0.048 | 0.865 | 0.957 | 0.984 27 | 28 | ![Image text](https://github.com/dontLoveBugs/FCRN_pytorch/blob/master/result/kitti.png) 29 | 30 | ### Installation 31 | The code was tested with Python 3.5 with Pytorch 0.4.1 in 2 GPU TITAN X. 32 | 33 | 0. Clone the repo: 34 | ```Shell 35 | git clone git@github.com:dontLoveBugs/FCRN_pyotrch.git 36 | cd FCRN_pytorch 37 | ``` 38 | 39 | 1. Install dependencies: 40 | 41 | For PyTorch dependency, see [pytorch.org](https://pytorch.org/) for more details. 42 | 43 | For custom dependencies: 44 | ```Shell 45 | pip install matplotlib pillow tensorboardX 46 | ``` 47 | 48 | 2. Configure your dataset path in "dataloaders/path.py". 49 | 50 | 3. Training 51 | 52 | To train NYU Depth v2, please do: 53 | ```Shell 54 | python main.py --dataset nyu 55 | ``` 56 | 57 | To train it on KITTI, please do: 58 | ```Shell 59 | python main.py --dataset kitti 60 | ``` 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /criteria.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/10/23 20:04 3 | # @Author : Wang Xin 4 | # @Email : wangxin_buaa@163.com 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class MaskedMSELoss(nn.Module): 12 | def __init__(self): 13 | super(MaskedMSELoss, self).__init__() 14 | 15 | def forward(self, pred, target): 16 | assert pred.dim() == target.dim(), "inconsistent dimensions" 17 | valid_mask = (target > 0).detach() 18 | diff = target - pred 19 | diff = diff[valid_mask] 20 | self.loss = (diff ** 2).mean() 21 | return self.loss 22 | 23 | 24 | class MaskedL1Loss(nn.Module): 25 | def __init__(self): 26 | super(MaskedL1Loss, self).__init__() 27 | 28 | def forward(self, pred, target): 29 | assert pred.dim() == target.dim(), "inconsistent dimensions" 30 | valid_mask = (target > 0).detach() 31 | diff = target - pred 32 | diff = diff[valid_mask] 33 | self.loss = diff.abs().mean() 34 | return self.loss 35 | 36 | 37 | class berHuLoss(nn.Module): 38 | def __init__(self): 39 | super(berHuLoss, self).__init__() 40 | 41 | def forward(self, pred, target): 42 | assert pred.dim() == target.dim(), "inconsistent dimensions" 43 | 44 | huber_c = torch.max(pred - target) 45 | huber_c = 0.2 * huber_c 46 | 47 | valid_mask = (target > 0).detach() 48 | diff = target - pred 49 | diff = diff[valid_mask] 50 | diff = diff.abs() 51 | 52 | huber_mask = (diff > huber_c).detach() 53 | 54 | diff2 = diff[huber_mask] 55 | diff2 = diff2 ** 2 56 | 57 | self.loss = torch.cat((diff, diff2)).mean() 58 | 59 | return self.loss 60 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/10/21 20:43 3 | # @Author : Wang Xin 4 | # @Email : wangxin_buaa@163.com -------------------------------------------------------------------------------- /dataloaders/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | import torch.utils.data as data 5 | import h5py 6 | import dataloaders.transforms as transforms 7 | 8 | IMG_EXTENSIONS = ['.h5', ] 9 | 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | 15 | def find_classes(dir): 16 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 17 | classes.sort() 18 | class_to_idx = {classes[i]: i for i in range(len(classes))} 19 | return classes, class_to_idx 20 | 21 | 22 | def make_dataset(dir, class_to_idx): 23 | images = [] 24 | dir = os.path.expanduser(dir) 25 | for target in sorted(os.listdir(dir)): 26 | d = os.path.join(dir, target) 27 | if not os.path.isdir(d): 28 | continue 29 | for root, _, fnames in sorted(os.walk(d)): 30 | for fname in sorted(fnames): 31 | if is_image_file(fname): 32 | path = os.path.join(root, fname) 33 | item = (path, class_to_idx[target]) 34 | images.append(item) 35 | return images 36 | 37 | 38 | def h5_loader(path): 39 | h5f = h5py.File(path, "r") 40 | rgb = np.array(h5f['rgb']) 41 | rgb = np.transpose(rgb, (1, 2, 0)) 42 | depth = np.array(h5f['depth']) 43 | return rgb, depth 44 | 45 | 46 | # def rgb2grayscale(rgb): 47 | # return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114 48 | 49 | to_tensor = transforms.ToTensor() 50 | 51 | 52 | class MyDataloader(data.Dataset): 53 | modality_names = ['rgb', 'rgbd', 'd'] # , 'g', 'gd' 54 | color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4) 55 | 56 | def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader): 57 | classes, class_to_idx = find_classes(root) 58 | imgs = make_dataset(root, class_to_idx) 59 | assert len(imgs) > 0, "Found 0 images in subfolders of: " + root + "\n" 60 | print("Found {} images in {} folder.".format(len(imgs), type)) 61 | self.root = root 62 | self.imgs = imgs 63 | self.classes = classes 64 | self.class_to_idx = class_to_idx 65 | if type == 'train': 66 | self.transform = self.train_transform 67 | elif type == 'val': 68 | self.transform = self.val_transform 69 | else: 70 | raise (RuntimeError("Invalid dataset type: " + type + "\n" 71 | "Supported dataset types are: train, val")) 72 | self.loader = loader 73 | self.sparsifier = sparsifier 74 | 75 | assert (modality in self.modality_names), "Invalid modality type: " + modality + "\n" + \ 76 | "Supported dataset types are: " + ''.join(self.modality_names) 77 | self.modality = modality 78 | 79 | def train_transform(self, rgb, depth): 80 | raise (RuntimeError("train_transform() is not implemented. ")) 81 | 82 | def val_transform(rgb, depth): 83 | raise (RuntimeError("val_transform() is not implemented.")) 84 | 85 | def create_sparse_depth(self, rgb, depth): 86 | if self.sparsifier is None: 87 | return depth 88 | else: 89 | mask_keep = self.sparsifier.dense_to_sparse(rgb, depth) 90 | sparse_depth = np.zeros(depth.shape) 91 | sparse_depth[mask_keep] = depth[mask_keep] 92 | return sparse_depth 93 | 94 | def create_rgbd(self, rgb, depth): 95 | sparse_depth = self.create_sparse_depth(rgb, depth) 96 | rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2) 97 | return rgbd 98 | 99 | def __getraw__(self, index): 100 | """ 101 | Args: 102 | index (int): Index 103 | 104 | Returns: 105 | tuple: (rgb, depth) the raw data. 106 | """ 107 | path, target = self.imgs[index] 108 | rgb, depth = self.loader(path) 109 | return rgb, depth 110 | 111 | def __getitem__(self, index): 112 | rgb, depth = self.__getraw__(index) 113 | if self.transform is not None: 114 | rgb_np, depth_np = self.transform(rgb, depth) 115 | else: 116 | raise (RuntimeError("transform not defined")) 117 | 118 | if self.modality == 'rgb': 119 | input_np = rgb_np 120 | elif self.modality == 'rgbd': 121 | input_np = self.create_rgbd(rgb_np, depth_np) 122 | elif self.modality == 'd': 123 | input_np = self.create_sparse_depth(rgb_np, depth_np) 124 | 125 | input_tensor = to_tensor(input_np) 126 | while input_tensor.dim() < 3: 127 | input_tensor = input_tensor.unsqueeze(0) 128 | depth_tensor = to_tensor(depth_np) 129 | depth_tensor = depth_tensor.unsqueeze(0) 130 | 131 | return input_tensor, depth_tensor 132 | 133 | def __len__(self): 134 | return len(self.imgs) 135 | -------------------------------------------------------------------------------- /dataloaders/kitti_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dataloaders.transforms as transforms 3 | from dataloaders.dataloader import MyDataloader 4 | 5 | 6 | class KITTIDataset(MyDataloader): 7 | def __init__(self, root, type, sparsifier=None, modality='rgb'): 8 | super(KITTIDataset, self).__init__(root, type, sparsifier, modality) 9 | self.output_size = (228, 912) 10 | 11 | def train_transform(self, rgb, depth): 12 | s = np.random.uniform(1.0, 1.5) # random scaling 13 | depth_np = depth / s 14 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees 15 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip 16 | 17 | # perform 1st step of data augmentation 18 | transform = transforms.Compose([ 19 | transforms.Crop(130, 10, 240, 1200), 20 | transforms.Rotate(angle), 21 | transforms.Resize(s), 22 | transforms.CenterCrop(self.output_size), 23 | transforms.HorizontalFlip(do_flip) 24 | ]) 25 | rgb_np = transform(rgb) 26 | rgb_np = self.color_jitter(rgb_np) # random color jittering 27 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 28 | # Scipy affine_transform produced RuntimeError when the depth map was 29 | # given as a 'numpy.ndarray' 30 | depth_np = np.asfarray(depth_np, dtype='float32') 31 | depth_np = transform(depth_np) 32 | 33 | return rgb_np, depth_np 34 | 35 | def val_transform(self, rgb, depth): 36 | depth_np = depth 37 | transform = transforms.Compose([ 38 | transforms.Crop(130, 10, 240, 1200), 39 | transforms.CenterCrop(self.output_size), 40 | ]) 41 | rgb_np = transform(rgb) 42 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 43 | depth_np = np.asfarray(depth_np, dtype='float32') 44 | depth_np = transform(depth_np) 45 | 46 | return rgb_np, depth_np 47 | -------------------------------------------------------------------------------- /dataloaders/nyu_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dataloaders.transforms as transforms 3 | from dataloaders.dataloader import MyDataloader 4 | 5 | iheight, iwidth = 480, 640 # raw image size 6 | 7 | 8 | class NYUDataset(MyDataloader): 9 | def __init__(self, root, type, sparsifier=None, modality='rgb'): 10 | super(NYUDataset, self).__init__(root, type, sparsifier, modality) 11 | self.output_size = (228, 304) 12 | 13 | def train_transform(self, rgb, depth): 14 | s = np.random.uniform(1.0, 1.5) # random scaling 15 | depth_np = depth / s 16 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees 17 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip 18 | 19 | # perform 1st step of data augmentation 20 | transform = transforms.Compose([ 21 | transforms.Resize(250.0 / iheight), # this is for computational efficiency, since rotation can be slow 22 | transforms.Rotate(angle), 23 | transforms.Resize(s), 24 | transforms.CenterCrop(self.output_size), 25 | transforms.HorizontalFlip(do_flip) 26 | ]) 27 | rgb_np = transform(rgb) 28 | rgb_np = self.color_jitter(rgb_np) # random color jittering 29 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 30 | depth_np = transform(depth_np) 31 | 32 | return rgb_np, depth_np 33 | 34 | def val_transform(self, rgb, depth): 35 | depth_np = depth 36 | transform = transforms.Compose([ 37 | transforms.Resize(240.0 / iheight), 38 | transforms.CenterCrop(self.output_size), 39 | ]) 40 | rgb_np = transform(rgb) 41 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255 42 | depth_np = transform(depth_np) 43 | 44 | return rgb_np, depth_np 45 | -------------------------------------------------------------------------------- /dataloaders/path.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/21 22:07 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | 9 | class Path(object): 10 | @staticmethod 11 | def db_root_dir(database): 12 | if database == 'nyu': 13 | return '/home/data/model/wangxin/nyudepthv2' 14 | elif database == 'kitti': 15 | return '/home/data/UnsupervisedDepth/wangixn/kitti' 16 | else: 17 | print('Database {} not available.'.format(database)) 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /dataloaders/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | 6 | from PIL import Image, ImageOps, ImageEnhance 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | 12 | import numpy as np 13 | import numbers 14 | import types 15 | import collections 16 | import warnings 17 | 18 | import scipy.ndimage.interpolation as itpl 19 | import scipy.misc as misc 20 | 21 | 22 | def _is_numpy_image(img): 23 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 24 | 25 | def _is_pil_image(img): 26 | if accimage is not None: 27 | return isinstance(img, (Image.Image, accimage.Image)) 28 | else: 29 | return isinstance(img, Image.Image) 30 | 31 | def _is_tensor_image(img): 32 | return torch.is_tensor(img) and img.ndimension() == 3 33 | 34 | def adjust_brightness(img, brightness_factor): 35 | """Adjust brightness of an Image. 36 | 37 | Args: 38 | img (PIL Image): PIL Image to be adjusted. 39 | brightness_factor (float): How much to adjust the brightness. Can be 40 | any non negative number. 0 gives a black image, 1 gives the 41 | original image while 2 increases the brightness by a factor of 2. 42 | 43 | Returns: 44 | PIL Image: Brightness adjusted image. 45 | """ 46 | if not _is_pil_image(img): 47 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 48 | 49 | enhancer = ImageEnhance.Brightness(img) 50 | img = enhancer.enhance(brightness_factor) 51 | return img 52 | 53 | 54 | def adjust_contrast(img, contrast_factor): 55 | """Adjust contrast of an Image. 56 | 57 | Args: 58 | img (PIL Image): PIL Image to be adjusted. 59 | contrast_factor (float): How much to adjust the contrast. Can be any 60 | non negative number. 0 gives a solid gray image, 1 gives the 61 | original image while 2 increases the contrast by a factor of 2. 62 | 63 | Returns: 64 | PIL Image: Contrast adjusted image. 65 | """ 66 | if not _is_pil_image(img): 67 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 68 | 69 | enhancer = ImageEnhance.Contrast(img) 70 | img = enhancer.enhance(contrast_factor) 71 | return img 72 | 73 | 74 | def adjust_saturation(img, saturation_factor): 75 | """Adjust color saturation of an image. 76 | 77 | Args: 78 | img (PIL Image): PIL Image to be adjusted. 79 | saturation_factor (float): How much to adjust the saturation. 0 will 80 | give a black and white image, 1 will give the original image while 81 | 2 will enhance the saturation by a factor of 2. 82 | 83 | Returns: 84 | PIL Image: Saturation adjusted image. 85 | """ 86 | if not _is_pil_image(img): 87 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 88 | 89 | enhancer = ImageEnhance.Color(img) 90 | img = enhancer.enhance(saturation_factor) 91 | return img 92 | 93 | 94 | def adjust_hue(img, hue_factor): 95 | """Adjust hue of an image. 96 | 97 | The image hue is adjusted by converting the image to HSV and 98 | cyclically shifting the intensities in the hue channel (H). 99 | The image is then converted back to original image mode. 100 | 101 | `hue_factor` is the amount of shift in H channel and must be in the 102 | interval `[-0.5, 0.5]`. 103 | 104 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 105 | 106 | Args: 107 | img (PIL Image): PIL Image to be adjusted. 108 | hue_factor (float): How much to shift the hue channel. Should be in 109 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 110 | HSV space in positive and negative direction respectively. 111 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 112 | with complementary colors while 0 gives the original image. 113 | 114 | Returns: 115 | PIL Image: Hue adjusted image. 116 | """ 117 | if not(-0.5 <= hue_factor <= 0.5): 118 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 119 | 120 | if not _is_pil_image(img): 121 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 122 | 123 | input_mode = img.mode 124 | if input_mode in {'L', '1', 'I', 'F'}: 125 | return img 126 | 127 | h, s, v = img.convert('HSV').split() 128 | 129 | np_h = np.array(h, dtype=np.uint8) 130 | # uint8 addition take cares of rotation across boundaries 131 | with np.errstate(over='ignore'): 132 | np_h += np.uint8(hue_factor * 255) 133 | h = Image.fromarray(np_h, 'L') 134 | 135 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 136 | return img 137 | 138 | 139 | def adjust_gamma(img, gamma, gain=1): 140 | """Perform gamma correction on an image. 141 | 142 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 143 | based on the following equation: 144 | 145 | I_out = 255 * gain * ((I_in / 255) ** gamma) 146 | 147 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 148 | 149 | Args: 150 | img (PIL Image): PIL Image to be adjusted. 151 | gamma (float): Non negative real number. gamma larger than 1 make the 152 | shadows darker, while gamma smaller than 1 make dark regions 153 | lighter. 154 | gain (float): The constant multiplier. 155 | """ 156 | if not _is_pil_image(img): 157 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 158 | 159 | if gamma < 0: 160 | raise ValueError('Gamma should be a non-negative real number') 161 | 162 | input_mode = img.mode 163 | img = img.convert('RGB') 164 | 165 | np_img = np.array(img, dtype=np.float32) 166 | np_img = 255 * gain * ((np_img / 255) ** gamma) 167 | np_img = np.uint8(np.clip(np_img, 0, 255)) 168 | 169 | img = Image.fromarray(np_img, 'RGB').convert(input_mode) 170 | return img 171 | 172 | 173 | class Compose(object): 174 | """Composes several transforms together. 175 | 176 | Args: 177 | transforms (list of ``Transform`` objects): list of transforms to compose. 178 | 179 | Example: 180 | >>> transforms.Compose([ 181 | >>> transforms.CenterCrop(10), 182 | >>> transforms.ToTensor(), 183 | >>> ]) 184 | """ 185 | 186 | def __init__(self, transforms): 187 | self.transforms = transforms 188 | 189 | def __call__(self, img): 190 | for t in self.transforms: 191 | img = t(img) 192 | return img 193 | 194 | 195 | class ToTensor(object): 196 | """Convert a ``numpy.ndarray`` to tensor. 197 | 198 | Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). 199 | """ 200 | 201 | def __call__(self, img): 202 | """Convert a ``numpy.ndarray`` to tensor. 203 | 204 | Args: 205 | img (numpy.ndarray): Image to be converted to tensor. 206 | 207 | Returns: 208 | Tensor: Converted image. 209 | """ 210 | if not(_is_numpy_image(img)): 211 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 212 | 213 | if isinstance(img, np.ndarray): 214 | # handle numpy array 215 | if img.ndim == 3: 216 | img = torch.from_numpy(img.transpose((2, 0, 1)).copy()) 217 | elif img.ndim == 2: 218 | img = torch.from_numpy(img.copy()) 219 | else: 220 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 221 | 222 | # backward compatibility 223 | # return img.float().div(255) 224 | return img.float() 225 | 226 | 227 | class NormalizeNumpyArray(object): 228 | """Normalize a ``numpy.ndarray`` with mean and standard deviation. 229 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 230 | will normalize each channel of the input ``numpy.ndarray`` i.e. 231 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 232 | 233 | Args: 234 | mean (sequence): Sequence of means for each channel. 235 | std (sequence): Sequence of standard deviations for each channel. 236 | """ 237 | 238 | def __init__(self, mean, std): 239 | self.mean = mean 240 | self.std = std 241 | 242 | def __call__(self, img): 243 | """ 244 | Args: 245 | img (numpy.ndarray): Image of size (H, W, C) to be normalized. 246 | 247 | Returns: 248 | Tensor: Normalized image. 249 | """ 250 | if not(_is_numpy_image(img)): 251 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 252 | # TODO: make efficient 253 | print(img.shape) 254 | for i in range(3): 255 | img[:,:,i] = (img[:,:,i] - self.mean[i]) / self.std[i] 256 | return img 257 | 258 | class NormalizeTensor(object): 259 | """Normalize an tensor image with mean and standard deviation. 260 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform 261 | will normalize each channel of the input ``torch.*Tensor`` i.e. 262 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 263 | 264 | Args: 265 | mean (sequence): Sequence of means for each channel. 266 | std (sequence): Sequence of standard deviations for each channel. 267 | """ 268 | 269 | def __init__(self, mean, std): 270 | self.mean = mean 271 | self.std = std 272 | 273 | def __call__(self, tensor): 274 | """ 275 | Args: 276 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 277 | 278 | Returns: 279 | Tensor: Normalized Tensor image. 280 | """ 281 | if not _is_tensor_image(tensor): 282 | raise TypeError('tensor is not a torch image.') 283 | # TODO: make efficient 284 | for t, m, s in zip(tensor, self.mean, self.std): 285 | t.sub_(m).div_(s) 286 | return tensor 287 | 288 | class Rotate(object): 289 | """Rotates the given ``numpy.ndarray``. 290 | 291 | Args: 292 | angle (float): The rotation angle in degrees. 293 | """ 294 | 295 | def __init__(self, angle): 296 | self.angle = angle 297 | 298 | def __call__(self, img): 299 | """ 300 | Args: 301 | img (numpy.ndarray (C x H x W)): Image to be rotated. 302 | 303 | Returns: 304 | img (numpy.ndarray (C x H x W)): Rotated image. 305 | """ 306 | 307 | # order=0 means nearest-neighbor type interpolation 308 | return itpl.rotate(img, self.angle, reshape=False, prefilter=False, order=0) 309 | 310 | 311 | class Resize(object): 312 | """Resize the the given ``numpy.ndarray`` to the given size. 313 | Args: 314 | size (sequence or int): Desired output size. If size is a sequence like 315 | (h, w), output size will be matched to this. If size is an int, 316 | smaller edge of the image will be matched to this number. 317 | i.e, if height > width, then image will be rescaled to 318 | (size * height / width, size) 319 | interpolation (int, optional): Desired interpolation. Default is 320 | ``PIL.Image.BILINEAR`` 321 | """ 322 | 323 | def __init__(self, size, interpolation='nearest'): 324 | assert isinstance(size, int) or isinstance(size, float) or \ 325 | (isinstance(size, collections.Iterable) and len(size) == 2) 326 | self.size = size 327 | self.interpolation = interpolation 328 | 329 | def __call__(self, img): 330 | """ 331 | Args: 332 | img (PIL Image): Image to be scaled. 333 | Returns: 334 | PIL Image: Rescaled image. 335 | """ 336 | if img.ndim == 3: 337 | return misc.imresize(img, self.size, self.interpolation) 338 | elif img.ndim == 2: 339 | return misc.imresize(img, self.size, self.interpolation, 'F') 340 | else: 341 | RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 342 | 343 | 344 | class CenterCrop(object): 345 | """Crops the given ``numpy.ndarray`` at the center. 346 | 347 | Args: 348 | size (sequence or int): Desired output size of the crop. If size is an 349 | int instead of sequence like (h, w), a square crop (size, size) is 350 | made. 351 | """ 352 | 353 | def __init__(self, size): 354 | if isinstance(size, numbers.Number): 355 | self.size = (int(size), int(size)) 356 | else: 357 | self.size = size 358 | 359 | @staticmethod 360 | def get_params(img, output_size): 361 | """Get parameters for ``crop`` for center crop. 362 | 363 | Args: 364 | img (numpy.ndarray (C x H x W)): Image to be cropped. 365 | output_size (tuple): Expected output size of the crop. 366 | 367 | Returns: 368 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop. 369 | """ 370 | h = img.shape[0] 371 | w = img.shape[1] 372 | th, tw = output_size 373 | i = int(round((h - th) / 2.)) 374 | j = int(round((w - tw) / 2.)) 375 | 376 | # # randomized cropping 377 | # i = np.random.randint(i-3, i+4) 378 | # j = np.random.randint(j-3, j+4) 379 | 380 | return i, j, th, tw 381 | 382 | def __call__(self, img): 383 | """ 384 | Args: 385 | img (numpy.ndarray (C x H x W)): Image to be cropped. 386 | 387 | Returns: 388 | img (numpy.ndarray (C x H x W)): Cropped image. 389 | """ 390 | i, j, h, w = self.get_params(img, self.size) 391 | 392 | """ 393 | i: Upper pixel coordinate. 394 | j: Left pixel coordinate. 395 | h: Height of the cropped image. 396 | w: Width of the cropped image. 397 | """ 398 | if not(_is_numpy_image(img)): 399 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 400 | if img.ndim == 3: 401 | return img[i:i+h, j:j+w, :] 402 | elif img.ndim == 2: 403 | return img[i:i + h, j:j + w] 404 | else: 405 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 406 | 407 | 408 | class Lambda(object): 409 | """Apply a user-defined lambda as a transform. 410 | 411 | Args: 412 | lambd (function): Lambda/function to be used for transform. 413 | """ 414 | 415 | def __init__(self, lambd): 416 | assert isinstance(lambd, types.LambdaType) 417 | self.lambd = lambd 418 | 419 | def __call__(self, img): 420 | return self.lambd(img) 421 | 422 | 423 | class HorizontalFlip(object): 424 | """Horizontally flip the given ``numpy.ndarray``. 425 | 426 | Args: 427 | do_flip (boolean): whether or not do horizontal flip. 428 | 429 | """ 430 | 431 | def __init__(self, do_flip): 432 | self.do_flip = do_flip 433 | 434 | def __call__(self, img): 435 | """ 436 | Args: 437 | img (numpy.ndarray (C x H x W)): Image to be flipped. 438 | 439 | Returns: 440 | img (numpy.ndarray (C x H x W)): flipped image. 441 | """ 442 | if not(_is_numpy_image(img)): 443 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 444 | 445 | if self.do_flip: 446 | return np.fliplr(img) 447 | else: 448 | return img 449 | 450 | 451 | class ColorJitter(object): 452 | """Randomly change the brightness, contrast and saturation of an image. 453 | 454 | Args: 455 | brightness (float): How much to jitter brightness. brightness_factor 456 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 457 | contrast (float): How much to jitter contrast. contrast_factor 458 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 459 | saturation (float): How much to jitter saturation. saturation_factor 460 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 461 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 462 | [-hue, hue]. Should be >=0 and <= 0.5. 463 | """ 464 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 465 | self.brightness = brightness 466 | self.contrast = contrast 467 | self.saturation = saturation 468 | self.hue = hue 469 | 470 | @staticmethod 471 | def get_params(brightness, contrast, saturation, hue): 472 | """Get a randomized transform to be applied on image. 473 | 474 | Arguments are same as that of __init__. 475 | 476 | Returns: 477 | Transform which randomly adjusts brightness, contrast and 478 | saturation in a random order. 479 | """ 480 | transforms = [] 481 | if brightness > 0: 482 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 483 | transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) 484 | 485 | if contrast > 0: 486 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 487 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) 488 | 489 | if saturation > 0: 490 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 491 | transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) 492 | 493 | if hue > 0: 494 | hue_factor = np.random.uniform(-hue, hue) 495 | transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) 496 | 497 | np.random.shuffle(transforms) 498 | transform = Compose(transforms) 499 | 500 | return transform 501 | 502 | def __call__(self, img): 503 | """ 504 | Args: 505 | img (numpy.ndarray (C x H x W)): Input image. 506 | 507 | Returns: 508 | img (numpy.ndarray (C x H x W)): Color jittered image. 509 | """ 510 | if not(_is_numpy_image(img)): 511 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 512 | 513 | pil = Image.fromarray(img) 514 | transform = self.get_params(self.brightness, self.contrast, 515 | self.saturation, self.hue) 516 | return np.array(transform(pil)) 517 | 518 | class Crop(object): 519 | """Crops the given PIL Image to a rectangular region based on a given 520 | 4-tuple defining the left, upper pixel coordinated, hight and width size. 521 | 522 | Args: 523 | a tuple: (upper pixel coordinate, left pixel coordinate, hight, width)-tuple 524 | """ 525 | 526 | def __init__(self, i, j, h, w): 527 | """ 528 | i: Upper pixel coordinate. 529 | j: Left pixel coordinate. 530 | h: Height of the cropped image. 531 | w: Width of the cropped image. 532 | """ 533 | self.i = i 534 | self.j = j 535 | self.h = h 536 | self.w = w 537 | 538 | def __call__(self, img): 539 | """ 540 | Args: 541 | img (numpy.ndarray (C x H x W)): Image to be cropped. 542 | Returns: 543 | img (numpy.ndarray (C x H x W)): Cropped image. 544 | """ 545 | 546 | i, j, h, w = self.i, self.j, self.h, self.w 547 | 548 | if not(_is_numpy_image(img)): 549 | raise TypeError('img should be ndarray. Got {}'.format(type(img))) 550 | if img.ndim == 3: 551 | return img[i:i + h, j:j + w, :] 552 | elif img.ndim == 2: 553 | return img[i:i + h, j:j + w] 554 | else: 555 | raise RuntimeError( 556 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim)) 557 | 558 | def __repr__(self): 559 | return self.__class__.__name__ + '(i={0},j={1},h={2},w={3})'.format( 560 | self.i, self.j, self.h, self.w) 561 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/21 15:25 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ 7 | 8 | from datetime import datetime 9 | import shutil 10 | import socket 11 | import time 12 | import torch 13 | from tensorboardX import SummaryWriter 14 | from torch.optim import lr_scheduler 15 | 16 | from dataloaders import kitti_dataloader, nyu_dataloader 17 | from dataloaders.path import Path 18 | from metrics import AverageMeter, Result 19 | import utils 20 | import criteria 21 | import os 22 | import torch.nn as nn 23 | 24 | import numpy as np 25 | 26 | from network import FCRN 27 | 28 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" # use single GPU 29 | 30 | args = utils.parse_command() 31 | print(args) 32 | 33 | best_result = Result() 34 | best_result.set_to_worst() 35 | 36 | 37 | def create_loader(args): 38 | traindir = os.path.join(Path.db_root_dir(args.dataset), 'train') 39 | if os.path.exists(traindir): 40 | print('Train dataset "{}" is existed!'.format(traindir)) 41 | else: 42 | print('Train dataset "{}" is not existed!'.format(traindir)) 43 | exit(-1) 44 | 45 | valdir = os.path.join(Path.db_root_dir(args.dataset), 'val') 46 | if os.path.exists(traindir): 47 | print('Train dataset "{}" is existed!'.format(valdir)) 48 | else: 49 | print('Train dataset "{}" is not existed!'.format(valdir)) 50 | exit(-1) 51 | 52 | if args.dataset == 'kitti': 53 | train_set = kitti_dataloader.KITTIDataset(traindir, type='train') 54 | val_set = kitti_dataloader.KITTIDataset(valdir, type='val') 55 | 56 | # sample 3200 pictures for validation from val set 57 | weights = [1 for i in range(len(val_set))] 58 | print('weights:', len(weights)) 59 | sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=3200) 60 | elif args.dataset == 'nyu': 61 | train_set = nyu_dataloader.NYUDataset(traindir, type='train') 62 | val_set = nyu_dataloader.NYUDataset(valdir, type='val') 63 | else: 64 | print('no dataset named as ', args.dataset) 65 | exit(-1) 66 | 67 | train_loader = torch.utils.data.DataLoader( 68 | train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 69 | 70 | if args.dataset == 'kitti': 71 | val_loader = torch.utils.data.DataLoader( 72 | val_set, batch_size=args.batch_size, sampler=sampler, num_workers=args.workers, pin_memory=True) 73 | else: 74 | val_loader = torch.utils.data.DataLoader( 75 | val_set, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) 76 | 77 | return train_loader, val_loader 78 | 79 | 80 | def main(): 81 | global args, best_result, output_directory 82 | 83 | # set random seed 84 | torch.manual_seed(args.manual_seed) 85 | 86 | if torch.cuda.device_count() > 1: 87 | print("Let's use", torch.cuda.device_count(), "GPUs!") 88 | args.batch_size = args.batch_size * torch.cuda.device_count() 89 | else: 90 | print("Let's use GPU ", torch.cuda.current_device()) 91 | 92 | train_loader, val_loader = create_loader(args) 93 | 94 | if args.resume: 95 | assert os.path.isfile(args.resume), \ 96 | "=> no checkpoint found at '{}'".format(args.resume) 97 | print("=> loading checkpoint '{}'".format(args.resume)) 98 | checkpoint = torch.load(args.resume) 99 | 100 | start_epoch = checkpoint['epoch'] + 1 101 | best_result = checkpoint['best_result'] 102 | optimizer = checkpoint['optimizer'] 103 | 104 | # model_dict = checkpoint['model'].module.state_dict() # to load the trained model using multi-GPUs 105 | # model = FCRN.ResNet(output_size=train_loader.dataset.output_size, pretrained=False) 106 | # model.load_state_dict(model_dict) 107 | 108 | # solve 'out of memory' 109 | model = checkpoint['model'] 110 | 111 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) 112 | 113 | # clear memory 114 | del checkpoint 115 | # del model_dict 116 | torch.cuda.empty_cache() 117 | else: 118 | print("=> creating Model") 119 | model = FCRN.ResNet(output_size=train_loader.dataset.output_size) 120 | print("=> model created.") 121 | start_epoch = 0 122 | 123 | # different modules have different learning rate 124 | train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr}, 125 | {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}] 126 | 127 | optimizer = torch.optim.SGD(train_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 128 | 129 | # You can use DataParallel() whether you use Multi-GPUs or not 130 | model = nn.DataParallel(model).cuda() 131 | 132 | # when training, use reduceLROnPlateau to reduce learning rate 133 | scheduler = lr_scheduler.ReduceLROnPlateau( 134 | optimizer, 'min', patience=args.lr_patience) 135 | 136 | # loss function 137 | criterion = criteria.MaskedL1Loss() 138 | 139 | # create directory path 140 | output_directory = utils.get_output_directory(args) 141 | if not os.path.exists(output_directory): 142 | os.makedirs(output_directory) 143 | best_txt = os.path.join(output_directory, 'best.txt') 144 | config_txt = os.path.join(output_directory, 'config.txt') 145 | 146 | # write training parameters to config file 147 | if not os.path.exists(config_txt): 148 | with open(config_txt, 'w') as txtfile: 149 | args_ = vars(args) 150 | args_str = '' 151 | for k, v in args_.items(): 152 | args_str = args_str + str(k) + ':' + str(v) + ',\t\n' 153 | txtfile.write(args_str) 154 | 155 | # create log 156 | log_path = os.path.join(output_directory, 'logs', 157 | datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) 158 | if os.path.isdir(log_path): 159 | shutil.rmtree(log_path) 160 | os.makedirs(log_path) 161 | logger = SummaryWriter(log_path) 162 | 163 | for epoch in range(start_epoch, args.epochs): 164 | 165 | # remember change of the learning rate 166 | for i, param_group in enumerate(optimizer.param_groups): 167 | old_lr = float(param_group['lr']) 168 | logger.add_scalar('Lr/lr_' + str(i), old_lr, epoch) 169 | 170 | train(train_loader, model, criterion, optimizer, epoch, logger) # train for one epoch 171 | result, img_merge = validate(val_loader, model, epoch, logger) # evaluate on validation set 172 | 173 | # remember best rmse and save checkpoint 174 | is_best = result.rmse < best_result.rmse 175 | if is_best: 176 | best_result = result 177 | with open(best_txt, 'w') as txtfile: 178 | txtfile.write( 179 | "epoch={}, rmse={:.3f}, rml={:.3f}, log10={:.3f}, d1={:.3f}, d2={:.3f}, dd31={:.3f}, " 180 | "t_gpu={:.4f}". 181 | format(epoch, result.rmse, result.absrel, result.lg10, result.delta1, result.delta2, 182 | result.delta3, 183 | result.gpu_time)) 184 | if img_merge is not None: 185 | img_filename = output_directory + '/comparison_best.png' 186 | utils.save_image(img_merge, img_filename) 187 | 188 | # save checkpoint for each epoch 189 | utils.save_checkpoint({ 190 | 'args': args, 191 | 'epoch': epoch, 192 | 'model': model, 193 | 'best_result': best_result, 194 | 'optimizer': optimizer, 195 | }, is_best, epoch, output_directory) 196 | 197 | # when rml doesn't fall, reduce learning rate 198 | scheduler.step(result.absrel) 199 | 200 | logger.close() 201 | 202 | 203 | # train 204 | def train(train_loader, model, criterion, optimizer, epoch, logger): 205 | average_meter = AverageMeter() 206 | model.train() # switch to train mode 207 | end = time.time() 208 | 209 | batch_num = len(train_loader) 210 | 211 | for i, (input, target) in enumerate(train_loader): 212 | 213 | # itr_count += 1 214 | input, target = input.cuda(), target.cuda() 215 | # print('input size = ', input.size()) 216 | # print('target size = ', target.size()) 217 | torch.cuda.synchronize() 218 | data_time = time.time() - end 219 | 220 | # compute pred 221 | end = time.time() 222 | 223 | pred = model(input) # @wx 注意输出 224 | 225 | # print('pred size = ', pred.size()) 226 | # print('target size = ', target.size()) 227 | 228 | loss = criterion(pred, target) 229 | optimizer.zero_grad() 230 | loss.backward() # compute gradient and do SGD step 231 | optimizer.step() 232 | torch.cuda.synchronize() 233 | gpu_time = time.time() - end 234 | 235 | # measure accuracy and record loss 236 | result = Result() 237 | result.evaluate(pred.data, target.data) 238 | average_meter.update(result, gpu_time, data_time, input.size(0)) 239 | end = time.time() 240 | 241 | if (i + 1) % args.print_freq == 0: 242 | print('=> output: {}'.format(output_directory)) 243 | print('Train Epoch: {0} [{1}/{2}]\t' 244 | 't_Data={data_time:.3f}({average.data_time:.3f}) ' 245 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 246 | 'Loss={Loss:.5f} ' 247 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 248 | 'RML={result.absrel:.2f}({average.absrel:.2f}) ' 249 | 'Log10={result.lg10:.3f}({average.lg10:.3f}) ' 250 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 251 | 'Delta2={result.delta2:.3f}({average.delta2:.3f}) ' 252 | 'Delta3={result.delta3:.3f}({average.delta3:.3f})'.format( 253 | epoch, i + 1, len(train_loader), data_time=data_time, 254 | gpu_time=gpu_time, Loss=loss.item(), result=result, average=average_meter.average())) 255 | current_step = epoch * batch_num + i 256 | logger.add_scalar('Train/RMSE', result.rmse, current_step) 257 | logger.add_scalar('Train/rml', result.absrel, current_step) 258 | logger.add_scalar('Train/Log10', result.lg10, current_step) 259 | logger.add_scalar('Train/Delta1', result.delta1, current_step) 260 | logger.add_scalar('Train/Delta2', result.delta2, current_step) 261 | logger.add_scalar('Train/Delta3', result.delta3, current_step) 262 | 263 | avg = average_meter.average() 264 | 265 | 266 | # validation 267 | def validate(val_loader, model, epoch, logger): 268 | average_meter = AverageMeter() 269 | 270 | model.eval() # switch to evaluate mode 271 | 272 | end = time.time() 273 | 274 | skip = len(val_loader) // 9 # save images every skip iters 275 | 276 | for i, (input, target) in enumerate(val_loader): 277 | 278 | input, target = input.cuda(), target.cuda() 279 | torch.cuda.synchronize() 280 | data_time = time.time() - end 281 | 282 | # compute output 283 | end = time.time() 284 | with torch.no_grad(): 285 | pred = model(input) 286 | 287 | torch.cuda.synchronize() 288 | gpu_time = time.time() - end 289 | 290 | # measure accuracy and record loss 291 | result = Result() 292 | result.evaluate(pred.data, target.data) 293 | 294 | average_meter.update(result, gpu_time, data_time, input.size(0)) 295 | end = time.time() 296 | 297 | # save 8 images for visualization 298 | if args.dataset == 'kitti': 299 | rgb = input[0] 300 | pred = pred[0] 301 | target = target[0] 302 | else: 303 | rgb = input 304 | 305 | if i == 0: 306 | img_merge = utils.merge_into_row(rgb, target, pred) 307 | elif (i < 8 * skip) and (i % skip == 0): 308 | row = utils.merge_into_row(rgb, target, pred) 309 | img_merge = utils.add_row(img_merge, row) 310 | elif i == 8 * skip: 311 | filename = output_directory + '/comparison_' + str(epoch) + '.png' 312 | utils.save_image(img_merge, filename) 313 | 314 | if (i + 1) % args.print_freq == 0: 315 | print('Test: [{0}/{1}]\t' 316 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t' 317 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) ' 318 | 'RML={result.absrel:.2f}({average.absrel:.2f}) ' 319 | 'Log10={result.lg10:.3f}({average.lg10:.3f}) ' 320 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) ' 321 | 'Delta2={result.delta2:.3f}({average.delta2:.3f}) ' 322 | 'Delta3={result.delta3:.3f}({average.delta3:.3f})'.format( 323 | i + 1, len(val_loader), gpu_time=gpu_time, result=result, average=average_meter.average())) 324 | 325 | avg = average_meter.average() 326 | 327 | print('\n*\n' 328 | 'RMSE={average.rmse:.3f}\n' 329 | 'Rel={average.absrel:.3f}\n' 330 | 'Log10={average.lg10:.3f}\n' 331 | 'Delta1={average.delta1:.3f}\n' 332 | 'Delta2={average.delta2:.3f}\n' 333 | 'Delta3={average.delta3:.3f}\n' 334 | 't_GPU={time:.3f}\n'.format( 335 | average=avg, time=avg.gpu_time)) 336 | 337 | logger.add_scalar('Test/rmse', avg.rmse, epoch) 338 | logger.add_scalar('Test/Rel', avg.absrel, epoch) 339 | logger.add_scalar('Test/log10', avg.lg10, epoch) 340 | logger.add_scalar('Test/Delta1', avg.delta1, epoch) 341 | logger.add_scalar('Test/Delta2', avg.delta2, epoch) 342 | logger.add_scalar('Test/Delta3', avg.delta3, epoch) 343 | return avg, img_merge 344 | 345 | 346 | if __name__ == '__main__': 347 | main() 348 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/10/23 19:53 3 | # @Author : Wang Xin 4 | # @Email : wangxin_buaa@163.com 5 | 6 | 7 | import torch 8 | import math 9 | import numpy as np 10 | 11 | 12 | def log10(x): 13 | """Convert a new tensor with the base-10 logarithm of the elements of x. """ 14 | return torch.log(x) / math.log(10) 15 | 16 | 17 | class Result(object): 18 | def __init__(self): 19 | self.irmse, self.imae = 0, 0 20 | self.mse, self.rmse, self.mae = 0, 0, 0 21 | self.absrel, self.lg10 = 0, 0 22 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 23 | self.data_time, self.gpu_time = 0, 0 24 | 25 | def set_to_worst(self): 26 | self.irmse, self.imae = np.inf, np.inf 27 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf 28 | self.absrel, self.lg10 = np.inf, np.inf 29 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 30 | self.data_time, self.gpu_time = 0, 0 31 | 32 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time): 33 | self.irmse, self.imae = irmse, imae 34 | self.mse, self.rmse, self.mae = mse, rmse, mae 35 | self.absrel, self.lg10 = absrel, lg10 36 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3 37 | self.data_time, self.gpu_time = data_time, gpu_time 38 | 39 | def evaluate(self, output, target): 40 | valid_mask = target > 0 41 | output = output[valid_mask] 42 | target = target[valid_mask] 43 | 44 | abs_diff = (output - target).abs() 45 | 46 | self.mse = float((torch.pow(abs_diff, 2)).mean()) 47 | self.rmse = math.sqrt(self.mse) 48 | self.mae = float(abs_diff.mean()) 49 | self.lg10 = float((log10(output) - log10(target)).abs().mean()) 50 | self.absrel = float((abs_diff / target).mean()) 51 | 52 | maxRatio = torch.max(output / target, target / output) 53 | self.delta1 = float((maxRatio < 1.25).float().mean()) 54 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean()) 55 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean()) 56 | self.data_time = 0 57 | self.gpu_time = 0 58 | 59 | inv_output = 1 / output 60 | inv_target = 1 / target 61 | abs_inv_diff = (inv_output - inv_target).abs() 62 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 63 | self.imae = float(abs_inv_diff.mean()) 64 | 65 | 66 | class AverageMeter(object): 67 | def __init__(self): 68 | self.reset() 69 | 70 | def reset(self): 71 | self.count = 0.0 72 | 73 | self.sum_irmse, self.sum_imae = 0, 0 74 | self.sum_mse, self.sum_rmse, self.sum_mae = 0, 0, 0 75 | self.sum_absrel, self.sum_lg10 = 0, 0 76 | self.sum_delta1, self.sum_delta2, self.sum_delta3 = 0, 0, 0 77 | self.sum_data_time, self.sum_gpu_time = 0, 0 78 | 79 | def update(self, result, gpu_time, data_time, n=1): 80 | self.count += n 81 | 82 | self.sum_irmse += n * result.irmse 83 | self.sum_imae += n * result.imae 84 | self.sum_mse += n * result.mse 85 | self.sum_rmse += n * result.rmse 86 | self.sum_mae += n * result.mae 87 | self.sum_absrel += n * result.absrel 88 | self.sum_lg10 += n * result.lg10 89 | self.sum_delta1 += n * result.delta1 90 | self.sum_delta2 += n * result.delta2 91 | self.sum_delta3 += n * result.delta3 92 | self.sum_data_time += n * data_time 93 | self.sum_gpu_time += n * gpu_time 94 | 95 | def average(self): 96 | avg = Result() 97 | avg.update( 98 | self.sum_irmse / self.count, self.sum_imae / self.count, 99 | self.sum_mse / self.count, self.sum_rmse / self.count, self.sum_mae / self.count, 100 | self.sum_absrel / self.count, self.sum_lg10 / self.count, 101 | self.sum_delta1 / self.count, self.sum_delta2 / self.count, self.sum_delta3 / self.count, 102 | self.sum_gpu_time / self.count, self.sum_data_time / self.count) 103 | return avg 104 | -------------------------------------------------------------------------------- /network/FCRN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/11/22 12:33 3 | # @Author : Wang Xin 4 | # @Email : wangxin_buaa@163.com 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.models 10 | import collections 11 | import math 12 | 13 | 14 | def weights_init(m): 15 | # Initialize filters with Gaussian random weights 16 | if isinstance(m, nn.Conv2d): 17 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 18 | m.weight.data.normal_(0, math.sqrt(2. / n)) 19 | if m.bias is not None: 20 | m.bias.data.zero_() 21 | elif isinstance(m, nn.ConvTranspose2d): 22 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels 23 | m.weight.data.normal_(0, math.sqrt(2. / n)) 24 | if m.bias is not None: 25 | m.bias.data.zero_() 26 | elif isinstance(m, nn.BatchNorm2d): 27 | m.weight.data.fill_(1) 28 | m.bias.data.zero_() 29 | 30 | 31 | class Unpool(nn.Module): 32 | # Unpool: 2*2 unpooling with zero padding 33 | def __init__(self, num_channels, stride=2): 34 | super(Unpool, self).__init__() 35 | 36 | self.num_channels = num_channels 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | weights = torch.zeros(self.num_channels, 1, self.stride, self.stride) 41 | if torch.cuda.is_available(): 42 | weights = weights.cuda() 43 | weights[:, :, 0, 0] = 1 44 | return F.conv_transpose2d(x, weights, stride=self.stride, groups=self.num_channels) 45 | 46 | 47 | class Decoder(nn.Module): 48 | # Decoder is the base class for all decoders 49 | 50 | names = ['deconv2', 'deconv3', 'upconv', 'upproj'] 51 | 52 | def __init__(self): 53 | super(Decoder, self).__init__() 54 | 55 | self.layer1 = None 56 | self.layer2 = None 57 | self.layer3 = None 58 | self.layer4 = None 59 | 60 | def forward(self, x): 61 | x = self.layer1(x) 62 | x = self.layer2(x) 63 | x = self.layer3(x) 64 | x = self.layer4(x) 65 | return x 66 | 67 | 68 | class DeConv(Decoder): 69 | def __init__(self, in_channels, kernel_size): 70 | assert kernel_size >= 2, "kernel_size out of range: {}".format(kernel_size) 71 | super(DeConv, self).__init__() 72 | 73 | def convt(in_channels): 74 | stride = 2 75 | padding = (kernel_size - 1) // 2 76 | output_padding = kernel_size % 2 77 | assert -2 - 2 * padding + kernel_size + output_padding == 0, "deconv parameters incorrect" 78 | 79 | module_name = "deconv{}".format(kernel_size) 80 | return nn.Sequential(collections.OrderedDict([ 81 | (module_name, nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size, 82 | stride, padding, output_padding, bias=False)), 83 | ('batchnorm', nn.BatchNorm2d(in_channels // 2)), 84 | ('relu', nn.ReLU(inplace=True)), 85 | ])) 86 | 87 | self.layer1 = convt(in_channels) 88 | self.layer2 = convt(in_channels // 2) 89 | self.layer3 = convt(in_channels // (2 ** 2)) 90 | self.layer4 = convt(in_channels // (2 ** 3)) 91 | 92 | 93 | class UpConv(Decoder): 94 | # UpConv decoder consists of 4 upconv modules with decreasing number of channels and increasing feature map size 95 | def upconv_module(self, in_channels): 96 | # UpConv module: unpool -> 5*5 conv -> batchnorm -> ReLU 97 | upconv = nn.Sequential(collections.OrderedDict([ 98 | ('unpool', Unpool(in_channels)), 99 | ('conv', nn.Conv2d(in_channels, in_channels // 2, kernel_size=5, stride=1, padding=2, bias=False)), 100 | ('batchnorm', nn.BatchNorm2d(in_channels // 2)), 101 | ('relu', nn.ReLU()), 102 | ])) 103 | return upconv 104 | 105 | def __init__(self, in_channels): 106 | super(UpConv, self).__init__() 107 | self.layer1 = self.upconv_module(in_channels) 108 | self.layer2 = self.upconv_module(in_channels // 2) 109 | self.layer3 = self.upconv_module(in_channels // 4) 110 | self.layer4 = self.upconv_module(in_channels // 8) 111 | 112 | 113 | class FasterUpConv(Decoder): 114 | # Faster Upconv using pixelshuffle 115 | 116 | class faster_upconv_module(nn.Module): 117 | 118 | def __init__(self, in_channel): 119 | super(FasterUpConv.faster_upconv_module, self).__init__() 120 | 121 | self.conv1_ = nn.Sequential(collections.OrderedDict([ 122 | ('conv1', nn.Conv2d(in_channel, in_channel // 2, kernel_size=3)), 123 | ('bn1', nn.BatchNorm2d(in_channel // 2)), 124 | ])) 125 | 126 | self.conv2_ = nn.Sequential(collections.OrderedDict([ 127 | ('conv1', nn.Conv2d(in_channel, in_channel // 2, kernel_size=(2, 3))), 128 | ('bn1', nn.BatchNorm2d(in_channel // 2)), 129 | ])) 130 | 131 | self.conv3_ = nn.Sequential(collections.OrderedDict([ 132 | ('conv1', nn.Conv2d(in_channel, in_channel // 2, kernel_size=(3, 2))), 133 | ('bn1', nn.BatchNorm2d(in_channel // 2)), 134 | ])) 135 | 136 | self.conv4_ = nn.Sequential(collections.OrderedDict([ 137 | ('conv1', nn.Conv2d(in_channel, in_channel // 2, kernel_size=2)), 138 | ('bn1', nn.BatchNorm2d(in_channel // 2)), 139 | ])) 140 | 141 | self.ps = nn.PixelShuffle(2) 142 | self.relu = nn.ReLU(inplace=True) 143 | 144 | def forward(self, x): 145 | # print('Upmodule x size = ', x.size()) 146 | x1 = self.conv1_(nn.functional.pad(x, (1, 1, 1, 1))) 147 | x2 = self.conv2_(nn.functional.pad(x, (1, 1, 0, 1))) 148 | x3 = self.conv3_(nn.functional.pad(x, (0, 1, 1, 1))) 149 | x4 = self.conv4_(nn.functional.pad(x, (0, 1, 0, 1))) 150 | 151 | x = torch.cat((x1, x2, x3, x4), dim=1) 152 | 153 | output = self.ps(x) 154 | output = self.relu(output) 155 | 156 | return output 157 | 158 | def __init__(self, in_channel): 159 | super(FasterUpConv, self).__init__() 160 | 161 | self.layer1 = self.faster_upconv_module(in_channel) 162 | self.layer2 = self.faster_upconv_module(in_channel // 2) 163 | self.layer3 = self.faster_upconv_module(in_channel // 4) 164 | self.layer4 = self.faster_upconv_module(in_channel // 8) 165 | 166 | 167 | class UpProj(Decoder): 168 | # UpProj decoder consists of 4 upproj modules with decreasing number of channels and increasing feature map size 169 | 170 | class UpProjModule(nn.Module): 171 | # UpProj module has two branches, with a Unpool at the start and a ReLu at the end 172 | # upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm 173 | # bottom branch: 5*5 conv -> batchnorm 174 | 175 | def __init__(self, in_channels): 176 | super(UpProj.UpProjModule, self).__init__() 177 | out_channels = in_channels // 2 178 | self.unpool = Unpool(in_channels) 179 | self.upper_branch = nn.Sequential(collections.OrderedDict([ 180 | ('conv1', nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False)), 181 | ('batchnorm1', nn.BatchNorm2d(out_channels)), 182 | ('relu', nn.ReLU()), 183 | ('conv2', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)), 184 | ('batchnorm2', nn.BatchNorm2d(out_channels)), 185 | ])) 186 | self.bottom_branch = nn.Sequential(collections.OrderedDict([ 187 | ('conv', nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False)), 188 | ('batchnorm', nn.BatchNorm2d(out_channels)), 189 | ])) 190 | self.relu = nn.ReLU() 191 | 192 | def forward(self, x): 193 | x = self.unpool(x) 194 | x1 = self.upper_branch(x) 195 | x2 = self.bottom_branch(x) 196 | x = x1 + x2 197 | x = self.relu(x) 198 | return x 199 | 200 | def __init__(self, in_channels): 201 | super(UpProj, self).__init__() 202 | self.layer1 = self.UpProjModule(in_channels) 203 | self.layer2 = self.UpProjModule(in_channels // 2) 204 | self.layer3 = self.UpProjModule(in_channels // 4) 205 | self.layer4 = self.UpProjModule(in_channels // 8) 206 | 207 | 208 | class FasterUpProj(Decoder): 209 | # Faster UpProj decorder using pixelshuffle 210 | 211 | class faster_upconv(nn.Module): 212 | 213 | def __init__(self, in_channel): 214 | super(FasterUpProj.faster_upconv, self).__init__() 215 | 216 | self.conv1_ = nn.Sequential(collections.OrderedDict([ 217 | ('conv1', nn.Conv2d(in_channel, in_channel // 2, kernel_size=3)), 218 | ('bn1', nn.BatchNorm2d(in_channel // 2)), 219 | ])) 220 | 221 | self.conv2_ = nn.Sequential(collections.OrderedDict([ 222 | ('conv1', nn.Conv2d(in_channel, in_channel // 2, kernel_size=(2, 3))), 223 | ('bn1', nn.BatchNorm2d(in_channel // 2)), 224 | ])) 225 | 226 | self.conv3_ = nn.Sequential(collections.OrderedDict([ 227 | ('conv1', nn.Conv2d(in_channel, in_channel // 2, kernel_size=(3, 2))), 228 | ('bn1', nn.BatchNorm2d(in_channel // 2)), 229 | ])) 230 | 231 | self.conv4_ = nn.Sequential(collections.OrderedDict([ 232 | ('conv1', nn.Conv2d(in_channel, in_channel // 2, kernel_size=2)), 233 | ('bn1', nn.BatchNorm2d(in_channel // 2)), 234 | ])) 235 | 236 | self.ps = nn.PixelShuffle(2) 237 | self.relu = nn.ReLU(inplace=True) 238 | 239 | def forward(self, x): 240 | # print('Upmodule x size = ', x.size()) 241 | x1 = self.conv1_(nn.functional.pad(x, (1, 1, 1, 1))) 242 | x2 = self.conv2_(nn.functional.pad(x, (1, 1, 0, 1))) 243 | x3 = self.conv3_(nn.functional.pad(x, (0, 1, 1, 1))) 244 | x4 = self.conv4_(nn.functional.pad(x, (0, 1, 0, 1))) 245 | # print(x1.size(), x2.size(), x3.size(), x4.size()) 246 | 247 | x = torch.cat((x1, x2, x3, x4), dim=1) 248 | 249 | x = self.ps(x) 250 | return x 251 | 252 | class FasterUpProjModule(nn.Module): 253 | def __init__(self, in_channels): 254 | super(FasterUpProj.FasterUpProjModule, self).__init__() 255 | out_channels = in_channels // 2 256 | 257 | self.upper_branch = nn.Sequential(collections.OrderedDict([ 258 | ('faster_upconv', FasterUpProj.faster_upconv(in_channels)), 259 | ('relu', nn.ReLU(inplace=True)), 260 | ('conv', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)), 261 | ('batchnorm', nn.BatchNorm2d(out_channels)), 262 | ])) 263 | self.bottom_branch = FasterUpProj.faster_upconv(in_channels) 264 | self.relu = nn.ReLU(inplace=True) 265 | 266 | def forward(self, x): 267 | x1 = self.upper_branch(x) 268 | x2 = self.bottom_branch(x) 269 | x = x1 + x2 270 | x = self.relu(x) 271 | return x 272 | 273 | def __init__(self, in_channel): 274 | super(FasterUpProj, self).__init__() 275 | 276 | self.layer1 = self.FasterUpProjModule(in_channel) 277 | self.layer2 = self.FasterUpProjModule(in_channel // 2) 278 | self.layer3 = self.FasterUpProjModule(in_channel // 4) 279 | self.layer4 = self.FasterUpProjModule(in_channel // 8) 280 | 281 | 282 | def choose_decoder(decoder, in_channels): 283 | if decoder[:6] == 'deconv': 284 | assert len(decoder) == 7 285 | kernel_size = int(decoder[6]) 286 | return DeConv(in_channels, kernel_size) 287 | elif decoder == "upproj": 288 | return UpProj(in_channels) 289 | elif decoder == "upconv": 290 | return UpConv(in_channels) 291 | elif decoder == "fasterupproj": 292 | return FasterUpProj(in_channels) 293 | else: 294 | assert False, "invalid option for decoder: {}".format(decoder) 295 | 296 | 297 | class ResNet(nn.Module): 298 | def __init__(self, dataset = 'kitti', layers = 50, decoder = 'upproj', output_size=(228, 304), in_channels=3, pretrained=True): 299 | 300 | if layers not in [18, 34, 50, 101, 152]: 301 | raise RuntimeError( 302 | 'Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers)) 303 | 304 | super(ResNet, self).__init__() 305 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained) 306 | 307 | if in_channels == 3: 308 | self.conv1 = pretrained_model._modules['conv1'] 309 | self.bn1 = pretrained_model._modules['bn1'] 310 | else: 311 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 312 | self.bn1 = nn.BatchNorm2d(64) 313 | weights_init(self.conv1) 314 | weights_init(self.bn1) 315 | 316 | self.output_size = output_size 317 | 318 | self.relu = pretrained_model._modules['relu'] 319 | self.maxpool = pretrained_model._modules['maxpool'] 320 | self.layer1 = pretrained_model._modules['layer1'] 321 | self.layer2 = pretrained_model._modules['layer2'] 322 | self.layer3 = pretrained_model._modules['layer3'] 323 | self.layer4 = pretrained_model._modules['layer4'] 324 | 325 | # clear memory 326 | del pretrained_model 327 | 328 | # define number of intermediate channels 329 | if layers <= 34: 330 | num_channels = 512 331 | elif layers >= 50: 332 | num_channels = 2048 333 | 334 | self.conv2 = nn.Conv2d(num_channels, num_channels // 2, kernel_size=1, bias=False) 335 | self.bn2 = nn.BatchNorm2d(num_channels // 2) 336 | 337 | self.upSample = choose_decoder(decoder, num_channels // 2) 338 | 339 | # setting bias=true doesn't improve accuracy 340 | self.conv3 = nn.Conv2d(num_channels // 32, 1, kernel_size=3, stride=1, padding=1, bias=False) 341 | self.bilinear = nn.Upsample(size=self.output_size, mode='bilinear', align_corners=True) 342 | 343 | # weight init 344 | self.conv2.apply(weights_init) 345 | self.bn2.apply(weights_init) 346 | 347 | self.upSample.apply(weights_init) 348 | 349 | self.conv3.apply(weights_init) 350 | 351 | def forward(self, x): 352 | # resnet 353 | x = self.conv1(x) 354 | x = self.bn1(x) 355 | x = self.relu(x) 356 | x = self.maxpool(x) 357 | x1 = self.layer1(x) 358 | x2 = self.layer2(x1) 359 | x3 = self.layer3(x2) 360 | x4 = self.layer4(x3) 361 | 362 | x = self.conv2(x4) 363 | x = self.bn2(x) 364 | 365 | # 上采样 366 | x = self.upSample(x) 367 | 368 | x = self.conv3(x) 369 | x = self.bilinear(x) 370 | 371 | return x 372 | 373 | def get_1x_lr_params(self): 374 | """ 375 | This generator returns all the parameters of the net layer whose learning rate is 1x lr. 376 | """ 377 | b = [self.conv1, self.bn1, self.relu, self.maxpool, self.layer1, self.layer2, self.layer3, self.layer4] 378 | for i in range(len(b)): 379 | for k in b[i].parameters(): 380 | if k.requires_grad: 381 | yield k 382 | 383 | def get_10x_lr_params(self): 384 | """ 385 | This generator returns all the parameters of the net layer whose learning rate is 20x lr. 386 | """ 387 | b = [self.conv2, self.bn2, self.upSample, self.conv3, self.bilinear] 388 | for j in range(len(b)): 389 | for k in b[j].parameters(): 390 | if k.requires_grad: 391 | yield k 392 | 393 | 394 | import time 395 | 396 | if __name__ == "__main__": 397 | model = ResNet(layers=50, output_size=(228, 912)) 398 | model = model.cuda() 399 | model.eval() 400 | image = torch.randn(8, 3, 228, 912) 401 | image = image.cuda() 402 | 403 | gpu_time = time.time() 404 | with torch.no_grad(): 405 | output = model(image) 406 | gpu_time = time.time() - gpu_time 407 | print('gpu_time = ', gpu_time) 408 | print(output.size()) 409 | print(output[0].size()) 410 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2019/1/21 15:11 4 | @Author : Wang Xin 5 | @Email : wangxin_buaa@163.com 6 | """ -------------------------------------------------------------------------------- /result/kitti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dontLoveBugs/FCRN_pytorch/e42845d3fd1f72772912b5696cff65ef1b06be76/result/kitti.png -------------------------------------------------------------------------------- /result/nyu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dontLoveBugs/FCRN_pytorch/e42845d3fd1f72772912b5696cff65ef1b06be76/result/nyu.png -------------------------------------------------------------------------------- /result/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dontLoveBugs/FCRN_pytorch/e42845d3fd1f72772912b5696cff65ef1b06be76/result/result.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/10/21 20:57 3 | # @Author : Wang Xin 4 | # @Email : wangxin_buaa@163.com 5 | import glob 6 | import os 7 | import torch 8 | import shutil 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | 13 | cmap = plt.cm.jet 14 | 15 | 16 | def parse_command(): 17 | modality_names = ['rgb', 'rgbd', 'd'] 18 | 19 | import argparse 20 | parser = argparse.ArgumentParser(description='FCRN') 21 | parser.add_argument('--decoder', default='upproj', type=str) 22 | parser.add_argument('--resume', 23 | default=None, 24 | type=str, metavar='PATH', 25 | help='path to latest checkpoint (default: ./run/run_1/checkpoint-5.pth.tar)') 26 | parser.add_argument('-b', '--batch-size', default=16, type=int, help='mini-batch size (default: 4)') 27 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 28 | help='number of total epochs to run (default: 15)') 29 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, 30 | metavar='LR', help='initial learning rate (default 0.0001)') 31 | parser.add_argument('--lr_patience', default=2, type=int, help='Patience of LR scheduler. ' 32 | 'See documentation of ReduceLROnPlateau.') 33 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 34 | help='momentum') 35 | parser.add_argument('--weight_decay', '--wd', default=0.0005, type=float, 36 | metavar='W', help='weight decay (default: 1e-4)') 37 | parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', 38 | help='number of data loading workers (default: 10)') 39 | parser.add_argument('--dataset', type=str, default="nyu") 40 | parser.add_argument('--manual_seed', default=1, type=int, help='Manually set random seed') 41 | parser.add_argument('--print-freq', '-p', default=10, type=int, 42 | metavar='N', help='print frequency (default: 10)') 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def get_output_directory(args): 48 | if args.resume: 49 | return os.path.dirname(args.resume) 50 | else: 51 | save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__))) 52 | save_dir_root = os.path.join(save_dir_root, 'result', args.decoder) 53 | runs = sorted(glob.glob(os.path.join(save_dir_root, 'run_*'))) 54 | run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0 55 | 56 | save_dir = os.path.join(save_dir_root, 'run_' + str(run_id)) 57 | return save_dir 58 | 59 | 60 | # 保存检查点 61 | def save_checkpoint(state, is_best, epoch, output_directory): 62 | checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch) + '.pth.tar') 63 | torch.save(state, checkpoint_filename) 64 | if is_best: 65 | best_filename = os.path.join(output_directory, 'model_best.pth.tar') 66 | shutil.copyfile(checkpoint_filename, best_filename) 67 | 68 | 69 | def colored_depthmap(depth, d_min=None, d_max=None): 70 | if d_min is None: 71 | d_min = np.min(depth) 72 | if d_max is None: 73 | d_max = np.max(depth) 74 | depth_relative = (depth - d_min) / (d_max - d_min) 75 | return 255 * cmap(depth_relative)[:, :, :3] # H, W, C 76 | 77 | 78 | def merge_into_row(input, depth_target, depth_pred): 79 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1, 2, 0)) # H, W, C 80 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 81 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) 82 | 83 | d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu)) 84 | d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu)) 85 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 86 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 87 | img_merge = np.hstack([rgb, depth_target_col, depth_pred_col]) 88 | 89 | return img_merge 90 | 91 | 92 | def merge_into_row_with_gt(input, depth_input, depth_target, depth_pred): 93 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1, 2, 0)) # H, W, C 94 | depth_input_cpu = np.squeeze(depth_input.cpu().numpy()) 95 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy()) 96 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy()) 97 | 98 | d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu)) 99 | d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.max(depth_pred_cpu)) 100 | depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max) 101 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max) 102 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max) 103 | 104 | img_merge = np.hstack([rgb, depth_input_col, depth_target_col, depth_pred_col]) 105 | 106 | return img_merge 107 | 108 | 109 | def add_row(img_merge, row): 110 | return np.vstack([img_merge, row]) 111 | 112 | 113 | def save_image(img_merge, filename): 114 | img_merge = Image.fromarray(img_merge.astype('uint8')) 115 | img_merge.save(filename) --------------------------------------------------------------------------------