├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── config ├── exe │ ├── inference_offline │ │ ├── event_kubric.yaml │ │ └── evimo2.yaml │ ├── inference_online │ │ ├── e2d2.yaml │ │ ├── feature_tracking.yaml │ │ └── penn_aviary.yaml │ └── prepare_event_representations │ │ ├── e2d2.yaml │ │ ├── ec.yaml │ │ ├── eds.yaml │ │ ├── event_kubric.yaml │ │ └── evimo2.yaml └── misc │ ├── ec │ └── gt_tracks │ │ ├── boxes_rotation.gt.txt │ │ ├── boxes_translation.gt.txt │ │ ├── shapes_6dof.gt.txt │ │ ├── shapes_rotation.gt.txt │ │ └── shapes_translation.gt.txt │ ├── eds │ ├── calib.yaml │ └── gt_tracks │ │ ├── 01_peanuts_light.gt.txt │ │ ├── 02_rocket_earth_light.gt.txt │ │ ├── 08_peanuts_running.gt.txt │ │ └── 14_ziggy_in_the_arena.gt.txt │ └── evimo2 │ └── val_samples.csv ├── data_pipeline ├── README.md ├── annotate.py ├── annotator.py ├── compress.py ├── convert.py ├── converter.py ├── decompress.py └── sample.py ├── docs ├── flowchart.png ├── pred_e2d2_fidget.gif ├── pred_eds_peanuts_running.gif ├── pred_event_kubric.gif ├── pred_evimo2.gif └── thumbnail.png ├── requirements.txt ├── scripts ├── benchmark_feature_tracking.py ├── benchmark_tap.py ├── create_e2d2_fidget_spinner_gt.py ├── create_evimo2_track_gt.py ├── demo.py ├── download_eds.sh ├── inference_offline.py ├── inference_online.py └── prepare_event_representations.py └── src ├── __init__.py ├── data ├── __init__.py ├── modules │ ├── __init__.py │ ├── e2d2.py │ ├── event_kubric.py │ ├── evimo2.py │ ├── feature_tracking_online.py │ └── penn_aviary.py └── utils │ ├── __init__.py │ └── collate.py ├── model └── etap │ ├── core │ ├── __init__.py │ ├── cotracker │ │ ├── __init__.py │ │ └── blocks.py │ ├── embeddings.py │ └── model_utils.py │ └── model.py ├── representations ├── __init__.py ├── event_stack.py └── voxel_grid.py └── utils ├── __init__.py ├── metrics.py ├── misc.py ├── supported_seqs_feature_tracking.py ├── track_utils.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | output/ 3 | weights/ 4 | launch.json 5 | .vscode 6 | __pycache__ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # UV 105 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | #uv.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | 117 | # pdm 118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 119 | #pdm.lock 120 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 121 | # in version control. 122 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 123 | .pdm.toml 124 | .pdm-python 125 | .pdm-build/ 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | ENV/ 143 | env.bak/ 144 | venv.bak/ 145 | 146 | # Spyder project settings 147 | .spyderproject 148 | .spyproject 149 | 150 | # Rope project settings 151 | .ropeproject 152 | 153 | # mkdocs documentation 154 | /site 155 | 156 | # mypy 157 | .mypy_cache/ 158 | .dmypy.json 159 | dmypy.json 160 | 161 | # Pyre type checker 162 | .pyre/ 163 | 164 | # pytype static type analyzer 165 | .pytype/ 166 | 167 | # Cython debug symbols 168 | cython_debug/ 169 | 170 | # PyCharm 171 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 172 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 173 | # and can be added to the global gitignore or merged into this file. For a more nuclear 174 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 175 | #.idea/ 176 | 177 | # PyPI configuration file 178 | .pypirc 179 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "data_pipeline/rpg_vid2e"] 2 | path = data_pipeline/rpg_vid2e 3 | url = git@github.com:filbert14/rpg_vid2e.git 4 | [submodule "data_pipeline/kubric"] 5 | path = data_pipeline/kubric 6 | url = git@github.com:filbert14/kubric.git 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ETAP: Event-based Tracking of Any Point (CVPR'25 Highlight) 2 | 3 | [![Paper](https://img.shields.io/badge/arXiv-2409.03358-b31b1b.svg)](https://arxiv.org/pdf/2412.00133) 4 | [![Data](https://img.shields.io/badge/Dataset-GoogleDrive-4285F4.svg)](https://drive.google.com/drive/folders/1Mprj-vOiTP5IgXE9iuu4-4bazcZUswpp?usp=drive_link) 5 | [![Video](https://img.shields.io/badge/Video-YouTube-FF0000.svg?style=flat&labelColor=444444&color=FF0000)](https://youtu.be/LaeA9WJ7ptc) 6 | [![License](https://img.shields.io/badge/License-CC_BY--NC_4.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc/4.0/) 7 | 8 | ## Introduction 9 | 10 | This is the official repository for [**ETAP: Event-based Tracking of Any Point**](https://arxiv.org/pdf/2412.00133), by [Friedhelm Hamann](https://friedhelmhamann.github.io/), [Daniel Gehrig](https://danielgehrig18.github.io/), [Filbert Febryanto](https://github.com/filbert14), [Kostas Daniilidis](https://www.cis.upenn.edu/~kostas/) and [Guillermo Gallego](http://www.guillermogallego.es/). 11 | 12 |

13 | 14 | ETAP_thumbnail 15 | 16 |

17 | 18 | ### Key Features 19 | 20 | - The first event-only point tracking method with strong cross-dataset generalization. 21 | - Robust tracking in challenging conditions (high-speed motion, lighting changes). 22 | - Evaluation benchmark based on six event camera datasets: EDS, EC, EVIMO2, EventKubric, E2D2 and PennAviary 23 | 24 | ### Example Predictions 25 | 26 |
27 | 28 | 29 | 33 | 37 | 38 | 39 | 43 | 47 | 48 |
30 | Example 1 31 |
Example 1: Synthetic dataset EventKubric 32 |
34 | Example 2 35 |
Example 2: Feature Tracking on EDS 36 |
40 | Example 3 41 |
Example 3: E2D2 42 |
44 | Example 4 45 |
Example 4: EVIMO2 46 |
49 |
50 | 51 | ## Table of Contents 52 | 53 | - [Quickstart Demo](#quickstart-demo) 54 | - [Installation](#installation) 55 | - [Model Selection](#model-selection) 56 | - [Evaluation Tasks](#evaluation-tasks) 57 | - [Feature Tracking (EDS, EC)](#feature-tracking-eds-ec) 58 | - [EVIMO2](#evaluation-evimo2) 59 | - [EventKubric](#evaluation-eventkubric) 60 | - [E2D2](#evaluation-e2d2) 61 | - [PennAviary](#evaluation-pennaviary-qualitative) 62 | - [Synthetic Data Generation](#synthetic-data-generation-eventkubric) 63 | - [Acknowledgements](#acknowledgements) 64 | - [Citation](#citation) 65 | - [Additional Resources](#additional-resources) 66 | - [License](#license) 67 | 68 | ## Quickstart Demo 69 | 70 | The quickest way to try ETAP is using our demo: 71 | 72 | 1. Clone the repository: 73 | ```bash 74 | git clone https://github.com/tub-rip/ETAP.git 75 | cd ETAP 76 | ``` 77 | 78 | 2. Download [the model weights](https://drive.google.com/file/d/18mnwu8CsrVJvDXeRtvU0Wgp6shdxn0HD/view?usp=drive_link) and save to `weights/ETAP_v1_cvpr25.pth` 79 | 3. Download [the demo example](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) (30MB) and extract to `data/demo_example` 80 | 4. Run the demo: 81 | ```bash 82 | python scripts/demo.py 83 | ``` 84 | 85 | This demo requires only basic dependencies: `torch`, `numpy`, `tqdm`, `matplotlib`, `imageio`, and `pillow`. No dataset preprocessing needed! 86 | 87 | ## Installation 88 | 89 | 1. Clone the repository: 90 | ```bash 91 | git clone git@github.com:tub-rip/ETAP.git 92 | cd ETAP 93 | ``` 94 | 95 | 2. Set up the environment: 96 | ```bash 97 | conda create --name ETAP python=3.10 98 | conda activate ETAP 99 | ``` 100 | 101 | 3. Install PyTorch (choose a command compatible with your CUDA version from the [PyTorch website](https://pytorch.org/get-started/locally/)), e.g.: 102 | ```bash 103 | conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia 104 | ``` 105 | 106 | 4. Install other dependencies: 107 | ```bash 108 | pip install -r requirements.txt 109 | ``` 110 | 111 | ## Model Selection 112 | 113 | Download the [pre-trained model](https://drive.google.com/drive/folders/17YNqf603b3dEdETmNht-ih7cwosKyACN?usp=drive_link) and move it into the folder `/weights/`. 114 | 115 | To reproduce the paper results, use the model `ETAP_v1_cvpr25.pth`. 116 | 117 | ## Evaluation Tasks and Datasets 118 | 119 | ### Evaluation: Feature Tracking (EDS, EC) 120 | 121 | #### Preparations 122 | 123 | ##### Download EDS (Prophesee Gen3 640 x 480 px) 124 | 125 | The four evaluation sequences of the "Event-aided Direct Sparse Odometry Dataset" (EDS) can be downloaded in two ways: 126 | 127 | **Option 1** 128 | 129 | Download the four evaluation sequences of the "Event-aided Direct Sparse Odometry Dataset" (EDS) from the [official web page](https://rpg.ifi.uzh.ch/eds.html): 130 | 131 | * `01_peanuts_light` 132 | * `02_rocket_earth_light` 133 | * `08_peanuts_running` 134 | * `14_ziggy_in_the_arena` 135 | 136 | Choose the archive file option, which contains the events as an hdf5 file. Place all sequences in a common folder. 137 | 138 | **Option 2:** Use our download script: 139 | ```bash 140 | bash scripts/download_eds.sh 141 | ``` 142 | 143 | We also use the calibration data provided by EDS. No action is required, as it is included in this repository at `config/misc/eds/calib.yaml`. This is the same file as in the `00_calib` results from the official source. 144 | 145 | The evaluation was introduced in [DDFT](https://github.com/uzh-rpg/deep_ev_tracker). As with the calibration data, we have hardcoded the ground truth tracks at `/config/misc/eds/gt_tracks`, so no additional steps are necessary. If you are interested in how the tracks are created, please refer to the DDFT repository. 146 | 147 | Create a symbolic link to your data root into `/data`, or alternatively you can change the paths in the config files. The setup should look something like this: 148 | 149 | ``` 150 | data/eds/ 151 | ├── 01_peanuts_light 152 | │ └── events.h5 153 | ├── 02_rocket_earth_light 154 | │ └── events.h5 155 | ├── 08_peanuts_running 156 | │ └── events.h5 157 | └── 14_ziggy_in_the_arena 158 | └── events.h5 159 | ``` 160 | 161 | ##### Download EC (DAVIS240C 240 x 180 px) 162 | 163 | Download the five evaluation sequences of the "Event Camera Dataset" (EC) from the [official source](https://rpg.ifi.uzh.ch/davis_data.html). Download the option `Text (zip)`. Unzip the sequences into a folder structure like this: 164 | 165 | ``` 166 | data/ec/ 167 | ├── boxes_rotation 168 | │ ├── calib.txt 169 | │ ├── events.txt 170 | │ ├── groundtruth.txt 171 | │ ├── images 172 | │ ├── images.txt 173 | │ └── imu.txt 174 | ├── boxes_translation 175 | │ ├── events.txt 176 | │ ├── ... 177 | ├── shapes_6dof 178 | │ ├── events.txt 179 | │ ├── ... 180 | ├── shapes_rotation 181 | │ ├── events.txt 182 | │ ├── ... 183 | └── shapes_translation 184 | ├── events.txt 185 | ├── ... 186 | ``` 187 | 188 | As with EDS, the ground truth tracks are from the evaluation introduced in DDFT but we have included them at `config/misc/ec/gt` for convenience. 189 | 190 | ##### Preprocessing 191 | 192 | Preprocess the data by transforming the raw events into event stacks with the following commands: 193 | 194 | ```bash 195 | # For EDS dataset 196 | python scripts/prepare_event_representations.py --dataset eds --config config/exe/prepare_event_representations/eds.yaml 197 | 198 | # For EC dataset 199 | python scripts/prepare_event_representations.py --dataset ec --config config/exe/prepare_event_representations/ec.yaml 200 | ``` 201 | 202 | #### Inference 203 | 204 | Run the tracking inference with: 205 | 206 | ```bash 207 | python scripts/inference_online.py --config config/exe/inference_online/feature_tracking.yaml 208 | ``` 209 | 210 | #### Evaluation 211 | 212 | Run the benchmarking script to evaluate the tracking results: 213 | 214 | ```bash 215 | python scripts/benchmark_feature_tracking.py feature_tracking_eds_ec 216 | ``` 217 | 218 | ### Evaluation: EVIMO2 (Samsung DVS Gen3 640 x 480 px) 219 | 220 | #### Preparations 221 | 222 | 1. Download the required EVIMO2 sequences from the [official source](https://better-flow.github.io/evimo/download_evimo_2.html#imo). You only need **Motion Segmentation / Object Recognition** sequences for the **samsung_mono** camera in **.npz** format (2.4GB). Unzip them and move them into the data directory. 223 | 224 | 2. Download the precomputed tracks [here](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) and merge them into the data directory. 225 | 226 | The result should look like this: 227 | 228 | ``` 229 | data/evimo/ 230 | └── samsung_mono 231 | └── imo 232 | └── eval 233 | ├── scene13_dyn_test_00_000000 234 | │ ├── dataset_classical.npz 235 | │ ├── dataset_depth.npz 236 | │ ├── dataset_events_p.npy 237 | │ ├── dataset_events_t.npy 238 | │ ├── dataset_events_xy.npy 239 | │ ├── dataset_info.npz 240 | │ ├── dataset_mask.npz 241 | │ └── dataset_tracks.h5 242 | ├── scene13_dyn_test_05_000000 243 | │ ├── dataset_classical.npz 244 | ... ... 245 | ``` 246 | 247 | 3. Precompute the event stacks: 248 | 249 | ```bash 250 | python scripts/prepare_event_representations.py --dataset evimo2 --config config/exe/prepare_event_representations/evimo2.yaml 251 | ``` 252 | 253 | #### Inference & Evaluation 254 | 255 | Run inference and evaluation with a single command: 256 | 257 | ```bash 258 | python scripts/inference_offline.py --config config/exe/inference_offline/evimo2.yaml 259 | ``` 260 | 261 | #### Ground Truth Track Generation (Optional) 262 | 263 | If you want to generate the point tracks yourself instead of using the precomputed ones: 264 | 265 | ```bash 266 | python scripts/create_evimo2_track_gt.py --config config/misc/evimo2/val_samples.csv --data_root data/evimo2 267 | ``` 268 | 269 | ### Evaluation: EventKubric (synthetic 512 x 512 px) 270 | 271 | #### Preparations 272 | 273 | 1. Download [the event_kubric test set](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) and move it to the data directory: 274 | 275 | ``` 276 | data/event_kubric 277 | └── test 278 | ├── sample_000042 279 | │ ├── annotations.npy 280 | │ ├── data.hdf5 281 | │ ├── data_ranges.json 282 | │ ├── events 283 | │ │ ├── 0000000000.npz 284 | │ │ ├── 0000000001.npz 285 | │ │ ... 286 | │ ├── events.json 287 | │ └── metadata.json 288 | ├── sample_000576 289 | │ ├── annotations.npy 290 | │ ... 291 | ... 292 | ``` 293 | 294 | 2. Prepare the event stacks: 295 | 296 | ```bash 297 | python scripts/prepare_event_representations.py --dataset event_kubric --config config/exe/prepare_event_representations/event_kubric.yaml 298 | ``` 299 | 300 | #### Inference & Evaluation 301 | 302 | Run inference and evaluation with a single command: 303 | 304 | ```bash 305 | python scripts/inference_offline.py --config config/exe/inference_offline/event_kubric.yaml 306 | ``` 307 | 308 | ### Evaluation: E2D2 (SilkyEvCam 640 x 480 px) 309 | 310 | #### Preparations 311 | 312 | 1. Download [the E2D2 fidget spinner sequence](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) and move it to the data directory: 313 | 314 | ``` 315 | data/e2d2/ 316 | └── 231025_110210_fidget5_high_exposure 317 | ├── gt_positions.npy 318 | ├── gt_timestamps.npy 319 | ├── queries.npy 320 | └── seq.h5 321 | ``` 322 | 323 | 2. Prepare the event stacks: 324 | 325 | ```bash 326 | python scripts/prepare_event_representations.py --dataset e2d2 --config config/exe/prepare_event_representations/e2d2.yaml 327 | ``` 328 | 329 | #### Inference 330 | 331 | ```bash 332 | python scripts/inference_online.py --config config/exe/inference_online/e2d2.yaml 333 | ``` 334 | 335 | #### Evaluation 336 | 337 | ```bash 338 | python scripts/benchmark_tap.py --gt_dir data/e2d2/231025_110210_fidget5_high_exposure --pred_dir output/inference/tap_e2d2 339 | ``` 340 | 341 | #### Ground Truth Generation (Optional) 342 | 343 | The ground truth is calculated from the turning speed of the fidget spinner and is provided for download. To calculate the ground truth tracks yourself, run: 344 | 345 | ```bash 346 | python scripts/create_e2d2_fidget_spinner_gt.py 347 | ``` 348 | 349 | ### Evaluation: PennAviary (SilkyEvCam 640 x 480 px, Qualitative) 350 | 351 | #### Preparations 352 | 353 | Download [the penn_aviary sequence](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) and move it to the data directory: 354 | 355 | ``` 356 | data/penn_aviary/ 357 | └── 231018_174107_view2 358 | ├── mask00082.png 359 | └── seq.h5 360 | ``` 361 | 362 | #### Inference 363 | 364 | Run the inference with: 365 | 366 | ```bash 367 | python scripts/inference_online.py --config config/exe/inference_online/penn_aviary.yaml 368 | ``` 369 | 370 | ## Synthetic Data Generation (EventKubric) 371 | 372 | We provide a [10 sample test set of EventKubric](https://drive.google.com/drive/folders/1d5Yi1q6ZFom3Q_VrELzXjrxE5aC2ezOk?usp=drive_link) for quick evaluation. The complete dataset consists of approximately 10,000 samples. 373 | 374 | To generate your own synthetic event data, please refer to the [Data Pipeline Instructions](data_pipeline/README.md). 375 | 376 | ## Acknowledgements 377 | 378 | We gratefully appreciate the following repositories and thank the authors for their excellent work: 379 | 380 | - [CoTracker](https://github.com/facebookresearch/co-tracker) 381 | - [TapVid](https://github.com/google-deepmind/tapnet/tree/main/tapnet/tapvid) 382 | - [DDFT](https://github.com/uzh-rpg/deep_ev_tracker) 383 | 384 | ## Citation 385 | 386 | If you find this work useful in your research, please consider citing: 387 | 388 | ```bibtex 389 | @InProceedings{Hamann25cvpr, 390 | author={Friedhelm Hamann and Daniel Gehrig and Filbert Febryanto and Kostas Daniilidis and Guillermo Gallego}, 391 | title={{ETAP}: Event-based Tracking of Any Point}, 392 | booktitle={{IEEE/CVF} Conf. Computer Vision and Pattern Recognition ({CVPR})}, 393 | year=2025, 394 | } 395 | ``` 396 | 397 | ## Additional Resources 398 | 399 | * [Research page (TU Berlin, RIP lab)](https://sites.google.com/view/guillermogallego/research/event-based-vision) 400 | * [Course at TU Berlin](https://sites.google.com/view/guillermogallego/teaching/event-based-robot-vision) 401 | * [Survey paper](http://rpg.ifi.uzh.ch/docs/EventVisionSurvey.pdf) 402 | * [List of Event-based Vision Resources](https://github.com/uzh-rpg/event-based_vision_resources) 403 | 404 | ## License 405 | 406 | This project is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License - see the [LICENSE](LICENSE) file for details. This means you are free to share and adapt the material for non-commercial purposes, provided you give appropriate credit and indicate if changes were made. -------------------------------------------------------------------------------- /config/exe/inference_offline/event_kubric.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | height: 512 3 | width: 512 4 | exp_name: event_kubric 5 | checkpoint: weights/ETAP_v1_cvpr25.pth 6 | 7 | model: 8 | num_in_channels: 10 9 | stride: 4 10 | window_len: 8 11 | num_virtual_tracks: 64 12 | 13 | data: 14 | data_root: data/event_kubric 15 | dataset_name: event_kubric 16 | preprocessed_name: event_stack_v1 17 | seq_len: 24 18 | traj_per_sample: 500 -------------------------------------------------------------------------------- /config/exe/inference_offline/evimo2.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | height: 480 3 | width: 640 4 | exp_name: evimo2 5 | checkpoint: weights/ETAP_v1_cvpr25.pth 6 | 7 | model: 8 | num_in_channels: 10 9 | stride: 4 10 | window_len: 8 11 | num_virtual_tracks: 64 12 | 13 | data: 14 | data_root: data/evimo2 15 | dataset_name: evimo2 16 | preprocessed_name: event_stack_v1 17 | metadata_path: config/misc/evimo2/val_samples.csv -------------------------------------------------------------------------------- /config/exe/inference_online/e2d2.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | exp_name: tap_e2d2 3 | height: 512 4 | width: 512 5 | ckp_path: weights/ETAP_v1_cvpr25.pth 6 | add_support_points: true 7 | support_point_stride: 20 8 | 9 | data: 10 | data_root: data/e2d2 11 | dataset_name: e2d2 12 | preprocessed_name: event_stack_v1 13 | sequences: 14 | - 231025_110210_fidget5_high_exposure 15 | 16 | model: 17 | num_in_channels: 10 18 | stride: 4 19 | window_len: 8 20 | num_virtual_tracks: 64 21 | -------------------------------------------------------------------------------- /config/exe/inference_online/feature_tracking.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | exp_name: feature_tracking_eds_ec 3 | ckp_path: weights/ETAP_v1_cvpr25.pth 4 | 5 | data: 6 | data_root: data/ 7 | dataset_name: feature_tracking_online 8 | preprocessed_name: event_stack_v1 9 | 10 | model: 11 | num_in_channels: 10 12 | stride: 4 13 | window_len: 8 14 | num_virtual_tracks: 64 15 | -------------------------------------------------------------------------------- /config/exe/inference_online/penn_aviary.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | exp_name: qualitative_penn_aviary 3 | height: 480 4 | width: 640 5 | ckp_path: weights/ETAP_v1_cvpr25.pth 6 | add_support_points: false 7 | 8 | data: 9 | data_root: data/penn_aviary 10 | dataset_name: penn_aviary 11 | load_rgb: false 12 | sequences: 13 | - 231018_174107_view2 14 | 15 | repr_config: 16 | representation_name: event_stack 17 | num_stacks: 10 18 | interpolation: bilinear 19 | channel_overlap: true 20 | centered_channels: false 21 | image_shape: [480, 640] # Height and width of data, in common is for inference 22 | 23 | sequence_data: 24 | 231018_174107_view2: 25 | start_time_s: 2.7099975 # frame 35 26 | duration_s: 0.5 27 | step_time_s: 0.0033 28 | num_events: 200000 29 | query_stride: 8 30 | mask_name: mask00082.png 31 | 32 | model: 33 | num_in_channels: 10 34 | stride: 4 35 | window_len: 8 36 | num_virtual_tracks: 64 37 | -------------------------------------------------------------------------------- /config/exe/prepare_event_representations/e2d2.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | height: 480 3 | width: 640 4 | save_prefix: v1 5 | num_events: 200000 6 | data_root: data/e2d2 7 | sequences: 8 | - 231025_110210_fidget5_high_exposure 9 | 10 | event_representation: 11 | representation_name: event_stack 12 | num_stacks: 10 13 | interpolation: bilinear 14 | channel_overlap: true 15 | centered_channels: false 16 | image_shape: [480, 640] -------------------------------------------------------------------------------- /config/exe/prepare_event_representations/ec.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | dataset_path: data/ec 3 | num_events: 100000 4 | height: 180 5 | width: 240 6 | save_prefix: v1 7 | 8 | event_representation: 9 | representation_name: event_stack 10 | num_stacks: 10 11 | interpolation: bilinear 12 | channel_overlap: true 13 | centered_channels: false 14 | -------------------------------------------------------------------------------- /config/exe/prepare_event_representations/eds.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | dataset_path: data/eds 3 | num_events: 800000 4 | height: 480 5 | width: 640 6 | save_prefix: v1 7 | 8 | event_representation: 9 | representation_name: event_stack 10 | num_stacks: 10 11 | interpolation: bilinear 12 | channel_overlap: true 13 | centered_channels: false 14 | -------------------------------------------------------------------------------- /config/exe/prepare_event_representations/event_kubric.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | dataset_path: data/event_kubric/test 3 | num_events: 400000 4 | sequence_length: 24 5 | height: 512 6 | width: 512 7 | save_prefix: v1 8 | create_time_inverted: false 9 | 10 | event_representation: 11 | representation_name: event_stack 12 | num_stacks: 10 13 | interpolation: bilinear 14 | channel_overlap: true 15 | centered_channels: false -------------------------------------------------------------------------------- /config/exe/prepare_event_representations/evimo2.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | data_path: data/evimo2 3 | sample_config: config/misc/evimo2/val_samples.csv 4 | num_events: 60000 5 | height: 480 6 | width: 640 7 | save_prefix: v1 8 | 9 | event_representation: 10 | representation_name: event_stack 11 | num_stacks: 10 12 | interpolation: bilinear 13 | channel_overlap: true 14 | centered_channels: false -------------------------------------------------------------------------------- /config/misc/eds/calib.yaml: -------------------------------------------------------------------------------- 1 | cam0: # This is the RGB camera 2 | cam_overlaps: [1] 3 | camera_model: pinhole 4 | distortion_coeffs: [-0.36965913545735024, 0.17414034009883844, 0.003915245015812422, 5 | 0.003666687416655559] 6 | distortion_model: radtan 7 | intrinsics: [766.536025127154, 767.5749459126396, 291.0503512057777, 227.4060484950132] 8 | resolution: [640, 480] 9 | rostopic: /cam0/image_raw 10 | flip: True 11 | cam1: # This is the event camera 12 | T_cn_cnm1: 13 | - [0.9998964430808897, -0.0020335804041023736, -0.014246672065022661, -0.00011238613157578769] 14 | - [0.001703024953250547, 0.9997299470300024, -0.023176123864880376, -0.0005981481496958399] 15 | - [0.014289955220253567, 0.02314946137886846, 0.9996298813149167, -0.004416681577516066] 16 | - [0.0, 0.0, 0.0, 1.0] 17 | cam_overlaps: [0] 18 | camera_model: pinhole 19 | distortion_coeffs: [-0.09776467241921379, 0.2143738428636279, -0.004710710105172864, 20 | -0.004215916089401789] 21 | distortion_model: radtan 22 | intrinsics: [560.8520948927032, 560.6295819972383, 313.00733235019237, 217.32858679842997] 23 | resolution: [640, 480] 24 | rostopic: /cam1/image_raw 25 | -------------------------------------------------------------------------------- /config/misc/evimo2/val_samples.csv: -------------------------------------------------------------------------------- 1 | name,t_start,t_end 2 | samsung_mono/imo/eval/scene13_dyn_test_00_000000,0.033333,1.516667 3 | samsung_mono/imo/eval/scene14_dyn_test_03_000000,1.0288,2.572 4 | samsung_mono/imo/eval/scene14_dyn_test_05_000000,0.711,1.59975 5 | samsung_mono/imo/eval/scene15_dyn_test_01_000000,1.406668,3.516667 6 | samsung_mono/imo/eval/scene15_dyn_test_02_000000,0.033333,0.966667 7 | samsung_mono/imo/eval/scene15_dyn_test_05_000000,0.7,1.4 -------------------------------------------------------------------------------- /data_pipeline/README.md: -------------------------------------------------------------------------------- 1 | # EventKubric Data Generation Pipeline 2 | 3 | ## Overview 4 | 5 | This document provides instructions for generating synthetic data samples for the EventKubric dataset. The pipeline leverages two key tools: 6 | - [Kubric](https://github.com/google-research/kubric) - For generating synthetic scenes and rendering frames 7 | - [Vid2E](https://github.com/uzh-rpg/rpg_vid2e) - For converting video frames to event data 8 | 9 | Both tools are included as Git submodules in this repository. To access them, you can either: 10 | 11 | Clone with the `--recursive` flag (if starting fresh): 12 | ```bash 13 | git clone --recursive git@github.com:tub-rip/ETAP.git 14 | ``` 15 | 16 | Or if you've already cloned the repository without the `--recursive` flag, initialize and update the submodules: 17 | ```bash 18 | # Navigate to your existing repository 19 | cd ETAP 20 | 21 | # Initialize and fetch all submodules 22 | git submodule update --init --recursive 23 | ``` 24 | 25 | For reference samples, see the [EventKubric test set](https://drive.google.com/drive/folders/1v8dYA-D7OOCAw9TimTxj74nz8t5mbsqz). 26 | 27 | ## Data Generation Pipeline 28 | 29 | ### 1. Setting Up Kubric 30 | 31 | #### CPU Setup (Standard) 32 | Pull the Kubric Docker image: 33 | ```bash 34 | docker pull kubricdockerhub/kubruntu 35 | ``` 36 | 37 | Set up and enter Docker container: 38 | ```bash 39 | docker run --interactive --user $(id -u):$(id -g) --volume "$(pwd):/kubric" kubricdockerhub/kubruntu 40 | sudo docker exec -it bash 41 | ``` 42 | 43 | #### GPU Setup (Optional, for faster rendering) 44 | Inside the `/kubric` directory: 45 | 1. Build the wheel file: 46 | ```bash 47 | python setup.py bdist_wheel 48 | ``` 49 | 50 | 2. Build Docker images with GPU support: 51 | ```bash 52 | docker build -f docker/Blender.Dockerfile -t kubricdockerhub/blender-gpu:latest . 53 | docker build -f docker/Kubruntu.Dockerfile -t kubricdockerhub/kubruntu-gpu:latest . 54 | ``` 55 | 56 | 3. Run container with GPU access: 57 | ```bash 58 | docker run --interactive --gpus all --env KUBRIC_USE_GPU=1 --volume "$(pwd):/kubric" kubricdockerhub/kubruntu-gpu 59 | sudo docker exec -it bash 60 | ``` 61 | 62 | ### 2. Generating Synthetic Samples 63 | 64 | The complete pipeline consists of four main steps as visualized in the flow chart: 65 | 66 |

