├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── config
├── exe
│ ├── inference_offline
│ │ ├── event_kubric.yaml
│ │ └── evimo2.yaml
│ ├── inference_online
│ │ ├── e2d2.yaml
│ │ ├── feature_tracking.yaml
│ │ └── penn_aviary.yaml
│ └── prepare_event_representations
│ │ ├── e2d2.yaml
│ │ ├── ec.yaml
│ │ ├── eds.yaml
│ │ ├── event_kubric.yaml
│ │ └── evimo2.yaml
└── misc
│ ├── ec
│ └── gt_tracks
│ │ ├── boxes_rotation.gt.txt
│ │ ├── boxes_translation.gt.txt
│ │ ├── shapes_6dof.gt.txt
│ │ ├── shapes_rotation.gt.txt
│ │ └── shapes_translation.gt.txt
│ ├── eds
│ ├── calib.yaml
│ └── gt_tracks
│ │ ├── 01_peanuts_light.gt.txt
│ │ ├── 02_rocket_earth_light.gt.txt
│ │ ├── 08_peanuts_running.gt.txt
│ │ └── 14_ziggy_in_the_arena.gt.txt
│ └── evimo2
│ └── val_samples.csv
├── data_pipeline
├── README.md
├── annotate.py
├── annotator.py
├── compress.py
├── convert.py
├── converter.py
├── decompress.py
└── sample.py
├── docs
├── flowchart.png
├── pred_e2d2_fidget.gif
├── pred_eds_peanuts_running.gif
├── pred_event_kubric.gif
├── pred_evimo2.gif
└── thumbnail.png
├── requirements.txt
├── scripts
├── benchmark_feature_tracking.py
├── benchmark_tap.py
├── create_e2d2_fidget_spinner_gt.py
├── create_evimo2_track_gt.py
├── demo.py
├── download_eds.sh
├── inference_offline.py
├── inference_online.py
└── prepare_event_representations.py
└── src
├── __init__.py
├── data
├── __init__.py
├── modules
│ ├── __init__.py
│ ├── e2d2.py
│ ├── event_kubric.py
│ ├── evimo2.py
│ ├── feature_tracking_online.py
│ └── penn_aviary.py
└── utils
│ ├── __init__.py
│ └── collate.py
├── model
└── etap
│ ├── core
│ ├── __init__.py
│ ├── cotracker
│ │ ├── __init__.py
│ │ └── blocks.py
│ ├── embeddings.py
│ └── model_utils.py
│ └── model.py
├── representations
├── __init__.py
├── event_stack.py
└── voxel_grid.py
└── utils
├── __init__.py
├── metrics.py
├── misc.py
├── supported_seqs_feature_tracking.py
├── track_utils.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /data
2 | output/
3 | weights/
4 | launch.json
5 | .vscode
6 | __pycache__
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 | cover/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | .pybuilder/
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # IPython
89 | profile_default/
90 | ipython_config.py
91 |
92 | # pyenv
93 | # For a library or package, you might want to ignore these files since the code is
94 | # intended to run in multiple environments; otherwise, check them in:
95 | # .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # UV
105 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | #uv.lock
109 |
110 | # poetry
111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
112 | # This is especially recommended for binary packages to ensure reproducibility, and is more
113 | # commonly ignored for libraries.
114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
115 | #poetry.lock
116 |
117 | # pdm
118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
119 | #pdm.lock
120 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
121 | # in version control.
122 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
123 | .pdm.toml
124 | .pdm-python
125 | .pdm-build/
126 |
127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128 | __pypackages__/
129 |
130 | # Celery stuff
131 | celerybeat-schedule
132 | celerybeat.pid
133 |
134 | # SageMath parsed files
135 | *.sage.py
136 |
137 | # Environments
138 | .env
139 | .venv
140 | env/
141 | venv/
142 | ENV/
143 | env.bak/
144 | venv.bak/
145 |
146 | # Spyder project settings
147 | .spyderproject
148 | .spyproject
149 |
150 | # Rope project settings
151 | .ropeproject
152 |
153 | # mkdocs documentation
154 | /site
155 |
156 | # mypy
157 | .mypy_cache/
158 | .dmypy.json
159 | dmypy.json
160 |
161 | # Pyre type checker
162 | .pyre/
163 |
164 | # pytype static type analyzer
165 | .pytype/
166 |
167 | # Cython debug symbols
168 | cython_debug/
169 |
170 | # PyCharm
171 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
172 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
173 | # and can be added to the global gitignore or merged into this file. For a more nuclear
174 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
175 | #.idea/
176 |
177 | # PyPI configuration file
178 | .pypirc
179 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "data_pipeline/rpg_vid2e"]
2 | path = data_pipeline/rpg_vid2e
3 | url = git@github.com:filbert14/rpg_vid2e.git
4 | [submodule "data_pipeline/kubric"]
5 | path = data_pipeline/kubric
6 | url = git@github.com:filbert14/kubric.git
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ETAP: Event-based Tracking of Any Point (CVPR'25 Highlight)
2 |
3 | [](https://arxiv.org/pdf/2412.00133)
4 | [](https://drive.google.com/drive/folders/1Mprj-vOiTP5IgXE9iuu4-4bazcZUswpp?usp=drive_link)
5 | [](https://youtu.be/LaeA9WJ7ptc)
6 | [](https://creativecommons.org/licenses/by-nc/4.0/)
7 |
8 | ## Introduction
9 |
10 | This is the official repository for [**ETAP: Event-based Tracking of Any Point**](https://arxiv.org/pdf/2412.00133), by [Friedhelm Hamann](https://friedhelmhamann.github.io/), [Daniel Gehrig](https://danielgehrig18.github.io/), [Filbert Febryanto](https://github.com/filbert14), [Kostas Daniilidis](https://www.cis.upenn.edu/~kostas/) and [Guillermo Gallego](http://www.guillermogallego.es/).
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | ### Key Features
19 |
20 | - The first event-only point tracking method with strong cross-dataset generalization.
21 | - Robust tracking in challenging conditions (high-speed motion, lighting changes).
22 | - Evaluation benchmark based on six event camera datasets: EDS, EC, EVIMO2, EventKubric, E2D2 and PennAviary
23 |
24 | ### Example Predictions
25 |
26 |
50 |
51 | ## Table of Contents
52 |
53 | - [Quickstart Demo](#quickstart-demo)
54 | - [Installation](#installation)
55 | - [Model Selection](#model-selection)
56 | - [Evaluation Tasks](#evaluation-tasks)
57 | - [Feature Tracking (EDS, EC)](#feature-tracking-eds-ec)
58 | - [EVIMO2](#evaluation-evimo2)
59 | - [EventKubric](#evaluation-eventkubric)
60 | - [E2D2](#evaluation-e2d2)
61 | - [PennAviary](#evaluation-pennaviary-qualitative)
62 | - [Synthetic Data Generation](#synthetic-data-generation-eventkubric)
63 | - [Acknowledgements](#acknowledgements)
64 | - [Citation](#citation)
65 | - [Additional Resources](#additional-resources)
66 | - [License](#license)
67 |
68 | ## Quickstart Demo
69 |
70 | The quickest way to try ETAP is using our demo:
71 |
72 | 1. Clone the repository:
73 | ```bash
74 | git clone https://github.com/tub-rip/ETAP.git
75 | cd ETAP
76 | ```
77 |
78 | 2. Download [the model weights](https://drive.google.com/file/d/18mnwu8CsrVJvDXeRtvU0Wgp6shdxn0HD/view?usp=drive_link) and save to `weights/ETAP_v1_cvpr25.pth`
79 | 3. Download [the demo example](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) (30MB) and extract to `data/demo_example`
80 | 4. Run the demo:
81 | ```bash
82 | python scripts/demo.py
83 | ```
84 |
85 | This demo requires only basic dependencies: `torch`, `numpy`, `tqdm`, `matplotlib`, `imageio`, and `pillow`. No dataset preprocessing needed!
86 |
87 | ## Installation
88 |
89 | 1. Clone the repository:
90 | ```bash
91 | git clone git@github.com:tub-rip/ETAP.git
92 | cd ETAP
93 | ```
94 |
95 | 2. Set up the environment:
96 | ```bash
97 | conda create --name ETAP python=3.10
98 | conda activate ETAP
99 | ```
100 |
101 | 3. Install PyTorch (choose a command compatible with your CUDA version from the [PyTorch website](https://pytorch.org/get-started/locally/)), e.g.:
102 | ```bash
103 | conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia
104 | ```
105 |
106 | 4. Install other dependencies:
107 | ```bash
108 | pip install -r requirements.txt
109 | ```
110 |
111 | ## Model Selection
112 |
113 | Download the [pre-trained model](https://drive.google.com/drive/folders/17YNqf603b3dEdETmNht-ih7cwosKyACN?usp=drive_link) and move it into the folder `/weights/`.
114 |
115 | To reproduce the paper results, use the model `ETAP_v1_cvpr25.pth`.
116 |
117 | ## Evaluation Tasks and Datasets
118 |
119 | ### Evaluation: Feature Tracking (EDS, EC)
120 |
121 | #### Preparations
122 |
123 | ##### Download EDS (Prophesee Gen3 640 x 480 px)
124 |
125 | The four evaluation sequences of the "Event-aided Direct Sparse Odometry Dataset" (EDS) can be downloaded in two ways:
126 |
127 | **Option 1**
128 |
129 | Download the four evaluation sequences of the "Event-aided Direct Sparse Odometry Dataset" (EDS) from the [official web page](https://rpg.ifi.uzh.ch/eds.html):
130 |
131 | * `01_peanuts_light`
132 | * `02_rocket_earth_light`
133 | * `08_peanuts_running`
134 | * `14_ziggy_in_the_arena`
135 |
136 | Choose the archive file option, which contains the events as an hdf5 file. Place all sequences in a common folder.
137 |
138 | **Option 2:** Use our download script:
139 | ```bash
140 | bash scripts/download_eds.sh
141 | ```
142 |
143 | We also use the calibration data provided by EDS. No action is required, as it is included in this repository at `config/misc/eds/calib.yaml`. This is the same file as in the `00_calib` results from the official source.
144 |
145 | The evaluation was introduced in [DDFT](https://github.com/uzh-rpg/deep_ev_tracker). As with the calibration data, we have hardcoded the ground truth tracks at `/config/misc/eds/gt_tracks`, so no additional steps are necessary. If you are interested in how the tracks are created, please refer to the DDFT repository.
146 |
147 | Create a symbolic link to your data root into `/data`, or alternatively you can change the paths in the config files. The setup should look something like this:
148 |
149 | ```
150 | data/eds/
151 | ├── 01_peanuts_light
152 | │ └── events.h5
153 | ├── 02_rocket_earth_light
154 | │ └── events.h5
155 | ├── 08_peanuts_running
156 | │ └── events.h5
157 | └── 14_ziggy_in_the_arena
158 | └── events.h5
159 | ```
160 |
161 | ##### Download EC (DAVIS240C 240 x 180 px)
162 |
163 | Download the five evaluation sequences of the "Event Camera Dataset" (EC) from the [official source](https://rpg.ifi.uzh.ch/davis_data.html). Download the option `Text (zip)`. Unzip the sequences into a folder structure like this:
164 |
165 | ```
166 | data/ec/
167 | ├── boxes_rotation
168 | │ ├── calib.txt
169 | │ ├── events.txt
170 | │ ├── groundtruth.txt
171 | │ ├── images
172 | │ ├── images.txt
173 | │ └── imu.txt
174 | ├── boxes_translation
175 | │ ├── events.txt
176 | │ ├── ...
177 | ├── shapes_6dof
178 | │ ├── events.txt
179 | │ ├── ...
180 | ├── shapes_rotation
181 | │ ├── events.txt
182 | │ ├── ...
183 | └── shapes_translation
184 | ├── events.txt
185 | ├── ...
186 | ```
187 |
188 | As with EDS, the ground truth tracks are from the evaluation introduced in DDFT but we have included them at `config/misc/ec/gt` for convenience.
189 |
190 | ##### Preprocessing
191 |
192 | Preprocess the data by transforming the raw events into event stacks with the following commands:
193 |
194 | ```bash
195 | # For EDS dataset
196 | python scripts/prepare_event_representations.py --dataset eds --config config/exe/prepare_event_representations/eds.yaml
197 |
198 | # For EC dataset
199 | python scripts/prepare_event_representations.py --dataset ec --config config/exe/prepare_event_representations/ec.yaml
200 | ```
201 |
202 | #### Inference
203 |
204 | Run the tracking inference with:
205 |
206 | ```bash
207 | python scripts/inference_online.py --config config/exe/inference_online/feature_tracking.yaml
208 | ```
209 |
210 | #### Evaluation
211 |
212 | Run the benchmarking script to evaluate the tracking results:
213 |
214 | ```bash
215 | python scripts/benchmark_feature_tracking.py feature_tracking_eds_ec
216 | ```
217 |
218 | ### Evaluation: EVIMO2 (Samsung DVS Gen3 640 x 480 px)
219 |
220 | #### Preparations
221 |
222 | 1. Download the required EVIMO2 sequences from the [official source](https://better-flow.github.io/evimo/download_evimo_2.html#imo). You only need **Motion Segmentation / Object Recognition** sequences for the **samsung_mono** camera in **.npz** format (2.4GB). Unzip them and move them into the data directory.
223 |
224 | 2. Download the precomputed tracks [here](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) and merge them into the data directory.
225 |
226 | The result should look like this:
227 |
228 | ```
229 | data/evimo/
230 | └── samsung_mono
231 | └── imo
232 | └── eval
233 | ├── scene13_dyn_test_00_000000
234 | │ ├── dataset_classical.npz
235 | │ ├── dataset_depth.npz
236 | │ ├── dataset_events_p.npy
237 | │ ├── dataset_events_t.npy
238 | │ ├── dataset_events_xy.npy
239 | │ ├── dataset_info.npz
240 | │ ├── dataset_mask.npz
241 | │ └── dataset_tracks.h5
242 | ├── scene13_dyn_test_05_000000
243 | │ ├── dataset_classical.npz
244 | ... ...
245 | ```
246 |
247 | 3. Precompute the event stacks:
248 |
249 | ```bash
250 | python scripts/prepare_event_representations.py --dataset evimo2 --config config/exe/prepare_event_representations/evimo2.yaml
251 | ```
252 |
253 | #### Inference & Evaluation
254 |
255 | Run inference and evaluation with a single command:
256 |
257 | ```bash
258 | python scripts/inference_offline.py --config config/exe/inference_offline/evimo2.yaml
259 | ```
260 |
261 | #### Ground Truth Track Generation (Optional)
262 |
263 | If you want to generate the point tracks yourself instead of using the precomputed ones:
264 |
265 | ```bash
266 | python scripts/create_evimo2_track_gt.py --config config/misc/evimo2/val_samples.csv --data_root data/evimo2
267 | ```
268 |
269 | ### Evaluation: EventKubric (synthetic 512 x 512 px)
270 |
271 | #### Preparations
272 |
273 | 1. Download [the event_kubric test set](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) and move it to the data directory:
274 |
275 | ```
276 | data/event_kubric
277 | └── test
278 | ├── sample_000042
279 | │ ├── annotations.npy
280 | │ ├── data.hdf5
281 | │ ├── data_ranges.json
282 | │ ├── events
283 | │ │ ├── 0000000000.npz
284 | │ │ ├── 0000000001.npz
285 | │ │ ...
286 | │ ├── events.json
287 | │ └── metadata.json
288 | ├── sample_000576
289 | │ ├── annotations.npy
290 | │ ...
291 | ...
292 | ```
293 |
294 | 2. Prepare the event stacks:
295 |
296 | ```bash
297 | python scripts/prepare_event_representations.py --dataset event_kubric --config config/exe/prepare_event_representations/event_kubric.yaml
298 | ```
299 |
300 | #### Inference & Evaluation
301 |
302 | Run inference and evaluation with a single command:
303 |
304 | ```bash
305 | python scripts/inference_offline.py --config config/exe/inference_offline/event_kubric.yaml
306 | ```
307 |
308 | ### Evaluation: E2D2 (SilkyEvCam 640 x 480 px)
309 |
310 | #### Preparations
311 |
312 | 1. Download [the E2D2 fidget spinner sequence](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) and move it to the data directory:
313 |
314 | ```
315 | data/e2d2/
316 | └── 231025_110210_fidget5_high_exposure
317 | ├── gt_positions.npy
318 | ├── gt_timestamps.npy
319 | ├── queries.npy
320 | └── seq.h5
321 | ```
322 |
323 | 2. Prepare the event stacks:
324 |
325 | ```bash
326 | python scripts/prepare_event_representations.py --dataset e2d2 --config config/exe/prepare_event_representations/e2d2.yaml
327 | ```
328 |
329 | #### Inference
330 |
331 | ```bash
332 | python scripts/inference_online.py --config config/exe/inference_online/e2d2.yaml
333 | ```
334 |
335 | #### Evaluation
336 |
337 | ```bash
338 | python scripts/benchmark_tap.py --gt_dir data/e2d2/231025_110210_fidget5_high_exposure --pred_dir output/inference/tap_e2d2
339 | ```
340 |
341 | #### Ground Truth Generation (Optional)
342 |
343 | The ground truth is calculated from the turning speed of the fidget spinner and is provided for download. To calculate the ground truth tracks yourself, run:
344 |
345 | ```bash
346 | python scripts/create_e2d2_fidget_spinner_gt.py
347 | ```
348 |
349 | ### Evaluation: PennAviary (SilkyEvCam 640 x 480 px, Qualitative)
350 |
351 | #### Preparations
352 |
353 | Download [the penn_aviary sequence](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) and move it to the data directory:
354 |
355 | ```
356 | data/penn_aviary/
357 | └── 231018_174107_view2
358 | ├── mask00082.png
359 | └── seq.h5
360 | ```
361 |
362 | #### Inference
363 |
364 | Run the inference with:
365 |
366 | ```bash
367 | python scripts/inference_online.py --config config/exe/inference_online/penn_aviary.yaml
368 | ```
369 |
370 | ## Synthetic Data Generation (EventKubric)
371 |
372 | We provide a [10 sample test set of EventKubric](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) for quick evaluation. The complete dataset consists of approximately 10,000 samples.
373 |
374 | To generate your own synthetic event data, please refer to the [Data Pipeline Instructions](data_pipeline/README.md).
375 |
376 | ## Acknowledgements
377 |
378 | We gratefully appreciate the following repositories and thank the authors for their excellent work:
379 |
380 | - [CoTracker](https://github.com/facebookresearch/co-tracker)
381 | - [TapVid](https://github.com/google-deepmind/tapnet/tree/main/tapnet/tapvid)
382 | - [DDFT](https://github.com/uzh-rpg/deep_ev_tracker)
383 |
384 | ## Citation
385 |
386 | If you find this work useful in your research, please consider citing:
387 |
388 | ```bibtex
389 | @InProceedings{Hamann25cvpr,
390 | author={Friedhelm Hamann and Daniel Gehrig and Filbert Febryanto and Kostas Daniilidis and Guillermo Gallego},
391 | title={{ETAP}: Event-based Tracking of Any Point},
392 | booktitle={{IEEE/CVF} Conf. Computer Vision and Pattern Recognition ({CVPR})},
393 | year=2025,
394 | }
395 | ```
396 |
397 | ## Additional Resources
398 |
399 | * [Research page (TU Berlin, RIP lab)](https://sites.google.com/view/guillermogallego/research/event-based-vision)
400 | * [Course at TU Berlin](https://sites.google.com/view/guillermogallego/teaching/event-based-robot-vision)
401 | * [Survey paper](http://rpg.ifi.uzh.ch/docs/EventVisionSurvey.pdf)
402 | * [List of Event-based Vision Resources](https://github.com/uzh-rpg/event-based_vision_resources)
403 |
404 | ## License
405 |
406 | This project is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License - see the [LICENSE](LICENSE) file for details. This means you are free to share and adapt the material for non-commercial purposes, provided you give appropriate credit and indicate if changes were made.
--------------------------------------------------------------------------------
/config/exe/inference_offline/event_kubric.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | height: 512
3 | width: 512
4 | exp_name: event_kubric
5 | checkpoint: weights/ETAP_v1_cvpr25.pth
6 |
7 | model:
8 | num_in_channels: 10
9 | stride: 4
10 | window_len: 8
11 | num_virtual_tracks: 64
12 |
13 | data:
14 | data_root: data/event_kubric
15 | dataset_name: event_kubric
16 | preprocessed_name: event_stack_v1
17 | seq_len: 24
18 | traj_per_sample: 500
--------------------------------------------------------------------------------
/config/exe/inference_offline/evimo2.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | height: 480
3 | width: 640
4 | exp_name: evimo2
5 | checkpoint: weights/ETAP_v1_cvpr25.pth
6 |
7 | model:
8 | num_in_channels: 10
9 | stride: 4
10 | window_len: 8
11 | num_virtual_tracks: 64
12 |
13 | data:
14 | data_root: data/evimo2
15 | dataset_name: evimo2
16 | preprocessed_name: event_stack_v1
17 | metadata_path: config/misc/evimo2/val_samples.csv
--------------------------------------------------------------------------------
/config/exe/inference_online/e2d2.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | exp_name: tap_e2d2
3 | height: 512
4 | width: 512
5 | ckp_path: weights/ETAP_v1_cvpr25.pth
6 | add_support_points: true
7 | support_point_stride: 20
8 |
9 | data:
10 | data_root: data/e2d2
11 | dataset_name: e2d2
12 | preprocessed_name: event_stack_v1
13 | sequences:
14 | - 231025_110210_fidget5_high_exposure
15 |
16 | model:
17 | num_in_channels: 10
18 | stride: 4
19 | window_len: 8
20 | num_virtual_tracks: 64
21 |
--------------------------------------------------------------------------------
/config/exe/inference_online/feature_tracking.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | exp_name: feature_tracking_eds_ec
3 | ckp_path: weights/ETAP_v1_cvpr25.pth
4 |
5 | data:
6 | data_root: data/
7 | dataset_name: feature_tracking_online
8 | preprocessed_name: event_stack_v1
9 |
10 | model:
11 | num_in_channels: 10
12 | stride: 4
13 | window_len: 8
14 | num_virtual_tracks: 64
15 |
--------------------------------------------------------------------------------
/config/exe/inference_online/penn_aviary.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | exp_name: qualitative_penn_aviary
3 | height: 480
4 | width: 640
5 | ckp_path: weights/ETAP_v1_cvpr25.pth
6 | add_support_points: false
7 |
8 | data:
9 | data_root: data/penn_aviary
10 | dataset_name: penn_aviary
11 | load_rgb: false
12 | sequences:
13 | - 231018_174107_view2
14 |
15 | repr_config:
16 | representation_name: event_stack
17 | num_stacks: 10
18 | interpolation: bilinear
19 | channel_overlap: true
20 | centered_channels: false
21 | image_shape: [480, 640] # Height and width of data, in common is for inference
22 |
23 | sequence_data:
24 | 231018_174107_view2:
25 | start_time_s: 2.7099975 # frame 35
26 | duration_s: 0.5
27 | step_time_s: 0.0033
28 | num_events: 200000
29 | query_stride: 8
30 | mask_name: mask00082.png
31 |
32 | model:
33 | num_in_channels: 10
34 | stride: 4
35 | window_len: 8
36 | num_virtual_tracks: 64
37 |
--------------------------------------------------------------------------------
/config/exe/prepare_event_representations/e2d2.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | height: 480
3 | width: 640
4 | save_prefix: v1
5 | num_events: 200000
6 | data_root: data/e2d2
7 | sequences:
8 | - 231025_110210_fidget5_high_exposure
9 |
10 | event_representation:
11 | representation_name: event_stack
12 | num_stacks: 10
13 | interpolation: bilinear
14 | channel_overlap: true
15 | centered_channels: false
16 | image_shape: [480, 640]
--------------------------------------------------------------------------------
/config/exe/prepare_event_representations/ec.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | dataset_path: data/ec
3 | num_events: 100000
4 | height: 180
5 | width: 240
6 | save_prefix: v1
7 |
8 | event_representation:
9 | representation_name: event_stack
10 | num_stacks: 10
11 | interpolation: bilinear
12 | channel_overlap: true
13 | centered_channels: false
14 |
--------------------------------------------------------------------------------
/config/exe/prepare_event_representations/eds.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | dataset_path: data/eds
3 | num_events: 800000
4 | height: 480
5 | width: 640
6 | save_prefix: v1
7 |
8 | event_representation:
9 | representation_name: event_stack
10 | num_stacks: 10
11 | interpolation: bilinear
12 | channel_overlap: true
13 | centered_channels: false
14 |
--------------------------------------------------------------------------------
/config/exe/prepare_event_representations/event_kubric.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | dataset_path: data/event_kubric/test
3 | num_events: 400000
4 | sequence_length: 24
5 | height: 512
6 | width: 512
7 | save_prefix: v1
8 | create_time_inverted: false
9 |
10 | event_representation:
11 | representation_name: event_stack
12 | num_stacks: 10
13 | interpolation: bilinear
14 | channel_overlap: true
15 | centered_channels: false
--------------------------------------------------------------------------------
/config/exe/prepare_event_representations/evimo2.yaml:
--------------------------------------------------------------------------------
1 | common:
2 | data_path: data/evimo2
3 | sample_config: config/misc/evimo2/val_samples.csv
4 | num_events: 60000
5 | height: 480
6 | width: 640
7 | save_prefix: v1
8 |
9 | event_representation:
10 | representation_name: event_stack
11 | num_stacks: 10
12 | interpolation: bilinear
13 | channel_overlap: true
14 | centered_channels: false
--------------------------------------------------------------------------------
/config/misc/eds/calib.yaml:
--------------------------------------------------------------------------------
1 | cam0: # This is the RGB camera
2 | cam_overlaps: [1]
3 | camera_model: pinhole
4 | distortion_coeffs: [-0.36965913545735024, 0.17414034009883844, 0.003915245015812422,
5 | 0.003666687416655559]
6 | distortion_model: radtan
7 | intrinsics: [766.536025127154, 767.5749459126396, 291.0503512057777, 227.4060484950132]
8 | resolution: [640, 480]
9 | rostopic: /cam0/image_raw
10 | flip: True
11 | cam1: # This is the event camera
12 | T_cn_cnm1:
13 | - [0.9998964430808897, -0.0020335804041023736, -0.014246672065022661, -0.00011238613157578769]
14 | - [0.001703024953250547, 0.9997299470300024, -0.023176123864880376, -0.0005981481496958399]
15 | - [0.014289955220253567, 0.02314946137886846, 0.9996298813149167, -0.004416681577516066]
16 | - [0.0, 0.0, 0.0, 1.0]
17 | cam_overlaps: [0]
18 | camera_model: pinhole
19 | distortion_coeffs: [-0.09776467241921379, 0.2143738428636279, -0.004710710105172864,
20 | -0.004215916089401789]
21 | distortion_model: radtan
22 | intrinsics: [560.8520948927032, 560.6295819972383, 313.00733235019237, 217.32858679842997]
23 | resolution: [640, 480]
24 | rostopic: /cam1/image_raw
25 |
--------------------------------------------------------------------------------
/config/misc/evimo2/val_samples.csv:
--------------------------------------------------------------------------------
1 | name,t_start,t_end
2 | samsung_mono/imo/eval/scene13_dyn_test_00_000000,0.033333,1.516667
3 | samsung_mono/imo/eval/scene14_dyn_test_03_000000,1.0288,2.572
4 | samsung_mono/imo/eval/scene14_dyn_test_05_000000,0.711,1.59975
5 | samsung_mono/imo/eval/scene15_dyn_test_01_000000,1.406668,3.516667
6 | samsung_mono/imo/eval/scene15_dyn_test_02_000000,0.033333,0.966667
7 | samsung_mono/imo/eval/scene15_dyn_test_05_000000,0.7,1.4
--------------------------------------------------------------------------------
/data_pipeline/README.md:
--------------------------------------------------------------------------------
1 | # EventKubric Data Generation Pipeline
2 |
3 | ## Overview
4 |
5 | This document provides instructions for generating synthetic data samples for the EventKubric dataset. The pipeline leverages two key tools:
6 | - [Kubric](https://github.com/google-research/kubric) - For generating synthetic scenes and rendering frames
7 | - [Vid2E](https://github.com/uzh-rpg/rpg_vid2e) - For converting video frames to event data
8 |
9 | Both tools are included as Git submodules in this repository. To access them, you can either:
10 |
11 | Clone with the `--recursive` flag (if starting fresh):
12 | ```bash
13 | git clone --recursive git@github.com:tub-rip/ETAP.git
14 | ```
15 |
16 | Or if you've already cloned the repository without the `--recursive` flag, initialize and update the submodules:
17 | ```bash
18 | # Navigate to your existing repository
19 | cd ETAP
20 |
21 | # Initialize and fetch all submodules
22 | git submodule update --init --recursive
23 | ```
24 |
25 | For reference samples, see the [EventKubric test set](https://drive.google.com/drive/folders/1v8dYA-D7OOCAw9TimTxj74nz8t5mbsqz).
26 |
27 | ## Data Generation Pipeline
28 |
29 | ### 1. Setting Up Kubric
30 |
31 | #### CPU Setup (Standard)
32 | Pull the Kubric Docker image:
33 | ```bash
34 | docker pull kubricdockerhub/kubruntu
35 | ```
36 |
37 | Set up and enter Docker container:
38 | ```bash
39 | docker run --interactive --user $(id -u):$(id -g) --volume "$(pwd):/kubric" kubricdockerhub/kubruntu
40 | sudo docker exec -it bash
41 | ```
42 |
43 | #### GPU Setup (Optional, for faster rendering)
44 | Inside the `/kubric` directory:
45 | 1. Build the wheel file:
46 | ```bash
47 | python setup.py bdist_wheel
48 | ```
49 |
50 | 2. Build Docker images with GPU support:
51 | ```bash
52 | docker build -f docker/Blender.Dockerfile -t kubricdockerhub/blender-gpu:latest .
53 | docker build -f docker/Kubruntu.Dockerfile -t kubricdockerhub/kubruntu-gpu:latest .
54 | ```
55 |
56 | 3. Run container with GPU access:
57 | ```bash
58 | docker run --interactive --gpus all --env KUBRIC_USE_GPU=1 --volume "$(pwd):/kubric" kubricdockerhub/kubruntu-gpu
59 | sudo docker exec -it bash
60 | ```
61 |
62 | ### 2. Generating Synthetic Samples
63 |
64 | The complete pipeline consists of four main steps as visualized in the flow chart:
65 |
66 |
67 |
68 |
69 |
70 | #### Step 1: Generate Training Examples
71 | ```bash
72 | python3 sample.py --output_dir= --start_index= --end_index= --worker_script=event_kubric_worker
73 | ```
74 | This generates `end_index - start_index` training examples with indices in the range `[start_index, end_index - 1]`.
75 |
76 | **Optional flags:**
77 | - `--panning`: Include camera panning motions (similar to [TAPIR](https://deepmind-tapir.github.io/))
78 |
79 | #### Step 2: Generate Ground Truth Point Tracks
80 | ```bash
81 | python3 annotate.py --dataset_path= --start_index= --end_index= --resolution=512 --num_frames=96 --tracks_to_sample=2048
82 | ```
83 |
84 | #### Step 3: Generate Events from Frames
85 | ```bash
86 | python3 convert.py --dataset_path= --start_index= --end_index= --frame_rate=48 --num_frames=96 --ct_lower=0.16 --ct_upper=0.34 --ref_period=0
87 | ```
88 |
89 | #### Step 4: Compress Data (Optional)
90 | ```bash
91 | python3 compress.py --dataset_root=
92 | ```
93 | This compresses raw files into a single `.hdf5` file per sample. For additional options (e.g., subsampling), refer to `compress.py`.
94 |
95 | To decompress a sample:
96 | ```bash
97 | python3 decompress.py --hdf5_path= --output_folder=
98 | ```
99 |
100 | ## Dataset Structure
101 |
102 | After generation, your dataset will have the following structure:
103 |
104 | ```
105 | dataset_root/
106 | ├── 00000000/
107 | ├── 00000001/
108 | │ ├── annotations.npy # Ground truth point tracks
109 | │ ├── data.hdf5 # Compressed data
110 | │ ├── events/ # Event data
111 | │ │ ├── 0000000000.npz
112 | │ │ ├── ...
113 | │ │ └── 0000000211.npz
114 | │ └── raw/ # Raw Kubric outputs
115 | │ ├── backward_flow_*.png
116 | │ ├── data_ranges.json
117 | │ ├── depth_*.tiff
118 | │ ├── events.json
119 | │ ├── forward_flow_*.png
120 | │ ├── metadata.json
121 | │ ├── normal_*.png
122 | │ ├── object_coordinates_*.png
123 | │ ├── rgba_*.png
124 | │ └── segmentation_*.png
125 | ├── 00000002/
126 | └── ...
127 | ```
128 |
129 | ## Data Format Conventions
130 |
131 | - **Events**: `[y, x, t, p]` where:
132 | - `y, x`: Pixel coordinates
133 | - `t`: Timestamp
134 | - `p`: Polarity (positive/negative)
135 |
136 | - **Annotations**:
137 | - `annotations["video"]`: RGB frames `[N, C, H, W]`
138 | - `annotations["target_points"]`: Point track coordinates `[N, P, 2]`
139 | - `annotations["occluded"]`: Occlusion masks `[N, P]`
140 |
141 | Where:
142 | - `N`: Number of frames
143 | - `C`: Number of channels
144 | - `H, W`: Height and width
145 | - `P`: Number of tracked points
146 |
--------------------------------------------------------------------------------
/data_pipeline/annotate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import subprocess
4 |
5 | def main():
6 | ap = argparse.ArgumentParser()
7 | ap.add_argument("--dataset_path" , required=True)
8 | ap.add_argument("--start_index" , type=int, required=True)
9 | ap.add_argument("--end_index" , type=int, required=True)
10 |
11 | ap.add_argument("--resolution" , type=int, required=True)
12 | ap.add_argument("--num_frames" , type=int, required=True)
13 | ap.add_argument("--tracks_to_sample", type=int, default=2048)
14 |
15 | args = vars(ap.parse_args())
16 | dataset_path = args["dataset_path"]
17 | start_index = args["start_index"]
18 | end_index = args["end_index"]
19 |
20 | resolution = args["resolution"]
21 | num_frames = args["num_frames"]
22 | tracks_to_sample= args["tracks_to_sample"]
23 |
24 | example_counter = start_index
25 |
26 | def get_current_example_path():
27 | return os.path.join(dataset_path, f"{example_counter:08d}")
28 |
29 | while example_counter < end_index:
30 | print(f"Annotating example {example_counter}")
31 |
32 | script = ["python3",
33 | "annotator.py",
34 | f"--scene_dir={os.path.join(get_current_example_path(), 'raw')}",
35 | f"--resolution={resolution}",
36 | f"--num_frames={num_frames}",
37 | f"--output_dir={get_current_example_path()}",
38 | f"--tracks_to_sample={tracks_to_sample}"]
39 |
40 | annotate_result = subprocess.run(script, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
41 |
42 | if annotate_result.returncode == 0:
43 | print(f"Successfully annotated example {example_counter}")
44 | else:
45 | print(f"Failed to annotate example {example_counter}, return code: {annotate_result.returncode}")
46 | break
47 |
48 | example_counter += 1
49 |
50 | if __name__ == "__main__":
51 | main()
--------------------------------------------------------------------------------
/data_pipeline/annotator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import shutil
4 | import json
5 | import tensorflow as tf
6 | import sys
7 | import numpy as np
8 |
9 | sys.path.append(os.path.join("kubric", "challenges"))
10 |
11 | from PIL import Image
12 | from point_tracking.dataset import add_tracks
13 |
14 | from movi.movi_f import subsample_nearest_neighbor
15 | from movi.movi_f import subsample_avg
16 | from movi.movi_f import read_png
17 | from movi.movi_f import read_tiff
18 | from movi.movi_f import convert_float_to_uint16
19 |
20 | def main():
21 | # Parse arguments
22 | ap = argparse.ArgumentParser()
23 | ap.add_argument("--scene_dir", required=True)
24 | ap.add_argument("--resolution", type=int, required=True)
25 | ap.add_argument("--num_frames", type=int, required=True)
26 | ap.add_argument("--output_dir", required=True)
27 | ap.add_argument("--tracks_to_sample", type=int, default=2048)
28 | args = vars(ap.parse_args())
29 |
30 | _scene_dir = args["scene_dir"]
31 | _width, _height = args["resolution"], args["resolution"]
32 | _num_frames = args["num_frames"]
33 | _target_size = (_height, _width)
34 | _output_dir = args["output_dir"]
35 | _tracks_to_sample= args["tracks_to_sample"]
36 | layers = ("rgba", "segmentation", "depth", "normal", "object_coordinates")
37 |
38 | # Load simulation output
39 | with tf.io.gfile.GFile(os.path.join(_scene_dir, 'metadata.json'), "r") as fp:
40 | metadata = json.load(fp)
41 | paths = {
42 | key: [os.path.join(_scene_dir, (f"{key}_{f:05d}.png")) for f in range (_num_frames)]
43 | for key in layers if key != "depth"
44 | }
45 |
46 | # Gather relevant data for point tracking annotation
47 | result = {}
48 | result["normal"] = tf.convert_to_tensor([subsample_nearest_neighbor(read_png(frame_path), _target_size) for frame_path in paths["normal"]], dtype=float)
49 | result["object_coordinates"] = tf.convert_to_tensor([subsample_nearest_neighbor(read_png(frame_path), _target_size) for frame_path in paths["object_coordinates"]])
50 | result["segmentations"] = tf.convert_to_tensor([subsample_nearest_neighbor(read_png(frame_path), _target_size) for frame_path in paths["segmentation"]])
51 | result["video"] = tf.convert_to_tensor([subsample_avg(read_png(frame_path), _target_size)[..., :3] for frame_path in paths["rgba"]])
52 | result["metadata"] = {}
53 |
54 | depth_paths = [os.path.join(_scene_dir, f"depth_{f:05d}.tiff") for f in range(_num_frames)]
55 | depth_frames = np.array([subsample_nearest_neighbor(read_tiff(frame_path), _target_size) for frame_path in depth_paths])
56 | depth_min, depth_max = np.min(depth_frames), np.max(depth_frames)
57 |
58 | result["depth"] = convert_float_to_uint16(depth_frames, depth_min, depth_max)
59 | result["metadata"]["depth_range"] = [depth_min, depth_max]
60 |
61 | result["instances"] = {}
62 | result["instances"]["bboxes_3d"] = tf.convert_to_tensor([np.array(obj["bboxes_3d"], np.float32) for obj in metadata["instances"]])
63 | result["instances"]["quaternions"] = tf.convert_to_tensor([np.array(obj["quaternions"], np.float32) for obj in metadata["instances"]])
64 | result["camera"] = {}
65 |
66 | result["camera"]["focal_length"] = metadata["camera"]["focal_length"]
67 | result["camera"]["sensor_width"] = metadata["camera"]["sensor_width"]
68 | result["camera"]["positions"] = np.array(metadata["camera"]["positions"], np.float32)
69 | result["camera"]["quaternions"] = np.array(metadata["camera"]["quaternions"], np.float32)
70 |
71 | # Annotate using add_tracks
72 | point_tracking = add_tracks(result, train_size=_target_size, random_crop=False, tracks_to_sample=_tracks_to_sample)
73 | video = point_tracking["video"].numpy()
74 | target_points = point_tracking["target_points"].numpy()
75 | occluded = point_tracking["occluded"].numpy()
76 |
77 | # Save annotations
78 | annotations = {"target_points": target_points, "occluded":occluded}
79 | np.save(os.path.join(_output_dir, "annotations.npy"), annotations)
80 |
81 | if __name__ == "__main__":
82 | main()
--------------------------------------------------------------------------------
/data_pipeline/compress.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import h5py
4 | import numpy as np
5 | import cv2
6 | import glob
7 | import argparse
8 | import multiprocessing
9 | from functools import partial
10 | import hdf5plugin
11 | import yaml
12 | from tqdm import tqdm
13 |
14 |
15 | def load_image(file_path):
16 | img = cv2.imread(file_path, cv2.IMREAD_UNCHANGED)
17 | if img is None:
18 | raise ValueError(f"Failed to load image: {file_path}")
19 |
20 | # Convert BGR to RGB for 3-channel color images
21 | if len(img.shape) == 3 and img.shape[2] == 3:
22 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
23 | # Convert BGRA to RGBA for 4-channel images
24 | elif len(img.shape) == 3 and img.shape[2] == 4:
25 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
26 |
27 | return img
28 |
29 |
30 | def calculate_original_size(sample_path):
31 | total_size = 0
32 | for file_name in os.listdir(sample_path):
33 | if file_name != 'data.hdf5': # Exclude the HDF5 file if it exists
34 | file_path = os.path.join(sample_path, file_name)
35 | if os.path.isfile(file_path):
36 | total_size += os.path.getsize(file_path)
37 | return total_size
38 |
39 |
40 | def process_sample(sample_path, verbose, subsample_factor, compression_level=9):
41 | output_file = os.path.join(sample_path, 'data.hdf5')
42 | sample_path = os.path.join(sample_path, "raw")
43 |
44 | # Skip if HDF5 file already exists
45 | if os.path.exists(output_file):
46 | if verbose:
47 | print(f"Skipping {sample_path}: HDF5 file already exists")
48 | return None
49 |
50 | try:
51 | original_size = calculate_original_size(sample_path)
52 |
53 | with h5py.File(output_file, 'w') as hf:
54 | for data_type in ['backward_flow', 'depth', 'forward_flow', 'normal', 'object_coordinates', 'rgba',
55 | 'segmentation']:
56 | file_pattern = f"{data_type}_*.{'tiff' if data_type == 'depth' else 'png'}"
57 | files = sorted(glob.glob(os.path.join(sample_path, file_pattern)))
58 |
59 | if not files:
60 | print(f"Warning: No {data_type} files found in {sample_path}")
61 | continue
62 |
63 | if data_type != 'rgba':
64 | # Subsample for all data types except 'rgba'
65 | files = files[::subsample_factor]
66 |
67 | images = [load_image(f) for f in files]
68 | dataset = np.stack(images)
69 |
70 | hf.create_dataset(data_type, data=dataset, compression="gzip", compression_opts=compression_level)
71 |
72 | # Store subsampling indices
73 | if len(files) > 0:
74 | total_frames = len(glob.glob(os.path.join(sample_path, 'rgba_*.png')))
75 | subsample_indices = np.arange(0, total_frames, subsample_factor)
76 | hf.create_dataset('subsample_indices', data=subsample_indices, compression="gzip",
77 | compression_opts=compression_level)
78 |
79 | # Add events and metadata as attributes
80 | with open(os.path.join(sample_path, 'events.json'), 'r') as f:
81 | events = json.load(f)
82 | hf.attrs['events'] = json.dumps(events)
83 |
84 | with open(os.path.join(sample_path, 'metadata.json'), 'r') as f:
85 | metadata = json.load(f)
86 | hf.attrs['metadata'] = json.dumps(metadata)
87 |
88 | # Check for and add additional metadata from metadata.yaml
89 | yaml_path = os.path.join(sample_path, 'metadata.yaml')
90 | if os.path.exists(yaml_path):
91 | with open(yaml_path, 'r') as f:
92 | additional_metadata = yaml.safe_load(f)
93 | hf.attrs['additional_metadata'] = json.dumps(additional_metadata)
94 |
95 | if verbose:
96 | compressed_size = os.path.getsize(output_file)
97 | compression_factor = original_size / compressed_size
98 | return {
99 | 'sample': os.path.basename(sample_path),
100 | 'original_size': original_size,
101 | 'compressed_size': compressed_size,
102 | 'compression_factor': compression_factor
103 | }
104 | return None
105 | except Exception as e:
106 | return {'sample': os.path.basename(sample_path), 'error': str(e)}
107 |
108 |
109 | def process_dataset(dataset_root, verbose, num_processes, subsample_factor, compression_level=9):
110 | sample_paths = [os.path.join(dataset_root, d) for d in os.listdir(dataset_root) if
111 | os.path.isdir(os.path.join(dataset_root, d))]
112 |
113 | with multiprocessing.Pool(processes=num_processes) as pool:
114 | results = list(tqdm(pool.imap(partial(process_sample, verbose=verbose, subsample_factor=subsample_factor,
115 | compression_level=compression_level), sample_paths),
116 | total=len(sample_paths), desc="Processing samples"))
117 |
118 | failed_samples = [result['sample'] for result in results if result and 'error' in result]
119 |
120 | if verbose:
121 | for result in results:
122 | if result and 'error' not in result:
123 | print(f"Sample: {result['sample']}")
124 | print(f"Original size: {result['original_size'] / 1024:.2f} KB")
125 | print(f"Compressed size: {result['compressed_size'] / 1024:.2f} KB")
126 | print(f"Compression factor: {result['compression_factor']:.2f}x")
127 | print("--------------------")
128 |
129 | # Print and log failed samples
130 | if failed_samples:
131 | print("Failed samples:")
132 | for sample in failed_samples:
133 | print(sample)
134 |
135 | log_file = os.path.join(dataset_root, 'log.txt')
136 | with open(log_file, 'w') as f:
137 | f.write("Failed samples:\n")
138 | for sample in failed_samples:
139 | f.write(f"{sample}\n")
140 | print(f"Failed samples have been logged in {log_file}")
141 |
142 |
143 | if __name__ == "__main__":
144 | parser = argparse.ArgumentParser(
145 | description="Convert dataset to HDF5 format using gzip compression with multiprocessing and subsampling")
146 | parser.add_argument("dataset_root", help="Path to the dataset root")
147 | parser.add_argument("-v", "--verbose", action="store_true", help="Print compression information for each sample")
148 | parser.add_argument("-j", "--jobs", type=int, default=multiprocessing.cpu_count(),
149 | help="Number of parallel jobs to run (default: number of CPU cores)")
150 | parser.add_argument("-s", "--subsample", type=int, default=1,
151 | help="Subsample factor for non-rgba data (default: 1)")
152 | parser.add_argument("-c", "--compression", type=int, default=4, choices=range(0, 10),
153 | help="Gzip compression level (0-9, default: 9)")
154 | args = parser.parse_args()
155 |
156 | process_dataset(args.dataset_root, args.verbose, args.jobs, args.subsample, args.compression)
157 | print("Conversion completed.")
158 |
--------------------------------------------------------------------------------
/data_pipeline/convert.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import subprocess
3 | import os
4 |
5 | def main():
6 | ap = argparse.ArgumentParser()
7 | ap.add_argument("--dataset_path", required=True)
8 | ap.add_argument("--start_index" , type=int, required=True)
9 | ap.add_argument("--end_index" , type=int, required=True)
10 |
11 | # Event generation parameters
12 | ap.add_argument("--frame_rate", type=int , default=48)
13 | ap.add_argument("--num_frames", type=int , default=96)
14 | ap.add_argument("--ct_lower", type=float, default=0.16)
15 | ap.add_argument("--ct_upper", type=float, default=0.34)
16 | ap.add_argument("--ref_period", type=int, default=0)
17 |
18 | args = vars(ap.parse_args())
19 | dataset_path = args["dataset_path"]
20 | start_index = args["start_index"]
21 | end_index = args["end_index"]
22 |
23 | frame_rate = args["frame_rate"]
24 | num_frames = args["num_frames"]
25 | ct_lower = args["ct_lower"]
26 | ct_upper = args["ct_upper"]
27 | ref_period = args["ref_period"]
28 |
29 | example_counter = start_index
30 |
31 | def get_current_example_path():
32 | return os.path.join(dataset_path, f"{example_counter:08d}")
33 |
34 | while example_counter < end_index:
35 | print(f"Converting example {example_counter}")
36 |
37 | script = ["python3",
38 | "converter.py",
39 | f"--scene_dir={os.path.join(get_current_example_path(), 'raw')}",
40 | f"--output_dir={get_current_example_path()}",
41 | f"--frame_rate={frame_rate}",
42 | f"--num_frames={num_frames}",
43 | f"--ct_lower={ct_lower}",
44 | f"--ct_upper={ct_upper}",
45 | f"--ref_period={ref_period}"]
46 |
47 | convert_result = subprocess.run(script, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
48 |
49 | if convert_result.returncode == 0:
50 | print(f"Successfully converted example {example_counter}")
51 | else:
52 | print(f"Failed to convert example {example_counter}, return code: {convert_result.returncode}")
53 | break
54 |
55 | example_counter += 1
56 |
57 | if __name__ == "__main__":
58 | main()
--------------------------------------------------------------------------------
/data_pipeline/converter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rpg_vid2e.upsampling.utils import Upsampler
3 | from rpg_vid2e.esim_torch.scripts.generate_events import process_dir
4 |
5 | import argparse
6 | import shutil
7 | import os
8 |
9 | def main():
10 | ap = argparse.ArgumentParser()
11 | ap.add_argument("--scene_dir", required=True)
12 | ap.add_argument("--output_dir", required=True)
13 |
14 | # Event generation parameters
15 | ap.add_argument("--frame_rate", type=int , default=48)
16 | ap.add_argument("--num_frames", type=int , default=96)
17 | ap.add_argument("--ct_lower", type=float, default=0.16)
18 | ap.add_argument("--ct_upper", type=float, default=0.34)
19 | ap.add_argument("--ref_period", type=int, default=0)
20 |
21 | ## Only for vid2e's process_dir
22 | ap.add_argument("--contrast_threshold_negative", "-cp", type=float, default=0.2)
23 | ap.add_argument("--contrast_threshold_positive", "-cn", type=float, default=0.2)
24 | ap.add_argument("--refractory_period_ns", "-rp", type=int, default=0)
25 |
26 | args = vars(ap.parse_args())
27 | scene_dir = args["scene_dir"]
28 | output_dir = args["output_dir"]
29 |
30 | frame_rate = args["frame_rate"]
31 | num_frames = args["num_frames"]
32 | ct_lower = args["ct_lower"]
33 | ct_upper = args["ct_upper"]
34 | ref_period = args["ref_period"]
35 |
36 | tmpf = os.path.join(output_dir, "tmp")
37 | os.makedirs(os.path.join(tmpf, "seq", "imgs"))
38 |
39 | rgbs = [f"rgba_{i:05d}.png" for i in range(num_frames)]
40 |
41 | for rgb in rgbs:
42 | shutil.copy(os.path.join(scene_dir, rgb),
43 | os.path.join(tmpf, "seq", "imgs", rgb.split("_")[1]))
44 |
45 | # Upsample frames
46 | fpsf = open(os.path.join(tmpf, "seq", "fps.txt"), "w")
47 | fpsf.write(str(frame_rate))
48 | fpsf.close()
49 |
50 | upsampler = Upsampler(input_dir=os.path.join(tmpf, "seq"),
51 | output_dir=os.path.join(tmpf, "seq_upsampled"))
52 | upsampler.upsample()
53 |
54 | # Generate events
55 | vid2e_args = ap.parse_args()
56 | vid2e_args.contrast_threshold_positive = np.random.uniform(ct_lower, ct_upper)
57 | vid2e_args.contrast_threshold_negative = np.random.uniform(ct_lower, ct_upper)
58 | vid2e_args.refractory_period_ns = ref_period
59 |
60 | process_dir(os.path.join(output_dir, "events"),
61 | os.path.join(tmpf, "seq_upsampled"),
62 | vid2e_args)
63 |
64 | # Remove temporary files
65 | shutil.rmtree(tmpf)
66 |
67 | if __name__ == "__main__":
68 | main()
--------------------------------------------------------------------------------
/data_pipeline/decompress.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import h5py
4 | import numpy as np
5 | import cv2
6 | import argparse
7 |
8 | def save_image(data, file_path, data_type):
9 | if file_path.endswith('.tiff') or file_path.endswith('.tif'):
10 | cv2.imwrite(file_path, data)
11 | else:
12 | if data_type in ['backward_flow', 'forward_flow'] or data.dtype == np.uint16:
13 | cv2.imwrite(file_path, cv2.cvtColor(data, cv2.COLOR_RGB2BGR))
14 | #cv2.imwrite(file_path, data)
15 | elif len(data.shape) == 2: # Grayscale image
16 | cv2.imwrite(file_path, data)
17 | elif len(data.shape) == 3:
18 | if data.shape[2] == 3: # RGB image
19 | cv2.imwrite(file_path, cv2.cvtColor(data, cv2.COLOR_RGB2BGR))
20 | elif data.shape[2] == 4: # RGBA image
21 | cv2.imwrite(file_path, cv2.cvtColor(data, cv2.COLOR_RGBA2BGRA))
22 | else:
23 | raise ValueError(f"Unsupported image shape: {data.shape}")
24 | else:
25 | raise ValueError(f"Unsupported image shape: {data.shape}")
26 |
27 | def extract_sample(hdf5_path, output_folder):
28 | # Create output folder
29 | os.makedirs(output_folder, exist_ok=True)
30 |
31 | with h5py.File(hdf5_path, 'r') as hf:
32 | # Extract datasets
33 | for data_type in ['backward_flow', 'depth', 'forward_flow', 'normal', 'object_coordinates', 'rgba', 'segmentation']:
34 | if data_type in hf:
35 | data = hf[data_type][:]
36 | if len(data.shape) == 3: # Single image or multiple 2D frames
37 | for i, frame in enumerate(data):
38 | file_name = f"{data_type}_{i:06d}.{'tiff' if data_type == 'depth' else 'png'}"
39 | file_path = os.path.join(output_folder, file_name)
40 | save_image(frame, file_path, data_type)
41 | elif len(data.shape) == 4: # Multiple color frames
42 | for i, frame in enumerate(data):
43 | file_name = f"{data_type}_{i:06d}.{'tiff' if data_type == 'depth' else 'png'}"
44 | file_path = os.path.join(output_folder, file_name)
45 | save_image(frame, file_path, data_type)
46 | else:
47 | print(f"Warning: Unexpected shape for {data_type}: {data.shape}")
48 | else:
49 | print(f"Warning: {data_type} not found in HDF5 file")
50 |
51 | # Extract metadata and events
52 | if 'metadata' in hf.attrs:
53 | metadata = json.loads(hf.attrs['metadata'])
54 | with open(os.path.join(output_folder, 'metadata.json'), 'w') as f:
55 | json.dump(metadata, f, indent=2)
56 |
57 | if 'events' in hf.attrs:
58 | events = json.loads(hf.attrs['events'])
59 | with open(os.path.join(output_folder, 'events.json'), 'w') as f:
60 | json.dump(events, f, indent=2)
61 |
62 | if __name__ == "__main__":
63 | parser = argparse.ArgumentParser(description="Extract sample data from HDF5 file to original file structure")
64 | parser.add_argument("hdf5_path", help="Path to the input HDF5 file")
65 | parser.add_argument("output_folder", help="Path to the output folder for extracted files")
66 | args = parser.parse_args()
67 |
68 | extract_sample(args.hdf5_path, args.output_folder)
69 | print("Extraction completed.")
70 |
--------------------------------------------------------------------------------
/data_pipeline/sample.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import subprocess
3 | import shutil
4 | import time
5 | import os
6 |
7 | def main():
8 | ap = argparse.ArgumentParser()
9 | ap.add_argument("--output_dir" , required=True)
10 | ap.add_argument("--start_index", type=int, required=True)
11 | ap.add_argument("--end_index" , type=int, required=True)
12 | ap.add_argument("--worker_script", required=True)
13 | ap.add_argument("--panning" , action="store_true")
14 |
15 | args = vars(ap.parse_args())
16 | output_dir = args["output_dir"]
17 | start_index = args["start_index"]
18 | end_index = args["end_index"]
19 | worker_script = args["worker_script"] + ".py"
20 | panning = args["panning"]
21 |
22 | example_counter = start_index
23 | os.makedirs(output_dir, exist_ok=True)
24 |
25 | def get_current_output_dir():
26 | return os.path.join(output_dir, f"{example_counter:08d}")
27 |
28 | while example_counter < end_index:
29 | print(f"Generating example {example_counter}")
30 |
31 | script = ["python3",
32 | os.path.join("kubric", "challenges", "movi", worker_script),
33 | f"--job-dir={os.path.join(get_current_output_dir(), 'raw')}"]
34 |
35 | if panning:
36 | script.append("--camera=linear_movement_linear_lookat")
37 | else:
38 | script.append("--camera=linear_movement")
39 |
40 | # Generate training example
41 | kubric_result = subprocess.run(script, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
42 |
43 | # Regenerate training example on error
44 | if kubric_result.returncode == 0:
45 | print(f"Successfully generated example {example_counter}")
46 | else:
47 | print(f"Failed to generate example {example_counter}, return code: {kubric_result.returncode}")
48 | print("Retrying in 10 seconds . . .")
49 |
50 | if os.path.exists(get_current_output_dir()):
51 | shutil.rmtree(get_current_output_dir())
52 |
53 | time.sleep(10)
54 | continue
55 |
56 | example_counter += 1
57 |
58 | if __name__ == "__main__":
59 | main()
60 |
--------------------------------------------------------------------------------
/docs/flowchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/flowchart.png
--------------------------------------------------------------------------------
/docs/pred_e2d2_fidget.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/pred_e2d2_fidget.gif
--------------------------------------------------------------------------------
/docs/pred_eds_peanuts_running.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/pred_eds_peanuts_running.gif
--------------------------------------------------------------------------------
/docs/pred_event_kubric.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/pred_event_kubric.gif
--------------------------------------------------------------------------------
/docs/pred_evimo2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/pred_evimo2.gif
--------------------------------------------------------------------------------
/docs/thumbnail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/thumbnail.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.19.0
2 | pyyaml>=5.1.0
3 | tqdm>=4.42.0
4 | h5py>=3.1.0
5 | hdf5plugin
6 | opencv-python>=4.5.0
7 | pandas>=2.2.0
8 | pillow>=10.4.0
9 | pytorch-lightning>=2.2.5
10 | matplotlib
11 | imageio[ffmpeg]
12 | prettytable
13 | scipy
--------------------------------------------------------------------------------
/scripts/benchmark_feature_tracking.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import os
4 | import sys
5 | from enum import Enum
6 | import csv
7 |
8 | import matplotlib.pyplot as plt
9 |
10 | import numpy as np
11 | from prettytable import PrettyTable
12 |
13 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
14 |
15 | from src.utils import compute_tracking_errors, read_txt_results, SUPPORTED_SEQUENCES_FEATURE_TRACKING
16 |
17 | def calculate_mean_by_dataset_type(table, dataset_type):
18 | values = []
19 | for row in table.rows:
20 | if dataset_type == EvalDatasetType.EDS:
21 | if row[0] in SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']:
22 | values.append(float(row[1]))
23 | else: # EvalDatasetType.EC
24 | if row[0] in SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']:
25 | values.append(float(row[1]))
26 | return np.mean(values) if values else 0
27 |
28 | def create_summary_csv(tables, output_path):
29 | header = ['fa_5_eds_mean', 'fa_5_ec_mean', 'te_5_eds_mean', 'te_5_ec_mean', 'inliers_eds_mean', 'inliers_ec_mean']
30 | data = {}
31 |
32 | for k in tables.keys():
33 | eds_mean = calculate_mean_by_dataset_type(tables[k], EvalDatasetType.EDS)
34 | ec_mean = calculate_mean_by_dataset_type(tables[k], EvalDatasetType.EC)
35 |
36 | if k.startswith('age_5'):
37 | data['fa_5_eds_mean'] = eds_mean
38 | data['fa_5_ec_mean'] = ec_mean
39 | elif k.startswith('te_5'):
40 | data['te_5_eds_mean'] = eds_mean
41 | data['te_5_ec_mean'] = ec_mean
42 | elif k.startswith('inliers'):
43 | data['inliers_eds_mean'] = eds_mean
44 | data['inliers_ec_mean'] = ec_mean
45 |
46 | # Add columns for individual sequences
47 | all_sequences = []
48 | for sequence_name in SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']:
49 | all_sequences.append((sequence_name, EvalDatasetType.EDS))
50 | for sequence_name in SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']:
51 | all_sequences.append((sequence_name, EvalDatasetType.EC))
52 |
53 | for i, eval_sequence in enumerate(all_sequences):
54 | sequence_name = eval_sequence[0]
55 | for k in ['age_5', 'te_5', 'inliers']:
56 | column_name = f"{k}_{sequence_name}"
57 | header.append(column_name)
58 | data[column_name] = tables[f"{k}_mu"].rows[i][1]
59 |
60 | with open(output_path, 'w', newline='') as f:
61 | writer = csv.DictWriter(f, fieldnames=header)
62 | writer.writeheader()
63 | writer.writerow(data)
64 |
65 |
66 | # Data Classes for Inference
67 | class EvalDatasetType(Enum):
68 | EC = 0
69 | EDS = 1
70 |
71 | plt.rcParams["font.family"] = "serif"
72 |
73 | def parse_args():
74 | """
75 | Parse command-line arguments for the script.
76 |
77 | Returns:
78 | argparse.Namespace: An object containing parsed arguments.
79 | """
80 | parser = argparse.ArgumentParser(
81 | description="Parse paths for results and output directories."
82 | )
83 | parser.add_argument(
84 | 'method',
85 | type=Path,
86 | help="Path to the output directory."
87 | )
88 |
89 | parser.add_argument(
90 | '--results_dir',
91 | type=Path,
92 | default=Path("output/inference"),
93 | help="Path to the results directory."
94 | )
95 |
96 | args = parser.parse_args()
97 | return args
98 |
99 |
100 | if __name__ == "__main__":
101 | args = parse_args()
102 | error_threshold_range = np.arange(1, 32, 1)
103 | methods = [args.method]
104 |
105 | table_keys = [
106 | "age_5_mu",
107 | "age_5_std",
108 | "te_5_mu",
109 | "te_5_std",
110 | "age_mu",
111 | "age_std",
112 | "inliers_mu",
113 | "inliers_std",
114 | "expected_age",
115 | ]
116 | tables = {}
117 | for k in table_keys:
118 | tables[k] = PrettyTable()
119 | tables[k].title = k
120 | tables[k].field_names = ["Sequence Name"] + methods
121 |
122 | # Create a list of sequences with their dataset types
123 | eval_sequences = []
124 | for sequence_name in SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']:
125 | eval_sequences.append((sequence_name, EvalDatasetType.EDS))
126 | for sequence_name in SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']:
127 | eval_sequences.append((sequence_name, EvalDatasetType.EC))
128 |
129 | for eval_sequence in eval_sequences:
130 | sequence_name = eval_sequence[0]
131 | sequence_type = eval_sequence[1]
132 |
133 | gt_folder_name = 'eds' if sequence_type == EvalDatasetType.EDS else 'ec'
134 | track_data_gt = read_txt_results(
135 | Path('config/misc') / gt_folder_name / 'gt_tracks' / f"{sequence_name}.gt.txt"
136 | )
137 |
138 | rows = {}
139 | for k in tables.keys():
140 | rows[k] = [sequence_name]
141 |
142 | for method in methods:
143 | inlier_ratio_arr, fa_rel_nz_arr = [], []
144 |
145 | track_data_pred = read_txt_results(
146 | str(args.results_dir / f"{method}" / f"{sequence_name}.txt")
147 | )
148 |
149 | if track_data_pred[0, 1] != track_data_gt[0, 1]:
150 | raise ValueError # TODO: double check if this case occurs
151 | track_data_pred[:, 1] += -track_data_pred[0, 1] + track_data_gt[0, 1]
152 |
153 | for thresh in error_threshold_range:
154 | fa_rel, _ = compute_tracking_errors(
155 | track_data_pred,
156 | track_data_gt,
157 | error_threshold=thresh,
158 | asynchronous=False,
159 | )
160 |
161 | inlier_ratio = np.sum(fa_rel > 0) / len(fa_rel)
162 | if inlier_ratio > 0:
163 | fa_rel_nz = fa_rel[np.nonzero(fa_rel)[0]]
164 | else:
165 | fa_rel_nz = [0]
166 | inlier_ratio_arr.append(inlier_ratio)
167 | fa_rel_nz_arr.append(np.mean(fa_rel_nz))
168 |
169 | mean_inlier_ratio, std_inlier_ratio = np.mean(inlier_ratio_arr), np.std(
170 | inlier_ratio_arr
171 | )
172 | mean_fa_rel_nz, std_fa_rel_nz = np.mean(fa_rel_nz_arr), np.std(fa_rel_nz_arr)
173 | expected_age = np.mean(np.array(inlier_ratio_arr) * np.array(fa_rel_nz_arr))
174 |
175 | rows["age_mu"].append(mean_fa_rel_nz)
176 | rows["age_std"].append(std_fa_rel_nz)
177 | rows["inliers_mu"].append(mean_inlier_ratio)
178 | rows["inliers_std"].append(std_inlier_ratio)
179 | rows["expected_age"].append(expected_age)
180 |
181 | fa_rel, te = compute_tracking_errors(
182 | track_data_pred, track_data_gt, error_threshold=5, asynchronous=False
183 | )
184 | inlier_ratio = np.sum(fa_rel > 0) / len(fa_rel)
185 | if inlier_ratio > 0:
186 | fa_rel_nz = fa_rel[np.nonzero(fa_rel)[0]]
187 | else:
188 | fa_rel_nz = [0]
189 | te = [0]
190 |
191 | mean_fa_rel_nz, std_fa_rel_nz = np.mean(fa_rel_nz), np.std(fa_rel_nz)
192 | mean_te, std_te = np.mean(te), np.std(te)
193 | rows["age_5_mu"].append(mean_fa_rel_nz)
194 | rows["age_5_std"].append(std_fa_rel_nz)
195 | rows["te_5_mu"].append(mean_te)
196 | rows["te_5_std"].append(std_te)
197 |
198 | # Load results
199 | for k in tables.keys():
200 | tables[k].add_row(rows[k])
201 |
202 | with open((args.results_dir / f"{method}" / f"benchmarking_results.csv"), "w") as f:
203 | for k in tables.keys():
204 | f.write(f"{k}\n")
205 | f.write(tables[k].get_csv_string())
206 |
207 | # Calculate and write mean values for EDS and EC
208 | eds_mean = calculate_mean_by_dataset_type(tables[k], EvalDatasetType.EDS)
209 | ec_mean = calculate_mean_by_dataset_type(tables[k], EvalDatasetType.EC)
210 |
211 | f.write(f"EDS Mean,{eds_mean}\n")
212 | f.write(f"EC Mean,{ec_mean}\n\n")
213 |
214 | print(tables[k].get_string())
215 | print(f"EDS Mean: {eds_mean}")
216 | print(f"EC Mean: {ec_mean}\n")
217 |
218 | summary_csv_path = args.results_dir / f"{method}" / f"summary_results.csv"
219 | create_summary_csv(tables, summary_csv_path)
220 |
221 | print(f"Summary results written to {summary_csv_path}")
--------------------------------------------------------------------------------
/scripts/benchmark_tap.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | from pathlib import Path
4 | import numpy as np
5 | sys.path.append(str(Path(__file__).parent.parent))
6 | import src.utils as utils
7 |
8 | def parse_args():
9 | parser = argparse.ArgumentParser(description='Compute TAPVid metrics for point tracking')
10 | parser.add_argument('--gt_dir', type=str, default='data/e2d2/231025_110210_fidget5_high_exposure',
11 | help='Directory containing ground truth files')
12 | parser.add_argument('--pred_dir', type=str, default='output/inference/tap_e2d2',
13 | help='Directory containing prediction files')
14 | return parser.parse_args()
15 |
16 | if __name__ == '__main__':
17 | args = parse_args()
18 | gt_dir = Path(args.gt_dir)
19 | pred_dir = Path(args.pred_dir)
20 | error_threshold_range = 2 * np.array([1, 2, 4, 8, 16])
21 |
22 | gt_tracks_path = gt_dir / 'gt_positions.npy'
23 | gt_tracks = np.load(gt_tracks_path)
24 |
25 | query_path = gt_dir / 'queries.npy'
26 | query_points = np.load(query_path)
27 | query_t = np.zeros((query_points.shape[0], 1)) # All query points are at t = 0
28 | query_points = np.concatenate([query_t, query_points], axis=1)
29 |
30 | pred_path = Path(pred_dir) / '231025_110210_fidget5_high_exposure.npz'
31 | pred = np.load(pred_path)
32 | coords_predicted = pred['coords_predicted']
33 | vis_logits = pred['vis_logits']
34 | vis_predicted = vis_logits > 0.8
35 |
36 | gt_tracks_formatted = np.expand_dims(np.transpose(gt_tracks, (1, 0, 2)), axis=0)
37 | coords_predicted_formatted = np.expand_dims(np.transpose(coords_predicted, (1, 0, 2)), axis=0)
38 |
39 | # Create occlusion masks (assuming no occlusions in gt, using vis_predicted for predictions)
40 | num_points, num_frames = gt_tracks.shape[1], gt_tracks.shape[0]
41 | occluded_gt = np.zeros((1, num_points, num_frames), dtype=bool)
42 | occluded_pred = np.expand_dims(~np.transpose(vis_predicted, (1, 0)), axis=0) # Assuming vis_predicted is visibility, not occlusion
43 |
44 | tap_metrics = utils.compute_tapvid_metrics(
45 | query_points=np.expand_dims(query_points, axis=0), # Add batch dimension
46 | gt_occluded=occluded_gt,
47 | gt_tracks=gt_tracks_formatted,
48 | pred_occluded=occluded_pred,
49 | pred_tracks=coords_predicted_formatted,
50 | query_mode='first',
51 | thresholds=error_threshold_range
52 | )
53 | print("TAPVid Metrics:")
54 | for k, v in tap_metrics.items():
55 | print(f"{k}: {v}")
56 |
--------------------------------------------------------------------------------
/scripts/create_e2d2_fidget_spinner_gt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import h5py
5 | from tqdm import tqdm
6 | import torch
7 | from PIL import Image, ImageDraw, ImageFont
8 | from pathlib import Path
9 | import matplotlib.pyplot as plt
10 | from scipy.signal import find_peaks
11 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
12 | from src.representations import VoxelGrid
13 | from src.utils import Visualizer
14 |
15 | def make_grid(height, width, stride=1):
16 | x = np.arange(0, width, stride)
17 | y = np.arange(0, height, stride)
18 | X, Y = np.meshgrid(x, y)
19 | return np.stack([X.flatten(), Y.flatten()], axis=1)
20 |
21 | def read_binary_mask(path):
22 | img = np.array(Image.open(path).convert('L'))
23 | mask = (img != 255).astype(np.uint8)
24 | return mask
25 |
26 | def interpolate_wheel_points(query_points, center_point, timestamps, valleys, t_start, target_timestamps=None):
27 | """
28 | Interpolate point positions on a rotating wheel.
29 |
30 | Args:
31 | query_points: numpy array of shape [N, 2] containing initial (x, y) positions at t_start
32 | center_point: tuple or array (x, y) representing wheel center
33 | timestamps: array of all timestamps used for valley detection
34 | valleys: indices of valleys in timestamps array (when wheel rotates exactly 1/3)
35 | t_start: initial timestamp
36 | target_timestamps: optional array of timestamps at which to calculate positions
37 |
38 | Returns:
39 | numpy array of shape [len(target_timestamps), N, 2] containing interpolated positions
40 | """
41 | # Convert everything to numpy arrays and ensure correct shapes
42 | query_points = np.array(query_points)
43 | center_point = np.array(center_point)
44 |
45 | # Use target_timestamps if provided, otherwise use original timestamps
46 | target_timestamps = timestamps if target_timestamps is None else target_timestamps
47 |
48 | N = len(query_points)
49 | n_timestamps = len(target_timestamps)
50 |
51 | # Initialize output array
52 | new_positions = np.zeros((n_timestamps, N, 2))
53 |
54 | # Calculate initial angles for each point
55 | initial_vectors = query_points - center_point
56 | initial_angles = np.arctan2(initial_vectors[:, 1], initial_vectors[:, 0])
57 | radii = np.linalg.norm(initial_vectors, axis=1)
58 |
59 | # Calculate angular velocity between each pair of valleys using original timestamps
60 | valley_times = timestamps[valleys]
61 | angular_velocity = -2 * np.pi / 3 # negative for clockwise rotation
62 |
63 | # For each target timestamp, calculate the angle and new position
64 | for i, t in enumerate(target_timestamps):
65 | # Find the appropriate valley interval using original valley times
66 | if t < valley_times[0]:
67 | # Before first valley - interpolate from start
68 | delta_t = t - timestamps[0]
69 | fraction = delta_t / (valley_times[0] - timestamps[0])
70 | angle_change = fraction * angular_velocity
71 | elif t > valley_times[-1]:
72 | # After last valley - extrapolate from last interval
73 | delta_t = t - valley_times[-1]
74 | last_interval = valley_times[-1] - valley_times[-2]
75 | angle_change = angular_velocity * (len(valleys) - 1 + delta_t / last_interval)
76 | else:
77 | # Between valleys - find appropriate interval and interpolate
78 | valley_idx = np.searchsorted(valley_times, t) - 1
79 | valley_idx = max(0, valley_idx)
80 | delta_t = t - valley_times[valley_idx]
81 | interval = valley_times[valley_idx + 1] - valley_times[valley_idx]
82 | fraction = delta_t / interval
83 | angle_change = angular_velocity * (valley_idx + fraction)
84 |
85 | # Calculate new angles and positions
86 | new_angles = initial_angles + angle_change
87 |
88 | # Convert polar coordinates back to Cartesian
89 | new_positions[i, :, 0] = center_point[0] + radii * np.cos(new_angles)
90 | new_positions[i, :, 1] = center_point[1] + radii * np.sin(new_angles)
91 |
92 | return new_positions
93 |
94 | if __name__ == '__main__':
95 | data_dir = Path('data/e2d2/231025_110210_fidget5_high_exposure')
96 | output_dir = Path('output/e2d2_gt/231025_110210_fidget5_high_exposure')
97 |
98 | data_path = data_dir / 'seq.h5'
99 | sequence_name = '231025_110210_fidget5_high_exposure'
100 | t_start = 3.3961115
101 | duration = 0.5
102 | t_delta = 0.001
103 | timestamps = np.arange(t_start, t_start + duration, t_delta)
104 | timestamps = (timestamps * 1e6).astype(int)
105 | N_events = 20000
106 | image_shape = (480, 640)
107 | converter = VoxelGrid(image_shape, num_bins=1)
108 | threshold_l2_norm = 300
109 | mask_path = data_dir / 'mask00004.png'
110 | query_stride = 40
111 | center_point = np.array([369, 229])
112 | t_delta_gt = 0.0033
113 | timestamps_gt = np.arange(t_start, t_start + duration, t_delta_gt)
114 |
115 | output_dir = output_dir / f'{str(int(1e6 * t_delta_gt)).zfill(8)}'
116 | os.makedirs(output_dir, exist_ok=True)
117 | #os.makedirs(output_dir / 'histograms', exist_ok=True)
118 |
119 | # List to store L2 norms
120 | l2_norms = []
121 | first_frame = None
122 |
123 | with h5py.File(data_path, 'r') as f:
124 | ev_idx = np.searchsorted(f['t'][:], timestamps) - 1
125 | ev_idx_start = ev_idx - N_events // 2
126 | ev_idx_end = ev_idx + N_events // 2
127 |
128 | vid = []
129 |
130 | for i_start, i_end, t in tqdm(zip(ev_idx_start, ev_idx_end, timestamps), total=len(timestamps)):
131 | events = np.stack([f['y'][i_start:i_end],
132 | f['x'][i_start:i_end],
133 | f['t'][i_start:i_end],
134 | f['p'][i_start:i_end]], axis=-1)
135 |
136 | repr = converter(events)
137 | repr = repr[0]
138 |
139 | if first_frame is None:
140 | first_frame = repr.copy()
141 |
142 | l2_norm = np.sqrt(np.sum((repr - first_frame) ** 2))
143 | l2_norms.append(l2_norm)
144 |
145 | if repr.max() != repr.min():
146 | repr_norm = ((repr - repr.min()) / (repr.max() - repr.min()) * 255).astype(np.uint8)
147 | else:
148 | repr_norm = np.zeros_like(repr, dtype=np.uint8)
149 |
150 | # Convert to PIL Image
151 | img = Image.fromarray(repr_norm)
152 |
153 | # Add timestamp to image
154 | draw = ImageDraw.Draw(img)
155 | try:
156 | font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
157 | except:
158 | font = ImageFont.load_default()
159 |
160 | timestamp_sec = t / 1e6 # Convert to seconds for display
161 | draw.text((10, 10), f"{timestamp_sec:.3f}s", font=font, fill=255)
162 |
163 | frame_with_timestamp = np.array(img)
164 | vid.append(frame_with_timestamp)
165 |
166 | #filename = f"{t:012d}.png"
167 | #filepath = output_dir / 'histograms' / filename
168 | #img.save(filepath)
169 |
170 | video = np.stack(vid, axis=0) # Shape: [T, H, W]
171 | video = np.float32(video)
172 | video = np.stack([video, video, video], axis=1) # Shape: [T, 3, H, W]
173 | video = 255 * (video - video.min()) / (video.max() - video.min())
174 |
175 | # Convert to numpy array and invert the signal
176 | l2_norms = np.array(l2_norms)
177 | inverted_signal = -l2_norms
178 |
179 | # Find valleys (peaks in the inverted signal)
180 | valleys, _ = find_peaks(inverted_signal,
181 | height=(-threshold_l2_norm, 0),
182 | distance=10,
183 | prominence=10)
184 |
185 | valleys = np.concatenate([np.zeros(1, dtype=valleys.dtype), valleys])
186 |
187 | # Plot detected minima with scientific formatting
188 | plt.figure(figsize=(12, 6))
189 | timestamps_sec = timestamps / 1e6
190 | plt.rcParams.update({
191 | 'font.family': 'DejaVu Serif',
192 | 'font.size': 14,
193 | 'axes.labelsize': 16,
194 | 'axes.titlesize': 16,
195 | 'xtick.labelsize': 14,
196 | 'ytick.labelsize': 14,
197 | 'legend.fontsize': 14
198 | })
199 |
200 | plt.plot(timestamps_sec, l2_norms, label='L2 Norm')
201 | plt.plot(timestamps_sec[valleys], l2_norms[valleys], 'r*', label='Detected Valleys')
202 | plt.axhline(y=threshold_l2_norm, color='r', linestyle='--',
203 | label=f'Threshold ({threshold_l2_norm})')
204 | plt.xlabel('Time (s)')
205 | plt.ylabel('L2 Norm')
206 | plt.grid(True)
207 | plt.legend()
208 | plt.gca().xaxis.set_major_formatter(plt.FormatStrFormatter('%.3f'))
209 | plt.xticks(rotation=45)
210 | plt.tight_layout()
211 |
212 | plt.savefig(output_dir / 'l2_norms.pdf', format='pdf', bbox_inches='tight')
213 | plt.close()
214 |
215 | # Make queries
216 | height, width = image_shape
217 | query_xy = make_grid(height, width, stride=query_stride)
218 | num_queries = query_xy.shape[0]
219 |
220 | if mask_path is not None:
221 | segm_mask = read_binary_mask(mask_path)
222 | query_x = query_xy[:, 0]
223 | query_y = query_xy[:, 1]
224 | segm_mask = segm_mask[query_y, query_x]
225 | query_xy = query_xy[segm_mask == 1]
226 |
227 | # Interpolate positions for visualization
228 | new_positions = interpolate_wheel_points(query_xy, center_point, timestamps, valleys, t_start)
229 |
230 | # Calculate positions at GT timestamps
231 | timestamps_gt_us = (timestamps_gt * 1e6).astype(int) # Convert to same units as timestamps
232 | new_positions_gt = interpolate_wheel_points(query_xy, center_point, timestamps, valleys, t_start,
233 | target_timestamps=timestamps_gt_us)
234 | np.save(output_dir / 'gt_positions.npy', new_positions_gt)
235 | np.save(output_dir / 'gt_timestamps.npy', timestamps_gt_us)
236 | np.save(output_dir / 'queries.npy', query_xy)
237 |
238 | viz = Visualizer(save_dir=str(output_dir),
239 | pad_value=0,
240 | linewidth=2,
241 | tracks_leave_trace=-1,
242 | show_first_frame=0)
243 | rgbs = viz.visualize(
244 | torch.from_numpy(video[None]),
245 | torch.from_numpy(new_positions[None])
246 | )
247 |
248 | print('Done.')
--------------------------------------------------------------------------------
/scripts/create_evimo2_track_gt.py:
--------------------------------------------------------------------------------
1 | # This script calculates the ground truth point tracks for EVIMO datasets.
2 | import argparse
3 | import multiprocessing
4 | import os
5 | from pathlib import Path
6 | os.environ['OPENBLAS_NUM_THREADS'] = '1'
7 | from multiprocessing import Pool
8 | import cv2
9 | import numpy as np
10 | from scipy.spatial.transform import Rotation
11 | from scipy.spatial.transform import Slerp
12 | from scipy.linalg import logm
13 | from tqdm import tqdm
14 | import h5py
15 | import multiprocessing.pool as mpp
16 | import yaml
17 | import pandas as pd
18 |
19 | def istarmap(self, func, iterable, chunksize=1):
20 | self._check_running()
21 | if chunksize < 1:
22 | raise ValueError(
23 | "Chunksize must be 1+, not {0:n}".format(
24 | chunksize))
25 |
26 | task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
27 | result = mpp.IMapIterator(self)
28 | self._taskqueue.put(
29 | (
30 | self._guarded_task_generation(result._job,
31 | mpp.starmapstar,
32 | task_batches),
33 | result._set_length
34 | ))
35 | return (item for chunk in result for item in chunk)
36 |
37 | mpp.Pool.istarmap = istarmap
38 |
39 | def interpolate_pose(t, pose):
40 | right_i = np.searchsorted(pose[:, 0], t)
41 | if right_i==pose.shape[0]:
42 | return None
43 | if right_i==0:
44 | return None
45 |
46 | left_t = pose[right_i-1, 0]
47 | right_t = pose[right_i, 0]
48 |
49 | alpha = (t - left_t) / (right_t - left_t)
50 | if alpha>1:
51 | return None
52 | elif alpha < 0:
53 | return None
54 |
55 | left_position = pose[right_i - 1, 1:4]
56 | right_position = pose[right_i, 1:4]
57 |
58 | position_interp = alpha * (right_position - left_position) + left_position
59 |
60 | left_right_rot_stack = Rotation.from_quat((pose[right_i - 1, 4:8],
61 | pose[right_i, 4:8]))
62 |
63 | slerp = Slerp((0, 1), left_right_rot_stack)
64 | R_interp = slerp(alpha)
65 |
66 | return np.array([t,] + list(position_interp) + list(R_interp.as_quat()))
67 |
68 | def apply_transform(T_cb, T_ba):
69 | R_ba = Rotation.from_quat(T_ba[4:8])
70 | t_ba = T_ba[1:4]
71 |
72 | R_cb = Rotation.from_quat(T_cb[4:8])
73 | t_cb = T_cb[1:4]
74 |
75 | R_ca = R_cb * R_ba
76 | t_ca = R_cb.as_matrix() @ t_ba + t_cb
77 | return np.array([T_ba[0],] + list(t_ca) + list(R_ca.as_quat()))
78 |
79 | def inv_transform(T_ba):
80 | R_ba = Rotation.from_quat(T_ba[4:8])
81 | t_ba = T_ba[1:4]
82 |
83 | R_ab = R_ba.inv()
84 | t_ab = -R_ba.inv().as_matrix() @ t_ba
85 |
86 | return np.array([T_ba[0],] + list(t_ab) + list(R_ab.as_quat()))
87 |
88 | def project_points_radtan(points,
89 | fx, fy, cx, cy,
90 | k1, k2, p1, p2):
91 | x_ = np.divide(points[:, :, 0], points[:, :, 2], out=np.zeros_like(points[:, :, 0]), where=points[:, :, 2]!=0)
92 | y_ = np.divide(points[:, :, 1], points[:, :, 2], out=np.zeros_like(points[:, :, 1]), where=points[:, :, 2]!=0)
93 |
94 | r2 = np.square(x_) + np.square(y_)
95 | r4 = np.square(r2)
96 |
97 | dist = (1.0 + k1 * r2 + k2 * r4)
98 |
99 | x__ = x_ * dist + 2.0 * p1 * x_ * y_ + p2 * (r2 + 2.0 * x_ * x_)
100 | y__ = y_ * dist + 2.0 * p2 * x_ * y_ + p1 * (r2 + 2.0 * y_ * y_)
101 |
102 | u = fx * x__ + cx
103 | v = fy * y__ + cy
104 |
105 | return u, v
106 |
107 | def get_all_poses(meta):
108 | vicon_pose_samples = len(meta['full_trajectory'])
109 |
110 | poses = {}
111 | key_i = {}
112 | for key in meta['full_trajectory'][0].keys():
113 | if key == 'id' or key == 'ts' or key == 'gt_frame':
114 | continue
115 | poses[key] = np.zeros((vicon_pose_samples, 1+3+4))
116 | key_i[key] = 0
117 |
118 | for all_pose in meta['full_trajectory']:
119 | for key in poses.keys():
120 | if key == 'id' or key == 'ts' or key == 'gt_frame':
121 | continue
122 |
123 | if key in all_pose:
124 | i = key_i[key]
125 | poses[key][i, 0] = all_pose['ts']
126 | poses[key][i, 1] = all_pose[key]['pos']['t']['x']
127 | poses[key][i, 2] = all_pose[key]['pos']['t']['y']
128 | poses[key][i, 3] = all_pose[key]['pos']['t']['z']
129 | poses[key][i, 4] = all_pose[key]['pos']['q']['x']
130 | poses[key][i, 5] = all_pose[key]['pos']['q']['y']
131 | poses[key][i, 6] = all_pose[key]['pos']['q']['z']
132 | poses[key][i, 7] = all_pose[key]['pos']['q']['w']
133 | key_i[key] += 1
134 |
135 | for key in poses.keys():
136 | poses[key] = poses[key][:key_i[key], :]
137 |
138 | return poses
139 |
140 | def get_intrinsics(meta):
141 | meta_meta = meta['meta']
142 | K = np.array(((meta_meta['fx'], 0, meta_meta['cx']),
143 | ( 0, meta_meta['fy'], meta_meta['cy']),
144 | ( 0, 0, 1)))
145 |
146 | dist_coeffs = np.array((meta_meta['k1'],
147 | meta_meta['k2'],
148 | meta_meta['p1'],
149 | meta_meta['p2']))
150 |
151 | return K, dist_coeffs
152 |
153 | def load_data(file):
154 | meta = np.load(Path(file) / 'dataset_info.npz', allow_pickle=True)['meta'].item()
155 | depth = np.load(Path(file) / 'dataset_depth.npz')
156 | mask = np.load(Path(file) / 'dataset_mask.npz')
157 | return meta, depth, mask
158 |
159 | def convert(file, overwrite=False, max_m_per_s=9.0, max_norm_deg_per_s=6.25*360):
160 | cv2.setNumThreads(1)
161 |
162 | h5_tracks_file_name = Path(file) / 'dataset_tracks.h5'
163 |
164 | if not overwrite and h5_tracks_file_name.exists():
165 | print(f'skipping {file} because {h5_tracks_file_name} exists')
166 | return
167 |
168 | # Load data
169 | meta, depth, mask = load_data(file)
170 | all_poses = get_all_poses(meta)
171 | K, dist_coeffs = get_intrinsics(meta)
172 |
173 | # Get depth shape for map initialization
174 | first_depth_key = 'depth_' + str(0).rjust(10, '0')
175 | depth_shape = depth[first_depth_key].shape
176 |
177 | # Initialize undistortion maps
178 | map1, map2 = cv2.initInverseRectificationMap(
179 | K,
180 | dist_coeffs,
181 | np.eye(3),
182 | np.eye(3),
183 | (depth_shape[1], depth_shape[0]),
184 | cv2.CV_32FC1)
185 |
186 | # Get info from initial frame
187 | initial_frame_idx = 1 # For frame 0 there are sometimes not previous poses so we start from frame 1
188 | first_frame_id = meta['frames'][initial_frame_idx]['id'] # Use second frame
189 | first_frame_info = meta['frames'][initial_frame_idx]
190 | first_depth_key = 'depth_' + str(initial_frame_idx).rjust(10, '0') # Use depth from second frame
191 | first_mask_key = 'mask_' + str(initial_frame_idx).rjust(10, '0') # Use mask from second frame
192 |
193 | depth_frame = depth[first_depth_key].astype(np.float32) / 1000.0 # convert to meters
194 | mask_frame = mask[first_mask_key]
195 |
196 | # Initialize points to track (e.g., all points with valid depth)
197 | valid_points = np.where(depth_frame > 0) # [y, x]
198 | valid_points = (valid_points[0][::1000], valid_points[1][::1000]) # for debugging subsample a little bit
199 |
200 | initial_points = np.stack([valid_points[1], valid_points[0]], axis=1) # nx2 array of x,y coords
201 | initial_depths = depth_frame[valid_points]
202 | initial_masks = mask_frame[valid_points]
203 |
204 | # Initialize storage for tracks
205 | num_frames = len(meta['frames'])
206 | num_points = len(initial_points)
207 | tracks = np.full((num_frames, num_points, 2), np.nan, dtype=np.float32) # store x,y coordinates
208 | occlusions = np.ones((num_frames, num_points), dtype=bool) # True means occluded
209 | times = np.full(num_frames, np.nan, dtype=np.float32)
210 |
211 | # Store initial positions at index 0
212 | tracks[initial_frame_idx] = initial_points
213 | occlusions[initial_frame_idx] = False
214 | times[initial_frame_idx] = meta['frames'][initial_frame_idx]['ts']
215 |
216 | # Unproject initial points to 3D
217 | Z_m = initial_depths
218 | X_m = map1[valid_points] * Z_m
219 | Y_m = map2[valid_points] * Z_m
220 | initial_XYZ = np.stack((X_m, Y_m, Z_m), axis=1)
221 |
222 | depth_keys = sorted([k for k in depth.files if k.startswith('depth_')])
223 | frame_numbers = [int(k.split('_')[1]) for k in depth_keys]
224 |
225 | # Get initial poses for each object
226 | initial_time = first_frame_info['cam']['ts']
227 | initial_poses = {}
228 | for key in all_poses:
229 | if key != 'cam':
230 | initial_poses[key] = interpolate_pose(initial_time, all_poses[key])
231 |
232 | # Track through all frames starting from the third frame
233 | found_gap = False
234 | for frame_idx, frame_info in enumerate(meta['frames'][initial_frame_idx + 1:]):
235 | frame_idx += (initial_frame_idx + 1) # Start from index 1 since we used frame 1 as our initial frame
236 |
237 | if frame_numbers[frame_idx] != frame_numbers[frame_idx - 1] + 1:
238 | print(f"Gap found in depth maps between frames {frame_numbers[frame_idx-1]} and {frame_numbers[frame_idx]}")
239 | found_gap = True
240 | break
241 |
242 | times[frame_idx] = frame_info['ts']
243 | frame_id = frame_info['id'] - first_frame_id
244 | current_depth_key = 'depth_' + str(frame_id).rjust(10, '0')
245 | current_mask_key = 'mask_' + str(frame_id).rjust(10, '0')
246 |
247 | current_depth = depth[current_depth_key].astype(np.float32) / 1000.0
248 | current_mask = mask[current_mask_key]
249 |
250 | # Get current poses for each object
251 | frame_time = frame_info['cam']['ts']
252 | frame_poses = {}
253 | for key in all_poses:
254 | if key != 'cam':
255 | frame_poses[key] = interpolate_pose(frame_time, all_poses[key])
256 |
257 | # Transform and project points
258 | for point_idx in range(num_points):
259 | object_id = initial_masks[point_idx] // 1000
260 | object_key = str(object_id)
261 |
262 | if object_key not in frame_poses or object_key not in initial_poses:
263 | continue
264 |
265 | # Get initial and current object poses
266 | T_c1o = initial_poses[object_key]
267 | T_c2o = frame_poses[object_key]
268 |
269 | if T_c1o is None or T_c2o is None:
270 | continue
271 |
272 | # Calculate relative transform from initial camera frame to current camera frame
273 | T_c2c1 = apply_transform(T_c2o, inv_transform(T_c1o))
274 |
275 | # Check velocity to detect potential tracking loss
276 | dt = frame_time - initial_time
277 | if dt > 0:
278 | v = np.linalg.norm(T_c2c1[1:4]) / dt
279 | R_matrix = Rotation.from_quat(T_c2c1[4:8]).as_matrix()
280 | w_matrix = logm(R_matrix) / dt
281 | w = np.array([-w_matrix[1,2], w_matrix[0, 2], -w_matrix[0, 1]])
282 | w_deg = np.linalg.norm(w) * 180 / np.pi
283 |
284 | if v > max_m_per_s or w_deg > max_norm_deg_per_s:
285 | continue
286 |
287 | # Transform point using relative transform
288 | point_3d = initial_XYZ[point_idx]
289 | R = Rotation.from_quat(T_c2c1[4:8]).as_matrix()
290 | t = T_c2c1[1:4]
291 | transformed_point = (R @ point_3d) + t
292 |
293 | # Project to image plane
294 | px, py = project_points_radtan(transformed_point[None, None],
295 | K[0, 0], K[1,1], K[0, 2], K[1, 2],
296 | *dist_coeffs)
297 | px = px[0, 0]
298 | py = py[0, 0]
299 |
300 | is_occluded = False
301 |
302 | # Check if point projects outside image bounds
303 | if px < 0 or px >= current_depth.shape[1] or py < 0 or py >= current_depth.shape[0]:
304 | is_occluded = True
305 | else:
306 | px_int, py_int = int(px), int(py)
307 |
308 | # Check if point is occluded using depth test
309 | actual_depth = current_depth[py_int, px_int]
310 | if actual_depth > 0 and abs(actual_depth - transformed_point[2]) > 0.03: # 3cm threshold
311 | is_occluded = True
312 |
313 | # Check if point projects onto the correct object
314 | if current_mask[py_int, px_int] != initial_masks[point_idx]:
315 | is_occluded = True
316 |
317 | tracks[frame_idx, point_idx] = [px, py]
318 | occlusions[frame_idx, point_idx] = is_occluded
319 |
320 | # Find the last valid timestamp if we found a gap
321 | if found_gap:
322 | valid_indices = ~np.isnan(times)
323 | last_valid_idx = np.where(valid_indices)[0][-1]
324 | else:
325 | last_valid_idx = num_frames - 1
326 |
327 | times = times[initial_frame_idx:last_valid_idx + 1]
328 | tracks = tracks[initial_frame_idx:last_valid_idx + 1]
329 | occlusions = occlusions[initial_frame_idx:last_valid_idx + 1]
330 |
331 | with h5py.File(h5_tracks_file_name, 'w') as f:
332 | f.create_dataset('tracks', data=tracks, compression='gzip', compression_opts=9)
333 | f.create_dataset('occlusions', data=occlusions, compression='gzip', compression_opts=9)
334 | f.create_dataset('initial_points', data=initial_points, compression='gzip', compression_opts=9)
335 | f.create_dataset('initial_masks', data=initial_masks, compression='gzip', compression_opts=9)
336 | f.create_dataset('times', data=times, compression='gzip', compression_opts=9)
337 |
338 | def process_with_error_handling(args):
339 | try:
340 | convert(*args)
341 | return True
342 | except Exception as e:
343 | print(f"Error processing {args[0]}: {str(e)}")
344 | return False
345 |
346 | if __name__ == '__main__':
347 | parser = argparse.ArgumentParser(epilog='Calculates optical flow from EVIMO datasets.')
348 | parser.add_argument('--dt', dest='dt', type=float, default=0.01,
349 | help='dt for flow approximation')
350 | parser.add_argument('--overwrite', dest='overwrite', action='store_true',
351 | help='Overwrite existing output files')
352 | parser.add_argument('--max_m_per_s', dest='max_m_per_s', type=float, default=9.0,
353 | help='Maximum meters per second of linear velocity')
354 | parser.add_argument('--max_norm_deg_per_s', dest='max_norm_deg_per_s', type=float, default=6.25*360,
355 | help='Maximum normed degrees per second of angular velocity')
356 | parser.add_argument('--data_root', type=str, required=True,
357 | help='Root directory for dataset paths')
358 | parser.add_argument('--config', type=str, default='config/misc/evimo2/val_samples.csv',
359 | help='YAML config file with dataset paths')
360 | parser.add_argument('--split', type=str, default='val_samples',
361 | help='Dataset split to process from YAML')
362 |
363 | args = parser.parse_args()
364 | files = list(pd.read_csv(args.config)['name'])
365 |
366 | p_args_list = [[
367 | Path(args.data_root) / f,
368 | args.overwrite,
369 | args.max_m_per_s,
370 | args.max_norm_deg_per_s
371 | ] for f in files]
372 |
373 | with Pool(multiprocessing.cpu_count()) as p:
374 | results = list(tqdm(
375 | p.imap_unordered(process_with_error_handling, p_args_list),
376 | total=len(p_args_list),
377 | desc='Sequences'
378 | ))
379 |
380 | successes = sum(results)
381 | print(f"Completed {successes}/{len(p_args_list)} sequences")
--------------------------------------------------------------------------------
/scripts/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 |
8 | sys.path.append(str(Path(__file__).parent.parent))
9 | from src.representations import MixedDensityEventStack
10 | from src.model.etap.model import Etap
11 | from src.utils import Visualizer
12 |
13 | def normalize_and_expand_channels(image):
14 | if not isinstance(image, torch.Tensor):
15 | raise TypeError("Input must be a torch tensor")
16 |
17 | *batch_dims, height, width = image.shape
18 |
19 | if len(image.shape) < 2:
20 | raise ValueError("Input tensor must have at least shape (..., height, width)")
21 |
22 | image_flat = image.view(-1, height, width)
23 |
24 | min_val = image_flat.min()
25 | max_val = image_flat.max()
26 | range_val = max_val - min_val
27 | range_val[range_val == 0] = 1
28 |
29 | image_normalized = (image_flat - min_val) / range_val * 255
30 | image_rgb_flat = image_normalized.to(torch.uint8).unsqueeze(1).repeat(1, 3, 1, 1)
31 | image_rgb = image_rgb_flat.view(*batch_dims, 3, height, width)
32 |
33 | return image_rgb
34 |
35 | if __name__ == '__main__':
36 | device = 'cpu'
37 | data_dir = Path('data/demo_example')
38 | ckpt_path = Path('weights/ETAP_v1_cvpr25.pth')
39 | output_dir = Path(f'output/{data_dir.name}')
40 | num_bins = 10
41 | num_events = 60000
42 | height, width = 480, 640
43 | t_start = 1/30 # seconds
44 | t_end = 1.5 # seconds
45 | t_delta = 1/60 # seconds
46 |
47 | # Object to convert raw event data into grid representations
48 | converter = MixedDensityEventStack(
49 | image_shape=(height, width),
50 | num_stacks=num_bins,
51 | interpolation='bilinear',
52 | channel_overlap=True,
53 | centered_channels=False
54 | )
55 |
56 | # Load the model
57 | tracker = Etap(num_in_channels=num_bins, stride=4, window_len=8)
58 | weights = torch.load(ckpt_path, map_location='cpu', weights_only=True)
59 | tracker.load_state_dict(weights)
60 | tracker = tracker.to(device)
61 | tracker.eval()
62 |
63 | # Let's choose some timestamps at which we create the frame representations
64 | tracking_timestamps = np.arange(t_start, t_end, t_delta) # seconds
65 |
66 | # Load raw event data
67 | xy = np.load(data_dir / 'dataset_events_xy.npy', mmap_mode='r')
68 | p = np.load(data_dir / 'dataset_events_p.npy', mmap_mode='r')
69 | t = np.load(data_dir / 'dataset_events_t.npy', mmap_mode='r')
70 |
71 | assert t_start > t[0], "Start time must be greater than the first event timestamp"
72 | assert t_end < t[-1], "End time must be less than the last event timestamp"
73 | assert t_delta > 0, "Time delta must be greater than zero"
74 | assert t_start < t_end, "Start time must be less than end time"
75 | assert xy.shape[0] == p.shape[0] == t.shape[0], "Event data arrays must have the same length"
76 |
77 | event_indices = np.searchsorted(t, tracking_timestamps)
78 | event_representations = []
79 |
80 | # At each tracking timestep, we take the last num_events events and convert
81 | # them into a grid representation.
82 | for i_end in tqdm(event_indices, desc='Creating grid representations'):
83 | i_start = max(i_end - num_events, 0)
84 |
85 | events = np.stack([xy[i_start:i_end, 1],
86 | xy[i_start:i_end, 0],
87 | t[i_start:i_end],
88 | p[i_start:i_end]], axis=1)
89 | ev_repr = converter(events)
90 | event_representations.append(ev_repr)
91 |
92 | voxels = np.stack(event_representations, axis=0)
93 | voxels = torch.from_numpy(voxels)[None].float().to(device)
94 |
95 | # Now let's determine some queries, meaning the initial positions
96 | # of the points to track.
97 | x = np.arange(0, width, 32)
98 | y = np.arange(0, height, 32)
99 | X, Y = np.meshgrid(x, y)
100 | grid = np.stack([X.flatten(), Y.flatten()], axis=1)
101 | query_xy = torch.from_numpy(grid).float().to(device)
102 | num_queries = query_xy.shape[0]
103 | # We track all points from the beginning (t=0)
104 | query_t = torch.zeros(num_queries, dtype=torch.int64, device=device)
105 | queries = torch.cat([query_t[:, None], query_xy], dim=1)
106 | queries = queries[None].to(device)
107 |
108 | with torch.no_grad():
109 | result = tracker(voxels, queries, iters=6)
110 | predictions, visibility = result['coords_predicted'], result['vis_predicted']
111 |
112 | visibility = visibility > 0.8
113 |
114 | # Visualization
115 | projection = voxels.sum(2).cpu()
116 | lower_bound = torch.tensor(np.percentile(projection.cpu().numpy(), 2))
117 | upper_bound = torch.tensor(np.percentile(projection.cpu().numpy(), 98))
118 | projection_clipped = torch.clamp(projection, lower_bound, upper_bound)
119 | video = normalize_and_expand_channels(projection_clipped)
120 |
121 | viz = Visualizer(save_dir=output_dir, tracks_leave_trace=-1)
122 | viz.visualize(
123 | video=video,
124 | tracks=predictions,
125 | visibility=visibility,
126 | filename=f"pred",
127 | )
128 | print('Done.')
--------------------------------------------------------------------------------
/scripts/download_eds.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Script to download and extract EDS dataset sequences
4 | # Usage: ./download_eds_data.sh [output_directory]
5 |
6 | set -e # Exit on error
7 |
8 | OUTPUT_DIR="$(pwd)/data/eds"
9 | if [ "$1" != "" ]; then
10 | OUTPUT_DIR="$1"
11 | fi
12 |
13 | mkdir -p "$OUTPUT_DIR"
14 | echo "Downloading datasets to: $OUTPUT_DIR"
15 |
16 | BASE_URL="https://download.ifi.uzh.ch/rpg/eds/dataset"
17 |
18 | SEQUENCES=(
19 | "01_peanuts_light"
20 | "02_rocket_earth_light"
21 | "08_peanuts_running"
22 | "14_ziggy_in_the_arena"
23 | )
24 |
25 | for seq in "${SEQUENCES[@]}"; do
26 | echo "======================================================="
27 | echo "Processing sequence: $seq"
28 |
29 | mkdir -p "$OUTPUT_DIR/$seq"
30 |
31 | URL="$BASE_URL/$seq/$seq.tgz"
32 | TGZ_FILE="$OUTPUT_DIR/$seq/$seq.tgz"
33 |
34 | echo "Downloading from: $URL"
35 | wget -c "$URL" -O "$TGZ_FILE"
36 |
37 | echo "Extracting..."
38 | tar -xzf "$TGZ_FILE" -C "$OUTPUT_DIR/$seq"
39 |
40 | echo "Removing archive..."
41 | rm "$TGZ_FILE"
42 |
43 | echo "Done with $seq"
44 | echo ""
45 | done
46 |
47 | echo "All sequences have been downloaded and extracted."
48 | echo "Data is available in: $OUTPUT_DIR"
49 |
50 | echo "======================================================="
51 | echo "Summary of downloaded data:"
52 | for seq in "${SEQUENCES[@]}"; do
53 | echo "$seq:"
54 | ls -la "$OUTPUT_DIR/$seq" | grep -v "^total"
55 | echo ""
56 | done
--------------------------------------------------------------------------------
/scripts/inference_offline.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import sys
4 |
5 | import yaml
6 | from tqdm import tqdm
7 | import torch
8 | import numpy as np
9 | import pandas as pd
10 |
11 | sys.path.append(str(Path(__file__).parent.parent))
12 | from src.data.modules import DataModuleFactory
13 | from src.model.etap.model import Etap
14 | from src.utils import Visualizer, compute_tapvid_metrics
15 |
16 |
17 | def normalize_and_expand_channels(image):
18 | if not isinstance(image, torch.Tensor):
19 | raise TypeError("Input must be a torch tensor")
20 |
21 | *batch_dims, height, width = image.shape
22 |
23 | if len(image.shape) < 2:
24 | raise ValueError("Input tensor must have at least shape (..., height, width)")
25 |
26 | image_flat = image.view(-1, height, width)
27 |
28 | min_val = image_flat.min()
29 | max_val = image_flat.max()
30 | range_val = max_val - min_val
31 | range_val[range_val == 0] = 1
32 |
33 | image_normalized = (image_flat - min_val) / range_val * 255
34 | image_rgb_flat = image_normalized.to(torch.uint8).unsqueeze(1).repeat(1, 3, 1, 1)
35 | image_rgb = image_rgb_flat.view(*batch_dims, 3, height, width)
36 |
37 | return image_rgb
38 |
39 | def main():
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument('--config', type=str, default='config/exe/test_event_kubric/debug.yaml')
42 | parser.add_argument('--device', type=str, default='cpu')
43 | args = parser.parse_args()
44 |
45 | # Load and process config
46 | with open(args.config, "r") as f:
47 | config = yaml.safe_load(f)
48 |
49 | output_dir = config['common'].get('output_dir', 'output/inference')
50 | output_dir = Path(output_dir) / config['common']['exp_name']
51 | output_dir.mkdir(parents=True, exist_ok=True)
52 | checkpoint_path = config['common'].get('checkpoint')
53 |
54 | data_module = DataModuleFactory.create(config['data'])
55 | data_module.prepare_data()
56 | test_set = data_module.test_dataset
57 |
58 | model_config = config['model']
59 | model_config['model_resolution'] = (512, 512)
60 | tracker = Etap(**model_config)
61 | weights = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
62 | tracker.load_state_dict(weights)
63 | tracker = tracker.to(args.device)
64 | tracker.eval()
65 |
66 | viz = Visualizer(save_dir=output_dir, tracks_leave_trace=-1)
67 |
68 | seq_names = []
69 | sequence_metrics = []
70 |
71 | for i, sample in tqdm(enumerate(test_set), total=len(test_set), desc="Testing"):
72 | voxels = sample.voxels[None].to(args.device)
73 | trajs_g = sample.trajectory[None].to(args.device)
74 | vis_g = sample.visibility[None].to(args.device)
75 | queries = sample.query_points[None].to(args.device)
76 | B, T, C, H, W = voxels.shape
77 | _, _, N, D = trajs_g.shape
78 |
79 | with torch.no_grad():
80 | result = tracker(voxels, queries, iters=6)
81 | predictions, visibility = result['coords_predicted'], result['vis_predicted']
82 |
83 | if visibility.dtype != torch.bool:
84 | visibility = visibility > 0.8
85 |
86 | # Visualization
87 | projection = voxels.sum(2).cpu()
88 | lower_bound = torch.tensor(np.percentile(projection.cpu().numpy(), 2))
89 | upper_bound = torch.tensor(np.percentile(projection.cpu().numpy(), 98))
90 | projection_clipped = torch.clamp(projection, lower_bound, upper_bound)
91 | video = normalize_and_expand_channels(projection_clipped)
92 |
93 | viz.visualize(
94 | video=video,
95 | tracks=predictions,
96 | visibility=visibility,
97 | filename=f"pred_{sample.seq_name}",
98 | )
99 | viz.visualize(
100 | video=video,
101 | tracks=trajs_g,
102 | visibility=vis_g,
103 | filename=f"gt_{sample.seq_name}",
104 | )
105 |
106 | # Calculate metrics for this sequence
107 | queries_np = queries.cpu().numpy()
108 | trajs_g_np = trajs_g.cpu().numpy().transpose(0, 2, 1, 3)
109 | gt_occluded_np = ~vis_g.cpu().numpy().transpose(0, 2, 1)
110 | trajs_pred_np = predictions.cpu().numpy().transpose(0, 2, 1, 3)
111 | pred_occluded_np = ~visibility.cpu().numpy().transpose(0, 2, 1)
112 |
113 | seq_metrics = compute_tapvid_metrics(
114 | query_points=queries_np,
115 | gt_occluded=gt_occluded_np,
116 | gt_tracks=trajs_g_np,
117 | pred_occluded=pred_occluded_np,
118 | pred_tracks=trajs_pred_np,
119 | query_mode="first"
120 | )
121 |
122 | seq_names.append(sample.seq_name)
123 | sequence_metrics.append(seq_metrics)
124 |
125 | data = []
126 | for seq_name, seq_metrics in zip(seq_names, sequence_metrics):
127 | row = {'sequence': seq_name}
128 | row.update(seq_metrics)
129 | data.append(row)
130 |
131 | avg_metrics = {}
132 | metric_keys = sequence_metrics[0].keys()
133 | for key in metric_keys:
134 | values = [metrics[key] for metrics in sequence_metrics]
135 | avg_metrics[key] = np.mean(values)
136 |
137 | avg_row = {'sequence': 'average'}
138 | avg_row.update(avg_metrics)
139 | data.append(avg_row)
140 |
141 | df = pd.DataFrame(data)
142 | csv_path = output_dir / 'sequence_metrics.csv'
143 | df.to_csv(csv_path, index=False)
144 |
145 | print(f'Metrics saved to: {csv_path}')
146 | print('Done.')
147 |
148 | if __name__ == '__main__':
149 | main()
--------------------------------------------------------------------------------
/scripts/inference_online.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import subprocess
3 | from pathlib import Path
4 | import sys
5 | import numpy as np
6 |
7 | import yaml
8 | import torch
9 | from tqdm import tqdm
10 |
11 | sys.path.append(str(Path(__file__).parent.parent))
12 |
13 | from src.data.modules import DataModuleFactory
14 | from src.model.etap.model import Etap
15 | from src.utils import Visualizer, normalize_and_expand_channels, make_grid
16 |
17 | torch.set_float32_matmul_precision('high')
18 |
19 |
20 | def write_points_to_file(points, timestamps, filepath):
21 | """Write tracking points to a file."""
22 | T, N, _ = points.shape
23 |
24 | with open(filepath, 'w') as f:
25 | for t in range(T):
26 | for n in range(N):
27 | x, y = points[t, n]
28 | f.write(f"{n} {timestamps[t]:.9f} {x:.9f} {y:.9f}\n")
29 |
30 |
31 | def get_git_commit_hash():
32 | """Get the current git commit hash."""
33 | try:
34 | commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD'], stderr=subprocess.STDOUT).decode().strip()
35 | return commit_hash
36 | except subprocess.CalledProcessError as e:
37 | print(f"Error obtaining git commit hash: {e.output.decode().strip()}")
38 | return "unknown"
39 |
40 |
41 | def normalize_voxels(voxels):
42 | """Perform channelwise std-mean normalization on voxels."""
43 | mask = voxels != 0
44 | mean = voxels.sum(dim=(0, 2, 3), keepdim=True) / mask.sum(dim=(0, 2, 3), keepdim=True)
45 | var = ((voxels - mean)**2 * mask).sum(dim=(0, 2, 3), keepdim=True) / mask.sum(dim=(0, 2, 3), keepdim=True)
46 | std = torch.sqrt(var + 1e-8)
47 | return torch.where(mask, (voxels - mean) / std, voxels)
48 |
49 |
50 | def main():
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument('--config', type=str, default='config/exe/inference_online/feature_tracking.yaml')
53 | parser.add_argument('--device', type=str, default='cuda:0')
54 | args = parser.parse_args()
55 |
56 | config_path = Path(args.config)
57 | with open(config_path, "r") as f:
58 | config = yaml.safe_load(f)
59 |
60 | project_root = Path(__file__).parent.parent
61 | save_dir = project_root / 'output' / 'inference' / config['common']['exp_name']
62 | save_dir.mkdir(parents=True, exist_ok=True)
63 |
64 | config['runtime_info'] = {
65 | 'command': ' '.join(sys.argv),
66 | 'git_commit': get_git_commit_hash()
67 | }
68 | config_save_path = save_dir / 'config.yaml'
69 | with open(config_save_path, 'w') as f:
70 | yaml.dump(config, f, default_flow_style=False)
71 |
72 | add_support_points = config['common'].get('add_support_points', False)
73 |
74 | if add_support_points:
75 | support_point_stride = config['common'].get('support_point_stride', 20)
76 | height, width = config['common']['height'], config['common']['width']
77 |
78 | device = torch.device(args.device)
79 | data_module = DataModuleFactory.create(config['data'])
80 | data_module.prepare_data()
81 |
82 | model_config = config['model']
83 | model_config['model_resolution'] = (512, 512)
84 | checkpoint_path = Path(config['common']['ckp_path'])
85 |
86 | tracker = Etap(**model_config)
87 | weights = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
88 | tracker.load_state_dict(weights)
89 | tracker = tracker.to(device)
90 | tracker.eval()
91 |
92 | viz = Visualizer(save_dir=save_dir, pad_value=0, linewidth=1,
93 | tracks_leave_trace=-1, show_first_frame=5)
94 |
95 | for dataset in data_module.test_datasets:
96 | sequence_name = dataset.subsequence_name
97 | tracker.init_video_online_processing()
98 | timestamps_s = None
99 |
100 | original_queries = dataset.query_points.to(device)
101 |
102 | if add_support_points:
103 | support_query_xy = torch.from_numpy(make_grid(height, width, stride=support_point_stride)).float().to(device)
104 | support_num_queries = support_query_xy.shape[0]
105 |
106 | support_query_t = torch.zeros(support_num_queries, dtype=torch.int64, device=device)
107 | support_queries = torch.cat([support_query_t[:, None], support_query_xy], dim=1)
108 |
109 | queries = torch.cat([original_queries, support_queries])
110 | print(f"Added {support_num_queries} support points to {original_queries.shape[0]} original queries")
111 | else:
112 | queries = original_queries
113 | support_num_queries = 0
114 |
115 | event_visus = None
116 |
117 | for sample, start_idx in tqdm(dataset, desc=f'Predicting {sequence_name}'):
118 | assert start_idx == tracker.online_ind
119 | voxels = sample.voxels.to(device)
120 | step = voxels.shape[0] // 2
121 |
122 | if timestamps_s is None:
123 | timestamps_s = sample.timestamps
124 | else:
125 | timestamps_s = torch.cat([timestamps_s, sample.timestamps[-step:]])
126 |
127 | voxels = normalize_voxels(voxels)
128 |
129 | with torch.no_grad():
130 | results = tracker(
131 | video=voxels[None],
132 | queries=queries[None],
133 | is_online=True,
134 | iters=6
135 | )
136 |
137 | coords_predicted = results['coords_predicted'].clone()
138 | vis_logits = results['vis_predicted']
139 |
140 | # Remove support points
141 | if support_num_queries > 0:
142 | coords_predicted = coords_predicted[:, :, :-support_num_queries]
143 | vis_logits = vis_logits[:, :, :-support_num_queries]
144 |
145 | event_visu = normalize_and_expand_channels(voxels.sum(dim=1))
146 | event_visus = torch.cat([event_visus, event_visu[-step:]]) if event_visus is not None else event_visu
147 |
148 | # Save predictions
149 | output_file = save_dir / f'{sequence_name}.npz'
150 | np.savez(
151 | output_file,
152 | coords_predicted=coords_predicted[0].cpu().numpy(),
153 | vis_logits=vis_logits[0].cpu().numpy(),
154 | timestamps=timestamps_s.cpu().numpy()
155 | )
156 |
157 | # Save predictions in feature tracking format
158 | output_file = save_dir / f'{sequence_name}.txt'
159 | write_points_to_file(
160 | coords_predicted.cpu().numpy()[0],
161 | timestamps_s,
162 | output_file
163 | )
164 |
165 | viz.visualize(
166 | event_visus[None],
167 | coords_predicted,
168 | filename=sequence_name
169 | )
170 |
171 | print('Done.')
172 |
173 | if __name__ == '__main__':
174 | main()
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/src/__init__.py
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/src/data/__init__.py
--------------------------------------------------------------------------------
/src/data/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .feature_tracking_online import FeatureTrackingDataModule
2 | from .e2d2 import E2d2DataModule
3 | from .penn_aviary import PennAviaryDataModule
4 | from .event_kubric import EventKubricDataModule
5 | from .evimo2 import Evimo2DataModule
6 |
7 | class DataModuleFactory:
8 | @staticmethod
9 | def create(data_config):
10 | dataset_name = data_config['dataset_name']
11 |
12 | if dataset_name == 'feature_tracking_online':
13 | return FeatureTrackingDataModule(**data_config)
14 | elif dataset_name == 'e2d2':
15 | return E2d2DataModule(**data_config)
16 | elif dataset_name == 'penn_aviary':
17 | return PennAviaryDataModule(**data_config)
18 | elif dataset_name == 'event_kubric':
19 | return EventKubricDataModule(**data_config)
20 | elif dataset_name == 'evimo2':
21 | return Evimo2DataModule(**data_config)
22 | else:
23 | raise ValueError("Unsupported dataset_name.")
24 |
--------------------------------------------------------------------------------
/src/data/modules/e2d2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from pathlib import Path
4 | import pytorch_lightning as pl
5 |
6 | from ..utils import EtapData
7 |
8 |
9 | class E2d2DataModule(pl.LightningDataModule):
10 | def __init__(self, dataset_name, data_root, preprocessed_name, sequences=None):
11 | super().__init__()
12 | self.save_hyperparameters()
13 | self.test_datasets = []
14 |
15 | def prepare_data(self):
16 | """Prepare test datasets using preprocessed event representations."""
17 | if self.hparams.sequences is None:
18 | sequences = [d for d in self.hparams.data_root.iterdir() if d.is_dir()]
19 | else:
20 | sequences = [Path(self.hparams.data_root) / seq for seq in self.hparams.sequences]
21 |
22 | for sequence_path in sequences:
23 | if not (sequence_path / 'seq.h5').exists():
24 | print(f"Warning: seq.h5 not found in {sequence_path}, skipping...")
25 | continue
26 |
27 | self.test_datasets.append(E2D2InferenceDataset(
28 | sequence_path=sequence_path,
29 | preprocessed_name=self.hparams.preprocessed_name,
30 | stride=4,
31 | sliding_window_len=8
32 | ))
33 |
34 |
35 | class E2D2InferenceDataset(torch.utils.data.Dataset):
36 | def __init__(
37 | self,
38 | sequence_path,
39 | preprocessed_name,
40 | stride=4,
41 | sliding_window_len=8
42 | ):
43 | super().__init__()
44 | self.sequence_path = Path(sequence_path)
45 | self.subsequence_name = self.sequence_path.name
46 | self.stride = stride
47 | self.sliding_window_len = sliding_window_len
48 |
49 | self.ev_repr_dir = self.sequence_path / 'event_representations' / preprocessed_name
50 |
51 | if not self.ev_repr_dir.exists():
52 | raise FileNotFoundError(f"Preprocessed event representations not found in {self.sequence_path}")
53 |
54 | print(f"Using preprocessed event representations from {self.ev_repr_dir}")
55 |
56 | self.ev_repr_paths = sorted(path for path in self.ev_repr_dir.iterdir()
57 | if path.is_file() and path.suffix == '.npy')
58 |
59 | if not self.ev_repr_paths:
60 | raise FileNotFoundError(f"No .npy files found in {self.ev_repr_dir}")
61 |
62 | self.timestamps = np.array([int(path.stem) for path in self.ev_repr_paths])
63 |
64 | gt_path = self.sequence_path / 'gt_positions.npy'
65 | if gt_path.exists():
66 | self.gt_tracks = torch.from_numpy(np.load(gt_path))
67 | else:
68 | self.gt_tracks = None
69 |
70 | # Load query points from file
71 | query_xy = torch.from_numpy(np.load(self.sequence_path / 'queries.npy'))
72 | query_t = torch.zeros(query_xy.shape[0], dtype=torch.int64)
73 | self.query_points = torch.cat([query_t[:, None], query_xy], dim=1).float()
74 |
75 | self.start_indices = np.arange(0, len(self.timestamps) - self.sliding_window_len + 1, self.stride)
76 |
77 | def load_event_representations(self, start_idx, end_idx):
78 | """Load event representations for the given index range."""
79 | ev_repr = []
80 |
81 | for i in range(start_idx, end_idx):
82 | ev_repr_path = self.ev_repr_paths[i]
83 | sample = np.load(ev_repr_path)
84 | ev_repr.append(sample)
85 |
86 | ev_repr = torch.from_numpy(np.stack(ev_repr, axis=0)).float()
87 | return ev_repr
88 |
89 | def __len__(self):
90 | """Return the number of start indices."""
91 | return len(self.start_indices)
92 |
93 | def __getitem__(self, idx):
94 | """Get a data sample for the given index."""
95 | start_idx = self.start_indices[idx]
96 | end_idx = start_idx + self.sliding_window_len
97 |
98 | ev_repr = self.load_event_representations(start_idx, end_idx)
99 | gt_tracks = self.gt_tracks[start_idx:end_idx] if self.gt_tracks is not None else None
100 | visibility = torch.ones((self.sliding_window_len, gt_tracks.shape[1])) if gt_tracks is not None else None
101 |
102 | sample = EtapData(
103 | voxels=ev_repr, # [T, C, H, W]
104 | rgbs=None,
105 | trajectory=gt_tracks, # [T, N, 2] or None
106 | visibility=visibility, # [T, N] or None
107 | timestamps=torch.from_numpy(self.timestamps[start_idx:end_idx]).float() / 1e6, # [T]
108 | )
109 | return sample, start_idx
--------------------------------------------------------------------------------
/src/data/modules/event_kubric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import h5py
4 | from pathlib import Path
5 | import pytorch_lightning as pl
6 | from ..utils import EtapData
7 |
8 | def calculate_frame_times(num_frames=96, total_time=2.0):
9 | time_step = total_time / (num_frames - 1)
10 | frame_times = np.arange(num_frames) * time_step
11 | return frame_times
12 |
13 | class EventKubricDataModule(pl.LightningDataModule):
14 | def __init__(self, data_root, seq_len, traj_per_sample,
15 | dataset_name, preprocessed_name=None):
16 | super().__init__()
17 | self.save_hyperparameters()
18 |
19 | def prepare_data(self):
20 | test_path = Path(self.hparams.data_root) / 'test'
21 | self.test_dataset = EventKubricDataset(
22 | data_root=test_path,
23 | seq_len=self.hparams.seq_len,
24 | traj_per_sample=self.hparams.traj_per_sample,
25 | preprocessed_name=self.hparams.preprocessed_name,
26 | )
27 |
28 | class EventKubricDataset(torch.utils.data.Dataset):
29 | def __init__(
30 | self,
31 | data_root,
32 | seq_len=24,
33 | traj_per_sample=768,
34 | sample_vis_1st_frame=False,
35 | preprocessed_name=None,
36 | ):
37 | super(EventKubricDataset, self).__init__()
38 | self.data_root = data_root
39 | self.traj_per_sample = traj_per_sample
40 | self.sample_vis_1st_frame = sample_vis_1st_frame
41 | self.seq_len = seq_len
42 |
43 | self.pad_bounds = [0, 25]
44 | self.resize_lim = [0.75, 1.25]
45 | self.resize_delta = 0.05
46 | self.max_crop_offset = 15
47 | self.preprocessed_name = preprocessed_name
48 | self.indices = np.arange(3, 93, 3)[3:-3]
49 |
50 | data_root_path = Path(self.data_root)
51 | self.samples = [d for d in data_root_path.iterdir() if d.is_dir()]
52 |
53 | # Validate samples
54 | valid_samples = []
55 | for sample_path in self.samples:
56 | gt_path = sample_path / 'annotations.npy'
57 | if not gt_path.exists():
58 | continue
59 | gt_data = np.load(str(gt_path), allow_pickle=True).item()
60 | visibility = gt_data['occluded']
61 | if len(visibility) > self.traj_per_sample:
62 | valid_samples.append(sample_path.name)
63 |
64 | self.samples = valid_samples
65 | print(f"Found {len(self.samples)} valid samples")
66 |
67 | def rgba_to_rgb(self, rgba):
68 | if rgba.shape[-1] == 3:
69 | return rgba
70 |
71 | rgb = rgba[..., :3]
72 | alpha = rgba[..., 3:4]
73 | alpha = alpha.astype(np.float32) / 255.0
74 | rgb = rgb.astype(np.float32) * alpha + 255.0 * (1.0 - alpha)
75 |
76 | return rgb.astype(np.uint8)
77 |
78 | def load_rgb_frames(self, seq_path):
79 | seq_path = Path(seq_path)
80 | h5_path = seq_path / 'data.hdf5'
81 | if h5_path.exists():
82 | with h5py.File(str(h5_path), 'r') as f:
83 | if 'rgba' in f:
84 | rgba = f['rgba'][self.indices]
85 | rgb = self.rgba_to_rgb(rgba)
86 | # Convert to (N, C, H, W) for PyTorch and correct channel order
87 | rgb = torch.from_numpy(rgb).permute(0, 3, 1, 2).float()
88 | rgb = rgb.flip(1) # Flip the channels to correct the order (BGR -> RGB)
89 | return rgb
90 | else:
91 | raise KeyError("'rgba' dataset not found in data.hdf5")
92 | raise FileNotFoundError(f"Could not find data.hdf5 in {seq_path}")
93 |
94 | def load_ground_truth(self, seq_path, seq_name):
95 | """
96 | Load ground truth trajectory data from annotations file.
97 |
98 | Args:
99 | seq_path: Path to sequence directory
100 | seq_name: Name of the sequence
101 |
102 | Returns:
103 | tuple: (trajectory data, visibility data)
104 | """
105 | data_root_path = Path(self.data_root)
106 | gt_path = data_root_path / seq_name / 'annotations.npy'
107 | gt_data = np.load(str(gt_path), allow_pickle=True).item()
108 |
109 | # Extract and process trajectory data
110 | traj_2d = gt_data['target_points']
111 | visibility = gt_data['occluded'] # Here a value of 1 means point is visible
112 |
113 | traj_2d = np.transpose(traj_2d, (1, 0, 2))[self.indices]
114 | visibility = np.transpose(np.logical_not(visibility), (1, 0))[self.indices]
115 |
116 | visibility = torch.from_numpy(visibility)
117 | traj_2d = torch.from_numpy(traj_2d)
118 |
119 | visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
120 |
121 | if self.sample_vis_1st_frame:
122 | visibile_pts_inds = visibile_pts_first_frame_inds
123 | else:
124 | visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[:, 0]
125 | visibile_pts_inds = torch.cat(
126 | (visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
127 | )
128 |
129 | return traj_2d, visibility, visibile_pts_inds
130 |
131 | def load_preprocessed_representations(self, seq_path):
132 | seq_path = Path(seq_path)
133 | ev_path = seq_path / 'event_representations' / f'{self.preprocessed_name}.h5'
134 |
135 | with h5py.File(str(ev_path), 'r') as f:
136 | ev_repr = f['representations'][:]
137 | return torch.from_numpy(ev_repr).float()
138 |
139 | def normalize_representation(self, repr):
140 | mask = repr != 0
141 | mean = repr.sum(dim=(0, 2, 3), keepdim=True) / mask.sum(dim=(0, 2, 3), keepdim=True)
142 | var = ((repr - mean)**2 * mask).sum(dim=(0, 2, 3), keepdim=True) / mask.sum(dim=(0, 2, 3), keepdim=True)
143 | std = torch.sqrt(var + 1e-8)
144 | return torch.where(mask, (repr - mean) / std, repr)
145 |
146 | def __getitem__(self, index):
147 | seq_name = self.samples[index]
148 | seq_path = Path(self.data_root) / seq_name
149 |
150 | traj_2d, visibility, visibile_pts_inds = self.load_ground_truth(seq_path, seq_name)
151 |
152 | # Select points
153 | point_inds = torch.arange(min(len(visibile_pts_inds), self.traj_per_sample))
154 | visible_inds_sampled = visibile_pts_inds[point_inds]
155 |
156 | trajs = traj_2d[:, visible_inds_sampled].float()
157 | visibles = visibility[:, visible_inds_sampled]
158 | valids = torch.ones((self.seq_len, self.traj_per_sample))
159 |
160 | rgbs = self.load_rgb_frames(seq_path)
161 | ev_repr = self.load_preprocessed_representations(seq_path)
162 |
163 | # Channelwise std-mean normalization
164 | ev_repr = self.normalize_representation(ev_repr)
165 |
166 | _, first_positive_inds = torch.max(visibles, dim=0) # Find first frame where each point is visible
167 | num_points = visibles.shape[1]
168 | query_coords = torch.zeros((num_points, 2), dtype=trajs.dtype)
169 |
170 | for p in range(num_points):
171 | first_frame = first_positive_inds[p].item()
172 | query_coords[p] = trajs[first_frame, p, :2]
173 |
174 | queries = torch.cat([first_positive_inds.unsqueeze(1), query_coords], dim=1)
175 |
176 | sample = EtapData(
177 | voxels=ev_repr,
178 | rgbs=rgbs,
179 | trajectory=trajs,
180 | visibility=visibles,
181 | valid=valids,
182 | query_points=queries,
183 | seq_name=seq_name,
184 | )
185 | return sample
186 |
187 | def __len__(self):
188 | return len(self.samples)
189 |
190 | def get_full_sample(self, index):
191 | '''Helper function to retrieve full sample data'''
192 | sample = {}
193 | seq_name = self.samples[index]
194 | seq_path = Path(self.data_root) / seq_name
195 | h5_path = seq_path / 'data.hdf5'
196 |
197 | # Get Kubric data and ground truth
198 | with h5py.File(str(h5_path), 'r') as f:
199 | indices = f['subsample_indices'][:]
200 | sample['rgba'] = f['rgba'][indices]
201 | sample['depths'] = f['depth'][:]
202 | sample['forward_flow'] = f['forward_flow'][:]
203 | sample['normal'] = f['normal'][:]
204 | sample['object_coordinates'] = f['object_coordinates'][:]
205 | sample['segmentation'] = f['segmentation'][:]
206 |
207 | # Get timestamps
208 | t_min, t_max = 0.0, 2.0
209 | timestamps = calculate_frame_times(num_frames=96, total_time=t_max)
210 | sample['timestamps'] = timestamps[indices]
211 |
212 | # Get Events
213 | event_root = seq_path / 'events'
214 | event_paths = sorted(event_root.iterdir())
215 | events = []
216 | for event_path in event_paths:
217 | event_mini_batch = np.load(str(event_path))
218 | events.append(np.stack([
219 | event_mini_batch['y'],
220 | event_mini_batch['x'],
221 | event_mini_batch['t'],
222 | event_mini_batch['p'],
223 | ], axis=1))
224 | events = np.concatenate(events, axis=0)
225 | sample['events'] = events
226 |
227 | # Get Point Tracks
228 | gt_path = Path(self.data_root) / seq_name / 'annotations.npy'
229 | gt_data = np.load(str(gt_path), allow_pickle=True).item()
230 | traj_2d = gt_data['target_points']
231 | visibility = gt_data['occluded'] # Here a value of 1 means point is visible
232 | traj_2d = np.transpose(traj_2d, (1, 0, 2))[indices]
233 | visibility = np.transpose(np.logical_not(visibility), (1, 0))[indices]
234 | visibility = torch.from_numpy(visibility)
235 | traj_2d = torch.from_numpy(traj_2d)
236 | sample['point_tracks'] = traj_2d
237 | sample['visibility'] = visibility
238 |
239 | return sample
--------------------------------------------------------------------------------
/src/data/modules/evimo2.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import h5py
4 | import torch
5 | import pytorch_lightning as pl
6 | import pandas as pd
7 | from torch.utils.data import ConcatDataset
8 | import numpy as np
9 |
10 | from ..utils import EtapData
11 |
12 | class Evimo2DataModule(pl.LightningDataModule):
13 | def __init__(self, dataset_name, data_root, preprocessed_name, metadata_path):
14 | super().__init__()
15 | self.save_hyperparameters()
16 |
17 | def prepare_data(self):
18 | samples = []
19 |
20 | df = pd.read_csv(self.hparams.metadata_path)
21 |
22 | for _, row in df.iterrows():
23 | sample_path = Path(self.hparams.data_root) / row['name']
24 | samples.append(Evimo2SequenceDataset(
25 | data_root=sample_path,
26 | preprocessed_name=self.hparams.preprocessed_name,
27 | t_start=row['t_start'],
28 | t_end=row['t_end'],
29 | ))
30 |
31 | self.test_dataset = ConcatDataset(samples)
32 |
33 | class Evimo2SequenceDataset(torch.utils.data.Dataset):
34 | def __init__(
35 | self,
36 | data_root: Path,
37 | preprocessed_name: str,
38 | t_start: float,
39 | t_end: float,
40 | ):
41 | super(Evimo2SequenceDataset, self).__init__()
42 | self.data_root = data_root
43 | self.preprocessed_name = preprocessed_name
44 | self.t_start = t_start
45 | self.t_end = t_end
46 |
47 | # Load timestamps to determine valid indices
48 | repr_path = self.data_root / 'event_representations' / f'{self.preprocessed_name}.h5'
49 | with h5py.File(repr_path, 'r') as f:
50 | self.timestamps = f['timestamps'][:]
51 |
52 | self.valid_indices = np.where(
53 | (self.timestamps >= self.t_start) &
54 | (self.timestamps <= self.t_end)
55 | )[0]
56 |
57 | if len(self.valid_indices) == 0:
58 | raise ValueError(
59 | f"No timestamps found in range [{t_start}, {t_end}] "
60 | f"for sequence {self.data_root.name}"
61 | )
62 |
63 | def __len__(self):
64 | return 1
65 |
66 | def __getitem__(self, idx):
67 | repr_path = self.data_root / 'event_representations' / f'{self.preprocessed_name}.h5'
68 |
69 | # Event representations
70 | with h5py.File(repr_path, 'r') as f:
71 | # Only load data for valid time range
72 | representations = f['representations'][self.valid_indices]
73 | timestamps = f['timestamps'][self.valid_indices]
74 |
75 | representations = torch.from_numpy(representations).float()
76 | timestamps = torch.from_numpy(timestamps)
77 |
78 | # GT data
79 | gt_path = self.data_root / 'dataset_tracks.h5'
80 | with h5py.File(gt_path, 'r') as f:
81 | gt_tracks = f['tracks'][self.valid_indices]
82 | gt_occlusion = f['occlusions'][self.valid_indices]
83 |
84 | gt_tracks = torch.from_numpy(gt_tracks).float()
85 | gt_visibility = torch.from_numpy(~gt_occlusion).bool()
86 |
87 | # Queries
88 | T, N, _ = gt_tracks.shape
89 | first_positive_inds = torch.argmax(gt_visibility.int(), dim=0)
90 | query_xy = gt_tracks[first_positive_inds, torch.arange(N)]
91 | queries = torch.cat([first_positive_inds[..., None], query_xy], dim=-1)
92 |
93 | return EtapData(
94 | voxels=representations, # [T', C, H, W]
95 | rgbs=None,
96 | trajectory=gt_tracks, # [T', N, 2]
97 | visibility=gt_visibility,# [T', N]
98 | query_points=queries,
99 | seq_name=self.data_root.name,
100 | timestamps=timestamps, # [T']
101 | )
102 |
--------------------------------------------------------------------------------
/src/data/modules/feature_tracking_online.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import pytorch_lightning as pl
5 | import numpy as np
6 |
7 | from ..utils import EtapData
8 | from src.utils import SUPPORTED_SEQUENCES_FEATURE_TRACKING
9 |
10 | class FeatureTrackingDataModule(pl.LightningDataModule):
11 | def __init__(self, data_root,
12 | dataset_name,
13 | preprocessed_name=None):
14 | super().__init__()
15 | self.save_hyperparameters()
16 | self.test_datasets = []
17 |
18 | def prepare_data(self):
19 | if self.hparams.dataset_name == 'eds':
20 | supported_sequences = SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']
21 | data_roots = [self.hparams.data_root for _ in SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']]
22 | elif self.hparams.dataset_name == 'ec':
23 | supported_sequences = SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']
24 | data_roots = [self.hparams.data_root for _ in SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']]
25 | elif self.hparams.dataset_name == 'feature_tracking_online':
26 | supported_sequences = SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds'] + SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']
27 | data_root_eds = os.path.join(self.hparams.data_root, 'eds')
28 | data_root_ec = os.path.join(self.hparams.data_root, 'ec')
29 | data_roots = [data_root_eds for _ in SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']] \
30 | + [data_root_ec for _ in SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']]
31 | else:
32 | raise ValueError(f"Unsupported dataset_name: {self.hparams.dataset_name}")
33 |
34 | for subsequence_name, data_root in zip(supported_sequences, data_roots):
35 | self.test_datasets.append(FeatureTrackingInferenceDataset(
36 | data_root=data_root,
37 | subsequence_name=subsequence_name,
38 | preprocessed_name=self.hparams.preprocessed_name
39 | ))
40 |
41 | class FeatureTrackingInferenceDataset(torch.utils.data.Dataset):
42 | """Dataset for a sequence of the EDS or EC dataset for point tracking
43 | as used in https://arxiv.org/pdf/2211.12826.
44 | This dataset is implemented for use in an online manner.
45 | Each item provides the data chunk for 1 step (e.g. 8 frames)
46 | of the sequence. The next data will start from the previous + stride frames.
47 | """
48 | def __init__(
49 | self,
50 | data_root,
51 | subsequence_name,
52 | stride=4,
53 | sliding_window_len=8,
54 | preprocessed_name=None
55 | ):
56 | self.subsequence_name = subsequence_name
57 | self.stride = stride
58 | self.sliding_window_len = sliding_window_len
59 | self.seq_root = os.path.join(data_root, subsequence_name)
60 | assert preprocessed_name is not None, 'online processing of raw events not supported.'
61 | self.preprocessed_name = preprocessed_name
62 |
63 | events_path = os.path.join(self.seq_root, 'events', preprocessed_name)
64 | self.samples, gt_ts_for_sanity_check = self.load_event_representations(events_path)
65 |
66 | gt_path = os.path.join('config/misc', os.path.basename(os.path.normpath(data_root)), 'gt_tracks', f'{subsequence_name}.gt.txt')
67 | gt_tracks = np.genfromtxt(gt_path) # [id, t, x, y]
68 | self.gt_tracks = torch.from_numpy(self.reformat_tracks(gt_tracks)).float()
69 |
70 | self.gt_times_s = np.unique(gt_tracks[:, 1])
71 | self.gt_times_us = (1e6 * self.gt_times_s).astype(int)
72 | assert np.allclose(gt_ts_for_sanity_check, self.gt_times_us)
73 |
74 | self.start_indices = np.arange(0, len(self.gt_times_us), stride)
75 |
76 | # Queries
77 | N = self.gt_tracks.shape[1]
78 | query_xy = self.gt_tracks[0, :]
79 | query_t = torch.zeros(N, dtype=torch.int64, device=query_xy.device)
80 | self.query_points = torch.cat([query_t[:, None], query_xy], dim=-1)
81 |
82 | def __len__(self):
83 | return len(self.start_indices)
84 |
85 | def __getitem__(self, idx):
86 | start_idx = self.start_indices[idx]
87 | end_idx = start_idx + self.sliding_window_len
88 | ev_repr = []
89 |
90 | for t in range(start_idx, end_idx):
91 | sample = np.load(os.path.join(self.seq_root, 'events', self.preprocessed_name,
92 | f'{self.gt_times_us[t]}.npy'))
93 | ev_repr.append(sample)
94 |
95 | ev_repr = torch.from_numpy(np.stack(ev_repr, axis=0)).float()
96 |
97 | gt_tracks = self.gt_tracks[start_idx:end_idx]
98 | height, width = ev_repr.shape[-2:]
99 | visibility = self.generate_visibility_mask(gt_tracks, height, width)
100 |
101 | sample = EtapData(
102 | voxels=ev_repr, # [T, C, H, W]
103 | rgbs=None, # [T, 3, H, W], only for visualization
104 | trajectory=gt_tracks, # [T, N, 2]
105 | visibility=visibility, # [T, N]
106 | timestamps=torch.from_numpy(self.gt_times_s[start_idx:end_idx]) # [T]
107 | )
108 | return sample, start_idx
109 |
110 | def load_event_representations(self, events_path):
111 | repr_files = [f for f in os.listdir(events_path) if f.endswith('.npy')]
112 |
113 | # Extract the integer numbers from the file names
114 | timestamps = []
115 | for file in repr_files:
116 | number = int(os.path.splitext(file)[0])
117 | timestamps.append(number)
118 |
119 | return repr_files, sorted(timestamps)
120 |
121 | def reformat_tracks(self, tracks):
122 | # Extract unique timestamps and point IDs
123 | timestamps = np.unique(tracks[:, 1])
124 | point_ids = np.unique(tracks[:, 0]).astype(int)
125 |
126 | T = len(timestamps)
127 | N = len(point_ids)
128 | output = np.full((T, N, 2), np.nan)
129 | time_to_index = {t: i for i, t in enumerate(timestamps)}
130 |
131 | for row in tracks:
132 | point_id, timestamp, x, y = row
133 | t_index = time_to_index[timestamp]
134 | p_index = int(point_id)
135 | output[t_index, p_index, :] = [x, y]
136 |
137 | return output
138 |
139 | def generate_visibility_mask(self, points, height, width):
140 | x = points[..., 0]
141 | y = points[..., 1]
142 | x_in_bounds = (x >= 0) & (x < width)
143 | y_in_bounds = (y >= 0) & (y < height)
144 | mask = x_in_bounds & y_in_bounds
145 | return mask
146 |
--------------------------------------------------------------------------------
/src/data/modules/penn_aviary.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import h5py
4 | import pytorch_lightning as pl
5 | import numpy as np
6 | from tqdm import tqdm
7 | from PIL import Image
8 | from pathlib import Path
9 |
10 | from src.representations import EventRepresentationFactory
11 | from src.utils import make_grid
12 | from ..utils import EtapData
13 |
14 |
15 | def pad_images_to_timestamps(images, img_ts, target_timestamps):
16 | """Match images to the nearest timestamp."""
17 | indices = np.searchsorted(img_ts, target_timestamps, side='right') - 1
18 | indices = np.clip(indices, 0, len(images) - 1)
19 | imgs_padded = images[indices]
20 | return imgs_padded
21 |
22 |
23 | def read_binary_mask(path):
24 | """Read binary mask from file."""
25 | img = np.array(Image.open(path).convert('L'))
26 | mask = (img != 255).astype(np.uint8)
27 | return mask
28 |
29 | class PennAviaryDataModule(pl.LightningDataModule):
30 | def __init__(self, dataset_name, data_root, sequence_data, repr_config=None,
31 | load_rgb=False, sequences=None):
32 | super().__init__()
33 | self.save_hyperparameters()
34 | self.test_datasets = []
35 |
36 | if self.hparams.repr_config is not None:
37 | self.hparams.repr_config['image_shape'] = tuple(self.hparams.repr_config['image_shape'])
38 | self.converter = EventRepresentationFactory.create(self.hparams.repr_config)
39 | else:
40 | self.converter = None
41 |
42 | def prepare_data(self):
43 | data_root = Path(self.hparams.data_root)
44 | if self.hparams.sequences is None:
45 | sequences = [d for d in data_root.iterdir() if d.is_dir()]
46 | else:
47 | sequences = [data_root / seq for seq in self.hparams.sequences]
48 |
49 | for sequence_path in sequences:
50 | subsequence_name = sequence_path.name
51 |
52 | if not (sequence_path / 'seq.h5').exists():
53 | print(f"Warning: seq.h5 not found in {sequence_path}, skipping...")
54 | continue
55 |
56 | self.test_datasets.append(PennAviaryDataset(
57 | sequence_path=sequence_path,
58 | num_events=self.hparams.sequence_data[subsequence_name]['num_events'],
59 | start_time_s=self.hparams.sequence_data[subsequence_name]['start_time_s'],
60 | duration_s=self.hparams.sequence_data[subsequence_name]['duration_s'],
61 | step_time_s=self.hparams.sequence_data[subsequence_name]['step_time_s'],
62 | converter=self.converter,
63 | load_rgb=self.hparams.load_rgb,
64 | query_stride=self.hparams.sequence_data[subsequence_name].get('query_stride', 40),
65 | mask_name=self.hparams.sequence_data[subsequence_name]['mask_name']
66 | ))
67 |
68 |
69 | class PennAviaryDataset(torch.utils.data.Dataset):
70 | def __init__(
71 | self,
72 | sequence_path,
73 | start_time_s,
74 | duration_s,
75 | step_time_s,
76 | num_events,
77 | stride=4,
78 | sliding_window_len=8,
79 | load_rgb=False,
80 | converter=None,
81 | query_stride=40,
82 | mask_name=None
83 | ):
84 | super().__init__()
85 | self.sequence_path = Path(sequence_path)
86 | self.subsequence_name = self.sequence_path.name
87 | self.num_events = num_events
88 | self.stride = stride
89 | self.sliding_window_len = sliding_window_len
90 | self.load_rgb = load_rgb
91 |
92 | # Define tracking timestamps
93 | start_time_us = start_time_s * 1e6
94 | end_time_us = start_time_us + duration_s * 1e6
95 | step_time_us = step_time_s * 1e6
96 | self.timestamps = np.arange(start_time_us, end_time_us, step_time_us)
97 | self.gt_tracks = None
98 |
99 | # Generate query points grid
100 | height, width = 480, 640
101 | query_xy = torch.from_numpy(make_grid(height, width, stride=query_stride))
102 | num_queries = query_xy.shape[0]
103 | query_t = torch.zeros(num_queries, dtype=torch.int64)
104 | self.query_points = torch.cat([query_t[:, None], query_xy], dim=1).float()
105 |
106 | # Only query points within mask
107 | if mask_name is not None:
108 | mask_path = self.sequence_path / mask_name
109 | segm_mask = torch.from_numpy(read_binary_mask(mask_path))
110 | query_x = self.query_points[:, 1].int()
111 | query_y = self.query_points[:, 2].int()
112 | segm_mask = segm_mask[query_y, query_x]
113 | self.query_points = self.query_points[segm_mask == 1]
114 |
115 | self.h5_path = self.sequence_path / 'seq.h5'
116 |
117 | if self.load_rgb:
118 | with h5py.File(self.h5_path, 'r') as f:
119 | images = f['images'][:]
120 | img_ts = f['img_ts'][:]
121 | self.imgs_padded = pad_images_to_timestamps(images, img_ts, self.timestamps)
122 |
123 | load_ev_repr_from_file = True if len(self.timestamps) > 1000 else False
124 |
125 | if converter is not None and not load_ev_repr_from_file:
126 | self.load_ev_repr = False
127 | self.ev_repr = self.create_representations(self.h5_path, converter, self.timestamps)
128 | else:
129 | self.load_ev_repr = True
130 | ev_repr_dir = self.sequence_path / str(int(1e6 * step_time_s)).zfill(9)
131 | self.ev_repr_paths = sorted(path for path in ev_repr_dir.iterdir() if path.is_file())
132 | assert len(self.ev_repr_paths) == len(self.timestamps), f"Expected {len(self.timestamps)} event representation files, but found {len(self.ev_repr_paths)}"
133 |
134 | self.start_indices = np.arange(0, len(self.timestamps) - self.stride, self.stride)
135 |
136 | def create_representations(self, h5_path, converter, timestamps):
137 | with h5py.File(h5_path, 'r') as f:
138 | indices_end = np.searchsorted(f['t'], timestamps) - 1
139 | indices_start = indices_end - self.num_events
140 |
141 | representations = []
142 |
143 | for i_start, i_end in tqdm(zip(indices_start, indices_end),
144 | desc=f"{self.subsequence_name}: creating event representations",
145 | total=len(indices_start)):
146 | events = np.stack([f['y'][i_start:i_end],
147 | f['x'][i_start:i_end],
148 | f['t'][i_start:i_end],
149 | f['p'][i_start:i_end]], axis=-1)
150 | repr = converter(events, t_mid=None)
151 | representations.append(repr)
152 |
153 | representations = np.stack(representations)
154 |
155 | return representations
156 |
157 | def load_event_representations(self, start_idx, end_idx):
158 | ev_repr = []
159 |
160 | for i in range(start_idx, end_idx):
161 | ev_repr_path = self.ev_repr_paths[i]
162 | sample = np.load(ev_repr_path)
163 | ev_repr.append(sample)
164 |
165 | ev_repr = torch.from_numpy(np.stack(ev_repr, axis=0)).float()
166 | return ev_repr
167 |
168 | def __len__(self):
169 | return len(self.start_indices)
170 |
171 | def __getitem__(self, idx):
172 | # Load event representation
173 | start_idx = self.start_indices[idx]
174 | end_idx = start_idx + self.sliding_window_len
175 |
176 | if self.load_ev_repr:
177 | ev_repr = self.load_event_representations(start_idx, end_idx)
178 | else:
179 | ev_repr = self.ev_repr[start_idx:end_idx]
180 | ev_repr = torch.from_numpy(np.stack(ev_repr, axis=0)).float()
181 |
182 | if self.load_rgb:
183 | rgbs = self.imgs_padded[start_idx:end_idx]
184 | rgbs = torch.from_numpy(rgbs).float().permute(0, 3, 1, 2)
185 | else:
186 | rgbs = None
187 |
188 | sample = EtapData(
189 | voxels=ev_repr, # [T, C, H, W]
190 | rgbs=rgbs, # [T, 3, H, W], only for visualization
191 | trajectory=None,
192 | visibility=None,
193 | timestamps=torch.from_numpy(self.timestamps[start_idx:end_idx]).float() / 1e6, # [T]
194 | )
195 | return sample, start_idx
--------------------------------------------------------------------------------
/src/data/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .collate import EtapData
--------------------------------------------------------------------------------
/src/data/utils/collate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dataclasses
3 | from dataclasses import dataclass
4 | from typing import Any, List, Optional, Tuple
5 |
6 |
7 | @dataclass(eq=False)
8 | class EtapData:
9 | """
10 | Dataclass for storing video tracks data.
11 | """
12 | voxels: torch.Tensor # B, S, C, H, W
13 | trajectory: torch.Tensor # B, S, N, 2
14 | visibility: torch.Tensor # B, S, N
15 | # optional data
16 | rgbs: Optional[torch.Tensor] # B, S, 3, H, W, only for visualization
17 | valid: Optional[torch.Tensor] = None # B, S, N
18 | segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
19 | seq_name: Optional[str] = None
20 | query_points: Optional[torch.Tensor] = None # TapVID evaluation format
21 | timestamps: Optional[torch.Tensor] = None # For EDS evaluation
22 | pair_indices: Optional[torch.Tensor] = None # For contrastive loss
23 | video: Optional[torch.Tensor] = None # B, S, C, H, W
24 | e2vid: Optional[torch.Tensor] = None # B, S, C, H, W
25 | dataset_name: Optional[str] = None
26 |
27 |
28 | def collate_fn(batch: List[EtapData]) -> EtapData:
29 | """
30 | Collate function for video tracks data.
31 | """
32 | voxels = torch.stack([b.voxels for b in batch], dim=0)
33 | trajectory = torch.stack([b.trajectory for b in batch], dim=0)
34 | visibility = torch.stack([b.visibility for b in batch], dim=0)
35 | query_points = segmentation = None
36 | if batch[0].query_points is not None:
37 | query_points = torch.stack([b.query_points for b in batch], dim=0)
38 | if batch[0].segmentation is not None:
39 | segmentation = torch.stack([b.segmentation for b in batch], dim=0)
40 | seq_name = [b.seq_name for b in batch]
41 |
42 | return EtapData(
43 | voxels=voxels,
44 | trajectory=trajectory,
45 | visibility=visibility,
46 | segmentation=segmentation,
47 | seq_name=seq_name,
48 | query_points=query_points,
49 | )
50 |
51 |
52 | def collate_fn_train(batch: List[Tuple[EtapData, bool]]) -> Tuple[EtapData, List[bool]]:
53 | """
54 | Collate function for video tracks data during training.
55 | """
56 | gotit = [gotit for _, gotit in batch]
57 | voxels = torch.stack([b.voxels for b, _ in batch], dim=0)
58 | rgbs = torch.stack([b.rgbs for b, _ in batch], dim=0) if batch[0][0].rgbs is not None else None
59 | trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
60 | visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
61 | valid = torch.stack([b.valid for b, _ in batch], dim=0)
62 | seq_name = [b.seq_name for b, _ in batch]
63 | return (
64 | EtapData(
65 | voxels=voxels,
66 | rgbs=rgbs,
67 | trajectory=trajectory,
68 | visibility=visibility,
69 | valid=valid,
70 | seq_name=seq_name,
71 | ),
72 | gotit,
73 | )
74 |
75 |
76 | def collate_fn_load_inverted(batch: List[Tuple[EtapData, EtapData, bool]]) -> Tuple[EtapData, List[bool]]:
77 | """
78 | Collate function for video tracks data with load_inverted case.
79 | Combines original and inverted samples into a single EtapData object.
80 | """
81 | # Separate original and inverted samples
82 | orig_samples = [b[0] for b in batch]
83 | inv_samples = [b[1] for b in batch]
84 | gotit = [gotit for _, _, gotit in batch]
85 |
86 | # Combine original and inverted samples
87 | combined_samples = orig_samples + inv_samples
88 |
89 | voxels = torch.stack([b.voxels for b in combined_samples], dim=0)
90 | rgbs = torch.stack([b.rgbs for b in combined_samples], dim=0) if combined_samples[0].rgbs is not None else None
91 | trajectory = torch.stack([b.trajectory for b in combined_samples], dim=0)
92 | visibility = torch.stack([b.visibility for b in combined_samples], dim=0)
93 | valid = torch.stack([b.valid for b in combined_samples], dim=0)
94 | seq_name = [b.seq_name for b in combined_samples]
95 |
96 | # Create a tensor to keep track of paired samples using explicit indices
97 | batch_size = len(orig_samples)
98 | pair_indices = torch.arange(batch_size).repeat(2)
99 |
100 | return EtapData(
101 | voxels=voxels,
102 | rgbs=rgbs,
103 | trajectory=trajectory,
104 | visibility=visibility,
105 | valid=valid,
106 | seq_name=seq_name,
107 | pair_indices=pair_indices,
108 | ), gotit
109 |
110 |
111 | def try_to_cuda(t: Any) -> Any:
112 | """
113 | Try to move the input variable `t` to a cuda device.
114 |
115 | Args:
116 | t: Input tensor or other object
117 |
118 | Returns:
119 | t_cuda: `t` moved to a cuda device, if supported
120 | """
121 | try:
122 | t = t.float().cuda()
123 | except AttributeError:
124 | pass
125 | return t
126 |
127 |
128 | def dataclass_to_cuda_(obj: Any) -> Any:
129 | """
130 | Move all contents of a dataclass to cuda inplace if supported.
131 |
132 | Args:
133 | obj: Input dataclass object to move to CUDA
134 |
135 | Returns:
136 | obj: The same object with its tensor fields moved to CUDA
137 | """
138 | for f in dataclasses.fields(obj):
139 | setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
140 | return obj
141 |
--------------------------------------------------------------------------------
/src/model/etap/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/src/model/etap/core/__init__.py
--------------------------------------------------------------------------------
/src/model/etap/core/cotracker/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/src/model/etap/core/cotracker/__init__.py
--------------------------------------------------------------------------------
/src/model/etap/core/cotracker/blocks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified by F. Hamann (TU Berlin)
8 | # Copyright (c) 2025 F. Hamann (TU Berlin)
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from functools import partial
14 | from typing import Callable
15 | import collections
16 | from itertools import repeat
17 |
18 | from ..model_utils import bilinear_sampler
19 |
20 |
21 | # From PyTorch internals
22 | def _ntuple(n):
23 | def parse(x):
24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25 | return tuple(x)
26 | return tuple(repeat(x, n))
27 |
28 | return parse
29 |
30 |
31 | def exists(val):
32 | return val is not None
33 |
34 |
35 | def default(val, d):
36 | return val if exists(val) else d
37 |
38 |
39 | to_2tuple = _ntuple(2)
40 |
41 |
42 | class Mlp(nn.Module):
43 | """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
44 |
45 | def __init__(
46 | self,
47 | in_features,
48 | hidden_features=None,
49 | out_features=None,
50 | act_layer=nn.GELU,
51 | norm_layer=None,
52 | bias=True,
53 | drop=0.0,
54 | use_conv=False,
55 | ):
56 | super().__init__()
57 | out_features = out_features or in_features
58 | hidden_features = hidden_features or in_features
59 | bias = to_2tuple(bias)
60 | drop_probs = to_2tuple(drop)
61 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
62 |
63 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
64 | self.act = act_layer()
65 | self.drop1 = nn.Dropout(drop_probs[0])
66 | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
67 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
68 | self.drop2 = nn.Dropout(drop_probs[1])
69 |
70 | def forward(self, x):
71 | x = self.fc1(x)
72 | x = self.act(x)
73 | x = self.drop1(x)
74 | x = self.fc2(x)
75 | x = self.drop2(x)
76 | return x
77 |
78 |
79 | class ResidualBlock(nn.Module):
80 | def __init__(self, in_planes, planes, norm_fn="group", stride=1):
81 | super(ResidualBlock, self).__init__()
82 |
83 | self.conv1 = nn.Conv2d(
84 | in_planes,
85 | planes,
86 | kernel_size=3,
87 | padding=1,
88 | stride=stride,
89 | padding_mode="zeros",
90 | )
91 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
92 | self.relu = nn.ReLU(inplace=True)
93 |
94 | num_groups = planes // 8
95 |
96 | if norm_fn == "group":
97 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
98 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
99 | if not stride == 1:
100 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
101 |
102 | elif norm_fn == "batch":
103 | self.norm1 = nn.BatchNorm2d(planes)
104 | self.norm2 = nn.BatchNorm2d(planes)
105 | if not stride == 1:
106 | self.norm3 = nn.BatchNorm2d(planes)
107 |
108 | elif norm_fn == "instance":
109 | self.norm1 = nn.InstanceNorm2d(planes)
110 | self.norm2 = nn.InstanceNorm2d(planes)
111 | if not stride == 1:
112 | self.norm3 = nn.InstanceNorm2d(planes)
113 |
114 | elif norm_fn == "none":
115 | self.norm1 = nn.Sequential()
116 | self.norm2 = nn.Sequential()
117 | if not stride == 1:
118 | self.norm3 = nn.Sequential()
119 |
120 | if stride == 1:
121 | self.downsample = None
122 |
123 | else:
124 | self.downsample = nn.Sequential(
125 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
126 | )
127 |
128 | def forward(self, x):
129 | y = x
130 | y = self.relu(self.norm1(self.conv1(y)))
131 | y = self.relu(self.norm2(self.conv2(y)))
132 |
133 | if self.downsample is not None:
134 | x = self.downsample(x)
135 |
136 | return self.relu(x + y)
137 |
138 |
139 | class BasicEncoder(nn.Module):
140 | def __init__(self, input_dim=3, output_dim=128, stride=4):
141 | super(BasicEncoder, self).__init__()
142 | self.stride = stride
143 | self.norm_fn = "instance"
144 | self.in_planes = output_dim // 2
145 |
146 | self.norm1 = nn.InstanceNorm2d(self.in_planes)
147 | self.norm2 = nn.InstanceNorm2d(output_dim * 2)
148 |
149 | self.conv1 = nn.Conv2d(
150 | input_dim,
151 | self.in_planes,
152 | kernel_size=7,
153 | stride=2,
154 | padding=3,
155 | padding_mode="zeros",
156 | )
157 | self.relu1 = nn.ReLU(inplace=True)
158 | self.layer1 = self._make_layer(output_dim // 2, stride=1)
159 | self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
160 | self.layer3 = self._make_layer(output_dim, stride=2)
161 | self.layer4 = self._make_layer(output_dim, stride=2)
162 |
163 | self.conv2 = nn.Conv2d(
164 | output_dim * 3 + output_dim // 4,
165 | output_dim * 2,
166 | kernel_size=3,
167 | padding=1,
168 | padding_mode="zeros",
169 | )
170 | self.relu2 = nn.ReLU(inplace=True)
171 | self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
172 | for m in self.modules():
173 | if isinstance(m, nn.Conv2d):
174 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
175 | elif isinstance(m, (nn.InstanceNorm2d)):
176 | if m.weight is not None:
177 | nn.init.constant_(m.weight, 1)
178 | if m.bias is not None:
179 | nn.init.constant_(m.bias, 0)
180 |
181 | def _make_layer(self, dim, stride=1):
182 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
183 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
184 | layers = (layer1, layer2)
185 |
186 | self.in_planes = dim
187 | return nn.Sequential(*layers)
188 |
189 | def forward(self, x):
190 | _, _, H, W = x.shape
191 |
192 | x = self.conv1(x)
193 | x = self.norm1(x)
194 | x = self.relu1(x)
195 |
196 | a = self.layer1(x)
197 | b = self.layer2(a)
198 | c = self.layer3(b)
199 | d = self.layer4(c)
200 |
201 | def _bilinear_intepolate(x):
202 | return F.interpolate(
203 | x,
204 | (H // self.stride, W // self.stride),
205 | mode="bilinear",
206 | align_corners=True,
207 | )
208 |
209 | a = _bilinear_intepolate(a)
210 | b = _bilinear_intepolate(b)
211 | c = _bilinear_intepolate(c)
212 | d = _bilinear_intepolate(d)
213 |
214 | x = self.conv2(torch.cat([a, b, c, d], dim=1))
215 | x = self.norm2(x)
216 | x = self.relu2(x)
217 | x = self.conv3(x)
218 | return x
219 |
220 |
221 | class CorrBlock:
222 | def __init__(
223 | self,
224 | fmaps,
225 | num_levels=4,
226 | radius=4,
227 | multiple_track_feats=False,
228 | padding_mode="zeros",
229 | appearance_fact_flow_dim=None,
230 | ):
231 | B, S, C, H, W = fmaps.shape
232 | self.S, self.C, self.H, self.W = S, C, H, W
233 | self.padding_mode = padding_mode
234 | self.num_levels = num_levels
235 | self.radius = radius
236 | self.fmaps_pyramid = []
237 | self.multiple_track_feats = multiple_track_feats
238 | self.appearance_fact_flow_dim = appearance_fact_flow_dim
239 |
240 | self.fmaps_pyramid.append(fmaps)
241 | for i in range(self.num_levels - 1):
242 | fmaps_ = fmaps.reshape(B * S, C, H, W)
243 | fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
244 | _, _, H, W = fmaps_.shape
245 | fmaps = fmaps_.reshape(B, S, C, H, W)
246 | self.fmaps_pyramid.append(fmaps)
247 |
248 | def sample(self, coords):
249 | r = self.radius
250 | B, S, N, D = coords.shape
251 | assert D == 2
252 |
253 | H, W = self.H, self.W
254 | out_pyramid = []
255 | for i in range(self.num_levels):
256 | corrs = self.corrs_pyramid[i] # B, S, N, H, W
257 | *_, H, W = corrs.shape
258 |
259 | dx = torch.linspace(-r, r, 2 * r + 1)
260 | dy = torch.linspace(-r, r, 2 * r + 1)
261 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
262 |
263 | centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
264 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
265 | coords_lvl = centroid_lvl + delta_lvl
266 |
267 | corrs = bilinear_sampler(
268 | corrs.reshape(B * S * N, 1, H, W),
269 | coords_lvl,
270 | padding_mode=self.padding_mode,
271 | )
272 | corrs = corrs.view(B, S, N, -1)
273 | out_pyramid.append(corrs)
274 |
275 | out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
276 | out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
277 | return out
278 |
279 | def corr(self, targets, coords=None, gt_flow=None, use_gt_flow=False, use_flow_tokens=False,
280 | use_af_high_dim=False, interaction_network=None):
281 | assert sum([coords is not None,
282 | use_gt_flow,
283 | use_flow_tokens,
284 | use_af_high_dim,
285 | interaction_network is not None]) <= 1, \
286 | "Exactly one of coords, use_gt_flow, use_flow_tokens, use_af_high_dim, or interaction_network must be specified."
287 | assert not use_flow_tokens or self.appearance_fact_flow_dim is not None
288 |
289 | # Appearance factorization
290 | if coords is not None:
291 | B, S, N, D2 = targets.shape
292 | targets = targets.reshape(B, S, N, 2, D2 // 2)
293 | if use_gt_flow:
294 | flow = gt_flow
295 | else:
296 | flow = coords[:, 1:] - coords[:, :-1]
297 | flow = torch.cat([flow[:, 0:1], flow], dim=1)
298 | flow = gt_flow if use_gt_flow else flow
299 | targets = flow[..., 0:1] * targets[..., 0, :] + flow[..., 1:2] * targets[..., 1, :]
300 |
301 | ###### DEBUG ########
302 | # flow: [B, S, N, 2]
303 | # targets: [B, S, N, 2, feat_dim], remember feat_dim == latent_dim / 2
304 |
305 | # Appearane factorization with flow tokens
306 | if use_flow_tokens:
307 | fdim = self.appearance_fact_flow_dim
308 | flow = targets[..., -fdim:]
309 | targets = targets[..., :-fdim]
310 | B, S, N, D2 = targets.shape
311 | targets = targets.reshape(B, S, N, fdim, D2 // fdim)
312 | #targets = flow[..., 0:1] * targets[..., 0, :] + flow[..., 1:2] * targets[..., 1, :]
313 | targets = torch.einsum('btnji,btnj->btni', targets, flow)
314 |
315 | if use_af_high_dim:
316 | fdim = self.appearance_fact_flow_dim
317 | flow = targets[..., -fdim:]
318 | targets = targets[..., :-fdim]
319 | targets = flow * targets
320 |
321 | #if interaction_network is not None:
322 | # flow_feat = targets[..., :self.appearance_fact_flow_dim]
323 | # targets = targets[..., self.appearance_fact_flow_dim:]
324 |
325 | # Correlation
326 | B, S, N, C = targets.shape
327 | #assert C == self.C
328 | assert S == self.S
329 |
330 | fmap1 = targets
331 |
332 | self.corrs_pyramid = []
333 | for _, fmaps in enumerate(self.fmaps_pyramid):
334 | *_, H, W = fmaps.shape
335 | fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
336 | corrs = torch.matmul(fmap1, fmap2s)
337 | corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
338 | corrs = corrs / torch.sqrt(torch.tensor(C).float())
339 | self.corrs_pyramid.append(corrs)
340 |
341 | def sample_fmap(self, coords):
342 | '''Sample at coords directly from scaled feature map.
343 | '''
344 | B, S, N, D = coords.shape
345 | assert D == 2
346 |
347 | H, W = self.H, self.W
348 | out_pyramid = []
349 | for i in range(self.num_levels):
350 |
351 | fmap = self.fmaps_pyramid[i] # B, S, C, H, W
352 | B, S, C, H, W = fmap.shape
353 |
354 | coords_lvl = coords / 2**i
355 |
356 | fmap = fmap.reshape(B * S, C, H, W)
357 | coords_lvl = coords_lvl.reshape(B * S, N, 2)
358 |
359 | coords_normalized = coords_lvl.clone()
360 | coords_normalized[..., 0] = coords_normalized[..., 0] / (W - 1) * 2 - 1
361 | coords_normalized[..., 1] = coords_normalized[..., 1] / (H - 1) * 2 - 1
362 | coords_normalized = coords_normalized.unsqueeze(1)
363 | feature_at_coords = F.grid_sample(fmap, coords_normalized,
364 | mode='bilinear', align_corners=True)
365 | feature_at_coords = feature_at_coords.permute(0, 2, 3, 1).view(B, S, N, C)
366 | out_pyramid.append(feature_at_coords)
367 |
368 | out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
369 | return out
370 |
371 | class Attention(nn.Module):
372 | def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
373 | super().__init__()
374 | inner_dim = dim_head * num_heads
375 | context_dim = default(context_dim, query_dim)
376 | self.scale = dim_head**-0.5
377 | self.heads = num_heads
378 |
379 | self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
380 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
381 | self.to_out = nn.Linear(inner_dim, query_dim)
382 |
383 | def forward(self, x, context=None, attn_bias=None):
384 | B, N1, C = x.shape
385 | h = self.heads
386 |
387 | q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
388 | context = default(context, x)
389 | k, v = self.to_kv(context).chunk(2, dim=-1)
390 |
391 | N2 = context.shape[1]
392 | k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
393 | v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
394 |
395 | sim = (q @ k.transpose(-2, -1)) * self.scale
396 |
397 | if attn_bias is not None:
398 | sim = sim + attn_bias
399 | attn = sim.softmax(dim=-1)
400 |
401 | x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
402 | return self.to_out(x)
403 |
404 |
405 | class AttnBlock(nn.Module):
406 | def __init__(
407 | self,
408 | hidden_size,
409 | num_heads,
410 | attn_class: Callable[..., nn.Module] = Attention,
411 | mlp_ratio=4.0,
412 | **block_kwargs
413 | ):
414 | super().__init__()
415 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
416 | self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
417 |
418 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
419 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
420 | approx_gelu = lambda: nn.GELU(approximate="tanh")
421 | self.mlp = Mlp(
422 | in_features=hidden_size,
423 | hidden_features=mlp_hidden_dim,
424 | act_layer=approx_gelu,
425 | drop=0,
426 | )
427 |
428 | def forward(self, x, mask=None):
429 | attn_bias = mask
430 | if mask is not None:
431 | mask = (
432 | (mask[:, None] * mask[:, :, None])
433 | .unsqueeze(1)
434 | .expand(-1, self.attn.num_heads, -1, -1)
435 | )
436 | max_neg_value = -torch.finfo(x.dtype).max
437 | attn_bias = (~mask) * max_neg_value
438 | x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
439 | x = x + self.mlp(self.norm2(x))
440 | return x
441 |
--------------------------------------------------------------------------------
/src/model/etap/core/embeddings.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified by F. Hamann (TU Berlin)
8 | # Copyright (c) 2025 F. Hamann (TU Berlin)
9 |
10 | from typing import Tuple, Union
11 | import torch
12 |
13 |
14 | def get_2d_sincos_pos_embed(
15 | embed_dim: int, grid_size: Union[int, Tuple[int, int]]
16 | ) -> torch.Tensor:
17 | """
18 | This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
19 | It is a wrapper of get_2d_sincos_pos_embed_from_grid.
20 | Args:
21 | - embed_dim: The embedding dimension.
22 | - grid_size: The grid size.
23 | Returns:
24 | - pos_embed: The generated 2D positional embedding.
25 | """
26 | if isinstance(grid_size, tuple):
27 | grid_size_h, grid_size_w = grid_size
28 | else:
29 | grid_size_h = grid_size_w = grid_size
30 | grid_h = torch.arange(grid_size_h, dtype=torch.float)
31 | grid_w = torch.arange(grid_size_w, dtype=torch.float)
32 | grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
33 | grid = torch.stack(grid, dim=0)
34 | grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
35 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
36 | return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
37 |
38 |
39 | def get_2d_sincos_pos_embed_from_grid(
40 | embed_dim: int, grid: torch.Tensor
41 | ) -> torch.Tensor:
42 | """
43 | This function generates a 2D positional embedding from a given grid using sine and cosine functions.
44 |
45 | Args:
46 | - embed_dim: The embedding dimension.
47 | - grid: The grid to generate the embedding from.
48 |
49 | Returns:
50 | - emb: The generated 2D positional embedding.
51 | """
52 | assert embed_dim % 2 == 0
53 |
54 | # use half of dimensions to encode grid_h
55 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
56 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
57 |
58 | emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
59 | return emb
60 |
61 |
62 | def get_1d_sincos_pos_embed_from_grid(
63 | embed_dim: int, pos: torch.Tensor
64 | ) -> torch.Tensor:
65 | """
66 | This function generates a 1D positional embedding from a given grid using sine and cosine functions.
67 |
68 | Args:
69 | - embed_dim: The embedding dimension.
70 | - pos: The position to generate the embedding from.
71 |
72 | Returns:
73 | - emb: The generated 1D positional embedding.
74 | """
75 | assert embed_dim % 2 == 0
76 | omega = torch.arange(embed_dim // 2, dtype=torch.double)
77 | omega /= embed_dim / 2.0
78 | omega = 1.0 / 10000**omega # (D/2,)
79 |
80 | pos = pos.reshape(-1) # (M,)
81 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
82 |
83 | emb_sin = torch.sin(out) # (M, D/2)
84 | emb_cos = torch.cos(out) # (M, D/2)
85 |
86 | emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
87 | return emb[None].float()
88 |
89 |
90 | def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
91 | """
92 | This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
93 |
94 | Args:
95 | - xy: The coordinates to generate the embedding from.
96 | - C: The size of the embedding.
97 | - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
98 |
99 | Returns:
100 | - pe: The generated 2D positional embedding.
101 | """
102 | B, N, D = xy.shape
103 | assert D == 2
104 |
105 | x = xy[:, :, 0:1]
106 | y = xy[:, :, 1:2]
107 | div_term = (
108 | torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
109 | ).reshape(1, 1, int(C / 2))
110 |
111 | pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
112 | pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
113 |
114 | pe_x[:, :, 0::2] = torch.sin(x * div_term)
115 | pe_x[:, :, 1::2] = torch.cos(x * div_term)
116 |
117 | pe_y[:, :, 0::2] = torch.sin(y * div_term)
118 | pe_y[:, :, 1::2] = torch.cos(y * div_term)
119 |
120 | pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
121 | if cat_coords:
122 | pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
123 | return pe
124 |
--------------------------------------------------------------------------------
/src/model/etap/core/model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified by F. Hamann (TU Berlin)
8 | # Copyright (c) 2025 F. Hamann (TU Berlin)
9 |
10 | from __future__ import annotations
11 | import torch
12 | import torch.nn.functional as F
13 | from typing import Optional, Tuple
14 |
15 | EPS = 1e-6
16 |
17 |
18 | def smart_cat(tensor1: Optional[torch.Tensor], tensor2: torch.Tensor, dim: int) -> torch.Tensor:
19 | if tensor1 is None:
20 | return tensor2
21 | return torch.cat([tensor1, tensor2], dim=dim)
22 |
23 |
24 | def get_points_on_a_grid(
25 | size: int,
26 | extent: Tuple[float, ...],
27 | center: Optional[Tuple[float, ...]] = None,
28 | device: Optional[torch.device] = torch.device("cpu"),
29 | ) -> torch.Tensor:
30 | r"""Get a grid of points covering a rectangular region
31 |
32 | `get_points_on_a_grid(size, extent)` generates a :attr:`size` by
33 | :attr:`size` grid fo points distributed to cover a rectangular area
34 | specified by `extent`.
35 |
36 | The `extent` is a pair of integer :math:`(H,W)` specifying the height
37 | and width of the rectangle.
38 |
39 | Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
40 | specifying the vertical and horizontal center coordinates. The center
41 | defaults to the middle of the extent.
42 |
43 | Points are distributed uniformly within the rectangle leaving a margin
44 | :math:`m=W/64` from the border.
45 |
46 | It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
47 | points :math:`P_{ij}=(x_i, y_i)` where
48 |
49 | .. math::
50 | P_{ij} = \left(
51 | c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
52 | c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
53 | \right)
54 |
55 | Points are returned in row-major order.
56 |
57 | Args:
58 | size (int): grid size.
59 | extent (tuple): height and with of the grid extent.
60 | center (tuple, optional): grid center.
61 | device (str, optional): Defaults to `"cpu"`.
62 |
63 | Returns:
64 | Tensor: grid.
65 | """
66 | if size == 1:
67 | return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
68 |
69 | if center is None:
70 | center = [extent[0] / 2, extent[1] / 2]
71 |
72 | margin = extent[1] / 64
73 | range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
74 | range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
75 | grid_y, grid_x = torch.meshgrid(
76 | torch.linspace(*range_y, size, device=device),
77 | torch.linspace(*range_x, size, device=device),
78 | indexing="ij",
79 | )
80 | return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
81 |
82 |
83 | def reduce_masked_mean(input: torch.Tensor, mask: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False) -> torch.Tensor:
84 | r"""Masked mean
85 |
86 | `reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
87 | over a mask :attr:`mask`, returning
88 |
89 | .. math::
90 | \text{output} =
91 | \frac
92 | {\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
93 | {\epsilon + \sum_{i=1}^N \text{mask}_i}
94 |
95 | where :math:`N` is the number of elements in :attr:`input` and
96 | :attr:`mask`, and :math:`\epsilon` is a small constant to avoid
97 | division by zero.
98 |
99 | `reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
100 | :attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
101 | Optionally, the dimension can be kept in the output by setting
102 | :attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
103 | the same dimension as :attr:`input`.
104 |
105 | The interface is similar to `torch.mean()`.
106 |
107 | Args:
108 | input (Tensor): input tensor.
109 | mask (Tensor): mask.
110 | dim (int, optional): Dimension to sum over. Defaults to None.
111 | keepdim (bool, optional): Keep the summed dimension. Defaults to False.
112 |
113 | Returns:
114 | Tensor: mean tensor.
115 | """
116 |
117 | mask = mask.expand_as(input)
118 |
119 | prod = input * mask
120 |
121 | if dim is None:
122 | numer = torch.sum(prod)
123 | denom = torch.sum(mask)
124 | else:
125 | numer = torch.sum(prod, dim=dim, keepdim=keepdim)
126 | denom = torch.sum(mask, dim=dim, keepdim=keepdim)
127 |
128 | mean = numer / (EPS + denom)
129 | return mean
130 |
131 |
132 | def bilinear_sampler(input: torch.Tensor, coords: torch.Tensor, align_corners: bool = True, padding_mode: str = "border") -> torch.Tensor:
133 | r"""Sample a tensor using bilinear interpolation
134 |
135 | `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
136 | coordinates :attr:`coords` using bilinear interpolation. It is the same
137 | as `torch.nn.functional.grid_sample()` but with a different coordinate
138 | convention.
139 |
140 | The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
141 | :math:`B` is the batch size, :math:`C` is the number of channels,
142 | :math:`H` is the height of the image, and :math:`W` is the width of the
143 | image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
144 | interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
145 |
146 | Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
147 | in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
148 | that in this case the order of the components is slightly different
149 | from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
150 |
151 | If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
152 | in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
153 | left-most image pixel :math:`W-1` to the center of the right-most
154 | pixel.
155 |
156 | If `align_corners` is `False`, the coordinate :math:`x` is assumed to
157 | be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
158 | the left-most pixel :math:`W` to the right edge of the right-most
159 | pixel.
160 |
161 | Similar conventions apply to the :math:`y` for the range
162 | :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
163 | :math:`[0,T-1]` and :math:`[0,T]`.
164 |
165 | Args:
166 | input (Tensor): batch of input images.
167 | coords (Tensor): batch of coordinates.
168 | align_corners (bool, optional): Coordinate convention. Defaults to `True`.
169 | padding_mode (str, optional): Padding mode. Defaults to `"border"`.
170 |
171 | Returns:
172 | Tensor: sampled points.
173 | """
174 |
175 | sizes = input.shape[2:]
176 |
177 | assert len(sizes) in [2, 3]
178 |
179 | if len(sizes) == 3:
180 | # t x y -> x y t to match dimensions T H W in grid_sample
181 | coords = coords[..., [1, 2, 0]]
182 |
183 | if align_corners:
184 | coords = coords * torch.tensor(
185 | [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
186 | )
187 | else:
188 | coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
189 |
190 | coords -= 1
191 |
192 | return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
193 |
194 |
195 | def sample_features4d(input: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
196 | r"""Sample spatial features
197 |
198 | `sample_features4d(input, coords)` samples the spatial features
199 | :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
200 |
201 | The field is sampled at coordinates :attr:`coords` using bilinear
202 | interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
203 | 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
204 | same convention as :func:`bilinear_sampler` with `align_corners=True`.
205 |
206 | The output tensor has one feature per point, and has shape :math:`(B,
207 | R, C)`.
208 |
209 | Args:
210 | input (Tensor): spatial features.
211 | coords (Tensor): points.
212 |
213 | Returns:
214 | Tensor: sampled features.
215 | """
216 |
217 | B, _, _, _ = input.shape
218 |
219 | # B R 2 -> B R 1 2
220 | coords = coords.unsqueeze(2)
221 |
222 | # B C R 1
223 | feats = bilinear_sampler(input, coords)
224 |
225 | return feats.permute(0, 2, 1, 3).view(
226 | B, -1, feats.shape[1] * feats.shape[3]
227 | ) # B C R 1 -> B R C
228 |
229 |
230 | def sample_features5d(input: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
231 | r"""Sample spatio-temporal features
232 |
233 | `sample_features5d(input, coords)` works in the same way as
234 | :func:`sample_features4d` but for spatio-temporal features and points:
235 | :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
236 | a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
237 | x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
238 |
239 | Args:
240 | input (Tensor): spatio-temporal features.
241 | coords (Tensor): spatio-temporal points.
242 |
243 | Returns:
244 | Tensor: sampled features.
245 | """
246 |
247 | B, T, _, _, _ = input.shape
248 |
249 | # B T C H W -> B C T H W
250 | input = input.permute(0, 2, 1, 3, 4)
251 |
252 | # B R1 R2 3 -> B R1 R2 1 3
253 | coords = coords.unsqueeze(3)
254 |
255 | # B C R1 R2 1
256 | feats = bilinear_sampler(input, coords)
257 |
258 | return feats.permute(0, 2, 3, 1, 4).view(
259 | B, feats.shape[2], feats.shape[3], feats.shape[1]
260 | ) # B C R1 R2 1 -> B R1 R2 C
261 |
--------------------------------------------------------------------------------
/src/representations/__init__.py:
--------------------------------------------------------------------------------
1 | from .event_stack import MixedDensityEventStack
2 | from .voxel_grid import VoxelGrid
3 |
4 | class EventRepresentationFactory:
5 | @staticmethod
6 | def create(representation_config):
7 | config = representation_config.copy()
8 | representation_name = config.pop('representation_name')
9 |
10 | if representation_name == 'event_stack':
11 | return MixedDensityEventStack(**config)
12 | elif representation_name == 'voxel_grid':
13 | return VoxelGrid(**config)
14 | else:
15 | raise ValueError("Unsupported representation_name.")
--------------------------------------------------------------------------------
/src/representations/event_stack.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 |
5 |
6 | class MixedDensityEventStack:
7 | """Create an mixed density event stack from events.
8 | Implementation inspired by https://github.com/yonseivnl/se-cff.
9 |
10 | Args:
11 | image_shape: (height, width)
12 | num_stacks: number of channels
13 | interpolation: interpolation method to use when building histograms
14 | """
15 | def __init__(self, image_shape: Tuple[int, int], num_stacks: int,
16 | interpolation: str = 'nearest_neighbor', channel_overlap=False,
17 | centered_channels=False) -> None:
18 | assert image_shape[0] > 0
19 | assert image_shape[1] > 0
20 | assert num_stacks > 0
21 | self.image_shape = image_shape
22 | self.num_stacks = num_stacks
23 | self.interpolation = interpolation
24 | self.channel_overlap = channel_overlap
25 | self.centered_channels = centered_channels
26 | assert self.interpolation in ['nearest_neighbor', 'bilinear']
27 | assert not self.centered_channels or (self.centered_channels and self.channel_overlap), "Centered channels require channel overlap"
28 |
29 | def __call__(self, events: np.ndarray, t_mid=None) -> np.ndarray:
30 | """Create mixed density event stack.
31 |
32 | Args:
33 | events: events: a NumPy array of size [n x d], where n is the number of events and d = 4.
34 | Every event is encoded with 4 values (y, x, t, p).
35 |
36 | Returns:
37 | A mixed density event stack representation of the event data.
38 | """
39 | assert events.shape[1] == 4
40 | assert not self.centered_channels or t_mid is not None, "Centered channels require t_mid"
41 | stacked_event_list = []
42 | curr_num_events = len(events)
43 |
44 | for _ in range(self.num_stacks):
45 | if self.interpolation == 'nearest_neighbor':
46 | stacked_event = self.stack_data_nearest_neighbor(events)
47 | elif self.interpolation == 'bilinear':
48 | stacked_event = self.stack_data_bilinear(events)
49 |
50 | stacked_event_list.append(stacked_event)
51 | curr_num_events = curr_num_events // 2
52 |
53 | if self.centered_channels:
54 | i_mid = np.searchsorted(events[:, 2], t_mid)
55 | i_start = max(i_mid - curr_num_events // 2, 0)
56 | i_end = min(i_mid + curr_num_events // 2, len(events))
57 | events = events[i_start:i_end]
58 | else:
59 | events = events[curr_num_events:]
60 |
61 | if not self.channel_overlap:
62 | for stack_idx in range(self.num_stacks - 1):
63 | prev_stack_polarity = stacked_event_list[stack_idx]
64 | next_stack_polarity = stacked_event_list[stack_idx + 1]
65 | diff_stack_polarity = prev_stack_polarity - next_stack_polarity
66 | stacked_event_list[stack_idx] = diff_stack_polarity
67 |
68 | return np.stack(stacked_event_list, axis=0)
69 |
70 | def stack_data_nearest_neighbor(self, events):
71 | height, width = self.image_shape
72 | y = events[:, 0].astype(int)
73 | x = events[:, 1].astype(int)
74 | p = events[:, 3]
75 | p[p == 0] = -1
76 |
77 | stacked_polarity = np.zeros([height, width], dtype=np.int8)
78 | index = (y * width) + x
79 | stacked_polarity.put(index, p)
80 |
81 | return stacked_polarity
82 |
83 | def stack_data_bilinear(self, events):
84 | if len(events.shape) == 2:
85 | events = events[None, ...] # 1 x n x 4
86 |
87 | weight = events[..., 3]
88 | weight[weight == 0] = -1
89 |
90 | h, w = self.image_shape
91 | nb = len(events)
92 | image = np.zeros((nb, h * w), dtype=np.float64)
93 |
94 | floor_xy = np.floor(events[..., :2] + 1e-8)
95 | floor_to_xy = events[..., :2] - floor_xy
96 |
97 | x1 = floor_xy[..., 1]
98 | y1 = floor_xy[..., 0]
99 | inds = np.concatenate(
100 | [
101 | x1 + y1 * w,
102 | x1 + (y1 + 1) * w,
103 | (x1 + 1) + y1 * w,
104 | (x1 + 1) + (y1 + 1) * w,
105 | ],
106 | axis=-1,
107 | )
108 | inds_mask = np.concatenate(
109 | [
110 | (0 <= x1) * (x1 < w) * (0 <= y1) * (y1 < h),
111 | (0 <= x1) * (x1 < w) * (0 <= y1 + 1) * (y1 + 1 < h),
112 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1) * (y1 < h),
113 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1 + 1) * (y1 + 1 < h),
114 | ],
115 | axis=-1,
116 | )
117 | w_pos0 = (1 - floor_to_xy[..., 0]) * (1 - floor_to_xy[..., 1]) * weight
118 | w_pos1 = floor_to_xy[..., 0] * (1 - floor_to_xy[..., 1]) * weight
119 | w_pos2 = (1 - floor_to_xy[..., 0]) * floor_to_xy[..., 1] * weight
120 | w_pos3 = floor_to_xy[..., 0] * floor_to_xy[..., 1] * weight
121 | vals = np.concatenate([w_pos0, w_pos1, w_pos2, w_pos3], axis=-1)
122 | inds = (inds * inds_mask).astype(np.int64)
123 | vals = vals * inds_mask
124 | for i in range(nb):
125 | np.add.at(image[i], inds[i], vals[i])
126 | return image.reshape((nb,) + self.image_shape).squeeze()
127 |
--------------------------------------------------------------------------------
/src/representations/voxel_grid.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 |
5 |
6 | class VoxelGrid:
7 | """Create a voxel grid from events.
8 | Implementation inspired by https://github.com/uzh-rpg/rpg_e2vid.
9 |
10 | Args:
11 | image_shape: (height, width)
12 | num_bins: number of bins in the temporal axis of the voxel grid
13 | """
14 | def __init__(self, image_shape: Tuple[int, int], num_bins: int) -> None:
15 | assert image_shape[0] > 0
16 | assert image_shape[1] > 0
17 | assert num_bins > 0
18 | self.image_shape = image_shape
19 | self.num_bins = num_bins
20 |
21 | def __call__(self, events: np.ndarray) -> np.ndarray:
22 | """Create voxel grid.
23 |
24 | Args:
25 | events: events: a NumPy array of size [n x d], where n is the number of events and d = 4.
26 | Every event is encoded with 4 values (y, x, t, p).
27 |
28 | Returns:
29 | A voxelized representation of the event data.
30 | """
31 | assert events.shape[1] == 4
32 | height, width = self.image_shape
33 | voxel_grid = np.zeros((self.num_bins, height, width), np.float32).ravel()
34 | y = events[:, 0].astype(int)
35 | x = events[:, 1].astype(int)
36 | t = events[:, 2]
37 | p = events[:, 3]
38 | p[p == 0] = -1
39 |
40 | # Normalize the event timestamps so that they lie between 0 and num_bins
41 | t_min, t_max = np.amin(t), np.amax(t)
42 | delta_t = t_max - t_min
43 |
44 | if delta_t == 0:
45 | delta_t = 1
46 |
47 | t = (self.num_bins - 1) * (t - t_min) / delta_t
48 | tis = t.astype(int)
49 | dts = t - tis
50 | vals_left = p * (1.0 - dts)
51 | vals_right = p * dts
52 | valid_indices = tis < self.num_bins
53 |
54 | np.add.at(
55 | voxel_grid,
56 | x[valid_indices] + y[valid_indices] * width + tis[valid_indices] * width * height,
57 | vals_left[valid_indices],
58 | )
59 |
60 | valid_indices = (tis + 1) < self.num_bins
61 | np.add.at(
62 | voxel_grid,
63 | x[valid_indices] + y[valid_indices] * width + (tis[valid_indices] + 1) * width * height,
64 | vals_right[valid_indices],
65 | )
66 |
67 | voxel_grid = np.reshape(voxel_grid, (self.num_bins, height, width))
68 |
69 | return voxel_grid
70 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .supported_seqs_feature_tracking import SUPPORTED_SEQUENCES_FEATURE_TRACKING
2 | from .visualizer import Visualizer, normalize_and_expand_channels
3 | from .track_utils import compute_tracking_errors, read_txt_results
4 | from .metrics import compute_tapvid_metrics
5 | from .misc import make_grid
--------------------------------------------------------------------------------
/src/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | from typing import Mapping, List
16 |
17 | import numpy as np
18 |
19 |
20 | def compute_tapvid_metrics(
21 | query_points: np.ndarray,
22 | gt_occluded: np.ndarray,
23 | gt_tracks: np.ndarray,
24 | pred_occluded: np.ndarray,
25 | pred_tracks: np.ndarray,
26 | query_mode: str,
27 | get_trackwise_metrics: bool = False,
28 | thresholds: List[int] = [1, 2, 4, 8, 16],
29 | ) -> Mapping[str, np.ndarray]:
30 | """Computes TAP-Vid metrics (Jaccard, Pts.
31 |
32 | Within Thresh, Occ.
33 |
34 | Acc.)
35 |
36 | See the TAP-Vid paper for details on the metric computation. All inputs are
37 | given in raster coordinates. The first three arguments should be the direct
38 | outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
39 | The paper metrics assume these are scaled relative to 256x256 images.
40 | pred_occluded and pred_tracks are your algorithm's predictions.
41 |
42 | This function takes a batch of inputs, and computes metrics separately for
43 | each video. The metrics for the full benchmark are a simple mean of the
44 | metrics across the full set of videos. These numbers are between 0 and 1,
45 | but the paper multiplies them by 100 to ease reading.
46 |
47 | Args:
48 | query_points: The query points, an in the format [t, y, x]. Its size is
49 | [b, n, 3], where b is the batch size and n is the number of queries
50 | gt_occluded: A boolean array of shape [b, n, t], where t is the number of
51 | frames. True indicates that the point is occluded.
52 | gt_tracks: The target points, of shape [b, n, t, 2]. Each point is in the
53 | format [x, y]
54 | pred_occluded: A boolean array of predicted occlusions, in the same format
55 | as gt_occluded.
56 | pred_tracks: An array of track predictions from your algorithm, in the same
57 | format as gt_tracks.
58 | query_mode: Either 'first' or 'strided', depending on how queries are
59 | sampled. If 'first', we assume the prior knowledge that all points
60 | before the query point are occluded, and these are removed from the
61 | evaluation.
62 | get_trackwise_metrics: if True, the metrics will be computed for every
63 | track (rather than every video, which is the default). This means
64 | every output tensor will have an extra axis [batch, num_tracks] rather
65 | than simply (batch).
66 |
67 | Returns:
68 | A dict with the following keys:
69 |
70 | occlusion_accuracy: Accuracy at predicting occlusion.
71 | pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
72 | predicted to be within the given pixel threshold, ignoring occlusion
73 | prediction.
74 | jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
75 | threshold
76 | average_pts_within_thresh: average across pts_within_{x}
77 | average_jaccard: average across jaccard_{x}
78 | """
79 |
80 | summing_axis = (2,) if get_trackwise_metrics else (1, 2)
81 |
82 | metrics = {}
83 |
84 | eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
85 | if query_mode == 'first':
86 | # evaluate frames after the query frame
87 | query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
88 | elif query_mode == 'strided':
89 | # evaluate all frames except the query frame
90 | query_frame_to_eval_frames = 1 - eye
91 | else:
92 | raise ValueError('Unknown query mode ' + query_mode)
93 |
94 | query_frame = query_points[..., 0]
95 | query_frame = np.round(query_frame).astype(np.int32)
96 | evaluation_points = query_frame_to_eval_frames[query_frame] > 0
97 |
98 | # Occlusion accuracy is simply how often the predicted occlusion equals the
99 | # ground truth.
100 | occ_acc = np.sum(
101 | np.equal(pred_occluded, gt_occluded) & evaluation_points,
102 | axis=summing_axis,
103 | ) / np.sum(evaluation_points, axis=summing_axis)
104 | metrics['occlusion_accuracy'] = occ_acc
105 |
106 | # Next, convert the predictions and ground truth positions into pixel
107 | # coordinates.
108 | visible = np.logical_not(gt_occluded)
109 | pred_visible = np.logical_not(pred_occluded)
110 | all_frac_within = []
111 | all_jaccard = []
112 | for thresh in thresholds:
113 | # True positives are points that are within the threshold and where both
114 | # the prediction and the ground truth are listed as visible.
115 | within_dist = np.sum(
116 | np.square(pred_tracks - gt_tracks),
117 | axis=-1,
118 | ) < np.square(thresh)
119 | is_correct = np.logical_and(within_dist, visible)
120 |
121 | # Compute the frac_within_threshold, which is the fraction of points
122 | # within the threshold among points that are visible in the ground truth,
123 | # ignoring whether they're predicted to be visible.
124 | count_correct = np.sum(
125 | is_correct & evaluation_points,
126 | axis=summing_axis,
127 | )
128 | count_visible_points = np.sum(
129 | visible & evaluation_points, axis=summing_axis
130 | )
131 | frac_correct = count_correct / count_visible_points
132 | metrics['pts_within_' + str(thresh)] = frac_correct
133 | all_frac_within.append(frac_correct)
134 |
135 | true_positives = np.sum(
136 | is_correct & pred_visible & evaluation_points, axis=summing_axis
137 | )
138 |
139 | # The denominator of the jaccard metric is the true positives plus
140 | # false positives plus false negatives. However, note that true positives
141 | # plus false negatives is simply the number of points in the ground truth
142 | # which is easier to compute than trying to compute all three quantities.
143 | # Thus we just add the number of points in the ground truth to the number
144 | # of false positives.
145 | #
146 | # False positives are simply points that are predicted to be visible,
147 | # but the ground truth is not visible or too far from the prediction.
148 | gt_positives = np.sum(visible & evaluation_points, axis=summing_axis)
149 | false_positives = (~visible) & pred_visible
150 | false_positives = false_positives | ((~within_dist) & pred_visible)
151 | false_positives = np.sum(
152 | false_positives & evaluation_points, axis=summing_axis
153 | )
154 | jaccard = true_positives / (gt_positives + false_positives)
155 | metrics['jaccard_' + str(thresh)] = jaccard
156 | all_jaccard.append(jaccard)
157 | metrics['average_jaccard'] = np.mean(
158 | np.stack(all_jaccard, axis=1),
159 | axis=1,
160 | )
161 | metrics['average_pts_within_thresh'] = np.mean(
162 | np.stack(all_frac_within, axis=1),
163 | axis=1,
164 | )
165 | return metrics
166 |
167 |
168 | def latex_table(mean_scalars: Mapping[str, float]) -> str:
169 | """Generate a latex table for displaying TAP-Vid and PCK metrics."""
170 | if 'average_jaccard' in mean_scalars:
171 | latex_fields = [
172 | 'average_jaccard',
173 | 'average_pts_within_thresh',
174 | 'occlusion_accuracy',
175 | 'jaccard_1',
176 | 'jaccard_2',
177 | 'jaccard_4',
178 | 'jaccard_8',
179 | 'jaccard_16',
180 | 'pts_within_1',
181 | 'pts_within_2',
182 | 'pts_within_4',
183 | 'pts_within_8',
184 | 'pts_within_16',
185 | ]
186 | header = (
187 | 'AJ & $<\\delta^{x}_{avg}$ & OA & Jac. $\\delta^{0}$ & '
188 | + 'Jac. $\\delta^{1}$ & Jac. $\\delta^{2}$ & '
189 | + 'Jac. $\\delta^{3}$ & Jac. $\\delta^{4}$ & $<\\delta^{0}$ & '
190 | + '$<\\delta^{1}$ & $<\\delta^{2}$ & $<\\delta^{3}$ & '
191 | + '$<\\delta^{4}$'
192 | )
193 | else:
194 | latex_fields = ['PCK@0.1', 'PCK@0.2', 'PCK@0.3', 'PCK@0.4', 'PCK@0.5']
195 | header = ' & '.join(latex_fields)
196 |
197 | body = ' & '.join(
198 | [f'{float(np.array(mean_scalars[x]*100)):.3}' for x in latex_fields]
199 | )
200 | return '\n'.join([header, body])
201 |
202 |
203 | def sample_queries_strided(
204 | target_occluded: np.ndarray,
205 | target_points: np.ndarray,
206 | frames: np.ndarray,
207 | query_stride: int = 5,
208 | ) -> Mapping[str, np.ndarray]:
209 | """Package a set of frames and tracks for use in TAPNet evaluations.
210 |
211 | Given a set of frames and tracks with no query points, sample queries
212 | strided every query_stride frames, ignoring points that are not visible
213 | at the selected frames.
214 |
215 | Args:
216 | target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
217 | where True indicates occluded.
218 | target_points: Position, of shape [n_tracks, n_frames, 2], where each point
219 | is [x,y] scaled between 0 and 1.
220 | frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
221 | -1 and 1.
222 | query_stride: When sampling query points, search for un-occluded points
223 | every query_stride frames and convert each one into a query.
224 |
225 | Returns:
226 | A dict with the keys:
227 | video: Video tensor of shape [1, n_frames, height, width, 3]. The video
228 | has floats scaled to the range [-1, 1].
229 | query_points: Query points of shape [1, n_queries, 3] where
230 | each point is [t, y, x] scaled to the range [-1, 1].
231 | target_points: Target points of shape [1, n_queries, n_frames, 2] where
232 | each point is [x, y] scaled to the range [-1, 1].
233 | trackgroup: Index of the original track that each query point was
234 | sampled from. This is useful for visualization.
235 | """
236 | tracks = []
237 | occs = []
238 | queries = []
239 | trackgroups = []
240 | total = 0
241 | trackgroup = np.arange(target_occluded.shape[0])
242 | for i in range(0, target_occluded.shape[1], query_stride):
243 | mask = target_occluded[:, i] == 0
244 | query = np.stack(
245 | [
246 | i * np.ones(target_occluded.shape[0:1]),
247 | target_points[:, i, 1],
248 | target_points[:, i, 0],
249 | ],
250 | axis=-1,
251 | )
252 | queries.append(query[mask])
253 | tracks.append(target_points[mask])
254 | occs.append(target_occluded[mask])
255 | trackgroups.append(trackgroup[mask])
256 | total += np.array(np.sum(target_occluded[:, i] == 0))
257 |
258 | return {
259 | 'video': frames[np.newaxis, ...],
260 | 'query_points': np.concatenate(queries, axis=0)[np.newaxis, ...],
261 | 'target_points': np.concatenate(tracks, axis=0)[np.newaxis, ...],
262 | 'occluded': np.concatenate(occs, axis=0)[np.newaxis, ...],
263 | 'trackgroup': np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
264 | }
265 |
266 |
267 | def sample_queries_first(
268 | target_occluded: np.ndarray,
269 | target_points: np.ndarray,
270 | frames: np.ndarray,
271 | ) -> Mapping[str, np.ndarray]:
272 | """Package a set of frames and tracks for use in TAPNet evaluations.
273 |
274 | Given a set of frames and tracks with no query points, use the first
275 | visible point in each track as the query.
276 |
277 | Args:
278 | target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
279 | where True indicates occluded.
280 | target_points: Position, of shape [n_tracks, n_frames, 2], where each point
281 | is [x,y] scaled between 0 and 1.
282 | frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
283 | -1 and 1.
284 |
285 | Returns:
286 | A dict with the keys:
287 | video: Video tensor of shape [1, n_frames, height, width, 3]
288 | query_points: Query points of shape [1, n_queries, 3] where
289 | each point is [t, y, x] scaled to the range [-1, 1]
290 | target_points: Target points of shape [1, n_queries, n_frames, 2] where
291 | each point is [x, y] scaled to the range [-1, 1]
292 | """
293 |
294 | valid = np.sum(~target_occluded, axis=1) > 0
295 | target_points = target_points[valid, :]
296 | target_occluded = target_occluded[valid, :]
297 |
298 | query_points = []
299 | for i in range(target_points.shape[0]):
300 | index = np.where(target_occluded[i] == 0)[0][0]
301 | x, y = target_points[i, index, 0], target_points[i, index, 1]
302 | query_points.append(np.array([index, y, x])) # [t, y, x]
303 | query_points = np.stack(query_points, axis=0)
304 |
305 | return {
306 | 'video': frames[np.newaxis, ...],
307 | 'query_points': query_points[np.newaxis, ...],
308 | 'target_points': target_points[np.newaxis, ...],
309 | 'occluded': target_occluded[np.newaxis, ...],
310 | }
311 |
--------------------------------------------------------------------------------
/src/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def make_grid(height, width, stride=1):
4 | """
5 | Create a grid of points with the given stride.
6 |
7 | Args:
8 | height (int): Grid height
9 | width (int): Grid width
10 | stride (int): Spacing between points
11 |
12 | Returns:
13 | np.ndarray: Grid points of shape ((height//stride)*(width//stride), 2)
14 | """
15 | x = np.arange(0, width, stride)
16 | y = np.arange(0, height, stride)
17 | X, Y = np.meshgrid(x, y)
18 | return np.stack([X.flatten(), Y.flatten()], axis=1)
--------------------------------------------------------------------------------
/src/utils/supported_seqs_feature_tracking.py:
--------------------------------------------------------------------------------
1 | SUPPORTED_SEQUENCES_FEATURE_TRACKING = {
2 | 'eds': [
3 | '01_peanuts_light',
4 | '02_rocket_earth_light',
5 | '08_peanuts_running',
6 | '14_ziggy_in_the_arena',
7 |
8 | ],
9 | 'ec': [
10 | 'shapes_rotation',
11 | 'boxes_rotation',
12 | 'boxes_translation',
13 | 'shapes_6dof',
14 | 'shapes_translation'
15 | ]
16 | }
--------------------------------------------------------------------------------
/src/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified by F. Hamann (TU Berlin)
8 | # Copyright (c) 2025 F. Hamann (TU Berlin)
9 | import os
10 | import numpy as np
11 | import imageio
12 | import torch
13 |
14 | from matplotlib import cm
15 | import torch.nn.functional as F
16 | import torchvision.transforms as transforms
17 | import matplotlib.pyplot as plt
18 | from PIL import Image, ImageDraw
19 |
20 |
21 | def normalize_and_expand_channels(image):
22 | if not isinstance(image, torch.Tensor):
23 | raise TypeError("Input must be a torch tensor")
24 |
25 | *batch_dims, height, width = image.shape
26 |
27 | if len(image.shape) < 2:
28 | raise ValueError("Input tensor must have at least shape (..., height, width)")
29 |
30 | image_flat = image.view(-1, height, width)
31 |
32 | min_val = image_flat.min()
33 | max_val = image_flat.max()
34 | range_val = max_val - min_val
35 | range_val[range_val == 0] = 1
36 |
37 | image_normalized = (image_flat - min_val) / range_val * 255
38 | image_rgb_flat = image_normalized.to(torch.uint8).unsqueeze(1).repeat(1, 3, 1, 1)
39 | image_rgb = image_rgb_flat.view(*batch_dims, 3, height, width)
40 |
41 | return image_rgb
42 |
43 |
44 | def read_video_from_path(path):
45 | try:
46 | reader = imageio.get_reader(path)
47 | except Exception as e:
48 | print("Error opening video file: ", e)
49 | return None
50 | frames = []
51 | for i, im in enumerate(reader):
52 | frames.append(np.array(im))
53 | return np.stack(frames)
54 |
55 |
56 | def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
57 | # Create a draw object
58 | draw = ImageDraw.Draw(rgb)
59 | # Calculate the bounding box of the circle
60 | left_up_point = (coord[0] - radius, coord[1] - radius)
61 | right_down_point = (coord[0] + radius, coord[1] + radius)
62 | # Draw the circle
63 | draw.ellipse(
64 | [left_up_point, right_down_point],
65 | fill=tuple(color) if visible else None,
66 | outline=tuple(color),
67 | )
68 | return rgb
69 |
70 |
71 | def draw_line(rgb, coord_y, coord_x, color, linewidth):
72 | draw = ImageDraw.Draw(rgb, 'RGBA') # Use RGBA mode for better antialiasing
73 | draw.line(
74 | (coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
75 | fill=color,
76 | width=linewidth,
77 | joint="round" # Add round joints to reduce artifacts at line intersections
78 | )
79 | return rgb
80 |
81 |
82 | def add_weighted(rgb, alpha, original, beta, gamma):
83 | return (rgb * alpha + original * beta + gamma).astype("uint8")
84 |
85 |
86 | class Visualizer:
87 | def __init__(
88 | self,
89 | save_dir: str = "./results",
90 | grayscale: bool = False,
91 | pad_value: int = 0,
92 | fps: int = 10,
93 | mode: str = "rainbow", # 'cool', 'optical_flow'
94 | linewidth: int = 2,
95 | show_first_frame: int = 10,
96 | tracks_leave_trace: int = 0, # -1 for infinite
97 | consistent_colors: bool = False,
98 | ):
99 | self.mode = mode
100 | self.save_dir = save_dir
101 | if mode == "rainbow":
102 | self.color_map = cm.get_cmap("gist_rainbow")
103 | elif mode == "cool":
104 | self.color_map = cm.get_cmap(mode)
105 | self.show_first_frame = show_first_frame
106 | self.grayscale = grayscale
107 | self.tracks_leave_trace = tracks_leave_trace
108 | self.pad_value = pad_value
109 | self.linewidth = linewidth
110 | self.fps = fps
111 | self.consistent_colors = consistent_colors
112 | self.color = None
113 |
114 | def visualize(
115 | self,
116 | video: torch.Tensor, # (B,T,C,H,W)
117 | tracks: torch.Tensor, # (B,T,N,2)
118 | visibility: torch.Tensor = None, # (B, T, N, 1) bool
119 | gt_tracks: torch.Tensor = None, # (B,T,N,2)
120 | segm_mask: torch.Tensor = None, # (B,1,H,W)
121 | filename: str = "video",
122 | writer=None, # tensorboard Summary Writer, used for visualization during training
123 | step: int = 0,
124 | query_frame: int = 0,
125 | save_video: bool = True,
126 | compensate_for_camera_motion: bool = False,
127 | mask_segm_tracks: bool = False,
128 | ):
129 | if compensate_for_camera_motion:
130 | assert segm_mask is not None
131 | if segm_mask is not None:
132 | coords = tracks[0, query_frame].round().long()
133 | segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
134 |
135 | video = F.pad(
136 | video,
137 | (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
138 | "constant",
139 | 255,
140 | )
141 | tracks = tracks + self.pad_value
142 |
143 | if self.grayscale:
144 | transform = transforms.Grayscale()
145 | video = transform(video)
146 | video = video.repeat(1, 1, 3, 1, 1)
147 |
148 | if mask_segm_tracks:
149 | tracks = tracks[:, :, segm_mask != 0]
150 | visibility = visibility[:, :, segm_mask != 0] if visibility is not None else None
151 | segm_mask = None
152 |
153 | res_video = self.draw_tracks_on_video(
154 | video=video,
155 | tracks=tracks,
156 | visibility=visibility,
157 | segm_mask=segm_mask,
158 | gt_tracks=gt_tracks,
159 | query_frame=query_frame,
160 | compensate_for_camera_motion=compensate_for_camera_motion,
161 | )
162 | if save_video:
163 | self.save_video(res_video, filename=filename, writer=writer, step=step)
164 | return res_video
165 |
166 | def save_video(self, video, filename, writer=None, step=0):
167 | if writer is not None:
168 | writer.add_video(
169 | filename,
170 | video.to(torch.uint8),
171 | global_step=step,
172 | fps=self.fps,
173 | )
174 | else:
175 | os.makedirs(self.save_dir, exist_ok=True)
176 | wide_list = list(video.unbind(1))
177 | wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
178 |
179 | # Prepare the video file path
180 | save_path = os.path.join(self.save_dir, f"{filename}.mp4")
181 |
182 | # Create a writer object
183 | video_writer = imageio.get_writer(save_path, fps=self.fps, macro_block_size=1)
184 |
185 | # Write frames to the video file
186 | for frame in wide_list[2:-1]:
187 | video_writer.append_data(frame)
188 |
189 | video_writer.close()
190 |
191 | #print(f"Video saved to {save_path}")
192 |
193 | def draw_tracks_on_video(
194 | self,
195 | video: torch.Tensor,
196 | tracks: torch.Tensor,
197 | visibility: torch.Tensor = None,
198 | segm_mask: torch.Tensor = None,
199 | gt_tracks=None,
200 | query_frame: int = 0,
201 | compensate_for_camera_motion=False,
202 | ):
203 | B, T, C, H, W = video.shape
204 | _, _, N, D = tracks.shape
205 |
206 | assert D == 2
207 | assert C == 3
208 | video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
209 | valid_mask = ~torch.all(torch.isnan(tracks)[0], dim=-1).cpu().numpy()
210 | tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
211 | if gt_tracks is not None:
212 | gt_tracks = gt_tracks[0].detach().cpu().numpy()
213 |
214 | res_video = []
215 |
216 | # process input video
217 | for rgb in video:
218 | res_video.append(rgb.copy())
219 | vector_colors = np.zeros((T, N, 3))
220 |
221 | if self.mode == "optical_flow":
222 | import flow_vis
223 |
224 | vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
225 | elif segm_mask is None:
226 | if self.mode == "rainbow":
227 | if self.color is None or not self.consistent_colors:
228 | y_min, y_max = (
229 | tracks[query_frame, :, 1].min(),
230 | tracks[query_frame, :, 1].max(),
231 | )
232 | norm = plt.Normalize(y_min, y_max)
233 | for n in range(N):
234 | color = self.color_map(norm(tracks[query_frame, n, 1]))
235 | color = np.array(color[:3])[None] * 255
236 | vector_colors[:, n] = np.repeat(color, T, axis=0)
237 |
238 | if self.color is None:
239 | self.color = vector_colors
240 | else:
241 | vector_colors = self.color
242 | else:
243 | # color changes with time
244 | for t in range(T):
245 | color = np.array(self.color_map(t / T)[:3])[None] * 255
246 | vector_colors[t] = np.repeat(color, N, axis=0)
247 | else:
248 | if self.mode == "rainbow":
249 | vector_colors[:, segm_mask <= 0, :] = 255
250 |
251 | y_min, y_max = (
252 | tracks[0, segm_mask > 0, 1].min(),
253 | tracks[0, segm_mask > 0, 1].max(),
254 | )
255 | norm = plt.Normalize(y_min, y_max)
256 | for n in range(N):
257 | if segm_mask[n] > 0:
258 | color = self.color_map(norm(tracks[0, n, 1]))
259 | color = np.array(color[:3])[None] * 255
260 | vector_colors[:, n] = np.repeat(color, T, axis=0)
261 |
262 | else:
263 | # color changes with segm class
264 | segm_mask = segm_mask.cpu()
265 | color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
266 | color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
267 | color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
268 | vector_colors = np.repeat(color[None], T, axis=0)
269 |
270 | # draw tracks
271 | if self.tracks_leave_trace != 0:
272 | for t in range(query_frame + 1, T):
273 | first_ind = (
274 | max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
275 | )
276 | curr_tracks = tracks[first_ind : t + 1]
277 | curr_colors = vector_colors[first_ind : t + 1]
278 | if compensate_for_camera_motion:
279 | diff = (
280 | tracks[first_ind : t + 1, segm_mask <= 0]
281 | - tracks[t : t + 1, segm_mask <= 0]
282 | ).mean(1)[:, None]
283 |
284 | curr_tracks = curr_tracks - diff
285 | curr_tracks = curr_tracks[:, segm_mask > 0]
286 | curr_colors = curr_colors[:, segm_mask > 0]
287 |
288 | res_video[t] = self._draw_pred_tracks(
289 | res_video[t],
290 | curr_tracks,
291 | curr_colors,
292 | )
293 | if gt_tracks is not None:
294 | res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
295 |
296 | # draw points
297 | for t in range(query_frame, T):
298 | img = Image.fromarray(np.uint8(res_video[t]))
299 | for i in range(N):
300 | coord = (tracks[t, i, 0], tracks[t, i, 1])
301 | visibile = True
302 | if visibility is not None:
303 | visibile = visibility[0, t, i]
304 | if coord[0] != 0 and coord[1] != 0:
305 | if not compensate_for_camera_motion or (
306 | compensate_for_camera_motion and segm_mask[i] > 0
307 | ):
308 | if valid_mask[t, i]:
309 | img = draw_circle(
310 | img,
311 | coord=coord,
312 | radius=int(self.linewidth * 2),
313 | color=vector_colors[t, i].astype(int),
314 | visible=visibile,
315 | )
316 | res_video[t] = np.array(img)
317 |
318 | # construct the final rgb sequence
319 | if self.show_first_frame > 0:
320 | res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
321 | return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
322 |
323 | def _draw_pred_tracks(
324 | self,
325 | rgb: np.ndarray, # H x W x 3
326 | tracks: np.ndarray, # T x 2
327 | vector_colors: np.ndarray,
328 | base_alpha: float = 0.5,
329 | ):
330 | T, N, _ = tracks.shape
331 | # Convert to PIL Image once at the start
332 | rgb = Image.fromarray(np.uint8(rgb))
333 |
334 | # Create a single composite image for all tracks
335 | for s in range(T - 1):
336 | vector_color = vector_colors[s]
337 | time_alpha = (s / T) ** 2 if self.tracks_leave_trace > 0 else 1.0
338 |
339 | for i in range(N):
340 | coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
341 | coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
342 |
343 | if coord_y[0] != 0 and coord_y[1] != 0 and coord_x[0] != 0 and coord_x[1] != 0:
344 | # Only draw valid coordinates
345 | color = (*vector_color[i].astype(int),) # Convert to tuple
346 | rgb = draw_line(
347 | rgb,
348 | coord_y,
349 | coord_x,
350 | color,
351 | self.linewidth,
352 | )
353 |
354 | return np.array(rgb)
355 |
356 | def _draw_gt_tracks(
357 | self,
358 | rgb: np.ndarray, # H x W x 3,
359 | gt_tracks: np.ndarray, # T x 2
360 | ):
361 | T, N, _ = gt_tracks.shape
362 | color = np.array((211, 0, 0))
363 | rgb = Image.fromarray(np.uint8(rgb))
364 | for t in range(T):
365 | for i in range(N):
366 | gt_tracks = gt_tracks[t][i]
367 | # draw a red cross
368 | if gt_tracks[0] > 0 and gt_tracks[1] > 0:
369 | length = self.linewidth * 3
370 | coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
371 | coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
372 | rgb = draw_line(
373 | rgb,
374 | coord_y,
375 | coord_x,
376 | color,
377 | self.linewidth,
378 | )
379 | coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
380 | coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
381 | rgb = draw_line(
382 | rgb,
383 | coord_y,
384 | coord_x,
385 | color,
386 | self.linewidth,
387 | )
388 | rgb = np.array(rgb)
389 | return rgb
390 |
--------------------------------------------------------------------------------