├── .gitignore
├── LICENSE
├── README.md
├── criteria.py
├── dataloaders
├── calib_cam_to_cam.txt
├── kitti_loader.py
├── pose_estimator.py
└── transforms.py
├── download
├── rgb_train_downloader.sh
└── rgb_val_downloader.sh
├── helper.py
├── inverse_warp.py
├── main.py
├── metrics.py
├── model.py
└── vis_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | 2011_*
2 | *fuse_hidden*
3 | data
4 | results
5 | .DS_Store*
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | env/
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # dotenv
89 | .env
90 |
91 | # virtualenv
92 | .venv
93 | venv/
94 | ENV/
95 |
96 | # Spyder project settings
97 | .spyderproject
98 | .spyproject
99 |
100 | # Rope project settings
101 | .ropeproject
102 |
103 | # mkdocs documentation
104 | /site
105 |
106 | # mypy
107 | .mypy_cache/
108 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Fangchang Ma
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # self-supervised-depth-completion
2 |
3 | This repo is the PyTorch implementation of our ICRA'19 paper on ["Self-supervised Sparse-to-Dense: Self-supervised Depth Completion from LiDAR and Monocular Camera"](https://arxiv.org/pdf/1807.00275.pdf), developed by [Fangchang Ma](http://www.mit.edu/~fcma/), Guilherme Venturelli Cavalheiro, and [Sertac Karaman](http://karaman.mit.edu/) at MIT. A video demonstration is available on [YouTube](https://youtu.be/bGXfvF261pc).
4 |
5 |
6 |
7 |
8 |
9 | Our network is trained with the KITTI dataset alone, without pretraining on Cityscapes or other similar driving dataset (either synthetic or real). The use of additional data is likely to further improve the accuracy.
10 |
11 | Please create a new issue for code-related questions.
12 |
13 | ## Contents
14 | 1. [Dependency](#dependency)
15 | 0. [Data](#data)
16 | 0. [Trained Models](#trained-models)
17 | 0. [Commands](#commands)
18 | 0. [Citation](#citation)
19 |
20 |
21 | ## Dependency
22 | This code was tested with Python 3 and PyTorch 1.0 on Ubuntu 16.04.
23 | ```bash
24 | pip install numpy matplotlib Pillow
25 | pip install torch torchvision # pytorch
26 |
27 | # for self-supervised training requires opencv, along with the contrib modules
28 | pip install opencv-contrib-python==3.4.2.16
29 | ```
30 |
31 | ## Data
32 | - Download the [KITTI Depth](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion) Dataset from their website. Use the following scripts to extract corresponding RGB images from the raw dataset.
33 | ```bash
34 | ./download/rgb_train_downloader.sh
35 | ./download/rgb_val_downloader.sh
36 | ```
37 | The downloaded rgb files will be stored in the `../data/data_rgb` folder. The overall code, data, and results directory is structured as follows (updated on Oct 1, 2019)
38 | ```
39 | .
40 | ├── self-supervised-depth-completion
41 | ├── data
42 | | ├── data_depth_annotated
43 | | | ├── train
44 | | | ├── val
45 | | ├── data_depth_velodyne
46 | | | ├── train
47 | | | ├── val
48 | | ├── depth_selection
49 | | | ├── test_depth_completion_anonymous
50 | | | ├── test_depth_prediction_anonymous
51 | | | ├── val_selection_cropped
52 | | └── data_rgb
53 | | | ├── train
54 | | | ├── val
55 | ├── results
56 | ```
57 |
58 | ## Trained Models
59 | Download our trained models at http://datasets.lids.mit.edu/self-supervised-depth-completion to a folder of your choice.
60 | - supervised training (i.e., models trained with semi-dense lidar ground truth): http://datasets.lids.mit.edu/self-supervised-depth-completion/supervised/
61 | - self-supervised (i.e., photometric loss + sparse depth loss + smoothness loss): http://datasets.lids.mit.edu/self-supervised-depth-completion/self-supervised/
62 |
63 | ## Commands
64 | A complete list of training options is available with
65 | ```bash
66 | python main.py -h
67 | ```
68 | For instance,
69 | ```bash
70 | # train with the KITTI semi-dense annotations, rgbd input, and batch size 1
71 | python main.py --train-mode dense -b 1 --input rgbd
72 |
73 | # train with the self-supervised framework, not using ground truth
74 | python main.py --train-mode sparse+photo
75 |
76 | # resume previous training
77 | python main.py --resume [checkpoint-path]
78 |
79 | # test the trained model on the val_selection_cropped data
80 | python main.py --evaluate [checkpoint-path] --val select
81 | ```
82 |
83 | ## Citation
84 | If you use our code or method in your work, please cite the following:
85 |
86 | @article{ma2018self,
87 | title={Self-supervised Sparse-to-Dense: Self-supervised Depth Completion from LiDAR and Monocular Camera},
88 | author={Ma, Fangchang and Cavalheiro, Guilherme Venturelli and Karaman, Sertac},
89 | booktitle={ICRA},
90 | year={2019}
91 | }
92 | @article{Ma2017SparseToDense,
93 | title={Sparse-to-Dense: Depth Prediction from Sparse Depth Samples and a Single Image},
94 | author={Ma, Fangchang and Karaman, Sertac},
95 | booktitle={ICRA},
96 | year={2018}
97 | }
98 |
99 |
--------------------------------------------------------------------------------
/criteria.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | loss_names = ['l1', 'l2']
5 |
6 |
7 | class MaskedMSELoss(nn.Module):
8 | def __init__(self):
9 | super(MaskedMSELoss, self).__init__()
10 |
11 | def forward(self, pred, target):
12 | assert pred.dim() == target.dim(), "inconsistent dimensions"
13 | valid_mask = (target > 0).detach()
14 | diff = target - pred
15 | diff = diff[valid_mask]
16 | self.loss = (diff**2).mean()
17 | return self.loss
18 |
19 |
20 | class MaskedL1Loss(nn.Module):
21 | def __init__(self):
22 | super(MaskedL1Loss, self).__init__()
23 |
24 | def forward(self, pred, target, weight=None):
25 | assert pred.dim() == target.dim(), "inconsistent dimensions"
26 | valid_mask = (target > 0).detach()
27 | diff = target - pred
28 | diff = diff[valid_mask]
29 | self.loss = diff.abs().mean()
30 | return self.loss
31 |
32 |
33 | class PhotometricLoss(nn.Module):
34 | def __init__(self):
35 | super(PhotometricLoss, self).__init__()
36 |
37 | def forward(self, target, recon, mask=None):
38 |
39 | assert recon.dim(
40 | ) == 4, "expected recon dimension to be 4, but instead got {}.".format(
41 | recon.dim())
42 | assert target.dim(
43 | ) == 4, "expected target dimension to be 4, but instead got {}.".format(
44 | target.dim())
45 | assert recon.size()==target.size(), "expected recon and target to have the same size, but got {} and {} instead"\
46 | .format(recon.size(), target.size())
47 | diff = (target - recon).abs()
48 | diff = torch.sum(diff, 1) # sum along the color channel
49 |
50 | # compare only pixels that are not black
51 | valid_mask = (torch.sum(recon, 1) > 0).float() * (torch.sum(target, 1)
52 | > 0).float()
53 | if mask is not None:
54 | valid_mask = valid_mask * torch.squeeze(mask).float()
55 | valid_mask = valid_mask.byte().detach()
56 | if valid_mask.numel() > 0:
57 | diff = diff[valid_mask]
58 | if diff.nelement() > 0:
59 | self.loss = diff.mean()
60 | else:
61 | print(
62 | "warning: diff.nelement()==0 in PhotometricLoss (this is expected during early stage of training, try larger batch size)."
63 | )
64 | self.loss = 0
65 | else:
66 | print("warning: 0 valid pixel in PhotometricLoss")
67 | self.loss = 0
68 | return self.loss
69 |
70 |
71 | class SmoothnessLoss(nn.Module):
72 | def __init__(self):
73 | super(SmoothnessLoss, self).__init__()
74 |
75 | def forward(self, depth):
76 | def second_derivative(x):
77 | assert x.dim(
78 | ) == 4, "expected 4-dimensional data, but instead got {}".format(
79 | x.dim())
80 | horizontal = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, 1:-1, :
81 | -2] - x[:, :, 1:-1, 2:]
82 | vertical = 2 * x[:, :, 1:-1, 1:-1] - x[:, :, :-2, 1:
83 | -1] - x[:, :, 2:, 1:-1]
84 | der_2nd = horizontal.abs() + vertical.abs()
85 | return der_2nd.mean()
86 |
87 | self.loss = second_derivative(depth)
88 | return self.loss
89 |
--------------------------------------------------------------------------------
/dataloaders/calib_cam_to_cam.txt:
--------------------------------------------------------------------------------
1 | calib_time: 09-Jan-2012 13:57:47
2 | corner_dist: 9.950000e-02
3 | S_00: 1.392000e+03 5.120000e+02
4 | K_00: 9.842439e+02 0.000000e+00 6.900000e+02 0.000000e+00 9.808141e+02 2.331966e+02 0.000000e+00 0.000000e+00 1.000000e+00
5 | D_00: -3.728755e-01 2.037299e-01 2.219027e-03 1.383707e-03 -7.233722e-02
6 | R_00: 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00
7 | T_00: 2.573699e-16 -1.059758e-16 1.614870e-16
8 | S_rect_00: 1.242000e+03 3.750000e+02
9 | R_rect_00: 9.999239e-01 9.837760e-03 -7.445048e-03 -9.869795e-03 9.999421e-01 -4.278459e-03 7.402527e-03 4.351614e-03 9.999631e-01
10 | P_rect_00: 7.215377e+02 0.000000e+00 6.095593e+02 0.000000e+00 0.000000e+00 7.215377e+02 1.728540e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00
11 | S_01: 1.392000e+03 5.120000e+02
12 | K_01: 9.895267e+02 0.000000e+00 7.020000e+02 0.000000e+00 9.878386e+02 2.455590e+02 0.000000e+00 0.000000e+00 1.000000e+00
13 | D_01: -3.644661e-01 1.790019e-01 1.148107e-03 -6.298563e-04 -5.314062e-02
14 | R_01: 9.993513e-01 1.860866e-02 -3.083487e-02 -1.887662e-02 9.997863e-01 -8.421873e-03 3.067156e-02 8.998467e-03 9.994890e-01
15 | T_01: -5.370000e-01 4.822061e-03 -1.252488e-02
16 | S_rect_01: 1.242000e+03 3.750000e+02
17 | R_rect_01: 9.996878e-01 -8.976826e-03 2.331651e-02 8.876121e-03 9.999508e-01 4.418952e-03 -2.335503e-02 -4.210612e-03 9.997184e-01
18 | P_rect_01: 7.215377e+02 0.000000e+00 6.095593e+02 -3.875744e+02 0.000000e+00 7.215377e+02 1.728540e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00
19 | S_02: 1.392000e+03 5.120000e+02
20 | K_02: 9.597910e+02 0.000000e+00 6.960217e+02 0.000000e+00 9.569251e+02 2.241806e+02 0.000000e+00 0.000000e+00 1.000000e+00
21 | D_02: -3.691481e-01 1.968681e-01 1.353473e-03 5.677587e-04 -6.770705e-02
22 | R_02: 9.999758e-01 -5.267463e-03 -4.552439e-03 5.251945e-03 9.999804e-01 -3.413835e-03 4.570332e-03 3.389843e-03 9.999838e-01
23 | T_02: 5.956621e-02 2.900141e-04 2.577209e-03
24 | S_rect_02: 1.242000e+03 3.750000e+02
25 | R_rect_02: 9.998817e-01 1.511453e-02 -2.841595e-03 -1.511724e-02 9.998853e-01 -9.338510e-04 2.827154e-03 9.766976e-04 9.999955e-01
26 | P_rect_02: 7.215377e+02 0.000000e+00 6.095593e+02 4.485728e+01 0.000000e+00 7.215377e+02 1.728540e+02 2.163791e-01 0.000000e+00 0.000000e+00 1.000000e+00 2.745884e-03
27 | S_03: 1.392000e+03 5.120000e+02
28 | K_03: 9.037596e+02 0.000000e+00 6.957519e+02 0.000000e+00 9.019653e+02 2.242509e+02 0.000000e+00 0.000000e+00 1.000000e+00
29 | D_03: -3.639558e-01 1.788651e-01 6.029694e-04 -3.922424e-04 -5.382460e-02
30 | R_03: 9.995599e-01 1.699522e-02 -2.431313e-02 -1.704422e-02 9.998531e-01 -1.809756e-03 2.427880e-02 2.223358e-03 9.997028e-01
31 | T_03: -4.731050e-01 5.551470e-03 -5.250882e-03
32 | S_rect_03: 1.242000e+03 3.750000e+02
33 | R_rect_03: 9.998321e-01 -7.193136e-03 1.685599e-02 7.232804e-03 9.999712e-01 -2.293585e-03 -1.683901e-02 2.415116e-03 9.998553e-01
34 | P_rect_03: 7.215377e+02 0.000000e+00 6.095593e+02 -3.395242e+02 0.000000e+00 7.215377e+02 1.728540e+02 2.199936e+00 0.000000e+00 0.000000e+00 1.000000e+00 2.729905e-03
35 |
--------------------------------------------------------------------------------
/dataloaders/kitti_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import glob
4 | import fnmatch # pattern matching
5 | import numpy as np
6 | from numpy import linalg as LA
7 | from random import choice
8 | from PIL import Image
9 | import torch
10 | import torch.utils.data as data
11 | import cv2
12 | from dataloaders import transforms
13 | from dataloaders.pose_estimator import get_pose_pnp
14 |
15 | input_options = ['d', 'rgb', 'rgbd', 'g', 'gd']
16 |
17 |
18 | def load_calib():
19 | """
20 | Temporarily hardcoding the calibration matrix using calib file from 2011_09_26
21 | """
22 | calib = open("dataloaders/calib_cam_to_cam.txt", "r")
23 | lines = calib.readlines()
24 | P_rect_line = lines[25]
25 |
26 | Proj_str = P_rect_line.split(":")[1].split(" ")[1:]
27 | Proj = np.reshape(np.array([float(p) for p in Proj_str]),
28 | (3, 4)).astype(np.float32)
29 | K = Proj[:3, :3] # camera matrix
30 |
31 | # note: we will take the center crop of the images during augmentation
32 | # that changes the optical centers, but not focal lengths
33 | K[0, 2] = K[
34 | 0,
35 | 2] - 13 # from width = 1242 to 1216, with a 13-pixel cut on both sides
36 | K[1, 2] = K[
37 | 1,
38 | 2] - 11.5 # from width = 375 to 352, with a 11.5-pixel cut on both sides
39 | return K
40 |
41 |
42 | def get_paths_and_transform(split, args):
43 | assert (args.use_d or args.use_rgb
44 | or args.use_g), 'no proper input selected'
45 |
46 | if split == "train":
47 | transform = train_transform
48 | glob_d = os.path.join(
49 | args.data_folder,
50 | 'data_depth_velodyne/train/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png'
51 | )
52 | glob_gt = os.path.join(
53 | args.data_folder,
54 | 'data_depth_annotated/train/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png'
55 | )
56 |
57 | def get_rgb_paths(p):
58 | ps = p.split('/')
59 | pnew = '/'.join([args.data_folder] + ['data_rgb'] + ps[-6:-4] +
60 | ps[-2:-1] + ['data'] + ps[-1:])
61 | return pnew
62 | elif split == "val":
63 | if args.val == "full":
64 | transform = val_transform
65 | glob_d = os.path.join(
66 | args.data_folder,
67 | 'data_depth_velodyne/val/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png'
68 | )
69 | glob_gt = os.path.join(
70 | args.data_folder,
71 | 'data_depth_annotated/val/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png'
72 | )
73 | def get_rgb_paths(p):
74 | ps = p.split('/')
75 | pnew = '/'.join(ps[:-7] +
76 | ['data_rgb']+ps[-6:-4]+ps[-2:-1]+['data']+ps[-1:])
77 | return pnew
78 | elif args.val == "select":
79 | transform = no_transform
80 | glob_d = os.path.join(
81 | args.data_folder,
82 | "depth_selection/val_selection_cropped/velodyne_raw/*.png")
83 | glob_gt = os.path.join(
84 | args.data_folder,
85 | "depth_selection/val_selection_cropped/groundtruth_depth/*.png"
86 | )
87 | def get_rgb_paths(p):
88 | return p.replace("groundtruth_depth","image")
89 | elif split == "test_completion":
90 | transform = no_transform
91 | glob_d = os.path.join(
92 | args.data_folder,
93 | "depth_selection/test_depth_completion_anonymous/velodyne_raw/*.png"
94 | )
95 | glob_gt = None #"test_depth_completion_anonymous/"
96 | glob_rgb = os.path.join(
97 | args.data_folder,
98 | "depth_selection/test_depth_completion_anonymous/image/*.png")
99 | elif split == "test_prediction":
100 | transform = no_transform
101 | glob_d = None
102 | glob_gt = None #"test_depth_completion_anonymous/"
103 | glob_rgb = os.path.join(
104 | args.data_folder,
105 | "depth_selection/test_depth_prediction_anonymous/image/*.png")
106 | else:
107 | raise ValueError("Unrecognized split " + str(split))
108 |
109 | if glob_gt is not None:
110 | # train or val-full or val-select
111 | paths_d = sorted(glob.glob(glob_d))
112 | paths_gt = sorted(glob.glob(glob_gt))
113 | paths_rgb = [get_rgb_paths(p) for p in paths_gt]
114 | else:
115 | # test only has d or rgb
116 | paths_rgb = sorted(glob.glob(glob_rgb))
117 | paths_gt = [None] * len(paths_rgb)
118 | if split == "test_prediction":
119 | paths_d = [None] * len(
120 | paths_rgb) # test_prediction has no sparse depth
121 | else:
122 | paths_d = sorted(glob.glob(glob_d))
123 |
124 | if len(paths_d) == 0 and len(paths_rgb) == 0 and len(paths_gt) == 0:
125 | raise (RuntimeError("Found 0 images under {}".format(glob_gt)))
126 | if len(paths_d) == 0 and args.use_d:
127 | raise (RuntimeError("Requested sparse depth but none was found"))
128 | if len(paths_rgb) == 0 and args.use_rgb:
129 | raise (RuntimeError("Requested rgb images but none was found"))
130 | if len(paths_rgb) == 0 and args.use_g:
131 | raise (RuntimeError("Requested gray images but no rgb was found"))
132 | if len(paths_rgb) != len(paths_d) or len(paths_rgb) != len(paths_gt):
133 | raise (RuntimeError("Produced different sizes for datasets"))
134 |
135 | paths = {"rgb": paths_rgb, "d": paths_d, "gt": paths_gt}
136 | return paths, transform
137 |
138 |
139 | def rgb_read(filename):
140 | assert os.path.exists(filename), "file not found: {}".format(filename)
141 | img_file = Image.open(filename)
142 | # rgb_png = np.array(img_file, dtype=float) / 255.0 # scale pixels to the range [0,1]
143 | rgb_png = np.array(img_file, dtype='uint8') # in the range [0,255]
144 | img_file.close()
145 | return rgb_png
146 |
147 |
148 | def depth_read(filename):
149 | # loads depth map D from png file
150 | # and returns it as a numpy array,
151 | # for details see readme.txt
152 | assert os.path.exists(filename), "file not found: {}".format(filename)
153 | img_file = Image.open(filename)
154 | depth_png = np.array(img_file, dtype=int)
155 | img_file.close()
156 | # make sure we have a proper 16bit depth map here.. not 8bit!
157 | assert np.max(depth_png) > 255, \
158 | "np.max(depth_png)={}, path={}".format(np.max(depth_png),filename)
159 |
160 | depth = depth_png.astype(np.float) / 256.
161 | # depth[depth_png == 0] = -1.
162 | depth = np.expand_dims(depth, -1)
163 | return depth
164 |
165 |
166 | oheight, owidth = 352, 1216
167 |
168 |
169 | def drop_depth_measurements(depth, prob_keep):
170 | mask = np.random.binomial(1, prob_keep, depth.shape)
171 | depth *= mask
172 | return depth
173 |
174 |
175 | def train_transform(rgb, sparse, target, rgb_near, args):
176 | # s = np.random.uniform(1.0, 1.5) # random scaling
177 | # angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
178 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
179 |
180 | transform_geometric = transforms.Compose([
181 | # transforms.Rotate(angle),
182 | # transforms.Resize(s),
183 | transforms.BottomCrop((oheight, owidth)),
184 | transforms.HorizontalFlip(do_flip)
185 | ])
186 | if sparse is not None:
187 | sparse = transform_geometric(sparse)
188 | target = transform_geometric(target)
189 | if rgb is not None:
190 | brightness = np.random.uniform(max(0, 1 - args.jitter),
191 | 1 + args.jitter)
192 | contrast = np.random.uniform(max(0, 1 - args.jitter), 1 + args.jitter)
193 | saturation = np.random.uniform(max(0, 1 - args.jitter),
194 | 1 + args.jitter)
195 | transform_rgb = transforms.Compose([
196 | transforms.ColorJitter(brightness, contrast, saturation, 0),
197 | transform_geometric
198 | ])
199 | rgb = transform_rgb(rgb)
200 | if rgb_near is not None:
201 | rgb_near = transform_rgb(rgb_near)
202 | # sparse = drop_depth_measurements(sparse, 0.9)
203 |
204 | return rgb, sparse, target, rgb_near
205 |
206 |
207 | def val_transform(rgb, sparse, target, rgb_near, args):
208 | transform = transforms.Compose([
209 | transforms.BottomCrop((oheight, owidth)),
210 | ])
211 | if rgb is not None:
212 | rgb = transform(rgb)
213 | if sparse is not None:
214 | sparse = transform(sparse)
215 | if target is not None:
216 | target = transform(target)
217 | if rgb_near is not None:
218 | rgb_near = transform(rgb_near)
219 | return rgb, sparse, target, rgb_near
220 |
221 |
222 | def no_transform(rgb, sparse, target, rgb_near, args):
223 | return rgb, sparse, target, rgb_near
224 |
225 |
226 | to_tensor = transforms.ToTensor()
227 | to_float_tensor = lambda x: to_tensor(x).float()
228 |
229 |
230 | def handle_gray(rgb, args):
231 | if rgb is None:
232 | return None, None
233 | if not args.use_g:
234 | return rgb, None
235 | else:
236 | img = np.array(Image.fromarray(rgb).convert('L'))
237 | img = np.expand_dims(img, -1)
238 | if not args.use_rgb:
239 | rgb_ret = None
240 | else:
241 | rgb_ret = rgb
242 | return rgb_ret, img
243 |
244 |
245 | def get_rgb_near(path, args):
246 | assert path is not None, "path is None"
247 |
248 | def extract_frame_id(filename):
249 | head, tail = os.path.split(filename)
250 | number_string = tail[0:tail.find('.')]
251 | number = int(number_string)
252 | return head, number
253 |
254 | def get_nearby_filename(filename, new_id):
255 | head, _ = os.path.split(filename)
256 | new_filename = os.path.join(head, '%010d.png' % new_id)
257 | return new_filename
258 |
259 | head, number = extract_frame_id(path)
260 | count = 0
261 | max_frame_diff = 3
262 | candidates = [
263 | i - max_frame_diff for i in range(max_frame_diff * 2 + 1)
264 | if i - max_frame_diff != 0
265 | ]
266 | while True:
267 | random_offset = choice(candidates)
268 | path_near = get_nearby_filename(path, number + random_offset)
269 | if os.path.exists(path_near):
270 | break
271 | assert count < 20, "cannot find a nearby frame in 20 trials for {}".format(
272 | path)
273 | count += 1
274 |
275 | return rgb_read(path_near)
276 |
277 |
278 | class KittiDepth(data.Dataset):
279 | """A data loader for the Kitti dataset
280 | """
281 | def __init__(self, split, args):
282 | self.args = args
283 | self.split = split
284 | paths, transform = get_paths_and_transform(split, args)
285 | self.paths = paths
286 | self.transform = transform
287 | self.K = load_calib()
288 | self.threshold_translation = 0.1
289 |
290 | def __getraw__(self, index):
291 | rgb = rgb_read(self.paths['rgb'][index]) if \
292 | (self.paths['rgb'][index] is not None and (self.args.use_rgb or self.args.use_g)) else None
293 | sparse = depth_read(self.paths['d'][index]) if \
294 | (self.paths['d'][index] is not None and self.args.use_d) else None
295 | target = depth_read(self.paths['gt'][index]) if \
296 | self.paths['gt'][index] is not None else None
297 | rgb_near = get_rgb_near(self.paths['rgb'][index], self.args) if \
298 | self.split == 'train' and self.args.use_pose else None
299 | return rgb, sparse, target, rgb_near
300 |
301 | def __getitem__(self, index):
302 | rgb, sparse, target, rgb_near = self.__getraw__(index)
303 | rgb, sparse, target, rgb_near = self.transform(rgb, sparse, target,
304 | rgb_near, self.args)
305 | r_mat, t_vec = None, None
306 | if self.split == 'train' and self.args.use_pose:
307 | success, r_vec, t_vec = get_pose_pnp(rgb, rgb_near, sparse, self.K)
308 | # discard if translation is too small
309 | success = success and LA.norm(t_vec) > self.threshold_translation
310 | if success:
311 | r_mat, _ = cv2.Rodrigues(r_vec)
312 | else:
313 | # return the same image and no motion when PnP fails
314 | rgb_near = rgb
315 | t_vec = np.zeros((3, 1))
316 | r_mat = np.eye(3)
317 |
318 | rgb, gray = handle_gray(rgb, self.args)
319 | candidates = {"rgb":rgb, "d":sparse, "gt":target, \
320 | "g":gray, "r_mat":r_mat, "t_vec":t_vec, "rgb_near":rgb_near}
321 | items = {
322 | key: to_float_tensor(val)
323 | for key, val in candidates.items() if val is not None
324 | }
325 |
326 | return items
327 |
328 | def __len__(self):
329 | return len(self.paths['gt'])
330 |
--------------------------------------------------------------------------------
/dataloaders/pose_estimator.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def rgb2gray(rgb):
6 | return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
7 |
8 |
9 | def convert_2d_to_3d(u, v, z, K):
10 | v0 = K[1][2]
11 | u0 = K[0][2]
12 | fy = K[1][1]
13 | fx = K[0][0]
14 | x = (u - u0) * z / fx
15 | y = (v - v0) * z / fy
16 | return (x, y, z)
17 |
18 |
19 | def feature_match(img1, img2):
20 | r''' Find features on both images and match them pairwise
21 | '''
22 | max_n_features = 1000
23 | # max_n_features = 500
24 | use_flann = False # better not use flann
25 |
26 | detector = cv2.xfeatures2d.SIFT_create(max_n_features)
27 |
28 | # find the keypoints and descriptors with SIFT
29 | kp1, des1 = detector.detectAndCompute(img1, None)
30 | kp2, des2 = detector.detectAndCompute(img2, None)
31 | if (des1 is None) or (des2 is None):
32 | return [], []
33 | des1 = des1.astype(np.float32)
34 | des2 = des2.astype(np.float32)
35 |
36 | if use_flann:
37 | # FLANN parameters
38 | FLANN_INDEX_KDTREE = 0
39 | index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
40 | search_params = dict(checks=50)
41 | flann = cv2.FlannBasedMatcher(index_params, search_params)
42 | matches = flann.knnMatch(des1, des2, k=2)
43 | else:
44 | matcher = cv2.DescriptorMatcher().create('BruteForce')
45 | matches = matcher.knnMatch(des1, des2, k=2)
46 |
47 | good = []
48 | pts1 = []
49 | pts2 = []
50 | # ratio test as per Lowe's paper
51 | for i, (m, n) in enumerate(matches):
52 | if m.distance < 0.8 * n.distance:
53 | good.append(m)
54 | pts2.append(kp2[m.trainIdx].pt)
55 | pts1.append(kp1[m.queryIdx].pt)
56 |
57 | pts1 = np.int32(pts1)
58 | pts2 = np.int32(pts2)
59 | return pts1, pts2
60 |
61 |
62 | def get_pose_pnp(rgb_curr, rgb_near, depth_curr, K):
63 | gray_curr = rgb2gray(rgb_curr).astype(np.uint8)
64 | gray_near = rgb2gray(rgb_near).astype(np.uint8)
65 | height, width = gray_curr.shape
66 |
67 | pts2d_curr, pts2d_near = feature_match(gray_curr,
68 | gray_near) # feature matching
69 |
70 | # dilation of depth
71 | kernel = np.ones((4, 4), np.uint8)
72 | depth_curr_dilated = cv2.dilate(depth_curr, kernel)
73 |
74 | # extract 3d pts
75 | pts3d_curr = []
76 | pts2d_near_filtered = [
77 | ] # keep only feature points with depth in the current frame
78 | for i, pt2d in enumerate(pts2d_curr):
79 | # print(pt2d)
80 | u, v = pt2d[0], pt2d[1]
81 | z = depth_curr_dilated[v, u]
82 | if z > 0:
83 | xyz_curr = convert_2d_to_3d(u, v, z, K)
84 | pts3d_curr.append(xyz_curr)
85 | pts2d_near_filtered.append(pts2d_near[i])
86 |
87 | # the minimal number of points accepted by solvePnP is 4:
88 | if len(pts3d_curr) >= 4 and len(pts2d_near_filtered) >= 4:
89 | pts3d_curr = np.expand_dims(np.array(pts3d_curr).astype(np.float32),
90 | axis=1)
91 | pts2d_near_filtered = np.expand_dims(
92 | np.array(pts2d_near_filtered).astype(np.float32), axis=1)
93 |
94 | # ransac
95 | ret = cv2.solvePnPRansac(pts3d_curr,
96 | pts2d_near_filtered,
97 | K,
98 | distCoeffs=None)
99 | success = ret[0]
100 | rotation_vector = ret[1]
101 | translation_vector = ret[2]
102 | return (success, rotation_vector, translation_vector)
103 | else:
104 | return (0, None, None)
105 |
--------------------------------------------------------------------------------
/dataloaders/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import math
4 | import random
5 |
6 | from PIL import Image, ImageOps, ImageEnhance
7 | try:
8 | import accimage
9 | except ImportError:
10 | accimage = None
11 |
12 | import numpy as np
13 | import numbers
14 | import types
15 | import collections
16 | import warnings
17 |
18 | import scipy.ndimage.interpolation as itpl
19 | import skimage.transform
20 |
21 |
22 | def _is_numpy_image(img):
23 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
24 |
25 |
26 | def _is_pil_image(img):
27 | if accimage is not None:
28 | return isinstance(img, (Image.Image, accimage.Image))
29 | else:
30 | return isinstance(img, Image.Image)
31 |
32 |
33 | def _is_tensor_image(img):
34 | return torch.is_tensor(img) and img.ndimension() == 3
35 |
36 |
37 | def adjust_brightness(img, brightness_factor):
38 | """Adjust brightness of an Image.
39 |
40 | Args:
41 | img (PIL Image): PIL Image to be adjusted.
42 | brightness_factor (float): How much to adjust the brightness. Can be
43 | any non negative number. 0 gives a black image, 1 gives the
44 | original image while 2 increases the brightness by a factor of 2.
45 |
46 | Returns:
47 | PIL Image: Brightness adjusted image.
48 | """
49 | if not _is_pil_image(img):
50 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
51 |
52 | enhancer = ImageEnhance.Brightness(img)
53 | img = enhancer.enhance(brightness_factor)
54 | return img
55 |
56 |
57 | def adjust_contrast(img, contrast_factor):
58 | """Adjust contrast of an Image.
59 |
60 | Args:
61 | img (PIL Image): PIL Image to be adjusted.
62 | contrast_factor (float): How much to adjust the contrast. Can be any
63 | non negative number. 0 gives a solid gray image, 1 gives the
64 | original image while 2 increases the contrast by a factor of 2.
65 |
66 | Returns:
67 | PIL Image: Contrast adjusted image.
68 | """
69 | if not _is_pil_image(img):
70 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
71 |
72 | enhancer = ImageEnhance.Contrast(img)
73 | img = enhancer.enhance(contrast_factor)
74 | return img
75 |
76 |
77 | def adjust_saturation(img, saturation_factor):
78 | """Adjust color saturation of an image.
79 |
80 | Args:
81 | img (PIL Image): PIL Image to be adjusted.
82 | saturation_factor (float): How much to adjust the saturation. 0 will
83 | give a black and white image, 1 will give the original image while
84 | 2 will enhance the saturation by a factor of 2.
85 |
86 | Returns:
87 | PIL Image: Saturation adjusted image.
88 | """
89 | if not _is_pil_image(img):
90 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
91 |
92 | enhancer = ImageEnhance.Color(img)
93 | img = enhancer.enhance(saturation_factor)
94 | return img
95 |
96 |
97 | def adjust_hue(img, hue_factor):
98 | """Adjust hue of an image.
99 |
100 | The image hue is adjusted by converting the image to HSV and
101 | cyclically shifting the intensities in the hue channel (H).
102 | The image is then converted back to original image mode.
103 |
104 | `hue_factor` is the amount of shift in H channel and must be in the
105 | interval `[-0.5, 0.5]`.
106 |
107 | See https://en.wikipedia.org/wiki/Hue for more details on Hue.
108 |
109 | Args:
110 | img (PIL Image): PIL Image to be adjusted.
111 | hue_factor (float): How much to shift the hue channel. Should be in
112 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
113 | HSV space in positive and negative direction respectively.
114 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
115 | with complementary colors while 0 gives the original image.
116 |
117 | Returns:
118 | PIL Image: Hue adjusted image.
119 | """
120 | if not (-0.5 <= hue_factor <= 0.5):
121 | raise ValueError(
122 | 'hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
123 |
124 | if not _is_pil_image(img):
125 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
126 |
127 | input_mode = img.mode
128 | if input_mode in {'L', '1', 'I', 'F'}:
129 | return img
130 |
131 | h, s, v = img.convert('HSV').split()
132 |
133 | np_h = np.array(h, dtype=np.uint8)
134 | # uint8 addition take cares of rotation across boundaries
135 | with np.errstate(over='ignore'):
136 | np_h += np.uint8(hue_factor * 255)
137 | h = Image.fromarray(np_h, 'L')
138 |
139 | img = Image.merge('HSV', (h, s, v)).convert(input_mode)
140 | return img
141 |
142 |
143 | def adjust_gamma(img, gamma, gain=1):
144 | """Perform gamma correction on an image.
145 |
146 | Also known as Power Law Transform. Intensities in RGB mode are adjusted
147 | based on the following equation:
148 |
149 | I_out = 255 * gain * ((I_in / 255) ** gamma)
150 |
151 | See https://en.wikipedia.org/wiki/Gamma_correction for more details.
152 |
153 | Args:
154 | img (PIL Image): PIL Image to be adjusted.
155 | gamma (float): Non negative real number. gamma larger than 1 make the
156 | shadows darker, while gamma smaller than 1 make dark regions
157 | lighter.
158 | gain (float): The constant multiplier.
159 | """
160 | if not _is_pil_image(img):
161 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
162 |
163 | if gamma < 0:
164 | raise ValueError('Gamma should be a non-negative real number')
165 |
166 | input_mode = img.mode
167 | img = img.convert('RGB')
168 |
169 | np_img = np.array(img, dtype=np.float32)
170 | np_img = 255 * gain * ((np_img / 255)**gamma)
171 | np_img = np.uint8(np.clip(np_img, 0, 255))
172 |
173 | img = Image.fromarray(np_img, 'RGB').convert(input_mode)
174 | return img
175 |
176 |
177 | class Compose(object):
178 | """Composes several transforms together.
179 |
180 | Args:
181 | transforms (list of ``Transform`` objects): list of transforms to compose.
182 |
183 | Example:
184 | >>> transforms.Compose([
185 | >>> transforms.CenterCrop(10),
186 | >>> transforms.ToTensor(),
187 | >>> ])
188 | """
189 | def __init__(self, transforms):
190 | self.transforms = transforms
191 |
192 | def __call__(self, img):
193 | for t in self.transforms:
194 | img = t(img)
195 | return img
196 |
197 |
198 | class ToTensor(object):
199 | """Convert a ``numpy.ndarray`` to tensor.
200 |
201 | Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
202 | """
203 | def __call__(self, img):
204 | """Convert a ``numpy.ndarray`` to tensor.
205 |
206 | Args:
207 | img (numpy.ndarray): Image to be converted to tensor.
208 |
209 | Returns:
210 | Tensor: Converted image.
211 | """
212 | if not (_is_numpy_image(img)):
213 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
214 |
215 | if isinstance(img, np.ndarray):
216 | # handle numpy array
217 | if img.ndim == 3:
218 | img = torch.from_numpy(img.transpose((2, 0, 1)).copy())
219 | elif img.ndim == 2:
220 | img = torch.from_numpy(img.copy())
221 | else:
222 | raise RuntimeError(
223 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.
224 | format(img.ndim))
225 |
226 | return img
227 |
228 |
229 | class NormalizeNumpyArray(object):
230 | """Normalize a ``numpy.ndarray`` with mean and standard deviation.
231 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
232 | will normalize each channel of the input ``numpy.ndarray`` i.e.
233 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
234 |
235 | Args:
236 | mean (sequence): Sequence of means for each channel.
237 | std (sequence): Sequence of standard deviations for each channel.
238 | """
239 | def __init__(self, mean, std):
240 | self.mean = mean
241 | self.std = std
242 |
243 | def __call__(self, img):
244 | """
245 | Args:
246 | img (numpy.ndarray): Image of size (H, W, C) to be normalized.
247 |
248 | Returns:
249 | Tensor: Normalized image.
250 | """
251 | if not (_is_numpy_image(img)):
252 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
253 | # TODO: make efficient
254 | print(img.shape)
255 | for i in range(3):
256 | img[:, :, i] = (img[:, :, i] - self.mean[i]) / self.std[i]
257 | return img
258 |
259 |
260 | class NormalizeTensor(object):
261 | """Normalize an tensor image with mean and standard deviation.
262 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
263 | will normalize each channel of the input ``torch.*Tensor`` i.e.
264 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
265 |
266 | Args:
267 | mean (sequence): Sequence of means for each channel.
268 | std (sequence): Sequence of standard deviations for each channel.
269 | """
270 | def __init__(self, mean, std):
271 | self.mean = mean
272 | self.std = std
273 |
274 | def __call__(self, tensor):
275 | """
276 | Args:
277 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
278 |
279 | Returns:
280 | Tensor: Normalized Tensor image.
281 | """
282 | if not _is_tensor_image(tensor):
283 | raise TypeError('tensor is not a torch image.')
284 | # TODO: make efficient
285 | for t, m, s in zip(tensor, self.mean, self.std):
286 | t.sub_(m).div_(s)
287 | return tensor
288 |
289 |
290 | class Rotate(object):
291 | """Rotates the given ``numpy.ndarray``.
292 |
293 | Args:
294 | angle (float): The rotation angle in degrees.
295 | """
296 | def __init__(self, angle):
297 | self.angle = angle
298 |
299 | def __call__(self, img):
300 | """
301 | Args:
302 | img (numpy.ndarray (C x H x W)): Image to be rotated.
303 |
304 | Returns:
305 | img (numpy.ndarray (C x H x W)): Rotated image.
306 | """
307 |
308 | # order=0 means nearest-neighbor type interpolation
309 | return skimage.transform.rotate(img, self.angle, resize=False, order=0)
310 |
311 |
312 | class Resize(object):
313 | """Resize the the given ``numpy.ndarray`` to the given size.
314 | Args:
315 | size (sequence or int): Desired output size. If size is a sequence like
316 | (h, w), output size will be matched to this. If size is an int,
317 | smaller edge of the image will be matched to this number.
318 | i.e, if height > width, then image will be rescaled to
319 | (size * height / width, size)
320 | interpolation (int, optional): Desired interpolation. Default is
321 | ``PIL.Image.BILINEAR``
322 | """
323 | def __init__(self, size, interpolation='nearest'):
324 | assert isinstance(size, float)
325 | self.size = size
326 | self.interpolation = interpolation
327 |
328 | def __call__(self, img):
329 | """
330 | Args:
331 | img (numpy.ndarray (C x H x W)): Image to be scaled.
332 | Returns:
333 | img (numpy.ndarray (C x H x W)): Rescaled image.
334 | """
335 | if img.ndim == 3:
336 | return skimage.transform.rescale(img, self.size, order=0)
337 | elif img.ndim == 2:
338 | return skimage.transform.rescale(img, self.size, order=0)
339 | else:
340 | RuntimeError(
341 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
342 | img.ndim))
343 |
344 |
345 | class CenterCrop(object):
346 | """Crops the given ``numpy.ndarray`` at the center.
347 |
348 | Args:
349 | size (sequence or int): Desired output size of the crop. If size is an
350 | int instead of sequence like (h, w), a square crop (size, size) is
351 | made.
352 | """
353 | def __init__(self, size):
354 | if isinstance(size, numbers.Number):
355 | self.size = (int(size), int(size))
356 | else:
357 | self.size = size
358 |
359 | @staticmethod
360 | def get_params(img, output_size):
361 | """Get parameters for ``crop`` for center crop.
362 |
363 | Args:
364 | img (numpy.ndarray (C x H x W)): Image to be cropped.
365 | output_size (tuple): Expected output size of the crop.
366 |
367 | Returns:
368 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
369 | """
370 | h = img.shape[0]
371 | w = img.shape[1]
372 | th, tw = output_size
373 | i = int(round((h - th) / 2.))
374 | j = int(round((w - tw) / 2.))
375 |
376 | # # randomized cropping
377 | # i = np.random.randint(i-3, i+4)
378 | # j = np.random.randint(j-3, j+4)
379 |
380 | return i, j, th, tw
381 |
382 | def __call__(self, img):
383 | """
384 | Args:
385 | img (numpy.ndarray (C x H x W)): Image to be cropped.
386 |
387 | Returns:
388 | img (numpy.ndarray (C x H x W)): Cropped image.
389 | """
390 | i, j, h, w = self.get_params(img, self.size)
391 | """
392 | i: Upper pixel coordinate.
393 | j: Left pixel coordinate.
394 | h: Height of the cropped image.
395 | w: Width of the cropped image.
396 | """
397 | if not (_is_numpy_image(img)):
398 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
399 | if img.ndim == 3:
400 | return img[i:i + h, j:j + w, :]
401 | elif img.ndim == 2:
402 | return img[i:i + h, j:j + w]
403 | else:
404 | raise RuntimeError(
405 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
406 | img.ndim))
407 |
408 |
409 | class BottomCrop(object):
410 | """Crops the given ``numpy.ndarray`` at the bottom.
411 |
412 | Args:
413 | size (sequence or int): Desired output size of the crop. If size is an
414 | int instead of sequence like (h, w), a square crop (size, size) is
415 | made.
416 | """
417 | def __init__(self, size):
418 | if isinstance(size, numbers.Number):
419 | self.size = (int(size), int(size))
420 | else:
421 | self.size = size
422 |
423 | @staticmethod
424 | def get_params(img, output_size):
425 | """Get parameters for ``crop`` for bottom crop.
426 |
427 | Args:
428 | img (numpy.ndarray (C x H x W)): Image to be cropped.
429 | output_size (tuple): Expected output size of the crop.
430 |
431 | Returns:
432 | tuple: params (i, j, h, w) to be passed to ``crop`` for bottom crop.
433 | """
434 | h = img.shape[0]
435 | w = img.shape[1]
436 | th, tw = output_size
437 | i = h - th
438 | j = int(round((w - tw) / 2.))
439 |
440 | # randomized left and right cropping
441 | # i = np.random.randint(i-3, i+4)
442 | # j = np.random.randint(j-1, j+1)
443 |
444 | return i, j, th, tw
445 |
446 | def __call__(self, img):
447 | """
448 | Args:
449 | img (numpy.ndarray (C x H x W)): Image to be cropped.
450 |
451 | Returns:
452 | img (numpy.ndarray (C x H x W)): Cropped image.
453 | """
454 | i, j, h, w = self.get_params(img, self.size)
455 | """
456 | i: Upper pixel coordinate.
457 | j: Left pixel coordinate.
458 | h: Height of the cropped image.
459 | w: Width of the cropped image.
460 | """
461 | if not (_is_numpy_image(img)):
462 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
463 | if img.ndim == 3:
464 | return img[i:i + h, j:j + w, :]
465 | elif img.ndim == 2:
466 | return img[i:i + h, j:j + w]
467 | else:
468 | raise RuntimeError(
469 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
470 | img.ndim))
471 |
472 |
473 | class Crop(object):
474 | """Crops the given ``numpy.ndarray`` at the center.
475 |
476 | Args:
477 | size (sequence or int): Desired output size of the crop. If size is an
478 | int instead of sequence like (h, w), a square crop (size, size) is
479 | made.
480 | """
481 | def __init__(self, crop):
482 | self.crop = crop
483 |
484 | @staticmethod
485 | def get_params(img, crop):
486 | """Get parameters for ``crop`` for center crop.
487 |
488 | Args:
489 | img (numpy.ndarray (C x H x W)): Image to be cropped.
490 | output_size (tuple): Expected output size of the crop.
491 |
492 | Returns:
493 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
494 | """
495 | x_l, x_r, y_b, y_t = crop
496 | h = img.shape[0]
497 | w = img.shape[1]
498 | assert x_l >= 0 and x_l < w
499 | assert x_r >= 0 and x_r < w
500 | assert y_b >= 0 and y_b < h
501 | assert y_t >= 0 and y_t < h
502 | assert x_l < x_r and y_b < y_t
503 |
504 | return x_l, x_r, y_b, y_t
505 |
506 | def __call__(self, img):
507 | """
508 | Args:
509 | img (numpy.ndarray (C x H x W)): Image to be cropped.
510 |
511 | Returns:
512 | img (numpy.ndarray (C x H x W)): Cropped image.
513 | """
514 | x_l, x_r, y_b, y_t = self.get_params(img, self.crop)
515 | """
516 | i: Upper pixel coordinate.
517 | j: Left pixel coordinate.
518 | h: Height of the cropped image.
519 | w: Width of the cropped image.
520 | """
521 | if not (_is_numpy_image(img)):
522 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
523 | if img.ndim == 3:
524 | return img[y_b:y_t, x_l:x_r, :]
525 | elif img.ndim == 2:
526 | return img[y_b:y_t, x_l:x_r]
527 | else:
528 | raise RuntimeError(
529 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
530 | img.ndim))
531 |
532 |
533 | class Lambda(object):
534 | """Apply a user-defined lambda as a transform.
535 |
536 | Args:
537 | lambd (function): Lambda/function to be used for transform.
538 | """
539 | def __init__(self, lambd):
540 | assert isinstance(lambd, types.LambdaType)
541 | self.lambd = lambd
542 |
543 | def __call__(self, img):
544 | return self.lambd(img)
545 |
546 |
547 | class HorizontalFlip(object):
548 | """Horizontally flip the given ``numpy.ndarray``.
549 |
550 | Args:
551 | do_flip (boolean): whether or not do horizontal flip.
552 |
553 | """
554 | def __init__(self, do_flip):
555 | self.do_flip = do_flip
556 |
557 | def __call__(self, img):
558 | """
559 | Args:
560 | img (numpy.ndarray (C x H x W)): Image to be flipped.
561 |
562 | Returns:
563 | img (numpy.ndarray (C x H x W)): flipped image.
564 | """
565 | if not (_is_numpy_image(img)):
566 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
567 |
568 | if self.do_flip:
569 | return np.fliplr(img)
570 | else:
571 | return img
572 |
573 |
574 | class ColorJitter(object):
575 | """Randomly change the brightness, contrast and saturation of an image.
576 |
577 | Args:
578 | brightness (float): How much to jitter brightness. brightness_factor
579 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
580 | contrast (float): How much to jitter contrast. contrast_factor
581 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
582 | saturation (float): How much to jitter saturation. saturation_factor
583 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
584 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
585 | [-hue, hue]. Should be >=0 and <= 0.5.
586 | """
587 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
588 | transforms = []
589 | transforms.append(
590 | Lambda(lambda img: adjust_brightness(img, brightness)))
591 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast)))
592 | transforms.append(
593 | Lambda(lambda img: adjust_saturation(img, saturation)))
594 | transforms.append(Lambda(lambda img: adjust_hue(img, hue)))
595 | np.random.shuffle(transforms)
596 | self.transform = Compose(transforms)
597 |
598 | def __call__(self, img):
599 | """
600 | Args:
601 | img (numpy.ndarray (C x H x W)): Input image.
602 |
603 | Returns:
604 | img (numpy.ndarray (C x H x W)): Color jittered image.
605 | """
606 | if not (_is_numpy_image(img)):
607 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
608 |
609 | pil = Image.fromarray(img)
610 | return np.array(self.transform(pil))
611 |
--------------------------------------------------------------------------------
/download/rgb_train_downloader.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | files=(
4 | # 2011_09_26_calib.zip
5 | 2011_09_26_drive_0001
6 | # 2011_09_26_drive_0002
7 | # 2011_09_26_drive_0005
8 | 2011_09_26_drive_0009
9 | 2011_09_26_drive_0011
10 | # 2011_09_26_drive_0013
11 | 2011_09_26_drive_0014
12 | 2011_09_26_drive_0015
13 | 2011_09_26_drive_0017
14 | 2011_09_26_drive_0018
15 | 2011_09_26_drive_0019
16 | # 2011_09_26_drive_0020
17 | 2011_09_26_drive_0022
18 | # 2011_09_26_drive_0023
19 | 2011_09_26_drive_0027
20 | 2011_09_26_drive_0028
21 | 2011_09_26_drive_0029
22 | 2011_09_26_drive_0032
23 | 2011_09_26_drive_0035
24 | # 2011_09_26_drive_0036
25 | 2011_09_26_drive_0039
26 | 2011_09_26_drive_0046
27 | 2011_09_26_drive_0048
28 | 2011_09_26_drive_0051
29 | 2011_09_26_drive_0052
30 | 2011_09_26_drive_0056
31 | 2011_09_26_drive_0057
32 | 2011_09_26_drive_0059
33 | 2011_09_26_drive_0060
34 | 2011_09_26_drive_0061
35 | 2011_09_26_drive_0064
36 | 2011_09_26_drive_0070
37 | # 2011_09_26_drive_0079
38 | 2011_09_26_drive_0084
39 | 2011_09_26_drive_0086
40 | 2011_09_26_drive_0087
41 | 2011_09_26_drive_0091
42 | 2011_09_26_drive_0093
43 | # 2011_09_26_drive_0095
44 | 2011_09_26_drive_0096
45 | 2011_09_26_drive_0101
46 | 2011_09_26_drive_0104
47 | 2011_09_26_drive_0106
48 | # 2011_09_26_drive_0113
49 | 2011_09_26_drive_0117
50 | # 2011_09_26_drive_0119
51 | # 2011_09_28_calib.zip
52 | 2011_09_28_drive_0001
53 | 2011_09_28_drive_0002
54 | 2011_09_28_drive_0016
55 | 2011_09_28_drive_0021
56 | 2011_09_28_drive_0034
57 | 2011_09_28_drive_0035
58 | # 2011_09_28_drive_0037
59 | 2011_09_28_drive_0038
60 | 2011_09_28_drive_0039
61 | 2011_09_28_drive_0043
62 | 2011_09_28_drive_0045
63 | 2011_09_28_drive_0047
64 | 2011_09_28_drive_0053
65 | 2011_09_28_drive_0054
66 | 2011_09_28_drive_0057
67 | 2011_09_28_drive_0065
68 | 2011_09_28_drive_0066
69 | 2011_09_28_drive_0068
70 | 2011_09_28_drive_0070
71 | 2011_09_28_drive_0071
72 | 2011_09_28_drive_0075
73 | 2011_09_28_drive_0077
74 | 2011_09_28_drive_0078
75 | 2011_09_28_drive_0080
76 | 2011_09_28_drive_0082
77 | 2011_09_28_drive_0086
78 | 2011_09_28_drive_0087
79 | 2011_09_28_drive_0089
80 | 2011_09_28_drive_0090
81 | 2011_09_28_drive_0094
82 | 2011_09_28_drive_0095
83 | 2011_09_28_drive_0096
84 | 2011_09_28_drive_0098
85 | 2011_09_28_drive_0100
86 | 2011_09_28_drive_0102
87 | 2011_09_28_drive_0103
88 | 2011_09_28_drive_0104
89 | 2011_09_28_drive_0106
90 | 2011_09_28_drive_0108
91 | 2011_09_28_drive_0110
92 | 2011_09_28_drive_0113
93 | 2011_09_28_drive_0117
94 | 2011_09_28_drive_0119
95 | 2011_09_28_drive_0121
96 | 2011_09_28_drive_0122
97 | 2011_09_28_drive_0125
98 | 2011_09_28_drive_0126
99 | 2011_09_28_drive_0128
100 | 2011_09_28_drive_0132
101 | 2011_09_28_drive_0134
102 | 2011_09_28_drive_0135
103 | 2011_09_28_drive_0136
104 | 2011_09_28_drive_0138
105 | 2011_09_28_drive_0141
106 | 2011_09_28_drive_0143
107 | 2011_09_28_drive_0145
108 | 2011_09_28_drive_0146
109 | 2011_09_28_drive_0149
110 | 2011_09_28_drive_0153
111 | 2011_09_28_drive_0154
112 | 2011_09_28_drive_0155
113 | 2011_09_28_drive_0156
114 | 2011_09_28_drive_0160
115 | 2011_09_28_drive_0161
116 | 2011_09_28_drive_0162
117 | 2011_09_28_drive_0165
118 | 2011_09_28_drive_0166
119 | 2011_09_28_drive_0167
120 | 2011_09_28_drive_0168
121 | 2011_09_28_drive_0171
122 | 2011_09_28_drive_0174
123 | 2011_09_28_drive_0177
124 | 2011_09_28_drive_0179
125 | 2011_09_28_drive_0183
126 | 2011_09_28_drive_0184
127 | 2011_09_28_drive_0185
128 | 2011_09_28_drive_0186
129 | 2011_09_28_drive_0187
130 | 2011_09_28_drive_0191
131 | 2011_09_28_drive_0192
132 | 2011_09_28_drive_0195
133 | 2011_09_28_drive_0198
134 | 2011_09_28_drive_0199
135 | 2011_09_28_drive_0201
136 | 2011_09_28_drive_0204
137 | 2011_09_28_drive_0205
138 | 2011_09_28_drive_0208
139 | 2011_09_28_drive_0209
140 | 2011_09_28_drive_0214
141 | 2011_09_28_drive_0216
142 | 2011_09_28_drive_0220
143 | 2011_09_28_drive_0222
144 | # 2011_09_28_drive_0225
145 | # 2011_09_29_calib.zip
146 | 2011_09_29_drive_0004
147 | # 2011_09_29_drive_0026
148 | 2011_09_29_drive_0071
149 | # 2011_09_29_drive_0108
150 | # 2011_09_30_calib.zip
151 | # 2011_09_30_drive_0016
152 | 2011_09_30_drive_0018
153 | 2011_09_30_drive_0020
154 | 2011_09_30_drive_0027
155 | 2011_09_30_drive_0028
156 | 2011_09_30_drive_0033
157 | 2011_09_30_drive_0034
158 | # 2011_09_30_drive_0072
159 | # 2011_10_03_calib.zip
160 | 2011_10_03_drive_0027
161 | 2011_10_03_drive_0034
162 | 2011_10_03_drive_0042
163 | # 2011_10_03_drive_0047
164 | # 2011_10_03_drive_0058
165 | )
166 |
167 | basedir='../data/data_rgb/train/'
168 | mkdir -p $basedir
169 | echo "Saving to "$basedir
170 | for i in ${files[@]}; do
171 | datadate="${i%%_drive_*}"
172 | echo $datadate
173 | shortname=$i'_sync.zip'
174 | fullname=$i'/'$i'_sync.zip'
175 | rm -f $shortname # remove zip file
176 | echo "Downloading: "$shortname
177 |
178 | wget 's3.eu-central-1.amazonaws.com/avg-kitti/raw_data/'$fullname
179 | unzip -o $shortname
180 | mv $datadate'/'$i'_sync' $basedir$i'_sync'
181 | rmdir $datadate
182 | rm -rf $basedir$i'_sync/image_00' $basedir$i'_sync/image_01' $basedir$i'_sync/velodyne_points' $basedir$i'_sync/oxts'
183 | rm $shortname # remove zip file
184 | done
185 |
186 |
187 |
--------------------------------------------------------------------------------
/download/rgb_val_downloader.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | files=(
4 | # 2011_09_26_calib.zip
5 | # 2011_09_26_drive_0001
6 | 2011_09_26_drive_0002
7 | 2011_09_26_drive_0005
8 | # 2011_09_26_drive_0009
9 | # 2011_09_26_drive_0011
10 | 2011_09_26_drive_0013
11 | # 2011_09_26_drive_0014
12 | # 2011_09_26_drive_0015
13 | # 2011_09_26_drive_0017
14 | # 2011_09_26_drive_0018
15 | # 2011_09_26_drive_0019
16 | 2011_09_26_drive_0020
17 | # 2011_09_26_drive_0022
18 | 2011_09_26_drive_0023
19 | # 2011_09_26_drive_0027
20 | # 2011_09_26_drive_0028
21 | # 2011_09_26_drive_0029
22 | # 2011_09_26_drive_0032
23 | # 2011_09_26_drive_0035
24 | 2011_09_26_drive_0036
25 | # 2011_09_26_drive_0039
26 | # 2011_09_26_drive_0046
27 | # 2011_09_26_drive_0048
28 | # 2011_09_26_drive_0051
29 | # 2011_09_26_drive_0052
30 | # 2011_09_26_drive_0056
31 | # 2011_09_26_drive_0057
32 | # 2011_09_26_drive_0059
33 | # 2011_09_26_drive_0060
34 | # 2011_09_26_drive_0061
35 | # 2011_09_26_drive_0064
36 | # 2011_09_26_drive_0070
37 | 2011_09_26_drive_0079
38 | # 2011_09_26_drive_0084
39 | # 2011_09_26_drive_0086
40 | # 2011_09_26_drive_0087
41 | # 2011_09_26_drive_0091
42 | # 2011_09_26_drive_0093
43 | 2011_09_26_drive_0095
44 | # 2011_09_26_drive_0096
45 | # 2011_09_26_drive_0101
46 | # 2011_09_26_drive_0104
47 | # 2011_09_26_drive_0106
48 | 2011_09_26_drive_0113
49 | # 2011_09_26_drive_0117
50 | 2011_09_26_drive_0119
51 | # 2011_09_28_calib.zip
52 | # 2011_09_28_drive_0001
53 | # 2011_09_28_drive_0002
54 | # 2011_09_28_drive_0016
55 | # 2011_09_28_drive_0021
56 | # 2011_09_28_drive_0034
57 | # 2011_09_28_drive_0035
58 | 2011_09_28_drive_0037
59 | # 2011_09_28_drive_0038
60 | # 2011_09_28_drive_0039
61 | # 2011_09_28_drive_0043
62 | # 2011_09_28_drive_0045
63 | # 2011_09_28_drive_0047
64 | # 2011_09_28_drive_0053
65 | # 2011_09_28_drive_0054
66 | # 2011_09_28_drive_0057
67 | # 2011_09_28_drive_0065
68 | # 2011_09_28_drive_0066
69 | # 2011_09_28_drive_0068
70 | # 2011_09_28_drive_0070
71 | # 2011_09_28_drive_0071
72 | # 2011_09_28_drive_0075
73 | # 2011_09_28_drive_0077
74 | # 2011_09_28_drive_0078
75 | # 2011_09_28_drive_0080
76 | # 2011_09_28_drive_0082
77 | # 2011_09_28_drive_0086
78 | # 2011_09_28_drive_0087
79 | # 2011_09_28_drive_0089
80 | # 2011_09_28_drive_0090
81 | # 2011_09_28_drive_0094
82 | # 2011_09_28_drive_0095
83 | # 2011_09_28_drive_0096
84 | # 2011_09_28_drive_0098
85 | # 2011_09_28_drive_0100
86 | # 2011_09_28_drive_0102
87 | # 2011_09_28_drive_0103
88 | # 2011_09_28_drive_0104
89 | # 2011_09_28_drive_0106
90 | # 2011_09_28_drive_0108
91 | # 2011_09_28_drive_0110
92 | # 2011_09_28_drive_0113
93 | # 2011_09_28_drive_0117
94 | # 2011_09_28_drive_0119
95 | # 2011_09_28_drive_0121
96 | # 2011_09_28_drive_0122
97 | # 2011_09_28_drive_0125
98 | # 2011_09_28_drive_0126
99 | # 2011_09_28_drive_0128
100 | # 2011_09_28_drive_0132
101 | # 2011_09_28_drive_0134
102 | # 2011_09_28_drive_0135
103 | # 2011_09_28_drive_0136
104 | # 2011_09_28_drive_0138
105 | # 2011_09_28_drive_0141
106 | # 2011_09_28_drive_0143
107 | # 2011_09_28_drive_0145
108 | # 2011_09_28_drive_0146
109 | # 2011_09_28_drive_0149
110 | # 2011_09_28_drive_0153
111 | # 2011_09_28_drive_0154
112 | # 2011_09_28_drive_0155
113 | # 2011_09_28_drive_0156
114 | # 2011_09_28_drive_0160
115 | # 2011_09_28_drive_0161
116 | # 2011_09_28_drive_0162
117 | # 2011_09_28_drive_0165
118 | # 2011_09_28_drive_0166
119 | # 2011_09_28_drive_0167
120 | # 2011_09_28_drive_0168
121 | # 2011_09_28_drive_0171
122 | # 2011_09_28_drive_0174
123 | # 2011_09_28_drive_0177
124 | # 2011_09_28_drive_0179
125 | # 2011_09_28_drive_0183
126 | # 2011_09_28_drive_0184
127 | # 2011_09_28_drive_0185
128 | # 2011_09_28_drive_0186
129 | # 2011_09_28_drive_0187
130 | # 2011_09_28_drive_0191
131 | # 2011_09_28_drive_0192
132 | # 2011_09_28_drive_0195
133 | # 2011_09_28_drive_0198
134 | # 2011_09_28_drive_0199
135 | # 2011_09_28_drive_0201
136 | # 2011_09_28_drive_0204
137 | # 2011_09_28_drive_0205
138 | # 2011_09_28_drive_0208
139 | # 2011_09_28_drive_0209
140 | # 2011_09_28_drive_0214
141 | # 2011_09_28_drive_0216
142 | # 2011_09_28_drive_0220
143 | # 2011_09_28_drive_0222
144 | 2011_09_28_drive_0225
145 | # 2011_09_29_calib.zip
146 | # 2011_09_29_drive_0004
147 | 2011_09_29_drive_0026
148 | # 2011_09_29_drive_0071
149 | 2011_09_29_drive_0108
150 | # 2011_09_30_calib.zip
151 | 2011_09_30_drive_0016
152 | # 2011_09_30_drive_0018
153 | # 2011_09_30_drive_0020
154 | # 2011_09_30_drive_0027
155 | # 2011_09_30_drive_0028
156 | # 2011_09_30_drive_0033
157 | # 2011_09_30_drive_0034
158 | 2011_09_30_drive_0072
159 | # 2011_10_03_calib.zip
160 | # 2011_10_03_drive_0027
161 | # 2011_10_03_drive_0034
162 | # 2011_10_03_drive_0042
163 | 2011_10_03_drive_0047
164 | 2011_10_03_drive_0058
165 | )
166 |
167 | basedir='../data/data_rgb/val/'
168 | mkdir -p $basedir
169 | echo "Saving to "$basedir
170 | for i in ${files[@]}; do
171 | datadate="${i%%_drive_*}"
172 | echo $datadate
173 | shortname=$i'_sync.zip'
174 | fullname=$i'/'$i'_sync.zip'
175 | rm -f $shortname # remove zip file
176 | echo "Downloading: "$shortname
177 |
178 | wget 's3.eu-central-1.amazonaws.com/avg-kitti/raw_data/'$fullname
179 | unzip -o $shortname
180 | mv $datadate'/'$i'_sync' $basedir$i'_sync'
181 | rmdir $datadate
182 | rm -rf $basedir$i'_sync/image_00' $basedir$i'_sync/image_01' $basedir$i'_sync/velodyne_points' $basedir$i'_sync/oxts'
183 | rm $shortname # remove zip file
184 | done
185 |
186 |
187 |
--------------------------------------------------------------------------------
/helper.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os, time
3 | import shutil
4 | import torch
5 | import csv
6 | import vis_utils
7 | from metrics import Result
8 |
9 | fieldnames = [
10 | 'epoch', 'rmse', 'photo', 'mae', 'irmse', 'imae', 'mse', 'absrel', 'lg10',
11 | 'silog', 'squared_rel', 'delta1', 'delta2', 'delta3', 'data_time',
12 | 'gpu_time'
13 | ]
14 |
15 |
16 | class logger:
17 | def __init__(self, args, prepare=True):
18 | self.args = args
19 | output_directory = get_folder_name(args)
20 | self.output_directory = output_directory
21 | self.best_result = Result()
22 | self.best_result.set_to_worst()
23 |
24 | if not prepare:
25 | return
26 | if not os.path.exists(output_directory):
27 | os.makedirs(output_directory)
28 | self.train_csv = os.path.join(output_directory, 'train.csv')
29 | self.val_csv = os.path.join(output_directory, 'val.csv')
30 | self.best_txt = os.path.join(output_directory, 'best.txt')
31 |
32 | # backup the source code
33 | if args.resume == '':
34 | print("=> creating source code backup ...")
35 | backup_directory = os.path.join(output_directory, "code_backup")
36 | self.backup_directory = backup_directory
37 | backup_source_code(backup_directory)
38 | # create new csv files with only header
39 | with open(self.train_csv, 'w') as csvfile:
40 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
41 | writer.writeheader()
42 | with open(self.val_csv, 'w') as csvfile:
43 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
44 | writer.writeheader()
45 | print("=> finished creating source code backup.")
46 |
47 | def conditional_print(self, split, i, epoch, lr, n_set, blk_avg_meter,
48 | avg_meter):
49 | if (i + 1) % self.args.print_freq == 0:
50 | avg = avg_meter.average()
51 | blk_avg = blk_avg_meter.average()
52 | print('=> output: {}'.format(self.output_directory))
53 | print(
54 | '{split} Epoch: {0} [{1}/{2}]\tlr={lr} '
55 | 't_Data={blk_avg.data_time:.3f}({average.data_time:.3f}) '
56 | 't_GPU={blk_avg.gpu_time:.3f}({average.gpu_time:.3f})\n\t'
57 | 'RMSE={blk_avg.rmse:.2f}({average.rmse:.2f}) '
58 | 'MAE={blk_avg.mae:.2f}({average.mae:.2f}) '
59 | 'iRMSE={blk_avg.irmse:.2f}({average.irmse:.2f}) '
60 | 'iMAE={blk_avg.imae:.2f}({average.imae:.2f})\n\t'
61 | 'silog={blk_avg.silog:.2f}({average.silog:.2f}) '
62 | 'squared_rel={blk_avg.squared_rel:.2f}({average.squared_rel:.2f}) '
63 | 'Delta1={blk_avg.delta1:.3f}({average.delta1:.3f}) '
64 | 'REL={blk_avg.absrel:.3f}({average.absrel:.3f})\n\t'
65 | 'Lg10={blk_avg.lg10:.3f}({average.lg10:.3f}) '
66 | 'Photometric={blk_avg.photometric:.3f}({average.photometric:.3f}) '
67 | .format(epoch,
68 | i + 1,
69 | n_set,
70 | lr=lr,
71 | blk_avg=blk_avg,
72 | average=avg,
73 | split=split.capitalize()))
74 | blk_avg_meter.reset()
75 |
76 | def conditional_save_info(self, split, average_meter, epoch):
77 | avg = average_meter.average()
78 | if split == "train":
79 | csvfile_name = self.train_csv
80 | elif split == "val":
81 | csvfile_name = self.val_csv
82 | elif split == "eval":
83 | eval_filename = os.path.join(self.output_directory, 'eval.txt')
84 | self.save_single_txt(eval_filename, avg, epoch)
85 | return avg
86 | elif "test" in split:
87 | return avg
88 | else:
89 | raise ValueError("wrong split provided to logger")
90 | with open(csvfile_name, 'a') as csvfile:
91 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
92 | writer.writerow({
93 | 'epoch': epoch,
94 | 'rmse': avg.rmse,
95 | 'photo': avg.photometric,
96 | 'mae': avg.mae,
97 | 'irmse': avg.irmse,
98 | 'imae': avg.imae,
99 | 'mse': avg.mse,
100 | 'silog': avg.silog,
101 | 'squared_rel': avg.squared_rel,
102 | 'absrel': avg.absrel,
103 | 'lg10': avg.lg10,
104 | 'delta1': avg.delta1,
105 | 'delta2': avg.delta2,
106 | 'delta3': avg.delta3,
107 | 'gpu_time': avg.gpu_time,
108 | 'data_time': avg.data_time
109 | })
110 | return avg
111 |
112 | def save_single_txt(self, filename, result, epoch):
113 | with open(filename, 'w') as txtfile:
114 | txtfile.write(
115 | ("rank_metric={}\n" + "epoch={}\n" + "rmse={:.3f}\n" +
116 | "mae={:.3f}\n" + "silog={:.3f}\n" + "squared_rel={:.3f}\n" +
117 | "irmse={:.3f}\n" + "imae={:.3f}\n" + "mse={:.3f}\n" +
118 | "absrel={:.3f}\n" + "lg10={:.3f}\n" + "delta1={:.3f}\n" +
119 | "t_gpu={:.4f}").format(self.args.rank_metric, epoch,
120 | result.rmse, result.mae, result.silog,
121 | result.squared_rel, result.irmse,
122 | result.imae, result.mse, result.absrel,
123 | result.lg10, result.delta1,
124 | result.gpu_time))
125 |
126 | def save_best_txt(self, result, epoch):
127 | self.save_single_txt(self.best_txt, result, epoch)
128 |
129 | def _get_img_comparison_name(self, mode, epoch, is_best=False):
130 | if mode == 'eval':
131 | return self.output_directory + '/comparison_eval.png'
132 | if mode == 'val':
133 | if is_best:
134 | return self.output_directory + '/comparison_best.png'
135 | else:
136 | return self.output_directory + '/comparison_' + str(
137 | epoch) + '.png'
138 |
139 | def conditional_save_img_comparison(self, mode, i, ele, pred, epoch):
140 | # save 8 images for visualization
141 | if mode == 'val' or mode == 'eval':
142 | skip = 100
143 | if i == 0:
144 | self.img_merge = vis_utils.merge_into_row(ele, pred)
145 | elif i % skip == 0 and i < 8 * skip:
146 | row = vis_utils.merge_into_row(ele, pred)
147 | self.img_merge = vis_utils.add_row(self.img_merge, row)
148 | elif i == 8 * skip:
149 | filename = self._get_img_comparison_name(mode, epoch)
150 | vis_utils.save_image(self.img_merge, filename)
151 |
152 | def save_img_comparison_as_best(self, mode, epoch):
153 | if mode == 'val':
154 | filename = self._get_img_comparison_name(mode, epoch, is_best=True)
155 | vis_utils.save_image(self.img_merge, filename)
156 |
157 | def get_ranking_error(self, result):
158 | return getattr(result, self.args.rank_metric)
159 |
160 | def rank_conditional_save_best(self, mode, result, epoch):
161 | error = self.get_ranking_error(result)
162 | best_error = self.get_ranking_error(self.best_result)
163 | is_best = error < best_error
164 | if is_best and mode == "val":
165 | self.old_best_result = self.best_result
166 | self.best_result = result
167 | self.save_best_txt(result, epoch)
168 | return is_best
169 |
170 | def conditional_save_pred(self, mode, i, pred, epoch):
171 | if ("test" in mode or mode == "eval") and self.args.save_pred:
172 |
173 | # save images for visualization/ testing
174 | image_folder = os.path.join(self.output_directory,
175 | mode + "_output")
176 | if not os.path.exists(image_folder):
177 | os.makedirs(image_folder)
178 | img = torch.squeeze(pred.data.cpu()).numpy()
179 | filename = os.path.join(image_folder, '{0:010d}.png'.format(i))
180 | vis_utils.save_depth_as_uint16png(img, filename)
181 |
182 | def conditional_summarize(self, mode, avg, is_best):
183 | print("\n*\nSummary of ", mode, "round")
184 | print(''
185 | 'RMSE={average.rmse:.3f}\n'
186 | 'MAE={average.mae:.3f}\n'
187 | 'Photo={average.photometric:.3f}\n'
188 | 'iRMSE={average.irmse:.3f}\n'
189 | 'iMAE={average.imae:.3f}\n'
190 | 'squared_rel={average.squared_rel}\n'
191 | 'silog={average.silog}\n'
192 | 'Delta1={average.delta1:.3f}\n'
193 | 'REL={average.absrel:.3f}\n'
194 | 'Lg10={average.lg10:.3f}\n'
195 | 't_GPU={time:.3f}'.format(average=avg, time=avg.gpu_time))
196 | if is_best and mode == "val":
197 | print("New best model by %s (was %.3f)" %
198 | (self.args.rank_metric,
199 | self.get_ranking_error(self.old_best_result)))
200 | elif mode == "val":
201 | print("(best %s is %.3f)" %
202 | (self.args.rank_metric,
203 | self.get_ranking_error(self.best_result)))
204 | print("*\n")
205 |
206 |
207 | ignore_hidden = shutil.ignore_patterns(".", "..", ".git*", "*pycache*",
208 | "*build", "*.fuse*", "*_drive_*")
209 |
210 |
211 | def backup_source_code(backup_directory):
212 | if os.path.exists(backup_directory):
213 | shutil.rmtree(backup_directory)
214 | shutil.copytree('.', backup_directory, ignore=ignore_hidden)
215 |
216 |
217 | def adjust_learning_rate(lr_init, optimizer, epoch):
218 | """Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
219 | lr = lr_init * (0.1**(epoch // 5))
220 | for param_group in optimizer.param_groups:
221 | param_group['lr'] = lr
222 | return lr
223 |
224 |
225 | def save_checkpoint(state, is_best, epoch, output_directory):
226 | checkpoint_filename = os.path.join(output_directory,
227 | 'checkpoint-' + str(epoch) + '.pth.tar')
228 | torch.save(state, checkpoint_filename)
229 | if is_best:
230 | best_filename = os.path.join(output_directory, 'model_best.pth.tar')
231 | shutil.copyfile(checkpoint_filename, best_filename)
232 | if epoch > 0:
233 | prev_checkpoint_filename = os.path.join(
234 | output_directory, 'checkpoint-' + str(epoch - 1) + '.pth.tar')
235 | if os.path.exists(prev_checkpoint_filename):
236 | os.remove(prev_checkpoint_filename)
237 |
238 |
239 | def get_folder_name(args):
240 | current_time = time.strftime('%Y-%m-%d@%H-%M')
241 | if args.use_pose:
242 | prefix = "mode={}.w1={}.w2={}.".format(args.train_mode, args.w1,
243 | args.w2)
244 | else:
245 | prefix = "mode={}.".format(args.train_mode)
246 | return os.path.join(args.result,
247 | prefix + 'input={}.resnet{}.criterion={}.lr={}.bs={}.wd={}.pretrained={}.jitter={}.time={}'.
248 | format(args.input, args.layers, args.criterion, \
249 | args.lr, args.batch_size, args.weight_decay, \
250 | args.pretrained, args.jitter, current_time
251 | ))
252 |
253 |
254 | avgpool = torch.nn.AvgPool2d(kernel_size=2, stride=2).cuda()
255 |
256 |
257 | def multiscale(img):
258 | img1 = avgpool(img)
259 | img2 = avgpool(img1)
260 | img3 = avgpool(img2)
261 | img4 = avgpool(img3)
262 | img5 = avgpool(img4)
263 | return img5, img4, img3, img2, img1
264 |
--------------------------------------------------------------------------------
/inverse_warp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | class Intrinsics:
6 | def __init__(self, width, height, fu, fv, cu=0, cv=0):
7 | self.height, self.width = height, width
8 | self.fu, self.fv = fu, fv # fu, fv: focal length along the horizontal and vertical axes
9 |
10 | # cu, cv: optical center along the horizontal and vertical axes
11 | self.cu = cu if cu > 0 else (width - 1) / 2.0
12 | self.cv = cv if cv > 0 else (height - 1) / 2.0
13 |
14 | # U, V represent the homogeneous horizontal and vertical coordinates in the pixel space
15 | self.U = torch.arange(start=0, end=width).expand(height, width).float()
16 | self.V = torch.arange(start=0, end=height).expand(width,
17 | height).t().float()
18 |
19 | # X_cam, Y_cam represent the homogeneous x, y coordinates (assuming depth z=1) in the camera coordinate system
20 | self.X_cam = (self.U - self.cu) / self.fu
21 | self.Y_cam = (self.V - self.cv) / self.fv
22 |
23 | self.is_cuda = False
24 |
25 | def cuda(self):
26 | self.X_cam.data = self.X_cam.data.cuda()
27 | self.Y_cam.data = self.Y_cam.data.cuda()
28 | self.is_cuda = True
29 | return self
30 |
31 | def scale(self, height, width):
32 | # return a new set of corresponding intrinsic parameters for the scaled image
33 | ratio_u = float(width) / self.width
34 | ratio_v = float(height) / self.height
35 | fu = ratio_u * self.fu
36 | fv = ratio_v * self.fv
37 | cu = ratio_u * self.cu
38 | cv = ratio_v * self.cv
39 | new_intrinsics = Intrinsics(width, height, fu, fv, cu, cv)
40 | if self.is_cuda:
41 | new_intrinsics.cuda()
42 | return new_intrinsics
43 |
44 | def __print__(self):
45 | print('size=({},{})\nfocal length=({},{})\noptical center=({},{})'.
46 | format(self.height, self.width, self.fv, self.fu, self.cv,
47 | self.cu))
48 |
49 |
50 | def image_to_pointcloud(depth, intrinsics):
51 | assert depth.dim() == 4
52 | assert depth.size(1) == 1
53 |
54 | X = depth * intrinsics.X_cam
55 | Y = depth * intrinsics.Y_cam
56 | return torch.cat((X, Y, depth), dim=1)
57 |
58 |
59 | def pointcloud_to_image(pointcloud, intrinsics):
60 | assert pointcloud.dim() == 4
61 |
62 | batch_size = pointcloud.size(0)
63 | X = pointcloud[:, 0, :, :] #.view(batch_size, -1)
64 | Y = pointcloud[:, 1, :, :] #.view(batch_size, -1)
65 | Z = pointcloud[:, 2, :, :].clamp(min=1e-3) #.view(batch_size, -1)
66 |
67 | # compute pixel coordinates
68 | U_proj = intrinsics.fu * X / Z + intrinsics.cu # horizontal pixel coordinate
69 | V_proj = intrinsics.fv * Y / Z + intrinsics.cv # vertical pixel coordinate
70 |
71 | # normalization to [-1, 1], required by torch.nn.functional.grid_sample
72 | U_proj_normalized = (2 * U_proj / (intrinsics.width - 1) - 1).view(
73 | batch_size, -1)
74 | V_proj_normalized = (2 * V_proj / (intrinsics.height - 1) - 1).view(
75 | batch_size, -1)
76 |
77 | # This was important since PyTorch didn't do as it claimed for points out of boundary
78 | # See https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
79 | # Might not be necessary any more
80 | U_proj_mask = ((U_proj_normalized > 1) + (U_proj_normalized < -1)).detach()
81 | U_proj_normalized[U_proj_mask] = 2
82 | V_proj_mask = ((V_proj_normalized > 1) + (V_proj_normalized < -1)).detach()
83 | V_proj_normalized[V_proj_mask] = 2
84 |
85 | pixel_coords = torch.stack([U_proj_normalized, V_proj_normalized],
86 | dim=2) # [B, H*W, 2]
87 | return pixel_coords.view(batch_size, intrinsics.height, intrinsics.width,
88 | 2)
89 |
90 |
91 | def batch_multiply(batch_scalar, batch_matrix):
92 | # input: batch_scalar of size b, batch_matrix of size b * 3 * 3
93 | # output: batch_matrix of size b * 3 * 3
94 | batch_size = batch_scalar.size(0)
95 | output = batch_matrix.clone()
96 | for i in range(batch_size):
97 | output[i] = batch_scalar[i] * batch_matrix[i]
98 | return output
99 |
100 |
101 | def transform_curr_to_near(pointcloud_curr, r_mat, t_vec, intrinsics):
102 | # translation and rotmat represent the transformation from tgt pose to src pose
103 | batch_size = pointcloud_curr.size(0)
104 | XYZ_ = torch.bmm(r_mat, pointcloud_curr.view(batch_size, 3, -1))
105 |
106 | X = (XYZ_[:, 0, :] + t_vec[:, 0].unsqueeze(1)).view(
107 | -1, 1, intrinsics.height, intrinsics.width)
108 | Y = (XYZ_[:, 1, :] + t_vec[:, 1].unsqueeze(1)).view(
109 | -1, 1, intrinsics.height, intrinsics.width)
110 | Z = (XYZ_[:, 2, :] + t_vec[:, 2].unsqueeze(1)).view(
111 | -1, 1, intrinsics.height, intrinsics.width)
112 |
113 | pointcloud_near = torch.cat((X, Y, Z), dim=1)
114 |
115 | return pointcloud_near
116 |
117 |
118 | def homography_from(rgb_near, depth_curr, r_mat, t_vec, intrinsics):
119 | # inverse warp the RGB image from the nearby frame to the current frame
120 |
121 | # to ensure dimension consistency
122 | r_mat = r_mat.view(-1, 3, 3)
123 | t_vec = t_vec.view(-1, 3)
124 |
125 | # compute source pixel coordinate
126 | pointcloud_curr = image_to_pointcloud(depth_curr, intrinsics)
127 | pointcloud_near = transform_curr_to_near(pointcloud_curr, r_mat, t_vec,
128 | intrinsics)
129 | pixel_coords_near = pointcloud_to_image(pointcloud_near, intrinsics)
130 |
131 | # the warping
132 | warped = F.grid_sample(rgb_near, pixel_coords_near)
133 |
134 | return warped
135 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import torch
6 | import torch.nn.parallel
7 | import torch.optim
8 | import torch.utils.data
9 |
10 | from dataloaders.kitti_loader import load_calib, oheight, owidth, input_options, KittiDepth
11 | from model import DepthCompletionNet
12 | from metrics import AverageMeter, Result
13 | import criteria
14 | import helper
15 | from inverse_warp import Intrinsics, homography_from
16 |
17 | parser = argparse.ArgumentParser(description='Sparse-to-Dense')
18 | parser.add_argument('-w',
19 | '--workers',
20 | default=4,
21 | type=int,
22 | metavar='N',
23 | help='number of data loading workers (default: 4)')
24 | parser.add_argument('--epochs',
25 | default=11,
26 | type=int,
27 | metavar='N',
28 | help='number of total epochs to run (default: 11)')
29 | parser.add_argument('--start-epoch',
30 | default=0,
31 | type=int,
32 | metavar='N',
33 | help='manual epoch number (useful on restarts)')
34 | parser.add_argument('-c',
35 | '--criterion',
36 | metavar='LOSS',
37 | default='l2',
38 | choices=criteria.loss_names,
39 | help='loss function: | '.join(criteria.loss_names) +
40 | ' (default: l2)')
41 | parser.add_argument('-b',
42 | '--batch-size',
43 | default=1,
44 | type=int,
45 | help='mini-batch size (default: 1)')
46 | parser.add_argument('--lr',
47 | '--learning-rate',
48 | default=1e-5,
49 | type=float,
50 | metavar='LR',
51 | help='initial learning rate (default 1e-5)')
52 | parser.add_argument('--weight-decay',
53 | '--wd',
54 | default=0,
55 | type=float,
56 | metavar='W',
57 | help='weight decay (default: 0)')
58 | parser.add_argument('--print-freq',
59 | '-p',
60 | default=10,
61 | type=int,
62 | metavar='N',
63 | help='print frequency (default: 10)')
64 | parser.add_argument('--resume',
65 | default='',
66 | type=str,
67 | metavar='PATH',
68 | help='path to latest checkpoint (default: none)')
69 | parser.add_argument('--data-folder',
70 | default='../data',
71 | type=str,
72 | metavar='PATH',
73 | help='data folder (default: none)')
74 | parser.add_argument('-i',
75 | '--input',
76 | type=str,
77 | default='gd',
78 | choices=input_options,
79 | help='input: | '.join(input_options))
80 | parser.add_argument('-l',
81 | '--layers',
82 | type=int,
83 | default=34,
84 | help='use 16 for sparse_conv; use 18 or 34 for resnet')
85 | parser.add_argument('--pretrained',
86 | action="store_true",
87 | help='use ImageNet pre-trained weights')
88 | parser.add_argument('--val',
89 | type=str,
90 | default="select",
91 | choices=["select", "full"],
92 | help='full or select validation set')
93 | parser.add_argument('--jitter',
94 | type=float,
95 | default=0.1,
96 | help='color jitter for images')
97 | parser.add_argument(
98 | '--rank-metric',
99 | type=str,
100 | default='rmse',
101 | choices=[m for m in dir(Result()) if not m.startswith('_')],
102 | help='metrics for which best result is sbatch_datacted')
103 | parser.add_argument(
104 | '-m',
105 | '--train-mode',
106 | type=str,
107 | default="dense",
108 | choices=["dense", "sparse", "photo", "sparse+photo", "dense+photo"],
109 | help='dense | sparse | photo | sparse+photo | dense+photo')
110 | parser.add_argument('-e', '--evaluate', default='', type=str, metavar='PATH')
111 | parser.add_argument('--cpu', action="store_true", help='run on cpu')
112 |
113 | args = parser.parse_args()
114 | args.use_pose = ("photo" in args.train_mode)
115 | # args.pretrained = not args.no_pretrained
116 | args.result = os.path.join('..', 'results')
117 | args.use_rgb = ('rgb' in args.input) or args.use_pose
118 | args.use_d = 'd' in args.input
119 | args.use_g = 'g' in args.input
120 | if args.use_pose:
121 | args.w1, args.w2 = 0.1, 0.1
122 | else:
123 | args.w1, args.w2 = 0, 0
124 | print(args)
125 |
126 | cuda = torch.cuda.is_available() and not args.cpu
127 | if cuda:
128 | import torch.backends.cudnn as cudnn
129 | cudnn.benchmark = True
130 | device = torch.device("cuda")
131 | else:
132 | device = torch.device("cpu")
133 | print("=> using '{}' for computation.".format(device))
134 |
135 | # define loss functions
136 | depth_criterion = criteria.MaskedMSELoss() if (
137 | args.criterion == 'l2') else criteria.MaskedL1Loss()
138 | photometric_criterion = criteria.PhotometricLoss()
139 | smoothness_criterion = criteria.SmoothnessLoss()
140 |
141 | if args.use_pose:
142 | # hard-coded KITTI camera intrinsics
143 | K = load_calib()
144 | fu, fv = float(K[0, 0]), float(K[1, 1])
145 | cu, cv = float(K[0, 2]), float(K[1, 2])
146 | kitti_intrinsics = Intrinsics(owidth, oheight, fu, fv, cu, cv)
147 | if cuda:
148 | kitti_intrinsics = kitti_intrinsics.cuda()
149 |
150 |
151 | def iterate(mode, args, loader, model, optimizer, logger, epoch):
152 | block_average_meter = AverageMeter()
153 | average_meter = AverageMeter()
154 | meters = [block_average_meter, average_meter]
155 |
156 | # switch to appropriate mode
157 | assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
158 | "unsupported mode: {}".format(mode)
159 | if mode == 'train':
160 | model.train()
161 | lr = helper.adjust_learning_rate(args.lr, optimizer, epoch)
162 | else:
163 | model.eval()
164 | lr = 0
165 |
166 | for i, batch_data in enumerate(loader):
167 | start = time.time()
168 | batch_data = {
169 | key: val.to(device)
170 | for key, val in batch_data.items() if val is not None
171 | }
172 | gt = batch_data[
173 | 'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
174 | data_time = time.time() - start
175 |
176 | start = time.time()
177 | pred = model(batch_data)
178 | depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
179 | if mode == 'train':
180 | # Loss 1: the direct depth supervision from ground truth label
181 | # mask=1 indicates that a pixel does not ground truth labels
182 | if 'sparse' in args.train_mode:
183 | depth_loss = depth_criterion(pred, batch_data['d'])
184 | mask = (batch_data['d'] < 1e-3).float()
185 | elif 'dense' in args.train_mode:
186 | depth_loss = depth_criterion(pred, gt)
187 | mask = (gt < 1e-3).float()
188 |
189 | # Loss 2: the self-supervised photometric loss
190 | if args.use_pose:
191 | # create multi-scale pyramids
192 | pred_array = helper.multiscale(pred)
193 | rgb_curr_array = helper.multiscale(batch_data['rgb'])
194 | rgb_near_array = helper.multiscale(batch_data['rgb_near'])
195 | if mask is not None:
196 | mask_array = helper.multiscale(mask)
197 | num_scales = len(pred_array)
198 |
199 | # compute photometric loss at multiple scales
200 | for scale in range(len(pred_array)):
201 | pred_ = pred_array[scale]
202 | rgb_curr_ = rgb_curr_array[scale]
203 | rgb_near_ = rgb_near_array[scale]
204 | mask_ = None
205 | if mask is not None:
206 | mask_ = mask_array[scale]
207 |
208 | # compute the corresponding intrinsic parameters
209 | height_, width_ = pred_.size(2), pred_.size(3)
210 | intrinsics_ = kitti_intrinsics.scale(height_, width_)
211 |
212 | # inverse warp from a nearby frame to the current frame
213 | warped_ = homography_from(rgb_near_, pred_,
214 | batch_data['r_mat'],
215 | batch_data['t_vec'], intrinsics_)
216 | photometric_loss += photometric_criterion(
217 | rgb_curr_, warped_, mask_) * (2**(scale - num_scales))
218 |
219 | # Loss 3: the depth smoothness loss
220 | smooth_loss = smoothness_criterion(pred) if args.w2 > 0 else 0
221 |
222 | # backprop
223 | loss = depth_loss + args.w1 * photometric_loss + args.w2 * smooth_loss
224 | optimizer.zero_grad()
225 | loss.backward()
226 | optimizer.step()
227 |
228 | gpu_time = time.time() - start
229 |
230 | # measure accuracy and record loss
231 | with torch.no_grad():
232 | mini_batch_size = next(iter(batch_data.values())).size(0)
233 | result = Result()
234 | if mode != 'test_prediction' and mode != 'test_completion':
235 | result.evaluate(pred.data, gt.data, photometric_loss)
236 | [
237 | m.update(result, gpu_time, data_time, mini_batch_size)
238 | for m in meters
239 | ]
240 | logger.conditional_print(mode, i, epoch, lr, len(loader),
241 | block_average_meter, average_meter)
242 | logger.conditional_save_img_comparison(mode, i, batch_data, pred,
243 | epoch)
244 | logger.conditional_save_pred(mode, i, pred, epoch)
245 |
246 | avg = logger.conditional_save_info(mode, average_meter, epoch)
247 | is_best = logger.rank_conditional_save_best(mode, avg, epoch)
248 | if is_best and not (mode == "train"):
249 | logger.save_img_comparison_as_best(mode, epoch)
250 | logger.conditional_summarize(mode, avg, is_best)
251 |
252 | return avg, is_best
253 |
254 |
255 | def main():
256 | global args
257 | checkpoint = None
258 | is_eval = False
259 | if args.evaluate:
260 | args_new = args
261 | if os.path.isfile(args.evaluate):
262 | print("=> loading checkpoint '{}' ... ".format(args.evaluate),
263 | end='')
264 | checkpoint = torch.load(args.evaluate, map_location=device)
265 | args = checkpoint['args']
266 | args.data_folder = args_new.data_folder
267 | args.val = args_new.val
268 | is_eval = True
269 | print("Completed.")
270 | else:
271 | print("No model found at '{}'".format(args.evaluate))
272 | return
273 | elif args.resume: # optionally resume from a checkpoint
274 | args_new = args
275 | if os.path.isfile(args.resume):
276 | print("=> loading checkpoint '{}' ... ".format(args.resume),
277 | end='')
278 | checkpoint = torch.load(args.resume, map_location=device)
279 | args.start_epoch = checkpoint['epoch'] + 1
280 | args.data_folder = args_new.data_folder
281 | args.val = args_new.val
282 | print("Completed. Resuming from epoch {}.".format(
283 | checkpoint['epoch']))
284 | else:
285 | print("No checkpoint found at '{}'".format(args.resume))
286 | return
287 |
288 | print("=> creating model and optimizer ... ", end='')
289 | model = DepthCompletionNet(args).to(device)
290 | model_named_params = [
291 | p for _, p in model.named_parameters() if p.requires_grad
292 | ]
293 | optimizer = torch.optim.Adam(model_named_params,
294 | lr=args.lr,
295 | weight_decay=args.weight_decay)
296 | print("completed.")
297 | if checkpoint is not None:
298 | model.load_state_dict(checkpoint['model'])
299 | optimizer.load_state_dict(checkpoint['optimizer'])
300 | print("=> checkpoint state loaded.")
301 |
302 | model = torch.nn.DataParallel(model)
303 |
304 | # Data loading code
305 | print("=> creating data loaders ... ")
306 | if not is_eval:
307 | train_dataset = KittiDepth('train', args)
308 | train_loader = torch.utils.data.DataLoader(train_dataset,
309 | batch_size=args.batch_size,
310 | shuffle=True,
311 | num_workers=args.workers,
312 | pin_memory=True,
313 | sampler=None)
314 | print("\t==> train_loader size:{}".format(len(train_loader)))
315 | val_dataset = KittiDepth('val', args)
316 | val_loader = torch.utils.data.DataLoader(
317 | val_dataset,
318 | batch_size=1,
319 | shuffle=False,
320 | num_workers=2,
321 | pin_memory=True) # set batch size to be 1 for validation
322 | print("\t==> val_loader size:{}".format(len(val_loader)))
323 |
324 | # create backups and results folder
325 | logger = helper.logger(args)
326 | if checkpoint is not None:
327 | logger.best_result = checkpoint['best_result']
328 | print("=> logger created.")
329 |
330 | if is_eval:
331 | print("=> starting model evaluation ...")
332 | result, is_best = iterate("val", args, val_loader, model, None, logger,
333 | checkpoint['epoch'])
334 | return
335 |
336 | # main loop
337 | print("=> starting main loop ...")
338 | for epoch in range(args.start_epoch, args.epochs):
339 | print("=> starting training epoch {} ..".format(epoch))
340 | iterate("train", args, train_loader, model, optimizer, logger,
341 | epoch) # train for one epoch
342 | result, is_best = iterate("val", args, val_loader, model, None, logger,
343 | epoch) # evaluate on validation set
344 | helper.save_checkpoint({ # save checkpoint
345 | 'epoch': epoch,
346 | 'model': model.module.state_dict(),
347 | 'best_result': logger.best_result,
348 | 'optimizer' : optimizer.state_dict(),
349 | 'args' : args,
350 | }, is_best, epoch, logger.output_directory)
351 |
352 |
353 | if __name__ == '__main__':
354 | main()
355 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 |
5 | lg_e_10 = math.log(10)
6 |
7 |
8 | def log10(x):
9 | """Convert a new tensor with the base-10 logarithm of the elements of x. """
10 | return torch.log(x) / lg_e_10
11 |
12 |
13 | class Result(object):
14 | def __init__(self):
15 | self.irmse = 0
16 | self.imae = 0
17 | self.mse = 0
18 | self.rmse = 0
19 | self.mae = 0
20 | self.absrel = 0
21 | self.squared_rel = 0
22 | self.lg10 = 0
23 | self.delta1 = 0
24 | self.delta2 = 0
25 | self.delta3 = 0
26 | self.data_time = 0
27 | self.gpu_time = 0
28 | self.silog = 0 # Scale invariant logarithmic error [log(m)*100]
29 | self.photometric = 0
30 |
31 | def set_to_worst(self):
32 | self.irmse = np.inf
33 | self.imae = np.inf
34 | self.mse = np.inf
35 | self.rmse = np.inf
36 | self.mae = np.inf
37 | self.absrel = np.inf
38 | self.squared_rel = np.inf
39 | self.lg10 = np.inf
40 | self.silog = np.inf
41 | self.delta1 = 0
42 | self.delta2 = 0
43 | self.delta3 = 0
44 | self.data_time = 0
45 | self.gpu_time = 0
46 |
47 | def update(self, irmse, imae, mse, rmse, mae, absrel, squared_rel, lg10, \
48 | delta1, delta2, delta3, gpu_time, data_time, silog, photometric=0):
49 | self.irmse = irmse
50 | self.imae = imae
51 | self.mse = mse
52 | self.rmse = rmse
53 | self.mae = mae
54 | self.absrel = absrel
55 | self.squared_rel = squared_rel
56 | self.lg10 = lg10
57 | self.delta1 = delta1
58 | self.delta2 = delta2
59 | self.delta3 = delta3
60 | self.data_time = data_time
61 | self.gpu_time = gpu_time
62 | self.silog = silog
63 | self.photometric = photometric
64 |
65 | def evaluate(self, output, target, photometric=0):
66 | valid_mask = target > 0.1
67 |
68 | # convert from meters to mm
69 | output_mm = 1e3 * output[valid_mask]
70 | target_mm = 1e3 * target[valid_mask]
71 |
72 | abs_diff = (output_mm - target_mm).abs()
73 |
74 | self.mse = float((torch.pow(abs_diff, 2)).mean())
75 | self.rmse = math.sqrt(self.mse)
76 | self.mae = float(abs_diff.mean())
77 | self.lg10 = float((log10(output_mm) - log10(target_mm)).abs().mean())
78 | self.absrel = float((abs_diff / target_mm).mean())
79 | self.squared_rel = float(((abs_diff / target_mm)**2).mean())
80 |
81 | maxRatio = torch.max(output_mm / target_mm, target_mm / output_mm)
82 | self.delta1 = float((maxRatio < 1.25).float().mean())
83 | self.delta2 = float((maxRatio < 1.25**2).float().mean())
84 | self.delta3 = float((maxRatio < 1.25**3).float().mean())
85 | self.data_time = 0
86 | self.gpu_time = 0
87 |
88 | # silog uses meters
89 | err_log = torch.log(target[valid_mask]) - torch.log(output[valid_mask])
90 | normalized_squared_log = (err_log**2).mean()
91 | log_mean = err_log.mean()
92 | self.silog = math.sqrt(normalized_squared_log -
93 | log_mean * log_mean) * 100
94 |
95 | # convert from meters to km
96 | inv_output_km = (1e-3 * output[valid_mask])**(-1)
97 | inv_target_km = (1e-3 * target[valid_mask])**(-1)
98 | abs_inv_diff = (inv_output_km - inv_target_km).abs()
99 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
100 | self.imae = float(abs_inv_diff.mean())
101 |
102 | self.photometric = float(photometric)
103 |
104 |
105 | class AverageMeter(object):
106 | def __init__(self):
107 | self.reset()
108 |
109 | def reset(self):
110 | self.count = 0.0
111 | self.sum_irmse = 0
112 | self.sum_imae = 0
113 | self.sum_mse = 0
114 | self.sum_rmse = 0
115 | self.sum_mae = 0
116 | self.sum_absrel = 0
117 | self.sum_squared_rel = 0
118 | self.sum_lg10 = 0
119 | self.sum_delta1 = 0
120 | self.sum_delta2 = 0
121 | self.sum_delta3 = 0
122 | self.sum_data_time = 0
123 | self.sum_gpu_time = 0
124 | self.sum_photometric = 0
125 | self.sum_silog = 0
126 |
127 | def update(self, result, gpu_time, data_time, n=1):
128 | self.count += n
129 | self.sum_irmse += n * result.irmse
130 | self.sum_imae += n * result.imae
131 | self.sum_mse += n * result.mse
132 | self.sum_rmse += n * result.rmse
133 | self.sum_mae += n * result.mae
134 | self.sum_absrel += n * result.absrel
135 | self.sum_squared_rel += n * result.squared_rel
136 | self.sum_lg10 += n * result.lg10
137 | self.sum_delta1 += n * result.delta1
138 | self.sum_delta2 += n * result.delta2
139 | self.sum_delta3 += n * result.delta3
140 | self.sum_data_time += n * data_time
141 | self.sum_gpu_time += n * gpu_time
142 | self.sum_silog += n * result.silog
143 | self.sum_photometric += n * result.photometric
144 |
145 | def average(self):
146 | avg = Result()
147 | if self.count > 0:
148 | avg.update(
149 | self.sum_irmse / self.count, self.sum_imae / self.count,
150 | self.sum_mse / self.count, self.sum_rmse / self.count,
151 | self.sum_mae / self.count, self.sum_absrel / self.count,
152 | self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
153 | self.sum_delta1 / self.count, self.sum_delta2 / self.count,
154 | self.sum_delta3 / self.count, self.sum_gpu_time / self.count,
155 | self.sum_data_time / self.count, self.sum_silog / self.count,
156 | self.sum_photometric / self.count)
157 | return avg
158 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision.models import resnet
6 |
7 |
8 | def init_weights(m):
9 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
10 | m.weight.data.normal_(0, 1e-3)
11 | if m.bias is not None:
12 | m.bias.data.zero_()
13 | elif isinstance(m, nn.ConvTranspose2d):
14 | m.weight.data.normal_(0, 1e-3)
15 | if m.bias is not None:
16 | m.bias.data.zero_()
17 | elif isinstance(m, nn.BatchNorm2d):
18 | m.weight.data.fill_(1)
19 | m.bias.data.zero_()
20 |
21 | def conv_bn_relu(in_channels, out_channels, kernel_size, \
22 | stride=1, padding=0, bn=True, relu=True):
23 | bias = not bn
24 | layers = []
25 | layers.append(
26 | nn.Conv2d(in_channels,
27 | out_channels,
28 | kernel_size,
29 | stride,
30 | padding,
31 | bias=bias))
32 | if bn:
33 | layers.append(nn.BatchNorm2d(out_channels))
34 | if relu:
35 | layers.append(nn.LeakyReLU(0.2, inplace=True))
36 | layers = nn.Sequential(*layers)
37 |
38 | # initialize the weights
39 | for m in layers.modules():
40 | init_weights(m)
41 |
42 | return layers
43 |
44 | def convt_bn_relu(in_channels, out_channels, kernel_size, \
45 | stride=1, padding=0, output_padding=0, bn=True, relu=True):
46 | bias = not bn
47 | layers = []
48 | layers.append(
49 | nn.ConvTranspose2d(in_channels,
50 | out_channels,
51 | kernel_size,
52 | stride,
53 | padding,
54 | output_padding,
55 | bias=bias))
56 | if bn:
57 | layers.append(nn.BatchNorm2d(out_channels))
58 | if relu:
59 | layers.append(nn.LeakyReLU(0.2, inplace=True))
60 | layers = nn.Sequential(*layers)
61 |
62 | # initialize the weights
63 | for m in layers.modules():
64 | init_weights(m)
65 |
66 | return layers
67 |
68 |
69 | class DepthCompletionNet(nn.Module):
70 | def __init__(self, args):
71 | assert (
72 | args.layers in [18, 34, 50, 101, 152]
73 | ), 'Only layers 18, 34, 50, 101, and 152 are defined, but got {}'.format(
74 | layers)
75 | super(DepthCompletionNet, self).__init__()
76 | self.modality = args.input
77 |
78 | if 'd' in self.modality:
79 | channels = 64 // len(self.modality)
80 | self.conv1_d = conv_bn_relu(1,
81 | channels,
82 | kernel_size=3,
83 | stride=1,
84 | padding=1)
85 | if 'rgb' in self.modality:
86 | channels = 64 * 3 // len(self.modality)
87 | self.conv1_img = conv_bn_relu(3,
88 | channels,
89 | kernel_size=3,
90 | stride=1,
91 | padding=1)
92 | elif 'g' in self.modality:
93 | channels = 64 // len(self.modality)
94 | self.conv1_img = conv_bn_relu(1,
95 | channels,
96 | kernel_size=3,
97 | stride=1,
98 | padding=1)
99 |
100 | pretrained_model = resnet.__dict__['resnet{}'.format(
101 | args.layers)](pretrained=args.pretrained)
102 | if not args.pretrained:
103 | pretrained_model.apply(init_weights)
104 | #self.maxpool = pretrained_model._modules['maxpool']
105 | self.conv2 = pretrained_model._modules['layer1']
106 | self.conv3 = pretrained_model._modules['layer2']
107 | self.conv4 = pretrained_model._modules['layer3']
108 | self.conv5 = pretrained_model._modules['layer4']
109 | del pretrained_model # clear memory
110 |
111 | # define number of intermediate channels
112 | if args.layers <= 34:
113 | num_channels = 512
114 | elif args.layers >= 50:
115 | num_channels = 2048
116 | self.conv6 = conv_bn_relu(num_channels,
117 | 512,
118 | kernel_size=3,
119 | stride=2,
120 | padding=1)
121 |
122 | # decoding layers
123 | kernel_size = 3
124 | stride = 2
125 | self.convt5 = convt_bn_relu(in_channels=512,
126 | out_channels=256,
127 | kernel_size=kernel_size,
128 | stride=stride,
129 | padding=1,
130 | output_padding=1)
131 | self.convt4 = convt_bn_relu(in_channels=768,
132 | out_channels=128,
133 | kernel_size=kernel_size,
134 | stride=stride,
135 | padding=1,
136 | output_padding=1)
137 | self.convt3 = convt_bn_relu(in_channels=(256 + 128),
138 | out_channels=64,
139 | kernel_size=kernel_size,
140 | stride=stride,
141 | padding=1,
142 | output_padding=1)
143 | self.convt2 = convt_bn_relu(in_channels=(128 + 64),
144 | out_channels=64,
145 | kernel_size=kernel_size,
146 | stride=stride,
147 | padding=1,
148 | output_padding=1)
149 | self.convt1 = convt_bn_relu(in_channels=128,
150 | out_channels=64,
151 | kernel_size=kernel_size,
152 | stride=1,
153 | padding=1)
154 | self.convtf = conv_bn_relu(in_channels=128,
155 | out_channels=1,
156 | kernel_size=1,
157 | stride=1,
158 | bn=False,
159 | relu=False)
160 |
161 | def forward(self, x):
162 | # first layer
163 | if 'd' in self.modality:
164 | conv1_d = self.conv1_d(x['d'])
165 | if 'rgb' in self.modality:
166 | conv1_img = self.conv1_img(x['rgb'])
167 | elif 'g' in self.modality:
168 | conv1_img = self.conv1_img(x['g'])
169 |
170 | if self.modality == 'rgbd' or self.modality == 'gd':
171 | conv1 = torch.cat((conv1_d, conv1_img), 1)
172 | else:
173 | conv1 = conv1_d if (self.modality == 'd') else conv1_img
174 |
175 | conv2 = self.conv2(conv1)
176 | conv3 = self.conv3(conv2) # batchsize * ? * 176 * 608
177 | conv4 = self.conv4(conv3) # batchsize * ? * 88 * 304
178 | conv5 = self.conv5(conv4) # batchsize * ? * 44 * 152
179 | conv6 = self.conv6(conv5) # batchsize * ? * 22 * 76
180 |
181 | # decoder
182 | convt5 = self.convt5(conv6)
183 | y = torch.cat((convt5, conv5), 1)
184 |
185 | convt4 = self.convt4(y)
186 | y = torch.cat((convt4, conv4), 1)
187 |
188 | convt3 = self.convt3(y)
189 | y = torch.cat((convt3, conv3), 1)
190 |
191 | convt2 = self.convt2(y)
192 | y = torch.cat((convt2, conv2), 1)
193 |
194 | convt1 = self.convt1(y)
195 | y = torch.cat((convt1, conv1), 1)
196 |
197 | y = self.convtf(y)
198 |
199 | if self.training:
200 | return 100 * y
201 | else:
202 | min_distance = 0.9
203 | return F.relu(
204 | 100 * y - min_distance
205 | ) + min_distance # the minimum range of Velodyne is around 3 feet ~= 0.9m
206 |
--------------------------------------------------------------------------------
/vis_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | if not ("DISPLAY" in os.environ):
3 | import matplotlib as mpl
4 | mpl.use('Agg')
5 | import matplotlib.pyplot as plt
6 | from PIL import Image
7 | import numpy as np
8 | import cv2
9 |
10 | cmap = plt.cm.jet
11 |
12 |
13 | def depth_colorize(depth):
14 | depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth))
15 | depth = 255 * cmap(depth)[:, :, :3] # H, W, C
16 | return depth.astype('uint8')
17 |
18 |
19 | def merge_into_row(ele, pred):
20 | def preprocess_depth(x):
21 | y = np.squeeze(x.data.cpu().numpy())
22 | return depth_colorize(y)
23 |
24 | # if is gray, transforms to rgb
25 | img_list = []
26 | if 'rgb' in ele:
27 | rgb = np.squeeze(ele['rgb'][0, ...].data.cpu().numpy())
28 | rgb = np.transpose(rgb, (1, 2, 0))
29 | img_list.append(rgb)
30 | elif 'g' in ele:
31 | g = np.squeeze(ele['g'][0, ...].data.cpu().numpy())
32 | g = np.array(Image.fromarray(g).convert('RGB'))
33 | img_list.append(g)
34 | if 'd' in ele:
35 | img_list.append(preprocess_depth(ele['d'][0, ...]))
36 | img_list.append(preprocess_depth(pred[0, ...]))
37 | if 'gt' in ele:
38 | img_list.append(preprocess_depth(ele['gt'][0, ...]))
39 |
40 | img_merge = np.hstack(img_list)
41 | return img_merge.astype('uint8')
42 |
43 |
44 | def add_row(img_merge, row):
45 | return np.vstack([img_merge, row])
46 |
47 |
48 | def save_image(img_merge, filename):
49 | image_to_write = cv2.cvtColor(img_merge, cv2.COLOR_RGB2BGR)
50 | cv2.imwrite(filename, image_to_write)
51 |
52 |
53 | def save_depth_as_uint16png(img, filename):
54 | img = (img * 256).astype('uint16')
55 | cv2.imwrite(filename, img)
56 |
57 |
58 | if ("DISPLAY" in os.environ):
59 | f, axarr = plt.subplots(4, 1)
60 | plt.tight_layout()
61 | plt.ion()
62 |
63 |
64 | def display_warping(rgb_tgt, pred_tgt, warped):
65 | def preprocess(rgb_tgt, pred_tgt, warped):
66 | rgb_tgt = 255 * np.transpose(np.squeeze(rgb_tgt.data.cpu().numpy()),
67 | (1, 2, 0)) # H, W, C
68 | # depth = np.squeeze(depth.cpu().numpy())
69 | # depth = depth_colorize(depth)
70 |
71 | # convert to log-scale
72 | pred_tgt = np.squeeze(pred_tgt.data.cpu().numpy())
73 | # pred_tgt[pred_tgt<=0] = 0.9 # remove negative predictions
74 | # pred_tgt = np.log10(pred_tgt)
75 |
76 | pred_tgt = depth_colorize(pred_tgt)
77 |
78 | warped = 255 * np.transpose(np.squeeze(warped.data.cpu().numpy()),
79 | (1, 2, 0)) # H, W, C
80 | recon_err = np.absolute(
81 | warped.astype('float') - rgb_tgt.astype('float')) * (warped > 0)
82 | recon_err = recon_err[:, :, 0] + recon_err[:, :, 1] + recon_err[:, :, 2]
83 | recon_err = depth_colorize(recon_err)
84 | return rgb_tgt.astype('uint8'), warped.astype(
85 | 'uint8'), recon_err, pred_tgt
86 |
87 | rgb_tgt, warped, recon_err, pred_tgt = preprocess(rgb_tgt, pred_tgt,
88 | warped)
89 |
90 | # 1st column
91 | column = 0
92 | axarr[0].imshow(rgb_tgt)
93 | axarr[0].axis('off')
94 | axarr[0].axis('equal')
95 | # axarr[0, column].set_title('rgb_tgt')
96 |
97 | axarr[1].imshow(warped)
98 | axarr[1].axis('off')
99 | axarr[1].axis('equal')
100 | # axarr[1, column].set_title('warped')
101 |
102 | axarr[2].imshow(recon_err, 'hot')
103 | axarr[2].axis('off')
104 | axarr[2].axis('equal')
105 | # axarr[2, column].set_title('recon_err error')
106 |
107 | axarr[3].imshow(pred_tgt, 'hot')
108 | axarr[3].axis('off')
109 | axarr[3].axis('equal')
110 | # axarr[3, column].set_title('pred_tgt')
111 |
112 | # plt.show()
113 | plt.pause(0.001)
114 |
--------------------------------------------------------------------------------