67 | flowchart_event_kubric 68 |

69 | 70 | #### Step 1: Generate Training Examples 71 | ```bash 72 | python3 sample.py --output_dir= --start_index= --end_index= --worker_script=event_kubric_worker 73 | ``` 74 | This generates `end_index - start_index` training examples with indices in the range `[start_index, end_index - 1]`. 75 | 76 | **Optional flags:** 77 | - `--panning`: Include camera panning motions (similar to [TAPIR](https://deepmind-tapir.github.io/)) 78 | 79 | #### Step 2: Generate Ground Truth Point Tracks 80 | ```bash 81 | python3 annotate.py --dataset_path= --start_index= --end_index= --resolution=512 --num_frames=96 --tracks_to_sample=2048 82 | ``` 83 | 84 | #### Step 3: Generate Events from Frames 85 | ```bash 86 | python3 convert.py --dataset_path= --start_index= --end_index= --frame_rate=48 --num_frames=96 --ct_lower=0.16 --ct_upper=0.34 --ref_period=0 87 | ``` 88 | 89 | #### Step 4: Compress Data (Optional) 90 | ```bash 91 | python3 compress.py --dataset_root= 92 | ``` 93 | This compresses raw files into a single `.hdf5` file per sample. For additional options (e.g., subsampling), refer to `compress.py`. 94 | 95 | To decompress a sample: 96 | ```bash 97 | python3 decompress.py --hdf5_path= --output_folder= 98 | ``` 99 | 100 | ## Dataset Structure 101 | 102 | After generation, your dataset will have the following structure: 103 | 104 | ``` 105 | dataset_root/ 106 | ├── 00000000/ 107 | ├── 00000001/ 108 | │ ├── annotations.npy # Ground truth point tracks 109 | │ ├── data.hdf5 # Compressed data 110 | │ ├── events/ # Event data 111 | │ │ ├── 0000000000.npz 112 | │ │ ├── ... 113 | │ │ └── 0000000211.npz 114 | │ └── raw/ # Raw Kubric outputs 115 | │ ├── backward_flow_*.png 116 | │ ├── data_ranges.json 117 | │ ├── depth_*.tiff 118 | │ ├── events.json 119 | │ ├── forward_flow_*.png 120 | │ ├── metadata.json 121 | │ ├── normal_*.png 122 | │ ├── object_coordinates_*.png 123 | │ ├── rgba_*.png 124 | │ └── segmentation_*.png 125 | ├── 00000002/ 126 | └── ... 127 | ``` 128 | 129 | ## Data Format Conventions 130 | 131 | - **Events**: `[y, x, t, p]` where: 132 | - `y, x`: Pixel coordinates 133 | - `t`: Timestamp 134 | - `p`: Polarity (positive/negative) 135 | 136 | - **Annotations**: 137 | - `annotations["video"]`: RGB frames `[N, C, H, W]` 138 | - `annotations["target_points"]`: Point track coordinates `[N, P, 2]` 139 | - `annotations["occluded"]`: Occlusion masks `[N, P]` 140 | 141 | Where: 142 | - `N`: Number of frames 143 | - `C`: Number of channels 144 | - `H, W`: Height and width 145 | - `P`: Number of tracked points 146 | -------------------------------------------------------------------------------- /data_pipeline/annotate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | def main(): 6 | ap = argparse.ArgumentParser() 7 | ap.add_argument("--dataset_path" , required=True) 8 | ap.add_argument("--start_index" , type=int, required=True) 9 | ap.add_argument("--end_index" , type=int, required=True) 10 | 11 | ap.add_argument("--resolution" , type=int, required=True) 12 | ap.add_argument("--num_frames" , type=int, required=True) 13 | ap.add_argument("--tracks_to_sample", type=int, default=2048) 14 | 15 | args = vars(ap.parse_args()) 16 | dataset_path = args["dataset_path"] 17 | start_index = args["start_index"] 18 | end_index = args["end_index"] 19 | 20 | resolution = args["resolution"] 21 | num_frames = args["num_frames"] 22 | tracks_to_sample= args["tracks_to_sample"] 23 | 24 | example_counter = start_index 25 | 26 | def get_current_example_path(): 27 | return os.path.join(dataset_path, f"{example_counter:08d}") 28 | 29 | while example_counter < end_index: 30 | print(f"Annotating example {example_counter}") 31 | 32 | script = ["python3", 33 | "annotator.py", 34 | f"--scene_dir={os.path.join(get_current_example_path(), 'raw')}", 35 | f"--resolution={resolution}", 36 | f"--num_frames={num_frames}", 37 | f"--output_dir={get_current_example_path()}", 38 | f"--tracks_to_sample={tracks_to_sample}"] 39 | 40 | annotate_result = subprocess.run(script, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 41 | 42 | if annotate_result.returncode == 0: 43 | print(f"Successfully annotated example {example_counter}") 44 | else: 45 | print(f"Failed to annotate example {example_counter}, return code: {annotate_result.returncode}") 46 | break 47 | 48 | example_counter += 1 49 | 50 | if __name__ == "__main__": 51 | main() -------------------------------------------------------------------------------- /data_pipeline/annotator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import json 5 | import tensorflow as tf 6 | import sys 7 | import numpy as np 8 | 9 | sys.path.append(os.path.join("kubric", "challenges")) 10 | 11 | from PIL import Image 12 | from point_tracking.dataset import add_tracks 13 | 14 | from movi.movi_f import subsample_nearest_neighbor 15 | from movi.movi_f import subsample_avg 16 | from movi.movi_f import read_png 17 | from movi.movi_f import read_tiff 18 | from movi.movi_f import convert_float_to_uint16 19 | 20 | def main(): 21 | # Parse arguments 22 | ap = argparse.ArgumentParser() 23 | ap.add_argument("--scene_dir", required=True) 24 | ap.add_argument("--resolution", type=int, required=True) 25 | ap.add_argument("--num_frames", type=int, required=True) 26 | ap.add_argument("--output_dir", required=True) 27 | ap.add_argument("--tracks_to_sample", type=int, default=2048) 28 | args = vars(ap.parse_args()) 29 | 30 | _scene_dir = args["scene_dir"] 31 | _width, _height = args["resolution"], args["resolution"] 32 | _num_frames = args["num_frames"] 33 | _target_size = (_height, _width) 34 | _output_dir = args["output_dir"] 35 | _tracks_to_sample= args["tracks_to_sample"] 36 | layers = ("rgba", "segmentation", "depth", "normal", "object_coordinates") 37 | 38 | # Load simulation output 39 | with tf.io.gfile.GFile(os.path.join(_scene_dir, 'metadata.json'), "r") as fp: 40 | metadata = json.load(fp) 41 | paths = { 42 | key: [os.path.join(_scene_dir, (f"{key}_{f:05d}.png")) for f in range (_num_frames)] 43 | for key in layers if key != "depth" 44 | } 45 | 46 | # Gather relevant data for point tracking annotation 47 | result = {} 48 | result["normal"] = tf.convert_to_tensor([subsample_nearest_neighbor(read_png(frame_path), _target_size) for frame_path in paths["normal"]], dtype=float) 49 | result["object_coordinates"] = tf.convert_to_tensor([subsample_nearest_neighbor(read_png(frame_path), _target_size) for frame_path in paths["object_coordinates"]]) 50 | result["segmentations"] = tf.convert_to_tensor([subsample_nearest_neighbor(read_png(frame_path), _target_size) for frame_path in paths["segmentation"]]) 51 | result["video"] = tf.convert_to_tensor([subsample_avg(read_png(frame_path), _target_size)[..., :3] for frame_path in paths["rgba"]]) 52 | result["metadata"] = {} 53 | 54 | depth_paths = [os.path.join(_scene_dir, f"depth_{f:05d}.tiff") for f in range(_num_frames)] 55 | depth_frames = np.array([subsample_nearest_neighbor(read_tiff(frame_path), _target_size) for frame_path in depth_paths]) 56 | depth_min, depth_max = np.min(depth_frames), np.max(depth_frames) 57 | 58 | result["depth"] = convert_float_to_uint16(depth_frames, depth_min, depth_max) 59 | result["metadata"]["depth_range"] = [depth_min, depth_max] 60 | 61 | result["instances"] = {} 62 | result["instances"]["bboxes_3d"] = tf.convert_to_tensor([np.array(obj["bboxes_3d"], np.float32) for obj in metadata["instances"]]) 63 | result["instances"]["quaternions"] = tf.convert_to_tensor([np.array(obj["quaternions"], np.float32) for obj in metadata["instances"]]) 64 | result["camera"] = {} 65 | 66 | result["camera"]["focal_length"] = metadata["camera"]["focal_length"] 67 | result["camera"]["sensor_width"] = metadata["camera"]["sensor_width"] 68 | result["camera"]["positions"] = np.array(metadata["camera"]["positions"], np.float32) 69 | result["camera"]["quaternions"] = np.array(metadata["camera"]["quaternions"], np.float32) 70 | 71 | # Annotate using add_tracks 72 | point_tracking = add_tracks(result, train_size=_target_size, random_crop=False, tracks_to_sample=_tracks_to_sample) 73 | video = point_tracking["video"].numpy() 74 | target_points = point_tracking["target_points"].numpy() 75 | occluded = point_tracking["occluded"].numpy() 76 | 77 | # Save annotations 78 | annotations = {"target_points": target_points, "occluded":occluded} 79 | np.save(os.path.join(_output_dir, "annotations.npy"), annotations) 80 | 81 | if __name__ == "__main__": 82 | main() -------------------------------------------------------------------------------- /data_pipeline/compress.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import h5py 4 | import numpy as np 5 | import cv2 6 | import glob 7 | import argparse 8 | import multiprocessing 9 | from functools import partial 10 | import hdf5plugin 11 | import yaml 12 | from tqdm import tqdm 13 | 14 | 15 | def load_image(file_path): 16 | img = cv2.imread(file_path, cv2.IMREAD_UNCHANGED) 17 | if img is None: 18 | raise ValueError(f"Failed to load image: {file_path}") 19 | 20 | # Convert BGR to RGB for 3-channel color images 21 | if len(img.shape) == 3 and img.shape[2] == 3: 22 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 23 | # Convert BGRA to RGBA for 4-channel images 24 | elif len(img.shape) == 3 and img.shape[2] == 4: 25 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) 26 | 27 | return img 28 | 29 | 30 | def calculate_original_size(sample_path): 31 | total_size = 0 32 | for file_name in os.listdir(sample_path): 33 | if file_name != 'data.hdf5': # Exclude the HDF5 file if it exists 34 | file_path = os.path.join(sample_path, file_name) 35 | if os.path.isfile(file_path): 36 | total_size += os.path.getsize(file_path) 37 | return total_size 38 | 39 | 40 | def process_sample(sample_path, verbose, subsample_factor, compression_level=9): 41 | output_file = os.path.join(sample_path, 'data.hdf5') 42 | sample_path = os.path.join(sample_path, "raw") 43 | 44 | # Skip if HDF5 file already exists 45 | if os.path.exists(output_file): 46 | if verbose: 47 | print(f"Skipping {sample_path}: HDF5 file already exists") 48 | return None 49 | 50 | try: 51 | original_size = calculate_original_size(sample_path) 52 | 53 | with h5py.File(output_file, 'w') as hf: 54 | for data_type in ['backward_flow', 'depth', 'forward_flow', 'normal', 'object_coordinates', 'rgba', 55 | 'segmentation']: 56 | file_pattern = f"{data_type}_*.{'tiff' if data_type == 'depth' else 'png'}" 57 | files = sorted(glob.glob(os.path.join(sample_path, file_pattern))) 58 | 59 | if not files: 60 | print(f"Warning: No {data_type} files found in {sample_path}") 61 | continue 62 | 63 | if data_type != 'rgba': 64 | # Subsample for all data types except 'rgba' 65 | files = files[::subsample_factor] 66 | 67 | images = [load_image(f) for f in files] 68 | dataset = np.stack(images) 69 | 70 | hf.create_dataset(data_type, data=dataset, compression="gzip", compression_opts=compression_level) 71 | 72 | # Store subsampling indices 73 | if len(files) > 0: 74 | total_frames = len(glob.glob(os.path.join(sample_path, 'rgba_*.png'))) 75 | subsample_indices = np.arange(0, total_frames, subsample_factor) 76 | hf.create_dataset('subsample_indices', data=subsample_indices, compression="gzip", 77 | compression_opts=compression_level) 78 | 79 | # Add events and metadata as attributes 80 | with open(os.path.join(sample_path, 'events.json'), 'r') as f: 81 | events = json.load(f) 82 | hf.attrs['events'] = json.dumps(events) 83 | 84 | with open(os.path.join(sample_path, 'metadata.json'), 'r') as f: 85 | metadata = json.load(f) 86 | hf.attrs['metadata'] = json.dumps(metadata) 87 | 88 | # Check for and add additional metadata from metadata.yaml 89 | yaml_path = os.path.join(sample_path, 'metadata.yaml') 90 | if os.path.exists(yaml_path): 91 | with open(yaml_path, 'r') as f: 92 | additional_metadata = yaml.safe_load(f) 93 | hf.attrs['additional_metadata'] = json.dumps(additional_metadata) 94 | 95 | if verbose: 96 | compressed_size = os.path.getsize(output_file) 97 | compression_factor = original_size / compressed_size 98 | return { 99 | 'sample': os.path.basename(sample_path), 100 | 'original_size': original_size, 101 | 'compressed_size': compressed_size, 102 | 'compression_factor': compression_factor 103 | } 104 | return None 105 | except Exception as e: 106 | return {'sample': os.path.basename(sample_path), 'error': str(e)} 107 | 108 | 109 | def process_dataset(dataset_root, verbose, num_processes, subsample_factor, compression_level=9): 110 | sample_paths = [os.path.join(dataset_root, d) for d in os.listdir(dataset_root) if 111 | os.path.isdir(os.path.join(dataset_root, d))] 112 | 113 | with multiprocessing.Pool(processes=num_processes) as pool: 114 | results = list(tqdm(pool.imap(partial(process_sample, verbose=verbose, subsample_factor=subsample_factor, 115 | compression_level=compression_level), sample_paths), 116 | total=len(sample_paths), desc="Processing samples")) 117 | 118 | failed_samples = [result['sample'] for result in results if result and 'error' in result] 119 | 120 | if verbose: 121 | for result in results: 122 | if result and 'error' not in result: 123 | print(f"Sample: {result['sample']}") 124 | print(f"Original size: {result['original_size'] / 1024:.2f} KB") 125 | print(f"Compressed size: {result['compressed_size'] / 1024:.2f} KB") 126 | print(f"Compression factor: {result['compression_factor']:.2f}x") 127 | print("--------------------") 128 | 129 | # Print and log failed samples 130 | if failed_samples: 131 | print("Failed samples:") 132 | for sample in failed_samples: 133 | print(sample) 134 | 135 | log_file = os.path.join(dataset_root, 'log.txt') 136 | with open(log_file, 'w') as f: 137 | f.write("Failed samples:\n") 138 | for sample in failed_samples: 139 | f.write(f"{sample}\n") 140 | print(f"Failed samples have been logged in {log_file}") 141 | 142 | 143 | if __name__ == "__main__": 144 | parser = argparse.ArgumentParser( 145 | description="Convert dataset to HDF5 format using gzip compression with multiprocessing and subsampling") 146 | parser.add_argument("dataset_root", help="Path to the dataset root") 147 | parser.add_argument("-v", "--verbose", action="store_true", help="Print compression information for each sample") 148 | parser.add_argument("-j", "--jobs", type=int, default=multiprocessing.cpu_count(), 149 | help="Number of parallel jobs to run (default: number of CPU cores)") 150 | parser.add_argument("-s", "--subsample", type=int, default=1, 151 | help="Subsample factor for non-rgba data (default: 1)") 152 | parser.add_argument("-c", "--compression", type=int, default=4, choices=range(0, 10), 153 | help="Gzip compression level (0-9, default: 9)") 154 | args = parser.parse_args() 155 | 156 | process_dataset(args.dataset_root, args.verbose, args.jobs, args.subsample, args.compression) 157 | print("Conversion completed.") 158 | -------------------------------------------------------------------------------- /data_pipeline/convert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import os 4 | 5 | def main(): 6 | ap = argparse.ArgumentParser() 7 | ap.add_argument("--dataset_path", required=True) 8 | ap.add_argument("--start_index" , type=int, required=True) 9 | ap.add_argument("--end_index" , type=int, required=True) 10 | 11 | # Event generation parameters 12 | ap.add_argument("--frame_rate", type=int , default=48) 13 | ap.add_argument("--num_frames", type=int , default=96) 14 | ap.add_argument("--ct_lower", type=float, default=0.16) 15 | ap.add_argument("--ct_upper", type=float, default=0.34) 16 | ap.add_argument("--ref_period", type=int, default=0) 17 | 18 | args = vars(ap.parse_args()) 19 | dataset_path = args["dataset_path"] 20 | start_index = args["start_index"] 21 | end_index = args["end_index"] 22 | 23 | frame_rate = args["frame_rate"] 24 | num_frames = args["num_frames"] 25 | ct_lower = args["ct_lower"] 26 | ct_upper = args["ct_upper"] 27 | ref_period = args["ref_period"] 28 | 29 | example_counter = start_index 30 | 31 | def get_current_example_path(): 32 | return os.path.join(dataset_path, f"{example_counter:08d}") 33 | 34 | while example_counter < end_index: 35 | print(f"Converting example {example_counter}") 36 | 37 | script = ["python3", 38 | "converter.py", 39 | f"--scene_dir={os.path.join(get_current_example_path(), 'raw')}", 40 | f"--output_dir={get_current_example_path()}", 41 | f"--frame_rate={frame_rate}", 42 | f"--num_frames={num_frames}", 43 | f"--ct_lower={ct_lower}", 44 | f"--ct_upper={ct_upper}", 45 | f"--ref_period={ref_period}"] 46 | 47 | convert_result = subprocess.run(script, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 48 | 49 | if convert_result.returncode == 0: 50 | print(f"Successfully converted example {example_counter}") 51 | else: 52 | print(f"Failed to convert example {example_counter}, return code: {convert_result.returncode}") 53 | break 54 | 55 | example_counter += 1 56 | 57 | if __name__ == "__main__": 58 | main() -------------------------------------------------------------------------------- /data_pipeline/converter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rpg_vid2e.upsampling.utils import Upsampler 3 | from rpg_vid2e.esim_torch.scripts.generate_events import process_dir 4 | 5 | import argparse 6 | import shutil 7 | import os 8 | 9 | def main(): 10 | ap = argparse.ArgumentParser() 11 | ap.add_argument("--scene_dir", required=True) 12 | ap.add_argument("--output_dir", required=True) 13 | 14 | # Event generation parameters 15 | ap.add_argument("--frame_rate", type=int , default=48) 16 | ap.add_argument("--num_frames", type=int , default=96) 17 | ap.add_argument("--ct_lower", type=float, default=0.16) 18 | ap.add_argument("--ct_upper", type=float, default=0.34) 19 | ap.add_argument("--ref_period", type=int, default=0) 20 | 21 | ## Only for vid2e's process_dir 22 | ap.add_argument("--contrast_threshold_negative", "-cp", type=float, default=0.2) 23 | ap.add_argument("--contrast_threshold_positive", "-cn", type=float, default=0.2) 24 | ap.add_argument("--refractory_period_ns", "-rp", type=int, default=0) 25 | 26 | args = vars(ap.parse_args()) 27 | scene_dir = args["scene_dir"] 28 | output_dir = args["output_dir"] 29 | 30 | frame_rate = args["frame_rate"] 31 | num_frames = args["num_frames"] 32 | ct_lower = args["ct_lower"] 33 | ct_upper = args["ct_upper"] 34 | ref_period = args["ref_period"] 35 | 36 | tmpf = os.path.join(output_dir, "tmp") 37 | os.makedirs(os.path.join(tmpf, "seq", "imgs")) 38 | 39 | rgbs = [f"rgba_{i:05d}.png" for i in range(num_frames)] 40 | 41 | for rgb in rgbs: 42 | shutil.copy(os.path.join(scene_dir, rgb), 43 | os.path.join(tmpf, "seq", "imgs", rgb.split("_")[1])) 44 | 45 | # Upsample frames 46 | fpsf = open(os.path.join(tmpf, "seq", "fps.txt"), "w") 47 | fpsf.write(str(frame_rate)) 48 | fpsf.close() 49 | 50 | upsampler = Upsampler(input_dir=os.path.join(tmpf, "seq"), 51 | output_dir=os.path.join(tmpf, "seq_upsampled")) 52 | upsampler.upsample() 53 | 54 | # Generate events 55 | vid2e_args = ap.parse_args() 56 | vid2e_args.contrast_threshold_positive = np.random.uniform(ct_lower, ct_upper) 57 | vid2e_args.contrast_threshold_negative = np.random.uniform(ct_lower, ct_upper) 58 | vid2e_args.refractory_period_ns = ref_period 59 | 60 | process_dir(os.path.join(output_dir, "events"), 61 | os.path.join(tmpf, "seq_upsampled"), 62 | vid2e_args) 63 | 64 | # Remove temporary files 65 | shutil.rmtree(tmpf) 66 | 67 | if __name__ == "__main__": 68 | main() -------------------------------------------------------------------------------- /data_pipeline/decompress.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import h5py 4 | import numpy as np 5 | import cv2 6 | import argparse 7 | 8 | def save_image(data, file_path, data_type): 9 | if file_path.endswith('.tiff') or file_path.endswith('.tif'): 10 | cv2.imwrite(file_path, data) 11 | else: 12 | if data_type in ['backward_flow', 'forward_flow'] or data.dtype == np.uint16: 13 | cv2.imwrite(file_path, cv2.cvtColor(data, cv2.COLOR_RGB2BGR)) 14 | #cv2.imwrite(file_path, data) 15 | elif len(data.shape) == 2: # Grayscale image 16 | cv2.imwrite(file_path, data) 17 | elif len(data.shape) == 3: 18 | if data.shape[2] == 3: # RGB image 19 | cv2.imwrite(file_path, cv2.cvtColor(data, cv2.COLOR_RGB2BGR)) 20 | elif data.shape[2] == 4: # RGBA image 21 | cv2.imwrite(file_path, cv2.cvtColor(data, cv2.COLOR_RGBA2BGRA)) 22 | else: 23 | raise ValueError(f"Unsupported image shape: {data.shape}") 24 | else: 25 | raise ValueError(f"Unsupported image shape: {data.shape}") 26 | 27 | def extract_sample(hdf5_path, output_folder): 28 | # Create output folder 29 | os.makedirs(output_folder, exist_ok=True) 30 | 31 | with h5py.File(hdf5_path, 'r') as hf: 32 | # Extract datasets 33 | for data_type in ['backward_flow', 'depth', 'forward_flow', 'normal', 'object_coordinates', 'rgba', 'segmentation']: 34 | if data_type in hf: 35 | data = hf[data_type][:] 36 | if len(data.shape) == 3: # Single image or multiple 2D frames 37 | for i, frame in enumerate(data): 38 | file_name = f"{data_type}_{i:06d}.{'tiff' if data_type == 'depth' else 'png'}" 39 | file_path = os.path.join(output_folder, file_name) 40 | save_image(frame, file_path, data_type) 41 | elif len(data.shape) == 4: # Multiple color frames 42 | for i, frame in enumerate(data): 43 | file_name = f"{data_type}_{i:06d}.{'tiff' if data_type == 'depth' else 'png'}" 44 | file_path = os.path.join(output_folder, file_name) 45 | save_image(frame, file_path, data_type) 46 | else: 47 | print(f"Warning: Unexpected shape for {data_type}: {data.shape}") 48 | else: 49 | print(f"Warning: {data_type} not found in HDF5 file") 50 | 51 | # Extract metadata and events 52 | if 'metadata' in hf.attrs: 53 | metadata = json.loads(hf.attrs['metadata']) 54 | with open(os.path.join(output_folder, 'metadata.json'), 'w') as f: 55 | json.dump(metadata, f, indent=2) 56 | 57 | if 'events' in hf.attrs: 58 | events = json.loads(hf.attrs['events']) 59 | with open(os.path.join(output_folder, 'events.json'), 'w') as f: 60 | json.dump(events, f, indent=2) 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser(description="Extract sample data from HDF5 file to original file structure") 64 | parser.add_argument("hdf5_path", help="Path to the input HDF5 file") 65 | parser.add_argument("output_folder", help="Path to the output folder for extracted files") 66 | args = parser.parse_args() 67 | 68 | extract_sample(args.hdf5_path, args.output_folder) 69 | print("Extraction completed.") 70 | -------------------------------------------------------------------------------- /data_pipeline/sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import shutil 4 | import time 5 | import os 6 | 7 | def main(): 8 | ap = argparse.ArgumentParser() 9 | ap.add_argument("--output_dir" , required=True) 10 | ap.add_argument("--start_index", type=int, required=True) 11 | ap.add_argument("--end_index" , type=int, required=True) 12 | ap.add_argument("--worker_script", required=True) 13 | ap.add_argument("--panning" , action="store_true") 14 | 15 | args = vars(ap.parse_args()) 16 | output_dir = args["output_dir"] 17 | start_index = args["start_index"] 18 | end_index = args["end_index"] 19 | worker_script = args["worker_script"] + ".py" 20 | panning = args["panning"] 21 | 22 | example_counter = start_index 23 | os.makedirs(output_dir, exist_ok=True) 24 | 25 | def get_current_output_dir(): 26 | return os.path.join(output_dir, f"{example_counter:08d}") 27 | 28 | while example_counter < end_index: 29 | print(f"Generating example {example_counter}") 30 | 31 | script = ["python3", 32 | os.path.join("kubric", "challenges", "movi", worker_script), 33 | f"--job-dir={os.path.join(get_current_output_dir(), 'raw')}"] 34 | 35 | if panning: 36 | script.append("--camera=linear_movement_linear_lookat") 37 | else: 38 | script.append("--camera=linear_movement") 39 | 40 | # Generate training example 41 | kubric_result = subprocess.run(script, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 42 | 43 | # Regenerate training example on error 44 | if kubric_result.returncode == 0: 45 | print(f"Successfully generated example {example_counter}") 46 | else: 47 | print(f"Failed to generate example {example_counter}, return code: {kubric_result.returncode}") 48 | print("Retrying in 10 seconds . . .") 49 | 50 | if os.path.exists(get_current_output_dir()): 51 | shutil.rmtree(get_current_output_dir()) 52 | 53 | time.sleep(10) 54 | continue 55 | 56 | example_counter += 1 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /docs/flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/flowchart.png -------------------------------------------------------------------------------- /docs/pred_e2d2_fidget.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/pred_e2d2_fidget.gif -------------------------------------------------------------------------------- /docs/pred_eds_peanuts_running.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/pred_eds_peanuts_running.gif -------------------------------------------------------------------------------- /docs/pred_event_kubric.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/pred_event_kubric.gif -------------------------------------------------------------------------------- /docs/pred_evimo2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/pred_evimo2.gif -------------------------------------------------------------------------------- /docs/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/docs/thumbnail.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.0 2 | pyyaml>=5.1.0 3 | tqdm>=4.42.0 4 | h5py>=3.1.0 5 | hdf5plugin 6 | opencv-python>=4.5.0 7 | pandas>=2.2.0 8 | pillow>=10.4.0 9 | pytorch-lightning>=2.2.5 10 | matplotlib 11 | imageio[ffmpeg] 12 | prettytable 13 | scipy -------------------------------------------------------------------------------- /scripts/benchmark_feature_tracking.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import os 4 | import sys 5 | from enum import Enum 6 | import csv 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | import numpy as np 11 | from prettytable import PrettyTable 12 | 13 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 14 | 15 | from src.utils import compute_tracking_errors, read_txt_results, SUPPORTED_SEQUENCES_FEATURE_TRACKING 16 | 17 | def calculate_mean_by_dataset_type(table, dataset_type): 18 | values = [] 19 | for row in table.rows: 20 | if dataset_type == EvalDatasetType.EDS: 21 | if row[0] in SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']: 22 | values.append(float(row[1])) 23 | else: # EvalDatasetType.EC 24 | if row[0] in SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']: 25 | values.append(float(row[1])) 26 | return np.mean(values) if values else 0 27 | 28 | def create_summary_csv(tables, output_path): 29 | header = ['fa_5_eds_mean', 'fa_5_ec_mean', 'te_5_eds_mean', 'te_5_ec_mean', 'inliers_eds_mean', 'inliers_ec_mean'] 30 | data = {} 31 | 32 | for k in tables.keys(): 33 | eds_mean = calculate_mean_by_dataset_type(tables[k], EvalDatasetType.EDS) 34 | ec_mean = calculate_mean_by_dataset_type(tables[k], EvalDatasetType.EC) 35 | 36 | if k.startswith('age_5'): 37 | data['fa_5_eds_mean'] = eds_mean 38 | data['fa_5_ec_mean'] = ec_mean 39 | elif k.startswith('te_5'): 40 | data['te_5_eds_mean'] = eds_mean 41 | data['te_5_ec_mean'] = ec_mean 42 | elif k.startswith('inliers'): 43 | data['inliers_eds_mean'] = eds_mean 44 | data['inliers_ec_mean'] = ec_mean 45 | 46 | # Add columns for individual sequences 47 | all_sequences = [] 48 | for sequence_name in SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']: 49 | all_sequences.append((sequence_name, EvalDatasetType.EDS)) 50 | for sequence_name in SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']: 51 | all_sequences.append((sequence_name, EvalDatasetType.EC)) 52 | 53 | for i, eval_sequence in enumerate(all_sequences): 54 | sequence_name = eval_sequence[0] 55 | for k in ['age_5', 'te_5', 'inliers']: 56 | column_name = f"{k}_{sequence_name}" 57 | header.append(column_name) 58 | data[column_name] = tables[f"{k}_mu"].rows[i][1] 59 | 60 | with open(output_path, 'w', newline='') as f: 61 | writer = csv.DictWriter(f, fieldnames=header) 62 | writer.writeheader() 63 | writer.writerow(data) 64 | 65 | 66 | # Data Classes for Inference 67 | class EvalDatasetType(Enum): 68 | EC = 0 69 | EDS = 1 70 | 71 | plt.rcParams["font.family"] = "serif" 72 | 73 | def parse_args(): 74 | """ 75 | Parse command-line arguments for the script. 76 | 77 | Returns: 78 | argparse.Namespace: An object containing parsed arguments. 79 | """ 80 | parser = argparse.ArgumentParser( 81 | description="Parse paths for results and output directories." 82 | ) 83 | parser.add_argument( 84 | 'method', 85 | type=Path, 86 | help="Path to the output directory." 87 | ) 88 | 89 | parser.add_argument( 90 | '--results_dir', 91 | type=Path, 92 | default=Path("output/inference"), 93 | help="Path to the results directory." 94 | ) 95 | 96 | args = parser.parse_args() 97 | return args 98 | 99 | 100 | if __name__ == "__main__": 101 | args = parse_args() 102 | error_threshold_range = np.arange(1, 32, 1) 103 | methods = [args.method] 104 | 105 | table_keys = [ 106 | "age_5_mu", 107 | "age_5_std", 108 | "te_5_mu", 109 | "te_5_std", 110 | "age_mu", 111 | "age_std", 112 | "inliers_mu", 113 | "inliers_std", 114 | "expected_age", 115 | ] 116 | tables = {} 117 | for k in table_keys: 118 | tables[k] = PrettyTable() 119 | tables[k].title = k 120 | tables[k].field_names = ["Sequence Name"] + methods 121 | 122 | # Create a list of sequences with their dataset types 123 | eval_sequences = [] 124 | for sequence_name in SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']: 125 | eval_sequences.append((sequence_name, EvalDatasetType.EDS)) 126 | for sequence_name in SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']: 127 | eval_sequences.append((sequence_name, EvalDatasetType.EC)) 128 | 129 | for eval_sequence in eval_sequences: 130 | sequence_name = eval_sequence[0] 131 | sequence_type = eval_sequence[1] 132 | 133 | gt_folder_name = 'eds' if sequence_type == EvalDatasetType.EDS else 'ec' 134 | track_data_gt = read_txt_results( 135 | Path('config/misc') / gt_folder_name / 'gt_tracks' / f"{sequence_name}.gt.txt" 136 | ) 137 | 138 | rows = {} 139 | for k in tables.keys(): 140 | rows[k] = [sequence_name] 141 | 142 | for method in methods: 143 | inlier_ratio_arr, fa_rel_nz_arr = [], [] 144 | 145 | track_data_pred = read_txt_results( 146 | str(args.results_dir / f"{method}" / f"{sequence_name}.txt") 147 | ) 148 | 149 | if track_data_pred[0, 1] != track_data_gt[0, 1]: 150 | raise ValueError # TODO: double check if this case occurs 151 | track_data_pred[:, 1] += -track_data_pred[0, 1] + track_data_gt[0, 1] 152 | 153 | for thresh in error_threshold_range: 154 | fa_rel, _ = compute_tracking_errors( 155 | track_data_pred, 156 | track_data_gt, 157 | error_threshold=thresh, 158 | asynchronous=False, 159 | ) 160 | 161 | inlier_ratio = np.sum(fa_rel > 0) / len(fa_rel) 162 | if inlier_ratio > 0: 163 | fa_rel_nz = fa_rel[np.nonzero(fa_rel)[0]] 164 | else: 165 | fa_rel_nz = [0] 166 | inlier_ratio_arr.append(inlier_ratio) 167 | fa_rel_nz_arr.append(np.mean(fa_rel_nz)) 168 | 169 | mean_inlier_ratio, std_inlier_ratio = np.mean(inlier_ratio_arr), np.std( 170 | inlier_ratio_arr 171 | ) 172 | mean_fa_rel_nz, std_fa_rel_nz = np.mean(fa_rel_nz_arr), np.std(fa_rel_nz_arr) 173 | expected_age = np.mean(np.array(inlier_ratio_arr) * np.array(fa_rel_nz_arr)) 174 | 175 | rows["age_mu"].append(mean_fa_rel_nz) 176 | rows["age_std"].append(std_fa_rel_nz) 177 | rows["inliers_mu"].append(mean_inlier_ratio) 178 | rows["inliers_std"].append(std_inlier_ratio) 179 | rows["expected_age"].append(expected_age) 180 | 181 | fa_rel, te = compute_tracking_errors( 182 | track_data_pred, track_data_gt, error_threshold=5, asynchronous=False 183 | ) 184 | inlier_ratio = np.sum(fa_rel > 0) / len(fa_rel) 185 | if inlier_ratio > 0: 186 | fa_rel_nz = fa_rel[np.nonzero(fa_rel)[0]] 187 | else: 188 | fa_rel_nz = [0] 189 | te = [0] 190 | 191 | mean_fa_rel_nz, std_fa_rel_nz = np.mean(fa_rel_nz), np.std(fa_rel_nz) 192 | mean_te, std_te = np.mean(te), np.std(te) 193 | rows["age_5_mu"].append(mean_fa_rel_nz) 194 | rows["age_5_std"].append(std_fa_rel_nz) 195 | rows["te_5_mu"].append(mean_te) 196 | rows["te_5_std"].append(std_te) 197 | 198 | # Load results 199 | for k in tables.keys(): 200 | tables[k].add_row(rows[k]) 201 | 202 | with open((args.results_dir / f"{method}" / f"benchmarking_results.csv"), "w") as f: 203 | for k in tables.keys(): 204 | f.write(f"{k}\n") 205 | f.write(tables[k].get_csv_string()) 206 | 207 | # Calculate and write mean values for EDS and EC 208 | eds_mean = calculate_mean_by_dataset_type(tables[k], EvalDatasetType.EDS) 209 | ec_mean = calculate_mean_by_dataset_type(tables[k], EvalDatasetType.EC) 210 | 211 | f.write(f"EDS Mean,{eds_mean}\n") 212 | f.write(f"EC Mean,{ec_mean}\n\n") 213 | 214 | print(tables[k].get_string()) 215 | print(f"EDS Mean: {eds_mean}") 216 | print(f"EC Mean: {ec_mean}\n") 217 | 218 | summary_csv_path = args.results_dir / f"{method}" / f"summary_results.csv" 219 | create_summary_csv(tables, summary_csv_path) 220 | 221 | print(f"Summary results written to {summary_csv_path}") -------------------------------------------------------------------------------- /scripts/benchmark_tap.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from pathlib import Path 4 | import numpy as np 5 | sys.path.append(str(Path(__file__).parent.parent)) 6 | import src.utils as utils 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description='Compute TAPVid metrics for point tracking') 10 | parser.add_argument('--gt_dir', type=str, default='data/e2d2/231025_110210_fidget5_high_exposure', 11 | help='Directory containing ground truth files') 12 | parser.add_argument('--pred_dir', type=str, default='output/inference/tap_e2d2', 13 | help='Directory containing prediction files') 14 | return parser.parse_args() 15 | 16 | if __name__ == '__main__': 17 | args = parse_args() 18 | gt_dir = Path(args.gt_dir) 19 | pred_dir = Path(args.pred_dir) 20 | error_threshold_range = 2 * np.array([1, 2, 4, 8, 16]) 21 | 22 | gt_tracks_path = gt_dir / 'gt_positions.npy' 23 | gt_tracks = np.load(gt_tracks_path) 24 | 25 | query_path = gt_dir / 'queries.npy' 26 | query_points = np.load(query_path) 27 | query_t = np.zeros((query_points.shape[0], 1)) # All query points are at t = 0 28 | query_points = np.concatenate([query_t, query_points], axis=1) 29 | 30 | pred_path = Path(pred_dir) / '231025_110210_fidget5_high_exposure.npz' 31 | pred = np.load(pred_path) 32 | coords_predicted = pred['coords_predicted'] 33 | vis_logits = pred['vis_logits'] 34 | vis_predicted = vis_logits > 0.8 35 | 36 | gt_tracks_formatted = np.expand_dims(np.transpose(gt_tracks, (1, 0, 2)), axis=0) 37 | coords_predicted_formatted = np.expand_dims(np.transpose(coords_predicted, (1, 0, 2)), axis=0) 38 | 39 | # Create occlusion masks (assuming no occlusions in gt, using vis_predicted for predictions) 40 | num_points, num_frames = gt_tracks.shape[1], gt_tracks.shape[0] 41 | occluded_gt = np.zeros((1, num_points, num_frames), dtype=bool) 42 | occluded_pred = np.expand_dims(~np.transpose(vis_predicted, (1, 0)), axis=0) # Assuming vis_predicted is visibility, not occlusion 43 | 44 | tap_metrics = utils.compute_tapvid_metrics( 45 | query_points=np.expand_dims(query_points, axis=0), # Add batch dimension 46 | gt_occluded=occluded_gt, 47 | gt_tracks=gt_tracks_formatted, 48 | pred_occluded=occluded_pred, 49 | pred_tracks=coords_predicted_formatted, 50 | query_mode='first', 51 | thresholds=error_threshold_range 52 | ) 53 | print("TAPVid Metrics:") 54 | for k, v in tap_metrics.items(): 55 | print(f"{k}: {v}") 56 | -------------------------------------------------------------------------------- /scripts/create_e2d2_fidget_spinner_gt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import h5py 5 | from tqdm import tqdm 6 | import torch 7 | from PIL import Image, ImageDraw, ImageFont 8 | from pathlib import Path 9 | import matplotlib.pyplot as plt 10 | from scipy.signal import find_peaks 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 12 | from src.representations import VoxelGrid 13 | from src.utils import Visualizer 14 | 15 | def make_grid(height, width, stride=1): 16 | x = np.arange(0, width, stride) 17 | y = np.arange(0, height, stride) 18 | X, Y = np.meshgrid(x, y) 19 | return np.stack([X.flatten(), Y.flatten()], axis=1) 20 | 21 | def read_binary_mask(path): 22 | img = np.array(Image.open(path).convert('L')) 23 | mask = (img != 255).astype(np.uint8) 24 | return mask 25 | 26 | def interpolate_wheel_points(query_points, center_point, timestamps, valleys, t_start, target_timestamps=None): 27 | """ 28 | Interpolate point positions on a rotating wheel. 29 | 30 | Args: 31 | query_points: numpy array of shape [N, 2] containing initial (x, y) positions at t_start 32 | center_point: tuple or array (x, y) representing wheel center 33 | timestamps: array of all timestamps used for valley detection 34 | valleys: indices of valleys in timestamps array (when wheel rotates exactly 1/3) 35 | t_start: initial timestamp 36 | target_timestamps: optional array of timestamps at which to calculate positions 37 | 38 | Returns: 39 | numpy array of shape [len(target_timestamps), N, 2] containing interpolated positions 40 | """ 41 | # Convert everything to numpy arrays and ensure correct shapes 42 | query_points = np.array(query_points) 43 | center_point = np.array(center_point) 44 | 45 | # Use target_timestamps if provided, otherwise use original timestamps 46 | target_timestamps = timestamps if target_timestamps is None else target_timestamps 47 | 48 | N = len(query_points) 49 | n_timestamps = len(target_timestamps) 50 | 51 | # Initialize output array 52 | new_positions = np.zeros((n_timestamps, N, 2)) 53 | 54 | # Calculate initial angles for each point 55 | initial_vectors = query_points - center_point 56 | initial_angles = np.arctan2(initial_vectors[:, 1], initial_vectors[:, 0]) 57 | radii = np.linalg.norm(initial_vectors, axis=1) 58 | 59 | # Calculate angular velocity between each pair of valleys using original timestamps 60 | valley_times = timestamps[valleys] 61 | angular_velocity = -2 * np.pi / 3 # negative for clockwise rotation 62 | 63 | # For each target timestamp, calculate the angle and new position 64 | for i, t in enumerate(target_timestamps): 65 | # Find the appropriate valley interval using original valley times 66 | if t < valley_times[0]: 67 | # Before first valley - interpolate from start 68 | delta_t = t - timestamps[0] 69 | fraction = delta_t / (valley_times[0] - timestamps[0]) 70 | angle_change = fraction * angular_velocity 71 | elif t > valley_times[-1]: 72 | # After last valley - extrapolate from last interval 73 | delta_t = t - valley_times[-1] 74 | last_interval = valley_times[-1] - valley_times[-2] 75 | angle_change = angular_velocity * (len(valleys) - 1 + delta_t / last_interval) 76 | else: 77 | # Between valleys - find appropriate interval and interpolate 78 | valley_idx = np.searchsorted(valley_times, t) - 1 79 | valley_idx = max(0, valley_idx) 80 | delta_t = t - valley_times[valley_idx] 81 | interval = valley_times[valley_idx + 1] - valley_times[valley_idx] 82 | fraction = delta_t / interval 83 | angle_change = angular_velocity * (valley_idx + fraction) 84 | 85 | # Calculate new angles and positions 86 | new_angles = initial_angles + angle_change 87 | 88 | # Convert polar coordinates back to Cartesian 89 | new_positions[i, :, 0] = center_point[0] + radii * np.cos(new_angles) 90 | new_positions[i, :, 1] = center_point[1] + radii * np.sin(new_angles) 91 | 92 | return new_positions 93 | 94 | if __name__ == '__main__': 95 | data_dir = Path('data/e2d2/231025_110210_fidget5_high_exposure') 96 | output_dir = Path('output/e2d2_gt/231025_110210_fidget5_high_exposure') 97 | 98 | data_path = data_dir / 'seq.h5' 99 | sequence_name = '231025_110210_fidget5_high_exposure' 100 | t_start = 3.3961115 101 | duration = 0.5 102 | t_delta = 0.001 103 | timestamps = np.arange(t_start, t_start + duration, t_delta) 104 | timestamps = (timestamps * 1e6).astype(int) 105 | N_events = 20000 106 | image_shape = (480, 640) 107 | converter = VoxelGrid(image_shape, num_bins=1) 108 | threshold_l2_norm = 300 109 | mask_path = data_dir / 'mask00004.png' 110 | query_stride = 40 111 | center_point = np.array([369, 229]) 112 | t_delta_gt = 0.0033 113 | timestamps_gt = np.arange(t_start, t_start + duration, t_delta_gt) 114 | 115 | output_dir = output_dir / f'{str(int(1e6 * t_delta_gt)).zfill(8)}' 116 | os.makedirs(output_dir, exist_ok=True) 117 | #os.makedirs(output_dir / 'histograms', exist_ok=True) 118 | 119 | # List to store L2 norms 120 | l2_norms = [] 121 | first_frame = None 122 | 123 | with h5py.File(data_path, 'r') as f: 124 | ev_idx = np.searchsorted(f['t'][:], timestamps) - 1 125 | ev_idx_start = ev_idx - N_events // 2 126 | ev_idx_end = ev_idx + N_events // 2 127 | 128 | vid = [] 129 | 130 | for i_start, i_end, t in tqdm(zip(ev_idx_start, ev_idx_end, timestamps), total=len(timestamps)): 131 | events = np.stack([f['y'][i_start:i_end], 132 | f['x'][i_start:i_end], 133 | f['t'][i_start:i_end], 134 | f['p'][i_start:i_end]], axis=-1) 135 | 136 | repr = converter(events) 137 | repr = repr[0] 138 | 139 | if first_frame is None: 140 | first_frame = repr.copy() 141 | 142 | l2_norm = np.sqrt(np.sum((repr - first_frame) ** 2)) 143 | l2_norms.append(l2_norm) 144 | 145 | if repr.max() != repr.min(): 146 | repr_norm = ((repr - repr.min()) / (repr.max() - repr.min()) * 255).astype(np.uint8) 147 | else: 148 | repr_norm = np.zeros_like(repr, dtype=np.uint8) 149 | 150 | # Convert to PIL Image 151 | img = Image.fromarray(repr_norm) 152 | 153 | # Add timestamp to image 154 | draw = ImageDraw.Draw(img) 155 | try: 156 | font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20) 157 | except: 158 | font = ImageFont.load_default() 159 | 160 | timestamp_sec = t / 1e6 # Convert to seconds for display 161 | draw.text((10, 10), f"{timestamp_sec:.3f}s", font=font, fill=255) 162 | 163 | frame_with_timestamp = np.array(img) 164 | vid.append(frame_with_timestamp) 165 | 166 | #filename = f"{t:012d}.png" 167 | #filepath = output_dir / 'histograms' / filename 168 | #img.save(filepath) 169 | 170 | video = np.stack(vid, axis=0) # Shape: [T, H, W] 171 | video = np.float32(video) 172 | video = np.stack([video, video, video], axis=1) # Shape: [T, 3, H, W] 173 | video = 255 * (video - video.min()) / (video.max() - video.min()) 174 | 175 | # Convert to numpy array and invert the signal 176 | l2_norms = np.array(l2_norms) 177 | inverted_signal = -l2_norms 178 | 179 | # Find valleys (peaks in the inverted signal) 180 | valleys, _ = find_peaks(inverted_signal, 181 | height=(-threshold_l2_norm, 0), 182 | distance=10, 183 | prominence=10) 184 | 185 | valleys = np.concatenate([np.zeros(1, dtype=valleys.dtype), valleys]) 186 | 187 | # Plot detected minima with scientific formatting 188 | plt.figure(figsize=(12, 6)) 189 | timestamps_sec = timestamps / 1e6 190 | plt.rcParams.update({ 191 | 'font.family': 'DejaVu Serif', 192 | 'font.size': 14, 193 | 'axes.labelsize': 16, 194 | 'axes.titlesize': 16, 195 | 'xtick.labelsize': 14, 196 | 'ytick.labelsize': 14, 197 | 'legend.fontsize': 14 198 | }) 199 | 200 | plt.plot(timestamps_sec, l2_norms, label='L2 Norm') 201 | plt.plot(timestamps_sec[valleys], l2_norms[valleys], 'r*', label='Detected Valleys') 202 | plt.axhline(y=threshold_l2_norm, color='r', linestyle='--', 203 | label=f'Threshold ({threshold_l2_norm})') 204 | plt.xlabel('Time (s)') 205 | plt.ylabel('L2 Norm') 206 | plt.grid(True) 207 | plt.legend() 208 | plt.gca().xaxis.set_major_formatter(plt.FormatStrFormatter('%.3f')) 209 | plt.xticks(rotation=45) 210 | plt.tight_layout() 211 | 212 | plt.savefig(output_dir / 'l2_norms.pdf', format='pdf', bbox_inches='tight') 213 | plt.close() 214 | 215 | # Make queries 216 | height, width = image_shape 217 | query_xy = make_grid(height, width, stride=query_stride) 218 | num_queries = query_xy.shape[0] 219 | 220 | if mask_path is not None: 221 | segm_mask = read_binary_mask(mask_path) 222 | query_x = query_xy[:, 0] 223 | query_y = query_xy[:, 1] 224 | segm_mask = segm_mask[query_y, query_x] 225 | query_xy = query_xy[segm_mask == 1] 226 | 227 | # Interpolate positions for visualization 228 | new_positions = interpolate_wheel_points(query_xy, center_point, timestamps, valleys, t_start) 229 | 230 | # Calculate positions at GT timestamps 231 | timestamps_gt_us = (timestamps_gt * 1e6).astype(int) # Convert to same units as timestamps 232 | new_positions_gt = interpolate_wheel_points(query_xy, center_point, timestamps, valleys, t_start, 233 | target_timestamps=timestamps_gt_us) 234 | np.save(output_dir / 'gt_positions.npy', new_positions_gt) 235 | np.save(output_dir / 'gt_timestamps.npy', timestamps_gt_us) 236 | np.save(output_dir / 'queries.npy', query_xy) 237 | 238 | viz = Visualizer(save_dir=str(output_dir), 239 | pad_value=0, 240 | linewidth=2, 241 | tracks_leave_trace=-1, 242 | show_first_frame=0) 243 | rgbs = viz.visualize( 244 | torch.from_numpy(video[None]), 245 | torch.from_numpy(new_positions[None]) 246 | ) 247 | 248 | print('Done.') -------------------------------------------------------------------------------- /scripts/create_evimo2_track_gt.py: -------------------------------------------------------------------------------- 1 | # This script calculates the ground truth point tracks for EVIMO datasets. 2 | import argparse 3 | import multiprocessing 4 | import os 5 | from pathlib import Path 6 | os.environ['OPENBLAS_NUM_THREADS'] = '1' 7 | from multiprocessing import Pool 8 | import cv2 9 | import numpy as np 10 | from scipy.spatial.transform import Rotation 11 | from scipy.spatial.transform import Slerp 12 | from scipy.linalg import logm 13 | from tqdm import tqdm 14 | import h5py 15 | import multiprocessing.pool as mpp 16 | import yaml 17 | import pandas as pd 18 | 19 | def istarmap(self, func, iterable, chunksize=1): 20 | self._check_running() 21 | if chunksize < 1: 22 | raise ValueError( 23 | "Chunksize must be 1+, not {0:n}".format( 24 | chunksize)) 25 | 26 | task_batches = mpp.Pool._get_tasks(func, iterable, chunksize) 27 | result = mpp.IMapIterator(self) 28 | self._taskqueue.put( 29 | ( 30 | self._guarded_task_generation(result._job, 31 | mpp.starmapstar, 32 | task_batches), 33 | result._set_length 34 | )) 35 | return (item for chunk in result for item in chunk) 36 | 37 | mpp.Pool.istarmap = istarmap 38 | 39 | def interpolate_pose(t, pose): 40 | right_i = np.searchsorted(pose[:, 0], t) 41 | if right_i==pose.shape[0]: 42 | return None 43 | if right_i==0: 44 | return None 45 | 46 | left_t = pose[right_i-1, 0] 47 | right_t = pose[right_i, 0] 48 | 49 | alpha = (t - left_t) / (right_t - left_t) 50 | if alpha>1: 51 | return None 52 | elif alpha < 0: 53 | return None 54 | 55 | left_position = pose[right_i - 1, 1:4] 56 | right_position = pose[right_i, 1:4] 57 | 58 | position_interp = alpha * (right_position - left_position) + left_position 59 | 60 | left_right_rot_stack = Rotation.from_quat((pose[right_i - 1, 4:8], 61 | pose[right_i, 4:8])) 62 | 63 | slerp = Slerp((0, 1), left_right_rot_stack) 64 | R_interp = slerp(alpha) 65 | 66 | return np.array([t,] + list(position_interp) + list(R_interp.as_quat())) 67 | 68 | def apply_transform(T_cb, T_ba): 69 | R_ba = Rotation.from_quat(T_ba[4:8]) 70 | t_ba = T_ba[1:4] 71 | 72 | R_cb = Rotation.from_quat(T_cb[4:8]) 73 | t_cb = T_cb[1:4] 74 | 75 | R_ca = R_cb * R_ba 76 | t_ca = R_cb.as_matrix() @ t_ba + t_cb 77 | return np.array([T_ba[0],] + list(t_ca) + list(R_ca.as_quat())) 78 | 79 | def inv_transform(T_ba): 80 | R_ba = Rotation.from_quat(T_ba[4:8]) 81 | t_ba = T_ba[1:4] 82 | 83 | R_ab = R_ba.inv() 84 | t_ab = -R_ba.inv().as_matrix() @ t_ba 85 | 86 | return np.array([T_ba[0],] + list(t_ab) + list(R_ab.as_quat())) 87 | 88 | def project_points_radtan(points, 89 | fx, fy, cx, cy, 90 | k1, k2, p1, p2): 91 | x_ = np.divide(points[:, :, 0], points[:, :, 2], out=np.zeros_like(points[:, :, 0]), where=points[:, :, 2]!=0) 92 | y_ = np.divide(points[:, :, 1], points[:, :, 2], out=np.zeros_like(points[:, :, 1]), where=points[:, :, 2]!=0) 93 | 94 | r2 = np.square(x_) + np.square(y_) 95 | r4 = np.square(r2) 96 | 97 | dist = (1.0 + k1 * r2 + k2 * r4) 98 | 99 | x__ = x_ * dist + 2.0 * p1 * x_ * y_ + p2 * (r2 + 2.0 * x_ * x_) 100 | y__ = y_ * dist + 2.0 * p2 * x_ * y_ + p1 * (r2 + 2.0 * y_ * y_) 101 | 102 | u = fx * x__ + cx 103 | v = fy * y__ + cy 104 | 105 | return u, v 106 | 107 | def get_all_poses(meta): 108 | vicon_pose_samples = len(meta['full_trajectory']) 109 | 110 | poses = {} 111 | key_i = {} 112 | for key in meta['full_trajectory'][0].keys(): 113 | if key == 'id' or key == 'ts' or key == 'gt_frame': 114 | continue 115 | poses[key] = np.zeros((vicon_pose_samples, 1+3+4)) 116 | key_i[key] = 0 117 | 118 | for all_pose in meta['full_trajectory']: 119 | for key in poses.keys(): 120 | if key == 'id' or key == 'ts' or key == 'gt_frame': 121 | continue 122 | 123 | if key in all_pose: 124 | i = key_i[key] 125 | poses[key][i, 0] = all_pose['ts'] 126 | poses[key][i, 1] = all_pose[key]['pos']['t']['x'] 127 | poses[key][i, 2] = all_pose[key]['pos']['t']['y'] 128 | poses[key][i, 3] = all_pose[key]['pos']['t']['z'] 129 | poses[key][i, 4] = all_pose[key]['pos']['q']['x'] 130 | poses[key][i, 5] = all_pose[key]['pos']['q']['y'] 131 | poses[key][i, 6] = all_pose[key]['pos']['q']['z'] 132 | poses[key][i, 7] = all_pose[key]['pos']['q']['w'] 133 | key_i[key] += 1 134 | 135 | for key in poses.keys(): 136 | poses[key] = poses[key][:key_i[key], :] 137 | 138 | return poses 139 | 140 | def get_intrinsics(meta): 141 | meta_meta = meta['meta'] 142 | K = np.array(((meta_meta['fx'], 0, meta_meta['cx']), 143 | ( 0, meta_meta['fy'], meta_meta['cy']), 144 | ( 0, 0, 1))) 145 | 146 | dist_coeffs = np.array((meta_meta['k1'], 147 | meta_meta['k2'], 148 | meta_meta['p1'], 149 | meta_meta['p2'])) 150 | 151 | return K, dist_coeffs 152 | 153 | def load_data(file): 154 | meta = np.load(Path(file) / 'dataset_info.npz', allow_pickle=True)['meta'].item() 155 | depth = np.load(Path(file) / 'dataset_depth.npz') 156 | mask = np.load(Path(file) / 'dataset_mask.npz') 157 | return meta, depth, mask 158 | 159 | def convert(file, overwrite=False, max_m_per_s=9.0, max_norm_deg_per_s=6.25*360): 160 | cv2.setNumThreads(1) 161 | 162 | h5_tracks_file_name = Path(file) / 'dataset_tracks.h5' 163 | 164 | if not overwrite and h5_tracks_file_name.exists(): 165 | print(f'skipping {file} because {h5_tracks_file_name} exists') 166 | return 167 | 168 | # Load data 169 | meta, depth, mask = load_data(file) 170 | all_poses = get_all_poses(meta) 171 | K, dist_coeffs = get_intrinsics(meta) 172 | 173 | # Get depth shape for map initialization 174 | first_depth_key = 'depth_' + str(0).rjust(10, '0') 175 | depth_shape = depth[first_depth_key].shape 176 | 177 | # Initialize undistortion maps 178 | map1, map2 = cv2.initInverseRectificationMap( 179 | K, 180 | dist_coeffs, 181 | np.eye(3), 182 | np.eye(3), 183 | (depth_shape[1], depth_shape[0]), 184 | cv2.CV_32FC1) 185 | 186 | # Get info from initial frame 187 | initial_frame_idx = 1 # For frame 0 there are sometimes not previous poses so we start from frame 1 188 | first_frame_id = meta['frames'][initial_frame_idx]['id'] # Use second frame 189 | first_frame_info = meta['frames'][initial_frame_idx] 190 | first_depth_key = 'depth_' + str(initial_frame_idx).rjust(10, '0') # Use depth from second frame 191 | first_mask_key = 'mask_' + str(initial_frame_idx).rjust(10, '0') # Use mask from second frame 192 | 193 | depth_frame = depth[first_depth_key].astype(np.float32) / 1000.0 # convert to meters 194 | mask_frame = mask[first_mask_key] 195 | 196 | # Initialize points to track (e.g., all points with valid depth) 197 | valid_points = np.where(depth_frame > 0) # [y, x] 198 | valid_points = (valid_points[0][::1000], valid_points[1][::1000]) # for debugging subsample a little bit 199 | 200 | initial_points = np.stack([valid_points[1], valid_points[0]], axis=1) # nx2 array of x,y coords 201 | initial_depths = depth_frame[valid_points] 202 | initial_masks = mask_frame[valid_points] 203 | 204 | # Initialize storage for tracks 205 | num_frames = len(meta['frames']) 206 | num_points = len(initial_points) 207 | tracks = np.full((num_frames, num_points, 2), np.nan, dtype=np.float32) # store x,y coordinates 208 | occlusions = np.ones((num_frames, num_points), dtype=bool) # True means occluded 209 | times = np.full(num_frames, np.nan, dtype=np.float32) 210 | 211 | # Store initial positions at index 0 212 | tracks[initial_frame_idx] = initial_points 213 | occlusions[initial_frame_idx] = False 214 | times[initial_frame_idx] = meta['frames'][initial_frame_idx]['ts'] 215 | 216 | # Unproject initial points to 3D 217 | Z_m = initial_depths 218 | X_m = map1[valid_points] * Z_m 219 | Y_m = map2[valid_points] * Z_m 220 | initial_XYZ = np.stack((X_m, Y_m, Z_m), axis=1) 221 | 222 | depth_keys = sorted([k for k in depth.files if k.startswith('depth_')]) 223 | frame_numbers = [int(k.split('_')[1]) for k in depth_keys] 224 | 225 | # Get initial poses for each object 226 | initial_time = first_frame_info['cam']['ts'] 227 | initial_poses = {} 228 | for key in all_poses: 229 | if key != 'cam': 230 | initial_poses[key] = interpolate_pose(initial_time, all_poses[key]) 231 | 232 | # Track through all frames starting from the third frame 233 | found_gap = False 234 | for frame_idx, frame_info in enumerate(meta['frames'][initial_frame_idx + 1:]): 235 | frame_idx += (initial_frame_idx + 1) # Start from index 1 since we used frame 1 as our initial frame 236 | 237 | if frame_numbers[frame_idx] != frame_numbers[frame_idx - 1] + 1: 238 | print(f"Gap found in depth maps between frames {frame_numbers[frame_idx-1]} and {frame_numbers[frame_idx]}") 239 | found_gap = True 240 | break 241 | 242 | times[frame_idx] = frame_info['ts'] 243 | frame_id = frame_info['id'] - first_frame_id 244 | current_depth_key = 'depth_' + str(frame_id).rjust(10, '0') 245 | current_mask_key = 'mask_' + str(frame_id).rjust(10, '0') 246 | 247 | current_depth = depth[current_depth_key].astype(np.float32) / 1000.0 248 | current_mask = mask[current_mask_key] 249 | 250 | # Get current poses for each object 251 | frame_time = frame_info['cam']['ts'] 252 | frame_poses = {} 253 | for key in all_poses: 254 | if key != 'cam': 255 | frame_poses[key] = interpolate_pose(frame_time, all_poses[key]) 256 | 257 | # Transform and project points 258 | for point_idx in range(num_points): 259 | object_id = initial_masks[point_idx] // 1000 260 | object_key = str(object_id) 261 | 262 | if object_key not in frame_poses or object_key not in initial_poses: 263 | continue 264 | 265 | # Get initial and current object poses 266 | T_c1o = initial_poses[object_key] 267 | T_c2o = frame_poses[object_key] 268 | 269 | if T_c1o is None or T_c2o is None: 270 | continue 271 | 272 | # Calculate relative transform from initial camera frame to current camera frame 273 | T_c2c1 = apply_transform(T_c2o, inv_transform(T_c1o)) 274 | 275 | # Check velocity to detect potential tracking loss 276 | dt = frame_time - initial_time 277 | if dt > 0: 278 | v = np.linalg.norm(T_c2c1[1:4]) / dt 279 | R_matrix = Rotation.from_quat(T_c2c1[4:8]).as_matrix() 280 | w_matrix = logm(R_matrix) / dt 281 | w = np.array([-w_matrix[1,2], w_matrix[0, 2], -w_matrix[0, 1]]) 282 | w_deg = np.linalg.norm(w) * 180 / np.pi 283 | 284 | if v > max_m_per_s or w_deg > max_norm_deg_per_s: 285 | continue 286 | 287 | # Transform point using relative transform 288 | point_3d = initial_XYZ[point_idx] 289 | R = Rotation.from_quat(T_c2c1[4:8]).as_matrix() 290 | t = T_c2c1[1:4] 291 | transformed_point = (R @ point_3d) + t 292 | 293 | # Project to image plane 294 | px, py = project_points_radtan(transformed_point[None, None], 295 | K[0, 0], K[1,1], K[0, 2], K[1, 2], 296 | *dist_coeffs) 297 | px = px[0, 0] 298 | py = py[0, 0] 299 | 300 | is_occluded = False 301 | 302 | # Check if point projects outside image bounds 303 | if px < 0 or px >= current_depth.shape[1] or py < 0 or py >= current_depth.shape[0]: 304 | is_occluded = True 305 | else: 306 | px_int, py_int = int(px), int(py) 307 | 308 | # Check if point is occluded using depth test 309 | actual_depth = current_depth[py_int, px_int] 310 | if actual_depth > 0 and abs(actual_depth - transformed_point[2]) > 0.03: # 3cm threshold 311 | is_occluded = True 312 | 313 | # Check if point projects onto the correct object 314 | if current_mask[py_int, px_int] != initial_masks[point_idx]: 315 | is_occluded = True 316 | 317 | tracks[frame_idx, point_idx] = [px, py] 318 | occlusions[frame_idx, point_idx] = is_occluded 319 | 320 | # Find the last valid timestamp if we found a gap 321 | if found_gap: 322 | valid_indices = ~np.isnan(times) 323 | last_valid_idx = np.where(valid_indices)[0][-1] 324 | else: 325 | last_valid_idx = num_frames - 1 326 | 327 | times = times[initial_frame_idx:last_valid_idx + 1] 328 | tracks = tracks[initial_frame_idx:last_valid_idx + 1] 329 | occlusions = occlusions[initial_frame_idx:last_valid_idx + 1] 330 | 331 | with h5py.File(h5_tracks_file_name, 'w') as f: 332 | f.create_dataset('tracks', data=tracks, compression='gzip', compression_opts=9) 333 | f.create_dataset('occlusions', data=occlusions, compression='gzip', compression_opts=9) 334 | f.create_dataset('initial_points', data=initial_points, compression='gzip', compression_opts=9) 335 | f.create_dataset('initial_masks', data=initial_masks, compression='gzip', compression_opts=9) 336 | f.create_dataset('times', data=times, compression='gzip', compression_opts=9) 337 | 338 | def process_with_error_handling(args): 339 | try: 340 | convert(*args) 341 | return True 342 | except Exception as e: 343 | print(f"Error processing {args[0]}: {str(e)}") 344 | return False 345 | 346 | if __name__ == '__main__': 347 | parser = argparse.ArgumentParser(epilog='Calculates optical flow from EVIMO datasets.') 348 | parser.add_argument('--dt', dest='dt', type=float, default=0.01, 349 | help='dt for flow approximation') 350 | parser.add_argument('--overwrite', dest='overwrite', action='store_true', 351 | help='Overwrite existing output files') 352 | parser.add_argument('--max_m_per_s', dest='max_m_per_s', type=float, default=9.0, 353 | help='Maximum meters per second of linear velocity') 354 | parser.add_argument('--max_norm_deg_per_s', dest='max_norm_deg_per_s', type=float, default=6.25*360, 355 | help='Maximum normed degrees per second of angular velocity') 356 | parser.add_argument('--data_root', type=str, required=True, 357 | help='Root directory for dataset paths') 358 | parser.add_argument('--config', type=str, default='config/misc/evimo2/val_samples.csv', 359 | help='YAML config file with dataset paths') 360 | parser.add_argument('--split', type=str, default='val_samples', 361 | help='Dataset split to process from YAML') 362 | 363 | args = parser.parse_args() 364 | files = list(pd.read_csv(args.config)['name']) 365 | 366 | p_args_list = [[ 367 | Path(args.data_root) / f, 368 | args.overwrite, 369 | args.max_m_per_s, 370 | args.max_norm_deg_per_s 371 | ] for f in files] 372 | 373 | with Pool(multiprocessing.cpu_count()) as p: 374 | results = list(tqdm( 375 | p.imap_unordered(process_with_error_handling, p_args_list), 376 | total=len(p_args_list), 377 | desc='Sequences' 378 | )) 379 | 380 | successes = sum(results) 381 | print(f"Completed {successes}/{len(p_args_list)} sequences") -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | 8 | sys.path.append(str(Path(__file__).parent.parent)) 9 | from src.representations import MixedDensityEventStack 10 | from src.model.etap.model import Etap 11 | from src.utils import Visualizer 12 | 13 | def normalize_and_expand_channels(image): 14 | if not isinstance(image, torch.Tensor): 15 | raise TypeError("Input must be a torch tensor") 16 | 17 | *batch_dims, height, width = image.shape 18 | 19 | if len(image.shape) < 2: 20 | raise ValueError("Input tensor must have at least shape (..., height, width)") 21 | 22 | image_flat = image.view(-1, height, width) 23 | 24 | min_val = image_flat.min() 25 | max_val = image_flat.max() 26 | range_val = max_val - min_val 27 | range_val[range_val == 0] = 1 28 | 29 | image_normalized = (image_flat - min_val) / range_val * 255 30 | image_rgb_flat = image_normalized.to(torch.uint8).unsqueeze(1).repeat(1, 3, 1, 1) 31 | image_rgb = image_rgb_flat.view(*batch_dims, 3, height, width) 32 | 33 | return image_rgb 34 | 35 | if __name__ == '__main__': 36 | device = 'cpu' 37 | data_dir = Path('data/demo_example') 38 | ckpt_path = Path('weights/ETAP_v1_cvpr25.pth') 39 | output_dir = Path(f'output/{data_dir.name}') 40 | num_bins = 10 41 | num_events = 60000 42 | height, width = 480, 640 43 | t_start = 1/30 # seconds 44 | t_end = 1.5 # seconds 45 | t_delta = 1/60 # seconds 46 | 47 | # Object to convert raw event data into grid representations 48 | converter = MixedDensityEventStack( 49 | image_shape=(height, width), 50 | num_stacks=num_bins, 51 | interpolation='bilinear', 52 | channel_overlap=True, 53 | centered_channels=False 54 | ) 55 | 56 | # Load the model 57 | tracker = Etap(num_in_channels=num_bins, stride=4, window_len=8) 58 | weights = torch.load(ckpt_path, map_location='cpu', weights_only=True) 59 | tracker.load_state_dict(weights) 60 | tracker = tracker.to(device) 61 | tracker.eval() 62 | 63 | # Let's choose some timestamps at which we create the frame representations 64 | tracking_timestamps = np.arange(t_start, t_end, t_delta) # seconds 65 | 66 | # Load raw event data 67 | xy = np.load(data_dir / 'dataset_events_xy.npy', mmap_mode='r') 68 | p = np.load(data_dir / 'dataset_events_p.npy', mmap_mode='r') 69 | t = np.load(data_dir / 'dataset_events_t.npy', mmap_mode='r') 70 | 71 | assert t_start > t[0], "Start time must be greater than the first event timestamp" 72 | assert t_end < t[-1], "End time must be less than the last event timestamp" 73 | assert t_delta > 0, "Time delta must be greater than zero" 74 | assert t_start < t_end, "Start time must be less than end time" 75 | assert xy.shape[0] == p.shape[0] == t.shape[0], "Event data arrays must have the same length" 76 | 77 | event_indices = np.searchsorted(t, tracking_timestamps) 78 | event_representations = [] 79 | 80 | # At each tracking timestep, we take the last num_events events and convert 81 | # them into a grid representation. 82 | for i_end in tqdm(event_indices, desc='Creating grid representations'): 83 | i_start = max(i_end - num_events, 0) 84 | 85 | events = np.stack([xy[i_start:i_end, 1], 86 | xy[i_start:i_end, 0], 87 | t[i_start:i_end], 88 | p[i_start:i_end]], axis=1) 89 | ev_repr = converter(events) 90 | event_representations.append(ev_repr) 91 | 92 | voxels = np.stack(event_representations, axis=0) 93 | voxels = torch.from_numpy(voxels)[None].float().to(device) 94 | 95 | # Now let's determine some queries, meaning the initial positions 96 | # of the points to track. 97 | x = np.arange(0, width, 32) 98 | y = np.arange(0, height, 32) 99 | X, Y = np.meshgrid(x, y) 100 | grid = np.stack([X.flatten(), Y.flatten()], axis=1) 101 | query_xy = torch.from_numpy(grid).float().to(device) 102 | num_queries = query_xy.shape[0] 103 | # We track all points from the beginning (t=0) 104 | query_t = torch.zeros(num_queries, dtype=torch.int64, device=device) 105 | queries = torch.cat([query_t[:, None], query_xy], dim=1) 106 | queries = queries[None].to(device) 107 | 108 | with torch.no_grad(): 109 | result = tracker(voxels, queries, iters=6) 110 | predictions, visibility = result['coords_predicted'], result['vis_predicted'] 111 | 112 | visibility = visibility > 0.8 113 | 114 | # Visualization 115 | projection = voxels.sum(2).cpu() 116 | lower_bound = torch.tensor(np.percentile(projection.cpu().numpy(), 2)) 117 | upper_bound = torch.tensor(np.percentile(projection.cpu().numpy(), 98)) 118 | projection_clipped = torch.clamp(projection, lower_bound, upper_bound) 119 | video = normalize_and_expand_channels(projection_clipped) 120 | 121 | viz = Visualizer(save_dir=output_dir, tracks_leave_trace=-1) 122 | viz.visualize( 123 | video=video, 124 | tracks=predictions, 125 | visibility=visibility, 126 | filename=f"pred", 127 | ) 128 | print('Done.') -------------------------------------------------------------------------------- /scripts/download_eds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to download and extract EDS dataset sequences 4 | # Usage: ./download_eds_data.sh [output_directory] 5 | 6 | set -e # Exit on error 7 | 8 | OUTPUT_DIR="$(pwd)/data/eds" 9 | if [ "$1" != "" ]; then 10 | OUTPUT_DIR="$1" 11 | fi 12 | 13 | mkdir -p "$OUTPUT_DIR" 14 | echo "Downloading datasets to: $OUTPUT_DIR" 15 | 16 | BASE_URL="https://download.ifi.uzh.ch/rpg/eds/dataset" 17 | 18 | SEQUENCES=( 19 | "01_peanuts_light" 20 | "02_rocket_earth_light" 21 | "08_peanuts_running" 22 | "14_ziggy_in_the_arena" 23 | ) 24 | 25 | for seq in "${SEQUENCES[@]}"; do 26 | echo "=======================================================" 27 | echo "Processing sequence: $seq" 28 | 29 | mkdir -p "$OUTPUT_DIR/$seq" 30 | 31 | URL="$BASE_URL/$seq/$seq.tgz" 32 | TGZ_FILE="$OUTPUT_DIR/$seq/$seq.tgz" 33 | 34 | echo "Downloading from: $URL" 35 | wget -c "$URL" -O "$TGZ_FILE" 36 | 37 | echo "Extracting..." 38 | tar -xzf "$TGZ_FILE" -C "$OUTPUT_DIR/$seq" 39 | 40 | echo "Removing archive..." 41 | rm "$TGZ_FILE" 42 | 43 | echo "Done with $seq" 44 | echo "" 45 | done 46 | 47 | echo "All sequences have been downloaded and extracted." 48 | echo "Data is available in: $OUTPUT_DIR" 49 | 50 | echo "=======================================================" 51 | echo "Summary of downloaded data:" 52 | for seq in "${SEQUENCES[@]}"; do 53 | echo "$seq:" 54 | ls -la "$OUTPUT_DIR/$seq" | grep -v "^total" 55 | echo "" 56 | done -------------------------------------------------------------------------------- /scripts/inference_offline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import sys 4 | 5 | import yaml 6 | from tqdm import tqdm 7 | import torch 8 | import numpy as np 9 | import pandas as pd 10 | 11 | sys.path.append(str(Path(__file__).parent.parent)) 12 | from src.data.modules import DataModuleFactory 13 | from src.model.etap.model import Etap 14 | from src.utils import Visualizer, compute_tapvid_metrics 15 | 16 | 17 | def normalize_and_expand_channels(image): 18 | if not isinstance(image, torch.Tensor): 19 | raise TypeError("Input must be a torch tensor") 20 | 21 | *batch_dims, height, width = image.shape 22 | 23 | if len(image.shape) < 2: 24 | raise ValueError("Input tensor must have at least shape (..., height, width)") 25 | 26 | image_flat = image.view(-1, height, width) 27 | 28 | min_val = image_flat.min() 29 | max_val = image_flat.max() 30 | range_val = max_val - min_val 31 | range_val[range_val == 0] = 1 32 | 33 | image_normalized = (image_flat - min_val) / range_val * 255 34 | image_rgb_flat = image_normalized.to(torch.uint8).unsqueeze(1).repeat(1, 3, 1, 1) 35 | image_rgb = image_rgb_flat.view(*batch_dims, 3, height, width) 36 | 37 | return image_rgb 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--config', type=str, default='config/exe/test_event_kubric/debug.yaml') 42 | parser.add_argument('--device', type=str, default='cpu') 43 | args = parser.parse_args() 44 | 45 | # Load and process config 46 | with open(args.config, "r") as f: 47 | config = yaml.safe_load(f) 48 | 49 | output_dir = config['common'].get('output_dir', 'output/inference') 50 | output_dir = Path(output_dir) / config['common']['exp_name'] 51 | output_dir.mkdir(parents=True, exist_ok=True) 52 | checkpoint_path = config['common'].get('checkpoint') 53 | 54 | data_module = DataModuleFactory.create(config['data']) 55 | data_module.prepare_data() 56 | test_set = data_module.test_dataset 57 | 58 | model_config = config['model'] 59 | model_config['model_resolution'] = (512, 512) 60 | tracker = Etap(**model_config) 61 | weights = torch.load(checkpoint_path, map_location='cpu', weights_only=True) 62 | tracker.load_state_dict(weights) 63 | tracker = tracker.to(args.device) 64 | tracker.eval() 65 | 66 | viz = Visualizer(save_dir=output_dir, tracks_leave_trace=-1) 67 | 68 | seq_names = [] 69 | sequence_metrics = [] 70 | 71 | for i, sample in tqdm(enumerate(test_set), total=len(test_set), desc="Testing"): 72 | voxels = sample.voxels[None].to(args.device) 73 | trajs_g = sample.trajectory[None].to(args.device) 74 | vis_g = sample.visibility[None].to(args.device) 75 | queries = sample.query_points[None].to(args.device) 76 | B, T, C, H, W = voxels.shape 77 | _, _, N, D = trajs_g.shape 78 | 79 | with torch.no_grad(): 80 | result = tracker(voxels, queries, iters=6) 81 | predictions, visibility = result['coords_predicted'], result['vis_predicted'] 82 | 83 | if visibility.dtype != torch.bool: 84 | visibility = visibility > 0.8 85 | 86 | # Visualization 87 | projection = voxels.sum(2).cpu() 88 | lower_bound = torch.tensor(np.percentile(projection.cpu().numpy(), 2)) 89 | upper_bound = torch.tensor(np.percentile(projection.cpu().numpy(), 98)) 90 | projection_clipped = torch.clamp(projection, lower_bound, upper_bound) 91 | video = normalize_and_expand_channels(projection_clipped) 92 | 93 | viz.visualize( 94 | video=video, 95 | tracks=predictions, 96 | visibility=visibility, 97 | filename=f"pred_{sample.seq_name}", 98 | ) 99 | viz.visualize( 100 | video=video, 101 | tracks=trajs_g, 102 | visibility=vis_g, 103 | filename=f"gt_{sample.seq_name}", 104 | ) 105 | 106 | # Calculate metrics for this sequence 107 | queries_np = queries.cpu().numpy() 108 | trajs_g_np = trajs_g.cpu().numpy().transpose(0, 2, 1, 3) 109 | gt_occluded_np = ~vis_g.cpu().numpy().transpose(0, 2, 1) 110 | trajs_pred_np = predictions.cpu().numpy().transpose(0, 2, 1, 3) 111 | pred_occluded_np = ~visibility.cpu().numpy().transpose(0, 2, 1) 112 | 113 | seq_metrics = compute_tapvid_metrics( 114 | query_points=queries_np, 115 | gt_occluded=gt_occluded_np, 116 | gt_tracks=trajs_g_np, 117 | pred_occluded=pred_occluded_np, 118 | pred_tracks=trajs_pred_np, 119 | query_mode="first" 120 | ) 121 | 122 | seq_names.append(sample.seq_name) 123 | sequence_metrics.append(seq_metrics) 124 | 125 | data = [] 126 | for seq_name, seq_metrics in zip(seq_names, sequence_metrics): 127 | row = {'sequence': seq_name} 128 | row.update(seq_metrics) 129 | data.append(row) 130 | 131 | avg_metrics = {} 132 | metric_keys = sequence_metrics[0].keys() 133 | for key in metric_keys: 134 | values = [metrics[key] for metrics in sequence_metrics] 135 | avg_metrics[key] = np.mean(values) 136 | 137 | avg_row = {'sequence': 'average'} 138 | avg_row.update(avg_metrics) 139 | data.append(avg_row) 140 | 141 | df = pd.DataFrame(data) 142 | csv_path = output_dir / 'sequence_metrics.csv' 143 | df.to_csv(csv_path, index=False) 144 | 145 | print(f'Metrics saved to: {csv_path}') 146 | print('Done.') 147 | 148 | if __name__ == '__main__': 149 | main() -------------------------------------------------------------------------------- /scripts/inference_online.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | from pathlib import Path 4 | import sys 5 | import numpy as np 6 | 7 | import yaml 8 | import torch 9 | from tqdm import tqdm 10 | 11 | sys.path.append(str(Path(__file__).parent.parent)) 12 | 13 | from src.data.modules import DataModuleFactory 14 | from src.model.etap.model import Etap 15 | from src.utils import Visualizer, normalize_and_expand_channels, make_grid 16 | 17 | torch.set_float32_matmul_precision('high') 18 | 19 | 20 | def write_points_to_file(points, timestamps, filepath): 21 | """Write tracking points to a file.""" 22 | T, N, _ = points.shape 23 | 24 | with open(filepath, 'w') as f: 25 | for t in range(T): 26 | for n in range(N): 27 | x, y = points[t, n] 28 | f.write(f"{n} {timestamps[t]:.9f} {x:.9f} {y:.9f}\n") 29 | 30 | 31 | def get_git_commit_hash(): 32 | """Get the current git commit hash.""" 33 | try: 34 | commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD'], stderr=subprocess.STDOUT).decode().strip() 35 | return commit_hash 36 | except subprocess.CalledProcessError as e: 37 | print(f"Error obtaining git commit hash: {e.output.decode().strip()}") 38 | return "unknown" 39 | 40 | 41 | def normalize_voxels(voxels): 42 | """Perform channelwise std-mean normalization on voxels.""" 43 | mask = voxels != 0 44 | mean = voxels.sum(dim=(0, 2, 3), keepdim=True) / mask.sum(dim=(0, 2, 3), keepdim=True) 45 | var = ((voxels - mean)**2 * mask).sum(dim=(0, 2, 3), keepdim=True) / mask.sum(dim=(0, 2, 3), keepdim=True) 46 | std = torch.sqrt(var + 1e-8) 47 | return torch.where(mask, (voxels - mean) / std, voxels) 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('--config', type=str, default='config/exe/inference_online/feature_tracking.yaml') 53 | parser.add_argument('--device', type=str, default='cuda:0') 54 | args = parser.parse_args() 55 | 56 | config_path = Path(args.config) 57 | with open(config_path, "r") as f: 58 | config = yaml.safe_load(f) 59 | 60 | project_root = Path(__file__).parent.parent 61 | save_dir = project_root / 'output' / 'inference' / config['common']['exp_name'] 62 | save_dir.mkdir(parents=True, exist_ok=True) 63 | 64 | config['runtime_info'] = { 65 | 'command': ' '.join(sys.argv), 66 | 'git_commit': get_git_commit_hash() 67 | } 68 | config_save_path = save_dir / 'config.yaml' 69 | with open(config_save_path, 'w') as f: 70 | yaml.dump(config, f, default_flow_style=False) 71 | 72 | add_support_points = config['common'].get('add_support_points', False) 73 | 74 | if add_support_points: 75 | support_point_stride = config['common'].get('support_point_stride', 20) 76 | height, width = config['common']['height'], config['common']['width'] 77 | 78 | device = torch.device(args.device) 79 | data_module = DataModuleFactory.create(config['data']) 80 | data_module.prepare_data() 81 | 82 | model_config = config['model'] 83 | model_config['model_resolution'] = (512, 512) 84 | checkpoint_path = Path(config['common']['ckp_path']) 85 | 86 | tracker = Etap(**model_config) 87 | weights = torch.load(checkpoint_path, map_location='cpu', weights_only=True) 88 | tracker.load_state_dict(weights) 89 | tracker = tracker.to(device) 90 | tracker.eval() 91 | 92 | viz = Visualizer(save_dir=save_dir, pad_value=0, linewidth=1, 93 | tracks_leave_trace=-1, show_first_frame=5) 94 | 95 | for dataset in data_module.test_datasets: 96 | sequence_name = dataset.subsequence_name 97 | tracker.init_video_online_processing() 98 | timestamps_s = None 99 | 100 | original_queries = dataset.query_points.to(device) 101 | 102 | if add_support_points: 103 | support_query_xy = torch.from_numpy(make_grid(height, width, stride=support_point_stride)).float().to(device) 104 | support_num_queries = support_query_xy.shape[0] 105 | 106 | support_query_t = torch.zeros(support_num_queries, dtype=torch.int64, device=device) 107 | support_queries = torch.cat([support_query_t[:, None], support_query_xy], dim=1) 108 | 109 | queries = torch.cat([original_queries, support_queries]) 110 | print(f"Added {support_num_queries} support points to {original_queries.shape[0]} original queries") 111 | else: 112 | queries = original_queries 113 | support_num_queries = 0 114 | 115 | event_visus = None 116 | 117 | for sample, start_idx in tqdm(dataset, desc=f'Predicting {sequence_name}'): 118 | assert start_idx == tracker.online_ind 119 | voxels = sample.voxels.to(device) 120 | step = voxels.shape[0] // 2 121 | 122 | if timestamps_s is None: 123 | timestamps_s = sample.timestamps 124 | else: 125 | timestamps_s = torch.cat([timestamps_s, sample.timestamps[-step:]]) 126 | 127 | voxels = normalize_voxels(voxels) 128 | 129 | with torch.no_grad(): 130 | results = tracker( 131 | video=voxels[None], 132 | queries=queries[None], 133 | is_online=True, 134 | iters=6 135 | ) 136 | 137 | coords_predicted = results['coords_predicted'].clone() 138 | vis_logits = results['vis_predicted'] 139 | 140 | # Remove support points 141 | if support_num_queries > 0: 142 | coords_predicted = coords_predicted[:, :, :-support_num_queries] 143 | vis_logits = vis_logits[:, :, :-support_num_queries] 144 | 145 | event_visu = normalize_and_expand_channels(voxels.sum(dim=1)) 146 | event_visus = torch.cat([event_visus, event_visu[-step:]]) if event_visus is not None else event_visu 147 | 148 | # Save predictions 149 | output_file = save_dir / f'{sequence_name}.npz' 150 | np.savez( 151 | output_file, 152 | coords_predicted=coords_predicted[0].cpu().numpy(), 153 | vis_logits=vis_logits[0].cpu().numpy(), 154 | timestamps=timestamps_s.cpu().numpy() 155 | ) 156 | 157 | # Save predictions in feature tracking format 158 | output_file = save_dir / f'{sequence_name}.txt' 159 | write_points_to_file( 160 | coords_predicted.cpu().numpy()[0], 161 | timestamps_s, 162 | output_file 163 | ) 164 | 165 | viz.visualize( 166 | event_visus[None], 167 | coords_predicted, 168 | filename=sequence_name 169 | ) 170 | 171 | print('Done.') 172 | 173 | if __name__ == '__main__': 174 | main() -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .feature_tracking_online import FeatureTrackingDataModule 2 | from .e2d2 import E2d2DataModule 3 | from .penn_aviary import PennAviaryDataModule 4 | from .event_kubric import EventKubricDataModule 5 | from .evimo2 import Evimo2DataModule 6 | 7 | class DataModuleFactory: 8 | @staticmethod 9 | def create(data_config): 10 | dataset_name = data_config['dataset_name'] 11 | 12 | if dataset_name == 'feature_tracking_online': 13 | return FeatureTrackingDataModule(**data_config) 14 | elif dataset_name == 'e2d2': 15 | return E2d2DataModule(**data_config) 16 | elif dataset_name == 'penn_aviary': 17 | return PennAviaryDataModule(**data_config) 18 | elif dataset_name == 'event_kubric': 19 | return EventKubricDataModule(**data_config) 20 | elif dataset_name == 'evimo2': 21 | return Evimo2DataModule(**data_config) 22 | else: 23 | raise ValueError("Unsupported dataset_name.") 24 | -------------------------------------------------------------------------------- /src/data/modules/e2d2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from pathlib import Path 4 | import pytorch_lightning as pl 5 | 6 | from ..utils import EtapData 7 | 8 | 9 | class E2d2DataModule(pl.LightningDataModule): 10 | def __init__(self, dataset_name, data_root, preprocessed_name, sequences=None): 11 | super().__init__() 12 | self.save_hyperparameters() 13 | self.test_datasets = [] 14 | 15 | def prepare_data(self): 16 | """Prepare test datasets using preprocessed event representations.""" 17 | if self.hparams.sequences is None: 18 | sequences = [d for d in self.hparams.data_root.iterdir() if d.is_dir()] 19 | else: 20 | sequences = [Path(self.hparams.data_root) / seq for seq in self.hparams.sequences] 21 | 22 | for sequence_path in sequences: 23 | if not (sequence_path / 'seq.h5').exists(): 24 | print(f"Warning: seq.h5 not found in {sequence_path}, skipping...") 25 | continue 26 | 27 | self.test_datasets.append(E2D2InferenceDataset( 28 | sequence_path=sequence_path, 29 | preprocessed_name=self.hparams.preprocessed_name, 30 | stride=4, 31 | sliding_window_len=8 32 | )) 33 | 34 | 35 | class E2D2InferenceDataset(torch.utils.data.Dataset): 36 | def __init__( 37 | self, 38 | sequence_path, 39 | preprocessed_name, 40 | stride=4, 41 | sliding_window_len=8 42 | ): 43 | super().__init__() 44 | self.sequence_path = Path(sequence_path) 45 | self.subsequence_name = self.sequence_path.name 46 | self.stride = stride 47 | self.sliding_window_len = sliding_window_len 48 | 49 | self.ev_repr_dir = self.sequence_path / 'event_representations' / preprocessed_name 50 | 51 | if not self.ev_repr_dir.exists(): 52 | raise FileNotFoundError(f"Preprocessed event representations not found in {self.sequence_path}") 53 | 54 | print(f"Using preprocessed event representations from {self.ev_repr_dir}") 55 | 56 | self.ev_repr_paths = sorted(path for path in self.ev_repr_dir.iterdir() 57 | if path.is_file() and path.suffix == '.npy') 58 | 59 | if not self.ev_repr_paths: 60 | raise FileNotFoundError(f"No .npy files found in {self.ev_repr_dir}") 61 | 62 | self.timestamps = np.array([int(path.stem) for path in self.ev_repr_paths]) 63 | 64 | gt_path = self.sequence_path / 'gt_positions.npy' 65 | if gt_path.exists(): 66 | self.gt_tracks = torch.from_numpy(np.load(gt_path)) 67 | else: 68 | self.gt_tracks = None 69 | 70 | # Load query points from file 71 | query_xy = torch.from_numpy(np.load(self.sequence_path / 'queries.npy')) 72 | query_t = torch.zeros(query_xy.shape[0], dtype=torch.int64) 73 | self.query_points = torch.cat([query_t[:, None], query_xy], dim=1).float() 74 | 75 | self.start_indices = np.arange(0, len(self.timestamps) - self.sliding_window_len + 1, self.stride) 76 | 77 | def load_event_representations(self, start_idx, end_idx): 78 | """Load event representations for the given index range.""" 79 | ev_repr = [] 80 | 81 | for i in range(start_idx, end_idx): 82 | ev_repr_path = self.ev_repr_paths[i] 83 | sample = np.load(ev_repr_path) 84 | ev_repr.append(sample) 85 | 86 | ev_repr = torch.from_numpy(np.stack(ev_repr, axis=0)).float() 87 | return ev_repr 88 | 89 | def __len__(self): 90 | """Return the number of start indices.""" 91 | return len(self.start_indices) 92 | 93 | def __getitem__(self, idx): 94 | """Get a data sample for the given index.""" 95 | start_idx = self.start_indices[idx] 96 | end_idx = start_idx + self.sliding_window_len 97 | 98 | ev_repr = self.load_event_representations(start_idx, end_idx) 99 | gt_tracks = self.gt_tracks[start_idx:end_idx] if self.gt_tracks is not None else None 100 | visibility = torch.ones((self.sliding_window_len, gt_tracks.shape[1])) if gt_tracks is not None else None 101 | 102 | sample = EtapData( 103 | voxels=ev_repr, # [T, C, H, W] 104 | rgbs=None, 105 | trajectory=gt_tracks, # [T, N, 2] or None 106 | visibility=visibility, # [T, N] or None 107 | timestamps=torch.from_numpy(self.timestamps[start_idx:end_idx]).float() / 1e6, # [T] 108 | ) 109 | return sample, start_idx -------------------------------------------------------------------------------- /src/data/modules/event_kubric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import h5py 4 | from pathlib import Path 5 | import pytorch_lightning as pl 6 | from ..utils import EtapData 7 | 8 | def calculate_frame_times(num_frames=96, total_time=2.0): 9 | time_step = total_time / (num_frames - 1) 10 | frame_times = np.arange(num_frames) * time_step 11 | return frame_times 12 | 13 | class EventKubricDataModule(pl.LightningDataModule): 14 | def __init__(self, data_root, seq_len, traj_per_sample, 15 | dataset_name, preprocessed_name=None): 16 | super().__init__() 17 | self.save_hyperparameters() 18 | 19 | def prepare_data(self): 20 | test_path = Path(self.hparams.data_root) / 'test' 21 | self.test_dataset = EventKubricDataset( 22 | data_root=test_path, 23 | seq_len=self.hparams.seq_len, 24 | traj_per_sample=self.hparams.traj_per_sample, 25 | preprocessed_name=self.hparams.preprocessed_name, 26 | ) 27 | 28 | class EventKubricDataset(torch.utils.data.Dataset): 29 | def __init__( 30 | self, 31 | data_root, 32 | seq_len=24, 33 | traj_per_sample=768, 34 | sample_vis_1st_frame=False, 35 | preprocessed_name=None, 36 | ): 37 | super(EventKubricDataset, self).__init__() 38 | self.data_root = data_root 39 | self.traj_per_sample = traj_per_sample 40 | self.sample_vis_1st_frame = sample_vis_1st_frame 41 | self.seq_len = seq_len 42 | 43 | self.pad_bounds = [0, 25] 44 | self.resize_lim = [0.75, 1.25] 45 | self.resize_delta = 0.05 46 | self.max_crop_offset = 15 47 | self.preprocessed_name = preprocessed_name 48 | self.indices = np.arange(3, 93, 3)[3:-3] 49 | 50 | data_root_path = Path(self.data_root) 51 | self.samples = [d for d in data_root_path.iterdir() if d.is_dir()] 52 | 53 | # Validate samples 54 | valid_samples = [] 55 | for sample_path in self.samples: 56 | gt_path = sample_path / 'annotations.npy' 57 | if not gt_path.exists(): 58 | continue 59 | gt_data = np.load(str(gt_path), allow_pickle=True).item() 60 | visibility = gt_data['occluded'] 61 | if len(visibility) > self.traj_per_sample: 62 | valid_samples.append(sample_path.name) 63 | 64 | self.samples = valid_samples 65 | print(f"Found {len(self.samples)} valid samples") 66 | 67 | def rgba_to_rgb(self, rgba): 68 | if rgba.shape[-1] == 3: 69 | return rgba 70 | 71 | rgb = rgba[..., :3] 72 | alpha = rgba[..., 3:4] 73 | alpha = alpha.astype(np.float32) / 255.0 74 | rgb = rgb.astype(np.float32) * alpha + 255.0 * (1.0 - alpha) 75 | 76 | return rgb.astype(np.uint8) 77 | 78 | def load_rgb_frames(self, seq_path): 79 | seq_path = Path(seq_path) 80 | h5_path = seq_path / 'data.hdf5' 81 | if h5_path.exists(): 82 | with h5py.File(str(h5_path), 'r') as f: 83 | if 'rgba' in f: 84 | rgba = f['rgba'][self.indices] 85 | rgb = self.rgba_to_rgb(rgba) 86 | # Convert to (N, C, H, W) for PyTorch and correct channel order 87 | rgb = torch.from_numpy(rgb).permute(0, 3, 1, 2).float() 88 | rgb = rgb.flip(1) # Flip the channels to correct the order (BGR -> RGB) 89 | return rgb 90 | else: 91 | raise KeyError("'rgba' dataset not found in data.hdf5") 92 | raise FileNotFoundError(f"Could not find data.hdf5 in {seq_path}") 93 | 94 | def load_ground_truth(self, seq_path, seq_name): 95 | """ 96 | Load ground truth trajectory data from annotations file. 97 | 98 | Args: 99 | seq_path: Path to sequence directory 100 | seq_name: Name of the sequence 101 | 102 | Returns: 103 | tuple: (trajectory data, visibility data) 104 | """ 105 | data_root_path = Path(self.data_root) 106 | gt_path = data_root_path / seq_name / 'annotations.npy' 107 | gt_data = np.load(str(gt_path), allow_pickle=True).item() 108 | 109 | # Extract and process trajectory data 110 | traj_2d = gt_data['target_points'] 111 | visibility = gt_data['occluded'] # Here a value of 1 means point is visible 112 | 113 | traj_2d = np.transpose(traj_2d, (1, 0, 2))[self.indices] 114 | visibility = np.transpose(np.logical_not(visibility), (1, 0))[self.indices] 115 | 116 | visibility = torch.from_numpy(visibility) 117 | traj_2d = torch.from_numpy(traj_2d) 118 | 119 | visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0] 120 | 121 | if self.sample_vis_1st_frame: 122 | visibile_pts_inds = visibile_pts_first_frame_inds 123 | else: 124 | visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[:, 0] 125 | visibile_pts_inds = torch.cat( 126 | (visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0 127 | ) 128 | 129 | return traj_2d, visibility, visibile_pts_inds 130 | 131 | def load_preprocessed_representations(self, seq_path): 132 | seq_path = Path(seq_path) 133 | ev_path = seq_path / 'event_representations' / f'{self.preprocessed_name}.h5' 134 | 135 | with h5py.File(str(ev_path), 'r') as f: 136 | ev_repr = f['representations'][:] 137 | return torch.from_numpy(ev_repr).float() 138 | 139 | def normalize_representation(self, repr): 140 | mask = repr != 0 141 | mean = repr.sum(dim=(0, 2, 3), keepdim=True) / mask.sum(dim=(0, 2, 3), keepdim=True) 142 | var = ((repr - mean)**2 * mask).sum(dim=(0, 2, 3), keepdim=True) / mask.sum(dim=(0, 2, 3), keepdim=True) 143 | std = torch.sqrt(var + 1e-8) 144 | return torch.where(mask, (repr - mean) / std, repr) 145 | 146 | def __getitem__(self, index): 147 | seq_name = self.samples[index] 148 | seq_path = Path(self.data_root) / seq_name 149 | 150 | traj_2d, visibility, visibile_pts_inds = self.load_ground_truth(seq_path, seq_name) 151 | 152 | # Select points 153 | point_inds = torch.arange(min(len(visibile_pts_inds), self.traj_per_sample)) 154 | visible_inds_sampled = visibile_pts_inds[point_inds] 155 | 156 | trajs = traj_2d[:, visible_inds_sampled].float() 157 | visibles = visibility[:, visible_inds_sampled] 158 | valids = torch.ones((self.seq_len, self.traj_per_sample)) 159 | 160 | rgbs = self.load_rgb_frames(seq_path) 161 | ev_repr = self.load_preprocessed_representations(seq_path) 162 | 163 | # Channelwise std-mean normalization 164 | ev_repr = self.normalize_representation(ev_repr) 165 | 166 | _, first_positive_inds = torch.max(visibles, dim=0) # Find first frame where each point is visible 167 | num_points = visibles.shape[1] 168 | query_coords = torch.zeros((num_points, 2), dtype=trajs.dtype) 169 | 170 | for p in range(num_points): 171 | first_frame = first_positive_inds[p].item() 172 | query_coords[p] = trajs[first_frame, p, :2] 173 | 174 | queries = torch.cat([first_positive_inds.unsqueeze(1), query_coords], dim=1) 175 | 176 | sample = EtapData( 177 | voxels=ev_repr, 178 | rgbs=rgbs, 179 | trajectory=trajs, 180 | visibility=visibles, 181 | valid=valids, 182 | query_points=queries, 183 | seq_name=seq_name, 184 | ) 185 | return sample 186 | 187 | def __len__(self): 188 | return len(self.samples) 189 | 190 | def get_full_sample(self, index): 191 | '''Helper function to retrieve full sample data''' 192 | sample = {} 193 | seq_name = self.samples[index] 194 | seq_path = Path(self.data_root) / seq_name 195 | h5_path = seq_path / 'data.hdf5' 196 | 197 | # Get Kubric data and ground truth 198 | with h5py.File(str(h5_path), 'r') as f: 199 | indices = f['subsample_indices'][:] 200 | sample['rgba'] = f['rgba'][indices] 201 | sample['depths'] = f['depth'][:] 202 | sample['forward_flow'] = f['forward_flow'][:] 203 | sample['normal'] = f['normal'][:] 204 | sample['object_coordinates'] = f['object_coordinates'][:] 205 | sample['segmentation'] = f['segmentation'][:] 206 | 207 | # Get timestamps 208 | t_min, t_max = 0.0, 2.0 209 | timestamps = calculate_frame_times(num_frames=96, total_time=t_max) 210 | sample['timestamps'] = timestamps[indices] 211 | 212 | # Get Events 213 | event_root = seq_path / 'events' 214 | event_paths = sorted(event_root.iterdir()) 215 | events = [] 216 | for event_path in event_paths: 217 | event_mini_batch = np.load(str(event_path)) 218 | events.append(np.stack([ 219 | event_mini_batch['y'], 220 | event_mini_batch['x'], 221 | event_mini_batch['t'], 222 | event_mini_batch['p'], 223 | ], axis=1)) 224 | events = np.concatenate(events, axis=0) 225 | sample['events'] = events 226 | 227 | # Get Point Tracks 228 | gt_path = Path(self.data_root) / seq_name / 'annotations.npy' 229 | gt_data = np.load(str(gt_path), allow_pickle=True).item() 230 | traj_2d = gt_data['target_points'] 231 | visibility = gt_data['occluded'] # Here a value of 1 means point is visible 232 | traj_2d = np.transpose(traj_2d, (1, 0, 2))[indices] 233 | visibility = np.transpose(np.logical_not(visibility), (1, 0))[indices] 234 | visibility = torch.from_numpy(visibility) 235 | traj_2d = torch.from_numpy(traj_2d) 236 | sample['point_tracks'] = traj_2d 237 | sample['visibility'] = visibility 238 | 239 | return sample -------------------------------------------------------------------------------- /src/data/modules/evimo2.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import h5py 4 | import torch 5 | import pytorch_lightning as pl 6 | import pandas as pd 7 | from torch.utils.data import ConcatDataset 8 | import numpy as np 9 | 10 | from ..utils import EtapData 11 | 12 | class Evimo2DataModule(pl.LightningDataModule): 13 | def __init__(self, dataset_name, data_root, preprocessed_name, metadata_path): 14 | super().__init__() 15 | self.save_hyperparameters() 16 | 17 | def prepare_data(self): 18 | samples = [] 19 | 20 | df = pd.read_csv(self.hparams.metadata_path) 21 | 22 | for _, row in df.iterrows(): 23 | sample_path = Path(self.hparams.data_root) / row['name'] 24 | samples.append(Evimo2SequenceDataset( 25 | data_root=sample_path, 26 | preprocessed_name=self.hparams.preprocessed_name, 27 | t_start=row['t_start'], 28 | t_end=row['t_end'], 29 | )) 30 | 31 | self.test_dataset = ConcatDataset(samples) 32 | 33 | class Evimo2SequenceDataset(torch.utils.data.Dataset): 34 | def __init__( 35 | self, 36 | data_root: Path, 37 | preprocessed_name: str, 38 | t_start: float, 39 | t_end: float, 40 | ): 41 | super(Evimo2SequenceDataset, self).__init__() 42 | self.data_root = data_root 43 | self.preprocessed_name = preprocessed_name 44 | self.t_start = t_start 45 | self.t_end = t_end 46 | 47 | # Load timestamps to determine valid indices 48 | repr_path = self.data_root / 'event_representations' / f'{self.preprocessed_name}.h5' 49 | with h5py.File(repr_path, 'r') as f: 50 | self.timestamps = f['timestamps'][:] 51 | 52 | self.valid_indices = np.where( 53 | (self.timestamps >= self.t_start) & 54 | (self.timestamps <= self.t_end) 55 | )[0] 56 | 57 | if len(self.valid_indices) == 0: 58 | raise ValueError( 59 | f"No timestamps found in range [{t_start}, {t_end}] " 60 | f"for sequence {self.data_root.name}" 61 | ) 62 | 63 | def __len__(self): 64 | return 1 65 | 66 | def __getitem__(self, idx): 67 | repr_path = self.data_root / 'event_representations' / f'{self.preprocessed_name}.h5' 68 | 69 | # Event representations 70 | with h5py.File(repr_path, 'r') as f: 71 | # Only load data for valid time range 72 | representations = f['representations'][self.valid_indices] 73 | timestamps = f['timestamps'][self.valid_indices] 74 | 75 | representations = torch.from_numpy(representations).float() 76 | timestamps = torch.from_numpy(timestamps) 77 | 78 | # GT data 79 | gt_path = self.data_root / 'dataset_tracks.h5' 80 | with h5py.File(gt_path, 'r') as f: 81 | gt_tracks = f['tracks'][self.valid_indices] 82 | gt_occlusion = f['occlusions'][self.valid_indices] 83 | 84 | gt_tracks = torch.from_numpy(gt_tracks).float() 85 | gt_visibility = torch.from_numpy(~gt_occlusion).bool() 86 | 87 | # Queries 88 | T, N, _ = gt_tracks.shape 89 | first_positive_inds = torch.argmax(gt_visibility.int(), dim=0) 90 | query_xy = gt_tracks[first_positive_inds, torch.arange(N)] 91 | queries = torch.cat([first_positive_inds[..., None], query_xy], dim=-1) 92 | 93 | return EtapData( 94 | voxels=representations, # [T', C, H, W] 95 | rgbs=None, 96 | trajectory=gt_tracks, # [T', N, 2] 97 | visibility=gt_visibility,# [T', N] 98 | query_points=queries, 99 | seq_name=self.data_root.name, 100 | timestamps=timestamps, # [T'] 101 | ) 102 | -------------------------------------------------------------------------------- /src/data/modules/feature_tracking_online.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import pytorch_lightning as pl 5 | import numpy as np 6 | 7 | from ..utils import EtapData 8 | from src.utils import SUPPORTED_SEQUENCES_FEATURE_TRACKING 9 | 10 | class FeatureTrackingDataModule(pl.LightningDataModule): 11 | def __init__(self, data_root, 12 | dataset_name, 13 | preprocessed_name=None): 14 | super().__init__() 15 | self.save_hyperparameters() 16 | self.test_datasets = [] 17 | 18 | def prepare_data(self): 19 | if self.hparams.dataset_name == 'eds': 20 | supported_sequences = SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds'] 21 | data_roots = [self.hparams.data_root for _ in SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']] 22 | elif self.hparams.dataset_name == 'ec': 23 | supported_sequences = SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec'] 24 | data_roots = [self.hparams.data_root for _ in SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']] 25 | elif self.hparams.dataset_name == 'feature_tracking_online': 26 | supported_sequences = SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds'] + SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec'] 27 | data_root_eds = os.path.join(self.hparams.data_root, 'eds') 28 | data_root_ec = os.path.join(self.hparams.data_root, 'ec') 29 | data_roots = [data_root_eds for _ in SUPPORTED_SEQUENCES_FEATURE_TRACKING['eds']] \ 30 | + [data_root_ec for _ in SUPPORTED_SEQUENCES_FEATURE_TRACKING['ec']] 31 | else: 32 | raise ValueError(f"Unsupported dataset_name: {self.hparams.dataset_name}") 33 | 34 | for subsequence_name, data_root in zip(supported_sequences, data_roots): 35 | self.test_datasets.append(FeatureTrackingInferenceDataset( 36 | data_root=data_root, 37 | subsequence_name=subsequence_name, 38 | preprocessed_name=self.hparams.preprocessed_name 39 | )) 40 | 41 | class FeatureTrackingInferenceDataset(torch.utils.data.Dataset): 42 | """Dataset for a sequence of the EDS or EC dataset for point tracking 43 | as used in https://arxiv.org/pdf/2211.12826. 44 | This dataset is implemented for use in an online manner. 45 | Each item provides the data chunk for 1 step (e.g. 8 frames) 46 | of the sequence. The next data will start from the previous + stride frames. 47 | """ 48 | def __init__( 49 | self, 50 | data_root, 51 | subsequence_name, 52 | stride=4, 53 | sliding_window_len=8, 54 | preprocessed_name=None 55 | ): 56 | self.subsequence_name = subsequence_name 57 | self.stride = stride 58 | self.sliding_window_len = sliding_window_len 59 | self.seq_root = os.path.join(data_root, subsequence_name) 60 | assert preprocessed_name is not None, 'online processing of raw events not supported.' 61 | self.preprocessed_name = preprocessed_name 62 | 63 | events_path = os.path.join(self.seq_root, 'events', preprocessed_name) 64 | self.samples, gt_ts_for_sanity_check = self.load_event_representations(events_path) 65 | 66 | gt_path = os.path.join('config/misc', os.path.basename(os.path.normpath(data_root)), 'gt_tracks', f'{subsequence_name}.gt.txt') 67 | gt_tracks = np.genfromtxt(gt_path) # [id, t, x, y] 68 | self.gt_tracks = torch.from_numpy(self.reformat_tracks(gt_tracks)).float() 69 | 70 | self.gt_times_s = np.unique(gt_tracks[:, 1]) 71 | self.gt_times_us = (1e6 * self.gt_times_s).astype(int) 72 | assert np.allclose(gt_ts_for_sanity_check, self.gt_times_us) 73 | 74 | self.start_indices = np.arange(0, len(self.gt_times_us), stride) 75 | 76 | # Queries 77 | N = self.gt_tracks.shape[1] 78 | query_xy = self.gt_tracks[0, :] 79 | query_t = torch.zeros(N, dtype=torch.int64, device=query_xy.device) 80 | self.query_points = torch.cat([query_t[:, None], query_xy], dim=-1) 81 | 82 | def __len__(self): 83 | return len(self.start_indices) 84 | 85 | def __getitem__(self, idx): 86 | start_idx = self.start_indices[idx] 87 | end_idx = start_idx + self.sliding_window_len 88 | ev_repr = [] 89 | 90 | for t in range(start_idx, end_idx): 91 | sample = np.load(os.path.join(self.seq_root, 'events', self.preprocessed_name, 92 | f'{self.gt_times_us[t]}.npy')) 93 | ev_repr.append(sample) 94 | 95 | ev_repr = torch.from_numpy(np.stack(ev_repr, axis=0)).float() 96 | 97 | gt_tracks = self.gt_tracks[start_idx:end_idx] 98 | height, width = ev_repr.shape[-2:] 99 | visibility = self.generate_visibility_mask(gt_tracks, height, width) 100 | 101 | sample = EtapData( 102 | voxels=ev_repr, # [T, C, H, W] 103 | rgbs=None, # [T, 3, H, W], only for visualization 104 | trajectory=gt_tracks, # [T, N, 2] 105 | visibility=visibility, # [T, N] 106 | timestamps=torch.from_numpy(self.gt_times_s[start_idx:end_idx]) # [T] 107 | ) 108 | return sample, start_idx 109 | 110 | def load_event_representations(self, events_path): 111 | repr_files = [f for f in os.listdir(events_path) if f.endswith('.npy')] 112 | 113 | # Extract the integer numbers from the file names 114 | timestamps = [] 115 | for file in repr_files: 116 | number = int(os.path.splitext(file)[0]) 117 | timestamps.append(number) 118 | 119 | return repr_files, sorted(timestamps) 120 | 121 | def reformat_tracks(self, tracks): 122 | # Extract unique timestamps and point IDs 123 | timestamps = np.unique(tracks[:, 1]) 124 | point_ids = np.unique(tracks[:, 0]).astype(int) 125 | 126 | T = len(timestamps) 127 | N = len(point_ids) 128 | output = np.full((T, N, 2), np.nan) 129 | time_to_index = {t: i for i, t in enumerate(timestamps)} 130 | 131 | for row in tracks: 132 | point_id, timestamp, x, y = row 133 | t_index = time_to_index[timestamp] 134 | p_index = int(point_id) 135 | output[t_index, p_index, :] = [x, y] 136 | 137 | return output 138 | 139 | def generate_visibility_mask(self, points, height, width): 140 | x = points[..., 0] 141 | y = points[..., 1] 142 | x_in_bounds = (x >= 0) & (x < width) 143 | y_in_bounds = (y >= 0) & (y < height) 144 | mask = x_in_bounds & y_in_bounds 145 | return mask 146 | -------------------------------------------------------------------------------- /src/data/modules/penn_aviary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import h5py 4 | import pytorch_lightning as pl 5 | import numpy as np 6 | from tqdm import tqdm 7 | from PIL import Image 8 | from pathlib import Path 9 | 10 | from src.representations import EventRepresentationFactory 11 | from src.utils import make_grid 12 | from ..utils import EtapData 13 | 14 | 15 | def pad_images_to_timestamps(images, img_ts, target_timestamps): 16 | """Match images to the nearest timestamp.""" 17 | indices = np.searchsorted(img_ts, target_timestamps, side='right') - 1 18 | indices = np.clip(indices, 0, len(images) - 1) 19 | imgs_padded = images[indices] 20 | return imgs_padded 21 | 22 | 23 | def read_binary_mask(path): 24 | """Read binary mask from file.""" 25 | img = np.array(Image.open(path).convert('L')) 26 | mask = (img != 255).astype(np.uint8) 27 | return mask 28 | 29 | class PennAviaryDataModule(pl.LightningDataModule): 30 | def __init__(self, dataset_name, data_root, sequence_data, repr_config=None, 31 | load_rgb=False, sequences=None): 32 | super().__init__() 33 | self.save_hyperparameters() 34 | self.test_datasets = [] 35 | 36 | if self.hparams.repr_config is not None: 37 | self.hparams.repr_config['image_shape'] = tuple(self.hparams.repr_config['image_shape']) 38 | self.converter = EventRepresentationFactory.create(self.hparams.repr_config) 39 | else: 40 | self.converter = None 41 | 42 | def prepare_data(self): 43 | data_root = Path(self.hparams.data_root) 44 | if self.hparams.sequences is None: 45 | sequences = [d for d in data_root.iterdir() if d.is_dir()] 46 | else: 47 | sequences = [data_root / seq for seq in self.hparams.sequences] 48 | 49 | for sequence_path in sequences: 50 | subsequence_name = sequence_path.name 51 | 52 | if not (sequence_path / 'seq.h5').exists(): 53 | print(f"Warning: seq.h5 not found in {sequence_path}, skipping...") 54 | continue 55 | 56 | self.test_datasets.append(PennAviaryDataset( 57 | sequence_path=sequence_path, 58 | num_events=self.hparams.sequence_data[subsequence_name]['num_events'], 59 | start_time_s=self.hparams.sequence_data[subsequence_name]['start_time_s'], 60 | duration_s=self.hparams.sequence_data[subsequence_name]['duration_s'], 61 | step_time_s=self.hparams.sequence_data[subsequence_name]['step_time_s'], 62 | converter=self.converter, 63 | load_rgb=self.hparams.load_rgb, 64 | query_stride=self.hparams.sequence_data[subsequence_name].get('query_stride', 40), 65 | mask_name=self.hparams.sequence_data[subsequence_name]['mask_name'] 66 | )) 67 | 68 | 69 | class PennAviaryDataset(torch.utils.data.Dataset): 70 | def __init__( 71 | self, 72 | sequence_path, 73 | start_time_s, 74 | duration_s, 75 | step_time_s, 76 | num_events, 77 | stride=4, 78 | sliding_window_len=8, 79 | load_rgb=False, 80 | converter=None, 81 | query_stride=40, 82 | mask_name=None 83 | ): 84 | super().__init__() 85 | self.sequence_path = Path(sequence_path) 86 | self.subsequence_name = self.sequence_path.name 87 | self.num_events = num_events 88 | self.stride = stride 89 | self.sliding_window_len = sliding_window_len 90 | self.load_rgb = load_rgb 91 | 92 | # Define tracking timestamps 93 | start_time_us = start_time_s * 1e6 94 | end_time_us = start_time_us + duration_s * 1e6 95 | step_time_us = step_time_s * 1e6 96 | self.timestamps = np.arange(start_time_us, end_time_us, step_time_us) 97 | self.gt_tracks = None 98 | 99 | # Generate query points grid 100 | height, width = 480, 640 101 | query_xy = torch.from_numpy(make_grid(height, width, stride=query_stride)) 102 | num_queries = query_xy.shape[0] 103 | query_t = torch.zeros(num_queries, dtype=torch.int64) 104 | self.query_points = torch.cat([query_t[:, None], query_xy], dim=1).float() 105 | 106 | # Only query points within mask 107 | if mask_name is not None: 108 | mask_path = self.sequence_path / mask_name 109 | segm_mask = torch.from_numpy(read_binary_mask(mask_path)) 110 | query_x = self.query_points[:, 1].int() 111 | query_y = self.query_points[:, 2].int() 112 | segm_mask = segm_mask[query_y, query_x] 113 | self.query_points = self.query_points[segm_mask == 1] 114 | 115 | self.h5_path = self.sequence_path / 'seq.h5' 116 | 117 | if self.load_rgb: 118 | with h5py.File(self.h5_path, 'r') as f: 119 | images = f['images'][:] 120 | img_ts = f['img_ts'][:] 121 | self.imgs_padded = pad_images_to_timestamps(images, img_ts, self.timestamps) 122 | 123 | load_ev_repr_from_file = True if len(self.timestamps) > 1000 else False 124 | 125 | if converter is not None and not load_ev_repr_from_file: 126 | self.load_ev_repr = False 127 | self.ev_repr = self.create_representations(self.h5_path, converter, self.timestamps) 128 | else: 129 | self.load_ev_repr = True 130 | ev_repr_dir = self.sequence_path / str(int(1e6 * step_time_s)).zfill(9) 131 | self.ev_repr_paths = sorted(path for path in ev_repr_dir.iterdir() if path.is_file()) 132 | assert len(self.ev_repr_paths) == len(self.timestamps), f"Expected {len(self.timestamps)} event representation files, but found {len(self.ev_repr_paths)}" 133 | 134 | self.start_indices = np.arange(0, len(self.timestamps) - self.stride, self.stride) 135 | 136 | def create_representations(self, h5_path, converter, timestamps): 137 | with h5py.File(h5_path, 'r') as f: 138 | indices_end = np.searchsorted(f['t'], timestamps) - 1 139 | indices_start = indices_end - self.num_events 140 | 141 | representations = [] 142 | 143 | for i_start, i_end in tqdm(zip(indices_start, indices_end), 144 | desc=f"{self.subsequence_name}: creating event representations", 145 | total=len(indices_start)): 146 | events = np.stack([f['y'][i_start:i_end], 147 | f['x'][i_start:i_end], 148 | f['t'][i_start:i_end], 149 | f['p'][i_start:i_end]], axis=-1) 150 | repr = converter(events, t_mid=None) 151 | representations.append(repr) 152 | 153 | representations = np.stack(representations) 154 | 155 | return representations 156 | 157 | def load_event_representations(self, start_idx, end_idx): 158 | ev_repr = [] 159 | 160 | for i in range(start_idx, end_idx): 161 | ev_repr_path = self.ev_repr_paths[i] 162 | sample = np.load(ev_repr_path) 163 | ev_repr.append(sample) 164 | 165 | ev_repr = torch.from_numpy(np.stack(ev_repr, axis=0)).float() 166 | return ev_repr 167 | 168 | def __len__(self): 169 | return len(self.start_indices) 170 | 171 | def __getitem__(self, idx): 172 | # Load event representation 173 | start_idx = self.start_indices[idx] 174 | end_idx = start_idx + self.sliding_window_len 175 | 176 | if self.load_ev_repr: 177 | ev_repr = self.load_event_representations(start_idx, end_idx) 178 | else: 179 | ev_repr = self.ev_repr[start_idx:end_idx] 180 | ev_repr = torch.from_numpy(np.stack(ev_repr, axis=0)).float() 181 | 182 | if self.load_rgb: 183 | rgbs = self.imgs_padded[start_idx:end_idx] 184 | rgbs = torch.from_numpy(rgbs).float().permute(0, 3, 1, 2) 185 | else: 186 | rgbs = None 187 | 188 | sample = EtapData( 189 | voxels=ev_repr, # [T, C, H, W] 190 | rgbs=rgbs, # [T, 3, H, W], only for visualization 191 | trajectory=None, 192 | visibility=None, 193 | timestamps=torch.from_numpy(self.timestamps[start_idx:end_idx]).float() / 1e6, # [T] 194 | ) 195 | return sample, start_idx -------------------------------------------------------------------------------- /src/data/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .collate import EtapData -------------------------------------------------------------------------------- /src/data/utils/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dataclasses 3 | from dataclasses import dataclass 4 | from typing import Any, List, Optional, Tuple 5 | 6 | 7 | @dataclass(eq=False) 8 | class EtapData: 9 | """ 10 | Dataclass for storing video tracks data. 11 | """ 12 | voxels: torch.Tensor # B, S, C, H, W 13 | trajectory: torch.Tensor # B, S, N, 2 14 | visibility: torch.Tensor # B, S, N 15 | # optional data 16 | rgbs: Optional[torch.Tensor] # B, S, 3, H, W, only for visualization 17 | valid: Optional[torch.Tensor] = None # B, S, N 18 | segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W 19 | seq_name: Optional[str] = None 20 | query_points: Optional[torch.Tensor] = None # TapVID evaluation format 21 | timestamps: Optional[torch.Tensor] = None # For EDS evaluation 22 | pair_indices: Optional[torch.Tensor] = None # For contrastive loss 23 | video: Optional[torch.Tensor] = None # B, S, C, H, W 24 | e2vid: Optional[torch.Tensor] = None # B, S, C, H, W 25 | dataset_name: Optional[str] = None 26 | 27 | 28 | def collate_fn(batch: List[EtapData]) -> EtapData: 29 | """ 30 | Collate function for video tracks data. 31 | """ 32 | voxels = torch.stack([b.voxels for b in batch], dim=0) 33 | trajectory = torch.stack([b.trajectory for b in batch], dim=0) 34 | visibility = torch.stack([b.visibility for b in batch], dim=0) 35 | query_points = segmentation = None 36 | if batch[0].query_points is not None: 37 | query_points = torch.stack([b.query_points for b in batch], dim=0) 38 | if batch[0].segmentation is not None: 39 | segmentation = torch.stack([b.segmentation for b in batch], dim=0) 40 | seq_name = [b.seq_name for b in batch] 41 | 42 | return EtapData( 43 | voxels=voxels, 44 | trajectory=trajectory, 45 | visibility=visibility, 46 | segmentation=segmentation, 47 | seq_name=seq_name, 48 | query_points=query_points, 49 | ) 50 | 51 | 52 | def collate_fn_train(batch: List[Tuple[EtapData, bool]]) -> Tuple[EtapData, List[bool]]: 53 | """ 54 | Collate function for video tracks data during training. 55 | """ 56 | gotit = [gotit for _, gotit in batch] 57 | voxels = torch.stack([b.voxels for b, _ in batch], dim=0) 58 | rgbs = torch.stack([b.rgbs for b, _ in batch], dim=0) if batch[0][0].rgbs is not None else None 59 | trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) 60 | visibility = torch.stack([b.visibility for b, _ in batch], dim=0) 61 | valid = torch.stack([b.valid for b, _ in batch], dim=0) 62 | seq_name = [b.seq_name for b, _ in batch] 63 | return ( 64 | EtapData( 65 | voxels=voxels, 66 | rgbs=rgbs, 67 | trajectory=trajectory, 68 | visibility=visibility, 69 | valid=valid, 70 | seq_name=seq_name, 71 | ), 72 | gotit, 73 | ) 74 | 75 | 76 | def collate_fn_load_inverted(batch: List[Tuple[EtapData, EtapData, bool]]) -> Tuple[EtapData, List[bool]]: 77 | """ 78 | Collate function for video tracks data with load_inverted case. 79 | Combines original and inverted samples into a single EtapData object. 80 | """ 81 | # Separate original and inverted samples 82 | orig_samples = [b[0] for b in batch] 83 | inv_samples = [b[1] for b in batch] 84 | gotit = [gotit for _, _, gotit in batch] 85 | 86 | # Combine original and inverted samples 87 | combined_samples = orig_samples + inv_samples 88 | 89 | voxels = torch.stack([b.voxels for b in combined_samples], dim=0) 90 | rgbs = torch.stack([b.rgbs for b in combined_samples], dim=0) if combined_samples[0].rgbs is not None else None 91 | trajectory = torch.stack([b.trajectory for b in combined_samples], dim=0) 92 | visibility = torch.stack([b.visibility for b in combined_samples], dim=0) 93 | valid = torch.stack([b.valid for b in combined_samples], dim=0) 94 | seq_name = [b.seq_name for b in combined_samples] 95 | 96 | # Create a tensor to keep track of paired samples using explicit indices 97 | batch_size = len(orig_samples) 98 | pair_indices = torch.arange(batch_size).repeat(2) 99 | 100 | return EtapData( 101 | voxels=voxels, 102 | rgbs=rgbs, 103 | trajectory=trajectory, 104 | visibility=visibility, 105 | valid=valid, 106 | seq_name=seq_name, 107 | pair_indices=pair_indices, 108 | ), gotit 109 | 110 | 111 | def try_to_cuda(t: Any) -> Any: 112 | """ 113 | Try to move the input variable `t` to a cuda device. 114 | 115 | Args: 116 | t: Input tensor or other object 117 | 118 | Returns: 119 | t_cuda: `t` moved to a cuda device, if supported 120 | """ 121 | try: 122 | t = t.float().cuda() 123 | except AttributeError: 124 | pass 125 | return t 126 | 127 | 128 | def dataclass_to_cuda_(obj: Any) -> Any: 129 | """ 130 | Move all contents of a dataclass to cuda inplace if supported. 131 | 132 | Args: 133 | obj: Input dataclass object to move to CUDA 134 | 135 | Returns: 136 | obj: The same object with its tensor fields moved to CUDA 137 | """ 138 | for f in dataclasses.fields(obj): 139 | setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) 140 | return obj 141 | -------------------------------------------------------------------------------- /src/model/etap/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/src/model/etap/core/__init__.py -------------------------------------------------------------------------------- /src/model/etap/core/cotracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/ETAP/e13875244b45e651c1ebcce1e769ed05ffbf7acd/src/model/etap/core/cotracker/__init__.py -------------------------------------------------------------------------------- /src/model/etap/core/cotracker/blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified by F. Hamann (TU Berlin) 8 | # Copyright (c) 2025 F. Hamann (TU Berlin) 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from functools import partial 14 | from typing import Callable 15 | import collections 16 | from itertools import repeat 17 | 18 | from ..model_utils import bilinear_sampler 19 | 20 | 21 | # From PyTorch internals 22 | def _ntuple(n): 23 | def parse(x): 24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 25 | return tuple(x) 26 | return tuple(repeat(x, n)) 27 | 28 | return parse 29 | 30 | 31 | def exists(val): 32 | return val is not None 33 | 34 | 35 | def default(val, d): 36 | return val if exists(val) else d 37 | 38 | 39 | to_2tuple = _ntuple(2) 40 | 41 | 42 | class Mlp(nn.Module): 43 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 44 | 45 | def __init__( 46 | self, 47 | in_features, 48 | hidden_features=None, 49 | out_features=None, 50 | act_layer=nn.GELU, 51 | norm_layer=None, 52 | bias=True, 53 | drop=0.0, 54 | use_conv=False, 55 | ): 56 | super().__init__() 57 | out_features = out_features or in_features 58 | hidden_features = hidden_features or in_features 59 | bias = to_2tuple(bias) 60 | drop_probs = to_2tuple(drop) 61 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 62 | 63 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) 64 | self.act = act_layer() 65 | self.drop1 = nn.Dropout(drop_probs[0]) 66 | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() 67 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) 68 | self.drop2 = nn.Dropout(drop_probs[1]) 69 | 70 | def forward(self, x): 71 | x = self.fc1(x) 72 | x = self.act(x) 73 | x = self.drop1(x) 74 | x = self.fc2(x) 75 | x = self.drop2(x) 76 | return x 77 | 78 | 79 | class ResidualBlock(nn.Module): 80 | def __init__(self, in_planes, planes, norm_fn="group", stride=1): 81 | super(ResidualBlock, self).__init__() 82 | 83 | self.conv1 = nn.Conv2d( 84 | in_planes, 85 | planes, 86 | kernel_size=3, 87 | padding=1, 88 | stride=stride, 89 | padding_mode="zeros", 90 | ) 91 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros") 92 | self.relu = nn.ReLU(inplace=True) 93 | 94 | num_groups = planes // 8 95 | 96 | if norm_fn == "group": 97 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 98 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 99 | if not stride == 1: 100 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 101 | 102 | elif norm_fn == "batch": 103 | self.norm1 = nn.BatchNorm2d(planes) 104 | self.norm2 = nn.BatchNorm2d(planes) 105 | if not stride == 1: 106 | self.norm3 = nn.BatchNorm2d(planes) 107 | 108 | elif norm_fn == "instance": 109 | self.norm1 = nn.InstanceNorm2d(planes) 110 | self.norm2 = nn.InstanceNorm2d(planes) 111 | if not stride == 1: 112 | self.norm3 = nn.InstanceNorm2d(planes) 113 | 114 | elif norm_fn == "none": 115 | self.norm1 = nn.Sequential() 116 | self.norm2 = nn.Sequential() 117 | if not stride == 1: 118 | self.norm3 = nn.Sequential() 119 | 120 | if stride == 1: 121 | self.downsample = None 122 | 123 | else: 124 | self.downsample = nn.Sequential( 125 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 126 | ) 127 | 128 | def forward(self, x): 129 | y = x 130 | y = self.relu(self.norm1(self.conv1(y))) 131 | y = self.relu(self.norm2(self.conv2(y))) 132 | 133 | if self.downsample is not None: 134 | x = self.downsample(x) 135 | 136 | return self.relu(x + y) 137 | 138 | 139 | class BasicEncoder(nn.Module): 140 | def __init__(self, input_dim=3, output_dim=128, stride=4): 141 | super(BasicEncoder, self).__init__() 142 | self.stride = stride 143 | self.norm_fn = "instance" 144 | self.in_planes = output_dim // 2 145 | 146 | self.norm1 = nn.InstanceNorm2d(self.in_planes) 147 | self.norm2 = nn.InstanceNorm2d(output_dim * 2) 148 | 149 | self.conv1 = nn.Conv2d( 150 | input_dim, 151 | self.in_planes, 152 | kernel_size=7, 153 | stride=2, 154 | padding=3, 155 | padding_mode="zeros", 156 | ) 157 | self.relu1 = nn.ReLU(inplace=True) 158 | self.layer1 = self._make_layer(output_dim // 2, stride=1) 159 | self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) 160 | self.layer3 = self._make_layer(output_dim, stride=2) 161 | self.layer4 = self._make_layer(output_dim, stride=2) 162 | 163 | self.conv2 = nn.Conv2d( 164 | output_dim * 3 + output_dim // 4, 165 | output_dim * 2, 166 | kernel_size=3, 167 | padding=1, 168 | padding_mode="zeros", 169 | ) 170 | self.relu2 = nn.ReLU(inplace=True) 171 | self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) 172 | for m in self.modules(): 173 | if isinstance(m, nn.Conv2d): 174 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 175 | elif isinstance(m, (nn.InstanceNorm2d)): 176 | if m.weight is not None: 177 | nn.init.constant_(m.weight, 1) 178 | if m.bias is not None: 179 | nn.init.constant_(m.bias, 0) 180 | 181 | def _make_layer(self, dim, stride=1): 182 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 183 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 184 | layers = (layer1, layer2) 185 | 186 | self.in_planes = dim 187 | return nn.Sequential(*layers) 188 | 189 | def forward(self, x): 190 | _, _, H, W = x.shape 191 | 192 | x = self.conv1(x) 193 | x = self.norm1(x) 194 | x = self.relu1(x) 195 | 196 | a = self.layer1(x) 197 | b = self.layer2(a) 198 | c = self.layer3(b) 199 | d = self.layer4(c) 200 | 201 | def _bilinear_intepolate(x): 202 | return F.interpolate( 203 | x, 204 | (H // self.stride, W // self.stride), 205 | mode="bilinear", 206 | align_corners=True, 207 | ) 208 | 209 | a = _bilinear_intepolate(a) 210 | b = _bilinear_intepolate(b) 211 | c = _bilinear_intepolate(c) 212 | d = _bilinear_intepolate(d) 213 | 214 | x = self.conv2(torch.cat([a, b, c, d], dim=1)) 215 | x = self.norm2(x) 216 | x = self.relu2(x) 217 | x = self.conv3(x) 218 | return x 219 | 220 | 221 | class CorrBlock: 222 | def __init__( 223 | self, 224 | fmaps, 225 | num_levels=4, 226 | radius=4, 227 | multiple_track_feats=False, 228 | padding_mode="zeros", 229 | appearance_fact_flow_dim=None, 230 | ): 231 | B, S, C, H, W = fmaps.shape 232 | self.S, self.C, self.H, self.W = S, C, H, W 233 | self.padding_mode = padding_mode 234 | self.num_levels = num_levels 235 | self.radius = radius 236 | self.fmaps_pyramid = [] 237 | self.multiple_track_feats = multiple_track_feats 238 | self.appearance_fact_flow_dim = appearance_fact_flow_dim 239 | 240 | self.fmaps_pyramid.append(fmaps) 241 | for i in range(self.num_levels - 1): 242 | fmaps_ = fmaps.reshape(B * S, C, H, W) 243 | fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) 244 | _, _, H, W = fmaps_.shape 245 | fmaps = fmaps_.reshape(B, S, C, H, W) 246 | self.fmaps_pyramid.append(fmaps) 247 | 248 | def sample(self, coords): 249 | r = self.radius 250 | B, S, N, D = coords.shape 251 | assert D == 2 252 | 253 | H, W = self.H, self.W 254 | out_pyramid = [] 255 | for i in range(self.num_levels): 256 | corrs = self.corrs_pyramid[i] # B, S, N, H, W 257 | *_, H, W = corrs.shape 258 | 259 | dx = torch.linspace(-r, r, 2 * r + 1) 260 | dy = torch.linspace(-r, r, 2 * r + 1) 261 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device) 262 | 263 | centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i 264 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 265 | coords_lvl = centroid_lvl + delta_lvl 266 | 267 | corrs = bilinear_sampler( 268 | corrs.reshape(B * S * N, 1, H, W), 269 | coords_lvl, 270 | padding_mode=self.padding_mode, 271 | ) 272 | corrs = corrs.view(B, S, N, -1) 273 | out_pyramid.append(corrs) 274 | 275 | out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2 276 | out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float() 277 | return out 278 | 279 | def corr(self, targets, coords=None, gt_flow=None, use_gt_flow=False, use_flow_tokens=False, 280 | use_af_high_dim=False, interaction_network=None): 281 | assert sum([coords is not None, 282 | use_gt_flow, 283 | use_flow_tokens, 284 | use_af_high_dim, 285 | interaction_network is not None]) <= 1, \ 286 | "Exactly one of coords, use_gt_flow, use_flow_tokens, use_af_high_dim, or interaction_network must be specified." 287 | assert not use_flow_tokens or self.appearance_fact_flow_dim is not None 288 | 289 | # Appearance factorization 290 | if coords is not None: 291 | B, S, N, D2 = targets.shape 292 | targets = targets.reshape(B, S, N, 2, D2 // 2) 293 | if use_gt_flow: 294 | flow = gt_flow 295 | else: 296 | flow = coords[:, 1:] - coords[:, :-1] 297 | flow = torch.cat([flow[:, 0:1], flow], dim=1) 298 | flow = gt_flow if use_gt_flow else flow 299 | targets = flow[..., 0:1] * targets[..., 0, :] + flow[..., 1:2] * targets[..., 1, :] 300 | 301 | ###### DEBUG ######## 302 | # flow: [B, S, N, 2] 303 | # targets: [B, S, N, 2, feat_dim], remember feat_dim == latent_dim / 2 304 | 305 | # Appearane factorization with flow tokens 306 | if use_flow_tokens: 307 | fdim = self.appearance_fact_flow_dim 308 | flow = targets[..., -fdim:] 309 | targets = targets[..., :-fdim] 310 | B, S, N, D2 = targets.shape 311 | targets = targets.reshape(B, S, N, fdim, D2 // fdim) 312 | #targets = flow[..., 0:1] * targets[..., 0, :] + flow[..., 1:2] * targets[..., 1, :] 313 | targets = torch.einsum('btnji,btnj->btni', targets, flow) 314 | 315 | if use_af_high_dim: 316 | fdim = self.appearance_fact_flow_dim 317 | flow = targets[..., -fdim:] 318 | targets = targets[..., :-fdim] 319 | targets = flow * targets 320 | 321 | #if interaction_network is not None: 322 | # flow_feat = targets[..., :self.appearance_fact_flow_dim] 323 | # targets = targets[..., self.appearance_fact_flow_dim:] 324 | 325 | # Correlation 326 | B, S, N, C = targets.shape 327 | #assert C == self.C 328 | assert S == self.S 329 | 330 | fmap1 = targets 331 | 332 | self.corrs_pyramid = [] 333 | for _, fmaps in enumerate(self.fmaps_pyramid): 334 | *_, H, W = fmaps.shape 335 | fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) 336 | corrs = torch.matmul(fmap1, fmap2s) 337 | corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W 338 | corrs = corrs / torch.sqrt(torch.tensor(C).float()) 339 | self.corrs_pyramid.append(corrs) 340 | 341 | def sample_fmap(self, coords): 342 | '''Sample at coords directly from scaled feature map. 343 | ''' 344 | B, S, N, D = coords.shape 345 | assert D == 2 346 | 347 | H, W = self.H, self.W 348 | out_pyramid = [] 349 | for i in range(self.num_levels): 350 | 351 | fmap = self.fmaps_pyramid[i] # B, S, C, H, W 352 | B, S, C, H, W = fmap.shape 353 | 354 | coords_lvl = coords / 2**i 355 | 356 | fmap = fmap.reshape(B * S, C, H, W) 357 | coords_lvl = coords_lvl.reshape(B * S, N, 2) 358 | 359 | coords_normalized = coords_lvl.clone() 360 | coords_normalized[..., 0] = coords_normalized[..., 0] / (W - 1) * 2 - 1 361 | coords_normalized[..., 1] = coords_normalized[..., 1] / (H - 1) * 2 - 1 362 | coords_normalized = coords_normalized.unsqueeze(1) 363 | feature_at_coords = F.grid_sample(fmap, coords_normalized, 364 | mode='bilinear', align_corners=True) 365 | feature_at_coords = feature_at_coords.permute(0, 2, 3, 1).view(B, S, N, C) 366 | out_pyramid.append(feature_at_coords) 367 | 368 | out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2 369 | return out 370 | 371 | class Attention(nn.Module): 372 | def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False): 373 | super().__init__() 374 | inner_dim = dim_head * num_heads 375 | context_dim = default(context_dim, query_dim) 376 | self.scale = dim_head**-0.5 377 | self.heads = num_heads 378 | 379 | self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias) 380 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias) 381 | self.to_out = nn.Linear(inner_dim, query_dim) 382 | 383 | def forward(self, x, context=None, attn_bias=None): 384 | B, N1, C = x.shape 385 | h = self.heads 386 | 387 | q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3) 388 | context = default(context, x) 389 | k, v = self.to_kv(context).chunk(2, dim=-1) 390 | 391 | N2 = context.shape[1] 392 | k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) 393 | v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) 394 | 395 | sim = (q @ k.transpose(-2, -1)) * self.scale 396 | 397 | if attn_bias is not None: 398 | sim = sim + attn_bias 399 | attn = sim.softmax(dim=-1) 400 | 401 | x = (attn @ v).transpose(1, 2).reshape(B, N1, C) 402 | return self.to_out(x) 403 | 404 | 405 | class AttnBlock(nn.Module): 406 | def __init__( 407 | self, 408 | hidden_size, 409 | num_heads, 410 | attn_class: Callable[..., nn.Module] = Attention, 411 | mlp_ratio=4.0, 412 | **block_kwargs 413 | ): 414 | super().__init__() 415 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 416 | self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 417 | 418 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 419 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 420 | approx_gelu = lambda: nn.GELU(approximate="tanh") 421 | self.mlp = Mlp( 422 | in_features=hidden_size, 423 | hidden_features=mlp_hidden_dim, 424 | act_layer=approx_gelu, 425 | drop=0, 426 | ) 427 | 428 | def forward(self, x, mask=None): 429 | attn_bias = mask 430 | if mask is not None: 431 | mask = ( 432 | (mask[:, None] * mask[:, :, None]) 433 | .unsqueeze(1) 434 | .expand(-1, self.attn.num_heads, -1, -1) 435 | ) 436 | max_neg_value = -torch.finfo(x.dtype).max 437 | attn_bias = (~mask) * max_neg_value 438 | x = x + self.attn(self.norm1(x), attn_bias=attn_bias) 439 | x = x + self.mlp(self.norm2(x)) 440 | return x 441 | -------------------------------------------------------------------------------- /src/model/etap/core/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified by F. Hamann (TU Berlin) 8 | # Copyright (c) 2025 F. Hamann (TU Berlin) 9 | 10 | from typing import Tuple, Union 11 | import torch 12 | 13 | 14 | def get_2d_sincos_pos_embed( 15 | embed_dim: int, grid_size: Union[int, Tuple[int, int]] 16 | ) -> torch.Tensor: 17 | """ 18 | This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. 19 | It is a wrapper of get_2d_sincos_pos_embed_from_grid. 20 | Args: 21 | - embed_dim: The embedding dimension. 22 | - grid_size: The grid size. 23 | Returns: 24 | - pos_embed: The generated 2D positional embedding. 25 | """ 26 | if isinstance(grid_size, tuple): 27 | grid_size_h, grid_size_w = grid_size 28 | else: 29 | grid_size_h = grid_size_w = grid_size 30 | grid_h = torch.arange(grid_size_h, dtype=torch.float) 31 | grid_w = torch.arange(grid_size_w, dtype=torch.float) 32 | grid = torch.meshgrid(grid_w, grid_h, indexing="xy") 33 | grid = torch.stack(grid, dim=0) 34 | grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) 35 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 36 | return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) 37 | 38 | 39 | def get_2d_sincos_pos_embed_from_grid( 40 | embed_dim: int, grid: torch.Tensor 41 | ) -> torch.Tensor: 42 | """ 43 | This function generates a 2D positional embedding from a given grid using sine and cosine functions. 44 | 45 | Args: 46 | - embed_dim: The embedding dimension. 47 | - grid: The grid to generate the embedding from. 48 | 49 | Returns: 50 | - emb: The generated 2D positional embedding. 51 | """ 52 | assert embed_dim % 2 == 0 53 | 54 | # use half of dimensions to encode grid_h 55 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 56 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 57 | 58 | emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) 59 | return emb 60 | 61 | 62 | def get_1d_sincos_pos_embed_from_grid( 63 | embed_dim: int, pos: torch.Tensor 64 | ) -> torch.Tensor: 65 | """ 66 | This function generates a 1D positional embedding from a given grid using sine and cosine functions. 67 | 68 | Args: 69 | - embed_dim: The embedding dimension. 70 | - pos: The position to generate the embedding from. 71 | 72 | Returns: 73 | - emb: The generated 1D positional embedding. 74 | """ 75 | assert embed_dim % 2 == 0 76 | omega = torch.arange(embed_dim // 2, dtype=torch.double) 77 | omega /= embed_dim / 2.0 78 | omega = 1.0 / 10000**omega # (D/2,) 79 | 80 | pos = pos.reshape(-1) # (M,) 81 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 82 | 83 | emb_sin = torch.sin(out) # (M, D/2) 84 | emb_cos = torch.cos(out) # (M, D/2) 85 | 86 | emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 87 | return emb[None].float() 88 | 89 | 90 | def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: 91 | """ 92 | This function generates a 2D positional embedding from given coordinates using sine and cosine functions. 93 | 94 | Args: 95 | - xy: The coordinates to generate the embedding from. 96 | - C: The size of the embedding. 97 | - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. 98 | 99 | Returns: 100 | - pe: The generated 2D positional embedding. 101 | """ 102 | B, N, D = xy.shape 103 | assert D == 2 104 | 105 | x = xy[:, :, 0:1] 106 | y = xy[:, :, 1:2] 107 | div_term = ( 108 | torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) 109 | ).reshape(1, 1, int(C / 2)) 110 | 111 | pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 112 | pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 113 | 114 | pe_x[:, :, 0::2] = torch.sin(x * div_term) 115 | pe_x[:, :, 1::2] = torch.cos(x * div_term) 116 | 117 | pe_y[:, :, 0::2] = torch.sin(y * div_term) 118 | pe_y[:, :, 1::2] = torch.cos(y * div_term) 119 | 120 | pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) 121 | if cat_coords: 122 | pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) 123 | return pe 124 | -------------------------------------------------------------------------------- /src/model/etap/core/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified by F. Hamann (TU Berlin) 8 | # Copyright (c) 2025 F. Hamann (TU Berlin) 9 | 10 | from __future__ import annotations 11 | import torch 12 | import torch.nn.functional as F 13 | from typing import Optional, Tuple 14 | 15 | EPS = 1e-6 16 | 17 | 18 | def smart_cat(tensor1: Optional[torch.Tensor], tensor2: torch.Tensor, dim: int) -> torch.Tensor: 19 | if tensor1 is None: 20 | return tensor2 21 | return torch.cat([tensor1, tensor2], dim=dim) 22 | 23 | 24 | def get_points_on_a_grid( 25 | size: int, 26 | extent: Tuple[float, ...], 27 | center: Optional[Tuple[float, ...]] = None, 28 | device: Optional[torch.device] = torch.device("cpu"), 29 | ) -> torch.Tensor: 30 | r"""Get a grid of points covering a rectangular region 31 | 32 | `get_points_on_a_grid(size, extent)` generates a :attr:`size` by 33 | :attr:`size` grid fo points distributed to cover a rectangular area 34 | specified by `extent`. 35 | 36 | The `extent` is a pair of integer :math:`(H,W)` specifying the height 37 | and width of the rectangle. 38 | 39 | Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)` 40 | specifying the vertical and horizontal center coordinates. The center 41 | defaults to the middle of the extent. 42 | 43 | Points are distributed uniformly within the rectangle leaving a margin 44 | :math:`m=W/64` from the border. 45 | 46 | It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of 47 | points :math:`P_{ij}=(x_i, y_i)` where 48 | 49 | .. math:: 50 | P_{ij} = \left( 51 | c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~ 52 | c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i 53 | \right) 54 | 55 | Points are returned in row-major order. 56 | 57 | Args: 58 | size (int): grid size. 59 | extent (tuple): height and with of the grid extent. 60 | center (tuple, optional): grid center. 61 | device (str, optional): Defaults to `"cpu"`. 62 | 63 | Returns: 64 | Tensor: grid. 65 | """ 66 | if size == 1: 67 | return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None] 68 | 69 | if center is None: 70 | center = [extent[0] / 2, extent[1] / 2] 71 | 72 | margin = extent[1] / 64 73 | range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin) 74 | range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin) 75 | grid_y, grid_x = torch.meshgrid( 76 | torch.linspace(*range_y, size, device=device), 77 | torch.linspace(*range_x, size, device=device), 78 | indexing="ij", 79 | ) 80 | return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2) 81 | 82 | 83 | def reduce_masked_mean(input: torch.Tensor, mask: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False) -> torch.Tensor: 84 | r"""Masked mean 85 | 86 | `reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input` 87 | over a mask :attr:`mask`, returning 88 | 89 | .. math:: 90 | \text{output} = 91 | \frac 92 | {\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i} 93 | {\epsilon + \sum_{i=1}^N \text{mask}_i} 94 | 95 | where :math:`N` is the number of elements in :attr:`input` and 96 | :attr:`mask`, and :math:`\epsilon` is a small constant to avoid 97 | division by zero. 98 | 99 | `reduced_masked_mean(x, mask, dim)` computes the mean of a tensor 100 | :attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`. 101 | Optionally, the dimension can be kept in the output by setting 102 | :attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to 103 | the same dimension as :attr:`input`. 104 | 105 | The interface is similar to `torch.mean()`. 106 | 107 | Args: 108 | input (Tensor): input tensor. 109 | mask (Tensor): mask. 110 | dim (int, optional): Dimension to sum over. Defaults to None. 111 | keepdim (bool, optional): Keep the summed dimension. Defaults to False. 112 | 113 | Returns: 114 | Tensor: mean tensor. 115 | """ 116 | 117 | mask = mask.expand_as(input) 118 | 119 | prod = input * mask 120 | 121 | if dim is None: 122 | numer = torch.sum(prod) 123 | denom = torch.sum(mask) 124 | else: 125 | numer = torch.sum(prod, dim=dim, keepdim=keepdim) 126 | denom = torch.sum(mask, dim=dim, keepdim=keepdim) 127 | 128 | mean = numer / (EPS + denom) 129 | return mean 130 | 131 | 132 | def bilinear_sampler(input: torch.Tensor, coords: torch.Tensor, align_corners: bool = True, padding_mode: str = "border") -> torch.Tensor: 133 | r"""Sample a tensor using bilinear interpolation 134 | 135 | `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at 136 | coordinates :attr:`coords` using bilinear interpolation. It is the same 137 | as `torch.nn.functional.grid_sample()` but with a different coordinate 138 | convention. 139 | 140 | The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where 141 | :math:`B` is the batch size, :math:`C` is the number of channels, 142 | :math:`H` is the height of the image, and :math:`W` is the width of the 143 | image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is 144 | interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. 145 | 146 | Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, 147 | in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note 148 | that in this case the order of the components is slightly different 149 | from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. 150 | 151 | If `align_corners` is `True`, the coordinate :math:`x` is assumed to be 152 | in the range :math:`[0,W-1]`, with 0 corresponding to the center of the 153 | left-most image pixel :math:`W-1` to the center of the right-most 154 | pixel. 155 | 156 | If `align_corners` is `False`, the coordinate :math:`x` is assumed to 157 | be in the range :math:`[0,W]`, with 0 corresponding to the left edge of 158 | the left-most pixel :math:`W` to the right edge of the right-most 159 | pixel. 160 | 161 | Similar conventions apply to the :math:`y` for the range 162 | :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range 163 | :math:`[0,T-1]` and :math:`[0,T]`. 164 | 165 | Args: 166 | input (Tensor): batch of input images. 167 | coords (Tensor): batch of coordinates. 168 | align_corners (bool, optional): Coordinate convention. Defaults to `True`. 169 | padding_mode (str, optional): Padding mode. Defaults to `"border"`. 170 | 171 | Returns: 172 | Tensor: sampled points. 173 | """ 174 | 175 | sizes = input.shape[2:] 176 | 177 | assert len(sizes) in [2, 3] 178 | 179 | if len(sizes) == 3: 180 | # t x y -> x y t to match dimensions T H W in grid_sample 181 | coords = coords[..., [1, 2, 0]] 182 | 183 | if align_corners: 184 | coords = coords * torch.tensor( 185 | [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device 186 | ) 187 | else: 188 | coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) 189 | 190 | coords -= 1 191 | 192 | return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) 193 | 194 | 195 | def sample_features4d(input: torch.Tensor, coords: torch.Tensor) -> torch.Tensor: 196 | r"""Sample spatial features 197 | 198 | `sample_features4d(input, coords)` samples the spatial features 199 | :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. 200 | 201 | The field is sampled at coordinates :attr:`coords` using bilinear 202 | interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, 203 | 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the 204 | same convention as :func:`bilinear_sampler` with `align_corners=True`. 205 | 206 | The output tensor has one feature per point, and has shape :math:`(B, 207 | R, C)`. 208 | 209 | Args: 210 | input (Tensor): spatial features. 211 | coords (Tensor): points. 212 | 213 | Returns: 214 | Tensor: sampled features. 215 | """ 216 | 217 | B, _, _, _ = input.shape 218 | 219 | # B R 2 -> B R 1 2 220 | coords = coords.unsqueeze(2) 221 | 222 | # B C R 1 223 | feats = bilinear_sampler(input, coords) 224 | 225 | return feats.permute(0, 2, 1, 3).view( 226 | B, -1, feats.shape[1] * feats.shape[3] 227 | ) # B C R 1 -> B R C 228 | 229 | 230 | def sample_features5d(input: torch.Tensor, coords: torch.Tensor) -> torch.Tensor: 231 | r"""Sample spatio-temporal features 232 | 233 | `sample_features5d(input, coords)` works in the same way as 234 | :func:`sample_features4d` but for spatio-temporal features and points: 235 | :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is 236 | a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i, 237 | x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`. 238 | 239 | Args: 240 | input (Tensor): spatio-temporal features. 241 | coords (Tensor): spatio-temporal points. 242 | 243 | Returns: 244 | Tensor: sampled features. 245 | """ 246 | 247 | B, T, _, _, _ = input.shape 248 | 249 | # B T C H W -> B C T H W 250 | input = input.permute(0, 2, 1, 3, 4) 251 | 252 | # B R1 R2 3 -> B R1 R2 1 3 253 | coords = coords.unsqueeze(3) 254 | 255 | # B C R1 R2 1 256 | feats = bilinear_sampler(input, coords) 257 | 258 | return feats.permute(0, 2, 3, 1, 4).view( 259 | B, feats.shape[2], feats.shape[3], feats.shape[1] 260 | ) # B C R1 R2 1 -> B R1 R2 C 261 | -------------------------------------------------------------------------------- /src/representations/__init__.py: -------------------------------------------------------------------------------- 1 | from .event_stack import MixedDensityEventStack 2 | from .voxel_grid import VoxelGrid 3 | 4 | class EventRepresentationFactory: 5 | @staticmethod 6 | def create(representation_config): 7 | config = representation_config.copy() 8 | representation_name = config.pop('representation_name') 9 | 10 | if representation_name == 'event_stack': 11 | return MixedDensityEventStack(**config) 12 | elif representation_name == 'voxel_grid': 13 | return VoxelGrid(**config) 14 | else: 15 | raise ValueError("Unsupported representation_name.") -------------------------------------------------------------------------------- /src/representations/event_stack.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | class MixedDensityEventStack: 7 | """Create an mixed density event stack from events. 8 | Implementation inspired by https://github.com/yonseivnl/se-cff. 9 | 10 | Args: 11 | image_shape: (height, width) 12 | num_stacks: number of channels 13 | interpolation: interpolation method to use when building histograms 14 | """ 15 | def __init__(self, image_shape: Tuple[int, int], num_stacks: int, 16 | interpolation: str = 'nearest_neighbor', channel_overlap=False, 17 | centered_channels=False) -> None: 18 | assert image_shape[0] > 0 19 | assert image_shape[1] > 0 20 | assert num_stacks > 0 21 | self.image_shape = image_shape 22 | self.num_stacks = num_stacks 23 | self.interpolation = interpolation 24 | self.channel_overlap = channel_overlap 25 | self.centered_channels = centered_channels 26 | assert self.interpolation in ['nearest_neighbor', 'bilinear'] 27 | assert not self.centered_channels or (self.centered_channels and self.channel_overlap), "Centered channels require channel overlap" 28 | 29 | def __call__(self, events: np.ndarray, t_mid=None) -> np.ndarray: 30 | """Create mixed density event stack. 31 | 32 | Args: 33 | events: events: a NumPy array of size [n x d], where n is the number of events and d = 4. 34 | Every event is encoded with 4 values (y, x, t, p). 35 | 36 | Returns: 37 | A mixed density event stack representation of the event data. 38 | """ 39 | assert events.shape[1] == 4 40 | assert not self.centered_channels or t_mid is not None, "Centered channels require t_mid" 41 | stacked_event_list = [] 42 | curr_num_events = len(events) 43 | 44 | for _ in range(self.num_stacks): 45 | if self.interpolation == 'nearest_neighbor': 46 | stacked_event = self.stack_data_nearest_neighbor(events) 47 | elif self.interpolation == 'bilinear': 48 | stacked_event = self.stack_data_bilinear(events) 49 | 50 | stacked_event_list.append(stacked_event) 51 | curr_num_events = curr_num_events // 2 52 | 53 | if self.centered_channels: 54 | i_mid = np.searchsorted(events[:, 2], t_mid) 55 | i_start = max(i_mid - curr_num_events // 2, 0) 56 | i_end = min(i_mid + curr_num_events // 2, len(events)) 57 | events = events[i_start:i_end] 58 | else: 59 | events = events[curr_num_events:] 60 | 61 | if not self.channel_overlap: 62 | for stack_idx in range(self.num_stacks - 1): 63 | prev_stack_polarity = stacked_event_list[stack_idx] 64 | next_stack_polarity = stacked_event_list[stack_idx + 1] 65 | diff_stack_polarity = prev_stack_polarity - next_stack_polarity 66 | stacked_event_list[stack_idx] = diff_stack_polarity 67 | 68 | return np.stack(stacked_event_list, axis=0) 69 | 70 | def stack_data_nearest_neighbor(self, events): 71 | height, width = self.image_shape 72 | y = events[:, 0].astype(int) 73 | x = events[:, 1].astype(int) 74 | p = events[:, 3] 75 | p[p == 0] = -1 76 | 77 | stacked_polarity = np.zeros([height, width], dtype=np.int8) 78 | index = (y * width) + x 79 | stacked_polarity.put(index, p) 80 | 81 | return stacked_polarity 82 | 83 | def stack_data_bilinear(self, events): 84 | if len(events.shape) == 2: 85 | events = events[None, ...] # 1 x n x 4 86 | 87 | weight = events[..., 3] 88 | weight[weight == 0] = -1 89 | 90 | h, w = self.image_shape 91 | nb = len(events) 92 | image = np.zeros((nb, h * w), dtype=np.float64) 93 | 94 | floor_xy = np.floor(events[..., :2] + 1e-8) 95 | floor_to_xy = events[..., :2] - floor_xy 96 | 97 | x1 = floor_xy[..., 1] 98 | y1 = floor_xy[..., 0] 99 | inds = np.concatenate( 100 | [ 101 | x1 + y1 * w, 102 | x1 + (y1 + 1) * w, 103 | (x1 + 1) + y1 * w, 104 | (x1 + 1) + (y1 + 1) * w, 105 | ], 106 | axis=-1, 107 | ) 108 | inds_mask = np.concatenate( 109 | [ 110 | (0 <= x1) * (x1 < w) * (0 <= y1) * (y1 < h), 111 | (0 <= x1) * (x1 < w) * (0 <= y1 + 1) * (y1 + 1 < h), 112 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1) * (y1 < h), 113 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1 + 1) * (y1 + 1 < h), 114 | ], 115 | axis=-1, 116 | ) 117 | w_pos0 = (1 - floor_to_xy[..., 0]) * (1 - floor_to_xy[..., 1]) * weight 118 | w_pos1 = floor_to_xy[..., 0] * (1 - floor_to_xy[..., 1]) * weight 119 | w_pos2 = (1 - floor_to_xy[..., 0]) * floor_to_xy[..., 1] * weight 120 | w_pos3 = floor_to_xy[..., 0] * floor_to_xy[..., 1] * weight 121 | vals = np.concatenate([w_pos0, w_pos1, w_pos2, w_pos3], axis=-1) 122 | inds = (inds * inds_mask).astype(np.int64) 123 | vals = vals * inds_mask 124 | for i in range(nb): 125 | np.add.at(image[i], inds[i], vals[i]) 126 | return image.reshape((nb,) + self.image_shape).squeeze() 127 | -------------------------------------------------------------------------------- /src/representations/voxel_grid.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | class VoxelGrid: 7 | """Create a voxel grid from events. 8 | Implementation inspired by https://github.com/uzh-rpg/rpg_e2vid. 9 | 10 | Args: 11 | image_shape: (height, width) 12 | num_bins: number of bins in the temporal axis of the voxel grid 13 | """ 14 | def __init__(self, image_shape: Tuple[int, int], num_bins: int) -> None: 15 | assert image_shape[0] > 0 16 | assert image_shape[1] > 0 17 | assert num_bins > 0 18 | self.image_shape = image_shape 19 | self.num_bins = num_bins 20 | 21 | def __call__(self, events: np.ndarray) -> np.ndarray: 22 | """Create voxel grid. 23 | 24 | Args: 25 | events: events: a NumPy array of size [n x d], where n is the number of events and d = 4. 26 | Every event is encoded with 4 values (y, x, t, p). 27 | 28 | Returns: 29 | A voxelized representation of the event data. 30 | """ 31 | assert events.shape[1] == 4 32 | height, width = self.image_shape 33 | voxel_grid = np.zeros((self.num_bins, height, width), np.float32).ravel() 34 | y = events[:, 0].astype(int) 35 | x = events[:, 1].astype(int) 36 | t = events[:, 2] 37 | p = events[:, 3] 38 | p[p == 0] = -1 39 | 40 | # Normalize the event timestamps so that they lie between 0 and num_bins 41 | t_min, t_max = np.amin(t), np.amax(t) 42 | delta_t = t_max - t_min 43 | 44 | if delta_t == 0: 45 | delta_t = 1 46 | 47 | t = (self.num_bins - 1) * (t - t_min) / delta_t 48 | tis = t.astype(int) 49 | dts = t - tis 50 | vals_left = p * (1.0 - dts) 51 | vals_right = p * dts 52 | valid_indices = tis < self.num_bins 53 | 54 | np.add.at( 55 | voxel_grid, 56 | x[valid_indices] + y[valid_indices] * width + tis[valid_indices] * width * height, 57 | vals_left[valid_indices], 58 | ) 59 | 60 | valid_indices = (tis + 1) < self.num_bins 61 | np.add.at( 62 | voxel_grid, 63 | x[valid_indices] + y[valid_indices] * width + (tis[valid_indices] + 1) * width * height, 64 | vals_right[valid_indices], 65 | ) 66 | 67 | voxel_grid = np.reshape(voxel_grid, (self.num_bins, height, width)) 68 | 69 | return voxel_grid 70 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .supported_seqs_feature_tracking import SUPPORTED_SEQUENCES_FEATURE_TRACKING 2 | from .visualizer import Visualizer, normalize_and_expand_channels 3 | from .track_utils import compute_tracking_errors, read_txt_results 4 | from .metrics import compute_tapvid_metrics 5 | from .misc import make_grid -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from typing import Mapping, List 16 | 17 | import numpy as np 18 | 19 | 20 | def compute_tapvid_metrics( 21 | query_points: np.ndarray, 22 | gt_occluded: np.ndarray, 23 | gt_tracks: np.ndarray, 24 | pred_occluded: np.ndarray, 25 | pred_tracks: np.ndarray, 26 | query_mode: str, 27 | get_trackwise_metrics: bool = False, 28 | thresholds: List[int] = [1, 2, 4, 8, 16], 29 | ) -> Mapping[str, np.ndarray]: 30 | """Computes TAP-Vid metrics (Jaccard, Pts. 31 | 32 | Within Thresh, Occ. 33 | 34 | Acc.) 35 | 36 | See the TAP-Vid paper for details on the metric computation. All inputs are 37 | given in raster coordinates. The first three arguments should be the direct 38 | outputs of the reader: the 'query_points', 'occluded', and 'target_points'. 39 | The paper metrics assume these are scaled relative to 256x256 images. 40 | pred_occluded and pred_tracks are your algorithm's predictions. 41 | 42 | This function takes a batch of inputs, and computes metrics separately for 43 | each video. The metrics for the full benchmark are a simple mean of the 44 | metrics across the full set of videos. These numbers are between 0 and 1, 45 | but the paper multiplies them by 100 to ease reading. 46 | 47 | Args: 48 | query_points: The query points, an in the format [t, y, x]. Its size is 49 | [b, n, 3], where b is the batch size and n is the number of queries 50 | gt_occluded: A boolean array of shape [b, n, t], where t is the number of 51 | frames. True indicates that the point is occluded. 52 | gt_tracks: The target points, of shape [b, n, t, 2]. Each point is in the 53 | format [x, y] 54 | pred_occluded: A boolean array of predicted occlusions, in the same format 55 | as gt_occluded. 56 | pred_tracks: An array of track predictions from your algorithm, in the same 57 | format as gt_tracks. 58 | query_mode: Either 'first' or 'strided', depending on how queries are 59 | sampled. If 'first', we assume the prior knowledge that all points 60 | before the query point are occluded, and these are removed from the 61 | evaluation. 62 | get_trackwise_metrics: if True, the metrics will be computed for every 63 | track (rather than every video, which is the default). This means 64 | every output tensor will have an extra axis [batch, num_tracks] rather 65 | than simply (batch). 66 | 67 | Returns: 68 | A dict with the following keys: 69 | 70 | occlusion_accuracy: Accuracy at predicting occlusion. 71 | pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points 72 | predicted to be within the given pixel threshold, ignoring occlusion 73 | prediction. 74 | jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given 75 | threshold 76 | average_pts_within_thresh: average across pts_within_{x} 77 | average_jaccard: average across jaccard_{x} 78 | """ 79 | 80 | summing_axis = (2,) if get_trackwise_metrics else (1, 2) 81 | 82 | metrics = {} 83 | 84 | eye = np.eye(gt_tracks.shape[2], dtype=np.int32) 85 | if query_mode == 'first': 86 | # evaluate frames after the query frame 87 | query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye 88 | elif query_mode == 'strided': 89 | # evaluate all frames except the query frame 90 | query_frame_to_eval_frames = 1 - eye 91 | else: 92 | raise ValueError('Unknown query mode ' + query_mode) 93 | 94 | query_frame = query_points[..., 0] 95 | query_frame = np.round(query_frame).astype(np.int32) 96 | evaluation_points = query_frame_to_eval_frames[query_frame] > 0 97 | 98 | # Occlusion accuracy is simply how often the predicted occlusion equals the 99 | # ground truth. 100 | occ_acc = np.sum( 101 | np.equal(pred_occluded, gt_occluded) & evaluation_points, 102 | axis=summing_axis, 103 | ) / np.sum(evaluation_points, axis=summing_axis) 104 | metrics['occlusion_accuracy'] = occ_acc 105 | 106 | # Next, convert the predictions and ground truth positions into pixel 107 | # coordinates. 108 | visible = np.logical_not(gt_occluded) 109 | pred_visible = np.logical_not(pred_occluded) 110 | all_frac_within = [] 111 | all_jaccard = [] 112 | for thresh in thresholds: 113 | # True positives are points that are within the threshold and where both 114 | # the prediction and the ground truth are listed as visible. 115 | within_dist = np.sum( 116 | np.square(pred_tracks - gt_tracks), 117 | axis=-1, 118 | ) < np.square(thresh) 119 | is_correct = np.logical_and(within_dist, visible) 120 | 121 | # Compute the frac_within_threshold, which is the fraction of points 122 | # within the threshold among points that are visible in the ground truth, 123 | # ignoring whether they're predicted to be visible. 124 | count_correct = np.sum( 125 | is_correct & evaluation_points, 126 | axis=summing_axis, 127 | ) 128 | count_visible_points = np.sum( 129 | visible & evaluation_points, axis=summing_axis 130 | ) 131 | frac_correct = count_correct / count_visible_points 132 | metrics['pts_within_' + str(thresh)] = frac_correct 133 | all_frac_within.append(frac_correct) 134 | 135 | true_positives = np.sum( 136 | is_correct & pred_visible & evaluation_points, axis=summing_axis 137 | ) 138 | 139 | # The denominator of the jaccard metric is the true positives plus 140 | # false positives plus false negatives. However, note that true positives 141 | # plus false negatives is simply the number of points in the ground truth 142 | # which is easier to compute than trying to compute all three quantities. 143 | # Thus we just add the number of points in the ground truth to the number 144 | # of false positives. 145 | # 146 | # False positives are simply points that are predicted to be visible, 147 | # but the ground truth is not visible or too far from the prediction. 148 | gt_positives = np.sum(visible & evaluation_points, axis=summing_axis) 149 | false_positives = (~visible) & pred_visible 150 | false_positives = false_positives | ((~within_dist) & pred_visible) 151 | false_positives = np.sum( 152 | false_positives & evaluation_points, axis=summing_axis 153 | ) 154 | jaccard = true_positives / (gt_positives + false_positives) 155 | metrics['jaccard_' + str(thresh)] = jaccard 156 | all_jaccard.append(jaccard) 157 | metrics['average_jaccard'] = np.mean( 158 | np.stack(all_jaccard, axis=1), 159 | axis=1, 160 | ) 161 | metrics['average_pts_within_thresh'] = np.mean( 162 | np.stack(all_frac_within, axis=1), 163 | axis=1, 164 | ) 165 | return metrics 166 | 167 | 168 | def latex_table(mean_scalars: Mapping[str, float]) -> str: 169 | """Generate a latex table for displaying TAP-Vid and PCK metrics.""" 170 | if 'average_jaccard' in mean_scalars: 171 | latex_fields = [ 172 | 'average_jaccard', 173 | 'average_pts_within_thresh', 174 | 'occlusion_accuracy', 175 | 'jaccard_1', 176 | 'jaccard_2', 177 | 'jaccard_4', 178 | 'jaccard_8', 179 | 'jaccard_16', 180 | 'pts_within_1', 181 | 'pts_within_2', 182 | 'pts_within_4', 183 | 'pts_within_8', 184 | 'pts_within_16', 185 | ] 186 | header = ( 187 | 'AJ & $<\\delta^{x}_{avg}$ & OA & Jac. $\\delta^{0}$ & ' 188 | + 'Jac. $\\delta^{1}$ & Jac. $\\delta^{2}$ & ' 189 | + 'Jac. $\\delta^{3}$ & Jac. $\\delta^{4}$ & $<\\delta^{0}$ & ' 190 | + '$<\\delta^{1}$ & $<\\delta^{2}$ & $<\\delta^{3}$ & ' 191 | + '$<\\delta^{4}$' 192 | ) 193 | else: 194 | latex_fields = ['PCK@0.1', 'PCK@0.2', 'PCK@0.3', 'PCK@0.4', 'PCK@0.5'] 195 | header = ' & '.join(latex_fields) 196 | 197 | body = ' & '.join( 198 | [f'{float(np.array(mean_scalars[x]*100)):.3}' for x in latex_fields] 199 | ) 200 | return '\n'.join([header, body]) 201 | 202 | 203 | def sample_queries_strided( 204 | target_occluded: np.ndarray, 205 | target_points: np.ndarray, 206 | frames: np.ndarray, 207 | query_stride: int = 5, 208 | ) -> Mapping[str, np.ndarray]: 209 | """Package a set of frames and tracks for use in TAPNet evaluations. 210 | 211 | Given a set of frames and tracks with no query points, sample queries 212 | strided every query_stride frames, ignoring points that are not visible 213 | at the selected frames. 214 | 215 | Args: 216 | target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], 217 | where True indicates occluded. 218 | target_points: Position, of shape [n_tracks, n_frames, 2], where each point 219 | is [x,y] scaled between 0 and 1. 220 | frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between 221 | -1 and 1. 222 | query_stride: When sampling query points, search for un-occluded points 223 | every query_stride frames and convert each one into a query. 224 | 225 | Returns: 226 | A dict with the keys: 227 | video: Video tensor of shape [1, n_frames, height, width, 3]. The video 228 | has floats scaled to the range [-1, 1]. 229 | query_points: Query points of shape [1, n_queries, 3] where 230 | each point is [t, y, x] scaled to the range [-1, 1]. 231 | target_points: Target points of shape [1, n_queries, n_frames, 2] where 232 | each point is [x, y] scaled to the range [-1, 1]. 233 | trackgroup: Index of the original track that each query point was 234 | sampled from. This is useful for visualization. 235 | """ 236 | tracks = [] 237 | occs = [] 238 | queries = [] 239 | trackgroups = [] 240 | total = 0 241 | trackgroup = np.arange(target_occluded.shape[0]) 242 | for i in range(0, target_occluded.shape[1], query_stride): 243 | mask = target_occluded[:, i] == 0 244 | query = np.stack( 245 | [ 246 | i * np.ones(target_occluded.shape[0:1]), 247 | target_points[:, i, 1], 248 | target_points[:, i, 0], 249 | ], 250 | axis=-1, 251 | ) 252 | queries.append(query[mask]) 253 | tracks.append(target_points[mask]) 254 | occs.append(target_occluded[mask]) 255 | trackgroups.append(trackgroup[mask]) 256 | total += np.array(np.sum(target_occluded[:, i] == 0)) 257 | 258 | return { 259 | 'video': frames[np.newaxis, ...], 260 | 'query_points': np.concatenate(queries, axis=0)[np.newaxis, ...], 261 | 'target_points': np.concatenate(tracks, axis=0)[np.newaxis, ...], 262 | 'occluded': np.concatenate(occs, axis=0)[np.newaxis, ...], 263 | 'trackgroup': np.concatenate(trackgroups, axis=0)[np.newaxis, ...], 264 | } 265 | 266 | 267 | def sample_queries_first( 268 | target_occluded: np.ndarray, 269 | target_points: np.ndarray, 270 | frames: np.ndarray, 271 | ) -> Mapping[str, np.ndarray]: 272 | """Package a set of frames and tracks for use in TAPNet evaluations. 273 | 274 | Given a set of frames and tracks with no query points, use the first 275 | visible point in each track as the query. 276 | 277 | Args: 278 | target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], 279 | where True indicates occluded. 280 | target_points: Position, of shape [n_tracks, n_frames, 2], where each point 281 | is [x,y] scaled between 0 and 1. 282 | frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between 283 | -1 and 1. 284 | 285 | Returns: 286 | A dict with the keys: 287 | video: Video tensor of shape [1, n_frames, height, width, 3] 288 | query_points: Query points of shape [1, n_queries, 3] where 289 | each point is [t, y, x] scaled to the range [-1, 1] 290 | target_points: Target points of shape [1, n_queries, n_frames, 2] where 291 | each point is [x, y] scaled to the range [-1, 1] 292 | """ 293 | 294 | valid = np.sum(~target_occluded, axis=1) > 0 295 | target_points = target_points[valid, :] 296 | target_occluded = target_occluded[valid, :] 297 | 298 | query_points = [] 299 | for i in range(target_points.shape[0]): 300 | index = np.where(target_occluded[i] == 0)[0][0] 301 | x, y = target_points[i, index, 0], target_points[i, index, 1] 302 | query_points.append(np.array([index, y, x])) # [t, y, x] 303 | query_points = np.stack(query_points, axis=0) 304 | 305 | return { 306 | 'video': frames[np.newaxis, ...], 307 | 'query_points': query_points[np.newaxis, ...], 308 | 'target_points': target_points[np.newaxis, ...], 309 | 'occluded': target_occluded[np.newaxis, ...], 310 | } 311 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def make_grid(height, width, stride=1): 4 | """ 5 | Create a grid of points with the given stride. 6 | 7 | Args: 8 | height (int): Grid height 9 | width (int): Grid width 10 | stride (int): Spacing between points 11 | 12 | Returns: 13 | np.ndarray: Grid points of shape ((height//stride)*(width//stride), 2) 14 | """ 15 | x = np.arange(0, width, stride) 16 | y = np.arange(0, height, stride) 17 | X, Y = np.meshgrid(x, y) 18 | return np.stack([X.flatten(), Y.flatten()], axis=1) -------------------------------------------------------------------------------- /src/utils/supported_seqs_feature_tracking.py: -------------------------------------------------------------------------------- 1 | SUPPORTED_SEQUENCES_FEATURE_TRACKING = { 2 | 'eds': [ 3 | '01_peanuts_light', 4 | '02_rocket_earth_light', 5 | '08_peanuts_running', 6 | '14_ziggy_in_the_arena', 7 | 8 | ], 9 | 'ec': [ 10 | 'shapes_rotation', 11 | 'boxes_rotation', 12 | 'boxes_translation', 13 | 'shapes_6dof', 14 | 'shapes_translation' 15 | ] 16 | } -------------------------------------------------------------------------------- /src/utils/visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified by F. Hamann (TU Berlin) 8 | # Copyright (c) 2025 F. Hamann (TU Berlin) 9 | import os 10 | import numpy as np 11 | import imageio 12 | import torch 13 | 14 | from matplotlib import cm 15 | import torch.nn.functional as F 16 | import torchvision.transforms as transforms 17 | import matplotlib.pyplot as plt 18 | from PIL import Image, ImageDraw 19 | 20 | 21 | def normalize_and_expand_channels(image): 22 | if not isinstance(image, torch.Tensor): 23 | raise TypeError("Input must be a torch tensor") 24 | 25 | *batch_dims, height, width = image.shape 26 | 27 | if len(image.shape) < 2: 28 | raise ValueError("Input tensor must have at least shape (..., height, width)") 29 | 30 | image_flat = image.view(-1, height, width) 31 | 32 | min_val = image_flat.min() 33 | max_val = image_flat.max() 34 | range_val = max_val - min_val 35 | range_val[range_val == 0] = 1 36 | 37 | image_normalized = (image_flat - min_val) / range_val * 255 38 | image_rgb_flat = image_normalized.to(torch.uint8).unsqueeze(1).repeat(1, 3, 1, 1) 39 | image_rgb = image_rgb_flat.view(*batch_dims, 3, height, width) 40 | 41 | return image_rgb 42 | 43 | 44 | def read_video_from_path(path): 45 | try: 46 | reader = imageio.get_reader(path) 47 | except Exception as e: 48 | print("Error opening video file: ", e) 49 | return None 50 | frames = [] 51 | for i, im in enumerate(reader): 52 | frames.append(np.array(im)) 53 | return np.stack(frames) 54 | 55 | 56 | def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True): 57 | # Create a draw object 58 | draw = ImageDraw.Draw(rgb) 59 | # Calculate the bounding box of the circle 60 | left_up_point = (coord[0] - radius, coord[1] - radius) 61 | right_down_point = (coord[0] + radius, coord[1] + radius) 62 | # Draw the circle 63 | draw.ellipse( 64 | [left_up_point, right_down_point], 65 | fill=tuple(color) if visible else None, 66 | outline=tuple(color), 67 | ) 68 | return rgb 69 | 70 | 71 | def draw_line(rgb, coord_y, coord_x, color, linewidth): 72 | draw = ImageDraw.Draw(rgb, 'RGBA') # Use RGBA mode for better antialiasing 73 | draw.line( 74 | (coord_y[0], coord_y[1], coord_x[0], coord_x[1]), 75 | fill=color, 76 | width=linewidth, 77 | joint="round" # Add round joints to reduce artifacts at line intersections 78 | ) 79 | return rgb 80 | 81 | 82 | def add_weighted(rgb, alpha, original, beta, gamma): 83 | return (rgb * alpha + original * beta + gamma).astype("uint8") 84 | 85 | 86 | class Visualizer: 87 | def __init__( 88 | self, 89 | save_dir: str = "./results", 90 | grayscale: bool = False, 91 | pad_value: int = 0, 92 | fps: int = 10, 93 | mode: str = "rainbow", # 'cool', 'optical_flow' 94 | linewidth: int = 2, 95 | show_first_frame: int = 10, 96 | tracks_leave_trace: int = 0, # -1 for infinite 97 | consistent_colors: bool = False, 98 | ): 99 | self.mode = mode 100 | self.save_dir = save_dir 101 | if mode == "rainbow": 102 | self.color_map = cm.get_cmap("gist_rainbow") 103 | elif mode == "cool": 104 | self.color_map = cm.get_cmap(mode) 105 | self.show_first_frame = show_first_frame 106 | self.grayscale = grayscale 107 | self.tracks_leave_trace = tracks_leave_trace 108 | self.pad_value = pad_value 109 | self.linewidth = linewidth 110 | self.fps = fps 111 | self.consistent_colors = consistent_colors 112 | self.color = None 113 | 114 | def visualize( 115 | self, 116 | video: torch.Tensor, # (B,T,C,H,W) 117 | tracks: torch.Tensor, # (B,T,N,2) 118 | visibility: torch.Tensor = None, # (B, T, N, 1) bool 119 | gt_tracks: torch.Tensor = None, # (B,T,N,2) 120 | segm_mask: torch.Tensor = None, # (B,1,H,W) 121 | filename: str = "video", 122 | writer=None, # tensorboard Summary Writer, used for visualization during training 123 | step: int = 0, 124 | query_frame: int = 0, 125 | save_video: bool = True, 126 | compensate_for_camera_motion: bool = False, 127 | mask_segm_tracks: bool = False, 128 | ): 129 | if compensate_for_camera_motion: 130 | assert segm_mask is not None 131 | if segm_mask is not None: 132 | coords = tracks[0, query_frame].round().long() 133 | segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() 134 | 135 | video = F.pad( 136 | video, 137 | (self.pad_value, self.pad_value, self.pad_value, self.pad_value), 138 | "constant", 139 | 255, 140 | ) 141 | tracks = tracks + self.pad_value 142 | 143 | if self.grayscale: 144 | transform = transforms.Grayscale() 145 | video = transform(video) 146 | video = video.repeat(1, 1, 3, 1, 1) 147 | 148 | if mask_segm_tracks: 149 | tracks = tracks[:, :, segm_mask != 0] 150 | visibility = visibility[:, :, segm_mask != 0] if visibility is not None else None 151 | segm_mask = None 152 | 153 | res_video = self.draw_tracks_on_video( 154 | video=video, 155 | tracks=tracks, 156 | visibility=visibility, 157 | segm_mask=segm_mask, 158 | gt_tracks=gt_tracks, 159 | query_frame=query_frame, 160 | compensate_for_camera_motion=compensate_for_camera_motion, 161 | ) 162 | if save_video: 163 | self.save_video(res_video, filename=filename, writer=writer, step=step) 164 | return res_video 165 | 166 | def save_video(self, video, filename, writer=None, step=0): 167 | if writer is not None: 168 | writer.add_video( 169 | filename, 170 | video.to(torch.uint8), 171 | global_step=step, 172 | fps=self.fps, 173 | ) 174 | else: 175 | os.makedirs(self.save_dir, exist_ok=True) 176 | wide_list = list(video.unbind(1)) 177 | wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] 178 | 179 | # Prepare the video file path 180 | save_path = os.path.join(self.save_dir, f"{filename}.mp4") 181 | 182 | # Create a writer object 183 | video_writer = imageio.get_writer(save_path, fps=self.fps, macro_block_size=1) 184 | 185 | # Write frames to the video file 186 | for frame in wide_list[2:-1]: 187 | video_writer.append_data(frame) 188 | 189 | video_writer.close() 190 | 191 | #print(f"Video saved to {save_path}") 192 | 193 | def draw_tracks_on_video( 194 | self, 195 | video: torch.Tensor, 196 | tracks: torch.Tensor, 197 | visibility: torch.Tensor = None, 198 | segm_mask: torch.Tensor = None, 199 | gt_tracks=None, 200 | query_frame: int = 0, 201 | compensate_for_camera_motion=False, 202 | ): 203 | B, T, C, H, W = video.shape 204 | _, _, N, D = tracks.shape 205 | 206 | assert D == 2 207 | assert C == 3 208 | video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C 209 | valid_mask = ~torch.all(torch.isnan(tracks)[0], dim=-1).cpu().numpy() 210 | tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2 211 | if gt_tracks is not None: 212 | gt_tracks = gt_tracks[0].detach().cpu().numpy() 213 | 214 | res_video = [] 215 | 216 | # process input video 217 | for rgb in video: 218 | res_video.append(rgb.copy()) 219 | vector_colors = np.zeros((T, N, 3)) 220 | 221 | if self.mode == "optical_flow": 222 | import flow_vis 223 | 224 | vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) 225 | elif segm_mask is None: 226 | if self.mode == "rainbow": 227 | if self.color is None or not self.consistent_colors: 228 | y_min, y_max = ( 229 | tracks[query_frame, :, 1].min(), 230 | tracks[query_frame, :, 1].max(), 231 | ) 232 | norm = plt.Normalize(y_min, y_max) 233 | for n in range(N): 234 | color = self.color_map(norm(tracks[query_frame, n, 1])) 235 | color = np.array(color[:3])[None] * 255 236 | vector_colors[:, n] = np.repeat(color, T, axis=0) 237 | 238 | if self.color is None: 239 | self.color = vector_colors 240 | else: 241 | vector_colors = self.color 242 | else: 243 | # color changes with time 244 | for t in range(T): 245 | color = np.array(self.color_map(t / T)[:3])[None] * 255 246 | vector_colors[t] = np.repeat(color, N, axis=0) 247 | else: 248 | if self.mode == "rainbow": 249 | vector_colors[:, segm_mask <= 0, :] = 255 250 | 251 | y_min, y_max = ( 252 | tracks[0, segm_mask > 0, 1].min(), 253 | tracks[0, segm_mask > 0, 1].max(), 254 | ) 255 | norm = plt.Normalize(y_min, y_max) 256 | for n in range(N): 257 | if segm_mask[n] > 0: 258 | color = self.color_map(norm(tracks[0, n, 1])) 259 | color = np.array(color[:3])[None] * 255 260 | vector_colors[:, n] = np.repeat(color, T, axis=0) 261 | 262 | else: 263 | # color changes with segm class 264 | segm_mask = segm_mask.cpu() 265 | color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) 266 | color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 267 | color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 268 | vector_colors = np.repeat(color[None], T, axis=0) 269 | 270 | # draw tracks 271 | if self.tracks_leave_trace != 0: 272 | for t in range(query_frame + 1, T): 273 | first_ind = ( 274 | max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0 275 | ) 276 | curr_tracks = tracks[first_ind : t + 1] 277 | curr_colors = vector_colors[first_ind : t + 1] 278 | if compensate_for_camera_motion: 279 | diff = ( 280 | tracks[first_ind : t + 1, segm_mask <= 0] 281 | - tracks[t : t + 1, segm_mask <= 0] 282 | ).mean(1)[:, None] 283 | 284 | curr_tracks = curr_tracks - diff 285 | curr_tracks = curr_tracks[:, segm_mask > 0] 286 | curr_colors = curr_colors[:, segm_mask > 0] 287 | 288 | res_video[t] = self._draw_pred_tracks( 289 | res_video[t], 290 | curr_tracks, 291 | curr_colors, 292 | ) 293 | if gt_tracks is not None: 294 | res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1]) 295 | 296 | # draw points 297 | for t in range(query_frame, T): 298 | img = Image.fromarray(np.uint8(res_video[t])) 299 | for i in range(N): 300 | coord = (tracks[t, i, 0], tracks[t, i, 1]) 301 | visibile = True 302 | if visibility is not None: 303 | visibile = visibility[0, t, i] 304 | if coord[0] != 0 and coord[1] != 0: 305 | if not compensate_for_camera_motion or ( 306 | compensate_for_camera_motion and segm_mask[i] > 0 307 | ): 308 | if valid_mask[t, i]: 309 | img = draw_circle( 310 | img, 311 | coord=coord, 312 | radius=int(self.linewidth * 2), 313 | color=vector_colors[t, i].astype(int), 314 | visible=visibile, 315 | ) 316 | res_video[t] = np.array(img) 317 | 318 | # construct the final rgb sequence 319 | if self.show_first_frame > 0: 320 | res_video = [res_video[0]] * self.show_first_frame + res_video[1:] 321 | return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() 322 | 323 | def _draw_pred_tracks( 324 | self, 325 | rgb: np.ndarray, # H x W x 3 326 | tracks: np.ndarray, # T x 2 327 | vector_colors: np.ndarray, 328 | base_alpha: float = 0.5, 329 | ): 330 | T, N, _ = tracks.shape 331 | # Convert to PIL Image once at the start 332 | rgb = Image.fromarray(np.uint8(rgb)) 333 | 334 | # Create a single composite image for all tracks 335 | for s in range(T - 1): 336 | vector_color = vector_colors[s] 337 | time_alpha = (s / T) ** 2 if self.tracks_leave_trace > 0 else 1.0 338 | 339 | for i in range(N): 340 | coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) 341 | coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) 342 | 343 | if coord_y[0] != 0 and coord_y[1] != 0 and coord_x[0] != 0 and coord_x[1] != 0: 344 | # Only draw valid coordinates 345 | color = (*vector_color[i].astype(int),) # Convert to tuple 346 | rgb = draw_line( 347 | rgb, 348 | coord_y, 349 | coord_x, 350 | color, 351 | self.linewidth, 352 | ) 353 | 354 | return np.array(rgb) 355 | 356 | def _draw_gt_tracks( 357 | self, 358 | rgb: np.ndarray, # H x W x 3, 359 | gt_tracks: np.ndarray, # T x 2 360 | ): 361 | T, N, _ = gt_tracks.shape 362 | color = np.array((211, 0, 0)) 363 | rgb = Image.fromarray(np.uint8(rgb)) 364 | for t in range(T): 365 | for i in range(N): 366 | gt_tracks = gt_tracks[t][i] 367 | # draw a red cross 368 | if gt_tracks[0] > 0 and gt_tracks[1] > 0: 369 | length = self.linewidth * 3 370 | coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) 371 | coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) 372 | rgb = draw_line( 373 | rgb, 374 | coord_y, 375 | coord_x, 376 | color, 377 | self.linewidth, 378 | ) 379 | coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) 380 | coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) 381 | rgb = draw_line( 382 | rgb, 383 | coord_y, 384 | coord_x, 385 | color, 386 | self.linewidth, 387 | ) 388 | rgb = np.array(rgb) 389 | return rgb 390 | --------------------------------------------------------------------------------