├── .gitignore ├── LICENSE ├── README.md ├── configs ├── data │ ├── mf.yaml │ ├── pose_ec.yaml │ └── pose_eds.yaml ├── eval_real_defaults.yaml ├── model │ └── correlation3_unscaled.yaml ├── optim │ └── adam.yaml ├── train_defaults.yaml └── training │ ├── pose_finetuning_train_ec.yaml │ ├── pose_finetuning_train_eds.yaml │ └── supervised_train.yaml ├── data_preparation ├── colmap.py ├── real │ ├── eds_rectify_events_and_frames.py │ ├── generate_even_event_representation.py │ ├── generate_event_representation.py │ ├── prepare_ec_pose_supervision.py │ ├── prepare_ec_subseq.py │ ├── prepare_eds_pose_supervision.py │ ├── prepare_eds_subseq.py │ └── rectify_ec.py └── synthetic │ ├── generate_event_representations.py │ └── generate_tracks.py ├── disp_training ├── README.md ├── __init__.py ├── disp_configs │ ├── data │ │ └── m3ed.yaml │ ├── m3ed_test.yaml │ ├── m3ed_train.yaml │ ├── model │ │ └── correlation3_unscaled_disp.yaml │ └── optim │ │ └── adam.yaml ├── disp_data_preparation │ ├── __init__.py │ ├── feature_filter_depth.py │ ├── prepare_m3ed_data.py │ ├── track_storage.py │ └── utils.py ├── disp_dataloader │ ├── __init__.py │ └── m3ed_loader.py ├── disp_model │ ├── __init__.py │ ├── correlation3_unscaled_disp.py │ └── disp_template.py ├── disp_scripts │ ├── __init__.py │ └── benchmark.py ├── disp_utils │ ├── __init__.py │ └── disp_utils_torch.py ├── doc │ └── thumbnail.png ├── evaluate.py ├── requirements.txt └── train.py ├── doc ├── shapes_6dof_485_565_tracks.gif ├── thumbnail.PNG └── ziggy_in_the_arena_1350_1650-opt.gif ├── evaluate_real.py ├── models ├── common.py ├── correlation3_unscaled.py └── template.py ├── requirements.txt ├── scripts ├── benchmark.py └── visualize_eds_data.py ├── train.py └── utils ├── augmentations.py ├── callbacks.py ├── dataset.py ├── losses.py ├── representations.py ├── timers.py ├── torch_utils.py ├── track_utils.py ├── utils.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Jupyter Notebook 2 | .ipynb_checkpoints 3 | __pycache__/ 4 | .idea/ 5 | *.pyc 6 | -------------------------------------------------------------------------------- /configs/data/mf.yaml: -------------------------------------------------------------------------------- 1 | name: mf 2 | 3 | _target_: utils.dataset.MFDataModule 4 | data_dir: 5 | extra_dir: 6 | dt: 0.0200 7 | num_workers: 4 8 | 9 | # For tracks 10 | batch_size: 32 11 | n_train: 30000 12 | n_val: 1000 13 | 14 | augment: True 15 | global_mode: False 16 | mixed_dt: True 17 | -------------------------------------------------------------------------------- /configs/data/pose_ec.yaml: -------------------------------------------------------------------------------- 1 | name: pose 2 | 3 | _target_: utils.dataset.PoseDataModule 4 | 5 | root_dir: 6 | dataset_type: EC 7 | n_frames_skip: 3 8 | n_event_representations_per_frame: 5 9 | 10 | num_workers: 0 11 | batch_size: 16 12 | n_train: 2500 13 | n_val: 200 14 | -------------------------------------------------------------------------------- /configs/data/pose_eds.yaml: -------------------------------------------------------------------------------- 1 | name: pose 2 | 3 | _target_: utils.dataset.PoseDataModule 4 | 5 | root_dir: 6 | dataset_type: EDS 7 | n_frames_skip: 8 8 | n_event_representations_per_frame: 3 9 | 10 | num_workers: 0 11 | batch_size: 16 12 | n_train: 2500 13 | n_val: 200 14 | -------------------------------------------------------------------------------- /configs/eval_real_defaults.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: /${model.name}/${now:%Y-%m-%d_%H%M%S} 4 | 5 | gt_path: /gt_tracks 6 | running_locally: False 7 | 8 | 9 | weights_path: 10 | 11 | track_name: shitomasi_custom 12 | representation: time_surfaces_v2_5 13 | patch_size: 31 14 | visualize: False 15 | dt_track_vis: 0.2 16 | 17 | # Composing nested config with default 18 | defaults: 19 | - model: correlation3_unscaled 20 | # Pytorch lightning trainer's argument 21 | trainer: 22 | gpus: [0] 23 | -------------------------------------------------------------------------------- /configs/model/correlation3_unscaled.yaml: -------------------------------------------------------------------------------- 1 | name: correlation3_unscaled 2 | 3 | _target_: models.correlation3_unscaled.TrackerNetC 4 | patch_size: 31 5 | feature_dim: 384 6 | defaults: 7 | - /optim/adam.yaml@optimizer -------------------------------------------------------------------------------- /configs/optim/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | lr: 1e-4 # [1e-4 for supervised, 1e-6 for finetuning] 3 | -------------------------------------------------------------------------------- /configs/train_defaults.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir:/${data.name}/${model.name}/${experiment}/${now:%Y-%m-%d_%H%M%S} 4 | 5 | # Composing nested config with default 6 | experiment: open_source 7 | track_name: shitomasi_custom 8 | 9 | representation: time_surfaces_v2_5 10 | patch_size: 31 11 | 12 | debug: False 13 | n_vis: 2 14 | logging: True 15 | 16 | # Do not forget to set the learning rate for supervised or for pose finetuning in configs/optim/adam.yaml 17 | defaults: 18 | - data: mf # [mf, pose_eds, pose_ec] 19 | - model: correlation3_unscaled 20 | - training: supervised_train # [supervised_train, pose_finetuning_train_ec, pose_finetuning_train_eds] 21 | 22 | # Pytorch lightning trainer's argument 23 | trainer: 24 | benchmark: True 25 | log_every_n_steps: 10 26 | max_epochs: 40000 27 | num_processes: 1 28 | num_sanity_val_steps: 1 29 | -------------------------------------------------------------------------------- /configs/training/pose_finetuning_train_ec.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | checkpoint_path: 4 | init_unrolls: 8 5 | max_unrolls: 8 6 | unroll_factor: 8 7 | unroll_schedule: [200] 8 | -------------------------------------------------------------------------------- /configs/training/pose_finetuning_train_eds.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | checkpoint_path: 4 | init_unrolls: 8 5 | max_unrolls: 8 6 | unroll_factor: 8 7 | unroll_schedule: [100] 8 | -------------------------------------------------------------------------------- /configs/training/supervised_train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | checkpoint_path: none 4 | init_unrolls: 4 5 | max_unrolls: 23 6 | unroll_factor: 4 7 | unroll_schedule: [80000, 120000, 140000] 8 | -------------------------------------------------------------------------------- /data_preparation/colmap.py: -------------------------------------------------------------------------------- 1 | """ 2 | EXTRACT: 3 | Convert colmap's images.txt file to a stamped_groundtruth poses file 4 | colmap entries are: 5 | # IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME 6 | # POINTS2D[] as (X, Y, POINT3D_ID) 7 | 8 | stamped_ground_truth pose entries are: 9 | #timestamp[seconds] px py pz qx qy qz qw 10 | 11 | GENERATE: 12 | Create the images.txt file for colmap 13 | Create an empty points.txt file 14 | Create the cameras.txt file with a single camera formatted like : 15 | (1 PINHOLE 640 480 766.536025127154 767.5749459126396 291.0503512057777 227.4060484950132) 16 | """ 17 | 18 | 19 | import os 20 | from pathlib import Path 21 | 22 | import fire 23 | import numpy as np 24 | from scipy.spatial.transform import Rotation 25 | 26 | from utils.dataset import ECPoseSegmentDataset, EDSPoseSegmentDataset 27 | 28 | 29 | def read_colmap_data(colmap_data_path, skip_header=4): 30 | """ 31 | Returns the colmap data from images.txt file 32 | :param colmap_data_path: path to images.txt 33 | :param skip_header: how many columns to skip 34 | :return: image_indices, poses (x, y, z, qx, qy, qz, qw) 35 | """ 36 | pose_data = [] 37 | with open(colmap_data_path, "r") as colmap_data_f: 38 | for i_row, data_row in enumerate(colmap_data_f): 39 | if (i_row > skip_header) and (i_row - skip_header) % 2 == 0: 40 | data_row = data_row.split(" ") 41 | pose_data.append([data_row[i] for i in [-1, 5, 6, 7, 2, 3, 4, 1]]) 42 | pose_data = sorted(pose_data, key=lambda x: x[0]) 43 | image_idxs = [ 44 | int(pose[0].replace("frame_", "").replace(".png\n", "")) for pose in pose_data 45 | ] 46 | pose_data = np.array([pose[1:] for pose in pose_data]).astype(np.float32) 47 | return image_idxs, pose_data 48 | 49 | 50 | def extract(sequence_dir, dataset_type): 51 | assert dataset_type in ["EC", "EDS"], "Dataset type must be one of EC, EDS" 52 | if dataset_type == "EC": 53 | dataset_class = ECPoseSegmentDataset 54 | else: 55 | dataset_class = EDSPoseSegmentDataset 56 | sequence_dir = Path(sequence_dir) 57 | 58 | # Read image timestamps 59 | image_ts = dataset_class.get_frame_timestamps(sequence_dir).reshape((-1, 1)) 60 | 61 | # Read colmap poses 62 | colmap_data_path = sequence_dir / "colmap" / "images.txt" 63 | image_idxs, colmap_data = read_colmap_data(str(colmap_data_path)) 64 | 65 | # Invert the poses bc colmap transforms are world->camera instead of camera->world 66 | inverted_poses = [] 67 | for pose in colmap_data: 68 | T_C_W = np.eye(4) 69 | T_C_W[:3, :3] = Rotation.from_quat(pose[3:]).as_matrix() 70 | T_C_W[0, 3] = pose[0] 71 | T_C_W[1, 3] = pose[1] 72 | T_C_W[2, 3] = pose[2] 73 | T_W_C = np.linalg.inv(T_C_W) 74 | quat = Rotation.from_matrix(T_W_C[:3, :3]).as_quat() 75 | inverted_poses.append( 76 | [T_W_C[0, 3], T_W_C[1, 3], T_W_C[2, 3], quat[0], quat[1], quat[2], quat[3]] 77 | ) 78 | inverted_poses = np.array(inverted_poses) 79 | 80 | colmap_poses = np.concatenate([image_ts[image_idxs, 0:1], inverted_poses], axis=1) 81 | 82 | output_path = sequence_dir / "colmap" / "stamped_groundtruth.txt" 83 | np.savetxt( 84 | str(output_path), 85 | colmap_poses, 86 | header="#timestamp[seconds] px py pz qx qy qz qw", 87 | ) 88 | 89 | 90 | def generate(sequence_dir, dataset_type): 91 | assert dataset_type in ["EC", "EDS"], "Dataset type must be one of EC, EDS" 92 | if dataset_type == "EC": 93 | dataset_class = ECPoseSegmentDataset 94 | else: 95 | dataset_class = EDSPoseSegmentDataset 96 | seq_dir = Path(sequence_dir) 97 | 98 | # Read poses and image names 99 | pose_interpolator = dataset_class.get_pose_interpolator(seq_dir) 100 | 101 | image_dir = seq_dir / "images_corrected" 102 | if not image_dir.exists(): 103 | print("Rectified image directory not found") 104 | exit() 105 | else: 106 | image_paths = dataset_class.get_frame_paths(seq_dir) 107 | image_names = [os.path.split(image_path)[1] for image_path in image_paths] 108 | 109 | image_timestamps = dataset_class.get_frame_timestamps(seq_dir) 110 | 111 | # Write formatted poses to images.txt 112 | colmap_dir = seq_dir / "colmap" 113 | if not colmap_dir.exists(): 114 | colmap_dir.mkdir() 115 | 116 | colmap_poses_path = colmap_dir / "images.txt" 117 | with open(colmap_poses_path, "w") as colmap_poses_f: 118 | for image_idx, (image_ts, image_name) in enumerate( 119 | zip(image_timestamps, image_names) 120 | ): 121 | image_pose = pose_interpolator.interpolate_colmap(image_ts) 122 | colmap_poses_f.write( 123 | f"{image_idx+1} {image_pose[6]} {image_pose[3]} {image_pose[4]} {image_pose[5]} {image_pose[0]} {image_pose[1]} {image_pose[2]} 1 {image_name}\n\n" 124 | ) 125 | 126 | # Write empty points file 127 | with open(str(colmap_dir / "points3D.txt"), "w") as _: 128 | pass 129 | 130 | # Write cameras 131 | camera_matrix, _, _ = dataset_class.get_calibration(seq_dir) 132 | with open(str(colmap_dir / "cameras.txt"), "w") as colmap_cameras_f: 133 | colmap_cameras_f.write( 134 | f"1 PINHOLE {dataset_class.resolution[0]} {dataset_class.resolution[1]} " 135 | f"{camera_matrix[0, 0]} {camera_matrix[1, 1]} {camera_matrix[0, 2]} {camera_matrix[1, 2]}" 136 | ) 137 | 138 | 139 | if __name__ == "__main__": 140 | fire.Fire({"generate": generate, "extract": extract}) 141 | -------------------------------------------------------------------------------- /data_preparation/real/eds_rectify_events_and_frames.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import sys 5 | from os.path import join 6 | 7 | import cv2 8 | import h5py 9 | import hdf5plugin 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import yaml 13 | from tqdm import tqdm 14 | 15 | from utils.utils import blosc_opts 16 | 17 | 18 | class Camera: 19 | def __init__(self, data): 20 | self.intrinsics = np.eye(3) 21 | self.intrinsics[[0, 1, 0, 1], [0, 1, 2, 2]] = data["intrinsics"] 22 | 23 | # distortion 24 | self.distortion_coeffs = np.array(data["distortion_coeffs"]) 25 | self.distortion_model = data["distortion_model"] 26 | self.resolution = data["resolution"] 27 | 28 | if "T_cn_cnm1" not in data: 29 | self.R = np.eye(3) 30 | else: 31 | self.R = np.array(data["T_cn_cnm1"])[:3, :3] 32 | 33 | self.K = self.intrinsics 34 | 35 | @property 36 | def num_pixels(self): 37 | return np.prod(self.resolution) 38 | 39 | 40 | class CameraSystem: 41 | def __init__(self, data, fix_rotation=False): 42 | # load calibration 43 | 44 | self.cam0 = Camera(data["cam0"]) 45 | self.cam1 = Camera( 46 | data["cam1"] 47 | ) # if cam0.num_pixels > cam1.num_pixels else (cam1, cam0) 48 | 49 | self.newK = self.cam0.K 50 | self.newR = self.cam1.R 51 | self.newRes = self.cam0.resolution 52 | 53 | if not fix_rotation: 54 | # camera chain parameters 55 | self.newK = self.event_cam.K 56 | 57 | # tmp = cv2.stereoRectify(self.cam.K, self.cam.distortion_coeffs, 58 | # self.event_cam.K, self.event_cam.distortion_coeffs, 59 | # self.event_cam.resolution, T[:3, :3], T[:3, 3]) 60 | # find new extrinsics 61 | self.t = T[:3, 3] 62 | r3_cam0 = self.cam.R[:, 2] 63 | 64 | r1 = self.t / np.linalg.norm(self.t) 65 | r2 = np.cross(r3_cam0, r1) 66 | r3 = np.cross(r1, r2) 67 | self.newR = np.stack([r1, r2, r3], -1) 68 | print("distance: %s" % (np.linalg.norm(self.t) * self.newK[0, 0])) 69 | else: 70 | self.newR = self.cam.R 71 | self.newK = self.event_cam.K 72 | 73 | 74 | def vizloop(kwargs, callbacks, image_fun): 75 | kwargs = {**kwargs, "index": 0} 76 | 77 | while True: 78 | image = image_fun(kwargs) 79 | cv2.imshow("Viz", image) 80 | 81 | c = cv2.waitKey(3) 82 | key = chr(c & 255) 83 | 84 | for k, callback in callbacks.items(): 85 | if key == k: 86 | ret = callback(kwargs) 87 | if ret is not None: 88 | kwargs.update(ret) 89 | 90 | if c == 27: # 'q' or 'Esc': Quit 91 | break 92 | 93 | cv2.destroyAllWindows() 94 | 95 | kwargs.pop("index") 96 | return kwargs 97 | 98 | 99 | def _remap_events(events, map, rotate, shape): 100 | mx, my = map 101 | x, y = mx[events["y"], events["x"]], my[events["y"], events["x"]] 102 | p = np.array(events["p"]) 103 | t = np.array(events["t"]) 104 | 105 | target_width, target_height = shape 106 | 107 | if rotate: 108 | x = target_width - 1 - x 109 | y = target_height - 1 - y 110 | 111 | mask = (x >= 0) & (x <= target_width - 1) & (y >= 0) & (y <= target_height - 1) 112 | 113 | x = x[mask] 114 | y = y[mask] 115 | t = t[mask] 116 | p = p[mask] 117 | 118 | return {"x": x, "y": y, "t": t, "p": p} 119 | 120 | 121 | def process_events(file, output, maps, shape, rotate=False): 122 | events = h5py.File(file) 123 | events = _remap_events(events, maps, rotate, shape) 124 | with h5py.File(output, "w") as h5f_out: 125 | h5f_out.create_dataset( 126 | "x", data=events["x"], **blosc_opts(complevel=1, shuffle="byte") 127 | ) 128 | h5f_out.create_dataset( 129 | "y", data=events["y"], **blosc_opts(complevel=1, shuffle="byte") 130 | ) 131 | h5f_out.create_dataset( 132 | "p", data=events["p"], **blosc_opts(complevel=1, shuffle="byte") 133 | ) 134 | h5f_out.create_dataset( 135 | "t", data=events["t"], **blosc_opts(complevel=1, shuffle="byte") 136 | ) 137 | 138 | 139 | def _remap_img(img, map, flip, rotate): 140 | if flip: 141 | img = img[:, ::-1] 142 | mx, my = map 143 | img_remapped = cv2.remap(img, mx, my, cv2.INTER_CUBIC) 144 | if rotate: 145 | img_remapped = cv2.rotate(img_remapped, cv2.ROTATE_180) 146 | return img_remapped 147 | 148 | 149 | def process_img(img_file, output_folder, distortion_maps, flip, rotate): 150 | img = cv2.imread(img_file, cv2.IMREAD_GRAYSCALE).astype(np.uint8) 151 | img_remapped = _remap_img(img, distortion_maps, flip, rotate) 152 | output_path = os.path.join(output_folder, os.path.basename(img_file)) 153 | cv2.imwrite(output_path, img_remapped) 154 | 155 | 156 | def getRemapping(camsys: CameraSystem): 157 | # undistort image 158 | img_mapx, img_mapy = cv2.initUndistortRectifyMap( 159 | camsys.cam0.K, 160 | camsys.cam0.distortion_coeffs, 161 | camsys.newR @ camsys.cam0.R.T, 162 | camsys.newK, 163 | camsys.newRes, 164 | cv2.CV_32FC1, 165 | ) 166 | 167 | ev_mapx, ev_mapy = cv2.initUndistortRectifyMap( 168 | camsys.cam1.K, 169 | camsys.cam1.distortion_coeffs, 170 | camsys.newR @ camsys.cam1.R.T, 171 | camsys.newK, 172 | camsys.newRes, 173 | cv2.CV_32FC1, 174 | ) 175 | 176 | W, H = camsys.cam1.resolution 177 | coords = ( 178 | np.stack(np.meshgrid(np.arange(W), np.arange(H))) 179 | .reshape((2, -1)) 180 | .T.reshape((-1, 1, 2)) 181 | .astype("float32") 182 | ) 183 | points = cv2.undistortPoints( 184 | coords, 185 | camsys.cam1.K, 186 | camsys.cam1.distortion_coeffs, 187 | None, 188 | camsys.newR @ camsys.cam1.R.T, 189 | camsys.newK, 190 | ) 191 | inv_maps = points.reshape((H, W, 2)) 192 | 193 | return { 194 | "img_mapx": img_mapx, 195 | "img_mapy": img_mapy, 196 | "ev_mapx": ev_mapx, 197 | "ev_mapy": ev_mapy, 198 | "inv_mapx": inv_maps[..., 0], 199 | "inv_mapy": inv_maps[..., 1], 200 | } 201 | 202 | 203 | if __name__ == "__main__": 204 | parser = argparse.ArgumentParser("""Remap images to be aligned with frames""") 205 | parser.add_argument("sequence_name") 206 | parser.add_argument("--data_dir") 207 | parser.add_argument("--rotate", action="store_true") 208 | parser.add_argument("--flip", action="store_true") 209 | parser.add_argument("--map_key", default="") 210 | parser.add_argument("--debug", action="store_true", default="false") 211 | parser.add_argument( 212 | "-n", "--num_processes", help="Number of workers", type=int, default=8 213 | ) 214 | 215 | args = parser.parse_args() 216 | 217 | map_key = args.map_key 218 | 219 | # search for image folder 220 | image_dir = os.path.join(args.data_dir, args.sequence_name, "images") 221 | image_paths = sorted(glob.glob(join(image_dir, "*.png"))) 222 | output_image_dir = image_dir + "_corrected" 223 | if not os.path.exists(output_image_dir): 224 | os.makedirs(output_image_dir) 225 | 226 | # search for calibration file 227 | calibration_path = os.path.join(args.data_dir, "calib.yaml") 228 | with open(calibration_path, "r") as fh: 229 | cam_data = yaml.load(fh, Loader=yaml.SafeLoader) 230 | 231 | fix_rotation = False 232 | if args.rotate: 233 | fix_rotation = True 234 | 235 | camsys = CameraSystem(cam_data, fix_rotation) 236 | maps = getRemapping(camsys) 237 | 238 | for f in tqdm(image_paths, desc="Processing images..."): 239 | process_img( 240 | f, 241 | output_image_dir, 242 | (maps["img_mapx"], maps["img_mapy"]), 243 | args.flip, 244 | args.rotate, 245 | ) 246 | 247 | events_path = os.path.join(args.data_dir, args.sequence_name, "events.h5") 248 | output_events_path = events_path.replace(".h5", "_corrected.h5") 249 | print("Processing events...") 250 | process_events( 251 | events_path, 252 | output_events_path, 253 | (maps["inv_mapx"], maps["inv_mapy"]), 254 | tuple(camsys.cam1.resolution), 255 | args.rotate, 256 | ) 257 | -------------------------------------------------------------------------------- /data_preparation/real/generate_even_event_representation.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | 4 | import fire 5 | import h5py 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from data_preparation.synthetic.generate_voxel_grids import ( 10 | blosc_opts, 11 | events_to_voxel_grid, 12 | ) 13 | from utils.dataset import ECSubseqDatasetV2, EDSSubseqDatasetV2 14 | from utils.representations import ( 15 | EventStack, 16 | TimeSurface, 17 | VoxelGrid, 18 | events_to_event_stack, 19 | events_to_time_surface, 20 | ) 21 | 22 | 23 | def generate(sequence_dir, sequence_type, representation_type, r): 24 | """ 25 | Generates event representations for pose supervision. 26 | Subdivides the events between frames into r bins and constructs a dense event representation for 27 | the events inside each bin. 28 | """ 29 | sequence_dir = Path(sequence_dir) 30 | output_dir = sequence_dir / "events" / f"pose_{r:.0f}" / f"{representation_type}s_5" 31 | if output_dir.exists(): 32 | shutil.rmtree(output_dir) 33 | output_dir.mkdir(parents=True) 34 | 35 | if sequence_type == "EDS": 36 | dataset_class = EDSSubseqDatasetV2 37 | elif sequence_type == "EC": 38 | dataset_class = ECSubseqDatasetV2 39 | else: 40 | raise NotImplementedError(f"No dataset class for {sequence_type}") 41 | 42 | if representation_type == "time_surface": 43 | generation_function = events_to_time_surface 44 | representation = TimeSurface( 45 | (10, dataset_class.resolution[1], dataset_class.resolution[0]) 46 | ) 47 | elif representation_type == "voxel_grid": 48 | generation_function = events_to_voxel_grid 49 | representation = VoxelGrid( 50 | (5, dataset_class.resolution[1], dataset_class.resolution[0]), False 51 | ) 52 | elif representation_type == "event_stack": 53 | generation_function = events_to_event_stack 54 | representation = EventStack( 55 | (5, dataset_class.resolution[1], dataset_class.resolution[0]) 56 | ) 57 | else: 58 | raise NotImplementedError(f"No generation function for {representation_type}") 59 | 60 | ev_iterator = dataset_class.get_even_events_iterator(sequence_dir, r=r) 61 | 62 | for dt_evs, evs in tqdm(ev_iterator, desc="Generating ev reps..."): 63 | # Generate 64 | rep_tensor = generation_function( 65 | representation, evs["p"], evs["t"], evs["x"], evs["y"] 66 | ).numpy() 67 | rep_np = np.transpose(rep_tensor, (1, 2, 0)) 68 | 69 | # Write to disk 70 | output_path = output_dir / f"{str(int(round(dt_evs*1e6))).zfill(7)}.h5" 71 | if output_path.exists(): 72 | continue 73 | 74 | with h5py.File(output_path, "w") as h5f_out: 75 | h5f_out.create_dataset( 76 | f"{representation_type}", 77 | data=rep_np, 78 | shape=rep_np.shape, 79 | dtype=np.float32, 80 | **blosc_opts(complevel=1, shuffle="byte"), 81 | ) 82 | 83 | 84 | if __name__ == "__main__": 85 | fire.Fire(generate) 86 | -------------------------------------------------------------------------------- /data_preparation/real/generate_event_representation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import fire 4 | import h5py 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from data_preparation.synthetic.generate_voxel_grids import ( 9 | blosc_opts, 10 | events_to_voxel_grid, 11 | ) 12 | from utils.dataset import ECPoseSegmentDataset, EDSPoseSegmentDataset 13 | from utils.representations import ( 14 | EventStack, 15 | VoxelGrid, 16 | events_to_event_stack, 17 | events_to_time_surface, 18 | ) 19 | 20 | 21 | def generate(sequence_dir, sequence_type, representation_type, dt): 22 | """ 23 | :param sequence_dir: 24 | :param sequence_type: 25 | :param representation_type: 26 | :param dt: 27 | :return: 28 | """ 29 | 30 | sequence_dir = Path(sequence_dir) 31 | output_dir = sequence_dir / "events" / f"{dt:.4f}" / f"{representation_type}s_5" 32 | if not output_dir.exists(): 33 | output_dir.mkdir(parents=True) 34 | 35 | if sequence_type == "EDS": 36 | dataset_class = EDSPoseSegmentDataset 37 | elif sequence_type == "EC": 38 | dataset_class = ECPoseSegmentDataset 39 | else: 40 | raise NotImplementedError(f"No dataset class for {sequence_type}") 41 | # dataset_class = FPVPoseSegmentDataset 42 | 43 | if representation_type == "time_surface": 44 | generation_function = events_to_time_surface 45 | if representation_type == "voxel_grid": 46 | generation_function = events_to_voxel_grid 47 | representation = VoxelGrid( 48 | (5, dataset_class.resolution[1], dataset_class.resolution[0]), False 49 | ) 50 | elif representation_type == "event_stack": 51 | generation_function = events_to_event_stack 52 | representation = EventStack( 53 | (5, dataset_class.resolution[1], dataset_class.resolution[0]) 54 | ) 55 | else: 56 | raise NotImplementedError(f"No generation function for {representation_type}") 57 | 58 | ev_iterator = dataset_class.get_events_iterator(sequence_dir, dt=dt) 59 | 60 | for dt_evs, evs in tqdm(ev_iterator, desc="Generating ev reps..."): 61 | # Generate 62 | rep_tensor = generation_function( 63 | representation, evs["p"], evs["t"], evs["x"], evs["y"] 64 | ).numpy() 65 | rep_np = np.transpose(rep_tensor, (1, 2, 0)) 66 | 67 | # Write to disk 68 | output_path = output_dir / f"{str(int(round(dt_evs*1e6))).zfill(7)}.h5" 69 | with h5py.File(output_path, "w") as h5f_out: 70 | h5f_out.create_dataset( 71 | f"{representation_type}", 72 | data=rep_np, 73 | shape=rep_np.shape, 74 | dtype=np.float32, 75 | **blosc_opts(complevel=1, shuffle="byte"), 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | fire.Fire(generate) 81 | -------------------------------------------------------------------------------- /data_preparation/real/prepare_ec_pose_supervision.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate input event representations for pose refinement. 3 | Between each image, generate N representations. 4 | """ 5 | 6 | import multiprocessing 7 | import os 8 | from pathlib import Path 9 | 10 | import cv2 11 | import fire 12 | import h5py 13 | import hdf5plugin 14 | import numpy as np 15 | from matplotlib import pyplot as plt 16 | from pandas import read_csv 17 | from tqdm import tqdm 18 | 19 | from utils.utils import blosc_opts 20 | 21 | IMG_H = 180 22 | IMG_W = 240 23 | OUTPUT_DIR = None 24 | 25 | 26 | def generate_time_surfaces(sequence_dir, r=5, n_bins=5): 27 | sequence_dir = Path(sequence_dir) 28 | output_dir = sequence_dir / "events" / f"pose_{r}" 29 | if not output_dir.exists(): 30 | output_dir.mkdir(parents=True) 31 | 32 | # Read image timestamps 33 | frame_ts_arr = np.genfromtxt(sequence_dir / "images.txt", usecols=[0]) 34 | 35 | # Read events 36 | events = read_csv( 37 | str(sequence_dir / "events_corrected.txt"), delimiter=" " 38 | ).to_numpy() 39 | events_times = events[:, 0] * 1e6 40 | 41 | # Debug images 42 | debug_dir = sequence_dir / "events" / f"pose_{r}_debug" 43 | if not debug_dir.exists(): 44 | debug_dir.mkdir() 45 | 46 | # Generate time surfaces 47 | idx_ts = 1 48 | for i in tqdm(range(len(frame_ts_arr) - 1)): 49 | dt_us = (frame_ts_arr[i + 1] - frame_ts_arr[i]) * 1e6 // r 50 | dt_bin_us = dt_us / n_bins 51 | 52 | t0 = frame_ts_arr[i] * 1e6 53 | for j in range(r): 54 | if j == r - 1: 55 | t1 = frame_ts_arr[i + 1] * 1e6 56 | else: 57 | t1 = t0 + dt_us 58 | 59 | output_path = output_dir / f"{int(t1)}.h5" 60 | idx_ts += 1 61 | if output_path.exists(): 62 | continue 63 | 64 | time_surface = np.zeros((IMG_H, IMG_W, n_bins * 2), dtype=np.uint64) 65 | 66 | # iterate over bins 67 | for i_bin in range(5): 68 | t0_bin = t0 + i_bin * dt_bin_us 69 | if i_bin == 4: 70 | t1_bin = t1 71 | else: 72 | t1_bin = t0_bin + dt_bin_us 73 | 74 | first_idx = np.searchsorted(events_times, t0_bin, side="left") 75 | last_idx_p1 = np.searchsorted(events_times, t1_bin, side="right") 76 | 77 | x_bin = np.rint(np.array(events[first_idx:last_idx_p1, 1])).astype(int) 78 | y_bin = np.rint(np.array(events[first_idx:last_idx_p1, 2])).astype(int) 79 | p_bin = np.array(events[first_idx:last_idx_p1, 3]) 80 | t_bin = np.array(events[first_idx:last_idx_p1, 0]) * 1e6 81 | 82 | n_events = len(x_bin) 83 | for i_e in range(n_events): 84 | time_surface[ 85 | y_bin[i_e], x_bin[i_e], 2 * i_bin + int(p_bin[i_e]) 86 | ] = (t_bin[i_e] - t0) 87 | time_surface = np.divide(time_surface, dt_us) 88 | 89 | # Write to disk 90 | with h5py.File(output_path, "w") as h5f_out: 91 | h5f_out.create_dataset( 92 | "time_surface", 93 | data=time_surface, 94 | shape=time_surface.shape, 95 | dtype=np.float32, 96 | **blosc_opts(complevel=1, shuffle="byte"), 97 | ) 98 | 99 | t0 = t1 100 | 101 | 102 | if __name__ == "__main__": 103 | fire.Fire(generate_time_surfaces) 104 | -------------------------------------------------------------------------------- /data_preparation/real/prepare_ec_subseq.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepare data for a subset of an Event Camera Dataset sequence 3 | - Undistort images and events 4 | - Create time surfaces 5 | - Create an output directory with undistorted images, undistorted event txt, and time surfaces 6 | """ 7 | import os 8 | import shutil 9 | from glob import glob 10 | from pathlib import Path 11 | 12 | import cv2 13 | import h5py 14 | import hdf5plugin 15 | import numpy as np 16 | from fire import Fire 17 | from matplotlib import pyplot as plt 18 | from pandas import read_csv 19 | from tqdm import tqdm 20 | 21 | from utils.utils import blosc_opts 22 | 23 | 24 | def prepare_data(root_dir, sequence_name, start_idx, end_idx): 25 | sequence_dir = Path(root_dir) / sequence_name 26 | if not sequence_dir.exists(): 27 | print(f"Sequence directory does not exist for {sequence_name}") 28 | exit() 29 | 30 | # Read calib 31 | calib_data = np.genfromtxt(str(sequence_dir / "calib.txt")) 32 | camera_matrix = calib_data[:4] 33 | distortion_coeffs = calib_data[4:] 34 | camera_matrix = np.array( 35 | [ 36 | [camera_matrix[0], 0, camera_matrix[2]], 37 | [0, camera_matrix[1], camera_matrix[3]], 38 | [0, 0, 1], 39 | ] 40 | ) 41 | print("Calibration loaded") 42 | 43 | # Create output directory 44 | subseq_dir = Path(root_dir) / f"{sequence_name}_{start_idx}_{end_idx}" 45 | subseq_dir.mkdir(exist_ok=True) 46 | 47 | # Undistort images 48 | images_dir = sequence_dir / "images_corrected" 49 | if not images_dir.exists(): 50 | images_dir.mkdir() 51 | for img_idx, img_path in enumerate( 52 | tqdm( 53 | sorted(glob(os.path.join(str(sequence_dir / "images" / "*.png")))), 54 | desc="Undistorting images...", 55 | ) 56 | ): 57 | img = cv2.imread(img_path) 58 | img = cv2.undistort( 59 | img, cameraMatrix=camera_matrix, distCoeffs=distortion_coeffs 60 | ) 61 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 62 | filename = f"frame_{str(img_idx).zfill(8)}.png" 63 | cv2.imwrite(os.path.join(str(images_dir / filename)), img) 64 | img_tmp = cv2.imread(str(images_dir / "frame_00000000.png")) 65 | H_img, W_img = img_tmp.shape[:2] 66 | 67 | # Remove first entry in image timestamps 68 | image_timestamps = np.genfromtxt(str(sequence_dir / "images.txt"), usecols=[0]) 69 | image_timestamps = image_timestamps[1:] 70 | np.savetxt(str(sequence_dir / "images.txt"), image_timestamps) 71 | 72 | # Undistort events 73 | events_corrected_path = sequence_dir / "events_corrected.txt" 74 | if not events_corrected_path.exists(): 75 | events = read_csv( 76 | str(sequence_dir / "events.txt"), header=None, delimiter=" " 77 | ).to_numpy() 78 | print("Raw events loaded") 79 | 80 | events[:, 1:3] = cv2.undistortPoints( 81 | events[:, 1:3].reshape((-1, 1, 2)), 82 | camera_matrix, 83 | distortion_coeffs, 84 | P=camera_matrix, 85 | ).reshape( 86 | (-1, 2), 87 | ) 88 | events[:, 1:3] = np.rint(events[:, 1:3]) 89 | 90 | inbounds_mask = np.logical_and(events[:, 1] >= 0, events[:, 1] < W_img) 91 | inbounds_mask = np.logical_and(inbounds_mask, events[:, 2] >= 0) 92 | inbounds_mask = np.logical_and(inbounds_mask, events[:, 2] < H_img) 93 | events = events[inbounds_mask, :] 94 | 95 | print("Events undistorted") 96 | np.savetxt(events_corrected_path, events, ["%.9f", "%i", "%i", "%i"]) 97 | else: 98 | events = read_csv( 99 | str(events_corrected_path), header=None, delimiter=" " 100 | ).to_numpy() 101 | t_events = events[:, 0] 102 | 103 | subseq_images_dir = subseq_dir / "images_corrected" 104 | if not subseq_images_dir.exists(): 105 | subseq_images_dir.mkdir() 106 | 107 | for i in range(start_idx, end_idx + 1): 108 | shutil.copy( 109 | str(images_dir / f"frame_{str(i).zfill(8)}.png"), 110 | str(subseq_images_dir / f"frame_{str(i-start_idx).zfill(8)}.png"), 111 | ) 112 | 113 | # Get image dimensions 114 | IMG_H, IMG_W = cv2.imread( 115 | str(images_dir / "frame_00000001.png"), cv2.IMREAD_GRAYSCALE 116 | ).shape 117 | 118 | # Read image timestamps 119 | image_timestamps = np.genfromtxt(sequence_dir / "images.txt", usecols=[0]) 120 | image_timestamps = image_timestamps[start_idx : end_idx + 1] 121 | np.savetxt(str(subseq_dir / "images.txt"), image_timestamps) 122 | print( 123 | f"Image timestamps are in range [{image_timestamps[0]}, {image_timestamps[-1]}]" 124 | ) 125 | print(f"Event timestamps are in range [{t_events.min()}, {t_events.max()}]") 126 | 127 | # Copy calib and poses 128 | shutil.copy(str(sequence_dir / "calib.txt"), str(subseq_dir / "calib.txt")) 129 | shutil.copy( 130 | str(sequence_dir / "groundtruth.txt"), str(subseq_dir / "groundtruth.txt") 131 | ) 132 | 133 | # Generate debug frames 134 | debug_dir = sequence_dir / "debug_frames" 135 | debug_dir.mkdir(exist_ok=True) 136 | n_frames_debug = 0 137 | dt = 0.005 138 | for i in range(n_frames_debug): 139 | # Events 140 | t1 = image_timestamps[i] 141 | t0 = t1 - dt 142 | time_mask = np.logical_and(events[:, 0] >= t0, events[:, 0] < t1) 143 | events_slice = events[time_mask, :] 144 | 145 | on_mask = events_slice[:, 3] == 1 146 | off_mask = events_slice[:, 3] == 0 147 | events_slice_on = events_slice[on_mask, :] 148 | events_slice_off = events_slice[off_mask, :] 149 | 150 | # Image 151 | img = cv2.imread( 152 | str(images_dir / f"frame_{str(i).zfill(8)}.png"), cv2.IMREAD_GRAYSCALE 153 | ) 154 | 155 | fig = plt.figure() 156 | ax = fig.add_subplot() 157 | ax.imshow(img, cmap="gray") 158 | ax.scatter(events_slice_on[:, 1], events_slice_on[:, 2], s=5, c="green") 159 | ax.scatter(events_slice_off[:, 1], events_slice_off[:, 2], s=5, c="red") 160 | plt.show() 161 | fig.savefig(str(debug_dir / f"frame_{str(i).zfill(8)}.png")) 162 | fig.close() 163 | 164 | # Generate time surfaces 165 | for dt in [0.01, 0.02]: 166 | for n_bins in [1, 5]: 167 | dt_bin = dt / n_bins 168 | output_ts_dir = ( 169 | subseq_dir / "events" / f"{dt:.4f}" / f"time_surfaces_v2_{n_bins}" 170 | ) 171 | if not output_ts_dir.exists(): 172 | output_ts_dir.mkdir(parents=True, exist_ok=True) 173 | 174 | debug_dir = subseq_dir / f"debug_events_{n_bins}" 175 | debug_dir.mkdir(exist_ok=True) 176 | for i, t1 in tqdm( 177 | enumerate( 178 | np.arange(image_timestamps[0], image_timestamps[-1] + dt, dt) 179 | ), 180 | total=int((image_timestamps[-1] - image_timestamps[0]) / dt), 181 | desc="Generating time surfaces...", 182 | ): 183 | output_ts_path = ( 184 | output_ts_dir / f"{str(int(i * (dt * 1e6))).zfill(7)}.h5" 185 | ) 186 | if output_ts_path.exists(): 187 | continue 188 | 189 | time_surface = np.zeros((IMG_H, IMG_W, 2 * n_bins), dtype=np.float64) 190 | t0 = t1 - dt 191 | 192 | # iterate over bins 193 | for i_bin in range(n_bins): 194 | t0_bin = t0 + i_bin * dt_bin 195 | t1_bin = t0_bin + dt_bin 196 | 197 | time_mask = np.logical_and( 198 | events[:, 0] >= t0_bin, events[:, 0] < t1_bin 199 | ) 200 | events_slice = events[time_mask, :] 201 | 202 | for i in range(events_slice.shape[0]): 203 | if ( 204 | 0 <= events_slice[i, 2] < IMG_H 205 | and 0 <= events_slice[i, 1] < IMG_W 206 | ): 207 | time_surface[ 208 | int(events_slice[i, 2]), 209 | int(events_slice[i, 1]), 210 | 2 * i_bin + int(events_slice[i, 3]), 211 | ] = ( 212 | events_slice[i, 0] - t0 213 | ) 214 | time_surface = np.divide(time_surface, dt) 215 | 216 | with h5py.File(output_ts_path, "w") as h5f_out: 217 | h5f_out.create_dataset( 218 | "time_surface", 219 | data=time_surface, 220 | shape=time_surface.shape, 221 | dtype=np.float32, 222 | **blosc_opts(complevel=1, shuffle="byte"), 223 | ) 224 | 225 | # Visualize 226 | debug_event_frame = ((time_surface[:, :, 0] > 0) * 255).astype(np.uint8) 227 | cv2.imwrite( 228 | str(debug_dir / f"{str(int(i * dt * 1e6)).zfill(7)}.png"), 229 | debug_event_frame, 230 | ) 231 | 232 | 233 | if __name__ == "__main__": 234 | Fire(prepare_data) 235 | -------------------------------------------------------------------------------- /data_preparation/real/prepare_eds_pose_supervision.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate input event representations for pose refinement. 3 | Between each image, generate N representations. 4 | """ 5 | import multiprocessing 6 | import os 7 | import timeit 8 | from pathlib import Path 9 | 10 | import cv2 11 | import fire 12 | import h5py 13 | import hdf5plugin 14 | import numpy as np 15 | from tqdm import tqdm 16 | 17 | from utils.utils import blosc_opts 18 | 19 | IMG_H = 480 20 | IMG_W = 640 21 | OUTPUT_DIR = None 22 | 23 | 24 | def generate_time_surfaces(sequence_dir, r=3, n_bins=5): 25 | count = 0 26 | sequence_dir = Path(sequence_dir) 27 | output_dir = sequence_dir / "events" / f"pose_{r}" 28 | if not output_dir.exists(): 29 | output_dir.mkdir(parents=True) 30 | 31 | # Read image timestamps 32 | frame_ts_arr = np.genfromtxt(str(sequence_dir / "images_timestamps.txt")) 33 | 34 | # Read events 35 | events_file = h5py.File(str(sequence_dir / "events_corrected.h5"), "r") 36 | events_times = np.array(events_file["t"]) 37 | print(f"Last event at: {events_times[-1]}") 38 | 39 | for i in tqdm(range(len(frame_ts_arr) - 1)): 40 | print(f"{i}/{len(frame_ts_arr)} at {timeit.default_timer()}") 41 | dt_us = (frame_ts_arr[i + 1] - frame_ts_arr[i]) // r 42 | dt_bin_us = dt_us / n_bins 43 | 44 | t0 = frame_ts_arr[i] 45 | for j in range(r): 46 | count += 1 47 | if j == r - 1: 48 | t1 = frame_ts_arr[i + 1] 49 | else: 50 | t1 = t0 + dt_us 51 | 52 | output_path = output_dir / f"{int(t1)}.h5" 53 | if output_path.exists(): 54 | continue 55 | 56 | time_surface = np.zeros((IMG_H, IMG_W, n_bins * 2), dtype=np.uint64) 57 | 58 | # iterate over bins 59 | for i_bin in range(5): 60 | t0_bin = t0 + i_bin * dt_bin_us 61 | if i_bin == 4: 62 | t1_bin = t1 63 | else: 64 | t1_bin = t0_bin + dt_bin_us 65 | 66 | first_idx = np.searchsorted(events_times, t0_bin, side="left") 67 | last_idx_p1 = np.searchsorted(events_times, t1_bin, side="right") 68 | 69 | x_bin = np.rint( 70 | np.array(events_file["x"][first_idx:last_idx_p1]) 71 | ).astype(int) 72 | y_bin = np.rint( 73 | np.array(events_file["y"][first_idx:last_idx_p1]) 74 | ).astype(int) 75 | p_bin = np.array(events_file["p"][first_idx:last_idx_p1]) 76 | t_bin = np.array(events_file["t"][first_idx:last_idx_p1]) 77 | 78 | n_events = len(x_bin) 79 | for i_e in range(n_events): 80 | time_surface[ 81 | y_bin[i_e], x_bin[i_e], 2 * i_bin + int(p_bin[i_e]) 82 | ] = (t_bin[i_e] - t0) 83 | time_surface = np.divide(time_surface, dt_us) 84 | 85 | # Write to disk 86 | with h5py.File(output_path, "w") as h5f_out: 87 | h5f_out.create_dataset( 88 | "time_surface", 89 | data=time_surface, 90 | shape=time_surface.shape, 91 | dtype=np.float32, 92 | **blosc_opts(complevel=1, shuffle="byte", complib="blosc:zstd"), 93 | ) 94 | t0 = t1 95 | 96 | 97 | if __name__ == "__main__": 98 | fire.Fire(generate_time_surfaces) 99 | -------------------------------------------------------------------------------- /data_preparation/real/prepare_eds_subseq.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from pathlib import Path 4 | from shutil import copy 5 | 6 | import fire 7 | import h5py 8 | import hdf5plugin 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from utils.utils import blosc_opts 14 | 15 | IMG_H = 480 16 | IMG_W = 640 17 | 18 | 19 | def generate_subseq(seq_name, start_idx, end_idx, dt): 20 | """ 21 | Generate a subsequence of an EDS sequence with folder structure: 22 | events 23 |
24 | 25 | 0000000.h5 26 | ... 27 | images 28 | frame_.png 29 | frame_.png 30 | ... 31 | image_timestamps.txt 32 | stamped_groundtruth.txt 33 | :param seq_name: 34 | :param start_idx: Image index to start 35 | :param end_idx: Terminal image index (non-inclusive) 36 | :param dt: time delete used for time surface generation 37 | """ 38 | 39 | # Pathing 40 | input_dir = Path(f"/{seq_name}") 41 | output_dir = Path( 42 | f"/{seq_name}_{start_idx}_{end_idx}" 43 | ) 44 | if not output_dir.exists(): 45 | output_dir.mkdir(parents=True, exist_ok=True) 46 | 47 | # Copy pose data 48 | copy( 49 | str(input_dir / "stamped_groundtruth.txt"), 50 | str(output_dir / "stamped_groundtruth.txt"), 51 | ) 52 | 53 | # Filter timestamps 54 | image_timestamps = np.genfromtxt(str(input_dir / "images_timestamps.txt"))[ 55 | start_idx:end_idx 56 | ] 57 | np.savetxt(str(output_dir / "images_timestamps.txt"), image_timestamps, fmt="%i") 58 | 59 | # Copy images 60 | output_image_dir = output_dir / "images_corrected" 61 | if not output_image_dir.exists(): 62 | output_image_dir.mkdir() 63 | 64 | for idx in tqdm(range(start_idx, end_idx), desc="Copying images..."): 65 | copy( 66 | str(input_dir / "images_corrected" / f"frame_{str(idx).zfill(10)}.png"), 67 | str( 68 | output_dir 69 | / "images_corrected" 70 | / f"frame_{str(idx-start_idx).zfill(10)}.png" 71 | ), 72 | ) 73 | 74 | # Generate time surfaces 75 | dt_us = dt * 1e6 76 | for n_bins in [5]: 77 | dt_bin_us = dt_us / n_bins 78 | output_ts_dir = ( 79 | output_dir / "events" / f"{dt:.4f}" / f"time_surfaces_v2_{n_bins}" 80 | ) 81 | if not output_ts_dir.exists(): 82 | output_ts_dir.mkdir(parents=True, exist_ok=True) 83 | 84 | with h5py.File(str(input_dir / "events_corrected.h5"), "r") as h5f: 85 | time = np.asarray(h5f["t"]) 86 | 87 | for i, t1 in tqdm( 88 | enumerate( 89 | np.arange(image_timestamps[0], image_timestamps[-1] + dt_us, dt_us) 90 | ), 91 | total=int((image_timestamps[-1] - image_timestamps[0]) / dt_us), 92 | desc="Generating time surfaces...", 93 | ): 94 | output_ts_path = output_ts_dir / f"{str(int(i*dt_us)).zfill(7)}.h5" 95 | if output_ts_path.exists(): 96 | continue 97 | 98 | time_surface = np.zeros((IMG_H, IMG_W, 2 * n_bins), dtype=np.uint64) 99 | t0 = t1 - dt_us 100 | 101 | # iterate over bins 102 | for i_bin in range(n_bins): 103 | t0_bin = t0 + i_bin * dt_bin_us 104 | t1_bin = t0_bin + dt_bin_us 105 | 106 | first_idx = np.searchsorted(time, t0_bin, side="left") 107 | last_idx_p1 = np.searchsorted(time, t1_bin, side="right") 108 | out = { 109 | "x": np.rint( 110 | np.asarray(h5f["x"][first_idx:last_idx_p1]) 111 | ).astype(int), 112 | "y": np.rint( 113 | np.asarray(h5f["y"][first_idx:last_idx_p1]) 114 | ).astype(int), 115 | "p": np.asarray(h5f["p"][first_idx:last_idx_p1]), 116 | "t": time[first_idx:last_idx_p1], 117 | } 118 | n_events = out["x"].shape[0] 119 | 120 | for i in range(n_events): 121 | time_surface[ 122 | out["y"][i], out["x"][i], 2 * i_bin + int(out["p"][i]) 123 | ] = (out["t"][i] - t0) 124 | time_surface = np.divide(time_surface, dt_us) 125 | with h5py.File(output_ts_path, "w") as h5f_out: 126 | h5f_out.create_dataset( 127 | "time_surface", 128 | data=time_surface, 129 | shape=time_surface.shape, 130 | dtype=np.float32, 131 | **blosc_opts(complevel=1, shuffle="byte"), 132 | ) 133 | 134 | # Visualize 135 | for i in range(n_bins): 136 | plt.imshow((time_surface[:, :, i] * 255).astype(np.uint8)) 137 | plt.show() 138 | 139 | # Storing events in one cropped file 140 | first_t, last_t = image_timestamps[0], image_timestamps[-1] 141 | event_idx = np.searchsorted(time, np.asarray([first_t, last_t]), side="left") 142 | output_path = output_ts_dir / "events.h5" 143 | with h5py.File(output_path, "w") as h5f_out: 144 | h5f_out.create_dataset( 145 | "x", data=h5f["x"][event_idx[0] : event_idx[1]].astype(np.uint16) 146 | ) 147 | h5f_out.create_dataset( 148 | "y", data=h5f["y"][event_idx[0] : event_idx[1]].astype(np.uint16) 149 | ) 150 | h5f_out.create_dataset("p", data=h5f["p"][event_idx[0] : event_idx[1]]) 151 | h5f_out.create_dataset("t", data=time[event_idx[0] : event_idx[1]]) 152 | 153 | 154 | if __name__ == "__main__": 155 | fire.Fire(generate_subseq) 156 | -------------------------------------------------------------------------------- /data_preparation/real/rectify_ec.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepare data for a subset of an Event Camera Dataset sequence 3 | - Undistort images and events 4 | - Create time surfaces 5 | - Create an output directory with undistorted images, undistorted event txt, and time surfaces 6 | """ 7 | from pathlib import Path 8 | import os 9 | from glob import glob 10 | from fire import Fire 11 | from tqdm import tqdm 12 | 13 | from pandas import read_csv 14 | import numpy as np 15 | import cv2 16 | 17 | 18 | def prepare_data(root_dir, sequence_name): 19 | sequence_dir = Path(root_dir) / sequence_name 20 | if not sequence_dir.exists(): 21 | print(f"Sequence directory does not exist for {sequence_name}") 22 | exit() 23 | 24 | # Read calib 25 | calib_data = np.genfromtxt(str(sequence_dir / 'calib.txt')) 26 | camera_matrix = calib_data[:4] 27 | distortion_coeffs = calib_data[4:] 28 | camera_matrix = np.array([[camera_matrix[0], 0, camera_matrix[2]], 29 | [0, camera_matrix[1], camera_matrix[3]], 30 | [0, 0, 1]]) 31 | print("Calibration loaded") 32 | 33 | # Undistort images 34 | images_dir = sequence_dir / 'images_corrected' 35 | images_dir.mkdir() 36 | for img_idx, img_path in enumerate(tqdm(sorted(glob(str(sequence_dir / 'images' / '*.png'))), 37 | desc="Undistorting images...")): 38 | img = cv2.imread(img_path) 39 | img = cv2.undistort(img, cameraMatrix=camera_matrix, distCoeffs=distortion_coeffs) 40 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 41 | filename = f'frame_{str(img_idx).zfill(8)}.png' 42 | cv2.imwrite(os.path.join(str(images_dir / filename)), img) 43 | img_tmp = cv2.imread(str(images_dir / 'frame_00000000.png')) 44 | H_img, W_img = img_tmp.shape[:2] 45 | 46 | # Remove first entry in image timestamps if not already 47 | image_timestamps = np.genfromtxt(str(sequence_dir / 'images.txt'), usecols=[0]) 48 | if len(image_timestamps) == len(glob(str(sequence_dir / 'images' / '*.png'))): 49 | np.savetxt(str(sequence_dir / 'images.txt'), image_timestamps) 50 | 51 | # Undistort events 52 | events_corrected_path = sequence_dir / 'events_corrected.txt' 53 | if not events_corrected_path.exists(): 54 | events = read_csv(str(sequence_dir / 'events.txt'), header=None, delimiter=' ').to_numpy() 55 | print("Raw events loaded") 56 | 57 | events[:, 1:3] = cv2.undistortPoints(events[:, 1:3].reshape((-1, 1, 2)), 58 | camera_matrix, distortion_coeffs, P=camera_matrix).reshape((-1, 2),) 59 | events[:, 1:3] = np.rint(events[:, 1:3]) 60 | 61 | inbounds_mask = np.logical_and(events[:, 1] >= 0, events[:, 1] < W_img) 62 | inbounds_mask = np.logical_and(inbounds_mask, events[:, 2] >= 0) 63 | inbounds_mask = np.logical_and(inbounds_mask, events[:, 2] < H_img) 64 | events = events[inbounds_mask, :] 65 | 66 | print("Events undistorted") 67 | np.savetxt(events_corrected_path, events, ["%.9f", "%i", "%i", "%i"]) 68 | 69 | 70 | if __name__ == '__main__': 71 | Fire(prepare_data) 72 | -------------------------------------------------------------------------------- /data_preparation/synthetic/generate_event_representations.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from pathlib import Path 4 | 5 | import cv2 6 | import fire 7 | import h5py 8 | import hdf5plugin 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from utils.representations import VoxelGrid, events_to_voxel_grid 13 | from utils.utils import blosc_opts 14 | 15 | IMG_H = 384 16 | IMG_W = 512 17 | VOXEL_GRID_CONSTRUCTOR = VoxelGrid((5, 384, 512), True) 18 | 19 | 20 | def generate_event_count_images_single( 21 | input_seq_dir, output_dir, visualize=False, dt=0.01, **kwargs 22 | ): 23 | """ 24 | For each ts in [0.4:0.01:0.9], generate the event count image for a Multiflow sequence 25 | :param seq_dir: 26 | :return: 27 | """ 28 | input_seq_dir = Path(input_seq_dir) 29 | split = input_seq_dir.parents[0].stem 30 | output_seq_dir = output_dir / split / input_seq_dir.stem 31 | output_dir = output_seq_dir / "events" / "count_images" 32 | output_dir.mkdir(exist_ok=True, parents=True) 33 | dt_us = dt * 1e6 34 | 35 | with h5py.File(str(input_seq_dir / "events" / "events.h5"), "r") as h5f: 36 | time = np.asarray(h5f["t"]) 37 | 38 | for t1 in np.arange(400000, 900000 + dt_us, dt_us): 39 | output_path = output_dir / f"0{t1}.npy" 40 | if output_path.exists(): 41 | continue 42 | 43 | t0 = t1 - dt_us 44 | first_idx = np.searchsorted(time, t0, side="left") 45 | last_idx_p1 = np.searchsorted(time, t1, side="right") 46 | out = { 47 | "x": np.asarray(h5f["x"][first_idx:last_idx_p1]), 48 | "y": np.asarray(h5f["y"][first_idx:last_idx_p1]), 49 | "p": np.asarray(h5f["p"][first_idx:last_idx_p1]), 50 | "t": time[first_idx:last_idx_p1], 51 | } 52 | n_events = out["x"].shape[0] 53 | img_counts = np.zeros((IMG_H, IMG_W, 2), dtype=np.uint8) 54 | for i in range(n_events): 55 | img_counts[out["y"][i], out["x"][i], out["p"][i]] += 1 56 | 57 | # Write to disk 58 | np.save(str(output_path), img_counts) 59 | 60 | # Visualize 61 | if visualize: 62 | img_vis = np.interp(img_counts, (0, img_counts.max()), (0, 255)).astype( 63 | np.uint8 64 | ) 65 | img_vis = np.concatenate([img_vis, np.zeros((IMG_H, IMG_W, 1))], axis=2) 66 | cv2.imshow("Count Image", img_vis) 67 | cv2.waitKey(1) 68 | 69 | 70 | def generate_sbt_single( 71 | input_seq_dir, output_dir, visualize=False, n_bins=5, dt=0.01, **kwargs 72 | ): 73 | """ 74 | For each ts in [0.4:0.02:0.9], generate the event count image 75 | :param seq_dir: 76 | :return: 77 | """ 78 | input_seq_dir = Path(input_seq_dir) 79 | split = input_seq_dir.parents[0].stem 80 | output_seq_dir = output_dir / split / input_seq_dir.stem 81 | output_dir = output_seq_dir / "events" / f"{dt:.4f}" / f"event_stacks_{n_bins}" 82 | output_dir.mkdir(exist_ok=True, parents=True) 83 | dt_us = dt * 1e6 84 | dt_us_bin = dt_us / n_bins 85 | 86 | with h5py.File(str(input_seq_dir / "events" / "events.h5"), "r") as h5f: 87 | x, y, p, time = ( 88 | np.asarray(h5f["x"]), 89 | np.asarray(h5f["y"]), 90 | np.asarray(h5f["p"]), 91 | np.asarray(h5f["t"]), 92 | ) 93 | 94 | # dt of labels 95 | for t1 in np.arange(400000, 900000 + dt_us, dt_us): 96 | output_path = output_dir / f"0{t1}.h5" 97 | if output_path.exists(): 98 | continue 99 | 100 | time_surface = np.zeros((IMG_H, IMG_W, n_bins), dtype=np.int64) 101 | t0 = t1 - dt_us 102 | 103 | # iterate over bins 104 | for i_bin in range(n_bins): 105 | t0_bin = t0 + i_bin * dt_us_bin 106 | t1_bin = t0_bin + dt_us_bin 107 | idx0 = np.searchsorted(time, t0_bin, side="left") 108 | idx1 = np.searchsorted(time, t1_bin, side="right") 109 | x_bin = x[idx0:idx1] 110 | y_bin = y[idx0:idx1] 111 | p_bin = p[idx0:idx1] * 2 - 1 112 | 113 | n_events = len(x_bin) 114 | for i in range(n_events): 115 | time_surface[y_bin[i], x_bin[i], i_bin] += p_bin[i] 116 | 117 | # Write to disk 118 | with h5py.File(output_path, "w") as h5f_out: 119 | h5f_out.create_dataset( 120 | "event_stack", 121 | data=time_surface, 122 | shape=time_surface.shape, 123 | dtype=np.float32, 124 | **blosc_opts(complevel=1, shuffle="byte"), 125 | ) 126 | # Visualize 127 | if visualize: 128 | for i in range(n_bins): 129 | cv2.imshow( 130 | f"Time Surface Bin {i}", 131 | (time_surface[:, :, i] * 255).astype(np.uint8), 132 | ) 133 | cv2.waitKey(0) 134 | 135 | 136 | def generate_time_surface_single( 137 | input_seq_dir, output_dir, visualize=False, n_bins=5, dt=0.01, **kwargs 138 | ): 139 | """ 140 | For each ts in [0.4:0.02:0.9], generate the event count image 141 | :param seq_dir: 142 | :return: 143 | """ 144 | input_seq_dir = Path(input_seq_dir) 145 | split = input_seq_dir.parents[0].stem 146 | output_seq_dir = output_dir / split / input_seq_dir.stem 147 | output_dir = output_seq_dir / "events" / "0.0200" / f"time_surfaces_v2_{n_bins}" 148 | output_dir.mkdir(exist_ok=True, parents=True) 149 | dt_us = dt * 1e6 150 | dt_us_bin = dt_us / n_bins 151 | 152 | with h5py.File(str(input_seq_dir / "events" / "events.h5"), "r") as h5f: 153 | time = np.asarray(h5f["t"]) 154 | idxs_sorted = np.argsort(time) 155 | x, y, p, time = ( 156 | np.asarray(h5f["x"])[idxs_sorted], 157 | np.asarray(h5f["y"])[idxs_sorted], 158 | np.asarray(h5f["p"])[idxs_sorted], 159 | np.asarray(h5f["t"])[idxs_sorted], 160 | ) 161 | 162 | # dt of labels 163 | for t1 in np.arange(400000, 900000 + dt_us, dt_us): 164 | output_path = output_dir / f"0{t1}.h5" 165 | if output_path.exists(): 166 | continue 167 | 168 | time_surface = np.zeros((IMG_H, IMG_W, n_bins * 2), dtype=np.uint64) 169 | t0 = t1 - dt_us 170 | 171 | # iterate over bins 172 | for i_bin in range(n_bins): 173 | t0_bin = t0 + i_bin * dt_us_bin 174 | t1_bin = t0_bin + dt_us_bin 175 | mask_t = np.logical_and(time > t0_bin, time <= t1_bin) 176 | x_bin, y_bin, p_bin, t_bin = ( 177 | x[mask_t], 178 | y[mask_t], 179 | p[mask_t], 180 | time[mask_t], 181 | ) 182 | n_events = len(x_bin) 183 | for i in range(n_events): 184 | time_surface[y_bin[i], x_bin[i], 2 * i_bin + int(p_bin[i])] = ( 185 | t_bin[i] - t0 186 | ) 187 | time_surface = np.divide(time_surface, dt_us) 188 | 189 | # Write to disk 190 | with h5py.File(output_path, "w") as h5f_out: 191 | h5f_out.create_dataset( 192 | "time_surface", 193 | data=time_surface, 194 | shape=time_surface.shape, 195 | dtype=np.float32, 196 | **blosc_opts(complevel=1, shuffle="byte"), 197 | ) 198 | # Visualize 199 | if visualize: 200 | for i in range(n_bins): 201 | cv2.imshow( 202 | f"Time Surface Bin {i}", 203 | (time_surface[:, :, i] * 255).astype(np.uint8), 204 | ) 205 | cv2.waitKey(0) 206 | 207 | 208 | def generate_voxel_grid_single(input_seq_dir, output_dir, n_bins=5, dt=0.01, **kwargs): 209 | """ 210 | For each ts in [0.4:0.02:0.9], generate the event count image 211 | :param seq_dir: 212 | :return: 213 | """ 214 | input_seq_dir = Path(input_seq_dir) 215 | split = input_seq_dir.parents[0].stem 216 | output_seq_dir = output_dir / split / input_seq_dir.stem 217 | output_dir = output_seq_dir / "events" / f"{dt:.4f}" / f"voxel_grids_{n_bins}" 218 | output_dir.mkdir(exist_ok=True, parents=True) 219 | 220 | dt_us = dt * 1e6 221 | 222 | with h5py.File(str(input_seq_dir / "events" / "events.h5"), "r") as h5f: 223 | time = np.asarray(h5f["t"]) 224 | idxs_sorted = np.argsort(time) 225 | x, y, p, time = ( 226 | np.asarray(h5f["x"])[idxs_sorted], 227 | np.asarray(h5f["y"])[idxs_sorted], 228 | np.asarray(h5f["p"])[idxs_sorted], 229 | np.asarray(h5f["t"])[idxs_sorted], 230 | ) 231 | 232 | # dt of labels 233 | for t1 in np.arange(400000, 900000 + dt_us, dt_us): 234 | output_path = output_dir / f"0{int(t1)}.h5" 235 | if output_path.exists(): 236 | continue 237 | 238 | t0 = t1 - dt_us 239 | mask_t = np.logical_and(time > t0, time <= t1) 240 | x_bin, y_bin, p_bin, t_bin = x[mask_t], y[mask_t], p[mask_t], time[mask_t] 241 | curr_voxel_grid = events_to_voxel_grid( 242 | VOXEL_GRID_CONSTRUCTOR, p_bin, t_bin, x_bin, y_bin 243 | ) 244 | curr_voxel_grid = curr_voxel_grid.numpy() 245 | curr_voxel_grid = np.transpose(curr_voxel_grid, (1, 2, 0)) 246 | 247 | # Write to disk 248 | with h5py.File(output_path, "w") as h5f_out: 249 | h5f_out.create_dataset( 250 | "voxel_grid", 251 | data=curr_voxel_grid, 252 | shape=curr_voxel_grid.shape, 253 | dtype=np.float32, 254 | **blosc_opts(complevel=1, shuffle="byte"), 255 | ) 256 | 257 | 258 | def generate( 259 | input_dir, 260 | output_dir, 261 | representation_type, 262 | dts=(0.01, 0.02), 263 | n_bins=5, 264 | visualize=False, 265 | **kwargs, 266 | ): 267 | input_dir = Path(input_dir) 268 | output_dir = Path(output_dir) 269 | 270 | if representation_type == "time_surface": 271 | generation_function = generate_time_surface_single 272 | elif representation_type == "voxel_grid": 273 | generation_function = generate_voxel_grid_single 274 | elif representation_type == "event_stack": 275 | generation_function = generate_sbt_single 276 | elif representation_type == "event_count": 277 | generation_function = generate_event_count_images_single 278 | else: 279 | raise NotImplementedError(f"No generation function for {representation_type}") 280 | 281 | for split in ["train", "test"]: 282 | split_dir = input_dir / split 283 | n_seqs = len(os.listdir(str(split_dir))) 284 | print(f"Generate representations for {split}") 285 | 286 | for input_seq_dir in tqdm(split_dir.iterdir(), total=n_seqs): 287 | for dt in dts: 288 | generation_function( 289 | input_seq_dir, output_dir, visualize=visualize, n_bins=n_bins, dt=dt 290 | ) 291 | 292 | 293 | if __name__ == "__main__": 294 | fire.Fire(generate) 295 | -------------------------------------------------------------------------------- /data_preparation/synthetic/generate_tracks.py: -------------------------------------------------------------------------------- 1 | """ Script for generating feature tracks from the Multiflow dataset. """ 2 | import multiprocessing 3 | import os 4 | import shutil 5 | from pathlib import Path 6 | 7 | import cv2 8 | import fire 9 | import h5py 10 | import hdf5plugin 11 | import numpy as np 12 | from matplotlib import pyplot as plt 13 | from tqdm import tqdm 14 | 15 | # MultiFlow Image Dims 16 | IMG_H = 384 17 | IMG_W = 512 18 | 19 | # Corner Parameters (Global for Multiprocessing) 20 | MAX_CORNERS = 30 21 | MIN_DISTANCE = 30 22 | QUALITY_LEVEL = 0.3 23 | BLOCK_SIZE = 25 24 | K = 0.15 25 | USE_HARRIS_DETECTOR = False 26 | OUTPUT_DIR = None 27 | TRACK_NAME = "shitomasi_custom_v5" 28 | 29 | # Filtering 30 | MIN_TRACK_DISPLACEMENT = 5 31 | displacements_all = [] 32 | 33 | 34 | def generate_single_track(seq_dir, dt=0.01): 35 | tracks = [] 36 | dt_us = dt * 1e6 37 | 38 | # Get split 39 | split = seq_dir.parents[0].stem 40 | 41 | # Load reference image 42 | img_t0_p = seq_dir / "images" / "0400000.png" 43 | if img_t0_p.exists(): 44 | img_t0 = cv2.imread( 45 | str(seq_dir / "images" / "0400000.png"), cv2.IMREAD_GRAYSCALE 46 | ) 47 | else: 48 | print(f"Sequence {seq_dir} has no reference image.") 49 | return 50 | 51 | # Detect corners 52 | corners = cv2.goodFeaturesToTrack( 53 | img_t0, 54 | MAX_CORNERS, 55 | QUALITY_LEVEL, 56 | MIN_DISTANCE, 57 | k=K, 58 | useHarrisDetector=USE_HARRIS_DETECTOR, 59 | blockSize=BLOCK_SIZE, 60 | ) 61 | 62 | # Initialize tracks 63 | for i_track in range(corners.shape[0]): 64 | track = np.array([0.4, corners[i_track, 0, 0], corners[i_track, 0, 1]]) 65 | tracks.append(track.reshape((1, 3))) 66 | 67 | # Read flow 68 | for ts_us in np.arange(400000 + dt_us, 900000 + dt_us, dt_us): 69 | flow_path = seq_dir / "flow" / f"0{ts_us:.0f}.h5" 70 | with h5py.File(str(flow_path), "r") as h5f: 71 | flow = np.asarray(h5f["flow"]) 72 | 73 | for i_corner, corner in enumerate(corners): 74 | x_init, y_init = corners[i_corner, 0, 0], corners[i_corner, 0, 1] 75 | new_track_entry = np.array([ts_us * 1e-6, x_init, y_init]) 76 | new_track_entry[1] += flow[int(y_init), int(x_init), 0] 77 | new_track_entry[2] += flow[int(y_init), int(x_init), 1] 78 | new_track_entry = new_track_entry.reshape((1, 3)) 79 | tracks[i_corner] = np.append(tracks[i_corner], new_track_entry, axis=0) 80 | 81 | # Filter tracks by minimum motion and OOB 82 | filtered_tracks = [] 83 | for i_corner in range(len(tracks)): 84 | track = tracks[i_corner][:, 1:] 85 | # Displacement 86 | start_pt = track[0, :] 87 | end_pt = track[-1, :] 88 | displacement = np.linalg.norm(end_pt - start_pt) 89 | displacements_all.append(displacement) 90 | if displacement < MIN_TRACK_DISPLACEMENT: 91 | continue 92 | 93 | # OOB 94 | x_inbounds = np.logical_and(track[:, 0] > 0, track[:, 0] < IMG_W - 1).all() 95 | y_inbounds = np.logical_and(track[:, 1] > 0, track[:, 1] < IMG_H - 1).all() 96 | if not (x_inbounds and y_inbounds): 97 | continue 98 | 99 | filtered_tracks.append(tracks[i_corner]) 100 | 101 | if len(filtered_tracks) == 0: 102 | return 103 | print(f"Remaining tracks after filtering: {len(filtered_tracks)}") 104 | 105 | for track_idx in range(len(filtered_tracks)): 106 | track = filtered_tracks[track_idx] 107 | track_idx_column = track_idx * np.ones((track.shape[0], 1), dtype=track.dtype) 108 | filtered_tracks[track_idx] = np.concatenate([track_idx_column, track], axis=1) 109 | 110 | # Sort row entries 111 | filtered_tracks = np.concatenate(filtered_tracks, axis=0) 112 | sorted_idxs = np.lexsort((filtered_tracks[:, 0], filtered_tracks[:, 1])) 113 | filtered_tracks = filtered_tracks[sorted_idxs] 114 | 115 | # Write tracks to disk 116 | tracks_dir = OUTPUT_DIR / split / seq_dir.stem / "tracks" 117 | if not tracks_dir.exists(): 118 | tracks_dir.mkdir() 119 | output_path = tracks_dir / f"{TRACK_NAME}.gt.txt" 120 | np.savetxt(output_path, filtered_tracks) 121 | 122 | 123 | def generate_tracks(dataset_dir, output_dir): 124 | """ 125 | - For both the train and test splits:\n 126 | - For each sequence: 127 | - Detect harris corners at t=0.4 128 | - For each corner, i_corner: 129 | - Read displacement from flow images (0.41 <= t <= 0.9) to obtain the track 130 | - Write track to output_dir//tracks/i_corner.txt 131 | :param dataset_dir: Directory path to multiflow dataset 132 | :param output_dir: Output directory to obtained tracks 133 | """ 134 | global OUTPUT_DIR 135 | 136 | # Input and output pathing 137 | dataset_dir = Path(dataset_dir) 138 | assert dataset_dir.exists(), "Path to Multiflow dataset does not exist." 139 | 140 | OUTPUT_DIR = Path(output_dir) 141 | if not OUTPUT_DIR.exists(): 142 | OUTPUT_DIR.mkdir() 143 | 144 | # Generate tracks 145 | for split in ["test", "train"]: 146 | split_dir = dataset_dir / split 147 | print(f"Generate tracks for {split}") 148 | 149 | global displacements_all 150 | displacements_all = [] 151 | 152 | n_seqs = len(os.listdir(str(split_dir))) 153 | with multiprocessing.Pool(10) as p: 154 | list(tqdm(p.imap(generate_single_track, split_dir.iterdir()), total=n_seqs)) 155 | 156 | 157 | if __name__ == "__main__": 158 | fire.Fire(generate_tracks) 159 | -------------------------------------------------------------------------------- /disp_training/README.md: -------------------------------------------------------------------------------- 1 | # Data-driven Feature Tracking for Event Cameras with and without Frames 2 | 3 |

