├── .gitignore
├── README.md
├── assets
├── concept-show.png
├── overview.png
└── smr.png
├── loader
└── loader_dsec.py
└── utils
├── dsec_utils.py
├── gen_dist_map.py
└── setupTensor.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 | *.DS_Store
162 | .vscode
163 | /runs/**
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EVA-Flow: Towards Anytime Optical Flow Estimation with Event Cameras
2 | The official implementation code repository for [EVA-Flow: Towards Anytime Optical Flow Estimation with Event Cameras](https://arxiv.org/abs/2307.05033)
3 |
4 |
5 |
6 |
7 | ```
8 | @misc{ye2023anytime,
9 | title={Towards Anytime Optical Flow Estimation with Event Cameras},
10 | author={Yaozu Ye and Hao Shi and Kailun Yang and Ze Wang and Xiaoting Yin and Yaonan Wang and Kaiwei Wang},
11 | year={2023},
12 | eprint={2307.05033},
13 | archivePrefix={arXiv},
14 | primaryClass={cs.CV}
15 | }
16 | ```
17 |
18 | ## Environment
19 |
20 | ```bash
21 | # create and activate conda environment
22 | conda create -n anyflow python=3.9
23 | conda activate anyflow
24 |
25 | # install dependencies for hdf5
26 | conda install blosc-hdf5-plugin=1.0.0 -c conda-forge
27 | conda install pytables
28 | pip install numba h5py hdf5plugin
29 |
30 | # install pytorch, torchvision, tensorboard
31 | # torch version: 1.12.1 or higher
32 | # torchvision version: 0.13.1 or higher
33 | pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
34 | pip install tensorboard
35 |
36 | # imageio depends on freeimage
37 | sudo apt install libfreeimage-dev
38 | # install dependencies for others
39 | pip install tqdm imageio opencv-python pyyaml matplotlib
40 | ```
41 |
42 | ## Dataset
43 |
44 | ### DSEC Dataset
45 |
46 | 1.Download the DSEC dataset. Dataset Structure is as follows:
47 |
48 | ```text
49 | ├── DSEC
50 | ├── Test
51 | │ ├── test_calibration
52 | │ │ ├── interlaken_00_a
53 | │ │ ├── interlaken_00_b
54 | │ │ ├── ...
55 | │ ├── test_events
56 | │ │ ├── interlaken_00_a
57 | │ │ ├── interlaken_00_b
58 | │ │ ├── ...
59 | │ └── test_forward_optical_flow_timestamps
60 | └── Train
61 | ├── train_calibration
62 | │ ├── interlaken_00_c
63 | │ ├── interlaken_00_d
64 | │ ├── ...
65 | ├── train_events
66 | │ ├── interlaken_00_c
67 | │ ├── interlaken_00_d
68 | │ ├── ...
69 | └── train_optical_flow
70 | ├── thun_00_a
71 | ├── zurich_city_01_a
72 | ├── ...
73 | ```
74 |
75 | 2. Generate distortion maps for DSEC dataset
76 |
77 | ```bash
78 | python ./utils/gen_dist_map.py -d 'path/to/dataset/DSEC'
79 | ```
--------------------------------------------------------------------------------
/assets/concept-show.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yaozhuwa/EVA-Flow/5534f5238c9b6ea1039919a0a032bce2f161530d/assets/concept-show.png
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yaozhuwa/EVA-Flow/5534f5238c9b6ea1039919a0a032bce2f161530d/assets/overview.png
--------------------------------------------------------------------------------
/assets/smr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yaozhuwa/EVA-Flow/5534f5238c9b6ea1039919a0a032bce2f161530d/assets/smr.png
--------------------------------------------------------------------------------
/loader/loader_dsec.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.utils.data import Dataset, DataLoader
5 | from torch.utils.data.dataset import Subset
6 | from torchvision import transforms
7 | import numpy as np
8 | import h5py
9 | import weakref
10 | import imageio
11 | import random
12 | from utils.dsec_utils import flow_16bit_to_float, EventSlicer, VoxelGrid
13 | import argparse
14 | import tqdm
15 |
16 | '''
17 | DSEC Dataset
18 | └── Train
19 | ├── train_events
20 | │ ├── thun_00_a
21 | │ │ ├── events
22 | │ │ │ ├── left
23 | │ │ │ │ ├── events.h5
24 | │ │ │ │ └── rectify_map.h5
25 | │ │ │ │ └── rect2dist_map.npy
26 | ├── train_optical_flow
27 | │ ├── thun_00_a
28 | │ │ ├── flow
29 | │ │ │ ├── forward
30 | │ │ │ │ ├── xxxxxx.png
31 | │ │ │ ├── forward_timestamps.txt
32 | '''
33 |
34 | class VoxelGridSequenceDSEC(Dataset):
35 | def __init__(self, events_sequence_path: Path, flow_sequence_path: Path,
36 | crop_size=None):
37 | assert events_sequence_path.is_dir()
38 | assert flow_sequence_path.is_dir()
39 | '''
40 | Directory Structure:
41 |
42 | Dataset
43 | └── Train
44 | ├── train_events
45 | │ ├── thun_00_a (events_sequence_path)
46 | │ │ ├── events
47 | │ │ │ ├── left
48 | │ │ │ │ ├── events.h5
49 | │ │ │ │ └── rectify_map.h5
50 | ├── train_optical_flow
51 | │ ├── thun_00_a (flow_sequence_path)
52 | │ │ ├── flow
53 | │ │ │ ├── forward
54 | │ │ │ │ ├── xxxxxx.png
55 | │ │ │ ├── forward_timestamps.txt
56 | '''
57 | forward_timestamps_file = flow_sequence_path / 'flow' / 'forward_timestamps.txt'
58 | assert forward_timestamps_file.is_file()
59 | self.forward_ts_pair = np.genfromtxt(forward_timestamps_file, delimiter=',')
60 | self.forward_ts_pair = self.forward_ts_pair.astype('int64')
61 | '''
62 | ----- forward_timestamps.txt -----
63 | # from_timestamp_us, to_timestamp_us
64 | 49599300523, 49599400524
65 | 49599400524, 49599500511
66 | 49599500511, 49599600529
67 | 49599600529, 49599700535
68 | 49599700535, 49599800517
69 | ...
70 | '''
71 | self.height = 480
72 | self.width = 640
73 | self.voxel_grid = VoxelGrid((15, self.height, self.width), normalize=True)
74 | # 光流的时间间隔
75 | self.delta_t_us = 100000
76 |
77 | ev_dir = events_sequence_path / 'events' / 'left'
78 | ev_data_file = ev_dir / 'events.h5'
79 | ev_rect_file = ev_dir / 'rectify_map.h5'
80 |
81 | h5f_location = h5py.File(str(ev_data_file), 'r')
82 | self.h5f = h5f_location
83 | self.event_slicer = EventSlicer(h5f_location)
84 | with h5py.File(str(ev_rect_file), 'r') as h5_rect:
85 | self.rectify_ev_map = h5_rect['rectify_map'][()]
86 |
87 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f)
88 |
89 | flow_images_path = flow_sequence_path / 'flow' / 'forward'
90 | flow_images_files = [x for x in flow_images_path.iterdir()]
91 | self.gt_flow_files = sorted(flow_images_files, key=lambda x: int(x.stem))
92 |
93 | # self.transform = transforms
94 | self.crop_size = crop_size
95 | if self.crop_size is not None:
96 | assert self.crop_size[0] <= self.crop_size[1]
97 |
98 | def events_to_voxel_grid(self, p, t, x, y, device: str = 'cpu'):
99 | t = (t - t[0]).astype('float32')
100 | t = (t / t[-1])
101 | x = x.astype('float32')
102 | y = y.astype('float32')
103 | pol = p.astype('float32')
104 | event_data_torch = {
105 | 'p': torch.from_numpy(pol),
106 | 't': torch.from_numpy(t),
107 | 'x': torch.from_numpy(x),
108 | 'y': torch.from_numpy(y),
109 | }
110 | return self.voxel_grid.convert(event_data_torch)
111 |
112 | @staticmethod
113 | def load_flow(flowfile: Path):
114 | assert flowfile.exists()
115 | assert flowfile.suffix == '.png'
116 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI')
117 | flow, valid2D = flow_16bit_to_float(flow_16bit)
118 | return flow, valid2D
119 |
120 | @staticmethod
121 | def close_callback(h5f):
122 | h5f.close()
123 |
124 | def get_image_width_height(self):
125 | return self.height, self.width
126 |
127 | def __len__(self):
128 | return len(self.forward_ts_pair)
129 |
130 | def rectify_events(self, x: np.ndarray, y: np.ndarray):
131 | # assert location in self.locations
132 | # From distorted to undistorted
133 | rectify_map = self.rectify_ev_map
134 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape
135 | assert x.max() < self.width
136 | assert y.max() < self.height
137 | return rectify_map[y, x]
138 |
139 | def get_data_sample(self, index, crop_window=None):
140 | '''
141 | output.keys = ['event_volume_old', 'event_volume_new', 'flow', 'valid2D']
142 | '''
143 | # First entry corresponds to all events BEFORE the flow map
144 | # Second entry corresponds to all events AFTER the flow map (corresponding to the actual fwd flow)
145 | names = ['event_volume_old', 'event_volume_new']
146 | ts_begin_i = self.forward_ts_pair[index, 0]
147 | ts_start = [ts_begin_i - self.delta_t_us, ts_begin_i]
148 | ts_end = [ts_begin_i, ts_begin_i + self.delta_t_us]
149 | output = {}
150 | flow, valid2D = VoxelGridSequenceDSEC.load_flow(self.gt_flow_files[index])
151 | output['flow'] = torch.from_numpy(flow)
152 | # print('flow-size', output['flow'].shape)
153 | output['valid2D'] = torch.from_numpy(valid2D)
154 | # print('valid-size', output['valid2D'].shape)
155 | # 获取事件的 voxel_gird
156 | for i in range(len(names)):
157 | event_data = self.event_slicer.get_events(ts_start[i], ts_end[i])
158 |
159 | p = event_data['p']
160 | t = event_data['t']
161 | x = event_data['x']
162 | y = event_data['y']
163 |
164 | # 得到畸变矫正之后的 x_rect 和 y_rect
165 | xy_rect = self.rectify_events(x, y)
166 | x_rect = xy_rect[:, 0]
167 | y_rect = xy_rect[:, 1]
168 |
169 | # 窗口裁剪,去掉一些事件,并非数据增强中的裁剪
170 | if crop_window is not None:
171 | # Cropping (+- 2 for safety reasons)
172 | x_mask = (x_rect >= crop_window['start_x'] - 2) & (
173 | x_rect < crop_window['start_x'] + crop_window['crop_width'] + 2)
174 | y_mask = (y_rect >= crop_window['start_y'] - 2) & (
175 | y_rect < crop_window['start_y'] + crop_window['crop_height'] + 2)
176 | mask_combined = x_mask & y_mask
177 | p = p[mask_combined]
178 | t = t[mask_combined]
179 | x_rect = x_rect[mask_combined]
180 | y_rect = y_rect[mask_combined]
181 |
182 | if self.voxel_grid is None:
183 | raise NotImplementedError
184 | else:
185 | event_representation = self.events_to_voxel_grid(p, t, x_rect, y_rect)
186 | output[names[i]] = event_representation
187 |
188 | # random crop and random flip
189 | output['flow'] = output['flow'].permute(2, 0, 1)
190 | if self.crop_size is not None:
191 | # get rand h_start and w_start
192 | rand_max_h = self.height - self.crop_size[0]
193 | rand_max_w = self.width - self.crop_size[1]
194 | h_start = random.randrange(0, rand_max_h + 1)
195 | w_start = random.randrange(0, rand_max_w + 1)
196 | # random flip transform
197 | p = random.randint(0, 1)
198 | flip_param = 1 - 2 * p
199 | flip = transforms.RandomHorizontalFlip(p)
200 | # apply transform
201 | for key in output.keys():
202 | # print(key, output[key].shape)
203 | output[key] = output[key][..., h_start:h_start + self.crop_size[0], w_start:w_start + self.crop_size[1]]
204 | output[key] = flip(output[key])
205 | # flip the flow value of u
206 | output['flow'][0] = flip_param * output['flow'][0]
207 |
208 | return output['event_volume_old'], output['event_volume_new'], output['flow'], output['valid2D'].float()
209 |
210 | def __getitem__(self, idx):
211 | return self.get_data_sample(idx)
212 |
213 |
214 | class VoxelGridSequenceDSECTest(Dataset):
215 | def __init__(self, events_sequence_path: Path, forward_ts_file: Path, bins=15):
216 | assert events_sequence_path.is_dir()
217 | assert forward_ts_file.is_file()
218 | '''
219 | Directory Structure:
220 |
221 | Dataset
222 | └── Train
223 | ├── train_events
224 | │ ├── thun_00_a (events_sequence_path)
225 | │ │ ├── events
226 | │ │ │ ├── left
227 | │ │ │ │ ├── events.h5
228 | │ │ │ │ └── rectify_map.h5
229 | ├── train_optical_flow
230 | │ ├── thun_00_a (flow_sequence_path)
231 | │ │ ├── flow
232 | │ │ │ ├── forward
233 | │ │ │ │ ├── xxxxxx.png
234 | │ │ │ ├── forward_timestamps.txt
235 | '''
236 | self.forward_ts_pair_idx = np.genfromtxt(forward_ts_file, delimiter=',')
237 | self.forward_ts_pair_idx = self.forward_ts_pair_idx.astype('int64')
238 | '''
239 | ----- forward_timestamps.txt -----
240 | # from_timestamp_us, to_timestamp_us
241 | 49599300523, 49599400524
242 | 49599400524, 49599500511
243 | 49599500511, 49599600529
244 | 49599600529, 49599700535
245 | 49599700535, 49599800517
246 | ...
247 | '''
248 | self.height = 480
249 | self.width = 640
250 | self.bins = bins
251 | self.voxel_grid = VoxelGrid((bins, self.height, self.width), normalize=True)
252 | # 光流的时间间隔
253 | self.delta_t_us = 100000
254 |
255 | ev_dir = events_sequence_path / 'events' / 'left'
256 | ev_data_file = ev_dir / 'events.h5'
257 | ev_rect_file = ev_dir / 'rectify_map.h5'
258 |
259 | h5f_location = h5py.File(str(ev_data_file), 'r')
260 | self.h5f = h5f_location
261 | self.event_slicer = EventSlicer(h5f_location)
262 | with h5py.File(str(ev_rect_file), 'r') as h5_rect:
263 | self.rectify_ev_map = h5_rect['rectify_map'][()]
264 |
265 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f)
266 |
267 |
268 | def events_to_voxel_grid(self, p, t, x, y, device: str = 'cpu'):
269 | t = (t - t[0]).astype('float32')
270 | t = (t / t[-1])
271 | x = x.astype('float32')
272 | y = y.astype('float32')
273 | pol = p.astype('float32')
274 | event_data_torch = {
275 | 'p': torch.from_numpy(pol),
276 | 't': torch.from_numpy(t),
277 | 'x': torch.from_numpy(x),
278 | 'y': torch.from_numpy(y),
279 | }
280 | return self.voxel_grid.convert(event_data_torch)
281 |
282 | @staticmethod
283 | def load_flow(flowfile: Path):
284 | assert flowfile.exists()
285 | assert flowfile.suffix == '.png'
286 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI')
287 | flow, valid2D = flow_16bit_to_float(flow_16bit)
288 | return flow, valid2D
289 |
290 | @staticmethod
291 | def close_callback(h5f):
292 | h5f.close()
293 |
294 | def get_image_width_height(self):
295 | return self.height, self.width
296 |
297 | def __len__(self):
298 | return len(self.forward_ts_pair_idx)
299 |
300 | def rectify_events(self, x: np.ndarray, y: np.ndarray):
301 | # assert location in self.locations
302 | # From distorted to undistorted
303 | rectify_map = self.rectify_ev_map
304 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape
305 | assert x.max() < self.width
306 | assert y.max() < self.height
307 | return rectify_map[y, x]
308 |
309 | def get_data_sample(self, index, crop_window=None):
310 | '''
311 | output.keys = ['event_volume_old', 'event_volume_new', 'flow', 'valid2D']
312 | '''
313 | # First entry corresponds to all events BEFORE the flow map
314 | # Second entry corresponds to all events AFTER the flow map (corresponding to the actual fwd flow)
315 | names = ['event_volume_old', 'event_volume_new']
316 | ts_begin_i = self.forward_ts_pair_idx[index, 0]
317 | output_name_idx = self.forward_ts_pair_idx[index, 2]
318 | ts_start = [ts_begin_i - self.delta_t_us, ts_begin_i]
319 | ts_end = [ts_begin_i, ts_begin_i + self.delta_t_us]
320 | output = {}
321 | # 获取事件的 voxel_gird
322 | for i in range(len(names)):
323 | event_data = self.event_slicer.get_events(ts_start[i], ts_end[i])
324 |
325 | p = event_data['p']
326 | t = event_data['t']
327 | x = event_data['x']
328 | y = event_data['y']
329 |
330 | # 得到畸变矫正之后的 x_rect 和 y_rect
331 | xy_rect = self.rectify_events(x, y)
332 | x_rect = xy_rect[:, 0]
333 | y_rect = xy_rect[:, 1]
334 |
335 | # 窗口裁剪,去掉一些事件,并非数据增强中的裁剪
336 | if crop_window is not None:
337 | # Cropping (+- 2 for safety reasons)
338 | x_mask = (x_rect >= crop_window['start_x'] - 2) & (
339 | x_rect < crop_window['start_x'] + crop_window['crop_width'] + 2)
340 | y_mask = (y_rect >= crop_window['start_y'] - 2) & (
341 | y_rect < crop_window['start_y'] + crop_window['crop_height'] + 2)
342 | mask_combined = x_mask & y_mask
343 | p = p[mask_combined]
344 | t = t[mask_combined]
345 | x_rect = x_rect[mask_combined]
346 | y_rect = y_rect[mask_combined]
347 |
348 | if self.voxel_grid is None:
349 | raise NotImplementedError
350 | else:
351 | event_representation = self.events_to_voxel_grid(p, t, x_rect, y_rect)
352 | output[names[i]] = event_representation
353 |
354 | return output['event_volume_old'], output['event_volume_new'], output_name_idx
355 |
356 | def __getitem__(self, idx):
357 | return self.get_data_sample(idx)
358 |
359 |
360 | class VoxelGridDatasetProviderDSEC:
361 | def __init__(self, dataset_path: Path, crop_size=None, random_split_seed: int = 42, train_ratio=0.8):
362 | events_sequence_path = dataset_path / 'Train' / 'train_events'
363 | flow_sequence_path = dataset_path / 'Train' / 'train_optical_flow'
364 | events_sequences = set([x.stem for x in events_sequence_path.iterdir()])
365 | flow_sequences = set([x.stem for x in flow_sequence_path.iterdir()])
366 | valid_sequences = events_sequences.intersection(flow_sequences)
367 | valid_sequences = list(valid_sequences)
368 | valid_sequences.sort()
369 |
370 | dataset_sequences = list()
371 | dataset_sequences_cropped = list()
372 | for sequence in valid_sequences:
373 | dataset_sequences.append(VoxelGridSequenceDSEC(events_sequence_path / sequence,
374 | flow_sequence_path / sequence,
375 | crop_size=None))
376 | dataset_sequences_cropped.append(VoxelGridSequenceDSEC(events_sequence_path / sequence,
377 | flow_sequence_path / sequence,
378 | crop_size=crop_size))
379 |
380 | self.dataset = torch.utils.data.ConcatDataset(dataset_sequences)
381 | self.dataset_cropped = torch.utils.data.ConcatDataset(dataset_sequences_cropped)
382 | self.full_size = len(self.dataset)
383 | self.train_size = int(train_ratio * self.full_size)
384 | self.valid_size = self.full_size - self.train_size
385 |
386 | generator = torch.Generator().manual_seed(random_split_seed)
387 | lengths = [self.train_size, self.valid_size]
388 | indices = torch.randperm(sum(lengths), generator=generator).tolist()
389 |
390 | self.train_indices = indices[0: self.train_size]
391 | self.train_indices.sort()
392 |
393 | self.valid_indices = indices[self.train_size:self.full_size]
394 | self.valid_indices.sort()
395 | self.train_set = Subset(self.dataset, self.train_indices)
396 | self.valid_set = Subset(self.dataset, self.valid_indices)
397 | self.train_set_cropped = Subset(self.dataset_cropped, self.train_indices)
398 | self.valid_set_cropped = Subset(self.dataset_cropped, self.valid_indices)
399 |
400 |
401 | class SingleVoxelGridSequenceDSEC(Dataset):
402 | def __init__(self, events_sequence_path: Path, flow_sequence_path: Path,
403 | bins=15, crop_size=None, return_raw=False, unified=True, norm=True):
404 | assert events_sequence_path.is_dir()
405 | assert flow_sequence_path.is_dir()
406 | '''
407 | Directory Structure:
408 |
409 | Dataset
410 | └── Train
411 | ├── train_events
412 | │ ├── thun_00_a (events_sequence_path)
413 | │ │ ├── events
414 | │ │ │ ├── left
415 | │ │ │ │ ├── events.h5
416 | │ │ │ │ └── rectify_map.h5
417 | ├── train_optical_flow
418 | │ ├── thun_00_a (flow_sequence_path)
419 | │ │ ├── flow
420 | │ │ │ ├── forward
421 | │ │ │ │ ├── xxxxxx.png
422 | │ │ │ ├── forward_timestamps.txt
423 | '''
424 | forward_timestamps_file = flow_sequence_path / 'flow' / 'forward_timestamps.txt'
425 | assert forward_timestamps_file.is_file()
426 | self.forward_ts_pair = np.genfromtxt(forward_timestamps_file, delimiter=',')
427 | self.forward_ts_pair = self.forward_ts_pair.astype('int64')
428 | self.bins = bins
429 | self.return_raw = return_raw
430 | self.unified = unified
431 | self.norm = norm
432 | '''
433 | ----- forward_timestamps.txt -----
434 | # from_timestamp_us, to_timestamp_us
435 | 49599300523, 49599400524
436 | 49599400524, 49599500511
437 | 49599500511, 49599600529
438 | 49599600529, 49599700535
439 | 49599700535, 49599800517
440 | ...
441 | '''
442 | self.height = 480
443 | self.width = 640
444 | # 光流的时间间隔
445 | self.delta_t_us = 100000
446 |
447 | ev_dir = events_sequence_path / 'events' / 'left'
448 | ev_data_file = ev_dir / 'events.h5'
449 | ev_rect_file = ev_dir / 'rectify_map.h5'
450 | rect2dist_map_path = ev_dir / 'rect2dist_map.npy'
451 | self.rect2dist_map = torch.from_numpy(np.load(str(rect2dist_map_path)))
452 | rect2dist_map = self.rect2dist_map.clone()
453 | rect2dist_map[..., 0] = 2 * rect2dist_map[..., 0] / (self.width - 1) - 1
454 | rect2dist_map[..., 1] = 2 * rect2dist_map[..., 1] / (self.height - 1) - 1
455 | self.grid = rect2dist_map.unsqueeze(0)
456 |
457 | h5f_location = h5py.File(str(ev_data_file), 'r')
458 | self.h5f = h5f_location
459 | self.event_slicer = EventSlicer(h5f_location)
460 | with h5py.File(str(ev_rect_file), 'r') as h5_rect:
461 | self.rectify_ev_map = h5_rect['rectify_map'][()]
462 |
463 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f)
464 |
465 | flow_images_path = flow_sequence_path / 'flow' / 'forward'
466 | flow_images_files = [x for x in flow_images_path.iterdir()]
467 | self.gt_flow_files = sorted(flow_images_files, key=lambda x: int(x.stem))
468 |
469 | # self.transform = transforms
470 | self.crop_size = crop_size
471 | if self.crop_size is not None:
472 | assert self.crop_size[0] <= self.crop_size[1]
473 |
474 | def events_to_unified_voxel_grid(self, p, t, x, y, bins):
475 | t_norm = torch.from_numpy((t - t[0]).astype('float32'))
476 | # t_norm = (self.bins - 1) * (t_norm / t_norm[-1])
477 | bin_time = self.delta_t_us / (self.bins - 1)
478 | total_t = 2 * bin_time + self.delta_t_us
479 | assert total_t > t_norm[-1]
480 | t_norm = (self.bins - 1 + 2) * (t_norm / total_t)
481 | x = torch.from_numpy(x.astype('float32')).int()
482 | y = torch.from_numpy(y.astype('float32')).int()
483 | # convert p (0, 1) → (-1, 1)
484 | p_value = 2 * torch.from_numpy(p.astype('float32')) - 1
485 |
486 | t0 = t_norm.int()
487 | H, W = self.height, self.width
488 | fast_voxel_grid = torch.zeros((bins + 2, H, W), dtype=torch.float, requires_grad=False)
489 | for tlim in [t0, t0 + 1]:
490 | mask = tlim < self.bins + 2
491 | index = H * W * tlim.long() + \
492 | W * y.long() + \
493 | x.long()
494 | interp_weights = p_value * (1 - (t_norm - tlim).abs())
495 | fast_voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
496 |
497 | return fast_voxel_grid[1:-1, :, :]
498 |
499 | def events_to_voxel_grid(self, p, t, x, y, bins):
500 | t_norm = torch.from_numpy((t - t[0]).astype('float32'))
501 | # t_norm = (self.bins - 1) * (t_norm / t_norm[-1])
502 | assert self.delta_t_us > t_norm[-1]
503 | t_norm = (self.bins - 1) * (t_norm / self.delta_t_us)
504 | x = torch.from_numpy(x.astype('float32')).int()
505 | y = torch.from_numpy(y.astype('float32')).int()
506 | # convert p (0, 1) → (-1, 1)
507 | p_value = 2 * torch.from_numpy(p.astype('float32')) - 1
508 |
509 | t0 = t_norm.int()
510 | H, W = self.height, self.width
511 | fast_voxel_grid = torch.zeros((bins, H, W), dtype=torch.float, requires_grad=False)
512 | for tlim in [t0, t0 + 1]:
513 | mask = tlim < self.bins
514 | index = H * W * tlim.long() + \
515 | W * y.long() + \
516 | x.long()
517 | interp_weights = p_value * (1 - (t_norm - tlim).abs())
518 | fast_voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
519 |
520 | return fast_voxel_grid
521 |
522 | @staticmethod
523 | def load_flow(flowfile: Path):
524 | assert flowfile.exists()
525 | assert flowfile.suffix == '.png'
526 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI')
527 | flow, valid2D = flow_16bit_to_float(flow_16bit)
528 | return flow, valid2D
529 |
530 | @staticmethod
531 | def close_callback(h5f):
532 | h5f.close()
533 |
534 | def get_image_width_height(self):
535 | return self.height, self.width
536 |
537 | def __len__(self):
538 | return len(self.forward_ts_pair)
539 |
540 | def rectify_events(self, x: np.ndarray, y: np.ndarray):
541 | # assert location in self.locations
542 | # From distorted to undistorted
543 | rectify_map = self.rectify_ev_map
544 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape
545 | assert x.max() < self.width
546 | assert y.max() < self.height
547 | return rectify_map[y, x]
548 |
549 | def warp_rectify(self, voxel_gird):
550 | rect_grid = F.grid_sample(voxel_gird.unsqueeze(0), self.grid, align_corners=True)
551 | return rect_grid.squeeze(0)
552 |
553 | def get_data_sample(self, index):
554 | '''
555 | output.keys = ['events_vg', 'flow', 'valid2D']
556 | '''
557 | ts_begin_i = self.forward_ts_pair[index, 0]
558 | bin_time = self.delta_t_us / (self.bins - 1)
559 | ts_start_extend = ts_begin_i - bin_time
560 | ts_end_extend = ts_begin_i + self.delta_t_us + bin_time
561 | output = {}
562 | flow, valid2D = VoxelGridSequenceDSEC.load_flow(self.gt_flow_files[index])
563 | output['flow'] = torch.from_numpy(flow)
564 | # print('flow-size', output['flow'].shape)
565 | output['valid2D'] = torch.from_numpy(valid2D)
566 | # print('valid-size', output['valid2D'].shape)
567 | raw_events = None
568 |
569 | if self.crop_size is not None:
570 | # get rand h_start and w_start
571 | rand_max_h = self.height - self.crop_size[0]
572 | rand_max_w = self.width - self.crop_size[1]
573 | h_start = random.randrange(0, rand_max_h + 1)
574 | w_start = random.randrange(0, rand_max_w + 1)
575 | W = self.crop_size[1]
576 | H = self.crop_size[0]
577 | else:
578 | h_start = 0
579 | w_start = 0
580 | W = self.width
581 | H = self.height
582 |
583 | is_unified = self.unified
584 | event_data = None
585 | if self.unified:
586 | # 获取事件的 voxel_gird_extend
587 | event_data = self.event_slicer.get_events(ts_start_extend, ts_end_extend)
588 | if (event_data is None) or (not self.unified):
589 | """
590 | if time extend the scope, return voxel_grid rather than voxel_grid_extend
591 | """
592 | event_data = self.event_slicer.get_events(ts_begin_i, ts_begin_i + self.delta_t_us)
593 | is_unified = False
594 |
595 | p = event_data['p']
596 | t = event_data['t']
597 | x = event_data['x']
598 | y = event_data['y']
599 | xy_rect = self.rectify_events(event_data['x'], event_data['y'])
600 | x_rect = xy_rect[:, 0]
601 | y_rect = xy_rect[:, 1]
602 |
603 | x_mask = (x_rect >= w_start - 2) & (
604 | x_rect < w_start + W + 2)
605 | y_mask = (y_rect >= h_start - 2) & (
606 | y_rect < h_start + H + 2)
607 | mask_combined = x_mask & y_mask
608 | x = x[mask_combined]
609 | y = y[mask_combined]
610 | p = p[mask_combined]
611 | t = t[mask_combined]
612 | x_rect = x_rect[mask_combined]
613 | y_rect = y_rect[mask_combined]
614 |
615 | if is_unified:
616 | fast_voxel_grid = self.events_to_unified_voxel_grid(p, t, x, y, self.bins)
617 | mask_time = (t >= ts_begin_i) & (t < ts_begin_i + self.delta_t_us)
618 | x_rect = x_rect[mask_time]
619 | y_rect = y_rect[mask_time]
620 | p = p[mask_time]
621 | t = t[mask_time]
622 | else:
623 | fast_voxel_grid = self.events_to_voxel_grid(p, t, x, y, self.bins)
624 |
625 | event_len = t.size
626 | if self.return_raw and event_len != 0:
627 | raw_events = torch.zeros(event_len, 4)
628 | raw_events[:, 0] = torch.from_numpy(x_rect)
629 | raw_events[:, 1] = torch.from_numpy(y_rect)
630 | raw_events[:, 3] = 2 * torch.from_numpy(p.astype('float32')) - 1
631 | raw_events[:, 2] = torch.from_numpy(
632 | (t - ts_begin_i).astype('float32')) / self.delta_t_us * (self.bins - 1)
633 |
634 |
635 | # warp to rectify distortion
636 | output['events_vg'] = self.warp_rectify(fast_voxel_grid)
637 |
638 | p_flip = 0
639 | # random crop and random flip
640 | output['flow'] = output['flow'].permute(2, 0, 1)
641 | if self.crop_size is not None:
642 | # random flip transform
643 | p_flip = random.randint(0, 1)
644 | flip_param = 1 - 2 * p_flip
645 | flip = transforms.RandomHorizontalFlip(p_flip)
646 | # apply transform
647 | for key in output.keys():
648 | # print(key, output[key].shape)
649 | output[key] = output[key][..., h_start:h_start + self.crop_size[0], w_start:w_start + self.crop_size[1]]
650 | output[key] = flip(output[key])
651 | # flip the flow value of u
652 | output['flow'][0] = flip_param * output['flow'][0]
653 |
654 | if self.return_raw and (raw_events is not None) and raw_events.size()[0] != 0:
655 | crop_mask = (raw_events[:, 0] >= w_start) & (raw_events[:, 0] < w_start + W-1)
656 | crop_mask &= (raw_events[:, 1] >= h_start) & (raw_events[:, 1] < h_start + H-1)
657 | if torch.any(crop_mask):
658 | raw_events = raw_events[crop_mask]
659 | raw_events[:, 0] -= w_start
660 | raw_events[:, 1] -= h_start
661 | if p_flip == 1:
662 | raw_events[:, 0] = W - 1 - raw_events[:, 0]
663 | # exchange [x, y, t, p] to [y, x, t, p]
664 | raw_events = raw_events[:, [1, 0, 2, 3]]
665 | else:
666 | raw_events = None
667 | else:
668 | raw_events = None
669 |
670 | if self.norm:
671 | mask = torch.nonzero(output['events_vg'], as_tuple=True)
672 | if mask[0].size()[0] > 0:
673 | mean = output['events_vg'][mask].mean()
674 | std = output['events_vg'][mask].std()
675 | if std > 0:
676 | output['events_vg'][mask] = (output['events_vg'][mask] - mean) / std
677 | else:
678 | output['events_vg'][mask] = output['events_vg'][mask] - mean
679 |
680 | # raw_events的维度信息和含义:
681 | # 维度: [N, 4], 其中N是事件的数量
682 | # 每一行的4个值分别表示:
683 | # [0]: y 坐标 (高度方向)
684 | # [1]: x 坐标 (宽度方向)
685 | # [2]: 归一化的时间戳 (范围0到bins-1)
686 | # [3]: 极性 (-1或1)
687 | if self.return_raw:
688 | return output['events_vg'], output['flow'], output['valid2D'].float(), raw_events
689 | return output['events_vg'], output['flow'], output['valid2D'].float()
690 |
691 | def __getitem__(self, idx):
692 | return self.get_data_sample(idx)
693 |
694 | class SingleVoxelGridDatasetProviderDSEC:
695 | def __init__(self, dataset_path: Path, bins=15, crop_size=[288, 384],
696 | random_split_seed: int = 42, train_ratio=0.8, return_raw=False, unified=True, norm=True):
697 | events_sequence_path = dataset_path / 'Train' / 'train_events'
698 | flow_sequence_path = dataset_path / 'Train' / 'train_optical_flow'
699 | events_sequences = set([x.stem for x in events_sequence_path.iterdir()])
700 | flow_sequences = set([x.stem for x in flow_sequence_path.iterdir()])
701 | valid_sequences = events_sequences.intersection(flow_sequences)
702 | valid_sequences = list(valid_sequences)
703 | valid_sequences.sort()
704 |
705 | dataset_sequences = list()
706 | dataset_sequences_cropped = list()
707 | for sequence in valid_sequences:
708 | dataset_sequences.append(SingleVoxelGridSequenceDSEC(events_sequence_path / sequence,
709 | flow_sequence_path / sequence,
710 | bins=bins,
711 | crop_size=None,
712 | return_raw=return_raw,
713 | unified=unified,
714 | norm=norm))
715 | dataset_sequences_cropped.append(SingleVoxelGridSequenceDSEC(events_sequence_path / sequence,
716 | flow_sequence_path / sequence,
717 | bins=bins,
718 | crop_size=crop_size,
719 | return_raw=return_raw,
720 | unified=unified,
721 | norm=norm))
722 |
723 | self.dataset = torch.utils.data.ConcatDataset(dataset_sequences)
724 | self.dataset_cropped = torch.utils.data.ConcatDataset(dataset_sequences_cropped)
725 | self.full_size = len(self.dataset)
726 | self.train_size = int(train_ratio * self.full_size)
727 | self.valid_size = self.full_size - self.train_size
728 |
729 | generator = torch.Generator().manual_seed(random_split_seed)
730 | lengths = [self.train_size, self.valid_size]
731 | indices = torch.randperm(sum(lengths), generator=generator).tolist()
732 |
733 | self.train_indices = indices[0: self.train_size]
734 | self.train_indices.sort()
735 |
736 | self.valid_indices = indices[self.train_size:self.full_size]
737 | self.valid_indices.sort()
738 | self.train_set = Subset(self.dataset, self.train_indices)
739 | self.valid_set = Subset(self.dataset, self.valid_indices)
740 | self.train_set_cropped = Subset(self.dataset_cropped, self.train_indices)
741 | self.valid_set_cropped = Subset(self.dataset_cropped, self.valid_indices)
742 |
743 | class SingleVoxelGridTestSequenceDSEC(Dataset):
744 | def __init__(self, events_sequence_path: Path, forward_ts_file: Path, bins=15, return_raw=False, unified=True, norm=True):
745 | """
746 | get a sequence
747 | :param events_sequence_path: Path, events_sequence_path
748 | :param flow_sequence_path: Path, flow_sequence_path
749 | :param raster_channels: compressed events channels
750 | """
751 | assert events_sequence_path.is_dir()
752 | assert forward_ts_file.is_file()
753 | '''
754 | Directory Structure:
755 |
756 | Dataset
757 | └── Train
758 | ├── train_events
759 | │ ├── thun_00_a (events_sequence_path)
760 | │ │ ├── events
761 | │ │ │ ├── left
762 | │ │ │ │ ├── events.h5
763 | │ │ │ │ └── rectify_map.h5
764 | │ │ │ │ └── rect2dist_map.npy
765 | '''
766 | self.forward_ts_pair_idx = np.genfromtxt(forward_ts_file, delimiter=',')
767 | self.forward_ts_pair_idx = self.forward_ts_pair_idx.astype('int64')
768 | '''
769 | ----- forward_ts_file -----
770 | # from_timestamp_us, to_timestamp_us, file_index
771 | 51648500652, 51648600574, 820
772 | 51649000383, 51649100410, 830
773 | 51649500439, 51649600452, 840
774 | 51650000446, 51650100510, 850
775 | 51650500682, 51650600786, 860
776 | 51651002403, 51651103123, 870
777 | 51651507548, 51651607608, 880
778 | 51652007591, 51652107617, 890
779 | 51652507627, 51652607642, 900
780 | ...
781 | '''
782 | self.height = 480
783 | self.width = 640
784 | self.bins = bins
785 | # 光流的时间间隔 100000 us
786 | self.delta_t_us = 100000
787 | self.return_raw = return_raw
788 | self.unified = unified
789 | self.norm = norm
790 |
791 | ev_dir = events_sequence_path / 'events' / 'left'
792 | ev_data_file = ev_dir / 'events.h5'
793 | ev_rect_file = ev_dir / 'rectify_map.h5'
794 |
795 | h5f_location = h5py.File(str(ev_data_file), 'r')
796 | self.h5f = h5f_location
797 | self.event_slicer = EventSlicer(h5f_location)
798 | with h5py.File(str(ev_rect_file), 'r') as h5_rect:
799 | self.rectify_ev_map = h5_rect['rectify_map'][()]
800 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f)
801 |
802 | rect2dist_map_path = ev_dir / 'rect2dist_map.npy'
803 | self.rect2dist_map = torch.from_numpy(np.load(str(rect2dist_map_path)))
804 | rect2dist_map = self.rect2dist_map.clone()
805 | rect2dist_map[..., 0] = 2 * rect2dist_map[..., 0] / (self.width - 1) - 1
806 | rect2dist_map[..., 1] = 2 * rect2dist_map[..., 1] / (self.height - 1) - 1
807 | self.grid = rect2dist_map.unsqueeze(0)
808 |
809 | def events_to_unified_voxel_grid(self, p, t, x, y, bins):
810 | t_norm = torch.from_numpy((t - t[0]).astype('float32'))
811 | # t_norm = (self.bins - 1) * (t_norm / t_norm[-1])
812 | bin_time = self.delta_t_us / (self.bins - 1)
813 | total_t = 2 * bin_time + self.delta_t_us
814 | assert total_t > t_norm[-1]
815 | t_norm = (self.bins - 1 + 2) * (t_norm / total_t)
816 | x = torch.from_numpy(x.astype('float32')).int()
817 | y = torch.from_numpy(y.astype('float32')).int()
818 | # convert p (0, 1) → (-1, 1)
819 | p_value = 2 * torch.from_numpy(p.astype('float32')) - 1
820 |
821 | t0 = t_norm.int()
822 | H, W = self.height, self.width
823 | fast_voxel_grid = torch.zeros((bins + 2, H, W), dtype=torch.float, requires_grad=False)
824 | for tlim in [t0, t0 + 1]:
825 | mask = tlim < self.bins + 2
826 | index = H * W * tlim.long() + \
827 | W * y.long() + \
828 | x.long()
829 | interp_weights = p_value * (1 - (t_norm - tlim).abs())
830 | fast_voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
831 |
832 | return fast_voxel_grid[1:-1, :, :]
833 |
834 | def events_to_voxel_grid(self, p, t, x, y, bins):
835 | t_norm = torch.from_numpy((t - t[0]).astype('float32'))
836 | # t_norm = (self.bins - 1) * (t_norm / t_norm[-1])
837 | assert self.delta_t_us > t_norm[-1]
838 | t_norm = (self.bins - 1) * (t_norm / self.delta_t_us)
839 | x = torch.from_numpy(x.astype('float32')).int()
840 | y = torch.from_numpy(y.astype('float32')).int()
841 | # convert p (0, 1) → (-1, 1)
842 | p_value = 2 * torch.from_numpy(p.astype('float32')) - 1
843 |
844 | t0 = t_norm.int()
845 | H, W = self.height, self.width
846 | fast_voxel_grid = torch.zeros((bins, H, W), dtype=torch.float, requires_grad=False)
847 | for tlim in [t0, t0 + 1]:
848 | mask = tlim < self.bins
849 | index = H * W * tlim.long() + \
850 | W * y.long() + \
851 | x.long()
852 | interp_weights = p_value * (1 - (t_norm - tlim).abs())
853 | fast_voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
854 |
855 | return fast_voxel_grid
856 |
857 | @staticmethod
858 | def load_flow(flowfile: Path):
859 | assert flowfile.exists()
860 | assert flowfile.suffix == '.png'
861 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI')
862 | flow, valid2D = flow_16bit_to_float(flow_16bit)
863 | return flow, valid2D
864 |
865 | @staticmethod
866 | def close_callback(h5f):
867 | h5f.close()
868 |
869 | def get_image_width_height(self):
870 | return self.height, self.width
871 |
872 | def __len__(self):
873 | return len(self.forward_ts_pair_idx)
874 |
875 | def rectify_events(self, x: np.ndarray, y: np.ndarray):
876 | # assert location in self.locations
877 | # From distorted to undistorted
878 | rectify_map = self.rectify_ev_map
879 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape
880 | assert x.max() < self.width
881 | assert y.max() < self.height
882 | return rectify_map[y, x]
883 |
884 | def warp_rectify(self, voxel_gird):
885 | rect_grid = F.grid_sample(voxel_gird.unsqueeze(0), self.grid, align_corners=True)
886 | return rect_grid.squeeze(0)
887 |
888 | def get_data_sample(self, index):
889 | '''
890 | output.keys = ['events_vg', 'name']
891 | '''
892 | ts_begin_i = self.forward_ts_pair_idx[index, 0]
893 | output_name_idx = self.forward_ts_pair_idx[index, 2]
894 | ts_end_i = self.forward_ts_pair_idx[index, 1]
895 | bin_time = self.delta_t_us / (self.bins - 1)
896 | ts_start_extend = ts_begin_i - bin_time
897 | ts_end_extend = ts_begin_i + self.delta_t_us + bin_time
898 | output = {}
899 | raw_events = None
900 |
901 | h_start = 0
902 | w_start = 0
903 | W = self.width
904 | H = self.height
905 |
906 | is_unified = self.unified
907 | event_data = None
908 | if self.unified:
909 | # 获取事件的 voxel_gird_extend
910 | event_data = self.event_slicer.get_events(ts_start_extend, ts_end_extend)
911 | if (event_data is None) or (not self.unified):
912 | """
913 | if time extend the scope, return voxel_grid rather than voxel_grid_extend
914 | """
915 | event_data = self.event_slicer.get_events(ts_begin_i, ts_begin_i + self.delta_t_us)
916 | is_unified = False
917 |
918 | p = event_data['p']
919 | t = event_data['t']
920 | x = event_data['x']
921 | y = event_data['y']
922 | xy_rect = self.rectify_events(event_data['x'], event_data['y'])
923 | x_rect = xy_rect[:, 0]
924 | y_rect = xy_rect[:, 1]
925 |
926 | x_mask = (x_rect >= w_start - 2) & (
927 | x_rect < w_start + W + 2)
928 | y_mask = (y_rect >= h_start - 2) & (
929 | y_rect < h_start + H + 2)
930 | mask_combined = x_mask & y_mask
931 | x = x[mask_combined]
932 | y = y[mask_combined]
933 | p = p[mask_combined]
934 | t = t[mask_combined]
935 | x_rect = x_rect[mask_combined]
936 | y_rect = y_rect[mask_combined]
937 |
938 | if is_unified:
939 | fast_voxel_grid = self.events_to_unified_voxel_grid(p, t, x, y, self.bins)
940 | mask_time = (t >= ts_begin_i) & (t < ts_begin_i + self.delta_t_us)
941 | x_rect = x_rect[mask_time]
942 | y_rect = y_rect[mask_time]
943 | p = p[mask_time]
944 | t = t[mask_time]
945 | else:
946 | fast_voxel_grid = self.events_to_voxel_grid(p, t, x, y, self.bins)
947 |
948 | # warp to rectify distortion
949 | out_voxel_grid = self.warp_rectify(fast_voxel_grid)
950 |
951 | event_len = t.size
952 | if self.return_raw and event_len != 0:
953 | raw_events = torch.zeros(event_len, 4)
954 | raw_events[:, 0] = torch.from_numpy(x_rect)
955 | raw_events[:, 1] = torch.from_numpy(y_rect)
956 | raw_events[:, 3] = 2 * torch.from_numpy(p.astype('float32')) - 1
957 | raw_events[:, 2] = torch.from_numpy(
958 | (t - ts_begin_i).astype('float32')) / self.delta_t_us * (self.bins - 1)
959 |
960 | if self.return_raw and (raw_events is not None) and raw_events.size()[0] != 0:
961 | crop_mask = (raw_events[:, 0] >= w_start) & (raw_events[:, 0] < w_start + W-1)
962 | crop_mask &= (raw_events[:, 1] >= h_start) & (raw_events[:, 1] < h_start + H-1)
963 | if torch.any(crop_mask):
964 | raw_events = raw_events[crop_mask]
965 | raw_events[:, 0] -= w_start
966 | raw_events[:, 1] -= h_start
967 | # exchange [x, y, t, p] to [y, x, t, p] for event_image_converter
968 | raw_events = raw_events[:, [1, 0, 2, 3]]
969 | else:
970 | raw_events = None
971 | else:
972 | raw_events = None
973 |
974 | if self.norm:
975 | mask = torch.nonzero(out_voxel_grid, as_tuple=True)
976 | if mask[0].size()[0] > 0:
977 | mean = out_voxel_grid[mask].mean()
978 | std = out_voxel_grid[mask].std()
979 | if std > 0:
980 | out_voxel_grid[mask] = (out_voxel_grid[mask] - mean) / std
981 | else:
982 | out_voxel_grid[mask] = out_voxel_grid[mask] - mean
983 |
984 | if self.return_raw:
985 | return out_voxel_grid, output_name_idx, raw_events
986 | return out_voxel_grid, output_name_idx
987 |
988 | def __getitem__(self, idx):
989 | return self.get_data_sample(idx)
990 |
991 |
992 | def collate_raw_events(data):
993 | from torch.utils.data.dataloader import default_collate
994 | voxel_grid_list = list()
995 | flow_list = list()
996 | mask_list = list()
997 | raw_events_list = list()
998 | # 遍历数据列表中的每个元素
999 | for i, d in enumerate(data):
1000 | voxel_grid_list.append(d[0]) # 将体素网格数据添加到列表中
1001 | flow_list.append(d[1]) # 将光流数据添加到列表中
1002 | mask_list.append(d[2]) # 将掩码数据添加到列表中
1003 | if d[3] is not None:
1004 | # 如果原始事件数据存在,将批次索引与事件数据拼接后添加到列表中
1005 | raw_events_list.append(torch.cat([i*torch.ones(len(d[3]), 1), d[3]], 1))
1006 |
1007 | # 使用default_collate函数将列表中的数据合并成张量
1008 | out_voxel_grid = default_collate(voxel_grid_list)
1009 | out_flow = default_collate(flow_list)
1010 | out_mask = default_collate(mask_list)
1011 |
1012 | out_raw_events = None
1013 | if len(raw_events_list)!=0:
1014 | # 如果存在原始事件数据,将所有批次的事件数据拼接成一个大张量
1015 | out_raw_events = torch.cat(raw_events_list, dim=0)
1016 | # out_raw_events的维度信息和含义:
1017 | # 维度: [N, 5], 其中N是所有批次中事件的总数量
1018 | # 每一行的5个值分别表示:
1019 | # [0]: 批次索引
1020 | # [1]: y 坐标 (高度方向)
1021 | # [2]: x 坐标 (宽度方向)
1022 | # [3]: 归一化的时间戳 (范围0到bins-1)
1023 | # [4]: 极性 (-1或1)
1024 |
1025 | # 返回处理后的数据
1026 | return out_voxel_grid, out_flow, out_mask, out_raw_events
1027 |
1028 |
1029 | if __name__ == '__main__':
1030 | parser = argparse.ArgumentParser()
1031 | parser.add_argument('--dataset_path', type=str, default="/media/yyz/FastDisk/Dataset/DSEC")
1032 | parser.add_argument('--bins', type=int, default=15)
1033 | parser.add_argument('--crop_size', type=int, default=[288, 384])
1034 | parser.add_argument('--random_split_seed', type=int, default=42)
1035 | parser.add_argument('--train_ratio', type=float, default=0.8)
1036 | parser.add_argument('--return_raw', type=bool, default=True)
1037 | parser.add_argument('--unified', type=bool, default=True)
1038 | parser.add_argument('--norm', type=bool, default=True)
1039 | args = parser.parse_args()
1040 |
1041 | import utils.setupTensor
1042 |
1043 | dsec_provider = SingleVoxelGridDatasetProviderDSEC(Path(args.dataset_path), bins=args.bins,
1044 | crop_size=args.crop_size,
1045 | random_split_seed=args.random_split_seed,
1046 | train_ratio=args.train_ratio,
1047 | return_raw=args.return_raw,
1048 | unified=args.unified,
1049 | norm=args.norm)
1050 |
1051 | dataset = dsec_provider.train_set_cropped
1052 | loader = DataLoader(dataset, batch_size=3, num_workers=4, drop_last=False, shuffle=True, collate_fn=collate_raw_events)
1053 | for i, data_blob in enumerate(tqdm.tqdm(loader)):
1054 | voxel_grid, flow, valid_mask, events = data_blob
1055 | voxel_grid, flow, valid_mask = [x.cuda() for x in [voxel_grid, flow, valid_mask]]
1056 | if events is not None:
1057 | events = events.cuda()
1058 |
1059 |
--------------------------------------------------------------------------------
/utils/dsec_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import hdf5plugin
4 | import h5py
5 | from typing import Dict, Tuple
6 | from numba import jit
7 | import math
8 | from pathlib import Path
9 |
10 | def flow_16bit_to_float(flow_16bit: np.ndarray):
11 | assert flow_16bit.dtype == np.uint16
12 | assert flow_16bit.ndim == 3
13 | h, w, c = flow_16bit.shape
14 | assert c == 3
15 |
16 | valid2D = flow_16bit[..., 2] == 1
17 | assert valid2D.shape == (h, w)
18 | assert np.all(flow_16bit[~valid2D, -1] == 0)
19 | valid_map = np.where(valid2D)
20 |
21 | # to actually compute something useful:
22 | flow_16bit = flow_16bit.astype('float')
23 |
24 | flow_map = np.zeros((h, w, 2))
25 | flow_map[valid_map[0], valid_map[1], 0] = (flow_16bit[valid_map[0], valid_map[1], 0] - 2 ** 15) / 128
26 | flow_map[valid_map[0], valid_map[1], 1] = (flow_16bit[valid_map[0], valid_map[1], 1] - 2 ** 15) / 128
27 | return flow_map, valid2D
28 |
29 |
30 | class EventSlicer:
31 | def __init__(self, h5f: h5py.File = None, folder: Path = None):
32 | if folder is None:
33 | self.h5f = h5f
34 | self.events = dict()
35 | for dset_str in ['p', 'x', 'y', 't']:
36 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)]
37 |
38 | # This is the mapping from milliseconds to event index:
39 | # It is defined such that
40 | # (1) t[ms_to_idx[ms]] >= ms*1000
41 | # (2) t[ms_to_idx[ms] - 1] < ms*1000
42 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds.
43 | #
44 | # As an example, given 't' and 'ms':
45 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000
46 | # ms: 0 1 2 3 4 5 6 7 8 9
47 | #
48 | # we get
49 | #
50 | # ms_to_idx:
51 | # 0 2 2 3 3 3 5 5 8 9
52 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64')
53 |
54 | self.t_offset = int(h5f['t_offset'][()])
55 | self.t_final = int(self.events['t'][-1]) + self.t_offset
56 | else:
57 | self.events = dict()
58 | for dset_str in ['p', 'x', 'y', 't']:
59 | file_name = "events_" + dset_str + ".npy"
60 | self.events[dset_str] = np.load(str(folder / file_name))
61 | self.ms_to_idx = np.load(str(folder / 'ms_to_idx.npy'))
62 | self.t_offset = int(np.load(str(folder/'t_offset.npy'))[()])
63 | self.t_final = int(self.events['t'][-1]) + self.t_offset
64 |
65 | def get_final_time_us(self):
66 | return self.t_final
67 |
68 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]:
69 | """Get events (p, x, y, t) within the specified time window
70 | Parameters
71 | ----------
72 | t_start_us: start time in microseconds
73 | t_end_us: end time in microseconds
74 | Returns
75 | -------
76 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved
77 | """
78 | assert t_start_us < t_end_us
79 |
80 | # We assume that the times are top-off-day, hence subtract offset:
81 | t_start_us -= self.t_offset
82 | t_end_us -= self.t_offset
83 |
84 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us)
85 | t_start_ms_idx = self.ms2idx(t_start_ms)
86 | t_end_ms_idx = self.ms2idx(t_end_ms)
87 |
88 | if t_start_ms_idx is None or t_end_ms_idx is None:
89 | # Cannot guarantee window size anymore
90 | return None
91 |
92 | events = dict()
93 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx])
94 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us)
95 | t_start_us_idx = t_start_ms_idx + idx_start_offset
96 | t_end_us_idx = t_start_ms_idx + idx_end_offset
97 | # Again add t_offset to get gps time
98 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset
99 | for dset_str in ['p', 'x', 'y']:
100 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx])
101 | assert events[dset_str].size == events['t'].size
102 | return events
103 |
104 |
105 | @staticmethod
106 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]:
107 | """Compute a conservative time window of time with millisecond resolution.
108 | We have a time to index mapping for each millisecond. Hence, we need
109 | to compute the lower and upper millisecond to retrieve events.
110 | Parameters
111 | ----------
112 | ts_start_us: start time in microseconds
113 | ts_end_us: end time in microseconds
114 | Returns
115 | -------
116 | window_start_ms: conservative start time in milliseconds
117 | window_end_ms: conservative end time in milliseconds
118 | """
119 | assert ts_end_us > ts_start_us
120 | window_start_ms = math.floor(ts_start_us/1000)
121 | window_end_ms = math.ceil(ts_end_us/1000)
122 | return window_start_ms, window_end_ms
123 |
124 | @staticmethod
125 | @jit(nopython=True)
126 | def get_time_indices_offsets(
127 | time_array: np.ndarray,
128 | time_start_us: int,
129 | time_end_us: int) -> Tuple[int, int]:
130 | """Compute index offset of start and end timestamps in microseconds
131 | Parameters
132 | ----------
133 | time_array: timestamps (in us) of the events
134 | time_start_us: start timestamp (in us)
135 | time_end_us: end timestamp (in us)
136 | Returns
137 | -------
138 | idx_start: Index within this array corresponding to time_start_us
139 | idx_end: Index within this array corresponding to time_end_us
140 | such that (in non-edge cases)
141 | time_array[idx_start] >= time_start_us
142 | time_array[idx_end] >= time_end_us
143 | time_array[idx_start - 1] < time_start_us
144 | time_array[idx_end - 1] < time_end_us
145 | this means that
146 | time_start_us <= time_array[idx_start:idx_end] < time_end_us
147 | """
148 |
149 | assert time_array.ndim == 1
150 |
151 | idx_start = -1
152 | if time_array[-1] < time_start_us:
153 | # This can happen in extreme corner cases. E.g.
154 | # time_array[0] = 1016
155 | # time_array[-1] = 1984
156 | # time_start_us = 1990
157 | # time_end_us = 2000
158 |
159 | # Return same index twice: array[x:x] is empty.
160 | return time_array.size, time_array.size
161 | else:
162 | for idx_from_start in range(0, time_array.size, 1):
163 | if time_array[idx_from_start] >= time_start_us:
164 | idx_start = idx_from_start
165 | break
166 | assert idx_start >= 0
167 |
168 | idx_end = time_array.size
169 | for idx_from_end in range(time_array.size - 1, -1, -1):
170 | if time_array[idx_from_end] >= time_end_us:
171 | idx_end = idx_from_end
172 | else:
173 | break
174 |
175 | assert time_array[idx_start] >= time_start_us
176 | if idx_end < time_array.size:
177 | assert time_array[idx_end] >= time_end_us
178 | if idx_start > 0:
179 | assert time_array[idx_start - 1] < time_start_us
180 | if idx_end > 0:
181 | assert time_array[idx_end - 1] < time_end_us
182 | return idx_start, idx_end
183 |
184 | def ms2idx(self, time_ms: int) -> int:
185 | assert time_ms >= 0
186 | if time_ms >= self.ms_to_idx.size:
187 | return None
188 | return self.ms_to_idx[time_ms]
189 |
190 | from enum import Enum, auto
191 |
192 |
193 | class RepresentationType(Enum):
194 | VOXEL = auto()
195 | STEPAN = auto()
196 |
197 |
198 | class EventRepresentation:
199 | def __init__(self):
200 | pass
201 |
202 | def convert(self, events):
203 | raise NotImplementedError
204 |
205 |
206 | class VoxelGrid(EventRepresentation):
207 | def __init__(self, input_size: tuple, normalize: bool):
208 | assert len(input_size) == 3
209 | self.voxel_grid = torch.zeros((input_size), dtype=torch.float, requires_grad=False)
210 | self.nb_channels = input_size[0]
211 | self.normalize = normalize
212 |
213 | def convert(self, events):
214 | C, H, W = self.voxel_grid.shape
215 | with torch.no_grad():
216 | self.voxel_grid = self.voxel_grid.to(events['p'].device)
217 | voxel_grid = self.voxel_grid.clone()
218 |
219 | t_norm = events['t']
220 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0])
221 |
222 | x0 = events['x'].int()
223 | y0 = events['y'].int()
224 | t0 = t_norm.int()
225 |
226 | value = 2*events['p']-1
227 |
228 | # voxel_grid 的线性插值(妙啊!)
229 | for xlim in [x0,x0+1]:
230 | for ylim in [y0,y0+1]:
231 | for tlim in [t0,t0+1]:
232 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels)
233 | interp_weights = value * (1 - (xlim-events['x']).abs()) * (1 - (ylim-events['y']).abs()) * (1 - (tlim - t_norm).abs())
234 |
235 | index = H * W * tlim.long() + \
236 | W * ylim.long() + \
237 | xlim.long()
238 |
239 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
240 |
241 | # for valid
242 | # tlim = t0
243 | # mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels)
244 | # interp_weights = value * (1 - (xlim - events['x']).abs()) * (1 - (ylim - events['y']).abs())
245 | #
246 | # index = H * W * tlim.long() + \
247 | # W * ylim.long() + \
248 | # xlim.long()
249 |
250 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
251 |
252 | # 对 voxel_grid 中非零的元素归一化
253 | if self.normalize:
254 | mask = torch.nonzero(voxel_grid, as_tuple=True)
255 | if mask[0].size()[0] > 0:
256 | mean = voxel_grid[mask].mean()
257 | std = voxel_grid[mask].std()
258 | if std > 0:
259 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std
260 | else:
261 | voxel_grid[mask] = voxel_grid[mask] - mean
262 |
263 | return voxel_grid
264 |
--------------------------------------------------------------------------------
/utils/gen_dist_map.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import yaml
4 | from pathlib import Path
5 | import random
6 | import argparse
7 |
8 | def generate_dist_map(event_seq_path: Path, calib_seq_path: Path):
9 | calib_sequences = [x.stem for x in calib_seq_path.iterdir() if x.is_dir()]
10 | for seq in calib_sequences:
11 | calib_file_path = calib_seq_path / seq / "calibration/cam_to_cam.yaml"
12 | rectify_map_path = event_seq_path / seq / "events/left/rectify_map.h5"
13 | dist_map_path = event_seq_path / seq / "events/left/rect2dist_map.npy"
14 | if calib_file_path.exists() and rectify_map_path.exists():
15 | calibration = yaml.load(open(calib_file_path), Loader=yaml.FullLoader)
16 | intrinsic = calibration['intrinsics']['cam0']['camera_matrix']
17 | dist_params = calibration['intrinsics']['cam0']['distortion_coeffs']
18 | rect_intrinsic = calibration['intrinsics']['camRect0']['camera_matrix']
19 | R_rect0 = calibration['extrinsics']["R_rect0"]
20 | fx, fy, cx, cy = intrinsic
21 | camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
22 | fx, fy, cx, cy = rect_intrinsic
23 | rect_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
24 | distortion_coeffs = np.array(dist_params)
25 | R = np.array(R_rect0)
26 |
27 | if not dist_map_path.is_file():
28 | rect2dist_map, _ = cv2.initUndistortRectifyMap(camera_matrix, distortion_coeffs, R, rect_matrix, (640, 480),
29 | cv2.CV_32FC2)
30 | np.save(str(dist_map_path), rect2dist_map)
31 | print("save", str(dist_map_path))
32 | else:
33 | rect2dist_map = np.load(str(dist_map_path))
34 | randx = random.randint(0, 639)
35 | randy = random.randint(0, 479)
36 | dist_point = rect2dist_map[randy, randx]
37 | rect_point = cv2.undistortPoints(dist_point, camera_matrix, distortion_coeffs, R=R, P=rect_matrix)
38 | rect_x = rect_point[0, 0, 0]
39 | rect_y = rect_point[0, 0, 1]
40 | max_error = 0.01
41 | if abs(rect_x-randx)