├── .gitignore ├── LICENSE ├── README.md ├── custom_transforms.py ├── data ├── cityscapes_loader.py ├── kitti_raw_loader.py ├── prepare_train_data.py ├── static_frames.txt └── test_scenes.txt ├── datasets ├── sequence_folders.py ├── shifted_sequence_folders.py ├── stacked_sequence_folders.py └── validation_folders.py ├── depthnet_unravel_bn.py ├── inverse_warp.py ├── kitti_eval ├── depth_evaluation_utils.py ├── pose_evaluation_utils.py ├── test_files_eigen.txt └── test_files_eigen_filtered.txt ├── logger.py ├── loss_functions.py ├── models ├── DepthNet.py ├── DispNetS.py ├── PoseNet.py ├── UpSampleNet.py ├── __init__.py └── utils.py ├── requirements.txt ├── ssim.py ├── stillbox_eval ├── depth_evaluation_utils.py ├── test_files_80.txt └── test_files_90.txt ├── test_depth.py ├── test_pose.py ├── train_flexible_shifts.py ├── train_img_pairs.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Tinghui Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised DepthNet 2 | 3 | This codebase implements the system described in the paper: 4 | 5 | Learning Structure From Motion *From Motion* 6 | 7 | [Clement Pinard](http://perso.ensta-paristech.fr/~pinard/), Laure Chevalley, [Antoine Manzanera](http://perso.ensta-paristech.fr/~manzaner/), [David Filliat](http://perso.ensta-paristech.fr/~filliat/eng/) 8 | 9 | [![youtube video](http://img.youtube.com/vi/ZDgWAWTwU7U/0.jpg)](https://www.youtube.com/watch?v=ZDgWAWTwU7U) 10 | 11 | In [GMDL](https://sites.google.com/site/deepgeometry2018/) Workshop @ [ECCV2018](https://eccv2018.org/). 12 | 13 | See the [project webpage](http://perso.ensta-paristech.fr/~pinard/unsupervised-depthnet/) for more details. 14 | 15 | If you use this repo in your work, please cite us with the following bibtex: 16 | 17 | ``` 18 | @inproceedings{pinard2018learning, 19 | title={Learning structure-from-motion from motion}, 20 | author={Pinard, Cl{\'e}ment and Chevalley, Laure and Manzanera, Antoine and Filliat, David}, 21 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 22 | year={2018} 23 | } 24 | ``` 25 | 26 | ## Preamble 27 | This codebase was developed and tested with Pytorch 0.4.1, CUDA 9.2 and Ubuntu 16.04. 28 | 29 | ## Prerequisite 30 | 31 | ```bash 32 | pip3 install -r requirements.txt 33 | ``` 34 | 35 | or install manually the following packages : 36 | 37 | ``` 38 | pytorch>=0.4.1 39 | scipy 40 | imageio 41 | argparse 42 | tensorboardX 43 | blessings 44 | progressbar2 45 | path.py 46 | ``` 47 | 48 | It is also advised to have python3 bindings for opencv for tensorboard visualizations 49 | 50 | ## Preparing training data 51 | For KITTI, preparation is roughly the same command as in [SFM Learner](https://github.com/ClementPinard/SfmLearner-Pytorch). Note that here you can get the pose at the same time. If translation data is not very precise (see note [here](http://www.cvlibs.net/datasets/kitti/eval_odometry.php)), rotation for stabilization is pretty accurate, which can make it a drone-like training environment. 52 | 53 | For StillBox, every thing is already setup. For this training, a new version with rotations has been developped, which you will be able to download soon. The rotation-less version can be found [here](http://academictorrents.com/details/4d3a60ad3c9ceac7662735ba8e90fb467b43a3aa) via a torrent link. 54 | 55 | To get [KITTI](http://www.cvlibs.net/datasets/kitti/raw_data.php), first download the dataset using this [script](http://www.cvlibs.net/download.php?file=raw_data_downloader.zip) provided on the official website, and then run the following command. The `--with-pose` option will get pose matrices, especially for supervision of rotation compensation. The `--with-depth` option will save resized copies of depth groundtruth for validation set, to help you setting hyper parameters. 56 | 57 | ```bash 58 | python3 data/prepare_train_data.py /path/to/raw/kitti/dataset/ --dataset-format 'kitti' --dump-root /path/to/resulting/formatted/data/ --width 416 --height 128 --num-threads 4 [--static-frames /path/to/static_frames.txt] [--with-pose] [--with-gt] 59 | ``` 60 | 61 | For [Cityscapes](https://www.cityscapes-dataset.com/), download the following packages: 1) `leftImg8bit_sequence_trainvaltest.zip`, 2) `camera_trainvaltest.zip`. You will probably need to contact the administrators to be able to get it. No pose is currently available, but metadata would be theoritcally possible to use to get them. Then run the following command 62 | ```bash 63 | python3 data/prepare_train_data.py /path/to/cityscapes/dataset/ --dataset-format 'cityscapes' --dump-root /path/to/resulting/formatted/data/ --width 416 --height 171 --num-threads 4 64 | ``` 65 | Notice that for Cityscapes the `img_height` is set to 171 because we crop out the bottom part of the image that contains the car logo, and the resulting image will have height 128. 66 | 67 | ## Training 68 | Once the data are formatted following the above instructions, you should be able to train the model by running the following command 69 | ```bash 70 | python3 train_img_pairs.py /path/to/the/formatted/data/ -b4 -s3.0 --ssim 0.1 --epoch-size 3000 --sequence-length 3 --log-output [--with-gt] [--supervise-pose] 71 | ``` 72 | You can then start a `tensorboard` session in this folder by 73 | ```bash 74 | tensorboard --logdir=checkpoints/ 75 | ``` 76 | and visualize the training progress by opening [https://localhost:6006](https://localhost:6006) on your browser. 77 | 78 | ### Some useful options, with points not discussed in paper 79 | 80 | * `rotation-mode` : Lets you change between euler and quaternion. In practice, does not have noticable effect. 81 | * `--network-input-size` : Lets you downsample the picture before feeding to Pose and Depth networks, this is especially useful for large images, that can have more spatial information for Photometric loss, but still having small input for networks. This has been tested with pictures of size `832 x 256` without much effect in KITTI. 82 | * `--training-milestones` : During training, I_t and I_r can be anything within the frame sequence, but for stability, especially from scratch, it can be interesting to first fix them. first milestone is the epoch after which I_r is not fixed anymore. Likewise, second milestone is for I_t. 83 | 84 | ### Flexible shifts training 85 | 86 | As an experimental training, you can try flexible shifts. This will every N epochs (N is argument) recompute optimal shifts for a given sample. The goal is to avoid sequences with too much disparity (by reducing shift) or static scenes (by increasing shift). A proper dataset has yet to be constructed to check if this is a good idea or not. See the equivalent for SFMLearner [here](https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/train_flexible_shifts.py) 87 | 88 | ```bash 89 | python3 train_flexible_shifts.py /path/to/the/formatted/data/ -b4 -s3.0 --ssim 0.1 --epoch-size 3000 --sequence-length 3 --log-output [--with-gt] [--supervise-pose] -D 30 -r5 90 | ``` 91 | 92 | ## Evaluation 93 | 94 | Depth evaluation is avalaible 95 | ```bash 96 | python3 test_disp.py --pretrained-dispnet /path/to/dispnet --pretrained-posenet /path/to/posenet --dataset-dir /path/to/KITTI_raw --dataset-list /path/to/test_files_list 97 | ``` 98 | 99 | Test file list is available in kitti eval folder. To get fair comparison with [SFM learner evaluation code](hhttps://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/test_disp.py), it should be tested only with depth scale from GT pose, and from `kitti_eval/test_files_eigen_filtered.txt`, a filtered subset of `kitti_eval/test_files_eigen.txt` with which the GPS accuracy was measured to be good. 100 | 101 | Pose evaluation is available by using [this code](https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/test_pose.py) 102 | 103 | ## Pretrained Nets 104 | 105 | Soon to be available 106 | 107 | ### Depth Results 108 | 109 | #### KITTI 110 | 111 | | Abs Rel | Sq Rel | RMSE | RMSE(log) | Acc.1 | Acc.2 | Acc.3 | 112 | |---------|--------|-------|-----------|-------|-------|-------| 113 | | 0.294 | 3.992 | 7.573 | 0.356 | 0.609 | 0.833 | 0.909 | 114 | 115 | #### KITTI stabilized 116 | 117 | | Abs Rel | Sq Rel | RMSE | RMSE(log) | Acc.1 | Acc.2 | Acc.3 | 118 | |---------|--------|-------|-----------|-------|-------|-------| 119 | | 0.271 | 4.495 | 7.312 | 0.345 | 0.678 | 0.856 | 0.924 | 120 | 121 | #### Still Box 122 | 123 | | Abs Rel | Sq Rel | RMSE | RMSE(log) | Acc.1 | Acc.2 | Acc.3 | 124 | |---------|--------|--------|-----------|-------|-------|-------| 125 | | 0.468 | 10.924 | 15.756 | 0.544 | 0.452 | 0.573 | 0.714 | 126 | 127 | #### Still Box stabilized 128 | 129 | | Abs Rel | Sq Rel | RMSE | RMSE(log) | Acc.1 | Acc.2 | Acc.3 | 130 | |---------|--------|--------|-----------|-------|-------|-------| 131 | | 0.297 | 5.253 | 10.509 | 0.404 | 0.668 | 0.840 | 0.906 | 132 | 133 | FYI, here are Still Box stabilized results from a supervised training. 134 | 135 | #### Still Box stabilized supervised 136 | 137 | | Abs Rel | Sq Rel | RMSE | RMSE(log) | Acc.1 | Acc.2 | Acc.3 | 138 | |---------|--------|-------|-----------|-------|-------|-------| 139 | | 0.212 | 2.064 | 7.067 | 0.296 | 0.709 | 0.881 | 0.946 | 140 | -------------------------------------------------------------------------------- /custom_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import random 4 | import numpy as np 5 | from skimage.transform import resize 6 | 7 | '''Set of tranform random routines that takes list of inputs as arguments, 8 | in order to have random but coherent transformations.''' 9 | 10 | 11 | class Compose(object): 12 | def __init__(self, transforms): 13 | self.transforms = transforms 14 | 15 | def __call__(self, images, intrinsics): 16 | for t in self.transforms: 17 | images, intrinsics = t(images, intrinsics) 18 | return images, intrinsics 19 | 20 | 21 | class Normalize(object): 22 | def __init__(self, mean, std): 23 | self.mean = mean 24 | self.std = std 25 | 26 | def __call__(self, images, intrinsics): 27 | for tensor in images: 28 | for t, m, s in zip(tensor, self.mean, self.std): 29 | t.sub_(m).div_(s) 30 | return images, intrinsics 31 | 32 | 33 | class ArrayToTensor(object): 34 | """Converts a list of numpy.ndarray (H x W x C) along with a intrinsics matrix to a list of torch.FloatTensor of shape (C x H x W) with a intrinsics tensor.""" 35 | 36 | def __call__(self, images, intrinsics): 37 | tensors = [] 38 | for im in images: 39 | # put it from HWC to CHW format 40 | im = np.transpose(im, (2, 0, 1)) 41 | # handle numpy array 42 | tensors.append(torch.from_numpy(im).float()/255) 43 | return tensors, intrinsics 44 | 45 | 46 | class RandomHorizontalFlip(object): 47 | """Randomly horizontally flips the given numpy array with a probability of 0.5""" 48 | 49 | def __call__(self, images, intrinsics): 50 | assert intrinsics is not None 51 | if random.random() < 0.5: 52 | output_intrinsics = np.copy(intrinsics) 53 | output_images = [np.copy(np.fliplr(im)) for im in images] 54 | w = output_images[0].shape[1] 55 | output_intrinsics[0,2] = w - output_intrinsics[0,2] 56 | else: 57 | output_images = images 58 | output_intrinsics = intrinsics 59 | return output_images, output_intrinsics 60 | 61 | 62 | class RandomScaleCrop(object): 63 | """Randomly zooms images up to 15% and crop them to keep same size as before.""" 64 | 65 | def __call__(self, images, intrinsics): 66 | assert intrinsics is not None 67 | output_intrinsics = np.copy(intrinsics) 68 | 69 | in_h, in_w, _ = images[0].shape 70 | x_scaling, y_scaling = np.random.uniform(1,1.15,2) 71 | scaled_h, scaled_w = int(in_h * y_scaling), int(in_w * x_scaling) 72 | 73 | output_intrinsics[0] *= x_scaling 74 | output_intrinsics[1] *= y_scaling 75 | scaled_images = [resize(im, (scaled_h, scaled_w)) for im in images] 76 | 77 | offset_y = np.random.randint(scaled_h - in_h + 1) 78 | offset_x = np.random.randint(scaled_w - in_w + 1) 79 | cropped_images = [im[offset_y:offset_y + in_h, offset_x:offset_x + in_w] for im in scaled_images] 80 | 81 | output_intrinsics[0,2] -= offset_x 82 | output_intrinsics[1,2] -= offset_y 83 | 84 | return cropped_images, output_intrinsics 85 | -------------------------------------------------------------------------------- /data/cityscapes_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import json 3 | import numpy as np 4 | import scipy.misc 5 | from path import Path 6 | from tqdm import tqdm 7 | 8 | 9 | class cityscapes_loader(object): 10 | def __init__(self, 11 | dataset_dir, 12 | split='train', 13 | crop_bottom=True, # Get rid of the car logo 14 | img_height=171, 15 | img_width=416): 16 | self.dataset_dir = Path(dataset_dir) 17 | self.split = split 18 | # Crop out the bottom 25% of the image to remove the car logo 19 | self.crop_bottom = crop_bottom 20 | self.img_height = img_height 21 | self.img_width = img_width 22 | self.min_speed = 2 23 | self.scenes = (self.dataset_dir/'leftImg8bit_sequence'/split).dirs() 24 | print('Total scenes collected: {}'.format(len(self.scenes))) 25 | 26 | def collect_scenes(self, city): 27 | img_files = sorted(city.files('*.png')) 28 | scenes = {} 29 | connex_scenes = {} 30 | connex_scene_data_list = [] 31 | for f in img_files: 32 | scene_id,frame_id = f.basename().split('_')[1:3] 33 | if scene_id not in scenes.keys(): 34 | scenes[scene_id] = [] 35 | scenes[scene_id].append(frame_id) 36 | 37 | # divide scenes into connexe sequences 38 | for scene_id in scenes.keys(): 39 | previous = None 40 | connex_scenes[scene_id] = [] 41 | for id in scenes[scene_id]: 42 | if previous is None or int(id) - int(previous) > 1: 43 | current_list = [] 44 | connex_scenes[scene_id].append(current_list) 45 | current_list.append(id) 46 | previous = id 47 | 48 | # create scene data dicts, and subsample scene every two frames 49 | for scene_id in connex_scenes.keys(): 50 | intrinsics = self.load_intrinsics(city, scene_id) 51 | for subscene in connex_scenes[scene_id]: 52 | frame_speeds = [self.load_speed(city, scene_id, frame_id) for frame_id in subscene] 53 | connex_scene_data_list.append({'city':city, 54 | 'scene_id': scene_id, 55 | 'rel_path': city.basename()+'_'+scene_id+'_'+subscene[0]+'_0', 56 | 'intrinsics': intrinsics, 57 | 'frame_ids':subscene[0::2], 58 | 'speeds':frame_speeds[0::2]}) 59 | connex_scene_data_list.append({'city':city, 60 | 'scene_id': scene_id, 61 | 'rel_path': city.basename()+'_'+scene_id+'_'+subscene[0]+'_1', 62 | 'intrinsics': intrinsics, 63 | 'frame_ids': subscene[1::2], 64 | 'speeds': frame_speeds[1::2]}) 65 | return connex_scene_data_list 66 | 67 | def load_intrinsics(self, city, scene_id): 68 | city_name = city.basename() 69 | camera_folder = self.dataset_dir/'camera'/self.split/city_name 70 | camera_file = camera_folder.files('{}_{}_*_camera.json'.format(city_name, scene_id))[0] 71 | frame_id = camera_file.split('_')[2] 72 | frame_path = city/'{}_{}_{}_leftImg8bit.png'.format(city_name, scene_id, frame_id) 73 | 74 | with open(camera_file, 'r') as f: 75 | camera = json.load(f) 76 | fx = camera['intrinsic']['fx'] 77 | fy = camera['intrinsic']['fy'] 78 | u0 = camera['intrinsic']['u0'] 79 | v0 = camera['intrinsic']['v0'] 80 | intrinsics = np.array([[fx, 0, u0], 81 | [0, fy, v0], 82 | [0, 0, 1]]) 83 | 84 | img = scipy.misc.imread(frame_path) 85 | h,w,_ = img.shape 86 | zoom_y = self.img_height/h 87 | zoom_x = self.img_width/w 88 | 89 | intrinsics[0] *= zoom_x 90 | intrinsics[1] *= zoom_y 91 | return intrinsics 92 | 93 | def load_speed(self, city, scene_id, frame_id): 94 | city_name = city.basename() 95 | vehicle_folder = self.dataset_dir/'vehicle_sequence'/self.split/city_name 96 | vehicle_file = vehicle_folder/'{}_{}_{}_vehicle.json'.format(city_name, scene_id, frame_id) 97 | with open(vehicle_file, 'r') as f: 98 | vehicle = json.load(f) 99 | return vehicle['speed'] 100 | 101 | def get_scene_imgs(self, scene_data): 102 | cum_speed = np.zeros(3) 103 | print(scene_data['city'].basename(), scene_data['scene_id'], scene_data['frame_ids'][0]) 104 | for i,frame_id in enumerate(scene_data['frame_ids']): 105 | cum_speed += scene_data['speeds'][i] 106 | speed_mag = np.linalg.norm(cum_speed) 107 | if speed_mag > self.min_speed: 108 | yield {"img":self.load_image(scene_data['city'], scene_data['scene_id'], frame_id), 109 | "id":frame_id} 110 | cum_speed *= 0 111 | 112 | def load_image(self, city, scene_id, frame_id): 113 | img_file = city/'{}_{}_{}_leftImg8bit.png'.format(city.basename(), 114 | scene_id, 115 | frame_id) 116 | if not img_file.isfile(): 117 | return None 118 | img = scipy.misc.imread(img_file) 119 | img = scipy.misc.imresize(img, (self.img_height, self.img_width))[:int(self.img_height*0.75)] 120 | return img 121 | -------------------------------------------------------------------------------- /data/kitti_raw_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | from path import Path 4 | from imageio import imread 5 | from skimage.transform import resize 6 | from collections import Counter 7 | 8 | 9 | def rotx(t): 10 | """Rotation about the x-axis.""" 11 | c = np.cos(t) 12 | s = np.sin(t) 13 | return np.array([[1, 0, 0], 14 | [0, c, -s], 15 | [0, s, c]]) 16 | 17 | 18 | def roty(t): 19 | """Rotation about the y-axis.""" 20 | c = np.cos(t) 21 | s = np.sin(t) 22 | return np.array([[c, 0, s], 23 | [0, 1, 0], 24 | [-s, 0, c]]) 25 | 26 | 27 | def rotz(t): 28 | """Rotation about the z-axis.""" 29 | c = np.cos(t) 30 | s = np.sin(t) 31 | return np.array([[c, -s, 0], 32 | [s, c, 0], 33 | [0, 0, 1]]) 34 | 35 | 36 | def pose_from_oxts_packet(metadata, scale): 37 | 38 | lat, lon, alt, roll, pitch, yaw = metadata 39 | """Helper method to compute a SE(3) pose matrix from an OXTS packet. 40 | Taken from https://github.com/utiasSTARS/pykitti 41 | """ 42 | 43 | er = 6378137. # earth radius (approx.) in meters 44 | # Use a Mercator projection to get the translation vector 45 | ty = lat * np.pi * er / 180. 46 | 47 | tx = scale * lon * np.pi * er / 180. 48 | # ty = scale * er * \ 49 | # np.log(np.tan((90. + lat) * np.pi / 360.)) 50 | tz = alt 51 | t = np.array([tx, ty, tz]).reshape(-1,1) 52 | 53 | # Use the Euler angles to get the rotation matrix 54 | Rx = rotx(roll) 55 | Ry = roty(pitch) 56 | Rz = rotz(yaw) 57 | R = Rz.dot(Ry.dot(Rx)) 58 | return transform_from_rot_trans(R, t) 59 | 60 | 61 | def read_calib_file(path): 62 | # taken from https://github.com/hunse/kitti 63 | float_chars = set("0123456789.e+- ") 64 | data = {} 65 | with open(path, 'r') as f: 66 | for line in f.readlines(): 67 | key, value = line.split(':', 1) 68 | value = value.strip() 69 | data[key] = value 70 | if float_chars.issuperset(value): 71 | # try to cast to float array 72 | try: 73 | data[key] = np.array(list(map(float, value.split(' ')))) 74 | except ValueError: 75 | # casting error: data[key] already eq. value, so pass 76 | pass 77 | 78 | return data 79 | 80 | 81 | def transform_from_rot_trans(R, t): 82 | """Transforation matrix from rotation matrix and translation vector.""" 83 | R = R.reshape(3, 3) 84 | t = t.reshape(3, 1) 85 | return np.vstack((np.hstack([R, t]), [0, 0, 0, 1])) 86 | 87 | 88 | class KittiRawLoader(object): 89 | def __init__(self, 90 | dataset_dir, 91 | static_frames_file=None, 92 | img_height=128, 93 | img_width=416, 94 | min_speed=2, 95 | get_depth=False, 96 | get_pose=False, 97 | depth_size_ratio=1): 98 | dir_path = Path(__file__).realpath().dirname() 99 | test_scene_file = dir_path/'test_scenes.txt' 100 | 101 | self.from_speed = static_frames_file is None 102 | if static_frames_file is not None: 103 | static_frames_file = Path(static_frames_file) 104 | self.collect_static_frames(static_frames_file) 105 | 106 | with open(test_scene_file, 'r') as f: 107 | test_scenes = f.readlines() 108 | self.test_scenes = [t[:-1] for t in test_scenes] 109 | self.dataset_dir = Path(dataset_dir) 110 | self.img_height = img_height 111 | self.img_width = img_width 112 | self.cam_ids = ['02', '03'] 113 | self.date_list = ['2011_09_26', '2011_09_28', '2011_09_29', '2011_09_30', '2011_10_03'] 114 | self.min_speed = min_speed 115 | self.get_depth = get_depth 116 | self.get_pose = get_pose 117 | self.depth_size_ratio = depth_size_ratio 118 | self.collect_train_folders() 119 | 120 | def collect_static_frames(self, static_frames_file): 121 | with open(static_frames_file, 'r') as f: 122 | frames = f.readlines() 123 | self.static_frames = {} 124 | for fr in frames: 125 | if fr == '\n': 126 | continue 127 | date, drive, frame_id = fr.split(' ') 128 | curr_fid = '%.10d' % (np.int(frame_id[:-1])) 129 | if drive not in self.static_frames.keys(): 130 | self.static_frames[drive] = [] 131 | self.static_frames[drive].append(curr_fid) 132 | 133 | def collect_train_folders(self): 134 | self.scenes = [] 135 | for date in self.date_list: 136 | drive_set = (self.dataset_dir/date).dirs() 137 | for dr in drive_set: 138 | if dr.name[:-5] not in self.test_scenes: 139 | self.scenes.append(dr) 140 | 141 | def collect_scenes(self, drive): 142 | train_scenes = [] 143 | for c in self.cam_ids: 144 | oxts = sorted((drive/'oxts'/'data').files('*.txt')) 145 | scene_data = {'cid': c, 'dir': drive, 'speed': [], 'frame_id': [], 'pose':[], 'rel_path': drive.name + '_' + c} 146 | scale = None 147 | origin = None 148 | imu2velo = read_calib_file(drive.parent/'calib_imu_to_velo.txt') 149 | velo2cam = read_calib_file(drive.parent/'calib_velo_to_cam.txt') 150 | cam2cam = read_calib_file(drive.parent/'calib_cam_to_cam.txt') 151 | 152 | velo2cam_mat = transform_from_rot_trans(velo2cam['R'], velo2cam['T']) 153 | imu2velo_mat = transform_from_rot_trans(imu2velo['R'], imu2velo['T']) 154 | cam_2rect_mat = transform_from_rot_trans(cam2cam['R_rect_00'], np.zeros(3)) 155 | 156 | imu2cam = cam_2rect_mat @ velo2cam_mat @ imu2velo_mat 157 | 158 | for n, f in enumerate(oxts): 159 | metadata = np.genfromtxt(f) 160 | speed = metadata[8:11] 161 | scene_data['speed'].append(speed) 162 | scene_data['frame_id'].append('{:010d}'.format(n)) 163 | lat = metadata[0] 164 | 165 | if scale is None: 166 | scale = np.cos(lat * np.pi / 180.) 167 | 168 | pose_matrix = pose_from_oxts_packet(metadata[:6], scale) 169 | if origin is None: 170 | origin = pose_matrix 171 | 172 | odo_pose = imu2cam @ np.linalg.inv(origin) @ pose_matrix @ np.linalg.inv(imu2cam) 173 | scene_data['pose'].append(odo_pose[:3]) 174 | 175 | sample = self.load_image(scene_data, 0) 176 | if sample is None: 177 | return [] 178 | scene_data['P_rect'] = self.get_P_rect(scene_data, sample[1], sample[2]) 179 | scene_data['intrinsics'] = scene_data['P_rect'][:,:3] 180 | 181 | train_scenes.append(scene_data) 182 | return train_scenes 183 | 184 | def get_scene_imgs(self, scene_data): 185 | def construct_sample(scene_data, i, frame_id): 186 | sample = {"img":self.load_image(scene_data, i)[0], "id":frame_id} 187 | 188 | if self.get_depth: 189 | sample['depth'] = self.generate_depth_map(scene_data, i) 190 | if self.get_pose: 191 | sample['pose'] = scene_data['pose'][i] 192 | return sample 193 | 194 | if self.from_speed: 195 | cum_speed = np.zeros(3) 196 | for i, speed in enumerate(scene_data['speed']): 197 | cum_speed += speed 198 | speed_mag = np.linalg.norm(cum_speed) 199 | if speed_mag > self.min_speed: 200 | frame_id = scene_data['frame_id'][i] 201 | yield construct_sample(scene_data, i, frame_id) 202 | cum_speed *= 0 203 | else: # from static frame file 204 | drive = str(scene_data['dir'].name) 205 | for (i,frame_id) in enumerate(scene_data['frame_id']): 206 | if (drive not in self.static_frames.keys()) or (frame_id not in self.static_frames[drive]): 207 | yield construct_sample(scene_data, i, frame_id) 208 | 209 | def get_P_rect(self, scene_data, zoom_x, zoom_y): 210 | calib_file = scene_data['dir'].parent/'calib_cam_to_cam.txt' 211 | 212 | filedata = self.read_raw_calib_file(calib_file) 213 | P_rect = np.reshape(filedata['P_rect_' + scene_data['cid']], (3, 4)) 214 | P_rect[0] *= zoom_x 215 | P_rect[1] *= zoom_y 216 | return P_rect 217 | 218 | def load_image(self, scene_data, tgt_idx): 219 | img_file = scene_data['dir']/'image_{}'.format(scene_data['cid'])/'data'/scene_data['frame_id'][tgt_idx]+'.png' 220 | if not img_file.isfile(): 221 | return None 222 | img = imread(img_file) 223 | zoom_y = self.img_height/img.shape[0] 224 | zoom_x = self.img_width/img.shape[1] 225 | img = resize(img, (self.img_height, self.img_width)) 226 | return img, zoom_x, zoom_y 227 | 228 | def read_raw_calib_file(self, filepath): 229 | # From https://github.com/utiasSTARS/pykitti/blob/master/pykitti/utils.py 230 | """Read in a calibration file and parse into a dictionary.""" 231 | data = {} 232 | 233 | with open(filepath, 'r') as f: 234 | for line in f.readlines(): 235 | key, value = line.split(':', 1) 236 | # The only non-float values in these files are dates, which 237 | # we don't care about anyway 238 | try: 239 | data[key] = np.array([float(x) for x in value.split()]) 240 | except ValueError: 241 | pass 242 | return data 243 | 244 | def generate_depth_map(self, scene_data, tgt_idx): 245 | # compute projection matrix velodyne->image plane 246 | 247 | def sub2ind(matrixSize, rowSub, colSub): 248 | m, n = matrixSize 249 | return rowSub * (n-1) + colSub - 1 250 | 251 | R_cam2rect = np.eye(4) 252 | 253 | calib_dir = scene_data['dir'].parent 254 | cam2cam = self.read_raw_calib_file(calib_dir/'calib_cam_to_cam.txt') 255 | velo2cam = self.read_raw_calib_file(calib_dir/'calib_velo_to_cam.txt') 256 | velo2cam = np.hstack((velo2cam['R'].reshape(3,3), velo2cam['T'][..., np.newaxis])) 257 | velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0]))) 258 | P_rect = np.copy(scene_data['P_rect']) 259 | P_rect[0] /= self.depth_size_ratio 260 | P_rect[1] /= self.depth_size_ratio 261 | 262 | R_cam2rect[:3,:3] = cam2cam['R_rect_00'].reshape(3,3) 263 | 264 | P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam) 265 | 266 | velo_file_name = scene_data['dir']/'velodyne_points'/'data'/'{}.bin'.format(scene_data['frame_id'][tgt_idx]) 267 | 268 | # load velodyne points and remove all behind image plane (approximation) 269 | # each row of the velodyne data is forward, left, up, reflectance 270 | velo = np.fromfile(velo_file_name, dtype=np.float32).reshape(-1, 4) 271 | velo[:,3] = 1 272 | velo = velo[velo[:, 0] >= 0, :] 273 | 274 | # project the points to the camera 275 | velo_pts_im = np.dot(P_velo2im, velo.T).T 276 | velo_pts_im[:, :2] = velo_pts_im[:,:2] / velo_pts_im[:,-1:] 277 | 278 | # check if in bounds 279 | # use minus 1 to get the exact same value as KITTI matlab code 280 | velo_pts_im[:, 0] = np.round(velo_pts_im[:,0]) - 1 281 | velo_pts_im[:, 1] = np.round(velo_pts_im[:,1]) - 1 282 | 283 | val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) 284 | val_inds = val_inds & (velo_pts_im[:,0] < self.img_width/self.depth_size_ratio) 285 | val_inds = val_inds & (velo_pts_im[:,1] < self.img_height/self.depth_size_ratio) 286 | velo_pts_im = velo_pts_im[val_inds, :] 287 | 288 | # project to image 289 | depth = np.zeros((self.img_height // self.depth_size_ratio, self.img_width // self.depth_size_ratio)).astype(np.float32) 290 | depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] 291 | 292 | # find the duplicate points and choose the closest depth 293 | inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) 294 | dupe_inds = [item for item, count in Counter(inds).items() if count > 1] 295 | for dd in dupe_inds: 296 | pts = np.where(inds == dd)[0] 297 | x_loc = int(velo_pts_im[pts[0], 0]) 298 | y_loc = int(velo_pts_im[pts[0], 1]) 299 | depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() 300 | depth[depth < 0] = 0 301 | return depth 302 | -------------------------------------------------------------------------------- /data/prepare_train_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from imageio import imsave 3 | import numpy as np 4 | from joblib import Parallel, delayed 5 | from tqdm import tqdm 6 | from path import Path 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("dataset_dir", metavar='DIR', 10 | help='path to original dataset') 11 | parser.add_argument("--dataset-format", type=str, required=True, choices=["kitti", "cityscapes"]) 12 | parser.add_argument("--static-frames", default=None, 13 | help="list of imgs to discard for being static, if not set will discard them based on speed \ 14 | (careful, on KITTI some frames have incorrect speed)") 15 | parser.add_argument("--with-depth", action='store_true', 16 | help="If available (e.g. with KITTI), will store depth ground truth along with images, for validation") 17 | parser.add_argument("--with-pose", action='store_true', 18 | help="If available (e.g. with KITTI), will store pose ground truth along with images, for validation") 19 | parser.add_argument("--no-train-gt", action='store_true', 20 | help="If selected, will delete ground truth depth to save space") 21 | parser.add_argument("--dump-root", type=str, required=True, help="Where to dump the data") 22 | parser.add_argument("--height", type=int, default=128, help="image height") 23 | parser.add_argument("--width", type=int, default=416, help="image width") 24 | parser.add_argument("--depth-size-ratio", type=int, default=1, help="will divide depth size by that ratio") 25 | parser.add_argument("--num-threads", type=int, default=4, help="number of threads to use") 26 | 27 | args = parser.parse_args() 28 | 29 | 30 | def dump_example(args, scene): 31 | scene_list = data_loader.collect_scenes(scene) 32 | for scene_data in scene_list: 33 | dump_dir = args.dump_root/scene_data['rel_path'] 34 | dump_dir.makedirs_p() 35 | intrinsics = scene_data['intrinsics'] 36 | 37 | dump_cam_file = dump_dir/'cam.txt' 38 | 39 | np.savetxt(dump_cam_file, intrinsics) 40 | poses_file = dump_dir/'poses.txt' 41 | poses = [] 42 | 43 | for sample in data_loader.get_scene_imgs(scene_data): 44 | img, frame_nb = sample["img"], sample["id"] 45 | dump_img_file = dump_dir/'{}.jpg'.format(frame_nb) 46 | imsave(dump_img_file, (256*img).astype(np.uint8)) 47 | if "pose" in sample.keys(): 48 | poses.append(sample["pose"].tolist()) 49 | if "depth" in sample.keys(): 50 | dump_depth_file = dump_dir/'{}.npy'.format(frame_nb) 51 | np.save(dump_depth_file, sample["depth"]) 52 | if len(poses) != 0: 53 | np.savetxt(poses_file, np.array(poses).reshape(-1, 12), fmt='%.6e') 54 | 55 | if len(dump_dir.files('*.jpg')) < 3: 56 | dump_dir.rmtree() 57 | 58 | 59 | def main(): 60 | args.dump_root = Path(args.dump_root) 61 | args.dump_root.mkdir_p() 62 | 63 | global data_loader 64 | 65 | if args.dataset_format == 'kitti': 66 | from kitti_raw_loader import KittiRawLoader 67 | data_loader = KittiRawLoader(args.dataset_dir, 68 | static_frames_file=args.static_frames, 69 | img_height=args.height, 70 | img_width=args.width, 71 | get_depth=args.with_depth, 72 | get_pose=args.with_pose, 73 | depth_size_ratio=args.depth_size_ratio) 74 | 75 | if args.dataset_format == 'cityscapes': 76 | from cityscapes_loader import cityscapes_loader 77 | data_loader = cityscapes_loader(args.dataset_dir, 78 | img_height=args.height, 79 | img_width=args.width) 80 | 81 | print('Retrieving frames') 82 | if args.num_threads == 1: 83 | for scene in tqdm(data_loader.scenes): 84 | dump_example(args, scene) 85 | else: 86 | Parallel(n_jobs=args.num_threads)(delayed(dump_example)(args, scene) for scene in tqdm(data_loader.scenes)) 87 | 88 | print('Generating train val lists') 89 | np.random.seed(8964) 90 | subfolders = args.dump_root.dirs() 91 | with open(args.dump_root / 'train.txt', 'w') as tf: 92 | with open(args.dump_root / 'val.txt', 'w') as vf: 93 | for s in tqdm(subfolders): 94 | if np.random.random() < 0.1: 95 | vf.write('{}\n'.format(s.name)) 96 | else: 97 | tf.write('{}\n'.format(s.name)) 98 | if args.with_depth and args.no_train_gt: 99 | for gt_file in s.files('*.npy'): 100 | gt_file.remove_p() 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /data/test_scenes.txt: -------------------------------------------------------------------------------- 1 | 2011_09_26_drive_0117 2 | 2011_09_28_drive_0002 3 | 2011_09_26_drive_0052 4 | 2011_09_30_drive_0016 5 | 2011_09_26_drive_0059 6 | 2011_09_26_drive_0027 7 | 2011_09_26_drive_0020 8 | 2011_09_26_drive_0009 9 | 2011_09_26_drive_0013 10 | 2011_09_26_drive_0101 11 | 2011_09_26_drive_0046 12 | 2011_09_26_drive_0029 13 | 2011_09_26_drive_0064 14 | 2011_09_26_drive_0048 15 | 2011_10_03_drive_0027 16 | 2011_09_26_drive_0002 17 | 2011_09_26_drive_0036 18 | 2011_09_29_drive_0071 19 | 2011_10_03_drive_0047 20 | 2011_09_30_drive_0027 21 | 2011_09_26_drive_0086 22 | 2011_09_26_drive_0084 23 | 2011_09_26_drive_0096 24 | 2011_09_30_drive_0018 25 | 2011_09_26_drive_0106 26 | 2011_09_26_drive_0056 27 | 2011_09_26_drive_0023 28 | 2011_09_26_drive_0093 29 | -------------------------------------------------------------------------------- /datasets/sequence_folders.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | from imageio import imread 4 | from path import Path 5 | import random 6 | 7 | 8 | def load_as_float(path): 9 | return imread(path).astype(np.float32) 10 | 11 | 12 | class SequenceFolder(data.Dataset): 13 | """A sequence data loader where the files are arranged in this way: 14 | root/scene_1/0000000.jpg 15 | root/scene_1/0000001.jpg 16 | .. 17 | root/scene_1/cam.txt 18 | root/scene_2/0000000.jpg 19 | . 20 | 21 | transform functions must take in a list a images and a numpy array (usually intrinsics matrix) 22 | """ 23 | 24 | def __init__(self, root, seed=None, train=True, sequence_length=3, transform=None, target_transform=None): 25 | np.random.seed(seed) 26 | random.seed(seed) 27 | self.root = Path(root) 28 | scene_list_path = self.root/'train.txt' if train else self.root/'val.txt' 29 | self.scenes = [self.root/folder[:-1] for folder in open(scene_list_path)] 30 | self.transform = transform 31 | self.crawl_folders(sequence_length) 32 | 33 | def crawl_folders(self, sequence_length): 34 | sequence_set = [] 35 | demi_length = (sequence_length-1)//2 36 | shifts = list(range(-demi_length, demi_length + 1)) 37 | shifts.pop(demi_length) 38 | for scene in self.scenes: 39 | intrinsics = np.genfromtxt(scene/'cam.txt', delimiter=',').astype(np.float32).reshape((3, 3)) 40 | imgs = sorted(scene.files('*.jpg')) 41 | if len(imgs) < sequence_length: 42 | continue 43 | for i in range(demi_length, len(imgs)-demi_length): 44 | sample = {'intrinsics': intrinsics, 'tgt': imgs[i], 'ref_imgs': []} 45 | for j in shifts: 46 | sample['ref_imgs'].append(imgs[i+j]) 47 | sequence_set.append(sample) 48 | random.shuffle(sequence_set) 49 | self.samples = sequence_set 50 | 51 | def __getitem__(self, index): 52 | sample = self.samples[index] 53 | tgt_img = load_as_float(sample['tgt']) 54 | ref_imgs = [load_as_float(ref_img) for ref_img in sample['ref_imgs']] 55 | if self.transform is not None: 56 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(sample['intrinsics'])) 57 | tgt_img = imgs[0] 58 | ref_imgs = imgs[1:] 59 | else: 60 | intrinsics = np.copy(sample['intrinsics']) 61 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics) 62 | 63 | def __len__(self): 64 | return len(self.samples) 65 | -------------------------------------------------------------------------------- /datasets/shifted_sequence_folders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.utils.data as data 3 | from imageio import imread 4 | import random 5 | import json 6 | from path import Path 7 | 8 | 9 | def load_as_float(path): 10 | return imread(path).astype(np.float32) 11 | 12 | 13 | def quat2mat(quat): 14 | w, x, y, z = quat[:,0], quat[:,1], quat[:,2], quat[:,3] 15 | w2, x2, y2, z2 = w**2, x**2, y**2, z**2 16 | wx, wy, wz = w*x, w*y, w*z 17 | xy, xz, yz = x*y, x*z, y*z 18 | 19 | rotMat = np.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 20 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 21 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], axis=1).reshape(quat.shape[0], 3, 3) 22 | return rotMat 23 | 24 | 25 | class ShiftedSequenceFolder(data.Dataset): 26 | """A sequence data loader where the files are arranged in this way: 27 | root/scene_1/0000000.jpg 28 | root/scene_1/0000001.jpg 29 | .. 30 | root/scene_1/cam.txt 31 | (optional) root/scene_1/shifts.json 32 | root/scene_2/0000000.jpg 33 | . 34 | 35 | transform functions must take in a list a images and a numpy array (usually intrinsics matrix) 36 | """ 37 | 38 | def __init__(self, root, seed=None, train=True, with_pose_gt=False, with_depth_gt=False, 39 | sequence_length=3, target_displacement=0.02, transform=None): 40 | np.random.seed(seed) 41 | random.seed(seed) 42 | self.root = Path(root) 43 | scene_list_path = self.root/'train.txt' if train else self.root/'val.txt' 44 | self.scenes = [self.root/folder[:-1] for folder in open(scene_list_path)] 45 | self.transform = transform 46 | self.target_displacement = target_displacement 47 | self.max_shift = 10 48 | self.adjust = False 49 | self.with_pose_gt = with_pose_gt 50 | self.with_depth_gt = with_depth_gt 51 | self.crawl_folders(sequence_length) 52 | 53 | def crawl_folders(self, sequence_length): 54 | sequence_set = [] 55 | img_sequences = [] 56 | poses_sequences = [] 57 | demi_length = sequence_length//2 58 | for scene in self.scenes: 59 | imgs = sorted(scene.files('*.jpg')) 60 | if len(imgs) < sequence_length: 61 | continue 62 | 63 | shifts_file = scene/'shifts.json' 64 | if shifts_file.isfile(): 65 | with open(shifts_file, 'r') as f: 66 | shifts = json.load(f) 67 | else: 68 | prior_shifts = list(range(-demi_length, 0)) 69 | post_shifts = list(range(1, sequence_length - demi_length)) 70 | shifts = [[prior_shifts[:], post_shifts[:]] for i in imgs] 71 | 72 | if self.with_pose_gt: 73 | pose_file = scene/'poses.txt' 74 | assert pose_file.isfile(), "cannot find ground truth pose file {}".format(pose_file) 75 | poses = np.loadtxt(pose_file).astype(np.float32).reshape(-1, 3, 4) 76 | poses_sequences.append(poses) 77 | img_sequences.append(imgs) 78 | sequence_index = len(img_sequences) - 1 79 | intrinsics = np.loadtxt(scene/'cam.txt').astype(np.float32).reshape(3, 3) 80 | for i in range(demi_length, len(imgs)-demi_length): 81 | sample = {'intrinsics': intrinsics, 82 | 'tgt': i, 83 | 'prior_shifts': shifts[i][0], 84 | 'post_shifts': shifts[i][1], 85 | 'sequence_index': sequence_index} 86 | if self.with_depth_gt: 87 | depth = imgs[i].stripext() + '.npy' 88 | assert depth.isfile(), "cannot find ground truth depth map {}".format(depth) 89 | sample['depth'] = depth 90 | sequence_set.append(sample) 91 | random.shuffle(sequence_set) 92 | self.samples = sequence_set 93 | self.img_sequences = img_sequences 94 | if self.with_pose_gt: 95 | self.poses_sequences = poses_sequences 96 | 97 | def __getitem__(self, index): 98 | sample = self.samples[index] 99 | preprocessed_sample = {} 100 | imgs_paths = self.img_sequences[sample['sequence_index']] 101 | tgt_index = sample['tgt'] 102 | tgt_img = load_as_float(imgs_paths[tgt_index]) 103 | if self.with_depth_gt: 104 | tgt_depth = np.load(sample['depth']) 105 | preprocessed_sample['depth'] = tgt_depth 106 | 107 | try: 108 | prior_imgs = [load_as_float(imgs_paths[tgt_index + i]) for i in sample['prior_shifts']] 109 | post_imgs = [load_as_float(imgs_paths[tgt_index + i]) for i in sample['post_shifts']] 110 | imgs = prior_imgs + [tgt_img] + post_imgs 111 | if self.with_pose_gt: 112 | poses = self.poses_sequences[sample['sequence_index']] 113 | tgt_pose = poses[tgt_index] 114 | prior_poses = [poses[tgt_index + i] for i in sample['prior_shifts']] 115 | post_poses = [poses[tgt_index + i] for i in sample['post_shifts']] 116 | pose_sequence = np.stack(prior_poses + [tgt_pose] + post_poses) 117 | # neutral pose is defined to be last frame 118 | pose_sequence[:,:,-1] -= pose_sequence[-1,:,-1] 119 | compensated_poses = np.linalg.inv(pose_sequence[-1,:,:3]) @ pose_sequence 120 | preprocessed_sample['pose'] = compensated_poses 121 | except Exception as e: 122 | print(index, sample['tgt'], sample['prior_shifts'], sample['post_shifts'], len(imgs)) 123 | raise e 124 | if self.transform is not None: 125 | imgs, intrinsics = self.transform(imgs, sample['intrinsics']) 126 | else: 127 | intrinsics = sample['intrinsics'] 128 | preprocessed_sample['imgs'] = imgs 129 | preprocessed_sample['intrinsics'] = intrinsics 130 | if self.adjust: 131 | preprocessed_sample['index'] = index 132 | return preprocessed_sample 133 | 134 | def reset_shifts(self, index, prior_ratio, post_ratio): 135 | sample = self.samples[index] 136 | assert(len(sample['prior_shifts']) == len(prior_ratio)) 137 | assert(len(sample['post_shifts']) == len(post_ratio)) 138 | imgs = self.img_sequences[sample['sequence_index']] 139 | tgt_index = sample['tgt'] 140 | 141 | for j, r in enumerate(prior_ratio[::-1]): 142 | 143 | shift_index = len(prior_ratio) - 1 - j 144 | old_shift = sample['prior_shifts'][shift_index] 145 | new_shift = old_shift * r 146 | assert(new_shift < 0), "shift must be negative: {:.3f}, {}, {:.3f}".format(new_shift, old_shift, r) 147 | new_shift = round(new_shift) 148 | ''' Here is how bounds work for prior shifts: 149 | prior shifts must be negative in a strict ascending order in the original list 150 | max_shift (in magnitude) is either tgt (to keep index inside list) or self.max_shift 151 | Let's say you have 2 anterior shifts, which means seq_length is 5 152 | 1st shift can be -max_shift but cannot be 0 as it would mean that 2nd would not be higher than 1st and above 0 153 | 2nd shift cannot be -max_shift as 1st shift would have to be less than -max_shift - 1. 154 | More generally, shift must be clipped within -max_shift + its index and upper shift - 1 155 | Note that priority is given for shifts closer to tgt_index, they are free to choose the value they want, at the risk of 156 | constraining outside shifts to one only valid value 157 | ''' 158 | 159 | max_shift = min(tgt_index, self.max_shift) 160 | 161 | lower_bound = -max_shift + shift_index 162 | upper_bound = -1 if shift_index == len(prior_ratio) - 1 else sample['prior_shifts'][shift_index + 1] - 1 163 | 164 | sample['prior_shifts'][shift_index] = int(np.clip(new_shift, lower_bound, upper_bound)) 165 | 166 | for j, r in enumerate(post_ratio): 167 | shift_index = j 168 | old_shift = sample['post_shifts'][shift_index] 169 | new_shift = old_shift * r 170 | assert(new_shift > 0), "shift must be positive: {:.3f}, {}, {}".format(new_shift, old_shift, r) 171 | new_shift = round(new_shift) 172 | '''For posterior shifts : 173 | must be postive in a strict descending order 174 | max_shift is either len(imgs) - tgt or self.max_shift 175 | shift must be clipped within upper shift + 1 and max_shift - seq_length + its index 176 | ''' 177 | 178 | max_shift = min(len(imgs) - tgt_index - 1, self.max_shift) 179 | 180 | lower_bound = 1 if shift_index == 0 else sample['post_shifts'][shift_index - 1] + 1 181 | upper_bound = max_shift + shift_index - len(post_ratio) + 1 182 | 183 | sample['post_shifts'][shift_index] = int(np.clip(new_shift, lower_bound, upper_bound)) 184 | 185 | def get_shifts(self, index): 186 | sample = self.samples[index] 187 | prior = sample['prior_shifts'] 188 | post = sample['post_shifts'] 189 | return prior + post 190 | 191 | def __len__(self): 192 | return len(self.samples) 193 | 194 | 195 | class StillBox(ShiftedSequenceFolder): 196 | def crawl_folders(self, sequence_length): 197 | import json 198 | sequence_set = [] 199 | img_sequences = [] 200 | poses_sequences = [] 201 | demi_length = sequence_length//2 202 | for folder in self.scenes: 203 | with open(folder/'metadata.json', 'r') as f: 204 | metadata = json.load(f) 205 | args = metadata['args'] 206 | hfov = args['fov'] 207 | w,h = args['resolution'] 208 | f = w/(2*np.tan(np.pi*hfov/360)) 209 | intrinsics = np.array([[f, 0, w/2], 210 | [0, f, h/2], 211 | [0, 0, 1]]).astype(np.float32) 212 | for scene in metadata['scenes']: 213 | imgs = [folder/i for i in scene['imgs']] 214 | if self.with_depth_gt: 215 | depth = [folder/i for i in scene['depth']] 216 | if len(imgs) < sequence_length: 217 | continue 218 | 219 | prior_shifts = list(range(-demi_length, 0)) 220 | post_shifts = list(range(1, sequence_length - demi_length)) 221 | shifts = [[prior_shifts[:], post_shifts[:]] for i in imgs] 222 | 223 | if self.with_pose_gt: 224 | sl = len(scene['imgs']) 225 | nominal_displacement = np.float32(scene['speed']) * scene['time_step'] 226 | if len(scene['orientation']) == 0: 227 | scene_quaternions = np.float32([1,0,0,0]).reshape(1,4).repeat(sl, axis=0) 228 | else: 229 | scene_quaternions = np.float32(scene['orientation']) 230 | scene_positions = np.arange(sl).astype(np.float32).reshape(sl, 1) * nominal_displacement 231 | orientation_matrices = quat2mat(scene_quaternions).reshape(sl, 3, 3) 232 | pose_matrices = np.concatenate((orientation_matrices, scene_positions.reshape(sl, 3, 1)), axis=2) 233 | poses_sequences.append(pose_matrices) 234 | 235 | img_sequences.append(imgs) 236 | sequence_index = len(img_sequences) - 1 237 | 238 | for i in range(demi_length, len(imgs)-demi_length): 239 | sample = {'intrinsics': intrinsics, 240 | 'tgt': i, 241 | 'prior_shifts': shifts[i][0], 242 | 'post_shifts': shifts[i][1], 243 | 'sequence_index': sequence_index} 244 | if self.with_depth_gt: 245 | sample['depth'] = depth[i] 246 | 247 | sequence_set.append(sample) 248 | random.shuffle(sequence_set) 249 | self.samples = sequence_set 250 | self.img_sequences = img_sequences 251 | if self.with_pose_gt: 252 | self.poses_sequences = poses_sequences 253 | -------------------------------------------------------------------------------- /datasets/stacked_sequence_folders.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | from imageio import imread 4 | from path import Path 5 | import random 6 | 7 | 8 | def load_as_float(path, sequence_length): 9 | stack = imread(path).astype(np.float32) 10 | h,w,_ = stack.shape 11 | w_img = int(w/(sequence_length)) 12 | imgs = [stack[:,i*w_img:(i+1)*w_img] for i in range(sequence_length)] 13 | tgt_index = sequence_length//2 14 | return([imgs[tgt_index]] + imgs[:tgt_index] + imgs[tgt_index+1:]) 15 | 16 | 17 | class SequenceFolder(data.Dataset): 18 | """A sequence data loader where the images are arranged in this way: 19 | root/scene_1/0000000.jpg 20 | root/scene_1/0000000_cam.txt 21 | root/scene_1/0000001.jpg 22 | root/scene_1/0000001_cam.txt 23 | . 24 | root/scene_2/0000000.jpg 25 | root/scene_2/0000000_cam.txt 26 | """ 27 | 28 | def __init__(self, root, seed=None, train=True, sequence_length=3, transform=None, target_transform=None): 29 | np.random.seed(seed) 30 | random.seed(seed) 31 | self.root = Path(root) 32 | self.samples = [] 33 | frames_list_path = self.root/'train.txt' if train else self.root/'val.txt' 34 | self.scenes = self.root.dirs() 35 | self.sequence_length = sequence_length 36 | for frame_path in open(frames_list_path): 37 | a,b = frame_path[:-1].split(' ') 38 | base_path = (self.root/a)/b 39 | intrinsics = np.genfromtxt(base_path+'_cam.txt', delimiter=',').astype(np.float32).reshape((3, 3)) 40 | sample = {'intrinsics': intrinsics, 'img_stack': base_path+'.jpg'} 41 | self.samples.append(sample) 42 | self.transform = transform 43 | 44 | def __getitem__(self, index): 45 | sample = self.samples[index] 46 | imgs = load_as_float(sample['img_stack'], self.sequence_length) 47 | if self.transform is not None: 48 | imgs, intrinsics = self.transform(imgs, np.copy(sample['intrinsics'])) 49 | else: 50 | intrinsics = sample['intrinsics'] 51 | return {'imgs':imgs, 'intrinsics':intrinsics} 52 | 53 | def __len__(self): 54 | return len(self.samples) 55 | -------------------------------------------------------------------------------- /datasets/validation_folders.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | from imageio import imread 4 | from path import Path 5 | 6 | 7 | def crawl_folders(folders_list): 8 | imgs = [] 9 | depth = [] 10 | for folder in folders_list: 11 | current_imgs = sorted(folder.files('*.jpg')) 12 | current_depth = [] 13 | for img in current_imgs: 14 | d = img.dirname()/(img.name[:-4] + '.npy') 15 | assert(d.isfile()), "depth file {} not found".format(str(d)) 16 | depth.append(d) 17 | imgs.extend(current_imgs) 18 | depth.extend(current_depth) 19 | return imgs, depth 20 | 21 | 22 | def load_as_float(path): 23 | return imread(path).astype(np.float32) 24 | 25 | 26 | class ValidationSet(data.Dataset): 27 | """A sequence data loader where the files are arranged in this way: 28 | root/scene_1/0000000.jpg 29 | root/scene_1/0000000.npy 30 | root/scene_1/0000001.jpg 31 | root/scene_1/0000001.npy 32 | .. 33 | root/scene_2/0000000.jpg 34 | root/scene_2/0000000.npy 35 | . 36 | 37 | transform functions must take in a list a images and a numpy array which can be None 38 | """ 39 | 40 | def __init__(self, root, transform=None): 41 | self.root = Path(root) 42 | scene_list_path = self.root/'val.txt' 43 | self.scenes = [self.root/folder[:-1] for folder in open(scene_list_path)] 44 | self.imgs, self.depth = crawl_folders(self.scenes) 45 | self.transform = transform 46 | 47 | def __getitem__(self, index): 48 | img = load_as_float(self.imgs[index]) 49 | depth = np.load(self.depth[index]).astype(np.float32) 50 | if self.transform is not None: 51 | img, _ = self.transform([img], None) 52 | img = img[0] 53 | return img, depth 54 | 55 | def __len__(self): 56 | return len(self.imgs) 57 | -------------------------------------------------------------------------------- /depthnet_unravel_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import models 3 | import argparse 4 | 5 | 6 | parser = argparse.ArgumentParser(description='DepthNet BN to DepthNet converter', 7 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | 9 | parser.add_argument('depth_bn_path', metavar='PATH', 10 | help='path to depthnet bn weights') 11 | parser.add_argument('--depth_path', default='depthnet.pth.tar', metavar='PATH', 12 | help='where to save depthnet weights') 13 | args = parser.parse_args() 14 | eps = 1e-3 15 | depthnet = models.DepthNet().cuda() 16 | 17 | depth_bn = torch.load(args.depth_bn_path) 18 | 19 | depth_bn_state = depth_bn['state_dict'] 20 | 21 | depthnet.load_state_dict(depth_bn_state, strict=False) 22 | 23 | state_dict = depthnet.state_dict() 24 | 25 | for k in depth_bn_state.keys(): 26 | if 'running_mean' in k: 27 | layer, index, _ = k.split('.') 28 | rm = depth_bn_state['.'.join([layer, index, 'running_mean'])] 29 | rv = depth_bn_state['.'.join([layer, index, 'running_var'])] 30 | w = depth_bn_state['.'.join([layer, index, 'weight'])] 31 | b = depth_bn_state['.'.join([layer, index, 'bias'])] 32 | 33 | conv_w = state_dict['.'.join([layer, str(int(index)-1), 'weight'])] 34 | conv_b = state_dict['.'.join([layer, str(int(index)-1), 'bias'])] 35 | 36 | inv_std = (rv + eps).pow(-0.5) 37 | 38 | conv_w.mul_(inv_std.view(conv_w.size(0), 1, 1, 1)) 39 | conv_b.add_(-rm).mul_(inv_std) 40 | conv_w.mul_(w.view(conv_w.size(0), 1, 1, 1)) 41 | conv_b.mul_(w).add_(b) 42 | 43 | depth_bn['state_dict'] = state_dict 44 | depth_bn['bn'] = False 45 | torch.save(depth_bn, args.depth_path) -------------------------------------------------------------------------------- /inverse_warp.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | pixel_coords = None 6 | 7 | 8 | def set_id_grid(depth): 9 | global pixel_coords 10 | b, h, w = depth.size() 11 | i_range = torch.arange(0, h, dtype=depth.dtype, device=depth.device).view(1, h, 1).expand(1, h, w) # [1, H, W] 12 | j_range = torch.arange(0, w, dtype=depth.dtype, device=depth.device).view(1, 1, w).expand(1, h, w) # [1, H, W] 13 | ones = torch.ones(1, h, w, dtype=depth.dtype, device=depth.device) 14 | 15 | pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W] 16 | 17 | 18 | def check_sizes(input, input_name, expected): 19 | condition = [input.ndimension() == len(expected)] 20 | for i,size in enumerate(expected): 21 | if size.isdigit(): 22 | condition.append(input.size(i) == int(size)) 23 | assert(all(condition)), "wrong size for {}, expected {}, got {}".format(input_name, 'x'.join(expected), list(input.size())) 24 | 25 | 26 | @torch.jit.script 27 | def compensate_pose(matrices, ref_matrix): 28 | # check_sizes(matrices, 'matrices', 'BS34') 29 | # check_sizes(ref_matrix, 'reference matrix', 'B34') 30 | translation_vectors = matrices[..., -1:] - ref_matrix[..., -1:].unsqueeze(1) 31 | inverse_rot = ref_matrix[..., :-1].transpose(1, 2).unsqueeze(1) 32 | return inverse_rot @ torch.cat([matrices[..., :-1], translation_vectors], dim=-1) 33 | 34 | 35 | @torch.jit.script 36 | def invert_mat(matrices): 37 | # check_sizes(matrices, 'matrices', 'BS34') 38 | rot_matrices = matrices[..., :-1].transpose(2, 3) 39 | translation_vectors = - rot_matrices @ matrices[..., -1:] 40 | return(torch.cat([rot_matrices, translation_vectors], dim=-1)) 41 | 42 | 43 | def pose_vec2mat(vec, rotation_mode='euler'): 44 | """ 45 | Convert 6DoF parameters to transformation matrix. 46 | 47 | Args:s 48 | vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz -- [B, 6] 49 | Returns: 50 | A transformation matrix -- [B, 3, 4] 51 | """ 52 | check_sizes(vec, 'rotation vector', 'BS6') 53 | translation = vec[:, :, :3].unsqueeze(-1) # [B, S, 3, 1] 54 | rot = vec[:, :, 3:] 55 | if rotation_mode == 'euler': 56 | rot_mat = euler2mat(rot) # [B, S, 3, 3] 57 | elif rotation_mode == 'quat': 58 | rot_mat = quat2mat(rot) # [B, S, 3, 3] 59 | transform_mat = torch.cat([rot_mat, translation], dim=-1) # [B, S, 3, 4] 60 | return transform_mat 61 | 62 | 63 | def pixel2cam(depth): 64 | """Transform coordinates in the pixel frame to the camera frame. 65 | Args: 66 | depth: depth maps -- [B, H, W] 67 | Returns: 68 | array of (u,v,1) cam coordinates -- [B, 3, H, W] 69 | """ 70 | global pixel_coords 71 | b, h, w = depth.size() 72 | if (pixel_coords is None) or pixel_coords.size(2) < h: 73 | set_id_grid(depth) 74 | pixel_coords.type_as(depth) 75 | cam_coords = pixel_coords[..., :h, :w].expand(b, 3, h, w) * depth.unsqueeze(1) 76 | return cam_coords.contiguous() 77 | 78 | 79 | @torch.jit.script 80 | def cam2pixel(cam_coords): 81 | """Transform coordinates in the camera frame to the pixel frame. 82 | Args: 83 | cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 4, H, W] 84 | proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4] 85 | Returns: 86 | array of [-1,1] coordinates -- [B, 2, H, W] 87 | """ 88 | b, _, h, w = cam_coords.size() 89 | pcoords = cam_coords.view(b, 3, -1) # [B, 3, H*W] 90 | 91 | X = pcoords[:, 0] 92 | Y = pcoords[:, 1] 93 | Z = pcoords[:, 2].clamp(min=1e-3) 94 | 95 | X_norm = 2*(X / Z)/(w-1) - 1 # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W] 96 | Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W] 97 | 98 | pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2] 99 | return pixel_coords.view(b, h, w, 2) 100 | 101 | 102 | @torch.jit.script 103 | def euler2mat(angle): 104 | """Convert euler angles to rotation matrix. 105 | 106 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174 107 | 108 | Args: 109 | angle: rotation angle along 3 axis (in radians) -- size = [B, S, 3] 110 | Returns: 111 | Rotation matrix corresponding to the euler angles -- size = [B, S, 3, 3] 112 | """ 113 | B, S = angle.size()[:2] 114 | x, y, z = angle[..., 0], angle[..., 1], angle[..., 2] 115 | 116 | cosz = torch.cos(z) 117 | sinz = torch.sin(z) 118 | 119 | zeros = z.detach() * 0 120 | ones = zeros.detach() + 1 121 | zmat = torch.stack([cosz, -sinz, zeros, 122 | sinz, cosz, zeros, 123 | zeros, zeros, ones], dim=-1).view(B, S, 3, 3) 124 | 125 | cosy = torch.cos(y) 126 | siny = torch.sin(y) 127 | 128 | ymat = torch.stack([cosy, zeros, siny, 129 | zeros, ones, zeros, 130 | -siny, zeros, cosy], dim=-1).view(B, S, 3, 3) 131 | 132 | cosx = torch.cos(x) 133 | sinx = torch.sin(x) 134 | 135 | xmat = torch.stack([ones, zeros, zeros, 136 | zeros, cosx, -sinx, 137 | zeros, sinx, cosx], dim=-1).view(B, S, 3, 3) 138 | rotMat = xmat @ ymat @ zmat 139 | return rotMat 140 | 141 | 142 | @torch.jit.script 143 | def quat2mat(quat): 144 | """Convert quaternion coefficients to rotation matrix. 145 | 146 | Args: 147 | quat: first three coeff of quaternion of rotation. fourth is then computed to have a norm of 1 -- size = [B, S, 3] 148 | Returns: 149 | Rotation matrix corresponding to the quaternion -- size = [B, S, 3, 3] 150 | """ 151 | norm_quat = torch.cat([quat[..., :1].detach() * 0 + 1, quat], dim=1) 152 | norm_quat = norm_quat / norm_quat.norm(p=2, dim=-1, keepdim=True) 153 | w, x, y, z = norm_quat[..., 0], norm_quat[..., 1], norm_quat[..., 2], norm_quat[..., 3] 154 | 155 | B, S = quat.size()[:2] 156 | 157 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 158 | wx, wy, wz = w*x, w*y, w*z 159 | xy, xz, yz = x*y, x*z, y*z 160 | 161 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 162 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 163 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, S, 3, 3) 164 | return rotMat 165 | 166 | 167 | def inverse_warp(img, depth, pose_matrix, intrinsics, rotation_mode='euler'): 168 | """ 169 | Inverse warp a source image to the target image plane. 170 | 171 | Args: 172 | img: the source image (where to sample pixels) -- [B, 3, H, W] 173 | depth: depth map of the target image -- [B, H, W] 174 | pose: 6DoF pose parameters from target to source -- [B, 6] 175 | intrinsics: camera intrinsic matrix -- [B, 3, 3] 176 | intrinsics_inv: inverse of the intrinsic matrix -- [B, 3, 3] 177 | Returns: 178 | Source image warped to the target image plane 179 | """ 180 | check_sizes(img, 'img', 'B3HW') 181 | check_sizes(depth, 'depth', 'BHW') 182 | check_sizes(pose_matrix, 'pose', 'B34') 183 | check_sizes(intrinsics, 'intrinsics', 'B33') 184 | intrinsics_inv = intrinsics.inverse() 185 | 186 | b, h, w = depth.shape 187 | batch_size, _, img_height, img_width = img.size() 188 | 189 | point_cloud = pixel2cam(depth) # [B,3,H,W] 190 | 191 | # Get projection matrix for tgt camera frame to source pixel frame 192 | rot = intrinsics @ pose_matrix[:,:,:-1] @ intrinsics_inv # [B, 3, 3] 193 | tr = intrinsics @ pose_matrix[:,:,-1:] 194 | 195 | transformed_points = rot @ point_cloud.view(b, 3, -1) + tr 196 | src_pixel_coords = cam2pixel(transformed_points.view(b, 3, h, w)) # [B,H,W,2] 197 | projected_img = F.grid_sample(img, src_pixel_coords, padding_mode='border', align_corners=True) 198 | 199 | with torch.no_grad(): 200 | valid_points = src_pixel_coords.abs().max(dim=-1)[0] <= 1 201 | 202 | return projected_img, valid_points 203 | 204 | 205 | def inverse_rotate(features, rot_matrix, intrinsics, rotation_mode='euler'): 206 | """ 207 | Inverse warp a source image to the target image plane. 208 | 209 | Args: 210 | features: the source image (where to sample pixels) -- [B, C, H, W] 211 | depth: depth map of the target image -- [B, H, W] 212 | pose: 6DoF pose parameters from target to source -- [B, 6] 213 | intrinsics: camera intrinsic matrix -- [B, 3, 3] 214 | intrinsics_inv: inverse of the intrinsic matrix -- [B, 3, 3] 215 | Returns: 216 | Source image warped to the target image plane 217 | """ 218 | check_sizes(features, 'features', 'BCHW') 219 | check_sizes(rot_matrix, 'rotation matrix', 'B33') 220 | check_sizes(intrinsics, 'intrinsics', 'B33') 221 | 222 | b, _, h, w = features.size() 223 | intrinsics_inv = intrinsics.inverse() 224 | 225 | # construct a fake depth, with 1 everywhere 226 | depth = features.new_ones([b, h, w]) 227 | 228 | cam_coords = pixel2cam(depth) # [B,3,H,W] 229 | 230 | # Get projection matrix for tgt camera frame to source pixel frame 231 | rot = intrinsics @ rot_matrix @ intrinsics_inv # [B, 3, 3] 232 | transformed_points = rot @ cam_coords.view(b, 3, -1) 233 | src_pixel_coords = cam2pixel(transformed_points.view(b, 3, h, w)) # [B,H,W,2] 234 | projected_img = F.grid_sample(features, src_pixel_coords, padding_mode='border', align_corners=True) 235 | return projected_img 236 | -------------------------------------------------------------------------------- /kitti_eval/depth_evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # Mostly based on the code written by Clement Godard: 2 | # https://github.com/mrharicot/monodepth/blob/master/utils/evaluation_utils.py 3 | import numpy as np 4 | # import pandas as pd 5 | from collections import Counter 6 | from path import Path 7 | from imageio import imread 8 | from tqdm import tqdm 9 | 10 | 11 | class test_framework_KITTI(object): 12 | def __init__(self, root, test_files, seq_length=3, min_depth=1e-3, max_depth=100, step=1): 13 | self.root = root 14 | self.min_depth, self.max_depth = min_depth, max_depth 15 | self.calib_dirs, self.gt_files, self.img_files, self.tgt_indices, self.poses, self.cams = read_scene_data(self.root, test_files, seq_length, step) 16 | 17 | def __getitem__(self, i): 18 | tgt = imread(self.img_files[i][0]).astype(np.float32) 19 | depth = generate_depth_map(self.calib_dirs[i], self.gt_files[i], tgt.shape[:2], self.cams[i]) 20 | intrinsics = read_calib_file(self.calib_dirs[i]/'calib_cam_to_cam.txt')['P_rect_02'].reshape(3,4)[:,:3] 21 | 22 | return {'imgs': [imread(img).astype(np.float32) for img in self.img_files[i]], 23 | 'tgt_index': self.tgt_indices[i], 24 | 'path': self.img_files[i][0], 25 | 'gt_depth': depth, 26 | 'poses': self.poses[i], 27 | 'mask': generate_mask(depth, self.min_depth, self.max_depth), 28 | 'intrinsics': intrinsics 29 | } 30 | 31 | def __len__(self): 32 | return len(self.img_files) 33 | 34 | 35 | ############################################################################### 36 | # EIGEN 37 | 38 | def transform_from_rot_trans(R, t): 39 | """Transforation matrix from rotation matrix and translation vector.""" 40 | R = R.reshape(3, 3) 41 | t = t.reshape(3, 1) 42 | return np.vstack((np.hstack([R, t]), [0, 0, 0, 1])) 43 | 44 | 45 | def rotx(t): 46 | """Rotation about the x-axis.""" 47 | c = np.cos(t) 48 | s = np.sin(t) 49 | return np.array([[1, 0, 0], 50 | [0, c, -s], 51 | [0, s, c]]) 52 | 53 | 54 | def roty(t): 55 | """Rotation about the y-axis.""" 56 | c = np.cos(t) 57 | s = np.sin(t) 58 | return np.array([[c, 0, s], 59 | [0, 1, 0], 60 | [-s, 0, c]]) 61 | 62 | 63 | def rotz(t): 64 | """Rotation about the z-axis.""" 65 | c = np.cos(t) 66 | s = np.sin(t) 67 | return np.array([[c, -s, 0], 68 | [s, c, 0], 69 | [0, 0, 1]]) 70 | 71 | 72 | def pose_from_oxts_packet(metadata): 73 | 74 | lat, lon, alt, roll, pitch, yaw, *_ = metadata 75 | """Helper method to compute a SE(3) pose matrix from an OXTS packet. 76 | Taken from https://github.com/utiasSTARS/pykitti 77 | """ 78 | 79 | er = 6378137. # earth radius (approx.) in meters 80 | # Use a Mercator projection to get the translation vector 81 | ty = lat * np.pi * er / 180. 82 | 83 | scale = np.cos(lat * np.pi / 180.) 84 | tx = scale * lon * np.pi * er / 180. 85 | # ty = scale * er * \ 86 | # np.log(np.tan((90. + lat) * np.pi / 360.)) 87 | tz = alt 88 | t = np.array([tx, ty, tz]).reshape(-1,1) 89 | 90 | # Use the Euler angles to get the rotation matrix 91 | Rx = rotx(roll) 92 | Ry = roty(pitch) 93 | Rz = rotz(yaw) 94 | R = Rz.dot(Ry.dot(Rx)) 95 | return transform_from_rot_trans(R, t) 96 | 97 | 98 | def get_pose_matrices(oxts_root, imu2cam, indices): 99 | matrices = [] 100 | for index in indices: 101 | oxts_data = np.genfromtxt(oxts_root/'data'/'{:010d}.txt'.format(index)) 102 | pose = pose_from_oxts_packet(oxts_data) 103 | odo_pose = pose @ np.linalg.inv(imu2cam) 104 | matrices.append(odo_pose) 105 | matrices_seq = np.stack(matrices) 106 | matrices_seq[:,:3,-1] -= matrices_seq[-1,:3,-1] 107 | matrices_seq = np.linalg.inv(matrices_seq[-1]) @ matrices_seq 108 | return matrices_seq[:,:3].astype(np.float32) 109 | 110 | 111 | def read_scene_data(data_root, test_list, seq_length=3, step=1): 112 | data_root = Path(data_root) 113 | gt_files = [] 114 | calib_dirs = [] 115 | im_files = [] 116 | cams = [] 117 | poses = [] 118 | tgt_indices = [] 119 | shift_range = step * (np.arange(seq_length)) 120 | 121 | print('getting test metadata ... ') 122 | for sample in tqdm(test_list): 123 | tgt_img_path = data_root/sample 124 | date, scene, cam_id, _, index = sample[:-4].split('/') 125 | index = int(index) 126 | 127 | imu2velo = read_calib_file(data_root/date/'calib_imu_to_velo.txt') 128 | velo2cam = read_calib_file(data_root/date/'calib_velo_to_cam.txt') 129 | cam2cam = read_calib_file(data_root/date/'calib_cam_to_cam.txt') 130 | 131 | velo2cam_mat = transform_from_rot_trans(velo2cam['R'], velo2cam['T']) 132 | imu2velo_mat = transform_from_rot_trans(imu2velo['R'], imu2velo['T']) 133 | cam_2rect_mat = transform_from_rot_trans(cam2cam['R_rect_00'], np.zeros(3)) 134 | 135 | imu2cam = cam_2rect_mat @ velo2cam_mat @ imu2velo_mat 136 | 137 | scene_length = len(tgt_img_path.parent.files('*.png')) 138 | 139 | # if index is high enough, take only frames before. Otherwise, take only frames after. 140 | if index - shift_range[-1] > 0: 141 | ref_indices = index + shift_range - shift_range[-1] 142 | tgt_index = seq_length - 1 143 | elif index + shift_range[-1] < scene_length: 144 | ref_indices = index + shift_range 145 | tgt_index = 0 146 | else: 147 | raise 148 | 149 | imgs_path = [tgt_img_path.dirname()/'{:010d}.png'.format(i) for i in ref_indices] 150 | vel_path = data_root/date/scene/'velodyne_points'/'data'/'{:010d}.bin'.format(index) 151 | 152 | if tgt_img_path.isfile(): 153 | gt_files.append(vel_path) 154 | calib_dirs.append(data_root/date) 155 | im_files.append(imgs_path) 156 | cams.append(int(cam_id[-2:])) 157 | poses.append(get_pose_matrices(data_root/date/scene/'oxts', imu2cam, ref_indices)) 158 | tgt_indices.append(tgt_index) 159 | else: 160 | print('{} missing'.format(tgt_img_path)) 161 | 162 | return calib_dirs, gt_files, im_files, tgt_indices, poses, cams 163 | 164 | 165 | def load_velodyne_points(file_name): 166 | # adapted from https://github.com/hunse/kitti 167 | points = np.fromfile(file_name, dtype=np.float32).reshape(-1, 4) 168 | points[:,3] = 1 169 | return points 170 | 171 | 172 | def read_calib_file(path): 173 | # taken from https://github.com/hunse/kitti 174 | float_chars = set("0123456789.e+- ") 175 | data = {} 176 | with open(path, 'r') as f: 177 | for line in f.readlines(): 178 | key, value = line.split(':', 1) 179 | value = value.strip() 180 | data[key] = value 181 | if float_chars.issuperset(value): 182 | # try to cast to float array 183 | try: 184 | data[key] = np.float32(list(map(float, value.split(' ')))) 185 | except ValueError: 186 | # casting error: data[key] already eq. value, so pass 187 | pass 188 | 189 | return data 190 | 191 | 192 | def get_focal_length_baseline(calib_dir, cam=2): 193 | cam2cam = read_calib_file(calib_dir + 'calib_cam_to_cam.txt') 194 | P2_rect = cam2cam['P_rect_02'].reshape(3,4) 195 | P3_rect = cam2cam['P_rect_03'].reshape(3,4) 196 | 197 | # cam 2 is left of camera 0 -6cm 198 | # cam 3 is to the right +54cm 199 | b2 = P2_rect[0,3] / -P2_rect[0,0] 200 | b3 = P3_rect[0,3] / -P3_rect[0,0] 201 | baseline = b3-b2 202 | 203 | if cam == 2: 204 | focal_length = P2_rect[0,0] 205 | elif cam == 3: 206 | focal_length = P3_rect[0,0] 207 | 208 | return focal_length, baseline 209 | 210 | 211 | def sub2ind(matrixSize, rowSub, colSub): 212 | m, n = matrixSize 213 | return rowSub * (n-1) + colSub - 1 214 | 215 | 216 | def generate_depth_map(calib_dir, velo_file_name, im_shape, cam=2): 217 | # load calibration files 218 | cam2cam = read_calib_file(calib_dir/'calib_cam_to_cam.txt') 219 | velo2cam = read_calib_file(calib_dir/'calib_velo_to_cam.txt') 220 | velo2cam = np.hstack((velo2cam['R'].reshape(3,3), velo2cam['T'][..., np.newaxis])) 221 | velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0]))) 222 | 223 | # compute projection matrix velodyne->image plane 224 | R_cam2rect = np.eye(4) 225 | R_cam2rect[:3,:3] = cam2cam['R_rect_00'].reshape(3,3) 226 | P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3,4) 227 | P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam) 228 | 229 | # load velodyne points and remove all behind image plane (approximation) 230 | # each row of the velodyne data is forward, left, up, reflectance 231 | velo = load_velodyne_points(velo_file_name) 232 | velo = velo[velo[:, 0] >= 0, :] 233 | 234 | # project the points to the camera 235 | velo_pts_im = np.dot(P_velo2im, velo.T).T 236 | velo_pts_im[:, :2] = velo_pts_im[:,:2] / velo_pts_im[:,-1:] 237 | 238 | # check if in bounds 239 | # use minus 1 to get the exact same value as KITTI matlab code 240 | velo_pts_im[:, 0] = np.round(velo_pts_im[:,0]) - 1 241 | velo_pts_im[:, 1] = np.round(velo_pts_im[:,1]) - 1 242 | val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) 243 | val_inds = val_inds & (velo_pts_im[:,0] < im_shape[1]) & (velo_pts_im[:,1] < im_shape[0]) 244 | velo_pts_im = velo_pts_im[val_inds, :] 245 | 246 | # project to image 247 | depth = np.zeros((im_shape)) 248 | depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] 249 | 250 | # find the duplicate points and choose the closest depth 251 | inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) 252 | dupe_inds = [item for item, count in Counter(inds).items() if count > 1] 253 | for dd in dupe_inds: 254 | pts = np.where(inds == dd)[0] 255 | x_loc = int(velo_pts_im[pts[0], 0]) 256 | y_loc = int(velo_pts_im[pts[0], 1]) 257 | depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() 258 | depth[depth < 0] = 0 259 | return depth 260 | 261 | 262 | def generate_mask(gt_depth, min_depth, max_depth): 263 | mask = np.logical_and(gt_depth > min_depth, 264 | gt_depth < max_depth) 265 | # crop used by Garg ECCV16 to reprocude Eigen NIPS14 results 266 | # if used on gt_size 370x1224 produces a crop of [-218, -3, 44, 1180] 267 | gt_height, gt_width = gt_depth.shape 268 | crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height, 269 | 0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32) 270 | 271 | crop_mask = np.zeros(mask.shape) 272 | crop_mask[crop[0]:crop[1],crop[2]:crop[3]] = 1 273 | mask = np.logical_and(mask, crop_mask) 274 | return mask 275 | -------------------------------------------------------------------------------- /kitti_eval/pose_evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # Mostly based on the code written by Clement Godard: 2 | # https://github.com/mrharicot/monodepth/blob/master/utils/evaluation_utils.py 3 | import numpy as np 4 | # import pandas as pd 5 | from path import Path 6 | from scipy.misc import imread 7 | from tqdm import tqdm 8 | 9 | width_to_focal = dict() 10 | width_to_focal[1242] = 721.5377 11 | width_to_focal[1241] = 718.856 12 | width_to_focal[1224] = 707.0493 13 | width_to_focal[1238] = 718.3351 14 | 15 | 16 | class test_framework_KITTI(object): 17 | def __init__(self, root, sequence_set, seq_length=3, step=1): 18 | self.root = root 19 | self.img_files, self.poses, self.sample_indices = read_scene_data(self.root, sequence_set, seq_length, step) 20 | 21 | def generator(self): 22 | for img_list, pose_list, sample_list in zip(self.img_files, self.poses, self.sample_indices): 23 | for snippet_indices in sample_list: 24 | imgs = [imread(img_list[i]).astype(np.float32) for i in snippet_indices] 25 | 26 | poses = np.stack(pose_list[i] for i in snippet_indices) 27 | first_pose = poses[0] 28 | poses[:,:,-1] -= first_pose[:,-1] 29 | compensated_poses = np.linalg.inv(first_pose[:,:3]) @ poses 30 | 31 | yield {'imgs': imgs, 32 | 'path': img_list[0], 33 | 'poses': compensated_poses 34 | } 35 | 36 | def __iter__(self): 37 | return self.generator() 38 | 39 | def __len__(self): 40 | return sum(len(imgs) for imgs in self.img_files) 41 | 42 | 43 | def read_scene_data(data_root, sequence_set, seq_length=3, step=1): 44 | data_root = Path(data_root) 45 | im_sequences = [] 46 | poses_sequences = [] 47 | indices_sequences = [] 48 | demi_length = (seq_length - 1) // 2 49 | shift_range = np.array([step*i for i in range(-demi_length, demi_length + 1)]).reshape(1, -1) 50 | 51 | sequences = set() 52 | for seq in sequence_set: 53 | corresponding_dirs = set((data_root/'sequences').dirs(seq)) 54 | sequences = sequences | corresponding_dirs 55 | 56 | print('getting test metadata for theses sequences : {}'.format(sequences)) 57 | for sequence in tqdm(sequences): 58 | poses = np.genfromtxt(data_root/'poses'/'{}.txt'.format(sequence.name)).astype(np.float32).reshape(-1, 3, 4) 59 | imgs = sorted((sequence/'image_2').files('*.png')) 60 | # construct 5-snippet sequences 61 | tgt_indices = np.arange(demi_length, len(imgs) - demi_length).reshape(-1, 1) 62 | snippet_indices = shift_range + tgt_indices 63 | im_sequences.append(imgs) 64 | poses_sequences.append(poses) 65 | indices_sequences.append(snippet_indices) 66 | return im_sequences, poses_sequences, indices_sequences -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from blessings import Terminal 2 | import progressbar 3 | import sys 4 | 5 | 6 | class TermLogger(object): 7 | def __init__(self, n_epochs, train_size, valid_size): 8 | self.n_epochs = n_epochs 9 | self.train_size = train_size 10 | self.valid_size = valid_size 11 | self.t = Terminal() 12 | s = 10 13 | e = 1 # epoch bar position 14 | tr = 3 # train bar position 15 | ts = 6 # valid bar position 16 | h = self.t.height 17 | 18 | for i in range(10): 19 | print('') 20 | self.epoch_bar = progressbar.ProgressBar(max_value=n_epochs, fd=Writer(self.t, (0, h-s+e))) 21 | 22 | self.train_writer = Writer(self.t, (0, h-s+tr)) 23 | self.train_bar_writer = Writer(self.t, (0, h-s+tr+1)) 24 | 25 | self.valid_writer = Writer(self.t, (0, h-s+ts)) 26 | self.valid_bar_writer = Writer(self.t, (0, h-s+ts+1)) 27 | 28 | self.reset_train_bar() 29 | self.reset_valid_bar() 30 | 31 | def reset_train_bar(self, train_size=None): 32 | self.train_bar = progressbar.ProgressBar(max_value=train_size if train_size is not None else self.train_size, 33 | fd=self.train_bar_writer) 34 | 35 | def reset_valid_bar(self): 36 | self.valid_bar = progressbar.ProgressBar(max_value=self.valid_size, fd=self.valid_bar_writer) 37 | 38 | 39 | class Writer(object): 40 | """Create an object with a write method that writes to a 41 | specific place on the screen, defined at instantiation. 42 | 43 | This is the glue between blessings and progressbar. 44 | """ 45 | 46 | def __init__(self, t, location): 47 | """ 48 | Input: location - tuple of ints (x, y), the position 49 | of the bar in the terminal 50 | """ 51 | self.location = location 52 | self.t = t 53 | 54 | def write(self, string): 55 | with self.t.location(*self.location): 56 | sys.stdout.write("\033[K") 57 | print(string) 58 | 59 | def flush(self): 60 | return 61 | 62 | 63 | class AverageMeter(object): 64 | """Computes and stores the average and current value""" 65 | 66 | def __init__(self, i=1, precision=3): 67 | self.meters = i 68 | self.precision = precision 69 | self.reset(self.meters) 70 | 71 | def reset(self, i): 72 | self.val = [0]*i 73 | self.avg = [0]*i 74 | self.sum = [0]*i 75 | self.count = 0 76 | 77 | def update(self, val, n=1): 78 | if not isinstance(val, list): 79 | val = [val] 80 | assert(len(val) == self.meters) 81 | self.count += n 82 | for i,v in enumerate(val): 83 | self.val[i] = v 84 | self.sum[i] += v * n 85 | self.avg[i] = self.sum[i] / self.count 86 | 87 | def __repr__(self): 88 | val = ' '.join(['{:.{}f}'.format(v, self.precision) for v in self.val]) 89 | avg = ' '.join(['{:.{}f}'.format(a, self.precision) for a in self.avg]) 90 | return '{} ({})'.format(val, avg) 91 | -------------------------------------------------------------------------------- /loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from inverse_warp import inverse_warp 4 | import math 5 | # from ssim import SSIM 6 | 7 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 8 | 9 | # ssim_mapper = SSIM(window_size=3) 10 | 11 | 12 | def photometric_reconstruction_loss(imgs, tgt_indices, ref_indices, 13 | depth, pose, intrinsics, 14 | rotation_mode='euler', ssim_weight=0, 15 | upsample=False): 16 | assert(pose.size(1) == imgs.size(1)) 17 | b, _, h, w = depth.size() 18 | loss = torch.tensor(0, dtype=torch.float32, device=device) 19 | if b == 0: 20 | return loss, None, None 21 | batch_range = torch.arange(b, dtype=torch.int64, device=device) 22 | 23 | b, s, c, hi, wi = imgs.size() 24 | 25 | assert(hi >= h and wi >= w), "Depth size is greater than img size, which is probably not what you want" 26 | if upsample: 27 | imgs_scaled = imgs 28 | intrinsics_scaled = intrinsics 29 | else: 30 | downscale = hi/h 31 | imgs_scaled = F.interpolate(imgs, (c, h, w), mode='area') 32 | intrinsics_scaled = torch.cat((intrinsics[:, 0:2]/downscale, intrinsics[:, 2:]), dim=1) 33 | 34 | tgt_img_scaled = imgs_scaled[batch_range, tgt_indices] 35 | 36 | warped_results, diff, dssim, valid = [], [], [], [] 37 | 38 | for i in range(s - 1): 39 | idx = ref_indices[:, i] 40 | current_pose = pose[batch_range, idx] 41 | ref_img = imgs[batch_range, idx] 42 | ref_img_warped, valid_points = inverse_warp(ref_img, 43 | depth[:,0], 44 | current_pose, 45 | intrinsics_scaled, 46 | rotation_mode) 47 | 48 | dssim_loss_map = (0.5*(1-ssim(tgt_img_scaled + 1, ref_img_warped + 1))).clamp(0,1) if ssim_weight > 0 else 0 49 | 50 | diff_map = tgt_img_scaled - ref_img_warped 51 | 52 | loss_map = ssim_weight * dssim_loss_map + (1-ssim_weight) * diff_map.abs() 53 | 54 | valid_loss_values = loss_map.masked_select(valid_points.unsqueeze(1)) 55 | if valid_loss_values.numel() > 0: 56 | loss += valid_loss_values.mean() 57 | 58 | warped_results.append(ref_img_warped[0]) 59 | dssim.append(dssim_loss_map[0]) 60 | diff.append(diff_map[0]) 61 | valid.append(valid_points[0]) 62 | return loss, warped_results, diff, dssim, valid 63 | 64 | 65 | grad_kernel = torch.FloatTensor([[ 1, 2, 1], 66 | [ 0, 0, 0], 67 | [-1,-2,-1]]).view(1,1,3,3).to(device)/4 68 | grad_img_kernel = grad_kernel.expand(3,1,3,3).contiguous() 69 | lapl_kernel = torch.FloatTensor([[-1,-2,-1], 70 | [-2,12,-2], 71 | [-1,-2,-1]]).view(1,1,3,3).to(device)/12 72 | 73 | 74 | def create_gaussian_window(window_size, channel): 75 | def _gaussian(window_size, sigma): 76 | gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 77 | return gauss/gauss.sum() 78 | _1D_window = _gaussian(window_size, 1.5).unsqueeze(1) 79 | _2D_window = _1D_window @ (_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 80 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 81 | return window 82 | 83 | 84 | window_size = 3 85 | gaussian_img_kernel = create_gaussian_window(window_size, 3).float().to(device) 86 | 87 | 88 | def grad_diffusion_loss(pred_disp, img=None, kappa=0.1): 89 | if type(pred_disp) not in [tuple, list]: 90 | pred_disp = [pred_disp] 91 | 92 | loss = 0 93 | weight = 1. 94 | 95 | for scaled_disp in pred_disp: 96 | b, _, h, w = scaled_disp.shape 97 | if img is not None: 98 | with torch.no_grad(): 99 | img_scaled = F.interpolate(img, (h, w), mode='area').norm(p=1, dim=1, keepdim=True) 100 | dx_i = img_scaled[:, :, 2:] - img_scaled[:, :, :-2] 101 | dy_i = img_scaled[:, :, :, 2:] - img_scaled[:, :, :, :-2] 102 | gx = torch.exp(-(dx_i.abs()/kappa)**2) 103 | gy = torch.exp(-(dy_i.abs()/kappa)**2) 104 | else: 105 | gx = gy = 1 106 | 107 | dx2 = scaled_disp[:,:, 2:] - 2 * scaled_disp[:,:,1:-1] + scaled_disp[:,:,:-2] 108 | dy2 = scaled_disp[:,:,:, 2:] - 2 * scaled_disp[:,:,:,1:-1] + scaled_disp[:,:,:,:-2] 109 | dx2 *= gx 110 | dy2 *= gy 111 | loss += (dx2.pow(2).mean() + dy2.pow(2).mean()) * weight 112 | weight /= 2 113 | return loss 114 | 115 | 116 | def ssim(img1, img2): 117 | params = {'weight': gaussian_img_kernel, 'groups':3, 'padding':window_size//2} 118 | mu1 = F.conv2d(img1, **params) 119 | mu2 = F.conv2d(img2, **params) 120 | 121 | mu1_sq = mu1.pow(2) 122 | mu2_sq = mu2.pow(2) 123 | mu1_mu2 = mu1*mu2 124 | 125 | sigma1_sq = F.conv2d(img1*img1, **params) - mu1_sq 126 | sigma2_sq = F.conv2d(img2*img2, **params) - mu2_sq 127 | sigma12 = F.conv2d(img1*img2, **params) - mu1_mu2 128 | 129 | C1 = 0.01**2 130 | C2 = 0.03**2 131 | 132 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 133 | return ssim_map 134 | 135 | 136 | @torch.no_grad() 137 | def compute_depth_errors(gt, pred, max_depth=80, crop=True): 138 | abs_diff, abs_rel, abs_log, a1, a2, a3 = 0,0,0,0,0,0 139 | b, h, w = gt.size() 140 | if pred.size(1) != h: 141 | pred_upscaled = F.interpolate(pred, (h, w), mode='bilinear', align_corners=False)[:,0] 142 | else: 143 | pred_upscaled = pred[0:,] 144 | 145 | ''' 146 | crop used by Garg ECCV16 to reprocude Eigen NIPS14 results 147 | construct a mask of False values, with the same size as target 148 | and then set to True values inside the crop 149 | ''' 150 | if crop: 151 | crop_mask = gt[0] != gt[0] 152 | y1,y2 = int(0.40810811 * gt.size(1)), int(0.99189189 * gt.size(1)) 153 | x1,x2 = int(0.03594771 * gt.size(2)), int(0.96405229 * gt.size(2)) 154 | crop_mask[y1:y2,x1:x2] = 1 155 | 156 | skipped = 0 157 | for current_gt, current_pred in zip(gt, pred_upscaled): 158 | valid = (current_gt > 0) & (current_gt < max_depth) 159 | if crop: 160 | valid = valid & crop_mask 161 | if valid.sum() == 0: 162 | skipped += 1 163 | continue 164 | valid_gt = current_gt[valid] 165 | valid_pred = current_pred[valid].clamp(1e-3, max_depth) 166 | 167 | thresh = torch.max((valid_gt / valid_pred), (valid_pred / valid_gt)) 168 | a1 += (thresh < 1.25).float().mean() 169 | a2 += (thresh < 1.25 ** 2).float().mean() 170 | a3 += (thresh < 1.25 ** 3).float().mean() 171 | 172 | abs_diff += torch.mean(torch.abs(valid_gt - valid_pred)) 173 | abs_rel += torch.mean(torch.abs(valid_gt - valid_pred) / valid_gt) 174 | 175 | abs_log += torch.mean(torch.abs(torch.log(valid_gt) - torch.log(valid_pred))) 176 | if skipped == b: 177 | return None 178 | else: 179 | return [metric / (b - skipped) for metric in [abs_diff, abs_rel, abs_log, a1, a2, a3]] 180 | 181 | 182 | @torch.no_grad() 183 | def compute_pose_error(gt, pred): 184 | ATE = 0 185 | RE = 0 186 | batch_size, seq_length = gt.size()[:2] 187 | for gt_pose_seq, pred_pose_seq in zip(gt, pred): 188 | scale_factor = (gt_pose_seq[:,:,-1] * pred_pose_seq[:,:,-1]).sum()/(pred_pose_seq[:,:,-1] ** 2).sum() 189 | for gt_pose, pred_pose in zip(gt_pose_seq, pred_pose_seq): 190 | ATE += ((gt_pose[:,-1] - scale_factor * pred_pose[:,-1]).norm(p=2))/seq_length 191 | 192 | # Residual matrix to which we compute angle's sin and cos 193 | R = gt_pose[:,:3] @ pred_pose[:,:3].inverse() 194 | s = torch.stack([R[0,1]-R[1,0],R[1,2]-R[2,1],R[0,2]-R[2,0]]).norm(p=2) 195 | c = R.trace() - 1 196 | 197 | # Note: we actually compute double of cos and sin, but arctan2 is invariant to scale 198 | RE += torch.atan2(s,c)/seq_length 199 | 200 | return [ATE/batch_size, RE/batch_size] 201 | -------------------------------------------------------------------------------- /models/DepthNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.utils import conv, deconv, predict_depth, post_process_depth, adaptative_cat, init_modules 5 | 6 | 7 | class DepthNet(nn.Module): 8 | 9 | def __init__(self, batch_norm=False, clamp=False, depth_activation=None, input_size=None, upscale=False): 10 | super(DepthNet, self).__init__() 11 | 12 | self.clamp = clamp 13 | if depth_activation == 'elu': 14 | self.depth_activation = lambda x: nn.functional.elu(x) + 1.01 15 | elif depth_activation == 'sigmoid': 16 | self.depth_activation = lambda x: 100 * nn.functional.sigmoid(0.1*x) + 0.01 17 | else: 18 | self.depth_activation = depth_activation 19 | 20 | self.input_size = input_size 21 | self.upscale = upscale 22 | 23 | self.conv1 = conv( 6, 32, stride=2, batch_norm=batch_norm) 24 | self.conv2 = conv( 32, 64, stride=2, batch_norm=batch_norm) 25 | self.conv3 = conv( 64, 128, stride=2, batch_norm=batch_norm) 26 | self.conv3_1 = conv(128, 128, batch_norm=batch_norm) 27 | self.conv4 = conv(128, 256, stride=2, batch_norm=batch_norm) 28 | self.conv4_1 = conv(256, 256, batch_norm=batch_norm) 29 | self.conv5 = conv(256, 256, stride=2, batch_norm=batch_norm) 30 | self.conv5_1 = conv(256, 256, batch_norm=batch_norm) 31 | self.conv6 = conv(256, 512, stride=2, batch_norm=batch_norm) 32 | self.conv6_1 = conv(512, 512, batch_norm=batch_norm) 33 | 34 | self.deconv5 = deconv(512, 256, batch_norm=batch_norm) 35 | self.deconv4 = deconv(513, 128, batch_norm=batch_norm) 36 | self.deconv3 = deconv(385, 64, batch_norm=batch_norm) 37 | self.deconv2 = deconv(193, 32, batch_norm=batch_norm) 38 | 39 | self.predict_depth6 = predict_depth(512) 40 | self.predict_depth5 = predict_depth(513) 41 | self.predict_depth4 = predict_depth(385) 42 | self.predict_depth3 = predict_depth(193) 43 | self.predict_depth2 = predict_depth( 97) 44 | 45 | self.upsampled_depth6_to_5 = nn.ConvTranspose2d(1, 1, 4, 2, 1, bias=False) 46 | self.upsampled_depth5_to_4 = nn.ConvTranspose2d(1, 1, 4, 2, 1, bias=False) 47 | self.upsampled_depth4_to_3 = nn.ConvTranspose2d(1, 1, 4, 2, 1, bias=False) 48 | self.upsampled_depth3_to_2 = nn.ConvTranspose2d(1, 1, 4, 2, 1, bias=False) 49 | 50 | init_modules(self) 51 | 52 | def forward(self, x): 53 | *_, ih, iw = x.shape 54 | if self.input_size: 55 | h,w = self.input_size 56 | x = F.interpolate(x,(h, w), mode='area') 57 | out_conv2 = self.conv2(self.conv1(x)) 58 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 59 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 60 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 61 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 62 | 63 | out6 = self.predict_depth6(out_conv6) 64 | #depth6 = post_process_depth(out6, self.depth_activation) 65 | depth6_up = self.upsampled_depth6_to_5(out6) 66 | out_deconv5 = self.deconv5(out_conv6) 67 | 68 | concat5 = adaptative_cat(out_conv5, out_deconv5, depth6_up) 69 | out5 = self.predict_depth5(concat5) 70 | #depth5 = post_process_depth(out5, self.depth_activation) 71 | depth5_up = self.upsampled_depth5_to_4(out5) 72 | out_deconv4 = self.deconv4(concat5) 73 | 74 | concat4 = adaptative_cat(out_conv4, out_deconv4, depth5_up) 75 | out4 = self.predict_depth4(concat4) 76 | depth4 = post_process_depth(out4, self.depth_activation) 77 | depth4_up = self.upsampled_depth4_to_3(out4) 78 | out_deconv3 = self.deconv3(concat4) 79 | 80 | concat3 = adaptative_cat(out_conv3, out_deconv3, depth4_up) 81 | out3 = self.predict_depth3(concat3) 82 | depth3 = post_process_depth(out3, self.depth_activation) 83 | depth3_up = self.upsampled_depth3_to_2(out3) 84 | out_deconv2 = self.deconv2(concat3) 85 | 86 | concat2 = adaptative_cat(out_conv2, out_deconv2, depth3_up) 87 | out2 = self.predict_depth2(concat2) 88 | depth2 = post_process_depth(out2, self.depth_activation) 89 | depth0 = F.interpolate(depth2, (ih, iw), mode='bilinear', align_corners=False) 90 | 91 | if self.training: 92 | if self.upscale: 93 | depth4_upscaled = F.interpolate(depth4, (ih, iw), mode='bilinear', align_corners=False) 94 | depth3_upscaled = F.interpolate(depth3, (ih, iw), mode='bilinear', align_corners=False) 95 | return depth0, depth3_upscaled, depth4_upscaled 96 | else: 97 | return depth0, depth2, depth3, depth4 98 | else: 99 | return depth0 100 | -------------------------------------------------------------------------------- /models/DispNetS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.utils import init_modules 5 | 6 | 7 | def downsample_conv(in_planes, out_planes, kernel_size=3): 8 | return nn.Sequential( 9 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2), 10 | nn.ReLU(inplace=True), 11 | nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2), 12 | nn.ReLU(inplace=True) 13 | ) 14 | 15 | 16 | def predict_disp(in_planes): 17 | return nn.Sequential( 18 | nn.Conv2d(in_planes, 1, kernel_size=3, padding=1), 19 | nn.Sigmoid() 20 | ) 21 | 22 | 23 | def conv(in_planes, out_planes): 24 | return nn.Sequential( 25 | nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), 26 | nn.ReLU(inplace=True) 27 | ) 28 | 29 | 30 | def upconv(in_planes, out_planes): 31 | return nn.Sequential( 32 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1), 33 | nn.ReLU(inplace=True) 34 | ) 35 | 36 | 37 | def crop_like(input, ref): 38 | assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) 39 | return input[:, :, :ref.size(2), :ref.size(3)] 40 | 41 | 42 | class DispNetS(nn.Module): 43 | 44 | def __init__(self, alpha=10, beta=0.01): 45 | super(DispNetS, self).__init__() 46 | 47 | self.alpha = alpha 48 | self.beta = beta 49 | 50 | conv_planes = [32, 64, 128, 256, 512, 512, 512] 51 | self.conv1 = downsample_conv(6, conv_planes[0], kernel_size=7) 52 | self.conv2 = downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5) 53 | self.conv3 = downsample_conv(conv_planes[1], conv_planes[2]) 54 | self.conv4 = downsample_conv(conv_planes[2], conv_planes[3]) 55 | self.conv5 = downsample_conv(conv_planes[3], conv_planes[4]) 56 | self.conv6 = downsample_conv(conv_planes[4], conv_planes[5]) 57 | self.conv7 = downsample_conv(conv_planes[5], conv_planes[6]) 58 | 59 | upconv_planes = [512, 512, 256, 128, 64, 32, 16] 60 | self.upconv7 = upconv(conv_planes[6], upconv_planes[0]) 61 | self.upconv6 = upconv(upconv_planes[0], upconv_planes[1]) 62 | self.upconv5 = upconv(upconv_planes[1], upconv_planes[2]) 63 | self.upconv4 = upconv(upconv_planes[2], upconv_planes[3]) 64 | self.upconv3 = upconv(upconv_planes[3], upconv_planes[4]) 65 | self.upconv2 = upconv(upconv_planes[4], upconv_planes[5]) 66 | self.upconv1 = upconv(upconv_planes[5], upconv_planes[6]) 67 | 68 | self.iconv7 = conv(upconv_planes[0] + conv_planes[5], upconv_planes[0]) 69 | self.iconv6 = conv(upconv_planes[1] + conv_planes[4], upconv_planes[1]) 70 | self.iconv5 = conv(upconv_planes[2] + conv_planes[3], upconv_planes[2]) 71 | self.iconv4 = conv(upconv_planes[3] + conv_planes[2], upconv_planes[3]) 72 | self.iconv3 = conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4]) 73 | self.iconv2 = conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5]) 74 | self.iconv1 = conv(1 + upconv_planes[6], upconv_planes[6]) 75 | 76 | self.predict_disp4 = predict_disp(upconv_planes[3]) 77 | self.predict_disp3 = predict_disp(upconv_planes[4]) 78 | self.predict_disp2 = predict_disp(upconv_planes[5]) 79 | self.predict_disp1 = predict_disp(upconv_planes[6]) 80 | 81 | init_modules(self) 82 | 83 | def forward(self, x): 84 | out_conv1 = self.conv1(x) 85 | out_conv2 = self.conv2(out_conv1) 86 | out_conv3 = self.conv3(out_conv2) 87 | out_conv4 = self.conv4(out_conv3) 88 | out_conv5 = self.conv5(out_conv4) 89 | out_conv6 = self.conv6(out_conv5) 90 | out_conv7 = self.conv7(out_conv6) 91 | 92 | out_upconv7 = crop_like(self.upconv7(out_conv7), out_conv6) 93 | concat7 = torch.cat((out_upconv7, out_conv6), 1) 94 | out_iconv7 = self.iconv7(concat7) 95 | 96 | out_upconv6 = crop_like(self.upconv6(out_iconv7), out_conv5) 97 | concat6 = torch.cat((out_upconv6, out_conv5), 1) 98 | out_iconv6 = self.iconv6(concat6) 99 | 100 | out_upconv5 = crop_like(self.upconv5(out_iconv6), out_conv4) 101 | concat5 = torch.cat((out_upconv5, out_conv4), 1) 102 | out_iconv5 = self.iconv5(concat5) 103 | 104 | out_upconv4 = crop_like(self.upconv4(out_iconv5), out_conv3) 105 | concat4 = torch.cat((out_upconv4, out_conv3), 1) 106 | out_iconv4 = self.iconv4(concat4) 107 | disp4 = self.alpha * self.predict_disp4(out_iconv4) + self.beta 108 | 109 | out_upconv3 = crop_like(self.upconv3(out_iconv4), out_conv2) 110 | disp4_up = crop_like(F.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2) 111 | concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1) 112 | out_iconv3 = self.iconv3(concat3) 113 | disp3 = self.alpha * self.predict_disp3(out_iconv3) + self.beta 114 | 115 | out_upconv2 = crop_like(self.upconv2(out_iconv3), out_conv1) 116 | disp3_up = crop_like(F.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) 117 | concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) 118 | out_iconv2 = self.iconv2(concat2) 119 | disp2 = self.alpha * self.predict_disp2(out_iconv2) + self.beta 120 | 121 | out_upconv1 = crop_like(self.upconv1(out_iconv2), x) 122 | disp2_up = crop_like(F.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x) 123 | concat1 = torch.cat((out_upconv1, disp2_up), 1) 124 | out_iconv1 = self.iconv1(concat1) 125 | disp1 = self.alpha * self.predict_disp1(out_iconv1) + self.beta 126 | 127 | if self.training: 128 | return disp1, disp2, disp3, disp4 129 | else: 130 | return disp1 131 | -------------------------------------------------------------------------------- /models/PoseNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .utils import conv,init_modules 5 | 6 | 7 | class PoseNet(nn.Module): 8 | 9 | def __init__(self, seq_length=3, batch_norm=False, input_size=None): 10 | super(PoseNet, self).__init__() 11 | self.seq_length = seq_length 12 | self.input_size = input_size 13 | 14 | conv_planes = [16, 32, 64, 128, 256, 256, 256] 15 | self.conv1 = conv(3*self.seq_length, conv_planes[0], kernel_size=7, batch_norm=batch_norm, stride=2) 16 | self.conv2 = conv(conv_planes[0], conv_planes[1], kernel_size=5, batch_norm=batch_norm, stride=2) 17 | self.conv3 = conv(conv_planes[1], conv_planes[2], batch_norm=batch_norm, stride=2) 18 | self.conv4 = conv(conv_planes[2], conv_planes[3], batch_norm=batch_norm, stride=2) 19 | self.conv5 = conv(conv_planes[3], conv_planes[4], batch_norm=batch_norm, stride=2) 20 | self.conv6 = conv(conv_planes[4], conv_planes[5], batch_norm=batch_norm, stride=2) 21 | self.conv7 = conv(conv_planes[5], conv_planes[6], batch_norm=batch_norm, stride=2) 22 | 23 | self.pose_pred = nn.Conv2d(conv_planes[6], 6*(self.seq_length - 1), kernel_size=1, padding=0) 24 | init_modules(self) 25 | 26 | def forward(self, img_sequence): 27 | b, s, c, h, w = img_sequence.size() 28 | concatenated_imgs = img_sequence.view(b, s*c, h, w) 29 | 30 | if self.input_size: 31 | h,w = self.input_size 32 | concatenated_imgs = F.interpolate(concatenated_imgs,(h, w), mode='area') 33 | 34 | out_conv1 = self.conv1(concatenated_imgs) 35 | out_conv2 = self.conv2(out_conv1) 36 | out_conv3 = self.conv3(out_conv2) 37 | out_conv4 = self.conv4(out_conv3) 38 | out_conv5 = self.conv5(out_conv4) 39 | out_conv6 = self.conv6(out_conv5) 40 | out_conv7 = self.conv7(out_conv6) 41 | 42 | pose = self.pose_pred(out_conv7) 43 | pose = pose.mean(3).mean(2) 44 | pose = pose.view(pose.size(0), self.seq_length - 1, 6) 45 | pose = 0.1 * torch.cat([pose, pose[:,:1].detach()*0], dim=1) # last frame is the Neutral position 46 | 47 | return pose -------------------------------------------------------------------------------- /models/UpSampleNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class UpSampleNet(nn.Module): 7 | 8 | def __init__(self, network, input_size=None): 9 | super(UpSampleNet, self).__init__() 10 | self.network = network 11 | self.input_size = input_size 12 | 13 | def forward(self, x): 14 | x_size = x.size()[-2:] 15 | if self.input_size is None: 16 | self.input_size = x_size 17 | downscaled_x = F.interpolate(x, self.input_size, mode='area') 18 | output = self.network(downscaled_x) 19 | 20 | if isinstance(output, tuple): 21 | return (F.interpolate(output[0], x_size, mode='bilinear', align_corners=False), *output) 22 | 23 | else: 24 | return F.interpolate(output, x_size, mode='bilinear', align_corners=False) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .DispNetS import DispNetS 2 | from .PoseNet import PoseNet 3 | from .DepthNet import DepthNet 4 | from .UpSampleNet import UpSampleNet 5 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.init import kaiming_normal_, zeros_, constant_ 6 | 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, batch_norm=False): 9 | if batch_norm: 10 | return nn.Sequential( 11 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False), 12 | nn.BatchNorm2d(out_planes, eps=1e-3), 13 | nn.ReLU(inplace=True) 14 | ) 15 | else: 16 | return nn.Sequential( 17 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | 22 | def deconv(in_planes, out_planes, batch_norm=False): 23 | if batch_norm: 24 | return nn.Sequential( 25 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 26 | nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False), 27 | nn.BatchNorm2d(out_planes, eps=1e-3), 28 | nn.ReLU(inplace=True) 29 | ) 30 | else: 31 | return nn.Sequential( 32 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 33 | nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=True), 34 | nn.ReLU(inplace=True) 35 | ) 36 | 37 | 38 | def predict_depth(in_planes): 39 | return nn.Conv2d(in_planes, 1, kernel_size=3, stride=1, padding=1, bias=True) 40 | 41 | 42 | def post_process_depth(out, activation_function=None): 43 | depth = out 44 | if activation_function is not None: 45 | depth = activation_function(depth) 46 | 47 | return depth 48 | 49 | 50 | def adaptative_cat(out_conv, out_deconv, out_depth_up): 51 | out_deconv = out_deconv[:, :, :out_conv.size(2), :out_conv.size(3)] 52 | out_depth_up = out_depth_up[:, :, :out_conv.size(2), :out_conv.size(3)] 53 | return torch.cat((out_conv, out_deconv, out_depth_up), 1) 54 | 55 | 56 | def init_modules(net): 57 | for m in net.modules(): 58 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 59 | kaiming_normal_(m.weight, nonlinearity='relu') 60 | if m.bias is not None: 61 | zeros_(m.bias) 62 | elif isinstance(m, nn.BatchNorm2d): 63 | constant_(m.weight, 1) 64 | zeros_(m.bias) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | imageio 3 | argparse 4 | tensorboardX 5 | blessings 6 | progressbar2 7 | scikit-image 8 | path.py -------------------------------------------------------------------------------- /ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.jit import ScriptModule, script_method, trace 4 | import math 5 | 6 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 7 | 8 | 9 | def create_gaussian_window(window_size, channel): 10 | def _gaussian(window_size, sigma): 11 | gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 12 | return gauss/gauss.sum() 13 | _1D_window = _gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window@(_1D_window.t()).float() 15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 16 | return window 17 | 18 | 19 | class SSIM(ScriptModule): 20 | def __init__(self, window_size=3): 21 | super(SSIM, self).__init__() 22 | 23 | gaussian_img_kernel = {'weight': create_gaussian_window(window_size, 3).float(), 24 | 'bias': torch.zeros(3)} 25 | gaussian_blur = nn.Conv2d(3,3,window_size, padding=window_size//2, groups=3).to(device) 26 | gaussian_blur.load_state_dict(gaussian_img_kernel) 27 | self.gaussian_blur = trace(gaussian_blur, torch.rand(3, 3, 16, 16, dtype=torch.float32, device=device)) 28 | 29 | @script_method 30 | def forward(self, img1, img2): 31 | mu1 = self.gaussian_blur(img1) 32 | mu2 = self.gaussian_blur(img2) 33 | 34 | mu1_sq = mu1.pow(2) 35 | mu2_sq = mu2.pow(2) 36 | mu1_mu2 = mu1*mu2 37 | 38 | sigma1_sq = self.gaussian_blur(img1*img1) - mu1_sq 39 | sigma2_sq = self.gaussian_blur(img2*img2) - mu2_sq 40 | sigma12 = self.gaussian_blur(img1*img2) - mu1_mu2 41 | 42 | C1 = 0.01**2 43 | C2 = 0.03**2 44 | 45 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 46 | return ssim_map -------------------------------------------------------------------------------- /stillbox_eval/depth_evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | from path import Path 4 | from scipy.misc import imread 5 | from tqdm import tqdm 6 | 7 | 8 | class test_framework_stillbox(object): 9 | def __init__(self, root, test_files, seq_length=3, min_depth=1e-3, max_depth=80, step=1): 10 | self.root = root 11 | self.min_depth, self.max_depth = min_depth, max_depth 12 | self.gt_files, self.img_files, self.tgt_indices, self.poses, self.intrinsics = read_scene_data(self.root, test_files, seq_length, step) 13 | 14 | def __getitem__(self, i): 15 | depth = np.load(self.gt_files[i]) 16 | return {'imgs': [imread(img).astype(np.float32) for img in self.img_files[i]], 17 | 'tgt_index': self.tgt_indices[i], 18 | 'path':self.img_files[i][0], 19 | 'gt_depth': depth, 20 | 'poses': self.poses[i], 21 | 'mask': generate_mask(depth, self.min_depth, self.max_depth), 22 | 'intrinsics': self.intrinsics 23 | } 24 | 25 | def __len__(self): 26 | return len(self.img_files) 27 | 28 | 29 | def quat2mat(quat): 30 | w, x, y, z = quat[:,0], quat[:,1], quat[:,2], quat[:,3] 31 | w2, x2, y2, z2 = w**2, x**2, y**2, z**2 32 | wx, wy, wz = w*x, w*y, w*z 33 | xy, xz, yz = x*y, x*z, y*z 34 | 35 | rotMat = np.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 36 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 37 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], axis=1).reshape(quat.shape[0], 3, 3) 38 | return rotMat 39 | 40 | 41 | def transform_from_rot_trans(R, t): 42 | """Transforation matrix from rotation matrix and translation vector.""" 43 | sl = R.shape[0] 44 | R = R.reshape(sl, 3, 3) 45 | t = t.reshape(sl, 3, 1) 46 | filler = np.array([0,0,0,1]).reshape(1,1,4).repeat(sl, axis=0) 47 | basis = np.concatenate([R,t], axis=-1) 48 | final_matrix = np.concatenate([basis, filler], axis=1) 49 | return final_matrix 50 | 51 | 52 | def get_poses(scene, indices): 53 | nominal_displacement = np.array(scene['speed']) * scene['time_step'] 54 | sl = len(indices) 55 | if len(scene['orientation']) == 0: 56 | scene_quaternions = np.array([1,0,0,0]).reshape(1,4).repeat(sl, axis=0) 57 | else: 58 | scene_quaternions = np.array(scene['orientation'])[indices] 59 | t = np.array(indices).reshape(sl, 1) * nominal_displacement 60 | R = quat2mat(scene_quaternions).reshape(sl, 3, 3) 61 | matrices_seq = transform_from_rot_trans(R,t) 62 | matrices_seq = np.linalg.inv(matrices_seq[-1]) @ matrices_seq 63 | return matrices_seq[:,:3].astype(np.float32) 64 | 65 | 66 | def read_scene_data(data_root, test_list, seq_length=3, step=1): 67 | data_root = Path(data_root) 68 | metadata_files = {} 69 | intrinsics = None 70 | for folder in data_root.dirs(): 71 | with open(folder/'metadata.json', 'r') as f: 72 | metadata_files[str(folder.name)] = json.load(f) 73 | if intrinsics is None: 74 | args = metadata_files[str(folder.name)]['args'] 75 | hfov = args['fov'] 76 | w,h = args['resolution'] 77 | f = w/(2*np.tan(np.pi*hfov/360)) 78 | intrinsics = np.array([[f, 0, w/2], 79 | [0, f, h/2], 80 | [0, 0, 1]]).astype(np.float32) 81 | gt_files = [] 82 | im_files = [] 83 | poses = [] 84 | tgt_indices = [] 85 | shift_range = step * (np.arange(seq_length)) 86 | 87 | print('getting test metadata ... ') 88 | for sample in tqdm(test_list): 89 | folder, file = sample.split('/') 90 | _, scene_index, index = file[:-4].split('_') # filename is in the form 'RGB_XXXX_XX.jpg' 91 | index = int(index) 92 | scene = metadata_files[folder]['scenes'][int(scene_index)] 93 | scene_length = len(scene['imgs']) 94 | tgt_img_path = data_root/sample 95 | folder_path = data_root/folder 96 | if tgt_img_path.isfile(): 97 | # if index is high enough, take only frames before. Otherwise, take only frames after. 98 | if index - shift_range[-1] > 0: 99 | ref_indices = index + shift_range - shift_range[-1] 100 | tgt_index = seq_length - 1 101 | elif index + shift_range[-1] < scene_length: 102 | ref_indices = index + shift_range 103 | tgt_index = 0 104 | else: 105 | raise 106 | tgt_indices.append(tgt_index) 107 | imgs_path = [folder_path/'{}'.format(scene['imgs'][ref_index]) for ref_index in ref_indices] 108 | 109 | gt_files.append(folder_path/'{}'.format(scene['depth'][index])) 110 | im_files.append(imgs_path) 111 | poses.append(get_poses(scene, ref_indices)) 112 | else: 113 | print('{} missing'.format(tgt_img_path)) 114 | 115 | return gt_files, im_files, tgt_indices, poses, intrinsics 116 | 117 | 118 | def generate_mask(gt_depth, min_depth, max_depth): 119 | mask = np.logical_and(gt_depth > min_depth, 120 | gt_depth < max_depth) 121 | # crop gt to exclude border values 122 | # if used on gt_size 100x100 produces a crop of [-95, -5, 5, 95] 123 | gt_height, gt_width = gt_depth.shape 124 | crop = np.array([0.05 * gt_height, 0.95 * gt_height, 125 | 0.05 * gt_width, 0.95 * gt_width]).astype(np.int32) 126 | 127 | crop_mask = np.zeros(mask.shape) 128 | crop_mask[crop[0]:crop[1],crop[2]:crop[3]] = 1 129 | mask = np.logical_and(mask, crop_mask) 130 | return mask 131 | -------------------------------------------------------------------------------- /stillbox_eval/test_files_80.txt: -------------------------------------------------------------------------------- 1 | 15/RGB_0112_08.jpg 2 | 15/RGB_0178_02.jpg 3 | 15/RGB_0167_06.jpg 4 | 15/RGB_0153_07.jpg 5 | 15/RGB_0119_02.jpg 6 | 15/RGB_0135_03.jpg 7 | 15/RGB_0044_06.jpg 8 | 15/RGB_0032_02.jpg 9 | 15/RGB_0171_01.jpg 10 | 15/RGB_0114_09.jpg 11 | 15/RGB_0089_03.jpg 12 | 15/RGB_0197_09.jpg 13 | 15/RGB_0105_00.jpg 14 | 15/RGB_0072_04.jpg 15 | 15/RGB_0066_03.jpg 16 | 15/RGB_0025_07.jpg 17 | 15/RGB_0058_04.jpg 18 | 15/RGB_0028_03.jpg 19 | 15/RGB_0025_04.jpg 20 | 15/RGB_0140_03.jpg 21 | 15/RGB_0059_08.jpg 22 | 15/RGB_0019_01.jpg 23 | 15/RGB_0186_03.jpg 24 | 15/RGB_0113_09.jpg 25 | 15/RGB_0054_02.jpg 26 | 15/RGB_0130_03.jpg 27 | 15/RGB_0153_03.jpg 28 | 15/RGB_0103_07.jpg 29 | 15/RGB_0004_07.jpg 30 | 15/RGB_0110_08.jpg 31 | 15/RGB_0078_05.jpg 32 | 15/RGB_0026_05.jpg 33 | 15/RGB_0043_07.jpg 34 | 15/RGB_0190_03.jpg 35 | 15/RGB_0122_02.jpg 36 | 15/RGB_0102_08.jpg 37 | 15/RGB_0187_04.jpg 38 | 15/RGB_0003_05.jpg 39 | 15/RGB_0058_07.jpg 40 | 15/RGB_0037_04.jpg 41 | 15/RGB_0125_03.jpg 42 | 15/RGB_0190_02.jpg 43 | 15/RGB_0052_06.jpg 44 | 15/RGB_0037_05.jpg 45 | 15/RGB_0196_01.jpg 46 | 15/RGB_0053_03.jpg 47 | 15/RGB_0129_08.jpg 48 | 15/RGB_0074_03.jpg 49 | 15/RGB_0167_00.jpg 50 | 15/RGB_0195_02.jpg 51 | 15/RGB_0010_07.jpg 52 | 15/RGB_0131_03.jpg 53 | 15/RGB_0037_03.jpg 54 | 15/RGB_0038_09.jpg 55 | 15/RGB_0115_04.jpg 56 | 15/RGB_0091_08.jpg 57 | 15/RGB_0043_04.jpg 58 | 15/RGB_0187_05.jpg 59 | 15/RGB_0112_03.jpg 60 | 15/RGB_0019_02.jpg 61 | 15/RGB_0170_08.jpg 62 | 15/RGB_0017_00.jpg 63 | 15/RGB_0062_05.jpg 64 | 15/RGB_0148_04.jpg 65 | 15/RGB_0012_08.jpg 66 | 15/RGB_0169_04.jpg 67 | 15/RGB_0112_04.jpg 68 | 15/RGB_0071_01.jpg 69 | 15/RGB_0103_01.jpg 70 | 15/RGB_0178_05.jpg 71 | 15/RGB_0092_06.jpg 72 | 15/RGB_0040_09.jpg 73 | 15/RGB_0138_06.jpg 74 | 15/RGB_0146_05.jpg 75 | 15/RGB_0004_06.jpg 76 | 15/RGB_0002_08.jpg 77 | 15/RGB_0101_09.jpg 78 | 15/RGB_0103_09.jpg 79 | 15/RGB_0021_02.jpg 80 | 15/RGB_0144_08.jpg 81 | 15/RGB_0163_07.jpg 82 | 15/RGB_0006_01.jpg 83 | 15/RGB_0105_04.jpg 84 | 15/RGB_0199_09.jpg 85 | 15/RGB_0149_05.jpg 86 | 15/RGB_0063_08.jpg 87 | 15/RGB_0021_04.jpg 88 | 15/RGB_0003_02.jpg 89 | 15/RGB_0051_08.jpg 90 | 15/RGB_0110_01.jpg 91 | 15/RGB_0172_09.jpg 92 | 15/RGB_0158_05.jpg 93 | 15/RGB_0049_04.jpg 94 | 15/RGB_0173_08.jpg 95 | 15/RGB_0099_04.jpg 96 | 15/RGB_0024_01.jpg 97 | 15/RGB_0003_09.jpg 98 | 15/RGB_0041_09.jpg 99 | 15/RGB_0091_02.jpg 100 | 15/RGB_0132_01.jpg 101 | 15/RGB_0095_03.jpg 102 | 15/RGB_0167_05.jpg 103 | 15/RGB_0176_00.jpg 104 | 15/RGB_0142_08.jpg 105 | 15/RGB_0107_09.jpg 106 | 15/RGB_0122_05.jpg 107 | 15/RGB_0048_01.jpg 108 | 15/RGB_0103_05.jpg 109 | 15/RGB_0098_09.jpg 110 | 15/RGB_0162_01.jpg 111 | 15/RGB_0008_06.jpg 112 | 15/RGB_0169_02.jpg 113 | 15/RGB_0057_02.jpg 114 | 15/RGB_0086_04.jpg 115 | 15/RGB_0138_01.jpg 116 | 15/RGB_0005_05.jpg 117 | 15/RGB_0095_02.jpg 118 | 15/RGB_0028_02.jpg 119 | 15/RGB_0110_02.jpg 120 | 15/RGB_0102_02.jpg 121 | 15/RGB_0136_09.jpg 122 | 15/RGB_0028_07.jpg 123 | 15/RGB_0043_05.jpg 124 | 15/RGB_0039_06.jpg 125 | 15/RGB_0126_03.jpg 126 | 15/RGB_0062_01.jpg 127 | 15/RGB_0082_03.jpg 128 | 15/RGB_0075_08.jpg 129 | 15/RGB_0016_05.jpg 130 | 15/RGB_0094_05.jpg 131 | 15/RGB_0198_02.jpg 132 | 15/RGB_0090_01.jpg 133 | 15/RGB_0022_01.jpg 134 | 15/RGB_0090_00.jpg 135 | 15/RGB_0155_06.jpg 136 | 15/RGB_0124_07.jpg 137 | 15/RGB_0168_04.jpg 138 | 15/RGB_0096_08.jpg 139 | 15/RGB_0100_02.jpg 140 | 15/RGB_0131_08.jpg 141 | 15/RGB_0074_02.jpg 142 | 15/RGB_0141_07.jpg 143 | 15/RGB_0139_01.jpg 144 | 15/RGB_0102_05.jpg 145 | 15/RGB_0182_09.jpg 146 | 15/RGB_0037_02.jpg 147 | 15/RGB_0067_03.jpg 148 | 15/RGB_0060_01.jpg 149 | 15/RGB_0186_01.jpg 150 | 15/RGB_0171_02.jpg 151 | 15/RGB_0155_04.jpg 152 | 15/RGB_0050_08.jpg 153 | 15/RGB_0034_02.jpg 154 | 15/RGB_0132_03.jpg 155 | 15/RGB_0147_05.jpg 156 | 15/RGB_0099_08.jpg 157 | 15/RGB_0110_00.jpg 158 | 15/RGB_0114_08.jpg 159 | 15/RGB_0159_02.jpg 160 | 15/RGB_0076_07.jpg 161 | 15/RGB_0116_05.jpg 162 | 15/RGB_0067_02.jpg 163 | 15/RGB_0080_03.jpg 164 | 15/RGB_0030_00.jpg 165 | 15/RGB_0137_09.jpg 166 | 15/RGB_0130_02.jpg 167 | 15/RGB_0090_02.jpg 168 | 15/RGB_0034_08.jpg 169 | 15/RGB_0137_07.jpg 170 | 15/RGB_0045_01.jpg 171 | 15/RGB_0131_04.jpg 172 | 15/RGB_0006_00.jpg 173 | 15/RGB_0068_05.jpg 174 | 15/RGB_0104_08.jpg 175 | 15/RGB_0193_08.jpg 176 | 15/RGB_0182_00.jpg 177 | 15/RGB_0129_06.jpg 178 | 15/RGB_0107_05.jpg 179 | 15/RGB_0158_07.jpg 180 | 15/RGB_0192_01.jpg 181 | 15/RGB_0018_05.jpg 182 | 15/RGB_0090_09.jpg 183 | 15/RGB_0018_07.jpg 184 | 15/RGB_0094_00.jpg 185 | 15/RGB_0009_02.jpg 186 | 15/RGB_0094_01.jpg 187 | 15/RGB_0046_04.jpg 188 | 15/RGB_0126_00.jpg 189 | 15/RGB_0146_02.jpg 190 | 15/RGB_0161_06.jpg 191 | 15/RGB_0154_08.jpg 192 | 15/RGB_0094_03.jpg 193 | -------------------------------------------------------------------------------- /stillbox_eval/test_files_90.txt: -------------------------------------------------------------------------------- 1 | 15/RGB_112_008.jpg 2 | 15/RGB_178_002.jpg 3 | 15/RGB_167_006.jpg 4 | 15/RGB_153_007.jpg 5 | 15/RGB_119_002.jpg 6 | 15/RGB_135_003.jpg 7 | 15/RGB_44_006.jpg 8 | 15/RGB_32_002.jpg 9 | 15/RGB_171_001.jpg 10 | 15/RGB_114_009.jpg 11 | 15/RGB_89_003.jpg 12 | 15/RGB_197_009.jpg 13 | 15/RGB_105_000.jpg 14 | 15/RGB_72_004.jpg 15 | 15/RGB_66_003.jpg 16 | 15/RGB_25_007.jpg 17 | 15/RGB_58_004.jpg 18 | 15/RGB_28_003.jpg 19 | 15/RGB_25_004.jpg 20 | 15/RGB_140_003.jpg 21 | 15/RGB_59_008.jpg 22 | 15/RGB_19_001.jpg 23 | 15/RGB_186_003.jpg 24 | 15/RGB_113_009.jpg 25 | 15/RGB_54_002.jpg 26 | 15/RGB_130_003.jpg 27 | 15/RGB_153_003.jpg 28 | 15/RGB_103_007.jpg 29 | 15/RGB_04_007.jpg 30 | 15/RGB_110_008.jpg 31 | 15/RGB_78_005.jpg 32 | 15/RGB_26_005.jpg 33 | 15/RGB_43_007.jpg 34 | 15/RGB_190_003.jpg 35 | 15/RGB_122_002.jpg 36 | 15/RGB_102_008.jpg 37 | 15/RGB_187_004.jpg 38 | 15/RGB_03_005.jpg 39 | 15/RGB_58_007.jpg 40 | 15/RGB_37_004.jpg 41 | 15/RGB_125_003.jpg 42 | 15/RGB_190_002.jpg 43 | 15/RGB_52_006.jpg 44 | 15/RGB_37_005.jpg 45 | 15/RGB_196_001.jpg 46 | 15/RGB_53_003.jpg 47 | 15/RGB_129_008.jpg 48 | 15/RGB_74_003.jpg 49 | 15/RGB_167_000.jpg 50 | 15/RGB_195_002.jpg 51 | 15/RGB_10_007.jpg 52 | 15/RGB_131_003.jpg 53 | 15/RGB_37_003.jpg 54 | 15/RGB_38_009.jpg 55 | 15/RGB_115_004.jpg 56 | 15/RGB_91_008.jpg 57 | 15/RGB_43_004.jpg 58 | 15/RGB_187_005.jpg 59 | 15/RGB_112_003.jpg 60 | 15/RGB_19_002.jpg 61 | 15/RGB_170_008.jpg 62 | 15/RGB_17_000.jpg 63 | 15/RGB_62_005.jpg 64 | 15/RGB_148_004.jpg 65 | 15/RGB_12_008.jpg 66 | 15/RGB_169_004.jpg 67 | 15/RGB_112_004.jpg 68 | 15/RGB_71_001.jpg 69 | 15/RGB_103_001.jpg 70 | 15/RGB_178_005.jpg 71 | 15/RGB_92_006.jpg 72 | 15/RGB_40_009.jpg 73 | 15/RGB_138_006.jpg 74 | 15/RGB_146_005.jpg 75 | 15/RGB_04_006.jpg 76 | 15/RGB_02_008.jpg 77 | 15/RGB_101_009.jpg 78 | 15/RGB_103_009.jpg 79 | 15/RGB_21_002.jpg 80 | 15/RGB_144_008.jpg 81 | 15/RGB_163_007.jpg 82 | 15/RGB_06_001.jpg 83 | 15/RGB_105_004.jpg 84 | 15/RGB_199_009.jpg 85 | 15/RGB_149_005.jpg 86 | 15/RGB_63_008.jpg 87 | 15/RGB_21_004.jpg 88 | 15/RGB_03_002.jpg 89 | 15/RGB_51_008.jpg 90 | 15/RGB_110_001.jpg 91 | 15/RGB_172_009.jpg 92 | 15/RGB_158_005.jpg 93 | 15/RGB_49_004.jpg 94 | 15/RGB_173_008.jpg 95 | 15/RGB_99_004.jpg 96 | 15/RGB_24_001.jpg 97 | 15/RGB_03_009.jpg 98 | 15/RGB_41_009.jpg 99 | 15/RGB_91_002.jpg 100 | 15/RGB_132_001.jpg 101 | 15/RGB_95_003.jpg 102 | 15/RGB_167_005.jpg 103 | 15/RGB_176_000.jpg 104 | 15/RGB_142_008.jpg 105 | 15/RGB_107_009.jpg 106 | 15/RGB_122_005.jpg 107 | 15/RGB_48_001.jpg 108 | 15/RGB_103_005.jpg 109 | 15/RGB_98_009.jpg 110 | 15/RGB_162_001.jpg 111 | 15/RGB_08_006.jpg 112 | 15/RGB_169_002.jpg 113 | 15/RGB_57_002.jpg 114 | 15/RGB_86_004.jpg 115 | 15/RGB_138_001.jpg 116 | 15/RGB_05_005.jpg 117 | 15/RGB_95_002.jpg 118 | 15/RGB_28_002.jpg 119 | 15/RGB_110_002.jpg 120 | 15/RGB_102_002.jpg 121 | 15/RGB_136_009.jpg 122 | 15/RGB_28_007.jpg 123 | 15/RGB_43_005.jpg 124 | 15/RGB_39_006.jpg 125 | 15/RGB_126_003.jpg 126 | 15/RGB_62_001.jpg 127 | 15/RGB_82_003.jpg 128 | 15/RGB_75_008.jpg 129 | 15/RGB_16_005.jpg 130 | 15/RGB_94_005.jpg 131 | 15/RGB_198_002.jpg 132 | 15/RGB_90_001.jpg 133 | 15/RGB_22_001.jpg 134 | 15/RGB_90_000.jpg 135 | 15/RGB_155_006.jpg 136 | 15/RGB_124_007.jpg 137 | 15/RGB_168_004.jpg 138 | 15/RGB_96_008.jpg 139 | 15/RGB_100_002.jpg 140 | 15/RGB_131_008.jpg 141 | 15/RGB_74_002.jpg 142 | 15/RGB_141_007.jpg 143 | 15/RGB_139_001.jpg 144 | 15/RGB_102_005.jpg 145 | 15/RGB_182_009.jpg 146 | 15/RGB_37_002.jpg 147 | 15/RGB_67_003.jpg 148 | 15/RGB_60_001.jpg 149 | 15/RGB_186_001.jpg 150 | 15/RGB_171_002.jpg 151 | 15/RGB_155_004.jpg 152 | 15/RGB_50_008.jpg 153 | 15/RGB_34_002.jpg 154 | 15/RGB_132_003.jpg 155 | 15/RGB_147_005.jpg 156 | 15/RGB_99_008.jpg 157 | 15/RGB_110_000.jpg 158 | 15/RGB_114_008.jpg 159 | 15/RGB_159_002.jpg 160 | 15/RGB_76_007.jpg 161 | 15/RGB_116_005.jpg 162 | 15/RGB_67_002.jpg 163 | 15/RGB_80_003.jpg 164 | 15/RGB_30_000.jpg 165 | 15/RGB_137_009.jpg 166 | 15/RGB_130_002.jpg 167 | 15/RGB_90_002.jpg 168 | 15/RGB_34_008.jpg 169 | 15/RGB_137_007.jpg 170 | 15/RGB_45_001.jpg 171 | 15/RGB_131_004.jpg 172 | 15/RGB_06_000.jpg 173 | 15/RGB_68_005.jpg 174 | 15/RGB_104_008.jpg 175 | 15/RGB_193_008.jpg 176 | 15/RGB_182_000.jpg 177 | 15/RGB_129_006.jpg 178 | 15/RGB_107_005.jpg 179 | 15/RGB_158_007.jpg 180 | 15/RGB_192_001.jpg 181 | 15/RGB_18_005.jpg 182 | 15/RGB_90_009.jpg 183 | 15/RGB_18_007.jpg 184 | 15/RGB_94_000.jpg 185 | 15/RGB_09_002.jpg 186 | 15/RGB_94_001.jpg 187 | 15/RGB_46_004.jpg 188 | 15/RGB_126_000.jpg 189 | 15/RGB_146_002.jpg 190 | 15/RGB_161_006.jpg 191 | 15/RGB_154_008.jpg 192 | 15/RGB_94_003.jpg 193 | -------------------------------------------------------------------------------- /test_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from path import Path 5 | import argparse 6 | from tqdm import tqdm 7 | import imageio 8 | 9 | from models import DepthNet, PoseNet 10 | from inverse_warp import pose_vec2mat, compensate_pose, invert_mat, inverse_rotate 11 | from utils import tensor2array 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Script for DispNet testing with corresponding groundTruth', 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument("--pretrained-depthnet", required=True, type=str, help="pretrained DispNet path") 17 | parser.add_argument("--pretrained-posenet", default=None, type=str, help="pretrained PoseNet path (for scale factor)") 18 | parser.add_argument("--img-height", default=128, type=int, help="Image height") 19 | parser.add_argument("--img-width", default=416, type=int, help="Image width") 20 | parser.add_argument("--no-resize", action='store_true', help="no resizing is done") 21 | parser.add_argument("--min-depth", default=1e-3) 22 | parser.add_argument("--max-depth", default=80, type=float) 23 | parser.add_argument("--stabilize-from-GT", action='store_true') 24 | parser.add_argument("--nominal-displacement", type=float, default=0.3) 25 | parser.add_argument("--output-dir", default='.', type=str, help="Output directory for saving") 26 | parser.add_argument("--log-best-worst", action='store_true', help="if selected, will log depthNet outputs") 27 | parser.add_argument("--save-output", action='store_true', help="if selected, will save all predictions in a big 3D numpy file") 28 | 29 | parser.add_argument("--dataset-dir", default='.', type=str, help="Dataset directory") 30 | parser.add_argument("--dataset-list", default=None, type=str, help="Dataset list file") 31 | 32 | parser.add_argument("--gt-type", default='KITTI', type=str, help="GroundTruth data type", choices=['npy', 'png', 'KITTI', 'stillbox']) 33 | parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob") 34 | parser.add_argument("--rotation-mode", default='euler', choices=['euler', 'quat'], type=str) 35 | 36 | 37 | target_mean_depthnet_output = 50 38 | best_error = np.inf 39 | worst_error = 0 40 | 41 | 42 | def select_best_map(maps, target_mean): 43 | unraveled_maps = maps.view(maps.size(0), -1) 44 | means = unraveled_maps.mean(1) # this should be a 1D tensor 45 | best_index = torch.min((means-target_mean).abs(), 0)[1].item() 46 | best_map = maps[best_index,0] 47 | return best_map, best_index 48 | 49 | 50 | def log_result(pred_depth, GT, input_batch, selected_index, folder, prefix): 51 | def save(path, to_save): 52 | to_save = (255*to_save.transpose(1,2,0)).astype(np.uint8) 53 | imageio.imsave(path, to_save) 54 | pred_to_save = tensor2array(pred_depth, max_value=100) 55 | gt_to_save = tensor2array(torch.from_numpy(GT), max_value=100) 56 | 57 | prefix = folder/prefix 58 | save('{}_depth_pred.jpg'.format(prefix), pred_to_save) 59 | save('{}_depth_gt.jpg'.format(prefix), gt_to_save) 60 | disp_to_save = tensor2array(1/pred_depth, max_value=None, colormap='magma') 61 | gt_disp = np.zeros_like(GT) 62 | valid_depth = GT > 0 63 | gt_disp[valid_depth] = 1/GT[valid_depth] 64 | 65 | gt_disp_to_save = tensor2array(torch.from_numpy(gt_disp), max_value=None, colormap='magma') 66 | save('{}_disp_pred.jpg'.format(prefix), disp_to_save) 67 | save('{}_disp_gt.jpg'.format(prefix), gt_disp_to_save) 68 | to_save = tensor2array(input_batch.cpu().data[selected_index,:3]) 69 | save('{}_input0.jpg'.format(prefix), to_save) 70 | to_save = tensor2array(input_batch.cpu()[selected_index,3:]) 71 | save('{}_input1.jpg'.format(prefix), to_save) 72 | for i, batch_elem in enumerate(input_batch.cpu().data): 73 | to_save = tensor2array(batch_elem[:3]) 74 | save('{}_batch_{}_0.jpg'.format(prefix, i), to_save) 75 | to_save = tensor2array(batch_elem[3:]) 76 | save('{}_batch_{}_1.jpg'.format(prefix, i), to_save) 77 | 78 | 79 | @torch.no_grad() 80 | def main(): 81 | global best_error, worst_error 82 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 83 | args = parser.parse_args() 84 | if args.gt_type == 'KITTI': 85 | from kitti_eval.depth_evaluation_utils import test_framework_KITTI as test_framework 86 | elif args.gt_type == 'stillbox': 87 | from stillbox_eval.depth_evaluation_utils import test_framework_stillbox as test_framework 88 | 89 | weights = torch.load(args.pretrained_depthnet) 90 | depthnet_params = {"depth_activation":"elu", 91 | "batch_norm":"bn" in weights.keys() and weights['bn']} 92 | if not args.no_resize: 93 | depthnet_params['input_size'] = (args.img_height, args.img_width) 94 | depthnet_params['upscale'] = True 95 | 96 | depth_net = DepthNet(**depthnet_params).to(device) 97 | depth_net.load_state_dict(weights['state_dict']) 98 | depth_net.eval() 99 | 100 | if args.pretrained_posenet is None: 101 | args.stabilize_from_GT = True 102 | print('no PoseNet specified, stab will be done from ground truth') 103 | seq_length = 5 104 | else: 105 | weights = torch.load(args.pretrained_posenet) 106 | seq_length = int(weights['state_dict']['conv1.0.weight'].size(1)/3) 107 | posenet_params = {'seq_length':seq_length} 108 | if not args.no_resize: 109 | posenet_params['input_size'] = (args.img_eight, args.img_width) 110 | 111 | pose_net = PoseNet(**posenet_params).to(device) 112 | pose_net.load_state_dict(weights['state_dict'], strict=False) 113 | 114 | dataset_dir = Path(args.dataset_dir) 115 | if args.dataset_list is not None: 116 | with open(args.dataset_list, 'r') as f: 117 | test_files = list(f.read().splitlines()) 118 | else: 119 | test_files = [file.relpathto(dataset_dir) for file in sum([dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts], [])] 120 | 121 | framework = test_framework(dataset_dir, test_files, seq_length, args.min_depth, args.max_depth) 122 | 123 | print('{} files to test'.format(len(test_files))) 124 | errors = np.zeros((9, len(test_files)), np.float32) 125 | 126 | args.output_dir = Path(args.output_dir) 127 | args.output_dir.makedirs_p() 128 | 129 | for j, sample in enumerate(tqdm(framework)): 130 | intrinsics = torch.from_numpy(sample['intrinsics']).unsqueeze(0).to(device) 131 | imgs = sample['imgs'] 132 | imgs = [torch.from_numpy(np.transpose(img, (2,0,1))) for img in imgs] 133 | imgs = torch.stack(imgs).unsqueeze(0).to(device) 134 | imgs = 2*(imgs/255 - 0.5) 135 | 136 | tgt_img = imgs[:,sample['tgt_index']] 137 | 138 | # Construct a batch of all possible stabilized pairs, with PoseNet or with GT orientation, will take the output closest to target mean depth 139 | if args.stabilize_from_GT: 140 | poses_GT = torch.from_numpy(sample['poses']).unsqueeze(0).to(device) 141 | inv_poses_GT = invert_mat(poses_GT) 142 | tgt_pose = inv_poses_GT[:,sample['tgt_index']] 143 | inv_transform_matrices_tgt = compensate_pose(inv_poses_GT, tgt_pose) 144 | else: 145 | poses = pose_net(imgs) 146 | inv_transform_matrices = pose_vec2mat(poses, rotation_mode=args.rotation_mode) 147 | 148 | tgt_pose = inv_transform_matrices[:,sample['tgt_index']] 149 | inv_transform_matrices_tgt = compensate_pose(inv_transform_matrices, tgt_pose) 150 | 151 | stabilized_pairs = [] 152 | corresponding_displ = [] 153 | for i in range(seq_length): 154 | if i == sample['tgt_index']: 155 | continue 156 | img = imgs[:,i] 157 | img_pose = inv_transform_matrices_tgt[:,i] 158 | stab_img = inverse_rotate(img, img_pose[:,:,:3], intrinsics) 159 | pair = torch.cat([stab_img, tgt_img], dim=1) # [1, 6, H, W] 160 | stabilized_pairs.append(pair) 161 | 162 | GT_translations = sample['poses'][:,:,-1] 163 | real_displacement = np.linalg.norm(GT_translations[sample['tgt_index']] - GT_translations[i]) 164 | corresponding_displ.append(real_displacement) 165 | stab_batch = torch.cat(stabilized_pairs) # [seq, 6, H, W] 166 | depth_maps = depth_net(stab_batch) # [seq, 1 , H/4, W/4] 167 | 168 | selected_depth, selected_index = select_best_map(depth_maps, target_mean_depthnet_output) 169 | 170 | pred_depth = selected_depth * corresponding_displ[selected_index] / args.nominal_displacement 171 | 172 | if args.save_output: 173 | if j == 0: 174 | predictions = np.zeros((len(test_files), *pred_depth.shape)) 175 | predictions[j] = 1/pred_depth 176 | 177 | gt_depth = sample['gt_depth'] 178 | pred_depth_zoomed = F.interpolate(pred_depth.view(1,1,*pred_depth.shape), 179 | gt_depth.shape[:2], 180 | mode='bilinear', 181 | align_corners=False).clamp(args.min_depth, args.max_depth)[0,0] 182 | if sample['mask'] is not None: 183 | pred_depth_zoomed_masked = pred_depth_zoomed.cpu().numpy()[sample['mask']] 184 | gt_depth = gt_depth[sample['mask']] 185 | errors[:,j] = compute_errors(gt_depth, pred_depth_zoomed_masked) 186 | if args.log_best_worst: 187 | if best_error > errors[0,j]: 188 | best_error = errors[0,j] 189 | log_result(pred_depth_zoomed, sample['gt_depth'], stab_batch, selected_index, args.output_dir, 'best') 190 | if worst_error < errors[0,j]: 191 | worst_error = errors[0,j] 192 | log_result(pred_depth_zoomed, sample['gt_depth'], stab_batch, selected_index, args.output_dir, 'worst') 193 | 194 | mean_errors = errors.mean(1) 195 | error_names = ['mean_abs', 'abs_rel','abs_log','sq_rel','rms','log_rms','a1','a2','a3'] 196 | 197 | print("Results : ") 198 | print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names)) 199 | print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors)) 200 | 201 | if args.save_output: 202 | np.save(args.output_dir/'predictions.npy', predictions) 203 | 204 | 205 | def compute_errors(gt, pred): 206 | thresh = np.maximum((gt / pred), (pred / gt)) 207 | a1 = (thresh < 1.25 ).mean() 208 | a2 = (thresh < 1.25 ** 2).mean() 209 | a3 = (thresh < 1.25 ** 3).mean() 210 | 211 | mabs = np.mean(np.abs(gt - pred)) 212 | rmse = (gt - pred) ** 2 213 | rmse = np.sqrt(rmse.mean()) 214 | 215 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 216 | rmse_log = np.sqrt(rmse_log.mean()) 217 | abs_log = np.mean(np.abs(np.log(gt) - np.log(pred))) 218 | 219 | abs_rel = np.mean(np.abs(gt - pred) / gt) 220 | 221 | sq_rel = np.mean(((gt - pred)**2) / gt) 222 | 223 | return mabs, abs_rel, abs_log, sq_rel, rmse, rmse_log, a1, a2, a3 224 | 225 | 226 | if __name__ == '__main__': 227 | main() 228 | -------------------------------------------------------------------------------- /test_pose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | from scipy.misc import imresize 5 | import numpy as np 6 | from path import Path 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | from models import PoseNet 11 | from inverse_warp import pose_vec2mat, invert_mat, compensate_pose 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Script for DispNet testing with corresponding groundTruth', 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument("pretrained_posenet", type=str, help="pretrained PoseNet path (for scale factor)") 17 | parser.add_argument("--img-height", default=128, type=int, help="Image height") 18 | parser.add_argument("--img-width", default=416, type=int, help="Image width") 19 | parser.add_argument("--no-resize", action='store_true', help="no resizing is done") 20 | parser.add_argument("--min-depth", default=1e-3) 21 | parser.add_argument("--max-depth", default=80) 22 | 23 | parser.add_argument("--dataset-dir", default='.', type=str, help="Dataset directory") 24 | parser.add_argument("--sequences", default=['09'], type=str, nargs='*', help="sequences to test") 25 | parser.add_argument("--output-dir", default=None, type=str, help="Output directory for saving predictions in a big 3D numpy file") 26 | 27 | parser.add_argument("--gt-type", default='KITTI', type=str, help="GroundTruth data type", choices=['npy', 'png', 'KITTI', 'stillbox']) 28 | parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob") 29 | parser.add_argument("--rotation-mode", default='euler', choices=['euler', 'quat'], type=str) 30 | 31 | 32 | @torch.no_grad() 33 | def main(): 34 | args = parser.parse_args() 35 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 36 | if args.gt_type == 'KITTI': 37 | from kitti_eval.pose_evaluation_utils import test_framework_KITTI as test_framework 38 | elif args.gt_type == 'stillbox': 39 | from stillbox_eval.pose_evaluation_utils import test_framework_stillbox as test_framework 40 | 41 | weights = torch.load(args.pretrained_posenet) 42 | seq_length = int(weights['state_dict']['conv1.0.weight'].size(1)/3) 43 | pose_net = PoseNet(seq_length=seq_length).to(device) 44 | pose_net.load_state_dict(weights['state_dict'], strict=False) 45 | 46 | dataset_dir = Path(args.dataset_dir) 47 | framework = test_framework(dataset_dir, args.sequences, seq_length) 48 | 49 | print('{} snippets to test'.format(len(framework))) 50 | errors = np.zeros((len(framework), 2), np.float32) 51 | if args.output_dir is not None: 52 | output_dir = Path(args.output_dir) 53 | output_dir.makedirs_p() 54 | predictions_array = np.zeros((len(framework), seq_length, 3, 4)) 55 | 56 | for j, sample in enumerate(tqdm(framework)): 57 | imgs = sample['imgs'] 58 | 59 | h,w,_ = imgs[0].shape 60 | if (not args.no_resize) and (h != args.img_height or w != args.img_width): 61 | imgs = [imresize(img, (args.img_height, args.img_width)).astype(np.float32) for img in imgs] 62 | 63 | imgs = [torch.from_numpy(np.transpose(img, (2,0,1))) for img in imgs] 64 | imgs = torch.stack(imgs).unsqueeze(0).to(device) 65 | imgs = 2*(imgs/255 - 0.5) 66 | 67 | poses = pose_net(imgs) 68 | 69 | inv_transform_matrices = pose_vec2mat(poses, rotation_mode=args.rotation_mode) 70 | 71 | transform_matrices = invert_mat(inv_transform_matrices) 72 | 73 | # rot_matrices = np.linalg.inv(inv_transform_matrices[:,:,:3]) 74 | # tr_vectors = rot_matrices @ inv_transform_matrices[:,:,-1:] 75 | 76 | # transform_matrices = np.concatenate([rot_matrices, tr_vectors], axis=-1) 77 | 78 | # first_transform = transform_matrices[0] 79 | # final_poses = np.linalg.inv(first_transform[:,:3]) @ transform_matrices 80 | # final_poses[:,:,-1:] -= np.linalg.inv(first_transform[:,:3]) @ first_transform[:,-1:] 81 | 82 | final_poses = compensate_pose(transform_matrices, transform_matrices[:,0])[0].cpu().numpy() 83 | 84 | if args.output_dir is not None: 85 | predictions_array[j] = final_poses 86 | 87 | ATE, RE = compute_pose_error(sample['poses'][1:], final_poses[1:]) 88 | errors[j] = ATE, RE 89 | 90 | mean_errors = errors.mean(0) 91 | std_errors = errors.std(0) 92 | error_names = ['ATE','RE'] 93 | print('') 94 | print("Results") 95 | print("\t {:>10}, {:>10}".format(*error_names)) 96 | print("mean \t {:10.4f}, {:10.4f}".format(*mean_errors)) 97 | print("std \t {:10.4f}, {:10.4f}".format(*std_errors)) 98 | 99 | if args.output_dir is not None: 100 | np.save(output_dir/'predictions.npy', predictions_array) 101 | 102 | 103 | def compute_pose_error(gt, pred): 104 | ATE = 0 105 | RE = 0 106 | snippet_length = gt.shape[0] 107 | scale_factor = np.sum(gt[:,:,-1] * pred[:,:,-1])/np.sum(pred[:,:,-1] ** 2) 108 | for gt_pose, pred_pose in zip(gt, pred): 109 | ATE += np.linalg.norm(gt_pose[:,-1] - scale_factor * pred_pose[:,-1]) 110 | 111 | # Residual matrix to which we compute angle's sin and cos 112 | R = gt_pose[:,:3] @ np.linalg.inv(pred_pose[:,:3]) 113 | s = np.linalg.norm([R[0,1]-R[1,0], 114 | R[1,2]-R[2,1], 115 | R[0,2]-R[2,0]]) 116 | c = np.trace(R) - 1 117 | # Note: we actually compute double of cos and sin, but arctan2 is invariant to scale 118 | RE += np.arctan2(s,c) 119 | 120 | return ATE/snippet_length, RE/snippet_length 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /train_flexible_shifts.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.optim 6 | import torch.utils.data 7 | import torch.nn.functional as F 8 | import models 9 | import train_img_pairs 10 | from inverse_warp import compensate_pose, pose_vec2mat, inverse_rotate 11 | from logger import AverageMeter 12 | 13 | train_img_pairs.parser.add_argument('-d', '--target-mean-depth', type=float, 14 | help='equivalent depth to aim at when adjustting shifts, regarding DepthNet output', 15 | metavar='D', default=40) 16 | train_img_pairs.parser.add_argument('-r', '--recompute-frequency', type=int, 17 | help='Will recompute optimal shifts every R epochs', 18 | metavar='R', default=5) 19 | 20 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 21 | 22 | 23 | def main(): 24 | env = train_img_pairs.prepare_environment() 25 | env['adjust_loader'] = torch.utils.data.DataLoader( 26 | env['train_set'], batch_size=env['args'].batch_size, shuffle=False, 27 | num_workers=0, pin_memory=True) # workers is set to 0 to avoid multiple instances to be modified at the same time 28 | launch_training_flexible_shifts(**env) 29 | 30 | 31 | def launch_training_flexible_shifts(scheduler, **env): 32 | logger = env['logger'] 33 | args = env["args"] 34 | train_set = env["train_set"] 35 | env['best_error'] = -1 36 | env['epoch'] = 0 37 | env['n_iter'] = 0 38 | 39 | if args.pretrained_depth or args.evaluate: 40 | train_img_pairs.validate(**env) 41 | 42 | for epoch in range(1, args.epochs + 1): 43 | env['epoch'] = epoch 44 | scheduler.step() 45 | logger.epoch_bar.update(epoch) 46 | 47 | # train for one epoch 48 | train_loss, env['n_iter'] = train_img_pairs.train_one_epoch(**env) 49 | logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) 50 | 51 | if epoch % args.recompute_frequency == 0: 52 | train_set.adjust = True 53 | average_shifts = adjust_shifts(**env) 54 | shifts_string = ' '.join(['{:.3f}'.format(s) for s in average_shifts]) 55 | logger.train_writer.write(' * adjusted shifts, average shifts are now : {}'.format(shifts_string)) 56 | train_set.adjust = False 57 | 58 | # evaluate on validation set 59 | error = train_img_pairs.validate(**env) 60 | 61 | env['best_error'] = train_img_pairs.finish_epoch(train_loss, error, **env) 62 | logger.epoch_bar.finish() 63 | 64 | 65 | @torch.no_grad() 66 | def adjust_shifts(args, train_set, adjust_loader, depth_net, pose_net, epoch, logger, training_writer, **env): 67 | batch_time = AverageMeter() 68 | data_time = AverageMeter() 69 | new_shifts = AverageMeter(args.sequence_length-1, precision=2) 70 | pose_net.eval() 71 | depth_net.eval() 72 | upsample_depth_net = models.UpSampleNet(depth_net, args.network_input_size) 73 | 74 | end = time.time() 75 | 76 | mid_index = (args.sequence_length - 1)//2 77 | 78 | # we contrain mean value of depth net output from pair 0 and mid_index 79 | target_values = np.arange(-mid_index, mid_index + 1) / (args.target_mean_depth * mid_index) 80 | target_values = 1/np.abs(np.concatenate([target_values[:mid_index], target_values[mid_index + 1:]])) 81 | 82 | logger.reset_train_bar(len(adjust_loader)) 83 | 84 | for i, sample in enumerate(adjust_loader): 85 | index = sample['index'] 86 | 87 | # measure data loading time 88 | data_time.update(time.time() - end) 89 | imgs = torch.stack(sample['imgs'], dim=1).to(device) 90 | intrinsics = sample['intrinsics'].to(device) 91 | intrinsics_inv = sample['intrinsics_inv'].to(device) 92 | 93 | # compute output 94 | batch_size, seq = imgs.size()[:2] 95 | 96 | if args.network_input_size is not None: 97 | h,w = args.network_input_size 98 | downsample_imgs = F.interpolate(imgs, 99 | (3, h, w), 100 | mode='area') 101 | poses = pose_net(downsample_imgs) # [B, seq, 6] 102 | else: 103 | poses = pose_net(imgs) 104 | 105 | pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] 106 | 107 | tgt_imgs = imgs[:, mid_index] # [B, 3, H, W] 108 | tgt_poses = pose_matrices[:, mid_index] # [B, 3, 4] 109 | compensated_poses = compensate_pose(pose_matrices, tgt_poses) # [B, seq, 3, 4] tgt_poses are now neutral pose 110 | 111 | ref_indices = list(range(args.sequence_length)) 112 | ref_indices.remove(mid_index) 113 | 114 | mean_depth_batch = [] 115 | 116 | for ref_index in ref_indices: 117 | prior_imgs = imgs[:, ref_index] 118 | prior_poses = compensated_poses[:, ref_index] # [B, 3, 4] 119 | 120 | prior_imgs_compensated = inverse_rotate(prior_imgs, prior_poses[:,:,:3], intrinsics, intrinsics_inv) 121 | input_pair = torch.cat([prior_imgs_compensated, tgt_imgs], dim=1) # [B, 6, W, H] 122 | 123 | depth = upsample_depth_net(input_pair) # [B, 1, H, W] 124 | mean_depth = depth.view(batch_size, -1).mean(-1).cpu().numpy() # B 125 | mean_depth_batch.append(mean_depth) 126 | 127 | for j, mean_values in zip(index, np.stack(mean_depth_batch, axis=-1)): 128 | ratio = mean_values / target_values # if mean value is too high, raise the shift, lower otherwise 129 | train_set.reset_shifts(j, ratio[:mid_index], ratio[mid_index:]) 130 | new_shifts.update(train_set.get_shifts(j)) 131 | 132 | # measure elapsed time 133 | batch_time.update(time.time() - end) 134 | end = time.time() 135 | 136 | logger.train_bar.update(i) 137 | if i % args.print_freq == 0: 138 | logger.train_writer.write('Adjustement:' 139 | 'Time {} Data {} shifts {}'.format(batch_time, data_time, new_shifts)) 140 | 141 | for i, shift in enumerate(new_shifts.avg): 142 | training_writer.add_scalar('shifts{}'.format(i), shift, epoch) 143 | 144 | return new_shifts.avg 145 | 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /train_img_pairs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import csv 4 | 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim 9 | import torch.utils.data 10 | import torch.nn.functional as F 11 | import custom_transforms 12 | import models 13 | 14 | from collections import OrderedDict 15 | 16 | from utils import tensor2array, save_checkpoint, save_path_formatter, log_output_tensorboard 17 | from inverse_warp import compensate_pose, pose_vec2mat, inverse_rotate, invert_mat 18 | 19 | from loss_functions import photometric_reconstruction_loss, compute_depth_errors, compute_pose_error, grad_diffusion_loss 20 | from logger import TermLogger, AverageMeter 21 | from tensorboardX import SummaryWriter 22 | 23 | parser = argparse.ArgumentParser(description='Structure from Motion Learner training on KITTI and CityScapes Dataset', 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 25 | 26 | parser.add_argument('data', metavar='DIR', 27 | help='path to dataset') 28 | parser.add_argument('--dataset-format', default='KITTI', choices=['KITTI', 'StillBox']) 29 | parser.add_argument('--sequence-length', type=int, metavar='N', help='sequence length for training', default=3) 30 | parser.add_argument('--rotation-mode', type=str, choices=['euler', 'quat'], default='euler', 31 | help='rotation mode for PoseExpnet : euler (yaw,pitch,roll) or quaternion (last 3 coefficients)') 32 | parser.add_argument('--nominal-displacement', type=float, metavar='D', default=0.3, 33 | help='magnitude assumption of DepthNet when given a pair of frames') 34 | parser.add_argument('--supervise-pose', action='store_true', 35 | help='use avalaible gt pose to supervise posenet and perform rotation compensation') 36 | parser.add_argument('--network-input-size', type=int, nargs=2, default=None, 37 | help='size to which images have to be resized before def into network, can only be smaller than raw image size. \ 38 | if not set, will take raw image size') 39 | parser.add_argument('--upscale', action='store_true', help='upscale depth maps from network to match image size \ 40 | if not set, will downscale images to match depth maps') 41 | parser.add_argument('--same-ratio', default=0, type=float, metavar='P', help='probability to pick pairs with the same image, compared to others\ 42 | Only effective after first milestone') 43 | parser.add_argument('--with-gt', action='store_true', help='use ground truth for validation. \ 44 | You need to store depth in npy 2D arrays and pose in 12 columns csv. See data/kitti_raw_loader.py for an example') 45 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 46 | help='number of data loading workers') 47 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 48 | help='number of total epochs to run') 49 | parser.add_argument('--epoch-size', default=1000, type=int, metavar='N', 50 | help='manual epoch size (will match dataset size if set to 0)') 51 | parser.add_argument('--training-milestones', default=[10,20], type=int, metavar='N', nargs=2, 52 | help='epochs at which training switch modes') 53 | parser.add_argument('--lr-decay-frequency', '--lr-df', default=50, type=int, metavar='N', 54 | help='will divide lr by 2 every N epoch') 55 | parser.add_argument('-b', '--batch-size', default=4, type=int, 56 | metavar='N', help='mini-batch size') 57 | parser.add_argument('--lr', '--learning-rate', default=2e-4, type=float, 58 | metavar='LR', help='initial learning rate') 59 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 60 | help='momentum for sgd, alpha parameter for adam') 61 | parser.add_argument('--beta', default=0.999, type=float, metavar='M', 62 | help='beta parameters for adam') 63 | parser.add_argument('--weight-decay', '--wd', default=0, type=float, 64 | metavar='W', help='weight decay') 65 | parser.add_argument('--bn', choices=['none','pose','depth','both'], default='none', 66 | metavar='W', help='To which network batch norm is applied') 67 | parser.add_argument('--print-freq', default=10, type=int, 68 | metavar='N', help='print frequency') 69 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 70 | help='evaluate model on validation set') 71 | parser.add_argument('--pretrained-depth', dest='pretrained_depth', default=None, metavar='PATH', 72 | help='path to pre-trained DepthNet model') 73 | parser.add_argument('--pretrained-pose', dest='pretrained_pose', default=None, metavar='PATH', 74 | help='path to pre-trained Pose net model') 75 | parser.add_argument('--seed', default=0, type=int, help='seed for random functions, and network initialization') 76 | parser.add_argument('--log-summary', default='progress_log_summary.csv', metavar='PATH', 77 | help='csv where to save per-epoch train and valid stats') 78 | parser.add_argument('--log-full', default='progress_log_full.csv', metavar='PATH', 79 | help='csv where to save per-gradient descent train stats') 80 | parser.add_argument('-p', '--photo-loss-weight', type=float, help='weight for photometric loss', metavar='W', default=1) 81 | parser.add_argument('--ssim', type=float, help='weight for SSIM loss', metavar='W', default=0.1) 82 | parser.add_argument('-s', '--smooth-loss-weight', type=float, help='weight for disparity smoothness loss', metavar='W', default=30) 83 | parser.add_argument('--kappa', default=1, type=float, help='kappa parameter for diffusion') 84 | parser.add_argument('--log-output', action='store_true', help='will log dispnet outputs and warped imgs at validation step') 85 | parser.add_argument('--max-depth', type=float, help='value to which depth colormap will be capped to', metavar='M', default=100) 86 | parser.add_argument('-f', '--training-output-freq', type=int, metavar='N', default=0, 87 | help='frequence for outputting dispnet outputs and warped imgs at training for all scales if 0 will not output') 88 | 89 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 90 | 91 | 92 | def main(): 93 | env = prepare_environment() 94 | launch_training(**env) 95 | 96 | 97 | def prepare_environment(): 98 | env = {} 99 | args = parser.parse_args() 100 | if args.dataset_format == 'KITTI': 101 | from datasets.shifted_sequence_folders import ShiftedSequenceFolder 102 | elif args.dataset_format == 'StillBox': 103 | from datasets.shifted_sequence_folders import StillBox as ShiftedSequenceFolder 104 | elif args.dataset_format == 'TUM': 105 | from datasets.shifted_sequence_folders import TUM as ShiftedSequenceFolder 106 | save_path = save_path_formatter(args, parser) 107 | args.save_path = 'checkpoints'/save_path 108 | print('=> will save everything to {}'.format(args.save_path)) 109 | args.save_path.makedirs_p() 110 | torch.manual_seed(args.seed) 111 | 112 | args.test_batch_size = 4*args.batch_size 113 | if args.evaluate: 114 | args.epochs = 0 115 | 116 | env['tb_writer'] = SummaryWriter(args.save_path) 117 | env['sample_nb_to_log'] = 3 118 | 119 | # Data loading code 120 | normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], 121 | std=[0.5, 0.5, 0.5]) 122 | train_transform = custom_transforms.Compose([ 123 | # custom_transforms.RandomHorizontalFlip(), 124 | custom_transforms.ArrayToTensor(), 125 | normalize 126 | ]) 127 | 128 | valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize]) 129 | 130 | print("=> fetching scenes in '{}'".format(args.data)) 131 | train_set = ShiftedSequenceFolder( 132 | args.data, 133 | transform=train_transform, 134 | seed=args.seed, 135 | train=True, 136 | with_depth_gt=False, 137 | with_pose_gt=args.supervise_pose, 138 | sequence_length=args.sequence_length 139 | ) 140 | val_set = ShiftedSequenceFolder( 141 | args.data, 142 | transform=valid_transform, 143 | seed=args.seed, 144 | train=False, 145 | sequence_length=args.sequence_length, 146 | with_depth_gt=args.with_gt, 147 | with_pose_gt=args.with_gt 148 | ) 149 | print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) 150 | print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes))) 151 | train_loader = torch.utils.data.DataLoader( 152 | train_set, batch_size=args.batch_size, shuffle=True, 153 | num_workers=args.workers, pin_memory=True) 154 | val_loader = torch.utils.data.DataLoader( 155 | val_set, batch_size=4*args.batch_size, shuffle=False, 156 | num_workers=args.workers, pin_memory=True) 157 | 158 | env['train_set'] = train_set 159 | env['val_set'] = val_set 160 | env['train_loader'] = train_loader 161 | env['val_loader'] = val_loader 162 | 163 | if args.epoch_size == 0: 164 | args.epoch_size = len(train_loader) 165 | 166 | # create model 167 | print("=> creating model") 168 | pose_net = models.PoseNet(seq_length=args.sequence_length, 169 | batch_norm=args.bn in ['pose', 'both'], 170 | input_size=args.network_input_size).to(device) 171 | 172 | if args.pretrained_pose: 173 | print("=> using pre-trained weights for pose net") 174 | weights = torch.load(args.pretrained_pose) 175 | pose_net.load_state_dict(weights['state_dict'], strict=False) 176 | 177 | depth_net = models.DepthNet(depth_activation="elu", 178 | batch_norm=args.bn in ['depth', 'both'], 179 | input_size=args.network_input_size, 180 | upscale=args.upscale).to(device) 181 | 182 | if args.pretrained_depth: 183 | print("=> using pre-trained DepthNet model") 184 | data = torch.load(args.pretrained_depth) 185 | depth_net.load_state_dict(data['state_dict']) 186 | 187 | cudnn.benchmark = True 188 | depth_net = torch.nn.DataParallel(depth_net) 189 | pose_net = torch.nn.DataParallel(pose_net) 190 | 191 | env['depth_net'] = depth_net 192 | env['pose_net'] = pose_net 193 | 194 | print('=> setting adam solver') 195 | 196 | optim_params = [ 197 | {'params': depth_net.parameters(), 'lr': args.lr}, 198 | {'params': pose_net.parameters(), 'lr': args.lr} 199 | ] 200 | # parameters = chain(depth_net.parameters(), pose_exp_net.parameters()) 201 | optimizer = torch.optim.Adam(optim_params, 202 | betas=(args.momentum, args.beta), 203 | weight_decay=args.weight_decay) 204 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 205 | args.lr_decay_frequency, 206 | gamma=0.5) 207 | env['optimizer'] = optimizer 208 | env['scheduler'] = scheduler 209 | 210 | with open(args.save_path/args.log_summary, 'w') as csvfile: 211 | writer = csv.writer(csvfile, delimiter='\t') 212 | writer.writerow(['train_loss', 'validation_loss']) 213 | 214 | with open(args.save_path/args.log_full, 'w') as csvfile: 215 | writer = csv.writer(csvfile, delimiter='\t') 216 | writer.writerow(['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss']) 217 | 218 | logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) 219 | logger.epoch_bar.start() 220 | env['logger'] = logger 221 | 222 | env['args'] = args 223 | 224 | return env 225 | 226 | 227 | def launch_training(scheduler, **env): 228 | logger = env['logger'] 229 | args = env["args"] 230 | env['best_error'] = -1 231 | env['epoch'] = 0 232 | env['n_iter'] = 0 233 | 234 | if args.pretrained_depth or args.evaluate: 235 | validate(**env) 236 | 237 | for epoch in range(1, args.epochs + 1): 238 | env['epoch'] = epoch 239 | logger.epoch_bar.update(epoch) 240 | 241 | # train for one epoch 242 | train_loss, env['n_iter'] = train_one_epoch(**env) 243 | logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) 244 | scheduler.step() 245 | 246 | # evaluate on validation set 247 | error = validate(**env) 248 | 249 | env['best_error'] = finish_epoch(train_loss, error, **env) 250 | logger.epoch_bar.finish() 251 | 252 | 253 | def finish_epoch(train_loss, error, best_error, args, epoch, depth_net, pose_net, **env): 254 | if best_error < 0: 255 | best_error = error 256 | 257 | # remember lowest error and save checkpoint 258 | is_best = error < best_error 259 | best_error = min(best_error, error) 260 | save_checkpoint( 261 | args.save_path, { 262 | 'epoch': epoch, 263 | 'state_dict': depth_net.module.state_dict(), 264 | 'bn': args.bn in ['depth', 'both'], 265 | 'nominal_displacement': args.nominal_displacement 266 | }, { 267 | 'epoch': epoch, 268 | 'bn': args.bn in ['pose', 'both'], 269 | 'state_dict': pose_net.module.state_dict() 270 | }, 271 | is_best) 272 | 273 | with open(args.save_path/args.log_summary, 'a') as csvfile: 274 | writer = csv.writer(csvfile, delimiter='\t') 275 | writer.writerow([train_loss, error]) 276 | return best_error 277 | 278 | 279 | def train_one_epoch(args, train_loader, 280 | depth_net, pose_net, optimizer, 281 | epoch, n_iter, 282 | logger, tb_writer, **env): 283 | global device 284 | logger.reset_train_bar() 285 | batch_time = AverageMeter() 286 | data_time = AverageMeter() 287 | losses = AverageMeter(precision=4) 288 | w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim 289 | e1, e2 = args.training_milestones 290 | 291 | # switch to train mode 292 | depth_net.train() 293 | pose_net.train() 294 | 295 | end = time.time() 296 | logger.train_bar.update(0) 297 | 298 | for i, sample in enumerate(train_loader): 299 | 300 | log_losses = i > 0 and n_iter % args.print_freq == 0 301 | log_output = args.training_output_freq > 0 and n_iter % args.training_output_freq == 0 302 | 303 | # measure data loading time 304 | data_time.update(time.time() - end) 305 | imgs = torch.stack(sample['imgs'], dim=1).to(device) 306 | intrinsics = sample['intrinsics'].to(device) 307 | 308 | batch_size, seq = imgs.size()[:2] 309 | 310 | if args.network_input_size is not None: 311 | h,w = args.network_input_size 312 | downsample_imgs = F.interpolate(imgs,(3, h, w), mode='area') 313 | poses = pose_net(downsample_imgs) # [B, seq, 6] 314 | else: 315 | poses = pose_net(imgs) 316 | 317 | pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] 318 | 319 | total_indices = torch.arange(seq, dtype=torch.int64, device=device).expand(batch_size, seq) 320 | batch_range = torch.arange(batch_size, dtype=torch.int64, device=device) 321 | 322 | ''' for each element of the batch select a random picture in the sequence to 323 | which we will compute the depth, all poses are then converted so that pose of this 324 | very picture is exactly identity. At first this image is always in the middle of the sequence''' 325 | 326 | if epoch > e2: 327 | tgt_id = torch.randint(0, seq, (batch_size,), device=device) 328 | else: 329 | tgt_id = torch.full_like(batch_range, args.sequence_length//2) 330 | 331 | ref_ids = total_indices[total_indices != tgt_id.unsqueeze(1)].view(batch_size, seq - 1) 332 | 333 | ''' 334 | Select what other picture we are going to feed DepthNet, it must not be the same 335 | as tgt_id. At first, it's always first picture of the sequence, it is randomly chosen when first training milestone is reached 336 | ''' 337 | 338 | if epoch > e1: 339 | probs = torch.ones_like(total_indices, dtype=torch.float32) 340 | probs[batch_range, tgt_id] = args.same_ratio 341 | prior_id = torch.multinomial(probs, 1)[:,0] 342 | else: 343 | prior_id = torch.zeros_like(batch_range) 344 | 345 | # Treat the case of prior_id == tgt_id and the depth must be max_depth, regardless of apparent movement 346 | 347 | tgt_imgs = imgs[batch_range, tgt_id] # [B, 3, H, W] 348 | tgt_poses = pose_matrices[batch_range, tgt_id] # [B, 3, 4] 349 | 350 | prior_imgs = imgs[batch_range, prior_id] 351 | 352 | compensated_poses = compensate_pose(pose_matrices, tgt_poses) # [B, seq, 3, 4] tgt_poses are now neutral pose 353 | prior_poses = compensated_poses[batch_range, prior_id] # [B, 3, 4] 354 | 355 | if args.supervise_pose: 356 | from_GT = invert_mat(sample['pose']).to(device) 357 | compensated_GT_poses = compensate_pose(from_GT, from_GT[batch_range, tgt_id]) 358 | prior_GT_poses = compensated_GT_poses[batch_range, prior_id] 359 | prior_imgs_compensated = inverse_rotate(prior_imgs, prior_GT_poses[:,:,:-1], intrinsics) 360 | else: 361 | prior_imgs_compensated = inverse_rotate(prior_imgs, prior_poses[:,:,:-1], intrinsics) 362 | 363 | input_pair = torch.cat([prior_imgs_compensated, tgt_imgs], dim=1) # [B, 6, W, H] 364 | depth = depth_net(input_pair) 365 | 366 | # depth = [sample['depth'].to(device).unsqueeze(1) * 3 / abs(tgt_id[0] - prior_id[0])] 367 | # depth.append(torch.nn.functional.interpolate(depth[0], scale_factor=2)) 368 | disparities = [1/d for d in depth] 369 | 370 | predicted_magnitude = prior_poses[:, :, -1:].norm(p=2, dim=1, keepdim=True).unsqueeze(1) 371 | scale_factor = args.nominal_displacement / (predicted_magnitude + 1e-5) 372 | normalized_translation = compensated_poses[:, :, :, -1:] * scale_factor # [B, seq_length-1, 3] 373 | new_pose_matrices = torch.cat([compensated_poses[:, :, :, :-1], normalized_translation], dim=-1) 374 | 375 | biggest_scale = depth[0].size(-1) 376 | 377 | # Construct valid sequence to compute photometric error, 378 | # make the rest converge to max_depth because nothing moved 379 | vb = batch_range[prior_id != tgt_id] 380 | same_range = batch_range[prior_id == tgt_id] # batch of still pairs 381 | 382 | loss_1 = 0 383 | loss_1_same = 0 384 | for k, scaled_depth in enumerate(depth): 385 | size_ratio = scaled_depth.size(-1) / biggest_scale 386 | 387 | if len(same_range) > 0: 388 | # Frames are identical. The corresponding depth must be infinite. Here, we set it to max depth 389 | still_depth = scaled_depth[same_range] 390 | loss_same = F.smooth_l1_loss(still_depth/args.max_depth, torch.ones_like(still_depth)) 391 | else: 392 | loss_same = 0 393 | 394 | loss_valid, *to_log = photometric_reconstruction_loss(imgs[vb], tgt_id[vb], ref_ids[vb], 395 | scaled_depth[vb], new_pose_matrices[vb], 396 | intrinsics[vb], 397 | args.rotation_mode, 398 | ssim_weight=w3, 399 | upsample=args.upscale) 400 | 401 | loss_1 += loss_valid * size_ratio 402 | loss_1_same += loss_same * size_ratio 403 | 404 | if log_output and len(vb) > 0: 405 | log_output_tensorboard(tb_writer, "train", 0, k, n_iter, 406 | scaled_depth[0], disparities[k][0], 407 | *to_log) 408 | loss_2 = grad_diffusion_loss(disparities, tgt_imgs, args.kappa) 409 | 410 | loss = w1*(loss_1 + loss_1_same) + w2*loss_2 411 | if args.supervise_pose: 412 | loss += (from_GT[:,:,:,:3] - pose_matrices[:,:,:,:3]).abs().mean() 413 | 414 | if log_losses: 415 | tb_writer.add_scalar('photometric_error', loss_1.item(), n_iter) 416 | tb_writer.add_scalar('disparity_smoothness_loss', loss_2.item(), n_iter) 417 | tb_writer.add_scalar('total_loss', loss.item(), n_iter) 418 | 419 | if log_output and len(vb) > 0: 420 | valid_poses = poses[vb] 421 | nominal_translation_magnitude = valid_poses[:,-2,:3].norm(p=2, dim=-1) 422 | # Log the translation magnitude relative to translation magnitude between last and penultimate frames 423 | # for a perfectly constant displacement magnitude, you should get ratio of 2,3,4 and so forth. 424 | # last pose is always identity and penultimate translation magnitude is always 1, so you don't need to log them 425 | for j in range(args.sequence_length - 2): 426 | trans_mag = valid_poses[:,j,:3].norm(p=2, dim=-1) 427 | tb_writer.add_histogram('tr {}'.format(j), 428 | (trans_mag/nominal_translation_magnitude).detach().cpu().numpy(), 429 | n_iter) 430 | for j in range(args.sequence_length - 1): 431 | # TODO log a better value : this is magnitude of vector (yaw, pitch, roll) which is not a physical value 432 | rot_mag = valid_poses[:,j,3:].norm(p=2, dim=-1) 433 | tb_writer.add_histogram('rot {}'.format(j), 434 | rot_mag.detach().cpu().numpy(), 435 | n_iter) 436 | 437 | tb_writer.add_image('train Input', tensor2array(tgt_imgs[0]), n_iter) 438 | 439 | # record loss for average meter 440 | losses.update(loss.item(), args.batch_size) 441 | 442 | # compute gradient and do Adam step 443 | optimizer.zero_grad() 444 | loss.backward() 445 | optimizer.step() 446 | 447 | # measure elapsed time 448 | batch_time.update(time.time() - end) 449 | end = time.time() 450 | 451 | with open(args.save_path/args.log_full, 'a') as csvfile: 452 | writer = csv.writer(csvfile, delimiter='\t') 453 | writer.writerow([loss.item(), loss_1.item(), loss_2.item()]) 454 | logger.train_bar.update(i+1) 455 | if i % args.print_freq == 0: 456 | logger.train_writer.write('Train: Time {} Data {} Loss {}'.format(batch_time, data_time, losses)) 457 | if i >= args.epoch_size - 1: 458 | break 459 | 460 | n_iter += 1 461 | 462 | return losses.avg[0], n_iter 463 | 464 | 465 | @torch.no_grad() 466 | def validate(tb_writer, **env): 467 | env['logger'].reset_valid_bar() 468 | if env['args'].with_gt: 469 | errors = validate_with_gt(tb_writer=tb_writer, **env) 470 | errors_to_log = list(errors.items())[2:9] 471 | decisive_error = errors["stab abs log"] 472 | else: 473 | errors = validate_without_gt(**env) 474 | errors_to_log = errors.items() 475 | decisive_error = errors["Total Loss"] 476 | 477 | for name, error in errors.items(): 478 | tb_writer.add_scalar(name, error, env['epoch']) 479 | 480 | error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in errors_to_log) 481 | env['logger'].valid_writer.write(' * Avg {}'.format(error_string)) 482 | 483 | return decisive_error 484 | 485 | 486 | def validate_without_gt(args, val_loader, depth_net, pose_net, epoch, logger, tb_writer, sample_nb_to_log, **env): 487 | global device 488 | batch_time = AverageMeter() 489 | losses = AverageMeter(i=3, precision=4) 490 | w1, w2, w3 = args.photo_loss_weight, args.smooth_loss_weight, args.ssim 491 | if args.log_output: 492 | poses_values = np.zeros(((len(val_loader) - 1) * args.test_batch_size * (args.sequence_length-1),6)) 493 | disp_values = np.zeros(((len(val_loader) - 1) * args.test_batch_size * 3)) 494 | 495 | # switch to evaluate mode 496 | depth_net.eval() 497 | pose_net.eval() 498 | 499 | end = time.time() 500 | logger.valid_bar.update(0) 501 | 502 | for i, sample in enumerate(val_loader): 503 | log_output = i < sample_nb_to_log 504 | 505 | imgs = torch.stack(sample['imgs'], dim=1).to(device) 506 | intrinsics = sample['intrinsics'].to(device) 507 | 508 | if epoch == 1 and log_output: 509 | for j,img in enumerate(sample['imgs']): 510 | tb_writer.add_image('val Input/{}'.format(i), tensor2array(img[0]), j) 511 | 512 | batch_size, seq = imgs.size()[:2] 513 | poses = pose_net(imgs) 514 | pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] 515 | 516 | mid_index = (args.sequence_length - 1)//2 517 | 518 | tgt_imgs = imgs[:, mid_index] # [B, 3, H, W] 519 | tgt_poses = pose_matrices[:, mid_index] # [B, 3, 4] 520 | compensated_poses = compensate_pose(pose_matrices, tgt_poses) # [B, seq, 3, 4] tgt_poses are now neutral pose 521 | 522 | ref_ids = list(range(args.sequence_length)) 523 | ref_ids.remove(mid_index) 524 | 525 | loss_1 = 0 526 | loss_2 = 0 527 | 528 | for ref_index in ref_ids: 529 | prior_imgs = imgs[:, ref_index] 530 | prior_poses = compensated_poses[:, ref_index] # [B, 3, 4] 531 | 532 | prior_imgs_compensated = inverse_rotate(prior_imgs, prior_poses[:,:,:3], intrinsics) 533 | input_pair = torch.cat([prior_imgs_compensated, tgt_imgs], dim=1) # [B, 6, W, H] 534 | 535 | predicted_magnitude = prior_poses[:, :, -1:].norm(p=2, dim=1, keepdim=True).unsqueeze(1) # [B, 1, 1, 1] 536 | scale_factor = args.nominal_displacement / predicted_magnitude 537 | normalized_translation = compensated_poses[:, :, :, -1:] * scale_factor # [B, seq, 3, 1] 538 | new_pose_matrices = torch.cat([compensated_poses[:, :, :, :-1], normalized_translation], dim=-1) 539 | 540 | depth = depth_net(input_pair) 541 | disparity = 1/depth 542 | 543 | tgt_id = torch.full((batch_size,), ref_index, dtype=torch.int64, device=device) 544 | ref_ids_tensor = torch.tensor(ref_ids, dtype=torch.int64, device=device).expand(batch_size, -1) 545 | photo_loss, *to_log = photometric_reconstruction_loss(imgs, tgt_id, ref_ids_tensor, 546 | depth, new_pose_matrices, 547 | intrinsics, 548 | args.rotation_mode, 549 | ssim_weight=w3, upsample=args.upscale) 550 | 551 | loss_1 += photo_loss 552 | 553 | if log_output: 554 | log_output_tensorboard(tb_writer, "train", i, ref_index, epoch, 555 | depth[0], disparity[0], 556 | *to_log) 557 | 558 | loss_2 += grad_diffusion_loss(disparity, tgt_imgs, args.kappa) 559 | 560 | if args.log_output and i < len(val_loader)-1: 561 | step = args.test_batch_size * (args.sequence_length-1) 562 | poses_values[i * step:(i+1) * step] = poses[:, :-1].cpu().view(-1,6).numpy() 563 | step = args.test_batch_size * 3 564 | disp_unraveled = disparity.cpu().view(args.test_batch_size, -1) 565 | disp_values[i * step:(i+1) * step] = torch.cat([disp_unraveled.min(-1)[0], 566 | disp_unraveled.median(-1)[0], 567 | disp_unraveled.max(-1)[0]]).numpy() 568 | 569 | loss = w1*loss_1 + w2*loss_2 570 | losses.update([loss.item(), loss_1.item(), loss_2.item()]) 571 | 572 | # measure elapsed time 573 | batch_time.update(time.time() - end) 574 | end = time.time() 575 | logger.valid_bar.update(i+1) 576 | if i % args.print_freq == 0: 577 | logger.valid_writer.write('valid: Time {} Loss {}'.format(batch_time, losses)) 578 | 579 | if args.log_output: 580 | rot_coeffs = ['rx', 'ry', 'rz'] if args.rotation_mode == 'euler' else ['qx', 'qy', 'qz'] 581 | tr_coeffs = ['tx', 'ty', 'tz'] 582 | for k, (coeff_name) in enumerate(tr_coeffs + rot_coeffs): 583 | tb_writer.add_histogram('val poses_{}'.format(coeff_name), poses_values[:,k], epoch) 584 | tb_writer.add_histogram('disp_values', disp_values, epoch) 585 | logger.valid_bar.update(len(val_loader)) 586 | return OrderedDict(zip(['Total loss', 'Photo loss', 'Smooth loss'], losses.avg)) 587 | 588 | 589 | def validate_with_gt(args, val_loader, depth_net, pose_net, epoch, logger, tb_writer, sample_nb_to_log, **env): 590 | global device 591 | batch_time = AverageMeter() 592 | depth_error_names = ['abs diff', 'abs rel', 'abs log', 'a1', 'a2', 'a3'] 593 | stab_depth_errors = AverageMeter(i=len(depth_error_names)) 594 | unstab_depth_errors = AverageMeter(i=len(depth_error_names)) 595 | pose_error_names = ['Absolute Trajectory Error', 'Rotation Error'] 596 | pose_errors = AverageMeter(i=len(pose_error_names)) 597 | 598 | # switch to evaluate mode 599 | depth_net.eval() 600 | pose_net.eval() 601 | 602 | end = time.time() 603 | logger.valid_bar.update(0) 604 | for i, sample in enumerate(val_loader): 605 | log_output = i < sample_nb_to_log 606 | 607 | imgs = torch.stack(sample['imgs'], dim=1).to(device) 608 | 609 | intrinsics = sample['intrinsics'].to(device) 610 | 611 | GT_depth = sample['depth'].to(device) 612 | GT_pose = sample['pose'].to(device) 613 | 614 | batch_size, seq, c, h, w = imgs.shape 615 | dh, dw = GT_depth.shape[-2:] 616 | 617 | mid_index = (args.sequence_length - 1)//2 618 | 619 | tgt_img = imgs[:,mid_index] 620 | 621 | if epoch == 1 and log_output: 622 | for j,img in enumerate(sample['imgs']): 623 | tb_writer.add_image('val Input/{}'.format(i), tensor2array(img[0]), j) 624 | depth_to_show = GT_depth[0].cpu() 625 | # KITTI Like data routine to discard invalid data 626 | depth_to_show[depth_to_show == 0] = 1000 627 | disp_to_show = (1/depth_to_show).clamp(0,10) 628 | tb_writer.add_image('val target Disparity Normalized/{}'.format(i), 629 | tensor2array(disp_to_show, max_value=None, colormap='magma'), 630 | epoch) 631 | 632 | poses = pose_net(imgs) 633 | pose_matrices = pose_vec2mat(poses, args.rotation_mode) # [B, seq, 3, 4] 634 | inverted_pose_matrices = invert_mat(pose_matrices) 635 | ATE, RE = compute_pose_error(GT_pose[:,:-1], inverted_pose_matrices[:,:-1]) 636 | pose_errors.update([ATE.item(), RE.item()]) 637 | 638 | tgt_poses = pose_matrices[:, mid_index] # [B, 3, 4] 639 | compensated_predicted_poses = compensate_pose(pose_matrices, tgt_poses) 640 | compensated_GT_poses = compensate_pose(GT_pose, GT_pose[:,mid_index]) 641 | 642 | for j in range(args.sequence_length): 643 | if j == mid_index: 644 | if log_output and epoch == 1: 645 | tb_writer.add_image('val Input Stabilized/{}'.format(i), tensor2array(sample['imgs'][j][0]), j) 646 | continue 647 | 648 | '''compute displacement magnitude for each element of batch, and rescale 649 | depth accordingly.''' 650 | 651 | prior_img = imgs[:,j] 652 | displacement = compensated_GT_poses[:, j, :, -1] # [B,3] 653 | displacement_magnitude = displacement.norm(p=2, dim=1) # [B] 654 | current_GT_depth = (GT_depth * args.nominal_displacement / displacement_magnitude.view(-1, 1, 1)) 655 | 656 | prior_predicted_pose = compensated_predicted_poses[:, j] # [B, 3, 4] 657 | prior_GT_pose = compensated_GT_poses[:, j] 658 | 659 | prior_predicted_rot = prior_predicted_pose[:,:,:-1] 660 | prior_GT_rot = prior_GT_pose[:,:,:-1].transpose(1,2) 661 | 662 | prior_compensated_from_GT = inverse_rotate(prior_img, 663 | prior_GT_rot, 664 | intrinsics) 665 | if log_output and epoch == 1: 666 | depth_to_show = current_GT_depth[0] 667 | tb_writer.add_image('val target Depth {}/{}'.format(j, i), tensor2array(depth_to_show, max_value=args.max_depth), epoch) 668 | tb_writer.add_image('val Input Stabilized/{}'.format(i), tensor2array(prior_compensated_from_GT[0]), j) 669 | 670 | prior_compensated_from_prediction = inverse_rotate(prior_img, prior_predicted_rot, intrinsics) 671 | predicted_input_pair = torch.cat([prior_compensated_from_prediction, tgt_img], dim=1) # [B, 6, W, H] 672 | GT_input_pair = torch.cat([prior_compensated_from_GT, tgt_img], dim=1) # [B, 6, W, H] 673 | 674 | # This is the depth from footage stabilized with GT pose, it should be better than depth from raw footage without any GT info 675 | raw_depth_stab = depth_net(GT_input_pair) 676 | raw_depth_unstab = depth_net(predicted_input_pair) 677 | 678 | # Upsample depth so that it matches GT size 679 | depth_stab = F.interpolate(raw_depth_stab, (dh, dw), mode='bilinear', align_corners=False) 680 | depth_unstab = F.interpolate(raw_depth_unstab, (dh, dw), mode='bilinear', align_corners=False) 681 | 682 | for k, depth in enumerate([depth_stab, depth_unstab]): 683 | disparity = 1/depth 684 | errors = stab_depth_errors if k == 0 else unstab_depth_errors 685 | errors.update(compute_depth_errors(current_GT_depth, depth, crop=True, max_depth=args.max_depth)) 686 | if log_output: 687 | prefix = 'stabilized' if k == 0 else 'unstabilized' 688 | tb_writer.add_image('val {} Dispnet Output Normalized {}/{}'.format(prefix, j, i), 689 | tensor2array(disparity[0],max_value=None, colormap='magma'), 690 | epoch) 691 | tb_writer.add_image('val {} Depth Output {}/{}'.format(prefix, j, i), 692 | tensor2array(depth[0], max_value=args.max_depth), 693 | epoch) 694 | 695 | # measure elapsed time 696 | batch_time.update(time.time() - end) 697 | end = time.time() 698 | logger.valid_bar.update(i+1) 699 | if i % args.print_freq == 0: 700 | logger.valid_writer.write( 701 | 'valid: Time {} ATE Error {:.4f} ({:.4f}), Unstab Rel Abs Error {:.4f} ({:.4f})'.format( 702 | batch_time, pose_errors.val[0], pose_errors.avg[0], 703 | unstab_depth_errors.val[1], unstab_depth_errors.avg[1]) 704 | ) 705 | logger.valid_bar.update(len(val_loader)) 706 | 707 | errors = (*pose_errors.avg, 708 | *unstab_depth_errors.avg, 709 | *stab_depth_errors.avg) 710 | error_names = (*pose_error_names, 711 | *['unstab {}'.format(e) for e in depth_error_names], 712 | *['stab {}'.format(e) for e in depth_error_names]) 713 | 714 | return OrderedDict(zip(error_names, errors)) 715 | 716 | 717 | if __name__ == '__main__': 718 | main() 719 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import shutil 3 | import numpy as np 4 | import torch 5 | from path import Path 6 | import datetime 7 | from collections import OrderedDict 8 | from matplotlib import cm 9 | from matplotlib.colors import ListedColormap, LinearSegmentedColormap 10 | 11 | 12 | def high_res_colormap(low_res_cmap, resolution=1000, max_value=1): 13 | # Construct the list colormap, with interpolated values for higer resolution 14 | # For a linear segmented colormap, you can just specify the number of point in 15 | # cm.get_cmap(name, lutsize) with the parameter lutsize 16 | x = np.linspace(0,1,low_res_cmap.N) 17 | low_res = low_res_cmap(x) 18 | new_x = np.linspace(0,max_value,resolution) 19 | high_res = np.stack([np.interp(new_x, x, low_res[:,i]) for i in range(low_res.shape[1])], axis=1) 20 | return ListedColormap(high_res) 21 | 22 | 23 | def opencv_rainbow(resolution=1000): 24 | # Construct the opencv equivalent of Rainbow 25 | opencv_rainbow_data = ( 26 | (0.000, (1.00, 0.00, 0.00)), 27 | (0.400, (1.00, 1.00, 0.00)), 28 | (0.600, (0.00, 1.00, 0.00)), 29 | (0.800, (0.00, 0.00, 1.00)), 30 | (1.000, (0.60, 0.00, 1.00)) 31 | ) 32 | 33 | return LinearSegmentedColormap.from_list('opencv_rainbow', opencv_rainbow_data, resolution) 34 | 35 | 36 | COLORMAPS = {'rainbow':opencv_rainbow(), 37 | 'magma':high_res_colormap(cm.get_cmap('magma'))} 38 | 39 | 40 | def save_path_formatter(args, parser): 41 | def is_default(key, value): 42 | return value == parser.get_default(key) 43 | args_dict = vars(args) 44 | data_folder_name = str(Path(args_dict['data']).normpath().name) 45 | folder_string = [data_folder_name] 46 | if not is_default('epochs', args_dict['epochs']): 47 | folder_string.append('{}epochs'.format(args_dict['epochs'])) 48 | keys_with_prefix = OrderedDict() 49 | keys_with_prefix['training_milestones'] = 'mls' 50 | keys_with_prefix['epoch_size'] = 'epoch_size' 51 | keys_with_prefix['sequence_length'] = 'seq' 52 | keys_with_prefix['rotation_mode'] = 'rot_' 53 | keys_with_prefix['batch_size'] = 'b' 54 | keys_with_prefix['lr'] = 'lr' 55 | keys_with_prefix['weight_decay'] = 'wd' 56 | keys_with_prefix['photo_loss_weight'] = 'p' 57 | keys_with_prefix['smooth_loss_weight'] = 's' 58 | keys_with_prefix['nominal_displacement'] = 'nd' 59 | 60 | for key, prefix in keys_with_prefix.items(): 61 | value = args_dict[key] 62 | if not is_default(key, value): 63 | if isinstance(value, list): 64 | value = ','.join(str(v) for v in value) 65 | folder_string.append('{}{}'.format(prefix, value)) 66 | save_path = Path(','.join(folder_string)) 67 | timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M") 68 | return save_path/timestamp 69 | 70 | 71 | def tensor2array(tensor, max_value=255, colormap='rainbow'): 72 | tensor = tensor.detach().cpu() 73 | if max_value is None: 74 | max_value = tensor.max().item() 75 | if tensor.ndimension() == 2 or tensor.size(0) == 1: 76 | norm_array = tensor.squeeze().numpy()/max_value 77 | array = COLORMAPS[colormap](norm_array).astype(np.float32)[:,:,:3] 78 | array = array.transpose(2, 0, 1) 79 | 80 | elif tensor.ndimension() == 3: 81 | assert(tensor.size(0) == 3) 82 | array = 0.5 + tensor.numpy()*0.5 83 | return array 84 | 85 | 86 | def log_output_tensorboard(writer, prefix, index, suffix, n_iter, depth, disp, warped, diff, dssim, valid): 87 | disp_to_show = tensor2array(disp[0], max_value=None, colormap='magma') 88 | depth_to_show = tensor2array(depth[0], max_value=None) 89 | writer.add_image('{} Dispnet Output Normalized {}/{}'.format(prefix, suffix, index), disp_to_show, n_iter) 90 | writer.add_image('{} Depth Output {}/{}'.format(prefix, suffix, index), depth_to_show, n_iter) 91 | # log warped images along with explainability mask 92 | for j, (warped_j, diff_j, dssim_j, valid_j) in enumerate(zip(warped, diff, dssim, valid)): 93 | whole_suffix = '{} {}/{}'.format(suffix, j, index) 94 | warped_to_show = tensor2array(warped_j * valid_j.to(warped_j)) 95 | diff_to_show = tensor2array(0.5*diff_j) 96 | dssim_to_show = tensor2array(2*dssim_j - 1) 97 | writer.add_image('{} Warped Outputs {}'.format(prefix, whole_suffix), warped_to_show, n_iter) 98 | writer.add_image('{} Diff Outputs {}'.format(prefix, whole_suffix), diff_to_show, n_iter) 99 | writer.add_image('{} DSSIM Outputs {}'.format(prefix, whole_suffix), dssim_to_show, n_iter) 100 | 101 | 102 | def save_checkpoint(save_path, depthnet_state, posenet_state, is_best, filename='checkpoint.pth.tar'): 103 | file_prefixes = ['depthnet', 'posenet'] 104 | states = [depthnet_state, posenet_state] 105 | for (prefix, state) in zip(file_prefixes, states): 106 | torch.save(state, save_path/'{}_{}'.format(prefix,filename)) 107 | 108 | if is_best: 109 | for prefix in file_prefixes: 110 | shutil.copyfile(save_path/'{}_{}'.format(prefix,filename), save_path/'{}_model_best.pth.tar'.format(prefix)) --------------------------------------------------------------------------------