├── .gitignore ├── README.md ├── assets ├── concept-show.png ├── overview.png └── smr.png ├── loader └── loader_dsec.py └── utils ├── dsec_utils.py ├── gen_dist_map.py └── setupTensor.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | *.DS_Store 162 | .vscode 163 | /runs/** -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EVA-Flow: Towards Anytime Optical Flow Estimation with Event Cameras 2 | The official implementation code repository for [EVA-Flow: Towards Anytime Optical Flow Estimation with Event Cameras](https://arxiv.org/abs/2307.05033) 3 | 4 | 5 | 6 | 7 | ``` 8 | @misc{ye2023anytime, 9 | title={Towards Anytime Optical Flow Estimation with Event Cameras}, 10 | author={Yaozu Ye and Hao Shi and Kailun Yang and Ze Wang and Xiaoting Yin and Yaonan Wang and Kaiwei Wang}, 11 | year={2023}, 12 | eprint={2307.05033}, 13 | archivePrefix={arXiv}, 14 | primaryClass={cs.CV} 15 | } 16 | ``` 17 | 18 | ## Environment 19 | 20 | ```bash 21 | # create and activate conda environment 22 | conda create -n anyflow python=3.9 23 | conda activate anyflow 24 | 25 | # install dependencies for hdf5 26 | conda install blosc-hdf5-plugin=1.0.0 -c conda-forge 27 | conda install pytables 28 | pip install numba h5py hdf5plugin 29 | 30 | # install pytorch, torchvision, tensorboard 31 | # torch version: 1.12.1 or higher 32 | # torchvision version: 0.13.1 or higher 33 | pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 34 | pip install tensorboard 35 | 36 | # imageio depends on freeimage 37 | sudo apt install libfreeimage-dev 38 | # install dependencies for others 39 | pip install tqdm imageio opencv-python pyyaml matplotlib 40 | ``` 41 | 42 | ## Dataset 43 | 44 | ### DSEC Dataset 45 | 46 | 1.Download the DSEC dataset. Dataset Structure is as follows: 47 | 48 | ```text 49 | ├── DSEC 50 | ├── Test 51 | │   ├── test_calibration 52 | │   │   ├── interlaken_00_a 53 | │   │   ├── interlaken_00_b 54 | │   │   ├── ... 55 | │   ├── test_events 56 | │   │   ├── interlaken_00_a 57 | │   │   ├── interlaken_00_b 58 | │   │   ├── ... 59 | │   └── test_forward_optical_flow_timestamps 60 | └── Train 61 | ├── train_calibration 62 | │   ├── interlaken_00_c 63 | │   ├── interlaken_00_d 64 | │   ├── ... 65 | ├── train_events 66 | │   ├── interlaken_00_c 67 | │   ├── interlaken_00_d 68 | │   ├── ... 69 | └── train_optical_flow 70 | ├── thun_00_a 71 | ├── zurich_city_01_a 72 | ├── ... 73 | ``` 74 | 75 | 2. Generate distortion maps for DSEC dataset 76 | 77 | ```bash 78 | python ./utils/gen_dist_map.py -d 'path/to/dataset/DSEC' 79 | ``` -------------------------------------------------------------------------------- /assets/concept-show.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaozhuwa/EVA-Flow/5534f5238c9b6ea1039919a0a032bce2f161530d/assets/concept-show.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaozhuwa/EVA-Flow/5534f5238c9b6ea1039919a0a032bce2f161530d/assets/overview.png -------------------------------------------------------------------------------- /assets/smr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaozhuwa/EVA-Flow/5534f5238c9b6ea1039919a0a032bce2f161530d/assets/smr.png -------------------------------------------------------------------------------- /loader/loader_dsec.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.utils.data import Dataset, DataLoader 5 | from torch.utils.data.dataset import Subset 6 | from torchvision import transforms 7 | import numpy as np 8 | import h5py 9 | import weakref 10 | import imageio 11 | import random 12 | from utils.dsec_utils import flow_16bit_to_float, EventSlicer, VoxelGrid 13 | import argparse 14 | import tqdm 15 | 16 | ''' 17 | DSEC Dataset 18 | └── Train 19 | ├── train_events 20 | │   ├── thun_00_a 21 | │   │   ├── events 22 | │   │   │   ├── left 23 | │   │   │   │   ├── events.h5 24 | │   │   │   │   └── rectify_map.h5 25 | │   │   │   │   └── rect2dist_map.npy 26 | ├── train_optical_flow 27 | │   ├── thun_00_a 28 | │   │   ├── flow 29 | │   │   │   ├── forward 30 | │   │   │   │   ├── xxxxxx.png 31 | │   │   │   ├── forward_timestamps.txt 32 | ''' 33 | 34 | class VoxelGridSequenceDSEC(Dataset): 35 | def __init__(self, events_sequence_path: Path, flow_sequence_path: Path, 36 | crop_size=None): 37 | assert events_sequence_path.is_dir() 38 | assert flow_sequence_path.is_dir() 39 | ''' 40 | Directory Structure: 41 | 42 | Dataset 43 | └── Train 44 | ├── train_events 45 | │   ├── thun_00_a (events_sequence_path) 46 | │   │   ├── events 47 | │   │   │   ├── left 48 | │   │   │   │   ├── events.h5 49 | │   │   │   │   └── rectify_map.h5 50 | ├── train_optical_flow 51 | │   ├── thun_00_a (flow_sequence_path) 52 | │   │   ├── flow 53 | │   │   │   ├── forward 54 | │   │   │   │   ├── xxxxxx.png 55 | │   │   │   ├── forward_timestamps.txt 56 | ''' 57 | forward_timestamps_file = flow_sequence_path / 'flow' / 'forward_timestamps.txt' 58 | assert forward_timestamps_file.is_file() 59 | self.forward_ts_pair = np.genfromtxt(forward_timestamps_file, delimiter=',') 60 | self.forward_ts_pair = self.forward_ts_pair.astype('int64') 61 | ''' 62 | ----- forward_timestamps.txt ----- 63 | # from_timestamp_us, to_timestamp_us 64 | 49599300523, 49599400524 65 | 49599400524, 49599500511 66 | 49599500511, 49599600529 67 | 49599600529, 49599700535 68 | 49599700535, 49599800517 69 | ... 70 | ''' 71 | self.height = 480 72 | self.width = 640 73 | self.voxel_grid = VoxelGrid((15, self.height, self.width), normalize=True) 74 | # 光流的时间间隔 75 | self.delta_t_us = 100000 76 | 77 | ev_dir = events_sequence_path / 'events' / 'left' 78 | ev_data_file = ev_dir / 'events.h5' 79 | ev_rect_file = ev_dir / 'rectify_map.h5' 80 | 81 | h5f_location = h5py.File(str(ev_data_file), 'r') 82 | self.h5f = h5f_location 83 | self.event_slicer = EventSlicer(h5f_location) 84 | with h5py.File(str(ev_rect_file), 'r') as h5_rect: 85 | self.rectify_ev_map = h5_rect['rectify_map'][()] 86 | 87 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f) 88 | 89 | flow_images_path = flow_sequence_path / 'flow' / 'forward' 90 | flow_images_files = [x for x in flow_images_path.iterdir()] 91 | self.gt_flow_files = sorted(flow_images_files, key=lambda x: int(x.stem)) 92 | 93 | # self.transform = transforms 94 | self.crop_size = crop_size 95 | if self.crop_size is not None: 96 | assert self.crop_size[0] <= self.crop_size[1] 97 | 98 | def events_to_voxel_grid(self, p, t, x, y, device: str = 'cpu'): 99 | t = (t - t[0]).astype('float32') 100 | t = (t / t[-1]) 101 | x = x.astype('float32') 102 | y = y.astype('float32') 103 | pol = p.astype('float32') 104 | event_data_torch = { 105 | 'p': torch.from_numpy(pol), 106 | 't': torch.from_numpy(t), 107 | 'x': torch.from_numpy(x), 108 | 'y': torch.from_numpy(y), 109 | } 110 | return self.voxel_grid.convert(event_data_torch) 111 | 112 | @staticmethod 113 | def load_flow(flowfile: Path): 114 | assert flowfile.exists() 115 | assert flowfile.suffix == '.png' 116 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI') 117 | flow, valid2D = flow_16bit_to_float(flow_16bit) 118 | return flow, valid2D 119 | 120 | @staticmethod 121 | def close_callback(h5f): 122 | h5f.close() 123 | 124 | def get_image_width_height(self): 125 | return self.height, self.width 126 | 127 | def __len__(self): 128 | return len(self.forward_ts_pair) 129 | 130 | def rectify_events(self, x: np.ndarray, y: np.ndarray): 131 | # assert location in self.locations 132 | # From distorted to undistorted 133 | rectify_map = self.rectify_ev_map 134 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape 135 | assert x.max() < self.width 136 | assert y.max() < self.height 137 | return rectify_map[y, x] 138 | 139 | def get_data_sample(self, index, crop_window=None): 140 | ''' 141 | output.keys = ['event_volume_old', 'event_volume_new', 'flow', 'valid2D'] 142 | ''' 143 | # First entry corresponds to all events BEFORE the flow map 144 | # Second entry corresponds to all events AFTER the flow map (corresponding to the actual fwd flow) 145 | names = ['event_volume_old', 'event_volume_new'] 146 | ts_begin_i = self.forward_ts_pair[index, 0] 147 | ts_start = [ts_begin_i - self.delta_t_us, ts_begin_i] 148 | ts_end = [ts_begin_i, ts_begin_i + self.delta_t_us] 149 | output = {} 150 | flow, valid2D = VoxelGridSequenceDSEC.load_flow(self.gt_flow_files[index]) 151 | output['flow'] = torch.from_numpy(flow) 152 | # print('flow-size', output['flow'].shape) 153 | output['valid2D'] = torch.from_numpy(valid2D) 154 | # print('valid-size', output['valid2D'].shape) 155 | # 获取事件的 voxel_gird 156 | for i in range(len(names)): 157 | event_data = self.event_slicer.get_events(ts_start[i], ts_end[i]) 158 | 159 | p = event_data['p'] 160 | t = event_data['t'] 161 | x = event_data['x'] 162 | y = event_data['y'] 163 | 164 | # 得到畸变矫正之后的 x_rect 和 y_rect 165 | xy_rect = self.rectify_events(x, y) 166 | x_rect = xy_rect[:, 0] 167 | y_rect = xy_rect[:, 1] 168 | 169 | # 窗口裁剪,去掉一些事件,并非数据增强中的裁剪 170 | if crop_window is not None: 171 | # Cropping (+- 2 for safety reasons) 172 | x_mask = (x_rect >= crop_window['start_x'] - 2) & ( 173 | x_rect < crop_window['start_x'] + crop_window['crop_width'] + 2) 174 | y_mask = (y_rect >= crop_window['start_y'] - 2) & ( 175 | y_rect < crop_window['start_y'] + crop_window['crop_height'] + 2) 176 | mask_combined = x_mask & y_mask 177 | p = p[mask_combined] 178 | t = t[mask_combined] 179 | x_rect = x_rect[mask_combined] 180 | y_rect = y_rect[mask_combined] 181 | 182 | if self.voxel_grid is None: 183 | raise NotImplementedError 184 | else: 185 | event_representation = self.events_to_voxel_grid(p, t, x_rect, y_rect) 186 | output[names[i]] = event_representation 187 | 188 | # random crop and random flip 189 | output['flow'] = output['flow'].permute(2, 0, 1) 190 | if self.crop_size is not None: 191 | # get rand h_start and w_start 192 | rand_max_h = self.height - self.crop_size[0] 193 | rand_max_w = self.width - self.crop_size[1] 194 | h_start = random.randrange(0, rand_max_h + 1) 195 | w_start = random.randrange(0, rand_max_w + 1) 196 | # random flip transform 197 | p = random.randint(0, 1) 198 | flip_param = 1 - 2 * p 199 | flip = transforms.RandomHorizontalFlip(p) 200 | # apply transform 201 | for key in output.keys(): 202 | # print(key, output[key].shape) 203 | output[key] = output[key][..., h_start:h_start + self.crop_size[0], w_start:w_start + self.crop_size[1]] 204 | output[key] = flip(output[key]) 205 | # flip the flow value of u 206 | output['flow'][0] = flip_param * output['flow'][0] 207 | 208 | return output['event_volume_old'], output['event_volume_new'], output['flow'], output['valid2D'].float() 209 | 210 | def __getitem__(self, idx): 211 | return self.get_data_sample(idx) 212 | 213 | 214 | class VoxelGridSequenceDSECTest(Dataset): 215 | def __init__(self, events_sequence_path: Path, forward_ts_file: Path, bins=15): 216 | assert events_sequence_path.is_dir() 217 | assert forward_ts_file.is_file() 218 | ''' 219 | Directory Structure: 220 | 221 | Dataset 222 | └── Train 223 | ├── train_events 224 | │   ├── thun_00_a (events_sequence_path) 225 | │   │   ├── events 226 | │   │   │   ├── left 227 | │   │   │   │   ├── events.h5 228 | │   │   │   │   └── rectify_map.h5 229 | ├── train_optical_flow 230 | │   ├── thun_00_a (flow_sequence_path) 231 | │   │   ├── flow 232 | │   │   │   ├── forward 233 | │   │   │   │   ├── xxxxxx.png 234 | │   │   │   ├── forward_timestamps.txt 235 | ''' 236 | self.forward_ts_pair_idx = np.genfromtxt(forward_ts_file, delimiter=',') 237 | self.forward_ts_pair_idx = self.forward_ts_pair_idx.astype('int64') 238 | ''' 239 | ----- forward_timestamps.txt ----- 240 | # from_timestamp_us, to_timestamp_us 241 | 49599300523, 49599400524 242 | 49599400524, 49599500511 243 | 49599500511, 49599600529 244 | 49599600529, 49599700535 245 | 49599700535, 49599800517 246 | ... 247 | ''' 248 | self.height = 480 249 | self.width = 640 250 | self.bins = bins 251 | self.voxel_grid = VoxelGrid((bins, self.height, self.width), normalize=True) 252 | # 光流的时间间隔 253 | self.delta_t_us = 100000 254 | 255 | ev_dir = events_sequence_path / 'events' / 'left' 256 | ev_data_file = ev_dir / 'events.h5' 257 | ev_rect_file = ev_dir / 'rectify_map.h5' 258 | 259 | h5f_location = h5py.File(str(ev_data_file), 'r') 260 | self.h5f = h5f_location 261 | self.event_slicer = EventSlicer(h5f_location) 262 | with h5py.File(str(ev_rect_file), 'r') as h5_rect: 263 | self.rectify_ev_map = h5_rect['rectify_map'][()] 264 | 265 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f) 266 | 267 | 268 | def events_to_voxel_grid(self, p, t, x, y, device: str = 'cpu'): 269 | t = (t - t[0]).astype('float32') 270 | t = (t / t[-1]) 271 | x = x.astype('float32') 272 | y = y.astype('float32') 273 | pol = p.astype('float32') 274 | event_data_torch = { 275 | 'p': torch.from_numpy(pol), 276 | 't': torch.from_numpy(t), 277 | 'x': torch.from_numpy(x), 278 | 'y': torch.from_numpy(y), 279 | } 280 | return self.voxel_grid.convert(event_data_torch) 281 | 282 | @staticmethod 283 | def load_flow(flowfile: Path): 284 | assert flowfile.exists() 285 | assert flowfile.suffix == '.png' 286 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI') 287 | flow, valid2D = flow_16bit_to_float(flow_16bit) 288 | return flow, valid2D 289 | 290 | @staticmethod 291 | def close_callback(h5f): 292 | h5f.close() 293 | 294 | def get_image_width_height(self): 295 | return self.height, self.width 296 | 297 | def __len__(self): 298 | return len(self.forward_ts_pair_idx) 299 | 300 | def rectify_events(self, x: np.ndarray, y: np.ndarray): 301 | # assert location in self.locations 302 | # From distorted to undistorted 303 | rectify_map = self.rectify_ev_map 304 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape 305 | assert x.max() < self.width 306 | assert y.max() < self.height 307 | return rectify_map[y, x] 308 | 309 | def get_data_sample(self, index, crop_window=None): 310 | ''' 311 | output.keys = ['event_volume_old', 'event_volume_new', 'flow', 'valid2D'] 312 | ''' 313 | # First entry corresponds to all events BEFORE the flow map 314 | # Second entry corresponds to all events AFTER the flow map (corresponding to the actual fwd flow) 315 | names = ['event_volume_old', 'event_volume_new'] 316 | ts_begin_i = self.forward_ts_pair_idx[index, 0] 317 | output_name_idx = self.forward_ts_pair_idx[index, 2] 318 | ts_start = [ts_begin_i - self.delta_t_us, ts_begin_i] 319 | ts_end = [ts_begin_i, ts_begin_i + self.delta_t_us] 320 | output = {} 321 | # 获取事件的 voxel_gird 322 | for i in range(len(names)): 323 | event_data = self.event_slicer.get_events(ts_start[i], ts_end[i]) 324 | 325 | p = event_data['p'] 326 | t = event_data['t'] 327 | x = event_data['x'] 328 | y = event_data['y'] 329 | 330 | # 得到畸变矫正之后的 x_rect 和 y_rect 331 | xy_rect = self.rectify_events(x, y) 332 | x_rect = xy_rect[:, 0] 333 | y_rect = xy_rect[:, 1] 334 | 335 | # 窗口裁剪,去掉一些事件,并非数据增强中的裁剪 336 | if crop_window is not None: 337 | # Cropping (+- 2 for safety reasons) 338 | x_mask = (x_rect >= crop_window['start_x'] - 2) & ( 339 | x_rect < crop_window['start_x'] + crop_window['crop_width'] + 2) 340 | y_mask = (y_rect >= crop_window['start_y'] - 2) & ( 341 | y_rect < crop_window['start_y'] + crop_window['crop_height'] + 2) 342 | mask_combined = x_mask & y_mask 343 | p = p[mask_combined] 344 | t = t[mask_combined] 345 | x_rect = x_rect[mask_combined] 346 | y_rect = y_rect[mask_combined] 347 | 348 | if self.voxel_grid is None: 349 | raise NotImplementedError 350 | else: 351 | event_representation = self.events_to_voxel_grid(p, t, x_rect, y_rect) 352 | output[names[i]] = event_representation 353 | 354 | return output['event_volume_old'], output['event_volume_new'], output_name_idx 355 | 356 | def __getitem__(self, idx): 357 | return self.get_data_sample(idx) 358 | 359 | 360 | class VoxelGridDatasetProviderDSEC: 361 | def __init__(self, dataset_path: Path, crop_size=None, random_split_seed: int = 42, train_ratio=0.8): 362 | events_sequence_path = dataset_path / 'Train' / 'train_events' 363 | flow_sequence_path = dataset_path / 'Train' / 'train_optical_flow' 364 | events_sequences = set([x.stem for x in events_sequence_path.iterdir()]) 365 | flow_sequences = set([x.stem for x in flow_sequence_path.iterdir()]) 366 | valid_sequences = events_sequences.intersection(flow_sequences) 367 | valid_sequences = list(valid_sequences) 368 | valid_sequences.sort() 369 | 370 | dataset_sequences = list() 371 | dataset_sequences_cropped = list() 372 | for sequence in valid_sequences: 373 | dataset_sequences.append(VoxelGridSequenceDSEC(events_sequence_path / sequence, 374 | flow_sequence_path / sequence, 375 | crop_size=None)) 376 | dataset_sequences_cropped.append(VoxelGridSequenceDSEC(events_sequence_path / sequence, 377 | flow_sequence_path / sequence, 378 | crop_size=crop_size)) 379 | 380 | self.dataset = torch.utils.data.ConcatDataset(dataset_sequences) 381 | self.dataset_cropped = torch.utils.data.ConcatDataset(dataset_sequences_cropped) 382 | self.full_size = len(self.dataset) 383 | self.train_size = int(train_ratio * self.full_size) 384 | self.valid_size = self.full_size - self.train_size 385 | 386 | generator = torch.Generator().manual_seed(random_split_seed) 387 | lengths = [self.train_size, self.valid_size] 388 | indices = torch.randperm(sum(lengths), generator=generator).tolist() 389 | 390 | self.train_indices = indices[0: self.train_size] 391 | self.train_indices.sort() 392 | 393 | self.valid_indices = indices[self.train_size:self.full_size] 394 | self.valid_indices.sort() 395 | self.train_set = Subset(self.dataset, self.train_indices) 396 | self.valid_set = Subset(self.dataset, self.valid_indices) 397 | self.train_set_cropped = Subset(self.dataset_cropped, self.train_indices) 398 | self.valid_set_cropped = Subset(self.dataset_cropped, self.valid_indices) 399 | 400 | 401 | class SingleVoxelGridSequenceDSEC(Dataset): 402 | def __init__(self, events_sequence_path: Path, flow_sequence_path: Path, 403 | bins=15, crop_size=None, return_raw=False, unified=True, norm=True): 404 | assert events_sequence_path.is_dir() 405 | assert flow_sequence_path.is_dir() 406 | ''' 407 | Directory Structure: 408 | 409 | Dataset 410 | └── Train 411 | ├── train_events 412 | │   ├── thun_00_a (events_sequence_path) 413 | │   │   ├── events 414 | │   │   │   ├── left 415 | │   │   │   │   ├── events.h5 416 | │   │   │   │   └── rectify_map.h5 417 | ├── train_optical_flow 418 | │   ├── thun_00_a (flow_sequence_path) 419 | │   │   ├── flow 420 | │   │   │   ├── forward 421 | │   │   │   │   ├── xxxxxx.png 422 | │   │   │   ├── forward_timestamps.txt 423 | ''' 424 | forward_timestamps_file = flow_sequence_path / 'flow' / 'forward_timestamps.txt' 425 | assert forward_timestamps_file.is_file() 426 | self.forward_ts_pair = np.genfromtxt(forward_timestamps_file, delimiter=',') 427 | self.forward_ts_pair = self.forward_ts_pair.astype('int64') 428 | self.bins = bins 429 | self.return_raw = return_raw 430 | self.unified = unified 431 | self.norm = norm 432 | ''' 433 | ----- forward_timestamps.txt ----- 434 | # from_timestamp_us, to_timestamp_us 435 | 49599300523, 49599400524 436 | 49599400524, 49599500511 437 | 49599500511, 49599600529 438 | 49599600529, 49599700535 439 | 49599700535, 49599800517 440 | ... 441 | ''' 442 | self.height = 480 443 | self.width = 640 444 | # 光流的时间间隔 445 | self.delta_t_us = 100000 446 | 447 | ev_dir = events_sequence_path / 'events' / 'left' 448 | ev_data_file = ev_dir / 'events.h5' 449 | ev_rect_file = ev_dir / 'rectify_map.h5' 450 | rect2dist_map_path = ev_dir / 'rect2dist_map.npy' 451 | self.rect2dist_map = torch.from_numpy(np.load(str(rect2dist_map_path))) 452 | rect2dist_map = self.rect2dist_map.clone() 453 | rect2dist_map[..., 0] = 2 * rect2dist_map[..., 0] / (self.width - 1) - 1 454 | rect2dist_map[..., 1] = 2 * rect2dist_map[..., 1] / (self.height - 1) - 1 455 | self.grid = rect2dist_map.unsqueeze(0) 456 | 457 | h5f_location = h5py.File(str(ev_data_file), 'r') 458 | self.h5f = h5f_location 459 | self.event_slicer = EventSlicer(h5f_location) 460 | with h5py.File(str(ev_rect_file), 'r') as h5_rect: 461 | self.rectify_ev_map = h5_rect['rectify_map'][()] 462 | 463 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f) 464 | 465 | flow_images_path = flow_sequence_path / 'flow' / 'forward' 466 | flow_images_files = [x for x in flow_images_path.iterdir()] 467 | self.gt_flow_files = sorted(flow_images_files, key=lambda x: int(x.stem)) 468 | 469 | # self.transform = transforms 470 | self.crop_size = crop_size 471 | if self.crop_size is not None: 472 | assert self.crop_size[0] <= self.crop_size[1] 473 | 474 | def events_to_unified_voxel_grid(self, p, t, x, y, bins): 475 | t_norm = torch.from_numpy((t - t[0]).astype('float32')) 476 | # t_norm = (self.bins - 1) * (t_norm / t_norm[-1]) 477 | bin_time = self.delta_t_us / (self.bins - 1) 478 | total_t = 2 * bin_time + self.delta_t_us 479 | assert total_t > t_norm[-1] 480 | t_norm = (self.bins - 1 + 2) * (t_norm / total_t) 481 | x = torch.from_numpy(x.astype('float32')).int() 482 | y = torch.from_numpy(y.astype('float32')).int() 483 | # convert p (0, 1) → (-1, 1) 484 | p_value = 2 * torch.from_numpy(p.astype('float32')) - 1 485 | 486 | t0 = t_norm.int() 487 | H, W = self.height, self.width 488 | fast_voxel_grid = torch.zeros((bins + 2, H, W), dtype=torch.float, requires_grad=False) 489 | for tlim in [t0, t0 + 1]: 490 | mask = tlim < self.bins + 2 491 | index = H * W * tlim.long() + \ 492 | W * y.long() + \ 493 | x.long() 494 | interp_weights = p_value * (1 - (t_norm - tlim).abs()) 495 | fast_voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 496 | 497 | return fast_voxel_grid[1:-1, :, :] 498 | 499 | def events_to_voxel_grid(self, p, t, x, y, bins): 500 | t_norm = torch.from_numpy((t - t[0]).astype('float32')) 501 | # t_norm = (self.bins - 1) * (t_norm / t_norm[-1]) 502 | assert self.delta_t_us > t_norm[-1] 503 | t_norm = (self.bins - 1) * (t_norm / self.delta_t_us) 504 | x = torch.from_numpy(x.astype('float32')).int() 505 | y = torch.from_numpy(y.astype('float32')).int() 506 | # convert p (0, 1) → (-1, 1) 507 | p_value = 2 * torch.from_numpy(p.astype('float32')) - 1 508 | 509 | t0 = t_norm.int() 510 | H, W = self.height, self.width 511 | fast_voxel_grid = torch.zeros((bins, H, W), dtype=torch.float, requires_grad=False) 512 | for tlim in [t0, t0 + 1]: 513 | mask = tlim < self.bins 514 | index = H * W * tlim.long() + \ 515 | W * y.long() + \ 516 | x.long() 517 | interp_weights = p_value * (1 - (t_norm - tlim).abs()) 518 | fast_voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 519 | 520 | return fast_voxel_grid 521 | 522 | @staticmethod 523 | def load_flow(flowfile: Path): 524 | assert flowfile.exists() 525 | assert flowfile.suffix == '.png' 526 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI') 527 | flow, valid2D = flow_16bit_to_float(flow_16bit) 528 | return flow, valid2D 529 | 530 | @staticmethod 531 | def close_callback(h5f): 532 | h5f.close() 533 | 534 | def get_image_width_height(self): 535 | return self.height, self.width 536 | 537 | def __len__(self): 538 | return len(self.forward_ts_pair) 539 | 540 | def rectify_events(self, x: np.ndarray, y: np.ndarray): 541 | # assert location in self.locations 542 | # From distorted to undistorted 543 | rectify_map = self.rectify_ev_map 544 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape 545 | assert x.max() < self.width 546 | assert y.max() < self.height 547 | return rectify_map[y, x] 548 | 549 | def warp_rectify(self, voxel_gird): 550 | rect_grid = F.grid_sample(voxel_gird.unsqueeze(0), self.grid, align_corners=True) 551 | return rect_grid.squeeze(0) 552 | 553 | def get_data_sample(self, index): 554 | ''' 555 | output.keys = ['events_vg', 'flow', 'valid2D'] 556 | ''' 557 | ts_begin_i = self.forward_ts_pair[index, 0] 558 | bin_time = self.delta_t_us / (self.bins - 1) 559 | ts_start_extend = ts_begin_i - bin_time 560 | ts_end_extend = ts_begin_i + self.delta_t_us + bin_time 561 | output = {} 562 | flow, valid2D = VoxelGridSequenceDSEC.load_flow(self.gt_flow_files[index]) 563 | output['flow'] = torch.from_numpy(flow) 564 | # print('flow-size', output['flow'].shape) 565 | output['valid2D'] = torch.from_numpy(valid2D) 566 | # print('valid-size', output['valid2D'].shape) 567 | raw_events = None 568 | 569 | if self.crop_size is not None: 570 | # get rand h_start and w_start 571 | rand_max_h = self.height - self.crop_size[0] 572 | rand_max_w = self.width - self.crop_size[1] 573 | h_start = random.randrange(0, rand_max_h + 1) 574 | w_start = random.randrange(0, rand_max_w + 1) 575 | W = self.crop_size[1] 576 | H = self.crop_size[0] 577 | else: 578 | h_start = 0 579 | w_start = 0 580 | W = self.width 581 | H = self.height 582 | 583 | is_unified = self.unified 584 | event_data = None 585 | if self.unified: 586 | # 获取事件的 voxel_gird_extend 587 | event_data = self.event_slicer.get_events(ts_start_extend, ts_end_extend) 588 | if (event_data is None) or (not self.unified): 589 | """ 590 | if time extend the scope, return voxel_grid rather than voxel_grid_extend 591 | """ 592 | event_data = self.event_slicer.get_events(ts_begin_i, ts_begin_i + self.delta_t_us) 593 | is_unified = False 594 | 595 | p = event_data['p'] 596 | t = event_data['t'] 597 | x = event_data['x'] 598 | y = event_data['y'] 599 | xy_rect = self.rectify_events(event_data['x'], event_data['y']) 600 | x_rect = xy_rect[:, 0] 601 | y_rect = xy_rect[:, 1] 602 | 603 | x_mask = (x_rect >= w_start - 2) & ( 604 | x_rect < w_start + W + 2) 605 | y_mask = (y_rect >= h_start - 2) & ( 606 | y_rect < h_start + H + 2) 607 | mask_combined = x_mask & y_mask 608 | x = x[mask_combined] 609 | y = y[mask_combined] 610 | p = p[mask_combined] 611 | t = t[mask_combined] 612 | x_rect = x_rect[mask_combined] 613 | y_rect = y_rect[mask_combined] 614 | 615 | if is_unified: 616 | fast_voxel_grid = self.events_to_unified_voxel_grid(p, t, x, y, self.bins) 617 | mask_time = (t >= ts_begin_i) & (t < ts_begin_i + self.delta_t_us) 618 | x_rect = x_rect[mask_time] 619 | y_rect = y_rect[mask_time] 620 | p = p[mask_time] 621 | t = t[mask_time] 622 | else: 623 | fast_voxel_grid = self.events_to_voxel_grid(p, t, x, y, self.bins) 624 | 625 | event_len = t.size 626 | if self.return_raw and event_len != 0: 627 | raw_events = torch.zeros(event_len, 4) 628 | raw_events[:, 0] = torch.from_numpy(x_rect) 629 | raw_events[:, 1] = torch.from_numpy(y_rect) 630 | raw_events[:, 3] = 2 * torch.from_numpy(p.astype('float32')) - 1 631 | raw_events[:, 2] = torch.from_numpy( 632 | (t - ts_begin_i).astype('float32')) / self.delta_t_us * (self.bins - 1) 633 | 634 | 635 | # warp to rectify distortion 636 | output['events_vg'] = self.warp_rectify(fast_voxel_grid) 637 | 638 | p_flip = 0 639 | # random crop and random flip 640 | output['flow'] = output['flow'].permute(2, 0, 1) 641 | if self.crop_size is not None: 642 | # random flip transform 643 | p_flip = random.randint(0, 1) 644 | flip_param = 1 - 2 * p_flip 645 | flip = transforms.RandomHorizontalFlip(p_flip) 646 | # apply transform 647 | for key in output.keys(): 648 | # print(key, output[key].shape) 649 | output[key] = output[key][..., h_start:h_start + self.crop_size[0], w_start:w_start + self.crop_size[1]] 650 | output[key] = flip(output[key]) 651 | # flip the flow value of u 652 | output['flow'][0] = flip_param * output['flow'][0] 653 | 654 | if self.return_raw and (raw_events is not None) and raw_events.size()[0] != 0: 655 | crop_mask = (raw_events[:, 0] >= w_start) & (raw_events[:, 0] < w_start + W-1) 656 | crop_mask &= (raw_events[:, 1] >= h_start) & (raw_events[:, 1] < h_start + H-1) 657 | if torch.any(crop_mask): 658 | raw_events = raw_events[crop_mask] 659 | raw_events[:, 0] -= w_start 660 | raw_events[:, 1] -= h_start 661 | if p_flip == 1: 662 | raw_events[:, 0] = W - 1 - raw_events[:, 0] 663 | # exchange [x, y, t, p] to [y, x, t, p] 664 | raw_events = raw_events[:, [1, 0, 2, 3]] 665 | else: 666 | raw_events = None 667 | else: 668 | raw_events = None 669 | 670 | if self.norm: 671 | mask = torch.nonzero(output['events_vg'], as_tuple=True) 672 | if mask[0].size()[0] > 0: 673 | mean = output['events_vg'][mask].mean() 674 | std = output['events_vg'][mask].std() 675 | if std > 0: 676 | output['events_vg'][mask] = (output['events_vg'][mask] - mean) / std 677 | else: 678 | output['events_vg'][mask] = output['events_vg'][mask] - mean 679 | 680 | # raw_events的维度信息和含义: 681 | # 维度: [N, 4], 其中N是事件的数量 682 | # 每一行的4个值分别表示: 683 | # [0]: y 坐标 (高度方向) 684 | # [1]: x 坐标 (宽度方向) 685 | # [2]: 归一化的时间戳 (范围0到bins-1) 686 | # [3]: 极性 (-1或1) 687 | if self.return_raw: 688 | return output['events_vg'], output['flow'], output['valid2D'].float(), raw_events 689 | return output['events_vg'], output['flow'], output['valid2D'].float() 690 | 691 | def __getitem__(self, idx): 692 | return self.get_data_sample(idx) 693 | 694 | class SingleVoxelGridDatasetProviderDSEC: 695 | def __init__(self, dataset_path: Path, bins=15, crop_size=[288, 384], 696 | random_split_seed: int = 42, train_ratio=0.8, return_raw=False, unified=True, norm=True): 697 | events_sequence_path = dataset_path / 'Train' / 'train_events' 698 | flow_sequence_path = dataset_path / 'Train' / 'train_optical_flow' 699 | events_sequences = set([x.stem for x in events_sequence_path.iterdir()]) 700 | flow_sequences = set([x.stem for x in flow_sequence_path.iterdir()]) 701 | valid_sequences = events_sequences.intersection(flow_sequences) 702 | valid_sequences = list(valid_sequences) 703 | valid_sequences.sort() 704 | 705 | dataset_sequences = list() 706 | dataset_sequences_cropped = list() 707 | for sequence in valid_sequences: 708 | dataset_sequences.append(SingleVoxelGridSequenceDSEC(events_sequence_path / sequence, 709 | flow_sequence_path / sequence, 710 | bins=bins, 711 | crop_size=None, 712 | return_raw=return_raw, 713 | unified=unified, 714 | norm=norm)) 715 | dataset_sequences_cropped.append(SingleVoxelGridSequenceDSEC(events_sequence_path / sequence, 716 | flow_sequence_path / sequence, 717 | bins=bins, 718 | crop_size=crop_size, 719 | return_raw=return_raw, 720 | unified=unified, 721 | norm=norm)) 722 | 723 | self.dataset = torch.utils.data.ConcatDataset(dataset_sequences) 724 | self.dataset_cropped = torch.utils.data.ConcatDataset(dataset_sequences_cropped) 725 | self.full_size = len(self.dataset) 726 | self.train_size = int(train_ratio * self.full_size) 727 | self.valid_size = self.full_size - self.train_size 728 | 729 | generator = torch.Generator().manual_seed(random_split_seed) 730 | lengths = [self.train_size, self.valid_size] 731 | indices = torch.randperm(sum(lengths), generator=generator).tolist() 732 | 733 | self.train_indices = indices[0: self.train_size] 734 | self.train_indices.sort() 735 | 736 | self.valid_indices = indices[self.train_size:self.full_size] 737 | self.valid_indices.sort() 738 | self.train_set = Subset(self.dataset, self.train_indices) 739 | self.valid_set = Subset(self.dataset, self.valid_indices) 740 | self.train_set_cropped = Subset(self.dataset_cropped, self.train_indices) 741 | self.valid_set_cropped = Subset(self.dataset_cropped, self.valid_indices) 742 | 743 | class SingleVoxelGridTestSequenceDSEC(Dataset): 744 | def __init__(self, events_sequence_path: Path, forward_ts_file: Path, bins=15, return_raw=False, unified=True, norm=True): 745 | """ 746 | get a sequence 747 | :param events_sequence_path: Path, events_sequence_path 748 | :param flow_sequence_path: Path, flow_sequence_path 749 | :param raster_channels: compressed events channels 750 | """ 751 | assert events_sequence_path.is_dir() 752 | assert forward_ts_file.is_file() 753 | ''' 754 | Directory Structure: 755 | 756 | Dataset 757 | └── Train 758 | ├── train_events 759 | │   ├── thun_00_a (events_sequence_path) 760 | │   │   ├── events 761 | │   │   │   ├── left 762 | │   │   │   │   ├── events.h5 763 | │   │   │   │   └── rectify_map.h5 764 | │   │   │   │   └── rect2dist_map.npy 765 | ''' 766 | self.forward_ts_pair_idx = np.genfromtxt(forward_ts_file, delimiter=',') 767 | self.forward_ts_pair_idx = self.forward_ts_pair_idx.astype('int64') 768 | ''' 769 | ----- forward_ts_file ----- 770 | # from_timestamp_us, to_timestamp_us, file_index 771 | 51648500652, 51648600574, 820 772 | 51649000383, 51649100410, 830 773 | 51649500439, 51649600452, 840 774 | 51650000446, 51650100510, 850 775 | 51650500682, 51650600786, 860 776 | 51651002403, 51651103123, 870 777 | 51651507548, 51651607608, 880 778 | 51652007591, 51652107617, 890 779 | 51652507627, 51652607642, 900 780 | ... 781 | ''' 782 | self.height = 480 783 | self.width = 640 784 | self.bins = bins 785 | # 光流的时间间隔 100000 us 786 | self.delta_t_us = 100000 787 | self.return_raw = return_raw 788 | self.unified = unified 789 | self.norm = norm 790 | 791 | ev_dir = events_sequence_path / 'events' / 'left' 792 | ev_data_file = ev_dir / 'events.h5' 793 | ev_rect_file = ev_dir / 'rectify_map.h5' 794 | 795 | h5f_location = h5py.File(str(ev_data_file), 'r') 796 | self.h5f = h5f_location 797 | self.event_slicer = EventSlicer(h5f_location) 798 | with h5py.File(str(ev_rect_file), 'r') as h5_rect: 799 | self.rectify_ev_map = h5_rect['rectify_map'][()] 800 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f) 801 | 802 | rect2dist_map_path = ev_dir / 'rect2dist_map.npy' 803 | self.rect2dist_map = torch.from_numpy(np.load(str(rect2dist_map_path))) 804 | rect2dist_map = self.rect2dist_map.clone() 805 | rect2dist_map[..., 0] = 2 * rect2dist_map[..., 0] / (self.width - 1) - 1 806 | rect2dist_map[..., 1] = 2 * rect2dist_map[..., 1] / (self.height - 1) - 1 807 | self.grid = rect2dist_map.unsqueeze(0) 808 | 809 | def events_to_unified_voxel_grid(self, p, t, x, y, bins): 810 | t_norm = torch.from_numpy((t - t[0]).astype('float32')) 811 | # t_norm = (self.bins - 1) * (t_norm / t_norm[-1]) 812 | bin_time = self.delta_t_us / (self.bins - 1) 813 | total_t = 2 * bin_time + self.delta_t_us 814 | assert total_t > t_norm[-1] 815 | t_norm = (self.bins - 1 + 2) * (t_norm / total_t) 816 | x = torch.from_numpy(x.astype('float32')).int() 817 | y = torch.from_numpy(y.astype('float32')).int() 818 | # convert p (0, 1) → (-1, 1) 819 | p_value = 2 * torch.from_numpy(p.astype('float32')) - 1 820 | 821 | t0 = t_norm.int() 822 | H, W = self.height, self.width 823 | fast_voxel_grid = torch.zeros((bins + 2, H, W), dtype=torch.float, requires_grad=False) 824 | for tlim in [t0, t0 + 1]: 825 | mask = tlim < self.bins + 2 826 | index = H * W * tlim.long() + \ 827 | W * y.long() + \ 828 | x.long() 829 | interp_weights = p_value * (1 - (t_norm - tlim).abs()) 830 | fast_voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 831 | 832 | return fast_voxel_grid[1:-1, :, :] 833 | 834 | def events_to_voxel_grid(self, p, t, x, y, bins): 835 | t_norm = torch.from_numpy((t - t[0]).astype('float32')) 836 | # t_norm = (self.bins - 1) * (t_norm / t_norm[-1]) 837 | assert self.delta_t_us > t_norm[-1] 838 | t_norm = (self.bins - 1) * (t_norm / self.delta_t_us) 839 | x = torch.from_numpy(x.astype('float32')).int() 840 | y = torch.from_numpy(y.astype('float32')).int() 841 | # convert p (0, 1) → (-1, 1) 842 | p_value = 2 * torch.from_numpy(p.astype('float32')) - 1 843 | 844 | t0 = t_norm.int() 845 | H, W = self.height, self.width 846 | fast_voxel_grid = torch.zeros((bins, H, W), dtype=torch.float, requires_grad=False) 847 | for tlim in [t0, t0 + 1]: 848 | mask = tlim < self.bins 849 | index = H * W * tlim.long() + \ 850 | W * y.long() + \ 851 | x.long() 852 | interp_weights = p_value * (1 - (t_norm - tlim).abs()) 853 | fast_voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 854 | 855 | return fast_voxel_grid 856 | 857 | @staticmethod 858 | def load_flow(flowfile: Path): 859 | assert flowfile.exists() 860 | assert flowfile.suffix == '.png' 861 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI') 862 | flow, valid2D = flow_16bit_to_float(flow_16bit) 863 | return flow, valid2D 864 | 865 | @staticmethod 866 | def close_callback(h5f): 867 | h5f.close() 868 | 869 | def get_image_width_height(self): 870 | return self.height, self.width 871 | 872 | def __len__(self): 873 | return len(self.forward_ts_pair_idx) 874 | 875 | def rectify_events(self, x: np.ndarray, y: np.ndarray): 876 | # assert location in self.locations 877 | # From distorted to undistorted 878 | rectify_map = self.rectify_ev_map 879 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape 880 | assert x.max() < self.width 881 | assert y.max() < self.height 882 | return rectify_map[y, x] 883 | 884 | def warp_rectify(self, voxel_gird): 885 | rect_grid = F.grid_sample(voxel_gird.unsqueeze(0), self.grid, align_corners=True) 886 | return rect_grid.squeeze(0) 887 | 888 | def get_data_sample(self, index): 889 | ''' 890 | output.keys = ['events_vg', 'name'] 891 | ''' 892 | ts_begin_i = self.forward_ts_pair_idx[index, 0] 893 | output_name_idx = self.forward_ts_pair_idx[index, 2] 894 | ts_end_i = self.forward_ts_pair_idx[index, 1] 895 | bin_time = self.delta_t_us / (self.bins - 1) 896 | ts_start_extend = ts_begin_i - bin_time 897 | ts_end_extend = ts_begin_i + self.delta_t_us + bin_time 898 | output = {} 899 | raw_events = None 900 | 901 | h_start = 0 902 | w_start = 0 903 | W = self.width 904 | H = self.height 905 | 906 | is_unified = self.unified 907 | event_data = None 908 | if self.unified: 909 | # 获取事件的 voxel_gird_extend 910 | event_data = self.event_slicer.get_events(ts_start_extend, ts_end_extend) 911 | if (event_data is None) or (not self.unified): 912 | """ 913 | if time extend the scope, return voxel_grid rather than voxel_grid_extend 914 | """ 915 | event_data = self.event_slicer.get_events(ts_begin_i, ts_begin_i + self.delta_t_us) 916 | is_unified = False 917 | 918 | p = event_data['p'] 919 | t = event_data['t'] 920 | x = event_data['x'] 921 | y = event_data['y'] 922 | xy_rect = self.rectify_events(event_data['x'], event_data['y']) 923 | x_rect = xy_rect[:, 0] 924 | y_rect = xy_rect[:, 1] 925 | 926 | x_mask = (x_rect >= w_start - 2) & ( 927 | x_rect < w_start + W + 2) 928 | y_mask = (y_rect >= h_start - 2) & ( 929 | y_rect < h_start + H + 2) 930 | mask_combined = x_mask & y_mask 931 | x = x[mask_combined] 932 | y = y[mask_combined] 933 | p = p[mask_combined] 934 | t = t[mask_combined] 935 | x_rect = x_rect[mask_combined] 936 | y_rect = y_rect[mask_combined] 937 | 938 | if is_unified: 939 | fast_voxel_grid = self.events_to_unified_voxel_grid(p, t, x, y, self.bins) 940 | mask_time = (t >= ts_begin_i) & (t < ts_begin_i + self.delta_t_us) 941 | x_rect = x_rect[mask_time] 942 | y_rect = y_rect[mask_time] 943 | p = p[mask_time] 944 | t = t[mask_time] 945 | else: 946 | fast_voxel_grid = self.events_to_voxel_grid(p, t, x, y, self.bins) 947 | 948 | # warp to rectify distortion 949 | out_voxel_grid = self.warp_rectify(fast_voxel_grid) 950 | 951 | event_len = t.size 952 | if self.return_raw and event_len != 0: 953 | raw_events = torch.zeros(event_len, 4) 954 | raw_events[:, 0] = torch.from_numpy(x_rect) 955 | raw_events[:, 1] = torch.from_numpy(y_rect) 956 | raw_events[:, 3] = 2 * torch.from_numpy(p.astype('float32')) - 1 957 | raw_events[:, 2] = torch.from_numpy( 958 | (t - ts_begin_i).astype('float32')) / self.delta_t_us * (self.bins - 1) 959 | 960 | if self.return_raw and (raw_events is not None) and raw_events.size()[0] != 0: 961 | crop_mask = (raw_events[:, 0] >= w_start) & (raw_events[:, 0] < w_start + W-1) 962 | crop_mask &= (raw_events[:, 1] >= h_start) & (raw_events[:, 1] < h_start + H-1) 963 | if torch.any(crop_mask): 964 | raw_events = raw_events[crop_mask] 965 | raw_events[:, 0] -= w_start 966 | raw_events[:, 1] -= h_start 967 | # exchange [x, y, t, p] to [y, x, t, p] for event_image_converter 968 | raw_events = raw_events[:, [1, 0, 2, 3]] 969 | else: 970 | raw_events = None 971 | else: 972 | raw_events = None 973 | 974 | if self.norm: 975 | mask = torch.nonzero(out_voxel_grid, as_tuple=True) 976 | if mask[0].size()[0] > 0: 977 | mean = out_voxel_grid[mask].mean() 978 | std = out_voxel_grid[mask].std() 979 | if std > 0: 980 | out_voxel_grid[mask] = (out_voxel_grid[mask] - mean) / std 981 | else: 982 | out_voxel_grid[mask] = out_voxel_grid[mask] - mean 983 | 984 | if self.return_raw: 985 | return out_voxel_grid, output_name_idx, raw_events 986 | return out_voxel_grid, output_name_idx 987 | 988 | def __getitem__(self, idx): 989 | return self.get_data_sample(idx) 990 | 991 | 992 | def collate_raw_events(data): 993 | from torch.utils.data.dataloader import default_collate 994 | voxel_grid_list = list() 995 | flow_list = list() 996 | mask_list = list() 997 | raw_events_list = list() 998 | # 遍历数据列表中的每个元素 999 | for i, d in enumerate(data): 1000 | voxel_grid_list.append(d[0]) # 将体素网格数据添加到列表中 1001 | flow_list.append(d[1]) # 将光流数据添加到列表中 1002 | mask_list.append(d[2]) # 将掩码数据添加到列表中 1003 | if d[3] is not None: 1004 | # 如果原始事件数据存在,将批次索引与事件数据拼接后添加到列表中 1005 | raw_events_list.append(torch.cat([i*torch.ones(len(d[3]), 1), d[3]], 1)) 1006 | 1007 | # 使用default_collate函数将列表中的数据合并成张量 1008 | out_voxel_grid = default_collate(voxel_grid_list) 1009 | out_flow = default_collate(flow_list) 1010 | out_mask = default_collate(mask_list) 1011 | 1012 | out_raw_events = None 1013 | if len(raw_events_list)!=0: 1014 | # 如果存在原始事件数据,将所有批次的事件数据拼接成一个大张量 1015 | out_raw_events = torch.cat(raw_events_list, dim=0) 1016 | # out_raw_events的维度信息和含义: 1017 | # 维度: [N, 5], 其中N是所有批次中事件的总数量 1018 | # 每一行的5个值分别表示: 1019 | # [0]: 批次索引 1020 | # [1]: y 坐标 (高度方向) 1021 | # [2]: x 坐标 (宽度方向) 1022 | # [3]: 归一化的时间戳 (范围0到bins-1) 1023 | # [4]: 极性 (-1或1) 1024 | 1025 | # 返回处理后的数据 1026 | return out_voxel_grid, out_flow, out_mask, out_raw_events 1027 | 1028 | 1029 | if __name__ == '__main__': 1030 | parser = argparse.ArgumentParser() 1031 | parser.add_argument('--dataset_path', type=str, default="/media/yyz/FastDisk/Dataset/DSEC") 1032 | parser.add_argument('--bins', type=int, default=15) 1033 | parser.add_argument('--crop_size', type=int, default=[288, 384]) 1034 | parser.add_argument('--random_split_seed', type=int, default=42) 1035 | parser.add_argument('--train_ratio', type=float, default=0.8) 1036 | parser.add_argument('--return_raw', type=bool, default=True) 1037 | parser.add_argument('--unified', type=bool, default=True) 1038 | parser.add_argument('--norm', type=bool, default=True) 1039 | args = parser.parse_args() 1040 | 1041 | import utils.setupTensor 1042 | 1043 | dsec_provider = SingleVoxelGridDatasetProviderDSEC(Path(args.dataset_path), bins=args.bins, 1044 | crop_size=args.crop_size, 1045 | random_split_seed=args.random_split_seed, 1046 | train_ratio=args.train_ratio, 1047 | return_raw=args.return_raw, 1048 | unified=args.unified, 1049 | norm=args.norm) 1050 | 1051 | dataset = dsec_provider.train_set_cropped 1052 | loader = DataLoader(dataset, batch_size=3, num_workers=4, drop_last=False, shuffle=True, collate_fn=collate_raw_events) 1053 | for i, data_blob in enumerate(tqdm.tqdm(loader)): 1054 | voxel_grid, flow, valid_mask, events = data_blob 1055 | voxel_grid, flow, valid_mask = [x.cuda() for x in [voxel_grid, flow, valid_mask]] 1056 | if events is not None: 1057 | events = events.cuda() 1058 | 1059 | -------------------------------------------------------------------------------- /utils/dsec_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import hdf5plugin 4 | import h5py 5 | from typing import Dict, Tuple 6 | from numba import jit 7 | import math 8 | from pathlib import Path 9 | 10 | def flow_16bit_to_float(flow_16bit: np.ndarray): 11 | assert flow_16bit.dtype == np.uint16 12 | assert flow_16bit.ndim == 3 13 | h, w, c = flow_16bit.shape 14 | assert c == 3 15 | 16 | valid2D = flow_16bit[..., 2] == 1 17 | assert valid2D.shape == (h, w) 18 | assert np.all(flow_16bit[~valid2D, -1] == 0) 19 | valid_map = np.where(valid2D) 20 | 21 | # to actually compute something useful: 22 | flow_16bit = flow_16bit.astype('float') 23 | 24 | flow_map = np.zeros((h, w, 2)) 25 | flow_map[valid_map[0], valid_map[1], 0] = (flow_16bit[valid_map[0], valid_map[1], 0] - 2 ** 15) / 128 26 | flow_map[valid_map[0], valid_map[1], 1] = (flow_16bit[valid_map[0], valid_map[1], 1] - 2 ** 15) / 128 27 | return flow_map, valid2D 28 | 29 | 30 | class EventSlicer: 31 | def __init__(self, h5f: h5py.File = None, folder: Path = None): 32 | if folder is None: 33 | self.h5f = h5f 34 | self.events = dict() 35 | for dset_str in ['p', 'x', 'y', 't']: 36 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)] 37 | 38 | # This is the mapping from milliseconds to event index: 39 | # It is defined such that 40 | # (1) t[ms_to_idx[ms]] >= ms*1000 41 | # (2) t[ms_to_idx[ms] - 1] < ms*1000 42 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds. 43 | # 44 | # As an example, given 't' and 'ms': 45 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000 46 | # ms: 0 1 2 3 4 5 6 7 8 9 47 | # 48 | # we get 49 | # 50 | # ms_to_idx: 51 | # 0 2 2 3 3 3 5 5 8 9 52 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64') 53 | 54 | self.t_offset = int(h5f['t_offset'][()]) 55 | self.t_final = int(self.events['t'][-1]) + self.t_offset 56 | else: 57 | self.events = dict() 58 | for dset_str in ['p', 'x', 'y', 't']: 59 | file_name = "events_" + dset_str + ".npy" 60 | self.events[dset_str] = np.load(str(folder / file_name)) 61 | self.ms_to_idx = np.load(str(folder / 'ms_to_idx.npy')) 62 | self.t_offset = int(np.load(str(folder/'t_offset.npy'))[()]) 63 | self.t_final = int(self.events['t'][-1]) + self.t_offset 64 | 65 | def get_final_time_us(self): 66 | return self.t_final 67 | 68 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]: 69 | """Get events (p, x, y, t) within the specified time window 70 | Parameters 71 | ---------- 72 | t_start_us: start time in microseconds 73 | t_end_us: end time in microseconds 74 | Returns 75 | ------- 76 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved 77 | """ 78 | assert t_start_us < t_end_us 79 | 80 | # We assume that the times are top-off-day, hence subtract offset: 81 | t_start_us -= self.t_offset 82 | t_end_us -= self.t_offset 83 | 84 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us) 85 | t_start_ms_idx = self.ms2idx(t_start_ms) 86 | t_end_ms_idx = self.ms2idx(t_end_ms) 87 | 88 | if t_start_ms_idx is None or t_end_ms_idx is None: 89 | # Cannot guarantee window size anymore 90 | return None 91 | 92 | events = dict() 93 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx]) 94 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us) 95 | t_start_us_idx = t_start_ms_idx + idx_start_offset 96 | t_end_us_idx = t_start_ms_idx + idx_end_offset 97 | # Again add t_offset to get gps time 98 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset 99 | for dset_str in ['p', 'x', 'y']: 100 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx]) 101 | assert events[dset_str].size == events['t'].size 102 | return events 103 | 104 | 105 | @staticmethod 106 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]: 107 | """Compute a conservative time window of time with millisecond resolution. 108 | We have a time to index mapping for each millisecond. Hence, we need 109 | to compute the lower and upper millisecond to retrieve events. 110 | Parameters 111 | ---------- 112 | ts_start_us: start time in microseconds 113 | ts_end_us: end time in microseconds 114 | Returns 115 | ------- 116 | window_start_ms: conservative start time in milliseconds 117 | window_end_ms: conservative end time in milliseconds 118 | """ 119 | assert ts_end_us > ts_start_us 120 | window_start_ms = math.floor(ts_start_us/1000) 121 | window_end_ms = math.ceil(ts_end_us/1000) 122 | return window_start_ms, window_end_ms 123 | 124 | @staticmethod 125 | @jit(nopython=True) 126 | def get_time_indices_offsets( 127 | time_array: np.ndarray, 128 | time_start_us: int, 129 | time_end_us: int) -> Tuple[int, int]: 130 | """Compute index offset of start and end timestamps in microseconds 131 | Parameters 132 | ---------- 133 | time_array: timestamps (in us) of the events 134 | time_start_us: start timestamp (in us) 135 | time_end_us: end timestamp (in us) 136 | Returns 137 | ------- 138 | idx_start: Index within this array corresponding to time_start_us 139 | idx_end: Index within this array corresponding to time_end_us 140 | such that (in non-edge cases) 141 | time_array[idx_start] >= time_start_us 142 | time_array[idx_end] >= time_end_us 143 | time_array[idx_start - 1] < time_start_us 144 | time_array[idx_end - 1] < time_end_us 145 | this means that 146 | time_start_us <= time_array[idx_start:idx_end] < time_end_us 147 | """ 148 | 149 | assert time_array.ndim == 1 150 | 151 | idx_start = -1 152 | if time_array[-1] < time_start_us: 153 | # This can happen in extreme corner cases. E.g. 154 | # time_array[0] = 1016 155 | # time_array[-1] = 1984 156 | # time_start_us = 1990 157 | # time_end_us = 2000 158 | 159 | # Return same index twice: array[x:x] is empty. 160 | return time_array.size, time_array.size 161 | else: 162 | for idx_from_start in range(0, time_array.size, 1): 163 | if time_array[idx_from_start] >= time_start_us: 164 | idx_start = idx_from_start 165 | break 166 | assert idx_start >= 0 167 | 168 | idx_end = time_array.size 169 | for idx_from_end in range(time_array.size - 1, -1, -1): 170 | if time_array[idx_from_end] >= time_end_us: 171 | idx_end = idx_from_end 172 | else: 173 | break 174 | 175 | assert time_array[idx_start] >= time_start_us 176 | if idx_end < time_array.size: 177 | assert time_array[idx_end] >= time_end_us 178 | if idx_start > 0: 179 | assert time_array[idx_start - 1] < time_start_us 180 | if idx_end > 0: 181 | assert time_array[idx_end - 1] < time_end_us 182 | return idx_start, idx_end 183 | 184 | def ms2idx(self, time_ms: int) -> int: 185 | assert time_ms >= 0 186 | if time_ms >= self.ms_to_idx.size: 187 | return None 188 | return self.ms_to_idx[time_ms] 189 | 190 | from enum import Enum, auto 191 | 192 | 193 | class RepresentationType(Enum): 194 | VOXEL = auto() 195 | STEPAN = auto() 196 | 197 | 198 | class EventRepresentation: 199 | def __init__(self): 200 | pass 201 | 202 | def convert(self, events): 203 | raise NotImplementedError 204 | 205 | 206 | class VoxelGrid(EventRepresentation): 207 | def __init__(self, input_size: tuple, normalize: bool): 208 | assert len(input_size) == 3 209 | self.voxel_grid = torch.zeros((input_size), dtype=torch.float, requires_grad=False) 210 | self.nb_channels = input_size[0] 211 | self.normalize = normalize 212 | 213 | def convert(self, events): 214 | C, H, W = self.voxel_grid.shape 215 | with torch.no_grad(): 216 | self.voxel_grid = self.voxel_grid.to(events['p'].device) 217 | voxel_grid = self.voxel_grid.clone() 218 | 219 | t_norm = events['t'] 220 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0]) 221 | 222 | x0 = events['x'].int() 223 | y0 = events['y'].int() 224 | t0 = t_norm.int() 225 | 226 | value = 2*events['p']-1 227 | 228 | # voxel_grid 的线性插值(妙啊!) 229 | for xlim in [x0,x0+1]: 230 | for ylim in [y0,y0+1]: 231 | for tlim in [t0,t0+1]: 232 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels) 233 | interp_weights = value * (1 - (xlim-events['x']).abs()) * (1 - (ylim-events['y']).abs()) * (1 - (tlim - t_norm).abs()) 234 | 235 | index = H * W * tlim.long() + \ 236 | W * ylim.long() + \ 237 | xlim.long() 238 | 239 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 240 | 241 | # for valid 242 | # tlim = t0 243 | # mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels) 244 | # interp_weights = value * (1 - (xlim - events['x']).abs()) * (1 - (ylim - events['y']).abs()) 245 | # 246 | # index = H * W * tlim.long() + \ 247 | # W * ylim.long() + \ 248 | # xlim.long() 249 | 250 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 251 | 252 | # 对 voxel_grid 中非零的元素归一化 253 | if self.normalize: 254 | mask = torch.nonzero(voxel_grid, as_tuple=True) 255 | if mask[0].size()[0] > 0: 256 | mean = voxel_grid[mask].mean() 257 | std = voxel_grid[mask].std() 258 | if std > 0: 259 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 260 | else: 261 | voxel_grid[mask] = voxel_grid[mask] - mean 262 | 263 | return voxel_grid 264 | -------------------------------------------------------------------------------- /utils/gen_dist_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import yaml 4 | from pathlib import Path 5 | import random 6 | import argparse 7 | 8 | def generate_dist_map(event_seq_path: Path, calib_seq_path: Path): 9 | calib_sequences = [x.stem for x in calib_seq_path.iterdir() if x.is_dir()] 10 | for seq in calib_sequences: 11 | calib_file_path = calib_seq_path / seq / "calibration/cam_to_cam.yaml" 12 | rectify_map_path = event_seq_path / seq / "events/left/rectify_map.h5" 13 | dist_map_path = event_seq_path / seq / "events/left/rect2dist_map.npy" 14 | if calib_file_path.exists() and rectify_map_path.exists(): 15 | calibration = yaml.load(open(calib_file_path), Loader=yaml.FullLoader) 16 | intrinsic = calibration['intrinsics']['cam0']['camera_matrix'] 17 | dist_params = calibration['intrinsics']['cam0']['distortion_coeffs'] 18 | rect_intrinsic = calibration['intrinsics']['camRect0']['camera_matrix'] 19 | R_rect0 = calibration['extrinsics']["R_rect0"] 20 | fx, fy, cx, cy = intrinsic 21 | camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 22 | fx, fy, cx, cy = rect_intrinsic 23 | rect_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 24 | distortion_coeffs = np.array(dist_params) 25 | R = np.array(R_rect0) 26 | 27 | if not dist_map_path.is_file(): 28 | rect2dist_map, _ = cv2.initUndistortRectifyMap(camera_matrix, distortion_coeffs, R, rect_matrix, (640, 480), 29 | cv2.CV_32FC2) 30 | np.save(str(dist_map_path), rect2dist_map) 31 | print("save", str(dist_map_path)) 32 | else: 33 | rect2dist_map = np.load(str(dist_map_path)) 34 | randx = random.randint(0, 639) 35 | randy = random.randint(0, 479) 36 | dist_point = rect2dist_map[randy, randx] 37 | rect_point = cv2.undistortPoints(dist_point, camera_matrix, distortion_coeffs, R=R, P=rect_matrix) 38 | rect_x = rect_point[0, 0, 0] 39 | rect_y = rect_point[0, 0, 1] 40 | max_error = 0.01 41 | if abs(rect_x-randx)