├── LICENSE ├── Poster_ICRA2019.pdf ├── README.md ├── custom_transforms.py ├── data ├── cityscapes_loader.py ├── kitti_raw_loader.py ├── prepare_train_data.py ├── static_frames.txt └── test_scenes.txt ├── datasets ├── general_sequence_folders.py ├── sequence_folders.py ├── stacked_sequence_folders.py ├── validation_flow.py ├── validation_flow_video.py └── validation_folders.py ├── flowutils ├── flow_io.py ├── flow_viz.py ├── flowlib.py └── pfm.py ├── inverse_warp_summary.py ├── kitti_eval ├── depth_evaluation_utils.py ├── pose_evaluation_utils.py ├── test_files_eigen.txt └── validation_flow_video.py ├── logger.py ├── loss_functions_summary.py ├── models ├── DispNetS.py ├── DispNetS6.py ├── DispResNet6.py ├── DispResNetS6.py ├── FlowNetC6.py ├── MaskNet6.py ├── MaskResNet6.py ├── PoseExpNet.py ├── PoseNet6.py ├── PoseNetB6.py ├── __init__.py ├── back2future.py ├── submodules.py └── utils.py ├── requirements.txt ├── ssim.py ├── stillbox_eval ├── depth_evaluation_utils.py └── test_files_90.txt ├── test_disp.py ├── test_flow.py ├── test_pose.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Guangming Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Poster_ICRA2019.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guangmingw/DOPlearning/cd7616205410142fceee6eac9b802dc42774a259/Poster_ICRA2019.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Learning of Depth, Optical Flow and Pose with Occlusion from 3D Geometry 2 | 3 | Code for the papers: 4 | 5 | *G. Wang, H. Wang, Y. Liu, and W. Chen,* [**Unsupervised Learning of Monocular Depth and Ego-Motion Using Multiple Masks**](https://ieeexplore.ieee.org/abstract/document/8793622), in International Conference on Robotics and Automation, pp. 4724-4730, 2019. [[Poster]](https://github.com/guangmingw/DOPlearning/blob/master/Poster_ICRA2019.pdf) 6 | 7 | *G. Wang, C. Zhang, H. Wang, J. Wang, Y. Wang, and X. Wang,* [**Unsupervised Learning of Depth, Optical Flow and Pose with Occlusion from 3D Geometry**](https://ieeexplore.ieee.org/document/9152137), in IEEE Transactions on Intelligent Transportation Systems, doi: 10.1109/TITS.2020.3010418. 8 | 9 | ## Prerequisites 10 | 11 | Python3 and pytorch are required. Besides, other libraries need to be installed by runing: 12 | ``` 13 | pip3 install -r requirements.txt 14 | ``` 15 | 16 | ## Preparing training data 17 | 18 | #### KITTI 19 | For [KITTI](http://www.cvlibs.net/datasets/kitti/raw_data.php), first download the dataset using this [script](http://www.cvlibs.net/download.php?file=raw_data_downloader.zip) provided on the official website, and then run the following command. 20 | 21 | ```bash 22 | python3 data/prepare_train_data.py /path/to/raw/kitti/dataset/ --dataset-format 'kitti' --dump-root /path/to/resulting/formatted/data/ --width 832 --height 256 --num-threads 1 --static-frames data/static_frames.txt --with-gt 23 | ``` 24 | 25 | For testing optical flow ground truths on KITTI, download [KITTI2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) dataset. You need to download 1) `stereo 2015/flow 2015/scene flow 2015` data set (2 GB), 2) `multi-view extension` (14 GB), and 3) `calibration files` (1 MB) . You should have the following directory structure: 26 | ``` 27 | kitti2015 28 | | data_scene_flow 29 | | data_scene_flow_calib 30 | | data_scene_flow_multiview 31 | ``` 32 | 33 | #### Cityscapes 34 | 35 | For [Cityscapes](https://www.cityscapes-dataset.com/), download the following packages: 1) `leftImg8bit_sequence_trainvaltest.zip`, 2) `camera_trainvaltest.zip`. You will probably need to contact the administrators to be able to get it. 36 | 37 | ```bash 38 | python3 data/prepare_train_data.py /path/to/cityscapes/dataset/ --dataset-format 'cityscapes' --dump-root /path/to/resulting/formatted/data/ --width 832 --height 342 --num-threads 1 39 | ``` 40 | 41 | Notice that for Cityscapes the `img_height` is set to 342 because we crop out the bottom part of the image that contains the car logo, and the resulting image will have height 256. 42 | 43 | ## Training 44 | 45 | ``` 46 | python3 train.py /path/to/prepared/data \ 47 | --dispnet DispResNetS6 --posenet PoseNetB6 --flownet Back2Future \ 48 | -b 4 -pc 1.0 -pf 0.0 -m 0.0 -c 0.0 -s 0.2 \ 49 | --epoch-size 100 --log-output -f 30 --nlevels 6 --lr 1e-4 -wssim 0.85 --epochs 4000 \ 50 | --smoothness-type edgeaware --fix-masknet --fix-flownet --with-depth-gt --log-terminal \ 51 | --spatial-normalize-max --workers 8 --kitti-dir /data/to/kitti --add-less-than-mean-mask \ 52 | --add-maskp01 --using-none-mask --name demo \ 53 | --pretrained-disp /path/to/disp/model \ 54 | --pretrained-pose /path/to/pose/model 55 | ``` 56 | 57 | Tensorboard can be open with the command: 58 | ``` 59 | tensorboard --logdir=./ 60 | ``` 61 | and visualize the training progress by opening https://localhost:6006 on your browser. 62 | 63 | ## Evaluation 64 | 65 | #### Disparity 66 | 67 | ``` 68 | python3 test_disp.py --dispnet DispResNetS6 --pretrained-dispnet /path/to/dispnet --pretrained-posent /path/to/posenet --dataset-dir /path/to/KITTI_raw --dataset-list /path/to/test_files_list 69 | ``` 70 | 71 | #### Pose 72 | 73 | ``` 74 | python test_pose.py pretrained/pose_model_best.pth.tar --img-width 832 --img-height 256 --dataset-dir /path/to/kitti/odometry/ --sequences 09 --posenet PoseNetB6 75 | ``` 76 | 77 | 78 | #### Optical Flow 79 | 80 | ``` 81 | python test_flow.py --pretrained-disp /path/to/dispnet --pretrained-pose /path/to/posenet --pretrained-mask /path/to/masknet --pretrained-flow /path/to/flownet --kitti-dir /path/to/kitti2015/dataset 82 | ``` 83 | 84 | ## Downloads 85 | #### Pretrained Models 86 | - [DispNet, PoseNet, and FlowNet](https://jbox.sjtu.edu.cn/l/6uq1SX) in joint unsupervised learning of depth, pose and optical flow. 87 | 88 | 89 | ## Acknowlegements 90 | We are grateful to Anurag Ranjan for his [github repository](https://github.com/anuragranj/cc). Our code is based on theirs. 91 | 92 | ## References 93 | 94 | *G. Wang, H. Wang, Y. Liu, and W. Chen,* **Unsupervised Learning of Monocular Depth and Ego-Motion Using Multiple Masks**, in International Conference on Robotics and Automation, pp. 4724-4730, 2019. 95 | 96 | *G. Wang, C. Zhang, H. Wang, J. Wang, Y. Wang, and X. Wang,* **Unsupervised Learning of Depth, Optical Flow and Pose with Occlusion from 3D Geometry**, in IEEE Transactions on Intelligent Transportation Systems, doi: 10.1109/TITS.2020.3010418. 97 | -------------------------------------------------------------------------------- /custom_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import random 4 | import numpy as np 5 | from scipy.misc import imresize, imrotate 6 | 7 | '''Set of tranform random routines that takes list of inputs as arguments, 8 | in order to have random but coherent transformations.''' 9 | 10 | 11 | class Compose(object): 12 | def __init__(self, transforms): 13 | self.transforms = transforms 14 | 15 | def __call__(self, images, intrinsics): 16 | for t in self.transforms: 17 | images, intrinsics = t(images, intrinsics) 18 | return images, intrinsics 19 | 20 | 21 | class Normalize(object): 22 | def __init__(self, mean, std): 23 | self.mean = mean 24 | self.std = std 25 | 26 | def __call__(self, images, intrinsics): 27 | for tensor in images: 28 | for t, m, s in zip(tensor, self.mean, self.std): 29 | t.sub_(m).div_(s) 30 | return images, intrinsics 31 | 32 | 33 | class NormalizeLocally(object): 34 | 35 | def __call__(self, images, intrinsics): 36 | image_tensor = torch.stack(images) 37 | assert(image_tensor.size(1)==3) #3 channel image 38 | mean = image_tensor.transpose(0,1).contiguous().view(3, -1).mean(1) 39 | std = image_tensor.transpose(0,1).contiguous().view(3, -1).std(1) 40 | 41 | for tensor in images: 42 | for t, m, s in zip(tensor, mean, std): 43 | t.sub_(m).div_(s) 44 | return images, intrinsics 45 | 46 | 47 | class ArrayToTensor(object): 48 | """Converts a list of numpy.ndarray (H x W x C) along with a intrinsics matrix to a list of torch.FloatTensor of shape (C x H x W) with a intrinsics tensor.""" 49 | 50 | def __call__(self, images, intrinsics): 51 | tensors = [] 52 | for im in images: 53 | # put it from HWC to CHW format 54 | im = np.transpose(im, (2, 0, 1)) 55 | # handle numpy array 56 | tensors.append(torch.from_numpy(im).float()/255) 57 | return tensors, intrinsics 58 | 59 | 60 | class RandomHorizontalFlip(object): 61 | """Randomly horizontally flips the given numpy array with a probability of 0.5""" 62 | 63 | def __call__(self, images, intrinsics): 64 | assert intrinsics is not None 65 | if random.random() < 0.5: 66 | output_intrinsics = np.copy(intrinsics) 67 | output_images = [np.copy(np.fliplr(im)) for im in images] 68 | w = output_images[0].shape[1] 69 | output_intrinsics[0,2] = w - output_intrinsics[0,2] 70 | else: 71 | output_images = images 72 | output_intrinsics = intrinsics 73 | return output_images, output_intrinsics 74 | 75 | class RandomRotate(object): 76 | """Randomly rotates images up to 10 degrees and crop them to keep same size as before.""" 77 | def __call__(self, images, intrinsics): 78 | if np.random.random() > 0.5: 79 | return images, intrinsics 80 | else: 81 | assert intrinsics is not None 82 | rot = np.random.uniform(0,10) 83 | rotated_images = [imrotate(im, rot) for im in images] 84 | 85 | return rotated_images, intrinsics 86 | 87 | 88 | 89 | 90 | class RandomScaleCrop(object): 91 | """Randomly zooms images up to 15% and crop them to keep same size as before.""" 92 | def __init__(self, h=0, w=0): 93 | self.h = h 94 | self.w = w 95 | 96 | def __call__(self, images, intrinsics): 97 | assert intrinsics is not None 98 | output_intrinsics = np.copy(intrinsics) 99 | 100 | in_h, in_w, _ = images[0].shape 101 | x_scaling, y_scaling = np.random.uniform(1,1.1,2) 102 | scaled_h, scaled_w = int(in_h * y_scaling), int(in_w * x_scaling) 103 | 104 | output_intrinsics[0] *= x_scaling 105 | output_intrinsics[1] *= y_scaling 106 | scaled_images = [imresize(im, (scaled_h, scaled_w)) for im in images] 107 | 108 | if self.h and self.w: 109 | in_h, in_w = self.h, self.w 110 | 111 | offset_y = np.random.randint(scaled_h - in_h + 1) 112 | offset_x = np.random.randint(scaled_w - in_w + 1) 113 | cropped_images = [im[offset_y:offset_y + in_h, offset_x:offset_x + in_w] for im in scaled_images] 114 | 115 | output_intrinsics[0,2] -= offset_x 116 | output_intrinsics[1,2] -= offset_y 117 | 118 | return cropped_images, output_intrinsics 119 | 120 | class Scale(object): 121 | """Scales images to a particular size""" 122 | def __init__(self, h, w): 123 | self.h = h 124 | self.w = w 125 | 126 | def __call__(self, images, intrinsics): 127 | assert intrinsics is not None 128 | output_intrinsics = np.copy(intrinsics) 129 | 130 | in_h, in_w, _ = images[0].shape 131 | scaled_h, scaled_w = self.h , self.w 132 | 133 | output_intrinsics[0] *= (scaled_w / in_w) 134 | output_intrinsics[1] *= (scaled_h / in_h) 135 | scaled_images = [imresize(im, (scaled_h, scaled_w)) for im in images] 136 | 137 | return scaled_images, output_intrinsics 138 | -------------------------------------------------------------------------------- /data/cityscapes_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import json 3 | import numpy as np 4 | import scipy.misc 5 | from path import Path 6 | from tqdm import tqdm 7 | 8 | 9 | class cityscapes_loader(object): 10 | def __init__(self, 11 | dataset_dir, 12 | split='train', 13 | crop_bottom=True, # Get rid of the car logo 14 | img_height=171, 15 | img_width=416): 16 | self.dataset_dir = Path(dataset_dir) 17 | self.split = split 18 | # Crop out the bottom 25% of the image to remove the car logo 19 | self.crop_bottom = crop_bottom 20 | self.img_height = img_height 21 | self.img_width = img_width 22 | self.min_speed = 2 23 | self.scenes = (self.dataset_dir/'leftImg8bit_sequence'/split).dirs() 24 | print('Total scenes collected: {}'.format(len(self.scenes))) 25 | 26 | def collect_scenes(self, city): 27 | img_files = sorted(city.files('*.png')) 28 | scenes = {} 29 | connex_scenes = {} 30 | connex_scene_data_list = [] 31 | for f in img_files: 32 | scene_id,frame_id = f.basename().split('_')[1:3] 33 | if scene_id not in scenes.keys(): 34 | scenes[scene_id] = [] 35 | scenes[scene_id].append(frame_id) 36 | 37 | # divide scenes into connexe sequences 38 | for scene_id in scenes.keys(): 39 | previous = None 40 | connex_scenes[scene_id] = [] 41 | for id in scenes[scene_id]: 42 | if previous is None or int(id) - int(previous) > 1: 43 | current_list = [] 44 | connex_scenes[scene_id].append(current_list) 45 | current_list.append(id) 46 | previous = id 47 | 48 | # create scene data dicts, and subsample scene every two frames 49 | for scene_id in connex_scenes.keys(): 50 | intrinsics = self.load_intrinsics(city, scene_id) 51 | for subscene in connex_scenes[scene_id]: 52 | frame_speeds = [self.load_speed(city, scene_id, frame_id) for frame_id in subscene] 53 | connex_scene_data_list.append({'city':city, 54 | 'scene_id': scene_id, 55 | 'rel_path': city.basename()+'_'+scene_id+'_'+subscene[0]+'_0', 56 | 'intrinsics': intrinsics, 57 | 'frame_ids':subscene[0::2], 58 | 'speeds':frame_speeds[0::2]}) 59 | connex_scene_data_list.append({'city':city, 60 | 'scene_id': scene_id, 61 | 'rel_path': city.basename()+'_'+scene_id+'_'+subscene[0]+'_1', 62 | 'intrinsics': intrinsics, 63 | 'frame_ids': subscene[1::2], 64 | 'speeds': frame_speeds[1::2]}) 65 | return connex_scene_data_list 66 | 67 | def load_intrinsics(self, city, scene_id): 68 | city_name = city.basename() 69 | camera_folder = self.dataset_dir/'camera'/self.split/city_name 70 | camera_file = camera_folder.files('{}_{}_*_camera.json'.format(city_name, scene_id))[0] 71 | frame_id = camera_file.split('_')[2] 72 | frame_path = city/'{}_{}_{}_leftImg8bit.png'.format(city_name, scene_id, frame_id) 73 | 74 | with open(camera_file, 'r') as f: 75 | camera = json.load(f) 76 | fx = camera['intrinsic']['fx'] 77 | fy = camera['intrinsic']['fy'] 78 | u0 = camera['intrinsic']['u0'] 79 | v0 = camera['intrinsic']['v0'] 80 | intrinsics = np.array([[fx, 0, u0], 81 | [0, fy, v0], 82 | [0, 0, 1]]) 83 | 84 | img = scipy.misc.imread(frame_path) 85 | h,w,_ = img.shape 86 | zoom_y = self.img_height/h 87 | zoom_x = self.img_width/w 88 | 89 | intrinsics[0] *= zoom_x 90 | intrinsics[1] *= zoom_y 91 | return intrinsics 92 | 93 | def load_speed(self, city, scene_id, frame_id): 94 | city_name = city.basename() 95 | vehicle_folder = self.dataset_dir/'vehicle_sequence'/self.split/city_name 96 | vehicle_file = vehicle_folder/'{}_{}_{}_vehicle.json'.format(city_name, scene_id, frame_id) 97 | with open(vehicle_file, 'r') as f: 98 | vehicle = json.load(f) 99 | return vehicle['speed'] 100 | 101 | def get_scene_imgs(self, scene_data): 102 | cum_speed = np.zeros(3) 103 | print(scene_data['city'].basename(), scene_data['scene_id'], scene_data['frame_ids'][0]) 104 | for i,frame_id in enumerate(scene_data['frame_ids']): 105 | cum_speed += scene_data['speeds'][i] 106 | speed_mag = np.linalg.norm(cum_speed) 107 | if speed_mag > self.min_speed: 108 | yield self.load_image(scene_data['city'], scene_data['scene_id'], frame_id), frame_id 109 | cum_speed *= 0 110 | 111 | def load_image(self, city, scene_id, frame_id): 112 | img_file = city/'{}_{}_{}_leftImg8bit.png'.format(city.basename(), 113 | scene_id, 114 | frame_id) 115 | if not img_file.isfile(): 116 | return None 117 | img = scipy.misc.imread(img_file) 118 | img = scipy.misc.imresize(img, (self.img_height, self.img_width))[:int(self.img_height*0.75)] 119 | return img 120 | -------------------------------------------------------------------------------- /data/kitti_raw_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from path import Path 3 | import scipy.misc 4 | from collections import Counter 5 | 6 | 7 | class KittiRawLoader(object): 8 | def __init__(self, 9 | dataset_dir, 10 | static_frames_file=None, 11 | img_height=128, 12 | img_width=416, 13 | min_speed=2, 14 | get_gt=False): 15 | dir_path = Path(__file__).realpath().dirname() 16 | test_scene_file = dir_path/'test_scenes.txt' 17 | 18 | self.from_speed = static_frames_file is None 19 | if static_frames_file is not None: 20 | static_frames_file = Path(static_frames_file) 21 | self.collect_static_frames(static_frames_file) 22 | 23 | with open(test_scene_file, 'r') as f: 24 | test_scenes = f.readlines() 25 | self.test_scenes = [t[:-1] for t in test_scenes] 26 | self.dataset_dir = Path(dataset_dir) 27 | self.img_height = img_height 28 | self.img_width = img_width 29 | self.cam_ids = ['02', '03'] 30 | self.date_list = ['2011_09_26', '2011_09_28', '2011_09_29', '2011_09_30', '2011_10_03'] 31 | self.min_speed = min_speed 32 | self.get_gt = get_gt 33 | self.collect_train_folders() 34 | 35 | def collect_static_frames(self, static_frames_file): 36 | with open(static_frames_file, 'r') as f: 37 | frames = f.readlines() 38 | self.static_frames = {} 39 | for fr in frames: 40 | if fr == '\n': 41 | continue 42 | date, drive, frame_id = fr.split(' ') 43 | curr_fid = '%.10d' % (np.int(frame_id[:-1])) 44 | if drive not in self.static_frames.keys(): 45 | self.static_frames[drive] = [] 46 | self.static_frames[drive].append(curr_fid) 47 | 48 | def collect_train_folders(self): 49 | self.scenes = [] 50 | for date in self.date_list: 51 | drive_set = (self.dataset_dir/date).dirs() 52 | for dr in drive_set: 53 | if dr.name[:-5] not in self.test_scenes: 54 | self.scenes.append(dr) 55 | 56 | def collect_scenes(self, drive): 57 | train_scenes = [] 58 | for c in self.cam_ids: 59 | oxts = sorted((drive/'oxts'/'data').files('*.txt')) 60 | scene_data = {'cid': c, 'dir': drive, 'speed': [], 'frame_id': [], 'rel_path': drive.name + '_' + c} 61 | for n, f in enumerate(oxts): 62 | metadata = np.genfromtxt(f) 63 | speed = metadata[8:11] 64 | scene_data['speed'].append(speed) 65 | scene_data['frame_id'].append('{:010d}'.format(n)) 66 | sample = self.load_image(scene_data, 0) 67 | if sample is None: 68 | return [] 69 | scene_data['P_rect'] = self.get_P_rect(scene_data, sample[1], sample[2]) 70 | scene_data['intrinsics'] = scene_data['P_rect'][:,:3] 71 | 72 | train_scenes.append(scene_data) 73 | return train_scenes 74 | 75 | def get_scene_imgs(self, scene_data): 76 | def construct_sample(scene_data, i, frame_id): 77 | sample = [self.load_image(scene_data, i)[0], frame_id] 78 | if self.get_gt: 79 | sample.append(self.generate_depth_map(scene_data, i)) 80 | return sample 81 | 82 | if self.from_speed: 83 | cum_speed = np.zeros(3) 84 | for i, speed in enumerate(scene_data['speed']): 85 | cum_speed += speed 86 | speed_mag = np.linalg.norm(cum_speed) 87 | if speed_mag > self.min_speed: 88 | frame_id = scene_data['frame_id'][i] 89 | yield construct_sample(scene_data, i, frame_id) 90 | cum_speed *= 0 91 | else: # from static frame file 92 | drive = str(scene_data['dir'].name) 93 | for (i,frame_id) in enumerate(scene_data['frame_id']): 94 | if (drive not in self.static_frames.keys()) or (frame_id not in self.static_frames[drive]): 95 | yield construct_sample(scene_data, i, frame_id) 96 | 97 | def get_P_rect(self, scene_data, zoom_x, zoom_y): 98 | #print(zoom_x, zoom_y) 99 | calib_file = scene_data['dir'].parent/'calib_cam_to_cam.txt' 100 | 101 | filedata = self.read_raw_calib_file(calib_file) 102 | P_rect = np.reshape(filedata['P_rect_' + scene_data['cid']], (3, 4)) 103 | P_rect[0] *= zoom_x 104 | P_rect[1] *= zoom_y 105 | return P_rect 106 | 107 | def load_image(self, scene_data, tgt_idx): 108 | img_file = scene_data['dir']/'image_{}'.format(scene_data['cid'])/'data'/scene_data['frame_id'][tgt_idx]+'.png' 109 | if not img_file.isfile(): 110 | return None 111 | img = scipy.misc.imread(img_file) 112 | zoom_y = self.img_height/img.shape[0] 113 | zoom_x = self.img_width/img.shape[1] 114 | img = scipy.misc.imresize(img, (self.img_height, self.img_width)) 115 | return img, zoom_x, zoom_y 116 | 117 | def read_raw_calib_file(self, filepath): 118 | # From https://github.com/utiasSTARS/pykitti/blob/master/pykitti/utils.py 119 | """Read in a calibration file and parse into a dictionary.""" 120 | data = {} 121 | 122 | with open(filepath, 'r') as f: 123 | for line in f.readlines(): 124 | key, value = line.split(':', 1) 125 | # The only non-float values in these files are dates, which 126 | # we don't care about anyway 127 | try: 128 | data[key] = np.array([float(x) for x in value.split()]) 129 | except ValueError: 130 | pass 131 | return data 132 | 133 | def generate_depth_map(self, scene_data, tgt_idx): 134 | # compute projection matrix velodyne->image plane 135 | 136 | def sub2ind(matrixSize, rowSub, colSub): 137 | m, n = matrixSize 138 | return rowSub * (n-1) + colSub - 1 139 | 140 | R_cam2rect = np.eye(4) 141 | 142 | calib_dir = scene_data['dir'].parent 143 | cam2cam = self.read_raw_calib_file(calib_dir/'calib_cam_to_cam.txt') 144 | velo2cam = self.read_raw_calib_file(calib_dir/'calib_velo_to_cam.txt') 145 | velo2cam = np.hstack((velo2cam['R'].reshape(3,3), velo2cam['T'][..., np.newaxis])) 146 | velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0]))) 147 | P_rect = scene_data['P_rect'] 148 | R_cam2rect[:3,:3] = cam2cam['R_rect_00'].reshape(3,3) 149 | 150 | P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam) 151 | 152 | velo_file_name = scene_data['dir']/'velodyne_points'/'data'/'{}.bin'.format(scene_data['frame_id'][tgt_idx]) 153 | 154 | # load velodyne points and remove all behind image plane (approximation) 155 | # each row of the velodyne data is forward, left, up, reflectance 156 | velo = np.fromfile(velo_file_name, dtype=np.float32).reshape(-1, 4) 157 | velo[:,3] = 1 158 | velo = velo[velo[:, 0] >= 0, :] 159 | 160 | # project the points to the camera 161 | velo_pts_im = np.dot(P_velo2im, velo.T).T 162 | velo_pts_im[:, :2] = velo_pts_im[:,:2] / velo_pts_im[:,-1:] 163 | 164 | # check if in bounds 165 | # use minus 1 to get the exact same value as KITTI matlab code 166 | velo_pts_im[:, 0] = np.round(velo_pts_im[:,0]) - 1 167 | velo_pts_im[:, 1] = np.round(velo_pts_im[:,1]) - 1 168 | 169 | val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) 170 | val_inds = val_inds & (velo_pts_im[:,0] < self.img_width) & (velo_pts_im[:,1] < self.img_height) 171 | velo_pts_im = velo_pts_im[val_inds, :] 172 | 173 | # project to image 174 | depth = np.zeros((self.img_height, self.img_width)).astype(np.float32) 175 | depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] 176 | 177 | # find the duplicate points and choose the closest depth 178 | inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) 179 | dupe_inds = [item for item, count in Counter(inds).items() if count > 1] 180 | for dd in dupe_inds: 181 | pts = np.where(inds == dd)[0] 182 | x_loc = int(velo_pts_im[pts[0], 0]) 183 | y_loc = int(velo_pts_im[pts[0], 1]) 184 | depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() 185 | depth[depth < 0] = 0 186 | return depth 187 | -------------------------------------------------------------------------------- /data/prepare_train_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import argparse 3 | import scipy.misc 4 | import numpy as np 5 | from joblib import Parallel, delayed 6 | from tqdm import tqdm 7 | from path import Path 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("dataset_dir", metavar='DIR', 11 | help='path to original dataset') 12 | parser.add_argument("--dataset-format", type=str, required=True, choices=["kitti", "cityscapes"]) 13 | parser.add_argument("--static-frames", default=None, 14 | help="list of imgs to discard for being static, if not set will discard them based on speed \ 15 | (careful, on KITTI some frames have incorrect speed)") 16 | parser.add_argument("--with-gt", action='store_true', 17 | help="If available (e.g. with KITTI), will store ground truth along with images, for validation") 18 | parser.add_argument("--dump-root", type=str, required=True, help="Where to dump the data") 19 | parser.add_argument("--height", type=int, default=128, help="image height") 20 | parser.add_argument("--width", type=int, default=416, help="image width") 21 | parser.add_argument("--num-threads", type=int, default=4, help="number of threads to use") 22 | 23 | args = parser.parse_args() 24 | 25 | 26 | def dump_example(scene): 27 | scene_list = data_loader.collect_scenes(scene) 28 | for scene_data in scene_list: 29 | dump_dir = args.dump_root/scene_data['rel_path'] 30 | dump_dir.makedirs_p() 31 | intrinsics = scene_data['intrinsics'] 32 | fx = intrinsics[0, 0] 33 | fy = intrinsics[1, 1] 34 | cx = intrinsics[0, 2] 35 | cy = intrinsics[1, 2] 36 | 37 | dump_cam_file = dump_dir/'cam.txt' 38 | with open(dump_cam_file, 'w') as f: 39 | f.write('%f,0.,%f,0.,%f,%f,0.,0.,1.' % (fx, cx, fy, cy)) 40 | 41 | for sample in data_loader.get_scene_imgs(scene_data): 42 | assert(len(sample) >= 2) 43 | img, frame_nb = sample[0], sample[1] 44 | dump_img_file = dump_dir/'{}.jpg'.format(frame_nb) 45 | scipy.misc.imsave(dump_img_file, img) 46 | if len(sample) == 3: 47 | dump_depth_file = dump_dir/'{}.npy'.format(frame_nb) 48 | np.save(dump_depth_file, sample[2]) 49 | 50 | if len(dump_dir.files('*.jpg')) < 3: 51 | dump_dir.rmtree() 52 | 53 | 54 | def main(): 55 | args.dump_root = Path(args.dump_root) 56 | args.dump_root.mkdir_p() 57 | 58 | global data_loader 59 | 60 | if args.dataset_format == 'kitti': 61 | from kitti_raw_loader import KittiRawLoader 62 | data_loader = KittiRawLoader(args.dataset_dir, 63 | static_frames_file=args.static_frames, 64 | img_height=args.height, 65 | img_width=args.width, 66 | get_gt=args.with_gt) 67 | 68 | if args.dataset_format == 'cityscapes': 69 | from cityscapes_loader import cityscapes_loader 70 | data_loader = cityscapes_loader(args.dataset_dir, 71 | img_height=args.height, 72 | img_width=args.width) 73 | 74 | print('Retrieving frames') 75 | Parallel(n_jobs=args.num_threads)(delayed(dump_example)(scene) for scene in tqdm(data_loader.scenes)) 76 | # Split into train/val 77 | print('Generating train val lists') 78 | np.random.seed(8964) 79 | subfolders = args.dump_root.dirs() 80 | with open(args.dump_root / 'train.txt', 'w') as tf: 81 | with open(args.dump_root / 'val.txt', 'w') as vf: 82 | for s in tqdm(subfolders): 83 | if np.random.random() < 0.1: 84 | vf.write('{}\n'.format(s.name)) 85 | else: 86 | tf.write('{}\n'.format(s.name)) 87 | # remove useless groundtruth data for training comment if you don't want to erase it 88 | for gt_file in s.files('*.npy'): 89 | gt_file.remove_p() 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /data/test_scenes.txt: -------------------------------------------------------------------------------- 1 | 2011_09_26_drive_0117 2 | 2011_09_28_drive_0002 3 | 2011_09_26_drive_0052 4 | 2011_09_30_drive_0016 5 | 2011_09_26_drive_0059 6 | 2011_09_26_drive_0027 7 | 2011_09_26_drive_0020 8 | 2011_09_26_drive_0009 9 | 2011_09_26_drive_0013 10 | 2011_09_26_drive_0101 11 | 2011_09_26_drive_0046 12 | 2011_09_26_drive_0029 13 | 2011_09_26_drive_0064 14 | 2011_09_26_drive_0048 15 | 2011_10_03_drive_0027 16 | 2011_09_26_drive_0002 17 | 2011_09_26_drive_0036 18 | 2011_09_29_drive_0071 19 | 2011_10_03_drive_0047 20 | 2011_09_30_drive_0027 21 | 2011_09_26_drive_0086 22 | 2011_09_26_drive_0084 23 | 2011_09_26_drive_0096 24 | 2011_09_30_drive_0018 25 | 2011_09_26_drive_0106 26 | 2011_09_26_drive_0056 27 | 2011_09_26_drive_0023 28 | 2011_09_26_drive_0093 29 | -------------------------------------------------------------------------------- /datasets/general_sequence_folders.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | from scipy.misc import imread 4 | from path import Path 5 | import random 6 | 7 | def crawl_folders(folders_list, sequence_length): 8 | sequence_set = [] 9 | demi_length = (sequence_length-1)//2 10 | for folder in folders_list: 11 | #intrinsics = np.genfromtxt(folder/'cam.txt', delimiter=',').astype(np.float32).reshape((3, 3)) 12 | imgs = sorted(folder.files('*.jpg')) 13 | if len(imgs) < sequence_length: 14 | continue 15 | for i in range(demi_length, len(imgs)-demi_length): 16 | sample = {'tgt': imgs[i], 'ref_imgs': []} 17 | for j in range(-demi_length, demi_length + 1): 18 | if j != 0: 19 | sample['ref_imgs'].append(imgs[i+j]) 20 | sequence_set.append(sample) 21 | random.shuffle(sequence_set) 22 | return sequence_set 23 | 24 | 25 | def load_as_float(path): 26 | return imread(path).astype(np.float32) 27 | 28 | 29 | class SequenceFolder(data.Dataset): 30 | """A sequence data loader where the files are arranged in this way: 31 | root/scene_1/0000000.jpg 32 | root/scene_1/0000001.jpg 33 | .. 34 | root/scene_1/cam.txt 35 | root/scene_2/0000000.jpg 36 | . 37 | 38 | transform functions must take in a list a images and a numpy array (usually intrinsics matrix) 39 | """ 40 | 41 | def __init__(self, root, seed=None, train=True, sequence_length=3, transform=None, target_transform=None): 42 | np.random.seed(seed) 43 | random.seed(seed) 44 | self.root = Path(root) 45 | #scene_list_path = self.root/'train.txt' if train else self.root/'val.txt' 46 | self.scenes = self.root.dirs() 47 | self.samples = crawl_folders(self.scenes, sequence_length) 48 | self.transform = transform 49 | 50 | def __getitem__(self, index): 51 | sample = self.samples[index] 52 | tgt_img = load_as_float(sample['tgt']) 53 | ref_imgs = [load_as_float(ref_img) for ref_img in sample['ref_imgs']] 54 | if self.transform is not None: 55 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(sample['intrinsics'])) 56 | tgt_img = imgs[0] 57 | ref_imgs = imgs[1:] 58 | else: 59 | intrinsics = np.copy(sample['intrinsics']) 60 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics) 61 | 62 | def __len__(self): 63 | return len(self.samples) 64 | -------------------------------------------------------------------------------- /datasets/sequence_folders.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | from scipy.misc import imread 4 | from path import Path 5 | import random 6 | 7 | 8 | def crawl_folders(folders_list, sequence_length):#sequence_length = 3 9 | sequence_set = [] 10 | demi_length = (sequence_length-1)//2 # 1 11 | for folder in folders_list: 12 | intrinsics = np.genfromtxt(folder/'cam.txt', delimiter=',').astype(np.float32).reshape((3, 3)) 13 | imgs = sorted(folder.files('*.jpg')) 14 | if len(imgs) < sequence_length: 15 | continue 16 | for i in range(demi_length, len(imgs)-demi_length): 17 | sample = {'intrinsics': intrinsics, 'tgt': imgs[i], 'ref_imgs': []} 18 | for j in range(-demi_length, demi_length + 1):#-1,0,1 19 | if j != 0: 20 | sample['ref_imgs'].append(imgs[i+j]) 21 | sequence_set.append(sample) 22 | random.shuffle(sequence_set) 23 | return sequence_set 24 | 25 | 26 | def load_as_float(path): 27 | return imread(path).astype(np.float32) 28 | 29 | 30 | class SequenceFolder(data.Dataset): 31 | """A sequence data loader where the files are arranged in this way: 32 | root/scene_1/0000000.jpg 33 | root/scene_1/0000001.jpg 34 | .. 35 | root/scene_1/cam.txt 36 | root/scene_2/0000000.jpg 37 | . 38 | 39 | transform functions must take in a list a images and a numpy array (usually intrinsics matrix) 40 | """ 41 | 42 | def __init__(self, root, seed=None, train=True, sequence_length=3, transform=None, target_transform=None): 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | self.root = Path(root) 46 | scene_list_path = self.root/'train.txt' if train else self.root/'val.txt' 47 | self.scenes = [self.root/folder[:-1] for folder in open(scene_list_path)] 48 | self.samples = crawl_folders(self.scenes, sequence_length) 49 | self.transform = transform 50 | 51 | def __getitem__(self, index): 52 | sample = self.samples[index] 53 | tgt_img = load_as_float(sample['tgt']) 54 | ref_imgs = [load_as_float(ref_img) for ref_img in sample['ref_imgs']] 55 | if self.transform is not None: 56 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(sample['intrinsics'])) 57 | tgt_img = imgs[0] 58 | ref_imgs = imgs[1:] 59 | else: 60 | intrinsics = np.copy(sample['intrinsics']) 61 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics) 62 | 63 | def __len__(self): 64 | return len(self.samples) 65 | -------------------------------------------------------------------------------- /datasets/stacked_sequence_folders.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | from scipy.misc import imread 4 | from path import Path 5 | import random 6 | 7 | 8 | def crawl_folders(folders_list, sequence_length): 9 | sequence_set = [] 10 | demi_length = (sequence_length-1)//2 11 | for folder in folders_list: 12 | intrinsics = [np.genfromtxt(cam_file, delimiter=',').astype(np.float32).reshape((3, 3)) for cam_file in sorted(folder.files('*_cam.txt'))] 13 | imgs = sorted(folder.files('*.jpg')) 14 | for i in range(len(imgs)): 15 | sample = {'intrinsics': intrinsics[i], 'img_stack': imgs[i]} 16 | sequence_set.append(sample) 17 | random.shuffle(sequence_set) 18 | return sequence_set 19 | 20 | 21 | def load_as_float(path, sequence_length): 22 | stack = imread(path).astype(np.float32) 23 | h,w,_ = stack.shape 24 | w_img = int(w/(sequence_length)) 25 | imgs = [stack[:,i*w_img:(i+1)*w_img] for i in range(sequence_length)] 26 | tgt_index = sequence_length//2 27 | return([imgs[tgt_index]] + imgs[:tgt_index] + imgs[tgt_index+1:]) 28 | 29 | 30 | class SequenceFolder(data.Dataset): 31 | """A sequence data loader where the images are arranged in this way: 32 | root/scene_1/0000000.jpg 33 | root/scene_1/0000000_cam.txt 34 | root/scene_1/0000001.jpg 35 | root/scene_1/0000001_cam.txt 36 | . 37 | root/scene_2/0000000.jpg 38 | root/scene_2/0000000_cam.txt 39 | """ 40 | 41 | def __init__(self, root, seed=None, train=True, sequence_length=3, transform=None, target_transform=None): 42 | np.random.seed(seed) 43 | random.seed(seed) 44 | self.root = Path(root) 45 | self.samples = [] 46 | frames_list_path = self.root/'train.txt' if train else self.root/'val.txt' 47 | self.scenes = self.root.dirs() 48 | self.sequence_length = sequence_length 49 | for frame_path in open(frames_list_path): 50 | a,b = frame_path[:-1].split(' ') 51 | base_path = (self.root/a)/b 52 | intrinsics = np.genfromtxt(base_path+'_cam.txt', delimiter=',').astype(np.float32).reshape((3, 3)) 53 | sample = {'intrinsics': intrinsics, 'img_stack': base_path+'.jpg'} 54 | self.samples.append(sample) 55 | self.transform = transform 56 | 57 | def __getitem__(self, index): 58 | sample = self.samples[index] 59 | imgs = load_as_float(sample['img_stack'], self.sequence_length) 60 | if self.transform is not None: 61 | imgs, intrinsics = self.transform(imgs, np.copy(sample['intrinsics'])) 62 | else: 63 | intrinsics = sample['intrinsics'] 64 | return imgs[0], imgs[1:], intrinsics, np.linalg.inv(intrinsics) 65 | 66 | def __len__(self): 67 | return len(self.samples) 68 | -------------------------------------------------------------------------------- /datasets/validation_flow.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2019, Anurag Ranjan 3 | # All rights reserved. 4 | # based on github.com/ClementPinard/SfMLearner-Pytorch 5 | 6 | import torch.utils.data as data 7 | import numpy as np 8 | from scipy.misc import imread 9 | from PIL import Image 10 | from path import Path 11 | from flowutils import flow_io 12 | import torch 13 | import os 14 | from skimage import transform as sktransform 15 | 16 | def crawl_folders(folders_list): 17 | imgs = [] 18 | depth = [] 19 | for folder in folders_list: 20 | current_imgs = sorted(folder.files('*.jpg')) 21 | current_depth = [] 22 | for img in current_imgs: 23 | d = img.dirname()/(img.name[:-4] + '.npy') 24 | assert(d.isfile()), "depth file {} not found".format(str(d)) 25 | depth.append(d) 26 | imgs.extend(current_imgs) 27 | depth.extend(current_depth) 28 | return imgs, depth 29 | 30 | 31 | def load_as_float(path): 32 | return imread(path).astype(np.float32) 33 | 34 | def get_intrinsics(calib_file, cid='02'): 35 | #print(zoom_x, zoom_y) 36 | filedata = read_raw_calib_file(calib_file) 37 | P_rect = np.reshape(filedata['P_rect_' + cid], (3, 4)) 38 | return P_rect[:,:3] 39 | 40 | 41 | def read_raw_calib_file(filepath): 42 | # From https://github.com/utiasSTARS/pykitti/blob/master/pykitti/utils.py 43 | """Read in a calibration file and parse into a dictionary.""" 44 | data = {} 45 | 46 | with open(filepath, 'r') as f: 47 | for line in f.readlines(): 48 | key, value = line.split(':', 1) 49 | # The only non-float values in these files are dates, which 50 | # we don't care about anyway 51 | try: 52 | data[key] = np.array([float(x) for x in value.split()]) 53 | except ValueError: 54 | pass 55 | return data 56 | 57 | class KITTI2015Test(data.Dataset): 58 | """ 59 | Kitti 2015 flow loader 60 | transform functions must take in a list a images and a numpy array which can be None 61 | """ 62 | 63 | def __init__(self, root, sequence_length, transform=None, N=200, phase='testing'): 64 | self.root = Path(root) 65 | self.sequence_length = sequence_length 66 | self.N = N 67 | self.transform = transform 68 | self.phase = phase 69 | seq_ids = list(range(-int(sequence_length/2), int(sequence_length/2)+1)) 70 | seq_ids.remove(0) 71 | self.seq_ids = [x+10 for x in seq_ids] 72 | 73 | def __getitem__(self, index): 74 | tgt_img_path = self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_10.png') 75 | ref_img_paths = [self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_'+str(k).zfill(2)+'.png') for k in self.seq_ids] 76 | cam_calib_path = self.root.joinpath('data_scene_flow_calib', self.phase, 'calib_cam_to_cam', str(index).zfill(6)+'.txt') 77 | 78 | tgt_img_original = load_as_float(tgt_img_path) 79 | tgt_img = load_as_float(tgt_img_path) 80 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths] 81 | intrinsics = get_intrinsics(cam_calib_path).astype('float32') 82 | tgt_img_original = torch.FloatTensor(tgt_img_original.transpose(2,0,1)) 83 | 84 | if self.transform is not None: 85 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(intrinsics)) 86 | tgt_img = imgs[0] 87 | ref_imgs = imgs[1:] 88 | else: 89 | intrinsics = np.copy(intrinsics) 90 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics), tgt_img_original 91 | 92 | def __len__(self): 93 | return self.N 94 | 95 | class ValidationFlow(data.Dataset): 96 | """ 97 | Kitti 2015 flow loader 98 | transform functions must take in a list a images and a numpy array which can be None 99 | """ 100 | 101 | def __init__(self, root, sequence_length, transform=None, N=200, phase='training', occ='flow_occ'): 102 | self.root = Path(root) 103 | self.sequence_length = sequence_length 104 | self.N = N 105 | self.transform = transform 106 | self.phase = phase 107 | seq_ids = list(range(-int(sequence_length/2), int(sequence_length/2)+1)) 108 | seq_ids.remove(0) 109 | self.seq_ids = [x+10 for x in seq_ids] 110 | self.occ = occ 111 | 112 | def __getitem__(self, index): 113 | tgt_img_path = self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_10.png') 114 | ref_img_paths = [self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_'+str(k).zfill(2)+'.png') for k in self.seq_ids] 115 | gt_flow_path = self.root.joinpath('data_scene_flow', self.phase, self.occ, str(index).zfill(6)+'_10.png') 116 | cam_calib_path = self.root.joinpath('data_scene_flow_calib', self.phase, 'calib_cam_to_cam', str(index).zfill(6)+'.txt') 117 | obj_map_path = self.root.joinpath('data_scene_flow', self.phase, 'obj_map', str(index).zfill(6)+'_10.png') 118 | 119 | tgt_img = load_as_float(tgt_img_path) 120 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths] 121 | if os.path.isfile(obj_map_path): 122 | obj_map = load_as_float(obj_map_path) 123 | else: 124 | obj_map = np.ones((tgt_img.shape[0], tgt_img.shape[1])) 125 | u,v,valid = flow_io.flow_read_png(gt_flow_path) 126 | gtFlow = np.dstack((u,v,valid)) 127 | #gtFlow = scale_flow(np.dstack((u,v,valid)), h=self.flow_h, w=self.flow_w) 128 | gtFlow = torch.FloatTensor(gtFlow.transpose(2,0,1)) 129 | intrinsics = get_intrinsics(cam_calib_path).astype('float32') 130 | 131 | if self.transform is not None: 132 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(intrinsics)) 133 | tgt_img = imgs[0] 134 | ref_imgs = imgs[1:] 135 | else: 136 | intrinsics = np.copy(intrinsics) 137 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics), gtFlow, obj_map 138 | 139 | def __len__(self): 140 | return self.N 141 | 142 | class ValidationMask(data.Dataset): 143 | """ 144 | Kitti 2015 flow loader 145 | transform functions must take in a list a images and a numpy array which can be None 146 | """ 147 | 148 | def __init__(self, root, sequence_length, transform=None, N=200, phase='training'): 149 | self.root = Path(root) 150 | self.sequence_length = sequence_length 151 | self.N = N 152 | self.transform = transform 153 | self.phase = phase 154 | seq_ids = list(range(-int(sequence_length/2), int(sequence_length/2)+1)) 155 | seq_ids.remove(0) 156 | self.seq_ids = [x+10 for x in seq_ids] 157 | 158 | def __getitem__(self, index): 159 | tgt_img_path = self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_10.png') 160 | ref_img_paths = [self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_'+str(k).zfill(2)+'.png') for k in self.seq_ids] 161 | gt_flow_path = self.root.joinpath('data_scene_flow', self.phase, 'flow_occ', str(index).zfill(6)+'_10.png') 162 | cam_calib_path = self.root.joinpath('data_scene_flow_calib', self.phase, 'calib_cam_to_cam', str(index).zfill(6)+'.txt') 163 | obj_map_path = self.root.joinpath('data_scene_flow', self.phase, 'obj_map', str(index).zfill(6)+'_10.png') 164 | semantic_map_path = self.root.joinpath('semantic_labels', self.phase, 'semantic', str(index).zfill(6)+'_10.png') 165 | 166 | tgt_img = load_as_float(tgt_img_path) 167 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths] 168 | obj_map = torch.LongTensor(np.array(Image.open(obj_map_path))) 169 | semantic_map = torch.LongTensor(np.array(Image.open(semantic_map_path))) 170 | u,v,valid = flow_io.flow_read_png(gt_flow_path) 171 | gtFlow = np.dstack((u,v,valid)) 172 | #gtFlow = scale_flow(np.dstack((u,v,valid)), h=self.flow_h, w=self.flow_w) 173 | gtFlow = torch.FloatTensor(gtFlow.transpose(2,0,1)) 174 | intrinsics = get_intrinsics(cam_calib_path).astype('float32') 175 | 176 | if self.transform is not None: 177 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(intrinsics)) 178 | tgt_img = imgs[0] 179 | ref_imgs = imgs[1:] 180 | else: 181 | intrinsics = np.copy(intrinsics) 182 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics), gtFlow, obj_map, semantic_map 183 | 184 | def __len__(self): 185 | return self.N 186 | 187 | class ValidationFlowKitti2012(data.Dataset): 188 | """ 189 | Kitti 2012 flow loader 190 | transform functions must take in a list a images and a numpy array which can be None 191 | """ 192 | 193 | def __init__(self, root, sequence_length=5, transform=None, N=194, flow_w=1024, flow_h=384, phase='training'): 194 | self.root = Path(root) 195 | self.sequence_length = sequence_length 196 | self.N = N 197 | self.transform = transform 198 | self.phase = phase 199 | self.flow_h = flow_h 200 | self.flow_w = flow_w 201 | 202 | def __getitem__(self, index): 203 | tgt_img_path = self.root.joinpath('data_stereo_flow', self.phase, 'colored_0',str(index).zfill(6)+'_10.png') 204 | ref_img_path = self.root.joinpath('data_stereo_flow', self.phase, 'colored_0',str(index).zfill(6)+'_11.png') 205 | gt_flow_path = self.root.joinpath('data_stereo_flow', self.phase, 'flow_occ', str(index).zfill(6)+'_10.png') 206 | 207 | tgt_img = load_as_float(tgt_img_path) 208 | ref_img = load_as_float(ref_img_path) 209 | 210 | u,v,valid = flow_io.flow_read_png(gt_flow_path) 211 | #gtFlow = scale_flow(np.dstack((u,v,valid)), h=self.flow_h, w=self.flow_w) 212 | gtFlow = np.dstack((u,v,valid)) 213 | gtFlow = torch.FloatTensor(gtFlow.transpose(2,0,1)) 214 | 215 | intrinsics = np.eye(3) 216 | if self.transform is not None: 217 | imgs, intrinsics = self.transform([tgt_img] + [ref_img], np.copy(intrinsics)) 218 | tgt_img = imgs[0] 219 | ref_img = imgs[1] 220 | else: 221 | intrinsics = np.copy(intrinsics) 222 | return tgt_img, ref_img, intrinsics, np.linalg.inv(intrinsics), gtFlow 223 | 224 | def __len__(self): 225 | return self.N 226 | -------------------------------------------------------------------------------- /datasets/validation_flow_video.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2019, Anurag Ranjan 3 | # All rights reserved. 4 | # based on github.com/ClementPinard/SfMLearner-Pytorch 5 | 6 | import torch.utils.data as data 7 | import numpy as np 8 | from scipy.misc import imread 9 | from PIL import Image 10 | from path import Path 11 | from flowutils import flow_io 12 | import torch 13 | import os 14 | from skimage import transform as sktransform 15 | 16 | def crawl_folders(folders_list): 17 | imgs = [] 18 | depth = [] 19 | for folder in folders_list: 20 | current_imgs = sorted(folder.files('*.jpg')) 21 | current_depth = [] 22 | for img in current_imgs: 23 | d = img.dirname()/(img.name[:-4] + '.npy') 24 | assert(d.isfile()), "depth file {} not found".format(str(d)) 25 | depth.append(d) 26 | imgs.extend(current_imgs) 27 | depth.extend(current_depth) 28 | return imgs, depth 29 | 30 | 31 | def load_as_float(path): 32 | return imread(path).astype(np.float32) 33 | 34 | def get_intrinsics(calib_file, cid='02'): 35 | #print(zoom_x, zoom_y) 36 | filedata = read_raw_calib_file(calib_file) 37 | P_rect = np.reshape(filedata['P_rect_' + cid], (3, 4)) 38 | return P_rect[:,:3] 39 | 40 | 41 | def read_raw_calib_file(filepath): 42 | # From https://github.com/utiasSTARS/pykitti/blob/master/pykitti/utils.py 43 | """Read in a calibration file and parse into a dictionary.""" 44 | data = {} 45 | 46 | with open(filepath, 'r') as f: 47 | for line in f.readlines(): 48 | key, value = line.split(':', 1) 49 | # The only non-float values in these files are dates, which 50 | # we don't care about anyway 51 | try: 52 | data[key] = np.array([float(x) for x in value.split()]) 53 | except ValueError: 54 | pass 55 | return data 56 | 57 | class KITTI2015Test(data.Dataset): 58 | """ 59 | Kitti 2015 flow loader 60 | transform functions must take in a list a images and a numpy array which can be None 61 | """ 62 | 63 | def __init__(self, root, sequence_length, transform=None, N=200, phase='testing'): 64 | self.root = Path(root) 65 | self.sequence_length = sequence_length 66 | self.N = N 67 | self.transform = transform 68 | self.phase = phase 69 | seq_ids = list(range(-int(sequence_length/2), int(sequence_length/2)+1)) 70 | seq_ids.remove(0) 71 | self.seq_ids = [x+10 for x in seq_ids] 72 | 73 | def __getitem__(self, index): 74 | tgt_img_path = self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_10.png') 75 | ref_img_paths = [self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_'+str(k).zfill(2)+'.png') for k in self.seq_ids] 76 | cam_calib_path = self.root.joinpath('data_scene_flow_calib', self.phase, 'calib_cam_to_cam', str(index).zfill(6)+'.txt') 77 | 78 | tgt_img_original = load_as_float(tgt_img_path) 79 | tgt_img = load_as_float(tgt_img_path) 80 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths] 81 | intrinsics = get_intrinsics(cam_calib_path).astype('float32') 82 | tgt_img_original = torch.FloatTensor(tgt_img_original.transpose(2,0,1)) 83 | 84 | if self.transform is not None: 85 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(intrinsics)) 86 | tgt_img = imgs[0] 87 | ref_imgs = imgs[1:] 88 | else: 89 | intrinsics = np.copy(intrinsics) 90 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics), tgt_img_original 91 | 92 | def __len__(self): 93 | return self.N 94 | 95 | class ValidationFlow(data.Dataset): 96 | """ 97 | Kitti 2015 flow loader 98 | transform functions must take in a list a images and a numpy array which can be None 99 | """ 100 | 101 | def __init__(self, root, sequence_length, transform=None, N=200, phase='training', occ='flow_occ'): 102 | self.root = Path(root) 103 | self.gt_root = Path('/data/kitti/kitti2015') 104 | self.sequence_length = sequence_length 105 | self.N = N 106 | self.transform = transform 107 | self.phase = phase 108 | seq_ids = list(range(-int(sequence_length/2), int(sequence_length/2)+1)) 109 | seq_ids.remove(0) 110 | self.seq_ids = [x for x in seq_ids] 111 | self.occ = occ 112 | 113 | def __getitem__(self, index): 114 | 115 | path_list=os.listdir(self.root) 116 | path_list.sort() 117 | index += 2 118 | 119 | tgt_img_path = self.root.joinpath(str(index).zfill(10)+'.png') 120 | ref_img_paths = [self.root.joinpath(str(index+k).zfill(10)+'.png') for k in self.seq_ids] 121 | 122 | # /data/raw/2011_09_28/2011_09_28_drive_0002_sync/image_02/data 123 | cam_calib_path = Path(self.root[0:len('/data/raw/2011_09_28')]).joinpath('calib_cam_to_cam.txt') 124 | 125 | tgt_img = load_as_float(tgt_img_path) 126 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths] 127 | 128 | intrinsics = get_intrinsics(cam_calib_path).astype('float32') 129 | 130 | if self.transform is not None: 131 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(intrinsics)) 132 | tgt_img = imgs[0] 133 | ref_imgs = imgs[1:] 134 | else: 135 | intrinsics = np.copy(intrinsics) 136 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics) 137 | 138 | def __len__(self): 139 | return self.N 140 | 141 | class ValidationMask(data.Dataset): 142 | """ 143 | Kitti 2015 flow loader 144 | transform functions must take in a list a images and a numpy array which can be None 145 | """ 146 | 147 | def __init__(self, root, sequence_length, transform=None, N=200, phase='training'): 148 | self.root = Path(root) 149 | self.sequence_length = sequence_length 150 | self.N = N 151 | self.transform = transform 152 | self.phase = phase 153 | seq_ids = list(range(-int(sequence_length/2), int(sequence_length/2)+1)) 154 | seq_ids.remove(0) 155 | self.seq_ids = [x+10 for x in seq_ids] 156 | 157 | def __getitem__(self, index): 158 | tgt_img_path = self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_10.png') 159 | ref_img_paths = [self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_'+str(k).zfill(2)+'.png') for k in self.seq_ids] 160 | 161 | gt_flow_path = self.root.joinpath('data_scene_flow', self.phase, 'flow_occ', str(index).zfill(6)+'_10.png') 162 | cam_calib_path = self.root.joinpath('data_scene_flow_calib', self.phase, 'calib_cam_to_cam', str(index).zfill(6)+'.txt') 163 | obj_map_path = self.root.joinpath('data_scene_flow', self.phase, 'obj_map', str(index).zfill(6)+'_10.png') 164 | semantic_map_path = self.root.joinpath('semantic_labels', self.phase, 'semantic', str(index).zfill(6)+'_10.png') 165 | 166 | tgt_img = load_as_float(tgt_img_path) 167 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths] 168 | obj_map = torch.LongTensor(np.array(Image.open(obj_map_path))) 169 | semantic_map = torch.LongTensor(np.array(Image.open(semantic_map_path))) 170 | u,v,valid = flow_io.flow_read_png(gt_flow_path) 171 | gtFlow = np.dstack((u,v,valid)) 172 | #gtFlow = scale_flow(np.dstack((u,v,valid)), h=self.flow_h, w=self.flow_w) 173 | gtFlow = torch.FloatTensor(gtFlow.transpose(2,0,1)) 174 | intrinsics = get_intrinsics(cam_calib_path).astype('float32') 175 | 176 | if self.transform is not None: 177 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(intrinsics)) 178 | tgt_img = imgs[0] 179 | ref_imgs = imgs[1:] 180 | else: 181 | intrinsics = np.copy(intrinsics) 182 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics), gtFlow, obj_map, semantic_map 183 | 184 | def __len__(self): 185 | return self.N 186 | 187 | class ValidationFlowKitti2012(data.Dataset): 188 | """ 189 | Kitti 2012 flow loader 190 | transform functions must take in a list a images and a numpy array which can be None 191 | """ 192 | 193 | def __init__(self, root, sequence_length=5, transform=None, N=194, flow_w=1024, flow_h=384, phase='training'): 194 | self.root = Path(root) 195 | self.sequence_length = sequence_length 196 | self.N = N 197 | self.transform = transform 198 | self.phase = phase 199 | self.flow_h = flow_h 200 | self.flow_w = flow_w 201 | 202 | def __getitem__(self, index): 203 | tgt_img_path = self.root.joinpath('data_stereo_flow', self.phase, 'colored_0',str(index).zfill(6)+'_10.png') 204 | ref_img_path = self.root.joinpath('data_stereo_flow', self.phase, 'colored_0',str(index).zfill(6)+'_11.png') 205 | gt_flow_path = self.root.joinpath('data_stereo_flow', self.phase, 'flow_occ', str(index).zfill(6)+'_10.png') 206 | 207 | tgt_img = load_as_float(tgt_img_path) 208 | ref_img = load_as_float(ref_img_path) 209 | 210 | u,v,valid = flow_io.flow_read_png(gt_flow_path) 211 | #gtFlow = scale_flow(np.dstack((u,v,valid)), h=self.flow_h, w=self.flow_w) 212 | gtFlow = np.dstack((u,v,valid)) 213 | gtFlow = torch.FloatTensor(gtFlow.transpose(2,0,1)) 214 | 215 | intrinsics = np.eye(3) 216 | if self.transform is not None: 217 | imgs, intrinsics = self.transform([tgt_img] + [ref_img], np.copy(intrinsics)) 218 | tgt_img = imgs[0] 219 | ref_img = imgs[1] 220 | else: 221 | intrinsics = np.copy(intrinsics) 222 | return tgt_img, ref_img, intrinsics, np.linalg.inv(intrinsics), gtFlow 223 | 224 | def __len__(self): 225 | return self.N 226 | -------------------------------------------------------------------------------- /datasets/validation_folders.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | from scipy.misc import imread 4 | from path import Path 5 | import torch 6 | 7 | 8 | def crawl_folders(folders_list): 9 | imgs = [] 10 | depth = [] 11 | for folder in folders_list: 12 | current_imgs = sorted(folder.files('*.jpg')) 13 | current_depth = [] 14 | for img in current_imgs: 15 | d = img.dirname()/(img.name[:-4] + '.npy') 16 | assert(d.isfile()), "depth file {} not found".format(str(d)) 17 | depth.append(d) 18 | imgs.extend(current_imgs) 19 | depth.extend(current_depth) 20 | return imgs, depth 21 | 22 | def crawl_folders_seq(folders_list, sequence_length): 23 | imgs1 = [] 24 | imgs2 = [] 25 | depth = [] 26 | for folder in folders_list: 27 | current_imgs = sorted(folder.files('*.jpg')) 28 | current_imgs1 = current_imgs[:-1] 29 | current_imgs2 = current_imgs[1:] 30 | current_depth = [] 31 | for (img1,img2) in zip(current_imgs1, current_imgs2): 32 | d = img1.dirname()/(img1.name[:-4] + '.npy') 33 | assert(d.isfile()), "depth file {} not found".format(str(d)) 34 | depth.append(d) 35 | imgs1.extend(current_imgs1) 36 | imgs2.extend(current_imgs2) 37 | depth.extend(current_depth) 38 | return imgs1, imgs2, depth 39 | 40 | 41 | def load_as_float(path): 42 | return imread(path).astype(np.float32) 43 | 44 | 45 | class ValidationSet(data.Dataset): 46 | """A sequence data loader where the files are arranged in this way: 47 | root/scene_1/0000000.jpg 48 | root/scene_1/0000000.npy 49 | root/scene_1/0000001.jpg 50 | root/scene_1/0000001.npy 51 | .. 52 | root/scene_2/0000000.jpg 53 | root/scene_2/0000000.npy 54 | . 55 | 56 | transform functions must take in a list a images and a numpy array which can be None 57 | """ 58 | 59 | def __init__(self, root, transform=None): 60 | self.root = Path(root) 61 | scene_list_path = self.root/'val.txt' 62 | self.scenes = [self.root/folder[:-1] for folder in open(scene_list_path)] 63 | self.imgs, self.depth = crawl_folders(self.scenes) 64 | self.transform = transform 65 | 66 | def __getitem__(self, index): 67 | img = load_as_float(self.imgs[index]) 68 | depth = np.load(self.depth[index]).astype(np.float32) 69 | if self.transform is not None: 70 | img, _ = self.transform([img], None) 71 | img = img[0] 72 | return img, depth 73 | 74 | def __len__(self): 75 | return len(self.imgs) 76 | 77 | class ValidationSetSeq(data.Dataset): 78 | """A sequence data loader where the files are arranged in this way: 79 | root/scene_1/0000000.jpg 80 | root/scene_1/0000000.npy 81 | root/scene_1/0000001.jpg 82 | root/scene_1/0000001.npy 83 | .. 84 | root/scene_2/0000000.jpg 85 | root/scene_2/0000000.npy 86 | . 87 | 88 | transform functions must take in a list a images and a numpy array which can be None 89 | """ 90 | 91 | def __init__(self, root, transform=None): 92 | self.root = Path(root) 93 | scene_list_path = self.root/'val.txt' 94 | self.scenes = [self.root/folder[:-1] for folder in open(scene_list_path)] 95 | self.imgs1, self.imgs2, self.depth = crawl_folders_seq(self.scenes) 96 | self.transform = transform 97 | 98 | def __getitem__(self, index): 99 | img1 = load_as_float(self.imgs1[index]) 100 | img2 = load_as_float(self.imgs2[index]) 101 | depth = np.load(self.depth[index]).astype(np.float32) 102 | if self.transform is not None: 103 | img, _ = self.transform([img1, img2], None) 104 | img1, img2 = img[0], img[1] 105 | return (img1, img2), depth 106 | 107 | def __len__(self): 108 | return len(self.imgs1) 109 | -------------------------------------------------------------------------------- /flowutils/flow_io.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python2 2 | 3 | """ 4 | I/O script to save and load the data coming with the MPI-Sintel low-level 5 | computer vision benchmark. 6 | 7 | For more details about the benchmark, please visit www.mpi-sintel.de 8 | 9 | CHANGELOG: 10 | v1.0 (2015/02/03): First release 11 | 12 | Copyright (c) 2015 Jonas Wulff 13 | Max Planck Institute for Intelligent Systems, Tuebingen, Germany 14 | 15 | """ 16 | 17 | # Requirements: Numpy as PIL/Pillow 18 | import numpy as np 19 | try: 20 | import png 21 | has_png = True 22 | except: 23 | has_png = False 24 | png=None 25 | 26 | 27 | 28 | # Check for endianness, based on Daniel Scharstein's optical flow code. 29 | # Using little-endian architecture, these two should be equal. 30 | TAG_FLOAT = 202021.25 31 | TAG_CHAR = 'PIEH'.encode() 32 | 33 | def flow_read(filename, return_validity=False): 34 | """ Read optical flow from file, return (U,V) tuple. 35 | 36 | Original code by Deqing Sun, adapted from Daniel Scharstein. 37 | """ 38 | f = open(filename,'rb') 39 | check = np.fromfile(f,dtype=np.float32,count=1)[0] 40 | assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) 41 | width = np.fromfile(f,dtype=np.int32,count=1)[0] 42 | height = np.fromfile(f,dtype=np.int32,count=1)[0] 43 | size = width*height 44 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) 45 | tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2)) 46 | u = tmp[:,np.arange(width)*2] 47 | v = tmp[:,np.arange(width)*2 + 1] 48 | 49 | if return_validity: 50 | valid = u<1e19 51 | u[valid==0] = 0 52 | v[valid==0] = 0 53 | return u,v,valid 54 | else: 55 | return u,v 56 | 57 | def flow_write(filename,uv,v=None): 58 | """ Write optical flow to file. 59 | 60 | If v is None, uv is assumed to contain both u and v channels, 61 | stacked in depth. 62 | 63 | Original code by Deqing Sun, adapted from Daniel Scharstein. 64 | """ 65 | nBands = 2 66 | 67 | if v is None: 68 | uv_ = np.array(uv) 69 | assert(uv_.ndim==3) 70 | if uv_.shape[0] == 2: 71 | u = uv_[0,:,:] 72 | v = uv_[1,:,:] 73 | elif uv_.shape[2] == 2: 74 | u = uv_[:,:,0] 75 | v = uv_[:,:,1] 76 | else: 77 | raise UVError('Wrong format for flow input') 78 | else: 79 | u = uv 80 | 81 | assert(u.shape == v.shape) 82 | height,width = u.shape 83 | f = open(filename,'wb') 84 | # write the header 85 | f.write(TAG_CHAR) 86 | np.array(width).astype(np.int32).tofile(f) 87 | np.array(height).astype(np.int32).tofile(f) 88 | # arrange into matrix form 89 | tmp = np.zeros((height, width*nBands)) 90 | tmp[:,np.arange(width)*2] = u 91 | tmp[:,np.arange(width)*2 + 1] = v 92 | tmp.astype(np.float32).tofile(f) 93 | f.close() 94 | 95 | 96 | def flow_read_png(fpath): 97 | """ 98 | Read KITTI optical flow, returns u,v,valid mask 99 | 100 | """ 101 | if not has_png: 102 | print('Error. Please install the PyPNG library') 103 | return 104 | 105 | R = png.Reader(fpath) 106 | width,height,data,_ = R.asDirect() 107 | # This only worked with python2. 108 | #I = np.array(map(lambda x:x,data)).reshape((height,width,3)) 109 | I = np.array([x for x in data]).reshape((height,width,3)) 110 | u_ = I[:,:,0] 111 | v_ = I[:,:,1] 112 | valid = I[:,:,2] 113 | 114 | u = (u_.astype('float64')-2**15)/64.0 115 | v = (v_.astype('float64')-2**15)/64.0 116 | 117 | return u,v,valid 118 | 119 | 120 | def flow_write_png(fpath,u,v,valid=None): 121 | """ 122 | Write KITTI optical flow. 123 | 124 | """ 125 | if not has_png: 126 | print('Error. Please install the PyPNG library') 127 | return 128 | 129 | 130 | if valid==None: 131 | valid_ = np.ones(u.shape,dtype='uint16') 132 | else: 133 | valid_ = valid.astype('uint16') 134 | 135 | 136 | u = u.astype('float64') 137 | v = v.astype('float64') 138 | 139 | u_ = ((u*64.0)+2**15).astype('uint16') 140 | v_ = ((v*64.0)+2**15).astype('uint16') 141 | 142 | I = np.dstack((u_,v_,valid_)) 143 | 144 | W = png.Writer(width=u.shape[1], 145 | height=u.shape[0], 146 | bitdepth=16, 147 | planes=3) 148 | 149 | with open(fpath,'wb') as fil: 150 | W.write(fil,I.reshape((-1,3*u.shape[1]))) 151 | -------------------------------------------------------------------------------- /flowutils/flow_viz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.transforms import ToTensor 4 | 5 | def batchComputeFlowImage(uv): 6 | flow_im = torch.zeros(uv.size(0), 3, uv.size(2), uv.size(3) ) 7 | uv_np = uv.numpy() 8 | for i in range(uv.size(0)): 9 | flow_im[i] = ToTensor()(computeFlowImage(uv_np[i][0], uv_np[i][1])) 10 | return flow_im 11 | 12 | def computeFlowImage(u,v,logscale=True,scaledown=6,output=False): 13 | """ 14 | topleft is zero, u is horiz, v is vertical 15 | red is 3 o'clock, yellow is 6, light blue is 9, blue/purple is 12 16 | """ 17 | colorwheel = makecolorwheel() 18 | ncols = colorwheel.shape[0] 19 | 20 | radius = np.sqrt(u**2 + v**2) 21 | if output: 22 | print("Maximum flow magnitude: %04f" % np.max(radius)) 23 | if logscale: 24 | radius = np.log(radius + 1) 25 | if output: 26 | print("Maximum flow magnitude (after log): %0.4f" % np.max(radius)) 27 | radius = radius / scaledown 28 | if output: 29 | print("Maximum flow magnitude (after scaledown): %0.4f" % np.max(radius)) 30 | rot = np.arctan2(-v, -u) / np.pi 31 | 32 | fk = (rot+1)/2 * (ncols-1) # -1~1 maped to 0~ncols 33 | k0 = fk.astype(np.uint8) # 0, 1, 2, ..., ncols 34 | 35 | k1 = k0+1 36 | k1[k1 == ncols] = 0 37 | 38 | f = fk - k0 39 | 40 | ncolors = colorwheel.shape[1] 41 | img = np.zeros(u.shape+(ncolors,)) 42 | for i in range(ncolors): 43 | tmp = colorwheel[:,i] 44 | col0 = tmp[k0] 45 | col1 = tmp[k1] 46 | col = (1-f)*col0 + f*col1 47 | 48 | idx = radius <= 1 49 | # increase saturation with radius 50 | col[idx] = 1 - radius[idx]*(1-col[idx]) 51 | # out of range 52 | col[~idx] *= 0.75 53 | img[:,:,i] = np.floor(255*col).astype(np.uint8) 54 | 55 | return img.astype(np.uint8) 56 | 57 | 58 | def makecolorwheel(): 59 | # Create a colorwheel for visualization 60 | RY = 15 61 | YG = 6 62 | GC = 4 63 | CB = 11 64 | BM = 13 65 | MR = 6 66 | 67 | ncols = RY + YG + GC + CB + BM + MR 68 | 69 | colorwheel = np.zeros((ncols,3)) 70 | 71 | col = 0 72 | # RY 73 | colorwheel[0:RY,0] = 1 74 | colorwheel[0:RY,1] = np.arange(0,1,1./RY) 75 | col += RY 76 | 77 | # YG 78 | colorwheel[col:col+YG,0] = np.arange(1,0,-1./YG) 79 | colorwheel[col:col+YG,1] = 1 80 | col += YG 81 | 82 | # GC 83 | colorwheel[col:col+GC,1] = 1 84 | colorwheel[col:col+GC,2] = np.arange(0,1,1./GC) 85 | col += GC 86 | 87 | # CB 88 | colorwheel[col:col+CB,1] = np.arange(1,0,-1./CB) 89 | colorwheel[col:col+CB,2] = 1 90 | col += CB 91 | 92 | # BM 93 | colorwheel[col:col+BM,2] = 1 94 | colorwheel[col:col+BM,0] = np.arange(0,1,1./BM) 95 | col += BM 96 | 97 | # MR 98 | colorwheel[col:col+MR,2] = np.arange(1,0,-1./MR) 99 | colorwheel[col:col+MR,0] = 1 100 | 101 | return colorwheel 102 | -------------------------------------------------------------------------------- /flowutils/flowlib.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """ 3 | # ============================== 4 | # flowlib.py 5 | # library for optical flow processing 6 | # Author: Ruoteng Li 7 | # Date: 6th Aug 2016 8 | # ============================== 9 | """ 10 | import png 11 | from . import pfm 12 | import numpy as np 13 | #import matplotlib.colors as cl 14 | #import matplotlib.pyplot as plt 15 | from PIL import Image 16 | import torch 17 | from torchvision.transforms import ToTensor 18 | 19 | 20 | UNKNOWN_FLOW_THRESH = 1e7 21 | SMALLFLOW = 0.0 22 | LARGEFLOW = 1e8 23 | 24 | """ 25 | ============= 26 | Flow Section 27 | ============= 28 | """ 29 | 30 | def batchComputeFlowImage(uv): 31 | flow_im = torch.zeros(uv.size(0), 3, uv.size(2), uv.size(3) ) 32 | uv_np = uv.numpy() 33 | for i in range(uv.size(0)): 34 | flow_im[i] = ToTensor()(flow_to_image(np.dstack((uv_np[i][0], uv_np[i][1])))) 35 | return flow_im 36 | 37 | def read_flow(filename): 38 | """ 39 | read optical flow data from flow file 40 | :param filename: name of the flow file 41 | :return: optical flow data in numpy array 42 | """ 43 | if filename.endswith('.flo'): 44 | flow = read_flo_file(filename) 45 | elif filename.endswith('.png'): 46 | flow = read_png_file(filename) 47 | elif filename.endswith('.pfm'): 48 | flow = read_pfm_file(filename) 49 | else: 50 | raise Exception('Invalid flow file format!') 51 | 52 | return flow 53 | 54 | 55 | def write_flow(flow, filename): 56 | """ 57 | write optical flow in Middlebury .flo format 58 | :param flow: optical flow map 59 | :param filename: optical flow file path to be saved 60 | :return: None 61 | """ 62 | f = open(filename, 'wb') 63 | magic = np.array([202021.25], dtype=np.float32) 64 | (height, width) = flow.shape[0:2] 65 | w = np.array([width], dtype=np.int32) 66 | h = np.array([height], dtype=np.int32) 67 | magic.tofile(f) 68 | w.tofile(f) 69 | h.tofile(f) 70 | flow.tofile(f) 71 | f.close() 72 | 73 | 74 | def save_flow_image(flow, image_file): 75 | """ 76 | save flow visualization into image file 77 | :param flow: optical flow data 78 | :param flow_fil 79 | :return: None 80 | """ 81 | flow_img = flow_to_image(flow) 82 | img_out = Image.fromarray(flow_img) 83 | img_out.save(image_file) 84 | 85 | 86 | def flowfile_to_imagefile(flow_file, image_file): 87 | """ 88 | convert flowfile into image file 89 | :param flow: optical flow data 90 | :param flow_fil 91 | :return: None 92 | """ 93 | flow = read_flow(flow_file) 94 | save_flow_image(flow, image_file) 95 | 96 | 97 | def segment_flow(flow): 98 | h = flow.shape[0] 99 | w = flow.shape[1] 100 | u = flow[:, :, 0] 101 | v = flow[:, :, 1] 102 | 103 | idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW)) 104 | idx2 = (abs(u) == SMALLFLOW) 105 | class0 = (v == 0) & (u == 0) 106 | u[idx2] = 0.00001 107 | tan_value = v / u 108 | 109 | class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0) 110 | class2 = (tan_value >= 1) & (u >= 0) & (v >= 0) 111 | class3 = (tan_value < -1) & (u <= 0) & (v >= 0) 112 | class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0) 113 | class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0) 114 | class7 = (tan_value < -1) & (u >= 0) & (v <= 0) 115 | class6 = (tan_value >= 1) & (u <= 0) & (v <= 0) 116 | class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0) 117 | 118 | seg = np.zeros((h, w)) 119 | 120 | seg[class1] = 1 121 | seg[class2] = 2 122 | seg[class3] = 3 123 | seg[class4] = 4 124 | seg[class5] = 5 125 | seg[class6] = 6 126 | seg[class7] = 7 127 | seg[class8] = 8 128 | seg[class0] = 0 129 | seg[idx] = 0 130 | 131 | return seg 132 | 133 | 134 | def flow_error(tu, tv, u, v): 135 | """ 136 | Calculate average end point error 137 | :param tu: ground-truth horizontal flow map 138 | :param tv: ground-truth vertical flow map 139 | :param u: estimated horizontal flow map 140 | :param v: estimated vertical flow map 141 | :return: End point error of the estimated flow 142 | """ 143 | smallflow = 0.0 144 | ''' 145 | stu = tu[bord+1:end-bord,bord+1:end-bord] 146 | stv = tv[bord+1:end-bord,bord+1:end-bord] 147 | su = u[bord+1:end-bord,bord+1:end-bord] 148 | sv = v[bord+1:end-bord,bord+1:end-bord] 149 | ''' 150 | stu = tu[:] 151 | stv = tv[:] 152 | su = u[:] 153 | sv = v[:] 154 | 155 | idxUnknow = (abs(stu) > UNKNOWN_FLOW_THRESH) | (abs(stv) > UNKNOWN_FLOW_THRESH) 156 | stu[idxUnknow] = 0 157 | stv[idxUnknow] = 0 158 | su[idxUnknow] = 0 159 | sv[idxUnknow] = 0 160 | 161 | ind2 = [(np.absolute(stu) > smallflow) | (np.absolute(stv) > smallflow)] 162 | index_su = su[ind2] 163 | index_sv = sv[ind2] 164 | an = 1.0 / np.sqrt(index_su ** 2 + index_sv ** 2 + 1) 165 | un = index_su * an 166 | vn = index_sv * an 167 | 168 | index_stu = stu[ind2] 169 | index_stv = stv[ind2] 170 | tn = 1.0 / np.sqrt(index_stu ** 2 + index_stv ** 2 + 1) 171 | tun = index_stu * tn 172 | tvn = index_stv * tn 173 | 174 | ''' 175 | angle = un * tun + vn * tvn + (an * tn) 176 | index = [angle == 1.0] 177 | angle[index] = 0.999 178 | ang = np.arccos(angle) 179 | mang = np.mean(ang) 180 | mang = mang * 180 / np.pi 181 | ''' 182 | 183 | epe = np.sqrt((stu - su) ** 2 + (stv - sv) ** 2) 184 | epe = epe[ind2] 185 | mepe = np.mean(epe) 186 | return mepe 187 | 188 | 189 | def flow_to_image(flow): 190 | """ 191 | Convert flow into middlebury color code image 192 | :param flow: optical flow map 193 | :return: optical flow image in middlebury color 194 | """ 195 | u = flow[0] 196 | v = flow[1] 197 | 198 | maxu = -999. 199 | maxv = -999. 200 | minu = 999. 201 | minv = 999. 202 | 203 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) 204 | u[idxUnknow] = 0 205 | v[idxUnknow] = 0 206 | 207 | maxu = max(maxu, np.max(u)) 208 | minu = min(minu, np.min(u)) 209 | 210 | maxv = max(maxv, np.max(v)) 211 | minv = min(minv, np.min(v)) 212 | 213 | rad = np.sqrt(u ** 2 + v ** 2) 214 | maxrad = max(-1, np.max(rad)) 215 | 216 | #print "max flow: %.4f\nflow range:\nu = %.3f .. %.3f\nv = %.3f .. %.3f" % (maxrad, minu,maxu, minv, maxv) 217 | 218 | u = u/(maxrad + np.finfo(float).eps) 219 | v = v/(maxrad + np.finfo(float).eps) 220 | 221 | img = compute_color(u, v) 222 | 223 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 224 | img[idx] = 0 225 | 226 | return img.transpose(2,0,1)/255. 227 | 228 | 229 | def evaluate_flow_file(gt_file, pred_file): 230 | """ 231 | evaluate the estimated optical flow end point error according to ground truth provided 232 | :param gt_file: ground truth file path 233 | :param pred_file: estimated optical flow file path 234 | :return: end point error, float32 235 | """ 236 | # Read flow files and calculate the errors 237 | gt_flow = read_flow(gt_file) # ground truth flow 238 | eva_flow = read_flow(pred_file) # predicted flow 239 | # Calculate errors 240 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], eva_flow[:, :, 0], eva_flow[:, :, 1]) 241 | return average_pe 242 | 243 | 244 | def evaluate_flow(gt_flow, pred_flow): 245 | """ 246 | gt: ground-truth flow 247 | pred: estimated flow 248 | """ 249 | average_pe = flow_error(gt_flow[:, :, 0], gt_flow[:, :, 1], pred_flow[:, :, 0], pred_flow[:, :, 1]) 250 | return average_pe 251 | 252 | 253 | """ 254 | ============== 255 | Disparity Section 256 | ============== 257 | """ 258 | 259 | 260 | def read_disp_png(file_name): 261 | """ 262 | Read optical flow from KITTI .png file 263 | :param file_name: name of the flow file 264 | :return: optical flow data in matrix 265 | """ 266 | image_object = png.Reader(filename=file_name) 267 | image_direct = image_object.asDirect() 268 | image_data = list(image_direct[2]) 269 | (w, h) = image_direct[3]['size'] 270 | channel = len(image_data[0]) / w 271 | flow = np.zeros((h, w, channel), dtype=np.uint16) 272 | for i in range(len(image_data)): 273 | for j in range(channel): 274 | flow[i, :, j] = image_data[i][j::channel] 275 | return flow[:, :, 0] / 256 276 | 277 | 278 | def disp_to_flowfile(disp, filename): 279 | """ 280 | Read KITTI disparity file in png format 281 | :param disp: disparity matrix 282 | :param filename: the flow file name to save 283 | :return: None 284 | """ 285 | f = open(filename, 'wb') 286 | magic = np.array([202021.25], dtype=np.float32) 287 | (height, width) = disp.shape[0:2] 288 | w = np.array([width], dtype=np.int32) 289 | h = np.array([height], dtype=np.int32) 290 | empty_map = np.zeros((height, width), dtype=np.float32) 291 | data = np.dstack((disp, empty_map)) 292 | magic.tofile(f) 293 | w.tofile(f) 294 | h.tofile(f) 295 | data.tofile(f) 296 | f.close() 297 | 298 | 299 | """ 300 | ============== 301 | Image Section 302 | ============== 303 | """ 304 | 305 | 306 | def read_image(filename): 307 | """ 308 | Read normal image of any format 309 | :param filename: name of the image file 310 | :return: image data in matrix uint8 type 311 | """ 312 | img = Image.open(filename) 313 | im = np.array(img) 314 | return im 315 | 316 | 317 | """ 318 | ============== 319 | Others 320 | ============== 321 | """ 322 | 323 | def pfm_to_flo(pfm_file): 324 | flow_filename = pfm_file[0:pfm_file.find('.pfm')] + '.flo' 325 | (data, scale) = pfm.readPFM(pfm_file) 326 | flow = data[:, :, 0:2] 327 | write_flow(flow, flow_filename) 328 | 329 | 330 | def scale_image(image, new_range): 331 | """ 332 | Linearly scale the image into desired range 333 | :param image: input image 334 | :param new_range: the new range to be aligned 335 | :return: image normalized in new range 336 | """ 337 | min_val = np.min(image).astype(np.float32) 338 | max_val = np.max(image).astype(np.float32) 339 | min_val_new = np.array(min(new_range), dtype=np.float32) 340 | max_val_new = np.array(max(new_range), dtype=np.float32) 341 | scaled_image = (image - min_val) / (max_val - min_val) * (max_val_new - min_val_new) + min_val_new 342 | return scaled_image.astype(np.uint8) 343 | 344 | 345 | def compute_color(u, v): 346 | """ 347 | compute optical flow color map 348 | :param u: optical flow horizontal map 349 | :param v: optical flow vertical map 350 | :return: optical flow in color code 351 | """ 352 | [h, w] = u.shape 353 | img = np.zeros([h, w, 3]) 354 | nanIdx = np.isnan(u) | np.isnan(v) 355 | u[nanIdx] = 0 356 | v[nanIdx] = 0 357 | 358 | colorwheel = make_color_wheel() 359 | ncols = np.size(colorwheel, 0) 360 | 361 | rad = np.sqrt(u**2+v**2) 362 | 363 | a = np.arctan2(-v, -u) / np.pi 364 | 365 | fk = (a+1) / 2 * (ncols - 1) + 1 366 | 367 | k0 = np.floor(fk).astype(int) 368 | 369 | k1 = k0 + 1 370 | k1[k1 == ncols+1] = 1 371 | f = fk - k0 372 | 373 | for i in range(0, np.size(colorwheel,1)): 374 | tmp = colorwheel[:, i] 375 | col0 = tmp[k0-1] / 255 376 | col1 = tmp[k1-1] / 255 377 | col = (1-f) * col0 + f * col1 378 | 379 | idx = rad <= 1 380 | col[idx] = 1-rad[idx]*(1-col[idx]) 381 | notidx = np.logical_not(idx) 382 | 383 | col[notidx] *= 0.75 384 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 385 | 386 | return img 387 | 388 | 389 | def make_color_wheel(): 390 | """ 391 | Generate color wheel according Middlebury color code 392 | :return: Color wheel 393 | """ 394 | RY = 15 395 | YG = 6 396 | GC = 4 397 | CB = 11 398 | BM = 13 399 | MR = 6 400 | 401 | ncols = RY + YG + GC + CB + BM + MR 402 | 403 | colorwheel = np.zeros([ncols, 3]) 404 | 405 | col = 0 406 | 407 | # RY 408 | colorwheel[0:RY, 0] = 255 409 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 410 | col += RY 411 | 412 | # YG 413 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 414 | colorwheel[col:col+YG, 1] = 255 415 | col += YG 416 | 417 | # GC 418 | colorwheel[col:col+GC, 1] = 255 419 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 420 | col += GC 421 | 422 | # CB 423 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 424 | colorwheel[col:col+CB, 2] = 255 425 | col += CB 426 | 427 | # BM 428 | colorwheel[col:col+BM, 2] = 255 429 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 430 | col += + BM 431 | 432 | # MR 433 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 434 | colorwheel[col:col+MR, 0] = 255 435 | 436 | return colorwheel 437 | 438 | 439 | def read_flo_file(filename): 440 | """ 441 | Read from Middlebury .flo file 442 | :param flow_file: name of the flow file 443 | :return: optical flow data in matrix 444 | """ 445 | f = open(filename, 'rb') 446 | magic = np.fromfile(f, np.float32, count=1) 447 | data2d = None 448 | 449 | if 202021.25 != magic: 450 | print('Magic number incorrect. Invalid .flo file') 451 | else: 452 | w = np.fromfile(f, np.int32, count=1) 453 | h = np.fromfile(f, np.int32, count=1) 454 | print("Reading %d x %d flow file in .flo format" % (h, w)) 455 | data2d = np.fromfile(f, np.float32, count=2 * w * h) 456 | # reshape data into 3D array (columns, rows, channels) 457 | data2d = np.resize(data2d, (h[0], w[0], 2)) 458 | f.close() 459 | return data2d 460 | 461 | 462 | def read_png_file(flow_file): 463 | """ 464 | Read from KITTI .png file 465 | :param flow_file: name of the flow file 466 | :return: optical flow data in matrix 467 | """ 468 | flow_object = png.Reader(filename=flow_file) 469 | flow_direct = flow_object.asDirect() 470 | flow_data = list(flow_direct[2]) 471 | (w, h) = flow_direct[3]['size'] 472 | print("Reading %d x %d flow file in .png format" % (h, w)) 473 | flow = np.zeros((h, w, 3), dtype=np.float64) 474 | for i in range(len(flow_data)): 475 | flow[i, :, 0] = flow_data[i][0::3] 476 | flow[i, :, 1] = flow_data[i][1::3] 477 | flow[i, :, 2] = flow_data[i][2::3] 478 | 479 | invalid_idx = (flow[:, :, 2] == 0) 480 | flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0 481 | flow[invalid_idx, 0] = 0 482 | flow[invalid_idx, 1] = 0 483 | return flow 484 | 485 | 486 | def read_pfm_file(flow_file): 487 | """ 488 | Read from .pfm file 489 | :param flow_file: name of the flow file 490 | :return: optical flow data in matrix 491 | """ 492 | import pfm 493 | (data, scale) = pfm.readPFM(flow_file) 494 | return data 495 | 496 | 497 | # fast resample layer 498 | def resample(img, sz): 499 | """ 500 | img: flow map to be resampled 501 | sz: new flow map size. Must be [height,weight] 502 | """ 503 | original_image_size = img.shape 504 | in_height = img.shape[0] 505 | in_width = img.shape[1] 506 | out_height = sz[0] 507 | out_width = sz[1] 508 | out_flow = np.zeros((out_height, out_width, 2)) 509 | # find scale 510 | height_scale = float(in_height) / float(out_height) 511 | width_scale = float(in_width) / float(out_width) 512 | 513 | [x,y] = np.meshgrid(range(out_width), range(out_height)) 514 | xx = x * width_scale 515 | yy = y * height_scale 516 | x0 = np.floor(xx).astype(np.int32) 517 | x1 = x0 + 1 518 | y0 = np.floor(yy).astype(np.int32) 519 | y1 = y0 + 1 520 | 521 | x0 = np.clip(x0,0,in_width-1) 522 | x1 = np.clip(x1,0,in_width-1) 523 | y0 = np.clip(y0,0,in_height-1) 524 | y1 = np.clip(y1,0,in_height-1) 525 | 526 | Ia = img[y0,x0,:] 527 | Ib = img[y1,x0,:] 528 | Ic = img[y0,x1,:] 529 | Id = img[y1,x1,:] 530 | 531 | wa = (y1-yy) * (x1-xx) 532 | wb = (yy-y0) * (x1-xx) 533 | wc = (y1-yy) * (xx-x0) 534 | wd = (yy-y0) * (xx-x0) 535 | out_flow[:,:,0] = (Ia[:,:,0]*wa + Ib[:,:,0]*wb + Ic[:,:,0]*wc + Id[:,:,0]*wd) * out_width / in_width 536 | out_flow[:,:,1] = (Ia[:,:,1]*wa + Ib[:,:,1]*wb + Ic[:,:,1]*wc + Id[:,:,1]*wd) * out_height / in_height 537 | 538 | return out_flow 539 | -------------------------------------------------------------------------------- /flowutils/pfm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sys 4 | 5 | 6 | def readPFM(file): 7 | file = open(file, 'rb') 8 | 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().rstrip() 16 | if header == 'PF': 17 | color = True 18 | elif header == 'Pf': 19 | color = False 20 | else: 21 | raise Exception('Not a PFM file.') 22 | 23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 24 | if dim_match: 25 | width, height = map(int, dim_match.groups()) 26 | else: 27 | raise Exception('Malformed PFM header.') 28 | 29 | scale = float(file.readline().rstrip()) 30 | if scale < 0: # little-endian 31 | endian = '<' 32 | scale = -scale 33 | else: 34 | endian = '>' # big-endian 35 | 36 | data = np.fromfile(file, endian + 'f') 37 | shape = (height, width, 3) if color else (height, width) 38 | 39 | data = np.reshape(data, shape) 40 | data = np.flipud(data) 41 | return data, scale 42 | 43 | 44 | def writePFM(file, image, scale=1): 45 | file = open(file, 'wb') 46 | 47 | color = None 48 | 49 | if image.dtype.name != 'float32': 50 | raise Exception('Image dtype must be float32.') 51 | 52 | image = np.flipud(image) 53 | 54 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 55 | color = True 56 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 57 | color = False 58 | else: 59 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 60 | 61 | file.write('PF\n' if color else 'Pf\n') 62 | file.write('%d %d\n' % (image.shape[1], image.shape[0])) 63 | 64 | endian = image.dtype.byteorder 65 | 66 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 67 | scale = -scale 68 | 69 | file.write('%f\n' % scale) 70 | 71 | image.tofile(file) -------------------------------------------------------------------------------- /kitti_eval/depth_evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # Mostly based on the code written by Clement Godard: 2 | # https://github.com/mrharicot/monodepth/blob/master/utils/evaluation_utils.py 3 | import numpy as np 4 | # import pandas as pd 5 | import datetime 6 | from collections import Counter 7 | from path import Path 8 | from scipy.misc import imread 9 | from tqdm import tqdm 10 | 11 | width_to_focal = dict() 12 | width_to_focal[1242] = 721.5377 13 | width_to_focal[1241] = 718.856 14 | width_to_focal[1224] = 707.0493 15 | width_to_focal[1238] = 718.3351 16 | 17 | 18 | class test_framework_KITTI(object): 19 | def __init__(self, root, test_files, seq_length=3, min_depth=1e-3, max_depth=100, step=1): 20 | self.root = root 21 | self.min_depth, self.max_depth = min_depth, max_depth 22 | self.calib_dirs, self.gt_files, self.img_files, self.displacements, self.cams = read_scene_data(self.root, test_files, seq_length, step) 23 | 24 | def __getitem__(self, i): 25 | tgt = imread(self.img_files[i][0]).astype(np.float32) 26 | depth = generate_depth_map(self.calib_dirs[i], self.gt_files[i], tgt.shape[:2], self.cams[i]) 27 | return {'tgt': tgt, 28 | 'ref': [imread(img).astype(np.float32) for img in self.img_files[i][1]], 29 | 'path':self.img_files[i][0], 30 | 'gt_depth': depth, 31 | 'displacements': np.array(self.displacements[i]), 32 | 'mask': generate_mask(depth, self.min_depth, self.max_depth) 33 | } 34 | 35 | def __len__(self): 36 | return len(self.img_files) 37 | 38 | 39 | ############################################################################### 40 | # EIGEN 41 | 42 | def read_text_lines(file_path): 43 | f = open(file_path, 'r') 44 | lines = f.readlines() 45 | f.close() 46 | lines = [l.rstrip() for l in lines] 47 | return lines 48 | 49 | 50 | def get_displacements(oxts_root, index, shifts): 51 | with open(oxts_root/'timestamps.txt') as f: 52 | timestamps = [datetime.datetime.strptime(ts[:-3], "%Y-%m-%d %H:%M:%S.%f").timestamp() for ts in f.read().splitlines()] 53 | oxts_data = np.genfromtxt(oxts_root/'data'/'{:010d}.txt'.format(index)) 54 | speed = np.linalg.norm(oxts_data[8:11]) 55 | assert(all(index+shift < len(timestamps) and index+shift >= 0 for shift in shifts)), str([index+shift for shift in shifts]) 56 | return [speed*abs(timestamps[index] - timestamps[index + shift]) for shift in shifts] 57 | 58 | 59 | def read_scene_data(data_root, test_list, seq_length=3, step=1): 60 | data_root = Path(data_root) 61 | gt_files = [] 62 | calib_dirs = [] 63 | im_files = [] 64 | cams = [] 65 | displacements = [] 66 | demi_length = (seq_length - 1) // 2 67 | shift_range = [step*i for i in list(range(-demi_length,0)) + list(range(1, demi_length + 1))] 68 | 69 | print('getting test metadata ... ') 70 | for sample in tqdm(test_list): 71 | tgt_img_path = data_root/sample 72 | date, scene, cam_id, _, index = sample[:-4].split('/') 73 | 74 | ref_imgs_path = [tgt_img_path.dirname()/'{:010d}.png'.format(int(index) + shift) for shift in shift_range] 75 | 76 | caped_shift_range = shift_range[:] # ensures ref_imgs are present, if not, set shift to 0 so that it will be discarded later 77 | for i,img in enumerate(ref_imgs_path): 78 | if not img.isfile(): 79 | ref_imgs_path[i] = tgt_img_path 80 | caped_shift_range[i] = 0 81 | 82 | vel_path = data_root/date/scene/'velodyne_points'/'data'/'{}.bin'.format(index[:10]) 83 | 84 | if tgt_img_path.isfile(): 85 | gt_files.append(vel_path) 86 | calib_dirs.append(data_root/date) 87 | im_files.append([tgt_img_path,ref_imgs_path]) 88 | cams.append(int(cam_id[-2:])) 89 | displacements.append(get_displacements(data_root/date/scene/'oxts', int(index), caped_shift_range)) 90 | else: 91 | print('{} missing'.format(tgt_img_path)) 92 | # print(num_probs, 'files missing') 93 | 94 | return calib_dirs, gt_files, im_files, displacements, cams 95 | 96 | 97 | def load_velodyne_points(file_name): 98 | # adapted from https://github.com/hunse/kitti 99 | points = np.fromfile(file_name, dtype=np.float32).reshape(-1, 4) 100 | points[:,3] = 1 101 | return points 102 | 103 | 104 | def read_calib_file(path): 105 | # taken from https://github.com/hunse/kitti 106 | float_chars = set("0123456789.e+- ") 107 | data = {} 108 | with open(path, 'r') as f: 109 | for line in f.readlines(): 110 | key, value = line.split(':', 1) 111 | value = value.strip() 112 | data[key] = value 113 | if float_chars.issuperset(value): 114 | # try to cast to float array 115 | try: 116 | data[key] = np.array(list(map(float, value.split(' ')))) 117 | except ValueError: 118 | # casting error: data[key] already eq. value, so pass 119 | pass 120 | 121 | return data 122 | 123 | 124 | def get_focal_length_baseline(calib_dir, cam=2): 125 | cam2cam = read_calib_file(calib_dir + 'calib_cam_to_cam.txt') 126 | P2_rect = cam2cam['P_rect_02'].reshape(3,4) 127 | P3_rect = cam2cam['P_rect_03'].reshape(3,4) 128 | 129 | # cam 2 is left of camera 0 -6cm 130 | # cam 3 is to the right +54cm 131 | b2 = P2_rect[0,3] / -P2_rect[0,0] 132 | b3 = P3_rect[0,3] / -P3_rect[0,0] 133 | baseline = b3-b2 134 | 135 | if cam == 2: 136 | focal_length = P2_rect[0,0] 137 | elif cam == 3: 138 | focal_length = P3_rect[0,0] 139 | 140 | return focal_length, baseline 141 | 142 | 143 | def sub2ind(matrixSize, rowSub, colSub): 144 | m, n = matrixSize 145 | return rowSub * (n-1) + colSub - 1 146 | 147 | 148 | def generate_depth_map(calib_dir, velo_file_name, im_shape, cam=2): 149 | # load calibration files 150 | cam2cam = read_calib_file(calib_dir/'calib_cam_to_cam.txt') 151 | velo2cam = read_calib_file(calib_dir/'calib_velo_to_cam.txt') 152 | velo2cam = np.hstack((velo2cam['R'].reshape(3,3), velo2cam['T'][..., np.newaxis])) 153 | velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0]))) 154 | 155 | # compute projection matrix velodyne->image plane 156 | R_cam2rect = np.eye(4) 157 | R_cam2rect[:3,:3] = cam2cam['R_rect_00'].reshape(3,3) 158 | P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3,4) 159 | P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam) 160 | 161 | # load velodyne points and remove all behind image plane (approximation) 162 | # each row of the velodyne data is forward, left, up, reflectance 163 | velo = load_velodyne_points(velo_file_name) 164 | velo = velo[velo[:, 0] >= 0, :] 165 | 166 | # project the points to the camera 167 | velo_pts_im = np.dot(P_velo2im, velo.T).T 168 | velo_pts_im[:, :2] = velo_pts_im[:,:2] / velo_pts_im[:,-1:] 169 | 170 | # check if in bounds 171 | # use minus 1 to get the exact same value as KITTI matlab code 172 | velo_pts_im[:, 0] = np.round(velo_pts_im[:,0]) - 1 173 | velo_pts_im[:, 1] = np.round(velo_pts_im[:,1]) - 1 174 | val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) 175 | val_inds = val_inds & (velo_pts_im[:,0] < im_shape[1]) & (velo_pts_im[:,1] < im_shape[0]) 176 | velo_pts_im = velo_pts_im[val_inds, :] 177 | 178 | # project to image 179 | depth = np.zeros((im_shape)) 180 | depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] 181 | 182 | # find the duplicate points and choose the closest depth 183 | inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) 184 | dupe_inds = [item for item, count in Counter(inds).items() if count > 1] 185 | for dd in dupe_inds: 186 | pts = np.where(inds == dd)[0] 187 | x_loc = int(velo_pts_im[pts[0], 0]) 188 | y_loc = int(velo_pts_im[pts[0], 1]) 189 | depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() 190 | depth[depth < 0] = 0 191 | return depth 192 | 193 | 194 | def generate_mask(gt_depth, min_depth, max_depth): 195 | mask = np.logical_and(gt_depth > min_depth, 196 | gt_depth < max_depth) 197 | # crop used by Garg ECCV16 to reprocude Eigen NIPS14 results 198 | # if used on gt_size 370x1224 produces a crop of [-218, -3, 44, 1180] 199 | gt_height, gt_width = gt_depth.shape 200 | crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height, 201 | 0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32) 202 | 203 | crop_mask = np.zeros(mask.shape) 204 | crop_mask[crop[0]:crop[1],crop[2]:crop[3]] = 1 205 | mask = np.logical_and(mask, crop_mask) 206 | return mask 207 | -------------------------------------------------------------------------------- /kitti_eval/pose_evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # Mostly based on the code written by Clement Godard: 2 | # https://github.com/mrharicot/monodepth/blob/master/utils/evaluation_utils.py 3 | import numpy as np 4 | # import pandas as pd 5 | from path import Path 6 | from scipy.misc import imread 7 | from tqdm import tqdm 8 | 9 | 10 | class test_framework_KITTI(object): 11 | def __init__(self, root, sequence_set, seq_length=3, step=1): 12 | self.root = root 13 | self.img_files, self.poses, self.sample_indices = read_scene_data(self.root, sequence_set, seq_length, step) 14 | 15 | def generator(self): 16 | for img_list, pose_list, sample_list in zip(self.img_files, self.poses, self.sample_indices): 17 | for snippet_indices in sample_list: 18 | imgs = [imread(img_list[i]).astype(np.float32) for i in snippet_indices] 19 | 20 | poses = np.stack(pose_list[i] for i in snippet_indices) 21 | first_pose = poses[0] 22 | poses[:,:,-1] -= first_pose[:,-1] 23 | compensated_poses = np.linalg.inv(first_pose[:,:3]) @ poses 24 | 25 | yield {'imgs': imgs, 26 | 'path': img_list[0], 27 | 'poses': compensated_poses 28 | } 29 | 30 | def __iter__(self): 31 | return self.generator() 32 | 33 | def __len__(self): 34 | return sum(len(imgs) for imgs in self.img_files) 35 | 36 | 37 | def read_scene_data(data_root, sequence_set, seq_length=3, step=1): 38 | data_root = Path(data_root) 39 | im_sequences = [] 40 | poses_sequences = [] 41 | indices_sequences = [] 42 | demi_length = (seq_length - 1) // 2 43 | shift_range = np.array([step*i for i in range(-demi_length, demi_length + 1)]).reshape(1, -1) 44 | 45 | sequences = set() 46 | for seq in sequence_set: 47 | corresponding_dirs = set((data_root/'sequences').dirs(seq)) 48 | sequences = sequences | corresponding_dirs 49 | 50 | print('getting test metadata for theses sequences : {}'.format(sequences)) 51 | for sequence in tqdm(sequences): 52 | poses = np.genfromtxt(data_root/'poses'/'{}.txt'.format(sequence.name)).astype(np.float64).reshape(-1, 3, 4) 53 | imgs = sorted((sequence/'image_2').files('*.png')) 54 | # construct 5-snippet sequences 55 | tgt_indices = np.arange(demi_length, len(imgs) - demi_length).reshape(-1, 1) 56 | snippet_indices = shift_range + tgt_indices 57 | im_sequences.append(imgs) 58 | poses_sequences.append(poses) 59 | indices_sequences.append(snippet_indices) 60 | return im_sequences, poses_sequences, indices_sequences -------------------------------------------------------------------------------- /kitti_eval/validation_flow_video.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2019, Anurag Ranjan 3 | # All rights reserved. 4 | # based on github.com/ClementPinard/SfMLearner-Pytorch 5 | 6 | import torch.utils.data as data 7 | import numpy as np 8 | from scipy.misc import imread 9 | from PIL import Image 10 | from path import Path 11 | from flowutils import flow_io 12 | import torch 13 | import os 14 | from skimage import transform as sktransform 15 | 16 | def crawl_folders(folders_list): 17 | imgs = [] 18 | depth = [] 19 | for folder in folders_list: 20 | current_imgs = sorted(folder.files('*.jpg')) 21 | current_depth = [] 22 | for img in current_imgs: 23 | d = img.dirname()/(img.name[:-4] + '.npy') 24 | assert(d.isfile()), "depth file {} not found".format(str(d)) 25 | depth.append(d) 26 | imgs.extend(current_imgs) 27 | depth.extend(current_depth) 28 | return imgs, depth 29 | 30 | 31 | def load_as_float(path): 32 | return imread(path).astype(np.float32) 33 | 34 | def get_intrinsics(calib_file, cid='02'): 35 | #print(zoom_x, zoom_y) 36 | filedata = read_raw_calib_file(calib_file) 37 | P_rect = np.reshape(filedata['P_rect_' + cid], (3, 4)) 38 | return P_rect[:,:3] 39 | 40 | 41 | def read_raw_calib_file(filepath): 42 | # From https://github.com/utiasSTARS/pykitti/blob/master/pykitti/utils.py 43 | """Read in a calibration file and parse into a dictionary.""" 44 | data = {} 45 | 46 | with open(filepath, 'r') as f: 47 | for line in f.readlines(): 48 | key, value = line.split(':', 1) 49 | # The only non-float values in these files are dates, which 50 | # we don't care about anyway 51 | try: 52 | data[key] = np.array([float(x) for x in value.split()]) 53 | except ValueError: 54 | pass 55 | return data 56 | 57 | class KITTI2015Test(data.Dataset): 58 | """ 59 | Kitti 2015 flow loader 60 | transform functions must take in a list a images and a numpy array which can be None 61 | """ 62 | 63 | def __init__(self, root, sequence_length, transform=None, N=200, phase='testing'): 64 | self.root = Path(root) 65 | self.sequence_length = sequence_length 66 | self.N = N 67 | self.transform = transform 68 | self.phase = phase 69 | seq_ids = list(range(-int(sequence_length/2), int(sequence_length/2)+1)) 70 | seq_ids.remove(0) 71 | self.seq_ids = [x+10 for x in seq_ids] 72 | 73 | def __getitem__(self, index): 74 | tgt_img_path = self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_10.png') 75 | ref_img_paths = [self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_'+str(k).zfill(2)+'.png') for k in self.seq_ids] 76 | cam_calib_path = self.root.joinpath('data_scene_flow_calib', self.phase, 'calib_cam_to_cam', str(index).zfill(6)+'.txt') 77 | 78 | tgt_img_original = load_as_float(tgt_img_path) 79 | tgt_img = load_as_float(tgt_img_path) 80 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths] 81 | intrinsics = get_intrinsics(cam_calib_path).astype('float32') 82 | tgt_img_original = torch.FloatTensor(tgt_img_original.transpose(2,0,1)) 83 | 84 | if self.transform is not None: 85 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(intrinsics)) 86 | tgt_img = imgs[0] 87 | ref_imgs = imgs[1:] 88 | else: 89 | intrinsics = np.copy(intrinsics) 90 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics), tgt_img_original 91 | 92 | def __len__(self): 93 | return self.N 94 | 95 | class ValidationFlow(data.Dataset): 96 | """ 97 | Kitti 2015 flow loader 98 | transform functions must take in a list a images and a numpy array which can be None 99 | """ 100 | 101 | def __init__(self, root, sequence_length, transform=None, N=200, phase='training', occ='flow_occ'): 102 | self.root = Path(root) 103 | self.sequence_length = sequence_length 104 | self.N = N 105 | self.transform = transform 106 | self.phase = phase 107 | seq_ids = list(range(-int(sequence_length/2), int(sequence_length/2)+1)) 108 | seq_ids.remove(0) 109 | self.seq_ids = [x for x in seq_ids] 110 | self.occ = occ 111 | 112 | def __getitem__(self, index): 113 | 114 | path_list=os.listdir(self.root) 115 | path_list.sort() 116 | index += 2 117 | 118 | tgt_img_path = self.root.joinpath(str(index).zfill(10)+'.png') 119 | ref_img_paths = [self.root.joinpath(str(index+k).zfill(10)+'.png') for k in self.seq_ids] 120 | 121 | gt_flow_path = self.root.joinpath('data_scene_flow', self.phase, self.occ, str(index).zfill(6)+'_10.png') 122 | cam_calib_path = self.root.joinpath('data_scene_flow_calib', self.phase, 'calib_cam_to_cam', str(index).zfill(6)+'.txt') 123 | obj_map_path = self.root.joinpath('data_scene_flow', self.phase, 'obj_map', str(index).zfill(6)+'_10.png') 124 | 125 | tgt_img = load_as_float(tgt_img_path) 126 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths] 127 | if os.path.isfile(obj_map_path): 128 | obj_map = load_as_float(obj_map_path) 129 | else: 130 | obj_map = np.ones((tgt_img.shape[0], tgt_img.shape[1])) 131 | u,v,valid = flow_io.flow_read_png(gt_flow_path) 132 | gtFlow = np.dstack((u,v,valid)) 133 | #gtFlow = scale_flow(np.dstack((u,v,valid)), h=self.flow_h, w=self.flow_w) 134 | gtFlow = torch.FloatTensor(gtFlow.transpose(2,0,1)) 135 | intrinsics = get_intrinsics(cam_calib_path).astype('float32') 136 | 137 | if self.transform is not None: 138 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(intrinsics)) 139 | tgt_img = imgs[0] 140 | ref_imgs = imgs[1:] 141 | else: 142 | intrinsics = np.copy(intrinsics) 143 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics), gtFlow, obj_map 144 | 145 | def __len__(self): 146 | return self.N 147 | 148 | class ValidationMask(data.Dataset): 149 | """ 150 | Kitti 2015 flow loader 151 | transform functions must take in a list a images and a numpy array which can be None 152 | """ 153 | 154 | def __init__(self, root, sequence_length, transform=None, N=200, phase='training'): 155 | self.root = Path(root) 156 | self.sequence_length = sequence_length 157 | self.N = N 158 | self.transform = transform 159 | self.phase = phase 160 | seq_ids = list(range(-int(sequence_length/2), int(sequence_length/2)+1)) 161 | seq_ids.remove(0) 162 | self.seq_ids = [x+10 for x in seq_ids] 163 | 164 | def __getitem__(self, index): 165 | tgt_img_path = self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_10.png') 166 | ref_img_paths = [self.root.joinpath('data_scene_flow_multiview', self.phase, 'image_2',str(index).zfill(6)+'_'+str(k).zfill(2)+'.png') for k in self.seq_ids] 167 | 168 | gt_flow_path = self.root.joinpath('data_scene_flow', self.phase, 'flow_occ', str(index).zfill(6)+'_10.png') 169 | cam_calib_path = self.root.joinpath('data_scene_flow_calib', self.phase, 'calib_cam_to_cam', str(index).zfill(6)+'.txt') 170 | obj_map_path = self.root.joinpath('data_scene_flow', self.phase, 'obj_map', str(index).zfill(6)+'_10.png') 171 | semantic_map_path = self.root.joinpath('semantic_labels', self.phase, 'semantic', str(index).zfill(6)+'_10.png') 172 | 173 | tgt_img = load_as_float(tgt_img_path) 174 | ref_imgs = [load_as_float(ref_img) for ref_img in ref_img_paths] 175 | obj_map = torch.LongTensor(np.array(Image.open(obj_map_path))) 176 | semantic_map = torch.LongTensor(np.array(Image.open(semantic_map_path))) 177 | u,v,valid = flow_io.flow_read_png(gt_flow_path) 178 | gtFlow = np.dstack((u,v,valid)) 179 | #gtFlow = scale_flow(np.dstack((u,v,valid)), h=self.flow_h, w=self.flow_w) 180 | gtFlow = torch.FloatTensor(gtFlow.transpose(2,0,1)) 181 | intrinsics = get_intrinsics(cam_calib_path).astype('float32') 182 | 183 | if self.transform is not None: 184 | imgs, intrinsics = self.transform([tgt_img] + ref_imgs, np.copy(intrinsics)) 185 | tgt_img = imgs[0] 186 | ref_imgs = imgs[1:] 187 | else: 188 | intrinsics = np.copy(intrinsics) 189 | return tgt_img, ref_imgs, intrinsics, np.linalg.inv(intrinsics), gtFlow, obj_map, semantic_map 190 | 191 | def __len__(self): 192 | return self.N 193 | 194 | class ValidationFlowKitti2012(data.Dataset): 195 | """ 196 | Kitti 2012 flow loader 197 | transform functions must take in a list a images and a numpy array which can be None 198 | """ 199 | 200 | def __init__(self, root, sequence_length=5, transform=None, N=194, flow_w=1024, flow_h=384, phase='training'): 201 | self.root = Path(root) 202 | self.sequence_length = sequence_length 203 | self.N = N 204 | self.transform = transform 205 | self.phase = phase 206 | self.flow_h = flow_h 207 | self.flow_w = flow_w 208 | 209 | def __getitem__(self, index): 210 | tgt_img_path = self.root.joinpath('data_stereo_flow', self.phase, 'colored_0',str(index).zfill(6)+'_10.png') 211 | ref_img_path = self.root.joinpath('data_stereo_flow', self.phase, 'colored_0',str(index).zfill(6)+'_11.png') 212 | gt_flow_path = self.root.joinpath('data_stereo_flow', self.phase, 'flow_occ', str(index).zfill(6)+'_10.png') 213 | 214 | tgt_img = load_as_float(tgt_img_path) 215 | ref_img = load_as_float(ref_img_path) 216 | 217 | u,v,valid = flow_io.flow_read_png(gt_flow_path) 218 | #gtFlow = scale_flow(np.dstack((u,v,valid)), h=self.flow_h, w=self.flow_w) 219 | gtFlow = np.dstack((u,v,valid)) 220 | gtFlow = torch.FloatTensor(gtFlow.transpose(2,0,1)) 221 | 222 | intrinsics = np.eye(3) 223 | if self.transform is not None: 224 | imgs, intrinsics = self.transform([tgt_img] + [ref_img], np.copy(intrinsics)) 225 | tgt_img = imgs[0] 226 | ref_img = imgs[1] 227 | else: 228 | intrinsics = np.copy(intrinsics) 229 | return tgt_img, ref_img, intrinsics, np.linalg.inv(intrinsics), gtFlow 230 | 231 | def __len__(self): 232 | return self.N 233 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from blessings import Terminal 2 | import progressbar 3 | import sys 4 | 5 | 6 | class TermLogger(object): 7 | def __init__(self, n_epochs, train_size, valid_size): 8 | self.n_epochs = n_epochs 9 | self.train_size = train_size 10 | self.valid_size = valid_size 11 | self.t = Terminal() 12 | s = 10 13 | e = 1 # epoch bar position 14 | tr = 3 # train bar position 15 | ts = 6 # valid bar position 16 | h = self.t.height 17 | 18 | for i in range(10): 19 | print('') 20 | self.epoch_bar = progressbar.ProgressBar(maxval=n_epochs, fd=Writer(self.t, (0, h-s+e))) 21 | 22 | self.train_writer = Writer(self.t, (0, h-s+tr)) 23 | self.train_bar_writer = Writer(self.t, (0, h-s+tr+1)) 24 | 25 | self.valid_writer = Writer(self.t, (0, h-s+ts)) 26 | self.valid_bar_writer = Writer(self.t, (0, h-s+ts+1)) 27 | 28 | self.reset_train_bar() 29 | self.reset_valid_bar() 30 | 31 | def reset_train_bar(self): 32 | self.train_bar = progressbar.ProgressBar(maxval=self.train_size, fd=self.train_bar_writer).start() 33 | 34 | def reset_valid_bar(self): 35 | self.valid_bar = progressbar.ProgressBar(maxval=self.valid_size, fd=self.valid_bar_writer).start() 36 | 37 | 38 | class Writer(object): 39 | """Create an object with a write method that writes to a 40 | specific place on the screen, defined at instantiation. 41 | 42 | This is the glue between blessings and progressbar. 43 | """ 44 | 45 | def __init__(self, t, location): 46 | """ 47 | Input: location - tuple of ints (x, y), the position 48 | of the bar in the terminal 49 | """ 50 | self.location = location 51 | self.t = t 52 | 53 | def write(self, string): 54 | with self.t.location(*self.location): 55 | sys.stdout.write("\033[K") 56 | print(string) 57 | 58 | def flush(self): 59 | return 60 | 61 | 62 | class AverageMeter(object): 63 | """Computes and stores the average and current value""" 64 | 65 | def __init__(self, i=1, precision=3): 66 | self.meters = i 67 | self.precision = precision 68 | self.reset(self.meters) 69 | 70 | def reset(self, i): 71 | self.val = [0]*i 72 | self.avg = [0]*i 73 | self.sum = [0]*i 74 | self.count = 0 75 | 76 | def update(self, val, n=1): 77 | if not isinstance(val, list): 78 | val = [val] 79 | assert(len(val) == self.meters) 80 | self.count += n 81 | for i,v in enumerate(val): 82 | self.val[i] = v 83 | self.sum[i] += v * n 84 | self.avg[i] = self.sum[i] / self.count 85 | 86 | def __repr__(self): 87 | val = ' '.join(['{:.{}f}'.format(v, self.precision) for v in self.val]) 88 | avg = ' '.join(['{:.{}f}'.format(a, self.precision) for a in self.avg]) 89 | return '{} ({})'.format(val, avg) 90 | -------------------------------------------------------------------------------- /models/DispNetS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def downsample_conv(in_planes, out_planes, kernel_size=3): 6 | return nn.Sequential( 7 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2), 8 | nn.ReLU(inplace=True), 9 | nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2), 10 | nn.ReLU(inplace=True) 11 | ) 12 | 13 | 14 | def predict_disp(in_planes): 15 | return nn.Sequential( 16 | nn.Conv2d(in_planes, 1, kernel_size=3, padding=1), 17 | nn.Sigmoid() 18 | ) 19 | 20 | 21 | def conv(in_planes, out_planes): 22 | return nn.Sequential( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), 24 | nn.ReLU(inplace=True) 25 | ) 26 | 27 | 28 | def upconv(in_planes, out_planes): 29 | return nn.Sequential( 30 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1), 31 | nn.ReLU(inplace=True) 32 | ) 33 | 34 | 35 | def crop_like(input, ref): 36 | assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) 37 | return input[:, :, :ref.size(2), :ref.size(3)] 38 | 39 | 40 | class DispNetS(nn.Module): 41 | 42 | def __init__(self, alpha=10, beta=0.01): 43 | super(DispNetS, self).__init__() 44 | 45 | self.alpha = alpha 46 | self.beta = beta 47 | 48 | conv_planes = [32, 64, 128, 256, 512, 512, 512] 49 | self.conv1 = downsample_conv(3, conv_planes[0], kernel_size=7) 50 | self.conv2 = downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5) 51 | self.conv3 = downsample_conv(conv_planes[1], conv_planes[2]) 52 | self.conv4 = downsample_conv(conv_planes[2], conv_planes[3]) 53 | self.conv5 = downsample_conv(conv_planes[3], conv_planes[4]) 54 | self.conv6 = downsample_conv(conv_planes[4], conv_planes[5]) 55 | self.conv7 = downsample_conv(conv_planes[5], conv_planes[6]) 56 | 57 | upconv_planes = [512, 512, 256, 128, 64, 32, 16] 58 | self.upconv7 = upconv(conv_planes[6], upconv_planes[0]) 59 | self.upconv6 = upconv(upconv_planes[0], upconv_planes[1]) 60 | self.upconv5 = upconv(upconv_planes[1], upconv_planes[2]) 61 | self.upconv4 = upconv(upconv_planes[2], upconv_planes[3]) 62 | self.upconv3 = upconv(upconv_planes[3], upconv_planes[4]) 63 | self.upconv2 = upconv(upconv_planes[4], upconv_planes[5]) 64 | self.upconv1 = upconv(upconv_planes[5], upconv_planes[6]) 65 | 66 | self.iconv7 = conv(upconv_planes[0] + conv_planes[5], upconv_planes[0]) 67 | self.iconv6 = conv(upconv_planes[1] + conv_planes[4], upconv_planes[1]) 68 | self.iconv5 = conv(upconv_planes[2] + conv_planes[3], upconv_planes[2]) 69 | self.iconv4 = conv(upconv_planes[3] + conv_planes[2], upconv_planes[3]) 70 | self.iconv3 = conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4]) 71 | self.iconv2 = conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5]) 72 | self.iconv1 = conv(1 + upconv_planes[6], upconv_planes[6]) 73 | 74 | self.predict_disp4 = predict_disp(upconv_planes[3]) 75 | self.predict_disp3 = predict_disp(upconv_planes[4]) 76 | self.predict_disp2 = predict_disp(upconv_planes[5]) 77 | self.predict_disp1 = predict_disp(upconv_planes[6]) 78 | 79 | def init_weights(self): 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 82 | nn.init.xavier_uniform(m.weight.data) 83 | if m.bias is not None: 84 | m.bias.data.zero_() 85 | 86 | def forward(self, x): 87 | out_conv1 = self.conv1(x) 88 | out_conv2 = self.conv2(out_conv1) 89 | out_conv3 = self.conv3(out_conv2) 90 | out_conv4 = self.conv4(out_conv3) 91 | out_conv5 = self.conv5(out_conv4) 92 | out_conv6 = self.conv6(out_conv5) 93 | out_conv7 = self.conv7(out_conv6) 94 | 95 | out_upconv7 = crop_like(self.upconv7(out_conv7), out_conv6) 96 | concat7 = torch.cat((out_upconv7, out_conv6), 1) 97 | out_iconv7 = self.iconv7(concat7) 98 | 99 | out_upconv6 = crop_like(self.upconv6(out_iconv7), out_conv5) 100 | concat6 = torch.cat((out_upconv6, out_conv5), 1) 101 | out_iconv6 = self.iconv6(concat6) 102 | 103 | out_upconv5 = crop_like(self.upconv5(out_iconv6), out_conv4) 104 | concat5 = torch.cat((out_upconv5, out_conv4), 1) 105 | out_iconv5 = self.iconv5(concat5) 106 | 107 | out_upconv4 = crop_like(self.upconv4(out_iconv5), out_conv3) 108 | concat4 = torch.cat((out_upconv4, out_conv3), 1) 109 | out_iconv4 = self.iconv4(concat4) 110 | disp4 = self.alpha * self.predict_disp4(out_iconv4) + self.beta 111 | 112 | out_upconv3 = crop_like(self.upconv3(out_iconv4), out_conv2) 113 | disp4_up = crop_like(nn.functional.upsample(disp4, scale_factor=2, mode='bilinear'), out_conv2) 114 | concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1) 115 | out_iconv3 = self.iconv3(concat3) 116 | disp3 = self.alpha * self.predict_disp3(out_iconv3) + self.beta 117 | 118 | out_upconv2 = crop_like(self.upconv2(out_iconv3), out_conv1) 119 | disp3_up = crop_like(nn.functional.upsample(disp3, scale_factor=2, mode='bilinear'), out_conv1) 120 | concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) 121 | out_iconv2 = self.iconv2(concat2) 122 | disp2 = self.alpha * self.predict_disp2(out_iconv2) + self.beta 123 | 124 | out_upconv1 = crop_like(self.upconv1(out_iconv2), x) 125 | disp2_up = crop_like(nn.functional.upsample(disp2, scale_factor=2, mode='bilinear'), x) 126 | concat1 = torch.cat((out_upconv1, disp2_up), 1) 127 | out_iconv1 = self.iconv1(concat1) 128 | disp1 = self.alpha * self.predict_disp1(out_iconv1) + self.beta 129 | 130 | if self.training: 131 | return disp1, disp2, disp3, disp4 132 | else: 133 | return disp1 134 | -------------------------------------------------------------------------------- /models/DispNetS6.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def downsample_conv(in_planes, out_planes, kernel_size=3): 6 | return nn.Sequential( 7 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2), 8 | nn.ReLU(inplace=True), 9 | nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2), 10 | nn.ReLU(inplace=True) 11 | ) 12 | 13 | 14 | def predict_disp(in_planes): 15 | return nn.Sequential( 16 | nn.Conv2d(in_planes, 1, kernel_size=3, padding=1), 17 | nn.Sigmoid() 18 | ) 19 | 20 | 21 | def conv(in_planes, out_planes): 22 | return nn.Sequential( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), 24 | nn.ReLU(inplace=True) 25 | ) 26 | 27 | 28 | def upconv(in_planes, out_planes): 29 | return nn.Sequential( 30 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1), 31 | nn.ReLU(inplace=True) 32 | ) 33 | 34 | 35 | def crop_like(input, ref): 36 | assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) 37 | return input[:, :, :ref.size(2), :ref.size(3)] 38 | 39 | 40 | class DispNetS6(nn.Module): 41 | 42 | def __init__(self, alpha=10, beta=0.01): 43 | super(DispNetS6, self).__init__() 44 | 45 | self.alpha = alpha 46 | self.beta = beta 47 | 48 | conv_planes = [32, 64, 128, 256, 512, 512, 512] 49 | self.conv1 = downsample_conv(3, conv_planes[0], kernel_size=7) 50 | self.conv2 = downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5) 51 | self.conv3 = downsample_conv(conv_planes[1], conv_planes[2]) 52 | self.conv4 = downsample_conv(conv_planes[2], conv_planes[3]) 53 | self.conv5 = downsample_conv(conv_planes[3], conv_planes[4]) 54 | self.conv6 = downsample_conv(conv_planes[4], conv_planes[5]) 55 | self.conv7 = downsample_conv(conv_planes[5], conv_planes[6]) 56 | 57 | upconv_planes = [512, 512, 256, 128, 64, 32, 16] 58 | self.upconv7 = upconv(conv_planes[6], upconv_planes[0]) 59 | self.upconv6 = upconv(upconv_planes[0], upconv_planes[1]) 60 | self.upconv5 = upconv(upconv_planes[1], upconv_planes[2]) 61 | self.upconv4 = upconv(upconv_planes[2], upconv_planes[3]) 62 | self.upconv3 = upconv(upconv_planes[3], upconv_planes[4]) 63 | self.upconv2 = upconv(upconv_planes[4], upconv_planes[5]) 64 | self.upconv1 = upconv(upconv_planes[5], upconv_planes[6]) 65 | 66 | self.iconv7 = conv(upconv_planes[0] + conv_planes[5], upconv_planes[0]) 67 | self.iconv6 = conv(upconv_planes[1] + conv_planes[4], upconv_planes[1]) 68 | self.iconv5 = conv(upconv_planes[2] + conv_planes[3], upconv_planes[2]) 69 | self.iconv4 = conv(upconv_planes[3] + conv_planes[2], upconv_planes[3]) 70 | self.iconv3 = conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4]) 71 | self.iconv2 = conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5]) 72 | self.iconv1 = conv(1 + upconv_planes[6], upconv_planes[6]) 73 | 74 | self.predict_disp6 = predict_disp(upconv_planes[1]) 75 | self.predict_disp5 = predict_disp(upconv_planes[2]) 76 | self.predict_disp4 = predict_disp(upconv_planes[3]) 77 | self.predict_disp3 = predict_disp(upconv_planes[4]) 78 | self.predict_disp2 = predict_disp(upconv_planes[5]) 79 | self.predict_disp1 = predict_disp(upconv_planes[6]) 80 | 81 | def init_weights(self): 82 | for m in self.modules(): 83 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 84 | nn.init.xavier_uniform(m.weight.data) 85 | if m.bias is not None: 86 | m.bias.data.zero_() 87 | 88 | def forward(self, x): 89 | out_conv1 = self.conv1(x) 90 | out_conv2 = self.conv2(out_conv1) 91 | out_conv3 = self.conv3(out_conv2) 92 | out_conv4 = self.conv4(out_conv3) 93 | out_conv5 = self.conv5(out_conv4) 94 | out_conv6 = self.conv6(out_conv5) 95 | out_conv7 = self.conv7(out_conv6) 96 | 97 | out_upconv7 = crop_like(self.upconv7(out_conv7), out_conv6) 98 | concat7 = torch.cat((out_upconv7, out_conv6), 1) 99 | out_iconv7 = self.iconv7(concat7) 100 | 101 | out_upconv6 = crop_like(self.upconv6(out_iconv7), out_conv5) 102 | concat6 = torch.cat((out_upconv6, out_conv5), 1) 103 | out_iconv6 = self.iconv6(concat6) 104 | disp6 = self.alpha * self.predict_disp6(out_iconv6) + self.beta 105 | 106 | out_upconv5 = crop_like(self.upconv5(out_iconv6), out_conv4) 107 | concat5 = torch.cat((out_upconv5, out_conv4), 1) 108 | out_iconv5 = self.iconv5(concat5) 109 | disp5 = self.alpha * self.predict_disp5(out_iconv5) + self.beta 110 | 111 | out_upconv4 = crop_like(self.upconv4(out_iconv5), out_conv3) 112 | concat4 = torch.cat((out_upconv4, out_conv3), 1) 113 | out_iconv4 = self.iconv4(concat4) 114 | disp4 = self.alpha * self.predict_disp4(out_iconv4) + self.beta 115 | 116 | out_upconv3 = crop_like(self.upconv3(out_iconv4), out_conv2) 117 | disp4_up = crop_like(nn.functional.upsample(disp4, scale_factor=2, mode='bilinear'), out_conv2) 118 | concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1) 119 | out_iconv3 = self.iconv3(concat3) 120 | disp3 = self.alpha * self.predict_disp3(out_iconv3) + self.beta 121 | 122 | out_upconv2 = crop_like(self.upconv2(out_iconv3), out_conv1) 123 | disp3_up = crop_like(nn.functional.upsample(disp3, scale_factor=2, mode='bilinear'), out_conv1) 124 | concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) 125 | out_iconv2 = self.iconv2(concat2) 126 | disp2 = self.alpha * self.predict_disp2(out_iconv2) + self.beta 127 | 128 | out_upconv1 = crop_like(self.upconv1(out_iconv2), x) 129 | disp2_up = crop_like(nn.functional.upsample(disp2, scale_factor=2, mode='bilinear'), x) 130 | concat1 = torch.cat((out_upconv1, disp2_up), 1) 131 | out_iconv1 = self.iconv1(concat1) 132 | disp1 = self.alpha * self.predict_disp1(out_iconv1) + self.beta 133 | 134 | if self.training: 135 | return disp1, disp2, disp3, disp4, disp5, disp6 136 | else: 137 | return disp1 138 | -------------------------------------------------------------------------------- /models/DispResNet6.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2019, Anurag Ranjan 3 | # All rights reserved. 4 | # based on github.com/ClementPinard/SfMLearner-Pytorch 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | #self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.conv2 = conv3x3(planes, planes) 23 | #self.bn2 = nn.BatchNorm2d(planes) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | #out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | #out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | def make_layer(inplanes, block, planes, blocks, stride=1): 46 | downsample = None 47 | if stride != 1 or inplanes != planes * block.expansion: 48 | downsample = nn.Sequential( 49 | nn.Conv2d(inplanes, planes * block.expansion, 50 | kernel_size=1, stride=stride, bias=False), 51 | nn.BatchNorm2d(planes * block.expansion), 52 | ) 53 | 54 | layers = [] 55 | layers.append(block(inplanes, planes, stride, downsample)) 56 | inplanes = planes * block.expansion 57 | for i in range(1, blocks): 58 | layers.append(block(inplanes, planes)) 59 | 60 | return nn.Sequential(*layers) 61 | 62 | def downsample_conv(in_planes, out_planes, kernel_size=3): 63 | return nn.Sequential( 64 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2), 67 | nn.ReLU(inplace=True) 68 | ) 69 | 70 | 71 | def predict_disp(in_planes): 72 | return nn.Sequential( 73 | nn.Conv2d(in_planes, 1, kernel_size=3, padding=1), 74 | nn.Sigmoid() 75 | ) 76 | 77 | 78 | def conv(in_planes, out_planes): 79 | return nn.Sequential( 80 | nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), 81 | nn.ReLU(inplace=True) 82 | ) 83 | 84 | 85 | def upconv(in_planes, out_planes): 86 | return nn.Sequential( 87 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1), 88 | nn.ReLU(inplace=True) 89 | ) 90 | 91 | 92 | def crop_like(input, ref): 93 | assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) 94 | return input[:, :, :ref.size(2), :ref.size(3)] 95 | 96 | 97 | class DispResNet6(nn.Module): 98 | 99 | def __init__(self, alpha=10, beta=0.01): 100 | super(DispResNet6, self).__init__() 101 | 102 | self.alpha = alpha 103 | self.beta = beta 104 | 105 | conv_planes = [32, 64, 128, 256, 512, 512, 512] 106 | self.conv1 = downsample_conv(3, conv_planes[0], kernel_size=7) 107 | self.conv2 = make_layer(conv_planes[0], BasicBlock, conv_planes[1], blocks=2, stride=2) 108 | self.conv3 = make_layer(conv_planes[1], BasicBlock, conv_planes[2], blocks=2, stride=2) 109 | self.conv4 = make_layer(conv_planes[2], BasicBlock, conv_planes[3], blocks=2, stride=2) 110 | self.conv5 = make_layer(conv_planes[3], BasicBlock, conv_planes[4], blocks=2, stride=2) 111 | self.conv6 = make_layer(conv_planes[4], BasicBlock, conv_planes[5], blocks=2, stride=2) 112 | self.conv7 = make_layer(conv_planes[5], BasicBlock, conv_planes[6], blocks=2, stride=2) 113 | 114 | upconv_planes = [512, 512, 256, 128, 64, 32, 16] 115 | self.upconv7 = upconv(conv_planes[6], upconv_planes[0]) 116 | self.upconv6 = upconv(upconv_planes[0], upconv_planes[1]) 117 | self.upconv5 = upconv(upconv_planes[1], upconv_planes[2]) 118 | self.upconv4 = upconv(upconv_planes[2], upconv_planes[3]) 119 | self.upconv3 = upconv(upconv_planes[3], upconv_planes[4]) 120 | self.upconv2 = upconv(upconv_planes[4], upconv_planes[5]) 121 | self.upconv1 = upconv(upconv_planes[5], upconv_planes[6]) 122 | 123 | self.iconv7 = make_layer(upconv_planes[0] + conv_planes[5], BasicBlock, upconv_planes[0], blocks=1, stride=1) 124 | self.iconv6 = make_layer(upconv_planes[1] + conv_planes[4], BasicBlock, upconv_planes[1], blocks=1, stride=1) 125 | self.iconv5 = make_layer(upconv_planes[2] + conv_planes[3], BasicBlock, upconv_planes[2], blocks=1, stride=1) 126 | self.iconv4 = make_layer(upconv_planes[3] + conv_planes[2], BasicBlock, upconv_planes[3], blocks=1, stride=1) 127 | self.iconv3 = make_layer(1 + upconv_planes[4] + conv_planes[1], BasicBlock, upconv_planes[4], blocks=1, stride=1) 128 | self.iconv2 = make_layer(1 + upconv_planes[5] + conv_planes[0], BasicBlock, upconv_planes[5], blocks=1, stride=1) 129 | self.iconv1 = make_layer(1 + upconv_planes[6], BasicBlock, upconv_planes[6], blocks=1, stride=1) 130 | 131 | self.predict_disp6 = predict_disp(upconv_planes[1]) 132 | self.predict_disp5 = predict_disp(upconv_planes[2]) 133 | self.predict_disp4 = predict_disp(upconv_planes[3]) 134 | self.predict_disp3 = predict_disp(upconv_planes[4]) 135 | self.predict_disp2 = predict_disp(upconv_planes[5]) 136 | self.predict_disp1 = predict_disp(upconv_planes[6]) 137 | 138 | def init_weights(self): 139 | for m in self.modules(): 140 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 141 | nn.init.xavier_uniform(m.weight.data) 142 | if m.bias is not None: 143 | m.bias.data.zero_() 144 | 145 | def forward(self, x): 146 | out_conv1 = self.conv1(x) 147 | out_conv2 = self.conv2(out_conv1) 148 | out_conv3 = self.conv3(out_conv2) 149 | out_conv4 = self.conv4(out_conv3) 150 | out_conv5 = self.conv5(out_conv4) 151 | out_conv6 = self.conv6(out_conv5) 152 | out_conv7 = self.conv7(out_conv6) 153 | 154 | out_upconv7 = crop_like(self.upconv7(out_conv7), out_conv6) 155 | concat7 = torch.cat((out_upconv7, out_conv6), 1) 156 | out_iconv7 = self.iconv7(concat7) 157 | 158 | out_upconv6 = crop_like(self.upconv6(out_iconv7), out_conv5) 159 | concat6 = torch.cat((out_upconv6, out_conv5), 1) 160 | out_iconv6 = self.iconv6(concat6) 161 | disp6 = self.alpha * self.predict_disp6(out_iconv6) + self.beta 162 | 163 | out_upconv5 = crop_like(self.upconv5(out_iconv6), out_conv4) 164 | concat5 = torch.cat((out_upconv5, out_conv4), 1) 165 | out_iconv5 = self.iconv5(concat5) 166 | disp5 = self.alpha * self.predict_disp5(out_iconv5) + self.beta 167 | 168 | out_upconv4 = crop_like(self.upconv4(out_iconv5), out_conv3) 169 | concat4 = torch.cat((out_upconv4, out_conv3), 1) 170 | out_iconv4 = self.iconv4(concat4) 171 | disp4 = self.alpha * self.predict_disp4(out_iconv4) + self.beta 172 | 173 | out_upconv3 = crop_like(self.upconv3(out_iconv4), out_conv2) 174 | disp4_up = crop_like(nn.functional.upsample(disp4, scale_factor=2, mode='bilinear'), out_conv2) 175 | concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1) 176 | out_iconv3 = self.iconv3(concat3) 177 | disp3 = self.alpha * self.predict_disp3(out_iconv3) + self.beta 178 | 179 | out_upconv2 = crop_like(self.upconv2(out_iconv3), out_conv1) 180 | disp3_up = crop_like(nn.functional.upsample(disp3, scale_factor=2, mode='bilinear'), out_conv1) 181 | concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) 182 | out_iconv2 = self.iconv2(concat2) 183 | disp2 = self.alpha * self.predict_disp2(out_iconv2) + self.beta 184 | 185 | out_upconv1 = crop_like(self.upconv1(out_iconv2), x) 186 | disp2_up = crop_like(nn.functional.upsample(disp2, scale_factor=2, mode='bilinear'), x) 187 | concat1 = torch.cat((out_upconv1, disp2_up), 1) 188 | out_iconv1 = self.iconv1(concat1) 189 | disp1 = self.alpha * self.predict_disp1(out_iconv1) + self.beta 190 | 191 | if self.training: 192 | return disp1, disp2, disp3, disp4, disp5, disp6 193 | else: 194 | return disp1 -------------------------------------------------------------------------------- /models/DispResNetS6.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2019, Anurag Ranjan 3 | # All rights reserved. 4 | # based on github.com/ClementPinard/SfMLearner-Pytorch 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | """3x3 convolution with padding""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | #self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.conv2 = conv3x3(planes, planes) 23 | #self.bn2 = nn.BatchNorm2d(planes) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | #out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | #out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | def make_layer(inplanes, block, planes, blocks, stride=1): 46 | downsample = None 47 | if stride != 1 or inplanes != planes * block.expansion: 48 | downsample = nn.Sequential( 49 | nn.Conv2d(inplanes, planes * block.expansion, 50 | kernel_size=1, stride=stride, bias=False), 51 | nn.BatchNorm2d(planes * block.expansion), 52 | ) 53 | 54 | layers = [] 55 | layers.append(block(inplanes, planes, stride, downsample)) 56 | inplanes = planes * block.expansion 57 | for i in range(1, blocks): 58 | layers.append(block(inplanes, planes)) 59 | 60 | return nn.Sequential(*layers) 61 | 62 | def downsample_conv(in_planes, out_planes, kernel_size=3): 63 | return nn.Sequential( 64 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2), 67 | nn.ReLU(inplace=True) 68 | ) 69 | 70 | 71 | def predict_disp(in_planes): 72 | return nn.Sequential( 73 | nn.Conv2d(in_planes, 1, kernel_size=3, padding=1), 74 | nn.Sigmoid() 75 | ) 76 | 77 | 78 | def conv(in_planes, out_planes): 79 | return nn.Sequential( 80 | nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), 81 | nn.ReLU(inplace=True) 82 | ) 83 | 84 | 85 | def upconv(in_planes, out_planes): 86 | return nn.Sequential( 87 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1), 88 | nn.ReLU(inplace=True) 89 | ) 90 | 91 | 92 | def crop_like(input, ref): 93 | assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)) 94 | return input[:, :, :ref.size(2), :ref.size(3)] 95 | 96 | 97 | class DispResNetS6(nn.Module): 98 | 99 | def __init__(self, alpha=10, beta=0.01): 100 | super(DispResNetS6, self).__init__() 101 | 102 | self.alpha = alpha 103 | self.beta = beta 104 | 105 | conv_planes = [32, 64, 128, 256, 512, 512, 512] 106 | self.conv1 = downsample_conv(3, conv_planes[0], kernel_size=7) 107 | self.conv2 = make_layer(conv_planes[0], BasicBlock, conv_planes[1], blocks=2, stride=2) 108 | self.conv3 = make_layer(conv_planes[1], BasicBlock, conv_planes[2], blocks=2, stride=2) 109 | self.conv4 = make_layer(conv_planes[2], BasicBlock, conv_planes[3], blocks=3, stride=2) 110 | self.conv5 = make_layer(conv_planes[3], BasicBlock, conv_planes[4], blocks=3, stride=2) 111 | self.conv6 = make_layer(conv_planes[4], BasicBlock, conv_planes[5], blocks=3, stride=2) 112 | self.conv7 = make_layer(conv_planes[5], BasicBlock, conv_planes[6], blocks=3, stride=2) 113 | 114 | upconv_planes = [512, 512, 256, 128, 64, 32, 16] 115 | self.upconv7 = upconv(conv_planes[6], upconv_planes[0]) 116 | self.upconv6 = upconv(upconv_planes[0], upconv_planes[1]) 117 | self.upconv5 = upconv(upconv_planes[1], upconv_planes[2]) 118 | self.upconv4 = upconv(upconv_planes[2], upconv_planes[3]) 119 | self.upconv3 = upconv(upconv_planes[3], upconv_planes[4]) 120 | self.upconv2 = upconv(upconv_planes[4], upconv_planes[5]) 121 | self.upconv1 = upconv(upconv_planes[5], upconv_planes[6]) 122 | 123 | self.iconv7 = make_layer(upconv_planes[0] + conv_planes[5], BasicBlock, upconv_planes[0], blocks=2, stride=1) 124 | self.iconv6 = make_layer(upconv_planes[1] + conv_planes[4], BasicBlock, upconv_planes[1], blocks=2, stride=1) 125 | self.iconv5 = make_layer(upconv_planes[2] + conv_planes[3], BasicBlock, upconv_planes[2], blocks=2, stride=1) 126 | self.iconv4 = make_layer(upconv_planes[3] + conv_planes[2], BasicBlock, upconv_planes[3], blocks=2, stride=1) 127 | self.iconv3 = make_layer(1 + upconv_planes[4] + conv_planes[1], BasicBlock, upconv_planes[4], blocks=1, stride=1) 128 | self.iconv2 = make_layer(1 + upconv_planes[5] + conv_planes[0], BasicBlock, upconv_planes[5], blocks=1, stride=1) 129 | self.iconv1 = make_layer(1 + upconv_planes[6], BasicBlock, upconv_planes[6], blocks=1, stride=1) 130 | 131 | self.predict_disp6 = predict_disp(upconv_planes[1]) 132 | self.predict_disp5 = predict_disp(upconv_planes[2]) 133 | self.predict_disp4 = predict_disp(upconv_planes[3]) 134 | self.predict_disp3 = predict_disp(upconv_planes[4]) 135 | self.predict_disp2 = predict_disp(upconv_planes[5]) 136 | self.predict_disp1 = predict_disp(upconv_planes[6]) 137 | 138 | def init_weights(self): 139 | for m in self.modules(): 140 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 141 | nn.init.xavier_uniform_(m.weight.data) 142 | if m.bias is not None: 143 | m.bias.data.zero_() 144 | 145 | def forward(self, x): 146 | out_conv1 = self.conv1(x) 147 | out_conv2 = self.conv2(out_conv1) 148 | out_conv3 = self.conv3(out_conv2) 149 | out_conv4 = self.conv4(out_conv3) 150 | out_conv5 = self.conv5(out_conv4) 151 | out_conv6 = self.conv6(out_conv5) 152 | out_conv7 = self.conv7(out_conv6) 153 | 154 | out_upconv7 = crop_like(self.upconv7(out_conv7), out_conv6) 155 | concat7 = torch.cat((out_upconv7, out_conv6), 1) 156 | out_iconv7 = self.iconv7(concat7) 157 | 158 | out_upconv6 = crop_like(self.upconv6(out_iconv7), out_conv5) 159 | concat6 = torch.cat((out_upconv6, out_conv5), 1) 160 | out_iconv6 = self.iconv6(concat6) 161 | disp6 = self.alpha * self.predict_disp6(out_iconv6) + self.beta 162 | 163 | out_upconv5 = crop_like(self.upconv5(out_iconv6), out_conv4) 164 | concat5 = torch.cat((out_upconv5, out_conv4), 1) 165 | out_iconv5 = self.iconv5(concat5) 166 | disp5 = self.alpha * self.predict_disp5(out_iconv5) + self.beta 167 | 168 | out_upconv4 = crop_like(self.upconv4(out_iconv5), out_conv3) 169 | concat4 = torch.cat((out_upconv4, out_conv3), 1) 170 | out_iconv4 = self.iconv4(concat4) 171 | disp4 = self.alpha * self.predict_disp4(out_iconv4) + self.beta 172 | 173 | out_upconv3 = crop_like(self.upconv3(out_iconv4), out_conv2) 174 | disp4_up = crop_like(nn.functional.upsample(disp4, scale_factor=2, mode='bilinear'), out_conv2) 175 | concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1) 176 | out_iconv3 = self.iconv3(concat3) 177 | disp3 = self.alpha * self.predict_disp3(out_iconv3) + self.beta 178 | 179 | out_upconv2 = crop_like(self.upconv2(out_iconv3), out_conv1) 180 | disp3_up = crop_like(nn.functional.upsample(disp3, scale_factor=2, mode='bilinear'), out_conv1) 181 | concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1) 182 | out_iconv2 = self.iconv2(concat2) 183 | disp2 = self.alpha * self.predict_disp2(out_iconv2) + self.beta 184 | 185 | out_upconv1 = crop_like(self.upconv1(out_iconv2), x) 186 | disp2_up = crop_like(nn.functional.upsample(disp2, scale_factor=2, mode='bilinear'), x) 187 | concat1 = torch.cat((out_upconv1, disp2_up), 1) 188 | out_iconv1 = self.iconv1(concat1) 189 | disp1 = self.alpha * self.predict_disp1(out_iconv1) + self.beta 190 | 191 | if self.training: 192 | return disp1, disp2, disp3, disp4, disp5, disp6 193 | else: 194 | return disp1 195 | -------------------------------------------------------------------------------- /models/FlowNetC6.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2019, Anurag Ranjan 3 | # All rights reserved. 4 | # based on github.com/NVIDIA/FlowNet2-Pytorch 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import init 9 | 10 | import math 11 | import numpy as np 12 | 13 | # from .correlation_package.modules.correlation import Correlation 14 | from spatial_correlation_sampler import spatial_correlation_sample 15 | from .submodules import conv, deconv, predict_flow 16 | 'Parameter count , 39,175,298 ' 17 | 18 | def correlate(input1, input2): 19 | out_corr = spatial_correlation_sample(input1, 20 | input2, 21 | kernel_size=1, 22 | patch_size=21, 23 | stride=1, 24 | padding=0, 25 | dilation_patch=2) 26 | # collate dimensions 1 and 2 in order to be treated as a 27 | # regular 4D tensor 28 | b, ph, pw, h, w = out_corr.size() 29 | out_corr = out_corr.view(b, ph * pw, h, w)/input1.size(1) 30 | return out_corr 31 | 32 | class FlowNetC6(nn.Module): 33 | def __init__(self, nlevels=5, batchNorm=False, div_flow = 20, full_res=True, pretrained=True): 34 | super(FlowNetC6,self).__init__() 35 | 36 | #assert(nlevels==5) 37 | self.batchNorm = batchNorm 38 | self.div_flow = div_flow 39 | self.full_res = full_res 40 | 41 | self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2) 42 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 43 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 44 | self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1) 45 | 46 | # if args.fp16: 47 | # self.corr = nn.Sequential( 48 | # tofp32(), 49 | # Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1), 50 | # tofp16()) 51 | # else: 52 | self.corr = correlate # Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1) 53 | 54 | self.corr_activation = nn.LeakyReLU(0.1,inplace=True) 55 | self.conv3_1 = conv(self.batchNorm, 473, 256) 56 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 57 | self.conv4_1 = conv(self.batchNorm, 512, 512) 58 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 59 | self.conv5_1 = conv(self.batchNorm, 512, 512) 60 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 61 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 62 | 63 | self.deconv5 = deconv(1024,512) 64 | self.deconv4 = deconv(1026,256) 65 | self.deconv3 = deconv(770,128) 66 | self.deconv2 = deconv(386,64) 67 | self.deconv1 = deconv(194,32) 68 | 69 | self.predict_flow6 = predict_flow(1024) 70 | self.predict_flow5 = predict_flow(1026) 71 | self.predict_flow4 = predict_flow(770) 72 | self.predict_flow3 = predict_flow(386) 73 | self.predict_flow2 = predict_flow(194) 74 | self.predict_flow1 = predict_flow(98) 75 | 76 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 77 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 78 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 79 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 80 | self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 81 | 82 | self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear') 83 | 84 | def init_weights(self): 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | if m.bias is not None: 88 | init.uniform(m.bias) 89 | init.xavier_uniform(m.weight) 90 | 91 | if isinstance(m, nn.ConvTranspose2d): 92 | if m.bias is not None: 93 | init.uniform(m.bias) 94 | init.xavier_uniform(m.weight) 95 | # init_deconv_bilinear(m.weight) 96 | 97 | 98 | 99 | def forward(self, x1,x2): 100 | 101 | out_conv1a = self.conv1(x1) 102 | out_conv2a = self.conv2(out_conv1a) 103 | out_conv3a = self.conv3(out_conv2a) 104 | 105 | # FlownetC bottom input stream 106 | out_conv1b = self.conv1(x2) 107 | out_conv2b = self.conv2(out_conv1b) 108 | out_conv3b = self.conv3(out_conv2b) 109 | 110 | # Merge streams 111 | out_corr = self.corr(out_conv3a, out_conv3b) 112 | out_corr = self.corr_activation(out_corr) 113 | 114 | # Redirect top input stream and concatenate 115 | out_conv_redir = self.conv_redir(out_conv3a) 116 | 117 | in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1) 118 | 119 | # Merged conv layers 120 | out_conv3_1 = self.conv3_1(in_conv3_1) 121 | out_conv4 = self.conv4_1(self.conv4(out_conv3_1)) 122 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 123 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 124 | 125 | flow6 = self.predict_flow6(out_conv6) 126 | out_deconv5 = self.deconv5(out_conv6) 127 | flow6_up = self.upsampled_flow6_to_5(flow6) 128 | 129 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 130 | 131 | flow5 = self.predict_flow5(concat5) 132 | out_deconv4 = self.deconv4(concat5) 133 | flow5_up = self.upsampled_flow5_to_4(flow5) 134 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 135 | 136 | flow4 = self.predict_flow4(concat4) 137 | out_deconv3 = self.deconv3(concat4) 138 | flow4_up = self.upsampled_flow4_to_3(flow4) 139 | concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1) 140 | 141 | flow3 = self.predict_flow3(concat3) 142 | out_deconv2 = self.deconv2(concat3) 143 | flow3_up = self.upsampled_flow3_to_2(flow3) 144 | concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1) 145 | 146 | flow2 = self.predict_flow2(concat2) 147 | out_deconv1 = self.deconv1(concat2) 148 | flow2_up = self.upsampled_flow2_to_1(flow2) 149 | concat1 = torch.cat((out_conv1a,out_deconv1,flow2_up), 1) 150 | 151 | flow1 = self.predict_flow1(concat1) 152 | #out_convs = [out_conv2a, out_conv2b, out_conv3a, out_conv3b] 153 | if self.full_res: 154 | flow1 = self.div_flow*self.upsample1(flow1) 155 | flow2 = self.div_flow*self.upsample1(flow2) 156 | flow3 = self.div_flow*self.upsample1(flow3) 157 | flow4 = self.div_flow*self.upsample1(flow4) 158 | flow5 = self.div_flow*self.upsample1(flow5) 159 | flow6 = self.div_flow*self.upsample1(flow6) 160 | 161 | if self.training: 162 | return flow1, flow2,flow3,flow4,flow5,flow6 #, out_convs 163 | else: 164 | return flow1 165 | -------------------------------------------------------------------------------- /models/MaskNet6.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def conv(in_planes, out_planes, kernel_size=3): 6 | return nn.Sequential( 7 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2, stride=2), 8 | nn.ReLU(inplace=True) 9 | ) 10 | 11 | 12 | def upconv(in_planes, out_planes): 13 | return nn.Sequential( 14 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | 19 | class MaskNet6(nn.Module): 20 | 21 | def __init__(self, nb_ref_imgs=4, output_exp=True): 22 | super(MaskNet6, self).__init__() 23 | self.nb_ref_imgs = nb_ref_imgs 24 | self.output_exp = output_exp 25 | 26 | conv_planes = [16, 32, 64, 128, 256, 256, 256, 256] 27 | self.conv1 = conv(3*(1+self.nb_ref_imgs), conv_planes[0], kernel_size=7) 28 | self.conv2 = conv(conv_planes[0], conv_planes[1], kernel_size=5) 29 | self.conv3 = conv(conv_planes[1], conv_planes[2]) 30 | self.conv4 = conv(conv_planes[2], conv_planes[3]) 31 | self.conv5 = conv(conv_planes[3], conv_planes[4]) 32 | self.conv6 = conv(conv_planes[4], conv_planes[5]) 33 | #self.conv7 = conv(conv_planes[5], conv_planes[6]) 34 | #self.conv8 = conv(conv_planes[6], conv_planes[7]) 35 | 36 | #self.pose_pred = nn.Conv2d(conv_planes[7], 6*self.nb_ref_imgs, kernel_size=1, padding=0) 37 | 38 | if self.output_exp: 39 | upconv_planes = [256, 256, 128, 64, 32, 16] 40 | self.deconv6 = upconv(conv_planes[5], upconv_planes[0]) 41 | self.deconv5 = upconv(upconv_planes[0]+conv_planes[4], upconv_planes[1]) 42 | self.deconv4 = upconv(upconv_planes[1]+conv_planes[3], upconv_planes[2]) 43 | self.deconv3 = upconv(upconv_planes[2]+conv_planes[2], upconv_planes[3]) 44 | self.deconv2 = upconv(upconv_planes[3]+conv_planes[1], upconv_planes[4]) 45 | self.deconv1 = upconv(upconv_planes[4]+conv_planes[0], upconv_planes[5]) 46 | 47 | self.pred_mask6 = nn.Conv2d(upconv_planes[0], self.nb_ref_imgs, kernel_size=3, padding=1) 48 | self.pred_mask5 = nn.Conv2d(upconv_planes[1], self.nb_ref_imgs, kernel_size=3, padding=1) 49 | self.pred_mask4 = nn.Conv2d(upconv_planes[2], self.nb_ref_imgs, kernel_size=3, padding=1) 50 | self.pred_mask3 = nn.Conv2d(upconv_planes[3], self.nb_ref_imgs, kernel_size=3, padding=1) 51 | self.pred_mask2 = nn.Conv2d(upconv_planes[4], self.nb_ref_imgs, kernel_size=3, padding=1) 52 | self.pred_mask1 = nn.Conv2d(upconv_planes[5], self.nb_ref_imgs, kernel_size=3, padding=1) 53 | 54 | def init_weights(self): 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 57 | nn.init.xavier_uniform(m.weight.data) 58 | if m.bias is not None: 59 | m.bias.data.zero_() 60 | 61 | def init_mask_weights(self): 62 | for m in self.modules(): 63 | if isinstance(m, nn.ConvTranspose2d): 64 | nn.init.xavier_uniform(m.weight.data) 65 | if m.bias is not None: 66 | m.bias.data.zero_() 67 | 68 | for module in [self.pred_mask1, self.pred_mask2, self.pred_mask3, self.pred_mask4, self.pred_mask5, self.pred_mask6]: 69 | for m in module.modules(): 70 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 71 | nn.init.xavier_uniform(m.weight.data) 72 | if m.bias is not None: 73 | m.bias.data.zero_() 74 | 75 | # for mod in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, self.pose_pred]: 76 | # for fparams in mod.parameters(): 77 | # fparams.requires_grad = False 78 | 79 | 80 | def forward(self, target_image, ref_imgs): 81 | assert(len(ref_imgs) == self.nb_ref_imgs) 82 | input = [target_image] 83 | input.extend(ref_imgs) 84 | input = torch.cat(input, 1) 85 | out_conv1 = self.conv1(input) 86 | out_conv2 = self.conv2(out_conv1) 87 | out_conv3 = self.conv3(out_conv2) 88 | out_conv4 = self.conv4(out_conv3) 89 | out_conv5 = self.conv5(out_conv4) 90 | out_conv6 = self.conv6(out_conv5) 91 | #out_conv7 = self.conv7(out_conv6) 92 | #out_conv8 = self.conv8(out_conv7) 93 | 94 | #pose = self.pose_pred(out_conv8) 95 | #pose = pose.mean(3).mean(2) 96 | #pose = 0.01 * pose.view(pose.size(0), self.nb_ref_imgs, 6) 97 | 98 | if self.output_exp: 99 | out_upconv6 = self.deconv6(out_conv6 )#[:, :, 0:out_conv5.size(2), 0:out_conv5.size(3)] 100 | out_upconv5 = self.deconv5(torch.cat((out_upconv6, out_conv5), 1))#[:, :, 0:out_conv4.size(2), 0:out_conv4.size(3)] 101 | out_upconv4 = self.deconv4(torch.cat((out_upconv5, out_conv4), 1))#[:, :, 0:out_conv3.size(2), 0:out_conv3.size(3)] 102 | out_upconv3 = self.deconv3(torch.cat((out_upconv4, out_conv3), 1))#[:, :, 0:out_conv2.size(2), 0:out_conv2.size(3)] 103 | out_upconv2 = self.deconv2(torch.cat((out_upconv3, out_conv2), 1))#[:, :, 0:out_conv1.size(2), 0:out_conv1.size(3)] 104 | out_upconv1 = self.deconv1(torch.cat((out_upconv2, out_conv1), 1))#[:, :, 0:input.size(2), 0:input.size(3)] 105 | 106 | exp_mask6 = nn.functional.sigmoid(self.pred_mask6(out_upconv6)) 107 | exp_mask5 = nn.functional.sigmoid(self.pred_mask5(out_upconv5)) 108 | exp_mask4 = nn.functional.sigmoid(self.pred_mask4(out_upconv4)) 109 | exp_mask3 = nn.functional.sigmoid(self.pred_mask3(out_upconv3)) 110 | exp_mask2 = nn.functional.sigmoid(self.pred_mask2(out_upconv2)) 111 | exp_mask1 = nn.functional.sigmoid(self.pred_mask1(out_upconv1)) 112 | else: 113 | exp_mask6 = None 114 | exp_mask5 = None 115 | exp_mask4 = None 116 | exp_mask3 = None 117 | exp_mask2 = None 118 | exp_mask1 = None 119 | 120 | if self.training: 121 | return exp_mask1, exp_mask2, exp_mask3, exp_mask4, exp_mask5, exp_mask6 122 | else: 123 | return exp_mask1 124 | -------------------------------------------------------------------------------- /models/MaskResNet6.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2019, Anurag Ranjan 3 | # All rights reserved. 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | def conv(in_planes, out_planes, kernel_size=3, stride=2): 14 | return nn.Sequential( 15 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2, stride=stride), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | 20 | def upconv(in_planes, out_planes): 21 | return nn.Sequential( 22 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1), 23 | nn.ReLU(inplace=True) 24 | ) 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.relu(out) 42 | out = self.conv2(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | def make_layer(inplanes, block, planes, blocks, stride=1): 53 | downsample = None 54 | if stride != 1 or inplanes != planes * block.expansion: 55 | downsample = nn.Sequential( 56 | nn.Conv2d(inplanes, planes * block.expansion, 57 | kernel_size=1, stride=stride, bias=False), 58 | nn.BatchNorm2d(planes * block.expansion), 59 | ) 60 | 61 | layers = [] 62 | layers.append(block(inplanes, planes, stride, downsample)) 63 | inplanes = planes * block.expansion 64 | for i in range(1, blocks): 65 | layers.append(block(inplanes, planes)) 66 | 67 | return nn.Sequential(*layers) 68 | 69 | class MaskResNet6(nn.Module): 70 | 71 | def __init__(self, nb_ref_imgs=4, output_exp=True): 72 | super(MaskResNet6, self).__init__() 73 | self.nb_ref_imgs = nb_ref_imgs 74 | self.output_exp = output_exp 75 | 76 | conv_planes = [16, 32, 64, 128, 256, 256, 256, 256] 77 | self.conv1 = conv(3*(1+self.nb_ref_imgs), conv_planes[0], kernel_size=7, stride=2) 78 | self.conv2 = make_layer(conv_planes[0], BasicBlock, conv_planes[1], blocks=2, stride=2) 79 | self.conv3 = make_layer(conv_planes[1], BasicBlock, conv_planes[2], blocks=2, stride=2) 80 | self.conv4 = make_layer(conv_planes[2], BasicBlock, conv_planes[3], blocks=2, stride=2) 81 | self.conv5 = make_layer(conv_planes[3], BasicBlock, conv_planes[4], blocks=2, stride=2) 82 | self.conv6 = make_layer(conv_planes[4], BasicBlock, conv_planes[5], blocks=2, stride=2) 83 | 84 | if self.output_exp: 85 | upconv_planes = [256, 256, 128, 64, 32, 16] 86 | self.deconv6 = upconv(conv_planes[5], upconv_planes[0]) 87 | self.deconv5 = upconv(upconv_planes[0]+conv_planes[4], upconv_planes[1]) 88 | self.deconv4 = upconv(upconv_planes[1]+conv_planes[3], upconv_planes[2]) 89 | self.deconv3 = upconv(upconv_planes[2]+conv_planes[2], upconv_planes[3]) 90 | self.deconv2 = upconv(upconv_planes[3]+conv_planes[1], upconv_planes[4]) 91 | self.deconv1 = upconv(upconv_planes[4]+conv_planes[0], upconv_planes[5]) 92 | 93 | self.pred_mask6 = nn.Conv2d(upconv_planes[0], self.nb_ref_imgs, kernel_size=3, padding=1) 94 | self.pred_mask5 = nn.Conv2d(upconv_planes[1], self.nb_ref_imgs, kernel_size=3, padding=1) 95 | self.pred_mask4 = nn.Conv2d(upconv_planes[2], self.nb_ref_imgs, kernel_size=3, padding=1) 96 | self.pred_mask3 = nn.Conv2d(upconv_planes[3], self.nb_ref_imgs, kernel_size=3, padding=1) 97 | self.pred_mask2 = nn.Conv2d(upconv_planes[4], self.nb_ref_imgs, kernel_size=3, padding=1) 98 | self.pred_mask1 = nn.Conv2d(upconv_planes[5], self.nb_ref_imgs, kernel_size=3, padding=1) 99 | 100 | def init_weights(self): 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 103 | nn.init.xavier_uniform(m.weight.data) 104 | if m.bias is not None: 105 | m.bias.data.zero_() 106 | 107 | def init_mask_weights(self): 108 | for m in self.modules(): 109 | if isinstance(m, nn.ConvTranspose2d): 110 | nn.init.xavier_uniform(m.weight.data) 111 | if m.bias is not None: 112 | m.bias.data.zero_() 113 | 114 | for module in [self.pred_mask1, self.pred_mask2, self.pred_mask3, self.pred_mask4, self.pred_mask5, self.pred_mask6]: 115 | for m in module.modules(): 116 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 117 | nn.init.xavier_uniform(m.weight.data) 118 | if m.bias is not None: 119 | m.bias.data.zero_() 120 | 121 | 122 | 123 | def forward(self, target_image, ref_imgs): 124 | assert(len(ref_imgs) == self.nb_ref_imgs) 125 | input = [target_image] 126 | input.extend(ref_imgs) 127 | input = torch.cat(input, 1) 128 | out_conv1 = self.conv1(input) 129 | out_conv2 = self.conv2(out_conv1) 130 | out_conv3 = self.conv3(out_conv2) 131 | out_conv4 = self.conv4(out_conv3) 132 | out_conv5 = self.conv5(out_conv4) 133 | out_conv6 = self.conv6(out_conv5) 134 | 135 | if self.output_exp: 136 | out_upconv6 = self.deconv6(out_conv6 )#[:, :, 0:out_conv5.size(2), 0:out_conv5.size(3)] 137 | out_upconv5 = self.deconv5(torch.cat((out_upconv6, out_conv5), 1))#[:, :, 0:out_conv4.size(2), 0:out_conv4.size(3)] 138 | out_upconv4 = self.deconv4(torch.cat((out_upconv5, out_conv4), 1))#[:, :, 0:out_conv3.size(2), 0:out_conv3.size(3)] 139 | out_upconv3 = self.deconv3(torch.cat((out_upconv4, out_conv3), 1))#[:, :, 0:out_conv2.size(2), 0:out_conv2.size(3)] 140 | out_upconv2 = self.deconv2(torch.cat((out_upconv3, out_conv2), 1))#[:, :, 0:out_conv1.size(2), 0:out_conv1.size(3)] 141 | out_upconv1 = self.deconv1(torch.cat((out_upconv2, out_conv1), 1))#[:, :, 0:input.size(2), 0:input.size(3)] 142 | 143 | exp_mask6 = nn.functional.sigmoid(self.pred_mask6(out_upconv6)) 144 | exp_mask5 = nn.functional.sigmoid(self.pred_mask5(out_upconv5)) 145 | exp_mask4 = nn.functional.sigmoid(self.pred_mask4(out_upconv4)) 146 | exp_mask3 = nn.functional.sigmoid(self.pred_mask3(out_upconv3)) 147 | exp_mask2 = nn.functional.sigmoid(self.pred_mask2(out_upconv2)) 148 | exp_mask1 = nn.functional.sigmoid(self.pred_mask1(out_upconv1)) 149 | else: 150 | exp_mask6 = None 151 | exp_mask5 = None 152 | exp_mask4 = None 153 | exp_mask3 = None 154 | exp_mask2 = None 155 | exp_mask1 = None 156 | 157 | if self.training: 158 | return exp_mask1, exp_mask2, exp_mask3, exp_mask4, exp_mask5, exp_mask6 159 | else: 160 | return exp_mask1 161 | -------------------------------------------------------------------------------- /models/PoseExpNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def conv(in_planes, out_planes, kernel_size=3): 6 | return nn.Sequential( 7 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2, stride=2), 8 | nn.ReLU(inplace=True) 9 | ) 10 | 11 | 12 | def upconv(in_planes, out_planes): 13 | return nn.Sequential( 14 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | 19 | class PoseExpNet(nn.Module): 20 | 21 | def __init__(self, nb_ref_imgs=2, output_exp=False): 22 | super(PoseExpNet, self).__init__() 23 | self.nb_ref_imgs = nb_ref_imgs 24 | self.output_exp = output_exp 25 | 26 | conv_planes = [16, 32, 64, 128, 256, 256, 256] 27 | self.conv1 = conv(3*(1+self.nb_ref_imgs), conv_planes[0], kernel_size=7) 28 | self.conv2 = conv(conv_planes[0], conv_planes[1], kernel_size=5) 29 | self.conv3 = conv(conv_planes[1], conv_planes[2]) 30 | self.conv4 = conv(conv_planes[2], conv_planes[3]) 31 | self.conv5 = conv(conv_planes[3], conv_planes[4]) 32 | self.conv6 = conv(conv_planes[4], conv_planes[5]) 33 | self.conv7 = conv(conv_planes[5], conv_planes[6]) 34 | 35 | self.pose_pred = nn.Conv2d(conv_planes[6], 6*self.nb_ref_imgs, kernel_size=1, padding=0) 36 | 37 | if self.output_exp: 38 | upconv_planes = [256, 128, 64, 32, 16] 39 | self.upconv5 = upconv(conv_planes[4], upconv_planes[0]) 40 | self.upconv4 = upconv(upconv_planes[0], upconv_planes[1]) 41 | self.upconv3 = upconv(upconv_planes[1], upconv_planes[2]) 42 | self.upconv2 = upconv(upconv_planes[2], upconv_planes[3]) 43 | self.upconv1 = upconv(upconv_planes[3], upconv_planes[4]) 44 | 45 | self.predict_mask4 = nn.Conv2d(upconv_planes[1], self.nb_ref_imgs, kernel_size=3, padding=1) 46 | self.predict_mask3 = nn.Conv2d(upconv_planes[2], self.nb_ref_imgs, kernel_size=3, padding=1) 47 | self.predict_mask2 = nn.Conv2d(upconv_planes[3], self.nb_ref_imgs, kernel_size=3, padding=1) 48 | self.predict_mask1 = nn.Conv2d(upconv_planes[4], self.nb_ref_imgs, kernel_size=3, padding=1) 49 | 50 | def init_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 53 | nn.init.xavier_uniform(m.weight.data) 54 | if m.bias is not None: 55 | m.bias.data.zero_() 56 | 57 | def forward(self, target_image, ref_imgs): 58 | assert(len(ref_imgs) == self.nb_ref_imgs) 59 | input = [target_image] 60 | input.extend(ref_imgs) 61 | input = torch.cat(input, 1) 62 | out_conv1 = self.conv1(input) 63 | out_conv2 = self.conv2(out_conv1) 64 | out_conv3 = self.conv3(out_conv2) 65 | out_conv4 = self.conv4(out_conv3) 66 | out_conv5 = self.conv5(out_conv4) 67 | out_conv6 = self.conv6(out_conv5) 68 | out_conv7 = self.conv7(out_conv6) 69 | 70 | pose = self.pose_pred(out_conv7) 71 | pose = pose.mean(3).mean(2) 72 | pose = 0.01 * pose.view(pose.size(0), self.nb_ref_imgs, 6) 73 | 74 | if self.output_exp: 75 | out_upconv5 = self.upconv5(out_conv5 )[:, :, 0:out_conv4.size(2), 0:out_conv4.size(3)] 76 | out_upconv4 = self.upconv4(out_upconv5)[:, :, 0:out_conv3.size(2), 0:out_conv3.size(3)] 77 | out_upconv3 = self.upconv3(out_upconv4)[:, :, 0:out_conv2.size(2), 0:out_conv2.size(3)] 78 | out_upconv2 = self.upconv2(out_upconv3)[:, :, 0:out_conv1.size(2), 0:out_conv1.size(3)] 79 | out_upconv1 = self.upconv1(out_upconv2)[:, :, 0:input.size(2), 0:input.size(3)] 80 | 81 | exp_mask4 = nn.functional.sigmoid(self.predict_mask4(out_upconv4)) 82 | exp_mask3 = nn.functional.sigmoid(self.predict_mask3(out_upconv3)) 83 | exp_mask2 = nn.functional.sigmoid(self.predict_mask2(out_upconv2)) 84 | exp_mask1 = nn.functional.sigmoid(self.predict_mask1(out_upconv1)) 85 | else: 86 | exp_mask4 = None 87 | exp_mask3 = None 88 | exp_mask2 = None 89 | exp_mask1 = None 90 | 91 | if self.training: 92 | return [exp_mask1, exp_mask2, exp_mask3, exp_mask4], pose 93 | else: 94 | return exp_mask1, pose 95 | -------------------------------------------------------------------------------- /models/PoseNet6.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def conv(in_planes, out_planes, kernel_size=3): 6 | return nn.Sequential( 7 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2, stride=2), 8 | nn.ReLU(inplace=True) 9 | ) 10 | 11 | 12 | def upconv(in_planes, out_planes): 13 | return nn.Sequential( 14 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | 19 | class PoseNet6(nn.Module): 20 | 21 | def __init__(self, nb_ref_imgs=2): 22 | super(PoseNet6, self).__init__() 23 | self.nb_ref_imgs = nb_ref_imgs 24 | 25 | conv_planes = [16, 32, 64, 128, 256, 256, 256] 26 | self.conv0 = conv(3*(1+self.nb_ref_imgs), 3*(1+self.nb_ref_imgs), kernel_size=3) 27 | self.conv1 = conv(3*(1+self.nb_ref_imgs), conv_planes[0], kernel_size=7) 28 | self.conv2 = conv(conv_planes[0], conv_planes[1], kernel_size=5) 29 | self.conv3 = conv(conv_planes[1], conv_planes[2]) 30 | self.conv4 = conv(conv_planes[2], conv_planes[3]) 31 | self.conv5 = conv(conv_planes[3], conv_planes[4]) 32 | self.conv6 = conv(conv_planes[4], conv_planes[5]) 33 | self.conv7 = conv(conv_planes[5], conv_planes[6]) 34 | 35 | self.pose_pred = nn.Conv2d(conv_planes[6], 6*self.nb_ref_imgs, kernel_size=1, padding=0) 36 | 37 | def init_weights(self): 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 40 | nn.init.xavier_uniform(m.weight.data) 41 | if m.bias is not None: 42 | m.bias.data.zero_() 43 | 44 | def forward(self, target_image, ref_imgs): 45 | assert(len(ref_imgs) == self.nb_ref_imgs) 46 | input = [target_image] 47 | input.extend(ref_imgs) 48 | input = torch.cat(input, 1) 49 | out_conv0 = self.conv0(input) 50 | out_conv1 = self.conv1(out_conv0) 51 | out_conv2 = self.conv2(out_conv1) 52 | out_conv3 = self.conv3(out_conv2) 53 | out_conv4 = self.conv4(out_conv3) 54 | out_conv5 = self.conv5(out_conv4) 55 | out_conv6 = self.conv6(out_conv5) 56 | out_conv7 = self.conv7(out_conv6) 57 | 58 | pose = self.pose_pred(out_conv7) 59 | pose = pose.mean(3).mean(2) 60 | pose = 0.01 * pose.view(pose.size(0), self.nb_ref_imgs, 6) 61 | 62 | return pose 63 | -------------------------------------------------------------------------------- /models/PoseNetB6.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2019, Anurag Ranjan 3 | # All rights reserved. 4 | # based on github.com/ClementPinard/SfMLearner-Pytorch 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def conv(in_planes, out_planes, kernel_size=3): 11 | return nn.Sequential( 12 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2, stride=2), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | 17 | def upconv(in_planes, out_planes): 18 | return nn.Sequential( 19 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | 24 | class PoseNetB6(nn.Module): 25 | 26 | def __init__(self, nb_ref_imgs=2): 27 | super(PoseNetB6, self).__init__() 28 | self.nb_ref_imgs = nb_ref_imgs 29 | 30 | conv_planes = [16, 32, 64, 128, 256, 256, 256, 256] 31 | self.conv1 = conv(3*(1+self.nb_ref_imgs), conv_planes[0], kernel_size=7) 32 | self.conv2 = conv(conv_planes[0], conv_planes[1], kernel_size=5) 33 | self.conv3 = conv(conv_planes[1], conv_planes[2]) 34 | self.conv4 = conv(conv_planes[2], conv_planes[3]) 35 | self.conv5 = conv(conv_planes[3], conv_planes[4]) 36 | self.conv6 = conv(conv_planes[4], conv_planes[5]) 37 | self.conv7 = conv(conv_planes[5], conv_planes[6]) 38 | self.conv8 = conv(conv_planes[6], conv_planes[7]) 39 | 40 | self.pose_pred = nn.Conv2d(conv_planes[7], 6*self.nb_ref_imgs, kernel_size=1, padding=0) 41 | 42 | 43 | def init_weights(self): 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 46 | nn.init.xavier_uniform_(m.weight.data) 47 | if m.bias is not None: 48 | m.bias.data.zero_() 49 | 50 | def init_mask_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.ConvTranspose2d): 53 | nn.init.xavier_uniform_(m.weight.data) 54 | if m.bias is not None: 55 | m.bias.data.zero_() 56 | 57 | for module in [self.pred_mask1, self.pred_mask2, self.pred_mask3, self.pred_mask4, self.pred_mask5, self.pred_mask6]: 58 | for m in module.modules(): 59 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 60 | nn.init.xavier_uniform_(m.weight.data) 61 | if m.bias is not None: 62 | m.bias.data.zero_() 63 | 64 | 65 | def forward(self, target_image, ref_imgs): 66 | assert(len(ref_imgs) == self.nb_ref_imgs) 67 | input = [target_image] 68 | input.extend(ref_imgs) 69 | input = torch.cat(input, 1) 70 | out_conv1 = self.conv1(input) 71 | out_conv2 = self.conv2(out_conv1) 72 | out_conv3 = self.conv3(out_conv2) 73 | out_conv4 = self.conv4(out_conv3) 74 | out_conv5 = self.conv5(out_conv4) 75 | out_conv6 = self.conv6(out_conv5) 76 | out_conv7 = self.conv7(out_conv6) 77 | out_conv8 = self.conv8(out_conv7) 78 | 79 | pose = self.pose_pred(out_conv8) 80 | pose = pose.mean(3).mean(2) 81 | pose = 0.01 * pose.view(pose.size(0), self.nb_ref_imgs, 6) 82 | 83 | return pose 84 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .back2future import Model as Back2Future 2 | from .DispNetS import DispNetS 3 | from .DispNetS6 import DispNetS6 4 | from .DispResNet6 import DispResNet6 5 | from .DispResNetS6 import DispResNetS6 6 | from .FlowNetC6 import FlowNetC6 7 | from .MaskNet6 import MaskNet6 8 | from .MaskResNet6 import MaskResNet6 9 | from .PoseExpNet import PoseExpNet 10 | from .PoseNet6 import PoseNet6 11 | from .PoseNetB6 import PoseNetB6 12 | -------------------------------------------------------------------------------- /models/back2future.py: -------------------------------------------------------------------------------- 1 | # Author: Anurag Ranjan 2 | # Copyright (c) 2019, Anurag Ranjan 3 | # All rights reserved. 4 | 5 | import os 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import init 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | from spatial_correlation_sampler import spatial_correlation_sample 14 | 15 | def correlate(input1, input2): 16 | out_corr = spatial_correlation_sample(input1, 17 | input2, 18 | kernel_size=1, 19 | patch_size=9, 20 | stride=1) 21 | # collate dimensions 1 and 2 in order to be treated as a 22 | # regular 4D tensor 23 | b, ph, pw, h, w = out_corr.size() 24 | out_corr = out_corr.view(b, ph * pw, h, w)/input1.size(1) 25 | return out_corr 26 | 27 | def conv_feat_block(nIn, nOut): 28 | return nn.Sequential( 29 | nn.Conv2d(nIn, nOut, kernel_size=3, stride=2, padding=1), 30 | nn.LeakyReLU(0.2), 31 | nn.Conv2d(nOut, nOut, kernel_size=3, stride=1, padding=1), 32 | nn.LeakyReLU(0.2) 33 | ) 34 | 35 | def conv_dec_block(nIn): 36 | return nn.Sequential( 37 | nn.Conv2d(nIn, 128, kernel_size=3, stride=1, padding=1), 38 | nn.LeakyReLU(0.2), 39 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), 40 | nn.LeakyReLU(0.2), 41 | nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1), 42 | nn.LeakyReLU(0.2), 43 | nn.Conv2d(96, 64, kernel_size=3, stride=1, padding=1), 44 | nn.LeakyReLU(0.2), 45 | nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1), 46 | nn.LeakyReLU(0.2), 47 | nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1) 48 | ) 49 | 50 | 51 | class Model(nn.Module): 52 | def __init__(self, nlevels): 53 | super(Model, self).__init__() 54 | 55 | self.nlevels = nlevels 56 | idx = [list(range(n, -1, -9)) for n in range(80,71,-1)] 57 | idx = list(np.array(idx).flatten()) 58 | self.idx_fwd = Variable(torch.LongTensor(np.array(idx)).cuda(), requires_grad=False) 59 | self.idx_bwd = Variable(torch.LongTensor(np.array(list(reversed(idx)))).cuda(), requires_grad=False) 60 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') 61 | self.softmax2d = nn.Softmax2d() 62 | 63 | self.conv1a = conv_feat_block(3,16) 64 | self.conv1b = conv_feat_block(3,16) 65 | self.conv1c = conv_feat_block(3,16) 66 | 67 | self.conv2a = conv_feat_block(16,32) 68 | self.conv2b = conv_feat_block(16,32) 69 | self.conv2c = conv_feat_block(16,32) 70 | 71 | self.conv3a = conv_feat_block(32,64) 72 | self.conv3b = conv_feat_block(32,64) 73 | self.conv3c = conv_feat_block(32,64) 74 | 75 | self.conv4a = conv_feat_block(64,96) 76 | self.conv4b = conv_feat_block(64,96) 77 | self.conv4c = conv_feat_block(64,96) 78 | 79 | self.conv5a = conv_feat_block(96,128) 80 | self.conv5b = conv_feat_block(96,128) 81 | self.conv5c = conv_feat_block(96,128) 82 | 83 | self.conv6a = conv_feat_block(128,192) 84 | self.conv6b = conv_feat_block(128,192) 85 | self.conv6c = conv_feat_block(128,192) 86 | 87 | self.corr = correlate # Correlation(pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1) 88 | 89 | self.decoder_fwd6 = conv_dec_block(162) 90 | self.decoder_bwd6 = conv_dec_block(162) 91 | self.decoder_fwd5 = conv_dec_block(292) 92 | self.decoder_bwd5 = conv_dec_block(292) 93 | self.decoder_fwd4 = conv_dec_block(260) 94 | self.decoder_bwd4 = conv_dec_block(260) 95 | self.decoder_fwd3 = conv_dec_block(228) 96 | self.decoder_bwd3 = conv_dec_block(228) 97 | self.decoder_fwd2 = conv_dec_block(196) 98 | self.decoder_bwd2 = conv_dec_block(196) 99 | 100 | self.decoder_occ6 = conv_dec_block(354) 101 | self.decoder_occ5 = conv_dec_block(292) 102 | self.decoder_occ4 = conv_dec_block(260) 103 | self.decoder_occ3 = conv_dec_block(228) 104 | self.decoder_occ2 = conv_dec_block(196) 105 | 106 | def init_weights(self): 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | if m.bias is not None: 110 | init.uniform_(m.bias) 111 | init.xavier_uniform_(m.weight) 112 | 113 | if isinstance(m, nn.ConvTranspose2d): 114 | if m.bias is not None: 115 | init.uniform(m.bias) 116 | init.xavier_uniform_(m.weight) 117 | 118 | def normalize(self, ims): 119 | imt = [] 120 | for im in ims: 121 | im = im * 0.5 122 | im = im + 0.5 123 | im[:,0,:,:] = im[:,0,:,:] - 0.485 # Red 124 | im[:,1,:,:] = im[:,1,:,:] - 0.456 # Green 125 | im[:,2,:,:] = im[:,2,:,:] - 0.406 # Blue 126 | 127 | im[:,0,:,:] = im[:,0,:,:] / 0.229 # Red 128 | im[:,1,:,:] = im[:,1,:,:] / 0.224 # Green 129 | im[:,2,:,:] = im[:,2,:,:] / 0.225 # Blue 130 | 131 | imt.append(im) 132 | return imt 133 | 134 | def forward(self, im_tar, im_refs): 135 | ''' 136 | inputS: 137 | im_tar: Middle Frame, I_0 138 | im_refs: Adjecent Frames in the order, [I-, I+] 139 | 140 | outputs: 141 | At self.nlevels different scales: 142 | flow_fwd: optical flow from I_0 to I+ 143 | flow_bwd: optical flow from I_0 to I+ 144 | occ : occlusions 145 | ''' 146 | # im = Variable(torch.zeros(1,9,512,512).cuda()) 147 | # ima = im[:, :3, :, :] + 0.2 # I_0 148 | # imb = im[:, 3:6, :, :] + 0.3 # I_+ 149 | # imc = im[:, 6:, :, :] + 0.1 # I_- 150 | im_norm = self.normalize([im_tar] + im_refs) 151 | 152 | feat1a = self.conv1a(im_norm[0]) 153 | feat2a = self.conv2a(feat1a) 154 | feat3a = self.conv3a(feat2a) 155 | feat4a = self.conv4a(feat3a) 156 | feat5a = self.conv5a(feat4a) 157 | feat6a = self.conv6a(feat5a) 158 | 159 | feat1b = self.conv1b(im_norm[2]) 160 | feat2b = self.conv2b(feat1b) 161 | feat3b = self.conv3b(feat2b) 162 | feat4b = self.conv4b(feat3b) 163 | feat5b = self.conv5b(feat4b) 164 | feat6b = self.conv6b(feat5b) 165 | 166 | feat1c = self.conv1c(im_norm[1]) 167 | feat2c = self.conv2c(feat1c) 168 | feat3c = self.conv3c(feat2c) 169 | feat4c = self.conv4c(feat3c) 170 | feat5c = self.conv5c(feat4c) 171 | feat6c = self.conv6c(feat5c) 172 | 173 | corr6_fwd = self.corr(feat6a, feat6b) 174 | corr6_fwd = corr6_fwd.index_select(1,self.idx_fwd) 175 | corr6_bwd = self.corr(feat6a, feat6c) 176 | corr6_bwd = corr6_bwd.index_select(1,self.idx_bwd) 177 | corr6 = torch.cat((corr6_fwd, corr6_bwd), 1) 178 | 179 | flow6_fwd = self.decoder_fwd6(corr6) 180 | flow6_fwd_up = self.upsample(flow6_fwd) 181 | flow6_bwd = self.decoder_bwd6(corr6) 182 | flow6_bwd_up = self.upsample(flow6_bwd) 183 | feat5b_warped = self.warp(feat5b, 0.625*flow6_fwd_up) 184 | feat5c_warped = self.warp(feat5c, -0.625*flow6_fwd_up) 185 | 186 | occ6_feat = torch.cat((corr6, feat6a), 1) 187 | occ6 = self.softmax2d(self.decoder_occ6(occ6_feat)) 188 | 189 | corr5_fwd = self.corr(feat5a, feat5b_warped) 190 | corr5_fwd = corr5_fwd.index_select(1,self.idx_fwd) 191 | corr5_bwd = self.corr(feat5a, feat5c_warped) 192 | corr5_bwd = corr5_bwd.index_select(1,self.idx_bwd) 193 | corr5 = torch.cat((corr5_fwd, corr5_bwd), 1) 194 | 195 | upfeat5_fwd = torch.cat((corr5, feat5a, flow6_fwd_up), 1) 196 | flow5_fwd = self.decoder_fwd5(upfeat5_fwd) 197 | flow5_fwd_up = self.upsample(flow5_fwd) 198 | upfeat5_bwd = torch.cat((corr5, feat5a, flow6_bwd_up),1) 199 | flow5_bwd = self.decoder_bwd5(upfeat5_bwd) 200 | flow5_bwd_up = self.upsample(flow5_bwd) 201 | feat4b_warped = self.warp(feat4b, 1.25*flow5_fwd_up) 202 | feat4c_warped = self.warp(feat4c, -1.25*flow5_fwd_up) 203 | 204 | occ5 = self.softmax2d(self.decoder_occ5(upfeat5_fwd)) 205 | 206 | corr4_fwd = self.corr(feat4a, feat4b_warped) 207 | corr4_fwd = corr4_fwd.index_select(1,self.idx_fwd) 208 | corr4_bwd = self.corr(feat4a, feat4c_warped) 209 | corr4_bwd = corr4_bwd.index_select(1,self.idx_bwd) 210 | corr4 = torch.cat((corr4_fwd, corr4_bwd), 1) 211 | 212 | upfeat4_fwd = torch.cat((corr4, feat4a, flow5_fwd_up), 1) 213 | flow4_fwd = self.decoder_fwd4(upfeat4_fwd) 214 | flow4_fwd_up = self.upsample(flow4_fwd) 215 | upfeat4_bwd = torch.cat((corr4, feat4a, flow5_bwd_up),1) 216 | flow4_bwd = self.decoder_bwd4(upfeat4_bwd) 217 | flow4_bwd_up = self.upsample(flow4_bwd) 218 | feat3b_warped = self.warp(feat3b, 2.5*flow4_fwd_up) 219 | feat3c_warped = self.warp(feat3c, -2.5*flow4_fwd_up) 220 | 221 | occ4 = self.softmax2d(self.decoder_occ4(upfeat4_fwd)) 222 | 223 | corr3_fwd = self.corr(feat3a, feat3b_warped) 224 | corr3_fwd = corr3_fwd.index_select(1,self.idx_fwd) 225 | corr3_bwd = self.corr(feat3a, feat3c_warped) 226 | corr3_bwd = corr3_bwd.index_select(1,self.idx_bwd) 227 | corr3 = torch.cat((corr3_fwd, corr3_bwd), 1) 228 | 229 | upfeat3_fwd = torch.cat((corr3, feat3a, flow4_fwd_up), 1) 230 | flow3_fwd = self.decoder_fwd3(upfeat3_fwd) 231 | flow3_fwd_up = self.upsample(flow3_fwd) 232 | upfeat3_bwd = torch.cat((corr3, feat3a, flow4_bwd_up),1) 233 | flow3_bwd = self.decoder_bwd3(upfeat3_bwd) 234 | flow3_bwd_up = self.upsample(flow3_bwd) 235 | feat2b_warped = self.warp(feat2b, 5.0*flow3_fwd_up) 236 | feat2c_warped = self.warp(feat2c, -5.0*flow3_fwd_up) 237 | 238 | occ3 = self.softmax2d(self.decoder_occ3(upfeat3_fwd)) 239 | 240 | corr2_fwd = self.corr(feat2a, feat2b_warped) 241 | corr2_fwd = corr2_fwd.index_select(1,self.idx_fwd) 242 | corr2_bwd = self.corr(feat2a, feat2c_warped) 243 | corr2_bwd = corr2_bwd.index_select(1,self.idx_bwd) 244 | corr2 = torch.cat((corr2_fwd, corr2_bwd), 1) 245 | 246 | upfeat2_fwd = torch.cat((corr2, feat2a, flow3_fwd_up), 1) 247 | flow2_fwd = self.decoder_fwd2(upfeat2_fwd) 248 | flow2_fwd_up = self.upsample(flow2_fwd) 249 | upfeat2_bwd = torch.cat((corr2, feat2a, flow3_bwd_up),1) 250 | flow2_bwd = self.decoder_bwd2(upfeat2_bwd) 251 | flow2_bwd_up = self.upsample(flow2_bwd) 252 | 253 | occ2 = self.softmax2d(self.decoder_occ2(upfeat2_fwd)) 254 | 255 | flow2_fwd_fullres = 20*self.upsample(flow2_fwd_up) 256 | flow3_fwd_fullres = 10*self.upsample(flow3_fwd_up) 257 | flow4_fwd_fullres = 5*self.upsample(flow4_fwd_up) 258 | flow5_fwd_fullres = 2.5*self.upsample(flow5_fwd_up) 259 | flow6_fwd_fullres = 1.25*self.upsample(flow6_fwd_up) 260 | 261 | flow2_bwd_fullres = -20*self.upsample(flow2_bwd_up) 262 | flow3_bwd_fullres = -10*self.upsample(flow3_bwd_up) 263 | flow4_bwd_fullres = -5*self.upsample(flow4_bwd_up) 264 | flow5_bwd_fullres = -2.5*self.upsample(flow5_bwd_up) 265 | flow6_bwd_fullres = -1.25*self.upsample(flow6_bwd_up) 266 | 267 | occ2_fullres = F.upsample(occ2, scale_factor=4) 268 | occ3_fullres = F.upsample(occ3, scale_factor=4) 269 | occ4_fullres = F.upsample(occ4, scale_factor=4) 270 | occ5_fullres = F.upsample(occ5, scale_factor=4) 271 | occ6_fullres = F.upsample(occ6, scale_factor=4) 272 | 273 | if self.training: 274 | flow_fwd = [flow2_fwd_fullres, flow3_fwd_fullres, flow4_fwd_fullres, flow5_fwd_fullres, flow6_fwd_fullres] 275 | flow_bwd = [flow2_bwd_fullres, flow3_bwd_fullres, flow4_bwd_fullres, flow5_bwd_fullres, flow6_bwd_fullres] 276 | occ = [occ2_fullres, occ3_fullres, occ4_fullres, occ5_fullres, occ6_fullres] 277 | 278 | if self.nlevels==6: 279 | flow_fwd.append(0.625*flow6_fwd_up) 280 | flow_bwd.append(-0.625*flow6_bwd_up) 281 | occ.append(F.upsample(occ6, scale_factor=2)) 282 | 283 | return flow_fwd, flow_bwd, occ 284 | else: 285 | return flow2_fwd_fullres, flow2_bwd_fullres, occ2_fullres 286 | 287 | def warp(self, x, flo): 288 | """ 289 | warp an image/tensor (im2) back to im1, according to the optical flow 290 | x: [B, C, H, W] (im2) 291 | flo: [B, 2, H, W] flow 292 | """ 293 | B, C, H, W = x.size() 294 | # mesh grid 295 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 296 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 297 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 298 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 299 | grid = torch.cat((xx,yy),1).float() 300 | 301 | if x.is_cuda: 302 | grid = grid.cuda() 303 | vgrid = Variable(grid) + flo 304 | 305 | # scale grid to [-1,1] 306 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone()/max(W-1,1)-1.0 307 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone()/max(H-1,1)-1.0 308 | 309 | vgrid = vgrid.permute(0,2,3,1) 310 | output = nn.functional.grid_sample(x, vgrid, padding_mode='border') 311 | mask = torch.autograd.Variable(torch.ones(x.size()), requires_grad=False).cuda() 312 | mask = nn.functional.grid_sample(mask, vgrid) 313 | 314 | # if W==128: 315 | # np.save('mask.npy', mask.cpu().data.numpy()) 316 | # np.save('warp.npy', output.cpu().data.numpy()) 317 | 318 | mask[mask.data<0.9999] = 0 319 | mask[mask.data>0] = 1 320 | 321 | return output#*mask 322 | -------------------------------------------------------------------------------- /models/submodules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1): 6 | if batchNorm: 7 | return nn.Sequential( 8 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 9 | nn.BatchNorm2d(out_planes), 10 | #_leaky_relu() 11 | nn.LeakyReLU(0.1,inplace=True) 12 | ) 13 | else: 14 | return nn.Sequential( 15 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 16 | #_leaky_relu() 17 | nn.LeakyReLU(0.1,inplace=True) 18 | ) 19 | 20 | def i_conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, bias = True): 21 | if batchNorm: 22 | return nn.Sequential( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 24 | nn.BatchNorm2d(out_planes), 25 | ) 26 | else: 27 | return nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 29 | ) 30 | 31 | def predict_flow(in_planes): 32 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) 33 | 34 | def deconv(in_planes, out_planes): 35 | return nn.Sequential( 36 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 37 | #_leaky_relu() 38 | nn.LeakyReLU(0.1,inplace=True) 39 | ) 40 | 41 | class tofp16(nn.Module): 42 | def __init__(self): 43 | super(tofp16, self).__init__() 44 | 45 | def forward(self, input): 46 | return input.half() 47 | 48 | class _leaky_relu(nn.Module): 49 | def __init__(self): 50 | super(_leaky_relu, self).__init__() 51 | 52 | def forward(self, x): 53 | x_neg = 0.1*x 54 | return torch.max(x_neg, x) 55 | 56 | class tofp32(nn.Module): 57 | def __init__(self): 58 | super(tofp32, self).__init__() 59 | 60 | def forward(self, input): 61 | return input.float() 62 | 63 | 64 | def save_grad(grads, name): 65 | def hook(grad): 66 | grads[name] = grad 67 | return hook 68 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def conv(in_planes, out_planes, stride=1, batch_norm=False): 9 | if batch_norm: 10 | return nn.Sequential( 11 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), 12 | nn.BatchNorm2d(out_planes, eps=1e-3), 13 | nn.ReLU(inplace=True) 14 | ) 15 | else: 16 | return nn.Sequential( 17 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | 22 | def deconv(in_planes, out_planes, batch_norm=False): 23 | if batch_norm: 24 | return nn.Sequential( 25 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 26 | nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False), 27 | nn.BatchNorm2d(out_planes, eps=1e-3), 28 | nn.ReLU(inplace=True) 29 | ) 30 | else: 31 | return nn.Sequential( 32 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 33 | nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=True), 34 | nn.ReLU(inplace=True) 35 | ) 36 | 37 | 38 | def predict_depth(in_planes, with_confidence): 39 | return nn.Conv2d(in_planes, 2 if with_confidence else 1, kernel_size=3, stride=1, padding=1, bias=True) 40 | 41 | 42 | def post_process_depth(depth, activation_function=None, clamp=False): 43 | if activation_function is not None: 44 | depth = activation_function(depth) 45 | 46 | if clamp: 47 | depth = depth.clamp(10, 80) 48 | 49 | return depth[:,0] 50 | 51 | 52 | def adaptative_cat(out_conv, out_deconv, out_depth_up): 53 | out_deconv = out_deconv[:, :, :out_conv.size(2), :out_conv.size(3)] 54 | out_depth_up = out_depth_up[:, :, :out_conv.size(2), :out_conv.size(3)] 55 | return torch.cat((out_conv, out_deconv, out_depth_up), 1) 56 | 57 | 58 | def init_modules(net): 59 | for m in net.modules(): 60 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 61 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 62 | m.weight.data.normal_(0, math.sqrt(2/n)) 63 | if m.bias is not None: 64 | m.bias.data.zero_() 65 | elif isinstance(m, nn.BatchNorm2d): 66 | m.weight.data.fill_(1) 67 | m.bias.data.zero_() 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | scipy 3 | argparse 4 | tensorboardX 5 | blessings 6 | progressbar2 7 | path.py 8 | matplotlib 9 | opencv-python 10 | scikit-image 11 | pypng 12 | tqdm 13 | spatial-correlation-sampler -------------------------------------------------------------------------------- /ssim.py: -------------------------------------------------------------------------------- 1 | # Author: Jonas Wulff 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from math import exp 8 | 9 | def gaussian(window_size, sigma): 10 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 11 | return gauss/gauss.sum() 12 | 13 | def create_window(window_size, channel): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous(), requires_grad=False) 17 | return window 18 | 19 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 20 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 21 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 22 | 23 | mu1_sq = mu1.pow(2) 24 | mu2_sq = mu2.pow(2) 25 | mu1_mu2 = mu1*mu2 26 | 27 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 28 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 29 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 30 | 31 | C1 = 0.01**2 32 | C2 = 0.03**2 33 | 34 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 35 | 36 | return ssim_map 37 | #if size_average: 38 | # return ssim_map.mean() 39 | #else: 40 | # return ssim_map.mean(1).mean(1).mean(1) 41 | 42 | class SSIM(torch.nn.Module): 43 | def __init__(self, window_size = 11, size_average = True): 44 | super(SSIM, self).__init__() 45 | self.window_size = window_size 46 | self.size_average = size_average 47 | self.channel = 1 48 | self.window = create_window(window_size, self.channel) 49 | 50 | def forward(self, img1, img2): 51 | (_, channel, _, _) = img1.size() 52 | 53 | if channel == self.channel and self.window.data.type() == img1.data.type(): 54 | window = self.window 55 | else: 56 | window = create_window(self.window_size, channel) 57 | 58 | if img1.is_cuda: 59 | window = window.cuda(img1.get_device()) 60 | window = window.type_as(img1) 61 | 62 | self.window = window 63 | self.channel = channel 64 | 65 | 66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | 68 | def ssim(img1, img2, window_size = 13, size_average = True): 69 | (_, channel, _, _) = img1.size() 70 | window = create_window(window_size, channel) 71 | 72 | if img1.is_cuda: 73 | window = window.cuda(img1.get_device()) 74 | window = window.type_as(img1) 75 | 76 | return _ssim(img1, img2, window, window_size, channel, size_average) 77 | -------------------------------------------------------------------------------- /stillbox_eval/depth_evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | from path import Path 4 | from scipy.misc import imread 5 | from tqdm import tqdm 6 | 7 | 8 | class test_framework_stillbox(object): 9 | def __init__(self, root, test_files, seq_length=3, min_depth=1e-3, max_depth=80, step=1): 10 | self.root = root 11 | self.min_depth, self.max_depth = min_depth, max_depth 12 | self.gt_files, self.img_files, self.displacements = read_scene_data(self.root, test_files, seq_length, step) 13 | 14 | def __getitem__(self, i): 15 | tgt = imread(self.img_files[i][0]).astype(np.float32) 16 | depth = np.load(self.gt_files[i]) 17 | return {'tgt': tgt, 18 | 'ref': [imread(img).astype(np.float32) for img in self.img_files[i][1]], 19 | 'path':self.img_files[i][0], 20 | 'gt_depth': depth, 21 | 'displacements': np.array(self.displacements[i]), 22 | 'mask': generate_mask(depth, self.min_depth, self.max_depth) 23 | } 24 | 25 | def __len__(self): 26 | return len(self.img_files) 27 | 28 | 29 | def get_displacements(scene, index, ref_indices): 30 | speed = np.around(np.linalg.norm(scene['speed']), decimals=3) 31 | assert(all(i < scene['length'] and i >= 0 for i in ref_indices)), str(ref_indices) 32 | return [speed*scene['time_step']*abs(index - i) for i in ref_indices] 33 | 34 | 35 | def read_scene_data(data_root, test_list, seq_length=3, step=1): 36 | data_root = Path(data_root) 37 | metadata_files = {} 38 | for folder in data_root.dirs(): 39 | with open(folder/'metadata.json', 'r') as f: 40 | metadata_files[str(folder.name)] = json.load(f) 41 | gt_files = [] 42 | im_files = [] 43 | displacements = [] 44 | demi_length = (seq_length - 1) // 2 45 | shift_range = [step*i for i in list(range(-demi_length,0)) + list(range(1, demi_length + 1))] 46 | 47 | print('getting test metadata ... ') 48 | for sample in tqdm(test_list): 49 | folder, file = sample.split('/') 50 | _, scene_index, index = file[:-4].split('_') # filename is in the form 'RGB_XXXX_XX.jpg' 51 | index = int(index) 52 | scene = metadata_files[folder]['scenes'][int(scene_index)] 53 | tgt_img_path = data_root/sample 54 | folder_path = data_root/folder 55 | if tgt_img_path.isfile(): 56 | capped_indices_range = list(map(lambda x: min(max(0, index + x), scene['length'] - 1), shift_range)) 57 | ref_imgs_path = [folder_path/'{}'.format(scene['imgs'][ref_index]) for ref_index in capped_indices_range] 58 | 59 | gt_files.append(folder_path/'{}'.format(scene['depth'][index])) 60 | im_files.append([tgt_img_path,ref_imgs_path]) 61 | displacements.append(get_displacements(scene, index, capped_indices_range)) 62 | else: 63 | print('{} missing'.format(tgt_img_path)) 64 | 65 | return gt_files, im_files, displacements 66 | 67 | 68 | def generate_mask(gt_depth, min_depth, max_depth): 69 | mask = np.logical_and(gt_depth > min_depth, 70 | gt_depth < max_depth) 71 | # crop gt to exclude border values 72 | # if used on gt_size 100x100 produces a crop of [-95, -5, 5, 95] 73 | gt_height, gt_width = gt_depth.shape 74 | crop = np.array([0.05 * gt_height, 0.95 * gt_height, 75 | 0.05 * gt_width, 0.95 * gt_width]).astype(np.int32) 76 | 77 | crop_mask = np.zeros(mask.shape) 78 | crop_mask[crop[0]:crop[1],crop[2]:crop[3]] = 1 79 | mask = np.logical_and(mask, crop_mask) 80 | return mask 81 | -------------------------------------------------------------------------------- /stillbox_eval/test_files_90.txt: -------------------------------------------------------------------------------- 1 | 15/RGB_112_008.jpg 2 | 15/RGB_178_002.jpg 3 | 15/RGB_167_006.jpg 4 | 15/RGB_153_007.jpg 5 | 15/RGB_119_002.jpg 6 | 15/RGB_135_003.jpg 7 | 15/RGB_44_006.jpg 8 | 15/RGB_32_002.jpg 9 | 15/RGB_171_001.jpg 10 | 15/RGB_114_009.jpg 11 | 15/RGB_89_003.jpg 12 | 15/RGB_197_009.jpg 13 | 15/RGB_105_000.jpg 14 | 15/RGB_72_004.jpg 15 | 15/RGB_66_003.jpg 16 | 15/RGB_25_007.jpg 17 | 15/RGB_58_004.jpg 18 | 15/RGB_28_003.jpg 19 | 15/RGB_25_004.jpg 20 | 15/RGB_140_003.jpg 21 | 15/RGB_59_008.jpg 22 | 15/RGB_19_001.jpg 23 | 15/RGB_186_003.jpg 24 | 15/RGB_113_009.jpg 25 | 15/RGB_54_002.jpg 26 | 15/RGB_130_003.jpg 27 | 15/RGB_153_003.jpg 28 | 15/RGB_103_007.jpg 29 | 15/RGB_04_007.jpg 30 | 15/RGB_110_008.jpg 31 | 15/RGB_78_005.jpg 32 | 15/RGB_26_005.jpg 33 | 15/RGB_43_007.jpg 34 | 15/RGB_190_003.jpg 35 | 15/RGB_122_002.jpg 36 | 15/RGB_102_008.jpg 37 | 15/RGB_187_004.jpg 38 | 15/RGB_03_005.jpg 39 | 15/RGB_58_007.jpg 40 | 15/RGB_37_004.jpg 41 | 15/RGB_125_003.jpg 42 | 15/RGB_190_002.jpg 43 | 15/RGB_52_006.jpg 44 | 15/RGB_37_005.jpg 45 | 15/RGB_196_001.jpg 46 | 15/RGB_53_003.jpg 47 | 15/RGB_129_008.jpg 48 | 15/RGB_74_003.jpg 49 | 15/RGB_167_000.jpg 50 | 15/RGB_195_002.jpg 51 | 15/RGB_10_007.jpg 52 | 15/RGB_131_003.jpg 53 | 15/RGB_37_003.jpg 54 | 15/RGB_38_009.jpg 55 | 15/RGB_115_004.jpg 56 | 15/RGB_91_008.jpg 57 | 15/RGB_43_004.jpg 58 | 15/RGB_187_005.jpg 59 | 15/RGB_112_003.jpg 60 | 15/RGB_19_002.jpg 61 | 15/RGB_170_008.jpg 62 | 15/RGB_17_000.jpg 63 | 15/RGB_62_005.jpg 64 | 15/RGB_148_004.jpg 65 | 15/RGB_12_008.jpg 66 | 15/RGB_169_004.jpg 67 | 15/RGB_112_004.jpg 68 | 15/RGB_71_001.jpg 69 | 15/RGB_103_001.jpg 70 | 15/RGB_178_005.jpg 71 | 15/RGB_92_006.jpg 72 | 15/RGB_40_009.jpg 73 | 15/RGB_138_006.jpg 74 | 15/RGB_146_005.jpg 75 | 15/RGB_04_006.jpg 76 | 15/RGB_02_008.jpg 77 | 15/RGB_101_009.jpg 78 | 15/RGB_103_009.jpg 79 | 15/RGB_21_002.jpg 80 | 15/RGB_144_008.jpg 81 | 15/RGB_163_007.jpg 82 | 15/RGB_06_001.jpg 83 | 15/RGB_105_004.jpg 84 | 15/RGB_199_009.jpg 85 | 15/RGB_149_005.jpg 86 | 15/RGB_63_008.jpg 87 | 15/RGB_21_004.jpg 88 | 15/RGB_03_002.jpg 89 | 15/RGB_51_008.jpg 90 | 15/RGB_110_001.jpg 91 | 15/RGB_172_009.jpg 92 | 15/RGB_158_005.jpg 93 | 15/RGB_49_004.jpg 94 | 15/RGB_173_008.jpg 95 | 15/RGB_99_004.jpg 96 | 15/RGB_24_001.jpg 97 | 15/RGB_03_009.jpg 98 | 15/RGB_41_009.jpg 99 | 15/RGB_91_002.jpg 100 | 15/RGB_132_001.jpg 101 | 15/RGB_95_003.jpg 102 | 15/RGB_167_005.jpg 103 | 15/RGB_176_000.jpg 104 | 15/RGB_142_008.jpg 105 | 15/RGB_107_009.jpg 106 | 15/RGB_122_005.jpg 107 | 15/RGB_48_001.jpg 108 | 15/RGB_103_005.jpg 109 | 15/RGB_98_009.jpg 110 | 15/RGB_162_001.jpg 111 | 15/RGB_08_006.jpg 112 | 15/RGB_169_002.jpg 113 | 15/RGB_57_002.jpg 114 | 15/RGB_86_004.jpg 115 | 15/RGB_138_001.jpg 116 | 15/RGB_05_005.jpg 117 | 15/RGB_95_002.jpg 118 | 15/RGB_28_002.jpg 119 | 15/RGB_110_002.jpg 120 | 15/RGB_102_002.jpg 121 | 15/RGB_136_009.jpg 122 | 15/RGB_28_007.jpg 123 | 15/RGB_43_005.jpg 124 | 15/RGB_39_006.jpg 125 | 15/RGB_126_003.jpg 126 | 15/RGB_62_001.jpg 127 | 15/RGB_82_003.jpg 128 | 15/RGB_75_008.jpg 129 | 15/RGB_16_005.jpg 130 | 15/RGB_94_005.jpg 131 | 15/RGB_198_002.jpg 132 | 15/RGB_90_001.jpg 133 | 15/RGB_22_001.jpg 134 | 15/RGB_90_000.jpg 135 | 15/RGB_155_006.jpg 136 | 15/RGB_124_007.jpg 137 | 15/RGB_168_004.jpg 138 | 15/RGB_96_008.jpg 139 | 15/RGB_100_002.jpg 140 | 15/RGB_131_008.jpg 141 | 15/RGB_74_002.jpg 142 | 15/RGB_141_007.jpg 143 | 15/RGB_139_001.jpg 144 | 15/RGB_102_005.jpg 145 | 15/RGB_182_009.jpg 146 | 15/RGB_37_002.jpg 147 | 15/RGB_67_003.jpg 148 | 15/RGB_60_001.jpg 149 | 15/RGB_186_001.jpg 150 | 15/RGB_171_002.jpg 151 | 15/RGB_155_004.jpg 152 | 15/RGB_50_008.jpg 153 | 15/RGB_34_002.jpg 154 | 15/RGB_132_003.jpg 155 | 15/RGB_147_005.jpg 156 | 15/RGB_99_008.jpg 157 | 15/RGB_110_000.jpg 158 | 15/RGB_114_008.jpg 159 | 15/RGB_159_002.jpg 160 | 15/RGB_76_007.jpg 161 | 15/RGB_116_005.jpg 162 | 15/RGB_67_002.jpg 163 | 15/RGB_80_003.jpg 164 | 15/RGB_30_000.jpg 165 | 15/RGB_137_009.jpg 166 | 15/RGB_130_002.jpg 167 | 15/RGB_90_002.jpg 168 | 15/RGB_34_008.jpg 169 | 15/RGB_137_007.jpg 170 | 15/RGB_45_001.jpg 171 | 15/RGB_131_004.jpg 172 | 15/RGB_06_000.jpg 173 | 15/RGB_68_005.jpg 174 | 15/RGB_104_008.jpg 175 | 15/RGB_193_008.jpg 176 | 15/RGB_182_000.jpg 177 | 15/RGB_129_006.jpg 178 | 15/RGB_107_005.jpg 179 | 15/RGB_158_007.jpg 180 | 15/RGB_192_001.jpg 181 | 15/RGB_18_005.jpg 182 | 15/RGB_90_009.jpg 183 | 15/RGB_18_007.jpg 184 | 15/RGB_94_000.jpg 185 | 15/RGB_09_002.jpg 186 | 15/RGB_94_001.jpg 187 | 15/RGB_46_004.jpg 188 | 15/RGB_126_000.jpg 189 | 15/RGB_146_002.jpg 190 | 15/RGB_161_006.jpg 191 | 15/RGB_154_008.jpg 192 | 15/RGB_94_003.jpg 193 | -------------------------------------------------------------------------------- /test_disp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from PIL import Image 4 | from scipy import interpolate 5 | from scipy.misc import imresize 6 | from scipy.ndimage.interpolation import zoom 7 | import numpy as np 8 | from path import Path 9 | import argparse 10 | from tqdm import tqdm 11 | from utils import tensor2array 12 | import models 13 | from loss_functions_summary import spatial_normalize 14 | 15 | parser = argparse.ArgumentParser(description='Script for DispNet testing with corresponding groundTruth', 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument("--dispnet", dest='dispnet', type=str, default='DispResNet6', help='dispnet architecture') 18 | parser.add_argument("--posenet", dest='posenet', type=str, default='PoseExpNet', help='posenet architecture') 19 | parser.add_argument("--pretrained-dispnet", required=True, type=str, help="pretrained DispNet path") 20 | parser.add_argument("--pretrained-posenet", default=None, type=str, help="pretrained PoseNet path (for scale factor)") 21 | parser.add_argument("--img-height", default=256, type=int, help="Image height") 22 | parser.add_argument("--img-width", default=832, type=int, help="Image width") 23 | parser.add_argument("--no-resize", action='store_true', help="no resizing is done") 24 | parser.add_argument("--spatial-normalize", action='store_true', help="spatial normalization") 25 | parser.add_argument("--min-depth", default=1e-3) 26 | parser.add_argument("--max-depth", default=80, type=float) 27 | 28 | parser.add_argument("--dataset-dir", default='.', type=str, help="Dataset directory") 29 | parser.add_argument("--dataset-list", default=None, type=str, help="Dataset list file") 30 | parser.add_argument("--output-dir", default=None, type=str, help="Output directory for saving predictions in a big 3D numpy file") 31 | 32 | parser.add_argument("--gt-type", default='KITTI', type=str, help="GroundTruth data type", choices=['npy', 'png', 'KITTI', 'stillbox']) 33 | parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob") 34 | 35 | 36 | def main(): 37 | args = parser.parse_args() 38 | if args.gt_type == 'KITTI': 39 | from kitti_eval.depth_evaluation_utils import test_framework_KITTI as test_framework 40 | elif args.gt_type == 'stillbox': 41 | from stillbox_eval.depth_evaluation_utils import test_framework_stillbox as test_framework 42 | 43 | disp_net = getattr(models, args.dispnet)().cuda() 44 | weights = torch.load(args.pretrained_dispnet) 45 | disp_net.load_state_dict(weights['state_dict']) 46 | disp_net.eval() 47 | 48 | if args.pretrained_posenet is None: 49 | print('no PoseNet specified, scale_factor will be determined by median ratio, which is kiiinda cheating\ 50 | (but consistent with original paper)') 51 | seq_length = 0 52 | else: 53 | weights = torch.load(args.pretrained_posenet) 54 | seq_length = int(weights['state_dict']['conv1.0.weight'].size(1)/3) 55 | pose_net = getattr(models, args.posenet)(nb_ref_imgs=seq_length - 1, output_exp=False).cuda() 56 | pose_net.load_state_dict(weights['state_dict'], strict=False) 57 | 58 | dataset_dir = Path(args.dataset_dir) 59 | if args.dataset_list is not None: 60 | with open(args.dataset_list, 'r') as f: 61 | test_files = list(f.read().splitlines()) 62 | else: 63 | test_files = [file.relpathto(dataset_dir) for file in sum([dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts], [])] 64 | 65 | framework = test_framework(dataset_dir, test_files, seq_length, args.min_depth, args.max_depth) 66 | 67 | print('{} files to test'.format(len(test_files))) 68 | errors = np.zeros((2, 7, len(test_files)), np.float32) 69 | if args.output_dir is not None: 70 | output_dir = Path(args.output_dir) 71 | viz_dir = output_dir/'viz' 72 | output_dir.makedirs_p() 73 | viz_dir.makedirs_p() 74 | 75 | for j, sample in enumerate(tqdm(framework)): 76 | tgt_img = sample['tgt'] 77 | 78 | ref_imgs = sample['ref'] 79 | 80 | h,w,_ = tgt_img.shape 81 | if (not args.no_resize) and (h != args.img_height or w != args.img_width): 82 | tgt_img = imresize(tgt_img, (args.img_height, args.img_width)).astype(np.float32) 83 | ref_imgs = [imresize(img, (args.img_height, args.img_width)).astype(np.float32) for img in ref_imgs] 84 | 85 | tgt_img = np.transpose(tgt_img, (2, 0, 1)) 86 | ref_imgs = [np.transpose(img, (2,0,1)) for img in ref_imgs] 87 | 88 | tgt_img = torch.from_numpy(tgt_img).unsqueeze(0) 89 | tgt_img = ((tgt_img/255 - 0.5)/0.5).cuda() 90 | tgt_img_var = Variable(tgt_img, volatile=True) 91 | 92 | ref_imgs_var = [] 93 | for i, img in enumerate(ref_imgs): 94 | img = torch.from_numpy(img).unsqueeze(0) 95 | img = ((img/255 - 0.5)/0.5).cuda() 96 | ref_imgs_var.append(Variable(img, volatile=True)) 97 | 98 | pred_disp = disp_net(tgt_img_var) 99 | if args.spatial_normalize: 100 | pred_disp = spatial_normalize(pred_disp) 101 | pred_disp = pred_disp.data.cpu().numpy()[0,0] 102 | gt_depth = sample['gt_depth'] 103 | 104 | if args.output_dir is not None: 105 | if j == 0: 106 | predictions = np.zeros((len(test_files), *pred_disp.shape)) 107 | predictions[j] = 1/pred_disp 108 | gt_viz = interp_gt_disp(gt_depth) 109 | gt_viz = torch.FloatTensor(gt_viz) 110 | gt_viz[gt_viz == 0] = 1000 111 | gt_viz = (1/gt_viz).clamp(0,10) 112 | 113 | tgt_img_viz = tensor2array(tgt_img[0].cpu()) 114 | depth_viz = tensor2array(torch.FloatTensor(pred_disp), max_value=None, colormap='hot') 115 | gt_viz = tensor2array(gt_viz, max_value=None, colormap='hot') 116 | tgt_img_viz_im = Image.fromarray((255*tgt_img_viz).astype('uint8')) 117 | tgt_img_viz_im.save(viz_dir/str(j).zfill(4)+'img.png') 118 | depth_viz_im = Image.fromarray((255*depth_viz).astype('uint8')) 119 | depth_viz_im.save(viz_dir/str(j).zfill(4)+'depth.png') 120 | gt_viz_im = Image.fromarray((255*gt_viz).astype('uint8')) 121 | gt_viz_im.save(viz_dir/str(j).zfill(4)+'gt.png') 122 | 123 | 124 | pred_depth = 1/pred_disp 125 | pred_depth_zoomed = zoom(pred_depth, (gt_depth.shape[0]/pred_depth.shape[0],gt_depth.shape[1]/pred_depth.shape[1])).clip(args.min_depth, args.max_depth) 126 | if sample['mask'] is not None: 127 | pred_depth_zoomed = pred_depth_zoomed[sample['mask']] 128 | gt_depth = gt_depth[sample['mask']] 129 | 130 | if seq_length > 0: 131 | _, poses = pose_net(tgt_img_var, ref_imgs_var) 132 | displacements = poses[0,:,:3].norm(2,1).cpu().data.numpy() # shape [1 - seq_length] 133 | 134 | scale_factors = [s1/s2 for s1, s2 in zip(sample['displacements'], displacements) if s1 > 0] 135 | scale_factor = np.mean(scale_factors) if len(scale_factors) > 0 else 0 136 | if len(scale_factors) == 0: 137 | print('not good ! ', sample['path'], sample['displacements']) 138 | errors[0,:,j] = compute_errors(gt_depth, pred_depth_zoomed*scale_factor) 139 | 140 | scale_factor = np.median(gt_depth)/np.median(pred_depth_zoomed) 141 | errors[1,:,j] = compute_errors(gt_depth, pred_depth_zoomed*scale_factor) 142 | 143 | mean_errors = errors.mean(2) 144 | error_names = ['abs_rel','sq_rel','rms','log_rms','a1','a2','a3'] 145 | if args.pretrained_posenet: 146 | print("Results with scale factor determined by PoseNet : ") 147 | print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names)) 148 | print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors[0])) 149 | 150 | print("Results with scale factor determined by GT/prediction ratio (like the original paper) : ") 151 | print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names)) 152 | print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors[1])) 153 | 154 | if args.output_dir is not None: 155 | np.save(output_dir/'predictions.npy', predictions) 156 | 157 | def interp_gt_disp(mat, mask_val=0): 158 | mat[mat==mask_val] = np.nan 159 | x = np.arange(0, mat.shape[1]) 160 | y = np.arange(0, mat.shape[0]) 161 | mat = np.ma.masked_invalid(mat) 162 | xx, yy = np.meshgrid(x, y) 163 | #get only the valid values 164 | x1 = xx[~mat.mask] 165 | y1 = yy[~mat.mask] 166 | newarr = mat[~mat.mask] 167 | 168 | GD1 = interpolate.griddata((x1, y1), newarr.ravel(), (xx, yy), method='linear', fill_value=mask_val) 169 | return GD1 170 | 171 | def compute_errors(gt, pred): 172 | thresh = np.maximum((gt / pred), (pred / gt)) 173 | a1 = (thresh < 1.25 ).mean() 174 | a2 = (thresh < 1.25 ** 2).mean() 175 | a3 = (thresh < 1.25 ** 3).mean() 176 | 177 | rmse = (gt - pred) ** 2 178 | rmse = np.sqrt(rmse.mean()) 179 | 180 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 181 | rmse_log = np.sqrt(rmse_log.mean()) 182 | 183 | abs_rel = np.mean(np.abs(gt - pred) / gt) 184 | 185 | sq_rel = np.mean(((gt - pred)**2) / gt) 186 | 187 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 188 | 189 | 190 | if __name__ == '__main__': 191 | main() 192 | -------------------------------------------------------------------------------- /test_pose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | from scipy.misc import imresize 5 | import numpy as np 6 | from path import Path 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | import models 11 | from inverse_warp_summary import pose_vec2mat 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Script for PoseNet testing with corresponding groundTruth from KITTI Odometry', 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument("pretrained_posenet", type=str, help="pretrained PoseNet path") 17 | parser.add_argument("--posenet", type=str, default="PoseNetB6", help="PoseNet model path") 18 | parser.add_argument("--img-height", default=256, type=int, help="Image height") 19 | parser.add_argument("--img-width", default=832, type=int, help="Image width") 20 | parser.add_argument("--no-resize", action='store_true', help="no resizing is done") 21 | parser.add_argument("--min-depth", default=1e-3) 22 | parser.add_argument("--max-depth", default=80) 23 | 24 | parser.add_argument("--dataset-dir", default='.', type=str, help="Dataset directory") 25 | parser.add_argument("--sequences", default=['09'], type=str, nargs='*', help="sequences to test") 26 | parser.add_argument("--output-dir", default=None, type=str, help="Output directory for saving predictions in a big 3D numpy file") 27 | parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob") 28 | parser.add_argument("--rotation-mode", default='euler', choices=['euler', 'quat'], type=str) 29 | 30 | 31 | def main(): 32 | args = parser.parse_args() 33 | from kitti_eval.pose_evaluation_utils import test_framework_KITTI as test_framework 34 | 35 | weights = torch.load(args.pretrained_posenet) 36 | seq_length = int(weights['state_dict']['conv1.0.weight'].size(1)/3) 37 | pose_net = getattr(models, args.posenet)(nb_ref_imgs=seq_length - 1).cuda() 38 | pose_net.load_state_dict(weights['state_dict'], strict=False) 39 | 40 | dataset_dir = Path(args.dataset_dir) 41 | framework = test_framework(dataset_dir, args.sequences, seq_length) 42 | 43 | print('{} snippets to test'.format(len(framework))) 44 | errors = np.zeros((len(framework), 2), np.float32) 45 | if args.output_dir is not None: 46 | output_dir = Path(args.output_dir) 47 | output_dir.makedirs_p() 48 | predictions_array = np.zeros((len(framework), seq_length, 3, 4)) 49 | 50 | for j, sample in enumerate(tqdm(framework)): 51 | imgs = sample['imgs'] 52 | 53 | h,w,_ = imgs[0].shape 54 | if (not args.no_resize) and (h != args.img_height or w != args.img_width): 55 | imgs = [imresize(img, (args.img_height, args.img_width)).astype(np.float32) for img in imgs] 56 | 57 | imgs = [np.transpose(img, (2,0,1)) for img in imgs] 58 | 59 | ref_imgs_var = [] 60 | for i, img in enumerate(imgs): 61 | img = torch.from_numpy(img).unsqueeze(0) 62 | img = ((img/255 - 0.5)/0.5).cuda() 63 | img_var = Variable(img, volatile=True) 64 | if i == len(imgs)//2: 65 | tgt_img_var = img_var 66 | else: 67 | ref_imgs_var.append(Variable(img, volatile=True)) 68 | 69 | if args.posenet in ["PoseNet6", "PoseNetB6"]: 70 | poses = pose_net(tgt_img_var, ref_imgs_var) 71 | else: 72 | _, poses = pose_net(tgt_img_var, ref_imgs_var) 73 | 74 | poses = poses.cpu().data[0] 75 | poses = torch.cat([poses[:len(imgs)//2], torch.zeros(1,6).float(), poses[len(imgs)//2:]]) 76 | 77 | inv_transform_matrices = pose_vec2mat(Variable(poses), rotation_mode=args.rotation_mode).data.numpy().astype(np.float64) 78 | 79 | rot_matrices = np.linalg.inv(inv_transform_matrices[:,:,:3]) 80 | tr_vectors = -rot_matrices @ inv_transform_matrices[:,:,-1:] 81 | 82 | transform_matrices = np.concatenate([rot_matrices, tr_vectors], axis=-1) 83 | 84 | first_inv_transform = inv_transform_matrices[0] 85 | final_poses = first_inv_transform[:,:3] @ transform_matrices 86 | final_poses[:,:,-1:] += first_inv_transform[:,-1:] 87 | 88 | if args.output_dir is not None: 89 | predictions_array[j] = final_poses 90 | 91 | ATE, RE = compute_pose_error(sample['poses'], final_poses) 92 | errors[j] = ATE, RE 93 | 94 | mean_errors = errors.mean(0) 95 | std_errors = errors.std(0) 96 | error_names = ['ATE','RE'] 97 | print('') 98 | print("Results") 99 | print("\t {:>10}, {:>10}".format(*error_names)) 100 | print("mean \t {:10.4f}, {:10.4f}".format(*mean_errors)) 101 | print("std \t {:10.4f}, {:10.4f}".format(*std_errors)) 102 | 103 | if args.output_dir is not None: 104 | np.save(output_dir/'predictions.npy', predictions_array) 105 | 106 | 107 | def compute_pose_error(gt, pred): 108 | RE = 0 109 | snippet_length = gt.shape[0] 110 | scale_factor = np.sum(gt[:,:,-1] * pred[:,:,-1])/np.sum(pred[:,:,-1] ** 2) 111 | ATE = np.linalg.norm((gt[:,:,-1] - scale_factor * pred[:,:,-1]).reshape(-1)) 112 | for gt_pose, pred_pose in zip(gt, pred): 113 | # Residual matrix to which we compute angle's sin and cos 114 | R = gt_pose[:,:3] @ np.linalg.inv(pred_pose[:,:3]) 115 | s = np.linalg.norm([R[0,1]-R[1,0], 116 | R[1,2]-R[2,1], 117 | R[0,2]-R[2,0]]) 118 | c = np.trace(R) - 1 119 | # Note: we actually compute double of cos and sin, but arctan2 is invariant to scale 120 | RE += np.arctan2(s,c) 121 | 122 | return ATE/snippet_length, RE/snippet_length 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import shutil 3 | import numpy as np 4 | import torch 5 | from matplotlib import cm 6 | from matplotlib.colors import ListedColormap, LinearSegmentedColormap 7 | 8 | def high_res_colormap(low_res_cmap, resolution=1000, max_value=1): 9 | # Construct the list colormap, with interpolated values for higer resolution 10 | # For a linear segmented colormap, you can just specify the number of point in 11 | # cm.get_cmap(name, lutsize) with the parameter lutsize 12 | x = np.linspace(0,1,low_res_cmap.N) 13 | low_res = low_res_cmap(x) 14 | new_x = np.linspace(0,max_value,resolution) 15 | high_res = np.stack([np.interp(new_x, x, low_res[:,i]) for i in range(low_res.shape[1])], axis=1) 16 | return ListedColormap(high_res) 17 | 18 | 19 | def opencv_rainbow(resolution=1000): 20 | # Construct the opencv equivalent of Rainbow 21 | opencv_rainbow_data = ( 22 | (0.000, (1.00, 0.00, 0.00)), 23 | (0.400, (1.00, 1.00, 0.00)), 24 | (0.600, (0.00, 1.00, 0.00)), 25 | (0.800, (0.00, 0.00, 1.00)), 26 | (1.000, (0.60, 0.00, 1.00)) 27 | ) 28 | 29 | return LinearSegmentedColormap.from_list('opencv_rainbow', opencv_rainbow_data, resolution) 30 | 31 | 32 | COLORMAPS = {'rainbow': opencv_rainbow(), 33 | 'magma': high_res_colormap(cm.get_cmap('magma')), 34 | 'bone': cm.get_cmap('bone', 10000)} 35 | 36 | 37 | def tensor2array(tensor, max_value=None, colormap='rainbow'): 38 | tensor = tensor.detach().cpu() 39 | if max_value is None: 40 | max_value = tensor.max().item() 41 | if tensor.ndimension() == 2 or tensor.size(0) == 1: 42 | norm_array = tensor.squeeze().numpy()/max_value 43 | array = COLORMAPS[colormap](norm_array).astype(np.float32) 44 | array = array[:,:,:3] 45 | array = array.transpose(2, 0, 1) 46 | 47 | elif tensor.ndimension() == 3: 48 | if (tensor.size(0) == 3): 49 | array = 0.5 + tensor.numpy()*0.5 50 | elif (tensor.size(0) == 2): 51 | array = tensor.numpy() 52 | 53 | return array 54 | 55 | def save_checkpoint(save_path, dispnet_state, posenet_state, masknet_state, flownet_state, optimizer_state, is_best, filename='checkpoint.pth.tar'): 56 | file_prefixes = ['dispnet', 'posenet', 'masknet', 'flownet', 'optimizer'] 57 | states = [dispnet_state, posenet_state, masknet_state, flownet_state, optimizer_state] 58 | for (prefix, state) in zip(file_prefixes, states): 59 | torch.save(state, save_path/'{}_{}'.format(prefix,filename)) 60 | if is_best: 61 | for prefix in file_prefixes: 62 | shutil.copyfile(save_path/'{}_{}'.format(prefix,filename), save_path/'{}_model_best.pth.tar'.format(prefix)) 63 | 64 | 65 | 66 | def save_checkpoint_best(save_path, dispnet_state, posenet_state, masknet_state, flownet_state, optimizer_state, name, filename='checkpoint.pth.tar'): 67 | file_prefixes = ['dispnet', 'posenet', 'masknet', 'flownet', 'optimizer'] 68 | states = [dispnet_state, posenet_state, masknet_state, flownet_state, optimizer_state] 69 | path_for_save = save_path/name 70 | path_for_save.makedirs_p() 71 | 72 | for (prefix, state) in zip(file_prefixes, states): 73 | torch.save(state, path_for_save/'{}_{}'.format(prefix,filename)) 74 | 75 | 76 | 77 | 78 | --------------------------------------------------------------------------------