├── .gitignore ├── LICENSE ├── README.md ├── common ├── constants.py └── utils.py ├── data ├── apolloscape.py ├── consistency.py ├── dataset.py ├── kitti.py ├── nuscenes.py └── utils.py ├── kitti ├── kitti_dataset.py ├── kitti_trainer.py └── train_mask.py ├── metrics └── metrics.py ├── monodepth ├── LICENSE ├── datasets │ ├── __init__.py │ ├── apollo_dataset.py │ ├── kitti_dataset.py │ └── mono_dataset.py ├── evaluate_depth.py ├── evaluate_pose.py ├── export_gt_depth.py ├── kitti_utils.py ├── layers.py ├── networks │ ├── __init__.py │ ├── depth_decoder.py │ ├── pose_cnn.py │ ├── pose_decoder.py │ └── resnet_encoder.py ├── options.py ├── test_simple.py ├── train.py ├── trainer.py └── utils.py ├── neural ├── dynamics │ └── dynamics_factory.py ├── layers.py ├── layers_3d.py ├── losses.py ├── model.py ├── resnet.py └── utils.py ├── options.py ├── run_training.py ├── trainer.py └── visualisation ├── utils.py └── visualisation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .idea/ 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 iclr-2020-embedding 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spatio-temporal-embedding 2 | Code accompanying the ICLR 2020 submission: Learning a Spatio-Temporal Embedding for Video Instance Segmentation. 3 | -------------------------------------------------------------------------------- /common/constants.py: -------------------------------------------------------------------------------- 1 | MEAN = [0.485, 0.456, 0.406] 2 | STD = [0.229, 0.224, 0.225] 3 | 4 | ID_FILTER = 5 5 | N_CLASSES = 1 6 | COST_THRESHOLD = 0.5 7 | NUSCENES_ROOT = '' 8 | SENSOR = 'CAM_FRONT' 9 | 10 | MAX_INSTANCES = 32 11 | MAX_INSTANCES_SCENE = 256 12 | 13 | # Meanshfit bandwidth 14 | BANDWIDTH = 1.0 15 | # Causal clustering cost threshold 16 | CLUSTERING_COST_THRESHOLD = 1.5 17 | # Filter detections containing less than MIN_PIXEL_THRESHOLD pixels 18 | MIN_PIXEL_THRESHOLD = 200 19 | # How long to keep old centers 20 | CLUSTER_MEAN_LIFE = 10 21 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import collections 4 | import yaml 5 | import skvideo.io 6 | 7 | import numpy as np 8 | from yaml.constructor import ConstructorError 9 | from yaml.nodes import MappingNode 10 | from sklearn.metrics import confusion_matrix 11 | 12 | LABEL_NAMES = np.asarray([ 13 | 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 14 | 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle']) 15 | 16 | 17 | def cummean(x: np.array) -> np.array: 18 | """ 19 | Computes the cumulative mean up to each position in a NaN sensitive way 20 | - If all values are NaN return an array of ones. 21 | - If some values are NaN, accumulate arrays discording those entries. 22 | """ 23 | if sum(np.isnan(x)) == len(x): 24 | # Is all numbers in array are NaN's. 25 | return np.ones(len(x)) # If all errors are NaN set to error to 1 for all operating points. 26 | else: 27 | # Accumulate in a nan-aware manner. 28 | sum_vals = np.nancumsum(x.astype(float)) # Cumulative sum ignoring nans. 29 | count_vals = np.cumsum(~np.isnan(x)) # Number of non-nans up to each position. 30 | return np.divide(sum_vals, count_vals, out=np.zeros_like(sum_vals), where=count_vals != 0) 31 | 32 | 33 | def compute_miou(gt, pred, n_classes=19): 34 | """ Calculate the mean IOU defined as TP / (TP + FN + FP). 35 | Parameters 36 | ---------- 37 | gt: np.array (batch_size, H, W) 38 | pred: np.array (batch_size, H, W) 39 | """ 40 | # Compute confusion matrix. IGNORED_ID being equal to 255, it will be ignored. 41 | cm = confusion_matrix(gt.ravel(), pred.ravel(), np.arange(n_classes)) 42 | 43 | # Calculate mean IOU 44 | miou_dict = {} 45 | miou = 0 46 | actual_n_classes = 0 47 | for l in range(n_classes): 48 | tp = cm[l, l] 49 | fn = cm[l, :].sum() - tp 50 | fp = cm[:, l].sum() - tp 51 | denom = tp + fn + fp 52 | if denom == 0: 53 | iou = float('nan') 54 | else: 55 | iou = tp / denom 56 | if not (np.isnan(iou)): 57 | miou_dict[LABEL_NAMES[l]] = iou 58 | miou += iou 59 | actual_n_classes += 1 60 | 61 | miou /= actual_n_classes 62 | 63 | miou_dict['miou'] = miou 64 | return miou_dict 65 | 66 | 67 | def get_iou(bb1, bb2): 68 | """ 69 | Calculate the Intersection over Union (IoU) of two bounding boxes. 70 | 71 | Parameters 72 | ---------- 73 | bb1 : dict 74 | Keys: {'x1', 'x2', 'y1', 'y2'} 75 | The (x1, y1) position is at the top left corner, 76 | the (x2, y2) position is at the bottom right corner 77 | bb2 : dict 78 | Keys: {'x1', 'x2', 'y1', 'y2'} 79 | The (x, y) position is at the top left corner, 80 | the (x2, y2) position is at the bottom right corner 81 | 82 | Returns 83 | ------- 84 | float 85 | in [0, 1] 86 | """ 87 | assert bb1['x1'] < bb1['x2'] 88 | assert bb1['y1'] < bb1['y2'] 89 | assert bb2['x1'] < bb2['x2'] 90 | assert bb2['y1'] < bb2['y2'] 91 | 92 | # determine the coordinates of the intersection rectangle 93 | x_left = max(bb1['x1'], bb2['x1']) 94 | y_top = max(bb1['y1'], bb2['y1']) 95 | x_right = min(bb1['x2'], bb2['x2']) 96 | y_bottom = min(bb1['y2'], bb2['y2']) 97 | 98 | if x_right < x_left or y_bottom < y_top: 99 | return 0.0 100 | 101 | # The intersection of two axis-aligned bounding boxes is always an 102 | # axis-aligned bounding box 103 | intersection_area = (x_right - x_left) * (y_bottom - y_top) 104 | 105 | # compute the area of both AABBs 106 | bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1']) 107 | bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1']) 108 | 109 | # compute the intersection over union by taking the intersection 110 | # area and dividing it by the sum of prediction + ground-truth 111 | # areas - the interesection area 112 | iou = intersection_area / float(bb1_area + bb2_area - intersection_area) 113 | assert iou >= 0.0 114 | assert iou <= 1.0 115 | return iou 116 | 117 | 118 | class Logger(): 119 | """ Writes on both terminal and output file.""" 120 | # TODO: add tensoflow logging 121 | def __init__(self, filename): 122 | self.terminal = sys.stdout 123 | self.log = open(filename, 'w', buffering=1) 124 | 125 | def write(self, message): 126 | self.terminal.write(message) 127 | self.log.write(message) 128 | 129 | def flush(self): 130 | # Needed for compatibility 131 | pass 132 | 133 | def close(self): 134 | self.log.flush() 135 | os.fsync(self.log.fileno()) 136 | self.log.close() 137 | 138 | 139 | def _update_dict_recursive(dict_1, dict_2): 140 | if isinstance(dict_1, dict) and isinstance(dict_2, collections.Mapping): 141 | for key, value in dict_2.items(): 142 | dict_1[key] = _update_dict_recursive(dict_1.get(key, None), value) 143 | return dict_1 144 | return dict_2 145 | 146 | 147 | def _ordered_dict_constructor(loader, node): 148 | pairs = loader.construct_pairs(node) 149 | res = collections.OrderedDict() 150 | for key, value in pairs: 151 | res[key] = _update_dict_recursive(res.get(key, None), value) 152 | return res 153 | 154 | 155 | def _ordered_dict_representer(dumper, data): 156 | return yaml.nodes.MappingNode(yaml.SafeDumper.DEFAULT_MAPPING_TAG, 157 | [(dumper.represent_data(k), dumper.represent_data(v)) 158 | for k, v in data.items()]) 159 | 160 | 161 | class _Loader(yaml.SafeLoader): 162 | def __init__(self, config_path): 163 | self.config_path = config_path 164 | super().__init__(open(config_path, 'r')) 165 | self.yaml_constructors[self.DEFAULT_MAPPING_TAG] = _ordered_dict_constructor 166 | yaml.SafeDumper.yaml_representers[collections.OrderedDict] = _ordered_dict_representer 167 | 168 | def construct_pairs(self, node, deep=False): 169 | if not isinstance(node, MappingNode): 170 | raise ConstructorError(None, None, 171 | "expected a mapping node, but found %s" % node.id, 172 | node.start_mark) 173 | pairs = [] 174 | for key_node, value_node in node.value: 175 | key = self.construct_object(key_node, deep=deep) 176 | value = self.construct_object(value_node, deep=deep) 177 | if key == "import": 178 | import_path = os.path.join(os.path.dirname(self.config_path), value) 179 | if not os.path.isfile(import_path): 180 | raise FileNotFoundError('cannot load referenced yml at: {}'.format(import_path)) 181 | imported_config = load_config(import_path) 182 | for imported_key, imported_value in imported_config.items(): 183 | pairs.append((imported_key, imported_value)) 184 | else: 185 | pairs.append((key, value)) 186 | return pairs 187 | 188 | def dispose(self): 189 | self.stream.close() 190 | super().dispose() 191 | 192 | 193 | def _load_config(config_path): 194 | loader = _Loader(config_path) 195 | try: 196 | return loader.get_single_data() 197 | finally: 198 | loader.dispose() 199 | 200 | 201 | def load_config(config_path): 202 | config = _load_config(config_path) 203 | config['config_path'] = config_path 204 | return config 205 | 206 | 207 | def write_mp4_file(video, output_filename, fps='5'): 208 | """ Lossless mp4 video creation. 209 | Parameters 210 | ---------- 211 | video: list 212 | each array must be (h, w, 3) RGB 213 | output_filename: str 214 | fps: str 215 | """ 216 | video_writer = skvideo.io.FFmpegWriter(output_filename, inputdict={'-r': fps}, 217 | outputdict={ 218 | '-vcodec': 'libx264', #use the h.264 codec 219 | '-crf': '0', #set the constant rate factor to 0, which is lossless 220 | '-preset':'veryslow', #the slower the better compression, in princple, try 221 | #other options see https://trac.ffmpeg.org/wiki/Encode/H.264 222 | '-r': fps, 223 | }) 224 | for j in range(len(video)): 225 | video_writer.writeFrame(video[j]) 226 | 227 | video_writer.close() 228 | 229 | 230 | def normalise_numpy_image(x): 231 | x_min = x.min() 232 | x_max = x.max() 233 | 234 | d = (x_max - x_min) if x_max != x_min else 1e-5 235 | return (x - x_min) / d 236 | -------------------------------------------------------------------------------- /data/apolloscape.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from glob import glob 8 | from tqdm import tqdm 9 | from PIL import Image 10 | from functools import partial 11 | from multiprocessing import Pool, cpu_count 12 | 13 | from common.constants import MAX_INSTANCES, MAX_INSTANCES_SCENE 14 | from data.consistency import enforce_consistency, make_ids_consecutives 15 | from data.utils import crop_and_resize 16 | 17 | ######################### 18 | ### Dataset creation 19 | # 1. Create apollo.csv with `create_apollo_df` 20 | # 2. Call `split_dataset` to split into train/val, containing folders {road_nb}_{record}_{camera_nb} 21 | # with img (.jpg), instance_seg (.png), semantic_seg (.png) 22 | # 3. Call `preprocess_dataset` to make instance ids consistent and resize to 128x256. 23 | ######################### 24 | 25 | ROOT = '' 26 | CSV_FILENAME = 'apollo.csv' 27 | CAMERA = 5 28 | IMG_SIZE = (128, 256) 29 | LABEL_ID = 33 # ids of car 30 | MIN_PIXEL_INSTANCE = 100 31 | 32 | 33 | def split_dataset(root_save=''): 34 | pool = Pool(cpu_count() - 1) 35 | df_apollo = pd.read_csv(os.path.join(ROOT, CSV_FILENAME)) 36 | 37 | for mode in ['train', 'val']: 38 | unique_road_nb = np.unique(df_apollo[df_apollo['split'] == mode]['road_nb']) 39 | 40 | for road_nb in unique_road_nb: # TODO remove 41 | df_apollo_road_nb = df_apollo[(df_apollo['road_nb'] == road_nb) & (df_apollo['split'] == mode)] 42 | unique_records = np.unique(df_apollo_road_nb['record']) 43 | 44 | for _ in tqdm(pool.imap_unordered( 45 | partial(split_dataset_iter, df_apollo_road_nb=df_apollo_road_nb, root_save=root_save, mode=mode, 46 | road_nb=road_nb), unique_records), total=len(unique_records)): 47 | pass 48 | 49 | 50 | def split_dataset_iter(record, df_apollo_road_nb, root_save, mode, road_nb): 51 | path = os.path.join(root_save, mode, road_nb + '_' + record + '_' + str(CAMERA)) 52 | os.makedirs(path, exist_ok=True) 53 | filter_mask = (df_apollo_road_nb['record'] == record) & (df_apollo_road_nb['camera'] == CAMERA) 54 | tuple_filenames = zip(df_apollo_road_nb[filter_mask]['img_path'], 55 | df_apollo_road_nb[filter_mask]['sem_path'], 56 | df_apollo_road_nb[filter_mask]['inst_path'] 57 | ) 58 | print('Road: {}, {}'.format(road_nb, record)) 59 | count = 0 60 | for f_img, f_semantic, f_instance_seg in tuple_filenames: 61 | shutil.copy(src=os.path.join(ROOT, f_img), 62 | dst=os.path.join(path, '{:04d}_image_tmp.jpg'.format(count))) # TODO: remove these tmp 63 | 64 | shutil.copy(src=os.path.join(ROOT, f_semantic), 65 | dst=os.path.join(path, '{:04d}_semantic_tmp.png'.format(count))) 66 | 67 | shutil.copy(src=os.path.join(ROOT, f_instance_seg), 68 | dst=os.path.join(path, '{:04d}_instance_seg_tmp.png'.format(count))) 69 | 70 | count += 1 71 | 72 | 73 | def preprocess_dataset(root='', 74 | root_save=''): 75 | pool = Pool(cpu_count() - 1) 76 | for mode in ['train', 'val']: 77 | all_scene_dir = sorted(glob(os.path.join(root, mode, '*'))) 78 | for _ in tqdm(pool.imap_unordered( 79 | partial(preprocess_dataset_iter, root=root, mode=mode, root_save=root_save), all_scene_dir), 80 | total=len(all_scene_dir)): 81 | pass 82 | 83 | 84 | def preprocess_dataset_iter(scene_dir, root, mode, root_save): 85 | print('Scene: {}'.format(scene_dir)) 86 | img_filenames = sorted(glob(os.path.join(root, mode, os.path.basename(scene_dir), '*_image_tmp.jpg'))) 87 | 88 | available_keys = set(range(1, MAX_INSTANCES_SCENE)) # max of 256 instances 89 | dict_ids = None 90 | prev_instance_seg = None 91 | for img_fname in img_filenames: 92 | basename = img_fname[:-len('image_tmp.jpg')] 93 | img = Image.open(img_fname) 94 | #  Much more compact to load and resize images compared to numpy arrays 95 | semantic = Image.open(basename + 'semantic_tmp.png') 96 | instance_seg = Image.open(basename + 'instance_seg_tmp.png') 97 | 98 | ##### 99 | #  Resize 100 | img = crop_and_resize(img, IMG_SIZE) 101 | instance_seg = crop_and_resize(instance_seg, IMG_SIZE, order=0) 102 | semantic = crop_and_resize(semantic, IMG_SIZE, order=0) 103 | # depth = crop_and_resize(depth, IMG_SIZE, order=1) 104 | 105 | ##### 106 | # Filter instance ids 107 | instance_seg = np.array(instance_seg) 108 | unique_ids = np.unique(instance_seg) 109 | # The relevant ids start with '33' 110 | relevant_ids = [x for x in unique_ids if str(x).startswith(str(LABEL_ID))][:MAX_INSTANCES] 111 | # Remove too small instances 112 | relevant_ids = [x for x in relevant_ids if np.sum(instance_seg == x) > MIN_PIXEL_INSTANCE] 113 | mask = np.isin(instance_seg, relevant_ids) 114 | instance_seg[~mask] = 0 115 | 116 | instance_seg = make_ids_consecutives(instance_seg) 117 | 118 | #  N_CLASSES = 1 119 | instance_seg = instance_seg[None, :, :] 120 | 121 | ###### 122 | # Enforce consistency 123 | consistent_instance_seg, available_keys, dict_ids = enforce_consistency(instance_seg, prev_instance_seg, 124 | available_keys, dict_ids) 125 | 126 | #  The algorithm only works when for each frame, instance ids are in [0, max_n_instance[ 127 | prev_instance_seg = instance_seg 128 | 129 | ###### 130 | #  Save 131 | save_path = os.path.join(root_save, mode, os.path.basename(scene_dir)) 132 | os.makedirs(save_path, exist_ok=True) 133 | img.save(os.path.join(save_path, os.path.basename(basename) + 'image.jpg')) 134 | np.save(os.path.join(save_path, os.path.basename(basename) + 'instance_seg.npy'), consistent_instance_seg) 135 | np.save(os.path.join(save_path, os.path.basename(basename) + 'semantic.npy'), np.array(semantic)) 136 | 137 | 138 | def create_apollo_df(): 139 | """ Create initial Apollo dataframe containing the train/val split as defined by ApolloScape. (only do it once) 140 | 141 | Parameters 142 | ---------- 143 | ROOT: directory containing the ApolloScape dataset 144 | CSV_FILENAME: filename of the saved dataframe as a csv file 145 | """ 146 | apollo_df = pd.DataFrame(columns=['split', 'road_nb', 'record', 'camera', 147 | 'img_path', 'sem_path', 'inst_path']) 148 | 149 | for split in ['train', 'val']: 150 | for road_nb in ['road01_ins', 'road02_ins', 'road03_ins']: 151 | print(road_nb) 152 | split_df = pd.read_csv(os.path.join(ROOT, 'dataset_splits', road_nb + '_' + split + '.lst'), sep='\t', 153 | header=None, names=['image_path', 'semantic_path']) 154 | print(split_df.shape) 155 | 156 | # All the image and label filenames (relative path) 157 | im_filenames = glob(os.path.join(ROOT, road_nb, 'ColorImage/*/*/*')) 158 | im_filenames = set([os.path.relpath(x, ROOT) for x in im_filenames]) 159 | label_filenames = glob(os.path.join(ROOT, road_nb, 'Label/*/*/*')) 160 | label_filenames = set([os.path.relpath(x, ROOT) for x in label_filenames]) 161 | 162 | inst_not_found = 0 163 | for im, sem in split_df[['image_path', 'semantic_path']].values: 164 | # Check that the files exist 165 | assert im in im_filenames, 'Image not found' 166 | assert sem in label_filenames, 'Semantic seg. not found' 167 | 168 | # Extract record, camera, image id 169 | _, _, record, camera, im_id = os.path.normpath(im).split(os.sep) 170 | camera_nb = camera[-1] 171 | im_id = im_id.split('.jpg')[0] # gives something like '170908_085500604_Camera_6' 172 | 173 | # Deduce instance ids filename 174 | inst = os.path.join(os.path.dirname(sem), im_id + '_instanceIds.png') 175 | if inst not in label_filenames: 176 | inst_not_found += 1 177 | 178 | apollo_df = apollo_df.append( 179 | pd.DataFrame({'split': split, 'road_nb': road_nb, 'record': record, 'camera': camera_nb, 180 | 'img_path': im, 'sem_path': sem, 'inst_path': inst}, index=[0]), 181 | sort=False) 182 | 183 | print('Instance not found: ', inst_not_found) 184 | 185 | apollo_df = apollo_df.reset_index(drop=True) 186 | apollo_df.to_csv(os.path.join(ROOT, CSV_FILENAME), index=False) 187 | print('Finished! Saved as csv file') 188 | -------------------------------------------------------------------------------- /data/consistency.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import confusion_matrix 3 | from sklearn.metrics.pairwise import euclidean_distances 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | from common.constants import MAX_INSTANCES_SCENE 7 | 8 | 9 | def temporally_align_sequence(instance_seg, iou_threshold=0.1): 10 | """ 11 | Parameters 12 | ---------- 13 | instance_seg: np.ndarray (batch_size, seq_len, N_CLASSES, height, width) 14 | """ 15 | aligned_instance_seg = np.zeros_like(instance_seg) 16 | batch_size, seq_len = instance_seg.shape[:2] 17 | 18 | for i in range(batch_size): 19 | available_keys = set(range(1, MAX_INSTANCES_SCENE)) # max of 256 instances 20 | dict_ids = None 21 | prev_instance_seg_t = None 22 | for t in range(seq_len): 23 | instance_seg_t = instance_seg[i, t] 24 | 25 | # The algorithm only works when for each frame, instance ids are in [0, max_n_instance[ 26 | instance_seg_t = make_ids_consecutives(instance_seg_t) 27 | # Enforce consistency 28 | consistent_instance_seg_t, available_keys, dict_ids = enforce_consistency(instance_seg_t, 29 | prev_instance_seg_t, 30 | available_keys, dict_ids, 31 | cost_threshold=(1 - iou_threshold)) 32 | aligned_instance_seg[i, t] = consistent_instance_seg_t 33 | prev_instance_seg_t = instance_seg_t 34 | 35 | return aligned_instance_seg 36 | 37 | 38 | def enforce_consistency(inst_seg, prev_inst_seg, available_keys, dict_ids, cost_threshold=0.99, use_barycenter=False, 39 | centers=None, prev_centers=None): 40 | """ 41 | TODO: remove center barycenter parameters 42 | Make the instance ids consistent the following way: 43 | 44 | Step 1: For each instance in the current frame, try to assign it to an instance in the previous frame. 45 | If one instance is assigned, but with cost 1.0, ignore it (as there is no overlap at all) 46 | Step 2: REMOVED FOR NOW. For the remaining instances, compare to all the instances that previously existed using 47 | barycenter distance. 48 | Match them if the distance is below some threshold. The barycenter can only be used to bridge 3 frames. 49 | ie usually, one frame in the middle where the instance disappears 50 | Step 3: If there still is some unassigned instances, assign to a new unique id 51 | 52 | Parameters 53 | ---------- 54 | inst_seg: np.ndarray (h, w) 55 | prev_inst_seg: np.ndarray (h, w) 56 | centers: np.ndarray (n_instances, emb_dim) 57 | centers of the current frame 58 | prev_centers: np.ndarray (n_instances_prev, emb_dim) 59 | centers of the previous frame 60 | """ 61 | # Skip if first element 62 | if prev_inst_seg is None: 63 | return inst_seg, available_keys, dict_ids 64 | 65 | # Update available keys 66 | available_keys = available_keys.difference(np.unique(prev_inst_seg)) 67 | 68 | # Only background 69 | if len(np.unique(inst_seg)) == 1: 70 | return inst_seg, available_keys, dict_ids 71 | 72 | # Compute cost matrix: 1 - IoU for each instance in frame 1 and frame2 73 | if not use_barycenter: 74 | cost_matrix = compute_IoU_instances(prev_inst_seg, inst_seg) 75 | else: 76 | cost_matrix = euclidean_distances(prev_centers, centers) 77 | # Apply step 1 and step 3 78 | inst_seg, dict_ids = sync_ids(inst_seg, cost_matrix, dict_ids, available_keys, cost_threshold) 79 | 80 | if len(available_keys) == 0: 81 | # Reset keys, since enough timeframes separate the ids, we can reuse without any chance of overlapping 82 | print('Reset instance id keys.') 83 | available_keys = set(range(1, MAX_INSTANCES_SCENE)) 84 | 85 | return inst_seg, available_keys, dict_ids 86 | 87 | 88 | def make_ids_consecutives(x): 89 | unique_ids = np.unique(x) 90 | if unique_ids[0] == 0: 91 | dict_ids = dict(zip(unique_ids, np.arange(len(unique_ids)))) 92 | else: # no background 93 | dict_ids = dict(zip(unique_ids, np.arange(1, len(unique_ids) + 1))) 94 | 95 | return np.vectorize(dict_ids.__getitem__)(x).astype(np.uint8) 96 | 97 | 98 | def sync_ids(current_ids, cost_matrix, old_dict_ids=None, available_keys=None, cost_threshold=0.99): 99 | """ Synchronise ids with the previous ids. 100 | 101 | Parameters 102 | ---------- 103 | current_ids: np.ndarray (N_CLASSES, height, width) 104 | cost_matrix: np.ndarray (n_instances_prev, n_instances_current) 105 | old_dict_ids: dict 106 | keys mapping previous frame instance ids to their original id 107 | available_keys: set 108 | available keys (256 max) 109 | """ 110 | assert cost_matrix.ndim == 2, 'Cost matrix is not two dimensional: {}'.format(cost_matrix) 111 | new_ids_to_old_ids = {} 112 | 113 | # Step 1: for each instance, try to map to another in the previous frame 114 | dict_existing_to_old, assigned_cost = hungarian_algo(cost_matrix, old_dict_ids) 115 | for key, value in dict_existing_to_old.items(): 116 | # If there is atleast some overlap 117 | if assigned_cost[key] < cost_threshold: 118 | new_ids_to_old_ids[key] = value 119 | 120 | # Step 3: assign remaining instances to a new unique id 121 | for j in range(cost_matrix.shape[1]): 122 | new_id = j + 1 123 | if new_id not in new_ids_to_old_ids: 124 | new_ids_to_old_ids[new_id] = available_keys.pop() 125 | 126 | new_frame = np.vectorize(new_ids_to_old_ids.__getitem__, otypes=[np.uint8])(current_ids) 127 | return new_frame, new_ids_to_old_ids 128 | 129 | 130 | def hungarian_algo(cost_matrix, old_dict_ids=None): 131 | """ Compute the optimal assignment given a cost matrix. 132 | 133 | Parameters 134 | ---------- 135 | cost_matrix: np.ndarray (n_inst, new_n_inst) 136 | cost matrix for the hungarian algorithm. n_inst and new_n_inst need not be equal. 137 | 138 | Returns 139 | ------- 140 | dict_new_to_old: mapping from new ids to old ids. 141 | assigned_cost: mapping from new ids to cost 142 | """ 143 | ids, new_ids = linear_sum_assignment(cost_matrix) 144 | assigned_cost = dict(zip(new_ids + 1, cost_matrix[ids, new_ids])) 145 | assigned_cost[0] = 0 # background 146 | # need to synchronise with the ids with the original ids (the first that appeared) 147 | if old_dict_ids is not None and len(ids) > 0: 148 | ids = np.vectorize(old_dict_ids.__getitem__)(ids + 1) - 1 149 | 150 | # add one to indices, to account for background (index 0) 151 | dict_new_to_old = dict(zip(new_ids + 1, ids + 1)) 152 | # Background id does not change 153 | dict_new_to_old[0] = 0 154 | return dict_new_to_old, assigned_cost 155 | 156 | 157 | def compute_IoU_instances(frame1, frame2): 158 | """ 159 | Parameters 160 | ---------- 161 | frame1: np.ndarray (N_CLASSES, height, width) 162 | instance ids taking values from 0 (background) to n_instances1 included 163 | frame2: np.ndarray (N_CLASSES, height, width) 164 | instance ids taking values from 0 (background) to n_instances2 included 165 | 166 | Returns 167 | ------- 168 | cost_matrix: np.ndarray (n_instances1, n_instances2) 169 | dissimilarity matrix between each instance in frame1 and frame2, based on IoU (background removed) 170 | """ 171 | # Compute IoU matrix 172 | unique_id_frame1 = np.unique(frame1) 173 | unique_id_frame2 = np.unique(frame2) 174 | assert np.all(unique_id_frame1== np.arange(len(unique_id_frame1))) and \ 175 | np.all(unique_id_frame2 == np.arange(len(unique_id_frame2))) 176 | 177 | cm = confusion_matrix(frame1.ravel(), frame2.ravel()) 178 | normalising_array = np.ones_like(cm) 179 | # row 180 | normalising_array += cm.sum(axis=0).reshape((1, -1)) 181 | # column 182 | normalising_array += cm.sum(axis=1).reshape((-1, 1)) 183 | # substract array to remove values appearing twice 184 | normalising_array -= cm 185 | 186 | row_indices = np.unique(frame1)[1:] # remove background 187 | col_indices = np.unique(frame2)[1:] 188 | # Compute IOU, amd remove row + colum related to background 189 | cm = (cm / normalising_array)[row_indices[:, None], col_indices].reshape((len(row_indices), len(col_indices))) 190 | 191 | cost_matrix = 1 - cm 192 | return cost_matrix 193 | 194 | 195 | def increment_life_clusters(dict_centers, cluster_mean_life): 196 | for existing_id, (life, existing_center) in dict_centers.items(): 197 | #  If too old, delete 198 | if life + 1 == cluster_mean_life: 199 | print('Delete id {}'.format(existing_id)) 200 | dict_centers.pop(existing_id) 201 | dict_centers[existing_id] = (life + 1, existing_center) 202 | 203 | 204 | def enforce_consistency_centers(inst_seg, centers, dict_centers, available_keys, cost_threshold=1.5, 205 | cluster_mean_life=10, verbose=False): 206 | """ 207 | Parameters 208 | ---------- 209 | inst_seg: np.ndarray (N_CLASSES, h, w) 210 | centers: np.ndarray (n_instances, emb_dim) 211 | centers of the current frame 212 | dict_centers : dict (id) -> (life, mean) 213 | existing_id: int 214 | unique id of the instance 215 | life: int 216 | from 1 to cluster_mean_life 217 | existing_center: np.ndarray (emb_dim,) 218 | """ 219 | # mapping from current_id to first_appeared_id 220 | id_mapping = {} 221 | #  Instance ids start at 1 222 | unique_ids = np.unique(inst_seg) 223 | assert 0 not in unique_ids, 'instance ids must start at 1' 224 | assert len(unique_ids) == len(centers), '{} unique ids for {} centers'.format(len(unique_ids), len(centers)) 225 | dict_id_to_center = dict(zip(unique_ids, centers)) 226 | remaining_ids = set(unique_ids) 227 | 228 | #  Initialise dict_centers 229 | if len(dict_centers) == 0: 230 | for id in remaining_ids: 231 | id_mapping[id] = available_keys.pop() 232 | # Life starts at 1 233 | dict_centers[id_mapping[id]] = (1, dict_id_to_center[id]) 234 | 235 | inst_seg = np.vectorize(id_mapping.__getitem__, otypes=[np.uint8])(inst_seg) 236 | return inst_seg, available_keys, dict_centers 237 | 238 | to_remove = [] 239 | prev_centers = [] 240 | map_prev_hunga_id_to_existing_id = {} 241 | j = 0 242 | for existing_id, (life, existing_center) in dict_centers.items(): 243 | if life == 1: 244 | prev_centers.append(existing_center) 245 | map_prev_hunga_id_to_existing_id[j] = existing_id 246 | 247 | j += 1 248 | 249 | if len(prev_centers) > 0: 250 | prev_centers = np.stack(prev_centers, axis=0) 251 | cost_matrix = euclidean_distances(prev_centers, centers) 252 | prev_hunga_ids, hunga_ids = linear_sum_assignment(cost_matrix) 253 | for i in range(len(prev_hunga_ids)): 254 | if cost_matrix[prev_hunga_ids[i], hunga_ids[i]] < cost_threshold: 255 | existing_id = map_prev_hunga_id_to_existing_id[prev_hunga_ids[i]] 256 | if verbose: 257 | print('Considered id: {}'.format(hunga_ids[i] + 1)) 258 | print('Correspondence found with existing id {} and cost: {:.2f}'.format( 259 | existing_id, cost_matrix[prev_hunga_ids[i], hunga_ids[i]])) 260 | # Update mapping 261 | id_mapping[hunga_ids[i] + 1] = existing_id 262 | # Update mean 263 | dict_centers[existing_id] = (0, dict_id_to_center[hunga_ids[i] + 1]) 264 | to_remove.append(hunga_ids[i] + 1) 265 | else: 266 | if verbose: 267 | print('Cost too high: {:.2f}'.format(cost_matrix[prev_hunga_ids[i], hunga_ids[i]])) 268 | 269 | remaining_ids = remaining_ids.difference(to_remove) 270 | 271 | to_remove = [] 272 | for id in remaining_ids: 273 | if verbose: 274 | print('----------') 275 | print('Remaining considered id {}'.format(id)) 276 | for existing_id, (life, existing_center) in dict_centers.items(): 277 | best_cost = float('inf') 278 | best_existing_id = None 279 | if life > 1: 280 | cost = np.linalg.norm(dict_id_to_center[id] - existing_center, ord=2) 281 | if verbose: 282 | print('Existing id {}, distance: {:.3f}'.format(existing_id, cost)) 283 | if cost < best_cost: 284 | best_cost = cost 285 | best_existing_id = existing_id 286 | if best_cost < cost_threshold: 287 | # Assign to existing id 288 | id_mapping[id] = best_existing_id 289 | # Update mean 290 | dict_centers[best_existing_id] = (0, dict_id_to_center[id]) 291 | # Update remaining ids 292 | to_remove.append(id) 293 | 294 | remaining_ids = remaining_ids.difference(to_remove) 295 | 296 | #  Remaining ids get assigned a new unique id 297 | for id in remaining_ids: 298 | id_mapping[id] = available_keys.pop() 299 | dict_centers[id_mapping[id]] = (0, dict_id_to_center[id]) 300 | 301 | #  Increase life by one 302 | increment_life_clusters(dict_centers, cluster_mean_life) 303 | 304 | # Map to tracked id 305 | inst_seg = np.vectorize(id_mapping.__getitem__, otypes=[np.uint8])(inst_seg) 306 | 307 | if len(available_keys) < 50: 308 | #  Reset keys, since enough timeframes separate the ids, we can reuse without any chance of overlapping 309 | print('Reset instance id keys.') 310 | available_keys = set(range(1, MAX_INSTANCES_SCENE)) 311 | 312 | return inst_seg, available_keys, dict_centers 313 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | 7 | from skimage.transform import resize 8 | from PIL import Image 9 | from glob import glob 10 | from multiprocessing import Pool, cpu_count 11 | from functools import partial 12 | from tqdm import tqdm 13 | 14 | from torch.utils.data import Dataset 15 | 16 | from common.constants import MAX_INSTANCES, N_CLASSES, MEAN, STD 17 | from data.utils import pil_loader 18 | 19 | 20 | class MotionDataset(Dataset): 21 | def __init__(self, root, dataset='nuscenes', mode='', seq_len=5, h=128, w=256, load_depth_inputs=False, 22 | num_scales=3, saved_numpy=False): 23 | assert dataset in ['nuscenes', 'apollo', 'kitti', 'kitti_ped', 'davis'], 'Not recognised dataset.' 24 | assert seq_len >= 3, 'Sequence length={} but must be greater of equal than 3.'.format(seq_len) 25 | 26 | self.dataset = dataset 27 | self.seq_len = seq_len 28 | # Original image in nuscenes is 900 x 1600. Divide by factor of 8 -> 112 x 200 29 | self.h = h 30 | self.w = w 31 | self.load_depth_inputs = load_depth_inputs 32 | self.saved_numpy = saved_numpy 33 | 34 | self.data_transforms = transforms.Compose([transforms.ToTensor(), 35 | transforms.Normalize(mean=MEAN, 36 | std=STD)]) 37 | 38 | self.total_n_sequences = 0 39 | self.scene_paths = sorted(glob(os.path.join(root, mode, '*'))) 40 | self.n_scenes = len(self.scene_paths) 41 | 42 | self.dict_scene_to_n_sequences = {} 43 | self.dict_scene_to_filenames = {} 44 | 45 | for i, path in enumerate(self.scene_paths): 46 | filenames = sorted(glob(os.path.join(path, '*_image.jpg'))) 47 | self.dict_scene_to_n_sequences[i] = len(filenames) - seq_len + 1 48 | self.dict_scene_to_filenames[i] = filenames 49 | 50 | self.total_n_sequences += len(filenames) - seq_len + 1 51 | 52 | # Depth 53 | if self.load_depth_inputs: 54 | self.frame_ids = [0, -1, 1] 55 | self.num_scales = num_scales 56 | self.interp = Image.ANTIALIAS 57 | self.img_ext = '.jpg' 58 | self.loader = pil_loader 59 | 60 | self.resize = {} 61 | for i in range(self.num_scales): 62 | s = 2 ** i 63 | self.resize[i] = transforms.Resize((self.h // s, self.w // s), 64 | interpolation=self.interp) 65 | 66 | self.load_depth = False 67 | self.is_train = mode == 'train' 68 | 69 | if self.dataset == 'apollo': 70 | # Defined as in apollo/utilities/intrinsics.txt 71 | # fx, 0, Cx / 3384 and fy, 0, Cy / 2710 72 | self.K = np.array([[0.68101, 0, 0.49830, 0], 73 | [0, 0.85088, 0.49999, 0], 74 | [0, 0, 1, 0], 75 | [0, 0, 0, 1]], dtype=np.float32) 76 | elif self.dataset in ['kitti', 'kitti_ped', 'davis']: 77 | self.K = np.array([[0.58, 0, 0.5, 0], 78 | [0, 1.92, 0.5, 0], 79 | [0, 0, 1, 0], 80 | [0, 0, 0, 1]], dtype=np.float32) 81 | 82 | def __len__(self): 83 | return self.total_n_sequences 84 | 85 | def __getitem__(self, idx): 86 | """ 87 | Returns 88 | ------- 89 | data: dict with keys 90 | img: torch.tensor (T, 3, H, W) 91 | instance_seg: torch.tensor (T, H, W) 92 | input_depth: dict with keys: 93 | ("color", , ): torch.tensor (T, 3, H, W) 94 | raw colour images 95 | ("color_aug", , ): torch.tensor (T, 3, H, W) 96 | augmented colour images 97 | ("K", scale) or ("inv_K", scale): torch.tensor (T, 4, 4) 98 | camera intrinsics 99 | 100 | depth: torch.tensor (T, H, W) # Only in nuscenes for now 101 | position: torch.tensor (T, MAX_INSTANCES, 3) # Only in nuscenes for now 102 | velocity: torch.tensor (T, MAX_INSTANCES, 3) # Only in nuscenes for now 103 | intrinsics: torch.tensor (T, 3, 3) # Only in nuscenes for now 104 | 105 | """ 106 | scene_number, position_in_seq = self.get_scene_number(idx) 107 | 108 | img_filenames = self.dict_scene_to_filenames[scene_number][position_in_seq:(position_in_seq + self.seq_len)] 109 | 110 | data = {} 111 | if self.dataset == 'nuscenes': 112 | keys = ['img', 'instance_seg', 'depth', 'position', 'velocity', 'intrinsics'] 113 | elif self.dataset in ['kitti', 'kitti_ped', 'davis', 'apollo']: 114 | keys = ['img', 'instance_seg', 'input_depth'] 115 | 116 | for key in keys: 117 | data[key] = [] 118 | 119 | for t in range(self.seq_len): 120 | data_one_frame = self.get_single_data(img_filenames[t]) 121 | for key in keys: 122 | if key != 'input_depth': 123 | data[key].append(data_one_frame[key]) 124 | 125 | if self.load_depth_inputs: 126 | #  Depth input data 127 | # Exclude first and last frame, as one past frame and one future frame is needed 128 | for t in range(1, self.seq_len - 1): 129 | triplet_img_filename = {-1: img_filenames[t-1], 130 | 0: img_filenames[t], 131 | 1: img_filenames[t+1] 132 | } 133 | data['input_depth'].append(self.get_depth_input(triplet_img_filename)) 134 | 135 | # Add dummy values for first and last time index 136 | dummy_input = {key: torch.zeros_like(value) for key, value in data['input_depth'][0].items()} 137 | data['input_depth'].insert(0, dummy_input) 138 | data['input_depth'].append(dummy_input) 139 | 140 | # Stack tensor in time dimension 141 | for key in keys: 142 | if key != 'input_depth': 143 | data[key] = torch.stack(data[key], dim=0) 144 | else: 145 | if self.load_depth_inputs: 146 | input_depth_dict = {} 147 | for depth_key in data[key][0].keys(): 148 | input_depth_dict[depth_key] = torch.stack([data[key][t][depth_key] for t in range(self.seq_len)], 149 | dim=0) 150 | data[key] = input_depth_dict 151 | return data 152 | 153 | def get_single_data(self, img_filename): 154 | base_filename = img_filename[:-len('image.jpg')] 155 | img = pil_loader(img_filename) 156 | instance_seg = np.load(base_filename + 'instance_seg.npy') 157 | 158 | if self.dataset == 'nuscenes': 159 | depth = np.load(base_filename + 'disp.npy') 160 | position = np.load(base_filename + 'position.npy') 161 | velocity = np.load(base_filename + 'velocity.npy') 162 | intrinsics = np.load(base_filename + 'intrinsics.npy') 163 | 164 | # TODO: Remove this check in future 165 | instance_seg[instance_seg >= MAX_INSTANCES] = 0 166 | 167 | if not self.saved_numpy: 168 | img, instance_seg, depth, intrinsics = resize_nuscenes_data(img, instance_seg, depth, intrinsics, 169 | h_target=self.h, w_target=self.w) 170 | 171 | # Convert to pytorch 172 | img = self.data_transforms(img) 173 | instance_seg = torch.from_numpy(instance_seg).to(torch.uint8) 174 | 175 | if self.dataset == 'nuscenes': 176 | depth = torch.from_numpy(depth).float() 177 | position = torch.from_numpy(position).float() 178 | velocity = torch.from_numpy(velocity).float() 179 | intrinsics = torch.from_numpy(intrinsics).float() 180 | 181 | data_one_frame = {'img': img, 182 | 'instance_seg': instance_seg, 183 | } 184 | 185 | if self.dataset == 'nuscenes': 186 | data_one_frame['depth'] = depth 187 | data_one_frame['position'] = position 188 | data_one_frame['velocity'] = velocity 189 | data_one_frame['intrinsics'] = intrinsics 190 | 191 | return data_one_frame 192 | 193 | def get_depth_input(self, triplet_img_filename): 194 | """ 195 | Parameters 196 | ---------- 197 | triplet_img_filenames: dict 198 | contains past frame (key -1), current frame (key 0), future frame (key 1) filenames 199 | 200 | Returns a single training item from the dataset as a dictionary. 201 | 202 | Values correspond to torch tensors. 203 | Keys in the dictionary are either strings or tuples: 204 | 205 | ("color", , ) for raw colour images, 206 | ("color_aug", , ) for augmented colour images, 207 | ("K", scale) or ("inv_K", scale) for camera intrinsics, 208 | "stereo_T" for camera extrinsics, and 209 | "depth_gt" for ground truth depth maps. 210 | 211 | is either: 212 | an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index', 213 | or 214 | "s" for the opposite image in the stereo pair. 215 | 216 | is an integer representing the scale of the image relative to the fullsize image: 217 | -1 images at native resolution as loaded from disk 218 | 0 images resized to (self.width, self.height ) 219 | 1 images resized to (self.width // 2, self.height // 2) 220 | 2 images resized to (self.width // 4, self.height // 4) 221 | 3 images resized to (self.width // 8, self.height // 8) 222 | """ 223 | inputs = {} 224 | 225 | do_color_aug = False 226 | do_flip = False 227 | 228 | for i in self.frame_ids: 229 | inputs[("color", i, -1)] = self.loader(triplet_img_filename[i]) 230 | 231 | # adjusting intrinsics to match each scale in the pyramid 232 | for scale in range(self.num_scales): 233 | K = self.K.copy() 234 | 235 | K[0, :] *= self.w // (2 ** scale) 236 | K[1, :] *= self.h // (2 ** scale) 237 | 238 | inv_K = np.linalg.pinv(K) 239 | 240 | inputs[("K", scale)] = torch.from_numpy(K) 241 | inputs[("inv_K", scale)] = torch.from_numpy(inv_K) 242 | 243 | if do_color_aug: 244 | color_aug = transforms.ColorJitter.get_params( 245 | self.brightness, self.contrast, self.saturation, self.hue) 246 | else: 247 | color_aug = (lambda x: x) 248 | 249 | self.preprocess_depth_input(inputs, color_aug) 250 | 251 | for i in self.frame_ids: 252 | del inputs[("color", i, -1)] 253 | del inputs[("color_aug", i, -1)] 254 | 255 | if self.load_depth: 256 | pass 257 | 258 | return inputs 259 | 260 | def preprocess_depth_input(self, inputs, color_aug): 261 | """Resize colour images to the required scales and augment if required 262 | 263 | We create the color_aug object in advance and apply the same augmentation to all 264 | images in this item. This ensures that all images input to the pose network receive the 265 | same augmentation. 266 | """ 267 | for k in list(inputs): 268 | if "color" in k: 269 | n, im, i = k 270 | for i in range(self.num_scales): 271 | inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)]) 272 | 273 | for k in list(inputs): 274 | f = inputs[k] 275 | if "color" in k: 276 | n, im, i = k 277 | inputs[(n, im, i)] = self.data_transforms(f) 278 | inputs[(n + "_aug", im, i)] = self.data_transforms(color_aug(f)) 279 | 280 | def get_scene_number(self, idx): 281 | start = 0 282 | for scene_number in range(self.n_scenes): 283 | end = start + self.dict_scene_to_n_sequences[scene_number] 284 | if start <= idx < end: 285 | position_in_seq = idx - start 286 | return scene_number, position_in_seq 287 | start = end 288 | 289 | raise ValueError('Index {} not found in dataset with {} sequences'.format(idx, self.total_n_sequences)) 290 | 291 | 292 | def resize_nuscenes_data(img, instance_seg, depth, intrinsics, h_target, w_target): 293 | # Resize 294 | original_h, original_w = instance_seg.shape[-2:] 295 | resize_scale = h_target / original_h 296 | h, w = h_target, int(np.ceil(original_w * resize_scale)) 297 | assert (h == h_target) and (w == w_target), 'Mismatch in w: size {} but expected {}'.format(w, w_target) 298 | 299 | img = transforms.Resize((h, w), interpolation=Image.BILINEAR)(img) 300 | instance_seg = (255 * resize(instance_seg, (N_CLASSES, h, w), order=0, anti_aliasing=None)).astype(np.uint8) 301 | depth = resize(depth, (h, w), order=1, anti_aliasing=None) 302 | # Intrinsics 303 | # If resize scale is different for x and y, need to adapt. 304 | intrinsics[0, 0] *= resize_scale 305 | intrinsics[0, 2] *= resize_scale 306 | intrinsics[1, 1] *= resize_scale 307 | intrinsics[1, 2] *= resize_scale 308 | 309 | return img, instance_seg, depth, intrinsics 310 | 311 | 312 | def resize_one_item_multiprocessing(img_fname, new_root, mode, h_target, w_target): 313 | scene = os.path.basename(os.path.dirname(img_fname)) 314 | save_path = os.path.join(new_root, mode, scene) 315 | prefix = os.path.basename(img_fname)[:-len('image.jpg')] 316 | 317 | base_filename = img_fname[:-len('image.jpg')] 318 | img = pil_loader(img_fname) 319 | instance_seg = np.load(base_filename + 'instance_seg.npy') 320 | depth = np.load(base_filename + 'disp.npy') 321 | position = np.load(base_filename + 'position.npy') 322 | velocity = np.load(base_filename + 'velocity.npy') 323 | intrinsics = np.load(base_filename + 'intrinsics.npy') 324 | 325 | img, instance_seg, depth, intrinsics = resize_nuscenes_data(img, instance_seg, depth, intrinsics, 326 | h_target=h_target, w_target=w_target) 327 | os.makedirs(save_path, exist_ok=True) 328 | img.save(os.path.join(save_path, prefix + 'image.jpg')) 329 | np.save(os.path.join(save_path, prefix + 'instance_seg.npy'), instance_seg) 330 | np.save(os.path.join(save_path, prefix + 'disp.npy'), depth) 331 | np.save(os.path.join(save_path, prefix + 'position.npy'), position) 332 | np.save(os.path.join(save_path, prefix + 'velocity.npy'), velocity) 333 | np.save(os.path.join(save_path, prefix + 'intrinsics.npy'), intrinsics) 334 | 335 | 336 | def save_dataset_into_disk(h_target=112, w_target=200, root='', mode='train', 337 | new_name='debug_112x200'): 338 | if root[-1] == '/': 339 | dirname = os.path.dirname(os.path.dirname(root)) 340 | else: 341 | dirname = os.path.dirname(root) 342 | new_root = os.path.join(dirname, new_name) 343 | img_filenames = sorted(glob(os.path.join(root, mode, '*', '*_image.jpg'))) 344 | 345 | pool = Pool(cpu_count() - 1) 346 | for _ in tqdm(pool.imap_unordered(partial(resize_one_item_multiprocessing, new_root=new_root, mode=mode, 347 | h_target=h_target, w_target=w_target), 348 | img_filenames), total=len(img_filenames)): 349 | pass 350 | -------------------------------------------------------------------------------- /data/kitti.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from glob import glob 6 | from PIL import Image 7 | from tqdm import tqdm 8 | from functools import partial 9 | from multiprocessing import Pool, cpu_count 10 | 11 | from data.utils import crop_and_resize, pil_loader 12 | 13 | 14 | def preprocess_dataset(root='', root_save='', 15 | img_size=(192, 640), keep_id=1, extension='.png'): 16 | pool = Pool(cpu_count() - 1) 17 | for mode in ['train', 'val']: 18 | all_scene_dir = sorted(glob(os.path.join(root, 'images', mode, '*'))) 19 | for _ in tqdm(pool.imap_unordered( 20 | partial(preprocess_dataset_iter, root=root, mode=mode, root_save=root_save, img_size=img_size, 21 | keep_id=keep_id, extension=extension), all_scene_dir), total=len(all_scene_dir)): 22 | pass 23 | 24 | 25 | def preprocess_dataset_iter(scene_dir, root, mode, root_save, img_size=(192, 640), keep_id=1, extension='.png'): 26 | """ keep_id 1 is car, 2 is pedestrian""" 27 | print('Scene: {}'.format(scene_dir)) 28 | img_filenames = sorted(glob(os.path.join(root, 'images', mode, os.path.basename(scene_dir), '*' + extension))) 29 | 30 | for img_fname in img_filenames: 31 | basename = os.path.basename(img_fname) 32 | img = pil_loader(img_fname) # Open and convert to RGB 33 | #  Much more compact to load and resize images compared to numpy arrays 34 | instance_seg = Image.open(os.path.join(root, 'instances', mode, os.path.basename(scene_dir), 35 | basename[:-len(extension)] + '.png')) 36 | assert instance_seg.mode == 'I' 37 | ##### 38 | #  Resize 39 | img = crop_and_resize(img, img_size, crop=False) 40 | instance_seg = crop_and_resize(instance_seg, img_size, order=0, crop=False) 41 | 42 | # Filter non-cars 43 | instance_seg = np.array(instance_seg) 44 | unique_ids = np.unique(instance_seg) 45 | # Only keep cars, remove 10000 (ignore regions) and 200x (pedestrians) 46 | keep_ids = [id for id in unique_ids if (str(id).startswith(str(keep_id)) and len(str(id)) == 4)] 47 | instance_seg[~np.isin(instance_seg, keep_ids)] = 0 48 | instance_seg = (instance_seg % 1000).astype(np.uint8) 49 | 50 | #  N_CLASSES = 1 51 | instance_seg = instance_seg[None, :, :] 52 | 53 | save_path = os.path.join(root_save, mode, os.path.basename(scene_dir)) 54 | os.makedirs(save_path, exist_ok=True) 55 | img.save(os.path.join(save_path, basename[:-len(extension)] + '_image.jpg')) 56 | np.save(os.path.join(save_path, basename[:-len(extension)] + '_instance_seg.npy'), instance_seg) 57 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision.transforms as transforms 4 | 5 | from PIL import Image 6 | from skimage.transform import resize 7 | 8 | from common.constants import MEAN, STD 9 | 10 | 11 | def crop_and_resize(x, out_size=(128, 256), order=0, crop=True): 12 | """ Opens and resizes an image as out_size (h, w). Also add option to crop top part of the image so that the ratio 13 | w/h = 2.0 14 | Usually resize to (128, 256) 15 | """ 16 | if isinstance(x, Image.Image): # PIL Image 17 | # Crop the top part of the image 18 | w, h = x.size 19 | if crop: 20 | assert (w, h) == (3384, 2710), 'Image size is wrong.' 21 | x = x.crop((0, h - 1692, w, h)) 22 | 23 | if order == 0: 24 | interpolation = Image.NEAREST 25 | else: 26 | interpolation = Image.BILINEAR 27 | out = transforms.Resize(out_size, interpolation=interpolation)(x) 28 | elif isinstance(x, np.ndarray): 29 | h, w = x.shape 30 | if crop: 31 | x = x[(h - 1692):, :] 32 | out = resize(x, out_size, order=order, anti_aliasing=None) 33 | else: 34 | raise ValueError('Can resize PIL Image or np.ndarray objects, but received {}'.format(type(x))) 35 | 36 | return out 37 | 38 | 39 | def pil_loader(path): 40 | # open path as file to avoid ResourceWarning 41 | # (https://github.com/python-pillow/Pillow/issues/835) 42 | with open(path, 'rb') as f: 43 | with Image.open(f) as img: 44 | return img.convert('RGB') 45 | 46 | 47 | def torch_img_to_numpy(img): 48 | """ 49 | Parameters 50 | ---------- 51 | img: torch.tensor (batch_size, seq_len, 3, H, W) 52 | 53 | Returns 54 | ------- 55 | img_np = np.array (batch_size, seq_len, H, W, 3) 56 | """ 57 | mean_np = np.array(MEAN).reshape((1, 1, 3, 1, 1)) 58 | std_np = np.array(STD).reshape((1, 1, 3, 1, 1)) 59 | 60 | img_np = img.detach().cpu().numpy() 61 | 62 | img_np = std_np * img_np + mean_np 63 | img_np = (255 * img_np).astype(np.uint8) 64 | img_np = img_np.transpose((0, 1, 3, 4, 2)) 65 | return img_np 66 | 67 | 68 | def denormalise(x, dimension=5): 69 | mean = torch.tensor(MEAN) 70 | std = torch.tensor(STD) 71 | 72 | if dimension == 5: 73 | mean = mean.view(1, 1, 3, 1, 1) 74 | std = std.view(1, 1, 3, 1, 1) 75 | elif dimension == 4: 76 | mean = mean.view(1, 3, 1, 1) 77 | std = std.view(1, 3, 1, 1) 78 | else: 79 | raise ValueError('Wrong dimension {}'.format(dimension)) 80 | 81 | return std * x + mean -------------------------------------------------------------------------------- /kitti/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | 7 | from PIL import Image 8 | from glob import glob 9 | 10 | from torch.utils.data import Dataset, DataLoader 11 | 12 | from common.constants import MEAN, STD 13 | 14 | DATA_ROOT = '' 15 | 16 | 17 | class KittiMaskDataset(Dataset): 18 | def __init__(self, split='train', data_transforms=None): 19 | super().__init__() 20 | self.data_transforms = data_transforms 21 | 22 | self.img_filenames = sorted(glob(os.path.join(DATA_ROOT, split, '*', '*_image.jpg'))) 23 | self.label_filenames = sorted(glob(os.path.join(DATA_ROOT, split, '*', '*_instance_seg.npy'))) 24 | 25 | def __len__(self): 26 | return len(self.img_filenames) 27 | 28 | def __getitem__(self, idx): 29 | """ 30 | Returns 31 | ------- 32 | img: torch Tensor of shape (H, W, 3) 33 | label: np.array of shape (H, W) 34 | """ 35 | img = Image.open(self.img_filenames[idx]) 36 | label = np.load(self.label_filenames[idx]) 37 | 38 | if self.data_transforms: 39 | img = self.data_transforms(img) 40 | else: 41 | img = transforms.ToTensor()(img) 42 | 43 | # Shape (1, h, w) 44 | # Binary mask 45 | label = (torch.from_numpy(label[0]) > 0).long() 46 | return img, label 47 | 48 | 49 | def get_kitti_mask_dataloaders(batch_size=2): 50 | data_transforms = transforms.Compose([transforms.ToTensor(), 51 | transforms.Normalize(mean=MEAN, 52 | std=STD)]) 53 | 54 | train_dataset = KittiMaskDataset(split='train', data_transforms=data_transforms) 55 | val_dataset = KittiMaskDataset(split='val', data_transforms=data_transforms) 56 | train_iterator = DataLoader(train_dataset, batch_size, shuffle=True) 57 | val_iterator = DataLoader(val_dataset, batch_size, shuffle=False) 58 | print('Train size: {}'.format(len(train_dataset))) 59 | print('Val size: {}'.format(len(val_dataset))) 60 | 61 | return train_iterator, val_iterator 62 | -------------------------------------------------------------------------------- /kitti/kitti_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kitti.kitti_dataset import get_kitti_mask_dataloaders 3 | from cityscapes.cityscapes_trainer import SemanticTrainer 4 | 5 | 6 | class KittiMaskTrainer(SemanticTrainer): 7 | def create_loss(self): 8 | self.loss = torch.nn.CrossEntropyLoss(weight=torch.tensor([1, 10]).float()) 9 | 10 | def create_data(self): 11 | self.h = self.params['h'] 12 | self.w = self.params['w'] 13 | self.train_iterator, self.val_iterator = get_kitti_mask_dataloaders(batch_size=self.params['batch_size']) 14 | -------------------------------------------------------------------------------- /kitti/train_mask.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from kitti.kitti_trainer import KittiMaskTrainer 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--batch_size", 9 | type=int, 10 | help="batch size", 11 | default=16) 12 | parser.add_argument("--output_path", 13 | type=str, 14 | help='output path', 15 | default='') 16 | parser.add_argument("--model_name", 17 | type=str, 18 | help="model name", 19 | default='resnet', 20 | choices=['resnet', 'deeplab']) 21 | parser.add_argument('--tag', 22 | type=str, 23 | help='session tag', 24 | default='baseline') 25 | 26 | options = parser.parse_args() 27 | 28 | params = {'batch_size': options.batch_size, 29 | 'output_dir': options.output_path, 30 | 'tag': options.tag, 31 | 'device': torch.device('cuda'), 32 | 'model_name': options.model_name, 33 | 'n_classes': 2, 34 | 'pretrained_path': '', 35 | 'h': 192, 36 | 'w': 640 37 | } 38 | 39 | trainer = KittiMaskTrainer(params) 40 | trainer = trainer.to(params['device']) 41 | 42 | trainer.train_model(n_epochs=5, save_every=100) -------------------------------------------------------------------------------- /metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | 4 | import numpy as np 5 | 6 | from nuscenes import NuScenes 7 | from nuscenes.eval.detection import NuScenesEval 8 | from nuscenes.eval.detection.data_classes import DetectionConfig 9 | from nuscenes.eval.detection.config import config_factory 10 | 11 | from common.utils import cummean 12 | 13 | 14 | class MotionMetrics(): 15 | def __init__(self, config, tensorboard, min_recall=0.1, min_precision=0.1): 16 | self.config = config 17 | self.tensorboard = tensorboard 18 | self.threshold = self.config['metrics_threshold'] 19 | self.min_recall = min_recall 20 | self.min_precision = min_precision 21 | 22 | self.metrics = {} 23 | 24 | self.tp = None # True positives 25 | self.fp = None # False positives 26 | self.conf = None # Confidence score 27 | self.n_positive = 0 # Number of ground truth instances seen so far 28 | 29 | # Position/velocity error for true positives only 30 | self.match_data = None 31 | 32 | # MOTS metrics 33 | self.tp_mots = None 34 | self.fp_mots = None 35 | self.fn_mots = None 36 | self.soft_tp_mots = None 37 | self.n_switches = None 38 | 39 | self.reset() 40 | 41 | def update(self, batch, output): 42 | """ 43 | Parameters 44 | ---------- 45 | batch: dict with keys: 46 | instance_seg: torch.tensor (b, seq_len, N_CLASSES, h, w) 47 | position: torch.tensor (b, seq_len, MAX_INSTANCES, 3) 48 | velocity: torch.tensor (b, seq_len, MAX_INSTANCES, 3) 49 | 50 | output: dict with keys: 51 | instance_seg: torch.tensor (b, seq_len, N_CLASSES, h, w) 52 | position: torch.tensor (b, seq_len, MAX_INSTANCES, 3) 53 | velocity: torch.tensor (b, seq_len, MAX_INSTANCES, 3) 54 | """ 55 | receptive_field = self.config['receptive_field'] 56 | if not self.config['instance_loss']: 57 | return 58 | batch_keys = ['img', 'instance_seg'] 59 | output_keys = ['instance_seg'] 60 | 61 | if self.config['motion_loss']: 62 | batch_keys += ['position', 'velocity'] 63 | output_keys += ['position', 'velocity'] 64 | 65 | batch_np = {key: batch[key].detach().cpu().numpy() for key in batch_keys} 66 | output_np = {key: output[key].detach().cpu().numpy() for key in output_keys} 67 | 68 | b, seq_len = batch_np['img'].shape[:2] 69 | 70 | dict_prev_id = {} 71 | for i in range(b): 72 | for t in range(receptive_field - 1, seq_len): 73 | pred_unique_ids = np.unique(output_np['instance_seg'][i, t])[1:] 74 | gt_unique_ids = np.unique(batch_np['instance_seg'][i, t])[1:] 75 | self.n_positive += len(gt_unique_ids) 76 | 77 | taken_gt_ids = set() 78 | used_pred_ids = set() 79 | for id in pred_unique_ids: 80 | mask = output_np['instance_seg'][i, t] == id 81 | best_iou = 0 82 | best_gt_id = None 83 | for gt_id in gt_unique_ids: 84 | if gt_id in taken_gt_ids: 85 | continue 86 | 87 | gt_mask = batch_np['instance_seg'][i, t] == gt_id 88 | inter = (mask & gt_mask).sum() 89 | union = (mask | gt_mask).sum() 90 | iou = inter / union 91 | if iou > best_iou: 92 | best_iou = iou 93 | best_gt_id = gt_id 94 | 95 | conf = np.random.random() 96 | if best_iou > self.threshold: 97 | self.tp.append(1) 98 | self.fp.append(0) 99 | taken_gt_ids.add(best_gt_id) 100 | used_pred_ids.add(id) 101 | 102 | if self.config['motion_loss']: 103 | # 2D velocity error 104 | self.match_data['vel_err'].append(calc_vel_err(output_np['velocity'][i, t, gt_unique_ids], 105 | batch_np['velocity'][i, t, gt_unique_ids])) 106 | self.match_data['pos_err'].append(np.linalg.norm((output_np['position'][i, t, gt_unique_ids] 107 | - batch_np['position'][i, t, gt_unique_ids]))) 108 | self.match_data['conf'].append(conf) 109 | 110 | self.tp_mots += 1 111 | self.soft_tp_mots += best_iou 112 | 113 | if best_gt_id in dict_prev_id and id != dict_prev_id[best_gt_id]: 114 | self.n_switches += 1 115 | dict_prev_id[best_gt_id] = id 116 | 117 | else: 118 | self.tp.append(0) 119 | self.fp.append(1) 120 | self.conf.append(conf) # TODO: add confidence score 121 | 122 | self.fp_mots += len(set(pred_unique_ids).difference(used_pred_ids)) 123 | self.fn_mots += len(set(gt_unique_ids).difference(taken_gt_ids)) 124 | 125 | def evaluate(self, global_step, mode): 126 | if not self.config['instance_loss']: 127 | return 0.0 128 | if len(self.tp) == 0: 129 | print('No accumulated metrics') 130 | self.reset() 131 | return 0 132 | 133 | # Sort by decreasing confidence score 134 | self.conf = np.array(self.conf) 135 | indices = np.argsort(-self.conf) 136 | self.tp = np.array(self.tp)[indices] 137 | self.fp = np.array(self.fp)[indices] 138 | self.conf = self.conf[indices] 139 | 140 | if self.config['motion_loss']: 141 | match_data_indices = np.argsort(-np.array(self.match_data['conf'])) 142 | for key in self.match_data.keys(): 143 | if key == 'conf': 144 | continue 145 | self.match_data[key] = np.array(self.match_data[key])[match_data_indices] 146 | self.match_data['conf'] = np.array(self.match_data['conf'])[match_data_indices] 147 | 148 | # Compute Average Precision 149 | self.tp = np.cumsum(self.tp) 150 | self.fp = np.cumsum(self.fp) 151 | precision = self.tp / (self.tp + self.fp) 152 | recall = self.tp / max(1, self.n_positive) 153 | 154 | # Interpolate to a equally spaced recall values [0-1] with 0.01 increment 155 | recall_interp = np.linspace(0, 1, 101) 156 | precision = np.interp(recall_interp, recall, precision, right=0) 157 | self.conf = np.interp(recall_interp, recall, self.conf, right=0) 158 | 159 | if self.config['motion_loss']: 160 | for key in self.match_data.keys(): 161 | if key == 'conf': 162 | continue 163 | tmp = cummean(self.match_data[key]) 164 | self.match_data[key] = np.interp(self.conf[::-1], self.match_data['conf'][::-1], tmp[::-1])[::-1] 165 | 166 | # Average Precision: area under the precision/recall curve for recall and precision over 10% 167 | self.metrics['ap'] = calc_ap(precision, self.min_recall, self.min_precision) 168 | if self.config['motion_loss']: 169 | self.metrics['vel_err'] = calc_tp(self.match_data['vel_err'], self.conf, self.min_recall) 170 | self.metrics['pos_err'] = calc_tp(self.match_data['pos_err'], self.conf, self.min_recall) 171 | 172 | # MOTS metrics 173 | self.metrics['motsa'] = (self.tp_mots - self.fp_mots - self.n_switches) / max(1, self.n_positive) 174 | self.metrics['motsp'] = self.soft_tp_mots / max(1, self.tp_mots) 175 | self.metrics['smotsa'] = (self.soft_tp_mots - self.fp_mots - self.n_switches) / max(1, self.n_positive) 176 | self.metrics['n_switches'] = self.n_switches 177 | self.metrics['tp_mots'] = self.tp_mots 178 | self.metrics['fp_mots'] = self.fp_mots 179 | self.metrics['fn_mots'] = self.fn_mots 180 | self.metrics['soft_tp_mots'] = self.soft_tp_mots 181 | self.metrics['n_positive'] = self.n_positive 182 | self.metrics['recall'] = self.tp_mots / max(1, self.n_positive) 183 | self.metrics['precision'] = self.tp_mots / max(1, (self.tp_mots + self.fp_mots)) 184 | 185 | for key, value in self.metrics.items(): 186 | print('{}: {:.3f}'.format(key, value)) 187 | self.tensorboard.add_scalar(mode + '/' + key, value, global_step) 188 | 189 | metric_score = self.score() 190 | self.reset() 191 | return metric_score 192 | 193 | def score(self): 194 | return self.metrics['motsa'] 195 | 196 | def reset(self): 197 | self.metrics = {} 198 | 199 | self.tp = [] 200 | self.fp = [] 201 | self.conf = [] 202 | self.n_positive = 0 203 | self.match_data = {'vel_err': [], 204 | 'pos_err': [], 205 | 'conf': [], 206 | } 207 | 208 | self.tp_mots = 0 209 | self.fp_mots = 0 210 | self.fn_mots = 0 211 | self.soft_tp_mots = 0 212 | self.n_switches = 0 213 | 214 | 215 | def calc_vel_err(pred_vel, gt_vel): 216 | # 2D velocity error 217 | vel_err = np.linalg.norm(pred_vel[:, [0, 2]] - gt_vel[:, [0, 2]]) 218 | return vel_err 219 | 220 | 221 | def calc_ap(precision_interp, min_recall: float, min_precision: float) -> float: 222 | """ Calculated average precision. """ 223 | 224 | assert 0 <= min_precision < 1 225 | assert 0 <= min_recall <= 1 226 | 227 | prec = np.copy(precision_interp) 228 | prec = prec[round(100 * min_recall) + 1:] # Clip low recalls. +1 to exclude the min recall bin. 229 | prec -= min_precision # Clip low precision 230 | prec[prec < 0] = 0 231 | return float(np.mean(prec)) / (1.0 - min_precision) 232 | 233 | 234 | def calc_tp(tp_metric, confidence, min_recall: float) -> float: 235 | """ Calculates true positive errors. """ 236 | 237 | first_ind = round(100 * min_recall) + 1 # +1 to exclude the error at min recall. 238 | last_ind = max_recall_ind(confidence) # First instance of confidence = 0 is index of max achieved recall. 239 | if last_ind < first_ind: 240 | return 1.0 # Assign 1 here. If this happens for all classes, the score for that TP metric will be 0. 241 | else: 242 | return float(np.mean(tp_metric[first_ind: last_ind + 1])) # +1 to include error at max recall 243 | 244 | 245 | def max_recall_ind(confidence): 246 | """ Returns index of max recall achieved. """ 247 | 248 | # Last instance of confidence > 0 is index of max achieved recall. 249 | non_zero = np.nonzero(confidence)[0] 250 | if len(non_zero) == 0: # If there are no matches, all the confidence values will be zero. 251 | max_recall_ind = 0 252 | else: 253 | max_recall_ind = non_zero[-1] 254 | 255 | return max_recall_ind 256 | 257 | 258 | if __name__ == '__main__': 259 | result_path_ = '' 260 | output_dir_ = '' 261 | eval_set_ = 'val' 262 | dataroot_ = '' 263 | version_ = 'v1.0-trainval' 264 | config_path = '' 265 | plot_examples_ = 0 266 | render_curves_ = False 267 | verbose_ = True 268 | 269 | if config_path == '': 270 | cfg_ = config_factory('cvpr_2019') 271 | else: 272 | with open(config_path, 'r') as f: 273 | cfg_ = DetectionConfig.deserialize(json.load(f)) 274 | 275 | nusc_ = NuScenes(version=version_, verbose=verbose_, dataroot=dataroot_) 276 | nusc_eval = NuScenesEval(nusc_, config=cfg_, result_path=result_path_, eval_set=eval_set_, 277 | output_dir=output_dir_, verbose=verbose_) 278 | nusc_eval.main(plot_examples=plot_examples_, render_curves=render_curves_) -------------------------------------------------------------------------------- /monodepth/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright © Niantic, Inc. 2018. Patent Pending. 2 | 3 | All rights reserved. 4 | 5 | 6 | 7 | ================================================================================ 8 | 9 | 10 | 11 | This Software is licensed under the terms of the following Monodepth2 license 12 | which allows for non-commercial use only. For any other use of the software not 13 | covered by the terms of this license, please contact partnerships@nianticlabs.com 14 | 15 | 16 | 17 | ================================================================================ 18 | 19 | 20 | 21 | Monodepth v2 License 22 | 23 | 24 | This Agreement is made by and between the Licensor and the Licensee as 25 | defined and identified below. 26 | 27 | 28 | 1. Definitions. 29 | 30 | In this Agreement (“the Agreement”) the following words shall have the 31 | following meanings: 32 | 33 | "Authors" shall mean C. Godard, O. Mac Aodha, M. Firman, G. Brostow 34 | "Licensee" Shall mean the person or organization agreeing to use the 35 | Software in accordance with these terms and conditions. 36 | "Licensor" shall mean Niantic Inc., a company organized and existing under 37 | the laws of Delaware, whose principal place of business is at 1 Ferry Building, 38 | Suite 200, San Francisco, 94111. 39 | "Software" shall mean the MonoDepth v2 Software uploaded by Licensor to the 40 | GitHub repository at [URL] on [DATE] in source code or object code form and any 41 | accompanying documentation as well as any modifications or additions uploaded 42 | to the same GitHub repository by Licensor. 43 | 44 | 45 | 2. License. 46 | 47 | 2.1 The Licensor has all necessary rights to grant a license under: (i) 48 | copyright and rights in the nature of copyright subsisting in the Software; and 49 | (ii) certain patent rights resulting from a patent application filed by the 50 | Licensor in the United States in connection with the Software. The Licensor 51 | grants the Licensee for the duration of this Agreement, a free of charge, 52 | non-sublicenseable, non-exclusive, non-transferable copyright and patent 53 | license (in consequence of said patent application) to use the Software for 54 | non-commercial purpose only, including teaching and research at educational 55 | institutions and research at not-for-profit research institutions in accordance 56 | with the provisions of this Agreement. Non-commercial use expressly excludes 57 | any profit-making or commercial activities, including without limitation sale, 58 | license, manufacture or development of commercial products, use in 59 | commercially-sponsored research, use at a laboratory or other facility owned or 60 | controlled (whether in whole or in part) by a commercial entity, provision of 61 | consulting service, use for or on behalf of any commercial entity, and use in 62 | research where a commercial party obtains rights to research results or any 63 | other benefit. Any use of the Software for any purpose other than 64 | non-commercial research shall automatically terminate this License. 65 | 66 | 67 | 2.2 The Licensee is permitted to make modifications to the Software 68 | provided that any distribution of such modifications is in accordance with 69 | Clause 3. 70 | 71 | 2.3 Except as expressly permitted by this Agreement and save to the 72 | extent and in the circumstances expressly required to be permitted by law, the 73 | Licensee is not permitted to rent, lease, sell, offer to sell, or loan the 74 | Software or its associated documentation. 75 | 76 | 77 | 3. Redistribution and modifications 78 | 79 | 3.1 The Licensee may reproduce and distribute copies of the Software, with 80 | or without modifications, in source format only and only to this same GitHub 81 | repository , and provided that any and every distribution is accompanied by an 82 | unmodified copy of this License and that the following copyright notice is 83 | always displayed in an obvious manner: Copyright © Niantic, Inc. 2018. All 84 | rights reserved. 85 | 86 | 87 | 3.2 In the case where the Software has been modified, any distribution must 88 | include prominent notices indicating which files have been changed. 89 | 90 | 3.3 The Licensee shall cause any work that it distributes or publishes, 91 | that in whole or in part contains or is derived from the Software or any part 92 | thereof (“Work based on the Software”), to be licensed as a whole at no charge 93 | to all third parties entitled to a license to the Software under the terms of 94 | this License and on the same terms provided in this License. 95 | 96 | 97 | 4. Duration. 98 | 99 | This Agreement is effective until the Licensee terminates it by destroying 100 | the Software, any Work based on the Software, and its documentation together 101 | with all copies. It will also terminate automatically if the Licensee fails to 102 | abide by its terms. Upon automatic termination the Licensee agrees to destroy 103 | all copies of the Software, Work based on the Software, and its documentation. 104 | 105 | 106 | 5. Disclaimer of Warranties. 107 | 108 | The Software is provided as is. To the maximum extent permitted by law, 109 | Licensor provides no warranties or conditions of any kind, either express or 110 | implied, including without limitation, any warranties or condition of title, 111 | non-infringement or fitness for a particular purpose. 112 | 113 | 114 | 6. LIMITATION OF LIABILITY. 115 | 116 | IN NO EVENT SHALL THE LICENSOR AND/OR AUTHORS BE LIABLE FOR ANY DIRECT, 117 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING 118 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 119 | DATA OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 120 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 121 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 122 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 123 | 124 | 125 | 7. Indemnity. 126 | 127 | The Licensee shall indemnify the Licensor and/or Authors against all third 128 | party claims that may be asserted against or suffered by the Licensor and/or 129 | Authors and which relate to use of the Software by the Licensee. 130 | 131 | 132 | 8. Intellectual Property. 133 | 134 | 8.1 As between the Licensee and Licensor, copyright and all other 135 | intellectual property rights subsisting in or in connection with the Software 136 | and supporting information shall remain at all times the property of the 137 | Licensor. The Licensee shall acquire no rights in any such material except as 138 | expressly provided in this Agreement. 139 | 140 | 8.2 No permission is granted to use the trademarks or product names of the 141 | Licensor except as required for reasonable and customary use in describing the 142 | origin of the Software and for the purposes of abiding by the terms of Clause 143 | 3.1. 144 | 145 | 8.3 The Licensee shall promptly notify the Licensor of any improvement or 146 | new use of the Software (“Improvements”) in sufficient detail for Licensor to 147 | evaluate the Improvements. The Licensee hereby grants the Licensor and its 148 | affiliates a non-exclusive, fully paid-up, royalty-free, irrevocable and 149 | perpetual license to all Improvements for non-commercial academic research and 150 | teaching purposes upon creation of such improvements. 151 | 152 | 8.4 The Licensee grants an exclusive first option to the Licensor to be 153 | exercised by the Licensor within three (3) years of the date of notification of 154 | an Improvement under Clause 8.3 to use any the Improvement for commercial 155 | purposes on terms to be negotiated and agreed by Licensee and Licensor in good 156 | faith within a period of six (6) months from the date of exercise of the said 157 | option (including without limitation any royalty share in net income from such 158 | commercialization payable to the Licensee, as the case may be). 159 | 160 | 161 | 9. Acknowledgements. 162 | 163 | The Licensee shall acknowledge the Authors and use of the Software in the 164 | publication of any work that uses, or results that are achieved through, the 165 | use of the Software. The following citation shall be included in the 166 | acknowledgement: “Digging Into Self-Supervised Monocular Depth Estimation, 167 | by C. Godard, O. Mac Aodha, M. Firman, G. Brostow, arXiv:1806.01260”. 168 | 169 | 170 | 10. Governing Law. 171 | 172 | This Agreement shall be governed by, construed and interpreted in 173 | accordance with English law and the parties submit to the exclusive 174 | jurisdiction of the English courts. 175 | 176 | 177 | 11. Termination. 178 | 179 | Upon termination of this Agreement, the licenses granted hereunder will 180 | terminate and Sections 5, 6, 7, 8, 9, 10 and 11 shall survive any termination 181 | of this Agreement. 182 | -------------------------------------------------------------------------------- /monodepth/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .kitti_dataset import KITTIRAWDataset, KITTIOdomDataset, KITTIDepthDataset 2 | from .apollo_dataset import ApolloDataset 3 | -------------------------------------------------------------------------------- /monodepth/datasets/apollo_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import os 10 | import numpy as np 11 | import PIL.Image as pil 12 | 13 | from glob import glob 14 | from tqdm import tqdm 15 | 16 | from monodepth.datasets.mono_dataset import MonoDataset 17 | 18 | 19 | class ApolloDataset(MonoDataset): 20 | """Superclass for different types of KITTI dataset loaders 21 | """ 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | 25 | self.mode = 'train' if self.is_train else 'val' 26 | 27 | # Defined as in apollo/utilities/intrinsics.txt 28 | # fx, 0, Cx / 3384 and fy, 0, Cy / 2710 29 | self.K = np.array([[0.68101, 0, 0.49830, 0], 30 | [0, 0.85088, 0.49999, 0], 31 | [0, 0, 1, 0], 32 | [0, 0, 0, 1]], dtype=np.float32) 33 | 34 | self.full_res_shape = (256, 128) # (width, height) 35 | 36 | def check_depth(self): 37 | # Do not load ground truth depth map 38 | return False 39 | 40 | def get_color(self, folder, frame_index, side, do_flip): 41 | color = self.loader(self.get_image_path(folder, frame_index, side)) 42 | 43 | if do_flip: 44 | color = color.transpose(pil.FLIP_LEFT_RIGHT) 45 | 46 | return color 47 | 48 | def get_image_path(self, folder, frame_index, side): 49 | f_str = "{:04d}_image{}".format(frame_index, self.img_ext) 50 | image_path = os.path.join( 51 | self.data_path, self.mode, folder, f_str) 52 | return image_path 53 | 54 | 55 | def generate_split(): 56 | root = '' 57 | side = 'l' 58 | 59 | for mode in ['train', 'val']: 60 | output_file = ''.format(mode) 61 | scene_names = glob(os.path.join(root, mode, '*')) 62 | for scene in tqdm(scene_names, total=len(scene_names)): 63 | img_filenames = sorted(glob(os.path.join(scene, '*.jpg'))) 64 | assert '{:04d}'.format(len(img_filenames) - 1) == os.path.basename(img_filenames[-1])[:-len('_image.jpg')] 65 | # Remove first and last element 66 | img_filenames = img_filenames[1:-1] 67 | base_folder = os.path.basename(scene) 68 | 69 | with open(output_file, 'a') as f: 70 | for t in range(len(img_filenames)): 71 | f.write('{} {} {}\n'.format(base_folder, t + 1, side)) 72 | 73 | 74 | if __name__ == '__main__': 75 | generate_split() 76 | -------------------------------------------------------------------------------- /monodepth/datasets/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import os 10 | import skimage.transform 11 | import numpy as np 12 | import PIL.Image as pil 13 | 14 | from monodepth.kitti_utils import generate_depth_map 15 | from monodepth.datasets.mono_dataset import MonoDataset 16 | 17 | 18 | class KITTIDataset(MonoDataset): 19 | """Superclass for different types of KITTI dataset loaders 20 | """ 21 | def __init__(self, *args, **kwargs): 22 | super(KITTIDataset, self).__init__(*args, **kwargs) 23 | 24 | self.K = np.array([[0.58, 0, 0.5, 0], 25 | [0, 1.92, 0.5, 0], 26 | [0, 0, 1, 0], 27 | [0, 0, 0, 1]], dtype=np.float32) 28 | 29 | self.full_res_shape = (1242, 375) 30 | self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3} 31 | 32 | def check_depth(self): 33 | line = self.filenames[0].split() 34 | scene_name = line[0] 35 | frame_index = int(line[1]) 36 | 37 | velo_filename = os.path.join( 38 | self.data_path, 39 | scene_name, 40 | "velodyne_points/data/{:010d}.bin".format(int(frame_index))) 41 | 42 | return os.path.isfile(velo_filename) 43 | 44 | def get_color(self, folder, frame_index, side, do_flip): 45 | color = self.loader(self.get_image_path(folder, frame_index, side)) 46 | 47 | if do_flip: 48 | color = color.transpose(pil.FLIP_LEFT_RIGHT) 49 | 50 | return color 51 | 52 | 53 | class KITTIRAWDataset(KITTIDataset): 54 | """KITTI dataset which loads the original velodyne depth maps for ground truth 55 | """ 56 | def __init__(self, *args, **kwargs): 57 | super(KITTIRAWDataset, self).__init__(*args, **kwargs) 58 | 59 | def get_image_path(self, folder, frame_index, side): 60 | f_str = "{:010d}{}".format(frame_index, self.img_ext) 61 | image_path = os.path.join( 62 | self.data_path, folder, "image_0{}/data".format(self.side_map[side]), f_str) 63 | return image_path 64 | 65 | def get_depth(self, folder, frame_index, side, do_flip): 66 | calib_path = os.path.join(self.data_path, folder.split("/")[0]) 67 | 68 | velo_filename = os.path.join( 69 | self.data_path, 70 | folder, 71 | "velodyne_points/data/{:010d}.bin".format(int(frame_index))) 72 | 73 | depth_gt = generate_depth_map(calib_path, velo_filename, self.side_map[side]) 74 | depth_gt = skimage.transform.resize( 75 | depth_gt, self.full_res_shape[::-1], order=0, preserve_range=True, mode='constant') 76 | 77 | if do_flip: 78 | depth_gt = np.fliplr(depth_gt) 79 | 80 | return depth_gt 81 | 82 | 83 | class KITTIOdomDataset(KITTIDataset): 84 | """KITTI dataset for odometry training and testing 85 | """ 86 | def __init__(self, *args, **kwargs): 87 | super(KITTIOdomDataset, self).__init__(*args, **kwargs) 88 | 89 | def get_image_path(self, folder, frame_index, side): 90 | f_str = "{:06d}{}".format(frame_index, self.img_ext) 91 | image_path = os.path.join( 92 | self.data_path, 93 | "sequences/{:02d}".format(int(folder)), 94 | "image_{}".format(self.side_map[side]), 95 | f_str) 96 | return image_path 97 | 98 | 99 | class KITTIDepthDataset(KITTIDataset): 100 | """KITTI dataset which uses the updated ground truth depth maps 101 | """ 102 | def __init__(self, *args, **kwargs): 103 | super(KITTIDepthDataset, self).__init__(*args, **kwargs) 104 | 105 | def get_image_path(self, folder, frame_index, side): 106 | f_str = "{:010d}{}".format(frame_index, self.img_ext) 107 | image_path = os.path.join( 108 | self.data_path, 109 | folder, 110 | "image_0{}/data".format(self.side_map[side]), 111 | f_str) 112 | return image_path 113 | 114 | def get_depth(self, folder, frame_index, side, do_flip): 115 | f_str = "{:010d}.png".format(frame_index) 116 | depth_path = os.path.join( 117 | self.data_path, 118 | folder, 119 | "proj_depth/groundtruth/image_0{}".format(self.side_map[side]), 120 | f_str) 121 | 122 | depth_gt = pil.open(depth_path) 123 | depth_gt = depth_gt.resize(self.full_res_shape, pil.NEAREST) 124 | depth_gt = np.array(depth_gt).astype(np.float32) / 256 125 | 126 | if do_flip: 127 | depth_gt = np.fliplr(depth_gt) 128 | 129 | return depth_gt 130 | -------------------------------------------------------------------------------- /monodepth/datasets/mono_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import os 10 | import random 11 | import numpy as np 12 | import copy 13 | from PIL import Image # using pillow-simd for increased speed 14 | 15 | import torch 16 | import torch.utils.data as data 17 | from torchvision import transforms 18 | 19 | from common.constants import MEAN, STD 20 | 21 | 22 | def pil_loader(path): 23 | # open path as file to avoid ResourceWarning 24 | # (https://github.com/python-pillow/Pillow/issues/835) 25 | with open(path, 'rb') as f: 26 | with Image.open(f) as img: 27 | return img.convert('RGB') 28 | 29 | 30 | class MonoDataset(data.Dataset): 31 | """Superclass for monocular dataloaders 32 | 33 | Args: 34 | data_path 35 | filenames 36 | height 37 | width 38 | frame_idxs 39 | num_scales 40 | is_train 41 | img_ext 42 | """ 43 | def __init__(self, 44 | data_path, 45 | filenames, 46 | height, 47 | width, 48 | frame_idxs, 49 | num_scales, 50 | is_train=False, 51 | img_ext='.jpg'): 52 | super(MonoDataset, self).__init__() 53 | 54 | self.data_path = data_path 55 | self.filenames = filenames 56 | self.height = height 57 | self.width = width 58 | self.num_scales = num_scales 59 | self.interp = Image.ANTIALIAS 60 | 61 | self.frame_idxs = frame_idxs 62 | 63 | self.is_train = is_train 64 | self.img_ext = img_ext 65 | 66 | self.loader = pil_loader 67 | self.to_tensor = transforms.Compose([transforms.ToTensor(), 68 | transforms.Normalize(mean=MEAN, 69 | std=STD)]) 70 | 71 | # We need to specify augmentations differently in newer versions of torchvision. 72 | # We first try the newer tuple version; if this fails we fall back to scalars 73 | try: 74 | self.brightness = (0.8, 1.2) 75 | self.contrast = (0.8, 1.2) 76 | self.saturation = (0.8, 1.2) 77 | self.hue = (-0.1, 0.1) 78 | transforms.ColorJitter.get_params( 79 | self.brightness, self.contrast, self.saturation, self.hue) 80 | except TypeError: 81 | self.brightness = 0.2 82 | self.contrast = 0.2 83 | self.saturation = 0.2 84 | self.hue = 0.1 85 | 86 | self.resize = {} 87 | for i in range(self.num_scales): 88 | s = 2 ** i 89 | self.resize[i] = transforms.Resize((self.height // s, self.width // s), 90 | interpolation=self.interp) 91 | 92 | self.load_depth = self.check_depth() 93 | 94 | def preprocess(self, inputs, color_aug): 95 | """Resize colour images to the required scales and augment if required 96 | 97 | We create the color_aug object in advance and apply the same augmentation to all 98 | images in this item. This ensures that all images input to the pose network receive the 99 | same augmentation. 100 | """ 101 | for k in list(inputs): 102 | frame = inputs[k] 103 | if "color" in k: 104 | n, im, i = k 105 | for i in range(self.num_scales): 106 | inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)]) 107 | 108 | for k in list(inputs): 109 | f = inputs[k] 110 | if "color" in k: 111 | n, im, i = k 112 | inputs[(n, im, i)] = self.to_tensor(f) 113 | inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f)) 114 | 115 | def __len__(self): 116 | return len(self.filenames) 117 | 118 | def __getitem__(self, index): 119 | """Returns a single training item from the dataset as a dictionary. 120 | 121 | Values correspond to torch tensors. 122 | Keys in the dictionary are either strings or tuples: 123 | 124 | ("color", , ) for raw colour images, 125 | ("color_aug", , ) for augmented colour images, 126 | ("K", scale) or ("inv_K", scale) for camera intrinsics, 127 | "stereo_T" for camera extrinsics, and 128 | "depth_gt" for ground truth depth maps. 129 | 130 | is either: 131 | an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index', 132 | or 133 | "s" for the opposite image in the stereo pair. 134 | 135 | is an integer representing the scale of the image relative to the fullsize image: 136 | -1 images at native resolution as loaded from disk 137 | 0 images resized to (self.width, self.height ) 138 | 1 images resized to (self.width // 2, self.height // 2) 139 | 2 images resized to (self.width // 4, self.height // 4) 140 | 3 images resized to (self.width // 8, self.height // 8) 141 | """ 142 | inputs = {} 143 | 144 | do_color_aug = self.is_train and random.random() > 0.5 145 | do_flip = self.is_train and random.random() > 0.5 146 | 147 | line = self.filenames[index].split() 148 | folder = line[0] 149 | 150 | if len(line) == 3: 151 | frame_index = int(line[1]) 152 | else: 153 | frame_index = 0 154 | 155 | if len(line) == 3: 156 | side = line[2] 157 | else: 158 | side = None 159 | 160 | for i in self.frame_idxs: 161 | if i == "s": 162 | other_side = {"r": "l", "l": "r"}[side] 163 | inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip) 164 | else: 165 | inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip) 166 | 167 | # adjusting intrinsics to match each scale in the pyramid 168 | for scale in range(self.num_scales): 169 | K = self.K.copy() 170 | 171 | K[0, :] *= self.width // (2 ** scale) 172 | K[1, :] *= self.height // (2 ** scale) 173 | 174 | inv_K = np.linalg.pinv(K) 175 | 176 | inputs[("K", scale)] = torch.from_numpy(K) 177 | inputs[("inv_K", scale)] = torch.from_numpy(inv_K) 178 | 179 | if do_color_aug: 180 | color_aug = transforms.ColorJitter.get_params( 181 | self.brightness, self.contrast, self.saturation, self.hue) 182 | else: 183 | color_aug = (lambda x: x) 184 | 185 | self.preprocess(inputs, color_aug) 186 | 187 | for i in self.frame_idxs: 188 | del inputs[("color", i, -1)] 189 | del inputs[("color_aug", i, -1)] 190 | 191 | if self.load_depth: 192 | depth_gt = self.get_depth(folder, frame_index, side, do_flip) 193 | inputs["depth_gt"] = np.expand_dims(depth_gt, 0) 194 | inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32)) 195 | 196 | if "s" in self.frame_idxs: 197 | stereo_T = np.eye(4, dtype=np.float32) 198 | baseline_sign = -1 if do_flip else 1 199 | side_sign = -1 if side == "l" else 1 200 | stereo_T[0, 3] = side_sign * baseline_sign * 0.1 201 | 202 | inputs["stereo_T"] = torch.from_numpy(stereo_T) 203 | 204 | return inputs 205 | 206 | def get_color(self, folder, frame_index, side, do_flip): 207 | raise NotImplementedError 208 | 209 | def check_depth(self): 210 | raise NotImplementedError 211 | 212 | def get_depth(self, folder, frame_index, side, do_flip): 213 | raise NotImplementedError 214 | -------------------------------------------------------------------------------- /monodepth/evaluate_depth.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from monodepth.layers import disp_to_depth 11 | from monodepth.utils import readlines 12 | from monodepth.options import MonodepthOptions 13 | import monodepth.datasets as datasets 14 | import monodepth.networks as networks 15 | 16 | cv2.setNumThreads(0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1) 17 | 18 | 19 | splits_dir = os.path.join(os.path.dirname(__file__), "splits") 20 | 21 | # Models which were trained with stereo supervision were trained with a nominal 22 | # baseline of 0.1 units. The KITTI rig has a baseline of 54cm. Therefore, 23 | # to convert our stereo predictions to real-world scale we multiply our depths by 5.4. 24 | STEREO_SCALE_FACTOR = 5.4 25 | 26 | 27 | def compute_errors(gt, pred): 28 | """Computation of error metrics between predicted and ground truth depths 29 | """ 30 | thresh = np.maximum((gt / pred), (pred / gt)) 31 | a1 = (thresh < 1.25 ).mean() 32 | a2 = (thresh < 1.25 ** 2).mean() 33 | a3 = (thresh < 1.25 ** 3).mean() 34 | 35 | rmse = (gt - pred) ** 2 36 | rmse = np.sqrt(rmse.mean()) 37 | 38 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 39 | rmse_log = np.sqrt(rmse_log.mean()) 40 | 41 | abs_rel = np.mean(np.abs(gt - pred) / gt) 42 | 43 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 44 | 45 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 46 | 47 | 48 | def batch_post_process_disparity(l_disp, r_disp): 49 | """Apply the disparity post-processing method as introduced in Monodepthv1 50 | """ 51 | _, h, w = l_disp.shape 52 | m_disp = 0.5 * (l_disp + r_disp) 53 | l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h)) 54 | l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...] 55 | r_mask = l_mask[:, :, ::-1] 56 | return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp 57 | 58 | 59 | def evaluate(opt): 60 | """Evaluates a pretrained model using a specified test set 61 | """ 62 | MIN_DEPTH = 1e-3 63 | MAX_DEPTH = 80 64 | 65 | assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \ 66 | "Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo" 67 | 68 | if opt.ext_disp_to_eval is None: 69 | 70 | opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder) 71 | 72 | assert os.path.isdir(opt.load_weights_folder), \ 73 | "Cannot find a folder at {}".format(opt.load_weights_folder) 74 | 75 | print("-> Loading weights from {}".format(opt.load_weights_folder)) 76 | 77 | filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt")) 78 | encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth") 79 | decoder_path = os.path.join(opt.load_weights_folder, "depth.pth") 80 | 81 | encoder_dict = torch.load(encoder_path) 82 | 83 | dataset = datasets.KITTIRAWDataset(opt.data_path, filenames, 84 | encoder_dict['height'], encoder_dict['width'], 85 | [0], 4, is_train=False) 86 | dataloader = DataLoader(dataset, 16, shuffle=False, num_workers=opt.num_workers, 87 | pin_memory=True, drop_last=False) 88 | 89 | encoder = networks.ResnetEncoder(opt.num_layers, False) 90 | depth_decoder = networks.DepthDecoder(encoder.num_ch_enc) 91 | 92 | model_dict = encoder.state_dict() 93 | encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict}) 94 | depth_decoder.load_state_dict(torch.load(decoder_path)) 95 | 96 | encoder.cuda() 97 | encoder.eval() 98 | depth_decoder.cuda() 99 | depth_decoder.eval() 100 | 101 | pred_disps = [] 102 | 103 | print("-> Computing predictions with size {}x{}".format( 104 | encoder_dict['width'], encoder_dict['height'])) 105 | 106 | with torch.no_grad(): 107 | for data in dataloader: 108 | input_color = data[("color", 0, 0)].cuda() 109 | 110 | if opt.post_process: 111 | # Post-processed results require each image to have two forward passes 112 | input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0) 113 | 114 | output = depth_decoder(encoder(input_color)) 115 | 116 | pred_disp, _ = disp_to_depth(output[("disp", 0)], opt.min_depth, opt.max_depth) 117 | pred_disp = pred_disp.cpu()[:, 0].numpy() 118 | 119 | if opt.post_process: 120 | N = pred_disp.shape[0] // 2 121 | pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1]) 122 | 123 | pred_disps.append(pred_disp) 124 | 125 | pred_disps = np.concatenate(pred_disps) 126 | 127 | else: 128 | # Load predictions from file 129 | print("-> Loading predictions from {}".format(opt.ext_disp_to_eval)) 130 | pred_disps = np.load(opt.ext_disp_to_eval) 131 | 132 | if opt.eval_eigen_to_benchmark: 133 | eigen_to_benchmark_ids = np.load( 134 | os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy")) 135 | 136 | pred_disps = pred_disps[eigen_to_benchmark_ids] 137 | 138 | if opt.save_pred_disps: 139 | output_path = os.path.join( 140 | opt.load_weights_folder, "disps_{}_split.npy".format(opt.eval_split)) 141 | print("-> Saving predicted disparities to ", output_path) 142 | np.save(output_path, pred_disps) 143 | 144 | if opt.no_eval: 145 | print("-> Evaluation disabled. Done.") 146 | quit() 147 | 148 | elif opt.eval_split == 'benchmark': 149 | save_dir = os.path.join(opt.load_weights_folder, "benchmark_predictions") 150 | print("-> Saving out benchmark predictions to {}".format(save_dir)) 151 | if not os.path.exists(save_dir): 152 | os.makedirs(save_dir) 153 | 154 | for idx in range(len(pred_disps)): 155 | disp_resized = cv2.resize(pred_disps[idx], (1216, 352)) 156 | depth = STEREO_SCALE_FACTOR / disp_resized 157 | depth = np.clip(depth, 0, 80) 158 | depth = np.uint16(depth * 256) 159 | save_path = os.path.join(save_dir, "{:010d}.png".format(idx)) 160 | cv2.imwrite(save_path, depth) 161 | 162 | print("-> No ground truth is available for the KITTI benchmark, so not evaluating. Done.") 163 | quit() 164 | 165 | gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz") 166 | gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1')["data"] 167 | 168 | print("-> Evaluating") 169 | 170 | if opt.eval_stereo: 171 | print(" Stereo evaluation - " 172 | "disabling median scaling, scaling by {}".format(STEREO_SCALE_FACTOR)) 173 | opt.disable_median_scaling = True 174 | opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR 175 | else: 176 | print(" Mono evaluation - using median scaling") 177 | 178 | errors = [] 179 | ratios = [] 180 | 181 | for i in range(pred_disps.shape[0]): 182 | 183 | gt_depth = gt_depths[i] 184 | gt_height, gt_width = gt_depth.shape[:2] 185 | 186 | pred_disp = pred_disps[i] 187 | pred_disp = cv2.resize(pred_disp, (gt_width, gt_height)) 188 | pred_depth = 1 / pred_disp 189 | 190 | if opt.eval_split == "eigen": 191 | mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH) 192 | 193 | crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height, 194 | 0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32) 195 | crop_mask = np.zeros(mask.shape) 196 | crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1 197 | mask = np.logical_and(mask, crop_mask) 198 | 199 | else: 200 | mask = gt_depth > 0 201 | 202 | pred_depth = pred_depth[mask] 203 | gt_depth = gt_depth[mask] 204 | 205 | pred_depth *= opt.pred_depth_scale_factor 206 | if not opt.disable_median_scaling: 207 | ratio = np.median(gt_depth) / np.median(pred_depth) 208 | ratios.append(ratio) 209 | pred_depth *= ratio 210 | 211 | pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH 212 | pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH 213 | 214 | errors.append(compute_errors(gt_depth, pred_depth)) 215 | 216 | if not opt.disable_median_scaling: 217 | ratios = np.array(ratios) 218 | med = np.median(ratios) 219 | print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med))) 220 | 221 | mean_errors = np.array(errors).mean(0) 222 | 223 | print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 224 | print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 225 | print("\n-> Done!") 226 | 227 | 228 | if __name__ == "__main__": 229 | options = MonodepthOptions() 230 | evaluate(options.parse()) 231 | -------------------------------------------------------------------------------- /monodepth/evaluate_pose.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import os 10 | import numpy as np 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from monodepth.layers import transformation_from_parameters 16 | from monodepth.utils import readlines 17 | from monodepth.options import MonodepthOptions 18 | from monodepth.datasets import KITTIOdomDataset 19 | import monodepth.networks as networks 20 | 21 | 22 | # from https://github.com/tinghuiz/SfMLearner 23 | def dump_xyz(source_to_target_transformations): 24 | xyzs = [] 25 | cam_to_world = np.eye(4) 26 | xyzs.append(cam_to_world[:3, 3]) 27 | for source_to_target_transformation in source_to_target_transformations: 28 | cam_to_world = np.dot(cam_to_world, source_to_target_transformation) 29 | xyzs.append(cam_to_world[:3, 3]) 30 | return xyzs 31 | 32 | 33 | # from https://github.com/tinghuiz/SfMLearner 34 | def compute_ate(gtruth_xyz, pred_xyz_o): 35 | 36 | # Make sure that the first matched frames align (no need for rotational alignment as 37 | # all the predicted/ground-truth snippets have been converted to use the same coordinate 38 | # system with the first frame of the snippet being the origin). 39 | offset = gtruth_xyz[0] - pred_xyz_o[0] 40 | pred_xyz = pred_xyz_o + offset[None, :] 41 | 42 | # Optimize the scaling factor 43 | scale = np.sum(gtruth_xyz * pred_xyz) / np.sum(pred_xyz ** 2) 44 | alignment_error = pred_xyz * scale - gtruth_xyz 45 | rmse = np.sqrt(np.sum(alignment_error ** 2)) / gtruth_xyz.shape[0] 46 | return rmse 47 | 48 | 49 | def evaluate(opt): 50 | """Evaluate odometry on the KITTI dataset 51 | """ 52 | assert os.path.isdir(opt.load_weights_folder), \ 53 | "Cannot find a folder at {}".format(opt.load_weights_folder) 54 | 55 | assert opt.eval_split == "odom_9" or opt.eval_split == "odom_10", \ 56 | "eval_split should be either odom_9 or odom_10" 57 | 58 | sequence_id = int(opt.eval_split.split("_")[1]) 59 | 60 | filenames = readlines( 61 | os.path.join(os.path.dirname(__file__), "splits", "odom", 62 | "test_files_{:02d}.txt".format(sequence_id))) 63 | 64 | dataset = KITTIOdomDataset(opt.data_path, filenames, opt.height, opt.width, 65 | [0, 1], 4, is_train=False) 66 | dataloader = DataLoader(dataset, opt.batch_size, shuffle=False, 67 | num_workers=opt.num_workers, pin_memory=True, drop_last=False) 68 | 69 | pose_encoder_path = os.path.join(opt.load_weights_folder, "pose_encoder.pth") 70 | pose_decoder_path = os.path.join(opt.load_weights_folder, "pose.pth") 71 | 72 | pose_encoder = networks.ResnetEncoder(opt.num_layers, False, 2) 73 | pose_encoder.load_state_dict(torch.load(pose_encoder_path)) 74 | 75 | pose_decoder = networks.PoseDecoder(pose_encoder.num_ch_enc, 1, 2) 76 | pose_decoder.load_state_dict(torch.load(pose_decoder_path)) 77 | 78 | pose_encoder.cuda() 79 | pose_encoder.eval() 80 | pose_decoder.cuda() 81 | pose_decoder.eval() 82 | 83 | pred_poses = [] 84 | 85 | print("-> Computing pose predictions") 86 | 87 | opt.frame_ids = [0, 1] # pose network only takes two frames as input 88 | 89 | with torch.no_grad(): 90 | for inputs in dataloader: 91 | for key, ipt in inputs.items(): 92 | inputs[key] = ipt.cuda() 93 | 94 | all_color_aug = torch.cat([inputs[("color_aug", i, 0)] for i in opt.frame_ids], 1) 95 | 96 | features = [pose_encoder(all_color_aug)] 97 | axisangle, translation = pose_decoder(features) 98 | 99 | pred_poses.append( 100 | transformation_from_parameters(axisangle[:, 0], translation[:, 0]).cpu().numpy()) 101 | 102 | pred_poses = np.concatenate(pred_poses) 103 | 104 | gt_poses_path = os.path.join(opt.data_path, "poses", "{:02d}.txt".format(sequence_id)) 105 | gt_global_poses = np.loadtxt(gt_poses_path).reshape(-1, 3, 4) 106 | gt_global_poses = np.concatenate( 107 | (gt_global_poses, np.zeros((gt_global_poses.shape[0], 1, 4))), 1) 108 | gt_global_poses[:, 3, 3] = 1 109 | gt_xyzs = gt_global_poses[:, :3, 3] 110 | 111 | gt_local_poses = [] 112 | for i in range(1, len(gt_global_poses)): 113 | gt_local_poses.append( 114 | np.linalg.inv(np.dot(np.linalg.inv(gt_global_poses[i - 1]), gt_global_poses[i]))) 115 | 116 | ates = [] 117 | num_frames = gt_xyzs.shape[0] 118 | track_length = 5 119 | for i in range(0, num_frames - 1): 120 | local_xyzs = np.array(dump_xyz(pred_poses[i:i + track_length - 1])) 121 | gt_local_xyzs = np.array(dump_xyz(gt_local_poses[i:i + track_length - 1])) 122 | 123 | ates.append(compute_ate(gt_local_xyzs, local_xyzs)) 124 | 125 | print("\n Trajectory error: {:0.3f}, std: {:0.3f}\n".format(np.mean(ates), np.std(ates))) 126 | 127 | save_path = os.path.join(opt.load_weights_folder, "poses.npy") 128 | np.save(save_path, pred_poses) 129 | print("-> Predictions saved to", save_path) 130 | 131 | 132 | if __name__ == "__main__": 133 | options = MonodepthOptions() 134 | evaluate(options.parse()) 135 | -------------------------------------------------------------------------------- /monodepth/export_gt_depth.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import os 10 | 11 | import argparse 12 | import numpy as np 13 | import PIL.Image as pil 14 | 15 | from monodepth.utils import readlines 16 | from monodepth.kitti_utils import generate_depth_map 17 | 18 | 19 | def export_gt_depths_kitti(): 20 | 21 | parser = argparse.ArgumentParser(description='export_gt_depth') 22 | 23 | parser.add_argument('--data_path', 24 | type=str, 25 | help='path to the root of the KITTI data', 26 | required=True) 27 | parser.add_argument('--split', 28 | type=str, 29 | help='which split to export gt from', 30 | required=True, 31 | choices=["eigen", "eigen_benchmark"]) 32 | opt = parser.parse_args() 33 | 34 | split_folder = os.path.join(os.path.dirname(__file__), "splits", opt.split) 35 | lines = readlines(os.path.join(split_folder, "test_files.txt")) 36 | 37 | print("Exporting ground truth depths for {}".format(opt.split)) 38 | 39 | gt_depths = [] 40 | for line in lines: 41 | 42 | folder, frame_id, _ = line.split() 43 | frame_id = int(frame_id) 44 | 45 | if opt.split == "eigen": 46 | calib_dir = os.path.join(opt.data_path, folder.split("/")[0]) 47 | velo_filename = os.path.join(opt.data_path, folder, 48 | "velodyne_points/data", "{:010d}.bin".format(frame_id)) 49 | gt_depth = generate_depth_map(calib_dir, velo_filename, 2, True) 50 | elif opt.split == "eigen_benchmark": 51 | gt_depth_path = os.path.join(opt.data_path, folder, "proj_depth", 52 | "groundtruth", "image_02", "{:010d}.png".format(frame_id)) 53 | gt_depth = np.array(pil.open(gt_depth_path)).astype(np.float32) / 256 54 | 55 | gt_depths.append(gt_depth.astype(np.float32)) 56 | 57 | output_path = os.path.join(split_folder, "gt_depths.npz") 58 | 59 | print("Saving to {}".format(opt.split)) 60 | 61 | np.savez_compressed(output_path, data=np.array(gt_depths)) 62 | 63 | 64 | if __name__ == "__main__": 65 | export_gt_depths_kitti() 66 | -------------------------------------------------------------------------------- /monodepth/kitti_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import numpy as np 5 | from collections import Counter 6 | 7 | 8 | def load_velodyne_points(filename): 9 | """Load 3D point cloud from KITTI file format 10 | (adapted from https://github.com/hunse/kitti) 11 | """ 12 | points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4) 13 | points[:, 3] = 1.0 # homogeneous 14 | return points 15 | 16 | 17 | def read_calib_file(path): 18 | """Read KITTI calibration file 19 | (from https://github.com/hunse/kitti) 20 | """ 21 | float_chars = set("0123456789.e+- ") 22 | data = {} 23 | with open(path, 'r') as f: 24 | for line in f.readlines(): 25 | key, value = line.split(':', 1) 26 | value = value.strip() 27 | data[key] = value 28 | if float_chars.issuperset(value): 29 | # try to cast to float array 30 | try: 31 | data[key] = np.array(list(map(float, value.split(' ')))) 32 | except ValueError: 33 | # casting error: data[key] already eq. value, so pass 34 | pass 35 | 36 | return data 37 | 38 | 39 | def sub2ind(matrixSize, rowSub, colSub): 40 | """Convert row, col matrix subscripts to linear indices 41 | """ 42 | m, n = matrixSize 43 | return rowSub * (n-1) + colSub - 1 44 | 45 | 46 | def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False): 47 | """Generate a depth map from velodyne data 48 | """ 49 | # load calibration files 50 | cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt')) 51 | velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt')) 52 | velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis])) 53 | velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0]))) 54 | 55 | # get image shape 56 | im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32) 57 | 58 | # compute projection matrix velodyne->image plane 59 | R_cam2rect = np.eye(4) 60 | R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3) 61 | P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4) 62 | P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam) 63 | 64 | # load velodyne points and remove all behind image plane (approximation) 65 | # each row of the velodyne data is forward, left, up, reflectance 66 | velo = load_velodyne_points(velo_filename) 67 | velo = velo[velo[:, 0] >= 0, :] 68 | 69 | # project the points to the camera 70 | velo_pts_im = np.dot(P_velo2im, velo.T).T 71 | velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis] 72 | 73 | if vel_depth: 74 | velo_pts_im[:, 2] = velo[:, 0] 75 | 76 | # check if in bounds 77 | # use minus 1 to get the exact same value as KITTI matlab code 78 | velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1 79 | velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1 80 | val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) 81 | val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0]) 82 | velo_pts_im = velo_pts_im[val_inds, :] 83 | 84 | # project to image 85 | depth = np.zeros((im_shape[:2])) 86 | depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] 87 | 88 | # find the duplicate points and choose the closest depth 89 | inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) 90 | dupe_inds = [item for item, count in Counter(inds).items() if count > 1] 91 | for dd in dupe_inds: 92 | pts = np.where(inds == dd)[0] 93 | x_loc = int(velo_pts_im[pts[0], 0]) 94 | y_loc = int(velo_pts_im[pts[0], 1]) 95 | depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() 96 | depth[depth < 0] = 0 97 | 98 | return depth 99 | -------------------------------------------------------------------------------- /monodepth/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def disp_to_depth(disp, min_depth, max_depth): 17 | """Convert network's sigmoid output into depth prediction 18 | The formula for this conversion is given in the 'additional considerations' 19 | section of the paper. 20 | """ 21 | min_disp = 1 / max_depth 22 | max_disp = 1 / min_depth 23 | scaled_disp = min_disp + (max_disp - min_disp) * disp 24 | depth = 1 / scaled_disp 25 | return scaled_disp, depth 26 | 27 | 28 | def transformation_from_parameters(axisangle, translation, invert=False): 29 | """Convert the network's (axisangle, translation) output into a 4x4 matrix 30 | """ 31 | R = rot_from_axisangle(axisangle) 32 | t = translation.clone() 33 | 34 | if invert: 35 | R = R.transpose(1, 2) 36 | t *= -1 37 | 38 | T = get_translation_matrix(t) 39 | 40 | if invert: 41 | M = torch.matmul(R, T) 42 | else: 43 | M = torch.matmul(T, R) 44 | 45 | return M 46 | 47 | 48 | def get_translation_matrix(translation_vector): 49 | """Convert a translation vector into a 4x4 transformation matrix 50 | """ 51 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) 52 | 53 | t = translation_vector.contiguous().view(-1, 3, 1) 54 | 55 | T[:, 0, 0] = 1 56 | T[:, 1, 1] = 1 57 | T[:, 2, 2] = 1 58 | T[:, 3, 3] = 1 59 | T[:, :3, 3, None] = t 60 | 61 | return T 62 | 63 | 64 | def rot_from_axisangle(vec): 65 | """Convert an axisangle rotation into a 4x4 transformation matrix 66 | (adapted from https://github.com/Wallacoloo/printipi) 67 | Input 'vec' has to be Bx1x3 68 | """ 69 | angle = torch.norm(vec, 2, 2, True) 70 | axis = vec / (angle + 1e-7) 71 | 72 | ca = torch.cos(angle) 73 | sa = torch.sin(angle) 74 | C = 1 - ca 75 | 76 | x = axis[..., 0].unsqueeze(1) 77 | y = axis[..., 1].unsqueeze(1) 78 | z = axis[..., 2].unsqueeze(1) 79 | 80 | xs = x * sa 81 | ys = y * sa 82 | zs = z * sa 83 | xC = x * C 84 | yC = y * C 85 | zC = z * C 86 | xyC = x * yC 87 | yzC = y * zC 88 | zxC = z * xC 89 | 90 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) 91 | 92 | rot[:, 0, 0] = torch.squeeze(x * xC + ca) 93 | rot[:, 0, 1] = torch.squeeze(xyC - zs) 94 | rot[:, 0, 2] = torch.squeeze(zxC + ys) 95 | rot[:, 1, 0] = torch.squeeze(xyC + zs) 96 | rot[:, 1, 1] = torch.squeeze(y * yC + ca) 97 | rot[:, 1, 2] = torch.squeeze(yzC - xs) 98 | rot[:, 2, 0] = torch.squeeze(zxC - ys) 99 | rot[:, 2, 1] = torch.squeeze(yzC + xs) 100 | rot[:, 2, 2] = torch.squeeze(z * zC + ca) 101 | rot[:, 3, 3] = 1 102 | 103 | return rot 104 | 105 | 106 | class ConvBlock(nn.Module): 107 | """Layer to perform a convolution followed by ELU 108 | """ 109 | def __init__(self, in_channels, out_channels): 110 | super(ConvBlock, self).__init__() 111 | 112 | self.conv = Conv3x3(in_channels, out_channels) 113 | self.nonlin = nn.ELU(inplace=True) 114 | 115 | def forward(self, x): 116 | out = self.conv(x) 117 | out = self.nonlin(out) 118 | return out 119 | 120 | 121 | class Conv3x3(nn.Module): 122 | """Layer to pad and convolve input 123 | """ 124 | def __init__(self, in_channels, out_channels, use_refl=True): 125 | super(Conv3x3, self).__init__() 126 | 127 | if use_refl: 128 | self.pad = nn.ReflectionPad2d(1) 129 | else: 130 | self.pad = nn.ZeroPad2d(1) 131 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 132 | 133 | def forward(self, x): 134 | out = self.pad(x) 135 | out = self.conv(out) 136 | return out 137 | 138 | 139 | class BackprojectDepth(nn.Module): 140 | """Layer to transform a depth image into a point cloud 141 | """ 142 | def __init__(self, batch_size, height, width): 143 | super(BackprojectDepth, self).__init__() 144 | 145 | self.batch_size = batch_size 146 | self.height = height 147 | self.width = width 148 | 149 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') 150 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 151 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), 152 | requires_grad=False) 153 | 154 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), 155 | requires_grad=False) 156 | 157 | self.pix_coords = torch.unsqueeze(torch.stack( 158 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) 159 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) 160 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), 161 | requires_grad=False) 162 | 163 | def forward(self, depth, inv_K): 164 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) 165 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 166 | cam_points = torch.cat([cam_points, self.ones], 1) 167 | 168 | return cam_points 169 | 170 | 171 | class Project3D(nn.Module): 172 | """Layer which projects 3D points into a camera with intrinsics K and at position T 173 | """ 174 | def __init__(self, batch_size, height, width, eps=1e-7): 175 | super(Project3D, self).__init__() 176 | 177 | self.batch_size = batch_size 178 | self.height = height 179 | self.width = width 180 | self.eps = eps 181 | 182 | def forward(self, points, K, T): 183 | P = torch.matmul(K, T)[:, :3, :] 184 | 185 | cam_points = torch.matmul(P, points) 186 | 187 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) 188 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) 189 | pix_coords = pix_coords.permute(0, 2, 3, 1) 190 | pix_coords[..., 0] /= self.width - 1 191 | pix_coords[..., 1] /= self.height - 1 192 | pix_coords = (pix_coords - 0.5) * 2 193 | return pix_coords 194 | 195 | 196 | def upsample(x): 197 | """Upsample input tensor by a factor of 2 198 | """ 199 | return F.interpolate(x, scale_factor=2, mode="nearest") 200 | 201 | 202 | def get_smooth_loss(disp, img): 203 | """Computes the smoothness loss for a disparity image 204 | The color image is used for edge-aware smoothness 205 | """ 206 | grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) 207 | grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) 208 | 209 | grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) 210 | grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) 211 | 212 | grad_disp_x *= torch.exp(-grad_img_x) 213 | grad_disp_y *= torch.exp(-grad_img_y) 214 | 215 | return grad_disp_x.mean() + grad_disp_y.mean() 216 | 217 | 218 | class SSIM(nn.Module): 219 | """Layer to compute the SSIM loss between a pair of images 220 | """ 221 | def __init__(self): 222 | super(SSIM, self).__init__() 223 | self.mu_x_pool = nn.AvgPool2d(3, 1) 224 | self.mu_y_pool = nn.AvgPool2d(3, 1) 225 | self.sig_x_pool = nn.AvgPool2d(3, 1) 226 | self.sig_y_pool = nn.AvgPool2d(3, 1) 227 | self.sig_xy_pool = nn.AvgPool2d(3, 1) 228 | 229 | self.refl = nn.ReflectionPad2d(1) 230 | 231 | self.C1 = 0.01 ** 2 232 | self.C2 = 0.03 ** 2 233 | 234 | def forward(self, x, y): 235 | x = self.refl(x) 236 | y = self.refl(y) 237 | 238 | mu_x = self.mu_x_pool(x) 239 | mu_y = self.mu_y_pool(y) 240 | 241 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 242 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 243 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y 244 | 245 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) 246 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) 247 | 248 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 249 | 250 | 251 | def compute_depth_errors(gt, pred): 252 | """Computation of error metrics between predicted and ground truth depths 253 | """ 254 | thresh = torch.max((gt / pred), (pred / gt)) 255 | a1 = (thresh < 1.25 ).float().mean() 256 | a2 = (thresh < 1.25 ** 2).float().mean() 257 | a3 = (thresh < 1.25 ** 3).float().mean() 258 | 259 | rmse = (gt - pred) ** 2 260 | rmse = torch.sqrt(rmse.mean()) 261 | 262 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 263 | rmse_log = torch.sqrt(rmse_log.mean()) 264 | 265 | abs_rel = torch.mean(torch.abs(gt - pred) / gt) 266 | 267 | sq_rel = torch.mean((gt - pred) ** 2 / gt) 268 | 269 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 270 | -------------------------------------------------------------------------------- /monodepth/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_encoder import ResnetEncoder 2 | from .depth_decoder import DepthDecoder 3 | from .pose_decoder import PoseDecoder 4 | from .pose_cnn import PoseCNN 5 | -------------------------------------------------------------------------------- /monodepth/networks/depth_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | from collections import OrderedDict 10 | from monodepth.layers import * 11 | 12 | 13 | class DepthDecoder(nn.Module): 14 | def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True): 15 | super(DepthDecoder, self).__init__() 16 | 17 | self.num_output_channels = num_output_channels 18 | self.use_skips = use_skips 19 | self.upsample_mode = 'nearest' 20 | self.scales = scales 21 | 22 | self.num_ch_enc = num_ch_enc 23 | self.num_ch_dec = np.array([16, 32, 64, 128, 256]) 24 | 25 | # decoder 26 | self.convs = OrderedDict() 27 | for i in range(4, -1, -1): 28 | # upconv_0 29 | num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] 30 | num_ch_out = self.num_ch_dec[i] 31 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) 32 | 33 | # upconv_1 34 | num_ch_in = self.num_ch_dec[i] 35 | if self.use_skips and i > 0: 36 | num_ch_in += self.num_ch_enc[i - 1] 37 | num_ch_out = self.num_ch_dec[i] 38 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) 39 | 40 | for s in self.scales: 41 | self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) 42 | 43 | self.decoder = nn.ModuleList(list(self.convs.values())) 44 | self.sigmoid = nn.Sigmoid() 45 | 46 | def forward(self, input_features): 47 | self.outputs = {} 48 | 49 | # decoder 50 | x = input_features[-1] 51 | for i in range(4, -1, -1): 52 | x = self.convs[("upconv", i, 0)](x) 53 | x = [upsample(x)] 54 | if self.use_skips and i > 0: 55 | x += [input_features[i - 1]] 56 | x = torch.cat(x, 1) 57 | x = self.convs[("upconv", i, 1)](x) 58 | if i in self.scales: 59 | self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x)) 60 | 61 | return self.outputs 62 | -------------------------------------------------------------------------------- /monodepth/networks/pose_cnn.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class PoseCNN(nn.Module): 14 | def __init__(self, num_input_frames): 15 | super(PoseCNN, self).__init__() 16 | 17 | self.num_input_frames = num_input_frames 18 | 19 | self.convs = {} 20 | self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3) 21 | self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2) 22 | self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1) 23 | self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1) 24 | self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1) 25 | self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1) 26 | self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1) 27 | 28 | self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1) 29 | 30 | self.num_convs = len(self.convs) 31 | 32 | self.relu = nn.ReLU(True) 33 | 34 | self.net = nn.ModuleList(list(self.convs.values())) 35 | 36 | def forward(self, out): 37 | 38 | for i in range(self.num_convs): 39 | out = self.convs[i](out) 40 | out = self.relu(out) 41 | 42 | out = self.pose_conv(out) 43 | out = out.mean(3).mean(2) 44 | 45 | out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6) 46 | 47 | axisangle = out[..., :3] 48 | translation = out[..., 3:] 49 | 50 | return axisangle, translation 51 | -------------------------------------------------------------------------------- /monodepth/networks/pose_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import torch 10 | import torch.nn as nn 11 | from collections import OrderedDict 12 | 13 | 14 | class PoseDecoder(nn.Module): 15 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1): 16 | super(PoseDecoder, self).__init__() 17 | 18 | self.num_ch_enc = num_ch_enc 19 | self.num_input_features = num_input_features 20 | 21 | if num_frames_to_predict_for is None: 22 | num_frames_to_predict_for = num_input_features - 1 23 | self.num_frames_to_predict_for = num_frames_to_predict_for 24 | 25 | self.convs = OrderedDict() 26 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1) 27 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1) 28 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) 29 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1) 30 | 31 | self.relu = nn.ReLU() 32 | 33 | self.net = nn.ModuleList(list(self.convs.values())) 34 | 35 | def forward(self, input_features): 36 | last_features = [f[-1] for f in input_features] 37 | 38 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features] 39 | cat_features = torch.cat(cat_features, 1) 40 | 41 | out = cat_features 42 | for i in range(3): 43 | out = self.convs[("pose", i)](out) 44 | if i != 2: 45 | out = self.relu(out) 46 | 47 | out = out.mean(3).mean(2) 48 | 49 | out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6) 50 | 51 | axisangle = out[..., :3] 52 | translation = out[..., 3:] 53 | 54 | return axisangle, translation 55 | -------------------------------------------------------------------------------- /monodepth/networks/resnet_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.models as models 14 | import torch.utils.model_zoo as model_zoo 15 | 16 | 17 | class ResNetMultiImageInput(models.ResNet): 18 | """Constructs a resnet model with varying number of input images. 19 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 20 | """ 21 | def __init__(self, block, layers, num_classes=1000, num_input_images=1): 22 | super(ResNetMultiImageInput, self).__init__(block, layers) 23 | self.inplanes = 64 24 | self.conv1 = nn.Conv2d( 25 | num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) 26 | self.bn1 = nn.BatchNorm2d(64) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 29 | self.layer1 = self._make_layer(block, 64, layers[0]) 30 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 31 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 32 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 33 | 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 37 | elif isinstance(m, nn.BatchNorm2d): 38 | nn.init.constant_(m.weight, 1) 39 | nn.init.constant_(m.bias, 0) 40 | 41 | 42 | def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): 43 | """Constructs a ResNet model. 44 | Args: 45 | num_layers (int): Number of resnet layers. Must be 18 or 50 46 | pretrained (bool): If True, returns a model pre-trained on ImageNet 47 | num_input_images (int): Number of frames stacked as input 48 | """ 49 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" 50 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] 51 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] 52 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) 53 | 54 | if pretrained: 55 | loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) 56 | loaded['conv1.weight'] = torch.cat( 57 | [loaded['conv1.weight']] * num_input_images, 1) / num_input_images 58 | model.load_state_dict(loaded) 59 | return model 60 | 61 | 62 | class ResnetEncoder(nn.Module): 63 | """Pytorch module for a resnet encoder 64 | """ 65 | def __init__(self, num_layers, pretrained, num_input_images=1): 66 | super(ResnetEncoder, self).__init__() 67 | 68 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 69 | 70 | resnets = {18: models.resnet18, 71 | 34: models.resnet34, 72 | 50: models.resnet50, 73 | 101: models.resnet101, 74 | 152: models.resnet152} 75 | 76 | if num_layers not in resnets: 77 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 78 | 79 | if num_input_images > 1: 80 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) 81 | else: 82 | self.encoder = resnets[num_layers](pretrained) 83 | 84 | if num_layers > 34: 85 | self.num_ch_enc[1:] *= 4 86 | 87 | def forward(self, input_image): 88 | self.features = [] 89 | x = input_image 90 | x = self.encoder.conv1(x) 91 | x = self.encoder.bn1(x) 92 | self.features.append(self.encoder.relu(x)) 93 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) 94 | self.features.append(self.encoder.layer2(self.features[-1])) 95 | self.features.append(self.encoder.layer3(self.features[-1])) 96 | self.features.append(self.encoder.layer4(self.features[-1])) 97 | 98 | return self.features 99 | -------------------------------------------------------------------------------- /monodepth/options.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import os 10 | import argparse 11 | 12 | file_dir = os.path.dirname(__file__) # the directory that options.py resides in 13 | 14 | 15 | class MonodepthOptions: 16 | def __init__(self): 17 | self.parser = argparse.ArgumentParser(description="Monodepthv2 options") 18 | 19 | # PATHS 20 | self.parser.add_argument("--data_path", 21 | type=str, 22 | help="path to the training data", 23 | default='') 24 | self.parser.add_argument("--log_dir", 25 | type=str, 26 | help="log directory", 27 | default='') 28 | 29 | # TRAINING options 30 | self.parser.add_argument("--model_name", 31 | type=str, 32 | help="the name of the folder to save the model in", 33 | default="debug") 34 | self.parser.add_argument("--split", 35 | type=str, 36 | help="which training split to use", 37 | choices=["eigen_zhou", "eigen_full", "odom", "benchmark", "apollo"], 38 | default="eigen_zhou") 39 | self.parser.add_argument("--num_layers", 40 | type=int, 41 | help="number of resnet layers", 42 | default=18, 43 | choices=[18, 34, 50, 101, 152]) 44 | self.parser.add_argument("--dataset", 45 | type=str, 46 | help="dataset to train on", 47 | default="kitti", 48 | choices=["kitti", "kitti_odom", "kitti_depth", "kitti_test", "apollo"]) 49 | self.parser.add_argument("--png", 50 | help="if set, trains from raw KITTI png files (instead of jpgs)", 51 | action="store_true") 52 | self.parser.add_argument("--height", 53 | type=int, 54 | help="input image height", 55 | default=192) 56 | self.parser.add_argument("--width", 57 | type=int, 58 | help="input image width", 59 | default=640) 60 | self.parser.add_argument("--disparity_smoothness", 61 | type=float, 62 | help="disparity smoothness weight", 63 | default=1e-3) 64 | self.parser.add_argument("--scales", 65 | nargs="+", 66 | type=int, 67 | help="scales used in the loss", 68 | default=[0, 1, 2]) 69 | self.parser.add_argument("--min_depth", 70 | type=float, 71 | help="minimum depth", 72 | default=0.1) 73 | self.parser.add_argument("--max_depth", 74 | type=float, 75 | help="maximum depth", 76 | default=100.0) 77 | self.parser.add_argument("--use_stereo", 78 | help="if set, uses stereo pair for training", 79 | action="store_true") 80 | self.parser.add_argument("--frame_ids", 81 | nargs="+", 82 | type=int, 83 | help="frames to load", 84 | default=[0, -1, 1]) 85 | self.parser.add_argument("--own_resnet", 86 | help="if set, use own resnet encoder/decoder", 87 | action="store_true") 88 | 89 | # OPTIMIZATION options 90 | self.parser.add_argument("--batch_size", 91 | type=int, 92 | help="batch size", 93 | default=12) 94 | self.parser.add_argument("--learning_rate", 95 | type=float, 96 | help="learning rate", 97 | default=1e-4) 98 | self.parser.add_argument("--num_epochs", 99 | type=int, 100 | help="number of epochs", 101 | default=20) 102 | self.parser.add_argument("--scheduler_step_size", 103 | type=int, 104 | help="step size of the scheduler", 105 | default=15) 106 | 107 | # ABLATION options 108 | self.parser.add_argument("--v1_multiscale", 109 | help="if set, uses monodepth v1 multiscale", 110 | action="store_true") 111 | self.parser.add_argument("--avg_reprojection", 112 | help="if set, uses average reprojection loss", 113 | action="store_true") 114 | self.parser.add_argument("--disable_automasking", 115 | help="if set, doesn't do auto-masking", 116 | action="store_true") 117 | self.parser.add_argument("--predictive_mask", 118 | help="if set, uses a predictive masking scheme as in Zhou et al", 119 | action="store_true") 120 | self.parser.add_argument("--no_ssim", 121 | help="if set, disables ssim in the loss", 122 | action="store_true") 123 | self.parser.add_argument("--weights_init", 124 | type=str, 125 | help="pretrained or scratch", 126 | default="pretrained", 127 | choices=["pretrained", "scratch"]) 128 | self.parser.add_argument("--pose_model_input", 129 | type=str, 130 | help="how many images the pose network gets", 131 | default="pairs", 132 | choices=["pairs", "all"]) 133 | self.parser.add_argument("--pose_model_type", 134 | type=str, 135 | help="normal or shared", 136 | default="separate_resnet", 137 | choices=["posecnn", "separate_resnet", "shared"]) 138 | 139 | # SYSTEM options 140 | self.parser.add_argument("--no_cuda", 141 | help="if set disables CUDA", 142 | action="store_true") 143 | self.parser.add_argument("--num_workers", 144 | type=int, 145 | help="number of dataloader workers", 146 | default=6) 147 | 148 | # LOADING options 149 | self.parser.add_argument("--load_weights_folder", 150 | type=str, 151 | help="name of model to load") 152 | self.parser.add_argument("--models_to_load", 153 | nargs="+", 154 | type=str, 155 | help="models to load", 156 | default=["encoder", "depth", "pose_encoder", "pose"]) 157 | 158 | # LOGGING options 159 | self.parser.add_argument("--log_frequency", 160 | type=int, 161 | help="number of batches between each tensorboard log", 162 | default=250) 163 | self.parser.add_argument("--save_frequency", 164 | type=int, 165 | help="number of epochs between each save", 166 | default=1) 167 | 168 | # EVALUATION options 169 | self.parser.add_argument("--eval_stereo", 170 | help="if set evaluates in stereo mode", 171 | action="store_true") 172 | self.parser.add_argument("--eval_mono", 173 | help="if set evaluates in mono mode", 174 | action="store_true") 175 | self.parser.add_argument("--disable_median_scaling", 176 | help="if set disables median scaling in evaluation", 177 | action="store_true") 178 | self.parser.add_argument("--pred_depth_scale_factor", 179 | help="if set multiplies predictions by this number", 180 | type=float, 181 | default=1) 182 | self.parser.add_argument("--ext_disp_to_eval", 183 | type=str, 184 | help="optional path to a .npy disparities file to evaluate") 185 | self.parser.add_argument("--eval_split", 186 | type=str, 187 | default="eigen", 188 | choices=[ 189 | "eigen", "eigen_benchmark", "benchmark", "odom_9", "odom_10"], 190 | help="which split to run eval on") 191 | self.parser.add_argument("--save_pred_disps", 192 | help="if set saves predicted disparities", 193 | action="store_true") 194 | self.parser.add_argument("--no_eval", 195 | help="if set disables evaluation", 196 | action="store_true") 197 | self.parser.add_argument("--eval_eigen_to_benchmark", 198 | help="if set assume we are loading eigen results from npy but " 199 | "we want to evaluate using the new benchmark.", 200 | action="store_true") 201 | self.parser.add_argument("--eval_out_dir", 202 | help="if set will output the disparities to this folder", 203 | type=str) 204 | self.parser.add_argument("--post_process", 205 | help="if set will perform the flipping post processing " 206 | "from the original monodepth paper", 207 | action="store_true") 208 | 209 | def parse(self): 210 | self.options = self.parser.parse_args() 211 | return self.options 212 | -------------------------------------------------------------------------------- /monodepth/test_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import os 10 | import sys 11 | import glob 12 | import argparse 13 | import numpy as np 14 | import PIL.Image as pil 15 | import matplotlib as mpl 16 | import matplotlib.cm as cm 17 | 18 | import torch 19 | from torchvision import transforms, datasets 20 | 21 | import monodepth.networks as networks 22 | from monodepth.layers import disp_to_depth 23 | from monodepth.utils import download_model_if_doesnt_exist 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser( 28 | description='Simple testing funtion for Monodepthv2 models.') 29 | 30 | parser.add_argument('--image_path', type=str, 31 | help='path to a test image or folder of images', required=True) 32 | parser.add_argument('--model_name', type=str, 33 | help='name of a pretrained model to use', 34 | choices=[ 35 | "mono_640x192", 36 | "stereo_640x192", 37 | "mono+stereo_640x192", 38 | "mono_no_pt_640x192", 39 | "stereo_no_pt_640x192", 40 | "mono+stereo_no_pt_640x192", 41 | "mono_1024x320", 42 | "stereo_1024x320", 43 | "mono+stereo_1024x320"]) 44 | parser.add_argument('--ext', type=str, 45 | help='image extension to search for in folder', default="jpg") 46 | parser.add_argument("--no_cuda", 47 | help='if set, disables CUDA', 48 | action='store_true') 49 | 50 | return parser.parse_args() 51 | 52 | 53 | def test_simple(args): 54 | """Function to predict for a single image or folder of images 55 | """ 56 | assert args.model_name is not None, \ 57 | "You must specify the --model_name parameter; see README.md for an example" 58 | 59 | if torch.cuda.is_available() and not args.no_cuda: 60 | device = torch.device("cuda") 61 | else: 62 | device = torch.device("cpu") 63 | 64 | download_model_if_doesnt_exist(args.model_name) 65 | model_path = os.path.join("models", args.model_name) 66 | print("-> Loading model from ", model_path) 67 | encoder_path = os.path.join(model_path, "encoder.pth") 68 | depth_decoder_path = os.path.join(model_path, "depth.pth") 69 | 70 | # LOADING PRETRAINED MODEL 71 | print(" Loading pretrained encoder") 72 | encoder = networks.ResnetEncoder(18, False) 73 | loaded_dict_enc = torch.load(encoder_path, map_location=device) 74 | 75 | # extract the height and width of image that this model was trained with 76 | feed_height = loaded_dict_enc['height'] 77 | feed_width = loaded_dict_enc['width'] 78 | filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()} 79 | encoder.load_state_dict(filtered_dict_enc) 80 | encoder.to(device) 81 | encoder.eval() 82 | 83 | print(" Loading pretrained decoder") 84 | depth_decoder = networks.DepthDecoder( 85 | num_ch_enc=encoder.num_ch_enc, scales=range(4)) 86 | 87 | loaded_dict = torch.load(depth_decoder_path, map_location=device) 88 | depth_decoder.load_state_dict(loaded_dict) 89 | 90 | depth_decoder.to(device) 91 | depth_decoder.eval() 92 | 93 | # FINDING INPUT IMAGES 94 | if os.path.isfile(args.image_path): 95 | # Only testing on a single image 96 | paths = [args.image_path] 97 | output_directory = os.path.dirname(args.image_path) 98 | elif os.path.isdir(args.image_path): 99 | # Searching folder for images 100 | paths = glob.glob(os.path.join(args.image_path, '*.{}'.format(args.ext))) 101 | output_directory = args.image_path 102 | else: 103 | raise Exception("Can not find args.image_path: {}".format(args.image_path)) 104 | 105 | print("-> Predicting on {:d} test images".format(len(paths))) 106 | 107 | # PREDICTING ON EACH IMAGE IN TURN 108 | with torch.no_grad(): 109 | for idx, image_path in enumerate(paths): 110 | 111 | if image_path.endswith("_disp.jpg"): 112 | # don't try to predict disparity for a disparity image! 113 | continue 114 | 115 | # Load image and preprocess 116 | input_image = pil.open(image_path).convert('RGB') 117 | original_width, original_height = input_image.size 118 | input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS) 119 | input_image = transforms.ToTensor()(input_image).unsqueeze(0) 120 | 121 | # PREDICTION 122 | input_image = input_image.to(device) 123 | features = encoder(input_image) 124 | outputs = depth_decoder(features) 125 | 126 | disp = outputs[("disp", 0)] 127 | disp_resized = torch.nn.functional.interpolate( 128 | disp, (original_height, original_width), mode="bilinear", align_corners=False) 129 | 130 | # Saving numpy file 131 | output_name = os.path.splitext(os.path.basename(image_path))[0] 132 | name_dest_npy = os.path.join(output_directory, "{}_disp.npy".format(output_name)) 133 | scaled_disp, _ = disp_to_depth(disp, 0.1, 100) 134 | np.save(name_dest_npy, scaled_disp.cpu().numpy()) 135 | 136 | # Saving colormapped depth image 137 | disp_resized_np = disp_resized.squeeze().cpu().numpy() 138 | vmax = np.percentile(disp_resized_np, 95) 139 | normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax) 140 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') 141 | colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8) 142 | im = pil.fromarray(colormapped_im) 143 | 144 | name_dest_im = os.path.join(output_directory, "{}_disp.jpeg".format(output_name)) 145 | im.save(name_dest_im) 146 | 147 | print(" Processed {:d} of {:d} images - saved prediction to {}".format( 148 | idx + 1, len(paths), name_dest_im)) 149 | 150 | print('-> Done!') 151 | 152 | 153 | if __name__ == '__main__': 154 | args = parse_args() 155 | test_simple(args) 156 | -------------------------------------------------------------------------------- /monodepth/train.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | from monodepth.trainer import Trainer 10 | from monodepth.options import MonodepthOptions 11 | 12 | options = MonodepthOptions() 13 | opts = options.parse() 14 | 15 | 16 | if __name__ == "__main__": 17 | trainer = Trainer(opts) 18 | trainer.train() 19 | -------------------------------------------------------------------------------- /monodepth/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import os 9 | import hashlib 10 | import zipfile 11 | from six.moves import urllib 12 | 13 | 14 | def readlines(filename): 15 | """Read all the lines in a text file and return as a list 16 | """ 17 | with open(filename, 'r') as f: 18 | lines = f.read().splitlines() 19 | return lines 20 | 21 | 22 | def normalize_image(x): 23 | """Rescale image pixels to span range [0, 1] 24 | """ 25 | ma = float(x.max().cpu().data) 26 | mi = float(x.min().cpu().data) 27 | d = ma - mi if ma != mi else 1e5 28 | return (x - mi) / d 29 | 30 | 31 | def sec_to_hm(t): 32 | """Convert time in seconds to time in hours, minutes and seconds 33 | e.g. 10239 -> (2, 50, 39) 34 | """ 35 | t = int(t) 36 | s = t % 60 37 | t //= 60 38 | m = t % 60 39 | t //= 60 40 | return t, m, s 41 | 42 | 43 | def sec_to_hm_str(t): 44 | """Convert time in seconds to a nice string 45 | e.g. 10239 -> '02h50m39s' 46 | """ 47 | h, m, s = sec_to_hm(t) 48 | return "{:02d}h{:02d}m{:02d}s".format(h, m, s) 49 | 50 | 51 | def download_model_if_doesnt_exist(model_name): 52 | """If pretrained kitti model doesn't exist, download and unzip it 53 | """ 54 | # values are tuples of (, ) 55 | download_paths = { 56 | "mono_640x192": 57 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_640x192.zip", 58 | "a964b8356e08a02d009609d9e3928f7c"), 59 | "stereo_640x192": 60 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_640x192.zip", 61 | "3dfb76bcff0786e4ec07ac00f658dd07"), 62 | "mono+stereo_640x192": 63 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_640x192.zip", 64 | "c024d69012485ed05d7eaa9617a96b81"), 65 | "mono_no_pt_640x192": 66 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_no_pt_640x192.zip", 67 | "9c2f071e35027c895a4728358ffc913a"), 68 | "stereo_no_pt_640x192": 69 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_no_pt_640x192.zip", 70 | "41ec2de112905f85541ac33a854742d1"), 71 | "mono+stereo_no_pt_640x192": 72 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_no_pt_640x192.zip", 73 | "46c3b824f541d143a45c37df65fbab0a"), 74 | "mono_1024x320": 75 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_1024x320.zip", 76 | "0ab0766efdfeea89a0d9ea8ba90e1e63"), 77 | "stereo_1024x320": 78 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_1024x320.zip", 79 | "afc2f2126d70cf3fdf26b550898b501a"), 80 | "mono+stereo_1024x320": 81 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_1024x320.zip", 82 | "cdc5fc9b23513c07d5b19235d9ef08f7"), 83 | } 84 | 85 | if not os.path.exists("models"): 86 | os.makedirs("models") 87 | 88 | model_path = os.path.join("models", model_name) 89 | 90 | def check_file_matches_md5(checksum, fpath): 91 | if not os.path.exists(fpath): 92 | return False 93 | with open(fpath, 'rb') as f: 94 | current_md5checksum = hashlib.md5(f.read()).hexdigest() 95 | return current_md5checksum == checksum 96 | 97 | # see if we have the model already downloaded... 98 | if not os.path.exists(os.path.join(model_path, "encoder.pth")): 99 | 100 | model_url, required_md5checksum = download_paths[model_name] 101 | 102 | if not check_file_matches_md5(required_md5checksum, model_path + ".zip"): 103 | print("-> Downloading pretrained model to {}".format(model_path + ".zip")) 104 | urllib.request.urlretrieve(model_url, model_path + ".zip") 105 | 106 | if not check_file_matches_md5(required_md5checksum, model_path + ".zip"): 107 | print(" Failed to download a file which matches the checksum - quitting") 108 | quit() 109 | 110 | print(" Unzipping model...") 111 | with zipfile.ZipFile(model_path + ".zip", 'r') as f: 112 | f.extractall(model_path) 113 | 114 | print(" Model unzipped to {}".format(model_path)) 115 | -------------------------------------------------------------------------------- /neural/dynamics/dynamics_factory.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from neural.layers_3d import CausalConv3d, Bottleneck3D 4 | from neural.layers import NormActivation 5 | from neural.utils import print_model_spec 6 | 7 | 8 | class TemporalModel(nn.Module): 9 | def __init__(self, input_features=128, model_name='baseline'): 10 | super().__init__() 11 | self.model_name = model_name 12 | self.model, receptive_field = DynamicsFactory(model_name=model_name, input_features=input_features).get_model() 13 | self.receptive_field = receptive_field 14 | self.output_features = input_features 15 | 16 | print_model_spec(self, 'Temporal') 17 | 18 | def forward(self, x): 19 | if self.model_name == 'gru': 20 | return self.model(x) 21 | 22 | x = x.permute(0, 2, 1, 3, 4) 23 | x = self.model(x) 24 | x = x.permute(0, 2, 1, 3, 4).contiguous() 25 | return x 26 | 27 | 28 | class DynamicsFactory: 29 | def __init__(self, model_name, input_features): 30 | self.model_name = model_name 31 | self.input_features = input_features 32 | self.receptive_field = 1 33 | 34 | def get_model(self): 35 | input_features = self.input_features 36 | if self.model_name == 'no_temporal': 37 | return nn.Sequential(), self.receptive_field 38 | elif self.model_name == 'baseline': 39 | self.receptive_field = 2 40 | return nn.Sequential(CausalConv3d(input_features, input_features, (2, 3, 3), dilation=(1, 1, 1)), 41 | NormActivation(input_features, dimension='3d', activation='leaky_relu'), 42 | CausalConv3d(input_features, input_features, (1, 3, 3), dilation=(1, 1, 1)), 43 | NormActivation(input_features, dimension='3d', activation='leaky_relu'), 44 | CausalConv3d(input_features, input_features, (1, 3, 3), dilation=(1, 1, 1)), 45 | NormActivation(input_features, dimension='3d', activation='leaky_relu'), 46 | ), self.receptive_field 47 | elif self.model_name == 'small': 48 | self.receptive_field = 3 49 | model = [] 50 | for i in range(1): 51 | model.append(Bottleneck3D(input_features, kernel_size=(2, 3, 3), dilation=(1, 1, 1))) 52 | for i in range(10): 53 | model.append(Bottleneck3D(input_features, kernel_size=(1, 3, 3), dilation=(1, 1, 1))) 54 | model.append(Bottleneck3D(input_features, kernel_size=(2, 3, 3), dilation=(1, 1, 1))) 55 | 56 | return nn.Sequential(*model), self.receptive_field 57 | elif self.model_name == 'medium': 58 | self.receptive_field = 3 59 | model = [] 60 | for i in range(1): 61 | model.append(Bottleneck3D(input_features, kernel_size=(2, 3, 3), dilation=(1, 1, 1))) 62 | for i in range(20): 63 | model.append(Bottleneck3D(input_features, kernel_size=(1, 3, 3), dilation=(1, 1, 1))) 64 | 65 | model.append(Bottleneck3D(input_features, kernel_size=(2, 3, 3), dilation=(1, 1, 1))) 66 | 67 | return nn.Sequential(*model), self.receptive_field 68 | elif self.model_name == 'large': 69 | self.receptive_field = 3 70 | model = [] 71 | for i in range(1): 72 | model.append(Bottleneck3D(input_features, kernel_size=(2, 3, 3), dilation=(1, 1, 1))) 73 | for i in range(40): 74 | model.append(Bottleneck3D(input_features, kernel_size=(1, 3, 3), dilation=(1, 1, 1))) 75 | 76 | model.append(Bottleneck3D(input_features, kernel_size=(2, 3, 3), dilation=(1, 1, 1))) 77 | 78 | return nn.Sequential(*model), self.receptive_field 79 | else: 80 | raise ValueError('Dynamics model {} not implemented.'.format(self.model_name)) 81 | 82 | -------------------------------------------------------------------------------- /neural/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | 5 | from functools import partial 6 | 7 | 8 | class NormActivation(nn.Module): 9 | def __init__(self, num_features, dimension='2d', activation='none', momentum=0.05, slope=0.01): 10 | super().__init__() 11 | 12 | if dimension == '1d': 13 | self.norm = nn.BatchNorm1d(num_features=num_features, momentum=momentum) 14 | elif dimension =='2d': 15 | self.norm = nn.BatchNorm2d(num_features=num_features, momentum=momentum) 16 | elif dimension == '3d': 17 | self.norm = nn.BatchNorm3d(num_features=num_features, momentum=momentum) 18 | else: 19 | raise ValueError('Dimension={} not handled.'.format(dimension)) 20 | 21 | if activation == "relu": 22 | self.activation_fn = lambda x: nn.functional.relu(x, inplace=True) 23 | elif activation == "leaky_relu": 24 | self.activation_fn = lambda x: nn.functional.leaky_relu(x, negative_slope=slope, inplace=True) 25 | elif activation == "elu": 26 | self.activation_fn = lambda x: nn.functional.elu(x, inplace=True) 27 | elif activation == "none": 28 | self.activation_fn = lambda x: x 29 | else: 30 | raise ValueError('Activation={} not handled.'.format(self.activation)) 31 | 32 | def forward(self, x): 33 | x = self.norm(x) 34 | x = self.activation_fn(x) 35 | return x 36 | 37 | 38 | class ConvBlock(nn.Module): 39 | """ Conv and optional (BN - ReLU) 40 | """ 41 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, norm='none', activation='none', bias=False, 42 | transpose=False): 43 | super().__init__() 44 | padding = int((kernel_size - 1) / 2) 45 | self.conv = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1) 46 | 47 | if norm == 'in': 48 | self.norm = nn.InstanceNorm2d(out_channels) 49 | elif norm =='bn': 50 | self.norm = nn.BatchNorm2d(out_channels) 51 | elif norm == 'none': 52 | self.norm = None 53 | else: 54 | raise ValueError('Not recognised norm {}'.format(norm)) 55 | 56 | if activation == 'lrelu': 57 | self.activation = nn.LeakyReLU(0.2) 58 | elif activation == 'relu': 59 | self.activation = nn.ReLU() 60 | elif activation == 'tanh': 61 | self.activation = nn.Tanh() 62 | elif activation == 'none': 63 | self.activation = None 64 | else: 65 | raise ValueError('Not recognised activation {}'.format(activation)) 66 | 67 | self.conv = self.conv(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias) 68 | 69 | def forward(self, x): 70 | x = self.conv(x) 71 | 72 | if self.norm is not None: 73 | x = self.norm(x) 74 | 75 | if self.activation is not None: 76 | x = self.activation(x) 77 | 78 | return x 79 | 80 | 81 | class ResBlock(nn.Module): 82 | """ Conv - BN - ReLU - Conv - BN - ADD and then ReLU 83 | Same number of channels in and out. 84 | """ 85 | def __init__(self, channels, norm='in', activation='lrelu', bias=False, last_block=False): 86 | super().__init__() 87 | if activation == 'lrelu': 88 | self.activation = nn.LeakyReLU(0.2) 89 | elif activation == 'relu': 90 | self.activation = nn.ReLU() 91 | elif activation == 'none': 92 | self.activation = None 93 | else: 94 | raise ValueError('Not recognised activation {}'.format(activation)) 95 | 96 | self.model = [] 97 | 98 | self.model.append(ConvBlock(channels, channels, 3, 1, norm=norm, activation=activation, bias=bias)) 99 | if last_block: 100 | norm = 'none' 101 | bias = True 102 | self.activation = None 103 | self.model.append(ConvBlock(channels, channels, 3, 1, norm=norm, activation='none', bias=bias)) 104 | 105 | self.model = nn.Sequential(*self.model) 106 | 107 | def forward(self, x): 108 | identity = x 109 | x = self.model(x) 110 | x += identity 111 | if self.activation: 112 | x = self.activation(x) 113 | return x 114 | -------------------------------------------------------------------------------- /neural/layers_3d.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch.nn as nn 3 | 4 | from neural.layers import NormActivation 5 | 6 | 7 | class CausalConv3d(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False): 9 | super().__init__() 10 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.' 11 | time_pad = (kernel_size[0] - 1) * dilation[0] 12 | height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2 13 | width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2 14 | 15 | # Pad temporally on the left 16 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0) 17 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias) 18 | 19 | # pylint: disable=arguments-differ 20 | def forward(self, x): 21 | x = self.pad(x) 22 | x = self.conv(x) 23 | return x 24 | 25 | 26 | class Bottleneck3D(nn.Module): 27 | """ 28 | Defines a 3D bottleneck module with a residual connection. 29 | """ 30 | def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1), low_rank=False, upsample=False, 31 | downsample=False): 32 | super().__init__() 33 | self.in_channels = in_channels 34 | self.half_channels = int(in_channels / 2) 35 | self.kernel_size = kernel_size 36 | self.dilation = dilation 37 | self.low_rank = low_rank 38 | self.upsample = upsample 39 | self.downsample = downsample 40 | self.out_channels = out_channels or self.in_channels 41 | 42 | # Define the main conv operation 43 | assert not (low_rank and upsample), 'Error, both upsample and low rank is not supported.' 44 | assert not (low_rank and downsample), 'Error, both downsample and low rank is not supported.' 45 | assert not (upsample and downsample), 'Error, both downsample and upsample is not supported.' 46 | 47 | if self.low_rank: 48 | raise NotImplementedError() 49 | elif self.upsample: 50 | raise NotImplementedError() 51 | elif self.downsample: 52 | raise NotImplementedError() 53 | else: 54 | bottleneck_conv = CausalConv3d(self.half_channels, self.half_channels, kernel_size=self.kernel_size, 55 | dilation=self.dilation, bias=False) 56 | 57 | self.layers = nn.Sequential(collections.OrderedDict([ 58 | # First projection with 1x1 kernel 59 | ('conv_down_project', nn.Conv3d(self.in_channels, self.half_channels, kernel_size=1, bias=False)), 60 | ('abn_down_project', NormActivation(num_features=self.half_channels, dimension='3d', activation='leaky_relu')), 61 | # Second conv block 62 | ('conv', bottleneck_conv), 63 | ('abn', NormActivation(num_features=self.half_channels, dimension='3d', activation='leaky_relu')), 64 | # Final projection with 1x1 kernel 65 | ('conv_up_project', nn.Conv3d(self.half_channels, self.out_channels, kernel_size=1, bias=False)), 66 | ('abn_up_project', NormActivation(num_features=self.out_channels, dimension='3d', activation='leaky_relu')), 67 | # Regulariser 68 | ('dropout', nn.Dropout2d(p=0.2)) 69 | ])) 70 | 71 | if self.out_channels != self.in_channels: 72 | raise NotImplementedError() 73 | else: 74 | self.projection = None 75 | 76 | # pylint: disable=arguments-differ 77 | def forward(self, *args): 78 | x, = args 79 | x_residual = self.layers(x) 80 | if self.downsample: 81 | x = nn.functional.max_pool3d(x, kernel_size=2, stride=2) 82 | if self.upsample: 83 | x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 84 | if self.out_channels != self.in_channels: 85 | x = self.projection(x) 86 | return x + x_residual -------------------------------------------------------------------------------- /neural/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def discriminative_loss_static(y_batch, label_batch, config, device): 5 | """ 6 | Parameters 7 | ---------- 8 | y_batch: torch.tensor shape (batch_size, emb_dim, h, w) 9 | label_batch: torch.tensor shape (batch_size, h, w) 10 | config: dict 11 | 12 | Returns 13 | ------- 14 | total_loss: torch.tensor 15 | Discriminative loss for instance segmentation 16 | """ 17 | delta_v = config['delta_v'] 18 | delta_d = config['delta_d'] 19 | batch_size = y_batch.size(0) 20 | 21 | total_v_loss = torch.tensor(0.0).to(device) 22 | total_d_loss = torch.tensor(0.0).to(device) 23 | total_reg_loss = torch.tensor(0.0).to(device) 24 | for i in range(batch_size): 25 | x = y_batch[i] 26 | label = label_batch[i] 27 | mask = label != 0 28 | emb_dim = x.size(0) 29 | 30 | # Mask out the background 31 | x = x[:, mask] 32 | label = label[mask] 33 | # Labels start at 1, make it start at 0 34 | label = (label - 1).byte() 35 | 36 | unique_labels = torch.unique(label) 37 | n_classes = len(unique_labels) 38 | assert torch.equal(unique_labels, torch.arange(n_classes, dtype=torch.uint8, device=device)), \ 39 | 'Unique labels are not consecutive' 40 | 41 | # Variance loss 42 | # Compute mean vector for each class 43 | indices = label.repeat(emb_dim, 1).long() 44 | mus = torch.zeros(emb_dim, n_classes, device=device).scatter_add(1, indices, x) 45 | counts = torch.zeros(emb_dim, n_classes, device=device).scatter_add(1, indices, torch.ones_like(x, device=device)) 46 | mus = mus / counts # shape (emb_dim, n_classes) 47 | 48 | v_loss = (torch.norm(x - torch.gather(mus, dim=1, index=indices), p=2, dim=0) - delta_v).clamp(min=0) 49 | v_loss = torch.pow(v_loss, 2) 50 | # Divide by pixel count of each instance 51 | v_loss /= torch.gather(counts[0, :], dim=0, index=label.long()) 52 | v_loss = torch.sum(v_loss) / n_classes 53 | 54 | # Distance loss 55 | mus_repeat = mus.view(emb_dim, n_classes, 1).repeat(1, 1, n_classes) 56 | mus_repeat_t = mus.view(emb_dim, 1, n_classes).repeat(1, n_classes, 1) 57 | 58 | mus_matrix = (2 * delta_d - torch.norm(mus_repeat - mus_repeat_t, p=2, dim=0)).clamp(min=0) 59 | mus_matrix = torch.pow(mus_matrix, 2) 60 | 61 | # If contains more than one class 62 | if n_classes > 1: 63 | d_loss = mus_matrix[1 - torch.eye(n_classes, dtype=torch.uint8, device=device)].mean() 64 | else: 65 | d_loss = torch.tensor(0, dtype=torch.float, device=device) 66 | 67 | # Regularisation loss 68 | reg_loss = torch.norm(mus, p=2, dim=0).mean() 69 | 70 | total_v_loss += v_loss 71 | total_d_loss += d_loss 72 | total_reg_loss += reg_loss 73 | 74 | total_v_loss = config['lambda_v'] * total_v_loss / batch_size 75 | total_d_loss = config['lambda_d'] * total_d_loss / batch_size 76 | total_reg_loss = config['lambda_reg'] * total_reg_loss / batch_size 77 | 78 | losses = {'v_loss': total_v_loss, 79 | 'd_loss': total_d_loss, 80 | 'reg_loss': total_reg_loss} 81 | return losses 82 | 83 | 84 | def discriminative_loss_static_loopy(y_batch, label_batch, config, device): 85 | """ Discriminative loss with loops, ignoring the background (id=0) 86 | 87 | Parameters 88 | ---------- 89 | y_batch: torch.tensor shape (batch_size, emb_dim, h, w) 90 | label_batch: torch.tensor shape (batch_size, h, w) 91 | config: dict with keys 'delta_v', 'delta_d', 'lambda_v', 'lambda_d', 'lambda_reg' 92 | 93 | Returns 94 | ------- 95 | losses: dict 96 | """ 97 | delta_v = config['delta_v'] 98 | delta_d = config['delta_d'] 99 | batch_size = y_batch.size(0) 100 | 101 | total_v_loss = torch.tensor(0.0).to(device) 102 | total_d_loss = torch.tensor(0.0).to(device) 103 | total_reg_loss = torch.tensor(0.0).to(device) 104 | 105 | for i in range(batch_size): 106 | # Variance loss 107 | x = y_batch[i] 108 | label = label_batch[i] 109 | 110 | v_loss = 0 111 | d_loss = 0 112 | reg_loss = 0 113 | 114 | unique_labels = torch.unique(label) 115 | # Remove background 116 | assert 0 in unique_labels 117 | unique_labels = unique_labels[1:] 118 | C = len(unique_labels) 119 | 120 | if C > 0: 121 | for c in unique_labels: 122 | x_masked = x[:, (label == c)] 123 | mu = x_masked.mean(dim=-1, keepdim=True) 124 | v_loss_current = (torch.norm(x_masked - mu, 2, dim=0) - delta_v).clamp(min=0) 125 | v_loss_current = torch.pow(v_loss_current, 2) 126 | v_loss += torch.mean(v_loss_current) 127 | 128 | v_loss /= C 129 | 130 | # Distance loss 131 | mus = [] 132 | for c in unique_labels: 133 | x_masked = x[:, (label == c)] 134 | mu = x_masked.mean(dim=-1) 135 | mus.append(mu) 136 | 137 | # shape (C, emb_dim) 138 | mus = torch.stack(mus, dim=0) 139 | for i in range(C): 140 | for j in range(C): 141 | if i == j: 142 | continue 143 | dist = (2 * delta_d - torch.norm(mus[i] - mus[j], 2)).clamp(min=0) 144 | dist = torch.pow(dist, 2) 145 | d_loss += dist 146 | 147 | d_loss /= torch.tensor(max(C * (C - 1), 1)) # so that d_loss is a torch.tensor (when C=1) 148 | 149 | # Regularisation loss 150 | for mu in mus: 151 | reg_loss += torch.norm(mu, 2) 152 | reg_loss /= C 153 | 154 | total_v_loss += v_loss 155 | total_d_loss += d_loss 156 | total_reg_loss += reg_loss 157 | 158 | total_v_loss = config['lambda_v'] * total_v_loss / batch_size 159 | total_d_loss = config['lambda_d'] * total_d_loss / batch_size 160 | total_reg_loss = config['lambda_reg'] * total_reg_loss / batch_size 161 | 162 | losses = {'v_loss': total_v_loss, 163 | 'd_loss': total_d_loss, 164 | 'reg_loss': total_reg_loss} 165 | return losses 166 | 167 | 168 | def discriminative_loss_sequence_static(batch, output, config, device): 169 | """ Discriminative loss with loops, ignoring the background (id=0) (static frames) 170 | 171 | Parameters 172 | ---------- 173 | batch: dict with key: 174 | instance_seg: torch.tensor shape (batch_size, T, N_CLASSES, h, w) 175 | output: dict with key: 176 | y: torch.tensor shape (batch_size, T, emb_dim, h, w) 177 | config: dict with keys 'delta_v', 'delta_d', 'lambda_v', 'lambda_d', 'lambda_reg' 178 | 179 | Returns 180 | ------- 181 | losses: dict 182 | """ 183 | y_batch = output['y'] 184 | label_batch = batch['instance_seg'].squeeze(2) 185 | seq_len = y_batch.size(1) 186 | 187 | losses = {'v_loss': torch.tensor(0.0).to(device), 188 | 'd_loss': torch.tensor(0.0).to(device), 189 | 'reg_loss': torch.tensor(0.0).to(device)} 190 | 191 | for t in range(seq_len): 192 | losses_t = discriminative_loss_static_loopy(y_batch[:, t], label_batch[:, t], config, device) 193 | for key in losses.keys(): 194 | losses[key] += losses_t[key] 195 | 196 | for key in losses.keys(): 197 | losses[key] /= seq_len 198 | 199 | return losses 200 | 201 | 202 | def discriminative_loss_loopy(batch, output, config, device): 203 | """ Discriminative loss with loops, ignoring the background (id=0) 204 | 205 | Parameters 206 | ---------- 207 | batch: dict with key: 208 | instance_seg: torch.tensor shape (batch_size, T, N_CLASSES, h, w) 209 | output: dict with key: 210 | y: torch.tensor shape (batch_size, T, emb_dim, h, w) 211 | config: dict with keys 'delta_v', 'delta_d', 'lambda_v', 'lambda_d', 'lambda_reg', 'receptive_field' 212 | 213 | Returns 214 | ------- 215 | losses: dict 216 | """ 217 | y_batch = output['y'] 218 | label_batch = batch['instance_seg'].squeeze(2) 219 | delta_v = config['delta_v'] 220 | delta_d = config['delta_d'] 221 | receptive_field = config['receptive_field'] 222 | batch_size = y_batch.size(0) 223 | 224 | total_v_loss = torch.tensor(0.0).to(device) 225 | total_d_loss = torch.tensor(0.0).to(device) 226 | total_reg_loss = torch.tensor(0.0).to(device) 227 | 228 | for i in range(batch_size): 229 | # Variance loss 230 | x = y_batch[i].permute(1, 0, 2, 3)[:, (receptive_field-1):] 231 | label = label_batch[i][(receptive_field-1):] 232 | 233 | v_loss = 0 234 | d_loss = 0 235 | reg_loss = 0 236 | 237 | unique_labels = torch.unique(label) 238 | # Remove background 239 | assert 0 in unique_labels 240 | unique_labels = unique_labels[1:] 241 | C = len(unique_labels) 242 | 243 | if C > 0: 244 | for c in unique_labels: 245 | x_masked = x[:, (label == c)] 246 | mu = x_masked.mean(dim=-1, keepdim=True) 247 | v_loss_current = (torch.norm(x_masked - mu, 2, dim=0) - delta_v).clamp(min=0) 248 | v_loss_current = torch.pow(v_loss_current, 2) 249 | v_loss += torch.mean(v_loss_current) 250 | 251 | v_loss /= C 252 | 253 | # Distance loss 254 | mus = [] 255 | for c in unique_labels: 256 | x_masked = x[:, (label == c)] 257 | mu = x_masked.mean(dim=-1) 258 | mus.append(mu) 259 | 260 | # shape (C, emb_dim) 261 | mus = torch.stack(mus, dim=0) 262 | for i in range(C): 263 | for j in range(C): 264 | if i == j: 265 | continue 266 | dist = (2 * delta_d - torch.norm(mus[i] - mus[j], 2)).clamp(min=0) 267 | dist = torch.pow(dist, 2) 268 | d_loss += dist 269 | 270 | d_loss /= torch.tensor(max(C * (C - 1), 1)) # so that d_loss is a torch.tensor (when C=1) 271 | 272 | # Regularisation loss 273 | for mu in mus: 274 | reg_loss += torch.norm(mu, 2) 275 | reg_loss /= C 276 | 277 | total_v_loss += v_loss 278 | total_d_loss += d_loss 279 | total_reg_loss += reg_loss 280 | 281 | total_v_loss = config['lambda_v'] * total_v_loss / batch_size 282 | total_d_loss = config['lambda_d'] * total_d_loss / batch_size 283 | total_reg_loss = config['lambda_reg'] * total_reg_loss / batch_size 284 | 285 | losses = {'v_loss': total_v_loss, 286 | 'd_loss': total_d_loss, 287 | 'reg_loss': total_reg_loss} 288 | return losses 289 | 290 | 291 | def mask_loss(batch, output): 292 | """ Cross-entropy loss 293 | 294 | Parameters 295 | ---------- 296 | batch: dict with key: 297 | 'instance_seg' 298 | output: dict with key: 299 | 'mask' 300 | """ 301 | b, t, c, h, w = output['mask_logits'].shape 302 | logits = output['mask_logits'].view(b*t, c, h, w) 303 | # TODO: N_CLASSES 304 | labels = (batch['instance_seg'].squeeze(2) > 0).view(b*t, h, w).long() 305 | mask_loss = torch.nn.functional.cross_entropy(input=logits, target=labels) 306 | 307 | losses = {'mask_loss': mask_loss} 308 | return losses 309 | 310 | 311 | def motion_loss(batch, output, device): 312 | """ 313 | Parameters 314 | ---------- 315 | batch: dict with keys: 316 | position: torch.tensor (B, T, MAX_INSTANCES, 3) 317 | velocity: torch.tensor (B, T, MAX_INSTANCES, 3) 318 | output: dict with keys: 319 | position: torch.tensor (B, T, MAX_INSTANCES, 3) 320 | velocity: torch.tensor (B, T, MAX_INSTANCES, 3) 321 | """ 322 | losses = {} 323 | position_loss = torch.tensor(0.0).to(device) 324 | velocity_loss = torch.tensor(0.0).to(device) 325 | 326 | batch_size = batch['img'].size(0) 327 | for i in range(batch_size): 328 | unique_ids = torch.unique(batch['instance_seg'][i])[1:].long() 329 | 330 | if len(unique_ids) > 0: 331 | # With the current model, we can only estimate z position (depth) 332 | position_loss += torch.dist(output['position'][i, :, unique_ids, 2], 333 | batch['position'][i, :, unique_ids, 2], p=2) 334 | # Only penalise 2D velocity 335 | velocity_loss += torch.dist(output['velocity'][i, :, unique_ids][:, :, [0, 2]], 336 | batch['velocity'][i, :, unique_ids][:, :, [0, 2]], p=2) 337 | 338 | position_loss /= batch_size 339 | velocity_loss /= batch_size 340 | 341 | losses['position_loss'] = position_loss 342 | losses['velocity_loss'] = velocity_loss 343 | return losses 344 | -------------------------------------------------------------------------------- /neural/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | import torch.nn as nn 5 | 6 | from neural.resnet import ResnetEncoder, ResnetDecoder 7 | from neural.dynamics.dynamics_factory import TemporalModel 8 | from neural.utils import print_model_spec, require_grad 9 | from monodepth.layers import ConvBlock, Conv3x3 10 | 11 | 12 | class TemporalEncoder(nn.Module): 13 | """ Encoding + temporal model. Temporal model can be set to identity to have a static model.""" 14 | def __init__(self, config, device): 15 | super().__init__() 16 | self.config = config 17 | self.device = device 18 | 19 | self.encoder = Encoder( 20 | encoder_name=self.config['encoder_name'], 21 | pretrained_encoder_path=self.config['pretrained_encoder_path'] 22 | ) 23 | self.temporal_model = TemporalModel(input_features=self.encoder.output_features[-1], 24 | model_name=config['dynamics_model_name']) 25 | self.receptive_field = self.temporal_model.receptive_field 26 | self.output_features = self.encoder.output_features # works as long as the temporal model does not change 27 | # the number of channels 28 | 29 | def forward(self, x): 30 | """ 31 | Parameters 32 | ---------- 33 | x: torch.tensor (B, T, 3, H, W) 34 | 35 | Returns 36 | ------- 37 | z: torch.tensor (B, T, C, H, W) 38 | temporal embedding 39 | """ 40 | b, seq_len, c, h, w = x.shape 41 | 42 | encoder_outputs = self.encoder(x.view(b * seq_len, c, h, w)) 43 | encoder_outputs = [encoder_outputs[i].view(b, seq_len, *encoder_outputs[i].shape[1:]) 44 | for i in range(len(encoder_outputs))] 45 | z = self.temporal_model(encoder_outputs[-1]) 46 | 47 | return encoder_outputs[:-1] + [z] 48 | 49 | 50 | class Encoder(nn.Module): 51 | def __init__(self, encoder_name='', pretrained_encoder_path=''): 52 | super().__init__() 53 | self.output_features = None 54 | self.encoder_name = encoder_name 55 | 56 | if self.encoder_name == 'resnet': 57 | self.model = ResnetEncoder() 58 | self.output_features = self.model.output_features 59 | 60 | if pretrained_encoder_path: 61 | print('Loading encoder weights from {}'.format(pretrained_encoder_path)) 62 | checkpoint = torch.load(pretrained_encoder_path) 63 | self.model.load_state_dict(checkpoint['encoder']) 64 | else: 65 | deeplab = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True) 66 | self.backbone = deeplab.backbone 67 | self.aspp = deeplab.classifier[0] 68 | require_grad(self, False) 69 | self.output_features = 256 70 | 71 | print_model_spec(self, 'Encoder') 72 | 73 | def forward(self, x): 74 | if self.encoder_name in ['resnet']: 75 | # Returns a list of elements 76 | x = self.model(x) 77 | else: 78 | # Returns one element 79 | x = self.backbone(x)['out'] 80 | x = self.aspp(x) 81 | return x 82 | 83 | 84 | class InstanceDecoder(nn.Module): 85 | def __init__(self, decoder_name='resnet', emb_dim=8, instance=False, mask=False, config=None): 86 | super().__init__() 87 | self.config = config 88 | if decoder_name == 'resnet': 89 | self.model = ResnetDecoder(num_output_channels=emb_dim, instance=instance, mask=mask) 90 | if self.config['pretrained_encoder_path']: 91 | print('Loading decoder weights from {}'.format(self.config['pretrained_encoder_path'])) 92 | checkpoint = torch.load(self.config['pretrained_encoder_path']) 93 | # last layer do not have same nb of channels 94 | try: 95 | self.model.load_state_dict(checkpoint['decoder']) 96 | except RuntimeError: 97 | print('Not loading weights from the last layer of the decoder.') 98 | checkpoint['decoder'].pop('decoder.6.conv.weight', None) 99 | checkpoint['decoder'].pop('decoder.6.conv.bias', None) 100 | self.model.load_state_dict(checkpoint['decoder'], strict=False) 101 | 102 | print_model_spec(self, 'Instance decoder') 103 | 104 | def forward(self, input_features): 105 | """ 106 | Parameters 107 | ---------- 108 | z: torch.tensor (B, T, C, H, W) 109 | temporal embedding 110 | """ 111 | b, seq_len = input_features[-1].shape[:2] 112 | input_features = [input_features[i].view(b * seq_len, *input_features[i].shape[2:]) 113 | for i in range(len(input_features))] 114 | output = self.model(input_features) 115 | for key in output.keys(): 116 | if output[key] is not None: 117 | output[key] = output[key].view(b, seq_len, *output[key].shape[1:]) 118 | 119 | output['y'] = output['instance'] 120 | output.pop('instance', None) 121 | 122 | return output 123 | 124 | 125 | class DepthEmbedding(nn.Module): 126 | def __init__(self, config): 127 | super().__init__() 128 | self.config = config 129 | self.conv1 = ConvBlock(in_channels=(self.config['emb_dim'] + 1), 130 | out_channels=self.config['emb_dim']) 131 | self.conv2 = Conv3x3(in_channels=self.config['emb_dim'], 132 | out_channels=self.config['emb_dim']) 133 | 134 | def forward(self, output): 135 | """ 136 | Parameters 137 | ---------- 138 | output: dict with keys: 139 | y: torch.tensor (batch_size, seq_len, emb_dim, H, W) 140 | depth: torch.tensor (batch_size, seq_len, 1, H, W) 141 | """ 142 | b, seq_len, emb_dim, h, w = output['y'].shape 143 | 144 | y = output['y'].view(b*seq_len, emb_dim, h, w) 145 | depth = output['depth'].view(b*seq_len, 1, h, w) 146 | 147 | depth_y = torch.cat([y, depth], dim=1) 148 | 149 | # Convolution 1 150 | x = self.conv1(depth_y) 151 | # Convolution 2 152 | x = self.conv2(x) 153 | 154 | x = x.view(b, seq_len, emb_dim, h, w) 155 | 156 | return {'y': x} 157 | -------------------------------------------------------------------------------- /neural/resnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | 7 | from collections import OrderedDict 8 | 9 | from monodepth.layers import ConvBlock, Conv3x3, upsample 10 | 11 | ENCODER_CHANNELS = np.array([64, 64, 128, 256, 512]) 12 | DECODER_CHANNELS = np.array([16, 32, 64, 128, 256]) 13 | 14 | 15 | class ResnetEncoder(nn.Module): 16 | """Pytorch module for a resnet encoder 17 | """ 18 | def __init__(self, num_layers=18, pretrained=True, use_skips=True): 19 | super().__init__() 20 | 21 | self.use_skips = use_skips 22 | self.num_ch_enc = ENCODER_CHANNELS 23 | self.num_ch_dec = DECODER_CHANNELS 24 | 25 | resnets = {18: models.resnet18, 26 | 34: models.resnet34, 27 | 50: models.resnet50, 28 | 101: models.resnet101, 29 | 152: models.resnet152} 30 | 31 | if num_layers not in resnets: 32 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 33 | 34 | self.encoder = resnets[num_layers](pretrained) 35 | 36 | if num_layers > 34: 37 | self.num_ch_enc[1:] *= 4 38 | 39 | # Upsample twice 40 | self.convs = OrderedDict() 41 | for i in range(4, 2, -1): 42 | # upconv_0 43 | num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] 44 | num_ch_out = self.num_ch_dec[i] 45 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) 46 | 47 | # upconv_1 48 | num_ch_in = self.num_ch_dec[i] 49 | if self.use_skips and i > 0: 50 | num_ch_in += self.num_ch_enc[i - 1] 51 | num_ch_out = self.num_ch_dec[i] 52 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) 53 | self.decoder = nn.ModuleList(list(self.convs.values())) 54 | 55 | self.output_features = [64, 64, 128] 56 | 57 | def forward(self, input_image): 58 | """ 59 | Returns 60 | ------- 61 | features0, features1 (from encoder) and output features after 2 steps of decoding 62 | """ 63 | features = [] 64 | x = input_image # 3x96x320 65 | x = self.encoder.conv1(x) 66 | x = self.encoder.bn1(x) 67 | features.append(self.encoder.relu(x)) # 64x48x160 68 | features.append(self.encoder.layer1(self.encoder.maxpool(features[-1]))) # 64x28x80 69 | features.append(self.encoder.layer2(features[-1])) # 128x12x40 70 | features.append(self.encoder.layer3(features[-1])) # 256x6x20 71 | features.append(self.encoder.layer4(features[-1])) # 512x3x10 72 | 73 | x = features[-1] 74 | for i in range(4, 2, -1): 75 | x = self.convs[("upconv", i, 0)](x) 76 | x = [upsample(x)] 77 | if self.use_skips and i > 0: 78 | x += [features[i - 1]] 79 | x = torch.cat(x, 1) 80 | x = self.convs[("upconv", i, 1)](x) 81 | 82 | output_features = [features[0], features[1], x] # x is shape 128x12x40 83 | 84 | return output_features 85 | 86 | 87 | class ResnetDecoder(nn.Module): 88 | def __init__(self, num_output_channels=1, use_skips=True, depth=False, instance=False, segmentation=False, 89 | mask=False, scales=range(3), n_classes=14): 90 | super().__init__() 91 | 92 | self.num_output_channels = num_output_channels # is the instance number of output channels 93 | self.use_skips = use_skips 94 | self.depth = depth 95 | self.instance = instance 96 | self.segmentation = segmentation 97 | self.mask = mask 98 | self.scales = scales 99 | self.n_classes = n_classes 100 | self.upsample_mode = 'nearest' 101 | 102 | self.num_ch_enc = ENCODER_CHANNELS 103 | self.num_ch_dec = DECODER_CHANNELS 104 | 105 | # decoder 106 | self.convs = OrderedDict() 107 | for i in range(2, -1, -1): 108 | # upconv_0 109 | num_ch_in = self.num_ch_enc[2] if i == 2 else self.num_ch_dec[i + 1] 110 | num_ch_out = self.num_ch_dec[i] 111 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) 112 | 113 | # upconv_1 114 | num_ch_in = self.num_ch_dec[i] 115 | if self.use_skips and i > 0: 116 | num_ch_in += self.num_ch_enc[i - 1] 117 | num_ch_out = self.num_ch_dec[i] 118 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) 119 | 120 | if self.depth: 121 | for s in self.scales: 122 | self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], 1) 123 | self.sigmoid = nn.Sigmoid() 124 | 125 | if self.instance: 126 | self.convs[("instance_conv")] = Conv3x3(self.num_ch_dec[0], self.num_output_channels) 127 | 128 | if self.segmentation: 129 | self.convs[("segmentation_conv")] = Conv3x3(self.num_ch_dec[0], self.n_classes) 130 | 131 | if self.mask: 132 | self.convs[("mask_conv")] = Conv3x3(self.num_ch_dec[0], 2) 133 | 134 | self.decoder = nn.ModuleList(list(self.convs.values())) 135 | 136 | def forward(self, input_features): 137 | output = {} 138 | 139 | # decoder 140 | x = input_features[-1] 141 | for i in range(2, -1, -1): 142 | x = self.convs[("upconv", i, 0)](x) 143 | x = [upsample(x)] 144 | if self.use_skips and i > 0: 145 | x += [input_features[i - 1]] 146 | x = torch.cat(x, 1) 147 | x = self.convs[("upconv", i, 1)](x) 148 | if i in self.scales and self.depth: 149 | output[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x)) 150 | 151 | if self.instance: 152 | output['instance'] = self.convs[("instance_conv")](x) 153 | 154 | if self.segmentation: 155 | output['segmentation'] = self.convs[("segmentation_conv")](x) 156 | 157 | if self.mask: 158 | output['mask_logits'] = self.convs[('mask_conv')](x) 159 | 160 | return output 161 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class TemporalModelOptions: 5 | def __init__(self): 6 | self.parser = argparse.ArgumentParser(description="Temporal instance segmentation and depth options") 7 | 8 | # PATHS 9 | self.parser.add_argument('--config', 10 | type=str, 11 | default='', 12 | help='Path of the config file') 13 | 14 | self.parser.add_argument('--restore', 15 | type=str, 16 | default='', 17 | help='Path of the model to restore (weights, optimiser)') 18 | 19 | # OPTIMIZATION options 20 | self.parser.add_argument("--batch_size", 21 | type=int, 22 | help="batch size", 23 | default=6) 24 | self.parser.add_argument("--learning_rate", 25 | type=float, 26 | help="learning rate", 27 | default=1e-4) 28 | self.parser.add_argument("--seq_len", 29 | type=int, 30 | help="sequence length", 31 | default=5) 32 | 33 | # TRAINING options 34 | self.parser.add_argument("--num_layers", 35 | type=int, 36 | help="number of resnet layers", 37 | default=18, 38 | choices=[18, 34, 50]) 39 | self.parser.add_argument("--disparity_smoothness", 40 | type=float, 41 | help="disparity smoothness weight", 42 | default=1e-3) 43 | self.parser.add_argument("--scales", 44 | nargs="+", 45 | type=int, 46 | help="scales used in the loss", 47 | default=[0, 1, 2]) 48 | self.parser.add_argument("--min_depth", 49 | type=float, 50 | help="minimum depth", 51 | default=0.1) 52 | self.parser.add_argument("--max_depth", 53 | type=float, 54 | help="maximum depth", 55 | default=100.0) 56 | self.parser.add_argument("--frame_ids", 57 | nargs="+", 58 | type=int, 59 | help="frames to load", 60 | default=[0, -1, 1]) 61 | 62 | def parse(self): 63 | self.options = self.parser.parse_args() 64 | return self.options 65 | -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | from options import TemporalModelOptions 2 | from trainer import Trainer 3 | 4 | 5 | if __name__ == '__main__': 6 | options = TemporalModelOptions() 7 | opt = options.parse() 8 | 9 | trainer = Trainer(opt) 10 | trainer.train() 11 | -------------------------------------------------------------------------------- /visualisation/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('') 3 | 4 | import torch 5 | import numpy as np 6 | import matplotlib 7 | 8 | matplotlib.use('Agg') 9 | 10 | import matplotlib.pyplot as plt 11 | from matplotlib.pylab import cm 12 | from PIL import Image 13 | from sklearn.decomposition import PCA 14 | 15 | from nuscenes.utils.geometry_utils import view_points 16 | 17 | DEFAULT_COLORMAP = cm.magma 18 | 19 | 20 | def create_colormap(): 21 | """Creates a colormap for visualisation of instances. 22 | """ 23 | colormap = np.zeros((256, 3), dtype=np.uint8) 24 | ind = np.arange(256, dtype=np.uint8) 25 | 26 | def bit_get(val, idx): 27 | """Gets the bit value. 28 | Args: 29 | val: Input value, int or numpy int array. 30 | idx: Which bit of the input val. 31 | Returns: 32 | The "idx"-th bit of input val. 33 | """ 34 | return (val >> idx) & 1 35 | 36 | for shift in reversed(range(8)): 37 | for channel in range(3): 38 | colormap[:, channel] |= bit_get(ind, channel) << shift 39 | ind >>= 3 40 | 41 | return colormap 42 | 43 | 44 | def hex_to_rgb(rgb_hex_str): 45 | """ converts string '0xFFFFFF' to a list of RGB values. """ 46 | rgb_int = int(rgb_hex_str, 16) 47 | r = rgb_int // (256 * 256) 48 | g = rgb_int // 256 % 256 49 | b = rgb_int % 256 % 256 50 | return [r, g, b] 51 | 52 | 53 | def apply_colormap(image, cmap=DEFAULT_COLORMAP, autoscale=False): 54 | """ 55 | Applies a colormap to the given 1 or 2 channel numpy image. if 2 channel, must be 2xHxW. Returns a HxWx3 numpy image 56 | """ 57 | if image.ndim == 2 or (image.ndim == 3 and image.shape[0] == 1): 58 | if image.ndim == 3: 59 | image = image[0] 60 | # grayscale scalar image 61 | if autoscale: 62 | image = _normalise(image) 63 | return cmap(image)[:, :, :3] 64 | if image.shape[0] == 3: 65 | # normalise rgb channels 66 | if autoscale: 67 | image = _normalise(image) 68 | return np.transpose(image, axes=[1, 2, 0]) 69 | raise Exception('Image must be 1, 2 or 3 channel to convert to colormap (CxHxW)') 70 | 71 | 72 | def _normalise(image): 73 | lower = np.min(image) 74 | delta = np.max(image) - lower 75 | if delta == 0: 76 | delta = 1 77 | image = (image.astype(np.float32) - lower) / delta 78 | return image 79 | 80 | 81 | def heatmap_image(image, cmap=DEFAULT_COLORMAP, autoscale=True, output_pil=False): 82 | """ 83 | Colourise a 1 or 2 channel image with a colourmap. 84 | """ 85 | image_cmap = apply_colormap(image, cmap=cmap, autoscale=autoscale) 86 | if output_pil: 87 | image_cmap = np.uint8(image_cmap * 255) 88 | return Image.fromarray(image_cmap) 89 | return image_cmap 90 | 91 | 92 | def image_to_tensor(pic): 93 | if isinstance(pic, np.ndarray): 94 | # handle numpy array 95 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 96 | if isinstance(img, torch.ByteTensor): 97 | return img.float().div(255) 98 | if isinstance(img, torch.DoubleTensor): 99 | return img.float() 100 | return img 101 | 102 | # handle PIL Image 103 | if pic.mode == 'I': 104 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 105 | elif pic.mode == 'I;16': 106 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 107 | else: 108 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 109 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 110 | if pic.mode == 'YCbCr': 111 | nchannel = 3 112 | elif pic.mode == 'I;16': 113 | nchannel = 1 114 | else: 115 | nchannel = len(pic.mode) 116 | img = img.view(pic.size[1], pic.size[0], nchannel) 117 | # put it from HWC to CHW format 118 | # yikes, this transpose takes 80% of the loading time/CPU 119 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 120 | if isinstance(img, torch.ByteTensor): 121 | return img.float().div(255) 122 | return img 123 | 124 | 125 | def convert_figure_numpy(figure): 126 | """ Convert figure to numpy image """ 127 | figure_np = np.frombuffer(figure.canvas.tostring_rgb(), dtype=np.uint8) 128 | figure_np = figure_np.reshape(figure.canvas.get_width_height()[::-1] + (3,)) 129 | return figure_np 130 | 131 | 132 | def plot_labels_on_image(img, instance_seg, position=None, velocity=None, intrinsics=None, dpi=100, 133 | alpha=0.8, id_legend=True): 134 | """ 135 | Parameters 136 | ---------- 137 | img_copy: np.array shape (H, W, 3) 138 | instance_seg: np.array shape (N_CLASSES, H, W) 139 | position: np.array shape (MAX_INSTANCES, 3) 140 | velocity: np.array shape (MAX_INSTANCES, 3) 141 | intrinsics: np.array shape (3, 3) 142 | """ 143 | img_copy = img.copy() 144 | height, width = img_copy.shape[0:2] 145 | fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi) 146 | ax = fig.gca() 147 | 148 | # Overlay instance segmentation on rgb image 149 | colormap = create_colormap() 150 | instance_seg = instance_seg.squeeze(0) # TODO: N_CLASSES 151 | mask = instance_seg != 0 152 | img_copy[mask] = ((1-alpha) * img_copy[mask] + alpha * colormap[instance_seg][mask]).astype(np.uint8) 153 | 154 | # Print all ids 155 | unique_ids = np.unique(instance_seg) 156 | if id_legend: 157 | ax.plot([], [], ' ', label='IDs: ' + ', '.join([str(x) for x in unique_ids[1:]])) 158 | ax.legend(loc='upper left', prop={'size': 8}) 159 | # Plot image 160 | ax.set_axis_off() 161 | ax.imshow(img_copy) 162 | if position is not None: 163 | # Convert position from camera reference frame to image plane 164 | image_position = view_points(position.T, intrinsics, True)[:2] 165 | unique_ids = np.unique(instance_seg) 166 | for inst_id in unique_ids[1:]: 167 | inst_id = inst_id - 1 168 | col, row = image_position[:, inst_id] 169 | 170 | text = 'xyz: {:.1f}/{:.1f}/{:.1f}\nv_xyz: {:.1f}/{:.1f}/{:.1f}'.format(*position[inst_id], 171 | *velocity[inst_id]) 172 | text_plot = ax.text(col, row, text, fontsize=5, fontweight='bold', color='black') 173 | text_plot.set_bbox(dict(facecolor='white', alpha=0.5, edgecolor=colormap[inst_id + 1] / 255)) 174 | 175 | fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) 176 | plt.draw() 177 | fig_np = convert_figure_numpy(fig) 178 | plt.close('all') 179 | 180 | return fig_np 181 | 182 | 183 | def plot_embedding_clustering(y, instance_seg, mask, config, dpi=100): 184 | """ 185 | Parameters 186 | ---------- 187 | y: np.array shape (emb_dim, H, W) 188 | instance_seg: np.array shape (N_CLASSES, H, W) 189 | mask: np.array shape(H, W) 190 | """ 191 | height, width = y.shape[1:] 192 | fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi) 193 | ax = fig.gca() 194 | 195 | y = np.transpose(y, (1, 2, 0)) 196 | instance_seg = instance_seg.squeeze(0) # TODO: N_CLASSES 197 | mask = mask.astype(np.bool) 198 | 199 | colormap = create_colormap() 200 | 201 | try: 202 | if mask.sum() > 0: 203 | y = y[mask] 204 | instance_seg = instance_seg[mask] 205 | 206 | pca = PCA(n_components=2) 207 | pca.fit(y) 208 | #print('Explained variance: {}'.format(pca.explained_variance_ratio_)) 209 | 210 | y_two_d = pca.transform(y) 211 | for id in np.unique(instance_seg): 212 | if id == 0: 213 | continue 214 | mask_id = (instance_seg == id) 215 | y1 = y_two_d[:, 0][mask_id] 216 | y2 = y_two_d[:, 1][mask_id] 217 | 218 | ax.scatter(y1, y2, c=(colormap[id] / 255).reshape((1, 3)), alpha=0.2) 219 | intra_cluster = plt.Circle((y1.mean(), y2.mean()), config['delta_v'], color='black', linestyle='--', linewidth=2, 220 | fill=False) 221 | inter_cluster = plt.Circle((y1.mean(), y2.mean()), config['delta_d'], color='black', linestyle='--', linewidth=2, 222 | alpha=0.5, fill=False) 223 | ax.add_artist(intra_cluster) 224 | ax.add_artist(inter_cluster) 225 | except: 226 | pass 227 | 228 | fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) 229 | plt.draw() 230 | fig_np = convert_figure_numpy(fig) 231 | plt.close('all') 232 | 233 | return fig_np 234 | 235 | 236 | def compute_pixel_barycenter(instance_seg, id): 237 | """ Compute the pixel barycenter of and instance. 238 | 239 | Parameters 240 | ---------- 241 | instance_seg: np.ndarray (height, width) 242 | id: int 243 | considered id 244 | 245 | Returns 246 | ------- 247 | barycenter: np.ndarray (2) 248 | barycenter of the instance in pixel space ie axis are height and width. 249 | """ 250 | height, width = instance_seg.shape 251 | mgrid = np.moveaxis(np.mgrid[:height, :width], source=0, destination=-1) 252 | instance_mask = (instance_seg == id) 253 | pixel_coords = mgrid[instance_mask] 254 | return pixel_coords.mean(axis=0) 255 | -------------------------------------------------------------------------------- /visualisation/visualisation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import torchvision.transforms as transforms 7 | 8 | from matplotlib import gridspec 9 | from visualisation.utils import create_colormap 10 | 11 | from common.utils import LABEL_NAMES 12 | 13 | 14 | def convert_to_pil(x): 15 | mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None] 16 | std = torch.tensor([0.229, 0.224, 0.225])[:, None, None] 17 | x = (x * std) + mean 18 | return transforms.ToPILImage()(x) 19 | 20 | 21 | ####### 22 | # Cityscapes semantic segmentation visualisation 23 | ####### 24 | # Segmentation visualisation 25 | def visualise_sem_seg(img, gt, y, save_filename=''): 26 | y = y.cpu().data.numpy() 27 | y = np.transpose(y, (1, 2, 0)) 28 | predicted_labels = np.argmax(y, axis=-1) 29 | gt = gt.cpu().data.numpy() 30 | 31 | colormap = create_colormap() 32 | plt.figure(figsize=(20, 10)) 33 | grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1]) 34 | 35 | plt.subplot(grid_spec[0]) 36 | plt.imshow(convert_to_pil(img.cpu())) 37 | plt.axis('off') 38 | plt.title('Image') 39 | 40 | plt.subplot(grid_spec[1]) 41 | plt.imshow(colormap[gt]) 42 | plt.axis('off') 43 | plt.title('Ground truth seg') 44 | 45 | plt.subplot(grid_spec[2]) 46 | plt.imshow(colormap[predicted_labels]) 47 | plt.axis('off') 48 | plt.title('Predicted seg') 49 | 50 | unique_labels = np.unique(predicted_labels) 51 | ax = plt.subplot(grid_spec[3]) 52 | # Legend 53 | full_color_map = colormap[np.arange(len(LABEL_NAMES))[:, None]] 54 | 55 | plt.imshow( 56 | full_color_map[unique_labels].astype(np.uint8)) 57 | ax.yaxis.tick_right() 58 | plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) 59 | plt.xticks([], []) 60 | ax.tick_params(width=0.0) 61 | plt.grid(False) 62 | 63 | if save_filename: 64 | plt.savefig(save_filename) 65 | plt.close() 66 | else: 67 | plt.show() 68 | 69 | 70 | def compare_bbox_instance_seg(img, nuscenes_box, instance_seg): 71 | idx = 0 72 | _, ax = plt.subplots(1, 1, figsize=(9, 16)) 73 | # Show image. 74 | ax.imshow(img) 75 | 76 | def draw_rect(selected_corners, color): 77 | prev = selected_corners[-1] 78 | for corner in selected_corners: 79 | ax.plot([prev[0], corner[0]], [prev[1], corner[1]], color=color, linewidth=2) 80 | prev = corner 81 | 82 | corresponding_box = nuscenes_box 83 | x_min = corresponding_box['x1'] 84 | x_max = corresponding_box['x2'] 85 | y_min = corresponding_box['y1'] 86 | y_max = corresponding_box['y2'] 87 | {'x1': x_min, 'x2': x_max, 'y1': y_min, 'y2': y_max} 88 | bounding_box_2d = np.array([[x_min, y_min], 89 | [x_min, y_max], 90 | [x_max, y_max], 91 | [x_max, y_min]]) 92 | 93 | draw_rect(bounding_box_2d, 'b') 94 | # draw_rect(corners.T[:4], color) 95 | plt.show() 96 | plt.figure(figsize=(9, 16)) 97 | plt.imshow((instance_seg[idx]).squeeze(), cmap='gray') 98 | plt.show() 99 | 100 | 101 | def visualise_nuscenes_3D(nusc): 102 | SENSOR = 'CAM_FRONT' 103 | 104 | scene = nusc.scene[0] 105 | sample_token = scene['first_sample_token'] 106 | count = 0 107 | 108 | while sample_token: 109 | print(sample_token) 110 | sample = nusc.get('sample', sample_token) 111 | data_token = sample['data'][SENSOR] 112 | data_path, boxes, camera_intrinsic = nusc.get_sample_data(data_token) 113 | 114 | nusc.render_sample_data(data_token) 115 | plt.show() 116 | sample_token = sample['next'] 117 | count += 1 118 | --------------------------------------------------------------------------------