├── .gitignore
├── README.md
├── criteria.py
├── dataloaders
├── dataloader.py
├── deepscene_dataloader.py
├── dense_to_sparse.py
├── kitti_dataloader.py
├── nyu_dataloader.py
├── sun_dataloader.py
├── transforms.py
└── zed_dataloader.py
├── imagenet
├── __init__.py
└── mobilenet.py
├── main.py
├── metrics.py
├── models.py
├── models_fast.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | results
2 | data
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | env/
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # dotenv
86 | .env
87 |
88 | # virtualenv
89 | .venv
90 | venv/
91 | ENV/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | sparse-to-dense.pytorch
2 | ============================
3 |
4 | This repo implements the training and testing of deep regression neural networks for ["Sparse-to-Dense: Depth Prediction from Sparse Depth Samples and a Single Image"](https://arxiv.org/pdf/1709.07492.pdf) by [Fangchang Ma](http://www.mit.edu/~fcma) and [Sertac Karaman](http://karaman.mit.edu/) at MIT. A video demonstration is available on [YouTube](https://youtu.be/vNIIT_M7x7Y).
5 |
6 |
7 |
8 |
9 |
10 | This repo can be used for training and testing of
11 | - RGB (or grayscale image) based depth prediction
12 | - sparse depth based depth prediction
13 | - RGBd (i.e., both RGB and sparse depth) based depth prediction
14 |
15 | The original Torch implementation of the paper can be found [here](https://github.com/fangchangma/sparse-to-dense).
16 |
17 | ## Contents
18 | 0. [Requirements](#requirements)
19 | 0. [Training](#training)
20 | 0. [Testing](#testing)
21 | 0. [Trained Models](#trained-models)
22 | 0. [Benchmark](#benchmark)
23 | 0. [Citation](#citation)
24 |
25 | ## Requirements
26 | This code was tested with Python 3.6 and PyTorch 1.2.0
27 | - Install [PyTorch](http://pytorch.org/) on a machine with CUDA GPU.
28 | - Install the [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) and other dependencies (files in our pre-processed datasets are in HDF5 formats).
29 | ```bash
30 | sudo apt-get update
31 | sudo apt-get install -y libhdf5-serial-dev hdf5-tools
32 | pip3 install h5py matplotlib imageio scipy==1.2.2 scikit-image==0.15.0 opencv-python
33 | ```
34 | - Download the preprocessed [NYU Depth V2](http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) and/or [KITTI Odometry](http://www.cvlibs.net/datasets/kitti/eval_odometry.php) dataset in HDF5 formats, and place them under the `data` folder. The downloading process might take an hour or so. The NYU dataset requires 32G of storage space, and KITTI requires 81G.
35 | ```bash
36 | mkdir data; cd data
37 | wget http://datasets.lids.mit.edu/sparse-to-dense/data/nyudepthv2.tar.gz
38 | tar -xvf nyudepthv2.tar.gz && rm -f nyudepthv2.tar.gz
39 | wget http://datasets.lids.mit.edu/sparse-to-dense/data/kitti.tar.gz
40 | tar -xvf kitti.tar.gz && rm -f kitti.tar.gz
41 | cd ..
42 | ```
43 | ## Training
44 | The training scripts come with several options, which can be listed with the `--help` flag.
45 | ```bash
46 | python3 main.py --help
47 | ```
48 |
49 | For instance, run the following command to train a network with ResNet50 as the encoder, deconvolutions of kernel size 3 as the decoder, and both RGB and 100 random sparse depth samples as the input to the network.
50 | ```bash
51 | python3 main.py -a resnet50 -d deconv3 -m rgbd -s 100 --data nyudepthv2
52 | ```
53 |
54 | Training results will be saved under the `results` folder. To resume a previous training, run
55 | ```bash
56 | python3 main.py --resume [path_to_previous_model]
57 | ```
58 |
59 | ## Testing
60 | To test the performance of a trained model without training, simply run main.py with the `-e` option. For instance,
61 | ```bash
62 | python3 main.py --evaluate [path_to_trained_model]
63 | ```
64 |
65 | ## Trained Models
66 | A number of trained models is available [here](http://datasets.lids.mit.edu/sparse-to-dense.pytorch/results/).
67 |
68 | ## Benchmark
69 | The following numbers are from the original Torch repo.
70 | - Error metrics on NYU Depth v2:
71 |
72 | | RGB | rms | rel | delta1 | delta2 | delta3 |
73 | |-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:|
74 | | [Roy & Todorovic](http://web.engr.oregonstate.edu/~sinisa/research/publications/cvpr16_NRF.pdf) (_CVPR 2016_) | 0.744 | 0.187 | - | - | - |
75 | | [Eigen & Fergus](http://cs.nyu.edu/~deigen/dnl/) (_ICCV 2015_) | 0.641 | 0.158 | 76.9 | 95.0 | 98.8 |
76 | | [Laina et al](https://arxiv.org/pdf/1606.00373.pdf) (_3DV 2016_) | 0.573 | **0.127** | **81.1** | 95.3 | 98.8 |
77 | | Ours-RGB | **0.514** | 0.143 | 81.0 | **95.9** | **98.9** |
78 |
79 | | RGBd-#samples | rms | rel | delta1 | delta2 | delta3 |
80 | |-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:|
81 | | [Liao et al](https://arxiv.org/abs/1611.02174) (_ICRA 2017_)-225 | 0.442 | 0.104 | 87.8 | 96.4 | 98.9 |
82 | | Ours-20 | 0.351 | 0.078 | 92.8 | 98.4 | 99.6 |
83 | | Ours-50 | 0.281 | 0.059 | 95.5 | 99.0 | 99.7 |
84 | | Ours-200| **0.230** | **0.044** | **97.1** | **99.4** | **99.8** |
85 |
86 |
87 |
88 | - Error metrics on KITTI dataset:
89 |
90 | | RGB | rms | rel | delta1 | delta2 | delta3 |
91 | |-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:|
92 | | [Make3D](http://papers.nips.cc/paper/5539-depth-map-prediction-from-a-single-image-using-a-multi-scale-deep-network.pdf) | 8.734 | 0.280 | 60.1 | 82.0 | 92.6 |
93 | | [Mancini et al](https://arxiv.org/pdf/1607.06349.pdf) (_IROS 2016_) | 7.508 | - | 31.8 | 61.7 | 81.3 |
94 | | [Eigen et al](http://papers.nips.cc/paper/5539-depth-map-prediction-from-a-single-image-using-a-multi-scale-deep-network.pdf) (_NIPS 2014_) | 7.156 | **0.190** | **69.2** | 89.9 | **96.7** |
95 | | Ours-RGB | **6.266** | 0.208 | 59.1 | **90.0** | 96.2 |
96 |
97 | | RGBd-#samples | rms | rel | delta1 | delta2 | delta3 |
98 | |-----------------------------|:-----:|:-----:|:-----:|:-----:|:-----:|
99 | | [Cadena et al](https://pdfs.semanticscholar.org/18d5/f0747a23706a344f1d15b032ea22795324fa.pdf) (_RSS 2016_)-650 | 7.14 | 0.179 | 70.9 | 88.8 | 95.6 |
100 | | Ours-50 | 4.884 | 0.109 | 87.1 | 95.2 | 97.9 |
101 | | [Liao et al](https://arxiv.org/abs/1611.02174) (_ICRA 2017_)-225 | 4.50 | 0.113 | 87.4 | 96.0 | 98.4 |
102 | | Ours-100 | 4.303 | 0.095 | 90.0 | 96.3 | 98.3 |
103 | | Ours-200 | 3.851 | 0.083 | 91.9 | 97.0 | 98.6 |
104 | | Ours-500| **3.378** | **0.073** | **93.5** | **97.6** | **98.9** |
105 |
106 |
107 |
108 | Note: our networks are trained on the KITTI odometry dataset, using only sparse labels from laser measurements.
109 |
110 | ## Citation
111 | If you use our code or method in your work, please consider citing the following:
112 |
113 | @article{Ma2017SparseToDense,
114 | title={Sparse-to-Dense: Depth Prediction from Sparse Depth Samples and a Single Image},
115 | author={Ma, Fangchang and Karaman, Sertac},
116 | booktitle={ICRA},
117 | year={2018}
118 | }
119 | @article{ma2018self,
120 | title={Self-supervised Sparse-to-Dense: Self-supervised Depth Completion from LiDAR and Monocular Camera},
121 | author={Ma, Fangchang and Cavalheiro, Guilherme Venturelli and Karaman, Sertac},
122 | journal={arXiv preprint arXiv:1807.00275},
123 | year={2018}
124 | }
125 |
126 | Please create a new issue for code-related questions. Pull requests are welcome.
127 |
--------------------------------------------------------------------------------
/criteria.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 | class MaskedMSELoss(nn.Module):
6 | def __init__(self):
7 | super(MaskedMSELoss, self).__init__()
8 |
9 | def forward(self, pred, target):
10 | assert pred.dim() == target.dim(), "inconsistent dimensions"
11 | valid_mask = (target>0).detach()
12 | diff = target - pred
13 | diff = diff[valid_mask]
14 | self.loss = (diff ** 2).mean()
15 | return self.loss
16 |
17 | class MaskedL1Loss(nn.Module):
18 | def __init__(self):
19 | super(MaskedL1Loss, self).__init__()
20 |
21 | def forward(self, pred, target):
22 | assert pred.dim() == target.dim(), "inconsistent dimensions"
23 | valid_mask = (target>0).detach()
24 | diff = target - pred
25 | diff = diff[valid_mask]
26 | self.loss = diff.abs().mean()
27 | return self.loss
--------------------------------------------------------------------------------
/dataloaders/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import numpy as np
4 | import torch.utils.data as data
5 | import h5py
6 | import dataloaders.transforms as transforms
7 |
8 | IMG_EXTENSIONS = ['.h5',]
9 |
10 | def is_image_file(filename):
11 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
12 |
13 | def find_classes(dir):
14 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
15 | classes.sort()
16 | class_to_idx = {classes[i]: i for i in range(len(classes))}
17 | return classes, class_to_idx
18 |
19 | def make_dataset(dir, class_to_idx):
20 | images = []
21 | dir = os.path.expanduser(dir)
22 | for target in sorted(os.listdir(dir)):
23 | d = os.path.join(dir, target)
24 | if not os.path.isdir(d):
25 | continue
26 | for root, _, fnames in sorted(os.walk(d)):
27 | for fname in sorted(fnames):
28 | if is_image_file(fname):
29 | path = os.path.join(root, fname)
30 | item = (path, class_to_idx[target])
31 | images.append(item)
32 | return images
33 |
34 | def h5_loader(path):
35 | h5f = h5py.File(path, "r")
36 | rgb = np.array(h5f['rgb'])
37 | rgb = np.transpose(rgb, (1, 2, 0))
38 | depth = np.array(h5f['depth'])
39 | return rgb, depth
40 |
41 | # def rgb2grayscale(rgb):
42 | # return rgb[:,:,0] * 0.2989 + rgb[:,:,1] * 0.587 + rgb[:,:,2] * 0.114
43 |
44 | to_tensor = transforms.ToTensor()
45 |
46 | class MyDataloader(data.Dataset):
47 | modality_names = ['rgb', 'rgbd', 'd'] # , 'g', 'gd'
48 | color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4)
49 |
50 | def __init__(self, root, type, sparsifier=None, modality='rgb', loader=h5_loader):
51 | classes, class_to_idx = find_classes(root)
52 | imgs = make_dataset(root, class_to_idx)
53 | assert len(imgs)>0, "Found 0 images in subfolders of: " + root + "\n"
54 | print("Found {} images in {} folder.".format(len(imgs), type))
55 | self.root = root
56 | self.imgs = imgs
57 | self.classes = classes
58 | self.class_to_idx = class_to_idx
59 | if type == 'train':
60 | self.transform = self.train_transform
61 | elif type == 'val':
62 | self.transform = self.val_transform
63 | else:
64 | raise (RuntimeError("Invalid dataset type: " + type + "\n"
65 | "Supported dataset types are: train, val"))
66 | self.loader = loader
67 | self.sparsifier = sparsifier
68 |
69 | assert (modality in self.modality_names), "Invalid modality type: " + modality + "\n" + \
70 | "Supported dataset types are: " + ''.join(self.modality_names)
71 | self.modality = modality
72 |
73 | def train_transform(self, rgb, depth):
74 | raise (RuntimeError("train_transform() is not implemented. "))
75 |
76 | def val_transform(rgb, depth):
77 | raise (RuntimeError("val_transform() is not implemented."))
78 |
79 | def create_sparse_depth(self, rgb, depth):
80 | if self.sparsifier is None:
81 | return depth
82 | else:
83 | mask_keep = self.sparsifier.dense_to_sparse(rgb, depth)
84 | sparse_depth = np.zeros(depth.shape)
85 | sparse_depth[mask_keep] = depth[mask_keep]
86 | return sparse_depth
87 |
88 | def create_rgbd(self, rgb, depth):
89 | sparse_depth = self.create_sparse_depth(rgb, depth)
90 | rgbd = np.append(rgb, np.expand_dims(sparse_depth, axis=2), axis=2)
91 | return rgbd
92 |
93 | def __getraw__(self, index):
94 | """
95 | Args:
96 | index (int): Index
97 |
98 | Returns:
99 | tuple: (rgb, depth) the raw data.
100 | """
101 | path, target = self.imgs[index]
102 | rgb, depth = self.loader(path)
103 | return rgb, depth
104 |
105 | def __getitem__(self, index):
106 | rgb, depth = self.__getraw__(index)
107 |
108 | #print('{:04d} min={:f} max={:f} shape='.format(index, np.amin(depth), np.amax(depth)) + str(depth.shape))
109 |
110 | if self.transform is not None:
111 | rgb_np, depth_np = self.transform(rgb, depth)
112 | else:
113 | raise(RuntimeError("transform not defined"))
114 |
115 | #print('{:04d} min={:f} max={:f} shape='.format(index, np.amin(depth_np), np.amax(depth_np)) + str(depth_np.shape))
116 |
117 | # color normalization
118 | # rgb_tensor = normalize_rgb(rgb_tensor)
119 | # rgb_np = normalize_np(rgb_np)
120 |
121 | if self.modality == 'rgb':
122 | input_np = rgb_np
123 | elif self.modality == 'rgbd':
124 | input_np = self.create_rgbd(rgb_np, depth_np)
125 | elif self.modality == 'd':
126 | input_np = self.create_sparse_depth(rgb_np, depth_np)
127 |
128 | input_tensor = to_tensor(input_np)
129 | while input_tensor.dim() < 3:
130 | input_tensor = input_tensor.unsqueeze(0)
131 | depth_tensor = to_tensor(depth_np)
132 | #print('{:04d} '.format(index) + str(depth_tensor.shape))
133 | depth_tensor = depth_tensor.unsqueeze(0)
134 | #print('{:04d} '.format(index) + str(depth_tensor.shape))
135 |
136 | return input_tensor, depth_tensor
137 |
138 | def __len__(self):
139 | return len(self.imgs)
140 |
141 | # def __get_all_item__(self, index):
142 | # """
143 | # Args:
144 | # index (int): Index
145 |
146 | # Returns:
147 | # tuple: (input_tensor, depth_tensor, input_np, depth_np)
148 | # """
149 | # rgb, depth = self.__getraw__(index)
150 | # if self.transform is not None:
151 | # rgb_np, depth_np = self.transform(rgb, depth)
152 | # else:
153 | # raise(RuntimeError("transform not defined"))
154 |
155 | # # color normalization
156 | # # rgb_tensor = normalize_rgb(rgb_tensor)
157 | # # rgb_np = normalize_np(rgb_np)
158 |
159 | # if self.modality == 'rgb':
160 | # input_np = rgb_np
161 | # elif self.modality == 'rgbd':
162 | # input_np = self.create_rgbd(rgb_np, depth_np)
163 | # elif self.modality == 'd':
164 | # input_np = self.create_sparse_depth(rgb_np, depth_np)
165 |
166 | # input_tensor = to_tensor(input_np)
167 | # while input_tensor.dim() < 3:
168 | # input_tensor = input_tensor.unsqueeze(0)
169 | # depth_tensor = to_tensor(depth_np)
170 | # depth_tensor = depth_tensor.unsqueeze(0)
171 |
172 | # return input_tensor, depth_tensor, input_np, depth_np
173 |
--------------------------------------------------------------------------------
/dataloaders/deepscene_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import numpy as np
4 | import dataloaders.transforms as transforms
5 |
6 | from imageio import imread
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 | to_tensor = transforms.ToTensor()
10 |
11 | iheight, iwidth = 472, 872 # original image size (there is some variation in DeepScene dataset)
12 |
13 | class DeepSceneDataset(Dataset):
14 | def __init__(self, root, type='train', train_extra=True):
15 | self.root = root
16 | self.output_size = (224, 224) #(224, 448)
17 |
18 | # search for images
19 | self.rgb_files, self.depth_files = self.gather_images(os.path.join(root, 'rgb'),
20 | os.path.join(root, 'depth_gray'))
21 |
22 | if type == 'train' and train_extra:
23 | extra_root = root + 'extra'
24 | extra_rgb, extra_depth = self.gather_images(os.path.join(extra_root, 'rgb'),
25 | os.path.join(extra_root, 'depth_gray'))
26 |
27 | if len(self.rgb_files) == 0:
28 | raise (RuntimeError("Empty dataset - found no image pairs under \n" + root))
29 |
30 | # determine if 16-bit or 8-bit depth images
31 | self.depth_16 = False
32 |
33 | if imread(self.depth_files[0]).dtype.type is np.uint16:
34 | self.depth_16 = True
35 | self.depth_16_max = 5000 #20000
36 |
37 | print('found {:d} image pairs with {:s}-bit depth under {:s}'.format(len(self.rgb_files), "16" if self.depth_16 else "8", root))
38 |
39 | # setup transforms
40 | if type == 'train':
41 | self.transform = self.train_transform
42 | elif type == 'val':
43 | self.transform = self.val_transform
44 | else:
45 | raise (RuntimeError("Invalid dataset type: " + type + "\n"
46 | "Supported dataset types are: train, val"))
47 |
48 | def gather_images(self, images_path, labels_path):
49 | def sorted_alphanumeric(data):
50 | convert = lambda text: int(text) if text.isdigit() else text.lower()
51 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
52 | return sorted(data, key=alphanum_key)
53 |
54 | #print('searching for images under: ')
55 | #print(' ' + images_path)
56 | #print(' ' + labels_path)
57 |
58 | image_files = sorted_alphanumeric(os.listdir(images_path))
59 | label_files = sorted_alphanumeric(os.listdir(labels_path))
60 |
61 | if len(image_files) != len(label_files):
62 | print('warning: images path has a different number of files than labels path')
63 | print(' ({:d} files) - {:s}'.format(len(image_files), images_path))
64 | print(' ({:d} files) - {:s}'.format(len(label_files), labels_path))
65 |
66 | for n in range(len(image_files)):
67 | image_files[n] = os.path.join(images_path, image_files[n])
68 | label_files[n] = os.path.join(labels_path, label_files[n])
69 |
70 | #print('{:s} -> {:s}'.format(image_files[n], label_files[n]))
71 |
72 | return image_files, label_files
73 |
74 | def train_transform(self, rgb, depth):
75 | s = np.random.uniform(1.0, 1.5) # random scaling
76 | depth_np = depth #/ s
77 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
78 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
79 |
80 | # perform 1st step of data augmentation
81 | transform = transforms.Compose([
82 | #transforms.Resize(240.0 / iheight), # this is for computational efficiency, since rotation can be slow
83 | #transforms.Rotate(angle),
84 | #transforms.Resize(s),
85 | #transforms.CenterCrop(self.output_size),
86 | #transforms.HorizontalFlip(do_flip)
87 | transforms.Resize(self.output_size)
88 | ])
89 |
90 | rgb_np = transform(rgb)
91 | #rgb_np = self.color_jitter(rgb_np) # random color jittering
92 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
93 |
94 | depth_np = transform(depth_np)
95 | depth_np = np.asfarray(depth_np, dtype='float')
96 |
97 | if self.depth_16:
98 | depth_np = depth_np / self.depth_16_max
99 | else:
100 | depth_np = depth_np / 255
101 |
102 | return rgb_np, depth_np
103 |
104 | def val_transform(self, rgb, depth):
105 | depth_np = depth
106 |
107 | transform = transforms.Compose([
108 | #transforms.Resize(240.0 / iheight),
109 | #transforms.CenterCrop(self.output_size),
110 | transforms.Resize(self.output_size)
111 | ])
112 |
113 | rgb_np = transform(rgb)
114 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
115 |
116 | depth_np = transform(depth_np)
117 | depth_np = np.asfarray(depth_np, dtype='float')
118 |
119 | if self.depth_16:
120 | depth_np = depth_np / self.depth_16_max
121 | else:
122 | depth_np = depth_np / 255
123 |
124 | return rgb_np, depth_np
125 |
126 | def load_rgb(self, index):
127 | return imread(self.rgb_files[index], as_gray=False, pilmode="RGB")
128 |
129 | def load_depth(self, index):
130 | if self.depth_16:
131 | depth = imread(self.depth_files[index])
132 | depth[depth == 65535] = 0 # map 'invalid' to 0
133 | return depth
134 | else:
135 | depth = imread(self.depth_files[index], as_gray=False, pilmode="L")
136 | #depth[depth == 0] = 255 # map 0 -> 255
137 | return depth
138 |
139 | def __len__(self):
140 | return len(self.rgb_files)
141 |
142 | def __getitem__(self, index):
143 | rgb = self.load_rgb(index)
144 | depth = self.load_depth(index)
145 |
146 | #print(self.rgb_files[index] + str(rgb.shape))
147 | #print(self.depth_files[index] + str(depth.shape))
148 | #print(depth)
149 |
150 | # apply train/val transforms
151 | if self.transform is not None:
152 | rgb_np, depth_np = self.transform(rgb, depth)
153 | else:
154 | raise(RuntimeError("transform not defined"))
155 |
156 | # convert from numpy to torch tensors
157 | input_tensor = to_tensor(rgb_np)
158 |
159 | while input_tensor.dim() < 3:
160 | input_tensor = input_tensor.unsqueeze(0)
161 |
162 | depth_tensor = to_tensor(depth_np)
163 | depth_tensor = depth_tensor.unsqueeze(0)
164 |
165 | #print("{:04d} rgb = ".format(index) + str(input_tensor.shape))
166 | #print("{:04d} depth = ".format(index) + str(depth_tensor.shape))
167 |
168 | return input_tensor, depth_tensor
169 |
170 |
--------------------------------------------------------------------------------
/dataloaders/dense_to_sparse.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | def rgb2grayscale(rgb):
6 | return rgb[:, :, 0] * 0.2989 + rgb[:, :, 1] * 0.587 + rgb[:, :, 2] * 0.114
7 |
8 |
9 | class DenseToSparse:
10 | def __init__(self):
11 | pass
12 |
13 | def dense_to_sparse(self, rgb, depth):
14 | pass
15 |
16 | def __repr__(self):
17 | pass
18 |
19 | class UniformSampling(DenseToSparse):
20 | name = "uar"
21 | def __init__(self, num_samples, max_depth=np.inf):
22 | DenseToSparse.__init__(self)
23 | self.num_samples = num_samples
24 | self.max_depth = max_depth
25 |
26 | def __repr__(self):
27 | return "%s{ns=%d,md=%f}" % (self.name, self.num_samples, self.max_depth)
28 |
29 | def dense_to_sparse(self, rgb, depth):
30 | """
31 | Samples pixels with `num_samples`/#pixels probability in `depth`.
32 | Only pixels with a maximum depth of `max_depth` are considered.
33 | If no `max_depth` is given, samples in all pixels
34 | """
35 | mask_keep = depth > 0
36 | if self.max_depth is not np.inf:
37 | mask_keep = np.bitwise_and(mask_keep, depth <= self.max_depth)
38 | n_keep = np.count_nonzero(mask_keep)
39 | if n_keep == 0:
40 | return mask_keep
41 | else:
42 | prob = float(self.num_samples) / n_keep
43 | return np.bitwise_and(mask_keep, np.random.uniform(0, 1, depth.shape) < prob)
44 |
45 |
46 | class SimulatedStereo(DenseToSparse):
47 | name = "sim_stereo"
48 |
49 | def __init__(self, num_samples, max_depth=np.inf, dilate_kernel=3, dilate_iterations=1):
50 | DenseToSparse.__init__(self)
51 | self.num_samples = num_samples
52 | self.max_depth = max_depth
53 | self.dilate_kernel = dilate_kernel
54 | self.dilate_iterations = dilate_iterations
55 |
56 | def __repr__(self):
57 | return "%s{ns=%d,md=%f,dil=%d.%d}" % \
58 | (self.name, self.num_samples, self.max_depth, self.dilate_kernel, self.dilate_iterations)
59 |
60 | # We do not use cv2.Canny, since that applies non max suppression
61 | # So we simply do
62 | # RGB to intensitities
63 | # Smooth with gaussian
64 | # Take simple sobel gradients
65 | # Threshold the edge gradient
66 | # Dilatate
67 | def dense_to_sparse(self, rgb, depth):
68 | gray = rgb2grayscale(rgb)
69 | blurred = cv2.GaussianBlur(gray, (5, 5), 0)
70 | gx = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=5)
71 | gy = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=5)
72 |
73 | depth_mask = np.bitwise_and(depth != 0.0, depth <= self.max_depth)
74 |
75 | edge_fraction = float(self.num_samples) / np.size(depth)
76 |
77 | mag = cv2.magnitude(gx, gy)
78 | min_mag = np.percentile(mag[depth_mask], 100 * (1.0 - edge_fraction))
79 | mag_mask = mag >= min_mag
80 |
81 | if self.dilate_iterations >= 0:
82 | kernel = np.ones((self.dilate_kernel, self.dilate_kernel), dtype=np.uint8)
83 | cv2.dilate(mag_mask.astype(np.uint8), kernel, iterations=self.dilate_iterations)
84 |
85 | mask = np.bitwise_and(mag_mask, depth_mask)
86 | return mask
87 |
--------------------------------------------------------------------------------
/dataloaders/kitti_dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import dataloaders.transforms as transforms
3 | from dataloaders.dataloader import MyDataloader
4 |
5 | class KITTIDataset(MyDataloader):
6 | def __init__(self, root, type, sparsifier=None, modality='rgb'):
7 | super(KITTIDataset, self).__init__(root, type, sparsifier, modality)
8 | self.output_size = (228, 912)
9 |
10 | def train_transform(self, rgb, depth):
11 | s = np.random.uniform(1.0, 1.5) # random scaling
12 | depth_np = depth / s
13 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
14 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
15 |
16 | # perform 1st step of data augmentation
17 | transform = transforms.Compose([
18 | transforms.Crop(130, 10, 240, 1200),
19 | transforms.Rotate(angle),
20 | transforms.Resize(s),
21 | transforms.CenterCrop(self.output_size),
22 | transforms.HorizontalFlip(do_flip)
23 | ])
24 | rgb_np = transform(rgb)
25 | rgb_np = self.color_jitter(rgb_np) # random color jittering
26 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
27 | # Scipy affine_transform produced RuntimeError when the depth map was
28 | # given as a 'numpy.ndarray'
29 | depth_np = np.asfarray(depth_np, dtype='float32')
30 | depth_np = transform(depth_np)
31 |
32 | return rgb_np, depth_np
33 |
34 | def val_transform(self, rgb, depth):
35 | depth_np = depth
36 | transform = transforms.Compose([
37 | transforms.Crop(130, 10, 240, 1200),
38 | transforms.CenterCrop(self.output_size),
39 | ])
40 | rgb_np = transform(rgb)
41 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
42 | depth_np = np.asfarray(depth_np, dtype='float32')
43 | depth_np = transform(depth_np)
44 |
45 | return rgb_np, depth_np
46 |
47 |
--------------------------------------------------------------------------------
/dataloaders/nyu_dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import dataloaders.transforms as transforms
3 | from dataloaders.dataloader import MyDataloader
4 |
5 | iheight, iwidth = 480, 640 # raw image size
6 |
7 | class NYUDataset(MyDataloader):
8 | def __init__(self, root, type, sparsifier=None, modality='rgb'):
9 | super(NYUDataset, self).__init__(root, type, sparsifier, modality)
10 | self.output_size = (448, 448) #(224, 224) #(228, 304) #(iheight, iwidth)
11 |
12 | def train_transform(self, rgb, depth):
13 | s = np.random.uniform(1.0, 1.5) # random scaling
14 | depth_np = depth / s
15 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
16 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
17 |
18 | # perform 1st step of data augmentation
19 | transform = transforms.Compose([
20 | transforms.Resize(480.0 / iheight), #250.0 / iheight), # this is for computational efficiency, since rotation can be slow
21 | transforms.Rotate(angle),
22 | #transforms.Resize(s), # disabled for 448x448
23 | transforms.CenterCrop(self.output_size),
24 | transforms.HorizontalFlip(do_flip)
25 | ])
26 | rgb_np = transform(rgb)
27 | rgb_np = self.color_jitter(rgb_np) # random color jittering
28 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
29 | depth_np = transform(depth_np)
30 |
31 | return rgb_np, depth_np
32 |
33 | def val_transform(self, rgb, depth):
34 | depth_np = depth
35 | transform = transforms.Compose([
36 | transforms.Resize(480.0 / iheight), #240.0 / iheight),
37 | transforms.CenterCrop(self.output_size),
38 | ])
39 | rgb_np = transform(rgb)
40 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
41 | depth_np = transform(depth_np)
42 |
43 | return rgb_np, depth_np
44 |
--------------------------------------------------------------------------------
/dataloaders/sun_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import dataloaders.transforms as transforms
4 |
5 | from imageio import imread
6 | from torch.utils.data import Dataset, DataLoader
7 |
8 | to_tensor = transforms.ToTensor()
9 |
10 | class SunRGBDDataset(Dataset):
11 | def __init__(self, root, type='train', train_extra=True):
12 | self.root = root
13 | self.output_size = (224, 224) #(224, 448)
14 |
15 | # search for images
16 | self.rgb_files, self.depth_files = self.gather_images(os.path.join(root, 'images'),
17 | os.path.join(root, 'depth'))
18 |
19 | if type == 'train' and train_extra:
20 | extra_root = root + 'extra'
21 | extra_rgb, extra_depth = self.gather_images(os.path.join(extra_root, 'images'),
22 | os.path.join(extra_root, 'depth'))
23 |
24 | if len(self.rgb_files) == 0:
25 | raise (RuntimeError("Empty dataset - found no image pairs under \n" + root))
26 |
27 | # determine if 16-bit or 8-bit depth images
28 | self.depth_16 = False
29 |
30 | if imread(self.depth_files[0]).dtype.type is np.uint16:
31 | self.depth_16 = True
32 | self.depth_16_max = 10000
33 |
34 | print('found {:d} image pairs with {:s}-bit depth under {:s}'.format(len(self.rgb_files), "16" if self.depth_16 else "8", root))
35 |
36 | # setup transforms
37 | if type == 'train':
38 | self.transform = self.train_transform
39 | elif type == 'val':
40 | self.transform = self.val_transform
41 | else:
42 | raise (RuntimeError("Invalid dataset type: " + type + "\n"
43 | "Supported dataset types are: train, val"))
44 |
45 | def gather_images(self, images_path, labels_path, max_images=5500):
46 | image_files = []
47 | label_files = []
48 |
49 | for n in range(max_images):
50 | image_filename = os.path.join(images_path, 'img-{:06d}.jpg'.format(n))
51 | label_filename = os.path.join(labels_path, '{:d}.png'.format(n))
52 |
53 | if os.path.isfile(image_filename) and os.path.isfile(label_filename):
54 | image_files.append(image_filename)
55 | label_files.append(label_filename)
56 |
57 | return image_files, label_files
58 |
59 | def train_transform(self, rgb, depth):
60 | s = np.random.uniform(1.0, 1.5) # random scaling
61 | depth_np = depth #/ s
62 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
63 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
64 |
65 | # perform 1st step of data augmentation
66 | transform = transforms.Compose([
67 | #transforms.Resize(240.0 / iheight), # this is for computational efficiency, since rotation can be slow
68 | #transforms.Rotate(angle),
69 | #transforms.Resize(s),
70 | #transforms.CenterCrop(self.output_size),
71 | #transforms.HorizontalFlip(do_flip)
72 | transforms.Resize(self.output_size)
73 | ])
74 |
75 | rgb_np = transform(rgb)
76 | #rgb_np = self.color_jitter(rgb_np) # random color jittering
77 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
78 |
79 | depth_np = transform(depth_np)
80 | depth_np = np.asfarray(depth_np, dtype='float')
81 |
82 | if self.depth_16:
83 | depth_np = depth_np / self.depth_16_max
84 | else:
85 | depth_np = depth_np / 255
86 |
87 | return rgb_np, depth_np
88 |
89 | def val_transform(self, rgb, depth):
90 | depth_np = depth
91 |
92 | transform = transforms.Compose([
93 | #transforms.Resize(240.0 / iheight),
94 | #transforms.CenterCrop(self.output_size),
95 | transforms.Resize(self.output_size)
96 | ])
97 |
98 | rgb_np = transform(rgb)
99 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
100 |
101 | depth_np = transform(depth_np)
102 | depth_np = np.asfarray(depth_np, dtype='float')
103 |
104 | if self.depth_16:
105 | depth_np = depth_np / self.depth_16_max
106 | else:
107 | depth_np = depth_np / 255
108 |
109 | return rgb_np, depth_np
110 |
111 | def load_rgb(self, index):
112 | return imread(self.rgb_files[index], as_gray=False, pilmode="RGB")
113 |
114 | def load_depth(self, index):
115 | if self.depth_16:
116 | depth = imread(self.depth_files[index])
117 | depth[depth == 65535] = 0 # map 'invalid' to 0
118 | return depth
119 | else:
120 | depth = imread(self.depth_files[index], as_gray=False, pilmode="L")
121 | #depth[depth == 0] = 255 # map 0 -> 255
122 | return depth
123 |
124 | def __len__(self):
125 | return len(self.rgb_files)
126 |
127 | def __getitem__(self, index):
128 | rgb = self.load_rgb(index)
129 | depth = self.load_depth(index)
130 |
131 | #print(self.rgb_files[index] + str(rgb.shape))
132 | #print(self.depth_files[index] + str(depth.shape))
133 | #print(depth)
134 |
135 | # apply train/val transforms
136 | if self.transform is not None:
137 | rgb_np, depth_np = self.transform(rgb, depth)
138 | else:
139 | raise(RuntimeError("transform not defined"))
140 |
141 | # convert from numpy to torch tensors
142 | input_tensor = to_tensor(rgb_np)
143 |
144 | while input_tensor.dim() < 3:
145 | input_tensor = input_tensor.unsqueeze(0)
146 |
147 | depth_tensor = to_tensor(depth_np)
148 | depth_tensor = depth_tensor.unsqueeze(0)
149 |
150 | #print("{:04d} rgb = ".format(index) + str(input_tensor.shape))
151 | #print("{:04d} depth = ".format(index) + str(depth_tensor.shape))
152 |
153 | return input_tensor, depth_tensor
154 |
155 |
--------------------------------------------------------------------------------
/dataloaders/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import math
4 | import random
5 |
6 | from PIL import Image, ImageOps, ImageEnhance
7 | try:
8 | import accimage
9 | except ImportError:
10 | accimage = None
11 |
12 | import numpy as np
13 | import numbers
14 | import types
15 | import collections
16 | import warnings
17 |
18 | import scipy.ndimage.interpolation as itpl
19 | import scipy.misc as misc
20 |
21 |
22 | def _is_numpy_image(img):
23 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
24 |
25 | def _is_pil_image(img):
26 | if accimage is not None:
27 | return isinstance(img, (Image.Image, accimage.Image))
28 | else:
29 | return isinstance(img, Image.Image)
30 |
31 | def _is_tensor_image(img):
32 | return torch.is_tensor(img) and img.ndimension() == 3
33 |
34 | def adjust_brightness(img, brightness_factor):
35 | """Adjust brightness of an Image.
36 |
37 | Args:
38 | img (PIL Image): PIL Image to be adjusted.
39 | brightness_factor (float): How much to adjust the brightness. Can be
40 | any non negative number. 0 gives a black image, 1 gives the
41 | original image while 2 increases the brightness by a factor of 2.
42 |
43 | Returns:
44 | PIL Image: Brightness adjusted image.
45 | """
46 | if not _is_pil_image(img):
47 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
48 |
49 | enhancer = ImageEnhance.Brightness(img)
50 | img = enhancer.enhance(brightness_factor)
51 | return img
52 |
53 |
54 | def adjust_contrast(img, contrast_factor):
55 | """Adjust contrast of an Image.
56 |
57 | Args:
58 | img (PIL Image): PIL Image to be adjusted.
59 | contrast_factor (float): How much to adjust the contrast. Can be any
60 | non negative number. 0 gives a solid gray image, 1 gives the
61 | original image while 2 increases the contrast by a factor of 2.
62 |
63 | Returns:
64 | PIL Image: Contrast adjusted image.
65 | """
66 | if not _is_pil_image(img):
67 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
68 |
69 | enhancer = ImageEnhance.Contrast(img)
70 | img = enhancer.enhance(contrast_factor)
71 | return img
72 |
73 |
74 | def adjust_saturation(img, saturation_factor):
75 | """Adjust color saturation of an image.
76 |
77 | Args:
78 | img (PIL Image): PIL Image to be adjusted.
79 | saturation_factor (float): How much to adjust the saturation. 0 will
80 | give a black and white image, 1 will give the original image while
81 | 2 will enhance the saturation by a factor of 2.
82 |
83 | Returns:
84 | PIL Image: Saturation adjusted image.
85 | """
86 | if not _is_pil_image(img):
87 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
88 |
89 | enhancer = ImageEnhance.Color(img)
90 | img = enhancer.enhance(saturation_factor)
91 | return img
92 |
93 |
94 | def adjust_hue(img, hue_factor):
95 | """Adjust hue of an image.
96 |
97 | The image hue is adjusted by converting the image to HSV and
98 | cyclically shifting the intensities in the hue channel (H).
99 | The image is then converted back to original image mode.
100 |
101 | `hue_factor` is the amount of shift in H channel and must be in the
102 | interval `[-0.5, 0.5]`.
103 |
104 | See https://en.wikipedia.org/wiki/Hue for more details on Hue.
105 |
106 | Args:
107 | img (PIL Image): PIL Image to be adjusted.
108 | hue_factor (float): How much to shift the hue channel. Should be in
109 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
110 | HSV space in positive and negative direction respectively.
111 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
112 | with complementary colors while 0 gives the original image.
113 |
114 | Returns:
115 | PIL Image: Hue adjusted image.
116 | """
117 | if not(-0.5 <= hue_factor <= 0.5):
118 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
119 |
120 | if not _is_pil_image(img):
121 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
122 |
123 | input_mode = img.mode
124 | if input_mode in {'L', '1', 'I', 'F'}:
125 | return img
126 |
127 | h, s, v = img.convert('HSV').split()
128 |
129 | np_h = np.array(h, dtype=np.uint8)
130 | # uint8 addition take cares of rotation across boundaries
131 | with np.errstate(over='ignore'):
132 | np_h += np.uint8(hue_factor * 255)
133 | h = Image.fromarray(np_h, 'L')
134 |
135 | img = Image.merge('HSV', (h, s, v)).convert(input_mode)
136 | return img
137 |
138 |
139 | def adjust_gamma(img, gamma, gain=1):
140 | """Perform gamma correction on an image.
141 |
142 | Also known as Power Law Transform. Intensities in RGB mode are adjusted
143 | based on the following equation:
144 |
145 | I_out = 255 * gain * ((I_in / 255) ** gamma)
146 |
147 | See https://en.wikipedia.org/wiki/Gamma_correction for more details.
148 |
149 | Args:
150 | img (PIL Image): PIL Image to be adjusted.
151 | gamma (float): Non negative real number. gamma larger than 1 make the
152 | shadows darker, while gamma smaller than 1 make dark regions
153 | lighter.
154 | gain (float): The constant multiplier.
155 | """
156 | if not _is_pil_image(img):
157 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
158 |
159 | if gamma < 0:
160 | raise ValueError('Gamma should be a non-negative real number')
161 |
162 | input_mode = img.mode
163 | img = img.convert('RGB')
164 |
165 | np_img = np.array(img, dtype=np.float32)
166 | np_img = 255 * gain * ((np_img / 255) ** gamma)
167 | np_img = np.uint8(np.clip(np_img, 0, 255))
168 |
169 | img = Image.fromarray(np_img, 'RGB').convert(input_mode)
170 | return img
171 |
172 |
173 | class Compose(object):
174 | """Composes several transforms together.
175 |
176 | Args:
177 | transforms (list of ``Transform`` objects): list of transforms to compose.
178 |
179 | Example:
180 | >>> transforms.Compose([
181 | >>> transforms.CenterCrop(10),
182 | >>> transforms.ToTensor(),
183 | >>> ])
184 | """
185 |
186 | def __init__(self, transforms):
187 | self.transforms = transforms
188 |
189 | def __call__(self, img):
190 | for t in self.transforms:
191 | img = t(img)
192 | return img
193 |
194 |
195 | class ToTensor(object):
196 | """Convert a ``numpy.ndarray`` to tensor.
197 |
198 | Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
199 | """
200 |
201 | def __call__(self, img):
202 | """Convert a ``numpy.ndarray`` to tensor.
203 |
204 | Args:
205 | img (numpy.ndarray): Image to be converted to tensor.
206 |
207 | Returns:
208 | Tensor: Converted image.
209 | """
210 | if not(_is_numpy_image(img)):
211 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
212 |
213 | if isinstance(img, np.ndarray):
214 | # handle numpy array
215 | if img.ndim == 3:
216 | img = torch.from_numpy(img.transpose((2, 0, 1)).copy())
217 | elif img.ndim == 2:
218 | img = torch.from_numpy(img.copy())
219 | else:
220 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
221 |
222 | # backward compatibility
223 | # return img.float().div(255)
224 | return img.float()
225 |
226 |
227 | class NormalizeNumpyArray(object):
228 | """Normalize a ``numpy.ndarray`` with mean and standard deviation.
229 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
230 | will normalize each channel of the input ``numpy.ndarray`` i.e.
231 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
232 |
233 | Args:
234 | mean (sequence): Sequence of means for each channel.
235 | std (sequence): Sequence of standard deviations for each channel.
236 | """
237 |
238 | def __init__(self, mean, std):
239 | self.mean = mean
240 | self.std = std
241 |
242 | def __call__(self, img):
243 | """
244 | Args:
245 | img (numpy.ndarray): Image of size (H, W, C) to be normalized.
246 |
247 | Returns:
248 | Tensor: Normalized image.
249 | """
250 | if not(_is_numpy_image(img)):
251 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
252 | # TODO: make efficient
253 | print(img.shape)
254 | for i in range(3):
255 | img[:,:,i] = (img[:,:,i] - self.mean[i]) / self.std[i]
256 | return img
257 |
258 | class NormalizeTensor(object):
259 | """Normalize an tensor image with mean and standard deviation.
260 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
261 | will normalize each channel of the input ``torch.*Tensor`` i.e.
262 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
263 |
264 | Args:
265 | mean (sequence): Sequence of means for each channel.
266 | std (sequence): Sequence of standard deviations for each channel.
267 | """
268 |
269 | def __init__(self, mean, std):
270 | self.mean = mean
271 | self.std = std
272 |
273 | def __call__(self, tensor):
274 | """
275 | Args:
276 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
277 |
278 | Returns:
279 | Tensor: Normalized Tensor image.
280 | """
281 | if not _is_tensor_image(tensor):
282 | raise TypeError('tensor is not a torch image.')
283 | # TODO: make efficient
284 | for t, m, s in zip(tensor, self.mean, self.std):
285 | t.sub_(m).div_(s)
286 | return tensor
287 |
288 | class Rotate(object):
289 | """Rotates the given ``numpy.ndarray``.
290 |
291 | Args:
292 | angle (float): The rotation angle in degrees.
293 | """
294 |
295 | def __init__(self, angle):
296 | self.angle = angle
297 |
298 | def __call__(self, img):
299 | """
300 | Args:
301 | img (numpy.ndarray (C x H x W)): Image to be rotated.
302 |
303 | Returns:
304 | img (numpy.ndarray (C x H x W)): Rotated image.
305 | """
306 |
307 | # order=0 means nearest-neighbor type interpolation
308 | return itpl.rotate(img, self.angle, reshape=False, prefilter=False, order=0)
309 |
310 |
311 | class Resize(object):
312 | """Resize the the given ``numpy.ndarray`` to the given size.
313 | Args:
314 | size (sequence or int): Desired output size. If size is a sequence like
315 | (h, w), output size will be matched to this. If size is an int,
316 | smaller edge of the image will be matched to this number.
317 | i.e, if height > width, then image will be rescaled to
318 | (size * height / width, size)
319 | interpolation (int, optional): Desired interpolation. Default is
320 | ``PIL.Image.BILINEAR``
321 | """
322 |
323 | def __init__(self, size, interpolation='nearest'):
324 | assert isinstance(size, int) or isinstance(size, float) or \
325 | (isinstance(size, collections.Iterable) and len(size) == 2)
326 | self.size = size
327 | self.interpolation = interpolation
328 |
329 | def __call__(self, img):
330 | """
331 | Args:
332 | img (PIL Image): Image to be scaled.
333 | Returns:
334 | PIL Image: Rescaled image.
335 | """
336 | if img.ndim == 3:
337 | return misc.imresize(img, self.size, self.interpolation)
338 | elif img.ndim == 2:
339 | return misc.imresize(img, self.size, self.interpolation, 'F')
340 | else:
341 | RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
342 |
343 |
344 | class CenterCrop(object):
345 | """Crops the given ``numpy.ndarray`` at the center.
346 |
347 | Args:
348 | size (sequence or int): Desired output size of the crop. If size is an
349 | int instead of sequence like (h, w), a square crop (size, size) is
350 | made.
351 | """
352 |
353 | def __init__(self, size):
354 | if isinstance(size, numbers.Number):
355 | self.size = (int(size), int(size))
356 | else:
357 | self.size = size
358 |
359 | @staticmethod
360 | def get_params(img, output_size):
361 | """Get parameters for ``crop`` for center crop.
362 |
363 | Args:
364 | img (numpy.ndarray (C x H x W)): Image to be cropped.
365 | output_size (tuple): Expected output size of the crop.
366 |
367 | Returns:
368 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
369 | """
370 | h = img.shape[0]
371 | w = img.shape[1]
372 | th, tw = output_size
373 | i = int(round((h - th) / 2.))
374 | j = int(round((w - tw) / 2.))
375 |
376 | # # randomized cropping
377 | # i = np.random.randint(i-3, i+4)
378 | # j = np.random.randint(j-3, j+4)
379 |
380 | return i, j, th, tw
381 |
382 | def __call__(self, img):
383 | """
384 | Args:
385 | img (numpy.ndarray (C x H x W)): Image to be cropped.
386 |
387 | Returns:
388 | img (numpy.ndarray (C x H x W)): Cropped image.
389 | """
390 | i, j, h, w = self.get_params(img, self.size)
391 |
392 | """
393 | i: Upper pixel coordinate.
394 | j: Left pixel coordinate.
395 | h: Height of the cropped image.
396 | w: Width of the cropped image.
397 | """
398 | if not(_is_numpy_image(img)):
399 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
400 | if img.ndim == 3:
401 | return img[i:i+h, j:j+w, :]
402 | elif img.ndim == 2:
403 | return img[i:i + h, j:j + w]
404 | else:
405 | raise RuntimeError('img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
406 |
407 |
408 | class Lambda(object):
409 | """Apply a user-defined lambda as a transform.
410 |
411 | Args:
412 | lambd (function): Lambda/function to be used for transform.
413 | """
414 |
415 | def __init__(self, lambd):
416 | assert isinstance(lambd, types.LambdaType)
417 | self.lambd = lambd
418 |
419 | def __call__(self, img):
420 | return self.lambd(img)
421 |
422 |
423 | class HorizontalFlip(object):
424 | """Horizontally flip the given ``numpy.ndarray``.
425 |
426 | Args:
427 | do_flip (boolean): whether or not do horizontal flip.
428 |
429 | """
430 |
431 | def __init__(self, do_flip):
432 | self.do_flip = do_flip
433 |
434 | def __call__(self, img):
435 | """
436 | Args:
437 | img (numpy.ndarray (C x H x W)): Image to be flipped.
438 |
439 | Returns:
440 | img (numpy.ndarray (C x H x W)): flipped image.
441 | """
442 | if not(_is_numpy_image(img)):
443 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
444 |
445 | if self.do_flip:
446 | return np.fliplr(img)
447 | else:
448 | return img
449 |
450 |
451 | class ColorJitter(object):
452 | """Randomly change the brightness, contrast and saturation of an image.
453 |
454 | Args:
455 | brightness (float): How much to jitter brightness. brightness_factor
456 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
457 | contrast (float): How much to jitter contrast. contrast_factor
458 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
459 | saturation (float): How much to jitter saturation. saturation_factor
460 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
461 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
462 | [-hue, hue]. Should be >=0 and <= 0.5.
463 | """
464 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
465 | self.brightness = brightness
466 | self.contrast = contrast
467 | self.saturation = saturation
468 | self.hue = hue
469 |
470 | @staticmethod
471 | def get_params(brightness, contrast, saturation, hue):
472 | """Get a randomized transform to be applied on image.
473 |
474 | Arguments are same as that of __init__.
475 |
476 | Returns:
477 | Transform which randomly adjusts brightness, contrast and
478 | saturation in a random order.
479 | """
480 | transforms = []
481 | if brightness > 0:
482 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
483 | transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))
484 |
485 | if contrast > 0:
486 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
487 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))
488 |
489 | if saturation > 0:
490 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
491 | transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))
492 |
493 | if hue > 0:
494 | hue_factor = np.random.uniform(-hue, hue)
495 | transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))
496 |
497 | np.random.shuffle(transforms)
498 | transform = Compose(transforms)
499 |
500 | return transform
501 |
502 | def __call__(self, img):
503 | """
504 | Args:
505 | img (numpy.ndarray (C x H x W)): Input image.
506 |
507 | Returns:
508 | img (numpy.ndarray (C x H x W)): Color jittered image.
509 | """
510 | if not(_is_numpy_image(img)):
511 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
512 |
513 | pil = Image.fromarray(img)
514 | transform = self.get_params(self.brightness, self.contrast,
515 | self.saturation, self.hue)
516 | return np.array(transform(pil))
517 |
518 | class Crop(object):
519 | """Crops the given PIL Image to a rectangular region based on a given
520 | 4-tuple defining the left, upper pixel coordinated, hight and width size.
521 |
522 | Args:
523 | a tuple: (upper pixel coordinate, left pixel coordinate, hight, width)-tuple
524 | """
525 |
526 | def __init__(self, i, j, h, w):
527 | """
528 | i: Upper pixel coordinate.
529 | j: Left pixel coordinate.
530 | h: Height of the cropped image.
531 | w: Width of the cropped image.
532 | """
533 | self.i = i
534 | self.j = j
535 | self.h = h
536 | self.w = w
537 |
538 | def __call__(self, img):
539 | """
540 | Args:
541 | img (numpy.ndarray (C x H x W)): Image to be cropped.
542 | Returns:
543 | img (numpy.ndarray (C x H x W)): Cropped image.
544 | """
545 |
546 | i, j, h, w = self.i, self.j, self.h, self.w
547 |
548 | if not(_is_numpy_image(img)):
549 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
550 | if img.ndim == 3:
551 | return img[i:i + h, j:j + w, :]
552 | elif img.ndim == 2:
553 | return img[i:i + h, j:j + w]
554 | else:
555 | raise RuntimeError(
556 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(img.ndim))
557 |
558 | def __repr__(self):
559 | return self.__class__.__name__ + '(i={0},j={1},h={2},w={3})'.format(
560 | self.i, self.j, self.h, self.w)
561 |
--------------------------------------------------------------------------------
/dataloaders/zed_dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import dataloaders.transforms as transforms
4 |
5 | from imageio import imread
6 | from torch.utils.data import Dataset, DataLoader
7 |
8 | to_tensor = transforms.ToTensor()
9 |
10 | iheight, iwidth = 720, 1280 # raw image size
11 |
12 | class ZEDDataset(Dataset):
13 | def __init__(self, root, type='train'):
14 | self.root = root
15 | self.output_size = (224, 224) #(228, 304)
16 |
17 | # search for images
18 | self.rgb_files = []
19 | self.depth_files = []
20 |
21 | self.gather_images(root, max_images=100000)
22 |
23 | if len(self.rgb_files) == 0:
24 | raise (RuntimeError("Empty dataset - found no image pairs under \n" + root))
25 |
26 | # determine if 16-bit or 8-bit depth images
27 | self.depth_16 = False
28 |
29 | if imread(self.depth_files[0]).dtype.type is np.uint16:
30 | self.depth_16 = True
31 | self.depth_16_max = 5000 #20000
32 |
33 | print('found {:d} image pairs with {:s}-bit depth under {:s}'.format(len(self.rgb_files), "16" if self.depth_16 else "8", root))
34 |
35 | # setup transforms
36 | if type == 'train':
37 | self.transform = self.train_transform
38 | elif type == 'val':
39 | self.transform = self.val_transform
40 | else:
41 | raise (RuntimeError("Invalid dataset type: " + type + "\n"
42 | "Supported dataset types are: train, val"))
43 |
44 | def gather_images(self, img_dir, max_images=999999):
45 | rgb_files = []
46 | depth_files = []
47 |
48 | # search in the current directory
49 | for n in range(max_images):
50 | img_name_rgb = 'left{:06d}.png'.format(n)
51 | img_path_rgb = os.path.join(img_dir, img_name_rgb)
52 |
53 | img_name_depth = 'depth{:06d}.png'.format(n)
54 | img_path_depth = os.path.join(img_dir, img_name_depth)
55 |
56 | if os.path.isfile(img_path_rgb) and os.path.isfile(img_path_depth):
57 | self.rgb_files.append(img_path_rgb)
58 | self.depth_files.append(img_path_depth)
59 |
60 | # search in subdirectories
61 | dir_files = os.listdir(img_dir)
62 |
63 | for dir_name in dir_files:
64 | dir_path = os.path.join(img_dir, dir_name)
65 |
66 | if os.path.isdir(dir_path):
67 | self.gather_images(dir_path, max_images)
68 |
69 | def train_transform(self, rgb, depth):
70 | s = np.random.uniform(1.0, 1.5) # random scaling
71 | depth_np = depth #/ s
72 | angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
73 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
74 |
75 | # perform 1st step of data augmentation
76 | transform = transforms.Compose([
77 | transforms.Resize(240.0 / iheight), # this is for computational efficiency, since rotation can be slow
78 | #transforms.Rotate(angle),
79 | #transforms.Resize(s),
80 | transforms.CenterCrop(self.output_size),
81 | transforms.HorizontalFlip(do_flip)
82 | ])
83 |
84 | rgb_np = transform(rgb)
85 | #rgb_np = self.color_jitter(rgb_np) # random color jittering
86 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
87 |
88 | depth_np = transform(depth_np)
89 | depth_np = np.asfarray(depth_np, dtype='float')
90 |
91 | if self.depth_16:
92 | depth_np = depth_np / self.depth_16_max
93 | else:
94 | depth_np = (255 - depth_np) / 255
95 |
96 | return rgb_np, depth_np
97 |
98 | def val_transform(self, rgb, depth):
99 | depth_np = depth
100 |
101 | transform = transforms.Compose([
102 | transforms.Resize(240.0 / iheight),
103 | transforms.CenterCrop(self.output_size),
104 | ])
105 |
106 | rgb_np = transform(rgb)
107 | rgb_np = np.asfarray(rgb_np, dtype='float') / 255
108 |
109 | depth_np = transform(depth_np)
110 | depth_np = np.asfarray(depth_np, dtype='float')
111 |
112 | if self.depth_16:
113 | depth_np = depth_np / self.depth_16_max
114 | else:
115 | depth_np = (255 - depth_np) / 255
116 |
117 | return rgb_np, depth_np
118 |
119 | def load_rgb(self, index):
120 | return imread(self.rgb_files[index], as_gray=False, pilmode="RGB")
121 |
122 | def load_depth(self, index):
123 | if self.depth_16:
124 | depth = imread(self.depth_files[index])
125 | depth[depth == 65535] = 0 # map 'invalid' to 0
126 | return depth
127 | else:
128 | depth = imread(self.depth_files[index], as_gray=False, pilmode="L")
129 | #depth[depth == 0] = 255 # map 0 -> 255
130 | return depth
131 |
132 | def __len__(self):
133 | return len(self.rgb_files)
134 |
135 | def __getitem__(self, index):
136 | rgb = self.load_rgb(index)
137 | depth = self.load_depth(index)
138 |
139 | #print(self.depth_files[index] + str(depth.shape))
140 | #print(depth)
141 |
142 | # apply train/val transforms
143 | if self.transform is not None:
144 | rgb_np, depth_np = self.transform(rgb, depth)
145 | else:
146 | raise(RuntimeError("transform not defined"))
147 |
148 | # convert from numpy to torch tensors
149 | input_tensor = to_tensor(rgb_np)
150 |
151 | while input_tensor.dim() < 3:
152 | input_tensor = input_tensor.unsqueeze(0)
153 |
154 | depth_tensor = to_tensor(depth_np)
155 | depth_tensor = depth_tensor.unsqueeze(0)
156 |
157 | #print("{:04d} rgb = ".format(index) + str(input_tensor.shape))
158 | #print("{:04d} depth = ".format(index) + str(depth_tensor.shape))
159 |
160 | return input_tensor, depth_tensor
161 |
162 |
--------------------------------------------------------------------------------
/imagenet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dusty-nv/pytorch-depth/41d6440dc0a64a4c59dff3daaaea50a1212897b1/imagenet/__init__.py
--------------------------------------------------------------------------------
/imagenet/mobilenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.parallel
8 | import torch.backends.cudnn as cudnn
9 | import torch.optim
10 | import torch.utils.data
11 |
12 | class MobileNet(nn.Module):
13 | def __init__(self, relu6=True):
14 | super(MobileNet, self).__init__()
15 |
16 | def relu(relu6):
17 | if relu6:
18 | return nn.ReLU6(inplace=True)
19 | else:
20 | return nn.ReLU(inplace=True)
21 |
22 | def conv_bn(inp, oup, stride, relu6):
23 | return nn.Sequential(
24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
25 | nn.BatchNorm2d(oup),
26 | relu(relu6),
27 | )
28 |
29 | def conv_dw(inp, oup, stride, relu6):
30 | return nn.Sequential(
31 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
32 | nn.BatchNorm2d(inp),
33 | relu(relu6),
34 |
35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
36 | nn.BatchNorm2d(oup),
37 | relu(relu6),
38 | )
39 |
40 | self.model = nn.Sequential(
41 | conv_bn( 3, 32, 2, relu6),
42 | conv_dw( 32, 64, 1, relu6),
43 | conv_dw( 64, 128, 2, relu6),
44 | conv_dw(128, 128, 1, relu6),
45 | conv_dw(128, 256, 2, relu6),
46 | conv_dw(256, 256, 1, relu6),
47 | conv_dw(256, 512, 2, relu6),
48 | conv_dw(512, 512, 1, relu6),
49 | conv_dw(512, 512, 1, relu6),
50 | conv_dw(512, 512, 1, relu6),
51 | conv_dw(512, 512, 1, relu6),
52 | conv_dw(512, 512, 1, relu6),
53 | conv_dw(512, 1024, 2, relu6),
54 | conv_dw(1024, 1024, 1, relu6),
55 | nn.AvgPool2d(7),
56 | )
57 | self.fc = nn.Linear(1024, 1000)
58 |
59 | def forward(self, x):
60 | x = self.model(x)
61 | #print('pre-view size: ' + str(x.size()))
62 | x = x.view(-1, 1024)
63 | #print('post-view size: ' + str(x.size()))
64 | x = self.fc(x)
65 | return x
66 |
67 | def main():
68 | import torchvision.models
69 | model = MobileNet(relu6=True)
70 | model = torch.nn.DataParallel(model).cuda()
71 | model_filename = os.path.join('results', 'imagenet.arch=mobilenet.lr=0.1.bs=256', 'model_best.pth.tar')
72 | if os.path.isfile(model_filename):
73 | print("=> loading Imagenet pretrained model '{}'".format(model_filename))
74 | checkpoint = torch.load(model_filename)
75 | epoch = checkpoint['epoch']
76 | best_prec1 = checkpoint['best_prec1']
77 | model.load_state_dict(checkpoint['state_dict'])
78 | print("=> loaded Imagenet pretrained model '{}' (epoch {}). best_prec1={}".format(model_filename, epoch, best_prec1))
79 |
80 | if __name__ == '__main__':
81 | main()
82 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import csv
4 | import numpy as np
5 |
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import torch.optim
9 |
10 | cudnn.benchmark = True
11 |
12 | from models import ResNet
13 | from models_fast import MobileNetSkipAdd
14 | from metrics import AverageMeter, Result
15 | from dataloaders.dense_to_sparse import UniformSampling, SimulatedStereo
16 |
17 | import criteria
18 | import utils
19 |
20 | args = utils.parse_command()
21 | print(args)
22 |
23 | fieldnames = ['mse', 'rmse', 'absrel', 'lg10', 'mae',
24 | 'delta1', 'delta2', 'delta3',
25 | 'data_time', 'gpu_time']
26 | best_result = Result()
27 | best_result.set_to_worst()
28 |
29 | def create_data_loaders(args):
30 | # Data loading code
31 | print("=> creating data loaders ...")
32 | traindir = os.path.join('data', args.data, 'train')
33 | valdir = os.path.join('data', args.data, 'val')
34 | train_loader = None
35 | val_loader = None
36 |
37 | # sparsifier is a class for generating random sparse depth input from the ground truth
38 | sparsifier = None
39 | max_depth = args.max_depth if args.max_depth >= 0.0 else np.inf
40 | if args.sparsifier == UniformSampling.name:
41 | sparsifier = UniformSampling(num_samples=args.num_samples, max_depth=max_depth)
42 | elif args.sparsifier == SimulatedStereo.name:
43 | sparsifier = SimulatedStereo(num_samples=args.num_samples, max_depth=max_depth)
44 |
45 | if args.data == 'nyudepthv2':
46 | from dataloaders.nyu_dataloader import NYUDataset
47 | if not args.evaluate:
48 | train_dataset = NYUDataset(traindir, type='train',
49 | modality=args.modality, sparsifier=sparsifier)
50 | val_dataset = NYUDataset(valdir, type='val',
51 | modality=args.modality, sparsifier=sparsifier)
52 |
53 | elif args.data == 'kitti':
54 | from dataloaders.kitti_dataloader import KITTIDataset
55 | if not args.evaluate:
56 | train_dataset = KITTIDataset(traindir, type='train',
57 | modality=args.modality, sparsifier=sparsifier)
58 | val_dataset = KITTIDataset(valdir, type='val',
59 | modality=args.modality, sparsifier=sparsifier)
60 |
61 | elif args.data == 'deepscene':
62 | from dataloaders.deepscene_dataloader import DeepSceneDataset
63 | if not args.evaluate:
64 | train_dataset = DeepSceneDataset(traindir, type='train')
65 |
66 | val_dataset = DeepSceneDataset(valdir, type='val')
67 |
68 | elif args.data == 'sun':
69 | from dataloaders.sun_dataloader import SunRGBDDataset
70 | if not args.evaluate:
71 | train_dataset = SunRGBDDataset(traindir, type='train')
72 |
73 | val_dataset = SunRGBDDataset(valdir, type='val')
74 |
75 | elif args.data == 'zed':
76 | from dataloaders.zed_dataloader import ZEDDataset
77 | if not args.evaluate:
78 | train_dataset = ZEDDataset(traindir, type='train')
79 |
80 | val_dataset = ZEDDataset(valdir, type='val')
81 |
82 |
83 | else:
84 | raise RuntimeError('Dataset not found.' +
85 | 'The dataset must be either of nyudepthv2, kitti, or zed.')
86 |
87 | # set batch size to be 1 for validation
88 | val_loader = torch.utils.data.DataLoader(val_dataset,
89 | batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True)
90 |
91 | # put construction of train loader here, for those who are interested in testing only
92 | if not args.evaluate:
93 | train_loader = torch.utils.data.DataLoader(
94 | train_dataset, batch_size=args.batch_size, shuffle=True,
95 | num_workers=args.workers, pin_memory=True, sampler=None,
96 | worker_init_fn=lambda work_id:np.random.seed(work_id))
97 | # worker_init_fn ensures different sampling patterns for each data loading thread
98 |
99 | print("=> data loaders created.")
100 | return train_loader, val_loader
101 |
102 | def main():
103 | global args, best_result, output_directory, train_csv, test_csv
104 |
105 | # evaluation mode
106 | start_epoch = 0
107 | if args.evaluate:
108 | assert os.path.isfile(args.evaluate), \
109 | "=> no best model found at '{}'".format(args.evaluate)
110 | print("=> loading best model '{}'".format(args.evaluate))
111 | checkpoint = torch.load(args.evaluate)
112 | output_directory = os.path.dirname(args.evaluate)
113 | args = checkpoint['args']
114 | start_epoch = checkpoint['epoch'] + 1
115 | best_result = checkpoint['best_result']
116 | model = checkpoint['model']
117 | print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
118 | _, val_loader = create_data_loaders(args)
119 | args.evaluate = True
120 | validate(val_loader, model, checkpoint['epoch'], write_to_file=False)
121 | return
122 |
123 | # export to ONNX
124 | elif args.export:
125 | assert os.path.isfile(args.export), \
126 | "=> no best model found at '{}'".format(args.export)
127 | print("=> loading best model '{}'".format(args.export))
128 | checkpoint = torch.load(args.export)
129 | output_directory = os.path.dirname(args.export)
130 | output_filename = args.export + '.onnx'
131 | args = checkpoint['args']
132 | start_epoch = checkpoint['epoch'] + 1
133 | best_result = checkpoint['best_result']
134 | model = checkpoint['model']
135 | model.export = True
136 | print("=> loaded best model (epoch {})".format(checkpoint['epoch']))
137 | export(model, output_filename, args.data)
138 | return
139 |
140 | # optionally resume from a checkpoint
141 | elif args.resume:
142 | chkpt_path = args.resume
143 | assert os.path.isfile(chkpt_path), \
144 | "=> no checkpoint found at '{}'".format(chkpt_path)
145 | print("=> loading checkpoint '{}'".format(chkpt_path))
146 | checkpoint = torch.load(chkpt_path)
147 | args = checkpoint['args']
148 | start_epoch = checkpoint['epoch'] + 1
149 | best_result = checkpoint['best_result']
150 | model = checkpoint['model']
151 | optimizer = checkpoint['optimizer']
152 | output_directory = os.path.dirname(os.path.abspath(chkpt_path))
153 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
154 | train_loader, val_loader = create_data_loaders(args)
155 | args.resume = True
156 |
157 | # transfer learning from a checkpoint
158 | elif args.checkpoint:
159 | chkpt_path = args.checkpoint
160 | assert os.path.isfile(chkpt_path), \
161 | "=> no checkpoint found at '{}'".format(chkpt_path)
162 | print("=> loading checkpoint '{}'".format(chkpt_path))
163 | checkpoint = torch.load(chkpt_path)
164 | model = checkpoint['model']
165 | optimizer = checkpoint['optimizer']
166 | print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
167 | train_loader, val_loader = create_data_loaders(args)
168 |
169 | # create new model
170 | else:
171 | train_loader, val_loader = create_data_loaders(args)
172 | print("=> creating Model ({}-{}) ...".format(args.arch, args.decoder))
173 | in_channels = len(args.modality)
174 | if args.arch == 'resnet50':
175 | model = ResNet(layers=50, decoder=args.decoder, output_size=train_loader.dataset.output_size,
176 | in_channels=in_channels, pretrained=args.pretrained)
177 | elif args.arch == 'resnet18':
178 | model = ResNet(layers=18, decoder=args.decoder, output_size=train_loader.dataset.output_size,
179 | in_channels=in_channels, pretrained=args.pretrained)
180 | elif args.arch == 'mobilenet':
181 | model = MobileNetSkipAdd(output_size=train_loader.dataset.output_size,
182 | pretrained=args.pretrained)
183 |
184 | print("=> model created " + str(train_loader.dataset.output_size))
185 |
186 | optimizer = torch.optim.SGD(model.parameters(), args.lr, \
187 | momentum=args.momentum, weight_decay=args.weight_decay)
188 |
189 | # model = torch.nn.DataParallel(model).cuda() # for multi-gpu training
190 | model = model.cuda()
191 |
192 | # define loss function (criterion) and optimizer
193 | if args.criterion == 'l2':
194 | criterion = criteria.MaskedMSELoss().cuda()
195 | elif args.criterion == 'l1':
196 | criterion = criteria.MaskedL1Loss().cuda()
197 |
198 | # create results folder, if not already exists
199 | output_directory = utils.get_output_directory(args)
200 | if not os.path.exists(output_directory):
201 | os.makedirs(output_directory)
202 | train_csv = os.path.join(output_directory, 'train.csv')
203 | test_csv = os.path.join(output_directory, 'test.csv')
204 | best_txt = os.path.join(output_directory, 'best.txt')
205 |
206 | # create new csv files with only header
207 | if not args.resume:
208 | with open(train_csv, 'w') as csvfile:
209 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
210 | writer.writeheader()
211 | with open(test_csv, 'w') as csvfile:
212 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
213 | writer.writeheader()
214 |
215 | for epoch in range(start_epoch, args.epochs):
216 | utils.adjust_learning_rate(optimizer, epoch, args.lr)
217 | train(train_loader, model, criterion, optimizer, epoch) # train for one epoch
218 | result, img_merge = validate(val_loader, model, epoch) # evaluate on validation set
219 |
220 | # remember best rmse and save checkpoint
221 | is_best = result.rmse < best_result.rmse
222 | if is_best:
223 | best_result = result
224 | with open(best_txt, 'w') as txtfile:
225 | txtfile.write("epoch={}\nmse={:.3f}\nrmse={:.3f}\nabsrel={:.3f}\nlg10={:.3f}\nmae={:.3f}\ndelta1={:.3f}\nt_gpu={:.4f}\n".
226 | format(epoch, result.mse, result.rmse, result.absrel, result.lg10, result.mae, result.delta1, result.gpu_time))
227 | if img_merge is not None:
228 | img_filename = output_directory + '/comparison_best.png'
229 | utils.save_image(img_merge, img_filename)
230 |
231 | utils.save_checkpoint({
232 | 'args': args,
233 | 'epoch': epoch,
234 | 'arch': args.arch,
235 | 'model': model,
236 | 'best_result': best_result,
237 | 'optimizer' : optimizer,
238 | }, is_best, epoch, output_directory)
239 |
240 |
241 | def train(train_loader, model, criterion, optimizer, epoch):
242 | average_meter = AverageMeter()
243 | model.train() # switch to train mode
244 | end = time.time()
245 | for i, (input, target) in enumerate(train_loader):
246 |
247 | input, target = input.cuda(), target.cuda()
248 | torch.cuda.synchronize()
249 | data_time = time.time() - end
250 |
251 | # compute pred
252 | end = time.time()
253 | pred = model(input)
254 | loss = criterion(pred, target)
255 | optimizer.zero_grad()
256 | loss.backward() # compute gradient and do SGD step
257 | optimizer.step()
258 | torch.cuda.synchronize()
259 | gpu_time = time.time() - end
260 |
261 | #print('input size: ' + str(input.size()))
262 | #print('output size: ' + str(pred.size()))
263 |
264 | # measure accuracy and record loss
265 | result = Result()
266 | result.evaluate(pred.data, target.data)
267 | average_meter.update(result, gpu_time, data_time, input.size(0))
268 | end = time.time()
269 |
270 | if (i + 1) % args.print_freq == 0:
271 | print('=> output: {}'.format(output_directory))
272 | print('Train Epoch: {0} [{1}/{2}]\t'
273 | 't_Data={data_time:.3f}({average.data_time:.3f}) '
274 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
275 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
276 | 'MAE={result.mae:.2f}({average.mae:.2f}) '
277 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
278 | 'REL={result.absrel:.3f}({average.absrel:.3f}) '
279 | 'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format(
280 | epoch, i+1, len(train_loader), data_time=data_time,
281 | gpu_time=gpu_time, result=result, average=average_meter.average()))
282 |
283 | avg = average_meter.average()
284 | with open(train_csv, 'a') as csvfile:
285 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
286 | writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10,
287 | 'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3,
288 | 'gpu_time': avg.gpu_time, 'data_time': avg.data_time})
289 |
290 |
291 | def validate(val_loader, model, epoch, write_to_file=True):
292 | average_meter = AverageMeter()
293 | model.eval() # switch to evaluate mode
294 | end = time.time()
295 | for i, (input, target) in enumerate(val_loader):
296 | input, target = input.cuda(), target.cuda()
297 | torch.cuda.synchronize()
298 | data_time = time.time() - end
299 |
300 | # compute output
301 | end = time.time()
302 | with torch.no_grad():
303 | pred = model(input)
304 | torch.cuda.synchronize()
305 | gpu_time = time.time() - end
306 |
307 | #print('input size: ' + str(input.size()))
308 | #print('output size: ' + str(pred.size()))
309 |
310 | # measure accuracy and record loss
311 | result = Result()
312 | result.evaluate(pred.data, target.data)
313 | average_meter.update(result, gpu_time, data_time, input.size(0))
314 | end = time.time()
315 |
316 | # save 8 images for visualization
317 | skip = 10 if args.data == 'deepscene' else 50
318 | if args.modality == 'd':
319 | img_merge = None
320 | else:
321 | if args.modality == 'rgb':
322 | rgb = input
323 | elif args.modality == 'rgbd':
324 | rgb = input[:,:3,:,:]
325 | depth = input[:,3:,:,:]
326 |
327 | if i == 0:
328 | if args.modality == 'rgbd':
329 | img_merge = utils.merge_into_row_with_gt(rgb, depth, target, pred)
330 | else:
331 | img_merge = utils.merge_into_row(rgb, target, pred)
332 | elif (i < 8*skip) and (i % skip == 0):
333 | if args.modality == 'rgbd':
334 | row = utils.merge_into_row_with_gt(rgb, depth, target, pred)
335 | else:
336 | row = utils.merge_into_row(rgb, target, pred)
337 | img_merge = utils.add_row(img_merge, row)
338 | elif i == 8*skip:
339 | filename = output_directory + '/comparison_' + str(epoch) + '.png'
340 | utils.save_image(img_merge, filename)
341 |
342 | if (i+1) % args.print_freq == 0:
343 | print('Test: [{0}/{1}]\t'
344 | 't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
345 | 'RMSE={result.rmse:.2f}({average.rmse:.2f}) '
346 | 'MAE={result.mae:.2f}({average.mae:.2f}) '
347 | 'Delta1={result.delta1:.3f}({average.delta1:.3f}) '
348 | 'REL={result.absrel:.3f}({average.absrel:.3f}) '
349 | 'Lg10={result.lg10:.3f}({average.lg10:.3f}) '.format(
350 | i+1, len(val_loader), gpu_time=gpu_time, result=result, average=average_meter.average()))
351 |
352 | avg = average_meter.average()
353 |
354 | print('\n*\n'
355 | 'RMSE={average.rmse:.3f}\n'
356 | 'MAE={average.mae:.3f}\n'
357 | 'Delta1={average.delta1:.3f}\n'
358 | 'REL={average.absrel:.3f}\n'
359 | 'Lg10={average.lg10:.3f}\n'
360 | 't_GPU={time:.3f}\n'.format(
361 | average=avg, time=avg.gpu_time))
362 |
363 | if write_to_file:
364 | with open(test_csv, 'a') as csvfile:
365 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
366 | writer.writerow({'mse': avg.mse, 'rmse': avg.rmse, 'absrel': avg.absrel, 'lg10': avg.lg10,
367 | 'mae': avg.mae, 'delta1': avg.delta1, 'delta2': avg.delta2, 'delta3': avg.delta3,
368 | 'data_time': avg.data_time, 'gpu_time': avg.gpu_time})
369 | return avg, img_merge
370 |
371 |
372 | # export model to ONNX
373 | def export(model, path, dataset):
374 | print('=> exporting ONNX model to: ' + path)
375 | model.eval()
376 |
377 | # set the input size from the dataset
378 | input_size = (1, 3, 448, 448) #(1, 3, 224, 224) #(1, 3, 480, 640) #(1, 3, 228, 304) # nyudepthv2
379 |
380 | if dataset == "kitti":
381 | input_size = (1, 3, 228, 912)
382 |
383 | input = torch.ones(input_size).cuda()
384 | print('=> input resolution: ' + str(input_size))
385 |
386 | # set the input/output layer names
387 | input_names = [ "input_0" ]
388 | output_names = [ "output_0" ]
389 |
390 | print(model)
391 |
392 | # export the model
393 | torch.onnx.export(model, input, path, verbose=True, input_names=input_names, output_names=output_names)
394 | print('=> ONNX model exported to: ' + path)
395 |
396 | if __name__ == '__main__':
397 | main()
398 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 |
5 | def log10(x):
6 | """Convert a new tensor with the base-10 logarithm of the elements of x. """
7 | return torch.log(x) / math.log(10)
8 |
9 | class Result(object):
10 | def __init__(self):
11 | self.irmse, self.imae = 0, 0
12 | self.mse, self.rmse, self.mae = 0, 0, 0
13 | self.absrel, self.lg10 = 0, 0
14 | self.delta1, self.delta2, self.delta3 = 0, 0, 0
15 | self.data_time, self.gpu_time = 0, 0
16 |
17 | def set_to_worst(self):
18 | self.irmse, self.imae = np.inf, np.inf
19 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf
20 | self.absrel, self.lg10 = np.inf, np.inf
21 | self.delta1, self.delta2, self.delta3 = 0, 0, 0
22 | self.data_time, self.gpu_time = 0, 0
23 |
24 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time):
25 | self.irmse, self.imae = irmse, imae
26 | self.mse, self.rmse, self.mae = mse, rmse, mae
27 | self.absrel, self.lg10 = absrel, lg10
28 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3
29 | self.data_time, self.gpu_time = data_time, gpu_time
30 |
31 | def evaluate(self, output, target):
32 | valid_mask = target>0
33 | output = output[valid_mask]
34 | target = target[valid_mask]
35 |
36 | abs_diff = (output - target).abs()
37 |
38 | self.mse = float((torch.pow(abs_diff, 2)).mean())
39 | self.rmse = math.sqrt(self.mse)
40 | self.mae = float(abs_diff.mean())
41 | self.lg10 = float((log10(output) - log10(target)).abs().mean())
42 | self.absrel = float((abs_diff / target).mean())
43 |
44 | maxRatio = torch.max(output / target, target / output)
45 | self.delta1 = float((maxRatio < 1.25).float().mean())
46 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean())
47 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean())
48 | self.data_time = 0
49 | self.gpu_time = 0
50 |
51 | inv_output = 1 / output
52 | inv_target = 1 / target
53 | abs_inv_diff = (inv_output - inv_target).abs()
54 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
55 | self.imae = float(abs_inv_diff.mean())
56 |
57 |
58 | class AverageMeter(object):
59 | def __init__(self):
60 | self.reset()
61 |
62 | def reset(self):
63 | self.count = 0.0
64 |
65 | self.sum_irmse, self.sum_imae = 0, 0
66 | self.sum_mse, self.sum_rmse, self.sum_mae = 0, 0, 0
67 | self.sum_absrel, self.sum_lg10 = 0, 0
68 | self.sum_delta1, self.sum_delta2, self.sum_delta3 = 0, 0, 0
69 | self.sum_data_time, self.sum_gpu_time = 0, 0
70 |
71 | def update(self, result, gpu_time, data_time, n=1):
72 | self.count += n
73 |
74 | self.sum_irmse += n*result.irmse
75 | self.sum_imae += n*result.imae
76 | self.sum_mse += n*result.mse
77 | self.sum_rmse += n*result.rmse
78 | self.sum_mae += n*result.mae
79 | self.sum_absrel += n*result.absrel
80 | self.sum_lg10 += n*result.lg10
81 | self.sum_delta1 += n*result.delta1
82 | self.sum_delta2 += n*result.delta2
83 | self.sum_delta3 += n*result.delta3
84 | self.sum_data_time += n*data_time
85 | self.sum_gpu_time += n*gpu_time
86 |
87 | def average(self):
88 | avg = Result()
89 | avg.update(
90 | self.sum_irmse / self.count, self.sum_imae / self.count,
91 | self.sum_mse / self.count, self.sum_rmse / self.count, self.sum_mae / self.count,
92 | self.sum_absrel / self.count, self.sum_lg10 / self.count,
93 | self.sum_delta1 / self.count, self.sum_delta2 / self.count, self.sum_delta3 / self.count,
94 | self.sum_gpu_time / self.count, self.sum_data_time / self.count)
95 | return avg
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.models
5 | import collections
6 | import math
7 |
8 | class Unpool(nn.Module):
9 | # Unpool: 2*2 unpooling with zero padding
10 | def __init__(self, num_channels, stride=2):
11 | super(Unpool, self).__init__()
12 |
13 | self.num_channels = num_channels
14 | self.stride = stride
15 |
16 | # create kernel [1, 0; 0, 0]
17 | self.weights = torch.autograd.Variable(torch.zeros(num_channels, 1, stride, stride).cuda()) # currently not compatible with running on CPU
18 | self.weights[:,:,0,0] = 1
19 |
20 | def forward(self, x):
21 | return F.conv_transpose2d(x, self.weights, stride=self.stride, groups=self.num_channels)
22 |
23 | def weights_init(m):
24 | # Initialize filters with Gaussian random weights
25 | if isinstance(m, nn.Conv2d):
26 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
27 | m.weight.data.normal_(0, math.sqrt(2. / n))
28 | if m.bias is not None:
29 | m.bias.data.zero_()
30 | elif isinstance(m, nn.ConvTranspose2d):
31 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
32 | m.weight.data.normal_(0, math.sqrt(2. / n))
33 | if m.bias is not None:
34 | m.bias.data.zero_()
35 | elif isinstance(m, nn.BatchNorm2d):
36 | m.weight.data.fill_(1)
37 | m.bias.data.zero_()
38 |
39 | class Decoder(nn.Module):
40 | # Decoder is the base class for all decoders
41 |
42 | names = ['deconv2', 'deconv3', 'upconv', 'upproj']
43 |
44 | def __init__(self):
45 | super(Decoder, self).__init__()
46 |
47 | self.layer1 = None
48 | self.layer2 = None
49 | self.layer3 = None
50 | self.layer4 = None
51 |
52 | def forward(self, x):
53 | x = self.layer1(x)
54 | x = self.layer2(x)
55 | x = self.layer3(x)
56 | x = self.layer4(x)
57 | return x
58 |
59 | class DeConv(Decoder):
60 | def __init__(self, in_channels, kernel_size):
61 | assert kernel_size>=2, "kernel_size out of range: {}".format(kernel_size)
62 | super(DeConv, self).__init__()
63 |
64 | def convt(in_channels):
65 | stride = 2
66 | padding = (kernel_size - 1) // 2
67 | output_padding = kernel_size % 2
68 | assert -2 - 2*padding + kernel_size + output_padding == 0, "deconv parameters incorrect"
69 |
70 | module_name = "deconv{}".format(kernel_size)
71 | return nn.Sequential(collections.OrderedDict([
72 | (module_name, nn.ConvTranspose2d(in_channels,in_channels//2,kernel_size,
73 | stride,padding,output_padding,bias=False)),
74 | ('batchnorm', nn.BatchNorm2d(in_channels//2)),
75 | ('relu', nn.ReLU(inplace=True)),
76 | ]))
77 |
78 | self.layer1 = convt(in_channels)
79 | self.layer2 = convt(in_channels // 2)
80 | self.layer3 = convt(in_channels // (2 ** 2))
81 | self.layer4 = convt(in_channels // (2 ** 3))
82 |
83 | class UpConv(Decoder):
84 | # UpConv decoder consists of 4 upconv modules with decreasing number of channels and increasing feature map size
85 | def upconv_module(self, in_channels):
86 | # UpConv module: unpool -> 5*5 conv -> batchnorm -> ReLU
87 | upconv = nn.Sequential(collections.OrderedDict([
88 | ('unpool', Unpool(in_channels)),
89 | ('conv', nn.Conv2d(in_channels,in_channels//2,kernel_size=5,stride=1,padding=2,bias=False)),
90 | ('batchnorm', nn.BatchNorm2d(in_channels//2)),
91 | ('relu', nn.ReLU()),
92 | ]))
93 | return upconv
94 |
95 | def __init__(self, in_channels):
96 | super(UpConv, self).__init__()
97 | self.layer1 = self.upconv_module(in_channels)
98 | self.layer2 = self.upconv_module(in_channels//2)
99 | self.layer3 = self.upconv_module(in_channels//4)
100 | self.layer4 = self.upconv_module(in_channels//8)
101 |
102 | class UpProj(Decoder):
103 | # UpProj decoder consists of 4 upproj modules with decreasing number of channels and increasing feature map size
104 |
105 | class UpProjModule(nn.Module):
106 | # UpProj module has two branches, with a Unpool at the start and a ReLu at the end
107 | # upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm
108 | # bottom branch: 5*5 conv -> batchnorm
109 |
110 | def __init__(self, in_channels):
111 | super(UpProj.UpProjModule, self).__init__()
112 | out_channels = in_channels//2
113 | self.unpool = Unpool(in_channels)
114 | self.upper_branch = nn.Sequential(collections.OrderedDict([
115 | ('conv1', nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False)),
116 | ('batchnorm1', nn.BatchNorm2d(out_channels)),
117 | ('relu', nn.ReLU()),
118 | ('conv2', nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False)),
119 | ('batchnorm2', nn.BatchNorm2d(out_channels)),
120 | ]))
121 | self.bottom_branch = nn.Sequential(collections.OrderedDict([
122 | ('conv', nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False)),
123 | ('batchnorm', nn.BatchNorm2d(out_channels)),
124 | ]))
125 | self.relu = nn.ReLU()
126 |
127 | def forward(self, x):
128 | x = self.unpool(x)
129 | x1 = self.upper_branch(x)
130 | x2 = self.bottom_branch(x)
131 | x = x1 + x2
132 | x = self.relu(x)
133 | return x
134 |
135 | def __init__(self, in_channels):
136 | super(UpProj, self).__init__()
137 | self.layer1 = self.UpProjModule(in_channels)
138 | self.layer2 = self.UpProjModule(in_channels//2)
139 | self.layer3 = self.UpProjModule(in_channels//4)
140 | self.layer4 = self.UpProjModule(in_channels//8)
141 |
142 | def choose_decoder(decoder, in_channels):
143 | # iheight, iwidth = 10, 8
144 | if decoder[:6] == 'deconv':
145 | assert len(decoder)==7
146 | kernel_size = int(decoder[6])
147 | return DeConv(in_channels, kernel_size)
148 | elif decoder == "upproj":
149 | return UpProj(in_channels)
150 | elif decoder == "upconv":
151 | return UpConv(in_channels)
152 | else:
153 | assert False, "invalid option for decoder: {}".format(decoder)
154 |
155 |
156 | class ResNet(nn.Module):
157 | def __init__(self, layers, decoder, output_size, in_channels=3, pretrained=True, export=False):
158 |
159 | if layers not in [18, 34, 50, 101, 152]:
160 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers))
161 |
162 | super(ResNet, self).__init__()
163 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained)
164 |
165 | if in_channels == 3:
166 | self.conv1 = pretrained_model._modules['conv1']
167 | self.bn1 = pretrained_model._modules['bn1']
168 | else:
169 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
170 | self.bn1 = nn.BatchNorm2d(64)
171 | weights_init(self.conv1)
172 | weights_init(self.bn1)
173 |
174 | self.output_size = output_size
175 | self.export = export
176 |
177 | self.relu = pretrained_model._modules['relu']
178 | self.maxpool = pretrained_model._modules['maxpool']
179 | self.layer1 = pretrained_model._modules['layer1']
180 | self.layer2 = pretrained_model._modules['layer2']
181 | self.layer3 = pretrained_model._modules['layer3']
182 | self.layer4 = pretrained_model._modules['layer4']
183 |
184 | # clear memory
185 | del pretrained_model
186 |
187 | # define number of intermediate channels
188 | if layers <= 34:
189 | num_channels = 512
190 | elif layers >= 50:
191 | num_channels = 2048
192 |
193 | self.conv2 = nn.Conv2d(num_channels,num_channels//2,kernel_size=1,bias=False)
194 | self.bn2 = nn.BatchNorm2d(num_channels//2)
195 | self.decoder = choose_decoder(decoder, num_channels//2)
196 |
197 | # setting bias=true doesn't improve accuracy
198 | self.conv3 = nn.Conv2d(num_channels//32,1,kernel_size=3,stride=1,padding=1,bias=False)
199 | self.bilinear = nn.Upsample(size=self.output_size, mode='bilinear', align_corners=True)
200 |
201 | # weight init
202 | self.conv2.apply(weights_init)
203 | self.bn2.apply(weights_init)
204 | self.decoder.apply(weights_init)
205 | self.conv3.apply(weights_init)
206 |
207 | def forward(self, x):
208 | # resnet
209 | x = self.conv1(x)
210 | x = self.bn1(x)
211 | x = self.relu(x)
212 | x = self.maxpool(x)
213 | x = self.layer1(x)
214 | x = self.layer2(x)
215 | x = self.layer3(x)
216 | x = self.layer4(x)
217 |
218 | x = self.conv2(x)
219 | x = self.bn2(x)
220 |
221 | # decoder
222 | x = self.decoder(x)
223 | x = self.conv3(x)
224 |
225 | if not hasattr(self, 'export') or not self.export:
226 | x = self.bilinear(x) # comment out for --export to ONNX mode
227 |
228 | return x
229 |
--------------------------------------------------------------------------------
/models_fast.py:
--------------------------------------------------------------------------------
1 | #
2 | # these are the models from FastDepth, which use the training code from sparse-to-dense:
3 | #
4 | # - https://github.com/dwofk/fast-depth/blob/master/models.py
5 | # - https://github.com/dwofk/fast-depth/issues/3#issuecomment-510545490
6 | #
7 | import os
8 | import torch
9 | import torch.nn as nn
10 | import torchvision.models
11 | import collections
12 | import math
13 | import torch.nn.functional as F
14 | import imagenet.mobilenet
15 |
16 | class Identity(nn.Module):
17 | # a dummy identity module
18 | def __init__(self):
19 | super(Identity, self).__init__()
20 |
21 | def forward(self, x):
22 | return x
23 |
24 | class Unpool(nn.Module):
25 | # Unpool: 2*2 unpooling with zero padding
26 | def __init__(self, stride=2):
27 | super(Unpool, self).__init__()
28 |
29 | self.stride = stride
30 |
31 | # create kernel [1, 0; 0, 0]
32 | self.mask = torch.zeros(1, 1, stride, stride)
33 | self.mask[:,:,0,0] = 1
34 |
35 | def forward(self, x):
36 | assert x.dim() == 4
37 | num_channels = x.size(1)
38 | return F.conv_transpose2d(x,
39 | self.mask.detach().type_as(x).expand(num_channels, 1, -1, -1),
40 | stride=self.stride, groups=num_channels)
41 |
42 | def weights_init(m):
43 | # Initialize kernel weights with Gaussian distributions
44 | if isinstance(m, nn.Conv2d):
45 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
46 | m.weight.data.normal_(0, math.sqrt(2. / n))
47 | if m.bias is not None:
48 | m.bias.data.zero_()
49 | elif isinstance(m, nn.ConvTranspose2d):
50 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
51 | m.weight.data.normal_(0, math.sqrt(2. / n))
52 | if m.bias is not None:
53 | m.bias.data.zero_()
54 | elif isinstance(m, nn.BatchNorm2d):
55 | m.weight.data.fill_(1)
56 | m.bias.data.zero_()
57 |
58 | def conv(in_channels, out_channels, kernel_size):
59 | padding = (kernel_size-1) // 2
60 | assert 2*padding == kernel_size-1, "parameters incorrect. kernel={}, padding={}".format(kernel_size, padding)
61 | return nn.Sequential(
62 | nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=padding,bias=False),
63 | nn.BatchNorm2d(out_channels),
64 | nn.ReLU(inplace=True),
65 | )
66 |
67 | def depthwise(in_channels, kernel_size):
68 | padding = (kernel_size-1) // 2
69 | assert 2*padding == kernel_size-1, "parameters incorrect. kernel={}, padding={}".format(kernel_size, padding)
70 | return nn.Sequential(
71 | nn.Conv2d(in_channels,in_channels,kernel_size,stride=1,padding=padding,bias=False,groups=in_channels),
72 | nn.BatchNorm2d(in_channels),
73 | nn.ReLU(inplace=True),
74 | )
75 |
76 | def pointwise(in_channels, out_channels):
77 | return nn.Sequential(
78 | nn.Conv2d(in_channels,out_channels,1,1,0,bias=False),
79 | nn.BatchNorm2d(out_channels),
80 | nn.ReLU(inplace=True),
81 | )
82 |
83 | def convt(in_channels, out_channels, kernel_size):
84 | stride = 2
85 | padding = (kernel_size - 1) // 2
86 | output_padding = kernel_size % 2
87 | assert -2 - 2*padding + kernel_size + output_padding == 0, "deconv parameters incorrect"
88 | return nn.Sequential(
89 | nn.ConvTranspose2d(in_channels,out_channels,kernel_size,
90 | stride,padding,output_padding,bias=False),
91 | nn.BatchNorm2d(out_channels),
92 | nn.ReLU(inplace=True),
93 | )
94 |
95 | def convt_dw(channels, kernel_size):
96 | stride = 2
97 | padding = (kernel_size - 1) // 2
98 | output_padding = kernel_size % 2
99 | assert -2 - 2*padding + kernel_size + output_padding == 0, "deconv parameters incorrect"
100 | return nn.Sequential(
101 | nn.ConvTranspose2d(channels,channels,kernel_size,
102 | stride,padding,output_padding,bias=False,groups=channels),
103 | nn.BatchNorm2d(channels),
104 | nn.ReLU(inplace=True),
105 | )
106 |
107 | def upconv(in_channels, out_channels):
108 | return nn.Sequential(
109 | Unpool(2),
110 | nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False),
111 | nn.BatchNorm2d(out_channels),
112 | nn.ReLU(),
113 | )
114 |
115 | class upproj(nn.Module):
116 | # UpProj module has two branches, with a Unpool at the start and a ReLu at the end
117 | # upper branch: 5*5 conv -> batchnorm -> ReLU -> 3*3 conv -> batchnorm
118 | # bottom branch: 5*5 conv -> batchnorm
119 |
120 | def __init__(self, in_channels, out_channels):
121 | super(upproj, self).__init__()
122 | self.unpool = Unpool(2)
123 | self.branch1 = nn.Sequential(
124 | nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False),
125 | nn.BatchNorm2d(out_channels),
126 | nn.ReLU(inplace=True),
127 | nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False),
128 | nn.BatchNorm2d(out_channels),
129 | )
130 | self.branch2 = nn.Sequential(
131 | nn.Conv2d(in_channels,out_channels,kernel_size=5,stride=1,padding=2,bias=False),
132 | nn.BatchNorm2d(out_channels),
133 | )
134 |
135 | def forward(self, x):
136 | x = self.unpool(x)
137 | x1 = self.branch1(x)
138 | x2 = self.branch2(x)
139 | return F.relu(x1 + x2)
140 |
141 | class Decoder(nn.Module):
142 | names = ['deconv{}{}'.format(i,dw) for i in range(3,10,2) for dw in ['', 'dw']]
143 | names.append("upconv")
144 | names.append("upproj")
145 | for i in range(3,10,2):
146 | for dw in ['', 'dw']:
147 | names.append("nnconv{}{}".format(i, dw))
148 | names.append("blconv{}{}".format(i, dw))
149 | names.append("shuffle{}{}".format(i, dw))
150 |
151 | class DeConv(nn.Module):
152 |
153 | def __init__(self, kernel_size, dw):
154 | super(DeConv, self).__init__()
155 | if dw:
156 | self.convt1 = nn.Sequential(
157 | convt_dw(1024, kernel_size),
158 | pointwise(1024, 512))
159 | self.convt2 = nn.Sequential(
160 | convt_dw(512, kernel_size),
161 | pointwise(512, 256))
162 | self.convt3 = nn.Sequential(
163 | convt_dw(256, kernel_size),
164 | pointwise(256, 128))
165 | self.convt4 = nn.Sequential(
166 | convt_dw(128, kernel_size),
167 | pointwise(128, 64))
168 | self.convt5 = nn.Sequential(
169 | convt_dw(64, kernel_size),
170 | pointwise(64, 32))
171 | else:
172 | self.convt1 = convt(1024, 512, kernel_size)
173 | self.convt2 = convt(512, 256, kernel_size)
174 | self.convt3 = convt(256, 128, kernel_size)
175 | self.convt4 = convt(128, 64, kernel_size)
176 | self.convt5 = convt(64, 32, kernel_size)
177 | self.convf = pointwise(32, 1)
178 |
179 | def forward(self, x):
180 | x = self.convt1(x)
181 | x = self.convt2(x)
182 | x = self.convt3(x)
183 | x = self.convt4(x)
184 | x = self.convt5(x)
185 | x = self.convf(x)
186 | return x
187 |
188 |
189 | class UpConv(nn.Module):
190 |
191 | def __init__(self):
192 | super(UpConv, self).__init__()
193 | self.upconv1 = upconv(1024, 512)
194 | self.upconv2 = upconv(512, 256)
195 | self.upconv3 = upconv(256, 128)
196 | self.upconv4 = upconv(128, 64)
197 | self.upconv5 = upconv(64, 32)
198 | self.convf = pointwise(32, 1)
199 |
200 | def forward(self, x):
201 | x = self.upconv1(x)
202 | x = self.upconv2(x)
203 | x = self.upconv3(x)
204 | x = self.upconv4(x)
205 | x = self.upconv5(x)
206 | x = self.convf(x)
207 | return x
208 |
209 | class UpProj(nn.Module):
210 | # UpProj decoder consists of 4 upproj modules with decreasing number of channels and increasing feature map size
211 |
212 | def __init__(self):
213 | super(UpProj, self).__init__()
214 | self.upproj1 = upproj(1024, 512)
215 | self.upproj2 = upproj(512, 256)
216 | self.upproj3 = upproj(256, 128)
217 | self.upproj4 = upproj(128, 64)
218 | self.upproj5 = upproj(64, 32)
219 | self.convf = pointwise(32, 1)
220 |
221 | def forward(self, x):
222 | x = self.upproj1(x)
223 | x = self.upproj2(x)
224 | x = self.upproj3(x)
225 | x = self.upproj4(x)
226 | x = self.upproj5(x)
227 | x = self.convf(x)
228 | return x
229 |
230 | class NNConv(nn.Module):
231 |
232 | def __init__(self, kernel_size, dw):
233 | super(NNConv, self).__init__()
234 | if dw:
235 | self.conv1 = nn.Sequential(
236 | depthwise(1024, kernel_size),
237 | pointwise(1024, 512))
238 | self.conv2 = nn.Sequential(
239 | depthwise(512, kernel_size),
240 | pointwise(512, 256))
241 | self.conv3 = nn.Sequential(
242 | depthwise(256, kernel_size),
243 | pointwise(256, 128))
244 | self.conv4 = nn.Sequential(
245 | depthwise(128, kernel_size),
246 | pointwise(128, 64))
247 | self.conv5 = nn.Sequential(
248 | depthwise(64, kernel_size),
249 | pointwise(64, 32))
250 | self.conv6 = pointwise(32, 1)
251 | else:
252 | self.conv1 = conv(1024, 512, kernel_size)
253 | self.conv2 = conv(512, 256, kernel_size)
254 | self.conv3 = conv(256, 128, kernel_size)
255 | self.conv4 = conv(128, 64, kernel_size)
256 | self.conv5 = conv(64, 32, kernel_size)
257 | self.conv6 = pointwise(32, 1)
258 |
259 | def forward(self, x):
260 | x = self.conv1(x)
261 | x = F.interpolate(x, scale_factor=2, mode='nearest')
262 |
263 | x = self.conv2(x)
264 | x = F.interpolate(x, scale_factor=2, mode='nearest')
265 |
266 | x = self.conv3(x)
267 | x = F.interpolate(x, scale_factor=2, mode='nearest')
268 |
269 | x = self.conv4(x)
270 | x = F.interpolate(x, scale_factor=2, mode='nearest')
271 |
272 | x = self.conv5(x)
273 | x = F.interpolate(x, scale_factor=2, mode='nearest')
274 |
275 | x = self.conv6(x)
276 | return x
277 |
278 | class BLConv(NNConv):
279 |
280 | def __init__(self, kernel_size, dw):
281 | super(BLConv, self).__init__(kernel_size, dw)
282 |
283 | def forward(self, x):
284 | x = self.conv1(x)
285 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
286 |
287 | x = self.conv2(x)
288 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
289 |
290 | x = self.conv3(x)
291 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
292 |
293 | x = self.conv4(x)
294 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
295 |
296 | x = self.conv5(x)
297 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
298 |
299 | x = self.conv6(x)
300 | return x
301 |
302 | class ShuffleConv(nn.Module):
303 |
304 | def __init__(self, kernel_size, dw):
305 | super(ShuffleConv, self).__init__()
306 | if dw:
307 | self.conv1 = nn.Sequential(
308 | depthwise(256, kernel_size),
309 | pointwise(256, 256))
310 | self.conv2 = nn.Sequential(
311 | depthwise(64, kernel_size),
312 | pointwise(64, 64))
313 | self.conv3 = nn.Sequential(
314 | depthwise(16, kernel_size),
315 | pointwise(16, 16))
316 | self.conv4 = nn.Sequential(
317 | depthwise(4, kernel_size),
318 | pointwise(4, 4))
319 | else:
320 | self.conv1 = conv(256, 256, kernel_size)
321 | self.conv2 = conv(64, 64, kernel_size)
322 | self.conv3 = conv(16, 16, kernel_size)
323 | self.conv4 = conv(4, 4, kernel_size)
324 |
325 | def forward(self, x):
326 | x = F.pixel_shuffle(x, 2)
327 | x = self.conv1(x)
328 |
329 | x = F.pixel_shuffle(x, 2)
330 | x = self.conv2(x)
331 |
332 | x = F.pixel_shuffle(x, 2)
333 | x = self.conv3(x)
334 |
335 | x = F.pixel_shuffle(x, 2)
336 | x = self.conv4(x)
337 |
338 | x = F.pixel_shuffle(x, 2)
339 | return x
340 |
341 | def choose_decoder(decoder):
342 | depthwise = ('dw' in decoder)
343 | if decoder[:6] == 'deconv':
344 | assert len(decoder)==7 or (len(decoder)==9 and 'dw' in decoder)
345 | kernel_size = int(decoder[6])
346 | model = DeConv(kernel_size, depthwise)
347 | elif decoder == "upproj":
348 | model = UpProj()
349 | elif decoder == "upconv":
350 | model = UpConv()
351 | elif decoder[:7] == 'shuffle':
352 | assert len(decoder)==8 or (len(decoder)==10 and 'dw' in decoder)
353 | kernel_size = int(decoder[7])
354 | model = ShuffleConv(kernel_size, depthwise)
355 | elif decoder[:6] == 'nnconv':
356 | assert len(decoder)==7 or (len(decoder)==9 and 'dw' in decoder)
357 | kernel_size = int(decoder[6])
358 | model = NNConv(kernel_size, depthwise)
359 | elif decoder[:6] == 'blconv':
360 | assert len(decoder)==7 or (len(decoder)==9 and 'dw' in decoder)
361 | kernel_size = int(decoder[6])
362 | model = BLConv(kernel_size, depthwise)
363 | else:
364 | assert False, "invalid option for decoder: {}".format(decoder)
365 | model.apply(weights_init)
366 | return model
367 |
368 |
369 | class ResNet(nn.Module):
370 | def __init__(self, layers, decoder, output_size, in_channels=3, pretrained=True):
371 |
372 | if layers not in [18, 34, 50, 101, 152]:
373 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers))
374 |
375 | super(ResNet, self).__init__()
376 | self.output_size = output_size
377 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained)
378 | if not pretrained:
379 | pretrained_model.apply(weights_init)
380 |
381 | if in_channels == 3:
382 | self.conv1 = pretrained_model._modules['conv1']
383 | self.bn1 = pretrained_model._modules['bn1']
384 | else:
385 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
386 | self.bn1 = nn.BatchNorm2d(64)
387 | weights_init(self.conv1)
388 | weights_init(self.bn1)
389 |
390 | self.relu = pretrained_model._modules['relu']
391 | self.maxpool = pretrained_model._modules['maxpool']
392 | self.layer1 = pretrained_model._modules['layer1']
393 | self.layer2 = pretrained_model._modules['layer2']
394 | self.layer3 = pretrained_model._modules['layer3']
395 | self.layer4 = pretrained_model._modules['layer4']
396 |
397 | # clear memory
398 | del pretrained_model
399 |
400 | # define number of intermediate channels
401 | if layers <= 34:
402 | num_channels = 512
403 | elif layers >= 50:
404 | num_channels = 2048
405 | self.conv2 = nn.Conv2d(num_channels, 1024, 1)
406 | weights_init(self.conv2)
407 | self.decoder = choose_decoder(decoder)
408 |
409 | def forward(self, x):
410 | # resnet
411 | x = self.conv1(x)
412 | x = self.bn1(x)
413 | x = self.relu(x)
414 | x = self.maxpool(x)
415 | x = self.layer1(x)
416 | x = self.layer2(x)
417 | x = self.layer3(x)
418 | x = self.layer4(x)
419 | x = self.conv2(x)
420 |
421 | # decoder
422 | x = self.decoder(x)
423 |
424 | return x
425 |
426 | class MobileNet(nn.Module):
427 | def __init__(self, decoder, output_size, in_channels=3, pretrained=True):
428 |
429 | super(MobileNet, self).__init__()
430 | self.output_size = output_size
431 | mobilenet = imagenet.mobilenet.MobileNet()
432 | if pretrained:
433 | pretrained_path = os.path.join('imagenet', 'results', 'imagenet.arch=mobilenet.lr=0.1.bs=256', 'model_best.pth.tar')
434 | checkpoint = torch.load(pretrained_path)
435 | state_dict = checkpoint['state_dict']
436 |
437 | from collections import OrderedDict
438 | new_state_dict = OrderedDict()
439 | for k, v in state_dict.items():
440 | name = k[7:] # remove `module.`
441 | new_state_dict[name] = v
442 | mobilenet.load_state_dict(new_state_dict)
443 | else:
444 | mobilenet.apply(weights_init)
445 |
446 | if in_channels == 3:
447 | self.mobilenet = nn.Sequential(*(mobilenet.model[i] for i in range(14)))
448 | else:
449 | def conv_bn(inp, oup, stride):
450 | return nn.Sequential(
451 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
452 | nn.BatchNorm2d(oup),
453 | nn.ReLU6(inplace=True)
454 | )
455 |
456 | self.mobilenet = nn.Sequential(
457 | conv_bn(in_channels, 32, 2),
458 | *(mobilenet.model[i] for i in range(1,14))
459 | )
460 |
461 | self.decoder = choose_decoder(decoder)
462 |
463 | def forward(self, x):
464 | x = self.mobilenet(x)
465 | x = self.decoder(x)
466 | return x
467 |
468 | class ResNetSkipAdd(nn.Module):
469 | def __init__(self, layers, output_size, in_channels=3, pretrained=True):
470 |
471 | if layers not in [18, 34, 50, 101, 152]:
472 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers))
473 |
474 | super(ResNetSkipAdd, self).__init__()
475 | self.output_size = output_size
476 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained)
477 | if not pretrained:
478 | pretrained_model.apply(weights_init)
479 |
480 | if in_channels == 3:
481 | self.conv1 = pretrained_model._modules['conv1']
482 | self.bn1 = pretrained_model._modules['bn1']
483 | else:
484 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
485 | self.bn1 = nn.BatchNorm2d(64)
486 | weights_init(self.conv1)
487 | weights_init(self.bn1)
488 |
489 | self.relu = pretrained_model._modules['relu']
490 | self.maxpool = pretrained_model._modules['maxpool']
491 | self.layer1 = pretrained_model._modules['layer1']
492 | self.layer2 = pretrained_model._modules['layer2']
493 | self.layer3 = pretrained_model._modules['layer3']
494 | self.layer4 = pretrained_model._modules['layer4']
495 |
496 | # clear memory
497 | del pretrained_model
498 |
499 | # define number of intermediate channels
500 | if layers <= 34:
501 | num_channels = 512
502 | elif layers >= 50:
503 | num_channels = 2048
504 | self.conv2 = nn.Conv2d(num_channels, 1024, 1)
505 | weights_init(self.conv2)
506 |
507 | kernel_size = 5
508 | self.decode_conv1 = conv(1024, 512, kernel_size)
509 | self.decode_conv2 = conv(512, 256, kernel_size)
510 | self.decode_conv3 = conv(256, 128, kernel_size)
511 | self.decode_conv4 = conv(128, 64, kernel_size)
512 | self.decode_conv5 = conv(64, 32, kernel_size)
513 | self.decode_conv6 = pointwise(32, 1)
514 | weights_init(self.decode_conv1)
515 | weights_init(self.decode_conv2)
516 | weights_init(self.decode_conv3)
517 | weights_init(self.decode_conv4)
518 | weights_init(self.decode_conv5)
519 | weights_init(self.decode_conv6)
520 |
521 | def forward(self, x):
522 | # resnet
523 | x = self.conv1(x)
524 | x = self.bn1(x)
525 | x1 = self.relu(x)
526 | # print("x1", x1.size())
527 | x2 = self.maxpool(x1)
528 | # print("x2", x2.size())
529 | x3 = self.layer1(x2)
530 | # print("x3", x3.size())
531 | x4 = self.layer2(x3)
532 | # print("x4", x4.size())
533 | x5 = self.layer3(x4)
534 | # print("x5", x5.size())
535 | x6 = self.layer4(x5)
536 | # print("x6", x6.size())
537 | x7 = self.conv2(x6)
538 |
539 | # decoder
540 | y10 = self.decode_conv1(x7)
541 | # print("y10", y10.size())
542 | y9 = F.interpolate(y10 + x6, scale_factor=2, mode='nearest')
543 | # print("y9", y9.size())
544 | y8 = self.decode_conv2(y9)
545 | # print("y8", y8.size())
546 | y7 = F.interpolate(y8 + x5, scale_factor=2, mode='nearest')
547 | # print("y7", y7.size())
548 | y6 = self.decode_conv3(y7)
549 | # print("y6", y6.size())
550 | y5 = F.interpolate(y6 + x4, scale_factor=2, mode='nearest')
551 | # print("y5", y5.size())
552 | y4 = self.decode_conv4(y5)
553 | # print("y4", y4.size())
554 | y3 = F.interpolate(y4 + x3, scale_factor=2, mode='nearest')
555 | # print("y3", y3.size())
556 | y2 = self.decode_conv5(y3 + x1)
557 | # print("y2", y2.size())
558 | y1 = F.interpolate(y2, scale_factor=2, mode='nearest')
559 | # print("y1", y1.size())
560 | y = self.decode_conv6(y1)
561 |
562 | return y
563 |
564 | class ResNetSkipConcat(nn.Module):
565 | def __init__(self, layers, output_size, in_channels=3, pretrained=True):
566 |
567 | if layers not in [18, 34, 50, 101, 152]:
568 | raise RuntimeError('Only 18, 34, 50, 101, and 152 layer model are defined for ResNet. Got {}'.format(layers))
569 |
570 | super(ResNetSkipConcat, self).__init__()
571 | self.output_size = output_size
572 | pretrained_model = torchvision.models.__dict__['resnet{}'.format(layers)](pretrained=pretrained)
573 | if not pretrained:
574 | pretrained_model.apply(weights_init)
575 |
576 | if in_channels == 3:
577 | self.conv1 = pretrained_model._modules['conv1']
578 | self.bn1 = pretrained_model._modules['bn1']
579 | else:
580 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
581 | self.bn1 = nn.BatchNorm2d(64)
582 | weights_init(self.conv1)
583 | weights_init(self.bn1)
584 |
585 | self.relu = pretrained_model._modules['relu']
586 | self.maxpool = pretrained_model._modules['maxpool']
587 | self.layer1 = pretrained_model._modules['layer1']
588 | self.layer2 = pretrained_model._modules['layer2']
589 | self.layer3 = pretrained_model._modules['layer3']
590 | self.layer4 = pretrained_model._modules['layer4']
591 |
592 | # clear memory
593 | del pretrained_model
594 |
595 | # define number of intermediate channels
596 | if layers <= 34:
597 | num_channels = 512
598 | elif layers >= 50:
599 | num_channels = 2048
600 | self.conv2 = nn.Conv2d(num_channels, 1024, 1)
601 | weights_init(self.conv2)
602 |
603 | kernel_size = 5
604 | self.decode_conv1 = conv(1024, 512, kernel_size)
605 | self.decode_conv2 = conv(768, 256, kernel_size)
606 | self.decode_conv3 = conv(384, 128, kernel_size)
607 | self.decode_conv4 = conv(192, 64, kernel_size)
608 | self.decode_conv5 = conv(128, 32, kernel_size)
609 | self.decode_conv6 = pointwise(32, 1)
610 | weights_init(self.decode_conv1)
611 | weights_init(self.decode_conv2)
612 | weights_init(self.decode_conv3)
613 | weights_init(self.decode_conv4)
614 | weights_init(self.decode_conv5)
615 | weights_init(self.decode_conv6)
616 |
617 | def forward(self, x):
618 | # resnet
619 | x = self.conv1(x)
620 | x = self.bn1(x)
621 | x1 = self.relu(x)
622 | # print("x1", x1.size())
623 | x2 = self.maxpool(x1)
624 | # print("x2", x2.size())
625 | x3 = self.layer1(x2)
626 | # print("x3", x3.size())
627 | x4 = self.layer2(x3)
628 | # print("x4", x4.size())
629 | x5 = self.layer3(x4)
630 | # print("x5", x5.size())
631 | x6 = self.layer4(x5)
632 | # print("x6", x6.size())
633 | x7 = self.conv2(x6)
634 |
635 | # decoder
636 | y10 = self.decode_conv1(x7)
637 | # print("y10", y10.size())
638 | y9 = F.interpolate(y10, scale_factor=2, mode='nearest')
639 | # print("y9", y9.size())
640 | y8 = self.decode_conv2(torch.cat((y9, x5), 1))
641 | # print("y8", y8.size())
642 | y7 = F.interpolate(y8, scale_factor=2, mode='nearest')
643 | # print("y7", y7.size())
644 | y6 = self.decode_conv3(torch.cat((y7, x4), 1))
645 | # print("y6", y6.size())
646 | y5 = F.interpolate(y6, scale_factor=2, mode='nearest')
647 | # print("y5", y5.size())
648 | y4 = self.decode_conv4(torch.cat((y5, x3), 1))
649 | # print("y4", y4.size())
650 | y3 = F.interpolate(y4, scale_factor=2, mode='nearest')
651 | # print("y3", y3.size())
652 | y2 = self.decode_conv5(torch.cat((y3, x1), 1))
653 | # print("y2", y2.size())
654 | y1 = F.interpolate(y2, scale_factor=2, mode='nearest')
655 | # print("y1", y1.size())
656 | y = self.decode_conv6(y1)
657 |
658 | return y
659 |
660 | class MobileNetSkipAdd(nn.Module):
661 | def __init__(self, output_size, pretrained=True):
662 |
663 | super(MobileNetSkipAdd, self).__init__()
664 | self.output_size = output_size
665 | mobilenet = imagenet.mobilenet.MobileNet()
666 | if pretrained:
667 | pretrained_path = os.path.join('imagenet', 'results', 'imagenet.arch=mobilenet.lr=0.1.bs=256', 'model_best.pth.tar')
668 | checkpoint = torch.load(pretrained_path)
669 | state_dict = checkpoint['state_dict']
670 |
671 | from collections import OrderedDict
672 | new_state_dict = OrderedDict()
673 | for k, v in state_dict.items():
674 | name = k[7:] # remove `module.`
675 | new_state_dict[name] = v
676 | mobilenet.load_state_dict(new_state_dict)
677 | else:
678 | mobilenet.apply(weights_init)
679 |
680 | for i in range(14):
681 | setattr( self, 'conv{}'.format(i), mobilenet.model[i])
682 |
683 | kernel_size = 5
684 | # self.decode_conv1 = conv(1024, 512, kernel_size)
685 | # self.decode_conv2 = conv(512, 256, kernel_size)
686 | # self.decode_conv3 = conv(256, 128, kernel_size)
687 | # self.decode_conv4 = conv(128, 64, kernel_size)
688 | # self.decode_conv5 = conv(64, 32, kernel_size)
689 | self.decode_conv1 = nn.Sequential(
690 | depthwise(1024, kernel_size),
691 | pointwise(1024, 512))
692 | self.decode_conv2 = nn.Sequential(
693 | depthwise(512, kernel_size),
694 | pointwise(512, 256))
695 | self.decode_conv3 = nn.Sequential(
696 | depthwise(256, kernel_size),
697 | pointwise(256, 128))
698 | self.decode_conv4 = nn.Sequential(
699 | depthwise(128, kernel_size),
700 | pointwise(128, 64))
701 | self.decode_conv5 = nn.Sequential(
702 | depthwise(64, kernel_size),
703 | pointwise(64, 32))
704 | self.decode_conv6 = pointwise(32, 1)
705 | weights_init(self.decode_conv1)
706 | weights_init(self.decode_conv2)
707 | weights_init(self.decode_conv3)
708 | weights_init(self.decode_conv4)
709 | weights_init(self.decode_conv5)
710 | weights_init(self.decode_conv6)
711 |
712 | def forward(self, x):
713 | # skip connections: dec4: enc1
714 | # dec 3: enc2 or enc3
715 | # dec 2: enc4 or enc5
716 | for i in range(14):
717 | layer = getattr(self, 'conv{}'.format(i))
718 | x = layer(x)
719 | # print("{}: {}".format(i, x.size()))
720 | if i==1:
721 | x1 = x
722 | elif i==3:
723 | x2 = x
724 | elif i==5:
725 | x3 = x
726 | for i in range(1,6):
727 | layer = getattr(self, 'decode_conv{}'.format(i))
728 | x = layer(x)
729 | x = F.interpolate(x, scale_factor=2, mode='nearest')
730 | if i==4:
731 | x = x + x1
732 | elif i==3:
733 | x = x + x2
734 | elif i==2:
735 | x = x + x3
736 | # print("{}: {}".format(i, x.size()))
737 | x = self.decode_conv6(x)
738 | return x
739 |
740 | class MobileNetSkipConcat(nn.Module):
741 | def __init__(self, output_size, pretrained=True):
742 |
743 | super(MobileNetSkipConcat, self).__init__()
744 | self.output_size = output_size
745 | mobilenet = imagenet.mobilenet.MobileNet()
746 | if pretrained:
747 | pretrained_path = os.path.join('imagenet', 'results', 'imagenet.arch=mobilenet.lr=0.1.bs=256', 'model_best.pth.tar')
748 | checkpoint = torch.load(pretrained_path)
749 | state_dict = checkpoint['state_dict']
750 |
751 | from collections import OrderedDict
752 | new_state_dict = OrderedDict()
753 | for k, v in state_dict.items():
754 | name = k[7:] # remove `module.`
755 | new_state_dict[name] = v
756 | mobilenet.load_state_dict(new_state_dict)
757 | else:
758 | mobilenet.apply(weights_init)
759 |
760 | for i in range(14):
761 | setattr( self, 'conv{}'.format(i), mobilenet.model[i])
762 |
763 | kernel_size = 5
764 | # self.decode_conv1 = conv(1024, 512, kernel_size)
765 | # self.decode_conv2 = conv(512, 256, kernel_size)
766 | # self.decode_conv3 = conv(256, 128, kernel_size)
767 | # self.decode_conv4 = conv(128, 64, kernel_size)
768 | # self.decode_conv5 = conv(64, 32, kernel_size)
769 | self.decode_conv1 = nn.Sequential(
770 | depthwise(1024, kernel_size),
771 | pointwise(1024, 512))
772 | self.decode_conv2 = nn.Sequential(
773 | depthwise(512, kernel_size),
774 | pointwise(512, 256))
775 | self.decode_conv3 = nn.Sequential(
776 | depthwise(512, kernel_size),
777 | pointwise(512, 128))
778 | self.decode_conv4 = nn.Sequential(
779 | depthwise(256, kernel_size),
780 | pointwise(256, 64))
781 | self.decode_conv5 = nn.Sequential(
782 | depthwise(128, kernel_size),
783 | pointwise(128, 32))
784 | self.decode_conv6 = pointwise(32, 1)
785 | weights_init(self.decode_conv1)
786 | weights_init(self.decode_conv2)
787 | weights_init(self.decode_conv3)
788 | weights_init(self.decode_conv4)
789 | weights_init(self.decode_conv5)
790 | weights_init(self.decode_conv6)
791 |
792 | def forward(self, x):
793 | # skip connections: dec4: enc1
794 | # dec 3: enc2 or enc3
795 | # dec 2: enc4 or enc5
796 | for i in range(14):
797 | layer = getattr(self, 'conv{}'.format(i))
798 | x = layer(x)
799 | # print("{}: {}".format(i, x.size()))
800 | if i==1:
801 | x1 = x
802 | elif i==3:
803 | x2 = x
804 | elif i==5:
805 | x3 = x
806 | for i in range(1,6):
807 | layer = getattr(self, 'decode_conv{}'.format(i))
808 | # print("{}a: {}".format(i, x.size()))
809 | x = layer(x)
810 | # print("{}b: {}".format(i, x.size()))
811 | x = F.interpolate(x, scale_factor=2, mode='nearest')
812 | if i==4:
813 | x = torch.cat((x, x1), 1)
814 | elif i==3:
815 | x = torch.cat((x, x2), 1)
816 | elif i==2:
817 | x = torch.cat((x, x3), 1)
818 | # print("{}c: {}".format(i, x.size()))
819 | x = self.decode_conv6(x)
820 | return x
821 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import shutil
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from PIL import Image
7 |
8 | cmap = plt.cm.viridis
9 |
10 | def parse_command():
11 | model_names = ['resnet18', 'resnet50', 'mobilenet']
12 | loss_names = ['l1', 'l2']
13 | data_names = ['nyudepthv2', 'kitti', 'deepscene', 'sun', 'zed']
14 | from dataloaders.dense_to_sparse import UniformSampling, SimulatedStereo
15 | sparsifier_names = [x.name for x in [UniformSampling, SimulatedStereo]]
16 | from models import Decoder
17 | decoder_names = Decoder.names
18 | from dataloaders.dataloader import MyDataloader
19 | modality_names = MyDataloader.modality_names
20 |
21 | import argparse
22 | parser = argparse.ArgumentParser(description='Sparse-to-Dense')
23 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', choices=model_names,
24 | help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)')
25 | parser.add_argument('--data', metavar='DATA', default='nyudepthv2',
26 | choices=data_names,
27 | help='dataset: ' + ' | '.join(data_names) + ' (default: nyudepthv2)')
28 | parser.add_argument('--modality', '-m', metavar='MODALITY', default='rgb', choices=modality_names,
29 | help='modality: ' + ' | '.join(modality_names) + ' (default: rgb)')
30 | parser.add_argument('-s', '--num-samples', default=0, type=int, metavar='N',
31 | help='number of sparse depth samples (default: 0)')
32 | parser.add_argument('--max-depth', default=-1.0, type=float, metavar='D',
33 | help='cut-off depth of sparsifier, negative values means infinity (default: inf [m])')
34 | parser.add_argument('--sparsifier', metavar='SPARSIFIER', default=UniformSampling.name, choices=sparsifier_names,
35 | help='sparsifier: ' + ' | '.join(sparsifier_names) + ' (default: ' + UniformSampling.name + ')')
36 | parser.add_argument('--decoder', '-d', metavar='DECODER', default='deconv2', choices=decoder_names,
37 | help='decoder: ' + ' | '.join(decoder_names) + ' (default: deconv2)')
38 | parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
39 | help='number of data loading workers (default: 10)')
40 | parser.add_argument('--epochs', default=15, type=int, metavar='N',
41 | help='number of total epochs to run (default: 15)')
42 | parser.add_argument('-c', '--criterion', metavar='LOSS', default='l1', choices=loss_names,
43 | help='loss function: ' + ' | '.join(loss_names) + ' (default: l1)')
44 | parser.add_argument('-b', '--batch-size', default=8, type=int, help='mini-batch size (default: 8)')
45 | parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
46 | metavar='LR', help='initial learning rate (default 0.01)')
47 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
48 | help='momentum')
49 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
50 | metavar='W', help='weight decay (default: 1e-4)')
51 | parser.add_argument('--print-freq', '-p', default=10, type=int,
52 | metavar='N', help='print frequency (default: 10)')
53 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
54 | help='path to latest checkpoint (default: none)')
55 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
56 | help='path to pretrained checkpoint to begin training from')
57 | parser.add_argument('-e', '--evaluate', dest='evaluate', type=str, default='',
58 | help='evaluate model on validation set')
59 | parser.add_argument('--no-pretrain', dest='pretrained', action='store_false',
60 | help='not to use ImageNet pre-trained weights')
61 | parser.add_argument('--export', default='', type=str, help='path to pre-trained model to load to export to ONNX')
62 | parser.set_defaults(pretrained=True)
63 | args = parser.parse_args()
64 | if args.modality == 'rgb' and args.num_samples != 0:
65 | print("number of samples is forced to be 0 when input modality is rgb")
66 | args.num_samples = 0
67 | if args.modality == 'rgb' and args.max_depth != 0.0:
68 | print("max depth is forced to be 0.0 when input modality is rgb/rgbd")
69 | args.max_depth = 0.0
70 | return args
71 |
72 | def save_checkpoint(state, is_best, epoch, output_directory):
73 | checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch) + '.pth.tar')
74 | torch.save(state, checkpoint_filename)
75 | if is_best:
76 | best_filename = os.path.join(output_directory, 'model_best.pth.tar')
77 | shutil.copyfile(checkpoint_filename, best_filename)
78 | if epoch > 0:
79 | prev_checkpoint_filename = os.path.join(output_directory, 'checkpoint-' + str(epoch-1) + '.pth.tar')
80 | if os.path.exists(prev_checkpoint_filename):
81 | os.remove(prev_checkpoint_filename)
82 |
83 | def adjust_learning_rate(optimizer, epoch, lr_init):
84 | """Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
85 | lr = lr_init * (0.1 ** (epoch // 5))
86 | for param_group in optimizer.param_groups:
87 | param_group['lr'] = lr
88 |
89 | def get_output_directory(args):
90 | output_directory = os.path.join('results',
91 | '{}.sparsifier={}.samples={}.modality={}.arch={}.decoder={}.criterion={}.lr={}.bs={}.pretrained={}'.
92 | format(args.data, args.sparsifier, args.num_samples, args.modality, \
93 | args.arch, args.decoder, args.criterion, args.lr, args.batch_size, \
94 | args.pretrained))
95 | return output_directory
96 |
97 |
98 | def colored_depthmap(depth, d_min=None, d_max=None):
99 | if d_min is None:
100 | d_min = np.min(depth)
101 | if d_max is None:
102 | d_max = np.max(depth)
103 | depth_relative = (depth - d_min) / (d_max - d_min)
104 | return 255 * cmap(depth_relative)[:,:,:3] # H, W, C
105 |
106 |
107 | def merge_into_row(input, depth_target, depth_pred):
108 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C
109 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
110 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy())
111 |
112 | d_min = min(np.min(depth_target_cpu), np.min(depth_pred_cpu))
113 | d_max = max(np.max(depth_target_cpu), np.max(depth_pred_cpu))
114 |
115 | print('depth_min {:f} depth_max {:f}'.format(d_min, d_max))
116 |
117 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
118 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
119 | img_merge = np.hstack([rgb, depth_target_col, depth_pred_col])
120 |
121 | return img_merge
122 |
123 |
124 | def merge_into_row_with_gt(input, depth_input, depth_target, depth_pred):
125 | rgb = 255 * np.transpose(np.squeeze(input.cpu().numpy()), (1,2,0)) # H, W, C
126 | depth_input_cpu = np.squeeze(depth_input.cpu().numpy())
127 | depth_target_cpu = np.squeeze(depth_target.cpu().numpy())
128 | depth_pred_cpu = np.squeeze(depth_pred.data.cpu().numpy())
129 |
130 | d_min = min(np.min(depth_input_cpu), np.min(depth_target_cpu), np.min(depth_pred_cpu))
131 | d_max = max(np.max(depth_input_cpu), np.max(depth_target_cpu), np.max(depth_pred_cpu))
132 | depth_input_col = colored_depthmap(depth_input_cpu, d_min, d_max)
133 | depth_target_col = colored_depthmap(depth_target_cpu, d_min, d_max)
134 | depth_pred_col = colored_depthmap(depth_pred_cpu, d_min, d_max)
135 |
136 | img_merge = np.hstack([rgb, depth_input_col, depth_target_col, depth_pred_col])
137 |
138 | return img_merge
139 |
140 |
141 | def add_row(img_merge, row):
142 | return np.vstack([img_merge, row])
143 |
144 |
145 | def save_image(img_merge, filename):
146 | img_merge = Image.fromarray(img_merge.astype('uint8'))
147 | img_merge.save(filename)
148 |
--------------------------------------------------------------------------------