├── .gitignore ├── LICENSE ├── README.md ├── callbacks ├── logger.py └── utils │ ├── flow_vis.py │ └── visualization.py ├── config ├── dataset │ ├── base.yaml │ ├── dsec.yaml │ └── multiflow_regen.yaml ├── experiment │ ├── dsec │ │ └── raft_spline │ │ │ ├── E_I_LU4_BD2_lowpyramid.yaml │ │ │ └── E_LU4_BD2_lowpyramid.yaml │ └── multiflow │ │ └── raft_spline │ │ ├── E_I_LU5_BD10_lowpyramid.yaml │ │ └── E_LU5_BD10_lowpyramid.yaml ├── general.yaml ├── model │ ├── base.yaml │ ├── raft-spline.yaml │ └── raft_base.yaml ├── train.yaml └── val.yaml ├── data ├── dsec │ ├── eventslicer.py │ ├── provider.py │ ├── sequence.py │ └── subsequence │ │ ├── base.py │ │ └── twostep.py ├── multiflow2d │ ├── datasubset.py │ ├── provider.py │ └── sample.py └── utils │ ├── augmentor.py │ ├── generic.py │ ├── keys.py │ ├── provider.py │ └── representations.py ├── loggers └── wandb_logger.py ├── models ├── raft_spline │ ├── bezier.py │ ├── raft.py │ └── update.py └── raft_utils │ ├── corr.py │ ├── extractor.py │ └── utils.py ├── modules ├── data_loading.py ├── raft_spline.py └── utils.py ├── train.py ├── utils ├── general.py ├── losses.py ├── metrics.py └── timers.py └── val.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dense Continuous-Time Optical Flow from Event Cameras 2 | 3 | ![readme](https://github.com/uzh-rpg/bflow/assets/6841681/2b8c7a7f-3c75-49d4-85cd-51c78b0884d3) 4 | 5 | This is the official Pytorch implementation of the TPAMI 2024 paper [Dense Continuous-Time Optical Flow from Event Cameras](https://ieeexplore.ieee.org/document/10419040). 6 | 7 | If you find this code useful, please cite us: 8 | ```bibtex 9 | @Article{Gehrig2024pami, 10 | author = {Mathias Gehrig and Manasi Muglikar and Davide Scaramuzza}, 11 | title = {Dense Continuous-Time Optical Flow from Event Cameras}, 12 | journal = {{IEEE} Trans. Pattern Anal. Mach. Intell. (T-PAMI)}, 13 | year = 2024 14 | } 15 | ``` 16 | 17 | ## Conda Installation 18 | We highly recommend to use [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) to reduce the installation time. 19 | ```Bash 20 | conda create -y -n bflow python=3.11 pip 21 | conda activate bflow 22 | conda config --set channel_priority flexible 23 | 24 | CUDA_VERSION=12.1 25 | 26 | conda install -y h5py=3.10.0 blosc-hdf5-plugin=1.0.0 llvm-openmp=15.0.7 \ 27 | hydra-core=1.3.2 einops=0.7 tqdm numba \ 28 | pytorch=2.1.2 torchvision pytorch-cuda=$CUDA_VERSION \ 29 | -c pytorch -c nvidia -c conda-forge 30 | 31 | python -m pip install pytorch-lightning==2.1.3 wandb==0.16.1 \ 32 | opencv-python==4.8.1.78 imageio==2.33.1 lpips==0.1.4 \ 33 | pandas==2.1.4 plotly==5.18.0 moviepy==1.0.3 tabulate==0.9.0 \ 34 | loguru==0.7.2 matplotlib==3.8.2 scikit-image==0.22.0 kaleido 35 | ``` 36 | ## Data 37 | ### MultiFlow 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 |
TrainVal
pre-processed datasetdownloaddownload
47 | 48 | ### DSEC 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 |
TrainTest (input)
pre-processed datasetdownloaddownload
crc32c1b618fcffbacb7e
62 | 63 | ## Checkpoints 64 | 65 | ### MultiFlow 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 |
Events onlyEvents + Images
pre-trained checkpointdownloaddownload
md561e1022ce3aa
80 | 81 | ### DSEC 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 |
Events onlyEvents + Images
pre-trained checkpointdownloaddownload
md5d1700205770b
96 | 97 | 98 | ## Training 99 | 100 | ### MultiFlow 101 | - Set `DATA_DIR` as the path to the MultiFlow dataset (parent of train and val dir) 102 | - Set 103 | - `MDL_CFG=E_I_LU5_BD10_lowpyramid` to use both events and frames, or 104 | - `MDL_CFG=E_LU5_BD10_lowpyramid` to use only events 105 | - Set `LOG_ONLY_NUMBERS=true` to avoid logging images (can require a lot of space). Set to false by default. 106 | 107 | ```Bash 108 | GPU_ID=0 109 | python train.py model=raft-spline dataset=multiflow_regen dataset.path=${DATA_DIR} wandb.group_name=multiflow \ 110 | hardware.gpus=${GPU_ID} hardware.num_workers=6 +experiment/multiflow/raft_spline=${MLD_CFG} \ 111 | logging.only_numbers=${LOG_ONLY_NUMBERS} 112 | ``` 113 | 114 | ### DSEC 115 | - Set `DATA_DIR` as the path to the DSEC dataset (parent of train and test dir) 116 | - 117 | - Set 118 | - `MDL_CFG=E_I_LU4_BD2_lowpyramid` to use both events and frames, or 119 | - `MDL_CFG=E_LU4_BD2_lowpyramid` to use only events 120 | - Set `LOG_ONLY_NUMBERS=true` to avoid logging images (can require a lot of space). Set to false by default. 121 | 122 | ```Bash 123 | GPU_ID=0 124 | python train.py model=raft-spline dataset=dsec dataset.path=${DATA_DIR} wandb.group_name=dsec \ 125 | hardware.gpus=${GPU_ID} hardware.num_workers=6 +experiment/dsec/raft_spline=${MLD_CFG} \ 126 | logging.only_numbers=${LOG_ONLY_NUMBERS} 127 | ``` 128 | 129 | ## Evaluation 130 | 131 | ### MultiFlow 132 | - Set `DATA_DIR` as the path to the MultiFlow dataset (parent of train and val dir) 133 | - Set 134 | - `MDL_CFG=E_I_LU5_BD10_lowpyramid` to use both events and frames, or 135 | - `MDL_CFG=E_LU5_BD10_lowpyramid` to use only events 136 | - Set `CKPT` to the path of the correct checkpoint 137 | 138 | ```Bash 139 | GPU_ID=0 140 | python val.py model=raft-spline dataset=multiflow_regen dataset.path=${DATA_DIR} hardware.gpus=${GPU_ID} \ 141 | +experiment/multiflow/raft_spline=${MLD_CFG} checkpoint=${CKPT} 142 | ``` 143 | 144 | ### DSEC 145 | 146 | work in progress 147 | 148 | ## Code Acknowledgments 149 | This project has used code from [RAFT](https://github.com/princeton-vl/RAFT) for parts of the model architecture. 150 | -------------------------------------------------------------------------------- /callbacks/logger.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, Union, Any, Optional, List 3 | 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.callbacks import Callback 6 | import wandb 7 | 8 | import torch 9 | 10 | from callbacks.utils.visualization import ( 11 | create_summary_img, 12 | get_grad_flow_figure, 13 | multi_plot_bezier_array, 14 | ev_repr_reduced_to_img_grayscale, 15 | img_torch_to_numpy) 16 | from utils.general import is_cpu 17 | from data.utils.keys import DataLoading, DataSetType 18 | 19 | from loggers.wandb_logger import WandbLogger 20 | from models.raft_spline.bezier import BezierCurves 21 | 22 | 23 | class WandBImageLoggingCallback(Callback): 24 | FLOW_GT: str = 'flow_gt_img' 25 | FLOW_PRED: str = 'flow_pred_img' 26 | FLOW_VALID: str = 'flow_valid_img' 27 | EV_REPR_REDUCED: str = 'ev_repr_reduced' 28 | EV_REPR_REDUCED_M1: str = 'ev_repr_reduced_m1' 29 | VAL_BATCH_IDX: str = 'val_batch_idx' 30 | BEZIER_PARAMS: str = 'bezier_prediction' 31 | IMAGES: str = 'images' 32 | 33 | MAX_FLOW_ERROR = { 34 | int(DataSetType.DSEC): 2.0, 35 | int(DataSetType.MULTIFLOW2D): 3.0, 36 | } 37 | 38 | def __init__(self, logging_params: Dict[str, Any], deterministic: bool=True): 39 | super().__init__() 40 | log_every_n_train_steps = logging_params['log_every_n_steps'] 41 | log_n_val_predictions = logging_params['log_n_val_predictions'] 42 | self.log_only_numbers = logging_params['only_numbers'] 43 | assert log_every_n_train_steps > 0 44 | assert log_n_val_predictions > 0 45 | self.log_every_n_train_steps = log_every_n_train_steps 46 | self.log_n_val_predictions = log_n_val_predictions 47 | self.deterministic = deterministic 48 | 49 | self._clear_val_data() 50 | self._training_started = False 51 | self._val_batch_indices = None 52 | 53 | self._dataset_type: Optional[DataSetType] = None 54 | 55 | def enable_immediate_validation(self): 56 | self._training_started = True 57 | 58 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: 59 | if self.log_only_numbers: 60 | return 61 | if not self._training_started: 62 | self._training_started = True 63 | global_step = trainer.global_step 64 | do_log = True 65 | do_log &= global_step >= self.log_every_n_train_steps 66 | # BUG: wandb bug? If we log metrics on the same step as logging images in this function 67 | # then no metrics are logged at all, only the images. Unclear why this happens. 68 | # So we introduce a small step margin to ensure that we don't log at the same time metrics and what we log here. 69 | delta_step = 2 70 | do_log &= (global_step - delta_step) % self.log_every_n_train_steps == 0 71 | if not do_log: 72 | return 73 | 74 | if self._dataset_type is None: 75 | self._dataset_type = batch[DataLoading.DATASET_TYPE][0].cpu().item() 76 | 77 | logger: WandbLogger = trainer.logger 78 | 79 | if isinstance(outputs['gt'], list): 80 | flow_gt = [x.detach().cpu() for x in outputs['gt']] 81 | else: 82 | flow_gt = outputs['gt'].detach().cpu() 83 | flow_pred = outputs['pred'].detach().cpu() 84 | flow_valid = outputs['gt_valid'].detach().cpu() if 'gt_valid' in outputs else None 85 | 86 | ev_repr_reduced = None 87 | if 'ev_repr_reduced' in outputs.keys(): 88 | ev_repr_reduced = outputs['ev_repr_reduced'].detach().cpu() 89 | ev_repr_reduced_m1 = None 90 | if 'ev_repr_reduced_m1' in outputs.keys(): 91 | ev_repr_reduced_m1 = outputs['ev_repr_reduced_m1'].detach().cpu() 92 | images = None 93 | if 'images' in outputs.keys(): 94 | images = [x.detach().cpu() for x in outputs['images']] 95 | 96 | summary_img = create_summary_img( 97 | flow_pred, 98 | flow_gt[-1] if isinstance(flow_gt, list) else flow_gt, 99 | valid_mask=flow_valid, 100 | ev_repr_reduced=ev_repr_reduced, 101 | ev_repr_reduced_m1=ev_repr_reduced_m1, 102 | images=images, 103 | max_error=self.MAX_FLOW_ERROR[self._dataset_type]) 104 | wandb_flow_img = wandb.Image(summary_img) 105 | logger.log_metrics({'train/flow': wandb_flow_img}, step=global_step) 106 | 107 | if 'bezier_prediction' in outputs.keys(): 108 | bezier_prediction: BezierCurves 109 | bezier_prediction = outputs['bezier_prediction'].detach(cpu=True) 110 | assert not bezier_prediction.requires_grad 111 | 112 | if images is not None: 113 | # images[0] is assumed to be the reference image 114 | background_img = img_torch_to_numpy(images[0]) 115 | else: 116 | assert ev_repr_reduced is not None 117 | background_img = ev_repr_reduced_to_img_grayscale(ev_repr_reduced) 118 | 119 | bezier_img = multi_plot_bezier_array( 120 | bezier_prediction, 121 | background_img, 122 | multi_flow_gt= flow_gt if isinstance(flow_gt, list) else [flow_gt], # Could also put [flow_gt] if not list. 123 | num_t=10, 124 | x_add_margin=30, 125 | y_add_margin=30) 126 | wandb_bezier_img = wandb.Image(bezier_img) 127 | logger.log_metrics({'train/bezier': wandb_bezier_img}, step=global_step) 128 | 129 | def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 130 | global_step = trainer.global_step 131 | if global_step % self.log_every_n_train_steps != 0: 132 | return 133 | named_parameters = pl_module.named_parameters() 134 | figure = get_grad_flow_figure(named_parameters) 135 | trainer.logger.log_metrics({'train/gradients': figure}, step=global_step) 136 | 137 | def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 138 | if self.log_only_numbers: 139 | return 140 | if not self._training_started: 141 | # Don't log before the training started. 142 | # In PL, there is a validation sanity check. 143 | return 144 | if self._val_batch_indices is None: 145 | val_batch_indices = self._set_val_batch_indices() 146 | self._subsample_val_data(val_batch_indices) 147 | 148 | flow_gt = self._val_data[self.FLOW_GT] 149 | flow_pred = self._val_data[self.FLOW_PRED] 150 | flow_valid = self._val_data[self.FLOW_VALID] 151 | ev_repr_reduced = self._val_data[self.EV_REPR_REDUCED] 152 | ev_repr_reduced_m1 = self._val_data[self.EV_REPR_REDUCED_M1] 153 | bezier_params = self._val_data[self.BEZIER_PARAMS] 154 | images = self._val_data[self.IMAGES] 155 | 156 | if flow_gt: 157 | assert flow_pred 158 | # Stack in batch dimension from list for make_grid 159 | if isinstance(flow_gt[0], list): 160 | # [[(2, H, W), ...], ...] 161 | # outer: val batch indices 162 | # inner: number of gt samples in time 163 | # We only want the inner loop but batched 164 | # This is the same as transposing the list -> use pytorch because we can just use their transpose 165 | # V: number of val samples 166 | # T: number of gt samples in time 167 | 168 | # V, T, 2, H, W 169 | new_flow_gt = torch.stack([torch.stack([single_gt_map for single_gt_map in val_sample]) for val_sample in flow_gt]) 170 | # V, T, 2, H, W -> T, V, 2, H, W 171 | new_flow_gt = torch.transpose(new_flow_gt, 0, 1) 172 | new_flow_gt = torch.split(new_flow_gt, [1]*new_flow_gt.shape[0], dim=0) 173 | # [(V, 2, H, W), ... ] T times :) the last item is the predictions for the final prediction 174 | flow_gt = [x.squeeze() for x in new_flow_gt] 175 | else: 176 | assert isinstance(flow_gt[0], torch.Tensor) 177 | flow_gt = torch.stack(flow_gt) 178 | flow_pred = torch.stack(flow_pred) 179 | flow_valid = torch.stack(flow_valid) if len(flow_valid) > 0 else None 180 | if len(ev_repr_reduced) == 0: 181 | ev_repr_reduced = None 182 | else: 183 | ev_repr_reduced = torch.stack(ev_repr_reduced) 184 | if len(ev_repr_reduced_m1) == 0: 185 | ev_repr_reduced_m1 = None 186 | else: 187 | ev_repr_reduced_m1 = torch.stack(ev_repr_reduced_m1) 188 | if len(images) == 0: 189 | images = None 190 | else: 191 | images = [torch.stack([x[0] for x in images]), torch.stack([x[1] for x in images])] 192 | 193 | summary_img = create_summary_img( 194 | flow_pred, 195 | flow_gt[-1] if isinstance(flow_gt, list) else flow_gt, 196 | flow_valid, 197 | ev_repr_reduced=ev_repr_reduced, 198 | ev_repr_reduced_m1=ev_repr_reduced_m1, 199 | images=images, 200 | max_error=self.MAX_FLOW_ERROR[self._dataset_type]) 201 | wandb_flow_img = wandb.Image(summary_img) 202 | 203 | global_step = trainer.global_step 204 | logger: WandbLogger = trainer.logger 205 | logger.log_metrics({'val/flow': wandb_flow_img}, step=global_step) 206 | 207 | if len(bezier_params) > 0: 208 | bezier_params = torch.stack(bezier_params) 209 | bezier_curves = BezierCurves(bezier_params) 210 | 211 | if images is not None: 212 | # images[0] is assumed to be the reference image 213 | background_img = img_torch_to_numpy(images[0]) 214 | else: 215 | assert ev_repr_reduced is not None 216 | background_img = ev_repr_reduced_to_img_grayscale(ev_repr_reduced) 217 | 218 | bezier_img = multi_plot_bezier_array( 219 | bezier_curves, 220 | background_img, 221 | multi_flow_gt= flow_gt if isinstance(flow_gt, list) else [flow_gt], # Could also put [flow_gt] if not list. 222 | num_t=10, 223 | x_add_margin=30, 224 | y_add_margin=30) 225 | wandb_bezier_img = wandb.Image(bezier_img) 226 | logger.log_metrics({'val/bezier': wandb_bezier_img}, step=global_step) 227 | 228 | self._clear_val_data() 229 | 230 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 231 | if self.log_only_numbers: 232 | return 233 | # NOTE: How to resolve the growing memory issue throughout the validation run? 234 | # A hack would be to only return what we want to log in the validation_step function of the LightningModule, 235 | # and then log the data in the validation_epoch_end function. 236 | if not self._training_started: 237 | # Don't log before the training started. 238 | # In PL, there is a validation sanity check. 239 | return 240 | if self._dataset_type is None: 241 | self._dataset_type = batch[DataLoading.DATASET_TYPE][0].cpu().item() 242 | if self._val_batch_indices is not None: 243 | # NOTE: For the first validation run, we still save everything which can lead to crashes due to full RAM. 244 | # Once we have set the validation batch indices, 245 | # we only save data from those. 246 | if batch_idx not in self._val_batch_indices: 247 | return 248 | 249 | if isinstance(outputs['gt'], list): 250 | flow_gt = [x[0].cpu() for x in outputs['gt']] 251 | # -> list(list(tensors)) will available after end of validation epoch 252 | else: 253 | flow_gt = outputs['gt'][0].cpu() 254 | # -> list(tensors) will available after end of validation epoch 255 | flow_pred = outputs['pred'][0].cpu() 256 | flow_valid = outputs['gt_valid'][0].cpu() if 'gt_valid' in outputs else None 257 | ev_repr_reduced = None 258 | if 'ev_repr_reduced' in outputs.keys(): 259 | ev_repr_reduced = outputs['ev_repr_reduced'][0].cpu() 260 | ev_repr_reduced_m1 = None 261 | if 'ev_repr_reduced_m1' in outputs.keys(): 262 | ev_repr_reduced_m1 = outputs['ev_repr_reduced_m1'][0].cpu() 263 | images = None 264 | if 'images' in outputs.keys(): 265 | images = [x[0].detach().cpu() for x in outputs['images']] 266 | assert len(images) == 2 267 | 268 | bezier_params = None 269 | if 'bezier_prediction' in outputs.keys(): 270 | bezier_prediction: BezierCurves = outputs['bezier_prediction'] 271 | assert not bezier_prediction.requires_grad 272 | bezier_params = bezier_prediction.get_params()[0].cpu() 273 | 274 | self._save_val_data(self.FLOW_GT, flow_gt) 275 | self._save_val_data(self.FLOW_PRED, flow_pred) 276 | if flow_valid is not None: 277 | self._save_val_data(self.FLOW_VALID, flow_valid) 278 | if ev_repr_reduced is not None: 279 | self._save_val_data(self.EV_REPR_REDUCED, ev_repr_reduced) 280 | if ev_repr_reduced_m1 is not None: 281 | self._save_val_data(self.EV_REPR_REDUCED_M1, ev_repr_reduced_m1) 282 | if bezier_params is not None: 283 | self._save_val_data(self.BEZIER_PARAMS, bezier_params) 284 | if images is not None: 285 | self._save_val_data(self.IMAGES, images) 286 | self._save_val_data(self.VAL_BATCH_IDX, batch_idx) 287 | 288 | def _set_val_batch_indices(self): 289 | val_indices = self._val_data[self.VAL_BATCH_IDX] 290 | assert val_indices 291 | num_samples = min(len(val_indices), self.log_n_val_predictions) 292 | 293 | if self.deterministic: 294 | random.seed(0) 295 | sampled_indices = random.sample(val_indices, num_samples) 296 | self._val_batch_indices = set(sampled_indices) 297 | return self._val_batch_indices 298 | 299 | def _clear_val_data(self): 300 | self._val_data = { 301 | self.FLOW_GT: list(), 302 | self.FLOW_PRED: list(), 303 | self.FLOW_VALID: list(), 304 | self.EV_REPR_REDUCED: list(), 305 | self.EV_REPR_REDUCED_M1: list(), 306 | self.VAL_BATCH_IDX: list(), 307 | self.BEZIER_PARAMS: list(), 308 | self.IMAGES: list(), 309 | } 310 | 311 | def _subsample_val_data(self, val_indices: Union[list, set]): 312 | for k, v in self._val_data.items(): 313 | if k == self.VAL_BATCH_IDX: 314 | continue 315 | assert isinstance(v, list) 316 | subsampled_list = list() 317 | for idx, x in enumerate(v): 318 | if idx not in val_indices: 319 | continue 320 | assert is_cpu(x) 321 | subsampled_list.append(x) 322 | self._val_data[k] = subsampled_list 323 | self._val_data[self.VAL_BATCH_IDX] = list(val_indices) 324 | 325 | def _save_val_data(self, key: str, data: Union[torch.Tensor, int, float, List[torch.Tensor]]): 326 | assert key in self._val_data.keys() 327 | if isinstance(data, torch.Tensor) or isinstance(data, list): 328 | assert is_cpu(data) 329 | self._val_data[key].append(data) 330 | -------------------------------------------------------------------------------- /callbacks/utils/flow_vis.py: -------------------------------------------------------------------------------- 1 | # From: https://github.com/tomrunia/OpticalFlow_Visualization 2 | # 3 | # MIT License 4 | # 5 | # Copyright (c) 2018 Tom Runia 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to conditions. 13 | # 14 | # Author: Tom Runia 15 | # Date Created: 2018-08-03 16 | 17 | import numpy as np 18 | 19 | def make_colorwheel(): 20 | """ 21 | Generates a color wheel for optical flow visualization as presented in: 22 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 23 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 24 | Code follows the original C++ source code of Daniel Scharstein. 25 | Code follows the the Matlab source code of Deqing Sun. 26 | Returns: 27 | np.ndarray: Color wheel 28 | """ 29 | 30 | RY = 15 31 | YG = 6 32 | GC = 4 33 | CB = 11 34 | BM = 13 35 | MR = 6 36 | 37 | ncols = RY + YG + GC + CB + BM + MR 38 | colorwheel = np.zeros((ncols, 3)) 39 | col = 0 40 | 41 | # RY 42 | colorwheel[0:RY, 0] = 255 43 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 44 | col = col+RY 45 | # YG 46 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 47 | colorwheel[col:col+YG, 1] = 255 48 | col = col+YG 49 | # GC 50 | colorwheel[col:col+GC, 1] = 255 51 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 52 | col = col+GC 53 | # CB 54 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 55 | colorwheel[col:col+CB, 2] = 255 56 | col = col+CB 57 | # BM 58 | colorwheel[col:col+BM, 2] = 255 59 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 60 | col = col+BM 61 | # MR 62 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 63 | colorwheel[col:col+MR, 0] = 255 64 | return colorwheel 65 | 66 | 67 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 68 | """ 69 | Applies the flow color wheel to (possibly clipped) flow components u and v. 70 | According to the C++ source code of Daniel Scharstein 71 | According to the Matlab source code of Deqing Sun 72 | Args: 73 | u (np.ndarray): Input horizontal flow of shape [H,W] 74 | v (np.ndarray): Input vertical flow of shape [H,W] 75 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 76 | Returns: 77 | np.ndarray: Flow visualization image of shape [H,W,3] 78 | """ 79 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 80 | colorwheel = make_colorwheel() # shape [55x3] 81 | ncols = colorwheel.shape[0] 82 | rad = np.sqrt(np.square(u) + np.square(v)) 83 | a = np.arctan2(-v, -u)/np.pi 84 | fk = (a+1) / 2*(ncols-1) 85 | k0 = np.floor(fk).astype(np.int32) 86 | k1 = k0 + 1 87 | k1[k1 == ncols] = 0 88 | f = fk - k0 89 | for i in range(colorwheel.shape[1]): 90 | tmp = colorwheel[:,i] 91 | col0 = tmp[k0] / 255.0 92 | col1 = tmp[k1] / 255.0 93 | col = (1-f)*col0 + f*col1 94 | idx = (rad <= 1) 95 | col[idx] = 1 - rad[idx] * (1-col[idx]) 96 | col[~idx] = col[~idx] * 0.75 # out of range 97 | # Note the 2-i => BGR instead of RGB 98 | ch_idx = 2-i if convert_to_bgr else i 99 | flow_image[:,:,ch_idx] = np.floor(255 * col) 100 | return flow_image 101 | 102 | 103 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): 104 | """ 105 | Expects a two dimensional flow image of shape. 106 | Args: 107 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 108 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 109 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 110 | Returns: 111 | np.ndarray: Flow visualization image of shape [H,W,3] 112 | """ 113 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 114 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 115 | if clip_flow is not None: 116 | flow_uv = np.clip(flow_uv, 0, clip_flow) 117 | u = flow_uv[:,:,0] 118 | v = flow_uv[:,:,1] 119 | rad = np.sqrt(np.square(u) + np.square(v)) 120 | rad_max = np.max(rad) 121 | epsilon = 1e-5 122 | u = u / (rad_max + epsilon) 123 | v = v / (rad_max + epsilon) 124 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /config/dataset/base.yaml: -------------------------------------------------------------------------------- 1 | load_voxel_grid: True 2 | normalize_voxel_grid: True 3 | path: ??? 4 | photo_augm: False 5 | return_ev: True 6 | return_img: True -------------------------------------------------------------------------------- /config/dataset/dsec.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | name: dsec 5 | extended_voxel_grid: True # If events outside the borders should be included. May slightly increase the performance. -------------------------------------------------------------------------------- /config/dataset/multiflow_regen.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | name: multiflow_regen 5 | extended_voxel_grid: True # If events outside the borders should be included. May slightly increase the performance. 6 | downsample: False 7 | flow_every_n_ms: 50 8 | -------------------------------------------------------------------------------- /config/experiment/dsec/raft_spline/E_I_LU4_BD2_lowpyramid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: raft-spline 4 | 5 | training: 6 | limit_val_batches: 0 # no validation 7 | multi_loss: false 8 | max_steps: 250000 9 | lr_scheduler: 10 | total_steps: ${..max_steps} 11 | model: 12 | bezier_degree: 2 13 | use_boundary_images: true 14 | use_events: true 15 | correlation: 16 | ev: 17 | target_indices: [1, 2, 3, 4] 18 | levels: [1, 1, 1, 4] 19 | radius: [4, 4, 4, 4] 20 | img: 21 | levels: 4 22 | radius: 4 23 | -------------------------------------------------------------------------------- /config/experiment/dsec/raft_spline/E_LU4_BD2_lowpyramid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: raft-spline 4 | 5 | training: 6 | limit_val_batches: 0 # no validation 7 | multi_loss: false 8 | max_steps: 250000 9 | lr_scheduler: 10 | total_steps: ${..max_steps} 11 | model: 12 | bezier_degree: 2 13 | use_boundary_images: false 14 | use_events: true 15 | correlation: 16 | ev: 17 | target_indices: [1, 2, 3, 4] 18 | levels: [1, 1, 1, 4] 19 | radius: [4, 4, 4, 4] 20 | img: 21 | levels: null 22 | radius: null 23 | -------------------------------------------------------------------------------- /config/experiment/multiflow/raft_spline/E_I_LU5_BD10_lowpyramid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: raft-spline 4 | 5 | training: 6 | multi_loss: true 7 | batch_size: 3 8 | learning_rate: 0.0001 9 | weight_decay: 0.0001 10 | lr_scheduler: 11 | use: true 12 | model: 13 | num_bins: 14 | context: 41 15 | correlation: 25 16 | bezier_degree: 10 17 | use_boundary_images: true 18 | use_events: true 19 | correlation: 20 | ev: 21 | #target_indices: [1, 2, 3, 4, 5] # for 6 context bins 22 | #target_indices: [2, 4, 6, 8, 10] # for 11 context bins 23 | target_indices: [8, 16, 24, 32, 40] # for 41 context bins 24 | levels: [1, 1, 1, 1, 4] 25 | radius: [4, 4, 4, 4, 4] 26 | img: 27 | levels: 4 28 | radius: 4 29 | -------------------------------------------------------------------------------- /config/experiment/multiflow/raft_spline/E_LU5_BD10_lowpyramid.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: raft-spline 4 | 5 | training: 6 | multi_loss: true 7 | batch_size: 3 8 | learning_rate: 0.0001 9 | weight_decay: 0.0001 10 | lr_scheduler: 11 | use: true 12 | model: 13 | num_bins: 14 | context: 41 15 | correlation: 25 16 | bezier_degree: 10 17 | use_boundary_images: false 18 | use_events: true 19 | correlation: 20 | ev: 21 | #target_indices: [1, 2, 3, 4, 5] # for 6 context bins 22 | #target_indices: [2, 4, 6, 8, 10] # for 11 context bins 23 | target_indices: [8, 16, 24, 32, 40] # for 41 context bins 24 | levels: [1, 1, 1, 1, 4] 25 | radius: [4, 4, 4, 4, 4] 26 | -------------------------------------------------------------------------------- /config/general.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | multi_loss: true # ATTENTION: This can only be used for the MultiFlow 3 | batch_size: 3 4 | max_epochs: 1000 5 | max_steps: 200000 6 | learning_rate: 0.0001 7 | weight_decay: 0.0001 8 | gradient_clip_val: 1 9 | limit_train_batches: 1 10 | limit_val_batches: 1 11 | lr_scheduler: 12 | use: true 13 | total_steps: ${..max_steps} 14 | pct_start: 0.01 15 | hardware: 16 | num_workers: null # Number of workers. By default will take twice the maximum batch size 17 | gpus: 0 # Either a single integer (e.g. 3) or a list of integers (e.g. [3, 5, 6]) 18 | logging: 19 | only_numbers: False 20 | ckpt_every_n_epochs: 1 21 | log_every_n_steps: 5000 22 | flush_logs_every_n_steps: 1000 23 | log_n_val_predictions: 2 24 | wandb: 25 | wandb_runpath: null # Specify WandB run path if you wish to resume from that run. E.g. magehrig/eRAFT/1lmicg6t 26 | artifact_runpath: null # Specify WandB run path if you wish to resume with a checkpoint/artifact from that run. Will take current wandb runpath if not specified 27 | artifact_name: null # Name of the checkpoint/artifact from which to resume. E.g. checkpoint-1ae609sb:v5 28 | resume_only_weights: False # If artifact is provided, you can choose to resume only the weights. Otherwise, the full training state is restored 29 | project_name: contflow # Specify group name of the run 30 | group_name: ??? # Specify group name of the run 31 | debugging: 32 | test_cpu_dataloading: False 33 | profiler: null # {None, simple, advanced, pytorch} 34 | -------------------------------------------------------------------------------- /config/model/base.yaml: -------------------------------------------------------------------------------- 1 | num_bins: 2 | context: 5 3 | -------------------------------------------------------------------------------- /config/model/raft-spline.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - raft_base 3 | 4 | name: raft-spline 5 | detach_bezier: false 6 | bezier_degree: 2 7 | use_boundary_images: true 8 | use_events: true 9 | use_gma: false 10 | correlation: 11 | ev: 12 | target_indices: [1, 2, 3, 4] # 0 idx is the reference. num_bins_context - 1 is the maximum idx. 13 | levels: [1, 2, 3, 4] # Number of pyramid levels. Must have the same length as target_indices. 14 | radius: [4, 4, 4, 4] # Look-up radius. Must have the same length as target_indices. 15 | img: 16 | levels: 4 17 | radius: 4 -------------------------------------------------------------------------------- /config/model/raft_base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | num_bins: 5 | correlation: null 6 | num_iter: 7 | train: 12 8 | test: 12 9 | correlation: 10 | use_cosine_sim: false 11 | hidden: 12 | dim: 128 13 | context: 14 | dim: 128 15 | norm: batch 16 | feature: 17 | dim: 256 18 | norm: instance 19 | motion: 20 | dim: 128 21 | -------------------------------------------------------------------------------- /config/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - general 3 | - dataset: ??? 4 | - model: ??? -------------------------------------------------------------------------------- /config/val.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: ??? 3 | - model: ??? 4 | - _self_ 5 | 6 | checkpoint: ??? 7 | hardware: 8 | num_workers: 4 9 | gpus: 0 # GPU idx (multi-gpu not supported for validation) 10 | batch_size: 8 -------------------------------------------------------------------------------- /data/dsec/eventslicer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Tuple 3 | 4 | import h5py 5 | from numba import jit 6 | import numpy as np 7 | 8 | 9 | class EventSlicer: 10 | def __init__(self, h5f: h5py.File): 11 | self.h5f = h5f 12 | 13 | self.events = dict() 14 | for dset_str in ['p', 'x', 'y', 't']: 15 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)] 16 | 17 | # This is the mapping from milliseconds to event index: 18 | # It is defined such that 19 | # (1) t[ms_to_idx[ms]] >= ms*1000 20 | # (2) t[ms_to_idx[ms] - 1] < ms*1000 21 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds. 22 | # 23 | # As an example, given 't' and 'ms': 24 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000 25 | # ms: 0 1 2 3 4 5 6 7 8 9 26 | # 27 | # we get 28 | # 29 | # ms_to_idx: 30 | # 0 2 2 3 3 3 5 5 8 9 31 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64') 32 | 33 | self.t_offset = int(h5f['t_offset'][()]) 34 | self.t_final = int(self.events['t'][-1]) + self.t_offset 35 | 36 | def get_start_time_us(self): 37 | return self.t_offset 38 | 39 | def get_final_time_us(self): 40 | return self.t_final 41 | 42 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]: 43 | """Get events (p, x, y, t) within the specified time window 44 | Parameters 45 | ---------- 46 | t_start_us: start time in microseconds 47 | t_end_us: end time in microseconds 48 | Returns 49 | ------- 50 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved 51 | """ 52 | assert t_start_us < t_end_us 53 | 54 | # We assume that the times are top-off-day, hence subtract offset: 55 | t_start_us -= self.t_offset 56 | t_end_us -= self.t_offset 57 | 58 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us) 59 | t_start_ms_idx = self.ms2idx(t_start_ms) 60 | t_end_ms_idx = self.ms2idx(t_end_ms) 61 | 62 | if t_start_ms_idx is None or t_end_ms_idx is None: 63 | # Cannot guarantee window size anymore 64 | return None 65 | 66 | events = dict() 67 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx]) 68 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us) 69 | t_start_us_idx = t_start_ms_idx + idx_start_offset 70 | t_end_us_idx = t_start_ms_idx + idx_end_offset 71 | # Again add t_offset to get gps time 72 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset 73 | for dset_str in ['p', 'x', 'y']: 74 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx]) 75 | assert events[dset_str].size == events['t'].size 76 | return events 77 | 78 | 79 | @staticmethod 80 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]: 81 | """Compute a conservative time window of time with millisecond resolution. 82 | We have a time to index mapping for each millisecond. Hence, we need 83 | to compute the lower and upper millisecond to retrieve events. 84 | Parameters 85 | ---------- 86 | ts_start_us: start time in microseconds 87 | ts_end_us: end time in microseconds 88 | Returns 89 | ------- 90 | window_start_ms: conservative start time in milliseconds 91 | window_end_ms: conservative end time in milliseconds 92 | """ 93 | assert ts_end_us > ts_start_us 94 | window_start_ms = math.floor(ts_start_us/1000) 95 | window_end_ms = math.ceil(ts_end_us/1000) 96 | return window_start_ms, window_end_ms 97 | 98 | @staticmethod 99 | @jit(nopython=True) 100 | def get_time_indices_offsets( 101 | time_array: np.ndarray, 102 | time_start_us: int, 103 | time_end_us: int) -> Tuple[int, int]: 104 | """Compute index offset of start and end timestamps in microseconds 105 | Parameters 106 | ---------- 107 | time_array: timestamps (in us) of the events 108 | time_start_us: start timestamp (in us) 109 | time_end_us: end timestamp (in us) 110 | Returns 111 | ------- 112 | idx_start: Index within this array corresponding to time_start_us 113 | idx_end: Index within this array corresponding to time_end_us 114 | such that (in non-edge cases) 115 | time_array[idx_start] >= time_start_us 116 | time_array[idx_end] >= time_end_us 117 | time_array[idx_start - 1] < time_start_us 118 | time_array[idx_end - 1] < time_end_us 119 | this means that 120 | time_start_us <= time_array[idx_start:idx_end] < time_end_us 121 | """ 122 | 123 | assert time_array.ndim == 1 124 | 125 | idx_start = -1 126 | if time_array[-1] < time_start_us: 127 | # This can happen in extreme corner cases. E.g. 128 | # time_array[0] = 1016 129 | # time_array[-1] = 1984 130 | # time_start_us = 1990 131 | # time_end_us = 2000 132 | 133 | # Return same index twice: array[x:x] is empty. 134 | return time_array.size, time_array.size 135 | else: 136 | # TODO(magehrig): use binary search for speedup 137 | for idx_from_start in range(0, time_array.size, 1): 138 | if time_array[idx_from_start] >= time_start_us: 139 | idx_start = idx_from_start 140 | break 141 | assert idx_start >= 0 142 | 143 | idx_end = time_array.size 144 | # TODO(magehrig): use binary search for speedup 145 | for idx_from_end in range(time_array.size - 1, -1, -1): 146 | if time_array[idx_from_end] >= time_end_us: 147 | idx_end = idx_from_end 148 | else: 149 | break 150 | 151 | assert time_array[idx_start] >= time_start_us 152 | if idx_end < time_array.size: 153 | assert time_array[idx_end] >= time_end_us 154 | if idx_start > 0: 155 | assert time_array[idx_start - 1] < time_start_us 156 | if idx_end > 0: 157 | assert time_array[idx_end - 1] < time_end_us 158 | return idx_start, idx_end 159 | 160 | def ms2idx(self, time_ms: int) -> int: 161 | assert time_ms >= 0 162 | if time_ms >= self.ms_to_idx.size: 163 | return None 164 | return self.ms_to_idx[time_ms] 165 | -------------------------------------------------------------------------------- /data/dsec/provider.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from pathlib import Path 3 | from typing import Dict, Any 4 | 5 | import torch 6 | 7 | from data.dsec.sequence import generate_sequence 8 | from data.dsec.subsequence.twostep import TwoStepSubSequence 9 | from data.utils.provider import DatasetProviderBase 10 | 11 | 12 | class DatasetProvider(DatasetProviderBase): 13 | def __init__(self, 14 | dataset_params: Dict[str, Any], 15 | nbins_context: int): 16 | dataset_path = Path(dataset_params['path']) 17 | 18 | train_path = dataset_path / 'train' 19 | test_path = dataset_path / 'test' 20 | assert dataset_path.is_dir(), str(dataset_path) 21 | assert train_path.is_dir(), str(train_path) 22 | assert test_path.is_dir(), str(test_path) 23 | 24 | # NOTE: For now, we assume that the number of bins for the correlation are the same as for the context. 25 | self.nbins = nbins_context 26 | 27 | subseq_class = TwoStepSubSequence 28 | 29 | base_args = { 30 | 'num_bins': self.nbins, 31 | 'load_voxel_grid': dataset_params['load_voxel_grid'], 32 | 'extended_voxel_grid': dataset_params['extended_voxel_grid'], 33 | 'normalize_voxel_grid': dataset_params['normalize_voxel_grid'], 34 | } 35 | base_args.update({'merge_grids': True}) 36 | train_args = copy.deepcopy(base_args) 37 | train_args.update({'data_augm': True}) 38 | test_args = copy.deepcopy(base_args) 39 | test_args.update({'data_augm': False}) 40 | 41 | train_sequences = list() 42 | for child in train_path.iterdir(): 43 | sequence = generate_sequence(child, subseq_class, train_args) 44 | if sequence is not None: 45 | train_sequences.append(sequence) 46 | 47 | self.train_dataset = torch.utils.data.ConcatDataset(train_sequences) 48 | 49 | # TODO: write specialized test sequence 50 | #test_sequences = list() 51 | #for child in test_path.iterdir(): 52 | # sequence = generate_test_sequence(child, subseq_class, test_args) 53 | # if sequence is not None: 54 | # test_sequences.append(sequence) 55 | #self.test_dataset = torch.utils.data.ConcatDataset(test_sequences) 56 | self.test_dataset = None 57 | 58 | def get_train_dataset(self): 59 | return self.train_dataset 60 | 61 | def get_val_dataset(self): 62 | raise NotImplementedError 63 | 64 | def get_test_dataset(self): 65 | return self.test_dataset 66 | 67 | def get_nbins_context(self): 68 | return self.nbins 69 | 70 | def get_nbins_correlation(self): 71 | return self.nbins 72 | -------------------------------------------------------------------------------- /data/dsec/sequence.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Type 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from data.dsec.subsequence.base import BaseSubSequence 8 | 9 | 10 | class SubSequenceGenerator: 11 | # seq_name (e.g. zurich_city_10_a) 12 | # ├── flow 13 | # │ ├── backward (for train only) 14 | # │ │ ├── xxxxxx.png 15 | # │ │ └── ... 16 | # │ ├── backward_timestamps.txt (for train only) 17 | # │ ├── forward (for train only) 18 | # │ │ ├── xxxxxx.png 19 | # │ │ └── ... 20 | # │ └── forward_timestamps.txt (for train and test) 21 | # └── events 22 | # ├── left 23 | # │ ├── events.h5 24 | # │ └── rectify_map.h5 25 | # └── right 26 | # ├── events.h5 27 | # └── rectify_map.h5 28 | def __init__(self, 29 | seq_path: Path, 30 | subseq_class: Type[BaseSubSequence], 31 | args: dict): 32 | 33 | self.args = args 34 | self.seq_path = seq_path 35 | 36 | self.seqseq_class = subseq_class 37 | 38 | # load forward optical flow timestamps 39 | assert sequence_has_flow(seq_path), seq_path 40 | flow_dir = seq_path / 'flow' 41 | assert flow_dir.is_dir(), str(flow_dir) 42 | forward_timestamps_file = flow_dir / 'forward_timestamps.txt' 43 | assert forward_timestamps_file.is_file() 44 | self.forward_flow_timestamps = np.loadtxt(str(forward_timestamps_file), dtype='int64', delimiter=',') 45 | assert self.forward_flow_timestamps.ndim == 2 46 | assert self.forward_flow_timestamps.shape[1] == 2 47 | 48 | # load forward optical flow paths 49 | forward_flow_dir = flow_dir / 'forward' 50 | assert forward_flow_dir.is_dir() 51 | forward_flow_list = list() 52 | for entry in forward_flow_dir.iterdir(): 53 | assert str(entry.name).endswith('.png') 54 | forward_flow_list.append(entry) 55 | forward_flow_list.sort() 56 | self.forward_flow_list = forward_flow_list 57 | 58 | # Extract start indices of sub-sequences 59 | from_ts = self.forward_flow_timestamps[:, 0] 60 | to_ts = self.forward_flow_timestamps[:, 1] 61 | 62 | is_start_subseq = from_ts[1:] != to_ts[:-1] 63 | # Add first index as start index too. 64 | is_start_subseq = np.concatenate((np.array((True,), dtype='bool'), is_start_subseq)) 65 | self.start_indices = list(np.where(is_start_subseq)[0]) 66 | 67 | self.subseq_idx = 0 68 | 69 | def __enter__(self): 70 | return self 71 | 72 | def __iter__(self): 73 | return self 74 | 75 | def __len__(self): 76 | return len(self.start_indices) 77 | 78 | def __next__(self) -> BaseSubSequence: 79 | if self.subseq_idx >= len(self.start_indices): 80 | raise StopIteration 81 | final_subseq = self.subseq_idx == len(self.start_indices) - 1 82 | 83 | start_idx = self.start_indices[self.subseq_idx] 84 | end_p1_idx = None if final_subseq else self.start_indices[self.subseq_idx + 1] 85 | 86 | forward_flow_timestamps = self.forward_flow_timestamps[start_idx:end_p1_idx, :] 87 | forward_flow_list = self.forward_flow_list[start_idx:end_p1_idx] 88 | 89 | self.subseq_idx += 1 90 | 91 | return self.seqseq_class(self.seq_path, forward_flow_timestamps, forward_flow_list, **self.args) 92 | 93 | 94 | def sequence_has_flow(seq_path: Path): 95 | return (seq_path / 'flow').is_dir() 96 | 97 | 98 | def generate_sequence( 99 | seq_path: Path, 100 | subseq_class: Type[BaseSubSequence], 101 | args: dict): 102 | if not sequence_has_flow(seq_path): 103 | return None 104 | 105 | subseq_list = list() 106 | 107 | for subseq in SubSequenceGenerator(seq_path, subseq_class, args): 108 | subseq_list.append(subseq) 109 | 110 | return torch.utils.data.ConcatDataset(subseq_list) 111 | -------------------------------------------------------------------------------- /data/dsec/subsequence/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import List, Optional 4 | import weakref 5 | 6 | import h5py 7 | import imageio as iio 8 | import numpy as np 9 | from PIL import Image 10 | import torch 11 | from torch.utils.data import Dataset 12 | 13 | from data.dsec.eventslicer import EventSlicer 14 | from data.utils.augmentor import FlowAugmentor 15 | from data.utils.generic import np_array_to_h5, h5_to_np_array 16 | from data.utils.representations import VoxelGrid, norm_voxel_grid 17 | 18 | 19 | class BaseSubSequence(Dataset): 20 | # seq_name (e.g. zurich_city_10_a) 21 | # ├── flow 22 | # │ ├── backward (for train only) 23 | # │ │ ├── xxxxxx.png 24 | # │ │ └── ... 25 | # │ ├── backward_timestamps.txt (for train only) 26 | # │ ├── forward (for train only) 27 | # │ │ ├── xxxxxx.png 28 | # │ │ └── ... 29 | # │ └── forward_timestamps.txt (for train and test) 30 | # └── events 31 | # ├── left 32 | # │ ├── events.h5 33 | # │ └── rectify_map.h5 34 | # └── right 35 | # ├── events.h5 36 | # └── rectify_map.h5 37 | # 38 | # For now this class 39 | # - only returns forward flow and the corresponding event representation (of the left event camera) 40 | # - does not implement recurrent loading of data 41 | # - only returns a forward rectified voxel grid with a fixed number of bins 42 | # - is not implemented to offer multi-dataset training schedules (like in RAFT) 43 | def __init__(self, 44 | seq_path: Path, 45 | forward_flow_timestamps: np.ndarray, 46 | forward_flow_paths: List[Path], 47 | data_augm: bool, 48 | num_bins: int=15, 49 | load_voxel_grid: bool=True, 50 | extended_voxel_grid: bool=True, 51 | normalize_voxel_grid: bool=False): 52 | assert num_bins >= 1 53 | assert seq_path.is_dir() 54 | 55 | # Save output dimensions 56 | self.height = 480 57 | self.width = 640 58 | self.num_bins = num_bins 59 | 60 | self.augmentor = FlowAugmentor(crop_size_hw=(288, 384)) if data_augm else None 61 | 62 | # Set event representation 63 | self.voxel_grid = VoxelGrid(self.num_bins, self.height, self.width) 64 | self.normalize_voxel_grid: Optional[norm_voxel_grid] = norm_voxel_grid if normalize_voxel_grid else None 65 | 66 | assert len(forward_flow_paths) == forward_flow_timestamps.shape[0] 67 | 68 | self.forward_flow_timestamps = forward_flow_timestamps 69 | 70 | for entry in forward_flow_paths: 71 | assert entry.exists() 72 | assert str(entry.name).endswith('.png') 73 | self.forward_flow_list = forward_flow_paths 74 | 75 | ### prepare loading of event data and load rectification map 76 | self.ev_dir = seq_path / 'events' / 'left' 77 | assert self.ev_dir.is_dir() 78 | self.ev_file = self.ev_dir / 'events.h5' 79 | assert self.ev_file.exists() 80 | 81 | rectify_events_map_file = self.ev_dir / 'rectify_map.h5' 82 | assert rectify_events_map_file.exists() 83 | with h5py.File(str(rectify_events_map_file), 'r') as h5_rect: 84 | self.rectify_events_map = h5_rect['rectify_map'][()] 85 | 86 | self.h5f: Optional[h5py.File] = None 87 | self.event_slicer: Optional[EventSlicer] = None 88 | self.h5f_opened = False 89 | 90 | ### prepare loading of image data (in left event camera frame) 91 | img_dir_ev_left = seq_path / 'images' / 'left' / 'ev_inf' 92 | self.img_dir_ev_left = None if not img_dir_ev_left.is_dir() else img_dir_ev_left 93 | 94 | ### Voxel Grid Saving 95 | # Version 0: Without considering the boundary properly but strictly causal 96 | # Version 1: Considering the boundary effect but loads a few events "of the future" 97 | self.version = 1 if extended_voxel_grid else 0 98 | self.voxel_grid_dir = self.ev_dir / f'voxel_grids_v{self.version}_100ms_forward_{num_bins}_bins' 99 | self.load_voxel_grid = load_voxel_grid 100 | if self.load_voxel_grid: 101 | if not self.voxel_grid_dir.exists(): 102 | os.mkdir(self.voxel_grid_dir) 103 | else: 104 | assert self.voxel_grid_dir.is_dir() 105 | 106 | def __open_h5f(self): 107 | assert self.h5f is None 108 | assert self.event_slicer is None 109 | 110 | self.h5f = h5py.File(str(self.ev_file), 'r') 111 | self.event_slicer = EventSlicer(self.h5f) 112 | 113 | self._finalizer = weakref.finalize(self, self.__close_callback, self.h5f) 114 | self.h5f_opened = True 115 | 116 | def _events_to_voxel_grid(self, x, y, p, t, t0_center: int=None, t1_center: int=None): 117 | t = t.astype('int64') 118 | x = x.astype('float32') 119 | y = y.astype('float32') 120 | pol = p.astype('float32') 121 | return self.voxel_grid.convert( 122 | torch.from_numpy(x), 123 | torch.from_numpy(y), 124 | torch.from_numpy(pol), 125 | torch.from_numpy(t), 126 | t0_center, 127 | t1_center) 128 | 129 | def getHeightAndWidth(self): 130 | return self.height, self.width 131 | 132 | @staticmethod 133 | def __close_callback(h5f: h5py.File): 134 | assert h5f is not None 135 | h5f.close() 136 | 137 | def _rectify_events(self, x: np.ndarray, y: np.ndarray): 138 | # From distorted to undistorted 139 | rectify_map = self.rectify_events_map 140 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape 141 | assert x.max() < self.width 142 | assert y.max() < self.height 143 | return rectify_map[y, x] 144 | 145 | def _get_ev_left_img(self, img_file_idx: int): 146 | if self.img_dir_ev_left is None: 147 | return None 148 | img_filename = f'{img_file_idx}'.zfill(6) + '.png' 149 | img_filepath = self.img_dir_ev_left / img_filename 150 | if not img_filepath.exists(): 151 | return None 152 | img = np.asarray(iio.imread(str(img_filepath), format='PNG-FI')) 153 | # img: (h, w, c) -> (c, h, w) 154 | img = np.moveaxis(img, -1, 0) 155 | return img 156 | 157 | def _get_events(self, ts_from: int, ts_to: int, rectify: bool): 158 | if not self.h5f_opened: 159 | self.__open_h5f() 160 | 161 | start_time_us = self.event_slicer.get_start_time_us() 162 | final_time_us = self.event_slicer.get_final_time_us() 163 | assert ts_from > start_time_us - 50000, 'Do not request more than 50 ms before the minimum time. Otherwise, something might be wrong.' 164 | assert ts_to < final_time_us + 50000, 'Do not request more than 50 ms past the maximum time. Otherwise, something might be wrong.' 165 | if ts_from < start_time_us: 166 | # Take the minimum time instead to avoid assertions in the event slicer. 167 | ts_from = start_time_us 168 | if ts_to > final_time_us: 169 | # Take the maximum time instead to avoid assertions in the event slicer. 170 | ts_to = final_time_us 171 | assert ts_from < ts_to 172 | 173 | event_data = self.event_slicer.get_events(ts_from, ts_to) 174 | 175 | pol = event_data['p'] 176 | time = event_data['t'] 177 | x = event_data['x'] 178 | y = event_data['y'] 179 | 180 | if rectify: 181 | xy_rect = self._rectify_events(x, y) 182 | x = xy_rect[:, 0] 183 | y = xy_rect[:, 1] 184 | assert pol.shape == time.shape == x.shape == y.shape 185 | 186 | out = { 187 | 'pol': pol, 188 | 'time': time, 189 | 'x': x, 190 | 'y': y, 191 | } 192 | return out 193 | 194 | def _construct_voxel_grid(self, ts_from: int, ts_to: int, rectify: bool=True): 195 | if self.version == 1: 196 | t_start, t_end = self.voxel_grid.get_extended_time_window(ts_from, ts_to) 197 | assert (ts_from - t_start) < 50000, f'ts_from: {ts_from}, t_start: {t_start}' 198 | assert (t_end - ts_to) < 50000, f't_end: {t_end}, ts_to: {ts_to}' 199 | event_data = self._get_events(t_start, t_end, rectify=rectify) 200 | voxel_grid = self._events_to_voxel_grid(event_data['x'], event_data['y'], event_data['pol'], event_data['time'], ts_from, ts_to) 201 | return voxel_grid 202 | elif self.version == 0: 203 | event_data = self._get_events(ts_from, ts_to, rectify=rectify) 204 | voxel_grid = self._events_to_voxel_grid(event_data['x'], event_data['y'], event_data['pol'], event_data['time']) 205 | return voxel_grid 206 | raise NotImplementedError 207 | 208 | def _load_voxel_grid(self, ts_from: int, ts_to: int, file_index: int): 209 | assert file_index >= 0 210 | 211 | # Assuming we want to load the 'forward' voxel grid. 212 | voxel_grid_file = self.voxel_grid_dir / (f'{file_index}'.zfill(6) + '.h5') 213 | if not voxel_grid_file.exists(): 214 | voxel_grid = self._construct_voxel_grid(ts_from, ts_to) 215 | np_array_to_h5(voxel_grid.numpy(), voxel_grid_file) 216 | return voxel_grid 217 | return torch.from_numpy(h5_to_np_array(voxel_grid_file)) 218 | 219 | def _get_voxel_grid(self, ts_from: int, ts_to: int, file_index: int): 220 | if self.load_voxel_grid: 221 | return self._load_voxel_grid(ts_from, ts_to, file_index) 222 | return self._construct_voxel_grid(ts_from, ts_to) -------------------------------------------------------------------------------- /data/dsec/subsequence/twostep.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from data.dsec.subsequence.base import BaseSubSequence 8 | from data.utils.generic import load_flow 9 | from data.utils.keys import DataLoading, DataSetType 10 | 11 | 12 | class TwoStepSubSequence(BaseSubSequence): 13 | def __init__(self, 14 | seq_path: Path, 15 | forward_flow_timestamps: np.ndarray, 16 | forward_flow_paths: List[Path], 17 | data_augm: bool, 18 | num_bins: int, 19 | load_voxel_grid: bool, 20 | extended_voxel_grid: bool, 21 | normalize_voxel_grid: bool, 22 | merge_grids: bool): 23 | super().__init__( 24 | seq_path, 25 | forward_flow_timestamps, 26 | forward_flow_paths, 27 | data_augm, 28 | num_bins, 29 | load_voxel_grid, 30 | extended_voxel_grid=extended_voxel_grid, 31 | normalize_voxel_grid=normalize_voxel_grid) 32 | self.merge_grids = merge_grids 33 | 34 | def __len__(self): 35 | return len(self.forward_flow_list) 36 | 37 | def __getitem__(self, index): 38 | forward_flow_gt_path = self.forward_flow_list[index] 39 | flow_file_index = int(forward_flow_gt_path.stem) 40 | forward_flow, forward_flow_valid2D = load_flow(forward_flow_gt_path) 41 | # forward_flow: (h, w, 2) -> (2, h, w) 42 | forward_flow = np.moveaxis(forward_flow, -1, 0) 43 | 44 | # Events during the flow duration. 45 | ev_repr_list = list() 46 | ts_from = None 47 | ts_to = None 48 | for idx in [index, index - 1]: 49 | if self._is_index_valid(idx): 50 | ts = self.forward_flow_timestamps[idx] 51 | ts_from = ts[0] 52 | ts_to = ts[1] 53 | else: 54 | assert idx == index - 1 55 | assert ts_from is not None 56 | assert ts_to is not None 57 | dt = ts_to - ts_from 58 | ts_to = ts_from 59 | ts_from = ts_from - dt 60 | # Hardcoded assumption about 100ms steps and filenames. 61 | file_index = flow_file_index if idx == index else flow_file_index - 2 62 | ev_repr = self._get_voxel_grid(ts_from, ts_to, file_index) 63 | ev_repr_list.append(ev_repr) 64 | 65 | imgs_list = None 66 | img_idx_reference = flow_file_index 67 | img_reference = self._get_ev_left_img(img_idx_reference) 68 | if img_reference is not None: 69 | # Assume 100ms steps (take every second frame). 70 | # Assume forward flow (target frame in the future) 71 | img_idx_target = img_idx_reference + 2 72 | img_target = self._get_ev_left_img(img_idx_target) 73 | assert img_target is not None 74 | imgs_list = [img_reference, img_target] 75 | 76 | # 0: previous, 1: current 77 | ev_repr_list.reverse() 78 | if self.merge_grids: 79 | ev_repr_0 = ev_repr_list[0] 80 | ev_repr_1 = ev_repr_list[1] 81 | assert (ev_repr_0[-1] - ev_repr_1[0]).flatten().abs().max() < 0.5, f'{(ev_repr_0[-1] - ev_repr_1[0]).flatten().abs().max()}' 82 | # Remove the redundant temporal slice. 83 | event_representations = torch.cat((ev_repr_0, ev_repr_1[1:, ...]), dim=0) 84 | if self.normalize_voxel_grid is not None: 85 | event_representations = self.normalize_voxel_grid(event_representations) 86 | else: 87 | if self.normalize_voxel_grid is not None: 88 | ev_repr_list = [self.normalize_voxel_grid(voxel_grid) for voxel_grid in ev_repr_list] 89 | event_representations = torch.stack(ev_repr_list) 90 | 91 | if self.augmentor is not None: 92 | event_representations, forward_flow, forward_flow_valid2D, imgs_list = self.augmentor(event_representations, forward_flow, forward_flow_valid2D, imgs_list) 93 | 94 | output = { 95 | DataLoading.FLOW: forward_flow, 96 | DataLoading.FLOW_VALID: forward_flow_valid2D, 97 | DataLoading.FILE_INDEX: flow_file_index, 98 | # For now returns: 2 x bins x height x width or (2*bins-1) x height x width 99 | DataLoading.EV_REPR: event_representations, 100 | DataLoading.DATASET_TYPE: DataSetType.DSEC, 101 | } 102 | if imgs_list is not None: 103 | output.update({DataLoading.IMG: imgs_list}) 104 | 105 | return output 106 | 107 | def _is_index_valid(self, index): 108 | return index >= 0 and index < len(self) 109 | -------------------------------------------------------------------------------- /data/multiflow2d/datasubset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List 3 | 4 | from torch.utils.data import Dataset 5 | 6 | from data.utils.augmentor import FlowAugmentor, PhotoAugmentor 7 | from data.utils.keys import DataLoading, DataSetType 8 | from data.utils.representations import norm_voxel_grid 9 | 10 | from data.multiflow2d.sample import Sample 11 | 12 | 13 | class Datasubset(Dataset): 14 | def __init__(self, 15 | train_or_val_path: Path, 16 | data_augm: bool, 17 | num_bins_context: int, 18 | flow_every_n_ms: int, 19 | load_voxel_grid: bool=True, 20 | extended_voxel_grid: bool=True, 21 | normalize_voxel_grid: bool=False, 22 | downsample: bool=False, 23 | photo_augm: bool=False, 24 | return_img: bool=True, 25 | return_ev: bool=True): 26 | assert train_or_val_path.is_dir() 27 | assert train_or_val_path.name in ('train', 'val') 28 | 29 | # Save output dimensions 30 | original_height = 384 31 | original_width = 512 32 | 33 | crop_height = 368 34 | crop_width = 496 35 | 36 | self.return_img = return_img 37 | if not self.return_img: 38 | raise NotImplementedError 39 | self.return_ev = return_ev 40 | 41 | if downsample: 42 | crop_height = crop_height // 2 43 | crop_width = crop_width // 2 44 | 45 | self.delta_ts_flow_ms = flow_every_n_ms 46 | 47 | self.spatial_augmentor = FlowAugmentor( 48 | crop_size_hw=(crop_height, crop_width), 49 | h_flip_prob=0.5, 50 | v_flip_prob=0.5) if data_augm else None 51 | self.photo_augmentor = PhotoAugmentor( 52 | brightness=0.4, 53 | contrast=0.4, 54 | saturation=0.4, 55 | hue=0.5/3.14, 56 | probability_color=0.2, 57 | noise_variance_range=(0.001, 0.01), 58 | probability_noise=0.2) if data_augm and photo_augm else None 59 | self.normalize_voxel_grid: Optional[norm_voxel_grid] = norm_voxel_grid if normalize_voxel_grid else None 60 | 61 | sample_list: List[Sample] = list() 62 | for sample_path in train_or_val_path.iterdir(): 63 | if not sample_path.is_dir(): 64 | continue 65 | sample_list.append( 66 | Sample(sample_path, original_height, original_width, num_bins_context, load_voxel_grid, extended_voxel_grid, downsample) 67 | ) 68 | self.sample_list = sample_list 69 | 70 | def get_num_bins_context(self): 71 | return self.sample_list[0].num_bins_context 72 | 73 | def get_num_bins_correlation(self): 74 | return self.sample_list[0].num_bins_correlation 75 | 76 | def get_num_bins_total(self): 77 | return self.sample_list[0].num_bins_total 78 | 79 | def _voxel_grid_bin_idx_for_reference(self) -> int: 80 | return self.sample_list[0].voxel_grid_bin_idx_for_reference() 81 | 82 | def __len__(self): 83 | return len(self.sample_list) 84 | 85 | def __getitem__(self, index): 86 | sample = self.sample_list[index] 87 | 88 | voxel_grid = sample.get_voxel_grid() if self.return_ev else None 89 | if voxel_grid is not None and self.normalize_voxel_grid is not None: 90 | voxel_grid = self.normalize_voxel_grid(voxel_grid) 91 | 92 | gt_flow_dict = sample.get_flow_gt(self.delta_ts_flow_ms) 93 | gt_flow = gt_flow_dict['flow'] 94 | gt_flow_ts = gt_flow_dict['timestamps'] 95 | 96 | imgs_with_ts = sample.get_images() 97 | imgs = imgs_with_ts['images'] 98 | img_ts = imgs_with_ts['timestamps'] 99 | 100 | # normalize image timestamps from 0 to 1 101 | assert len(img_ts) == 2 102 | ts_start = img_ts[0] 103 | ts_end = img_ts[1] 104 | assert ts_end > ts_start 105 | img_ts = [(x - ts_start)/(ts_end - ts_start) for x in img_ts] 106 | assert img_ts[0] == 0 107 | assert img_ts[1] == 1 108 | 109 | # we assume that img_ts[0] refers to reference time and img_ts[1] to the final target time 110 | gt_flow_ts = [(x - ts_start)/(ts_end - ts_start) for x in gt_flow_ts] 111 | assert gt_flow_ts[-1] == 1 112 | assert len(gt_flow_ts) == len(gt_flow) 113 | 114 | if self.spatial_augmentor is not None: 115 | if voxel_grid is None: 116 | gt_flow, imgs = self.spatial_augmentor(flow=gt_flow, images=imgs) 117 | else: 118 | voxel_grid, gt_flow, imgs = self.spatial_augmentor(ev_repr=voxel_grid, flow=gt_flow, images=imgs) 119 | if self.photo_augmentor is not None: 120 | imgs = self.photo_augmentor(imgs) 121 | out = { 122 | DataLoading.BIN_META: { 123 | 'bin_idx_for_reference': self._voxel_grid_bin_idx_for_reference(), 124 | 'nbins_context': self.get_num_bins_context(), 125 | 'nbins_correlation': self.get_num_bins_correlation(), 126 | 'nbins_total': self.get_num_bins_total(), 127 | }, 128 | DataLoading.FLOW: gt_flow, 129 | DataLoading.FLOW_TIMESTAMPS: gt_flow_ts, 130 | DataLoading.IMG: imgs, 131 | DataLoading.IMG_TIMESTAMPS: img_ts, 132 | DataLoading.DATASET_TYPE: DataSetType.MULTIFLOW2D, 133 | } 134 | if voxel_grid is not None: 135 | out.update({DataLoading.EV_REPR: voxel_grid}) 136 | 137 | return out 138 | -------------------------------------------------------------------------------- /data/multiflow2d/provider.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from pathlib import Path 3 | from typing import Dict, Any 4 | 5 | import torch.utils.data 6 | 7 | from data.utils.provider import DatasetProviderBase 8 | from data.multiflow2d.datasubset import Datasubset 9 | 10 | 11 | class DatasetProvider(DatasetProviderBase): 12 | def __init__(self, 13 | dataset_params: Dict[str, Any], 14 | nbins_context: int): 15 | dataset_path = Path(dataset_params['path']) 16 | train_path = dataset_path / 'train' 17 | val_path = dataset_path / 'val' 18 | assert dataset_path.is_dir(), str(dataset_path) 19 | assert train_path.is_dir(), str(train_path) 20 | assert val_path.is_dir(), str(val_path) 21 | 22 | return_img = True 23 | return_img_key = 'return_img' 24 | if return_img_key in dataset_params: 25 | return_img = dataset_params[return_img_key] 26 | return_ev = True 27 | return_ev_key = 'return_ev' 28 | if return_ev_key in dataset_params: 29 | return_ev = dataset_params[return_ev_key] 30 | base_args = { 31 | 'num_bins_context': nbins_context, 32 | 'load_voxel_grid': dataset_params['load_voxel_grid'], 33 | 'normalize_voxel_grid': dataset_params['normalize_voxel_grid'], 34 | 'extended_voxel_grid': dataset_params['extended_voxel_grid'], 35 | 'flow_every_n_ms': dataset_params['flow_every_n_ms'], 36 | 'downsample': dataset_params['downsample'], 37 | 'photo_augm': dataset_params['photo_augm'], 38 | return_img_key: return_img, 39 | return_ev_key: return_ev, 40 | } 41 | train_args = copy.deepcopy(base_args) 42 | train_args.update({'data_augm': True}) 43 | val_test_args = copy.deepcopy(base_args) 44 | val_test_args.update({'data_augm': False}) 45 | 46 | train_dataset = Datasubset(train_path, **train_args) 47 | self.nbins_context = train_dataset.get_num_bins_context() 48 | self.nbins_correlation = train_dataset.get_num_bins_correlation() 49 | 50 | self.train_dataset = train_dataset 51 | self.val_dataset = Datasubset(val_path, **val_test_args) 52 | assert self.val_dataset.get_num_bins_context() == self.nbins_context 53 | assert self.val_dataset.get_num_bins_correlation() == self.nbins_correlation 54 | 55 | def get_train_dataset(self): 56 | return self.train_dataset 57 | 58 | def get_val_dataset(self): 59 | return self.val_dataset 60 | 61 | def get_test_dataset(self): 62 | raise NotImplementedError 63 | 64 | def get_nbins_context(self): 65 | return self.nbins_context 66 | 67 | def get_nbins_correlation(self): 68 | return self.nbins_correlation 69 | -------------------------------------------------------------------------------- /data/multiflow2d/sample.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict 3 | 4 | import h5py 5 | import imageio as iio 6 | import numpy as np 7 | import torch 8 | from torch.nn.functional import interpolate 9 | 10 | from data.utils.generic import np_array_to_h5, h5_to_np_array 11 | from data.utils.representations import VoxelGrid 12 | 13 | 14 | class Sample: 15 | # Assumes the following structure: 16 | # seq* 17 | # ├── events 18 | # │ └── events.h5 19 | # ├── flow 20 | # │ ├── 0500000.h5 21 | # │ ├── ... 22 | # │ └── 0900000.h5 23 | # └── images 24 | # ├── 0400000.png 25 | # ├── ... 26 | # └── 0900000.png 27 | 28 | def __init__(self, 29 | sample_path: Path, 30 | height: int, 31 | width: int, 32 | num_bins_context: int, 33 | load_voxel_grid: bool=True, 34 | extended_voxel_grid: bool=True, 35 | downsample: bool=False, 36 | ) -> None: 37 | assert sample_path.is_dir() 38 | assert num_bins_context >= 1 39 | assert sample_path.is_dir() 40 | 41 | nbins_context2corr = { 42 | 6: 4, 43 | 11: 7, 44 | 21: 13, 45 | 41: 25, 46 | } 47 | nbins_context2deltatime = { 48 | 6: 100000, 49 | 11: 50000, 50 | 21: 25000, 51 | 41: 12500, 52 | } 53 | 54 | ### To downsample or to not downsample 55 | self.downsample = downsample 56 | 57 | ### Voxel Grid 58 | assert num_bins_context in nbins_context2corr.keys() 59 | self.num_bins_context = num_bins_context 60 | self.num_bins_correlation = nbins_context2corr[num_bins_context] 61 | # We subtract one because the bin at the reference time is redundant 62 | self.num_bins_total = self.num_bins_context + self.num_bins_correlation - 1 63 | 64 | self.voxel_grid = VoxelGrid(self.num_bins_total, height, width) 65 | 66 | # Image data 67 | ref_time_us = 400*1000 68 | target_time_us = 900*1000 69 | img_ref_path = sample_path / 'images' / (f'{ref_time_us}'.zfill(7) + '.png') 70 | assert img_ref_path.exists() 71 | img_target_path = sample_path / 'images' / (f'{target_time_us}'.zfill(7) + '.png') 72 | assert img_target_path.exists() 73 | self.img_filepaths = [img_ref_path, img_target_path] 74 | self.img_ts = [int(x.stem) for x in self.img_filepaths] 75 | 76 | # Extract timestamps for later retrieving event data 77 | self.bin_0_time = self.img_ts[0] - (self.num_bins_correlation - 1)*nbins_context2deltatime[num_bins_context] 78 | assert self.bin_0_time >= 0 79 | self.bin_target_time = self.img_ts[1] 80 | 81 | # Flow data 82 | self.flow_ref_ts_us = ref_time_us 83 | flow_dir = sample_path / 'flow' 84 | assert flow_dir.is_dir() 85 | flow_filepaths = list() 86 | for flow_file in flow_dir.iterdir(): 87 | assert flow_file.suffix == '.h5' 88 | flow_filepaths.append(flow_file) 89 | flow_filepaths.sort() 90 | self.flow_filepaths = flow_filepaths 91 | self.flow_ts_us = [int(x.stem) for x in self.flow_filepaths] 92 | 93 | # Event data 94 | ev_dir = sample_path / 'events' 95 | assert ev_dir.is_dir() 96 | self.event_filepath = ev_dir / 'events.h5' 97 | assert self.event_filepath.exists() 98 | 99 | ### Voxel Grid Saving 100 | self.version = 1 if extended_voxel_grid else 0 101 | downsample_str = '_downsampled' if self.downsample else '' 102 | self.voxel_grid_file = ev_dir / f'voxel_grid_v{self.version}_{self.num_bins_total}_bins{downsample_str}.h5' 103 | self.load_voxel_grid_from_disk = load_voxel_grid 104 | 105 | def downsample_tensor(self, input_tensor: torch.Tensor): 106 | assert input_tensor.ndim == 3 107 | assert self.downsample 108 | ch, ht, wd = input_tensor.shape 109 | input_tensor = input_tensor.float() 110 | return interpolate(input_tensor[None, ...], size=(ht//2, wd//2), align_corners=True, mode='bilinear').squeeze() 111 | 112 | def get_flow_gt(self, flow_every_n_ms: int): 113 | assert flow_every_n_ms > 0 114 | assert flow_every_n_ms % 10 == 0, 'must be a multiple of 10' 115 | delta_ts_us = flow_every_n_ms*1000 116 | out = { 117 | 'flow': list(), 118 | 'timestamps': list(), 119 | } 120 | for flow_ts_us, flow_filepath in zip(self.flow_ts_us, self.flow_filepaths): 121 | if (flow_ts_us - self.flow_ref_ts_us) % delta_ts_us != 0: 122 | continue 123 | out['timestamps'].append(flow_ts_us) 124 | with h5py.File(str(flow_filepath), 'r') as h5f: 125 | flow = np.asarray(h5f['flow']) 126 | # h, w, c -> c, h, w 127 | flow = np.moveaxis(flow, -1, 0) 128 | flow = torch.from_numpy(flow) 129 | if self.downsample: 130 | flow = self.downsample_tensor(flow) 131 | flow = flow/2 132 | out['flow'].append(flow) 133 | return out 134 | 135 | def get_images(self): 136 | out = { 137 | 'images': list(), 138 | 'timestamps': self.img_ts, 139 | } 140 | for img_path in self.img_filepaths: 141 | img = np.asarray(iio.imread(str(img_path), format='PNG-FI')) 142 | # img: (h, w, c) -> (c, h, w) 143 | img = np.moveaxis(img, -1, 0) 144 | img = torch.from_numpy(img) 145 | if self.downsample: 146 | img = self.downsample_tensor(img) 147 | out['images'].append(img) 148 | return out 149 | 150 | def _get_events(self, t_start: int, t_end: int): 151 | assert t_start >= 0 152 | assert t_end <= 1000000 153 | assert t_end > t_start 154 | with h5py.File(str(self.event_filepath), 'r') as h5f: 155 | time = np.asarray(h5f['t']) 156 | first_idx = np.searchsorted(time, t_start, side='left') 157 | last_idx_p1 = np.searchsorted(time, t_end, side='right') 158 | out = { 159 | 'x': np.asarray(h5f['x'][first_idx:last_idx_p1]), 160 | 'y': np.asarray(h5f['y'][first_idx:last_idx_p1]), 161 | 'p': np.asarray(h5f['p'][first_idx:last_idx_p1]), 162 | 't': time[first_idx:last_idx_p1], 163 | } 164 | return out 165 | 166 | def _events_to_voxel_grid(self, event_dict: Dict[str, np.ndarray], t0_center: int=None, t1_center: int=None) -> torch.Tensor: 167 | # int32 is enough for this dataset as the timestamps are in us and start at 0 for every sample sequence 168 | t = event_dict['t'].astype('int32') 169 | x = event_dict['x'].astype('int16') 170 | y = event_dict['y'].astype('int16') 171 | pol = event_dict['p'].astype('int8') 172 | return self.voxel_grid.convert( 173 | torch.from_numpy(x), 174 | torch.from_numpy(y), 175 | torch.from_numpy(pol), 176 | torch.from_numpy(t), 177 | t0_center, 178 | t1_center) 179 | 180 | def _construct_voxel_grid(self, ts_from: int, ts_to: int): 181 | if self.version == 1: 182 | t_start, t_end = self.voxel_grid.get_extended_time_window(ts_from, ts_to) 183 | assert (ts_from - t_start) <= 100000, f'ts_from: {ts_from}, t_start: {t_start}' 184 | assert (t_end - ts_to) <= 100000, f't_end: {t_end}, ts_to: {ts_to}' 185 | event_data = self._get_events(t_start, t_end) 186 | voxel_grid = self._events_to_voxel_grid(event_data, ts_from, ts_to) 187 | elif self.version == 0: 188 | event_data = self._get_events(ts_from, ts_to) 189 | voxel_grid = self._events_to_voxel_grid(event_data) 190 | else: 191 | raise NotImplementedError 192 | if self.downsample: 193 | voxel_grid = self.downsample_tensor(voxel_grid) 194 | return voxel_grid 195 | 196 | def _load_or_save_voxel_grid(self, ts_from: int, ts_to: int) -> torch.Tensor: 197 | if self.voxel_grid_file.exists(): 198 | voxel_grid_numpy = h5_to_np_array(self.voxel_grid_file) 199 | if voxel_grid_numpy is not None: 200 | # Squeeze because we may have saved it with batch dim 1 before (by mistake) 201 | return torch.from_numpy(voxel_grid_numpy).squeeze() 202 | # If None is returned, it means that the file was corrupt and we overwrite it. 203 | voxel_grid = self._construct_voxel_grid(ts_from, ts_to) 204 | np_array_to_h5(voxel_grid.numpy(), self.voxel_grid_file) 205 | return voxel_grid 206 | 207 | def get_voxel_grid(self) -> torch.Tensor: 208 | ts_from = self.bin_0_time 209 | ts_to = self.bin_target_time 210 | if self.load_voxel_grid_from_disk: 211 | return self._load_or_save_voxel_grid(ts_from, ts_to).squeeze() 212 | return self._construct_voxel_grid(ts_from, ts_to).squeeze() 213 | 214 | def voxel_grid_bin_idx_for_reference(self) -> int: 215 | return self.num_bins_correlation - 1 -------------------------------------------------------------------------------- /data/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union, Any, Tuple 2 | 3 | import numpy as np 4 | import skimage 5 | from skimage import img_as_ubyte 6 | import torch 7 | from torchvision.transforms import ColorJitter 8 | 9 | from utils.general import inputs_to_tensor, wrap_unwrap_lists_for_class_method 10 | 11 | UAT = Union[np.ndarray, torch.Tensor] 12 | ULISTT = Union[List[torch.Tensor], torch.Tensor] 13 | ULISTAT = Union[List[UAT], UAT] 14 | 15 | def torch_img_to_numpy(torch_img: torch.Tensor): 16 | ch, ht, wd = torch_img.shape 17 | assert ch == 3 18 | numpy_img = torch_img.numpy() 19 | numpy_img = np.moveaxis(numpy_img, 0, -1) 20 | return numpy_img 21 | 22 | def numpy_img_to_torch(numpy_img: np.ndarray): 23 | ht, wd, ch = numpy_img.shape 24 | assert ch == 3 25 | numpy_img = np.moveaxis(numpy_img, -1, 0) 26 | torch_img = torch.from_numpy(numpy_img) 27 | return torch_img 28 | 29 | class PhotoAugmentor: 30 | def __init__(self, 31 | brightness: float, 32 | contrast: float, 33 | saturation: float, 34 | hue: float, 35 | probability_color: float, 36 | noise_variance_range: Tuple[float, float], 37 | probability_noise: float): 38 | assert 0 <= probability_color <= 1 39 | assert 0 <= probability_noise <= 1 40 | assert len(noise_variance_range) == 2 41 | self.photo_augm = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 42 | self.probability_color = probability_color 43 | self.probability_noise = probability_noise 44 | self.var_min = noise_variance_range[0] 45 | self.var_max = noise_variance_range[1] 46 | assert self.var_max > self.var_min 47 | 48 | self.seed = torch.randint(low=0, high=2**32, size=(1,))[0].item() 49 | 50 | @staticmethod 51 | def sample_uniform(min_value: float=0, max_value: float=1) -> float: 52 | assert max_value > min_value 53 | uni_sample = torch.rand(1)[0].item() 54 | return (max_value - min_value)*uni_sample + min_value 55 | 56 | @wrap_unwrap_lists_for_class_method 57 | def _apply_jitter(self, images: ULISTT): 58 | assert isinstance(images, list) 59 | 60 | for idx, entry in enumerate(images): 61 | images[idx] = self.photo_augm(entry) 62 | 63 | return images 64 | 65 | @wrap_unwrap_lists_for_class_method 66 | def _apply_noise(self, images: ULISTT): 67 | assert isinstance(images, list) 68 | variance = self.sample_uniform(min_value=0.001, max_value=0.01) 69 | 70 | for idx, entry in enumerate(images): 71 | assert isinstance(entry, torch.Tensor) 72 | numpy_img = torch_img_to_numpy(entry) 73 | noisy_img = skimage.util.random_noise(numpy_img, mode='speckle', var=variance, clip=True, seed=self.seed) # return float64 in [0, 1] 74 | noisy_img = img_as_ubyte(noisy_img) 75 | torch_img = numpy_img_to_torch(noisy_img) 76 | images[idx] = torch_img 77 | 78 | return images 79 | 80 | @inputs_to_tensor 81 | def __call__(self, images: ULISTAT): 82 | if self.probability_color > torch.rand(1).item(): 83 | images = self._apply_jitter(images) 84 | if self.probability_noise > torch.rand(1).item(): 85 | images = self._apply_noise(images) 86 | return images 87 | 88 | class FlowAugmentor: 89 | def __init__(self, 90 | crop_size_hw, 91 | h_flip_prob: float=0.5, 92 | v_flip_prob: float=0.1): 93 | assert crop_size_hw[0] > 0 94 | assert crop_size_hw[1] > 0 95 | assert 0 <= h_flip_prob <= 1 96 | assert 0 <= v_flip_prob <= 1 97 | 98 | self.h_flip_prob = h_flip_prob 99 | self.v_flip_prob = v_flip_prob 100 | self.crop_size_hw = crop_size_hw 101 | 102 | @wrap_unwrap_lists_for_class_method 103 | def _random_cropping(self, ev_repr: Optional[ULISTT], flow: Optional[ULISTT], valid: Optional[ULISTT]=None, images: Optional[ULISTT]=None): 104 | if ev_repr is not None: 105 | assert isinstance(ev_repr, list) 106 | height, width = ev_repr[0].shape[-2:] 107 | elif images is not None: 108 | assert isinstance(images, list) 109 | height, width = images[0].shape[-2:] 110 | else: 111 | raise NotImplementedError 112 | 113 | y0 = torch.randint(0, height - self.crop_size_hw[0], (1,)).item() 114 | x0 = torch.randint(0, width - self.crop_size_hw[1], (1,)).item() 115 | 116 | if ev_repr is not None: 117 | assert isinstance(ev_repr, list) 118 | nbins = ev_repr[0].shape[-3] 119 | for idx, entry in enumerate(ev_repr): 120 | assert entry.shape[-3:] == (nbins, height, width), f'actual (nbins, h, w) = ({entry.shape[-3]}, {entry.shape[-2]}, {entry.shape[-1]}), expected (nbins, h, w) = ({nbins}, {height}, {width})' 121 | # NOTE: Elements of a range-based for loop do not directly modify the original list in Python! 122 | ev_repr[idx] = entry[..., y0:y0+self.crop_size_hw[0], x0:x0+self.crop_size_hw[1]] 123 | 124 | if flow is not None: 125 | assert isinstance(flow, list) 126 | for idx, entry in enumerate(flow): 127 | assert entry.shape[-3:] == (2, height, width), f'actual (c, h, w) = ({entry.shape[-3]}, {entry.shape[-2]}, {entry.shape[-1]}), expected (2, h, w) = (2, {height}, {width})' 128 | flow[idx] = entry[..., y0:y0+self.crop_size_hw[0], x0:x0+self.crop_size_hw[1]] 129 | 130 | if valid is not None: 131 | assert isinstance(valid, list) 132 | for idx, entry in enumerate(valid): 133 | assert entry.shape[-2:] == (height, width), f'actual (h, w) = ({entry.shape[-2]}, {entry.shape[-1]}), expected (h, w) = ({height}, {width})' 134 | valid[idx] = entry[y0:y0+self.crop_size_hw[0], x0:x0+self.crop_size_hw[1]] 135 | 136 | if images is not None: 137 | assert isinstance(images, list) 138 | for idx, entry in enumerate(images): 139 | assert entry.shape[-2:] == (height, width), f'actual (h, w) = ({entry.shape[-2]}, {entry.shape[-1]}), expected (h, w) = ({height}, {width})' 140 | images[idx] = entry[..., y0:y0+self.crop_size_hw[0], x0:x0+self.crop_size_hw[1]] 141 | 142 | return ev_repr, flow, valid, images 143 | 144 | @wrap_unwrap_lists_for_class_method 145 | def _horizontal_flipping(self, ev_repr: Optional[ULISTT], flow: Optional[ULISTT], valid: Optional[ULISTT]=None, images: Optional[ULISTT]=None): 146 | # flip last axis which is assumed to be width 147 | if ev_repr is not None: 148 | assert isinstance(ev_repr, list) 149 | ev_repr = [x.flip(-1) for x in ev_repr] 150 | if images is not None: 151 | assert isinstance(images, list) 152 | images = [x.flip(-1) for x in images] 153 | if valid is not None: 154 | assert isinstance(valid, list) 155 | valid = [x.flip(-1) for x in valid] 156 | if flow is not None: 157 | assert isinstance(flow, list) 158 | flow = [x.flip(-1) for x in flow] 159 | # also flip the sign of the x component of the flow 160 | for idx, entry in enumerate(flow): 161 | flow[idx][0] = -1 * entry[0] 162 | 163 | return ev_repr, flow, valid, images 164 | 165 | @wrap_unwrap_lists_for_class_method 166 | def _vertical_flipping(self, ev_repr: Optional[ULISTT], flow: Optional[ULISTT], valid: Optional[ULISTT]=None, images: Optional[ULISTT]=None): 167 | # flip second last axis which is assumed to be height 168 | if ev_repr is not None: 169 | assert isinstance(ev_repr, list) 170 | ev_repr = [x.flip(-2) for x in ev_repr] 171 | if images is not None: 172 | assert isinstance(images, list) 173 | images = [x.flip(-2) for x in images] 174 | if valid is not None: 175 | assert isinstance(valid, list) 176 | valid = [x.flip(-2) for x in valid] 177 | if flow is not None: 178 | assert isinstance(flow, list) 179 | flow = [x.flip(-2) for x in flow] 180 | # also flip the sign of the y component of the flow 181 | for idx, entry in enumerate(flow): 182 | flow[idx][1] = -1 * entry[1] 183 | 184 | return ev_repr, flow, valid, images 185 | 186 | @inputs_to_tensor 187 | def __call__(self, 188 | ev_repr: Optional[ULISTAT]=None, 189 | flow: Optional[ULISTAT]=None, 190 | valid: Optional[ULISTAT]=None, 191 | images: Optional[ULISTAT]=None): 192 | 193 | if self.h_flip_prob > torch.rand(1).item(): 194 | ev_repr, flow, valid, images = self._horizontal_flipping(ev_repr, flow, valid, images) 195 | 196 | if self.v_flip_prob > torch.rand(1).item(): 197 | ev_repr, flow, valid, images = self._vertical_flipping(ev_repr, flow, valid, images) 198 | 199 | ev_repr, flow, valid, images = self._random_cropping(ev_repr, flow, valid, images) 200 | 201 | out = (ev_repr, flow, valid, images) 202 | out = (x for x in out if x is not None) 203 | return out 204 | -------------------------------------------------------------------------------- /data/utils/generic.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union 3 | 4 | import cv2 5 | import h5py 6 | import numpy as np 7 | 8 | 9 | def _flow_16bit_to_float(flow_16bit: np.ndarray): 10 | assert flow_16bit.dtype == np.uint16, flow_16bit.dtype 11 | assert flow_16bit.ndim == 3 12 | h, w, c = flow_16bit.shape 13 | assert c == 3 14 | 15 | valid2D = flow_16bit[..., 2] == 1 16 | assert valid2D.shape == (h, w) 17 | assert np.all(flow_16bit[~valid2D, -1] == 0) 18 | valid_map = np.where(valid2D) 19 | flow_16bit = flow_16bit.astype('float32') 20 | flow_map = np.zeros((h, w, 2), dtype='float32') 21 | flow_map[valid_map[0], valid_map[1], 0] = (flow_16bit[valid_map[0], valid_map[1], 0] - 2 ** 15) / 128 22 | flow_map[valid_map[0], valid_map[1], 1] = (flow_16bit[valid_map[0], valid_map[1], 1] - 2 ** 15) / 128 23 | return flow_map, valid2D 24 | 25 | 26 | def load_flow(flowfile: Path): 27 | assert flowfile.exists() 28 | assert flowfile.suffix == '.png' 29 | # flow_16bit = np.array(Image.open(str(flowfile))) 30 | flow_16bit = cv2.imread(str(flowfile), cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR) 31 | flow, valid2D = _flow_16bit_to_float(flow_16bit) 32 | return flow, valid2D 33 | 34 | 35 | def _blosc_opts(complevel=1, complib='blosc:zstd', shuffle='byte'): 36 | shuffle = 2 if shuffle == 'bit' else 1 if shuffle == 'byte' else 0 37 | compressors = ['blosclz', 'lz4', 'lz4hc', 'snappy', 'zlib', 'zstd'] 38 | complib = ['blosc:' + c for c in compressors].index(complib) 39 | args = { 40 | 'compression': 32001, 41 | 'compression_opts': (0, 0, 0, 0, complevel, shuffle, complib), 42 | } 43 | if shuffle > 0: 44 | # Do not use h5py shuffle if blosc shuffle is enabled. 45 | args['shuffle'] = False 46 | return args 47 | 48 | 49 | def np_array_to_h5(array: np.ndarray, outpath: Path) -> None: 50 | isinstance(array, np.ndarray) 51 | assert outpath.suffix == '.h5' 52 | 53 | with h5py.File(str(outpath), 'w') as h5f: 54 | h5f.create_dataset('voxel_grid', data=array, shape=array.shape, dtype=array.dtype, 55 | **_blosc_opts(complevel=1, shuffle='byte')) 56 | 57 | 58 | def h5_to_np_array(inpath: Path) -> Union[np.ndarray, None]: 59 | assert inpath.suffix == '.h5' 60 | assert inpath.exists() 61 | 62 | try: 63 | with h5py.File(str(inpath), 'r') as h5f: 64 | array = np.asarray(h5f['voxel_grid']) 65 | return array 66 | except OSError as e: 67 | print(f'Error loading {inpath}') 68 | return None 69 | -------------------------------------------------------------------------------- /data/utils/keys.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto, IntEnum 2 | 3 | class DataSetType(IntEnum): 4 | DSEC = auto() 5 | MULTIFLOW2D = auto() 6 | 7 | class DataLoading(Enum): 8 | FLOW = auto() 9 | FLOW_TIMESTAMPS = auto() 10 | FLOW_VALID = auto() 11 | FILE_INDEX = auto() 12 | EV_REPR = auto() 13 | BIN_META = auto() 14 | IMG = auto() 15 | IMG_TIMESTAMPS = auto() 16 | DATASET_TYPE = auto() -------------------------------------------------------------------------------- /data/utils/provider.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class DatasetProviderBase(abc.ABC): 5 | @abc.abstractmethod 6 | def get_train_dataset(self): 7 | raise NotImplementedError 8 | 9 | @abc.abstractmethod 10 | def get_val_dataset(self): 11 | raise NotImplementedError 12 | 13 | @abc.abstractmethod 14 | def get_test_dataset(self): 15 | raise NotImplementedError 16 | 17 | @abc.abstractmethod 18 | def get_nbins_context(self): 19 | raise NotImplementedError 20 | 21 | @abc.abstractmethod 22 | def get_nbins_correlation(self): 23 | raise NotImplementedError -------------------------------------------------------------------------------- /data/utils/representations.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | torch.set_num_threads(1) # intraop parallelism (this can be a good option) 6 | torch.set_num_interop_threads(1) # interop parallelism 7 | 8 | 9 | def norm_voxel_grid(voxel_grid: torch.Tensor): 10 | mask = torch.nonzero(voxel_grid, as_tuple=True) 11 | if mask[0].size()[0] > 0: 12 | mean = voxel_grid[mask].mean() 13 | std = voxel_grid[mask].std() 14 | if std > 0: 15 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 16 | else: 17 | voxel_grid[mask] = voxel_grid[mask] - mean 18 | return voxel_grid 19 | 20 | 21 | class EventRepresentation: 22 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor, t_from: Optional[int]=None, t_to: Optional[int]=None): 23 | raise NotImplementedError 24 | 25 | 26 | class VoxelGrid(EventRepresentation): 27 | def __init__(self, channels: int, height: int, width: int): 28 | assert channels > 1 29 | assert height > 1 30 | assert width > 1 31 | self.nb_channels = channels 32 | self.height = height 33 | self.width = width 34 | 35 | def get_extended_time_window(self, t0_center: int, t1_center: int): 36 | dt = self._get_dt(t0_center, t1_center) 37 | t_start = math.floor(t0_center - dt) 38 | t_end = math.ceil(t1_center + dt) 39 | return t_start, t_end 40 | 41 | def _construct_empty_voxel_grid(self): 42 | return torch.zeros( 43 | (self.nb_channels, self.height, self.width), 44 | dtype=torch.float, 45 | requires_grad=False, 46 | device=torch.device('cpu')) 47 | 48 | def _get_dt(self, t0_center: int, t1_center: int): 49 | assert t1_center > t0_center 50 | return (t1_center - t0_center)/(self.nb_channels - 1) 51 | 52 | def _normalize_time(self, time: torch.Tensor, t0_center: int, t1_center: int): 53 | # time_norm < t0_center will be negative 54 | # time_norm == t0_center is 0 55 | # time_norm > t0_center is positive 56 | # time_norm == t1_center is (nb_channels - 1) 57 | # time_norm > t1_center is greater than (nb_channels - 1) 58 | return (time - t0_center)/(t1_center - t0_center)*(self.nb_channels - 1) 59 | 60 | @staticmethod 61 | def _is_int_tensor(tensor: torch.Tensor) -> bool: 62 | return not torch.is_floating_point(tensor) and not torch.is_complex(tensor) 63 | 64 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor, t0_center: Optional[int]=None, t1_center: Optional[int]=None): 65 | assert x.device == y.device == pol.device == time.device == torch.device('cpu') 66 | assert type(t0_center) == type(t1_center) 67 | assert x.shape == y.shape == pol.shape == time.shape 68 | assert x.ndim == 1 69 | assert self._is_int_tensor(time) 70 | 71 | is_int_xy = self._is_int_tensor(x) 72 | if is_int_xy: 73 | assert self._is_int_tensor(y) 74 | 75 | voxel_grid = self._construct_empty_voxel_grid() 76 | ch, ht, wd = self.nb_channels, self.height, self.width 77 | with torch.no_grad(): 78 | t0_center = t0_center if t0_center is not None else time[0] 79 | t1_center = t1_center if t1_center is not None else time[-1] 80 | t_norm = self._normalize_time(time, t0_center, t1_center) 81 | 82 | t0 = t_norm.floor().int() 83 | value = 2*pol.float()-1 84 | 85 | if is_int_xy: 86 | for tlim in [t0,t0+1]: 87 | mask = (tlim >= 0) & (tlim < ch) 88 | interp_weights = value * (1 - (tlim - t_norm).abs()) 89 | 90 | index = ht * wd * tlim.long() + \ 91 | wd * y.long() + \ 92 | x.long() 93 | 94 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 95 | else: 96 | x0 = x.floor().int() 97 | y0 = y.floor().int() 98 | for xlim in [x0,x0+1]: 99 | for ylim in [y0,y0+1]: 100 | for tlim in [t0,t0+1]: 101 | 102 | mask = (xlim < wd) & (xlim >= 0) & (ylim < ht) & (ylim >= 0) & (tlim >= 0) & (tlim < ch) 103 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs()) 104 | 105 | index = ht * wd * tlim.long() + \ 106 | wd * ylim.long() + \ 107 | xlim.long() 108 | 109 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 110 | 111 | return voxel_grid 112 | -------------------------------------------------------------------------------- /loggers/wandb_logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a modified version of the Pytorch Lightning logger 3 | """ 4 | 5 | import time 6 | from argparse import Namespace 7 | from pathlib import Path 8 | from typing import Any, Dict, List, Optional, Union 9 | from weakref import ReferenceType 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import wandb 15 | from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params 16 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 17 | from pytorch_lightning.loggers.logger import rank_zero_experiment, Logger 18 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn 19 | from wandb.sdk.lib import RunDisabled 20 | from wandb.wandb_run import Run 21 | 22 | 23 | class WandbLogger(Logger): 24 | LOGGER_JOIN_CHAR = "-" 25 | STEP_METRIC = "trainer/global_step" 26 | 27 | def __init__( 28 | self, 29 | name: Optional[str] = None, 30 | project: Optional[str] = None, 31 | group: Optional[str] = None, 32 | wandb_id: Optional[str] = None, 33 | prefix: Optional[str] = "", 34 | log_model: Optional[bool] = True, 35 | save_last_only_final: Optional[bool] = False, 36 | config_args: Optional[Dict[str, Any]] = None, 37 | **kwargs, 38 | ): 39 | super().__init__() 40 | self._experiment = None 41 | self._log_model = log_model 42 | self._prefix = prefix 43 | self._logged_model_time = {} 44 | self._checkpoint_callback = None 45 | # Save last is determined by the checkpoint callback argument 46 | self._save_last = None 47 | # Whether to save the last checkpoint continuously (more storage) or only when the run is aborted 48 | self._save_last_only_final = save_last_only_final 49 | # Save the configuration args (e.g. parsed arguments) and log it in wandb 50 | self._config_args = config_args 51 | # set wandb init arguments 52 | self._wandb_init = dict( 53 | name=name, 54 | project=project, 55 | group=group, 56 | id=wandb_id, 57 | resume="allow", 58 | save_code=True, 59 | ) 60 | self._wandb_init.update(**kwargs) 61 | # extract parameters 62 | self._name = self._wandb_init.get("name") 63 | self._id = self._wandb_init.get("id") 64 | # for save_top_k 65 | self._public_run = None 66 | 67 | # start wandb run (to create an attach_id for distributed modes) 68 | wandb.require("service") 69 | _ = self.experiment 70 | 71 | def get_checkpoint(self, artifact_name: str, artifact_filepath: Optional[Path] = None) -> Path: 72 | artifact = self.experiment.use_artifact(artifact_name) 73 | if artifact_filepath is None: 74 | assert artifact is not None, 'You are probably using DDP, ' \ 75 | 'in which case you should provide an artifact filepath.' 76 | # TODO: specify download directory 77 | artifact_dir = artifact.download() 78 | artifact_filepath = next(Path(artifact_dir).iterdir()) 79 | assert artifact_filepath.exists() 80 | assert artifact_filepath.suffix == '.ckpt' 81 | return artifact_filepath 82 | 83 | def __getstate__(self) -> Dict[str, Any]: 84 | state = self.__dict__.copy() 85 | # args needed to reload correct experiment 86 | if self._experiment is not None: 87 | state["_id"] = getattr(self._experiment, "id", None) 88 | state["_attach_id"] = getattr(self._experiment, "_attach_id", None) 89 | state["_name"] = self._experiment.name 90 | 91 | # cannot be pickled 92 | state["_experiment"] = None 93 | return state 94 | 95 | @property 96 | @rank_zero_experiment 97 | def experiment(self) -> Run: 98 | if self._experiment is None: 99 | attach_id = getattr(self, "_attach_id", None) 100 | if wandb.run is not None: 101 | # wandb process already created in this instance 102 | rank_zero_warn( 103 | "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" 104 | " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." 105 | ) 106 | self._experiment = wandb.run 107 | elif attach_id is not None and hasattr(wandb, "_attach"): 108 | # attach to wandb process referenced 109 | self._experiment = wandb._attach(attach_id) 110 | else: 111 | # create new wandb process 112 | self._experiment = wandb.init(**self._wandb_init) 113 | if self._config_args is not None: 114 | self._experiment.config.update(self._config_args, allow_val_change=True) 115 | 116 | # define default x-axis 117 | if isinstance(self._experiment, (Run, RunDisabled)) and getattr( 118 | self._experiment, "define_metric", None 119 | ): 120 | self._experiment.define_metric(self.STEP_METRIC) 121 | self._experiment.define_metric("*", step_metric=self.STEP_METRIC, step_sync=True) 122 | 123 | assert isinstance(self._experiment, (Run, RunDisabled)) 124 | return self._experiment 125 | 126 | def watch(self, model: nn.Module, log: str = 'all', log_freq: int = 100, log_graph: bool = True): 127 | self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph) 128 | 129 | def add_step_metric(self, input_dict: dict, step: int) -> None: 130 | input_dict.update({self.STEP_METRIC: step}) 131 | 132 | @rank_zero_only 133 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: 134 | params = _convert_params(params) 135 | params = _flatten_dict(params) 136 | params = _sanitize_callable_params(params) 137 | self.experiment.config.update(params, allow_val_change=True) 138 | 139 | @rank_zero_only 140 | def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: 141 | assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" 142 | 143 | metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) 144 | if step is not None: 145 | self.add_step_metric(metrics, step) 146 | self.experiment.log({**metrics}, step=step) 147 | else: 148 | self.experiment.log(metrics) 149 | 150 | @rank_zero_only 151 | def log_images(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: str) -> None: 152 | """Log images (tensors, numpy arrays, PIL Images or file paths). 153 | Optional kwargs are lists passed to each image (ex: caption, masks, boxes). 154 | 155 | How to use: https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.wandb.html#weights-and-biases-logger 156 | Taken from: https://github.com/PyTorchLightning/pytorch-lightning/blob/11e289ad9f95f5fe23af147fa4edcc9794f9b9a7/pytorch_lightning/loggers/wandb.py#L420 157 | """ 158 | if not isinstance(images, list): 159 | raise TypeError(f'Expected a list as "images", found {type(images)}') 160 | n = len(images) 161 | for k, v in kwargs.items(): 162 | if len(v) != n: 163 | raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") 164 | kwarg_list = [{k: kwargs[k][i] for k in kwargs.keys()} for i in range(n)] 165 | metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]} 166 | self.log_metrics(metrics, step) 167 | 168 | @rank_zero_only 169 | def log_videos(self, 170 | key: str, 171 | videos: List[Union[np.ndarray, str]], 172 | step: Optional[int] = None, 173 | captions: Optional[List[str]] = None, 174 | fps: int = 4, 175 | format_: Optional[str] = None): 176 | """ 177 | :param video: List[(T,C,H,W)] or List[(N,T,C,H,W)] 178 | :param captions: List[str] or None 179 | 180 | More info: https://docs.wandb.ai/ref/python/data-types/video and 181 | https://docs.wandb.ai/guides/track/log/media#other-media 182 | """ 183 | assert isinstance(videos, list) 184 | if captions is not None: 185 | assert isinstance(captions, list) 186 | assert len(captions) == len(videos) 187 | wandb_videos = list() 188 | for idx, video in enumerate(videos): 189 | caption = captions[idx] if captions is not None else None 190 | wandb_videos.append(wandb.Video(data_or_path=video, caption=caption, fps=fps, format=format_)) 191 | self.log_metrics(metrics={key: wandb_videos}, step=step) 192 | 193 | @property 194 | def name(self) -> Optional[str]: 195 | # This function seems to be only relevant if LoggerCollection is used. 196 | # don't create an experiment if we don't have one 197 | return self._experiment.project_name() if self._experiment else self._name 198 | 199 | @property 200 | def version(self) -> Optional[str]: 201 | # This function seems to be only relevant if LoggerCollection is used. 202 | # don't create an experiment if we don't have one 203 | return self._experiment.id if self._experiment else self._id 204 | 205 | @rank_zero_only 206 | def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: 207 | # log checkpoints as artifacts 208 | if self._checkpoint_callback is None: 209 | self._checkpoint_callback = checkpoint_callback 210 | self._save_last = checkpoint_callback.save_last 211 | if self._log_model: 212 | self._scan_and_log_checkpoints(checkpoint_callback, self._save_last and not self._save_last_only_final) 213 | 214 | @rank_zero_only 215 | def finalize(self, status: str) -> None: 216 | # log checkpoints as artifacts 217 | if self._checkpoint_callback and self._log_model: 218 | self._scan_and_log_checkpoints(self._checkpoint_callback, self._save_last) 219 | 220 | def _get_public_run(self): 221 | if self._public_run is None: 222 | experiment = self.experiment 223 | runpath = experiment._entity + '/' + experiment._project + '/' + experiment._run_id 224 | api = wandb.Api() 225 | self._public_run = api.run(path=runpath) 226 | return self._public_run 227 | 228 | def _num_logged_artifact(self): 229 | public_run = self._get_public_run() 230 | return len(public_run.logged_artifacts()) 231 | 232 | def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]", save_last: bool) -> None: 233 | assert self._log_model 234 | if self._checkpoint_callback is None: 235 | self._checkpoint_callback = checkpoint_callback 236 | self._save_last = checkpoint_callback.save_last 237 | 238 | checkpoints = { 239 | checkpoint_callback.best_model_path: checkpoint_callback.best_model_score, 240 | **checkpoint_callback.best_k_models, 241 | } 242 | assert len(checkpoints) <= max(checkpoint_callback.save_top_k, 0) 243 | 244 | if save_last: 245 | last_model_path = Path(checkpoint_callback.last_model_path) 246 | if last_model_path.exists(): 247 | checkpoints.update({checkpoint_callback.last_model_path: checkpoint_callback.current_score}) 248 | else: 249 | print(f'last model checkpoint not found at {checkpoint_callback.last_model_path}') 250 | 251 | checkpoints = sorted( 252 | ((Path(path).stat().st_mtime, path, score) for path, score in checkpoints.items() if Path(path).is_file()), 253 | key=lambda x: x[0]) 254 | # Retain only checkpoints that we have not logged before with one exception: 255 | # If the name is the same (e.g. last checkpoint which should be overwritten), 256 | # make sure that they are newer than the previously saved checkpoint by checking their modification time 257 | checkpoints = [ckpt for ckpt in checkpoints if 258 | ckpt[1] not in self._logged_model_time.keys() or self._logged_model_time[ckpt[1]] < ckpt[0]] 259 | # remove checkpoints with undefined (None) score 260 | checkpoints = [x for x in checkpoints if x[2] is not None] 261 | 262 | num_ckpt_logged_before = self._num_logged_artifact() 263 | num_new_cktps = len(checkpoints) 264 | 265 | if num_new_cktps == 0: 266 | return 267 | 268 | # log iteratively all new checkpoints 269 | for time_, path, score in checkpoints: 270 | score = score.item() if isinstance(score, torch.Tensor) else score 271 | is_best = path == checkpoint_callback.best_model_path 272 | is_last = path == checkpoint_callback.last_model_path 273 | metadata = ({ 274 | "score": score, 275 | "original_filename": Path(path).name, 276 | "ModelCheckpoint": { 277 | k: getattr(checkpoint_callback, k) 278 | for k in [ 279 | "monitor", 280 | "mode", 281 | "save_last", 282 | "save_top_k", 283 | "save_weights_only", 284 | ] 285 | # ensure it does not break if `ModelCheckpoint` args change 286 | if hasattr(checkpoint_callback, k) 287 | }, 288 | } 289 | ) 290 | aliases = [] 291 | if is_best: 292 | aliases.append('best') 293 | if is_last: 294 | aliases.append('last') 295 | artifact_name = f'checkpoint-{self.experiment.id}-' + ('last' if is_last else 'topK') 296 | artifact = wandb.Artifact(name=artifact_name, type='model', metadata=metadata) 297 | assert Path(path).exists() 298 | artifact.add_file(path, name=f'{self.experiment.id}.ckpt') 299 | self.experiment.log_artifact(artifact, aliases=aliases) 300 | # remember logged model - timestamp needed in case filename didn't change (last.ckpt or custom name) 301 | self._logged_model_time[path] = time_ 302 | 303 | timeout = 20 304 | time_spent = 0 305 | while self._num_logged_artifact() < num_ckpt_logged_before + num_new_cktps: 306 | time.sleep(1) 307 | time_spent += 1 308 | if time_spent >= timeout: 309 | rank_zero_warn("Timeout: Num logged artifacts never reached expected value.") 310 | print(f'self._num_logged_artifact() = {self._num_logged_artifact()}') 311 | print(f'num_ckpt_logged_before = {num_ckpt_logged_before}') 312 | print(f'num_new_cktps = {num_new_cktps}') 313 | break 314 | try: 315 | self._rm_but_top_k(checkpoint_callback.save_top_k) 316 | except KeyError: 317 | pass 318 | 319 | def _rm_but_top_k(self, top_k: int): 320 | # top_k == -1: save all model 321 | # top_k == 0: no model saved at all. The checkpoint callback does not return checkpoints. 322 | # top_k > 0: keep only top k model (last and best will not be deleted) 323 | def is_last(artifact): 324 | return 'last' in artifact.aliases 325 | 326 | def is_best(artifact): 327 | return 'best' in artifact.aliases 328 | 329 | def try_delete(artifact): 330 | try: 331 | artifact.delete(delete_aliases=True) 332 | except wandb.errors.CommError: 333 | print(f'Failed to delete artifact {artifact.name} due to wandb.errors.CommError') 334 | 335 | public_run = self._get_public_run() 336 | 337 | score2art = list() 338 | for artifact in public_run.logged_artifacts(): 339 | score = artifact.metadata['score'] 340 | original_filename = artifact.metadata['original_filename'] 341 | if score == 'Infinity': 342 | print( 343 | f'removing INF artifact (name, score, original_filename): ({artifact.name}, {score}, {original_filename})') 344 | try_delete(artifact) 345 | continue 346 | if score is None: 347 | print( 348 | f'removing None artifact (name, score, original_filename): ({artifact.name}, {score}, {original_filename})') 349 | try_delete(artifact) 350 | continue 351 | score2art.append((score, artifact)) 352 | 353 | # From high score to low score 354 | score2art.sort(key=lambda x: x[0], reverse=True) 355 | 356 | count = 0 357 | for score, artifact in score2art: 358 | original_filename = artifact.metadata['original_filename'] 359 | if 'last' in original_filename and not is_last(artifact): 360 | try_delete(artifact) 361 | continue 362 | if is_last(artifact): 363 | continue 364 | count += 1 365 | if is_best(artifact): 366 | continue 367 | # if top_k == -1, we do not delete anything 368 | if 0 <= top_k < count: 369 | try_delete(artifact) -------------------------------------------------------------------------------- /models/raft_spline/bezier.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from typing import Union, List 5 | 6 | import numpy as np 7 | import torch as th 8 | from numba import jit 9 | 10 | has_scipy_special = True 11 | try: 12 | from scipy import special 13 | except ImportError: 14 | has_scipy_special = False 15 | 16 | from models.raft_utils.utils import cvx_upsample 17 | 18 | class BezierCurves: 19 | # Each ctrl point lives in R^2 20 | CTRL_DIM: int = 2 21 | 22 | def __init__(self, bezier_params: th.Tensor): 23 | # bezier_params: batch, ctrl_dim*(n_ctrl_pts - 1), height, width 24 | assert bezier_params.ndim == 4 25 | self._params = bezier_params 26 | 27 | # some helpful meta-data: 28 | self.batch, channels, self.ht, self.wd = self._params.shape 29 | assert channels % 2 == 0 30 | # P0 is always zeros as it corresponds to the pixel locations. 31 | # Consequently, we only compute P1, P2, ... 32 | self.n_ctrl_pts = channels // self.CTRL_DIM + 1 33 | assert self.n_ctrl_pts > 0 34 | 35 | # math.comb is only available in python 3.8 or higher 36 | self.use_math_comb = hasattr(math, 'comb') 37 | if not self.use_math_comb: 38 | assert has_scipy_special 39 | assert hasattr(special, 'comb') 40 | 41 | def comb(self, n: int, k:int): 42 | if self.use_math_comb: 43 | return math.comb(n, k) 44 | return special.comb(n, k) 45 | 46 | @classmethod 47 | def create_from_specification(cls, batch_size: int, n_ctrl_pts: int, height: int, width: int, device: th.device) -> BezierCurves: 48 | assert batch_size > 0 49 | assert n_ctrl_pts > 1 50 | assert height > 0 51 | assert width > 0 52 | params = th.zeros(batch_size, cls.CTRL_DIM * (n_ctrl_pts - 1), height, width, device=device) 53 | return cls(params) 54 | 55 | @classmethod 56 | def from_2view(cls, flow_tensor: th.Tensor) -> BezierCurves: 57 | # This function has been written to visualize 2-view predictions for our paper. 58 | batch_size, channel_size, height, width = flow_tensor.shape 59 | assert channel_size == 2 == cls.CTRL_DIM 60 | return cls(flow_tensor) 61 | 62 | @classmethod 63 | def create_from_voxel_grid(cls, voxel_grid: th.Tensor, downsample_factor: int=8, bezier_degree: int=2) -> BezierCurves: 64 | assert isinstance(downsample_factor, int) 65 | assert downsample_factor >= 1 66 | batch, _, ht, wd = voxel_grid.shape 67 | assert ht % 8 == 0 68 | assert wd % 8 == 0 69 | ht, wd = ht//downsample_factor, wd//downsample_factor 70 | n_ctrl_pts = bezier_degree + 1 71 | return cls.create_from_specification(batch_size=batch, n_ctrl_pts=n_ctrl_pts, height=ht, width=wd, device=voxel_grid.device) 72 | 73 | @property 74 | def device(self): 75 | return self._params.device 76 | 77 | @property 78 | def dtype(self): 79 | return self._params.dtype 80 | 81 | def create_upsampled(self, mask: th.Tensor) -> BezierCurves: 82 | """ Upsample params [N, dim, H/8, W/8] -> [N, dim, H, W] using convex combination """ 83 | up_params = cvx_upsample(self._params, mask) 84 | return BezierCurves(up_params) 85 | 86 | def detach(self, clone: bool=False, cpu: bool=False) -> BezierCurves: 87 | params = self._params.detach() 88 | if cpu: 89 | return BezierCurves(params.cpu()) 90 | if clone: 91 | params = params.clone() 92 | return BezierCurves(params) 93 | 94 | def detach_(self, cpu: bool=False) -> None: 95 | # Detaches the bezier parameters in-place! 96 | self._params = self._params.detach() 97 | if cpu: 98 | self._params = self._params.cpu() 99 | 100 | def cpu(self) -> BezierCurves: 101 | return BezierCurves(self._params.cpu()) 102 | 103 | def cpu_(self) -> None: 104 | # Puts the bezier parameters to CPU in-place! 105 | self._params = self._params.cpu() 106 | 107 | @property 108 | def requires_grad(self): 109 | return self._params.requires_grad 110 | 111 | @property 112 | def batch_size(self): 113 | return self._params.shape[0] 114 | 115 | @property 116 | def degree(self): 117 | return self.n_ctrl_pts - 1 118 | 119 | @property 120 | def dim(self): 121 | return self._params.shape[1] 122 | 123 | @property 124 | def height(self): 125 | return self._params.shape[-2] 126 | 127 | @property 128 | def width(self): 129 | return self._params.shape[-1] 130 | 131 | def get_params(self) -> th.Tensor: 132 | return self._params 133 | 134 | def _param_view(self) -> th.Tensor: 135 | return self._params.view(self.batch, self.CTRL_DIM, self.degree, self.ht, self.wd) 136 | 137 | def delta_update_params(self, delta_bezier: th.Tensor) -> None: 138 | assert delta_bezier.shape == self._params.shape 139 | self._params = self._params + delta_bezier 140 | 141 | @staticmethod 142 | def _get_binom_coeffs(degree: int): 143 | n = degree 144 | k = np.arange(degree) + 1 145 | return special.binom(n, k) 146 | 147 | @staticmethod 148 | @jit(nopython=True) 149 | def _get_time_coeffs(timestamps: np.ndarray, degree: int): 150 | assert timestamps.min() >= 0 151 | assert timestamps.max() <= 1 152 | assert timestamps.ndim == 1 153 | # I would like to check ensure float64 dtype but have not found a way to check in jit 154 | #assert timestamps.dtype == np.dtype('float64') 155 | 156 | num_ts = timestamps.size 157 | out = np.zeros((num_ts, degree)) 158 | for t_idx in range(num_ts): 159 | for d_idx in range(degree): 160 | time = timestamps[t_idx] 161 | i = d_idx + 1 162 | out[t_idx, d_idx] = (1 - time)**(degree - i)*time**i 163 | return out 164 | 165 | def _compute_flow_from_timestamps(self, timestamps: Union[List[float], np.ndarray]): 166 | if isinstance(timestamps, list): 167 | timestamps = np.asarray(timestamps) 168 | else: 169 | assert isinstance(timestamps, np.ndarray) 170 | assert timestamps.dtype == 'float64' 171 | assert timestamps.size > 0 172 | assert np.min(timestamps) >= 0 173 | assert np.max(timestamps) <= 1 174 | 175 | degree = self.degree 176 | binom_coeffs = self._get_binom_coeffs(degree) 177 | time_coeffs = self._get_time_coeffs(timestamps, degree) 178 | # poly coeffs: time, degree 179 | polynomial_coeffs = np.einsum('j,ij->ij', binom_coeffs, time_coeffs) 180 | polynomial_coeffs = th.from_numpy(polynomial_coeffs).float().to(device=self.device) 181 | 182 | # params: batch, dim, degree, height, width 183 | params = self._param_view() 184 | # flow: timestamps, batch, dim, height, width 185 | flow = th.einsum('bdphw,tp->tbdhw', params, polynomial_coeffs) 186 | return flow 187 | 188 | def get_flow_from_reference(self, time: Union[float, int, List[float], np.ndarray]) -> th.Tensor: 189 | params = self._param_view() 190 | batch, dim, degree, height, width = params.shape 191 | time_is_scalar = isinstance(time, int) or isinstance(time, float) 192 | if time_is_scalar: 193 | assert time >= 0.0 194 | assert time <= 1.0 195 | if time == 1: 196 | P_end = params[:, :, -1, ...] 197 | return P_end 198 | if time == 0: 199 | return th.zeros((batch, dim, height, width), dtype=self.dtype, device=self.device) 200 | time = np.array([time], dtype='float64') 201 | elif isinstance(time, list): 202 | time = np.asarray(time, dtype='float64') 203 | else: 204 | assert isinstance(time, np.ndarray) 205 | assert time.dtype == 'float64' 206 | assert time.size > 0 207 | assert np.min(time) >= 0 208 | assert np.max(time) <= 1 209 | 210 | # flow is coords1 - coords0 211 | # flows: timestamps, batch, dim, height, width 212 | flows = self._compute_flow_from_timestamps(timestamps=time) 213 | if time_is_scalar: 214 | assert flows.shape[0] == 1 215 | return flows[0] 216 | return flows -------------------------------------------------------------------------------- /models/raft_spline/raft.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Optional, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.raft_spline.bezier import BezierCurves 7 | from models.raft_spline.update import BasicUpdateBlock 8 | from models.raft_utils.extractor import BasicEncoder 9 | from models.raft_utils.corr import CorrComputation, CorrBlockParallelMultiTarget 10 | from models.raft_utils.utils import coords_grid 11 | from utils.timers import CudaTimerDummy as CudaTimer 12 | 13 | 14 | class RAFTSpline(nn.Module): 15 | def __init__(self, model_params: Dict[str, Any]): 16 | super().__init__() 17 | nbins_context = model_params['num_bins']['context'] 18 | nbins_correlation = model_params['num_bins']['correlation'] 19 | self.bezier_degree = model_params['bezier_degree'] 20 | self.detach_bezier = model_params['detach_bezier'] 21 | print(f'Detach Bezier curves: {self.detach_bezier}') 22 | 23 | assert nbins_correlation > 0 and nbins_context > 0 24 | assert self.bezier_degree >= 1 25 | self.nbins_context = nbins_context 26 | self.nbins_corr = nbins_correlation 27 | 28 | print('RAFT-Spline config:') 29 | print(f'Num bins context: {nbins_context}') 30 | print(f'Num bins correlation: {nbins_correlation}') 31 | 32 | corr_params = model_params['correlation'] 33 | self.corr_use_cosine_sim = corr_params['use_cosine_sim'] 34 | 35 | ev_corr_params = corr_params['ev'] 36 | self.ev_corr_target_indices = ev_corr_params['target_indices'] 37 | self.ev_corr_levels = ev_corr_params['levels'] 38 | # TODO: fix this in the config 39 | #self.ev_corr_radius = ev_corr_params['radius'] 40 | self.ev_corr_radius = 4 41 | 42 | self.img_corr_params = None 43 | if model_params['use_boundary_images']: 44 | print('Using images') 45 | self.img_corr_params = corr_params['img'] 46 | assert 'levels' in self.img_corr_params 47 | assert 'radius' in self.img_corr_params 48 | 49 | self.hidden_dim = hdim = model_params['hidden']['dim'] 50 | self.context_dim = cdim = model_params['context']['dim'] 51 | cnorm = model_params['context']['norm'] 52 | feature_dim = model_params['feature']['dim'] 53 | fnorm = model_params['feature']['norm'] 54 | 55 | # feature network, context network, and update block 56 | context_dim = 0 57 | self.fnet_img = None 58 | if self.img_corr_params is not None: 59 | self.fnet_img = BasicEncoder(input_dim=3, output_dim=feature_dim, norm_fn=fnorm) 60 | context_dim += 3 61 | self.fnet_ev = None 62 | if model_params['use_events']: 63 | print('Using events') 64 | assert 0 not in self.ev_corr_target_indices 65 | assert len(self.ev_corr_target_indices) > 0 66 | assert max(self.ev_corr_target_indices) < self.nbins_context 67 | assert len(self.ev_corr_target_indices) == len(self.ev_corr_levels) 68 | self.fnet_ev = BasicEncoder(input_dim=nbins_correlation, output_dim=feature_dim, norm_fn=fnorm) 69 | context_dim += nbins_context 70 | assert self.fnet_ev is not None or self.fnet_img is not None 71 | self.cnet = BasicEncoder(input_dim=context_dim, output_dim=hdim + cdim, norm_fn=cnorm) 72 | 73 | self.update_block = BasicUpdateBlock(model_params, hidden_dim=hdim) 74 | 75 | def freeze_bn(self): 76 | for m in self.modules(): 77 | if isinstance(m, nn.BatchNorm2d): 78 | m.eval() 79 | 80 | def initialize_flow(self, input_): 81 | N, _, H, W = input_.shape 82 | # batch, 2, ht, wd 83 | downsample_factor = 8 84 | coords0 = coords_grid(N, H//downsample_factor, W//downsample_factor, device=input_.device) 85 | bezier = BezierCurves.create_from_voxel_grid(input_, downsample_factor=downsample_factor, bezier_degree=self.bezier_degree) 86 | return coords0, bezier 87 | 88 | def gen_voxel_grids(self, input_: torch.Tensor): 89 | # input_: N, nbins_context + nbins_corr - 1 , H, W 90 | assert self.nbins_context + self.nbins_corr - 1 == input_.shape[-3] 91 | corr_grids = list() 92 | # We need to add the reference index (which is 0). 93 | indices_with_reference = [0] 94 | indices_with_reference.extend(self.ev_corr_target_indices) 95 | for idx in indices_with_reference: 96 | slice_ = input_[:, idx:idx+self.nbins_corr, ...] 97 | corr_grids.append(slice_) 98 | context_grid = input_[:, -self.nbins_context:, ...] 99 | return corr_grids, context_grid 100 | 101 | def forward(self, 102 | voxel_grid: Optional[torch.Tensor]=None, 103 | images: Optional[List[torch.Tensor]]=None, 104 | iters: int=12, 105 | flow_init: Optional[BezierCurves]=None, 106 | test_mode: bool=False): 107 | assert voxel_grid is not None or images is not None 108 | assert iters > 0 109 | 110 | hdim = self.hidden_dim 111 | cdim = self.context_dim 112 | current_device = voxel_grid.device if voxel_grid is not None else images[0].device 113 | 114 | corr_computation_events = None 115 | context_input = None 116 | with CudaTimer(current_device, 'fnet_ev'): 117 | if self.fnet_ev is not None: 118 | assert voxel_grid is not None 119 | voxel_grid = voxel_grid.contiguous() 120 | corr_grids, context_input = self.gen_voxel_grids(voxel_grid) 121 | fmaps_ev = self.fnet_ev(corr_grids) 122 | fmaps_ev = [x.float() for x in fmaps_ev] 123 | fmap1_ev = fmaps_ev[0] 124 | fmap2_ev = torch.stack(fmaps_ev[1:], dim=0) 125 | corr_computation_events = CorrComputation(fmap1_ev, fmap2_ev, num_levels_per_target=self.ev_corr_levels) 126 | 127 | corr_computation_frames = None 128 | with CudaTimer(current_device, 'fnet_img'): 129 | if self.fnet_img is not None: 130 | assert self.img_corr_params is not None 131 | assert len(images) == 2 132 | # images[0]: at reference time 133 | # images[1]: at target time 134 | images = [2 * (x.float().contiguous() / 255) - 1 for x in images] 135 | fmaps_img = self.fnet_img(images) 136 | corr_computation_frames = CorrComputation(fmaps_img[0], fmaps_img[1], num_levels_per_target=self.img_corr_params['levels']) 137 | if context_input is not None: 138 | context_input = torch.cat((context_input, images[0]), dim=-3) 139 | else: 140 | context_input = images[0] 141 | assert context_input is not None 142 | 143 | with CudaTimer(current_device, 'cnet'): 144 | cnet = self.cnet(context_input) 145 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 146 | net = torch.tanh(net) 147 | inp = torch.relu(inp) 148 | 149 | # (batch, 2, ht, wd), ... 150 | coords0, bezier = self.initialize_flow(context_input) 151 | 152 | if flow_init is not None: 153 | bezier.delta_update_params(flow_init.get_params()) 154 | 155 | bezier_up_predictions = [] 156 | dt = 1/(self.nbins_context - 1) 157 | 158 | with CudaTimer(current_device, 'corr computation'): 159 | corr_block = CorrBlockParallelMultiTarget( 160 | corr_computation_events=corr_computation_events, 161 | corr_computation_frames=corr_computation_frames) 162 | with CudaTimer(current_device, 'all iters'): 163 | for itr in range(iters): 164 | # NOTE: original RAFT detaches the flow (bezier) here from the graph. 165 | # Our experiments with bezier curves indicate that detaching is lowering the validation EPE by up to 5% on DSEC. 166 | with CudaTimer(current_device, '1 iter'): 167 | if self.detach_bezier: 168 | bezier.detach_() 169 | 170 | lookup_timestamps = list() 171 | if corr_computation_events is not None: 172 | for tindex in self.ev_corr_target_indices: 173 | # 0 < time <= 1 174 | time = dt*tindex 175 | lookup_timestamps.append(time) 176 | if corr_computation_frames is not None: 177 | lookup_timestamps.append(1) 178 | 179 | with CudaTimer(current_device, 'get_flow (per iter)'): 180 | flows = bezier.get_flow_from_reference(time=lookup_timestamps) 181 | coords1 = coords0 + flows 182 | 183 | with CudaTimer(current_device, 'corr lookup (per iter)'): 184 | corr_total = corr_block(coords1) 185 | 186 | with CudaTimer(current_device, 'update (per iter)'): 187 | bezier_params = bezier.get_params() 188 | net, up_mask, delta_bezier = self.update_block(net, inp, corr_total, bezier_params) 189 | 190 | # B(k+1) = B(k) + \Delta(B) 191 | bezier.delta_update_params(delta_bezier=delta_bezier) 192 | 193 | if not test_mode or itr == iters - 1: 194 | bezier_up = bezier.create_upsampled(up_mask) 195 | bezier_up_predictions.append(bezier_up) 196 | 197 | if test_mode: 198 | return bezier, bezier_up 199 | 200 | return bezier_up_predictions 201 | -------------------------------------------------------------------------------- /models/raft_spline/update.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class BezierHead(nn.Module): 9 | def __init__(self, bezier_degree: int, input_dim=128, hidden_dim=256): 10 | super().__init__() 11 | output_dim = bezier_degree * 2 12 | # TODO: figure out if we need to increase the capacity of this head due to bezier curves 13 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 14 | self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | def forward(self, x): 18 | return self.conv2(self.relu(self.conv1(x))) 19 | 20 | 21 | class SepConvGRU(nn.Module): 22 | def __init__(self, hidden_dim=128, input_dim=192+128): 23 | super().__init__() 24 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 25 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 26 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 27 | 28 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 29 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 30 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 31 | 32 | 33 | def forward(self, h, x): 34 | # horizontal 35 | hx = torch.cat([h, x], dim=1) 36 | z = torch.sigmoid(self.convz1(hx)) 37 | r = torch.sigmoid(self.convr1(hx)) 38 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 39 | h = (1-z) * h + z * q 40 | 41 | # vertical 42 | hx = torch.cat([h, x], dim=1) 43 | z = torch.sigmoid(self.convz2(hx)) 44 | r = torch.sigmoid(self.convr2(hx)) 45 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 46 | h = (1-z) * h + z * q 47 | 48 | return h 49 | 50 | class BasicMotionEncoder(nn.Module): 51 | def __init__(self, model_params: Dict[str, Any], output_dim: int=128): 52 | super().__init__() 53 | cor_planes = self._num_cor_planes(model_params['correlation'], model_params['use_boundary_images'], model_params['use_events']) 54 | 55 | # TODO: Are two layers enough for this? Because the number of input channels grew substantially 56 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 57 | c2_out = 192 58 | self.convc2 = nn.Conv2d(256, c2_out, 3, padding=1) 59 | 60 | # TODO: Consider upgrading the number of channels of the flow encoders (dealing now with bezier curves) 61 | bezier_planes = model_params['bezier_degree'] * 2 62 | self.convf1 = nn.Conv2d(bezier_planes, 128, 7, padding=3) 63 | f2_out = 64 64 | self.convf2 = nn.Conv2d(128, f2_out, 3, padding=1) 65 | 66 | combined_channels = f2_out+c2_out 67 | self.conv = nn.Conv2d(combined_channels, output_dim-bezier_planes, 3, padding=1) 68 | 69 | @staticmethod 70 | def _num_cor_planes(corr_params: Dict[str, Any], use_boundary_images: bool, use_events: bool): 71 | assert use_events or use_boundary_images 72 | out = 0 73 | if use_events: 74 | ev_params = corr_params['ev'] 75 | ev_corr_levels = ev_params['levels'] 76 | ev_corr_radius = ev_params['radius'] 77 | assert len(ev_corr_levels) > 0 78 | assert len(ev_corr_radius) > 0 79 | assert len(ev_corr_levels) == len(ev_corr_radius) 80 | for lvl, rad in zip(ev_corr_levels, ev_corr_radius): 81 | out += lvl * (2*rad + 1)**2 82 | if use_boundary_images: 83 | img_corr_levels = corr_params['img']['levels'] 84 | img_corr_radius = corr_params['img']['radius'] 85 | out += img_corr_levels * (2*img_corr_radius + 1)**2 86 | return out 87 | 88 | def forward(self, bezier, corr): 89 | cor = F.relu(self.convc1(corr)) 90 | cor = F.relu(self.convc2(cor)) 91 | bez = F.relu(self.convf1(bezier)) 92 | bez = F.relu(self.convf2(bez)) 93 | 94 | cor_bez = torch.cat([cor, bez], dim=1) 95 | 96 | out = F.relu(self.conv(cor_bez)) 97 | return torch.cat([out, bezier], dim=1) 98 | 99 | 100 | class BasicUpdateBlock(nn.Module): 101 | def __init__(self, model_params: Dict[str, Any], hidden_dim: int=128): 102 | super().__init__() 103 | motion_encoder_output_dim = model_params['motion']['dim'] 104 | context_dim = model_params['context']['dim'] 105 | bezier_degree = model_params['bezier_degree'] 106 | self.encoder = BasicMotionEncoder(model_params, output_dim=motion_encoder_output_dim) 107 | gru_input_dim = context_dim + motion_encoder_output_dim 108 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=gru_input_dim) 109 | self.bezier_head = BezierHead(bezier_degree, input_dim=hidden_dim, hidden_dim=256) 110 | 111 | self.mask = nn.Sequential( 112 | nn.Conv2d(hidden_dim, 256, 3, padding=1), 113 | nn.ReLU(inplace=True), 114 | nn.Conv2d(256, 64*9, 1, padding=0)) 115 | 116 | def forward(self, net, inp, corr, bezier): 117 | # TODO: check if we can simplify this similar to the DROID-SLAM update block 118 | motion_features = self.encoder(bezier, corr) 119 | inp = torch.cat([inp, motion_features], dim=1) 120 | 121 | net = self.gru(net, inp) 122 | delta_bezier = self.bezier_head(net) 123 | 124 | # scale mask to balance gradients 125 | mask = .25 * self.mask(net) 126 | return net, mask, delta_bezier 127 | -------------------------------------------------------------------------------- /models/raft_utils/corr.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Union, List, Tuple, Optional 4 | 5 | import torch 6 | import torch as th 7 | import torch.nn.functional as F 8 | from omegaconf import ListConfig 9 | 10 | from models.raft_utils.utils import bilinear_sampler 11 | 12 | 13 | class CorrData: 14 | def __init__(self, 15 | corr: th.Tensor, 16 | batch_size: int): 17 | assert isinstance(corr, th.Tensor) 18 | num_targets, bhw, dim, ht, wd = corr.shape 19 | assert dim == 1 20 | assert isinstance(batch_size, int) 21 | assert batch_size >= 1 22 | 23 | # corr: num_targets, batch_size*ht*wd, 1, ht, wd 24 | self._corr = corr 25 | self._batch_size = batch_size 26 | self._target_indices = None 27 | 28 | @property 29 | def corr(self) -> th.Tensor: 30 | assert self._corr.ndim == 5 31 | return self._corr 32 | 33 | @property 34 | def corr_batched(self) -> th.Tensor: 35 | num_targets, bhw, dim, ht, wd = self.corr.shape 36 | assert dim == 1 37 | return self.corr.view(-1, dim, ht, wd) 38 | 39 | @property 40 | def batch_size(self) -> int: 41 | return self._batch_size 42 | 43 | @property 44 | def device(self) -> th.device: 45 | return self.corr.device 46 | 47 | @property 48 | def target_indices(self) -> th.Tensor: 49 | assert self._target_indices is not None 50 | return self._target_indices 51 | 52 | @target_indices.setter 53 | def target_indices(self, values: Union[List[int], th.Tensor]): 54 | if isinstance(values, list): 55 | values = sorted(values) 56 | values = th.Tensor(values, device=self.device) 57 | else: 58 | assert values.device == self.device 59 | assert values.dtype == th.int64 60 | assert values.ndim == 1 61 | assert values.numel() > 0 62 | assert values.min() >= 0 63 | assert th.all(values[1:] - values[:-1] > 0) 64 | assert self.corr.shape[0] == values.numel() 65 | self._target_indices = values 66 | 67 | def init_target_indices(self): 68 | num_targets = self.corr.shape[0] 69 | self.target_indices = torch.arange(num_targets, device=self.device) 70 | 71 | @staticmethod 72 | def _extract_database_indices_from_query_matches(query_values: torch.Tensor, database_values: torch.Tensor) -> th.Tensor: 73 | ''' 74 | query_values: set of unique integers (assumed in ascending order) 75 | database_values: set of unique integers (assumed in ascending order) 76 | 77 | This function returns the database indices where query values match the database values. 78 | Example: 79 | query_values: torch.tensor([4, 5, 9]) 80 | database_values: torch.tensor([1, 4, 5, 6, 9]) 81 | returns: torch.tensor([1, 2, 4]) 82 | ''' 83 | assert database_values.dtype == torch.int64 84 | assert query_values.dtype == torch.int64 85 | assert database_values.ndim == 1 86 | assert query_values.ndim == 1 87 | assert torch.equal(database_values, torch.unique_consecutive(database_values)) 88 | assert torch.equal(query_values, torch.unique_consecutive(query_values)) 89 | device = query_values.device 90 | assert device == database_values.device 91 | 92 | num_db = database_values.numel() 93 | num_q = query_values.numel() 94 | 95 | database_values_expanded = database_values.expand(num_q, -1) 96 | query_values_expanded = query_values.expand(num_db, -1).transpose(0, 1) 97 | 98 | compare = torch.eq(database_values_expanded, query_values_expanded) 99 | 100 | indices = torch.arange(num_db, device=device) 101 | indices = indices.expand(num_q, -1) 102 | assert torch.all(compare.sum(1) == 1), f'compare =\n{compare}' 103 | # If the previous assertion fails it likely is the case that the query values are not a subset of database values. 104 | out = indices[compare] 105 | assert torch.equal(query_values, database_values[out]) 106 | return out 107 | 108 | def get_downsampled(self, target_indices: Union[List[int], th.Tensor]) -> CorrData: 109 | if isinstance(target_indices, list): 110 | target_indices = sorted(target_indices) 111 | target_indices = th.tensor(target_indices, device=self.device) 112 | 113 | indices_to_select = self._extract_database_indices_from_query_matches(target_indices, self.target_indices) 114 | corr_selected = th.index_select(self.corr, dim=0, index=indices_to_select) 115 | 116 | num_new_targets, bhw, dim, ht, wd = corr_selected.shape 117 | assert dim == 1 118 | corr_selected = corr_selected.reshape(-1, dim, ht, wd) 119 | corr_down = F.avg_pool2d(corr_selected, 2, stride=2) 120 | _, _, ht_down, wd_down = corr_down.shape 121 | corr_down = corr_down.view(num_new_targets, bhw, dim, ht_down, wd_down) 122 | 123 | out = CorrData(corr_down, self.batch_size) 124 | out.target_indices = target_indices 125 | return out 126 | 127 | class CorrComputation: 128 | def __init__(self, 129 | fmap1: Union[th.Tensor, List[th.Tensor]], 130 | fmap2: Union[th.Tensor, List[th.Tensor]], 131 | num_levels_per_target: Union[int, List[int], List[th.Tensor]]): 132 | ''' 133 | fmap1: batch, dim, ht, wd OR 134 | List[tensor(batch, dim, ht, wd)] 135 | fmap2: batch, dim, ht, wd OR 136 | num_targets, batch, dim, ht, wd OR 137 | List[tensor(num_targets, batch, dim, ht, wd)] 138 | -> in case of list input, len(fmap1) == len(fmap2) is enforced 139 | num_levels_per_target: int OR 140 | List[int] SUCH THAT len(...) == num_targets OR 141 | List[tensor(num_targets)] 142 | ''' 143 | self._has_single_reference = isinstance(fmap1, th.Tensor) 144 | if isinstance(num_levels_per_target, int): 145 | num_levels_per_target = [num_levels_per_target] 146 | else: 147 | assert isinstance(num_levels_per_target, list) or isinstance(num_levels_per_target, ListConfig) 148 | if self._has_single_reference: 149 | assert fmap1.ndim == 4 150 | assert isinstance(fmap2, th.Tensor) 151 | if fmap2.ndim == 4: 152 | # This is the case where we also only have a single target 153 | fmap2 = fmap2.unsqueeze(0) 154 | else: 155 | assert fmap2.ndim == 5 156 | assert fmap1.shape == fmap2.shape[1:] 157 | fmap1 = [fmap1] 158 | fmap2 = [fmap2] 159 | assert len(num_levels_per_target) == fmap2[0].shape[0] 160 | for x in num_levels_per_target: 161 | assert isinstance(x, int) 162 | num_levels_per_target = [th.tensor(num_levels_per_target)] 163 | else: 164 | assert isinstance(fmap1, list) 165 | assert isinstance(fmap2, list) 166 | assert len(fmap1) == len(fmap2) 167 | assert len(num_levels_per_target) == len(fmap2) 168 | for f1, f2, num_lvls in zip(fmap1, fmap2, num_levels_per_target): 169 | assert isinstance(f1, th.Tensor) 170 | assert f1.ndim == 4 171 | assert isinstance(f2, th.Tensor) 172 | assert f2.ndim == 5 173 | assert f1.shape == f2.shape[1:] 174 | assert isinstance(num_lvls, th.Tensor) 175 | assert num_lvls.dtype == th.int64 176 | assert num_lvls.numel() == f2.shape[0] 177 | self._fmap1 = fmap1 178 | self._fmap2 = fmap2 179 | 180 | self._num_targets_per_reference = [x.shape[0] for x in fmap2] 181 | self._num_targets_overall = sum(self._num_targets_per_reference) 182 | assert self._num_targets_overall == th.cat(num_levels_per_target).numel() 183 | self._num_levels_per_target = num_levels_per_target 184 | 185 | self._bdhw = fmap1[0].shape 186 | 187 | @property 188 | def batch(self) -> int: 189 | return self._bdhw[0] 190 | 191 | @property 192 | def dim(self) -> int: 193 | return self._bdhw[1] 194 | 195 | @property 196 | def height(self) -> int: 197 | return self._bdhw[2] 198 | 199 | @property 200 | def width(self) -> int: 201 | return self._bdhw[3] 202 | 203 | @property 204 | def num_references(self) -> int: 205 | return len(self._fmap1) 206 | 207 | @property 208 | def num_levels_per_target(self) -> List[th.Tensor]: 209 | return self._num_levels_per_target 210 | 211 | @property 212 | def num_levels_per_target_merged(self) -> th.Tensor: 213 | return th.cat(self._num_levels_per_target) 214 | 215 | @property 216 | def num_targets_per_reference(self) -> List[int]: 217 | return self._num_targets_per_reference 218 | 219 | @property 220 | def num_targets_overall(self) -> int: 221 | return self._num_targets_overall 222 | 223 | def __add__(self, other: CorrComputation) -> CorrComputation: 224 | fmap1 = self._fmap1 + other._fmap1 225 | fmap2 = self._fmap2 + other._fmap2 226 | num_levels_per_target = self._num_levels_per_target + other._num_levels_per_target 227 | return CorrComputation(fmap1=fmap1, fmap2=fmap2, num_levels_per_target=num_levels_per_target) 228 | 229 | def get_correlation_volume(self) -> CorrData: 230 | # Note: we could also just use the more general N to N code but this should be slightly faster and easier to understand. 231 | if self._has_single_reference: 232 | # case: 1 to many, which includes the 1 to 1 special case 233 | return self._corr_dot_prod_1_to_N() 234 | # case: many to many 235 | return self._corr_dot_prod_M_to_N() 236 | 237 | def _corr_dot_prod_1_to_N(self) -> CorrData: 238 | # 1 to 1 if num_targets_sum is 1, which is a special case of this code. 239 | assert len(self._fmap1) == 1 240 | assert len(self._fmap2) == 1 241 | num_targets_sum = self.num_targets_overall 242 | fmap1 = self._fmap1[0].view(self.batch, self.dim, self.height*self.width) 243 | fmap2 = self._fmap2[0].view(num_targets_sum, self.batch, self.dim, self.height*self.width) 244 | 245 | corr_data = self._corr_dot_prod_util(fmap1, fmap2) 246 | return corr_data 247 | 248 | def _corr_dot_prod_M_to_N(self) -> CorrData: 249 | assert len(self._fmap1) > 1 250 | assert len(self._fmap2) > 1 251 | assert len(self._fmap1) == len(self._fmap2) 252 | 253 | # fmap{1,2}: num_targets_overall, batch, dim, height, width 254 | fmap1 = th.cat([x.expand(num_trgts, -1, -1, -1, -1) for x, num_trgts in zip(self._fmap1, self.num_targets_per_reference)], dim=0) 255 | fmap2 = th.cat(self._fmap2, dim=0) 256 | 257 | num_targets_overall = self.num_targets_overall 258 | fmap1 = fmap1.view(num_targets_overall, self.batch, self.dim, self.height*self.width) 259 | fmap2 = fmap2.view(num_targets_overall, self.batch, self.dim, self.height*self.width) 260 | 261 | corr_data = self._corr_dot_prod_util(fmap1, fmap2) 262 | return corr_data 263 | 264 | def _corr_dot_prod_util(self, fmap1: th.Tensor, fmap2: th.Tensor) -> CorrData: 265 | # corr: num_targets_sum, batch, ht*wd, ht*wd 266 | corr = fmap1.transpose(-1, -2) @ fmap2 267 | corr = corr / th.sqrt(th.tensor(self.dim, device=corr.device).float()) 268 | corr = corr.view(self.num_targets_overall, self.batch * self.height * self.width, 1, self.height, self.width) 269 | 270 | out = CorrData(corr=corr, batch_size=self.batch) 271 | out.init_target_indices() 272 | return out 273 | 274 | 275 | class CorrBlockParallelMultiTarget: 276 | def __init__(self, 277 | corr_computation_events: Optional[CorrComputation]=None, 278 | corr_computation_frames: Optional[CorrComputation]=None, 279 | radius: int=4): 280 | do_events = corr_computation_events is not None 281 | do_frames = corr_computation_frames is not None 282 | assert do_events or do_frames 283 | assert radius >= 1 284 | 285 | if do_events and not do_frames: 286 | corr_computation = corr_computation_events 287 | elif do_frames and not do_events: 288 | corr_computation = corr_computation_frames 289 | else: 290 | assert do_events and do_frames 291 | corr_computation = corr_computation_events + corr_computation_frames 292 | 293 | num_levels_per_target = corr_computation.num_levels_per_target_merged.tolist() 294 | self._num_targets_base = len(num_levels_per_target) 295 | self._radius = radius 296 | 297 | corr_base_data = corr_computation.get_correlation_volume() 298 | 299 | max_num_levels = max(num_levels_per_target) 300 | 301 | self._corr_pyramid = [corr_base_data] 302 | for num_levels in range(2, max_num_levels + 1): 303 | target_idx_list = [idx for idx, val in enumerate(num_levels_per_target) if val >= num_levels] 304 | corr_data = self._corr_pyramid[-1].get_downsampled(target_indices=target_idx_list) 305 | self._corr_pyramid.append(corr_data) 306 | 307 | def __call__(self, coords: Union[th.Tensor, List[th.Tensor], Tuple[th.Tensor]]) -> th.Tensor: 308 | if isinstance(coords, list) or isinstance(coords, tuple): 309 | # num_targets_base, N, 2, H, W 310 | coords = th.stack(coords, dim=0) 311 | assert coords.ndim == 5 312 | # num_targets_base, N, H, W, 2 313 | coords = coords.permute(0, 1, 3, 4, 2) 314 | num_targets_base, batch, h1, w1, _ = coords.shape 315 | assert num_targets_base == self._num_targets_base 316 | 317 | r = self._radius 318 | out_pyramid = [] 319 | for idx, corr_data in enumerate(self._corr_pyramid): 320 | target_indices = corr_data.target_indices 321 | assert target_indices.ndim == 1 322 | coords_selected = th.index_select(coords, dim=0, index=target_indices) 323 | num_targets = target_indices.numel() 324 | 325 | dx = th.linspace(-r, r, 2*r+1, device=coords.device) 326 | dy = th.linspace(-r, r, 2*r+1, device=coords.device) 327 | # delta: 2*r+1, 2*r+1, 2 328 | # NOTE: Unlike in the original implementation, we change the order 329 | # such that delta[..., 0] corresponds to x and delta[..., 1] corresponds to y 330 | # In fact, it does not matter since the same targets are looked up and then flattened and fed as channels. 331 | delta = th.stack(th.meshgrid(dy, dx)[::-1], dim=-1) 332 | 333 | centroid_lvl = coords_selected.reshape(num_targets*batch*h1*w1, 1, 1, 2) / 2**idx 334 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 335 | # coords_lvl: num_targets*batch*h1*w1, 2*r+1, 2*r+1, 2 336 | coords_lvl = centroid_lvl + delta_lvl 337 | 338 | # (reminder) corr_batched: num_targets*batch*h1*w1, 1, h2, w2 339 | # corr_feat: num_targets*batch*h1*w1, 1, 2*r+1, 2*r+1 340 | corr_feat = bilinear_sampler(corr_data.corr_batched, coords_lvl) 341 | # corr_feat: num_targets, batch, h1, w1, (2*r+1)**2 342 | corr_feat = corr_feat.view(num_targets, batch, h1, w1, -1) 343 | out_pyramid.append(corr_feat) 344 | 345 | # out: (num_targets_at_lvl_0 + ... + num_targets_at_lvl_final), batch, h1, w1, (2*r+1)**2 346 | out = th.cat(out_pyramid, dim=0) 347 | # out: batch, (num_targets_at_lvl_0 + ... + num_targets_at_lvl_final), (2*r+1)**2, h1, w1 348 | out = out.permute(1, 0, 4, 2, 3) 349 | # out: batch, (num_targets_at_lvl_0 + ... + num_targets_at_lvl_final)*(2*r+1)**2, h1, w1 350 | out = out.reshape(batch, -1, h1, w1).float() 351 | return out -------------------------------------------------------------------------------- /models/raft_utils/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResidualBlock(nn.Module): 6 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 7 | super(ResidualBlock, self).__init__() 8 | 9 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 10 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | num_groups = planes // 8 14 | 15 | if norm_fn == 'group': 16 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 17 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | if not stride == 1: 19 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 20 | 21 | elif norm_fn == 'batch': 22 | self.norm1 = nn.BatchNorm2d(planes) 23 | self.norm2 = nn.BatchNorm2d(planes) 24 | if not stride == 1: 25 | self.norm3 = nn.BatchNorm2d(planes) 26 | 27 | elif norm_fn == 'instance': 28 | self.norm1 = nn.InstanceNorm2d(planes) 29 | self.norm2 = nn.InstanceNorm2d(planes) 30 | if not stride == 1: 31 | self.norm3 = nn.InstanceNorm2d(planes) 32 | 33 | elif norm_fn == 'none': 34 | self.norm1 = nn.Sequential() 35 | self.norm2 = nn.Sequential() 36 | if not stride == 1: 37 | self.norm3 = nn.Sequential() 38 | 39 | if stride == 1: 40 | self.downsample = None 41 | 42 | else: 43 | self.downsample = nn.Sequential( 44 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 45 | 46 | 47 | def forward(self, x): 48 | y = x 49 | y = self.relu(self.norm1(self.conv1(y))) 50 | y = self.relu(self.norm2(self.conv2(y))) 51 | 52 | if self.downsample is not None: 53 | x = self.downsample(x) 54 | 55 | return self.relu(x+y) 56 | 57 | 58 | class BasicEncoder(nn.Module): 59 | def __init__(self, input_dim=3, output_dim=128, norm_fn='batch'): 60 | super(BasicEncoder, self).__init__() 61 | self.norm_fn = norm_fn 62 | 63 | if self.norm_fn == 'group': 64 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 65 | elif self.norm_fn == 'batch': 66 | self.norm1 = nn.BatchNorm2d(64) 67 | elif self.norm_fn == 'instance': 68 | self.norm1 = nn.InstanceNorm2d(64) 69 | elif self.norm_fn == 'none': 70 | self.norm1 = nn.Sequential() 71 | else: 72 | raise NotImplementedError 73 | 74 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3) 75 | self.relu1 = nn.ReLU(inplace=True) 76 | 77 | self.in_planes = 64 78 | self.layer1 = self._make_layer(64, stride=1) 79 | self.layer2 = self._make_layer(96, stride=2) 80 | self.layer3 = self._make_layer(128, stride=2) 81 | 82 | # output convolution 83 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 88 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 89 | if m.weight is not None: 90 | nn.init.constant_(m.weight, 1) 91 | if m.bias is not None: 92 | nn.init.constant_(m.bias, 0) 93 | 94 | def _make_layer(self, dim, stride=1): 95 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 96 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 97 | layers = (layer1, layer2) 98 | 99 | self.in_planes = dim 100 | return nn.Sequential(*layers) 101 | 102 | 103 | def forward(self, x): 104 | 105 | # if input is list, combine batch dimension 106 | is_list = isinstance(x, tuple) or isinstance(x, list) 107 | if is_list: 108 | batch_dim = x[0].shape[0] 109 | length = len(x) 110 | x = torch.cat(x, dim=0) 111 | 112 | x = self.conv1(x) 113 | x = self.norm1(x) 114 | x = self.relu1(x) 115 | 116 | x = self.layer1(x) 117 | x = self.layer2(x) 118 | x = self.layer3(x) 119 | 120 | x = self.conv2(x) 121 | 122 | if is_list: 123 | x = torch.split(x, [batch_dim]*length, dim=0) 124 | 125 | return x -------------------------------------------------------------------------------- /models/raft_utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def bilinear_sampler(img, coords): 6 | """ Wrapper for grid_sample, uses pixel coordinates """ 7 | # img (corr): batch*h1*w1, 1, h2, w2 8 | # coords: batch*h1*w1, 2*r+1, 2*r+1, 2 9 | H, W = img.shape[-2:] 10 | # *grid: batch*h1*w1, 2*r+1, 2*r+1, 1 11 | xgrid, ygrid = coords.split([1,1], dim=-1) 12 | # map *grid from [0, N-1] to [-1, 1] 13 | xgrid = 2*xgrid/(W-1) - 1 14 | ygrid = 2*ygrid/(H-1) - 1 15 | 16 | # grid: batch*h1*w1, 2*r+1, 2*r+1, 2 17 | grid = torch.cat([xgrid, ygrid], dim=-1) 18 | # img: batch*h1*w1, 1, 2*r+1, 2*r+1 19 | img = F.grid_sample(img, grid, align_corners=True) 20 | 21 | return img 22 | 23 | 24 | def coords_grid(batch, ht, wd, device): 25 | # ((ht, wd), (ht, wd)) 26 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 27 | # 2, ht, wd 28 | coords = torch.stack(coords[::-1], dim=0).float() 29 | # batch, 2, ht, wd 30 | return coords[None].repeat(batch, 1, 1, 1) 31 | 32 | 33 | def cvx_upsample(data: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 34 | """ Upsample data [N, dim, H/8, W/8] -> [N, dim, H, W] using convex combination """ 35 | N, dim, H, W = data.shape 36 | mask = mask.view(N, 1, 9, 8, 8, H, W) 37 | mask = torch.softmax(mask, dim=2) 38 | 39 | # NOTE: multiply by 8 due to the change in resolution. 40 | up_data = F.unfold(8 * data, [3, 3], padding=1) 41 | up_data = up_data.view(N, dim, 9, 1, 1, H, W) 42 | 43 | # N, dim, 8, 8, H, W 44 | up_data = torch.sum(mask * up_data, dim=2) 45 | # N, dim, H, 8, W, 8 46 | up_data = up_data.permute(0, 1, 4, 2, 5, 3) 47 | # N, dim, H*8, W*8 48 | return up_data.reshape(N, dim, 8*H, 8*W) 49 | -------------------------------------------------------------------------------- /modules/data_loading.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Any, Union 3 | 4 | import pytorch_lightning as pl 5 | from omegaconf import DictConfig 6 | from torch.utils.data import DataLoader 7 | 8 | from data.dsec.provider import DatasetProvider as DatasetProviderDSEC 9 | from data.multiflow2d.provider import DatasetProvider as DatasetProviderMULTIFLOW2D 10 | 11 | 12 | class DataModule(pl.LightningDataModule): 13 | DSEC_STR = 'dsec' 14 | MULTIFLOW2D_REGEN_STR = 'multiflow_regen' 15 | 16 | def __init__(self, 17 | config: Union[Dict[str, Any], DictConfig], 18 | batch_size_train: int = 1, 19 | batch_size_val: int = 1): 20 | super().__init__() 21 | dataset_params = config['dataset'] 22 | dataset_type = dataset_params['name'] 23 | num_workers = config['hardware']['num_workers'] 24 | 25 | assert dataset_type in {self.DSEC_STR, self.MULTIFLOW2D_REGEN_STR} 26 | self.dataset_type = dataset_type 27 | 28 | self.batch_size_train = batch_size_train 29 | self.batch_size_val = batch_size_val 30 | 31 | assert self.batch_size_train >= 1 32 | assert self.batch_size_val >= 1 33 | 34 | if num_workers is None: 35 | num_workers = 2*max([batch_size_train, batch_size_val]) 36 | num_workers = min([num_workers, os.cpu_count()]) 37 | print(f'num_workers: {num_workers}') 38 | 39 | self.num_workers = num_workers 40 | assert self.num_workers >= 0 41 | 42 | nbins_context = config['model']['num_bins']['context'] 43 | 44 | if dataset_type == self.DSEC_STR: 45 | dataset_provider = DatasetProviderDSEC(dataset_params, nbins_context) 46 | elif dataset_type == self.MULTIFLOW2D_REGEN_STR: 47 | dataset_provider = DatasetProviderMULTIFLOW2D(dataset_params, nbins_context) 48 | else: 49 | raise NotImplementedError 50 | self.train_dataset = dataset_provider.get_train_dataset() 51 | if dataset_type == self.DSEC_STR: 52 | self.val_dataset = None 53 | self.test_dataset = dataset_provider.get_test_dataset() 54 | else: 55 | self.val_dataset = dataset_provider.get_val_dataset() 56 | self.test_dataset = None 57 | 58 | self.nbins_context = dataset_provider.get_nbins_context() 59 | self.nbins_correlation = dataset_provider.get_nbins_correlation() 60 | 61 | assert self.nbins_context == nbins_context 62 | # Fill in nbins_correlation here because it can depend on the dataset. 63 | if 'correlation' in config['model']['num_bins']: 64 | nbins_correlation = config['model']['num_bins']['correlation'] 65 | if nbins_correlation is None: 66 | config['model']['num_bins']['correlation'] = self.nbins_correlation 67 | else: 68 | assert nbins_correlation == self.nbins_correlation 69 | 70 | def train_dataloader(self): 71 | return DataLoader( 72 | dataset=self.train_dataset, 73 | batch_size=self.batch_size_train, 74 | shuffle=True, 75 | num_workers=self.num_workers, 76 | pin_memory=True, 77 | drop_last=True) 78 | 79 | def val_dataloader(self): 80 | assert self.val_dataset is not None, f'No validation data found for {self.dataset_type} dataset' 81 | return DataLoader( 82 | dataset=self.val_dataset, 83 | batch_size=self.batch_size_val, 84 | shuffle=False, 85 | num_workers=self.num_workers, 86 | pin_memory=True, 87 | drop_last=True) 88 | 89 | def test_dataloader(self): 90 | assert self.test_dataset is not None, f'No test data found for {self.dataset_type} dataset' 91 | return DataLoader( 92 | dataset=self.test_dataset, 93 | batch_size=1, 94 | shuffle=False, 95 | num_workers=self.num_workers, 96 | pin_memory=True, 97 | drop_last=False) 98 | 99 | def get_nbins_context(self): 100 | return self.nbins_context 101 | 102 | def get_nbins_correlation(self): 103 | return self.nbins_correlation 104 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | from functools import wraps 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def detach_tensors(cpu: bool=False): 9 | ''' Detach tensors and optionally move to cpu. Only detaches torch.Tensor instances ''' 10 | # Decorator factory to enable decorator arguments: https://stackoverflow.com/a/50538967 11 | def allow_detach(key: str, value, train: bool): 12 | if train and key == 'loss': 13 | return False 14 | return isinstance(value, torch.Tensor) 15 | def detach_tensor(input_tensor: torch.Tensor, cpu: bool): 16 | assert isinstance(input_tensor, torch.Tensor) 17 | if cpu: 18 | return input_tensor.detach().cpu() 19 | return input_tensor.detach() 20 | def decorator(func): 21 | train = 'train' in func.__name__ 22 | 23 | @wraps(func) 24 | def inner(*args, **kwargs): 25 | output = func(*args, **kwargs) 26 | if isinstance(output, Mapping): 27 | return {k: (detach_tensor(v, cpu) if allow_detach(k, v, train) else v) for k, v in output.items()} 28 | assert isinstance(output, torch.Tensor) 29 | if train: 30 | # Do not detach because this will be the loss function of the training hook, which must not be detached. 31 | return output 32 | return detach_tensor(output, cpu) 33 | return inner 34 | return decorator 35 | 36 | 37 | def reduce_ev_repr(ev_repr: torch.Tensor) -> torch.Tensor: 38 | # This function is useful to reduce the overhead of moving an event representation 39 | # to CPU for visualization. 40 | # For now simply sum up the time dimension to reduce the memory. 41 | assert isinstance(ev_repr, torch.Tensor) 42 | assert ev_repr.ndim == 4 43 | assert ev_repr.is_cuda 44 | 45 | return torch.sum(ev_repr, dim=1) 46 | 47 | 48 | class InputPadder: 49 | """ Pads input tensor such that the last two dimensions are divisible by min_size """ 50 | def __init__(self, min_size: int=8, no_top_padding: bool=False): 51 | assert min_size > 0 52 | self.min_size = min_size 53 | self.no_top_padding = no_top_padding 54 | self._pad = None 55 | 56 | def requires_padding(self, input_tensor: torch.Tensor): 57 | ht, wd = input_tensor.shape[-2:] 58 | answer = False 59 | answer &= ht % self.min_size == 0 60 | answer &= wd % self.min_size == 0 61 | return answer 62 | 63 | def pad(self, input_tensor: torch.Tensor): 64 | ht, wd = input_tensor.shape[-2:] 65 | pad_ht = (((ht // self.min_size) + 1) * self.min_size - ht) % self.min_size 66 | pad_wd = (((wd // self.min_size) + 1) * self.min_size - wd) % self.min_size 67 | if self.no_top_padding: 68 | # Pad only bottom instead of top 69 | # RAFT uses this for KITTI 70 | pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 71 | else: 72 | # RAFT uses this for SINTEL (as default too) 73 | pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 74 | if self._pad is None: 75 | self._pad = pad 76 | else: 77 | assert self._pad == pad 78 | return F.pad(input_tensor, self._pad, mode='replicate') 79 | 80 | def unpad(self, input_tensor: torch.Tensor): 81 | ht, wd = input_tensor.shape[-2:] 82 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 83 | return input_tensor[..., c[0]:c[1], c[2]:c[3]] 84 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 5 | os.environ["OMP_NUM_THREADS"] = "1" 6 | os.environ["OPENBLAS_NUM_THREADS"] = "1" 7 | os.environ["MKL_NUM_THREADS"] = "1" 8 | os.environ["VECLIB_MAXIMUM_THREADS"] = "1" 9 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 10 | 11 | import hydra 12 | from omegaconf import DictConfig, OmegaConf 13 | import pytorch_lightning as pl 14 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelSummary, ModelCheckpoint 15 | from pytorch_lightning.strategies import DDPStrategy 16 | import torch 17 | import wandb 18 | 19 | from callbacks.logger import WandBImageLoggingCallback 20 | from utils.general import get_ckpt_callback 21 | from loggers.wandb_logger import WandbLogger 22 | from modules.data_loading import DataModule 23 | from modules.raft_spline import RAFTSplineModule 24 | 25 | 26 | @hydra.main(config_path='config', config_name='train', version_base='1.3') 27 | def main(cfg: DictConfig): 28 | print('------ Configuration ------\n') 29 | print(OmegaConf.to_yaml(cfg)) 30 | print('---------------------------\n') 31 | config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) 32 | # ------------ 33 | # Args 34 | # ------------ 35 | gpu_devices = config['hardware']['gpus'] 36 | gpus = gpu_devices if isinstance(gpu_devices, list) else [gpu_devices] 37 | num_gpus = len(gpus) 38 | 39 | batch_size: int = config['training']['batch_size'] 40 | assert batch_size > 0 41 | per_gpu_batch_size = batch_size 42 | if num_gpus == 1: 43 | strategy = 'auto' 44 | ddp_active = False 45 | else: 46 | strategy = DDPStrategy(process_group_backend='nccl', 47 | find_unused_parameters=False, 48 | gradient_as_bucket_view=True) 49 | ddp_active = True 50 | per_gpu_batch_size = batch_size // num_gpus 51 | assert_info = 'Batch size ({}) must be divisible by number of gpus ({})'.format(batch_size, num_gpus) 52 | assert batch_size * num_gpus == per_gpu_batch_size, assert_info 53 | 54 | limit_train_batches = float(config['training']['limit_train_batches']) 55 | if limit_train_batches > 1.0: 56 | limit_train_batches = int(limit_train_batches) 57 | 58 | limit_val_batches = float(config['training']['limit_val_batches']) 59 | if limit_val_batches > 1.0: 60 | limit_val_batches = int(limit_val_batches) 61 | 62 | # ------------ 63 | # Data 64 | # ------------ 65 | data_module = DataModule( 66 | config, 67 | batch_size_train=per_gpu_batch_size, 68 | batch_size_val=per_gpu_batch_size) 69 | 70 | num_bins_context = data_module.get_nbins_context() 71 | num_bins_corr = data_module.get_nbins_correlation() 72 | print(f'num_bins:\n\tcontext: {num_bins_context}\n\tcorrelation: {num_bins_corr}') 73 | 74 | # ------------ 75 | # Logging 76 | # ------------ 77 | wandb_config = config['wandb'] 78 | wandb_runpath = wandb_config['wandb_runpath'] 79 | if wandb_runpath is None: 80 | wandb_id = wandb.util.generate_id() 81 | print(f'new run: generating id {wandb_id}') 82 | else: 83 | wandb_id = Path(wandb_runpath).name 84 | print(f'using provided id {wandb_id}') 85 | logger = WandbLogger( 86 | project=wandb_config['project_name'], 87 | group=wandb_config['group_name'], 88 | wandb_id=wandb_id, 89 | log_model=True, 90 | save_last_only_final=False, 91 | save_code=True, 92 | config_args=config) 93 | resume_path = None 94 | if wandb_config['artifact_name'] is not None: 95 | artifact_runpath = wandb_config['artifact_runpath'] 96 | if artifact_runpath is None: 97 | artifact_runpath = wandb_runpath 98 | if artifact_runpath is None: 99 | print( 100 | 'must specify wandb_runpath or artifact_runpath to restore a checkpoint/artifact. Cannot load artifact.') 101 | else: 102 | artifact_name = wandb_config['artifact_name'] 103 | print(f'resuming checkpoint from runpath {artifact_runpath} and artifact name {artifact_name}') 104 | resume_path = logger.get_checkpoint(artifact_runpath, artifact_name) 105 | assert resume_path.exists() 106 | assert resume_path.suffix == '.ckpt', resume_path.suffix 107 | 108 | # ------------ 109 | # Checkpoints 110 | # ------------ 111 | checkpoint_callback = get_ckpt_callback(config=config) 112 | 113 | # ------------ 114 | # Other Callbacks 115 | # ------------ 116 | image_callback = WandBImageLoggingCallback(config['logging']) 117 | 118 | callback_list = None 119 | if config['debugging']['profiler'] is None: 120 | callback_list = [checkpoint_callback, image_callback] 121 | if config['training']['lr_scheduler']['use']: 122 | callback_list.append(LearningRateMonitor(logging_interval='step')) 123 | # ------------ 124 | # Model 125 | # ------------ 126 | 127 | if resume_path is not None and wandb_config['resume_only_weights']: 128 | print('Resuming only the weights instead of the full training state') 129 | net = RAFTSplineModule.load_from_checkpoint(str(resume_path), **{"config": config}) 130 | resume_path = None 131 | else: 132 | net = RAFTSplineModule(config) 133 | 134 | # ------------ 135 | # Training 136 | # ------------ 137 | logger.watch(net, log='all', log_freq=int(config['logging']['log_every_n_steps']), log_graph=True) 138 | 139 | gradient_clip_val = config['training']['gradient_clip_val'] 140 | if gradient_clip_val is not None and gradient_clip_val > 0: 141 | for param in net.parameters(): 142 | param.register_hook(lambda grad: torch.clamp(grad, -gradient_clip_val, gradient_clip_val)) 143 | 144 | # --------------------- 145 | # Callbacks and Misc 146 | # --------------------- 147 | callbacks = list() 148 | if config['training']['lr_scheduler']['use']: 149 | callbacks.append(LearningRateMonitor(logging_interval='step')) 150 | callbacks.append(ModelSummary(max_depth=2)) 151 | 152 | trainer = pl.Trainer( 153 | accelerator='gpu', 154 | strategy=strategy, 155 | enable_checkpointing=True, 156 | sync_batchnorm=True if ddp_active else False, 157 | devices=gpus, 158 | logger=logger, 159 | precision=32, 160 | max_epochs=int(config['training']['max_epochs']), 161 | max_steps=int(config['training']['max_steps']), 162 | profiler=config['debugging']['profiler'], 163 | limit_train_batches=limit_train_batches, 164 | limit_val_batches=limit_val_batches, 165 | log_every_n_steps=int(config['logging']['log_every_n_steps']), 166 | callbacks=callback_list) 167 | trainer.fit(model=net, ckpt_path=resume_path, datamodule=data_module) 168 | 169 | 170 | if __name__ == '__main__': 171 | main() 172 | -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from functools import wraps 3 | from typing import Any, Dict, Union, List 4 | 5 | import numpy as np 6 | import torch 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | 9 | 10 | def is_cpu(input_: Union[torch.Tensor, List[torch.Tensor]]) -> bool: 11 | if isinstance(input_, torch.Tensor): 12 | return input_.device == torch.device('cpu') 13 | assert isinstance(input_, list) 14 | on_cpu = True 15 | for x in input_: 16 | assert isinstance(x, torch.Tensor) 17 | on_cpu &= x.device == torch.device('cpu') 18 | return on_cpu 19 | 20 | 21 | def _convert_to_tensor(input_: Any): 22 | if input_ is None or isinstance(input_, torch.Tensor): 23 | return input_ 24 | if isinstance(input_, np.ndarray): 25 | return torch.from_numpy(input_) 26 | if isinstance(input_, numbers.Number): 27 | return torch.tensor(input_) 28 | if isinstance(input_, dict): 29 | return {k: _convert_to_tensor(v) for k, v in input_.items()} 30 | if isinstance(input_, list): 31 | return [_convert_to_tensor(x) for x in input_] 32 | if isinstance(input_, tuple): 33 | return (_convert_to_tensor(x) for x in input_) 34 | return input_ 35 | 36 | 37 | def inputs_to_tensor(func): 38 | @wraps(func) 39 | def inner(*args, **kwargs): 40 | args = _convert_to_tensor(args) 41 | kwargs = _convert_to_tensor(kwargs) 42 | return func(*args, **kwargs) 43 | 44 | return inner 45 | 46 | 47 | def _obj_has_function(obj, func_name: str): 48 | return hasattr(obj, func_name) and callable(getattr(obj, func_name)) 49 | 50 | 51 | def _data_to_cpu(input_: Any): 52 | if input_ is None: 53 | return input_ 54 | if isinstance(input_, torch.Tensor): 55 | return input_.cpu() 56 | if isinstance(input_, dict): 57 | return {k: _data_to_cpu(v) for k, v in input_.items()} 58 | if isinstance(input_, list): 59 | return [_data_to_cpu(x) for x in input_] 60 | if isinstance(input_, tuple): 61 | return (_data_to_cpu(x) for x in input_) 62 | assert _obj_has_function(input_, 'cpu') 63 | return input_.cpu() 64 | 65 | 66 | def to_cpu(func): 67 | ''' Move stuff to cpu ''' 68 | 69 | @wraps(func) 70 | def inner(*args, **kwargs): 71 | output = func(*args, **kwargs) 72 | output = _data_to_cpu(output) 73 | return output 74 | 75 | return inner 76 | 77 | 78 | def _unwrap_len1_list_or_tuple(input_: Any): 79 | if isinstance(input_, tuple): 80 | if len(input_) == 1: 81 | return input_[0] 82 | return (_unwrap_len1_list_or_tuple(x) for x in input_) 83 | if isinstance(input_, list): 84 | if len(input_) == 1: 85 | return input_[0] 86 | return [_unwrap_len1_list_or_tuple(x) for x in input_] 87 | if isinstance(input_, dict): 88 | return {k: _unwrap_len1_list_or_tuple(v) for k, v in input_.items()} 89 | return input_ 90 | 91 | 92 | def wrap_unwrap_lists_for_class_method(func): 93 | # We add "self" such that it can (only) be used on class methods 94 | @wraps(func) 95 | def inner(self, *args, **kwargs): 96 | # The reason why we have to explicitly add self in the arguments is that 97 | # we wrap the inputs in a list and would need to detect whether the input arg is self or not. 98 | args = [arg if isinstance(arg, list) or arg is None else [arg] for arg in args] 99 | kwargs = {k: (v if isinstance(v, list) or v is None else [v]) for k, v in kwargs.items()} 100 | out = func(self, *args, **kwargs) 101 | out = _unwrap_len1_list_or_tuple(out) 102 | return out 103 | 104 | return inner 105 | 106 | 107 | def get_ckpt_callback(config: Dict) -> ModelCheckpoint: 108 | dataset_name = config['dataset']['name'] 109 | assert dataset_name in {'dsec', 'multiflow_regen'} 110 | 111 | if dataset_name == 'dsec': 112 | ckpt_callback_monitor = 'global_step' 113 | ckpt_filename = 'epoch={epoch:03d}-step={' + ckpt_callback_monitor + ':.0f}' 114 | mode = 'max' # because of global_step 115 | else: 116 | prefix = 'val' 117 | metric = 'epe_multi' 118 | ckpt_callback_monitor = f'{prefix}/{metric}' 119 | filename_monitor_string = f'{prefix}_{metric}' 120 | ckpt_filename = 'epoch={epoch:03d}-step={step}-' + filename_monitor_string + '={' + ckpt_callback_monitor + ':.2f}' 121 | mode = 'min' 122 | 123 | checkpoint_callback = ModelCheckpoint( 124 | monitor=ckpt_callback_monitor, 125 | filename=ckpt_filename, 126 | auto_insert_metric_name=False, 127 | save_top_k=1, 128 | mode=mode, 129 | every_n_epochs=config['logging']['ckpt_every_n_epochs'], 130 | save_last=True, 131 | verbose=True, 132 | ) 133 | checkpoint_callback.CHECKPOINT_NAME_LAST = 'last_epoch={epoch:03d}-step={step}' 134 | return checkpoint_callback -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch as th 4 | 5 | 6 | def l1_loss_channel_masked(source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None): 7 | # source: (N, C, *) 8 | # target: (N, C, *), source.shape == target.shape 9 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1 10 | assert source.ndim > 2 11 | assert source.shape == target.shape 12 | 13 | loss = th.abs(source - target).sum(1) 14 | if valid_mask is not None: 15 | assert valid_mask.shape[0] == target.shape[0] 16 | assert valid_mask.ndim == target.ndim - 1 17 | assert valid_mask.dtype == th.bool 18 | assert loss.shape == valid_mask.shape 19 | loss_masked = loss[valid_mask].sum() / valid_mask.sum() 20 | return loss_masked 21 | return th.mean(loss) 22 | 23 | 24 | def l1_seq_loss_channel_masked(source_list: List[th.Tensor], target: th.Tensor, valid_mask: Optional[th.Tensor]=None, gamma: float=0.8): 25 | # source: [(N, C, *), ...], I predictions from I iterations 26 | # target: (N, C, *), source.shape == target.shape 27 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1 28 | 29 | # Adopted from https://github.com/princeton-vl/RAFT/blob/224320502d66c356d88e6c712f38129e60661e80/train.py#L47 30 | 31 | n_predictions = len(source_list) 32 | loss = 0 33 | 34 | for i in range(n_predictions): 35 | i_weight = gamma**(n_predictions - i - 1) 36 | i_loss = l1_loss_channel_masked(source_list[i], target, valid_mask) 37 | loss += i_weight * i_loss 38 | 39 | return loss 40 | 41 | def l1_multi_seq_loss_channel_masked(src_list_list: List[List[th.Tensor]], target_list: List[th.Tensor], valid_mask_list: Optional[List[th.Tensor]]=None, gamma: float=0.8): 42 | # src_list_list: [[(N, C, *), ...], ...], I*M predictions -> I iterations (outer) and M supervision targets (inner) 43 | # target_list: [(N, C, *), ...], M supervision targets 44 | # valid_mask_list: [(N, *), ...], M supervision targets 45 | 46 | loss = 0 47 | 48 | num_iters = len(src_list_list) 49 | for iter_idx, sources_per_iter in enumerate(src_list_list): 50 | # iteration over RAFT iterations 51 | num_targets = len(sources_per_iter) 52 | assert num_targets > 0 53 | assert num_targets == len(target_list) 54 | i_loss = 0 55 | for tindex, source in enumerate(sources_per_iter): 56 | # iteration over prediction times 57 | i_loss += l1_loss_channel_masked(source, target_list[tindex], valid_mask_list[tindex] if valid_mask_list is not None else None) 58 | 59 | i_loss /= num_targets 60 | i_weight = gamma**(num_iters - iter_idx - 1) 61 | loss += i_weight * i_loss 62 | 63 | return loss -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional 3 | 4 | import torch as th 5 | from torchmetrics import Metric 6 | 7 | from utils.losses import l1_loss_channel_masked 8 | 9 | 10 | class L1ChannelMasked(Metric): 11 | def __init__(self, dist_sync_on_step=False): 12 | super().__init__(dist_sync_on_step=dist_sync_on_step) 13 | 14 | self.add_state("l1", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum") 15 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum") 16 | 17 | def update(self, source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None): 18 | # source (prediction): (N, C, *), 19 | # target (ground truth): (N, C, *), source.shape == target.shape 20 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1 21 | 22 | self.l1 += l1_loss_channel_masked(source, target, valid_mask).double() 23 | self.total += 1 24 | 25 | def compute(self): 26 | assert self.total > 0 27 | return (self.l1 / self.total).float() 28 | 29 | 30 | class EPE(Metric): 31 | def __init__(self, dist_sync_on_step=False): 32 | super().__init__(dist_sync_on_step=dist_sync_on_step) 33 | 34 | self.add_state("epe", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum") 35 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum") 36 | 37 | def update(self, source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None): 38 | # source (prediction): (N, C, *), 39 | # target (ground truth): (N, C, *), source.shape == target.shape 40 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1 41 | 42 | epe = epe_masked(source, target, valid_mask) 43 | if epe is not None: 44 | self.epe += epe.double() 45 | self.total += 1 46 | 47 | def compute(self): 48 | assert self.total > 0 49 | return (self.epe / self.total).float() 50 | 51 | class EPE_MULTI(Metric): 52 | def __init__(self, dist_sync_on_step=False, min_traj_len=None, max_traj_len=None): 53 | super().__init__(dist_sync_on_step=dist_sync_on_step) 54 | 55 | self.add_state("epe", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum") 56 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum") 57 | self.min_traj_len = min_traj_len 58 | self.max_traj_len = max_traj_len 59 | 60 | @staticmethod 61 | def compute_traj_len(target: List[th.Tensor]): 62 | target_stack = th.stack(target, dim=0) 63 | diff = target_stack[1:] - target_stack[:-1] 64 | return diff.square().sum(dim=2).sqrt().sum(dim=0) 65 | 66 | def get_true_mask(self, target: List[th.Tensor], device: th.device): 67 | valid_shape = (target[0].shape[0],) + target[0].shape[2:] 68 | return th.ones(valid_shape, dtype=th.bool, device=device) 69 | 70 | def update(self, source: List[th.Tensor], target: List[th.Tensor], valid_mask: Optional[List[th.Tensor]]=None): 71 | # source_lst: [(N, C, *), ...], M evaluation/predictions 72 | # target_lst: [(N, C, *), ...], source_lst[*].shape == target_lst[*].shape 73 | # valid_mask_lst: [(N, *), ...], where the channel dimension is missing. I.e. valid_mask_lst[*].ndim == target_lst[*].ndim - 1 74 | 75 | if self.min_traj_len is not None or self.max_traj_len is not None: 76 | traj_len = self.compute_traj_len(target=target) 77 | valid_len = self.get_true_mask(target=target, device=target[0].device) 78 | if self.min_traj_len is not None: 79 | valid_len &= (traj_len >= self.min_traj_len) 80 | if self.max_traj_len is not None: 81 | valid_len &= (traj_len <= self.max_traj_len) 82 | if valid_mask is None: 83 | valid_mask = [valid_len.clone() for _ in range(len(target))] 84 | else: 85 | valid_mask = [valid_mask[idx] & valid_len for idx in range(len(target))] 86 | 87 | epe = epe_masked_multi(source, target, valid_mask) 88 | if epe is not None: 89 | self.epe += epe.double() 90 | self.total += 1 91 | 92 | def compute(self): 93 | assert self.total > 0 94 | return (self.epe / self.total).float() 95 | 96 | class AE(Metric): 97 | def __init__(self, degrees: bool=True, dist_sync_on_step=False): 98 | super().__init__(dist_sync_on_step=dist_sync_on_step) 99 | 100 | self.degrees = degrees 101 | 102 | self.add_state("ae", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum") 103 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum") 104 | 105 | def update(self, source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None): 106 | # source (prediction): (N, C, *), 107 | # target (ground truth): (N, C, *), source.shape == target.shape 108 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1 109 | 110 | self.ae += ae_masked(source, target, valid_mask, degrees=self.degrees).double() 111 | self.total += 1 112 | 113 | def compute(self): 114 | assert self.total > 0 115 | return (self.ae / self.total).float() 116 | 117 | 118 | class AE_MULTI(Metric): 119 | def __init__(self, degrees: bool=True, dist_sync_on_step=False): 120 | super().__init__(dist_sync_on_step=dist_sync_on_step) 121 | 122 | self.degrees = degrees 123 | 124 | self.add_state("ae", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum") 125 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum") 126 | 127 | def update(self, source: List[th.Tensor], target: List[th.Tensor], valid_mask: Optional[List[th.Tensor]]=None): 128 | # source_lst: [(N, C, *), ...], M evaluation/predictions 129 | # target_lst: [(N, C, *), ...], source_lst[*].shape == target_lst[*].shape 130 | # valid_mask_lst: [(N, *), ...], where the channel dimension is missing. I.e. valid_mask_lst[*].ndim == target_lst[*].ndim - 1 131 | 132 | self.ae += ae_masked_multi(source, target, valid_mask, degrees=self.degrees).double() 133 | self.total += 1 134 | 135 | def compute(self): 136 | assert self.total > 0 137 | return (self.ae / self.total).float() 138 | 139 | class NPE(Metric): 140 | def __init__(self, n_pixels: float, dist_sync_on_step=False): 141 | super().__init__(dist_sync_on_step=dist_sync_on_step) 142 | 143 | assert n_pixels > 0 144 | self.n_pixels = n_pixels 145 | 146 | self.add_state("npe", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum") 147 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum") 148 | 149 | def update(self, source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None): 150 | # source (prediction): (N, C, *), 151 | # target (ground truth): (N, C, *), source.shape == target.shape 152 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1 153 | 154 | self.npe += n_pixel_error_masked(source, target, valid_mask, self.n_pixels).double() 155 | self.total += 1 156 | 157 | def compute(self): 158 | assert self.total > 0 159 | return (self.npe / self.total).float() 160 | 161 | def n_pixel_error_masked(source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor], n_pixels: float): 162 | # source: (N, C, *), 163 | # target: (N, C, *), source.shape == target.shape 164 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1 165 | assert source.ndim > 2 166 | assert source.shape == target.shape 167 | 168 | if valid_mask is not None: 169 | assert valid_mask.shape[0] == target.shape[0] 170 | assert valid_mask.ndim == target.ndim - 1 171 | assert valid_mask.dtype == th.bool 172 | 173 | num_valid = th.sum(valid_mask) 174 | assert num_valid > 0 175 | 176 | gt_flow_magn = th.linalg.norm(target, dim=1) 177 | error_magn = th.linalg.norm(source - target, dim=1) 178 | 179 | if valid_mask is not None: 180 | rel_error = th.zeros_like(error_magn) 181 | rel_error[valid_mask] = error_magn[valid_mask] / th.clip(gt_flow_magn[valid_mask], min=1e-6) 182 | else: 183 | rel_error = error_magn / th.clip(gt_flow_magn, min=1e-6) 184 | 185 | error_map = (error_magn > n_pixels) & (rel_error >= 0.05) 186 | 187 | if valid_mask is not None: 188 | error = error_map[valid_mask].sum() / num_valid 189 | else: 190 | error = th.mean(error_map.float()) 191 | 192 | error *= 100 193 | return error 194 | 195 | 196 | def epe_masked(source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor] = None) -> Optional[th.Tensor]: 197 | # source: (N, C, *), 198 | # target: (N, C, *), source.shape == target.shape 199 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1 200 | assert source.ndim > 2 201 | assert source.shape == target.shape 202 | 203 | epe = th.sqrt(th.square(source - target).sum(1)) 204 | if valid_mask is not None: 205 | assert valid_mask.shape[0] == target.shape[0] 206 | assert valid_mask.ndim == target.ndim - 1 207 | assert valid_mask.dtype == th.bool 208 | assert epe.shape == valid_mask.shape 209 | denominator = valid_mask.sum() 210 | if denominator == 0: 211 | return None 212 | return epe[valid_mask].sum() / denominator 213 | return th.mean(epe) 214 | 215 | 216 | def epe_masked_multi(source_lst: List[th.Tensor], target_lst: List[th.Tensor], valid_mask_lst: Optional[List[th.Tensor]] = None) -> Optional[th.Tensor]: 217 | # source_lst: [(N, C, *), ...], M evaluation/predictions 218 | # target_lst: [(N, C, *), ...], source_lst[*].shape == target_lst[*].shape 219 | # valid_mask_lst: [(N, *), ...], where the channel dimension is missing. I.e. valid_mask_lst[*].ndim == target_lst[*].ndim - 1 220 | 221 | num_preds = len(source_lst) 222 | assert num_preds > 0 223 | assert len(target_lst) == num_preds, len(target_lst) 224 | if valid_mask_lst is not None: 225 | assert len(valid_mask_lst) == num_preds, len(valid_mask_lst) 226 | else: 227 | valid_mask_lst = [None]*num_preds 228 | epe_sum = 0 229 | denominator = 0 230 | for source, target, valid_mask in zip(source_lst, target_lst, valid_mask_lst): 231 | epe = epe_masked(source, target, valid_mask) 232 | if epe is not None: 233 | epe_sum += epe 234 | denominator += 1 235 | if denominator == 0: 236 | return None 237 | epe_sum: th.Tensor 238 | return epe_sum / denominator 239 | 240 | def ae_masked_multi(source_lst: List[th.Tensor], target_lst: List[th.Tensor], valid_mask_lst: Optional[List[th.Tensor]]=None, degrees: bool=True) -> th.Tensor: 241 | # source_lst: [(N, C, *), ...], M evaluation/predictions 242 | # target_lst: [(N, C, *), ...], source_lst[*].shape == target_lst[*].shape 243 | # valid_mask_lst: [(N, *), ...], where the channel dimension is missing. I.e. valid_mask_lst[*].ndim == target_lst[*].ndim - 1 244 | 245 | num_preds = len(source_lst) 246 | assert num_preds > 0 247 | assert len(target_lst) == num_preds, len(target_lst) 248 | if valid_mask_lst is not None: 249 | assert len(valid_mask_lst) == num_preds, len(valid_mask_lst) 250 | else: 251 | valid_mask_lst = [None]*num_preds 252 | ae_sum = 0 253 | for source, target, valid_mask in zip(source_lst, target_lst, valid_mask_lst): 254 | ae_sum += ae_masked(source, target, valid_mask, degrees) 255 | ae_sum: th.Tensor 256 | return ae_sum / num_preds 257 | 258 | 259 | def ae_masked(source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None, degrees: bool=True) -> th.Tensor: 260 | # source: (N, C, *), 261 | # target: (N, C, *), source.shape == target.shape 262 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1 263 | assert source.ndim > 2 264 | assert source.shape == target.shape 265 | 266 | shape = list(source.shape) 267 | extension_shape = shape 268 | extension_shape[1] = 1 269 | extension = th.ones(extension_shape, device=source.device) 270 | 271 | source_ext = th.cat((source, extension), dim=1) 272 | target_ext = th.cat((target, extension), dim=1) 273 | 274 | # according to https://vision.middlebury.edu/flow/floweval-ijcv2011.pdf 275 | 276 | nominator = th.sum(source_ext * target_ext, dim=1) 277 | denominator = th.linalg.norm(source_ext, dim=1) * th.linalg.norm(target_ext, dim=1) 278 | 279 | tmp = th.div(nominator, denominator) 280 | 281 | # Somehow this seems necessary 282 | tmp[tmp > 1.0] = 1.0 283 | tmp[tmp < -1.0] = -1.0 284 | 285 | ae = th.acos(tmp) 286 | if degrees: 287 | ae = ae/math.pi*180 288 | 289 | if valid_mask is not None: 290 | assert valid_mask.shape[0] == target.shape[0] 291 | assert valid_mask.ndim == target.ndim - 1 292 | assert valid_mask.dtype == th.bool 293 | assert ae.shape == valid_mask.shape 294 | ae_masked = ae[valid_mask].sum() / valid_mask.sum() 295 | return ae_masked 296 | return th.mean(ae) 297 | 298 | def predictions_from_lin_assumption(source: th.Tensor, target_timestamps: List[float]) -> List[th.Tensor]: 299 | assert max(target_timestamps) <= 1 300 | assert 0 <= min(target_timestamps) 301 | 302 | output = list() 303 | for target_ts in target_timestamps: 304 | output.append(target_ts * source) 305 | return output 306 | -------------------------------------------------------------------------------- /utils/timers.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | 7 | cuda_timers = {} 8 | timers = {} 9 | 10 | 11 | class CudaTimer: 12 | # NOTE: This timer seems to work fine. 13 | # However, it's unclear whether I have to synchronize just after entering the __enter__ function. 14 | def __init__(self, device: torch.device, timer_name: str = ''): 15 | self.timer_name = timer_name 16 | if self.timer_name not in cuda_timers: 17 | cuda_timers[self.timer_name] = [] 18 | 19 | self.device = device 20 | self.start = None 21 | self.end = None 22 | 23 | def __enter__(self): 24 | torch.cuda.synchronize(device=self.device) 25 | self.start = time.time() 26 | return self 27 | 28 | def __exit__(self, *args): 29 | assert self.start is not None 30 | torch.cuda.synchronize(device=self.device) 31 | end = time.time() 32 | cuda_timers[self.timer_name].append(end - self.start) 33 | 34 | 35 | class CudaTimerDummy: 36 | def __init__(self, *args, **kwargs): 37 | pass 38 | 39 | def __enter__(self): 40 | pass 41 | 42 | def __exit__(self, *args): 43 | pass 44 | 45 | 46 | class Timer: 47 | def __init__(self, timer_name=''): 48 | self.timer_name = timer_name 49 | if self.timer_name not in timers: 50 | timers[self.timer_name] = [] 51 | 52 | def __enter__(self): 53 | self.start = time.time() 54 | return self 55 | 56 | def __exit__(self, *args): 57 | end = time.time() 58 | time_diff_s = end - self.start # measured in seconds 59 | timers[self.timer_name].append(time_diff_s) 60 | 61 | 62 | def print_timing_info(): 63 | print('== Timing statistics ==') 64 | skip_warmup = 2 65 | for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]: 66 | if len(timing_values) <= skip_warmup: 67 | continue 68 | values = timing_values[skip_warmup:] 69 | timing_value_s = np.mean(np.array(values)) 70 | timing_value_ms = timing_value_s * 1000 71 | if timing_value_ms > 1000: 72 | print('{}: {:.2f} s'.format(timer_name, timing_value_s)) 73 | else: 74 | print('{}: {:.2f} ms'.format(timer_name, timing_value_ms)) 75 | 76 | 77 | # this will print all the timer values upon termination of any program that imported this file 78 | atexit.register(print_timing_info) 79 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 5 | os.environ["OMP_NUM_THREADS"] = "1" 6 | os.environ["OPENBLAS_NUM_THREADS"] = "1" 7 | os.environ["MKL_NUM_THREADS"] = "1" 8 | os.environ["VECLIB_MAXIMUM_THREADS"] = "1" 9 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 10 | 11 | import hydra 12 | from omegaconf import DictConfig, OmegaConf 13 | import pytorch_lightning as pl 14 | from pytorch_lightning.callbacks import ModelSummary 15 | from pytorch_lightning.loggers import CSVLogger 16 | import torch 17 | 18 | from modules.data_loading import DataModule 19 | from modules.raft_spline import RAFTSplineModule 20 | 21 | 22 | @hydra.main(config_path='config', config_name='val', version_base='1.3') 23 | def main(config: DictConfig): 24 | print('------ Configuration ------\n') 25 | print(OmegaConf.to_yaml(config)) 26 | print('---------------------------\n') 27 | OmegaConf.to_container(config, resolve=True, throw_on_missing=True) 28 | 29 | # ------------ 30 | # GPU Options 31 | # ------------ 32 | gpus = config.hardware.gpus 33 | assert isinstance(gpus, int), 'no more than 1 gpu supported' 34 | gpus = [gpus] 35 | 36 | batch_size: int = config.batch_size 37 | assert batch_size > 0 38 | 39 | # ------------ 40 | # Data 41 | # ------------ 42 | data_module = DataModule(config, batch_size_train=batch_size, batch_size_val=batch_size) 43 | 44 | num_bins_context = data_module.get_nbins_context() 45 | num_bins_corr = data_module.get_nbins_correlation() 46 | print(f'num_bins:\n\tcontext: {num_bins_context}\n\tcorrelation: {num_bins_corr}') 47 | 48 | # --------------------- 49 | # Logging and Checkpoints 50 | # --------------------- 51 | logger = CSVLogger(save_dir='./validation_logs') 52 | ckpt_path = Path(config.checkpoint) 53 | 54 | # ------------ 55 | # Model 56 | # ------------ 57 | 58 | module = RAFTSplineModule.load_from_checkpoint(str(ckpt_path), **{'config': config}) 59 | 60 | # --------------------- 61 | # Callbacks and Misc 62 | # --------------------- 63 | callbacks = [ModelSummary(max_depth=2)] 64 | 65 | trainer = pl.Trainer( 66 | accelerator='gpu', 67 | callbacks=callbacks, 68 | default_root_dir=None, 69 | devices=gpus, 70 | logger=logger, 71 | log_every_n_steps=100, 72 | precision=32, 73 | ) 74 | 75 | with torch.inference_mode(): 76 | trainer.validate(model=module, datamodule=data_module, ckpt_path=str(ckpt_path)) 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | --------------------------------------------------------------------------------