4 | youtube_video 5 |

6 | 7 | This is the code for the T-PAMI 2025 paper **Data-driven Feature Tracking for Event Cameras with and without Frames** 8 | ([PDF](https://rpg.ifi.uzh.ch/docs/Arxiv24_Messikommer.pdf)) by [Nico Messikommer](https://messikommernico.github.io/), [Carter Fang](https://ctyfang.github.io/), [Mathis Gehrig](https://magehrig.github.io/), [Giovanni Cioffi](https://giovanni-cioffi.netlify.app/), and [Davide Scaramuzza](http://rpg.ifi.uzh.ch/people_scaramuzza.html). 9 | 10 | This subdirectory `disp_training/` represents a separate code base for the disparity estimation method presented in the T-PAMI 2025 paper **Data-driven Feature Tracking for Event Cameras with and without Frames**. 11 | This code is independent of the code base used in the CVPR23 paper "Data-driven Feature Tracking for Event Cameras" located in the parent directory. 12 | 13 | If you use any of this code, please cite the following publication: 14 | 15 | ```bibtex 16 | @Article{Messikommer25tpami, 17 | author = {Nico Messikommer and Carter Fang and Mathias Gehrig and Giovanni Cioffi and Davide Scaramuzza}, 18 | title = {Data-driven Feature Tracking for Event Cameras with and without Frames}, 19 | journal = {{IEEE} Trans. Pattern Anal. Mach. Intell. (T-PAMI)}, 20 | year = {2025}, 21 | } 22 | ``` 23 | 24 | ## Abstract 25 | 26 | Because of their high temporal resolution, increased resilience to motion blur, and very sparse output, event cameras have 27 | been shown to be ideal for low-latency and low-bandwidth feature tracking, even in challenging scenarios. Existing feature tracking 28 | methods for event cameras are either handcrafted or derived from first principles but require extensive parameter tuning, are sensitive 29 | to noise, and do not generalize to different scenarios due to unmodeled effects. To tackle these deficiencies, we introduce the first 30 | data-driven feature tracker for event cameras, which leverages low-latency events to track features detected in an intensity frame. We 31 | achieve robust performance via a novel frame attention module, which shares information across feature tracks. Our tracker is 32 | designed to operate in two distinct configurations: solely with events or in a hybrid mode incorporating both events and frames. The 33 | hybrid model offers two setups: an aligned configuration where the event and frame cameras share the same viewpoint, and a hybrid 34 | stereo configuration where the event camera and the standard camera are positioned side-by-side. This side-by-side arrangement is 35 | particularly valuable as it provides depth information for each feature track, enhancing its utility in applications such as visual odometry 36 | and simultaneous localization and mapping. 37 | 38 | --- 39 | 40 | ## Content 41 | 42 | This document describes the usage and installation for the sparse disparity method.
43 | 44 | 1. [Installation](#Installation)
45 | 2. [Pretrained Weights](#Pretrained-Weights)
46 | 3. [Preparing Disparity Data](#Preparing-Disparity-Data)
47 | 4. [Training](#Training)
48 | 5. [Evaluation](#Evaluation)
49 | 50 | --- 51 | 52 | ## Installation 53 | 54 | The following code uses Python 3.9.7
55 | All the files referenced below are located in the directory `disp_training/` to avoid compatibility problems with the event tracker published in the CVPR23 paper "Data-driven Feature Tracking for Event Cameras" paper.
56 | 57 | 1. If desired, a conda environment can be created using the following command: 58 | 59 | ```bash 60 | conda create -n 61 | ``` 62 | 63 | 2. Install the dependencies via the requirements.txt file
64 | 65 | Dependencies for training: 66 |
    67 |
  • PyTorch
  • 68 |
  • Torch Lightning
  • 69 |
  • Hydra
  • 70 |

71 | 72 | Dependencies for pre-processing: 73 |
    74 |
  • numpy
  • 75 |
  • OpenCV
  • 76 |
  • H5Py and HDF5Plugin
  • 77 |

78 | 79 | Dependencies for visualization: 80 |
    81 |
  • matplotlib
  • 82 |
  • seaborn
  • 83 |
  • imageio
  • 84 |

85 | --- 86 | 87 | ## Pretrained Weights 88 | 89 | We provide the [network weights](https://download.ifi.uzh.ch/rpg/CVPR23_deep_ev_tracker/disp_pretrained_weights.ckpt) trained on the M3ED dataset 90 | 91 | 92 | --- 93 | 94 | ## Preparing Disparity Data 95 | 96 | ### Download M3ED Dataset 97 | 98 | Download link [M3ED](https://m3ed.io/download/) 99 | 100 | To generate the ground truth disparity and feature tracks for the M3ED dataset, we need the `_data.h5` and `_depth_gt.h5` files. 101 | 102 | If you use the M3ED dataset in an academic context, please cite: 103 | 104 | ```bibtex 105 | @InProceedings{Chaney_2023_CVPR, 106 | author = {Chaney, Kenneth and Cladera, Fernando and Wang, Ziyun and Bisulco, Anthony and Hsieh, M. Ani and Korpela, Christopher and Kumar, Vijay and Taylor, Camillo J. and Daniilidis, Kostas}, 107 | title = {M3ED: Multi-Robot, Multi-Sensor, Multi-Environment Event Dataset}, 108 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 109 | month = {June}, 110 | year = {2023}, 111 | pages = {4015-4022} 112 | } 113 | ``` 114 | 115 | ### Pre-Processing Instructions 116 | 117 | The generation of the ground truth disparity, as well as the feature track, is done in the `data_preparation/real/prepare_m3ed.py` script. 118 | Before running the preprocessing script, it is necessary to set the paths to the M3ED dataset and output directory inside the script as global values. 119 | The script can be run with the following command: 120 | 121 | ```bash 122 | python -m disp_training.disp_data_preparation.prepare_m3ed_data 123 | ``` 124 | 125 | 126 | The resulting directory structure should look like:
127 | 128 | ``` 129 | / 130 | ├─ / 131 | │ ├─ rectification_calibration.yaml 132 | │ ├─ rectification_maps.h5 133 | │ ├─ rectified_data.h5 134 | │ ├─ rectified_tracks.h5 135 | ``` 136 | 137 | --- 138 | 139 | ## Training 140 | 141 | The training config is located at `disp_configs/m3ed_train.yaml`.
142 | 143 | To configure the training, the following parameters are important: 144 | 145 |
    146 |
  • dir - Log directory storing the training metrics and network weights. The directory will be created at the beginning of the training.
  • 147 |
  • augment - Whether to augment the indices map concatenated to the image input. 148 | No augmentations are applied in the image space.
  • 149 |
150 | 151 | Additionally, inside the `disp_configs/data/m3ed.yaml`, the path to the preprocessed M3ED dataset `data_dir` must be set. 152 | 153 | With everything configured, we can begin training by running 154 | 155 | ```bash 156 | CUDA_VISIBLE_DEVICES= python train.py 157 | ``` 158 | 159 | Hydra will then instantiate the dataloader and model. 160 | PyTorch Lightning will handle the training and validation loops. 161 | All outputs (checkpoints, etc) will be written to the log directory.
162 | 163 | The disp_correlation_unscaled model inherits from `disp_models/disp_template.py` since it contains the core logic for training and validation. 164 | 165 | To inspect models during training, we can launch an instance of tensorboard for the log directory: 166 | `tensorboard --logdir `. 167 | 168 | --- 169 | 170 | ## Evaluation 171 | 172 | To test the disparity method on the M3ED dataset, we need to first run the evaluation script `evaluate.py`. 173 | In a second step, we can run the benchmark script `benchmark.py` to compute the performance metrics. 174 | 175 | ### Inference 176 | 177 | To run the evaluation script, the following parameters should be specified in the `disp_configs/m3ed_test.yaml` file. 178 | 179 |
    180 |
  • dir - Log directory, where the evaluation files such as the predicted disparity (results.npz) and the ground truth disparity (ground_truth.npz) will be stored.
  • 181 |
  • checkpoint_path - Path to the trained network weights.
  • 182 |
183 | 184 | 185 | ### Benchmark Metrics 186 | 187 | After obtaining the predicted disparities, we can compute the performance using `disp_scripts/benchmark.py`. 188 | This script loads the predicted and the ground truth disparities to compute the performance metrics. 189 | Inside the `disp_scripts/benchmark.py`, the path to the `results.npz` and the `ground_truth.npz` (both output of evaluate.py), and the output directory need to be specified. 190 | It is important to include the paths to the `results.npz` and the `ground_truth.npz` files that are generated by the same `evaluate.py` run since there might be some order mismatch otherwise. 191 | 192 | The results are printed to the console and written to a CSV in the output directory. 193 | -------------------------------------------------------------------------------- /disp_training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/deep_ev_tracker/fd09bcea0be9870905f4883466e2d16d8b45b84b/disp_training/__init__.py -------------------------------------------------------------------------------- /disp_training/disp_configs/data/m3ed.yaml: -------------------------------------------------------------------------------- 1 | name: m3ed 2 | 3 | _target_: disp_dataloader.m3ed_loader.M3EDDataModule 4 | data_dir: 5 | 6 | num_workers: 8 7 | 8 | # For tracks 9 | batch_size: 4 10 | 11 | n_train: 3000000 # Max number 12 | n_val: 10000 # Max number 13 | 14 | augment: True -------------------------------------------------------------------------------- /disp_training/disp_configs/m3ed_test.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: /${data.name}/${now:%Y-%m-%d_%H%M%S} 4 | 5 | # Composing nested config with default 6 | experiment: disparity_method 7 | 8 | patch_size: 63 9 | min_track_length: 20 10 | tracks_per_sample: 16 11 | disp_patch_range: 122 12 | 13 | checkpoint_path: 14 | 15 | defaults: 16 | - data: m3ed 17 | - model: correlation3_unscaled_disp 18 | -------------------------------------------------------------------------------- /disp_training/disp_configs/m3ed_train.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: /${data.name}/${model.name}/${experiment}/${now:%Y-%m-%d_%H%M%S} 4 | 5 | # Composing nested config with default 6 | experiment: disparity_evaluation 7 | 8 | patch_size: 63 9 | min_track_length: 20 10 | min_tracks_per_sample: 4 11 | max_tracks_per_sample: 12 12 | disp_patch_range: 122 13 | augment: True 14 | 15 | debug: False 16 | n_vis: 2 17 | logging: True 18 | 19 | # Train on M3ED 20 | checkpoint_path: none 21 | 22 | defaults: 23 | - data: m3ed 24 | - model: correlation3_unscaled_disp 25 | 26 | # Pytorch lightning trainer's argument 27 | trainer: 28 | benchmark: True 29 | log_every_n_steps: 10 30 | max_epochs: 40000 31 | num_sanity_val_steps: 10 -------------------------------------------------------------------------------- /disp_training/disp_configs/model/correlation3_unscaled_disp.yaml: -------------------------------------------------------------------------------- 1 | name: correlation3_unscaled_disp 2 | 3 | _target_: disp_model.correlation3_unscaled_disp.TrackerNetC 4 | feature_dim: 384 5 | defaults: 6 | - /optim/adam.yaml@optimizer -------------------------------------------------------------------------------- /disp_training/disp_configs/optim/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | lr: 1e-4 -------------------------------------------------------------------------------- /disp_training/disp_data_preparation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/deep_ev_tracker/fd09bcea0be9870905f4883466e2d16d8b45b84b/disp_training/disp_data_preparation/__init__.py -------------------------------------------------------------------------------- /disp_training/disp_data_preparation/feature_filter_depth.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import scipy.ndimage 4 | import matplotlib 5 | 6 | from disp_training.disp_data_preparation.track_storage import TrackStorage 7 | 8 | 9 | class FeatureFilter: 10 | def __init__(self, reprojection_thresh): 11 | self.forward_counter = 0 12 | self.backward_counter = 0 13 | self.max_corners = 1000 14 | self.quality_level = 0.001 # Lower means more points are detected 15 | self.minimum_distance = 31 16 | self.block_size = None 17 | self.k = 0.01 18 | self.minEigThreshold = 1e-7 # Lower threshold means less points are tracked 19 | 20 | self.track_storage = TrackStorage(reprojection_thresh) 21 | self.prev_img = None 22 | 23 | def forward_step(self, depth_image, image, T_W_C, K, frame_id): 24 | image = cv2.GaussianBlur(image, (3, 3), 0) 25 | if image.ndim == 3: 26 | gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 27 | else: 28 | gray_image = image 29 | # Visual Tracking 30 | if self.forward_counter == 0: 31 | # High Blurring since event camera has lower resolution and more noise 32 | detect_image = cv2.medianBlur(gray_image.copy(), 15) 33 | new_points = cv2.goodFeaturesToTrack(detect_image, 34 | self.max_corners, self.quality_level, self.minimum_distance, 35 | k=self.k, 36 | useHarrisDetector=True, 37 | blockSize=self.block_size).squeeze(1) 38 | else: 39 | # Tracking 40 | nextPoints = self.track_storage.tracked_points.astype(np.float32) 41 | if T_W_C is not None: 42 | projection_mask, projected_points = self.track_storage.get_projected_points_with_storage_idx( 43 | self.track_storage.tracked_point_to_storage_id, T_W_C, K) 44 | nextPoints[projection_mask, :] = projected_points 45 | 46 | tracked_points, status, err = cv2.calcOpticalFlowPyrLK(prevImg=self.prev_img, 47 | nextImg=gray_image, 48 | prevPts=self.track_storage.tracked_points.astype(np.float32), 49 | nextPts=nextPoints if nextPoints.shape[0] > 0 else None, 50 | flags=cv2.OPTFLOW_USE_INITIAL_FLOW, 51 | minEigThreshold=self.minEigThreshold, 52 | winSize=(25, 25), maxLevel=5) 53 | 54 | # Remove points tracked outside image 55 | if tracked_points is not None: 56 | _, in_frame_mask = self.get_image_points(tracked_points, gray_image.shape[0], gray_image.shape[1]) 57 | status *= in_frame_mask[:, None].astype(np.uint8) 58 | self.track_storage.update_tracked_points(tracked_points, status, depth_image, T_W_C, K, frame_id) 59 | 60 | # Detection 61 | # High Blurring since event camera has lower resolution and more noise 62 | detect_image = cv2.medianBlur(gray_image.copy(), 15) 63 | new_points = cv2.goodFeaturesToTrack(detect_image, 64 | self.max_corners, self.quality_level, self.minimum_distance, 65 | k=self.k, 66 | useHarrisDetector=True, 67 | blockSize=self.block_size) 68 | 69 | if new_points is not None: 70 | new_points = new_points.squeeze(1) 71 | 72 | if tracked_points is not None: 73 | tracked_points, _ = self.get_image_points(tracked_points[status.astype(bool).squeeze(1), :], 74 | image.shape[0], image.shape[1]) 75 | 76 | track_mask = self.create_nms_mask(tracked_points, image.shape[0], image.shape[1]) 77 | 78 | new_points_int = self.get_image_points(new_points, image.shape[0], image.shape[1])[0].astype(int) 79 | untracked_mask = track_mask[new_points_int[:, 1], new_points_int[:, 0]] 80 | new_points = new_points[untracked_mask, :] 81 | 82 | if new_points is not None and new_points.shape[0] > 0: 83 | self.track_storage.add_new_points(new_points, depth_image, T_W_C, K, frame_id) 84 | 85 | self.prev_img = gray_image 86 | self.forward_counter += 1 87 | 88 | viz_dict = None 89 | return viz_dict 90 | 91 | def backward_step(self, depth_image, image, T_W_C, K, frame_id): 92 | image = cv2.GaussianBlur(image, (3, 3), 0) 93 | if image.ndim == 3: 94 | gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 95 | else: 96 | gray_image = image 97 | 98 | if self.backward_counter == 0: 99 | self.prev_img = gray_image 100 | self.backward_counter += 1 101 | self.prev_frame_id = frame_id 102 | viz_dict = None 103 | return viz_dict 104 | 105 | # Image points tracked in the forward loop will be overwritten 106 | prev_points, prev_point_to_storage_id = self.track_storage.get_tracked_points_reverse(self.prev_frame_id) 107 | nextPoints = prev_points.copy() 108 | if T_W_C is not None: 109 | projection_mask, projected_points = self.track_storage.get_projected_points_with_storage_idx( 110 | prev_point_to_storage_id, T_W_C, K) 111 | nextPoints[projection_mask, :] = projected_points 112 | tracked_points, status, err = cv2.calcOpticalFlowPyrLK(prevImg=self.prev_img, 113 | nextImg=gray_image, 114 | prevPts=prev_points, 115 | nextPts=nextPoints, 116 | flags=cv2.OPTFLOW_USE_INITIAL_FLOW, 117 | minEigThreshold=self.minEigThreshold, 118 | winSize=(25, 25), maxLevel=5) 119 | _, in_frame_mask = self.get_image_points(tracked_points, gray_image.shape[0], gray_image.shape[1]) 120 | status *= in_frame_mask[:, None].astype(np.uint8) 121 | 122 | self.track_storage.tracked_point_to_storage_id = prev_point_to_storage_id 123 | self.track_storage.update_tracked_points(tracked_points, status, depth_image, T_W_C, K, frame_id) 124 | 125 | self.prev_img = gray_image 126 | self.backward_counter += 1 127 | self.prev_frame_id = frame_id 128 | 129 | # Visualization 130 | viz_dict = None 131 | 132 | return viz_dict 133 | 134 | def visualization_step(self, depth_image, image, T_W_C, K, frame_id): 135 | viz_image = image.copy() 136 | viz_dict = {} 137 | 138 | if T_W_C is not None: 139 | # Projected Points 140 | viz_projected_points, proj_points_to_storage_id = self.track_storage.get_projected_points_with_frame_id(frame_id, T_W_C, K) 141 | if viz_projected_points is not None: 142 | viz_projected_points, in_frame_mask = self.get_image_points(viz_projected_points, image.shape[0], image.shape[1]) 143 | proj_points_to_storage_id = proj_points_to_storage_id[in_frame_mask] 144 | viz_dict['projected_features'] = self.draw_corners(viz_projected_points, viz_image, 145 | proj_points_to_storage_id) 146 | 147 | # Tracked Points 148 | viz_tracked_points, tracked_points_to_storage_id = self.track_storage.get_points_with_frame_id(frame_id) 149 | if viz_projected_points is not None: 150 | viz_tracked_points, in_frame_mask = self.get_image_points(viz_tracked_points, image.shape[0], image.shape[1]) 151 | tracked_points_to_storage_id = tracked_points_to_storage_id[in_frame_mask] 152 | viz_dict['tracked_features'] = self.draw_corners(viz_tracked_points, viz_image, tracked_points_to_storage_id) 153 | 154 | if viz_projected_points is not None and viz_tracked_points is not None: 155 | viz_dict['combined_features'] = self.draw_corners(np.concatenate([viz_projected_points, viz_tracked_points], axis=0), 156 | viz_image, 157 | np.concatenate([proj_points_to_storage_id, 158 | tracked_points_to_storage_id], axis=0)) 159 | 160 | return viz_dict 161 | 162 | def get_image_points(self, points, height, width): 163 | in_frame_mask = np.logical_and(points[:, 0] < width, 164 | points[:, 0] > 0) 165 | in_frame_mask = np.logical_and(in_frame_mask, points[:, 1] < height) 166 | in_frame_mask = np.logical_and(in_frame_mask, points[:, 1] > 0) 167 | 168 | return points[in_frame_mask, :], in_frame_mask 169 | 170 | def create_nms_mask(self, tracked_points, img_h, img_w): 171 | track_mask = np.zeros([img_h, img_w]) 172 | track_mask[tracked_points[:, 1].astype(int), 173 | tracked_points[:, 0].astype(int)] = 1 174 | pooled_track_mask = scipy.ndimage.maximum_filter(track_mask, 175 | (self.minimum_distance * 2, self.minimum_distance * 2)) 176 | pooled_track_mask[:self.minimum_distance, :] = 1 177 | pooled_track_mask[-self.minimum_distance:, :] = 1 178 | pooled_track_mask[:, :self.minimum_distance] = 1 179 | pooled_track_mask[:, -self.minimum_distance:] = 1 180 | track_mask = (1 - pooled_track_mask).astype(bool) 181 | 182 | return track_mask 183 | 184 | def draw_corners(self, corners, image, corner_track_id=None): 185 | if corner_track_id is not None: 186 | cmap = matplotlib.cm.get_cmap('hsv') 187 | colors = [] 188 | for track_id in corner_track_id: 189 | track_id = track_id % 256 190 | color = cmap(track_id) 191 | colors.append((color[0] * 255, color[1] * 255, color[2] * 255)) 192 | else: 193 | colors = [(0, 255, 0) for _ in range(len(corners))] 194 | if image.shape[2] == 1: 195 | image = np.tile(image, [1, 1, 3]) 196 | 197 | viz_img = image.copy() 198 | corners = corners.astype('int') 199 | for i_point, point in enumerate(corners): 200 | cv2.circle(viz_img, tuple(point), radius=3, color=colors[i_point], thickness=-1) 201 | 202 | return viz_img 203 | -------------------------------------------------------------------------------- /disp_training/disp_data_preparation/track_storage.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from disp_training.disp_data_preparation.utils import project_points, reproject_points 4 | 5 | 6 | class TrackStorage: 7 | def __init__(self, reprojection_thresh): 8 | self.track_storage = [] 9 | self.track_counter = 0 10 | self.tracked_points = np.zeros([0, 2], dtype=np.float32) 11 | self.tracked_point_to_storage_id = np.zeros([0], dtype='int') 12 | self.reprojection_thresh = reprojection_thresh 13 | 14 | def add_new_points(self, new_points, depth_image, T_W_C, K, frame_id): 15 | if T_W_C is not None: 16 | P_W_array, points_with_depth_mask = reproject_points(new_points[:, 0], new_points[:, 1], 17 | depth_image, np.linalg.inv(T_W_C), K) 18 | else: 19 | points_with_depth_mask = np.zeros([new_points.shape[0]], dtype='bool') 20 | 21 | new_points_to_track_id = [] 22 | P_W_counter = 0 23 | for i_point in range(new_points.shape[0]): 24 | if points_with_depth_mask[i_point]: 25 | P_W_list = [P_W_array[P_W_counter]] 26 | P_W_counter += 1 27 | else: 28 | P_W_list = [] 29 | track_dict = { 30 | 'P_W': P_W_list, 31 | 'track': {int(frame_id): new_points[i_point, :]} 32 | } 33 | self.track_storage.append(track_dict) 34 | new_points_to_track_id.append(self.track_counter) 35 | self.track_counter += 1 36 | 37 | self.tracked_points = np.concatenate([self.tracked_points, new_points], axis=0) 38 | self.tracked_point_to_storage_id = np.concatenate([self.tracked_point_to_storage_id, 39 | np.array(new_points_to_track_id)], axis=0) 40 | 41 | def update_tracked_points(self, tracked_points, status, depth_image, T_W_C, K, frame_id): 42 | success_tracking = status.squeeze(1).astype('bool') 43 | self.tracked_points = tracked_points[success_tracking, :] 44 | self.tracked_point_to_storage_id = self.tracked_point_to_storage_id[success_tracking] 45 | 46 | # Triangulate 3D points for tracked points with depth information and make reprojection error check 47 | if T_W_C is not None: 48 | # Reprojection error check 49 | valid_tracks = np.ones([self.tracked_points.shape[0]], dtype='bool') 50 | valid_P_W = np.zeros([self.tracked_points.shape[0]], dtype='bool') 51 | P_W_tracked_points = np.zeros([self.tracked_points.shape[0], 4, 1]) 52 | for i_point in range(self.tracked_points.shape[0]): 53 | storage_id = self.tracked_point_to_storage_id[i_point] 54 | # Get 3D point corresponding to track 55 | if len(self.track_storage[storage_id]['P_W']) == 0: 56 | continue 57 | 58 | P_W_track = np.stack(self.track_storage[storage_id]['P_W'], axis=0) 59 | P_W_mean = self.get_aggregated_P_W(P_W_track) 60 | P_W_tracked_points[i_point] = P_W_mean 61 | valid_P_W[i_point] = True 62 | 63 | if valid_P_W.sum() != 0: 64 | projected_P_W, in_front_camera_mask = project_points(P_W_tracked_points[valid_P_W, :, 0], 65 | np.linalg.inv(T_W_C), K) 66 | if in_front_camera_mask.sum() != 0: 67 | project_point_idx = valid_P_W.nonzero()[0][in_front_camera_mask] 68 | reprojection_error = np.sqrt(((projected_P_W[:, :2] - self.tracked_points[project_point_idx, :])**2).sum(1)) 69 | valid_tracks[project_point_idx] *= reprojection_error < self.reprojection_thresh 70 | 71 | # Add 3D point if depth information is available 72 | P_W_array, points_with_depth_mask = reproject_points(self.tracked_points[:, 0], self.tracked_points[:, 1], 73 | depth_image, np.linalg.inv(T_W_C), K) 74 | points_with_depth_idx = points_with_depth_mask.nonzero()[0] 75 | for i_point in range(self.tracked_points.shape[0]): 76 | if points_with_depth_mask[i_point] and valid_tracks[i_point]: 77 | storage_id = self.tracked_point_to_storage_id[i_point] 78 | self.track_storage[storage_id]['P_W'].append(P_W_array[points_with_depth_idx == i_point, :, :].squeeze(0)) 79 | 80 | self.tracked_points = self.tracked_points[valid_tracks, :] 81 | self.tracked_point_to_storage_id = self.tracked_point_to_storage_id[valid_tracks] 82 | 83 | # Add 2D point 84 | for i_point in range(self.tracked_points.shape[0]): 85 | storage_id = self.tracked_point_to_storage_id[i_point] 86 | self.track_storage[storage_id]['track'][int(frame_id)] = self.tracked_points[i_point] 87 | 88 | def get_tracked_points_reverse(self, prev_frame_id): 89 | prev_points = [] 90 | prev_point_to_storage_id = [] 91 | for i_track, track_dict in enumerate(self.track_storage): 92 | if prev_frame_id in track_dict['track']: 93 | prev_points.append(track_dict['track'][prev_frame_id]) 94 | prev_point_to_storage_id.append(i_track) 95 | 96 | return np.stack(prev_points), np.array(prev_point_to_storage_id) 97 | 98 | def get_aggregated_P_W(self, P_W_track): 99 | nr_P_W = P_W_track.shape[0] 100 | P_W_distances = ((P_W_track[:, None, :3, 0] - P_W_track[None, :, :3, 0]) ** 2).sum(2) 101 | P_W_distances = P_W_distances.mean(axis=1) 102 | nr_cluster_samples = max((nr_P_W // 3) * 2, 1) 103 | cluster_P_W_idx = P_W_distances.argsort()[:nr_cluster_samples] 104 | P_W_mean = np.mean(P_W_track[cluster_P_W_idx], axis=0) 105 | 106 | return P_W_mean 107 | 108 | def get_points_with_frame_id(self, frame_id): 109 | points = [] 110 | points_to_storage_id = [] 111 | for i_track, track_dict in enumerate(self.track_storage): 112 | if len(track_dict['P_W']) == 0: 113 | continue 114 | if frame_id in track_dict['track']: 115 | points.append(track_dict['track'][frame_id]) 116 | points_to_storage_id.append(i_track) 117 | 118 | if len(points) == 0: 119 | return None, None 120 | 121 | return np.stack(points), np.array(points_to_storage_id) 122 | 123 | def get_projected_points_with_frame_id(self, frame_id, T_W_C, K): 124 | points_to_storage_id = [] 125 | P_W = [] 126 | for i_track, track_dict in enumerate(self.track_storage): 127 | if len(track_dict['P_W']) == 0: 128 | continue 129 | if frame_id in track_dict['track']: 130 | P_W_track = np.stack(track_dict['P_W'], axis=0) 131 | P_W.append(self.get_aggregated_P_W(P_W_track)) 132 | points_to_storage_id.append(i_track) 133 | 134 | if len(P_W) == 0: 135 | return None, None 136 | 137 | P_W = np.stack(P_W, axis=0) 138 | projected_points, in_front_camera_mask = project_points(P_W[:, :, 0], np.linalg.inv(T_W_C), K) 139 | 140 | return projected_points[:, :2], np.array(points_to_storage_id)[in_front_camera_mask] 141 | 142 | def get_projected_points_with_storage_idx(self, storage_idxs, T_W_C, K): 143 | P_W = [] 144 | projection_mask = np.zeros([storage_idxs.shape[0]], dtype='bool') 145 | for i_idx, storage_idx in enumerate(storage_idxs): 146 | track_dict = self.track_storage[storage_idx] 147 | if len(track_dict['P_W']) == 0: 148 | continue 149 | P_W_track = np.stack(track_dict['P_W'], axis=0) 150 | P_W.append(self.get_aggregated_P_W(P_W_track)) 151 | projection_mask[i_idx] = True 152 | 153 | if len(P_W) == 0: 154 | return projection_mask, None 155 | 156 | P_W = np.stack(P_W, axis=0) 157 | projected_points, in_front_camera_mask = project_points(P_W[:, :, 0], np.linalg.inv(T_W_C), K) 158 | projection_mask[projection_mask] = in_front_camera_mask 159 | 160 | return projection_mask, projected_points[:, :2] 161 | -------------------------------------------------------------------------------- /disp_training/disp_data_preparation/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib 4 | 5 | 6 | def compute_rectify_map(target_group, source_group): 7 | """Adapted from: https://github.com/daniilidis-group/m3ed/blob/main/build_system/semantics/internimage.py""" 8 | target_T_to_prophesee_left = target_group['T_to_prophesee_left'][...] 9 | source_T_to_prophesee_left = source_group['T_to_prophesee_left'][...] 10 | 11 | source_T_target = source_T_to_prophesee_left @ np.linalg.inv( target_T_to_prophesee_left ) 12 | target_dist_coeffs = target_group['distortion_coeffs'][...] 13 | target_intrinsics = target_group['intrinsics'][...] 14 | target_res = target_group['resolution'][...] 15 | target_Size = target_res 16 | 17 | target_K = np.eye(3) 18 | target_K[0,0] = target_intrinsics[0] 19 | target_K[1,1] = target_intrinsics[1] 20 | target_K[0,2] = target_intrinsics[2] 21 | target_K[1,2] = target_intrinsics[3] 22 | 23 | target_P = np.zeros((3,4)) 24 | target_P[:3,:3] = target_K 25 | 26 | source_dist_coeffs = source_group['distortion_coeffs'][...] 27 | source_intrinsics = source_group['intrinsics'][...] 28 | source_res = source_group['resolution'][...] 29 | source_Size = source_res 30 | 31 | source_K = np.eye(3) 32 | source_K[0,0] = source_intrinsics[0] 33 | source_K[1,1] = source_intrinsics[1] 34 | source_K[0,2] = source_intrinsics[2] 35 | source_K[1,2] = source_intrinsics[3] 36 | 37 | # Image is already undistorted, this only works for the M3ED loading 38 | target_dist_coeffs *= 0 39 | out = cv2.stereoRectify(cameraMatrix1=source_K, 40 | distCoeffs1=source_dist_coeffs, 41 | cameraMatrix2=target_K, 42 | distCoeffs2=target_dist_coeffs, 43 | imageSize=target_Size, 44 | newImageSize=target_Size, 45 | T=source_T_target[:3, 3], 46 | R=source_T_target[:3, :3], 47 | alpha=0, 48 | ) 49 | rot_source, rot_target, proj_source, proj_target, Q, validPixROI1, validPixROI2 = out 50 | map_target = np.stack(cv2.initUndistortRectifyMap(target_K, target_dist_coeffs, rot_target, proj_target, source_Size, cv2.CV_32FC1), axis=-1) 51 | map_source = np.stack(cv2.initUndistortRectifyMap(source_K, source_dist_coeffs, rot_source, proj_source, source_Size, cv2.CV_32FC1), axis=-1) 52 | 53 | inv_map_target = np.stack(cv2.initInverseRectificationMap(target_K, target_dist_coeffs, rot_target, proj_target, target_Size, cv2.CV_32FC1), axis=-1) 54 | inv_map_source = invert_map(map_source) 55 | 56 | return map_target, map_source, inv_map_target, inv_map_source, proj_target, proj_source, rot_target, rot_source, Q 57 | 58 | 59 | def invert_map(F): 60 | # shape is (h, w, 2), a "xymap" 61 | (h, w) = F.shape[:2] 62 | I = np.zeros_like(F) 63 | I[:,:,1], I[:,:,0] = np.indices((h, w)) # identity map 64 | P = np.copy(I) 65 | for i in range(10): 66 | correction = I - cv2.remap(F, P, None, interpolation=cv2.INTER_LINEAR) 67 | P += correction * 0.5 68 | return P 69 | 70 | 71 | def reproject_points(u, v, depth_image, T_C_W, K, z=None): 72 | if z is None: 73 | z = depth_image[v.astype('int'), u.astype('int')] 74 | valid_mask = z != 0 75 | 76 | u, v, z = u[valid_mask], v[valid_mask], z[valid_mask] 77 | p = np.stack([u, v, np.ones([u.shape[0]])], axis=1) 78 | P_C = np.matmul(np.linalg.inv(K)[None, :, :], p[:, :, None]) * z[:, None, None] 79 | P_W = np.matmul(np.linalg.inv(T_C_W), 80 | np.concatenate([P_C, np.ones([P_C.shape[0], 1, 1])], axis=1)) 81 | 82 | return P_W, valid_mask 83 | 84 | 85 | def project_points(P_W, T_C_W, K): 86 | assert T_C_W.ndim == 2 87 | P_C = np.matmul(T_C_W[:3, :], P_W[:, :, None]) 88 | in_front_camera_mask = P_C[:, 2, 0] > 0 89 | P_C = P_C[in_front_camera_mask, :, :] 90 | 91 | points = np.matmul(K[:, :], P_C).squeeze(-1) 92 | points = points / points[:, 2, None] 93 | 94 | return points, in_front_camera_mask 95 | 96 | 97 | def visualize_depth_image(depth_image, rgb_image, file_path): 98 | rgb_image = rgb_image.copy() 99 | cmap = matplotlib.colormaps.get_cmap('gist_ncar') 100 | v, u = np.nonzero(depth_image != 0) 101 | img_depth = depth_image[v, u] 102 | depth_colors = cmap((img_depth - img_depth.min()) / (img_depth.max() - img_depth.min()))[:, :3] 103 | rgb_image[v, u, :] = depth_colors * 255 104 | cv2.imwrite(file_path, rgb_image) 105 | 106 | 107 | def remap_events(events, map, rotate, shape=None, valid_region=None): 108 | mx, my = map 109 | x, y = mx[events['y'], events['x']], my[events['y'], events['x']] 110 | p = np.array(events['p']) 111 | t = np.array(events['t']) 112 | 113 | if rotate: 114 | target_width, target_height = shape 115 | x = target_width - 1 - x 116 | y = target_height - 1 - y 117 | 118 | if valid_region is not None: 119 | mask = ((x >= valid_region[0]) & (x < valid_region[2]) & 120 | (y >= valid_region[1]) & (y < valid_region[3])) 121 | x = x - valid_region[0] 122 | y = y - valid_region[1] 123 | else: 124 | target_width, target_height = shape 125 | mask = (x >= 0) & (x <= target_width - 1) & (y >= 0) & (y <= target_height - 1) 126 | 127 | return {'x': x[mask], 'y': y[mask], 't': t[mask], 'p': p[mask]} 128 | -------------------------------------------------------------------------------- /disp_training/disp_dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/deep_ev_tracker/fd09bcea0be9870905f4883466e2d16d8b45b84b/disp_training/disp_dataloader/__init__.py -------------------------------------------------------------------------------- /disp_training/disp_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/deep_ev_tracker/fd09bcea0be9870905f4883466e2d16d8b45b84b/disp_training/disp_model/__init__.py -------------------------------------------------------------------------------- /disp_training/disp_model/correlation3_unscaled_disp.py: -------------------------------------------------------------------------------- 1 | import torch.nn.init 2 | 3 | from disp_model.disp_template import DistTemplate 4 | from models.common import * 5 | from utils.losses import * 6 | 7 | 8 | class FPNEncoder(nn.Module): 9 | def __init__(self, in_channels=1, out_channels=512, recurrent=False): 10 | super(FPNEncoder, self).__init__() 11 | self.conv_bottom_0 = ConvBlock3(in_channels=in_channels, out_channels=32, n_convs=2, 12 | kernel_size=1, padding=0, downsample=False) 13 | 14 | self.conv_bottom_0_2 = ConvBlock3(in_channels=32, out_channels=32, n_convs=2, 15 | kernel_size=3, padding=1, downsample=True) 16 | 17 | self.conv_bottom_1 = ConvBlock3(in_channels=32, out_channels=64, n_convs=2, 18 | kernel_size=5, padding=0, 19 | downsample=False) 20 | self.conv_bottom_2 = ConvBlock3(in_channels=64, out_channels=128, n_convs=2, 21 | kernel_size=5, padding=0, 22 | downsample=False) 23 | self.conv_bottom_3 = ConvBlock3(in_channels=128, out_channels=256, n_convs=2, 24 | kernel_size=3, padding=0, 25 | downsample=True) 26 | self.conv_bottom_4 = ConvBlock3(in_channels=256, out_channels=out_channels, n_convs=2, 27 | kernel_size=3, padding=0, downsample=False) 28 | 29 | self.recurrent = recurrent 30 | if self.recurrent: 31 | self.conv_rnn = ConvLSTMCell(out_channels, out_channels, 1) 32 | 33 | self.conv_lateral_3 = nn.Conv2d(in_channels=256, out_channels=out_channels, kernel_size=1, bias=True) 34 | self.conv_lateral_2 = nn.Conv2d(in_channels=128, out_channels=out_channels, kernel_size=1, bias=True) 35 | self.conv_lateral_1 = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=1, bias=True) 36 | self.conv_lateral_0 = nn.Conv2d(in_channels=32, out_channels=out_channels, kernel_size=1, bias=True) 37 | 38 | self.conv_dealias_3 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=True) 39 | self.conv_dealias_2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=True) 40 | self.conv_dealias_1 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=True) 41 | self.conv_dealias_0 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=True) 42 | self.conv_out = nn.Sequential( 43 | ConvBlock3(in_channels=out_channels, out_channels=out_channels, 44 | n_convs=1, kernel_size=3, padding=1, downsample=False), 45 | nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, 46 | bias=True) 47 | ) 48 | 49 | self.conv_bottleneck_out = nn.Sequential( 50 | ConvBlock3(in_channels=out_channels, out_channels=out_channels, 51 | n_convs=1, kernel_size=3, padding=1, downsample=False), 52 | nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, 53 | bias=True) 54 | ) 55 | 56 | def reset(self): 57 | if self.recurrent: 58 | self.conv_rnn.reset() 59 | 60 | def forward(self, x): 61 | """ 62 | :param x: 63 | :return: (highest res feature map, lowest res feature map) 64 | """ 65 | 66 | # Bottom-up pathway 67 | c0_0 = self.conv_bottom_0(x) # 61x61 Tracker: 31x31 68 | 69 | c0 = self.conv_bottom_0_2(c0_0) # 31x31 70 | 71 | c1 = self.conv_bottom_1(c0) # 23x23 72 | c2 = self.conv_bottom_2(c1) # 15x15 73 | c3 = self.conv_bottom_3(c2) # 5x5 74 | c4 = self.conv_bottom_4(c3) # 1x1 75 | 76 | # Top-down pathway (with lateral cnx and de-aliasing) 77 | p4 = c4 78 | p3 = self.conv_dealias_3(self.conv_lateral_3(c3) + F.interpolate(p4, (c3.shape[2], c3.shape[3]), mode='bilinear')) 79 | p2 = self.conv_dealias_2(self.conv_lateral_2(c2) + F.interpolate(p3, (c2.shape[2], c2.shape[3]), mode='bilinear')) 80 | p1 = self.conv_dealias_1(self.conv_lateral_1(c1) + F.interpolate(p2, (c1.shape[2], c1.shape[3]), mode='bilinear')) 81 | p0 = self.conv_dealias_0(self.conv_lateral_0(c0) + F.interpolate(p1, (c0.shape[2], c0.shape[3]), mode='bilinear')) 82 | 83 | if self.recurrent: 84 | p0 = self.conv_rnn(p0) 85 | 86 | return self.conv_out(p0), self.conv_bottleneck_out(c4) 87 | 88 | 89 | class JointEncoder(nn.Module): 90 | def __init__(self, in_channels, out_channels): 91 | super(JointEncoder, self).__init__() 92 | 93 | self.conv1 = ConvBlock3(in_channels=in_channels, out_channels=64, n_convs=2, downsample=True) 94 | self.conv2 = ConvBlock3(in_channels=64, out_channels=128, n_convs=2, downsample=True) 95 | self.convlstm0 = ConvLSTMCell(128, 128, 3) 96 | self.conv3 = ConvBlock3(in_channels=128, out_channels=256, n_convs=2, downsample=True) 97 | self.conv4 = ConvBlock3(in_channels=256, out_channels=256, kernel_size=3, padding=0, 98 | n_convs=1, downsample=False) 99 | 100 | # Transformer Addition 101 | self.flatten = nn.Flatten() 102 | embed_dim = 256 103 | num_heads = 8 104 | self.multihead_attention0 = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) 105 | 106 | self.prev_x_res = None 107 | self.gates = nn.Linear(2*embed_dim, embed_dim) 108 | self.ls_layer = LayerScale(embed_dim) 109 | 110 | # Attention Mask Transformer 111 | self.fusion_layer0 = nn.Sequential(nn.Linear(embed_dim*2, embed_dim), 112 | nn.LeakyReLU(0.1), 113 | nn.Linear(embed_dim, embed_dim), 114 | nn.LeakyReLU(0.1) 115 | ) 116 | self.output_layers = nn.Sequential(nn.Linear(embed_dim, 512), 117 | nn.LeakyReLU(0.1) 118 | ) 119 | 120 | def reset(self): 121 | self.convlstm0.reset() 122 | self.prev_x_res = None 123 | 124 | def forward(self, x, attn_mask=None): 125 | x = self.conv1(x) 126 | x = self.conv2(x) 127 | x = self.convlstm0(x) 128 | x = self.conv3(x) 129 | x = self.conv4(x) 130 | x = self.flatten(x) 131 | 132 | if self.prev_x_res is None: 133 | self.prev_x_res = Variable(torch.zeros_like(x)) 134 | 135 | x = self.fusion_layer0(torch.cat((x, self.prev_x_res), 1)) 136 | 137 | x_attn = x[None, :, :].detach() 138 | x_attn = self.multihead_attention0(query=x_attn, 139 | key=x_attn, 140 | value=x_attn, 141 | attn_mask=attn_mask.bool())[0].squeeze(0) 142 | x = x + self.ls_layer(x_attn) 143 | 144 | gate_weight = torch.sigmoid(self.gates(torch.cat((self.prev_x_res, x), 1))) 145 | x = self.prev_x_res * gate_weight + x * (1 - gate_weight) 146 | 147 | self.prev_x_res = x 148 | 149 | x = self.output_layers(x) 150 | 151 | return x 152 | 153 | 154 | class LayerScale(nn.Module): 155 | def __init__(self, dim, init_values=1e-5, inplace=False): 156 | super().__init__() 157 | self.inplace = inplace 158 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 159 | 160 | def forward(self, x): 161 | gamma = self.gamma 162 | return x.mul_(gamma) if self.inplace else x * gamma 163 | 164 | 165 | class TrackerNetC(DistTemplate): 166 | def __init__(self, min_track_length, disp_patch_range, n_vis=8, patch_size=31, feature_dim=1024, **kwargs): 167 | super(TrackerNetC, self).__init__(min_track_length=min_track_length, n_vis=n_vis, patch_size=patch_size, 168 | **kwargs) 169 | # Configuration 170 | self.grayscale_ref = True 171 | self.channels_in_per_patch = 10 172 | 173 | # Architecture 174 | self.disp_patch_range = disp_patch_range 175 | self.feature_dim = feature_dim 176 | self.redir_dim = 128 177 | 178 | self.reference_encoder = FPNEncoder(3, self.feature_dim) 179 | 180 | self.target_encoder = FPNEncoder(self.channels_in_per_patch, self.feature_dim) 181 | 182 | # Correlation3 had k=1, p=0 183 | self.reference_redir = nn.Conv2d(self.feature_dim, self.redir_dim, kernel_size=3, padding=1) 184 | self.target_redir = nn.Conv2d(self.feature_dim, self.redir_dim, kernel_size=3, padding=1) 185 | self.softmax = nn.Softmax(dim=2) 186 | 187 | self.joint_encoder = JointEncoder(in_channels=1+2*self.redir_dim, out_channels=512) 188 | 189 | # Disp Adjustment 190 | self.target_vertical_reshape = nn.Conv2d(self.redir_dim + 1, self.redir_dim + 1, 191 | stride=(2, 1), kernel_size=3, padding=1) 192 | self.predictor = nn.Linear(in_features=512, out_features=2, bias=False) 193 | 194 | # Operational 195 | self.loss = nn.L1Loss(reduction='none') 196 | 197 | self.correlation_maps = [] 198 | self.inputs = [] 199 | self.refs = [] 200 | 201 | def init_weights(self): 202 | torch.nn.init.xavier_uniform(self.fc_out.weight) 203 | 204 | def reset(self, _): 205 | self.joint_encoder.reset() 206 | 207 | def forward(self, frame_patches, event_patches, attn_mask=None): 208 | # Feature Extraction 209 | e_f0, _ = self.target_encoder(event_patches) 210 | f_f0, f_bottleneck = self.reference_encoder(frame_patches) 211 | f_f0 = self.reference_redir(f_f0) 212 | 213 | # Correlation and softmax 214 | f_corr = (e_f0 * f_bottleneck).sum(dim=1, keepdim=True) 215 | f_corr = self.softmax(f_corr.view(-1, 1, self.disp_patch_range // 2 * 31)).view(-1, 1, self.disp_patch_range // 2, 31) 216 | 217 | # Feature re-direction 218 | e_f0 = self.target_redir(e_f0) 219 | 220 | # Disp Adjustment 221 | f_combined_corr_e_f0 = self.target_vertical_reshape(torch.cat([e_f0, f_corr], dim=1)) 222 | 223 | f = torch.cat([f_combined_corr_e_f0, f_f0], dim=1) 224 | f = self.joint_encoder(f, attn_mask) 225 | 226 | f = self.predictor(f) 227 | 228 | return f 229 | -------------------------------------------------------------------------------- /disp_training/disp_model/disp_template.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch.optim.lr_scheduler 3 | from pytorch_lightning import LightningModule 4 | import numpy as np 5 | 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | from matplotlib import cm 10 | 11 | from utils.losses import * 12 | 13 | 14 | class DistTemplate(LightningModule): 15 | def __init__(self, min_track_length, n_vis=8, patch_size=31, debug=True, **kwargs): 16 | super(DistTemplate, self).__init__() 17 | self.save_hyperparameters() 18 | 19 | # High level model config 20 | self.patch_size = patch_size 21 | self.min_track_length = min_track_length 22 | self.debug = debug 23 | 24 | # Determine num channels from representation name 25 | self.channels_in_per_patch = 1 26 | 27 | # Loss Function 28 | self.loss = None 29 | 30 | # Training variables 31 | self.n_vis = n_vis 32 | self.colormap = cm.get_cmap('inferno') 33 | self.graymap = cm.get_cmap('gray') 34 | 35 | def create_attention_mask(self, seq_frame_idx): 36 | attn_mask = torch.from_numpy(seq_frame_idx[:, None] == seq_frame_idx[None, :]).to(self.device) 37 | attn_mask = torch.logical_not(attn_mask).bool() 38 | 39 | return attn_mask 40 | 41 | def configure_optimizers(self): 42 | if not self.debug: 43 | opt = hydra.utils.instantiate(self.hparams.optimizer, params=self.parameters()) 44 | return {'optimizer': opt, 45 | 'lr_scheduler': { 46 | "scheduler": torch.optim.lr_scheduler.OneCycleLR(opt, self.hparams.optimizer.lr, 47 | total_steps=1000000, 48 | pct_start=0.002), 49 | "interval": "step", 50 | "frequency": 1, 51 | "strict": True, 52 | "name": "lr"} 53 | } 54 | else: 55 | return hydra.utils.instantiate(self.hparams.optimizer, params=self.parameters()) 56 | 57 | def forward(self, frame_patches, event_patches, attn_mask=None): 58 | return None 59 | 60 | def on_train_epoch_end(self, *args): 61 | return 62 | 63 | def training_step(self, batch_sample, batch_nb): 64 | # Get data 65 | seq_frame_patches, seq_event_patches, seq_y_disps, seq_frame_idx = batch_sample 66 | n_samples = seq_frame_patches.shape[0] 67 | 68 | # Create attention mask for frame attention module 69 | attn_mask = self.create_attention_mask(seq_frame_idx) 70 | 71 | seq_frame_patches = torch.from_numpy(seq_frame_patches).permute([0, 1, 4, 2, 3]).to(self.device) 72 | seq_event_patches = torch.from_numpy(seq_event_patches).permute([0, 1, 4, 2, 3]).to(self.device) 73 | seq_y_disps = torch.from_numpy(seq_y_disps).to(self.device) 74 | 75 | # Unroll network 76 | loss_total = torch.zeros(n_samples, dtype=torch.float32, device=self.device) 77 | self.reset(n_samples) 78 | 79 | for i_unroll in range(self.min_track_length): 80 | # Get data 81 | frame_patches = seq_frame_patches[:, i_unroll, :, :, :] 82 | event_patches = seq_event_patches[:, i_unroll, :, :, :] 83 | y_disps = seq_y_disps[:, i_unroll] 84 | 85 | # Inference 86 | y_disp_pred = self.forward(frame_patches, event_patches, attn_mask) 87 | 88 | # Accumulate losses 89 | loss = self.loss(y_disps, y_disp_pred[:, 1]) 90 | loss_total += loss 91 | 92 | loss_total = loss_total.mean() / self.min_track_length 93 | 94 | self.log("loss/train", loss_total, on_step=True, on_epoch=True, prog_bar=True, batch_size=1) 95 | 96 | return loss_total 97 | 98 | def on_validation_epoch_start(self): 99 | self.metrics = {'disp_error': []} 100 | 101 | def validation_step(self, batch_sample, batch_nb): 102 | # Get data 103 | seq_frame_patches, seq_event_patches, seq_y_disps, seq_frame_idx = batch_sample 104 | n_samples = seq_frame_patches.shape[0] 105 | 106 | seq_frame_patches = torch.from_numpy(seq_frame_patches).permute([0, 1, 4, 2, 3]).to(self.device) 107 | seq_event_patches = torch.from_numpy(seq_event_patches).permute([0, 1, 4, 2, 3]).to(self.device) 108 | seq_y_disps = torch.from_numpy(seq_y_disps).to(self.device) 109 | 110 | # Flow history visualization for first batch 111 | if batch_nb == 0: 112 | x_hat_hist = [] 113 | x_ref_hist = [] 114 | 115 | # Unroll network 116 | loss_total = torch.zeros(n_samples, dtype=torch.float32, device=self.device) 117 | self.reset(n_samples) 118 | 119 | # Create attention mask for frame attention module 120 | attn_mask = self.create_attention_mask(seq_frame_idx) 121 | 122 | # Rollout 123 | seq_disp_error = np.zeros([n_samples, self.min_track_length]) 124 | for i_unroll in range(self.min_track_length): 125 | # Construct x and y 126 | # Get data 127 | frame_patches = seq_frame_patches[:, i_unroll, :, :, :] 128 | event_patches = seq_event_patches[:, i_unroll, :, :, :] 129 | y_disps = seq_y_disps[:, i_unroll] 130 | 131 | # Inference 132 | y_disp_pred = self.forward(frame_patches, event_patches, attn_mask) 133 | 134 | loss = self.loss(y_disps, y_disp_pred[:, 1]) 135 | loss_total += loss 136 | 137 | # Patch visualizations for first batch 138 | if batch_nb == 0: 139 | x_hat_hist.append(torch.max(event_patches[0, :, :, :], dim=0, keepdim=True)[0].detach().clone()) 140 | x_ref_hist.append(frame_patches[0, 0, None, :, :].detach().clone()) 141 | 142 | seq_disp_error[:, i_unroll] = torch.abs(y_disps - y_disp_pred[:, 1]).detach().cpu().numpy() 143 | 144 | self.metrics['disp_error'].append(seq_disp_error) 145 | 146 | loss_total = loss_total.mean() / self.min_track_length 147 | # Log loss for both training modes 148 | self.log("loss/val", loss_total.detach(), on_step=False, on_epoch=True, prog_bar=True, batch_size=1) 149 | 150 | # Log predicted patches for both training modes 151 | if batch_nb == 0: 152 | for i_vis in range(self.n_vis): 153 | # Patches 154 | ev_patch = x_hat_hist[i_vis].cpu().squeeze(0).numpy() 155 | ev_patch = torch.from_numpy(self.colormap(ev_patch)[:, :, :3]) 156 | self.logger.experiment.add_image(f'input/event_patch_{i_vis}', 157 | ev_patch, 158 | self.global_step, dataformats='HWC') 159 | 160 | # Reference 161 | img_patch = x_ref_hist[i_vis].cpu().squeeze(0).numpy() 162 | img_patch = torch.from_numpy(self.graymap(img_patch)[:, :, :3]) 163 | self.logger.experiment.add_image(f'input/frame_patch_{i_vis}', 164 | img_patch, 165 | self.global_step, dataformats='HWC') 166 | 167 | return loss_total 168 | 169 | def on_validation_epoch_end(self): 170 | # Disparity error visualization 171 | disp_errors = np.concatenate(self.metrics['disp_error'], axis=0) 172 | 173 | # Cumulative Error Plot 174 | with plt.style.context('ggplot'): 175 | fig = plt.figure() 176 | x, counts = np.unique(disp_errors[:, -1], return_counts=True) 177 | y = np.cumsum(counts) / np.sum(counts) 178 | ax = fig.add_subplot() 179 | ax.plot(x, y) 180 | ax.set_xlabel('EPE (px)') 181 | ax.set_ylabel('Proportion') 182 | self.logger.experiment.add_figure("cumulative_error/val", fig, self.global_step) 183 | plt.close("all") 184 | 185 | # Mean Error Plot 186 | with plt.style.context('ggplot'): 187 | fig = plt.figure() 188 | mean_error = disp_errors.mean(axis=0) 189 | 190 | y = mean_error 191 | x = np.arange(mean_error.shape[0]) 192 | ax = fig.add_subplot() 193 | ax.plot(x, y) 194 | ax.set_xlabel('Timesteps') 195 | ax.set_ylabel('Error') 196 | self.logger.experiment.add_figure("mean_error_seq/val", fig, self.global_step) 197 | plt.close("all") 198 | 199 | self.log("mean_epe/val", np.mean(disp_errors[:, -1])) 200 | self.log("var_epe/val", np.var(disp_errors[:, -1])) 201 | self.log("mean_error/val", np.mean(disp_errors)) 202 | self.log("var_error/val", np.var(disp_errors)) 203 | -------------------------------------------------------------------------------- /disp_training/disp_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/deep_ev_tracker/fd09bcea0be9870905f4883466e2d16d8b45b84b/disp_training/disp_scripts/__init__.py -------------------------------------------------------------------------------- /disp_training/disp_scripts/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import numpy as np 4 | 5 | RESULTS_DICT = { 6 | "Ours": /results.npz, 7 | } 8 | GT_PATH = /ground_truth.npz 9 | DATASET_PATH = 10 | OUT_DIR = 11 | 12 | 13 | def save_metrics(methods_avg, methods_end_point): 14 | with open(os.path.join(OUT_DIR, "metrics.csv"), 'w') as file: 15 | writer = csv.writer(file) 16 | 17 | writer.writerow(["Overall Average"] + [method_key for method_key in methods_avg.keys()]) 18 | 19 | first_method = list(methods_avg.keys())[0] 20 | for metric_key in methods_avg[first_method].keys(): 21 | row = [metric_key] 22 | for method in methods_avg: 23 | row.append(methods_avg[method][metric_key]) 24 | writer.writerow(row) 25 | 26 | writer.writerow([""]) 27 | 28 | writer.writerow(["End-Point Average"] + [method_key for method_key in methods_avg.keys()]) 29 | for metric_key in methods_end_point[first_method].keys(): 30 | row = [metric_key] 31 | for method in methods_end_point: 32 | row.append(methods_end_point[method][metric_key]) 33 | writer.writerow(row) 34 | 35 | def compute_metrics(disp_gt, disp_pred, depth_gt, depth_pred): 36 | eps = 1e-5 37 | metrics = {} 38 | 39 | # Disparity Metrics 40 | disp_abs_diff = np.abs(disp_gt-disp_pred) 41 | metrics["disp/MAE"] = disp_abs_diff.mean(axis=0) 42 | metrics["disp/RMSE"] = np.sqrt((disp_abs_diff ** 2).mean(axis=0)) 43 | metrics["disp/ratio_pe_1"] = np.mean(disp_abs_diff > 1, axis=0) 44 | metrics["disp/ratio_pe_2"] = np.mean(disp_abs_diff > 2, axis=0) 45 | metrics["disp/ratio_pe_3"] = np.mean(disp_abs_diff > 3, axis=0) 46 | metrics["disp/ratio_pe_4"] = np.mean(disp_abs_diff > 4, axis=0) 47 | metrics["disp/ratio_pe_5"] = np.mean(disp_abs_diff > 5, axis=0) 48 | 49 | # Depth Metrics 50 | depth_abs_diff = np.abs(depth_gt-depth_pred) 51 | metrics["depth/MAE"] = depth_abs_diff.mean(axis=0) 52 | metrics["depth/MAE_rel"] = (depth_abs_diff/(depth_gt+eps)).mean(axis=0) 53 | metrics["depth/RMS"] = np.sqrt((depth_abs_diff ** 2).mean(axis=0)) 54 | ratio = np.max(np.stack([depth_gt / (depth_pred + eps), depth_pred / (depth_gt + eps)]), axis=0) 55 | metrics["depth/ratio_delta_1.25"] = np.mean(ratio <= 1.25, axis=0) 56 | metrics["depth/ratio_delta_1.25^2"] = np.mean(ratio <= 1.25 ** 2, axis=0) 57 | metrics["depth/ratio_delta_1.25^3"] = np.mean(ratio <= 1.25 ** 3, axis=0) 58 | 59 | return metrics 60 | 61 | def process_method(method_name, method_path): 62 | ground_truth = np.load(GT_PATH) 63 | disp_gt = ground_truth["disparity_gt"] 64 | seq_names_gt = ground_truth["seq_names"] 65 | 66 | if method_name == 'Ours': 67 | predictions = np.load(method_path) 68 | disp_pred = predictions["disparity_pred"] 69 | seq_names_pred = predictions["seq_names"] 70 | assert np.all(seq_names_pred == seq_names_gt), "Sequence names do not match" 71 | 72 | else: 73 | raise ValueError("Method not supported") 74 | 75 | 76 | depth_gt = np.zeros_like(disp_gt) 77 | depth_pred = np.zeros_like(disp_pred) 78 | 79 | # Metric Calculations 80 | avg_metrics = compute_metrics(disp_gt.flatten(), disp_pred.flatten(), depth_gt.flatten(), depth_pred.flatten()) 81 | end_point_metrics = compute_metrics(disp_gt[:, -1], disp_pred[:, -1], depth_gt[:, -1], depth_pred[:, -1]) 82 | 83 | return avg_metrics, end_point_metrics 84 | 85 | 86 | def main(): 87 | methods_avg, methods_end_point = {}, {} 88 | for method_name, method_path in RESULTS_DICT.items(): 89 | avg_metrics, end_point_metrics = process_method(method_name, method_path) 90 | methods_avg[method_name] = avg_metrics 91 | methods_end_point[method_name] = end_point_metrics 92 | 93 | # Save metrics 94 | save_metrics(methods_avg, methods_end_point) 95 | 96 | # Print 97 | for method_name in methods_avg: 98 | print(f"{method_name}:") 99 | for metric_key in methods_avg[method_name]: 100 | if "disp" in metric_key: 101 | print(f"\t{metric_key}: {methods_avg[method_name][metric_key]}") 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /disp_training/disp_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/deep_ev_tracker/fd09bcea0be9870905f4883466e2d16d8b45b84b/disp_training/disp_utils/__init__.py -------------------------------------------------------------------------------- /disp_training/disp_utils/disp_utils_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_patches(representation, u_center, patch_size): 5 | center = np.rint(u_center).astype(int) 6 | h, w = representation.shape[:2] 7 | 8 | patch_uv = np.stack(np.meshgrid(np.arange(patch_size), np.arange(patch_size)), axis=2) 9 | patch_uv = patch_uv - patch_size // 2 10 | grid_coords = center[:, None, None, :] + patch_uv[None, :, :, :] 11 | 12 | assert grid_coords.min() >= 0 13 | assert grid_coords[:, :, :, 0].max() < w 14 | assert grid_coords[:, :, :, 1].max() < h 15 | 16 | if representation.ndim == 2: 17 | patches = representation[grid_coords[:, :, :, 1], grid_coords[:, :, :, 0]] 18 | elif representation.ndim == 3: 19 | patches = representation[grid_coords[:, :, :, 1], grid_coords[:, :, :, 0], :] 20 | 21 | return patches 22 | 23 | 24 | def get_event_patches(representation, u_center, patch_size, disp_patch_range): 25 | center = np.rint(u_center).astype(int) 26 | h, w = representation.shape[:2] 27 | 28 | patch_uv = np.stack(np.meshgrid(np.arange(patch_size), np.arange(-(disp_patch_range-1), 1)), axis=2) 29 | patch_uv[:, :, 0] = patch_uv[:, :, 0] - patch_size // 2 30 | patch_uv[:, :, 1] = patch_uv[:, :, 1] + patch_size // 2 31 | grid_coords = center[:, None, None, :] + patch_uv[None, :, :, :] 32 | 33 | grid_coords[:, :, :, 1] = np.clip(grid_coords[:, :, :, 1], 0, h-1) 34 | 35 | assert grid_coords.min() >= 0 36 | assert grid_coords[:, :, :, 0].max() < w 37 | assert grid_coords[:, :, :, 1].max() < h 38 | 39 | if representation.ndim == 2: 40 | patches = representation[grid_coords[:, :, :, 1], grid_coords[:, :, :, 0]] 41 | elif representation.ndim == 3: 42 | patches = representation[grid_coords[:, :, :, 1], grid_coords[:, :, :, 0], :] 43 | 44 | return patches 45 | 46 | -------------------------------------------------------------------------------- /disp_training/doc/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/deep_ev_tracker/fd09bcea0be9870905f4883466e2d16d8b45b84b/disp_training/doc/thumbnail.png -------------------------------------------------------------------------------- /disp_training/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import os 4 | import hydra 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | import sys 9 | import tqdm 10 | 11 | sys.path.append('../') 12 | 13 | from utils.utils import * 14 | from disp_dataloader.m3ed_loader import M3EDTestDataModule 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 19 | torch.set_num_threads(1) 20 | torch.backends.cudnn.benchmark = True 21 | 22 | 23 | def propagate_keys_disp(cfg): 24 | OmegaConf.set_struct(cfg, True) 25 | 26 | with open_dict(cfg): 27 | cfg.data.patch_size = cfg.patch_size 28 | cfg.data.min_track_length = cfg.min_track_length 29 | cfg.data.tracks_per_sample = cfg.tracks_per_sample 30 | cfg.data.disp_patch_range = cfg.disp_patch_range 31 | 32 | cfg.model.patch_size = cfg.patch_size 33 | cfg.model.min_track_length = cfg.min_track_length 34 | cfg.model.tracks_per_sample = cfg.tracks_per_sample 35 | cfg.model.disp_patch_range = cfg.disp_patch_range 36 | 37 | 38 | def create_attn_mask(seq_frame_idx, device): 39 | attn_mask = torch.from_numpy(seq_frame_idx[:, None] == seq_frame_idx[None, :]).to(device) 40 | attn_mask = torch.logical_not(attn_mask).bool() 41 | 42 | return attn_mask 43 | 44 | 45 | def test_run(model, dataloader, cfg): 46 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 47 | model = model.eval() 48 | model = model.to(device) 49 | 50 | # Create attention mask 51 | n_samples = cfg.data.batch_size * cfg.tracks_per_sample 52 | attn_mask = None 53 | 54 | # Iterate over test dataset 55 | list_disp_pred, list_disp_gt, list_seq_names, list_img_points = [], [], [], [] 56 | for batch_sample in tqdm(dataloader): 57 | seq_frame_patches, seq_event_patches, seq_y_gt_disp_samples, seq_frame_idx, seq_names, img_points = batch_sample 58 | seq_frame_patches = torch.from_numpy(seq_frame_patches).permute([0, 1, 4, 2, 3]).to(device) 59 | seq_event_patches = torch.from_numpy(seq_event_patches).permute([0, 1, 4, 2, 3]).to(device) 60 | n_samples = seq_frame_patches.shape[0] 61 | 62 | assert cfg.min_track_length == seq_frame_patches.shape[1] 63 | y_pred_disp_samples = np.zeros([n_samples, cfg.min_track_length]) 64 | 65 | if attn_mask is None or seq_frame_patches.shape[0] != attn_mask.shape[0]: 66 | attn_mask = create_attn_mask(seq_frame_idx, device) 67 | 68 | model.reset(None) 69 | for i_unroll in range(cfg.min_track_length): 70 | frame_patches = seq_frame_patches[:, i_unroll, :, :, :] 71 | event_patches = seq_event_patches[:, i_unroll, :, :, :] 72 | 73 | # Inference 74 | y_disp_pred = model.forward(frame_patches, event_patches, attn_mask) 75 | y_pred_disp_samples[:, i_unroll] = y_disp_pred[:, 1].detach().cpu().numpy() 76 | 77 | list_disp_pred.append(y_pred_disp_samples) 78 | list_disp_gt.append(seq_y_gt_disp_samples) 79 | list_seq_names.append(seq_names) 80 | list_img_points.append(img_points) 81 | 82 | # Save results 83 | np.savez_compressed('results.npz', 84 | disparity_pred=np.concatenate(list_disp_pred, axis=0), 85 | seq_names=np.concatenate(list_seq_names, axis=0).flatten()) 86 | 87 | np.savez_compressed('ground_truth.npz', 88 | disparity_gt=np.concatenate(list_disp_gt, axis=0), 89 | image_points=np.concatenate(list_img_points, axis=0), 90 | seq_names=np.concatenate(list_seq_names, axis=0).flatten()) 91 | 92 | 93 | @hydra.main(config_path="disp_configs", config_name="m3ed_test") 94 | def test(cfg): 95 | pl.seed_everything(1234) 96 | 97 | # Update configuration dicts with common keys 98 | propagate_keys_disp(cfg) 99 | logger.info("\n" + OmegaConf.to_yaml(cfg)) 100 | 101 | with open('test_config.yaml', 'w') as outfile: 102 | OmegaConf.save(cfg, outfile) 103 | 104 | # Instantiate model 105 | model = hydra.utils.instantiate( 106 | cfg.model, 107 | _recursive_=False, 108 | ) 109 | if cfg.checkpoint_path.lower() == 'none': 110 | print("Provide Checkpoints") 111 | 112 | # Load weights 113 | checkpoint = torch.load(cfg.checkpoint_path) 114 | model.load_state_dict(checkpoint['state_dict'], strict=True) 115 | 116 | data_module = M3EDTestDataModule(**cfg.data) 117 | data_module.setup() 118 | dataloader = data_module.test_dataloader() 119 | 120 | with torch.no_grad(): 121 | test_run(model, dataloader, cfg) 122 | 123 | 124 | if __name__ == '__main__': 125 | test() 126 | -------------------------------------------------------------------------------- /disp_training/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | addict==2.4.0 3 | aiohttp==3.9.3 4 | aiosignal==1.3.1 5 | ansi2html==1.8.0 6 | antlr4-python3-runtime==4.9.3 7 | asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work 8 | async-timeout==4.0.3 9 | attrs==23.1.0 10 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work 11 | beautifulsoup4==4.12.2 12 | Brotli @ file:///tmp/abs_ecyw11_7ze/croots/recipe/brotli-split_1659616059936/work 13 | certifi==2023.7.22 14 | chardet==5.2.0 15 | charset-normalizer==3.2.0 16 | click @ file:///croot/click_1698129812380/work 17 | cloudpickle @ file:///croot/cloudpickle_1683040006038/work 18 | colorama @ file:///croot/colorama_1672386526460/work 19 | comm==0.1.4 20 | ConfigArgParse==1.7 21 | contourpy @ file:///croot/contourpy_1700583582875/work 22 | cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work 23 | cytoolz @ file:///croot/cytoolz_1701723583781/work 24 | dash==2.13.0 25 | dash-core-components==2.0.0 26 | dash-html-components==2.0.0 27 | dash-table==5.0.0 28 | dask @ file:///croot/dask-core_1701396095060/work 29 | decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work 30 | einops==0.7.0 31 | executing @ file:///opt/conda/conda-bld/executing_1646925071911/work 32 | fastjsonschema==2.18.0 33 | filelock==3.12.2 34 | Flask==2.2.5 35 | fonttools==4.25.0 36 | frozenlist==1.4.1 37 | fsspec @ file:///croot/fsspec_1701286474621/work 38 | gdown==4.7.1 39 | grpcio==1.60.1 40 | h5py @ file:///croot/h5py_1691589708553/work 41 | hdf5plugin @ file:///home/conda/feedstock_root/build_artifacts/hdf5plugin_1651556479333/work 42 | hydra-core==1.3.2 43 | idna==3.6 44 | imagecodecs @ file:///croot/imagecodecs_1695064943445/work 45 | imageio @ file:///croot/imageio_1707247282708/work 46 | importlib-metadata==6.8.0 47 | importlib-resources @ file:///croot/importlib_resources-suite_1704281845041/work 48 | ipython @ file:///croot/ipython_1674681422581/work 49 | ipywidgets==8.1.0 50 | itsdangerous==2.1.2 51 | jedi @ file:///tmp/build/80754af9/jedi_1644297102865/work 52 | Jinja2 @ file:///croot/jinja2_1666908132255/work 53 | joblib==1.3.2 54 | jsonschema==4.19.0 55 | jsonschema-specifications==2023.7.1 56 | jupyter_core==5.3.1 57 | jupyterlab-widgets==3.0.8 58 | kiwisolver @ file:///croot/kiwisolver_1672387140495/work 59 | lightning==2.2.0.post0 60 | lightning-utilities==0.10.1 61 | locket @ file:///opt/conda/conda-bld/locket_1652903118915/work 62 | Markdown==3.5.2 63 | MarkupSafe @ file:///croot/markupsafe_1704205993651/work 64 | matplotlib @ file:///croot/matplotlib-suite_1679593461707/work 65 | matplotlib-inline @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work 66 | mpmath==1.3.0 67 | multidict==6.0.5 68 | munkres==1.1.4 69 | nbformat==5.7.0 70 | nest-asyncio==1.5.7 71 | networkx @ file:///croot/networkx_1690561992265/work 72 | numpy==1.25.2 73 | nvidia-cublas-cu12==12.1.3.1 74 | nvidia-cuda-cupti-cu12==12.1.105 75 | nvidia-cuda-nvrtc-cu12==12.1.105 76 | nvidia-cuda-runtime-cu12==12.1.105 77 | nvidia-cudnn-cu12==8.9.2.26 78 | nvidia-cufft-cu12==11.0.2.54 79 | nvidia-curand-cu12==10.3.2.106 80 | nvidia-cusolver-cu12==11.4.5.107 81 | nvidia-cusparse-cu12==12.1.0.106 82 | nvidia-nccl-cu12==2.19.3 83 | nvidia-nvjitlink-cu12==12.3.101 84 | nvidia-nvtx-cu12==12.1.105 85 | omegaconf==2.3.0 86 | open3d==0.17.0 87 | packaging @ file:///croot/packaging_1693575174725/work 88 | pandas==2.1.0 89 | parso @ file:///opt/conda/conda-bld/parso_1641458642106/work 90 | partd @ file:///croot/partd_1698702562572/work 91 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work 92 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 93 | pillow @ file:///croot/pillow_1707233021655/work 94 | platformdirs==3.10.0 95 | plotly==5.16.1 96 | ply==3.11 97 | prompt-toolkit @ file:///croot/prompt-toolkit_1704404351921/work 98 | protobuf==4.25.3 99 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 100 | pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work 101 | Pygments @ file:///croot/pygments_1684279966437/work 102 | pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work 103 | PyQt5==5.15.10 104 | PyQt5-sip @ file:///croot/pyqt-split_1698769088074/work/pyqt_sip 105 | pyquaternion==0.9.9 106 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305812635/work 107 | python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work 108 | pytorch-lightning==2.2.0.post0 109 | pytz==2023.3 110 | pywavelets @ file:///croot/pywavelets_1705049820073/work 111 | PyYAML @ file:///croot/pyyaml_1698096049011/work 112 | referencing==0.30.2 113 | requests==2.31.0 114 | retrying==1.3.4 115 | rpds-py==0.10.0 116 | SciencePlots==2.1.1 117 | scikit-image @ file:///croot/scikit-image_1669241743693/work 118 | scikit-learn==1.3.0 119 | scipy==1.11.2 120 | sip @ file:///croot/sip_1698675935381/work 121 | six @ file:///tmp/build/80754af9/six_1644875935023/work 122 | soupsieve==2.4.1 123 | stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work 124 | sympy==1.12 125 | tenacity==8.2.3 126 | tensorboard==2.16.2 127 | tensorboard-data-server==0.7.2 128 | tensorboardX==2.6.2.2 129 | threadpoolctl==3.2.0 130 | tifffile @ file:///croot/tifffile_1695107451082/work 131 | tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work 132 | toolz @ file:///croot/toolz_1667464077321/work 133 | torch==2.2.0 134 | torchmetrics==1.3.1 135 | torchvision==0.17.0 136 | tornado @ file:///croot/tornado_1696936946304/work 137 | tqdm==4.66.1 138 | traitlets @ file:///croot/traitlets_1671143879854/work 139 | triton==2.2.0 140 | typing_extensions==4.9.0 141 | tzdata==2023.3 142 | urllib3==2.0.4 143 | wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work 144 | Werkzeug==2.2.3 145 | widgetsnbextension==4.0.8 146 | yarl==1.9.4 147 | zipp @ file:///croot/zipp_1704206909481/work 148 | -------------------------------------------------------------------------------- /disp_training/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('../') 4 | import logging 5 | import hydra 6 | import pytorch_lightning as pl 7 | import torch 8 | 9 | from utils.utils import * 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 14 | torch.set_num_threads(1) 15 | torch.backends.cudnn.benchmark = True 16 | 17 | 18 | def propagate_keys_disp(cfg, testing=False): 19 | OmegaConf.set_struct(cfg, True) 20 | 21 | with open_dict(cfg): 22 | cfg.data.patch_size = cfg.patch_size 23 | cfg.data.min_track_length = cfg.min_track_length 24 | cfg.data.min_tracks_per_sample = cfg.min_tracks_per_sample 25 | cfg.data.max_tracks_per_sample = cfg.max_tracks_per_sample 26 | cfg.data.disp_patch_range = cfg.disp_patch_range 27 | cfg.data.augment = cfg.augment 28 | 29 | cfg.model.patch_size = cfg.patch_size 30 | cfg.model.min_track_length = cfg.min_track_length 31 | cfg.model.disp_patch_range = cfg.disp_patch_range 32 | 33 | if not testing: 34 | cfg.model.n_vis = cfg.n_vis 35 | cfg.model.debug = cfg.debug 36 | 37 | 38 | @hydra.main(config_path="disp_configs", config_name="m3ed_train") 39 | def train(cfg): 40 | pl.seed_everything(1234) 41 | 42 | # Update configuration dicts with common keys 43 | propagate_keys_disp(cfg) 44 | logger.info("\n" + OmegaConf.to_yaml(cfg)) 45 | 46 | # Instantiate model and dataloaders 47 | model = hydra.utils.instantiate( 48 | cfg.model, 49 | _recursive_=False, 50 | ) 51 | if cfg.checkpoint_path.lower() != 'none': 52 | # Load weights 53 | model = model.load_from_checkpoint(checkpoint_path=cfg.checkpoint_path) 54 | 55 | data_module = hydra.utils.instantiate(cfg.data) 56 | 57 | # Logging 58 | if cfg.logging: 59 | training_logger = pl.loggers.TensorBoardLogger(".", "", "", log_graph=True, default_hp_metric=False) 60 | else: 61 | training_logger = None 62 | 63 | # Training schedule 64 | callbacks = [pl.callbacks.LearningRateMonitor(logging_interval='epoch'), 65 | pl.callbacks.ModelCheckpoint(save_top_k=-1)] 66 | 67 | trainer = pl.Trainer( 68 | **OmegaConf.to_container(cfg.trainer), 69 | devices=[0], 70 | accelerator='gpu', 71 | callbacks=callbacks, 72 | logger=training_logger 73 | ) 74 | 75 | trainer.fit(model, datamodule=data_module) 76 | 77 | 78 | if __name__ == '__main__': 79 | train() 80 | -------------------------------------------------------------------------------- /doc/shapes_6dof_485_565_tracks.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/deep_ev_tracker/fd09bcea0be9870905f4883466e2d16d8b45b84b/doc/shapes_6dof_485_565_tracks.gif -------------------------------------------------------------------------------- /doc/thumbnail.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/deep_ev_tracker/fd09bcea0be9870905f4883466e2d16d8b45b84b/doc/thumbnail.PNG -------------------------------------------------------------------------------- /doc/ziggy_in_the_arena_1350_1650-opt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/deep_ev_tracker/fd09bcea0be9870905f4883466e2d16d8b45b84b/doc/ziggy_in_the_arena_1350_1650-opt.gif -------------------------------------------------------------------------------- /evaluate_real.py: -------------------------------------------------------------------------------- 1 | """ Predict tracks for a sequence with a network """ 2 | import logging 3 | import os 4 | from pathlib import Path 5 | 6 | import hydra 7 | import imageio 8 | import IPython 9 | import numpy as np 10 | import pytorch_lightning as pl 11 | import torch 12 | from omegaconf import OmegaConf, open_dict 13 | from prettytable import PrettyTable 14 | from tqdm import tqdm 15 | 16 | from utils.dataset import CornerConfig, ECSubseq, EDSSubseq, EvalDatasetType 17 | from utils.timers import CudaTimer, cuda_timers 18 | from utils.track_utils import ( 19 | TrackObserver, 20 | get_gt_corners, 21 | ) 22 | from utils.visualization import generate_track_colors, render_pred_tracks, render_tracks 23 | 24 | # Configure GPU order 25 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 26 | 27 | # Logging 28 | logger = logging.getLogger(__name__) 29 | results_table = PrettyTable() 30 | results_table.field_names = ["Inference Time"] 31 | 32 | # Configure datasets 33 | corner_config = CornerConfig(30, 0.3, 15, 0.15, False, 11) 34 | 35 | EvalDatasetConfigDict = { 36 | EvalDatasetType.EC: {"dt": 0.010, "root_dir": ""}, 37 | EvalDatasetType.EDS: {"dt": 0.005, "root_dir": ""}, 38 | } 39 | 40 | EVAL_DATASETS = [ 41 | ("peanuts_light_160_386", EvalDatasetType.EDS), 42 | ("rocket_earth_light_338_438", EvalDatasetType.EDS), 43 | ("ziggy_in_the_arena_1350_1650", EvalDatasetType.EDS), 44 | ("peanuts_running_2360_2460", EvalDatasetType.EDS), 45 | ("shapes_translation_8_88", EvalDatasetType.EC), 46 | ("shapes_rotation_165_245", EvalDatasetType.EC), 47 | ("shapes_6dof_485_565", EvalDatasetType.EC), 48 | ("boxes_translation_330_410", EvalDatasetType.EC), 49 | ("boxes_rotation_198_278", EvalDatasetType.EC), 50 | ] 51 | 52 | 53 | def evaluate(model, sequence_dataset, dt_track_vis, sequence_name, visualize): 54 | tracks_pred = TrackObserver( 55 | t_init=sequence_dataset.t_init, u_centers_init=sequence_dataset.u_centers 56 | ) 57 | 58 | model.reset(sequence_dataset.n_tracks) 59 | event_generator = sequence_dataset.events() 60 | 61 | cuda_timer = CudaTimer(model.device, sequence_dataset.sequence_name) 62 | 63 | with torch.no_grad(): 64 | # Predict network tracks 65 | for t, x in tqdm( 66 | event_generator, 67 | total=sequence_dataset.n_events - 1, 68 | desc="Predicting tracks with network...", 69 | ): 70 | with cuda_timer: 71 | x = x.to(model.device) 72 | y_hat = model.forward(x) 73 | 74 | sequence_dataset.accumulate_y_hat(y_hat) 75 | tracks_pred.add_observation(t, sequence_dataset.u_centers.cpu().numpy()) 76 | 77 | if visualize: 78 | # Visualize network tracks 79 | gif_img_arr = [] 80 | tracks_pred_interp = tracks_pred.get_interpolators() 81 | track_colors = generate_track_colors(sequence_dataset.n_tracks) 82 | for i, (t, img_now) in enumerate( 83 | tqdm( 84 | sequence_dataset.frames(), 85 | total=sequence_dataset.n_frames - 1, 86 | desc="Rendering predicted tracks... ", 87 | ) 88 | ): 89 | fig_arr = render_pred_tracks( 90 | tracks_pred_interp, t, img_now, track_colors, dt_track=dt_track_vis 91 | ) 92 | gif_img_arr.append(fig_arr) 93 | imageio.mimsave(f"{sequence_name}_tracks_pred.gif", gif_img_arr) 94 | 95 | # Save predicted tracks 96 | np.savetxt( 97 | f"{sequence_name}.txt", 98 | tracks_pred.track_data, 99 | fmt=["%i", "%.9f", "%i", "%i"], 100 | delimiter=" ", 101 | ) 102 | 103 | metrics = {} 104 | metrics["latency"] = sum(cuda_timers[sequence_dataset.sequence_name]) 105 | 106 | return metrics 107 | 108 | 109 | @hydra.main(config_path="configs", config_name="eval_real_defaults") 110 | def track(cfg): 111 | pl.seed_everything(1234) 112 | OmegaConf.set_struct(cfg, True) 113 | with open_dict(cfg): 114 | cfg.model.representation = cfg.representation 115 | logger.info("\n" + OmegaConf.to_yaml(cfg)) 116 | 117 | # Configure model 118 | model = hydra.utils.instantiate(cfg.model, _recursive_=False) 119 | 120 | state_dict = torch.load(cfg.weights_path, map_location="cuda:0")["state_dict"] 121 | model.load_state_dict(state_dict) 122 | if torch.cuda.is_available(): 123 | model = model.cuda() 124 | model.eval() 125 | 126 | # Run evaluation on each dataset 127 | for seq_name, seq_type in EVAL_DATASETS: 128 | if seq_type == EvalDatasetType.EC: 129 | dataset_class = ECSubseq 130 | elif seq_type == EvalDatasetType.EDS: 131 | dataset_class = EDSSubseq 132 | else: 133 | raise ValueError 134 | 135 | dataset = dataset_class( 136 | EvalDatasetConfigDict[seq_type]["root_dir"], 137 | seq_name, 138 | -1, 139 | cfg.patch_size, 140 | cfg.representation, 141 | EvalDatasetConfigDict[seq_type]["dt"], 142 | corner_config, 143 | ) 144 | 145 | # Load ground truth corners for this seq and override initialization 146 | gt_features_path = str(Path(cfg.gt_path) / f"{seq_name}.gt.txt") 147 | gt_start_corners = get_gt_corners(gt_features_path) 148 | 149 | dataset.override_keypoints(gt_start_corners) 150 | 151 | metrics = evaluate(model, dataset, cfg.dt_track_vis, seq_name, cfg.visualize) 152 | 153 | logger.info(f"=== DATASET: {seq_name} ===") 154 | logger.info(f"Latency: {metrics['latency']} s") 155 | 156 | results_table.add_row([metrics["latency"]]) 157 | 158 | logger.info(f"\n{results_table.get_string()}") 159 | 160 | 161 | if __name__ == "__main__": 162 | track() 163 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class ConvLSTMCell(nn.Module): 7 | """ 8 | Generate a convolutional LSTM cell 9 | From: https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py 10 | """ 11 | 12 | def __init__(self, in_channels, hidden_channels, kernel_size): 13 | super().__init__() 14 | self.in_channels = in_channels 15 | self.hidden_channels = hidden_channels 16 | self.Gates = nn.Conv2d( 17 | in_channels + hidden_channels, 18 | 4 * hidden_channels, 19 | kernel_size, 20 | padding=kernel_size // 2, 21 | ) 22 | self.prev_state = None 23 | self.init_weights() 24 | 25 | def init_weights(self): 26 | torch.nn.init.xavier_uniform_(self.Gates.weight) 27 | 28 | def reset(self): 29 | self.prev_state = None 30 | 31 | def forward(self, input_): 32 | # get batch and spatial sizes 33 | batch_size = input_.data.size()[0] 34 | spatial_size = input_.data.size()[2:] 35 | 36 | # generate empty prev_state, if None is provided 37 | if self.prev_state is None: 38 | state_size = [batch_size, self.hidden_channels] + list(spatial_size) 39 | self.prev_state = ( 40 | Variable(torch.zeros(state_size, device=input_.device)), 41 | Variable(torch.zeros(state_size, device=input_.device)), 42 | ) 43 | 44 | prev_hidden, prev_cell = self.prev_state 45 | 46 | # data size is [batch, channel, height, width] 47 | stacked_inputs = torch.cat((input_, prev_hidden), 1) 48 | gates = self.Gates(stacked_inputs) 49 | 50 | # chunk across channel dimension 51 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1) 52 | 53 | # apply sigmoid non linearity 54 | in_gate = torch.sigmoid(in_gate) 55 | remember_gate = torch.sigmoid(remember_gate) 56 | out_gate = torch.sigmoid(out_gate) 57 | 58 | # apply tanh non linearity 59 | cell_gate = torch.tanh(cell_gate) 60 | 61 | # compute current cell and hidden state 62 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate) 63 | hidden = out_gate * torch.tanh(cell) 64 | 65 | self.prev_state = (hidden, cell) 66 | return hidden 67 | 68 | 69 | class ConvBlock(nn.Module): 70 | def __init__( 71 | self, 72 | in_channels, 73 | out_channels, 74 | n_convs=3, 75 | kernel_size=3, 76 | stride=1, 77 | padding=1, 78 | downsample=True, 79 | dilation=1, 80 | ): 81 | super(ConvBlock, self).__init__() 82 | self.modules = [] 83 | 84 | c_in = in_channels 85 | c_out = out_channels 86 | for i in range(n_convs): 87 | self.modules.append( 88 | nn.Conv2d( 89 | in_channels=c_in, 90 | out_channels=c_out, 91 | kernel_size=kernel_size, 92 | padding=padding, 93 | stride=stride, 94 | bias=False, 95 | dilation=dilation, 96 | ) 97 | ) 98 | self.modules.append(nn.BatchNorm2d(num_features=out_channels)) 99 | self.modules.append(nn.LeakyReLU(0.1)) 100 | c_in = c_out 101 | 102 | if downsample: 103 | self.modules.append( 104 | nn.Conv2d( 105 | in_channels=out_channels, 106 | out_channels=out_channels, 107 | kernel_size=2, 108 | stride=2, 109 | ) 110 | ) 111 | self.modules.append(nn.ReLU()) 112 | # self.modules.append(nn.MaxPool2d(kernel_size=2, stride=2)) 113 | 114 | self.model = nn.Sequential(*self.modules) 115 | self.init_weights() 116 | 117 | def init_weights(self): 118 | for m in self.model.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | torch.nn.init.xavier_uniform_(m.weight) 121 | 122 | def forward(self, x): 123 | return self.model(x) 124 | -------------------------------------------------------------------------------- /models/correlation3_unscaled.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn.init 3 | 4 | from models.common import * 5 | from models.template import Template 6 | from utils.losses import * 7 | 8 | 9 | class FPNEncoder(nn.Module): 10 | def __init__(self, in_channels=1, out_channels=512, recurrent=False): 11 | super(FPNEncoder, self).__init__() 12 | 13 | self.conv_bottom_0 = ConvBlock( 14 | in_channels=in_channels, 15 | out_channels=32, 16 | n_convs=2, 17 | kernel_size=1, 18 | padding=0, 19 | downsample=False, 20 | ) 21 | self.conv_bottom_1 = ConvBlock( 22 | in_channels=32, 23 | out_channels=64, 24 | n_convs=2, 25 | kernel_size=5, 26 | padding=0, 27 | downsample=False, 28 | ) 29 | self.conv_bottom_2 = ConvBlock( 30 | in_channels=64, 31 | out_channels=128, 32 | n_convs=2, 33 | kernel_size=5, 34 | padding=0, 35 | downsample=False, 36 | ) 37 | self.conv_bottom_3 = ConvBlock( 38 | in_channels=128, 39 | out_channels=256, 40 | n_convs=2, 41 | kernel_size=3, 42 | padding=0, 43 | downsample=True, 44 | ) 45 | self.conv_bottom_4 = ConvBlock( 46 | in_channels=256, 47 | out_channels=out_channels, 48 | n_convs=2, 49 | kernel_size=3, 50 | padding=0, 51 | downsample=False, 52 | ) 53 | 54 | self.recurrent = recurrent 55 | if self.recurrent: 56 | self.conv_rnn = ConvLSTMCell(out_channels, out_channels, 1) 57 | 58 | self.conv_lateral_3 = nn.Conv2d( 59 | in_channels=256, out_channels=out_channels, kernel_size=1, bias=True 60 | ) 61 | self.conv_lateral_2 = nn.Conv2d( 62 | in_channels=128, out_channels=out_channels, kernel_size=1, bias=True 63 | ) 64 | self.conv_lateral_1 = nn.Conv2d( 65 | in_channels=64, out_channels=out_channels, kernel_size=1, bias=True 66 | ) 67 | self.conv_lateral_0 = nn.Conv2d( 68 | in_channels=32, out_channels=out_channels, kernel_size=1, bias=True 69 | ) 70 | 71 | self.conv_dealias_3 = nn.Conv2d( 72 | in_channels=out_channels, 73 | out_channels=out_channels, 74 | kernel_size=3, 75 | padding=1, 76 | bias=True, 77 | ) 78 | self.conv_dealias_2 = nn.Conv2d( 79 | in_channels=out_channels, 80 | out_channels=out_channels, 81 | kernel_size=3, 82 | padding=1, 83 | bias=True, 84 | ) 85 | self.conv_dealias_1 = nn.Conv2d( 86 | in_channels=out_channels, 87 | out_channels=out_channels, 88 | kernel_size=3, 89 | padding=1, 90 | bias=True, 91 | ) 92 | self.conv_dealias_0 = nn.Conv2d( 93 | in_channels=out_channels, 94 | out_channels=out_channels, 95 | kernel_size=3, 96 | padding=1, 97 | bias=True, 98 | ) 99 | self.conv_out = nn.Sequential( 100 | ConvBlock( 101 | in_channels=out_channels, 102 | out_channels=out_channels, 103 | n_convs=1, 104 | kernel_size=3, 105 | padding=1, 106 | downsample=False, 107 | ), 108 | nn.Conv2d( 109 | in_channels=out_channels, 110 | out_channels=out_channels, 111 | kernel_size=3, 112 | padding=1, 113 | bias=True, 114 | ), 115 | ) 116 | 117 | self.conv_bottleneck_out = nn.Sequential( 118 | ConvBlock( 119 | in_channels=out_channels, 120 | out_channels=out_channels, 121 | n_convs=1, 122 | kernel_size=3, 123 | padding=1, 124 | downsample=False, 125 | ), 126 | nn.Conv2d( 127 | in_channels=out_channels, 128 | out_channels=out_channels, 129 | kernel_size=3, 130 | padding=1, 131 | bias=True, 132 | ), 133 | ) 134 | 135 | def reset(self): 136 | if self.recurrent: 137 | self.conv_rnn.reset() 138 | 139 | def forward(self, x): 140 | """ 141 | :param x: 142 | :return: (highest res feature map, lowest res feature map) 143 | """ 144 | 145 | # Bottom-up pathway 146 | c0 = self.conv_bottom_0(x) # 31x31 147 | c1 = self.conv_bottom_1(c0) # 23x23 148 | c2 = self.conv_bottom_2(c1) # 15x15 149 | c3 = self.conv_bottom_3(c2) # 5x5 150 | c4 = self.conv_bottom_4(c3) # 1x1 151 | 152 | # Top-down pathway (with lateral cnx and de-aliasing) 153 | p4 = c4 154 | p3 = self.conv_dealias_3( 155 | self.conv_lateral_3(c3) 156 | + F.interpolate(p4, (c3.shape[2], c3.shape[3]), mode="bilinear") 157 | ) 158 | p2 = self.conv_dealias_2( 159 | self.conv_lateral_2(c2) 160 | + F.interpolate(p3, (c2.shape[2], c2.shape[3]), mode="bilinear") 161 | ) 162 | p1 = self.conv_dealias_1( 163 | self.conv_lateral_1(c1) 164 | + F.interpolate(p2, (c1.shape[2], c1.shape[3]), mode="bilinear") 165 | ) 166 | p0 = self.conv_dealias_0( 167 | self.conv_lateral_0(c0) 168 | + F.interpolate(p1, (c0.shape[2], c0.shape[3]), mode="bilinear") 169 | ) 170 | 171 | if self.recurrent: 172 | p0 = self.conv_rnn(p0) 173 | 174 | return self.conv_out(p0), self.conv_bottleneck_out(c4) 175 | 176 | 177 | class JointEncoder(nn.Module): 178 | def __init__(self, in_channels, out_channels): 179 | super(JointEncoder, self).__init__() 180 | 181 | self.conv1 = ConvBlock( 182 | in_channels=in_channels, out_channels=64, n_convs=2, downsample=True 183 | ) 184 | self.conv2 = ConvBlock( 185 | in_channels=64, out_channels=128, n_convs=2, downsample=True 186 | ) 187 | self.convlstm0 = ConvLSTMCell(128, 128, 3) 188 | self.conv3 = ConvBlock( 189 | in_channels=128, out_channels=256, n_convs=2, downsample=True 190 | ) 191 | self.conv4 = ConvBlock( 192 | in_channels=256, 193 | out_channels=256, 194 | kernel_size=3, 195 | padding=0, 196 | n_convs=1, 197 | downsample=False, 198 | ) 199 | 200 | # Transformer Addition 201 | self.flatten = nn.Flatten() 202 | embed_dim = 256 203 | num_heads = 8 204 | self.multihead_attention0 = nn.MultiheadAttention( 205 | embed_dim, num_heads, batch_first=True 206 | ) 207 | 208 | self.prev_x_res = None 209 | self.gates = nn.Linear(2 * embed_dim, embed_dim) 210 | self.ls_layer = LayerScale(embed_dim) 211 | 212 | # Attention Mask Transformer 213 | self.fusion_layer0 = nn.Sequential( 214 | nn.Linear(embed_dim * 2, embed_dim), 215 | nn.LeakyReLU(0.1), 216 | nn.Linear(embed_dim, embed_dim), 217 | nn.LeakyReLU(0.1), 218 | ) 219 | self.output_layers = nn.Sequential(nn.Linear(embed_dim, 512), nn.LeakyReLU(0.1)) 220 | 221 | def reset(self): 222 | self.convlstm0.reset() 223 | self.prev_x_res = None 224 | 225 | def forward(self, x, attn_mask=None): 226 | x = self.conv1(x) 227 | x = self.conv2(x) 228 | x = self.convlstm0(x) 229 | x = self.conv3(x) 230 | x = self.conv4(x) 231 | x = self.flatten(x) 232 | 233 | if self.prev_x_res is None: 234 | self.prev_x_res = Variable(torch.zeros_like(x)) 235 | 236 | x = self.fusion_layer0(torch.cat((x, self.prev_x_res), 1)) 237 | 238 | x_attn = x[None, :, :].detach() 239 | if self.training: 240 | x_attn = self.multihead_attention0( 241 | query=x_attn, key=x_attn, value=x_attn, attn_mask=attn_mask.bool() 242 | )[0].squeeze(0) 243 | else: 244 | x_attn = self.multihead_attention0(query=x_attn, key=x_attn, value=x_attn)[ 245 | 0 246 | ].squeeze(0) 247 | x = x + self.ls_layer(x_attn) 248 | 249 | gate_weight = torch.sigmoid(self.gates(torch.cat((self.prev_x_res, x), 1))) 250 | x = self.prev_x_res * gate_weight + x * (1 - gate_weight) 251 | 252 | self.prev_x_res = x 253 | 254 | x = self.output_layers(x) 255 | 256 | return x 257 | 258 | 259 | class LayerScale(nn.Module): 260 | def __init__(self, dim, init_values=1e-5, inplace=False): 261 | super().__init__() 262 | self.inplace = inplace 263 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 264 | 265 | def forward(self, x): 266 | gamma = self.gamma 267 | return x.mul_(gamma) if self.inplace else x * gamma 268 | 269 | 270 | class TrackerNetC(Template): 271 | def __init__( 272 | self, 273 | representation="time_surfaces_1", 274 | max_unrolls=16, 275 | n_vis=8, 276 | feature_dim=1024, 277 | patch_size=31, 278 | init_unrolls=1, 279 | input_channels=None, 280 | **kwargs, 281 | ): 282 | super(TrackerNetC, self).__init__( 283 | representation=representation, 284 | max_unrolls=max_unrolls, 285 | init_unrolls=init_unrolls, 286 | n_vis=n_vis, 287 | patch_size=patch_size, 288 | **kwargs, 289 | ) 290 | # Configuration 291 | self.grayscale_ref = True 292 | if not isinstance(input_channels, type(None)): 293 | self.channels_in_per_patch = input_channels 294 | 295 | # Architecture 296 | self.feature_dim = feature_dim 297 | self.redir_dim = 128 298 | 299 | self.reference_encoder = FPNEncoder(1, self.feature_dim) 300 | self.target_encoder = FPNEncoder(self.channels_in_per_patch, self.feature_dim) 301 | 302 | # Correlation3 had k=1, p=0 303 | self.reference_redir = nn.Conv2d( 304 | self.feature_dim, self.redir_dim, kernel_size=3, padding=1 305 | ) 306 | self.target_redir = nn.Conv2d( 307 | self.feature_dim, self.redir_dim, kernel_size=3, padding=1 308 | ) 309 | self.softmax = nn.Softmax(dim=2) 310 | 311 | self.joint_encoder = JointEncoder( 312 | in_channels=1 + 2 * self.redir_dim, out_channels=512 313 | ) 314 | self.predictor = nn.Linear(in_features=512, out_features=2, bias=False) 315 | self.flatten = nn.Flatten() 316 | 317 | # Operational 318 | self.loss = L1Truncated(patch_size=patch_size) 319 | self.name = f"corr_{self.representation}" 320 | 321 | # Persistent Tensors 322 | self.f_ref, self.d_ref = None, None 323 | 324 | self.correlation_maps = [] 325 | self.inputs = [] 326 | self.refs = [] 327 | 328 | def init_weights(self): 329 | torch.nn.init.xavier_uniform(self.fc_out.weight) 330 | 331 | def reset(self, _): 332 | self.d_ref, self.f_ref = None, None 333 | self.joint_encoder.reset() 334 | 335 | def forward(self, x, attn_mask=None): 336 | # Feature Extraction 337 | f0, _ = self.target_encoder(x[:, : self.channels_in_per_patch, :, :]) 338 | if isinstance(self.f_ref, type(None)): 339 | self.f_ref, self.d_ref = self.reference_encoder( 340 | x[:, self.channels_in_per_patch :, :, :] 341 | ) 342 | self.f_ref = self.reference_redir(self.f_ref) 343 | 344 | # Correlation and softmax 345 | f_corr = (f0 * self.d_ref).sum(dim=1, keepdim=True) 346 | f_corr = self.softmax( 347 | f_corr.view(-1, 1, self.patch_size * self.patch_size) 348 | ).view(-1, 1, self.patch_size, self.patch_size) 349 | 350 | # Feature re-direction 351 | f = torch.cat([f_corr, self.target_redir(f0), self.f_ref], dim=1) 352 | f = self.joint_encoder(f, attn_mask) 353 | 354 | f = self.predictor(f) 355 | 356 | return f 357 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fire==0.4.0 2 | h5py==3.6.0 3 | hdf5plugin==3.2.0 4 | hydra-core==1.2.0 5 | imageio==2.19.3 6 | omegaconf==2.2.2 7 | opencv-python==4.5.5.64 8 | pandas==1.4.3 9 | pytorch-lightning==1.7.4 10 | PyYAML==6.0 11 | scipy==1.8.1 12 | seaborn==0.11.2 13 | tensorboard==2.10.0 14 | tensorboard-data-server==0.6.1 15 | tensorboard-plugin-wit==1.8.1 16 | termcolor==1.1.0 17 | tf-estimator-nightly==2.8.0.dev2021122109 18 | torch==1.12.1+cu113 19 | torchaudio==0.12.1+cu113 20 | torchmetrics==0.9.3 21 | torchvision==0.13.1+cu113 22 | tqdm==4.63.0 23 | prettytable==3.4.1 24 | -------------------------------------------------------------------------------- /scripts/benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compare our results against KLT with reduced frame-rate 3 | python -m scripts.benchmark 4 | """ 5 | from pathlib import Path 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | # from sklearn.metrics import auc as compute_auc 10 | import numpy as np 11 | from prettytable import PrettyTable 12 | from tqdm import tqdm 13 | 14 | from utils.dataset import EvalDatasetType 15 | from utils.track_utils import compute_tracking_errors, read_txt_results 16 | 17 | plt.rcParams["font.family"] = "serif" 18 | 19 | EVAL_DATASETS = [ 20 | ("peanuts_light_160_386", EvalDatasetType.EDS), 21 | ("rocket_earth_light_338_438", EvalDatasetType.EDS), 22 | ("ziggy_in_the_arena_1350_1650", EvalDatasetType.EDS), 23 | ("peanuts_running_2360_2460", EvalDatasetType.EDS), 24 | ("shapes_translation_8_88", EvalDatasetType.EC), 25 | ("shapes_rotation_165_245", EvalDatasetType.EC), 26 | ("shapes_6dof_485_565", EvalDatasetType.EC), 27 | ("boxes_translation_330_410", EvalDatasetType.EC), 28 | ("boxes_rotation_198_278", EvalDatasetType.EC), 29 | ] 30 | 31 | error_threshold_range = np.arange(1, 32, 1) 32 | results_dir = Path( 33 | "/benchmark_data" 34 | ) 35 | out_dir = Path( 36 | "/benchmark_results" 37 | ) 38 | methods = ["network_pred"] 39 | 40 | table_keys = [ 41 | "age_5_mu", 42 | "age_5_std", 43 | "te_5_mu", 44 | "te_5_std", 45 | "age_mu", 46 | "age_std", 47 | "inliers_mu", 48 | "inliers_std", 49 | "expected_age", 50 | ] 51 | tables = {} 52 | for k in table_keys: 53 | tables[k] = PrettyTable() 54 | tables[k].title = k 55 | tables[k].field_names = ["Sequence Name"] + methods 56 | 57 | for eval_sequence in EVAL_DATASETS: 58 | sequence_name = eval_sequence[0] 59 | track_data_gt = read_txt_results( 60 | str(results_dir / "gt" / f"{sequence_name}.gt.txt") 61 | ) 62 | 63 | rows = {} 64 | for k in tables.keys(): 65 | rows[k] = [sequence_name] 66 | 67 | for method in methods: 68 | inlier_ratio_arr, fa_rel_nz_arr = [], [] 69 | 70 | track_data_pred = read_txt_results( 71 | str(results_dir / f"{method}" / f"{sequence_name}.txt") 72 | ) 73 | 74 | if track_data_pred[0, 1] != track_data_gt[0, 1]: 75 | track_data_pred[:, 1] += -track_data_pred[0, 1] + track_data_gt[0, 1] 76 | 77 | for thresh in error_threshold_range: 78 | fa_rel, _ = compute_tracking_errors( 79 | track_data_pred, 80 | track_data_gt, 81 | error_threshold=thresh, 82 | asynchronous=False, 83 | ) 84 | 85 | inlier_ratio = np.sum(fa_rel > 0) / len(fa_rel) 86 | if inlier_ratio > 0: 87 | fa_rel_nz = fa_rel[np.nonzero(fa_rel)[0]] 88 | else: 89 | fa_rel_nz = [0] 90 | inlier_ratio_arr.append(inlier_ratio) 91 | fa_rel_nz_arr.append(np.mean(fa_rel_nz)) 92 | 93 | mean_inlier_ratio, std_inlier_ratio = np.mean(inlier_ratio_arr), np.std( 94 | inlier_ratio_arr 95 | ) 96 | mean_fa_rel_nz, std_fa_rel_nz = np.mean(fa_rel_nz_arr), np.std(fa_rel_nz_arr) 97 | expected_age = np.mean(np.array(inlier_ratio_arr) * np.array(fa_rel_nz_arr)) 98 | 99 | rows["age_mu"].append(mean_fa_rel_nz) 100 | rows["age_std"].append(std_fa_rel_nz) 101 | rows["inliers_mu"].append(mean_inlier_ratio) 102 | rows["inliers_std"].append(std_inlier_ratio) 103 | rows["expected_age"].append(expected_age) 104 | 105 | fa_rel, te = compute_tracking_errors( 106 | track_data_pred, track_data_gt, error_threshold=5, asynchronous=False 107 | ) 108 | inlier_ratio = np.sum(fa_rel > 0) / len(fa_rel) 109 | if inlier_ratio > 0: 110 | fa_rel_nz = fa_rel[np.nonzero(fa_rel)[0]] 111 | else: 112 | fa_rel_nz = [0] 113 | te = [0] 114 | 115 | mean_fa_rel_nz, std_fa_rel_nz = np.mean(fa_rel_nz), np.std(fa_rel_nz) 116 | mean_te, std_te = np.mean(te), np.std(te) 117 | rows["age_5_mu"].append(mean_fa_rel_nz) 118 | rows["age_5_std"].append(std_fa_rel_nz) 119 | rows["te_5_mu"].append(mean_te) 120 | rows["te_5_std"].append(std_te) 121 | 122 | # Load results 123 | for k in tables.keys(): 124 | tables[k].add_row(rows[k]) 125 | 126 | with open((out_dir / f"benchmarking_results.csv"), "w") as f: 127 | for k in tables.keys(): 128 | f.write(f"{k}\n") 129 | f.write(tables[k].get_csv_string()) 130 | 131 | print(tables[k].get_string()) 132 | -------------------------------------------------------------------------------- /scripts/visualize_eds_data.py: -------------------------------------------------------------------------------- 1 | """ Visualize corrected EDS events and images """ 2 | import os 3 | from glob import glob 4 | 5 | import cv2 6 | import h5py 7 | import hdf5plugin 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | sequence_name = "all_characters" 12 | corrected_image_dir = f"/{sequence_name}/images_corrected" 13 | image_ts_path = f"/{sequence_name}/images_timestamps.txt" 14 | corrected_events_path = ( 15 | f"/{sequence_name}/events_corrected.h5" 16 | ) 17 | dt_event_slice = 500 18 | 19 | # Read times 20 | image_ts = np.genfromtxt(image_ts_path, skip_header=False) 21 | 22 | # Load events 23 | with h5py.File(corrected_events_path, "r") as h5f: 24 | event_times = np.array(h5f["t"]) 25 | print(f"Event time range: {event_times.min()} - {event_times.max()}") 26 | 27 | # plt.ion() 28 | # Iterate over images 29 | for image_idx, image_p in enumerate( 30 | sorted(glob(os.path.join(corrected_image_dir, "*.png"))) 31 | ): 32 | # Load grayscale frame 33 | image = cv2.imread(image_p, cv2.IMREAD_COLOR) 34 | plt.imshow(image) 35 | 36 | # Get relevant events 37 | t1 = image_ts[image_idx] 38 | t0 = t1 - dt_event_slice 39 | first_idx = np.searchsorted(event_times, t0, side="left") 40 | last_idx_p1 = np.searchsorted(event_times, t1, side="right") 41 | x = np.asarray(h5f["x"][first_idx:last_idx_p1]) 42 | y = np.asarray(h5f["y"][first_idx:last_idx_p1]) 43 | p = np.asarray(h5f["p"][first_idx:last_idx_p1]) 44 | on_mask = p == 1 45 | off_mask = p == 0 46 | 47 | # Draw events 48 | plt.scatter(x[on_mask], y[on_mask], s=1, c="green") 49 | plt.scatter(x[off_mask], y[off_mask], s=1, c="red") 50 | plt.title(f"Image Time: {t1*1e-6}") 51 | plt.axis("off") 52 | plt.draw() 53 | plt.pause(0.0001) 54 | plt.clf() 55 | # plt.show() 56 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import hydra 5 | import pytorch_lightning as pl 6 | import torch 7 | 8 | from utils.callbacks import IncreaseSequenceLengthCallback 9 | from utils.utils import * 10 | 11 | logger = logging.getLogger(__name__) 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 13 | torch.set_num_threads(1) 14 | torch.backends.cudnn.benchmark = True 15 | 16 | 17 | @hydra.main(config_path="configs", config_name="train_defaults") 18 | def train(cfg): 19 | pl.seed_everything(1234) 20 | 21 | # Update configuration dicts with common keys 22 | propagate_keys(cfg) 23 | logger.info("\n" + OmegaConf.to_yaml(cfg)) 24 | 25 | # Instantiate model and dataloaders 26 | model = hydra.utils.instantiate( 27 | cfg.model, 28 | _recursive_=False, 29 | ) 30 | if cfg.checkpoint_path.lower() != "none": 31 | # Load weights 32 | model = model.load_from_checkpoint(checkpoint_path=cfg.checkpoint_path) 33 | 34 | # Override stuff for fine-tuning 35 | model.hparams.optimizer.lr = cfg.model.optimizer.lr 36 | model.hparams.optimizer._target_ = cfg.model.optimizer._target_ 37 | model.debug = True 38 | model.unrolls = cfg.init_unrolls 39 | model.max_unrolls = cfg.max_unrolls 40 | model.pose_mode = cfg.model.pose_mode 41 | 42 | data_module = hydra.utils.instantiate(cfg.data) 43 | 44 | # Logging 45 | if cfg.logging: 46 | training_logger = pl.loggers.TensorBoardLogger( 47 | ".", "", "", log_graph=True, default_hp_metric=False 48 | ) 49 | else: 50 | training_logger = None 51 | 52 | # Training schedule 53 | callbacks = [ 54 | IncreaseSequenceLengthCallback( 55 | unroll_factor=cfg.unroll_factor, schedule=cfg.unroll_schedule 56 | ), 57 | pl.callbacks.LearningRateMonitor(logging_interval="epoch"), 58 | ] 59 | 60 | trainer = pl.Trainer( 61 | **OmegaConf.to_container(cfg.trainer), 62 | devices=[0], 63 | accelerator="gpu", 64 | callbacks=callbacks, 65 | logger=training_logger 66 | ) 67 | 68 | trainer.fit(model, datamodule=data_module) 69 | 70 | 71 | if __name__ == "__main__": 72 | train() 73 | -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | import random 2 | from math import pi 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torchvision.transforms import InterpolationMode 8 | from torchvision.transforms.functional import ( 9 | _get_perspective_coeffs, 10 | perspective, 11 | resize, 12 | rotate, 13 | ) 14 | 15 | 16 | def augment_rotation(x, y, max_rotation_deg=15, rotation_deg=None): 17 | """ 18 | Augment a target patch by rotating it. One of [max_rotation_deg, rotation_deg] must be given. 19 | If max_rotation_deg is given, an angle is sampled from [-max, max] 20 | If rotation_deg is is given, that angle is directly applied. 21 | :param x: (C, P, P) tensor of the event representation (patch) 22 | :param y: (m, 2) tensor of the gt displacement 23 | :param max_rotation_deg: int, max rotation angle (+/-) in degrees 24 | :param rotation_deg: int, rotation angle in degrees to apply 25 | :return: x_aug, y_aug 26 | """ 27 | 28 | if not isinstance(max_rotation_deg, type(None)): 29 | angle = random.randint(-max_rotation_deg, max_rotation_deg) 30 | else: 31 | angle = rotation_deg 32 | 33 | x_aug = rotate(x, angle, interpolation=InterpolationMode.NEAREST) 34 | phi = torch.tensor(-angle * pi / 180) 35 | s = torch.sin(phi) 36 | c = torch.cos(phi) 37 | rot = torch.stack([torch.stack([c, -s]), torch.stack([s, c])]) 38 | y_aug = torch.reshape((rot @ torch.reshape(y, (2, 1))), (2,)) 39 | return x_aug, y_aug, angle 40 | 41 | 42 | def unaugment_rotation(y, rotation_deg=None): 43 | """ 44 | Augment a target patch by rotating it. One of [max_rotation_deg, rotation_deg] must be given. 45 | If max_rotation_deg is given, an angle is sampled from [-max, max] 46 | If rotation_deg is is given, that angle is directly applied. 47 | :param x: (C, P, P) tensor of the event representation (patch) 48 | :param y: (m, 2) tensor of the gt displacement 49 | :param max_rotation_deg: int, max rotation angle (+/-) in degrees 50 | :param rotation_deg: int, rotation angle in degrees to apply 51 | :return: x_aug, y_aug 52 | """ 53 | 54 | angle = -rotation_deg 55 | phi = torch.tensor(-angle * pi / 180) 56 | s = torch.sin(phi) 57 | c = torch.cos(phi) 58 | rot = torch.stack([torch.stack([c, -s]), torch.stack([s, c])]) 59 | y_aug = torch.reshape((rot @ torch.reshape(y, (2, 1))), (2,)) 60 | return y_aug 61 | 62 | 63 | def augment_scale(x, y, max_scale_percentage=10, scale_percentage=None): 64 | """ 65 | Augment a target patch by scaling it. Scale percentage is uniformly sampled from [-MAX, MAX] 66 | :param x: (C, P, P) tensor of the event representation (patch) 67 | :param y: (2,) tensor of the gt displacement 68 | :param max_scale_percentage: int, max scale change (+/-) in percentage 69 | :return: x_aug, y_aug 70 | """ 71 | _, patch_size_old, _ = x.shape 72 | if not isinstance(max_scale_percentage, type(None)): 73 | scaling = ( 74 | 1.0 75 | + float(random.randint(-max_scale_percentage, max_scale_percentage)) / 100.0 76 | ) 77 | patch_size_new = int(round(patch_size_old * scaling)) 78 | 79 | # Enforce odd patch size 80 | if patch_size_new % 2 == 0: 81 | patch_size_new += 1 82 | 83 | scaling = patch_size_new / patch_size_old 84 | else: 85 | scaling = scale_percentage 86 | patch_size_new = int(patch_size_old * scaling) 87 | 88 | x_aug = resize(x, [patch_size_new], interpolation=InterpolationMode.NEAREST) 89 | 90 | if scaling < 1.0: 91 | # Pad with zeros 92 | padding = patch_size_old // 2 - patch_size_new // 2 93 | x_aug = F.pad(x_aug, (padding, padding, padding, padding)) 94 | 95 | elif scaling > 1.0: 96 | # Center crop 97 | x_aug = x_aug[ 98 | :, 99 | patch_size_new // 2 100 | - patch_size_old // 2 : patch_size_new // 2 101 | + patch_size_old // 2 102 | + 1, 103 | patch_size_new // 2 104 | - patch_size_old // 2 : patch_size_new // 2 105 | + patch_size_old // 2 106 | + 1, 107 | ] 108 | y_aug = y * scaling 109 | 110 | return x_aug, y_aug, scaling 111 | 112 | 113 | def unaugment_scale(y, scale_percentage): 114 | """ 115 | Augment a target patch by scaling it. Scale percentage is uniformly sampled from [-MAX, MAX] 116 | :param x: (C, P, P) tensor of the event representation (patch) 117 | :param y: (2,) tensor of the gt displacement 118 | :param max_scale_percentage: int, max scale change (+/-) in percentage 119 | :return: x_aug, y_aug 120 | """ 121 | scaling = 1.0 / scale_percentage 122 | y_aug = y * scaling 123 | return y_aug 124 | 125 | 126 | def augment_perspective(x, y, theta=0.1, displacements=None): 127 | """ 128 | Sample displacements for the corners 129 | x_tl, x_tr, x_bl, x_br in [0, theta*P] 130 | y_tl, y_tr, y_bl, y_br in [0, theta*P] 131 | :param x: (C, P, P) tensor of the event representation (patch) 132 | :param y: (2,) tensor of the gt displacement 133 | :param theta: parameter to adjust maximum extent of warping 134 | :param displacements: [(x_tl, x_tr, x_bl, x_br), (y_tl, y_tr, y_bl, y_br)] 135 | :return: 136 | """ 137 | _, patch_size, _ = x.shape 138 | if not isinstance(theta, type(None)): 139 | max_delta = int(round(theta * patch_size)) 140 | x_tl = random.randint(0, max_delta) 141 | x_tr = random.randint(0, max_delta) 142 | x_bl = random.randint(0, max_delta) 143 | x_br = random.randint(0, max_delta) 144 | y_tl = random.randint(0, max_delta) 145 | y_tr = random.randint(0, max_delta) 146 | y_bl = random.randint(0, max_delta) 147 | y_br = random.randint(0, max_delta) 148 | 149 | else: 150 | x_tl, x_tr, x_bl, x_br = displacements[0] 151 | y_tl, y_tr, y_bl, y_br = displacements[1] 152 | 153 | start_points = [ 154 | [0, 0], 155 | [patch_size - 1, 0], 156 | [patch_size - 1, patch_size - 1], 157 | [0, patch_size - 1], 158 | ] 159 | end_points = [ 160 | [x_tl, y_tl], 161 | [patch_size - 1 - x_tr, y_tr], 162 | [patch_size - 1 - x_br, patch_size - 1 - y_br], 163 | [x_bl, patch_size - 1 - y_bl], 164 | ] 165 | x_aug = perspective( 166 | x, start_points, end_points, interpolation=InterpolationMode.NEAREST 167 | ) 168 | 169 | coeffs = _get_perspective_coeffs(start_points, end_points) 170 | # (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) ) 171 | scale = coeffs[6] * y[0] + coeffs[7] * y[1] + 1 172 | y_aug = y.clone() 173 | y_aug[0] = (coeffs[0] * y[0] + coeffs[1] * y[1] + coeffs[2]) / scale 174 | y_aug[1] = (coeffs[3] * y[0] + coeffs[4] * y[1] + coeffs[5]) / scale 175 | 176 | return x_aug, y_aug, (scale.item(), coeffs) 177 | 178 | 179 | def unaugment_perspective(y, scale, coeffs): 180 | H = np.array( 181 | [ 182 | [coeffs[0], coeffs[1], coeffs[2]], 183 | [coeffs[3], coeffs[4], coeffs[5]], 184 | [coeffs[6], coeffs[7], 1], 185 | ] 186 | ) 187 | H_inv = np.linalg.inv(H) 188 | 189 | y = np.array([y[0], y[1], 1]).reshape((3, 1)) * scale 190 | y_aug = H_inv @ y 191 | return torch.from_numpy( 192 | np.array([y_aug[0], y_aug[1]], dtype=np.float32).reshape((2,)) 193 | ) 194 | 195 | 196 | def augment_track(track_data, flipped_lr, flipped_ud, rotation_angle, image_size): 197 | """ 198 | Augment tracks by flipped LR, UP, then rotating 199 | :param track_data: Nx2 array of feature locations over time with time increasing in row dimension 200 | :param flipped_lr: bool 201 | :param flipped_ud: bool 202 | :param rotation_angle: numeric 203 | :param image_size: (W, H) 204 | :return: augmented_track_data: Nx2 array of augmented feature locs 205 | """ 206 | image_center = ((image_size[0] - 1.0) / 2.0, (image_size[1] - 1.0) / 2.0) 207 | 208 | # Offset the track data wrt center of image 209 | track_data_aug = np.copy(track_data) 210 | track_data_aug[:, 0] -= image_center[0] 211 | track_data_aug[:, 1] -= image_center[1] 212 | 213 | # Apply augs 214 | if flipped_lr: 215 | track_data_aug[:, 0] *= -1 216 | if flipped_ud: 217 | track_data_aug[:, 1] *= -1 218 | if rotation_angle > 0: 219 | pass 220 | 221 | # Restore coordinate frame 222 | track_data_aug[:, 0] += image_center[0] 223 | track_data_aug[:, 1] += image_center[1] 224 | return track_data_aug 225 | 226 | 227 | def augment_input(input, flipped_lr, flipped_ud, rotation_angle): 228 | """ 229 | :param input: array-like of shape (H, W), or (H, W, C) 230 | :param flipped_lr: 231 | :param flipped_ud: 232 | :param rotation_angle: 233 | :return: 234 | """ 235 | 236 | if flipped_lr: 237 | input = np.fliplr(input) 238 | if flipped_ud: 239 | input = np.flipud(input) 240 | if rotation_angle > 0: 241 | pass 242 | 243 | return input 244 | -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Callback 2 | 3 | 4 | class IncreaseSequenceLengthCallback(Callback): 5 | def __init__(self, unroll_factor=4, schedule=[40000, 20000, 10000]): 6 | self.unroll_factor = unroll_factor 7 | self.schedule = schedule 8 | self.idx_schedule = 0 9 | 10 | def on_train_batch_end(self, *args): 11 | if ( 12 | self.idx_schedule < len(self.schedule) 13 | and args[0].global_step > self.schedule[self.idx_schedule] 14 | ): 15 | args[1].unrolls = min( 16 | args[1].max_unrolls, self.unroll_factor * args[1].unrolls 17 | ) 18 | self.idx_schedule += 1 19 | print( 20 | f"Increasing unrolls: {self.idx_schedule}, {self.schedule[self.idx_schedule]}, {args[0].global_step}" 21 | ) 22 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class L1Truncated(nn.Module): 6 | """ 7 | L1 Loss, but zero if label is outside the patch 8 | """ 9 | 10 | def __init__(self, patch_size=31): 11 | super(L1Truncated, self).__init__() 12 | self.patch_size = patch_size 13 | self.L1 = nn.L1Loss(reduction="none") 14 | 15 | def forward(self, y, y_hat): 16 | self.mask = ( 17 | (torch.abs(y) <= self.patch_size / 2.0) 18 | .all(dim=1) 19 | .float() 20 | .detach() 21 | .requires_grad_(True) 22 | ) 23 | loss = self.L1(y, y_hat).sum(1) 24 | loss *= self.mask 25 | return loss, self.mask 26 | 27 | 28 | class ReprojectionError: 29 | def __init__(self, threshold=15): 30 | self.threshold = threshold 31 | 32 | def forward(self, projection_matrices, u_centers_hat, training=True): 33 | """ 34 | :param projection_matrices: (B, T, 3, 4) 35 | :param u_centers_hat: (B, T, 2) 36 | :return: (N, T) re-projection errors, (N, T) masks 37 | """ 38 | e_reproj, masks, u_centers_reproj = [], [], [] 39 | 40 | for idx_track in range(u_centers_hat.size(0)): 41 | A_rows = [] 42 | 43 | # Triangulate 44 | for idx_obs in range(u_centers_hat.size(1)): 45 | A_rows.append( 46 | u_centers_hat[idx_track, idx_obs, 0] 47 | * projection_matrices[idx_track, idx_obs, 2:3, :] 48 | - projection_matrices[idx_track, idx_obs, 0:1, :] 49 | ) 50 | A_rows.append( 51 | u_centers_hat[idx_track, idx_obs, 1] 52 | * projection_matrices[idx_track, idx_obs, 2:3, :] 53 | - projection_matrices[idx_track, idx_obs, 1:2, :] 54 | ) 55 | A = torch.cat(A_rows, dim=0) 56 | _, s, vh = torch.linalg.svd(A) 57 | X_init = vh[-1, :].view(4, 1) 58 | X_init = X_init / X_init[3, 0] 59 | 60 | # Re-project 61 | ( 62 | e_reproj_track, 63 | mask_track, 64 | x_proj_track, 65 | ) = ( 66 | [], 67 | [], 68 | [], 69 | ) 70 | for idx_obs in range(u_centers_hat.size(1)): 71 | x_proj = torch.matmul( 72 | projection_matrices[idx_track, idx_obs, :, :], X_init 73 | ) 74 | x_proj = x_proj / x_proj[2, 0] 75 | x_proj_track.append(x_proj[:2, :].detach().view(1, 1, 2)) 76 | err = torch.linalg.norm( 77 | x_proj[:2, 0].view(1, 2).detach() 78 | - u_centers_hat[idx_track, idx_obs, :].view(1, 2), 79 | dim=1, 80 | ) 81 | e_reproj_track.append(err.view(1, 1)) 82 | mask_track.append((err < self.threshold).view(1, 1)) 83 | e_reproj.append(torch.cat(e_reproj_track, dim=1)) 84 | u_centers_reproj.append(torch.cat(x_proj_track, dim=1)) 85 | 86 | mask_track = torch.cat(mask_track, dim=1) 87 | # if X_init[2, 0] < 0 or s[-1] > 20: 88 | # if s[-1] > 20: 89 | # mask_track = torch.zeros_like(mask_track) 90 | masks.append(mask_track) 91 | 92 | e_reproj = torch.cat(e_reproj, dim=0) 93 | masks = torch.cat(masks, dim=0).detach() 94 | 95 | e_reproj *= masks 96 | 97 | if training: 98 | return e_reproj, masks 99 | else: 100 | u_centers_reproj = torch.cat(u_centers_reproj, dim=0) 101 | return e_reproj, masks, u_centers_reproj 102 | 103 | 104 | class L2Distance(nn.Module): 105 | def __init__(self): 106 | super(L2Distance, self).__init__() 107 | 108 | def forward(self, y, y_hat): 109 | diff = y - y_hat 110 | diff = diff**2 111 | return torch.sqrt(torch.sum(diff, dim=list(range(1, len(y.size()))))) 112 | -------------------------------------------------------------------------------- /utils/representations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from enum import Enum, auto 4 | 5 | import hdf5plugin 6 | import h5py 7 | 8 | 9 | class EventRepresentationTypes(Enum): 10 | time_surface = 0 11 | voxel_grid = 1 12 | event_stack = 2 13 | 14 | 15 | class EventRepresentation: 16 | def __init__(self): 17 | pass 18 | 19 | def convert(self, events): 20 | raise NotImplementedError 21 | 22 | 23 | class TimeSurface(EventRepresentation): 24 | def __init__(self, input_size: tuple): 25 | assert len(input_size) == 3 26 | self.input_size = input_size 27 | self.time_surface = torch.zeros(input_size, dtype=torch.float, requires_grad=False) 28 | self.n_bins = input_size[0] // 2 29 | 30 | def convert(self, events): 31 | _, H, W = self.time_surface.shape 32 | with torch.no_grad(): 33 | self.time_surface = torch.zeros(self.input_size, dtype=torch.float, requires_grad=False, 34 | device=events['p'].device) 35 | time_surface = self.time_surface.clone() 36 | 37 | t = events['t'].cpu().numpy() 38 | dt_bin = 1. / self.n_bins 39 | x0 = events['x'].int() 40 | y0 = events['y'].int() 41 | p0 = events['p'].int() 42 | t0 = events['t'] 43 | 44 | # iterate over bins 45 | for i_bin in range(self.n_bins): 46 | t0_bin = i_bin * dt_bin 47 | t1_bin = t0_bin + dt_bin 48 | 49 | # mask_t = np.logical_and(time > t0_bin, time <= t1_bin) 50 | # x_bin, y_bin, p_bin, t_bin = x[mask_t], y[mask_t], p[mask_t], time[mask_t] 51 | idx0 = np.searchsorted(t, t0_bin, side='left') 52 | idx1 = np.searchsorted(t, t1_bin, side='right') 53 | x_bin = x0[idx0:idx1] 54 | y_bin = y0[idx0:idx1] 55 | p_bin = p0[idx0:idx1] 56 | t_bin = t0[idx0:idx1] 57 | 58 | n_events = len(x_bin) 59 | for i in range(n_events): 60 | if 0 <= x_bin[i] < W and 0 <= y_bin[i] < H: 61 | time_surface[2*i_bin+p_bin[i], y_bin[i], x_bin[i]] = t_bin[i] 62 | 63 | return time_surface 64 | 65 | 66 | class VoxelGrid(EventRepresentation): 67 | def __init__(self, input_size: tuple, normalize: bool): 68 | assert len(input_size) == 3 69 | self.voxel_grid = torch.zeros((input_size), dtype=torch.float, requires_grad=False) 70 | self.input_size = input_size 71 | self.nb_channels = input_size[0] 72 | self.normalize = normalize 73 | 74 | def convert(self, events): 75 | C, H, W = self.voxel_grid.shape 76 | with torch.no_grad(): 77 | self.voxel_grid = torch.zeros((self.input_size), dtype=torch.float, requires_grad=False, 78 | device=events['p'].device) 79 | voxel_grid = self.voxel_grid.clone() 80 | 81 | t_norm = events['t'] 82 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0]) 83 | 84 | x0 = events['x'].int() 85 | y0 = events['y'].int() 86 | t0 = t_norm.int() 87 | 88 | value = 2*events['p']-1 89 | 90 | for xlim in [x0,x0+1]: 91 | for ylim in [y0,y0+1]: 92 | for tlim in [t0,t0+1]: 93 | 94 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels) 95 | interp_weights = value * (1 - (xlim-events['x']).abs()) * (1 - (ylim-events['y']).abs()) * (1 - (tlim - t_norm).abs()) 96 | 97 | index = H * W * tlim.long() + \ 98 | W * ylim.long() + \ 99 | xlim.long() 100 | 101 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 102 | 103 | if self.normalize: 104 | mask = torch.nonzero(voxel_grid, as_tuple=True) 105 | if mask[0].size()[0] > 0: 106 | mean = voxel_grid[mask].mean() 107 | std = voxel_grid[mask].std() 108 | if std > 0: 109 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 110 | else: 111 | voxel_grid[mask] = voxel_grid[mask] - mean 112 | 113 | return voxel_grid 114 | 115 | 116 | class EventStack(EventRepresentation): 117 | def __init__(self, input_size: tuple): 118 | """ 119 | :param input_size: (C, H, W) 120 | """ 121 | assert len(input_size) == 3 122 | self.input_size = input_size 123 | self.event_stack = torch.zeros((input_size), dtype=torch.float, requires_grad=False) 124 | self.nb_channels = input_size[0] 125 | 126 | def convert(self, events): 127 | C, H, W = self.event_stack.shape 128 | with torch.no_grad(): 129 | self.event_stack = torch.zeros((self.input_size), dtype=torch.float, requires_grad=False, 130 | device=events['p'].device) 131 | event_stack = self.event_stack.clone() 132 | 133 | t = events['t'].cpu().numpy() 134 | dt_bin = 1. / self.nb_channels 135 | x0 = events['x'].int() 136 | y0 = events['y'].int() 137 | p0 = 2*events['p'].int()-1 138 | t0 = events['t'] 139 | 140 | # iterate over bins 141 | for i_bin in range(self.nb_channels): 142 | t0_bin = i_bin * dt_bin 143 | t1_bin = t0_bin + dt_bin 144 | 145 | # mask_t = np.logical_and(time > t0_bin, time <= t1_bin) 146 | # x_bin, y_bin, p_bin, t_bin = x[mask_t], y[mask_t], p[mask_t], time[mask_t] 147 | idx0 = np.searchsorted(t, t0_bin, side='left') 148 | idx1 = np.searchsorted(t, t1_bin, side='right') 149 | x_bin = x0[idx0:idx1] 150 | y_bin = y0[idx0:idx1] 151 | p_bin = p0[idx0:idx1] 152 | 153 | n_events = len(x_bin) 154 | for i in range(n_events): 155 | if 0 <= x_bin[i] < W and 0 <= y_bin[i] < H: 156 | event_stack[i_bin, y_bin[i], x_bin[i]] += p_bin[i] 157 | 158 | return event_stack 159 | 160 | 161 | def events_to_time_surface(time_surface, p, t, x, y): 162 | t = (t - t[0]).astype('float32') 163 | t = (t/t[-1]) 164 | x = x.astype('float32') 165 | y = y.astype('float32') 166 | pol = p.astype('float32') 167 | event_data_torch = { 168 | 'p': torch.from_numpy(pol), 169 | 't': torch.from_numpy(t), 170 | 'x': torch.from_numpy(x), 171 | 'y': torch.from_numpy(y), 172 | } 173 | return time_surface.convert(event_data_torch) 174 | 175 | def events_to_event_stack(event_stack, p, t, x, y): 176 | t = (t - t[0]).astype('float32') 177 | t = (t/t[-1]) 178 | x = x.astype('float32') 179 | y = y.astype('float32') 180 | pol = p.astype('float32') 181 | event_data_torch = { 182 | 'p': torch.from_numpy(pol), 183 | 't': torch.from_numpy(t), 184 | 'x': torch.from_numpy(x), 185 | 'y': torch.from_numpy(y), 186 | } 187 | return event_stack.convert(event_data_torch) 188 | 189 | 190 | def events_to_voxel_grid(voxel_grid, p, t, x, y): 191 | t = (t - t[0]).astype('float32') 192 | t = (t/t[-1]) 193 | x = x.astype('float32') 194 | y = y.astype('float32') 195 | pol = p.astype('float32') 196 | event_data_torch = { 197 | 'p': torch.from_numpy(pol), 198 | 't': torch.from_numpy(t), 199 | 'x': torch.from_numpy(x), 200 | 'y': torch.from_numpy(y), 201 | } 202 | return voxel_grid.convert(event_data_torch) 203 | -------------------------------------------------------------------------------- /utils/timers.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | 7 | cuda_timers = {} 8 | timers = {} 9 | 10 | 11 | class CudaTimer: 12 | def __init__(self, device: torch.device, timer_name: str = ""): 13 | self.timer_name = timer_name 14 | if self.timer_name not in cuda_timers: 15 | cuda_timers[self.timer_name] = [] 16 | 17 | self.device = device 18 | self.start = None 19 | self.end = None 20 | 21 | def __enter__(self): 22 | torch.cuda.synchronize(device=self.device) 23 | self.start = time.time() 24 | return self 25 | 26 | def __exit__(self, *args): 27 | assert self.start is not None 28 | torch.cuda.synchronize(device=self.device) 29 | end = time.time() 30 | cuda_timers[self.timer_name].append(end - self.start) 31 | 32 | 33 | class CudaTimerWithEvents: 34 | def __init__(self, device: torch.device, timer_name: str = ""): 35 | self.timer_name = timer_name 36 | if self.timer_name not in cuda_timers: 37 | cuda_timers[self.timer_name] = [] 38 | 39 | self.device = device 40 | self.start = torch.cuda.Event(enable_timing=True) 41 | self.end = torch.cuda.Event(enable_timing=True) 42 | 43 | def __enter__(self): 44 | self.start.record(stream=torch.cuda.current_stream(self.device)) 45 | return self 46 | 47 | def __exit__(self, *args): 48 | self.end.record(stream=torch.cuda.current_stream(self.device)) 49 | self.end.synchronize() 50 | # For some reason calling "elapsed_time" uses memory on GPU 0. 51 | cuda_timers[self.timer_name].append(self.start.elapsed_time(self.end) / 1000) 52 | 53 | 54 | class TimerDummy: 55 | def __init__(self, *args, **kwargs): 56 | pass 57 | 58 | def __enter__(self): 59 | pass 60 | 61 | def __exit__(self, *args): 62 | pass 63 | 64 | 65 | class Timer: 66 | def __init__(self, timer_name=""): 67 | self.timer_name = timer_name 68 | if self.timer_name not in timers: 69 | timers[self.timer_name] = [] 70 | 71 | def __enter__(self): 72 | self.start = time.time() 73 | return self 74 | 75 | def __exit__(self, *args): 76 | end = time.time() 77 | time_diff_s = end - self.start # measured in seconds 78 | timers[self.timer_name].append(time_diff_s) 79 | 80 | 81 | def print_timing_info(): 82 | print("== Timing statistics ==") 83 | skip_warmup = 2 84 | for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]: 85 | if len(timing_values) <= skip_warmup: 86 | continue 87 | values = timing_values[skip_warmup:] 88 | timing_value_s = np.mean(np.array(values)) 89 | timing_value_ms = timing_value_s * 1000 90 | if timing_value_ms > 1000: 91 | print("{}: {:.2f} s".format(timer_name, timing_value_s)) 92 | else: 93 | print("{}: {:.2f} ms".format(timer_name, timing_value_ms)) 94 | 95 | 96 | # this will print all the timer values upon termination of any program that imported this file 97 | atexit.register(print_timing_info) 98 | -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def array_to_tensor(array): 6 | # Get patch inputs 7 | array = np.array(array) 8 | if len(array.shape) == 2: 9 | array = np.expand_dims(array, 0) 10 | array = np.transpose(array, (2, 0, 1)) 11 | return torch.from_numpy(array) 12 | 13 | 14 | def get_patch(time_surface, u_center, patch_size): 15 | center = np.rint(u_center).astype(int) 16 | h, w = time_surface.shape 17 | c = 1 18 | 19 | # Check out-of-bounds 20 | if not ((0 <= center[0] < w) and (0 <= center[1] < h)): 21 | return torch.zeros((c, patch_size, patch_size), dtype=torch.float32) 22 | 23 | r_lims, c_lims, pad_lrud = compute_padding(center, patch_size, (w, h)) 24 | 25 | x = np.array(time_surface[r_lims[0] : r_lims[1], c_lims[0] : c_lims[1]]).astype( 26 | np.float32 27 | ) 28 | x = np.expand_dims(x, axis=0) 29 | x = torch.from_numpy(x) 30 | x = torch.nn.functional.pad(x, pad_lrud) 31 | return x 32 | 33 | 34 | def compute_padding(center, patch_size, resolution): 35 | """ 36 | Return patch crop area and required padding 37 | :param center: Integer center coordinates of desired patch crop 38 | :param resolution: Image res (w, h) 39 | :return: 40 | """ 41 | w, h = resolution 42 | 43 | # Crop around the patch 44 | r_min = int(max(0, center[1] - patch_size // 2)) 45 | r_max = int(min(h - 1, center[1] + patch_size // 2 + 1)) 46 | c_min = int(max(0, center[0] - patch_size // 2)) 47 | c_max = int(min(w - 1, center[0] + patch_size // 2 + 1)) 48 | 49 | # Determine padding 50 | pad_l, pad_r, pad_u, pad_d = 0, 0, 0, 0 51 | if center[1] - patch_size // 2 < 0: 52 | pad_u = abs(center[1] - patch_size // 2) 53 | if center[1] + patch_size // 2 + 1 > h - 1: 54 | pad_d = center[1] + patch_size // 2 + 1 - (h - 1) 55 | if center[0] - patch_size // 2 < 0: 56 | pad_l = abs(center[0] - patch_size // 2) 57 | if center[0] + patch_size // 2 + 1 > w - 1: 58 | pad_r = center[0] + patch_size // 2 + 1 - (w - 1) 59 | 60 | return ( 61 | (r_min, r_max), 62 | (c_min, c_max), 63 | (int(pad_l), int(pad_r), int(pad_u), int(pad_d)), 64 | ) 65 | 66 | 67 | def get_patch_tensor(input_tensor, center, patch_size): 68 | """ 69 | 70 | :param input_tensor: (1, c, h, w) 71 | :param u_center: 72 | :param patch_size: 73 | :return: 74 | """ 75 | # center = np.rint(u_center).astype(int) 76 | _, c, h, w = input_tensor.shape 77 | 78 | # Check out-of-bounds 79 | if not ((0 <= center[0] < w) and (0 <= center[1] < h)): 80 | return torch.zeros( 81 | (1, c, patch_size, patch_size), 82 | dtype=torch.float32, 83 | device=input_tensor.device, 84 | ) 85 | 86 | r_lims, c_lims, pad_lrud = compute_padding(center, patch_size, (w, h)) 87 | 88 | x = input_tensor[0:1, :, r_lims[0] : r_lims[1], c_lims[0] : c_lims[1]] 89 | x = torch.nn.functional.pad(x, pad_lrud) 90 | return x 91 | 92 | 93 | def get_patch_voxel(voxel_grid, u_center, patch_size): 94 | center = np.rint(u_center).astype(int).reshape((2,)) 95 | if len(voxel_grid.shape) == 2: 96 | c = 1 97 | h, w = voxel_grid.shape 98 | else: 99 | h, w, c = voxel_grid.shape 100 | 101 | # Check out-of-bounds 102 | if not ((0 <= center[0] < w) and (0 <= center[1] < h)): 103 | return torch.zeros((c, patch_size, patch_size), dtype=torch.float32) 104 | 105 | r_lims, c_lims, pad_lrud = compute_padding(center, patch_size, (w, h)) 106 | 107 | if len(voxel_grid.shape) == 2: 108 | x = np.array(voxel_grid[r_lims[0] : r_lims[1], c_lims[0] : c_lims[1]]).astype( 109 | np.float32 110 | ) 111 | x = np.expand_dims(x, axis=2) 112 | else: 113 | x = np.array( 114 | voxel_grid[r_lims[0] : r_lims[1], c_lims[0] : c_lims[1], :] 115 | ).astype(np.float32) 116 | x = np.transpose(x, (2, 0, 1)) 117 | x = torch.from_numpy(x) 118 | x = torch.nn.functional.pad(x, pad_lrud) 119 | return x 120 | 121 | 122 | def get_patch_voxel2(voxel_grid, u_center, patch_size, padding=10): 123 | """ 124 | get_patch_voxel but using extract glimpse (no roundidng of center coords) 125 | :param voxel_grid: 126 | :param u_center: 127 | :param patch_size: 128 | :return: (C, P, P) 129 | """ 130 | # Extract expanded patches from the h5 file 131 | u_center = u_center.reshape((2,)) 132 | u_center_rounded = np.rint(u_center) 133 | 134 | u_center_offset = u_center - u_center_rounded + ((patch_size + padding) // 2.0) 135 | x_patch_expanded = get_patch_voxel( 136 | voxel_grid, u_center_rounded.reshape((-1,)), patch_size + padding 137 | ).unsqueeze(0) 138 | return extract_glimpse( 139 | x_patch_expanded, 140 | (patch_size, patch_size), 141 | torch.from_numpy(u_center_offset.astype(np.float32)).view((1, 2)) + 0.5, 142 | mode="bilinear", 143 | ).squeeze(0) 144 | 145 | 146 | def get_patch_voxel_pairs(voxel_grid_0, voxel_grid_1, u_center, patch_size): 147 | center = np.rint(u_center).astype(int) 148 | if len(voxel_grid_0.shape) == 2: 149 | c = 1 150 | h, w = voxel_grid_0.shape 151 | else: 152 | h, w, c = voxel_grid_0.shape 153 | 154 | # Check out-of-bounds 155 | if not ((0 <= center[0] < w) and (0 <= center[1] < h)): 156 | return torch.zeros((c * 2, patch_size, patch_size), dtype=torch.float32) 157 | 158 | r_lims, c_lims, pad_lrud = compute_padding(center, patch_size, (w, h)) 159 | 160 | if len(voxel_grid_0.shape) == 2: 161 | x0 = np.array( 162 | voxel_grid_0[r_lims[0] : r_lims[1], c_lims[0] : c_lims[1]] 163 | ).astype(np.float32) 164 | x1 = np.array( 165 | voxel_grid_1[r_lims[0] : r_lims[1], c_lims[0] : c_lims[1]] 166 | ).astype(np.float32) 167 | x0 = np.expand_dims(x0, axis=2) 168 | x1 = np.expand_dims(x1, axis=2) 169 | else: 170 | x0 = np.array( 171 | voxel_grid_0[r_lims[0] : r_lims[1], c_lims[0] : c_lims[1], :] 172 | ).astype(np.float32) 173 | x1 = np.array( 174 | voxel_grid_1[r_lims[0] : r_lims[1], c_lims[0] : c_lims[1], :] 175 | ).astype(np.float32) 176 | x = np.concatenate([x0, x1], axis=2) 177 | x = np.transpose(x, (2, 0, 1)) 178 | x = torch.from_numpy(x) 179 | x = torch.nn.functional.pad(x, pad_lrud) 180 | return x 181 | 182 | 183 | def get_patch_pairs(time_surface_0, time_surface_1, u_center, patch_size): 184 | center = np.rint(u_center).astype(int) 185 | h, w = time_surface_0.shape 186 | c = 1 187 | 188 | # Check out-of-bounds 189 | if not ((0 <= center[0] < w) and (0 <= center[1] < h)): 190 | return torch.zeros((c * 2, patch_size, patch_size), dtype=torch.float32) 191 | 192 | r_lims, c_lims, pad_lrud = compute_padding(center, patch_size, (w, h)) 193 | 194 | x0 = np.array(time_surface_0[r_lims[0] : r_lims[1], c_lims[0] : c_lims[1]]).astype( 195 | np.float32 196 | ) 197 | x1 = np.array(time_surface_1[r_lims[0] : r_lims[1], c_lims[0] : c_lims[1]]).astype( 198 | np.float32 199 | ) 200 | x0 = np.expand_dims(x0, axis=0) 201 | x1 = np.expand_dims(x1, axis=0) 202 | x = np.concatenate([x0, x1], axis=0) 203 | x = torch.from_numpy(x) 204 | x = torch.nn.functional.pad(x, pad_lrud) 205 | return x 206 | 207 | 208 | def extract_glimpse( 209 | input, 210 | size, 211 | offsets, 212 | centered=False, 213 | normalized=False, 214 | mode="nearest", 215 | padding_mode="zeros", 216 | ): 217 | """Returns a set of windows called glimpses extracted at location offsets 218 | from the input tensor. If the windows only partially overlaps the inputs, 219 | the non-overlapping areas are handled as defined by :attr:`padding_mode`. 220 | Options of :attr:`padding_mode` refers to `torch.grid_sample`'s document. 221 | The result is a 4-D tensor of shape [N, C, h, w]. The channels and batch 222 | dimensions are the same as that of the input tensor. The height and width 223 | of the output windows are specified in the size parameter. 224 | The argument normalized and centered controls how the windows are built: 225 | * If the coordinates are normalized but not centered, 0.0 and 1.0 correspond 226 | to the minimum and maximum of each height and width dimension. 227 | * If the coordinates are both normalized and centered, they range from 228 | -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper left 229 | corner, the lower right corner is located at (1.0, 1.0) and the center 230 | is at (0, 0). 231 | * If the coordinates are not normalized they are interpreted as numbers 232 | of pixels. 233 | Args: 234 | input (Tensor): A Tensor of type float32. A 4-D float tensor of shape 235 | [N, C, H, W]. 236 | size (tuple): 2-element integer tuple specified the 237 | output glimpses' size. The glimpse height must be specified first, 238 | following by the glimpse width. 239 | offsets (Tensor): A Tensor of type float32. A 2-D integer tensor of 240 | shape [N, 2] containing the x, y locations of the center 241 | of each window. 242 | centered (bool, optional): An optional bool. Defaults to True. indicates 243 | if the offset coordinates are centered relative to the image, in 244 | which case the (0, 0) offset is relative to the center of the input 245 | images. If false, the (0,0) offset corresponds to the upper left 246 | corner of the input images. 247 | normalized (bool, optional): An optional bool. Defaults to True. indicates 248 | if the offset coordinates are normalized. 249 | mode (str, optional): Interpolation mode to calculate output values. 250 | Defaults to 'bilinear'. 251 | padding_mode (str, optional): padding mode for values outside the input. 252 | Raises: 253 | ValueError: When normalized set False but centered set True 254 | Returns: 255 | output (Tensor): A Tensor of same type with input. 256 | """ 257 | W, H = input.size(-1), input.size(-2) 258 | 259 | if normalized and centered: 260 | offsets = (offsets + 1) * offsets.new_tensor([W / 2, H / 2]) 261 | elif normalized: 262 | offsets = offsets * offsets.new_tensor([W, H]) 263 | elif centered: 264 | raise ValueError("Invalid parameter that offsets centered but not normlized") 265 | 266 | h, w = size 267 | xs = torch.arange(0, w, dtype=input.dtype, device=input.device) - (w - 1) / 2.0 268 | ys = torch.arange(0, h, dtype=input.dtype, device=input.device) - (h - 1) / 2.0 269 | 270 | # vy, vx = torch.meshgrid(ys, xs) 271 | vy, vx = torch.meshgrid(ys, xs, indexing="ij") 272 | grid = torch.stack([vx, vy], dim=-1) # h, w, 2 273 | 274 | offsets_grid = offsets[:, None, None, :] + grid[None, ...] 275 | 276 | # normalised grid to [-1, 1] 277 | offsets_grid = ( 278 | offsets_grid - offsets_grid.new_tensor([W / 2, H / 2]) 279 | ) / offsets_grid.new_tensor([W / 2, H / 2]) 280 | 281 | return torch.nn.functional.grid_sample( 282 | input, offsets_grid, mode=mode, align_corners=True, padding_mode=padding_mode 283 | ) 284 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import Pool 3 | 4 | import h5py 5 | import hdf5plugin 6 | import numpy as np 7 | from cv2 import IMREAD_GRAYSCALE, imread 8 | from omegaconf import OmegaConf, open_dict 9 | from tqdm import tqdm 10 | 11 | 12 | def blosc_opts(complevel=1, complib="blosc:zstd", shuffle="byte"): 13 | # Inspired by: https://github.com/h5py/h5py/issues/611#issuecomment-353694301 14 | # More info on options: https://github.com/Blosc/c-blosc/blob/7435f28dd08606bd51ab42b49b0e654547becac4/blosc/blosc.h#L55-L79 15 | shuffle = 2 if shuffle == "bit" else 1 if shuffle == "byte" else 0 16 | compressors = ["blosclz", "lz4", "lz4hc", "snappy", "zlib", "zstd"] 17 | complib = ["blosc:" + c for c in compressors].index(complib) 18 | args = { 19 | "compression": 32001, 20 | "compression_opts": (0, 0, 0, 0, complevel, shuffle, complib), 21 | } 22 | if shuffle > 0: 23 | # Do not use h5py shuffle if blosc shuffle is enabled. 24 | args["shuffle"] = False 25 | return args 26 | 27 | 28 | def query_events(events_h5, events_t, t0, t1): 29 | """ 30 | Return a numpy array of events in temporal range [t0, t1) 31 | :param events_h5: h5 object with events. {x, y, p, t} as keys. 32 | :param events_t: np array of the uncompressed event times 33 | :param t0: start time of slice in us 34 | :param t1: terminal time of slice in us 35 | :return: (-1, 4) np array 36 | """ 37 | first_idx = np.searchsorted(events_t, t0, side="left") 38 | last_idx_p1 = np.searchsorted(events_t, t1, side="right") 39 | x = np.asarray(events_h5["x"][first_idx:last_idx_p1]) 40 | y = np.asarray(events_h5["y"][first_idx:last_idx_p1]) 41 | p = np.asarray(events_h5["p"][first_idx:last_idx_p1]) 42 | t = np.asarray(events_h5["t"][first_idx:last_idx_p1]) 43 | return {"x": x, "y": y, "p": p, "t": t, "n_events": len(x)} 44 | 45 | 46 | def events2time_surface(events_h5, events_t, t0, t1, resolution): 47 | """ 48 | Build a timesurface from events in temporal range [t0, t1) 49 | :param events_h5: h5 object with events. {x, y, p, t} as keys. 50 | :param events_t: np array of the uncompressed event times 51 | :param t0: start time of slice in us 52 | :param t1: terminal time of slice in us 53 | :param resolution: 2-element tuple (W, H) 54 | :return: (H, W) np array 55 | """ 56 | time_surface = np.zeros((resolution[1], resolution[0]), dtype=np.float64) 57 | events_dict = query_events(events_h5, events_t, t0, t1) 58 | 59 | for i in range(events_dict["n_events"]): 60 | x = int(np.rint(events_dict["x"][i])) 61 | y = int(np.rint(events_dict["y"][i])) 62 | 63 | if 0 <= x < resolution[0] and 0 <= y < resolution[1]: 64 | time_surface[y, x] = (events_dict["t"][i] - t0) / (t1 - t0) 65 | 66 | return time_surface 67 | 68 | 69 | def read_input(input_path, representation): 70 | input_path = str(input_path) 71 | 72 | assert os.path.exists(input_path), f"Path to input file {input_path} doesn't exist." 73 | 74 | if "time_surface" in representation: 75 | return h5py.File(input_path, "r")["time_surface"] 76 | 77 | elif "voxel" in representation: 78 | return h5py.File(input_path, "r")["voxel_grid"] 79 | 80 | elif "event_stack" in representation: 81 | return h5py.File(input_path, "r")["event_stack"] 82 | 83 | elif "grayscale" in representation: 84 | return imread(input_path, IMREAD_GRAYSCALE).astype(np.float32) / 255.0 85 | 86 | else: 87 | print("Unsupported representation") 88 | exit() 89 | 90 | 91 | def propagate_keys(cfg, testing=False): 92 | OmegaConf.set_struct(cfg, True) 93 | with open_dict(cfg): 94 | cfg.data.representation = cfg.representation 95 | cfg.data.track_name = cfg.track_name 96 | cfg.data.patch_size = cfg.patch_size 97 | 98 | cfg.model.representation = cfg.representation 99 | cfg.data.patch_size = cfg.patch_size 100 | 101 | if not testing: 102 | cfg.model.n_vis = cfg.n_vis 103 | cfg.model.init_unrolls = cfg.init_unrolls 104 | cfg.model.max_unrolls = cfg.max_unrolls 105 | cfg.model.debug = cfg.debug 106 | 107 | cfg.model.pose_mode = cfg.data.name == "pose" 108 | 109 | 110 | def skew(x): 111 | return np.array([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]]) 112 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | random.seed(1234) 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | 9 | def generate_track_colors(n_tracks): 10 | track_colors = [] 11 | for i_track in range(n_tracks): 12 | track_colors.append( 13 | ( 14 | random.randint(0, 255) / 255.0, 15 | random.randint(0, 255) / 255.0, 16 | random.randint(0, 255) / 255.0, 17 | ) 18 | ) 19 | return track_colors 20 | 21 | 22 | def fig_to_img(fig): 23 | fig.canvas.draw() 24 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 25 | w, h = fig.canvas.get_width_height() 26 | return data.reshape((h, w, 3)) 27 | 28 | 29 | def render_tracks( 30 | pred_track_interpolator, 31 | gt_track_interpolator, 32 | t, 33 | img, 34 | dt_track=0.2, 35 | error_threshold=15, 36 | track_counter=False, 37 | ): 38 | """ 39 | Plot pred and gt tracks on an image with past tracks in the time window [t-dt_track, t]. 40 | Predicted tracks that exceed the error threshold are not drawn. 41 | :param pred_track_interpolator: 42 | :param gt_track_interpolator: 43 | :param t: 44 | :param img: 45 | :param dt_track: 46 | :param error_threshold: 47 | :return: 48 | """ 49 | h, w = img.shape[:2] 50 | # Create figure 51 | fig = plt.figure() 52 | ax = fig.add_subplot() 53 | if img.ndim == 3: 54 | ax.imshow(img) 55 | else: 56 | ax.imshow(img, cmap="gray") 57 | ax.autoscale(False) 58 | 59 | active_pred_tracks = 0 60 | active_gt_tracks = 0 61 | 62 | # Draw each track 63 | for track_id in range(gt_track_interpolator.n_corners): 64 | gt_track_data_curr = gt_track_interpolator.interpolate(track_id, t) 65 | pred_track_data_curr = pred_track_interpolator.interpolate(track_id, t) 66 | 67 | # Check time 68 | if gt_track_data_curr is not None: 69 | out_of_frame = ( 70 | (gt_track_data_curr[1] >= h) 71 | or (gt_track_data_curr[0] >= w) 72 | or (gt_track_data_curr[1] < 0) 73 | or (gt_track_data_curr[0] < 0) 74 | ) 75 | if isinstance(gt_track_data_curr, type(None)) or out_of_frame: 76 | # print(f"No gt tracks at queried time for track idx {track_id}.") 77 | continue 78 | 79 | else: 80 | active_gt_tracks += 1 81 | 82 | # Draw tracks at query time 83 | ax.scatter( 84 | gt_track_data_curr[0], 85 | gt_track_data_curr[1], 86 | # color=[0, 1, 0], alpha=1., linewidth=1, s=30, marker='o') 87 | color=[255 / 255.0, 255 / 255.0, 0], 88 | alpha=1.0, 89 | linewidth=1, 90 | s=30, 91 | marker="o", 92 | ) 93 | 94 | gt_track_data_hist = gt_track_interpolator.history(track_id, t, dt_track) 95 | ax.plot( 96 | gt_track_data_hist[:, 0], 97 | gt_track_data_hist[:, 1], 98 | # color=[0, 1, 0], alpha=0.5, linewidth=4, linestyle='solid') 99 | color=[255 / 255.0, 255 / 255.0, 0], 100 | alpha=0.5, 101 | linewidth=4, 102 | linestyle="solid", 103 | ) 104 | 105 | if ( 106 | not isinstance(pred_track_data_curr, type(None)) 107 | and np.linalg.norm(pred_track_data_curr - gt_track_data_curr) 108 | < error_threshold 109 | ): 110 | ax.scatter( 111 | pred_track_data_curr[0], 112 | pred_track_data_curr[1], 113 | # color=[0, 0, 1], alpha=1., linewidth=1, s=30, marker='o') 114 | color=[0 / 255.0, 255 / 255.0, 255 / 255.0], 115 | alpha=1.0, 116 | linewidth=1, 117 | s=30, 118 | marker="o", 119 | ) 120 | 121 | pred_track_data_hist = pred_track_interpolator.history( 122 | track_id, t, dt_track 123 | ) 124 | ax.plot( 125 | pred_track_data_hist[:, 0], 126 | pred_track_data_hist[:, 1], 127 | # color=[0, 0, 1], alpha=0.5, linewidth=4, linestyle='solid') 128 | color=[0 / 255.0, 255 / 255.0, 255 / 255.0], 129 | alpha=0.5, 130 | linewidth=4, 131 | linestyle="solid", 132 | ) 133 | 134 | active_pred_tracks += 1 135 | 136 | if track_counter: 137 | # fig = plt.figure() 138 | # ax = fig.add_subplot() 139 | # ax.imshow(img, cmap='gray') 140 | # ax.autoscale(False) 141 | # ax.text(2.5, 10, 'Active Tracks: {} / {}'.format(active_pred_tracks, active_gt_tracks), 142 | ax.text( 143 | 8, 144 | 28.5, 145 | "Active Tracks: {} / {}".format(active_pred_tracks, active_gt_tracks), 146 | fontsize=15, 147 | c="yellow", 148 | bbox=dict(facecolor="black", alpha=0.75), 149 | ) 150 | # ax.axis('off') 151 | # plt.savefig("tmp.png") 152 | 153 | ax.axis("off") 154 | fig_array = fig_to_img(fig) 155 | plt.close(fig) 156 | return fig_array 157 | 158 | 159 | def render_pred_tracks(pred_track_interpolator, t, img, track_colors, dt_track=0.0025): 160 | # Create figure 161 | fig = plt.figure() 162 | ax = fig.add_subplot() 163 | ax.imshow(img, cmap="gray") 164 | ax.autoscale(False) 165 | 166 | for track_id in range(pred_track_interpolator.n_corners): 167 | pred_track_data_curr = pred_track_interpolator.interpolate(track_id, t) 168 | 169 | if not isinstance(pred_track_data_curr, type(None)): 170 | ax.scatter( 171 | pred_track_data_curr[0], 172 | pred_track_data_curr[1], 173 | color=track_colors[track_id], 174 | alpha=0.5, 175 | linewidth=1.0, 176 | s=30, 177 | marker="o", 178 | ) 179 | 180 | pred_track_data_hist = pred_track_interpolator.history( 181 | track_id, t, dt_track 182 | ) 183 | 184 | # ToDo: Change back 185 | pred_track_data_hist = np.concatenate( 186 | [pred_track_data_hist, pred_track_data_curr[None, :]], axis=0 187 | ) 188 | 189 | ax.plot( 190 | pred_track_data_hist[:, 0], 191 | pred_track_data_hist[:, 1], 192 | color=track_colors[track_id], 193 | alpha=0.5, 194 | linewidth=4.0, 195 | linestyle="solid", 196 | ) 197 | 198 | ax.axis("off") 199 | fig.subplots_adjust(bottom=0, top=1, left=0, right=1) 200 | fig_array = fig_to_img(fig) 201 | plt.close(fig) 202 | return fig_array 203 | --------------------------------------------------------------------------------