├── .gitignore
├── LICENSE
├── README.md
├── callbacks
├── logger.py
└── utils
│ ├── flow_vis.py
│ └── visualization.py
├── config
├── dataset
│ ├── base.yaml
│ ├── dsec.yaml
│ └── multiflow_regen.yaml
├── experiment
│ ├── dsec
│ │ └── raft_spline
│ │ │ ├── E_I_LU4_BD2_lowpyramid.yaml
│ │ │ └── E_LU4_BD2_lowpyramid.yaml
│ └── multiflow
│ │ └── raft_spline
│ │ ├── E_I_LU5_BD10_lowpyramid.yaml
│ │ └── E_LU5_BD10_lowpyramid.yaml
├── general.yaml
├── model
│ ├── base.yaml
│ ├── raft-spline.yaml
│ └── raft_base.yaml
├── train.yaml
└── val.yaml
├── data
├── dsec
│ ├── eventslicer.py
│ ├── provider.py
│ ├── sequence.py
│ └── subsequence
│ │ ├── base.py
│ │ └── twostep.py
├── multiflow2d
│ ├── datasubset.py
│ ├── provider.py
│ └── sample.py
└── utils
│ ├── augmentor.py
│ ├── generic.py
│ ├── keys.py
│ ├── provider.py
│ └── representations.py
├── loggers
└── wandb_logger.py
├── models
├── raft_spline
│ ├── bezier.py
│ ├── raft.py
│ └── update.py
└── raft_utils
│ ├── corr.py
│ ├── extractor.py
│ └── utils.py
├── modules
├── data_loading.py
├── raft_spline.py
└── utils.py
├── train.py
├── utils
├── general.py
├── losses.py
├── metrics.py
└── timers.py
└── val.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dense Continuous-Time Optical Flow from Event Cameras
2 |
3 | 
4 |
5 | This is the official Pytorch implementation of the TPAMI 2024 paper [Dense Continuous-Time Optical Flow from Event Cameras](https://ieeexplore.ieee.org/document/10419040).
6 |
7 | If you find this code useful, please cite us:
8 | ```bibtex
9 | @Article{Gehrig2024pami,
10 | author = {Mathias Gehrig and Manasi Muglikar and Davide Scaramuzza},
11 | title = {Dense Continuous-Time Optical Flow from Event Cameras},
12 | journal = {{IEEE} Trans. Pattern Anal. Mach. Intell. (T-PAMI)},
13 | year = 2024
14 | }
15 | ```
16 |
17 | ## Conda Installation
18 | We highly recommend to use [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) to reduce the installation time.
19 | ```Bash
20 | conda create -y -n bflow python=3.11 pip
21 | conda activate bflow
22 | conda config --set channel_priority flexible
23 |
24 | CUDA_VERSION=12.1
25 |
26 | conda install -y h5py=3.10.0 blosc-hdf5-plugin=1.0.0 llvm-openmp=15.0.7 \
27 | hydra-core=1.3.2 einops=0.7 tqdm numba \
28 | pytorch=2.1.2 torchvision pytorch-cuda=$CUDA_VERSION \
29 | -c pytorch -c nvidia -c conda-forge
30 |
31 | python -m pip install pytorch-lightning==2.1.3 wandb==0.16.1 \
32 | opencv-python==4.8.1.78 imageio==2.33.1 lpips==0.1.4 \
33 | pandas==2.1.4 plotly==5.18.0 moviepy==1.0.3 tabulate==0.9.0 \
34 | loguru==0.7.2 matplotlib==3.8.2 scikit-image==0.22.0 kaleido
35 | ```
36 | ## Data
37 | ### MultiFlow
38 |
39 | |
40 | Train |
41 | Val |
42 | pre-processed dataset |
43 | download |
44 | download |
45 |
46 |
47 |
48 | ### DSEC
49 |
50 | |
51 | Train |
52 | Test (input) |
53 | pre-processed dataset |
54 | download |
55 | download |
56 |
57 | crc32 |
58 | c1b618fc |
59 | ffbacb7e |
60 |
61 |
62 |
63 | ## Checkpoints
64 |
65 | ### MultiFlow
66 |
67 |
68 | |
69 | Events only |
70 | Events + Images |
71 | pre-trained checkpoint |
72 | download |
73 | download |
74 |
75 | md5 |
76 | 61e102 |
77 | 2ce3aa |
78 |
79 |
80 |
81 | ### DSEC
82 |
83 |
84 | |
85 | Events only |
86 | Events + Images |
87 | pre-trained checkpoint |
88 | download |
89 | download |
90 |
91 | md5 |
92 | d17002 |
93 | 05770b |
94 |
95 |
96 |
97 |
98 | ## Training
99 |
100 | ### MultiFlow
101 | - Set `DATA_DIR` as the path to the MultiFlow dataset (parent of train and val dir)
102 | - Set
103 | - `MDL_CFG=E_I_LU5_BD10_lowpyramid` to use both events and frames, or
104 | - `MDL_CFG=E_LU5_BD10_lowpyramid` to use only events
105 | - Set `LOG_ONLY_NUMBERS=true` to avoid logging images (can require a lot of space). Set to false by default.
106 |
107 | ```Bash
108 | GPU_ID=0
109 | python train.py model=raft-spline dataset=multiflow_regen dataset.path=${DATA_DIR} wandb.group_name=multiflow \
110 | hardware.gpus=${GPU_ID} hardware.num_workers=6 +experiment/multiflow/raft_spline=${MLD_CFG} \
111 | logging.only_numbers=${LOG_ONLY_NUMBERS}
112 | ```
113 |
114 | ### DSEC
115 | - Set `DATA_DIR` as the path to the DSEC dataset (parent of train and test dir)
116 | -
117 | - Set
118 | - `MDL_CFG=E_I_LU4_BD2_lowpyramid` to use both events and frames, or
119 | - `MDL_CFG=E_LU4_BD2_lowpyramid` to use only events
120 | - Set `LOG_ONLY_NUMBERS=true` to avoid logging images (can require a lot of space). Set to false by default.
121 |
122 | ```Bash
123 | GPU_ID=0
124 | python train.py model=raft-spline dataset=dsec dataset.path=${DATA_DIR} wandb.group_name=dsec \
125 | hardware.gpus=${GPU_ID} hardware.num_workers=6 +experiment/dsec/raft_spline=${MLD_CFG} \
126 | logging.only_numbers=${LOG_ONLY_NUMBERS}
127 | ```
128 |
129 | ## Evaluation
130 |
131 | ### MultiFlow
132 | - Set `DATA_DIR` as the path to the MultiFlow dataset (parent of train and val dir)
133 | - Set
134 | - `MDL_CFG=E_I_LU5_BD10_lowpyramid` to use both events and frames, or
135 | - `MDL_CFG=E_LU5_BD10_lowpyramid` to use only events
136 | - Set `CKPT` to the path of the correct checkpoint
137 |
138 | ```Bash
139 | GPU_ID=0
140 | python val.py model=raft-spline dataset=multiflow_regen dataset.path=${DATA_DIR} hardware.gpus=${GPU_ID} \
141 | +experiment/multiflow/raft_spline=${MLD_CFG} checkpoint=${CKPT}
142 | ```
143 |
144 | ### DSEC
145 |
146 | work in progress
147 |
148 | ## Code Acknowledgments
149 | This project has used code from [RAFT](https://github.com/princeton-vl/RAFT) for parts of the model architecture.
150 |
--------------------------------------------------------------------------------
/callbacks/logger.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Dict, Union, Any, Optional, List
3 |
4 | import pytorch_lightning as pl
5 | from pytorch_lightning.callbacks import Callback
6 | import wandb
7 |
8 | import torch
9 |
10 | from callbacks.utils.visualization import (
11 | create_summary_img,
12 | get_grad_flow_figure,
13 | multi_plot_bezier_array,
14 | ev_repr_reduced_to_img_grayscale,
15 | img_torch_to_numpy)
16 | from utils.general import is_cpu
17 | from data.utils.keys import DataLoading, DataSetType
18 |
19 | from loggers.wandb_logger import WandbLogger
20 | from models.raft_spline.bezier import BezierCurves
21 |
22 |
23 | class WandBImageLoggingCallback(Callback):
24 | FLOW_GT: str = 'flow_gt_img'
25 | FLOW_PRED: str = 'flow_pred_img'
26 | FLOW_VALID: str = 'flow_valid_img'
27 | EV_REPR_REDUCED: str = 'ev_repr_reduced'
28 | EV_REPR_REDUCED_M1: str = 'ev_repr_reduced_m1'
29 | VAL_BATCH_IDX: str = 'val_batch_idx'
30 | BEZIER_PARAMS: str = 'bezier_prediction'
31 | IMAGES: str = 'images'
32 |
33 | MAX_FLOW_ERROR = {
34 | int(DataSetType.DSEC): 2.0,
35 | int(DataSetType.MULTIFLOW2D): 3.0,
36 | }
37 |
38 | def __init__(self, logging_params: Dict[str, Any], deterministic: bool=True):
39 | super().__init__()
40 | log_every_n_train_steps = logging_params['log_every_n_steps']
41 | log_n_val_predictions = logging_params['log_n_val_predictions']
42 | self.log_only_numbers = logging_params['only_numbers']
43 | assert log_every_n_train_steps > 0
44 | assert log_n_val_predictions > 0
45 | self.log_every_n_train_steps = log_every_n_train_steps
46 | self.log_n_val_predictions = log_n_val_predictions
47 | self.deterministic = deterministic
48 |
49 | self._clear_val_data()
50 | self._training_started = False
51 | self._val_batch_indices = None
52 |
53 | self._dataset_type: Optional[DataSetType] = None
54 |
55 | def enable_immediate_validation(self):
56 | self._training_started = True
57 |
58 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
59 | if self.log_only_numbers:
60 | return
61 | if not self._training_started:
62 | self._training_started = True
63 | global_step = trainer.global_step
64 | do_log = True
65 | do_log &= global_step >= self.log_every_n_train_steps
66 | # BUG: wandb bug? If we log metrics on the same step as logging images in this function
67 | # then no metrics are logged at all, only the images. Unclear why this happens.
68 | # So we introduce a small step margin to ensure that we don't log at the same time metrics and what we log here.
69 | delta_step = 2
70 | do_log &= (global_step - delta_step) % self.log_every_n_train_steps == 0
71 | if not do_log:
72 | return
73 |
74 | if self._dataset_type is None:
75 | self._dataset_type = batch[DataLoading.DATASET_TYPE][0].cpu().item()
76 |
77 | logger: WandbLogger = trainer.logger
78 |
79 | if isinstance(outputs['gt'], list):
80 | flow_gt = [x.detach().cpu() for x in outputs['gt']]
81 | else:
82 | flow_gt = outputs['gt'].detach().cpu()
83 | flow_pred = outputs['pred'].detach().cpu()
84 | flow_valid = outputs['gt_valid'].detach().cpu() if 'gt_valid' in outputs else None
85 |
86 | ev_repr_reduced = None
87 | if 'ev_repr_reduced' in outputs.keys():
88 | ev_repr_reduced = outputs['ev_repr_reduced'].detach().cpu()
89 | ev_repr_reduced_m1 = None
90 | if 'ev_repr_reduced_m1' in outputs.keys():
91 | ev_repr_reduced_m1 = outputs['ev_repr_reduced_m1'].detach().cpu()
92 | images = None
93 | if 'images' in outputs.keys():
94 | images = [x.detach().cpu() for x in outputs['images']]
95 |
96 | summary_img = create_summary_img(
97 | flow_pred,
98 | flow_gt[-1] if isinstance(flow_gt, list) else flow_gt,
99 | valid_mask=flow_valid,
100 | ev_repr_reduced=ev_repr_reduced,
101 | ev_repr_reduced_m1=ev_repr_reduced_m1,
102 | images=images,
103 | max_error=self.MAX_FLOW_ERROR[self._dataset_type])
104 | wandb_flow_img = wandb.Image(summary_img)
105 | logger.log_metrics({'train/flow': wandb_flow_img}, step=global_step)
106 |
107 | if 'bezier_prediction' in outputs.keys():
108 | bezier_prediction: BezierCurves
109 | bezier_prediction = outputs['bezier_prediction'].detach(cpu=True)
110 | assert not bezier_prediction.requires_grad
111 |
112 | if images is not None:
113 | # images[0] is assumed to be the reference image
114 | background_img = img_torch_to_numpy(images[0])
115 | else:
116 | assert ev_repr_reduced is not None
117 | background_img = ev_repr_reduced_to_img_grayscale(ev_repr_reduced)
118 |
119 | bezier_img = multi_plot_bezier_array(
120 | bezier_prediction,
121 | background_img,
122 | multi_flow_gt= flow_gt if isinstance(flow_gt, list) else [flow_gt], # Could also put [flow_gt] if not list.
123 | num_t=10,
124 | x_add_margin=30,
125 | y_add_margin=30)
126 | wandb_bezier_img = wandb.Image(bezier_img)
127 | logger.log_metrics({'train/bezier': wandb_bezier_img}, step=global_step)
128 |
129 | def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
130 | global_step = trainer.global_step
131 | if global_step % self.log_every_n_train_steps != 0:
132 | return
133 | named_parameters = pl_module.named_parameters()
134 | figure = get_grad_flow_figure(named_parameters)
135 | trainer.logger.log_metrics({'train/gradients': figure}, step=global_step)
136 |
137 | def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
138 | if self.log_only_numbers:
139 | return
140 | if not self._training_started:
141 | # Don't log before the training started.
142 | # In PL, there is a validation sanity check.
143 | return
144 | if self._val_batch_indices is None:
145 | val_batch_indices = self._set_val_batch_indices()
146 | self._subsample_val_data(val_batch_indices)
147 |
148 | flow_gt = self._val_data[self.FLOW_GT]
149 | flow_pred = self._val_data[self.FLOW_PRED]
150 | flow_valid = self._val_data[self.FLOW_VALID]
151 | ev_repr_reduced = self._val_data[self.EV_REPR_REDUCED]
152 | ev_repr_reduced_m1 = self._val_data[self.EV_REPR_REDUCED_M1]
153 | bezier_params = self._val_data[self.BEZIER_PARAMS]
154 | images = self._val_data[self.IMAGES]
155 |
156 | if flow_gt:
157 | assert flow_pred
158 | # Stack in batch dimension from list for make_grid
159 | if isinstance(flow_gt[0], list):
160 | # [[(2, H, W), ...], ...]
161 | # outer: val batch indices
162 | # inner: number of gt samples in time
163 | # We only want the inner loop but batched
164 | # This is the same as transposing the list -> use pytorch because we can just use their transpose
165 | # V: number of val samples
166 | # T: number of gt samples in time
167 |
168 | # V, T, 2, H, W
169 | new_flow_gt = torch.stack([torch.stack([single_gt_map for single_gt_map in val_sample]) for val_sample in flow_gt])
170 | # V, T, 2, H, W -> T, V, 2, H, W
171 | new_flow_gt = torch.transpose(new_flow_gt, 0, 1)
172 | new_flow_gt = torch.split(new_flow_gt, [1]*new_flow_gt.shape[0], dim=0)
173 | # [(V, 2, H, W), ... ] T times :) the last item is the predictions for the final prediction
174 | flow_gt = [x.squeeze() for x in new_flow_gt]
175 | else:
176 | assert isinstance(flow_gt[0], torch.Tensor)
177 | flow_gt = torch.stack(flow_gt)
178 | flow_pred = torch.stack(flow_pred)
179 | flow_valid = torch.stack(flow_valid) if len(flow_valid) > 0 else None
180 | if len(ev_repr_reduced) == 0:
181 | ev_repr_reduced = None
182 | else:
183 | ev_repr_reduced = torch.stack(ev_repr_reduced)
184 | if len(ev_repr_reduced_m1) == 0:
185 | ev_repr_reduced_m1 = None
186 | else:
187 | ev_repr_reduced_m1 = torch.stack(ev_repr_reduced_m1)
188 | if len(images) == 0:
189 | images = None
190 | else:
191 | images = [torch.stack([x[0] for x in images]), torch.stack([x[1] for x in images])]
192 |
193 | summary_img = create_summary_img(
194 | flow_pred,
195 | flow_gt[-1] if isinstance(flow_gt, list) else flow_gt,
196 | flow_valid,
197 | ev_repr_reduced=ev_repr_reduced,
198 | ev_repr_reduced_m1=ev_repr_reduced_m1,
199 | images=images,
200 | max_error=self.MAX_FLOW_ERROR[self._dataset_type])
201 | wandb_flow_img = wandb.Image(summary_img)
202 |
203 | global_step = trainer.global_step
204 | logger: WandbLogger = trainer.logger
205 | logger.log_metrics({'val/flow': wandb_flow_img}, step=global_step)
206 |
207 | if len(bezier_params) > 0:
208 | bezier_params = torch.stack(bezier_params)
209 | bezier_curves = BezierCurves(bezier_params)
210 |
211 | if images is not None:
212 | # images[0] is assumed to be the reference image
213 | background_img = img_torch_to_numpy(images[0])
214 | else:
215 | assert ev_repr_reduced is not None
216 | background_img = ev_repr_reduced_to_img_grayscale(ev_repr_reduced)
217 |
218 | bezier_img = multi_plot_bezier_array(
219 | bezier_curves,
220 | background_img,
221 | multi_flow_gt= flow_gt if isinstance(flow_gt, list) else [flow_gt], # Could also put [flow_gt] if not list.
222 | num_t=10,
223 | x_add_margin=30,
224 | y_add_margin=30)
225 | wandb_bezier_img = wandb.Image(bezier_img)
226 | logger.log_metrics({'val/bezier': wandb_bezier_img}, step=global_step)
227 |
228 | self._clear_val_data()
229 |
230 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
231 | if self.log_only_numbers:
232 | return
233 | # NOTE: How to resolve the growing memory issue throughout the validation run?
234 | # A hack would be to only return what we want to log in the validation_step function of the LightningModule,
235 | # and then log the data in the validation_epoch_end function.
236 | if not self._training_started:
237 | # Don't log before the training started.
238 | # In PL, there is a validation sanity check.
239 | return
240 | if self._dataset_type is None:
241 | self._dataset_type = batch[DataLoading.DATASET_TYPE][0].cpu().item()
242 | if self._val_batch_indices is not None:
243 | # NOTE: For the first validation run, we still save everything which can lead to crashes due to full RAM.
244 | # Once we have set the validation batch indices,
245 | # we only save data from those.
246 | if batch_idx not in self._val_batch_indices:
247 | return
248 |
249 | if isinstance(outputs['gt'], list):
250 | flow_gt = [x[0].cpu() for x in outputs['gt']]
251 | # -> list(list(tensors)) will available after end of validation epoch
252 | else:
253 | flow_gt = outputs['gt'][0].cpu()
254 | # -> list(tensors) will available after end of validation epoch
255 | flow_pred = outputs['pred'][0].cpu()
256 | flow_valid = outputs['gt_valid'][0].cpu() if 'gt_valid' in outputs else None
257 | ev_repr_reduced = None
258 | if 'ev_repr_reduced' in outputs.keys():
259 | ev_repr_reduced = outputs['ev_repr_reduced'][0].cpu()
260 | ev_repr_reduced_m1 = None
261 | if 'ev_repr_reduced_m1' in outputs.keys():
262 | ev_repr_reduced_m1 = outputs['ev_repr_reduced_m1'][0].cpu()
263 | images = None
264 | if 'images' in outputs.keys():
265 | images = [x[0].detach().cpu() for x in outputs['images']]
266 | assert len(images) == 2
267 |
268 | bezier_params = None
269 | if 'bezier_prediction' in outputs.keys():
270 | bezier_prediction: BezierCurves = outputs['bezier_prediction']
271 | assert not bezier_prediction.requires_grad
272 | bezier_params = bezier_prediction.get_params()[0].cpu()
273 |
274 | self._save_val_data(self.FLOW_GT, flow_gt)
275 | self._save_val_data(self.FLOW_PRED, flow_pred)
276 | if flow_valid is not None:
277 | self._save_val_data(self.FLOW_VALID, flow_valid)
278 | if ev_repr_reduced is not None:
279 | self._save_val_data(self.EV_REPR_REDUCED, ev_repr_reduced)
280 | if ev_repr_reduced_m1 is not None:
281 | self._save_val_data(self.EV_REPR_REDUCED_M1, ev_repr_reduced_m1)
282 | if bezier_params is not None:
283 | self._save_val_data(self.BEZIER_PARAMS, bezier_params)
284 | if images is not None:
285 | self._save_val_data(self.IMAGES, images)
286 | self._save_val_data(self.VAL_BATCH_IDX, batch_idx)
287 |
288 | def _set_val_batch_indices(self):
289 | val_indices = self._val_data[self.VAL_BATCH_IDX]
290 | assert val_indices
291 | num_samples = min(len(val_indices), self.log_n_val_predictions)
292 |
293 | if self.deterministic:
294 | random.seed(0)
295 | sampled_indices = random.sample(val_indices, num_samples)
296 | self._val_batch_indices = set(sampled_indices)
297 | return self._val_batch_indices
298 |
299 | def _clear_val_data(self):
300 | self._val_data = {
301 | self.FLOW_GT: list(),
302 | self.FLOW_PRED: list(),
303 | self.FLOW_VALID: list(),
304 | self.EV_REPR_REDUCED: list(),
305 | self.EV_REPR_REDUCED_M1: list(),
306 | self.VAL_BATCH_IDX: list(),
307 | self.BEZIER_PARAMS: list(),
308 | self.IMAGES: list(),
309 | }
310 |
311 | def _subsample_val_data(self, val_indices: Union[list, set]):
312 | for k, v in self._val_data.items():
313 | if k == self.VAL_BATCH_IDX:
314 | continue
315 | assert isinstance(v, list)
316 | subsampled_list = list()
317 | for idx, x in enumerate(v):
318 | if idx not in val_indices:
319 | continue
320 | assert is_cpu(x)
321 | subsampled_list.append(x)
322 | self._val_data[k] = subsampled_list
323 | self._val_data[self.VAL_BATCH_IDX] = list(val_indices)
324 |
325 | def _save_val_data(self, key: str, data: Union[torch.Tensor, int, float, List[torch.Tensor]]):
326 | assert key in self._val_data.keys()
327 | if isinstance(data, torch.Tensor) or isinstance(data, list):
328 | assert is_cpu(data)
329 | self._val_data[key].append(data)
330 |
--------------------------------------------------------------------------------
/callbacks/utils/flow_vis.py:
--------------------------------------------------------------------------------
1 | # From: https://github.com/tomrunia/OpticalFlow_Visualization
2 | #
3 | # MIT License
4 | #
5 | # Copyright (c) 2018 Tom Runia
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to conditions.
13 | #
14 | # Author: Tom Runia
15 | # Date Created: 2018-08-03
16 |
17 | import numpy as np
18 |
19 | def make_colorwheel():
20 | """
21 | Generates a color wheel for optical flow visualization as presented in:
22 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
23 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
24 | Code follows the original C++ source code of Daniel Scharstein.
25 | Code follows the the Matlab source code of Deqing Sun.
26 | Returns:
27 | np.ndarray: Color wheel
28 | """
29 |
30 | RY = 15
31 | YG = 6
32 | GC = 4
33 | CB = 11
34 | BM = 13
35 | MR = 6
36 |
37 | ncols = RY + YG + GC + CB + BM + MR
38 | colorwheel = np.zeros((ncols, 3))
39 | col = 0
40 |
41 | # RY
42 | colorwheel[0:RY, 0] = 255
43 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
44 | col = col+RY
45 | # YG
46 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
47 | colorwheel[col:col+YG, 1] = 255
48 | col = col+YG
49 | # GC
50 | colorwheel[col:col+GC, 1] = 255
51 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
52 | col = col+GC
53 | # CB
54 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
55 | colorwheel[col:col+CB, 2] = 255
56 | col = col+CB
57 | # BM
58 | colorwheel[col:col+BM, 2] = 255
59 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
60 | col = col+BM
61 | # MR
62 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
63 | colorwheel[col:col+MR, 0] = 255
64 | return colorwheel
65 |
66 |
67 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
68 | """
69 | Applies the flow color wheel to (possibly clipped) flow components u and v.
70 | According to the C++ source code of Daniel Scharstein
71 | According to the Matlab source code of Deqing Sun
72 | Args:
73 | u (np.ndarray): Input horizontal flow of shape [H,W]
74 | v (np.ndarray): Input vertical flow of shape [H,W]
75 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
76 | Returns:
77 | np.ndarray: Flow visualization image of shape [H,W,3]
78 | """
79 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
80 | colorwheel = make_colorwheel() # shape [55x3]
81 | ncols = colorwheel.shape[0]
82 | rad = np.sqrt(np.square(u) + np.square(v))
83 | a = np.arctan2(-v, -u)/np.pi
84 | fk = (a+1) / 2*(ncols-1)
85 | k0 = np.floor(fk).astype(np.int32)
86 | k1 = k0 + 1
87 | k1[k1 == ncols] = 0
88 | f = fk - k0
89 | for i in range(colorwheel.shape[1]):
90 | tmp = colorwheel[:,i]
91 | col0 = tmp[k0] / 255.0
92 | col1 = tmp[k1] / 255.0
93 | col = (1-f)*col0 + f*col1
94 | idx = (rad <= 1)
95 | col[idx] = 1 - rad[idx] * (1-col[idx])
96 | col[~idx] = col[~idx] * 0.75 # out of range
97 | # Note the 2-i => BGR instead of RGB
98 | ch_idx = 2-i if convert_to_bgr else i
99 | flow_image[:,:,ch_idx] = np.floor(255 * col)
100 | return flow_image
101 |
102 |
103 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
104 | """
105 | Expects a two dimensional flow image of shape.
106 | Args:
107 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
108 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
109 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
110 | Returns:
111 | np.ndarray: Flow visualization image of shape [H,W,3]
112 | """
113 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
114 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
115 | if clip_flow is not None:
116 | flow_uv = np.clip(flow_uv, 0, clip_flow)
117 | u = flow_uv[:,:,0]
118 | v = flow_uv[:,:,1]
119 | rad = np.sqrt(np.square(u) + np.square(v))
120 | rad_max = np.max(rad)
121 | epsilon = 1e-5
122 | u = u / (rad_max + epsilon)
123 | v = v / (rad_max + epsilon)
124 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/config/dataset/base.yaml:
--------------------------------------------------------------------------------
1 | load_voxel_grid: True
2 | normalize_voxel_grid: True
3 | path: ???
4 | photo_augm: False
5 | return_ev: True
6 | return_img: True
--------------------------------------------------------------------------------
/config/dataset/dsec.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base
3 |
4 | name: dsec
5 | extended_voxel_grid: True # If events outside the borders should be included. May slightly increase the performance.
--------------------------------------------------------------------------------
/config/dataset/multiflow_regen.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base
3 |
4 | name: multiflow_regen
5 | extended_voxel_grid: True # If events outside the borders should be included. May slightly increase the performance.
6 | downsample: False
7 | flow_every_n_ms: 50
8 |
--------------------------------------------------------------------------------
/config/experiment/dsec/raft_spline/E_I_LU4_BD2_lowpyramid.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: raft-spline
4 |
5 | training:
6 | limit_val_batches: 0 # no validation
7 | multi_loss: false
8 | max_steps: 250000
9 | lr_scheduler:
10 | total_steps: ${..max_steps}
11 | model:
12 | bezier_degree: 2
13 | use_boundary_images: true
14 | use_events: true
15 | correlation:
16 | ev:
17 | target_indices: [1, 2, 3, 4]
18 | levels: [1, 1, 1, 4]
19 | radius: [4, 4, 4, 4]
20 | img:
21 | levels: 4
22 | radius: 4
23 |
--------------------------------------------------------------------------------
/config/experiment/dsec/raft_spline/E_LU4_BD2_lowpyramid.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: raft-spline
4 |
5 | training:
6 | limit_val_batches: 0 # no validation
7 | multi_loss: false
8 | max_steps: 250000
9 | lr_scheduler:
10 | total_steps: ${..max_steps}
11 | model:
12 | bezier_degree: 2
13 | use_boundary_images: false
14 | use_events: true
15 | correlation:
16 | ev:
17 | target_indices: [1, 2, 3, 4]
18 | levels: [1, 1, 1, 4]
19 | radius: [4, 4, 4, 4]
20 | img:
21 | levels: null
22 | radius: null
23 |
--------------------------------------------------------------------------------
/config/experiment/multiflow/raft_spline/E_I_LU5_BD10_lowpyramid.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: raft-spline
4 |
5 | training:
6 | multi_loss: true
7 | batch_size: 3
8 | learning_rate: 0.0001
9 | weight_decay: 0.0001
10 | lr_scheduler:
11 | use: true
12 | model:
13 | num_bins:
14 | context: 41
15 | correlation: 25
16 | bezier_degree: 10
17 | use_boundary_images: true
18 | use_events: true
19 | correlation:
20 | ev:
21 | #target_indices: [1, 2, 3, 4, 5] # for 6 context bins
22 | #target_indices: [2, 4, 6, 8, 10] # for 11 context bins
23 | target_indices: [8, 16, 24, 32, 40] # for 41 context bins
24 | levels: [1, 1, 1, 1, 4]
25 | radius: [4, 4, 4, 4, 4]
26 | img:
27 | levels: 4
28 | radius: 4
29 |
--------------------------------------------------------------------------------
/config/experiment/multiflow/raft_spline/E_LU5_BD10_lowpyramid.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: raft-spline
4 |
5 | training:
6 | multi_loss: true
7 | batch_size: 3
8 | learning_rate: 0.0001
9 | weight_decay: 0.0001
10 | lr_scheduler:
11 | use: true
12 | model:
13 | num_bins:
14 | context: 41
15 | correlation: 25
16 | bezier_degree: 10
17 | use_boundary_images: false
18 | use_events: true
19 | correlation:
20 | ev:
21 | #target_indices: [1, 2, 3, 4, 5] # for 6 context bins
22 | #target_indices: [2, 4, 6, 8, 10] # for 11 context bins
23 | target_indices: [8, 16, 24, 32, 40] # for 41 context bins
24 | levels: [1, 1, 1, 1, 4]
25 | radius: [4, 4, 4, 4, 4]
26 |
--------------------------------------------------------------------------------
/config/general.yaml:
--------------------------------------------------------------------------------
1 | training:
2 | multi_loss: true # ATTENTION: This can only be used for the MultiFlow
3 | batch_size: 3
4 | max_epochs: 1000
5 | max_steps: 200000
6 | learning_rate: 0.0001
7 | weight_decay: 0.0001
8 | gradient_clip_val: 1
9 | limit_train_batches: 1
10 | limit_val_batches: 1
11 | lr_scheduler:
12 | use: true
13 | total_steps: ${..max_steps}
14 | pct_start: 0.01
15 | hardware:
16 | num_workers: null # Number of workers. By default will take twice the maximum batch size
17 | gpus: 0 # Either a single integer (e.g. 3) or a list of integers (e.g. [3, 5, 6])
18 | logging:
19 | only_numbers: False
20 | ckpt_every_n_epochs: 1
21 | log_every_n_steps: 5000
22 | flush_logs_every_n_steps: 1000
23 | log_n_val_predictions: 2
24 | wandb:
25 | wandb_runpath: null # Specify WandB run path if you wish to resume from that run. E.g. magehrig/eRAFT/1lmicg6t
26 | artifact_runpath: null # Specify WandB run path if you wish to resume with a checkpoint/artifact from that run. Will take current wandb runpath if not specified
27 | artifact_name: null # Name of the checkpoint/artifact from which to resume. E.g. checkpoint-1ae609sb:v5
28 | resume_only_weights: False # If artifact is provided, you can choose to resume only the weights. Otherwise, the full training state is restored
29 | project_name: contflow # Specify group name of the run
30 | group_name: ??? # Specify group name of the run
31 | debugging:
32 | test_cpu_dataloading: False
33 | profiler: null # {None, simple, advanced, pytorch}
34 |
--------------------------------------------------------------------------------
/config/model/base.yaml:
--------------------------------------------------------------------------------
1 | num_bins:
2 | context: 5
3 |
--------------------------------------------------------------------------------
/config/model/raft-spline.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - raft_base
3 |
4 | name: raft-spline
5 | detach_bezier: false
6 | bezier_degree: 2
7 | use_boundary_images: true
8 | use_events: true
9 | use_gma: false
10 | correlation:
11 | ev:
12 | target_indices: [1, 2, 3, 4] # 0 idx is the reference. num_bins_context - 1 is the maximum idx.
13 | levels: [1, 2, 3, 4] # Number of pyramid levels. Must have the same length as target_indices.
14 | radius: [4, 4, 4, 4] # Look-up radius. Must have the same length as target_indices.
15 | img:
16 | levels: 4
17 | radius: 4
--------------------------------------------------------------------------------
/config/model/raft_base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base
3 |
4 | num_bins:
5 | correlation: null
6 | num_iter:
7 | train: 12
8 | test: 12
9 | correlation:
10 | use_cosine_sim: false
11 | hidden:
12 | dim: 128
13 | context:
14 | dim: 128
15 | norm: batch
16 | feature:
17 | dim: 256
18 | norm: instance
19 | motion:
20 | dim: 128
21 |
--------------------------------------------------------------------------------
/config/train.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - general
3 | - dataset: ???
4 | - model: ???
--------------------------------------------------------------------------------
/config/val.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: ???
3 | - model: ???
4 | - _self_
5 |
6 | checkpoint: ???
7 | hardware:
8 | num_workers: 4
9 | gpus: 0 # GPU idx (multi-gpu not supported for validation)
10 | batch_size: 8
--------------------------------------------------------------------------------
/data/dsec/eventslicer.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, Tuple
3 |
4 | import h5py
5 | from numba import jit
6 | import numpy as np
7 |
8 |
9 | class EventSlicer:
10 | def __init__(self, h5f: h5py.File):
11 | self.h5f = h5f
12 |
13 | self.events = dict()
14 | for dset_str in ['p', 'x', 'y', 't']:
15 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)]
16 |
17 | # This is the mapping from milliseconds to event index:
18 | # It is defined such that
19 | # (1) t[ms_to_idx[ms]] >= ms*1000
20 | # (2) t[ms_to_idx[ms] - 1] < ms*1000
21 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds.
22 | #
23 | # As an example, given 't' and 'ms':
24 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000
25 | # ms: 0 1 2 3 4 5 6 7 8 9
26 | #
27 | # we get
28 | #
29 | # ms_to_idx:
30 | # 0 2 2 3 3 3 5 5 8 9
31 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64')
32 |
33 | self.t_offset = int(h5f['t_offset'][()])
34 | self.t_final = int(self.events['t'][-1]) + self.t_offset
35 |
36 | def get_start_time_us(self):
37 | return self.t_offset
38 |
39 | def get_final_time_us(self):
40 | return self.t_final
41 |
42 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]:
43 | """Get events (p, x, y, t) within the specified time window
44 | Parameters
45 | ----------
46 | t_start_us: start time in microseconds
47 | t_end_us: end time in microseconds
48 | Returns
49 | -------
50 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved
51 | """
52 | assert t_start_us < t_end_us
53 |
54 | # We assume that the times are top-off-day, hence subtract offset:
55 | t_start_us -= self.t_offset
56 | t_end_us -= self.t_offset
57 |
58 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us)
59 | t_start_ms_idx = self.ms2idx(t_start_ms)
60 | t_end_ms_idx = self.ms2idx(t_end_ms)
61 |
62 | if t_start_ms_idx is None or t_end_ms_idx is None:
63 | # Cannot guarantee window size anymore
64 | return None
65 |
66 | events = dict()
67 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx])
68 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us)
69 | t_start_us_idx = t_start_ms_idx + idx_start_offset
70 | t_end_us_idx = t_start_ms_idx + idx_end_offset
71 | # Again add t_offset to get gps time
72 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset
73 | for dset_str in ['p', 'x', 'y']:
74 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx])
75 | assert events[dset_str].size == events['t'].size
76 | return events
77 |
78 |
79 | @staticmethod
80 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]:
81 | """Compute a conservative time window of time with millisecond resolution.
82 | We have a time to index mapping for each millisecond. Hence, we need
83 | to compute the lower and upper millisecond to retrieve events.
84 | Parameters
85 | ----------
86 | ts_start_us: start time in microseconds
87 | ts_end_us: end time in microseconds
88 | Returns
89 | -------
90 | window_start_ms: conservative start time in milliseconds
91 | window_end_ms: conservative end time in milliseconds
92 | """
93 | assert ts_end_us > ts_start_us
94 | window_start_ms = math.floor(ts_start_us/1000)
95 | window_end_ms = math.ceil(ts_end_us/1000)
96 | return window_start_ms, window_end_ms
97 |
98 | @staticmethod
99 | @jit(nopython=True)
100 | def get_time_indices_offsets(
101 | time_array: np.ndarray,
102 | time_start_us: int,
103 | time_end_us: int) -> Tuple[int, int]:
104 | """Compute index offset of start and end timestamps in microseconds
105 | Parameters
106 | ----------
107 | time_array: timestamps (in us) of the events
108 | time_start_us: start timestamp (in us)
109 | time_end_us: end timestamp (in us)
110 | Returns
111 | -------
112 | idx_start: Index within this array corresponding to time_start_us
113 | idx_end: Index within this array corresponding to time_end_us
114 | such that (in non-edge cases)
115 | time_array[idx_start] >= time_start_us
116 | time_array[idx_end] >= time_end_us
117 | time_array[idx_start - 1] < time_start_us
118 | time_array[idx_end - 1] < time_end_us
119 | this means that
120 | time_start_us <= time_array[idx_start:idx_end] < time_end_us
121 | """
122 |
123 | assert time_array.ndim == 1
124 |
125 | idx_start = -1
126 | if time_array[-1] < time_start_us:
127 | # This can happen in extreme corner cases. E.g.
128 | # time_array[0] = 1016
129 | # time_array[-1] = 1984
130 | # time_start_us = 1990
131 | # time_end_us = 2000
132 |
133 | # Return same index twice: array[x:x] is empty.
134 | return time_array.size, time_array.size
135 | else:
136 | # TODO(magehrig): use binary search for speedup
137 | for idx_from_start in range(0, time_array.size, 1):
138 | if time_array[idx_from_start] >= time_start_us:
139 | idx_start = idx_from_start
140 | break
141 | assert idx_start >= 0
142 |
143 | idx_end = time_array.size
144 | # TODO(magehrig): use binary search for speedup
145 | for idx_from_end in range(time_array.size - 1, -1, -1):
146 | if time_array[idx_from_end] >= time_end_us:
147 | idx_end = idx_from_end
148 | else:
149 | break
150 |
151 | assert time_array[idx_start] >= time_start_us
152 | if idx_end < time_array.size:
153 | assert time_array[idx_end] >= time_end_us
154 | if idx_start > 0:
155 | assert time_array[idx_start - 1] < time_start_us
156 | if idx_end > 0:
157 | assert time_array[idx_end - 1] < time_end_us
158 | return idx_start, idx_end
159 |
160 | def ms2idx(self, time_ms: int) -> int:
161 | assert time_ms >= 0
162 | if time_ms >= self.ms_to_idx.size:
163 | return None
164 | return self.ms_to_idx[time_ms]
165 |
--------------------------------------------------------------------------------
/data/dsec/provider.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from pathlib import Path
3 | from typing import Dict, Any
4 |
5 | import torch
6 |
7 | from data.dsec.sequence import generate_sequence
8 | from data.dsec.subsequence.twostep import TwoStepSubSequence
9 | from data.utils.provider import DatasetProviderBase
10 |
11 |
12 | class DatasetProvider(DatasetProviderBase):
13 | def __init__(self,
14 | dataset_params: Dict[str, Any],
15 | nbins_context: int):
16 | dataset_path = Path(dataset_params['path'])
17 |
18 | train_path = dataset_path / 'train'
19 | test_path = dataset_path / 'test'
20 | assert dataset_path.is_dir(), str(dataset_path)
21 | assert train_path.is_dir(), str(train_path)
22 | assert test_path.is_dir(), str(test_path)
23 |
24 | # NOTE: For now, we assume that the number of bins for the correlation are the same as for the context.
25 | self.nbins = nbins_context
26 |
27 | subseq_class = TwoStepSubSequence
28 |
29 | base_args = {
30 | 'num_bins': self.nbins,
31 | 'load_voxel_grid': dataset_params['load_voxel_grid'],
32 | 'extended_voxel_grid': dataset_params['extended_voxel_grid'],
33 | 'normalize_voxel_grid': dataset_params['normalize_voxel_grid'],
34 | }
35 | base_args.update({'merge_grids': True})
36 | train_args = copy.deepcopy(base_args)
37 | train_args.update({'data_augm': True})
38 | test_args = copy.deepcopy(base_args)
39 | test_args.update({'data_augm': False})
40 |
41 | train_sequences = list()
42 | for child in train_path.iterdir():
43 | sequence = generate_sequence(child, subseq_class, train_args)
44 | if sequence is not None:
45 | train_sequences.append(sequence)
46 |
47 | self.train_dataset = torch.utils.data.ConcatDataset(train_sequences)
48 |
49 | # TODO: write specialized test sequence
50 | #test_sequences = list()
51 | #for child in test_path.iterdir():
52 | # sequence = generate_test_sequence(child, subseq_class, test_args)
53 | # if sequence is not None:
54 | # test_sequences.append(sequence)
55 | #self.test_dataset = torch.utils.data.ConcatDataset(test_sequences)
56 | self.test_dataset = None
57 |
58 | def get_train_dataset(self):
59 | return self.train_dataset
60 |
61 | def get_val_dataset(self):
62 | raise NotImplementedError
63 |
64 | def get_test_dataset(self):
65 | return self.test_dataset
66 |
67 | def get_nbins_context(self):
68 | return self.nbins
69 |
70 | def get_nbins_correlation(self):
71 | return self.nbins
72 |
--------------------------------------------------------------------------------
/data/dsec/sequence.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Type
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from data.dsec.subsequence.base import BaseSubSequence
8 |
9 |
10 | class SubSequenceGenerator:
11 | # seq_name (e.g. zurich_city_10_a)
12 | # ├── flow
13 | # │ ├── backward (for train only)
14 | # │ │ ├── xxxxxx.png
15 | # │ │ └── ...
16 | # │ ├── backward_timestamps.txt (for train only)
17 | # │ ├── forward (for train only)
18 | # │ │ ├── xxxxxx.png
19 | # │ │ └── ...
20 | # │ └── forward_timestamps.txt (for train and test)
21 | # └── events
22 | # ├── left
23 | # │ ├── events.h5
24 | # │ └── rectify_map.h5
25 | # └── right
26 | # ├── events.h5
27 | # └── rectify_map.h5
28 | def __init__(self,
29 | seq_path: Path,
30 | subseq_class: Type[BaseSubSequence],
31 | args: dict):
32 |
33 | self.args = args
34 | self.seq_path = seq_path
35 |
36 | self.seqseq_class = subseq_class
37 |
38 | # load forward optical flow timestamps
39 | assert sequence_has_flow(seq_path), seq_path
40 | flow_dir = seq_path / 'flow'
41 | assert flow_dir.is_dir(), str(flow_dir)
42 | forward_timestamps_file = flow_dir / 'forward_timestamps.txt'
43 | assert forward_timestamps_file.is_file()
44 | self.forward_flow_timestamps = np.loadtxt(str(forward_timestamps_file), dtype='int64', delimiter=',')
45 | assert self.forward_flow_timestamps.ndim == 2
46 | assert self.forward_flow_timestamps.shape[1] == 2
47 |
48 | # load forward optical flow paths
49 | forward_flow_dir = flow_dir / 'forward'
50 | assert forward_flow_dir.is_dir()
51 | forward_flow_list = list()
52 | for entry in forward_flow_dir.iterdir():
53 | assert str(entry.name).endswith('.png')
54 | forward_flow_list.append(entry)
55 | forward_flow_list.sort()
56 | self.forward_flow_list = forward_flow_list
57 |
58 | # Extract start indices of sub-sequences
59 | from_ts = self.forward_flow_timestamps[:, 0]
60 | to_ts = self.forward_flow_timestamps[:, 1]
61 |
62 | is_start_subseq = from_ts[1:] != to_ts[:-1]
63 | # Add first index as start index too.
64 | is_start_subseq = np.concatenate((np.array((True,), dtype='bool'), is_start_subseq))
65 | self.start_indices = list(np.where(is_start_subseq)[0])
66 |
67 | self.subseq_idx = 0
68 |
69 | def __enter__(self):
70 | return self
71 |
72 | def __iter__(self):
73 | return self
74 |
75 | def __len__(self):
76 | return len(self.start_indices)
77 |
78 | def __next__(self) -> BaseSubSequence:
79 | if self.subseq_idx >= len(self.start_indices):
80 | raise StopIteration
81 | final_subseq = self.subseq_idx == len(self.start_indices) - 1
82 |
83 | start_idx = self.start_indices[self.subseq_idx]
84 | end_p1_idx = None if final_subseq else self.start_indices[self.subseq_idx + 1]
85 |
86 | forward_flow_timestamps = self.forward_flow_timestamps[start_idx:end_p1_idx, :]
87 | forward_flow_list = self.forward_flow_list[start_idx:end_p1_idx]
88 |
89 | self.subseq_idx += 1
90 |
91 | return self.seqseq_class(self.seq_path, forward_flow_timestamps, forward_flow_list, **self.args)
92 |
93 |
94 | def sequence_has_flow(seq_path: Path):
95 | return (seq_path / 'flow').is_dir()
96 |
97 |
98 | def generate_sequence(
99 | seq_path: Path,
100 | subseq_class: Type[BaseSubSequence],
101 | args: dict):
102 | if not sequence_has_flow(seq_path):
103 | return None
104 |
105 | subseq_list = list()
106 |
107 | for subseq in SubSequenceGenerator(seq_path, subseq_class, args):
108 | subseq_list.append(subseq)
109 |
110 | return torch.utils.data.ConcatDataset(subseq_list)
111 |
--------------------------------------------------------------------------------
/data/dsec/subsequence/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import List, Optional
4 | import weakref
5 |
6 | import h5py
7 | import imageio as iio
8 | import numpy as np
9 | from PIL import Image
10 | import torch
11 | from torch.utils.data import Dataset
12 |
13 | from data.dsec.eventslicer import EventSlicer
14 | from data.utils.augmentor import FlowAugmentor
15 | from data.utils.generic import np_array_to_h5, h5_to_np_array
16 | from data.utils.representations import VoxelGrid, norm_voxel_grid
17 |
18 |
19 | class BaseSubSequence(Dataset):
20 | # seq_name (e.g. zurich_city_10_a)
21 | # ├── flow
22 | # │ ├── backward (for train only)
23 | # │ │ ├── xxxxxx.png
24 | # │ │ └── ...
25 | # │ ├── backward_timestamps.txt (for train only)
26 | # │ ├── forward (for train only)
27 | # │ │ ├── xxxxxx.png
28 | # │ │ └── ...
29 | # │ └── forward_timestamps.txt (for train and test)
30 | # └── events
31 | # ├── left
32 | # │ ├── events.h5
33 | # │ └── rectify_map.h5
34 | # └── right
35 | # ├── events.h5
36 | # └── rectify_map.h5
37 | #
38 | # For now this class
39 | # - only returns forward flow and the corresponding event representation (of the left event camera)
40 | # - does not implement recurrent loading of data
41 | # - only returns a forward rectified voxel grid with a fixed number of bins
42 | # - is not implemented to offer multi-dataset training schedules (like in RAFT)
43 | def __init__(self,
44 | seq_path: Path,
45 | forward_flow_timestamps: np.ndarray,
46 | forward_flow_paths: List[Path],
47 | data_augm: bool,
48 | num_bins: int=15,
49 | load_voxel_grid: bool=True,
50 | extended_voxel_grid: bool=True,
51 | normalize_voxel_grid: bool=False):
52 | assert num_bins >= 1
53 | assert seq_path.is_dir()
54 |
55 | # Save output dimensions
56 | self.height = 480
57 | self.width = 640
58 | self.num_bins = num_bins
59 |
60 | self.augmentor = FlowAugmentor(crop_size_hw=(288, 384)) if data_augm else None
61 |
62 | # Set event representation
63 | self.voxel_grid = VoxelGrid(self.num_bins, self.height, self.width)
64 | self.normalize_voxel_grid: Optional[norm_voxel_grid] = norm_voxel_grid if normalize_voxel_grid else None
65 |
66 | assert len(forward_flow_paths) == forward_flow_timestamps.shape[0]
67 |
68 | self.forward_flow_timestamps = forward_flow_timestamps
69 |
70 | for entry in forward_flow_paths:
71 | assert entry.exists()
72 | assert str(entry.name).endswith('.png')
73 | self.forward_flow_list = forward_flow_paths
74 |
75 | ### prepare loading of event data and load rectification map
76 | self.ev_dir = seq_path / 'events' / 'left'
77 | assert self.ev_dir.is_dir()
78 | self.ev_file = self.ev_dir / 'events.h5'
79 | assert self.ev_file.exists()
80 |
81 | rectify_events_map_file = self.ev_dir / 'rectify_map.h5'
82 | assert rectify_events_map_file.exists()
83 | with h5py.File(str(rectify_events_map_file), 'r') as h5_rect:
84 | self.rectify_events_map = h5_rect['rectify_map'][()]
85 |
86 | self.h5f: Optional[h5py.File] = None
87 | self.event_slicer: Optional[EventSlicer] = None
88 | self.h5f_opened = False
89 |
90 | ### prepare loading of image data (in left event camera frame)
91 | img_dir_ev_left = seq_path / 'images' / 'left' / 'ev_inf'
92 | self.img_dir_ev_left = None if not img_dir_ev_left.is_dir() else img_dir_ev_left
93 |
94 | ### Voxel Grid Saving
95 | # Version 0: Without considering the boundary properly but strictly causal
96 | # Version 1: Considering the boundary effect but loads a few events "of the future"
97 | self.version = 1 if extended_voxel_grid else 0
98 | self.voxel_grid_dir = self.ev_dir / f'voxel_grids_v{self.version}_100ms_forward_{num_bins}_bins'
99 | self.load_voxel_grid = load_voxel_grid
100 | if self.load_voxel_grid:
101 | if not self.voxel_grid_dir.exists():
102 | os.mkdir(self.voxel_grid_dir)
103 | else:
104 | assert self.voxel_grid_dir.is_dir()
105 |
106 | def __open_h5f(self):
107 | assert self.h5f is None
108 | assert self.event_slicer is None
109 |
110 | self.h5f = h5py.File(str(self.ev_file), 'r')
111 | self.event_slicer = EventSlicer(self.h5f)
112 |
113 | self._finalizer = weakref.finalize(self, self.__close_callback, self.h5f)
114 | self.h5f_opened = True
115 |
116 | def _events_to_voxel_grid(self, x, y, p, t, t0_center: int=None, t1_center: int=None):
117 | t = t.astype('int64')
118 | x = x.astype('float32')
119 | y = y.astype('float32')
120 | pol = p.astype('float32')
121 | return self.voxel_grid.convert(
122 | torch.from_numpy(x),
123 | torch.from_numpy(y),
124 | torch.from_numpy(pol),
125 | torch.from_numpy(t),
126 | t0_center,
127 | t1_center)
128 |
129 | def getHeightAndWidth(self):
130 | return self.height, self.width
131 |
132 | @staticmethod
133 | def __close_callback(h5f: h5py.File):
134 | assert h5f is not None
135 | h5f.close()
136 |
137 | def _rectify_events(self, x: np.ndarray, y: np.ndarray):
138 | # From distorted to undistorted
139 | rectify_map = self.rectify_events_map
140 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape
141 | assert x.max() < self.width
142 | assert y.max() < self.height
143 | return rectify_map[y, x]
144 |
145 | def _get_ev_left_img(self, img_file_idx: int):
146 | if self.img_dir_ev_left is None:
147 | return None
148 | img_filename = f'{img_file_idx}'.zfill(6) + '.png'
149 | img_filepath = self.img_dir_ev_left / img_filename
150 | if not img_filepath.exists():
151 | return None
152 | img = np.asarray(iio.imread(str(img_filepath), format='PNG-FI'))
153 | # img: (h, w, c) -> (c, h, w)
154 | img = np.moveaxis(img, -1, 0)
155 | return img
156 |
157 | def _get_events(self, ts_from: int, ts_to: int, rectify: bool):
158 | if not self.h5f_opened:
159 | self.__open_h5f()
160 |
161 | start_time_us = self.event_slicer.get_start_time_us()
162 | final_time_us = self.event_slicer.get_final_time_us()
163 | assert ts_from > start_time_us - 50000, 'Do not request more than 50 ms before the minimum time. Otherwise, something might be wrong.'
164 | assert ts_to < final_time_us + 50000, 'Do not request more than 50 ms past the maximum time. Otherwise, something might be wrong.'
165 | if ts_from < start_time_us:
166 | # Take the minimum time instead to avoid assertions in the event slicer.
167 | ts_from = start_time_us
168 | if ts_to > final_time_us:
169 | # Take the maximum time instead to avoid assertions in the event slicer.
170 | ts_to = final_time_us
171 | assert ts_from < ts_to
172 |
173 | event_data = self.event_slicer.get_events(ts_from, ts_to)
174 |
175 | pol = event_data['p']
176 | time = event_data['t']
177 | x = event_data['x']
178 | y = event_data['y']
179 |
180 | if rectify:
181 | xy_rect = self._rectify_events(x, y)
182 | x = xy_rect[:, 0]
183 | y = xy_rect[:, 1]
184 | assert pol.shape == time.shape == x.shape == y.shape
185 |
186 | out = {
187 | 'pol': pol,
188 | 'time': time,
189 | 'x': x,
190 | 'y': y,
191 | }
192 | return out
193 |
194 | def _construct_voxel_grid(self, ts_from: int, ts_to: int, rectify: bool=True):
195 | if self.version == 1:
196 | t_start, t_end = self.voxel_grid.get_extended_time_window(ts_from, ts_to)
197 | assert (ts_from - t_start) < 50000, f'ts_from: {ts_from}, t_start: {t_start}'
198 | assert (t_end - ts_to) < 50000, f't_end: {t_end}, ts_to: {ts_to}'
199 | event_data = self._get_events(t_start, t_end, rectify=rectify)
200 | voxel_grid = self._events_to_voxel_grid(event_data['x'], event_data['y'], event_data['pol'], event_data['time'], ts_from, ts_to)
201 | return voxel_grid
202 | elif self.version == 0:
203 | event_data = self._get_events(ts_from, ts_to, rectify=rectify)
204 | voxel_grid = self._events_to_voxel_grid(event_data['x'], event_data['y'], event_data['pol'], event_data['time'])
205 | return voxel_grid
206 | raise NotImplementedError
207 |
208 | def _load_voxel_grid(self, ts_from: int, ts_to: int, file_index: int):
209 | assert file_index >= 0
210 |
211 | # Assuming we want to load the 'forward' voxel grid.
212 | voxel_grid_file = self.voxel_grid_dir / (f'{file_index}'.zfill(6) + '.h5')
213 | if not voxel_grid_file.exists():
214 | voxel_grid = self._construct_voxel_grid(ts_from, ts_to)
215 | np_array_to_h5(voxel_grid.numpy(), voxel_grid_file)
216 | return voxel_grid
217 | return torch.from_numpy(h5_to_np_array(voxel_grid_file))
218 |
219 | def _get_voxel_grid(self, ts_from: int, ts_to: int, file_index: int):
220 | if self.load_voxel_grid:
221 | return self._load_voxel_grid(ts_from, ts_to, file_index)
222 | return self._construct_voxel_grid(ts_from, ts_to)
--------------------------------------------------------------------------------
/data/dsec/subsequence/twostep.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from data.dsec.subsequence.base import BaseSubSequence
8 | from data.utils.generic import load_flow
9 | from data.utils.keys import DataLoading, DataSetType
10 |
11 |
12 | class TwoStepSubSequence(BaseSubSequence):
13 | def __init__(self,
14 | seq_path: Path,
15 | forward_flow_timestamps: np.ndarray,
16 | forward_flow_paths: List[Path],
17 | data_augm: bool,
18 | num_bins: int,
19 | load_voxel_grid: bool,
20 | extended_voxel_grid: bool,
21 | normalize_voxel_grid: bool,
22 | merge_grids: bool):
23 | super().__init__(
24 | seq_path,
25 | forward_flow_timestamps,
26 | forward_flow_paths,
27 | data_augm,
28 | num_bins,
29 | load_voxel_grid,
30 | extended_voxel_grid=extended_voxel_grid,
31 | normalize_voxel_grid=normalize_voxel_grid)
32 | self.merge_grids = merge_grids
33 |
34 | def __len__(self):
35 | return len(self.forward_flow_list)
36 |
37 | def __getitem__(self, index):
38 | forward_flow_gt_path = self.forward_flow_list[index]
39 | flow_file_index = int(forward_flow_gt_path.stem)
40 | forward_flow, forward_flow_valid2D = load_flow(forward_flow_gt_path)
41 | # forward_flow: (h, w, 2) -> (2, h, w)
42 | forward_flow = np.moveaxis(forward_flow, -1, 0)
43 |
44 | # Events during the flow duration.
45 | ev_repr_list = list()
46 | ts_from = None
47 | ts_to = None
48 | for idx in [index, index - 1]:
49 | if self._is_index_valid(idx):
50 | ts = self.forward_flow_timestamps[idx]
51 | ts_from = ts[0]
52 | ts_to = ts[1]
53 | else:
54 | assert idx == index - 1
55 | assert ts_from is not None
56 | assert ts_to is not None
57 | dt = ts_to - ts_from
58 | ts_to = ts_from
59 | ts_from = ts_from - dt
60 | # Hardcoded assumption about 100ms steps and filenames.
61 | file_index = flow_file_index if idx == index else flow_file_index - 2
62 | ev_repr = self._get_voxel_grid(ts_from, ts_to, file_index)
63 | ev_repr_list.append(ev_repr)
64 |
65 | imgs_list = None
66 | img_idx_reference = flow_file_index
67 | img_reference = self._get_ev_left_img(img_idx_reference)
68 | if img_reference is not None:
69 | # Assume 100ms steps (take every second frame).
70 | # Assume forward flow (target frame in the future)
71 | img_idx_target = img_idx_reference + 2
72 | img_target = self._get_ev_left_img(img_idx_target)
73 | assert img_target is not None
74 | imgs_list = [img_reference, img_target]
75 |
76 | # 0: previous, 1: current
77 | ev_repr_list.reverse()
78 | if self.merge_grids:
79 | ev_repr_0 = ev_repr_list[0]
80 | ev_repr_1 = ev_repr_list[1]
81 | assert (ev_repr_0[-1] - ev_repr_1[0]).flatten().abs().max() < 0.5, f'{(ev_repr_0[-1] - ev_repr_1[0]).flatten().abs().max()}'
82 | # Remove the redundant temporal slice.
83 | event_representations = torch.cat((ev_repr_0, ev_repr_1[1:, ...]), dim=0)
84 | if self.normalize_voxel_grid is not None:
85 | event_representations = self.normalize_voxel_grid(event_representations)
86 | else:
87 | if self.normalize_voxel_grid is not None:
88 | ev_repr_list = [self.normalize_voxel_grid(voxel_grid) for voxel_grid in ev_repr_list]
89 | event_representations = torch.stack(ev_repr_list)
90 |
91 | if self.augmentor is not None:
92 | event_representations, forward_flow, forward_flow_valid2D, imgs_list = self.augmentor(event_representations, forward_flow, forward_flow_valid2D, imgs_list)
93 |
94 | output = {
95 | DataLoading.FLOW: forward_flow,
96 | DataLoading.FLOW_VALID: forward_flow_valid2D,
97 | DataLoading.FILE_INDEX: flow_file_index,
98 | # For now returns: 2 x bins x height x width or (2*bins-1) x height x width
99 | DataLoading.EV_REPR: event_representations,
100 | DataLoading.DATASET_TYPE: DataSetType.DSEC,
101 | }
102 | if imgs_list is not None:
103 | output.update({DataLoading.IMG: imgs_list})
104 |
105 | return output
106 |
107 | def _is_index_valid(self, index):
108 | return index >= 0 and index < len(self)
109 |
--------------------------------------------------------------------------------
/data/multiflow2d/datasubset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional, List
3 |
4 | from torch.utils.data import Dataset
5 |
6 | from data.utils.augmentor import FlowAugmentor, PhotoAugmentor
7 | from data.utils.keys import DataLoading, DataSetType
8 | from data.utils.representations import norm_voxel_grid
9 |
10 | from data.multiflow2d.sample import Sample
11 |
12 |
13 | class Datasubset(Dataset):
14 | def __init__(self,
15 | train_or_val_path: Path,
16 | data_augm: bool,
17 | num_bins_context: int,
18 | flow_every_n_ms: int,
19 | load_voxel_grid: bool=True,
20 | extended_voxel_grid: bool=True,
21 | normalize_voxel_grid: bool=False,
22 | downsample: bool=False,
23 | photo_augm: bool=False,
24 | return_img: bool=True,
25 | return_ev: bool=True):
26 | assert train_or_val_path.is_dir()
27 | assert train_or_val_path.name in ('train', 'val')
28 |
29 | # Save output dimensions
30 | original_height = 384
31 | original_width = 512
32 |
33 | crop_height = 368
34 | crop_width = 496
35 |
36 | self.return_img = return_img
37 | if not self.return_img:
38 | raise NotImplementedError
39 | self.return_ev = return_ev
40 |
41 | if downsample:
42 | crop_height = crop_height // 2
43 | crop_width = crop_width // 2
44 |
45 | self.delta_ts_flow_ms = flow_every_n_ms
46 |
47 | self.spatial_augmentor = FlowAugmentor(
48 | crop_size_hw=(crop_height, crop_width),
49 | h_flip_prob=0.5,
50 | v_flip_prob=0.5) if data_augm else None
51 | self.photo_augmentor = PhotoAugmentor(
52 | brightness=0.4,
53 | contrast=0.4,
54 | saturation=0.4,
55 | hue=0.5/3.14,
56 | probability_color=0.2,
57 | noise_variance_range=(0.001, 0.01),
58 | probability_noise=0.2) if data_augm and photo_augm else None
59 | self.normalize_voxel_grid: Optional[norm_voxel_grid] = norm_voxel_grid if normalize_voxel_grid else None
60 |
61 | sample_list: List[Sample] = list()
62 | for sample_path in train_or_val_path.iterdir():
63 | if not sample_path.is_dir():
64 | continue
65 | sample_list.append(
66 | Sample(sample_path, original_height, original_width, num_bins_context, load_voxel_grid, extended_voxel_grid, downsample)
67 | )
68 | self.sample_list = sample_list
69 |
70 | def get_num_bins_context(self):
71 | return self.sample_list[0].num_bins_context
72 |
73 | def get_num_bins_correlation(self):
74 | return self.sample_list[0].num_bins_correlation
75 |
76 | def get_num_bins_total(self):
77 | return self.sample_list[0].num_bins_total
78 |
79 | def _voxel_grid_bin_idx_for_reference(self) -> int:
80 | return self.sample_list[0].voxel_grid_bin_idx_for_reference()
81 |
82 | def __len__(self):
83 | return len(self.sample_list)
84 |
85 | def __getitem__(self, index):
86 | sample = self.sample_list[index]
87 |
88 | voxel_grid = sample.get_voxel_grid() if self.return_ev else None
89 | if voxel_grid is not None and self.normalize_voxel_grid is not None:
90 | voxel_grid = self.normalize_voxel_grid(voxel_grid)
91 |
92 | gt_flow_dict = sample.get_flow_gt(self.delta_ts_flow_ms)
93 | gt_flow = gt_flow_dict['flow']
94 | gt_flow_ts = gt_flow_dict['timestamps']
95 |
96 | imgs_with_ts = sample.get_images()
97 | imgs = imgs_with_ts['images']
98 | img_ts = imgs_with_ts['timestamps']
99 |
100 | # normalize image timestamps from 0 to 1
101 | assert len(img_ts) == 2
102 | ts_start = img_ts[0]
103 | ts_end = img_ts[1]
104 | assert ts_end > ts_start
105 | img_ts = [(x - ts_start)/(ts_end - ts_start) for x in img_ts]
106 | assert img_ts[0] == 0
107 | assert img_ts[1] == 1
108 |
109 | # we assume that img_ts[0] refers to reference time and img_ts[1] to the final target time
110 | gt_flow_ts = [(x - ts_start)/(ts_end - ts_start) for x in gt_flow_ts]
111 | assert gt_flow_ts[-1] == 1
112 | assert len(gt_flow_ts) == len(gt_flow)
113 |
114 | if self.spatial_augmentor is not None:
115 | if voxel_grid is None:
116 | gt_flow, imgs = self.spatial_augmentor(flow=gt_flow, images=imgs)
117 | else:
118 | voxel_grid, gt_flow, imgs = self.spatial_augmentor(ev_repr=voxel_grid, flow=gt_flow, images=imgs)
119 | if self.photo_augmentor is not None:
120 | imgs = self.photo_augmentor(imgs)
121 | out = {
122 | DataLoading.BIN_META: {
123 | 'bin_idx_for_reference': self._voxel_grid_bin_idx_for_reference(),
124 | 'nbins_context': self.get_num_bins_context(),
125 | 'nbins_correlation': self.get_num_bins_correlation(),
126 | 'nbins_total': self.get_num_bins_total(),
127 | },
128 | DataLoading.FLOW: gt_flow,
129 | DataLoading.FLOW_TIMESTAMPS: gt_flow_ts,
130 | DataLoading.IMG: imgs,
131 | DataLoading.IMG_TIMESTAMPS: img_ts,
132 | DataLoading.DATASET_TYPE: DataSetType.MULTIFLOW2D,
133 | }
134 | if voxel_grid is not None:
135 | out.update({DataLoading.EV_REPR: voxel_grid})
136 |
137 | return out
138 |
--------------------------------------------------------------------------------
/data/multiflow2d/provider.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from pathlib import Path
3 | from typing import Dict, Any
4 |
5 | import torch.utils.data
6 |
7 | from data.utils.provider import DatasetProviderBase
8 | from data.multiflow2d.datasubset import Datasubset
9 |
10 |
11 | class DatasetProvider(DatasetProviderBase):
12 | def __init__(self,
13 | dataset_params: Dict[str, Any],
14 | nbins_context: int):
15 | dataset_path = Path(dataset_params['path'])
16 | train_path = dataset_path / 'train'
17 | val_path = dataset_path / 'val'
18 | assert dataset_path.is_dir(), str(dataset_path)
19 | assert train_path.is_dir(), str(train_path)
20 | assert val_path.is_dir(), str(val_path)
21 |
22 | return_img = True
23 | return_img_key = 'return_img'
24 | if return_img_key in dataset_params:
25 | return_img = dataset_params[return_img_key]
26 | return_ev = True
27 | return_ev_key = 'return_ev'
28 | if return_ev_key in dataset_params:
29 | return_ev = dataset_params[return_ev_key]
30 | base_args = {
31 | 'num_bins_context': nbins_context,
32 | 'load_voxel_grid': dataset_params['load_voxel_grid'],
33 | 'normalize_voxel_grid': dataset_params['normalize_voxel_grid'],
34 | 'extended_voxel_grid': dataset_params['extended_voxel_grid'],
35 | 'flow_every_n_ms': dataset_params['flow_every_n_ms'],
36 | 'downsample': dataset_params['downsample'],
37 | 'photo_augm': dataset_params['photo_augm'],
38 | return_img_key: return_img,
39 | return_ev_key: return_ev,
40 | }
41 | train_args = copy.deepcopy(base_args)
42 | train_args.update({'data_augm': True})
43 | val_test_args = copy.deepcopy(base_args)
44 | val_test_args.update({'data_augm': False})
45 |
46 | train_dataset = Datasubset(train_path, **train_args)
47 | self.nbins_context = train_dataset.get_num_bins_context()
48 | self.nbins_correlation = train_dataset.get_num_bins_correlation()
49 |
50 | self.train_dataset = train_dataset
51 | self.val_dataset = Datasubset(val_path, **val_test_args)
52 | assert self.val_dataset.get_num_bins_context() == self.nbins_context
53 | assert self.val_dataset.get_num_bins_correlation() == self.nbins_correlation
54 |
55 | def get_train_dataset(self):
56 | return self.train_dataset
57 |
58 | def get_val_dataset(self):
59 | return self.val_dataset
60 |
61 | def get_test_dataset(self):
62 | raise NotImplementedError
63 |
64 | def get_nbins_context(self):
65 | return self.nbins_context
66 |
67 | def get_nbins_correlation(self):
68 | return self.nbins_correlation
69 |
--------------------------------------------------------------------------------
/data/multiflow2d/sample.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Dict
3 |
4 | import h5py
5 | import imageio as iio
6 | import numpy as np
7 | import torch
8 | from torch.nn.functional import interpolate
9 |
10 | from data.utils.generic import np_array_to_h5, h5_to_np_array
11 | from data.utils.representations import VoxelGrid
12 |
13 |
14 | class Sample:
15 | # Assumes the following structure:
16 | # seq*
17 | # ├── events
18 | # │ └── events.h5
19 | # ├── flow
20 | # │ ├── 0500000.h5
21 | # │ ├── ...
22 | # │ └── 0900000.h5
23 | # └── images
24 | # ├── 0400000.png
25 | # ├── ...
26 | # └── 0900000.png
27 |
28 | def __init__(self,
29 | sample_path: Path,
30 | height: int,
31 | width: int,
32 | num_bins_context: int,
33 | load_voxel_grid: bool=True,
34 | extended_voxel_grid: bool=True,
35 | downsample: bool=False,
36 | ) -> None:
37 | assert sample_path.is_dir()
38 | assert num_bins_context >= 1
39 | assert sample_path.is_dir()
40 |
41 | nbins_context2corr = {
42 | 6: 4,
43 | 11: 7,
44 | 21: 13,
45 | 41: 25,
46 | }
47 | nbins_context2deltatime = {
48 | 6: 100000,
49 | 11: 50000,
50 | 21: 25000,
51 | 41: 12500,
52 | }
53 |
54 | ### To downsample or to not downsample
55 | self.downsample = downsample
56 |
57 | ### Voxel Grid
58 | assert num_bins_context in nbins_context2corr.keys()
59 | self.num_bins_context = num_bins_context
60 | self.num_bins_correlation = nbins_context2corr[num_bins_context]
61 | # We subtract one because the bin at the reference time is redundant
62 | self.num_bins_total = self.num_bins_context + self.num_bins_correlation - 1
63 |
64 | self.voxel_grid = VoxelGrid(self.num_bins_total, height, width)
65 |
66 | # Image data
67 | ref_time_us = 400*1000
68 | target_time_us = 900*1000
69 | img_ref_path = sample_path / 'images' / (f'{ref_time_us}'.zfill(7) + '.png')
70 | assert img_ref_path.exists()
71 | img_target_path = sample_path / 'images' / (f'{target_time_us}'.zfill(7) + '.png')
72 | assert img_target_path.exists()
73 | self.img_filepaths = [img_ref_path, img_target_path]
74 | self.img_ts = [int(x.stem) for x in self.img_filepaths]
75 |
76 | # Extract timestamps for later retrieving event data
77 | self.bin_0_time = self.img_ts[0] - (self.num_bins_correlation - 1)*nbins_context2deltatime[num_bins_context]
78 | assert self.bin_0_time >= 0
79 | self.bin_target_time = self.img_ts[1]
80 |
81 | # Flow data
82 | self.flow_ref_ts_us = ref_time_us
83 | flow_dir = sample_path / 'flow'
84 | assert flow_dir.is_dir()
85 | flow_filepaths = list()
86 | for flow_file in flow_dir.iterdir():
87 | assert flow_file.suffix == '.h5'
88 | flow_filepaths.append(flow_file)
89 | flow_filepaths.sort()
90 | self.flow_filepaths = flow_filepaths
91 | self.flow_ts_us = [int(x.stem) for x in self.flow_filepaths]
92 |
93 | # Event data
94 | ev_dir = sample_path / 'events'
95 | assert ev_dir.is_dir()
96 | self.event_filepath = ev_dir / 'events.h5'
97 | assert self.event_filepath.exists()
98 |
99 | ### Voxel Grid Saving
100 | self.version = 1 if extended_voxel_grid else 0
101 | downsample_str = '_downsampled' if self.downsample else ''
102 | self.voxel_grid_file = ev_dir / f'voxel_grid_v{self.version}_{self.num_bins_total}_bins{downsample_str}.h5'
103 | self.load_voxel_grid_from_disk = load_voxel_grid
104 |
105 | def downsample_tensor(self, input_tensor: torch.Tensor):
106 | assert input_tensor.ndim == 3
107 | assert self.downsample
108 | ch, ht, wd = input_tensor.shape
109 | input_tensor = input_tensor.float()
110 | return interpolate(input_tensor[None, ...], size=(ht//2, wd//2), align_corners=True, mode='bilinear').squeeze()
111 |
112 | def get_flow_gt(self, flow_every_n_ms: int):
113 | assert flow_every_n_ms > 0
114 | assert flow_every_n_ms % 10 == 0, 'must be a multiple of 10'
115 | delta_ts_us = flow_every_n_ms*1000
116 | out = {
117 | 'flow': list(),
118 | 'timestamps': list(),
119 | }
120 | for flow_ts_us, flow_filepath in zip(self.flow_ts_us, self.flow_filepaths):
121 | if (flow_ts_us - self.flow_ref_ts_us) % delta_ts_us != 0:
122 | continue
123 | out['timestamps'].append(flow_ts_us)
124 | with h5py.File(str(flow_filepath), 'r') as h5f:
125 | flow = np.asarray(h5f['flow'])
126 | # h, w, c -> c, h, w
127 | flow = np.moveaxis(flow, -1, 0)
128 | flow = torch.from_numpy(flow)
129 | if self.downsample:
130 | flow = self.downsample_tensor(flow)
131 | flow = flow/2
132 | out['flow'].append(flow)
133 | return out
134 |
135 | def get_images(self):
136 | out = {
137 | 'images': list(),
138 | 'timestamps': self.img_ts,
139 | }
140 | for img_path in self.img_filepaths:
141 | img = np.asarray(iio.imread(str(img_path), format='PNG-FI'))
142 | # img: (h, w, c) -> (c, h, w)
143 | img = np.moveaxis(img, -1, 0)
144 | img = torch.from_numpy(img)
145 | if self.downsample:
146 | img = self.downsample_tensor(img)
147 | out['images'].append(img)
148 | return out
149 |
150 | def _get_events(self, t_start: int, t_end: int):
151 | assert t_start >= 0
152 | assert t_end <= 1000000
153 | assert t_end > t_start
154 | with h5py.File(str(self.event_filepath), 'r') as h5f:
155 | time = np.asarray(h5f['t'])
156 | first_idx = np.searchsorted(time, t_start, side='left')
157 | last_idx_p1 = np.searchsorted(time, t_end, side='right')
158 | out = {
159 | 'x': np.asarray(h5f['x'][first_idx:last_idx_p1]),
160 | 'y': np.asarray(h5f['y'][first_idx:last_idx_p1]),
161 | 'p': np.asarray(h5f['p'][first_idx:last_idx_p1]),
162 | 't': time[first_idx:last_idx_p1],
163 | }
164 | return out
165 |
166 | def _events_to_voxel_grid(self, event_dict: Dict[str, np.ndarray], t0_center: int=None, t1_center: int=None) -> torch.Tensor:
167 | # int32 is enough for this dataset as the timestamps are in us and start at 0 for every sample sequence
168 | t = event_dict['t'].astype('int32')
169 | x = event_dict['x'].astype('int16')
170 | y = event_dict['y'].astype('int16')
171 | pol = event_dict['p'].astype('int8')
172 | return self.voxel_grid.convert(
173 | torch.from_numpy(x),
174 | torch.from_numpy(y),
175 | torch.from_numpy(pol),
176 | torch.from_numpy(t),
177 | t0_center,
178 | t1_center)
179 |
180 | def _construct_voxel_grid(self, ts_from: int, ts_to: int):
181 | if self.version == 1:
182 | t_start, t_end = self.voxel_grid.get_extended_time_window(ts_from, ts_to)
183 | assert (ts_from - t_start) <= 100000, f'ts_from: {ts_from}, t_start: {t_start}'
184 | assert (t_end - ts_to) <= 100000, f't_end: {t_end}, ts_to: {ts_to}'
185 | event_data = self._get_events(t_start, t_end)
186 | voxel_grid = self._events_to_voxel_grid(event_data, ts_from, ts_to)
187 | elif self.version == 0:
188 | event_data = self._get_events(ts_from, ts_to)
189 | voxel_grid = self._events_to_voxel_grid(event_data)
190 | else:
191 | raise NotImplementedError
192 | if self.downsample:
193 | voxel_grid = self.downsample_tensor(voxel_grid)
194 | return voxel_grid
195 |
196 | def _load_or_save_voxel_grid(self, ts_from: int, ts_to: int) -> torch.Tensor:
197 | if self.voxel_grid_file.exists():
198 | voxel_grid_numpy = h5_to_np_array(self.voxel_grid_file)
199 | if voxel_grid_numpy is not None:
200 | # Squeeze because we may have saved it with batch dim 1 before (by mistake)
201 | return torch.from_numpy(voxel_grid_numpy).squeeze()
202 | # If None is returned, it means that the file was corrupt and we overwrite it.
203 | voxel_grid = self._construct_voxel_grid(ts_from, ts_to)
204 | np_array_to_h5(voxel_grid.numpy(), self.voxel_grid_file)
205 | return voxel_grid
206 |
207 | def get_voxel_grid(self) -> torch.Tensor:
208 | ts_from = self.bin_0_time
209 | ts_to = self.bin_target_time
210 | if self.load_voxel_grid_from_disk:
211 | return self._load_or_save_voxel_grid(ts_from, ts_to).squeeze()
212 | return self._construct_voxel_grid(ts_from, ts_to).squeeze()
213 |
214 | def voxel_grid_bin_idx_for_reference(self) -> int:
215 | return self.num_bins_correlation - 1
--------------------------------------------------------------------------------
/data/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union, Any, Tuple
2 |
3 | import numpy as np
4 | import skimage
5 | from skimage import img_as_ubyte
6 | import torch
7 | from torchvision.transforms import ColorJitter
8 |
9 | from utils.general import inputs_to_tensor, wrap_unwrap_lists_for_class_method
10 |
11 | UAT = Union[np.ndarray, torch.Tensor]
12 | ULISTT = Union[List[torch.Tensor], torch.Tensor]
13 | ULISTAT = Union[List[UAT], UAT]
14 |
15 | def torch_img_to_numpy(torch_img: torch.Tensor):
16 | ch, ht, wd = torch_img.shape
17 | assert ch == 3
18 | numpy_img = torch_img.numpy()
19 | numpy_img = np.moveaxis(numpy_img, 0, -1)
20 | return numpy_img
21 |
22 | def numpy_img_to_torch(numpy_img: np.ndarray):
23 | ht, wd, ch = numpy_img.shape
24 | assert ch == 3
25 | numpy_img = np.moveaxis(numpy_img, -1, 0)
26 | torch_img = torch.from_numpy(numpy_img)
27 | return torch_img
28 |
29 | class PhotoAugmentor:
30 | def __init__(self,
31 | brightness: float,
32 | contrast: float,
33 | saturation: float,
34 | hue: float,
35 | probability_color: float,
36 | noise_variance_range: Tuple[float, float],
37 | probability_noise: float):
38 | assert 0 <= probability_color <= 1
39 | assert 0 <= probability_noise <= 1
40 | assert len(noise_variance_range) == 2
41 | self.photo_augm = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
42 | self.probability_color = probability_color
43 | self.probability_noise = probability_noise
44 | self.var_min = noise_variance_range[0]
45 | self.var_max = noise_variance_range[1]
46 | assert self.var_max > self.var_min
47 |
48 | self.seed = torch.randint(low=0, high=2**32, size=(1,))[0].item()
49 |
50 | @staticmethod
51 | def sample_uniform(min_value: float=0, max_value: float=1) -> float:
52 | assert max_value > min_value
53 | uni_sample = torch.rand(1)[0].item()
54 | return (max_value - min_value)*uni_sample + min_value
55 |
56 | @wrap_unwrap_lists_for_class_method
57 | def _apply_jitter(self, images: ULISTT):
58 | assert isinstance(images, list)
59 |
60 | for idx, entry in enumerate(images):
61 | images[idx] = self.photo_augm(entry)
62 |
63 | return images
64 |
65 | @wrap_unwrap_lists_for_class_method
66 | def _apply_noise(self, images: ULISTT):
67 | assert isinstance(images, list)
68 | variance = self.sample_uniform(min_value=0.001, max_value=0.01)
69 |
70 | for idx, entry in enumerate(images):
71 | assert isinstance(entry, torch.Tensor)
72 | numpy_img = torch_img_to_numpy(entry)
73 | noisy_img = skimage.util.random_noise(numpy_img, mode='speckle', var=variance, clip=True, seed=self.seed) # return float64 in [0, 1]
74 | noisy_img = img_as_ubyte(noisy_img)
75 | torch_img = numpy_img_to_torch(noisy_img)
76 | images[idx] = torch_img
77 |
78 | return images
79 |
80 | @inputs_to_tensor
81 | def __call__(self, images: ULISTAT):
82 | if self.probability_color > torch.rand(1).item():
83 | images = self._apply_jitter(images)
84 | if self.probability_noise > torch.rand(1).item():
85 | images = self._apply_noise(images)
86 | return images
87 |
88 | class FlowAugmentor:
89 | def __init__(self,
90 | crop_size_hw,
91 | h_flip_prob: float=0.5,
92 | v_flip_prob: float=0.1):
93 | assert crop_size_hw[0] > 0
94 | assert crop_size_hw[1] > 0
95 | assert 0 <= h_flip_prob <= 1
96 | assert 0 <= v_flip_prob <= 1
97 |
98 | self.h_flip_prob = h_flip_prob
99 | self.v_flip_prob = v_flip_prob
100 | self.crop_size_hw = crop_size_hw
101 |
102 | @wrap_unwrap_lists_for_class_method
103 | def _random_cropping(self, ev_repr: Optional[ULISTT], flow: Optional[ULISTT], valid: Optional[ULISTT]=None, images: Optional[ULISTT]=None):
104 | if ev_repr is not None:
105 | assert isinstance(ev_repr, list)
106 | height, width = ev_repr[0].shape[-2:]
107 | elif images is not None:
108 | assert isinstance(images, list)
109 | height, width = images[0].shape[-2:]
110 | else:
111 | raise NotImplementedError
112 |
113 | y0 = torch.randint(0, height - self.crop_size_hw[0], (1,)).item()
114 | x0 = torch.randint(0, width - self.crop_size_hw[1], (1,)).item()
115 |
116 | if ev_repr is not None:
117 | assert isinstance(ev_repr, list)
118 | nbins = ev_repr[0].shape[-3]
119 | for idx, entry in enumerate(ev_repr):
120 | assert entry.shape[-3:] == (nbins, height, width), f'actual (nbins, h, w) = ({entry.shape[-3]}, {entry.shape[-2]}, {entry.shape[-1]}), expected (nbins, h, w) = ({nbins}, {height}, {width})'
121 | # NOTE: Elements of a range-based for loop do not directly modify the original list in Python!
122 | ev_repr[idx] = entry[..., y0:y0+self.crop_size_hw[0], x0:x0+self.crop_size_hw[1]]
123 |
124 | if flow is not None:
125 | assert isinstance(flow, list)
126 | for idx, entry in enumerate(flow):
127 | assert entry.shape[-3:] == (2, height, width), f'actual (c, h, w) = ({entry.shape[-3]}, {entry.shape[-2]}, {entry.shape[-1]}), expected (2, h, w) = (2, {height}, {width})'
128 | flow[idx] = entry[..., y0:y0+self.crop_size_hw[0], x0:x0+self.crop_size_hw[1]]
129 |
130 | if valid is not None:
131 | assert isinstance(valid, list)
132 | for idx, entry in enumerate(valid):
133 | assert entry.shape[-2:] == (height, width), f'actual (h, w) = ({entry.shape[-2]}, {entry.shape[-1]}), expected (h, w) = ({height}, {width})'
134 | valid[idx] = entry[y0:y0+self.crop_size_hw[0], x0:x0+self.crop_size_hw[1]]
135 |
136 | if images is not None:
137 | assert isinstance(images, list)
138 | for idx, entry in enumerate(images):
139 | assert entry.shape[-2:] == (height, width), f'actual (h, w) = ({entry.shape[-2]}, {entry.shape[-1]}), expected (h, w) = ({height}, {width})'
140 | images[idx] = entry[..., y0:y0+self.crop_size_hw[0], x0:x0+self.crop_size_hw[1]]
141 |
142 | return ev_repr, flow, valid, images
143 |
144 | @wrap_unwrap_lists_for_class_method
145 | def _horizontal_flipping(self, ev_repr: Optional[ULISTT], flow: Optional[ULISTT], valid: Optional[ULISTT]=None, images: Optional[ULISTT]=None):
146 | # flip last axis which is assumed to be width
147 | if ev_repr is not None:
148 | assert isinstance(ev_repr, list)
149 | ev_repr = [x.flip(-1) for x in ev_repr]
150 | if images is not None:
151 | assert isinstance(images, list)
152 | images = [x.flip(-1) for x in images]
153 | if valid is not None:
154 | assert isinstance(valid, list)
155 | valid = [x.flip(-1) for x in valid]
156 | if flow is not None:
157 | assert isinstance(flow, list)
158 | flow = [x.flip(-1) for x in flow]
159 | # also flip the sign of the x component of the flow
160 | for idx, entry in enumerate(flow):
161 | flow[idx][0] = -1 * entry[0]
162 |
163 | return ev_repr, flow, valid, images
164 |
165 | @wrap_unwrap_lists_for_class_method
166 | def _vertical_flipping(self, ev_repr: Optional[ULISTT], flow: Optional[ULISTT], valid: Optional[ULISTT]=None, images: Optional[ULISTT]=None):
167 | # flip second last axis which is assumed to be height
168 | if ev_repr is not None:
169 | assert isinstance(ev_repr, list)
170 | ev_repr = [x.flip(-2) for x in ev_repr]
171 | if images is not None:
172 | assert isinstance(images, list)
173 | images = [x.flip(-2) for x in images]
174 | if valid is not None:
175 | assert isinstance(valid, list)
176 | valid = [x.flip(-2) for x in valid]
177 | if flow is not None:
178 | assert isinstance(flow, list)
179 | flow = [x.flip(-2) for x in flow]
180 | # also flip the sign of the y component of the flow
181 | for idx, entry in enumerate(flow):
182 | flow[idx][1] = -1 * entry[1]
183 |
184 | return ev_repr, flow, valid, images
185 |
186 | @inputs_to_tensor
187 | def __call__(self,
188 | ev_repr: Optional[ULISTAT]=None,
189 | flow: Optional[ULISTAT]=None,
190 | valid: Optional[ULISTAT]=None,
191 | images: Optional[ULISTAT]=None):
192 |
193 | if self.h_flip_prob > torch.rand(1).item():
194 | ev_repr, flow, valid, images = self._horizontal_flipping(ev_repr, flow, valid, images)
195 |
196 | if self.v_flip_prob > torch.rand(1).item():
197 | ev_repr, flow, valid, images = self._vertical_flipping(ev_repr, flow, valid, images)
198 |
199 | ev_repr, flow, valid, images = self._random_cropping(ev_repr, flow, valid, images)
200 |
201 | out = (ev_repr, flow, valid, images)
202 | out = (x for x in out if x is not None)
203 | return out
204 |
--------------------------------------------------------------------------------
/data/utils/generic.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Union
3 |
4 | import cv2
5 | import h5py
6 | import numpy as np
7 |
8 |
9 | def _flow_16bit_to_float(flow_16bit: np.ndarray):
10 | assert flow_16bit.dtype == np.uint16, flow_16bit.dtype
11 | assert flow_16bit.ndim == 3
12 | h, w, c = flow_16bit.shape
13 | assert c == 3
14 |
15 | valid2D = flow_16bit[..., 2] == 1
16 | assert valid2D.shape == (h, w)
17 | assert np.all(flow_16bit[~valid2D, -1] == 0)
18 | valid_map = np.where(valid2D)
19 | flow_16bit = flow_16bit.astype('float32')
20 | flow_map = np.zeros((h, w, 2), dtype='float32')
21 | flow_map[valid_map[0], valid_map[1], 0] = (flow_16bit[valid_map[0], valid_map[1], 0] - 2 ** 15) / 128
22 | flow_map[valid_map[0], valid_map[1], 1] = (flow_16bit[valid_map[0], valid_map[1], 1] - 2 ** 15) / 128
23 | return flow_map, valid2D
24 |
25 |
26 | def load_flow(flowfile: Path):
27 | assert flowfile.exists()
28 | assert flowfile.suffix == '.png'
29 | # flow_16bit = np.array(Image.open(str(flowfile)))
30 | flow_16bit = cv2.imread(str(flowfile), cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR)
31 | flow, valid2D = _flow_16bit_to_float(flow_16bit)
32 | return flow, valid2D
33 |
34 |
35 | def _blosc_opts(complevel=1, complib='blosc:zstd', shuffle='byte'):
36 | shuffle = 2 if shuffle == 'bit' else 1 if shuffle == 'byte' else 0
37 | compressors = ['blosclz', 'lz4', 'lz4hc', 'snappy', 'zlib', 'zstd']
38 | complib = ['blosc:' + c for c in compressors].index(complib)
39 | args = {
40 | 'compression': 32001,
41 | 'compression_opts': (0, 0, 0, 0, complevel, shuffle, complib),
42 | }
43 | if shuffle > 0:
44 | # Do not use h5py shuffle if blosc shuffle is enabled.
45 | args['shuffle'] = False
46 | return args
47 |
48 |
49 | def np_array_to_h5(array: np.ndarray, outpath: Path) -> None:
50 | isinstance(array, np.ndarray)
51 | assert outpath.suffix == '.h5'
52 |
53 | with h5py.File(str(outpath), 'w') as h5f:
54 | h5f.create_dataset('voxel_grid', data=array, shape=array.shape, dtype=array.dtype,
55 | **_blosc_opts(complevel=1, shuffle='byte'))
56 |
57 |
58 | def h5_to_np_array(inpath: Path) -> Union[np.ndarray, None]:
59 | assert inpath.suffix == '.h5'
60 | assert inpath.exists()
61 |
62 | try:
63 | with h5py.File(str(inpath), 'r') as h5f:
64 | array = np.asarray(h5f['voxel_grid'])
65 | return array
66 | except OSError as e:
67 | print(f'Error loading {inpath}')
68 | return None
69 |
--------------------------------------------------------------------------------
/data/utils/keys.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, auto, IntEnum
2 |
3 | class DataSetType(IntEnum):
4 | DSEC = auto()
5 | MULTIFLOW2D = auto()
6 |
7 | class DataLoading(Enum):
8 | FLOW = auto()
9 | FLOW_TIMESTAMPS = auto()
10 | FLOW_VALID = auto()
11 | FILE_INDEX = auto()
12 | EV_REPR = auto()
13 | BIN_META = auto()
14 | IMG = auto()
15 | IMG_TIMESTAMPS = auto()
16 | DATASET_TYPE = auto()
--------------------------------------------------------------------------------
/data/utils/provider.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class DatasetProviderBase(abc.ABC):
5 | @abc.abstractmethod
6 | def get_train_dataset(self):
7 | raise NotImplementedError
8 |
9 | @abc.abstractmethod
10 | def get_val_dataset(self):
11 | raise NotImplementedError
12 |
13 | @abc.abstractmethod
14 | def get_test_dataset(self):
15 | raise NotImplementedError
16 |
17 | @abc.abstractmethod
18 | def get_nbins_context(self):
19 | raise NotImplementedError
20 |
21 | @abc.abstractmethod
22 | def get_nbins_correlation(self):
23 | raise NotImplementedError
--------------------------------------------------------------------------------
/data/utils/representations.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional
3 |
4 | import torch
5 | torch.set_num_threads(1) # intraop parallelism (this can be a good option)
6 | torch.set_num_interop_threads(1) # interop parallelism
7 |
8 |
9 | def norm_voxel_grid(voxel_grid: torch.Tensor):
10 | mask = torch.nonzero(voxel_grid, as_tuple=True)
11 | if mask[0].size()[0] > 0:
12 | mean = voxel_grid[mask].mean()
13 | std = voxel_grid[mask].std()
14 | if std > 0:
15 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std
16 | else:
17 | voxel_grid[mask] = voxel_grid[mask] - mean
18 | return voxel_grid
19 |
20 |
21 | class EventRepresentation:
22 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor, t_from: Optional[int]=None, t_to: Optional[int]=None):
23 | raise NotImplementedError
24 |
25 |
26 | class VoxelGrid(EventRepresentation):
27 | def __init__(self, channels: int, height: int, width: int):
28 | assert channels > 1
29 | assert height > 1
30 | assert width > 1
31 | self.nb_channels = channels
32 | self.height = height
33 | self.width = width
34 |
35 | def get_extended_time_window(self, t0_center: int, t1_center: int):
36 | dt = self._get_dt(t0_center, t1_center)
37 | t_start = math.floor(t0_center - dt)
38 | t_end = math.ceil(t1_center + dt)
39 | return t_start, t_end
40 |
41 | def _construct_empty_voxel_grid(self):
42 | return torch.zeros(
43 | (self.nb_channels, self.height, self.width),
44 | dtype=torch.float,
45 | requires_grad=False,
46 | device=torch.device('cpu'))
47 |
48 | def _get_dt(self, t0_center: int, t1_center: int):
49 | assert t1_center > t0_center
50 | return (t1_center - t0_center)/(self.nb_channels - 1)
51 |
52 | def _normalize_time(self, time: torch.Tensor, t0_center: int, t1_center: int):
53 | # time_norm < t0_center will be negative
54 | # time_norm == t0_center is 0
55 | # time_norm > t0_center is positive
56 | # time_norm == t1_center is (nb_channels - 1)
57 | # time_norm > t1_center is greater than (nb_channels - 1)
58 | return (time - t0_center)/(t1_center - t0_center)*(self.nb_channels - 1)
59 |
60 | @staticmethod
61 | def _is_int_tensor(tensor: torch.Tensor) -> bool:
62 | return not torch.is_floating_point(tensor) and not torch.is_complex(tensor)
63 |
64 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor, t0_center: Optional[int]=None, t1_center: Optional[int]=None):
65 | assert x.device == y.device == pol.device == time.device == torch.device('cpu')
66 | assert type(t0_center) == type(t1_center)
67 | assert x.shape == y.shape == pol.shape == time.shape
68 | assert x.ndim == 1
69 | assert self._is_int_tensor(time)
70 |
71 | is_int_xy = self._is_int_tensor(x)
72 | if is_int_xy:
73 | assert self._is_int_tensor(y)
74 |
75 | voxel_grid = self._construct_empty_voxel_grid()
76 | ch, ht, wd = self.nb_channels, self.height, self.width
77 | with torch.no_grad():
78 | t0_center = t0_center if t0_center is not None else time[0]
79 | t1_center = t1_center if t1_center is not None else time[-1]
80 | t_norm = self._normalize_time(time, t0_center, t1_center)
81 |
82 | t0 = t_norm.floor().int()
83 | value = 2*pol.float()-1
84 |
85 | if is_int_xy:
86 | for tlim in [t0,t0+1]:
87 | mask = (tlim >= 0) & (tlim < ch)
88 | interp_weights = value * (1 - (tlim - t_norm).abs())
89 |
90 | index = ht * wd * tlim.long() + \
91 | wd * y.long() + \
92 | x.long()
93 |
94 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
95 | else:
96 | x0 = x.floor().int()
97 | y0 = y.floor().int()
98 | for xlim in [x0,x0+1]:
99 | for ylim in [y0,y0+1]:
100 | for tlim in [t0,t0+1]:
101 |
102 | mask = (xlim < wd) & (xlim >= 0) & (ylim < ht) & (ylim >= 0) & (tlim >= 0) & (tlim < ch)
103 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs())
104 |
105 | index = ht * wd * tlim.long() + \
106 | wd * ylim.long() + \
107 | xlim.long()
108 |
109 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
110 |
111 | return voxel_grid
112 |
--------------------------------------------------------------------------------
/loggers/wandb_logger.py:
--------------------------------------------------------------------------------
1 | """
2 | This is a modified version of the Pytorch Lightning logger
3 | """
4 |
5 | import time
6 | from argparse import Namespace
7 | from pathlib import Path
8 | from typing import Any, Dict, List, Optional, Union
9 | from weakref import ReferenceType
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | import wandb
15 | from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
16 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
17 | from pytorch_lightning.loggers.logger import rank_zero_experiment, Logger
18 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
19 | from wandb.sdk.lib import RunDisabled
20 | from wandb.wandb_run import Run
21 |
22 |
23 | class WandbLogger(Logger):
24 | LOGGER_JOIN_CHAR = "-"
25 | STEP_METRIC = "trainer/global_step"
26 |
27 | def __init__(
28 | self,
29 | name: Optional[str] = None,
30 | project: Optional[str] = None,
31 | group: Optional[str] = None,
32 | wandb_id: Optional[str] = None,
33 | prefix: Optional[str] = "",
34 | log_model: Optional[bool] = True,
35 | save_last_only_final: Optional[bool] = False,
36 | config_args: Optional[Dict[str, Any]] = None,
37 | **kwargs,
38 | ):
39 | super().__init__()
40 | self._experiment = None
41 | self._log_model = log_model
42 | self._prefix = prefix
43 | self._logged_model_time = {}
44 | self._checkpoint_callback = None
45 | # Save last is determined by the checkpoint callback argument
46 | self._save_last = None
47 | # Whether to save the last checkpoint continuously (more storage) or only when the run is aborted
48 | self._save_last_only_final = save_last_only_final
49 | # Save the configuration args (e.g. parsed arguments) and log it in wandb
50 | self._config_args = config_args
51 | # set wandb init arguments
52 | self._wandb_init = dict(
53 | name=name,
54 | project=project,
55 | group=group,
56 | id=wandb_id,
57 | resume="allow",
58 | save_code=True,
59 | )
60 | self._wandb_init.update(**kwargs)
61 | # extract parameters
62 | self._name = self._wandb_init.get("name")
63 | self._id = self._wandb_init.get("id")
64 | # for save_top_k
65 | self._public_run = None
66 |
67 | # start wandb run (to create an attach_id for distributed modes)
68 | wandb.require("service")
69 | _ = self.experiment
70 |
71 | def get_checkpoint(self, artifact_name: str, artifact_filepath: Optional[Path] = None) -> Path:
72 | artifact = self.experiment.use_artifact(artifact_name)
73 | if artifact_filepath is None:
74 | assert artifact is not None, 'You are probably using DDP, ' \
75 | 'in which case you should provide an artifact filepath.'
76 | # TODO: specify download directory
77 | artifact_dir = artifact.download()
78 | artifact_filepath = next(Path(artifact_dir).iterdir())
79 | assert artifact_filepath.exists()
80 | assert artifact_filepath.suffix == '.ckpt'
81 | return artifact_filepath
82 |
83 | def __getstate__(self) -> Dict[str, Any]:
84 | state = self.__dict__.copy()
85 | # args needed to reload correct experiment
86 | if self._experiment is not None:
87 | state["_id"] = getattr(self._experiment, "id", None)
88 | state["_attach_id"] = getattr(self._experiment, "_attach_id", None)
89 | state["_name"] = self._experiment.name
90 |
91 | # cannot be pickled
92 | state["_experiment"] = None
93 | return state
94 |
95 | @property
96 | @rank_zero_experiment
97 | def experiment(self) -> Run:
98 | if self._experiment is None:
99 | attach_id = getattr(self, "_attach_id", None)
100 | if wandb.run is not None:
101 | # wandb process already created in this instance
102 | rank_zero_warn(
103 | "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
104 | " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`."
105 | )
106 | self._experiment = wandb.run
107 | elif attach_id is not None and hasattr(wandb, "_attach"):
108 | # attach to wandb process referenced
109 | self._experiment = wandb._attach(attach_id)
110 | else:
111 | # create new wandb process
112 | self._experiment = wandb.init(**self._wandb_init)
113 | if self._config_args is not None:
114 | self._experiment.config.update(self._config_args, allow_val_change=True)
115 |
116 | # define default x-axis
117 | if isinstance(self._experiment, (Run, RunDisabled)) and getattr(
118 | self._experiment, "define_metric", None
119 | ):
120 | self._experiment.define_metric(self.STEP_METRIC)
121 | self._experiment.define_metric("*", step_metric=self.STEP_METRIC, step_sync=True)
122 |
123 | assert isinstance(self._experiment, (Run, RunDisabled))
124 | return self._experiment
125 |
126 | def watch(self, model: nn.Module, log: str = 'all', log_freq: int = 100, log_graph: bool = True):
127 | self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph)
128 |
129 | def add_step_metric(self, input_dict: dict, step: int) -> None:
130 | input_dict.update({self.STEP_METRIC: step})
131 |
132 | @rank_zero_only
133 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
134 | params = _convert_params(params)
135 | params = _flatten_dict(params)
136 | params = _sanitize_callable_params(params)
137 | self.experiment.config.update(params, allow_val_change=True)
138 |
139 | @rank_zero_only
140 | def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
141 | assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
142 |
143 | metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
144 | if step is not None:
145 | self.add_step_metric(metrics, step)
146 | self.experiment.log({**metrics}, step=step)
147 | else:
148 | self.experiment.log(metrics)
149 |
150 | @rank_zero_only
151 | def log_images(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: str) -> None:
152 | """Log images (tensors, numpy arrays, PIL Images or file paths).
153 | Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
154 |
155 | How to use: https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.wandb.html#weights-and-biases-logger
156 | Taken from: https://github.com/PyTorchLightning/pytorch-lightning/blob/11e289ad9f95f5fe23af147fa4edcc9794f9b9a7/pytorch_lightning/loggers/wandb.py#L420
157 | """
158 | if not isinstance(images, list):
159 | raise TypeError(f'Expected a list as "images", found {type(images)}')
160 | n = len(images)
161 | for k, v in kwargs.items():
162 | if len(v) != n:
163 | raise ValueError(f"Expected {n} items but only found {len(v)} for {k}")
164 | kwarg_list = [{k: kwargs[k][i] for k in kwargs.keys()} for i in range(n)]
165 | metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]}
166 | self.log_metrics(metrics, step)
167 |
168 | @rank_zero_only
169 | def log_videos(self,
170 | key: str,
171 | videos: List[Union[np.ndarray, str]],
172 | step: Optional[int] = None,
173 | captions: Optional[List[str]] = None,
174 | fps: int = 4,
175 | format_: Optional[str] = None):
176 | """
177 | :param video: List[(T,C,H,W)] or List[(N,T,C,H,W)]
178 | :param captions: List[str] or None
179 |
180 | More info: https://docs.wandb.ai/ref/python/data-types/video and
181 | https://docs.wandb.ai/guides/track/log/media#other-media
182 | """
183 | assert isinstance(videos, list)
184 | if captions is not None:
185 | assert isinstance(captions, list)
186 | assert len(captions) == len(videos)
187 | wandb_videos = list()
188 | for idx, video in enumerate(videos):
189 | caption = captions[idx] if captions is not None else None
190 | wandb_videos.append(wandb.Video(data_or_path=video, caption=caption, fps=fps, format=format_))
191 | self.log_metrics(metrics={key: wandb_videos}, step=step)
192 |
193 | @property
194 | def name(self) -> Optional[str]:
195 | # This function seems to be only relevant if LoggerCollection is used.
196 | # don't create an experiment if we don't have one
197 | return self._experiment.project_name() if self._experiment else self._name
198 |
199 | @property
200 | def version(self) -> Optional[str]:
201 | # This function seems to be only relevant if LoggerCollection is used.
202 | # don't create an experiment if we don't have one
203 | return self._experiment.id if self._experiment else self._id
204 |
205 | @rank_zero_only
206 | def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
207 | # log checkpoints as artifacts
208 | if self._checkpoint_callback is None:
209 | self._checkpoint_callback = checkpoint_callback
210 | self._save_last = checkpoint_callback.save_last
211 | if self._log_model:
212 | self._scan_and_log_checkpoints(checkpoint_callback, self._save_last and not self._save_last_only_final)
213 |
214 | @rank_zero_only
215 | def finalize(self, status: str) -> None:
216 | # log checkpoints as artifacts
217 | if self._checkpoint_callback and self._log_model:
218 | self._scan_and_log_checkpoints(self._checkpoint_callback, self._save_last)
219 |
220 | def _get_public_run(self):
221 | if self._public_run is None:
222 | experiment = self.experiment
223 | runpath = experiment._entity + '/' + experiment._project + '/' + experiment._run_id
224 | api = wandb.Api()
225 | self._public_run = api.run(path=runpath)
226 | return self._public_run
227 |
228 | def _num_logged_artifact(self):
229 | public_run = self._get_public_run()
230 | return len(public_run.logged_artifacts())
231 |
232 | def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]", save_last: bool) -> None:
233 | assert self._log_model
234 | if self._checkpoint_callback is None:
235 | self._checkpoint_callback = checkpoint_callback
236 | self._save_last = checkpoint_callback.save_last
237 |
238 | checkpoints = {
239 | checkpoint_callback.best_model_path: checkpoint_callback.best_model_score,
240 | **checkpoint_callback.best_k_models,
241 | }
242 | assert len(checkpoints) <= max(checkpoint_callback.save_top_k, 0)
243 |
244 | if save_last:
245 | last_model_path = Path(checkpoint_callback.last_model_path)
246 | if last_model_path.exists():
247 | checkpoints.update({checkpoint_callback.last_model_path: checkpoint_callback.current_score})
248 | else:
249 | print(f'last model checkpoint not found at {checkpoint_callback.last_model_path}')
250 |
251 | checkpoints = sorted(
252 | ((Path(path).stat().st_mtime, path, score) for path, score in checkpoints.items() if Path(path).is_file()),
253 | key=lambda x: x[0])
254 | # Retain only checkpoints that we have not logged before with one exception:
255 | # If the name is the same (e.g. last checkpoint which should be overwritten),
256 | # make sure that they are newer than the previously saved checkpoint by checking their modification time
257 | checkpoints = [ckpt for ckpt in checkpoints if
258 | ckpt[1] not in self._logged_model_time.keys() or self._logged_model_time[ckpt[1]] < ckpt[0]]
259 | # remove checkpoints with undefined (None) score
260 | checkpoints = [x for x in checkpoints if x[2] is not None]
261 |
262 | num_ckpt_logged_before = self._num_logged_artifact()
263 | num_new_cktps = len(checkpoints)
264 |
265 | if num_new_cktps == 0:
266 | return
267 |
268 | # log iteratively all new checkpoints
269 | for time_, path, score in checkpoints:
270 | score = score.item() if isinstance(score, torch.Tensor) else score
271 | is_best = path == checkpoint_callback.best_model_path
272 | is_last = path == checkpoint_callback.last_model_path
273 | metadata = ({
274 | "score": score,
275 | "original_filename": Path(path).name,
276 | "ModelCheckpoint": {
277 | k: getattr(checkpoint_callback, k)
278 | for k in [
279 | "monitor",
280 | "mode",
281 | "save_last",
282 | "save_top_k",
283 | "save_weights_only",
284 | ]
285 | # ensure it does not break if `ModelCheckpoint` args change
286 | if hasattr(checkpoint_callback, k)
287 | },
288 | }
289 | )
290 | aliases = []
291 | if is_best:
292 | aliases.append('best')
293 | if is_last:
294 | aliases.append('last')
295 | artifact_name = f'checkpoint-{self.experiment.id}-' + ('last' if is_last else 'topK')
296 | artifact = wandb.Artifact(name=artifact_name, type='model', metadata=metadata)
297 | assert Path(path).exists()
298 | artifact.add_file(path, name=f'{self.experiment.id}.ckpt')
299 | self.experiment.log_artifact(artifact, aliases=aliases)
300 | # remember logged model - timestamp needed in case filename didn't change (last.ckpt or custom name)
301 | self._logged_model_time[path] = time_
302 |
303 | timeout = 20
304 | time_spent = 0
305 | while self._num_logged_artifact() < num_ckpt_logged_before + num_new_cktps:
306 | time.sleep(1)
307 | time_spent += 1
308 | if time_spent >= timeout:
309 | rank_zero_warn("Timeout: Num logged artifacts never reached expected value.")
310 | print(f'self._num_logged_artifact() = {self._num_logged_artifact()}')
311 | print(f'num_ckpt_logged_before = {num_ckpt_logged_before}')
312 | print(f'num_new_cktps = {num_new_cktps}')
313 | break
314 | try:
315 | self._rm_but_top_k(checkpoint_callback.save_top_k)
316 | except KeyError:
317 | pass
318 |
319 | def _rm_but_top_k(self, top_k: int):
320 | # top_k == -1: save all model
321 | # top_k == 0: no model saved at all. The checkpoint callback does not return checkpoints.
322 | # top_k > 0: keep only top k model (last and best will not be deleted)
323 | def is_last(artifact):
324 | return 'last' in artifact.aliases
325 |
326 | def is_best(artifact):
327 | return 'best' in artifact.aliases
328 |
329 | def try_delete(artifact):
330 | try:
331 | artifact.delete(delete_aliases=True)
332 | except wandb.errors.CommError:
333 | print(f'Failed to delete artifact {artifact.name} due to wandb.errors.CommError')
334 |
335 | public_run = self._get_public_run()
336 |
337 | score2art = list()
338 | for artifact in public_run.logged_artifacts():
339 | score = artifact.metadata['score']
340 | original_filename = artifact.metadata['original_filename']
341 | if score == 'Infinity':
342 | print(
343 | f'removing INF artifact (name, score, original_filename): ({artifact.name}, {score}, {original_filename})')
344 | try_delete(artifact)
345 | continue
346 | if score is None:
347 | print(
348 | f'removing None artifact (name, score, original_filename): ({artifact.name}, {score}, {original_filename})')
349 | try_delete(artifact)
350 | continue
351 | score2art.append((score, artifact))
352 |
353 | # From high score to low score
354 | score2art.sort(key=lambda x: x[0], reverse=True)
355 |
356 | count = 0
357 | for score, artifact in score2art:
358 | original_filename = artifact.metadata['original_filename']
359 | if 'last' in original_filename and not is_last(artifact):
360 | try_delete(artifact)
361 | continue
362 | if is_last(artifact):
363 | continue
364 | count += 1
365 | if is_best(artifact):
366 | continue
367 | # if top_k == -1, we do not delete anything
368 | if 0 <= top_k < count:
369 | try_delete(artifact)
--------------------------------------------------------------------------------
/models/raft_spline/bezier.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | from typing import Union, List
5 |
6 | import numpy as np
7 | import torch as th
8 | from numba import jit
9 |
10 | has_scipy_special = True
11 | try:
12 | from scipy import special
13 | except ImportError:
14 | has_scipy_special = False
15 |
16 | from models.raft_utils.utils import cvx_upsample
17 |
18 | class BezierCurves:
19 | # Each ctrl point lives in R^2
20 | CTRL_DIM: int = 2
21 |
22 | def __init__(self, bezier_params: th.Tensor):
23 | # bezier_params: batch, ctrl_dim*(n_ctrl_pts - 1), height, width
24 | assert bezier_params.ndim == 4
25 | self._params = bezier_params
26 |
27 | # some helpful meta-data:
28 | self.batch, channels, self.ht, self.wd = self._params.shape
29 | assert channels % 2 == 0
30 | # P0 is always zeros as it corresponds to the pixel locations.
31 | # Consequently, we only compute P1, P2, ...
32 | self.n_ctrl_pts = channels // self.CTRL_DIM + 1
33 | assert self.n_ctrl_pts > 0
34 |
35 | # math.comb is only available in python 3.8 or higher
36 | self.use_math_comb = hasattr(math, 'comb')
37 | if not self.use_math_comb:
38 | assert has_scipy_special
39 | assert hasattr(special, 'comb')
40 |
41 | def comb(self, n: int, k:int):
42 | if self.use_math_comb:
43 | return math.comb(n, k)
44 | return special.comb(n, k)
45 |
46 | @classmethod
47 | def create_from_specification(cls, batch_size: int, n_ctrl_pts: int, height: int, width: int, device: th.device) -> BezierCurves:
48 | assert batch_size > 0
49 | assert n_ctrl_pts > 1
50 | assert height > 0
51 | assert width > 0
52 | params = th.zeros(batch_size, cls.CTRL_DIM * (n_ctrl_pts - 1), height, width, device=device)
53 | return cls(params)
54 |
55 | @classmethod
56 | def from_2view(cls, flow_tensor: th.Tensor) -> BezierCurves:
57 | # This function has been written to visualize 2-view predictions for our paper.
58 | batch_size, channel_size, height, width = flow_tensor.shape
59 | assert channel_size == 2 == cls.CTRL_DIM
60 | return cls(flow_tensor)
61 |
62 | @classmethod
63 | def create_from_voxel_grid(cls, voxel_grid: th.Tensor, downsample_factor: int=8, bezier_degree: int=2) -> BezierCurves:
64 | assert isinstance(downsample_factor, int)
65 | assert downsample_factor >= 1
66 | batch, _, ht, wd = voxel_grid.shape
67 | assert ht % 8 == 0
68 | assert wd % 8 == 0
69 | ht, wd = ht//downsample_factor, wd//downsample_factor
70 | n_ctrl_pts = bezier_degree + 1
71 | return cls.create_from_specification(batch_size=batch, n_ctrl_pts=n_ctrl_pts, height=ht, width=wd, device=voxel_grid.device)
72 |
73 | @property
74 | def device(self):
75 | return self._params.device
76 |
77 | @property
78 | def dtype(self):
79 | return self._params.dtype
80 |
81 | def create_upsampled(self, mask: th.Tensor) -> BezierCurves:
82 | """ Upsample params [N, dim, H/8, W/8] -> [N, dim, H, W] using convex combination """
83 | up_params = cvx_upsample(self._params, mask)
84 | return BezierCurves(up_params)
85 |
86 | def detach(self, clone: bool=False, cpu: bool=False) -> BezierCurves:
87 | params = self._params.detach()
88 | if cpu:
89 | return BezierCurves(params.cpu())
90 | if clone:
91 | params = params.clone()
92 | return BezierCurves(params)
93 |
94 | def detach_(self, cpu: bool=False) -> None:
95 | # Detaches the bezier parameters in-place!
96 | self._params = self._params.detach()
97 | if cpu:
98 | self._params = self._params.cpu()
99 |
100 | def cpu(self) -> BezierCurves:
101 | return BezierCurves(self._params.cpu())
102 |
103 | def cpu_(self) -> None:
104 | # Puts the bezier parameters to CPU in-place!
105 | self._params = self._params.cpu()
106 |
107 | @property
108 | def requires_grad(self):
109 | return self._params.requires_grad
110 |
111 | @property
112 | def batch_size(self):
113 | return self._params.shape[0]
114 |
115 | @property
116 | def degree(self):
117 | return self.n_ctrl_pts - 1
118 |
119 | @property
120 | def dim(self):
121 | return self._params.shape[1]
122 |
123 | @property
124 | def height(self):
125 | return self._params.shape[-2]
126 |
127 | @property
128 | def width(self):
129 | return self._params.shape[-1]
130 |
131 | def get_params(self) -> th.Tensor:
132 | return self._params
133 |
134 | def _param_view(self) -> th.Tensor:
135 | return self._params.view(self.batch, self.CTRL_DIM, self.degree, self.ht, self.wd)
136 |
137 | def delta_update_params(self, delta_bezier: th.Tensor) -> None:
138 | assert delta_bezier.shape == self._params.shape
139 | self._params = self._params + delta_bezier
140 |
141 | @staticmethod
142 | def _get_binom_coeffs(degree: int):
143 | n = degree
144 | k = np.arange(degree) + 1
145 | return special.binom(n, k)
146 |
147 | @staticmethod
148 | @jit(nopython=True)
149 | def _get_time_coeffs(timestamps: np.ndarray, degree: int):
150 | assert timestamps.min() >= 0
151 | assert timestamps.max() <= 1
152 | assert timestamps.ndim == 1
153 | # I would like to check ensure float64 dtype but have not found a way to check in jit
154 | #assert timestamps.dtype == np.dtype('float64')
155 |
156 | num_ts = timestamps.size
157 | out = np.zeros((num_ts, degree))
158 | for t_idx in range(num_ts):
159 | for d_idx in range(degree):
160 | time = timestamps[t_idx]
161 | i = d_idx + 1
162 | out[t_idx, d_idx] = (1 - time)**(degree - i)*time**i
163 | return out
164 |
165 | def _compute_flow_from_timestamps(self, timestamps: Union[List[float], np.ndarray]):
166 | if isinstance(timestamps, list):
167 | timestamps = np.asarray(timestamps)
168 | else:
169 | assert isinstance(timestamps, np.ndarray)
170 | assert timestamps.dtype == 'float64'
171 | assert timestamps.size > 0
172 | assert np.min(timestamps) >= 0
173 | assert np.max(timestamps) <= 1
174 |
175 | degree = self.degree
176 | binom_coeffs = self._get_binom_coeffs(degree)
177 | time_coeffs = self._get_time_coeffs(timestamps, degree)
178 | # poly coeffs: time, degree
179 | polynomial_coeffs = np.einsum('j,ij->ij', binom_coeffs, time_coeffs)
180 | polynomial_coeffs = th.from_numpy(polynomial_coeffs).float().to(device=self.device)
181 |
182 | # params: batch, dim, degree, height, width
183 | params = self._param_view()
184 | # flow: timestamps, batch, dim, height, width
185 | flow = th.einsum('bdphw,tp->tbdhw', params, polynomial_coeffs)
186 | return flow
187 |
188 | def get_flow_from_reference(self, time: Union[float, int, List[float], np.ndarray]) -> th.Tensor:
189 | params = self._param_view()
190 | batch, dim, degree, height, width = params.shape
191 | time_is_scalar = isinstance(time, int) or isinstance(time, float)
192 | if time_is_scalar:
193 | assert time >= 0.0
194 | assert time <= 1.0
195 | if time == 1:
196 | P_end = params[:, :, -1, ...]
197 | return P_end
198 | if time == 0:
199 | return th.zeros((batch, dim, height, width), dtype=self.dtype, device=self.device)
200 | time = np.array([time], dtype='float64')
201 | elif isinstance(time, list):
202 | time = np.asarray(time, dtype='float64')
203 | else:
204 | assert isinstance(time, np.ndarray)
205 | assert time.dtype == 'float64'
206 | assert time.size > 0
207 | assert np.min(time) >= 0
208 | assert np.max(time) <= 1
209 |
210 | # flow is coords1 - coords0
211 | # flows: timestamps, batch, dim, height, width
212 | flows = self._compute_flow_from_timestamps(timestamps=time)
213 | if time_is_scalar:
214 | assert flows.shape[0] == 1
215 | return flows[0]
216 | return flows
--------------------------------------------------------------------------------
/models/raft_spline/raft.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any, Optional, List
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from models.raft_spline.bezier import BezierCurves
7 | from models.raft_spline.update import BasicUpdateBlock
8 | from models.raft_utils.extractor import BasicEncoder
9 | from models.raft_utils.corr import CorrComputation, CorrBlockParallelMultiTarget
10 | from models.raft_utils.utils import coords_grid
11 | from utils.timers import CudaTimerDummy as CudaTimer
12 |
13 |
14 | class RAFTSpline(nn.Module):
15 | def __init__(self, model_params: Dict[str, Any]):
16 | super().__init__()
17 | nbins_context = model_params['num_bins']['context']
18 | nbins_correlation = model_params['num_bins']['correlation']
19 | self.bezier_degree = model_params['bezier_degree']
20 | self.detach_bezier = model_params['detach_bezier']
21 | print(f'Detach Bezier curves: {self.detach_bezier}')
22 |
23 | assert nbins_correlation > 0 and nbins_context > 0
24 | assert self.bezier_degree >= 1
25 | self.nbins_context = nbins_context
26 | self.nbins_corr = nbins_correlation
27 |
28 | print('RAFT-Spline config:')
29 | print(f'Num bins context: {nbins_context}')
30 | print(f'Num bins correlation: {nbins_correlation}')
31 |
32 | corr_params = model_params['correlation']
33 | self.corr_use_cosine_sim = corr_params['use_cosine_sim']
34 |
35 | ev_corr_params = corr_params['ev']
36 | self.ev_corr_target_indices = ev_corr_params['target_indices']
37 | self.ev_corr_levels = ev_corr_params['levels']
38 | # TODO: fix this in the config
39 | #self.ev_corr_radius = ev_corr_params['radius']
40 | self.ev_corr_radius = 4
41 |
42 | self.img_corr_params = None
43 | if model_params['use_boundary_images']:
44 | print('Using images')
45 | self.img_corr_params = corr_params['img']
46 | assert 'levels' in self.img_corr_params
47 | assert 'radius' in self.img_corr_params
48 |
49 | self.hidden_dim = hdim = model_params['hidden']['dim']
50 | self.context_dim = cdim = model_params['context']['dim']
51 | cnorm = model_params['context']['norm']
52 | feature_dim = model_params['feature']['dim']
53 | fnorm = model_params['feature']['norm']
54 |
55 | # feature network, context network, and update block
56 | context_dim = 0
57 | self.fnet_img = None
58 | if self.img_corr_params is not None:
59 | self.fnet_img = BasicEncoder(input_dim=3, output_dim=feature_dim, norm_fn=fnorm)
60 | context_dim += 3
61 | self.fnet_ev = None
62 | if model_params['use_events']:
63 | print('Using events')
64 | assert 0 not in self.ev_corr_target_indices
65 | assert len(self.ev_corr_target_indices) > 0
66 | assert max(self.ev_corr_target_indices) < self.nbins_context
67 | assert len(self.ev_corr_target_indices) == len(self.ev_corr_levels)
68 | self.fnet_ev = BasicEncoder(input_dim=nbins_correlation, output_dim=feature_dim, norm_fn=fnorm)
69 | context_dim += nbins_context
70 | assert self.fnet_ev is not None or self.fnet_img is not None
71 | self.cnet = BasicEncoder(input_dim=context_dim, output_dim=hdim + cdim, norm_fn=cnorm)
72 |
73 | self.update_block = BasicUpdateBlock(model_params, hidden_dim=hdim)
74 |
75 | def freeze_bn(self):
76 | for m in self.modules():
77 | if isinstance(m, nn.BatchNorm2d):
78 | m.eval()
79 |
80 | def initialize_flow(self, input_):
81 | N, _, H, W = input_.shape
82 | # batch, 2, ht, wd
83 | downsample_factor = 8
84 | coords0 = coords_grid(N, H//downsample_factor, W//downsample_factor, device=input_.device)
85 | bezier = BezierCurves.create_from_voxel_grid(input_, downsample_factor=downsample_factor, bezier_degree=self.bezier_degree)
86 | return coords0, bezier
87 |
88 | def gen_voxel_grids(self, input_: torch.Tensor):
89 | # input_: N, nbins_context + nbins_corr - 1 , H, W
90 | assert self.nbins_context + self.nbins_corr - 1 == input_.shape[-3]
91 | corr_grids = list()
92 | # We need to add the reference index (which is 0).
93 | indices_with_reference = [0]
94 | indices_with_reference.extend(self.ev_corr_target_indices)
95 | for idx in indices_with_reference:
96 | slice_ = input_[:, idx:idx+self.nbins_corr, ...]
97 | corr_grids.append(slice_)
98 | context_grid = input_[:, -self.nbins_context:, ...]
99 | return corr_grids, context_grid
100 |
101 | def forward(self,
102 | voxel_grid: Optional[torch.Tensor]=None,
103 | images: Optional[List[torch.Tensor]]=None,
104 | iters: int=12,
105 | flow_init: Optional[BezierCurves]=None,
106 | test_mode: bool=False):
107 | assert voxel_grid is not None or images is not None
108 | assert iters > 0
109 |
110 | hdim = self.hidden_dim
111 | cdim = self.context_dim
112 | current_device = voxel_grid.device if voxel_grid is not None else images[0].device
113 |
114 | corr_computation_events = None
115 | context_input = None
116 | with CudaTimer(current_device, 'fnet_ev'):
117 | if self.fnet_ev is not None:
118 | assert voxel_grid is not None
119 | voxel_grid = voxel_grid.contiguous()
120 | corr_grids, context_input = self.gen_voxel_grids(voxel_grid)
121 | fmaps_ev = self.fnet_ev(corr_grids)
122 | fmaps_ev = [x.float() for x in fmaps_ev]
123 | fmap1_ev = fmaps_ev[0]
124 | fmap2_ev = torch.stack(fmaps_ev[1:], dim=0)
125 | corr_computation_events = CorrComputation(fmap1_ev, fmap2_ev, num_levels_per_target=self.ev_corr_levels)
126 |
127 | corr_computation_frames = None
128 | with CudaTimer(current_device, 'fnet_img'):
129 | if self.fnet_img is not None:
130 | assert self.img_corr_params is not None
131 | assert len(images) == 2
132 | # images[0]: at reference time
133 | # images[1]: at target time
134 | images = [2 * (x.float().contiguous() / 255) - 1 for x in images]
135 | fmaps_img = self.fnet_img(images)
136 | corr_computation_frames = CorrComputation(fmaps_img[0], fmaps_img[1], num_levels_per_target=self.img_corr_params['levels'])
137 | if context_input is not None:
138 | context_input = torch.cat((context_input, images[0]), dim=-3)
139 | else:
140 | context_input = images[0]
141 | assert context_input is not None
142 |
143 | with CudaTimer(current_device, 'cnet'):
144 | cnet = self.cnet(context_input)
145 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
146 | net = torch.tanh(net)
147 | inp = torch.relu(inp)
148 |
149 | # (batch, 2, ht, wd), ...
150 | coords0, bezier = self.initialize_flow(context_input)
151 |
152 | if flow_init is not None:
153 | bezier.delta_update_params(flow_init.get_params())
154 |
155 | bezier_up_predictions = []
156 | dt = 1/(self.nbins_context - 1)
157 |
158 | with CudaTimer(current_device, 'corr computation'):
159 | corr_block = CorrBlockParallelMultiTarget(
160 | corr_computation_events=corr_computation_events,
161 | corr_computation_frames=corr_computation_frames)
162 | with CudaTimer(current_device, 'all iters'):
163 | for itr in range(iters):
164 | # NOTE: original RAFT detaches the flow (bezier) here from the graph.
165 | # Our experiments with bezier curves indicate that detaching is lowering the validation EPE by up to 5% on DSEC.
166 | with CudaTimer(current_device, '1 iter'):
167 | if self.detach_bezier:
168 | bezier.detach_()
169 |
170 | lookup_timestamps = list()
171 | if corr_computation_events is not None:
172 | for tindex in self.ev_corr_target_indices:
173 | # 0 < time <= 1
174 | time = dt*tindex
175 | lookup_timestamps.append(time)
176 | if corr_computation_frames is not None:
177 | lookup_timestamps.append(1)
178 |
179 | with CudaTimer(current_device, 'get_flow (per iter)'):
180 | flows = bezier.get_flow_from_reference(time=lookup_timestamps)
181 | coords1 = coords0 + flows
182 |
183 | with CudaTimer(current_device, 'corr lookup (per iter)'):
184 | corr_total = corr_block(coords1)
185 |
186 | with CudaTimer(current_device, 'update (per iter)'):
187 | bezier_params = bezier.get_params()
188 | net, up_mask, delta_bezier = self.update_block(net, inp, corr_total, bezier_params)
189 |
190 | # B(k+1) = B(k) + \Delta(B)
191 | bezier.delta_update_params(delta_bezier=delta_bezier)
192 |
193 | if not test_mode or itr == iters - 1:
194 | bezier_up = bezier.create_upsampled(up_mask)
195 | bezier_up_predictions.append(bezier_up)
196 |
197 | if test_mode:
198 | return bezier, bezier_up
199 |
200 | return bezier_up_predictions
201 |
--------------------------------------------------------------------------------
/models/raft_spline/update.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class BezierHead(nn.Module):
9 | def __init__(self, bezier_degree: int, input_dim=128, hidden_dim=256):
10 | super().__init__()
11 | output_dim = bezier_degree * 2
12 | # TODO: figure out if we need to increase the capacity of this head due to bezier curves
13 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
14 | self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
15 | self.relu = nn.ReLU(inplace=True)
16 |
17 | def forward(self, x):
18 | return self.conv2(self.relu(self.conv1(x)))
19 |
20 |
21 | class SepConvGRU(nn.Module):
22 | def __init__(self, hidden_dim=128, input_dim=192+128):
23 | super().__init__()
24 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
25 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
26 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
27 |
28 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
29 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
30 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
31 |
32 |
33 | def forward(self, h, x):
34 | # horizontal
35 | hx = torch.cat([h, x], dim=1)
36 | z = torch.sigmoid(self.convz1(hx))
37 | r = torch.sigmoid(self.convr1(hx))
38 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
39 | h = (1-z) * h + z * q
40 |
41 | # vertical
42 | hx = torch.cat([h, x], dim=1)
43 | z = torch.sigmoid(self.convz2(hx))
44 | r = torch.sigmoid(self.convr2(hx))
45 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
46 | h = (1-z) * h + z * q
47 |
48 | return h
49 |
50 | class BasicMotionEncoder(nn.Module):
51 | def __init__(self, model_params: Dict[str, Any], output_dim: int=128):
52 | super().__init__()
53 | cor_planes = self._num_cor_planes(model_params['correlation'], model_params['use_boundary_images'], model_params['use_events'])
54 |
55 | # TODO: Are two layers enough for this? Because the number of input channels grew substantially
56 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
57 | c2_out = 192
58 | self.convc2 = nn.Conv2d(256, c2_out, 3, padding=1)
59 |
60 | # TODO: Consider upgrading the number of channels of the flow encoders (dealing now with bezier curves)
61 | bezier_planes = model_params['bezier_degree'] * 2
62 | self.convf1 = nn.Conv2d(bezier_planes, 128, 7, padding=3)
63 | f2_out = 64
64 | self.convf2 = nn.Conv2d(128, f2_out, 3, padding=1)
65 |
66 | combined_channels = f2_out+c2_out
67 | self.conv = nn.Conv2d(combined_channels, output_dim-bezier_planes, 3, padding=1)
68 |
69 | @staticmethod
70 | def _num_cor_planes(corr_params: Dict[str, Any], use_boundary_images: bool, use_events: bool):
71 | assert use_events or use_boundary_images
72 | out = 0
73 | if use_events:
74 | ev_params = corr_params['ev']
75 | ev_corr_levels = ev_params['levels']
76 | ev_corr_radius = ev_params['radius']
77 | assert len(ev_corr_levels) > 0
78 | assert len(ev_corr_radius) > 0
79 | assert len(ev_corr_levels) == len(ev_corr_radius)
80 | for lvl, rad in zip(ev_corr_levels, ev_corr_radius):
81 | out += lvl * (2*rad + 1)**2
82 | if use_boundary_images:
83 | img_corr_levels = corr_params['img']['levels']
84 | img_corr_radius = corr_params['img']['radius']
85 | out += img_corr_levels * (2*img_corr_radius + 1)**2
86 | return out
87 |
88 | def forward(self, bezier, corr):
89 | cor = F.relu(self.convc1(corr))
90 | cor = F.relu(self.convc2(cor))
91 | bez = F.relu(self.convf1(bezier))
92 | bez = F.relu(self.convf2(bez))
93 |
94 | cor_bez = torch.cat([cor, bez], dim=1)
95 |
96 | out = F.relu(self.conv(cor_bez))
97 | return torch.cat([out, bezier], dim=1)
98 |
99 |
100 | class BasicUpdateBlock(nn.Module):
101 | def __init__(self, model_params: Dict[str, Any], hidden_dim: int=128):
102 | super().__init__()
103 | motion_encoder_output_dim = model_params['motion']['dim']
104 | context_dim = model_params['context']['dim']
105 | bezier_degree = model_params['bezier_degree']
106 | self.encoder = BasicMotionEncoder(model_params, output_dim=motion_encoder_output_dim)
107 | gru_input_dim = context_dim + motion_encoder_output_dim
108 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=gru_input_dim)
109 | self.bezier_head = BezierHead(bezier_degree, input_dim=hidden_dim, hidden_dim=256)
110 |
111 | self.mask = nn.Sequential(
112 | nn.Conv2d(hidden_dim, 256, 3, padding=1),
113 | nn.ReLU(inplace=True),
114 | nn.Conv2d(256, 64*9, 1, padding=0))
115 |
116 | def forward(self, net, inp, corr, bezier):
117 | # TODO: check if we can simplify this similar to the DROID-SLAM update block
118 | motion_features = self.encoder(bezier, corr)
119 | inp = torch.cat([inp, motion_features], dim=1)
120 |
121 | net = self.gru(net, inp)
122 | delta_bezier = self.bezier_head(net)
123 |
124 | # scale mask to balance gradients
125 | mask = .25 * self.mask(net)
126 | return net, mask, delta_bezier
127 |
--------------------------------------------------------------------------------
/models/raft_utils/corr.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Union, List, Tuple, Optional
4 |
5 | import torch
6 | import torch as th
7 | import torch.nn.functional as F
8 | from omegaconf import ListConfig
9 |
10 | from models.raft_utils.utils import bilinear_sampler
11 |
12 |
13 | class CorrData:
14 | def __init__(self,
15 | corr: th.Tensor,
16 | batch_size: int):
17 | assert isinstance(corr, th.Tensor)
18 | num_targets, bhw, dim, ht, wd = corr.shape
19 | assert dim == 1
20 | assert isinstance(batch_size, int)
21 | assert batch_size >= 1
22 |
23 | # corr: num_targets, batch_size*ht*wd, 1, ht, wd
24 | self._corr = corr
25 | self._batch_size = batch_size
26 | self._target_indices = None
27 |
28 | @property
29 | def corr(self) -> th.Tensor:
30 | assert self._corr.ndim == 5
31 | return self._corr
32 |
33 | @property
34 | def corr_batched(self) -> th.Tensor:
35 | num_targets, bhw, dim, ht, wd = self.corr.shape
36 | assert dim == 1
37 | return self.corr.view(-1, dim, ht, wd)
38 |
39 | @property
40 | def batch_size(self) -> int:
41 | return self._batch_size
42 |
43 | @property
44 | def device(self) -> th.device:
45 | return self.corr.device
46 |
47 | @property
48 | def target_indices(self) -> th.Tensor:
49 | assert self._target_indices is not None
50 | return self._target_indices
51 |
52 | @target_indices.setter
53 | def target_indices(self, values: Union[List[int], th.Tensor]):
54 | if isinstance(values, list):
55 | values = sorted(values)
56 | values = th.Tensor(values, device=self.device)
57 | else:
58 | assert values.device == self.device
59 | assert values.dtype == th.int64
60 | assert values.ndim == 1
61 | assert values.numel() > 0
62 | assert values.min() >= 0
63 | assert th.all(values[1:] - values[:-1] > 0)
64 | assert self.corr.shape[0] == values.numel()
65 | self._target_indices = values
66 |
67 | def init_target_indices(self):
68 | num_targets = self.corr.shape[0]
69 | self.target_indices = torch.arange(num_targets, device=self.device)
70 |
71 | @staticmethod
72 | def _extract_database_indices_from_query_matches(query_values: torch.Tensor, database_values: torch.Tensor) -> th.Tensor:
73 | '''
74 | query_values: set of unique integers (assumed in ascending order)
75 | database_values: set of unique integers (assumed in ascending order)
76 |
77 | This function returns the database indices where query values match the database values.
78 | Example:
79 | query_values: torch.tensor([4, 5, 9])
80 | database_values: torch.tensor([1, 4, 5, 6, 9])
81 | returns: torch.tensor([1, 2, 4])
82 | '''
83 | assert database_values.dtype == torch.int64
84 | assert query_values.dtype == torch.int64
85 | assert database_values.ndim == 1
86 | assert query_values.ndim == 1
87 | assert torch.equal(database_values, torch.unique_consecutive(database_values))
88 | assert torch.equal(query_values, torch.unique_consecutive(query_values))
89 | device = query_values.device
90 | assert device == database_values.device
91 |
92 | num_db = database_values.numel()
93 | num_q = query_values.numel()
94 |
95 | database_values_expanded = database_values.expand(num_q, -1)
96 | query_values_expanded = query_values.expand(num_db, -1).transpose(0, 1)
97 |
98 | compare = torch.eq(database_values_expanded, query_values_expanded)
99 |
100 | indices = torch.arange(num_db, device=device)
101 | indices = indices.expand(num_q, -1)
102 | assert torch.all(compare.sum(1) == 1), f'compare =\n{compare}'
103 | # If the previous assertion fails it likely is the case that the query values are not a subset of database values.
104 | out = indices[compare]
105 | assert torch.equal(query_values, database_values[out])
106 | return out
107 |
108 | def get_downsampled(self, target_indices: Union[List[int], th.Tensor]) -> CorrData:
109 | if isinstance(target_indices, list):
110 | target_indices = sorted(target_indices)
111 | target_indices = th.tensor(target_indices, device=self.device)
112 |
113 | indices_to_select = self._extract_database_indices_from_query_matches(target_indices, self.target_indices)
114 | corr_selected = th.index_select(self.corr, dim=0, index=indices_to_select)
115 |
116 | num_new_targets, bhw, dim, ht, wd = corr_selected.shape
117 | assert dim == 1
118 | corr_selected = corr_selected.reshape(-1, dim, ht, wd)
119 | corr_down = F.avg_pool2d(corr_selected, 2, stride=2)
120 | _, _, ht_down, wd_down = corr_down.shape
121 | corr_down = corr_down.view(num_new_targets, bhw, dim, ht_down, wd_down)
122 |
123 | out = CorrData(corr_down, self.batch_size)
124 | out.target_indices = target_indices
125 | return out
126 |
127 | class CorrComputation:
128 | def __init__(self,
129 | fmap1: Union[th.Tensor, List[th.Tensor]],
130 | fmap2: Union[th.Tensor, List[th.Tensor]],
131 | num_levels_per_target: Union[int, List[int], List[th.Tensor]]):
132 | '''
133 | fmap1: batch, dim, ht, wd OR
134 | List[tensor(batch, dim, ht, wd)]
135 | fmap2: batch, dim, ht, wd OR
136 | num_targets, batch, dim, ht, wd OR
137 | List[tensor(num_targets, batch, dim, ht, wd)]
138 | -> in case of list input, len(fmap1) == len(fmap2) is enforced
139 | num_levels_per_target: int OR
140 | List[int] SUCH THAT len(...) == num_targets OR
141 | List[tensor(num_targets)]
142 | '''
143 | self._has_single_reference = isinstance(fmap1, th.Tensor)
144 | if isinstance(num_levels_per_target, int):
145 | num_levels_per_target = [num_levels_per_target]
146 | else:
147 | assert isinstance(num_levels_per_target, list) or isinstance(num_levels_per_target, ListConfig)
148 | if self._has_single_reference:
149 | assert fmap1.ndim == 4
150 | assert isinstance(fmap2, th.Tensor)
151 | if fmap2.ndim == 4:
152 | # This is the case where we also only have a single target
153 | fmap2 = fmap2.unsqueeze(0)
154 | else:
155 | assert fmap2.ndim == 5
156 | assert fmap1.shape == fmap2.shape[1:]
157 | fmap1 = [fmap1]
158 | fmap2 = [fmap2]
159 | assert len(num_levels_per_target) == fmap2[0].shape[0]
160 | for x in num_levels_per_target:
161 | assert isinstance(x, int)
162 | num_levels_per_target = [th.tensor(num_levels_per_target)]
163 | else:
164 | assert isinstance(fmap1, list)
165 | assert isinstance(fmap2, list)
166 | assert len(fmap1) == len(fmap2)
167 | assert len(num_levels_per_target) == len(fmap2)
168 | for f1, f2, num_lvls in zip(fmap1, fmap2, num_levels_per_target):
169 | assert isinstance(f1, th.Tensor)
170 | assert f1.ndim == 4
171 | assert isinstance(f2, th.Tensor)
172 | assert f2.ndim == 5
173 | assert f1.shape == f2.shape[1:]
174 | assert isinstance(num_lvls, th.Tensor)
175 | assert num_lvls.dtype == th.int64
176 | assert num_lvls.numel() == f2.shape[0]
177 | self._fmap1 = fmap1
178 | self._fmap2 = fmap2
179 |
180 | self._num_targets_per_reference = [x.shape[0] for x in fmap2]
181 | self._num_targets_overall = sum(self._num_targets_per_reference)
182 | assert self._num_targets_overall == th.cat(num_levels_per_target).numel()
183 | self._num_levels_per_target = num_levels_per_target
184 |
185 | self._bdhw = fmap1[0].shape
186 |
187 | @property
188 | def batch(self) -> int:
189 | return self._bdhw[0]
190 |
191 | @property
192 | def dim(self) -> int:
193 | return self._bdhw[1]
194 |
195 | @property
196 | def height(self) -> int:
197 | return self._bdhw[2]
198 |
199 | @property
200 | def width(self) -> int:
201 | return self._bdhw[3]
202 |
203 | @property
204 | def num_references(self) -> int:
205 | return len(self._fmap1)
206 |
207 | @property
208 | def num_levels_per_target(self) -> List[th.Tensor]:
209 | return self._num_levels_per_target
210 |
211 | @property
212 | def num_levels_per_target_merged(self) -> th.Tensor:
213 | return th.cat(self._num_levels_per_target)
214 |
215 | @property
216 | def num_targets_per_reference(self) -> List[int]:
217 | return self._num_targets_per_reference
218 |
219 | @property
220 | def num_targets_overall(self) -> int:
221 | return self._num_targets_overall
222 |
223 | def __add__(self, other: CorrComputation) -> CorrComputation:
224 | fmap1 = self._fmap1 + other._fmap1
225 | fmap2 = self._fmap2 + other._fmap2
226 | num_levels_per_target = self._num_levels_per_target + other._num_levels_per_target
227 | return CorrComputation(fmap1=fmap1, fmap2=fmap2, num_levels_per_target=num_levels_per_target)
228 |
229 | def get_correlation_volume(self) -> CorrData:
230 | # Note: we could also just use the more general N to N code but this should be slightly faster and easier to understand.
231 | if self._has_single_reference:
232 | # case: 1 to many, which includes the 1 to 1 special case
233 | return self._corr_dot_prod_1_to_N()
234 | # case: many to many
235 | return self._corr_dot_prod_M_to_N()
236 |
237 | def _corr_dot_prod_1_to_N(self) -> CorrData:
238 | # 1 to 1 if num_targets_sum is 1, which is a special case of this code.
239 | assert len(self._fmap1) == 1
240 | assert len(self._fmap2) == 1
241 | num_targets_sum = self.num_targets_overall
242 | fmap1 = self._fmap1[0].view(self.batch, self.dim, self.height*self.width)
243 | fmap2 = self._fmap2[0].view(num_targets_sum, self.batch, self.dim, self.height*self.width)
244 |
245 | corr_data = self._corr_dot_prod_util(fmap1, fmap2)
246 | return corr_data
247 |
248 | def _corr_dot_prod_M_to_N(self) -> CorrData:
249 | assert len(self._fmap1) > 1
250 | assert len(self._fmap2) > 1
251 | assert len(self._fmap1) == len(self._fmap2)
252 |
253 | # fmap{1,2}: num_targets_overall, batch, dim, height, width
254 | fmap1 = th.cat([x.expand(num_trgts, -1, -1, -1, -1) for x, num_trgts in zip(self._fmap1, self.num_targets_per_reference)], dim=0)
255 | fmap2 = th.cat(self._fmap2, dim=0)
256 |
257 | num_targets_overall = self.num_targets_overall
258 | fmap1 = fmap1.view(num_targets_overall, self.batch, self.dim, self.height*self.width)
259 | fmap2 = fmap2.view(num_targets_overall, self.batch, self.dim, self.height*self.width)
260 |
261 | corr_data = self._corr_dot_prod_util(fmap1, fmap2)
262 | return corr_data
263 |
264 | def _corr_dot_prod_util(self, fmap1: th.Tensor, fmap2: th.Tensor) -> CorrData:
265 | # corr: num_targets_sum, batch, ht*wd, ht*wd
266 | corr = fmap1.transpose(-1, -2) @ fmap2
267 | corr = corr / th.sqrt(th.tensor(self.dim, device=corr.device).float())
268 | corr = corr.view(self.num_targets_overall, self.batch * self.height * self.width, 1, self.height, self.width)
269 |
270 | out = CorrData(corr=corr, batch_size=self.batch)
271 | out.init_target_indices()
272 | return out
273 |
274 |
275 | class CorrBlockParallelMultiTarget:
276 | def __init__(self,
277 | corr_computation_events: Optional[CorrComputation]=None,
278 | corr_computation_frames: Optional[CorrComputation]=None,
279 | radius: int=4):
280 | do_events = corr_computation_events is not None
281 | do_frames = corr_computation_frames is not None
282 | assert do_events or do_frames
283 | assert radius >= 1
284 |
285 | if do_events and not do_frames:
286 | corr_computation = corr_computation_events
287 | elif do_frames and not do_events:
288 | corr_computation = corr_computation_frames
289 | else:
290 | assert do_events and do_frames
291 | corr_computation = corr_computation_events + corr_computation_frames
292 |
293 | num_levels_per_target = corr_computation.num_levels_per_target_merged.tolist()
294 | self._num_targets_base = len(num_levels_per_target)
295 | self._radius = radius
296 |
297 | corr_base_data = corr_computation.get_correlation_volume()
298 |
299 | max_num_levels = max(num_levels_per_target)
300 |
301 | self._corr_pyramid = [corr_base_data]
302 | for num_levels in range(2, max_num_levels + 1):
303 | target_idx_list = [idx for idx, val in enumerate(num_levels_per_target) if val >= num_levels]
304 | corr_data = self._corr_pyramid[-1].get_downsampled(target_indices=target_idx_list)
305 | self._corr_pyramid.append(corr_data)
306 |
307 | def __call__(self, coords: Union[th.Tensor, List[th.Tensor], Tuple[th.Tensor]]) -> th.Tensor:
308 | if isinstance(coords, list) or isinstance(coords, tuple):
309 | # num_targets_base, N, 2, H, W
310 | coords = th.stack(coords, dim=0)
311 | assert coords.ndim == 5
312 | # num_targets_base, N, H, W, 2
313 | coords = coords.permute(0, 1, 3, 4, 2)
314 | num_targets_base, batch, h1, w1, _ = coords.shape
315 | assert num_targets_base == self._num_targets_base
316 |
317 | r = self._radius
318 | out_pyramid = []
319 | for idx, corr_data in enumerate(self._corr_pyramid):
320 | target_indices = corr_data.target_indices
321 | assert target_indices.ndim == 1
322 | coords_selected = th.index_select(coords, dim=0, index=target_indices)
323 | num_targets = target_indices.numel()
324 |
325 | dx = th.linspace(-r, r, 2*r+1, device=coords.device)
326 | dy = th.linspace(-r, r, 2*r+1, device=coords.device)
327 | # delta: 2*r+1, 2*r+1, 2
328 | # NOTE: Unlike in the original implementation, we change the order
329 | # such that delta[..., 0] corresponds to x and delta[..., 1] corresponds to y
330 | # In fact, it does not matter since the same targets are looked up and then flattened and fed as channels.
331 | delta = th.stack(th.meshgrid(dy, dx)[::-1], dim=-1)
332 |
333 | centroid_lvl = coords_selected.reshape(num_targets*batch*h1*w1, 1, 1, 2) / 2**idx
334 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
335 | # coords_lvl: num_targets*batch*h1*w1, 2*r+1, 2*r+1, 2
336 | coords_lvl = centroid_lvl + delta_lvl
337 |
338 | # (reminder) corr_batched: num_targets*batch*h1*w1, 1, h2, w2
339 | # corr_feat: num_targets*batch*h1*w1, 1, 2*r+1, 2*r+1
340 | corr_feat = bilinear_sampler(corr_data.corr_batched, coords_lvl)
341 | # corr_feat: num_targets, batch, h1, w1, (2*r+1)**2
342 | corr_feat = corr_feat.view(num_targets, batch, h1, w1, -1)
343 | out_pyramid.append(corr_feat)
344 |
345 | # out: (num_targets_at_lvl_0 + ... + num_targets_at_lvl_final), batch, h1, w1, (2*r+1)**2
346 | out = th.cat(out_pyramid, dim=0)
347 | # out: batch, (num_targets_at_lvl_0 + ... + num_targets_at_lvl_final), (2*r+1)**2, h1, w1
348 | out = out.permute(1, 0, 4, 2, 3)
349 | # out: batch, (num_targets_at_lvl_0 + ... + num_targets_at_lvl_final)*(2*r+1)**2, h1, w1
350 | out = out.reshape(batch, -1, h1, w1).float()
351 | return out
--------------------------------------------------------------------------------
/models/raft_utils/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ResidualBlock(nn.Module):
6 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
7 | super(ResidualBlock, self).__init__()
8 |
9 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
10 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | num_groups = planes // 8
14 |
15 | if norm_fn == 'group':
16 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
17 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | if not stride == 1:
19 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
20 |
21 | elif norm_fn == 'batch':
22 | self.norm1 = nn.BatchNorm2d(planes)
23 | self.norm2 = nn.BatchNorm2d(planes)
24 | if not stride == 1:
25 | self.norm3 = nn.BatchNorm2d(planes)
26 |
27 | elif norm_fn == 'instance':
28 | self.norm1 = nn.InstanceNorm2d(planes)
29 | self.norm2 = nn.InstanceNorm2d(planes)
30 | if not stride == 1:
31 | self.norm3 = nn.InstanceNorm2d(planes)
32 |
33 | elif norm_fn == 'none':
34 | self.norm1 = nn.Sequential()
35 | self.norm2 = nn.Sequential()
36 | if not stride == 1:
37 | self.norm3 = nn.Sequential()
38 |
39 | if stride == 1:
40 | self.downsample = None
41 |
42 | else:
43 | self.downsample = nn.Sequential(
44 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
45 |
46 |
47 | def forward(self, x):
48 | y = x
49 | y = self.relu(self.norm1(self.conv1(y)))
50 | y = self.relu(self.norm2(self.conv2(y)))
51 |
52 | if self.downsample is not None:
53 | x = self.downsample(x)
54 |
55 | return self.relu(x+y)
56 |
57 |
58 | class BasicEncoder(nn.Module):
59 | def __init__(self, input_dim=3, output_dim=128, norm_fn='batch'):
60 | super(BasicEncoder, self).__init__()
61 | self.norm_fn = norm_fn
62 |
63 | if self.norm_fn == 'group':
64 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
65 | elif self.norm_fn == 'batch':
66 | self.norm1 = nn.BatchNorm2d(64)
67 | elif self.norm_fn == 'instance':
68 | self.norm1 = nn.InstanceNorm2d(64)
69 | elif self.norm_fn == 'none':
70 | self.norm1 = nn.Sequential()
71 | else:
72 | raise NotImplementedError
73 |
74 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3)
75 | self.relu1 = nn.ReLU(inplace=True)
76 |
77 | self.in_planes = 64
78 | self.layer1 = self._make_layer(64, stride=1)
79 | self.layer2 = self._make_layer(96, stride=2)
80 | self.layer3 = self._make_layer(128, stride=2)
81 |
82 | # output convolution
83 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
84 |
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
88 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
89 | if m.weight is not None:
90 | nn.init.constant_(m.weight, 1)
91 | if m.bias is not None:
92 | nn.init.constant_(m.bias, 0)
93 |
94 | def _make_layer(self, dim, stride=1):
95 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
96 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
97 | layers = (layer1, layer2)
98 |
99 | self.in_planes = dim
100 | return nn.Sequential(*layers)
101 |
102 |
103 | def forward(self, x):
104 |
105 | # if input is list, combine batch dimension
106 | is_list = isinstance(x, tuple) or isinstance(x, list)
107 | if is_list:
108 | batch_dim = x[0].shape[0]
109 | length = len(x)
110 | x = torch.cat(x, dim=0)
111 |
112 | x = self.conv1(x)
113 | x = self.norm1(x)
114 | x = self.relu1(x)
115 |
116 | x = self.layer1(x)
117 | x = self.layer2(x)
118 | x = self.layer3(x)
119 |
120 | x = self.conv2(x)
121 |
122 | if is_list:
123 | x = torch.split(x, [batch_dim]*length, dim=0)
124 |
125 | return x
--------------------------------------------------------------------------------
/models/raft_utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def bilinear_sampler(img, coords):
6 | """ Wrapper for grid_sample, uses pixel coordinates """
7 | # img (corr): batch*h1*w1, 1, h2, w2
8 | # coords: batch*h1*w1, 2*r+1, 2*r+1, 2
9 | H, W = img.shape[-2:]
10 | # *grid: batch*h1*w1, 2*r+1, 2*r+1, 1
11 | xgrid, ygrid = coords.split([1,1], dim=-1)
12 | # map *grid from [0, N-1] to [-1, 1]
13 | xgrid = 2*xgrid/(W-1) - 1
14 | ygrid = 2*ygrid/(H-1) - 1
15 |
16 | # grid: batch*h1*w1, 2*r+1, 2*r+1, 2
17 | grid = torch.cat([xgrid, ygrid], dim=-1)
18 | # img: batch*h1*w1, 1, 2*r+1, 2*r+1
19 | img = F.grid_sample(img, grid, align_corners=True)
20 |
21 | return img
22 |
23 |
24 | def coords_grid(batch, ht, wd, device):
25 | # ((ht, wd), (ht, wd))
26 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
27 | # 2, ht, wd
28 | coords = torch.stack(coords[::-1], dim=0).float()
29 | # batch, 2, ht, wd
30 | return coords[None].repeat(batch, 1, 1, 1)
31 |
32 |
33 | def cvx_upsample(data: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
34 | """ Upsample data [N, dim, H/8, W/8] -> [N, dim, H, W] using convex combination """
35 | N, dim, H, W = data.shape
36 | mask = mask.view(N, 1, 9, 8, 8, H, W)
37 | mask = torch.softmax(mask, dim=2)
38 |
39 | # NOTE: multiply by 8 due to the change in resolution.
40 | up_data = F.unfold(8 * data, [3, 3], padding=1)
41 | up_data = up_data.view(N, dim, 9, 1, 1, H, W)
42 |
43 | # N, dim, 8, 8, H, W
44 | up_data = torch.sum(mask * up_data, dim=2)
45 | # N, dim, H, 8, W, 8
46 | up_data = up_data.permute(0, 1, 4, 2, 5, 3)
47 | # N, dim, H*8, W*8
48 | return up_data.reshape(N, dim, 8*H, 8*W)
49 |
--------------------------------------------------------------------------------
/modules/data_loading.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Any, Union
3 |
4 | import pytorch_lightning as pl
5 | from omegaconf import DictConfig
6 | from torch.utils.data import DataLoader
7 |
8 | from data.dsec.provider import DatasetProvider as DatasetProviderDSEC
9 | from data.multiflow2d.provider import DatasetProvider as DatasetProviderMULTIFLOW2D
10 |
11 |
12 | class DataModule(pl.LightningDataModule):
13 | DSEC_STR = 'dsec'
14 | MULTIFLOW2D_REGEN_STR = 'multiflow_regen'
15 |
16 | def __init__(self,
17 | config: Union[Dict[str, Any], DictConfig],
18 | batch_size_train: int = 1,
19 | batch_size_val: int = 1):
20 | super().__init__()
21 | dataset_params = config['dataset']
22 | dataset_type = dataset_params['name']
23 | num_workers = config['hardware']['num_workers']
24 |
25 | assert dataset_type in {self.DSEC_STR, self.MULTIFLOW2D_REGEN_STR}
26 | self.dataset_type = dataset_type
27 |
28 | self.batch_size_train = batch_size_train
29 | self.batch_size_val = batch_size_val
30 |
31 | assert self.batch_size_train >= 1
32 | assert self.batch_size_val >= 1
33 |
34 | if num_workers is None:
35 | num_workers = 2*max([batch_size_train, batch_size_val])
36 | num_workers = min([num_workers, os.cpu_count()])
37 | print(f'num_workers: {num_workers}')
38 |
39 | self.num_workers = num_workers
40 | assert self.num_workers >= 0
41 |
42 | nbins_context = config['model']['num_bins']['context']
43 |
44 | if dataset_type == self.DSEC_STR:
45 | dataset_provider = DatasetProviderDSEC(dataset_params, nbins_context)
46 | elif dataset_type == self.MULTIFLOW2D_REGEN_STR:
47 | dataset_provider = DatasetProviderMULTIFLOW2D(dataset_params, nbins_context)
48 | else:
49 | raise NotImplementedError
50 | self.train_dataset = dataset_provider.get_train_dataset()
51 | if dataset_type == self.DSEC_STR:
52 | self.val_dataset = None
53 | self.test_dataset = dataset_provider.get_test_dataset()
54 | else:
55 | self.val_dataset = dataset_provider.get_val_dataset()
56 | self.test_dataset = None
57 |
58 | self.nbins_context = dataset_provider.get_nbins_context()
59 | self.nbins_correlation = dataset_provider.get_nbins_correlation()
60 |
61 | assert self.nbins_context == nbins_context
62 | # Fill in nbins_correlation here because it can depend on the dataset.
63 | if 'correlation' in config['model']['num_bins']:
64 | nbins_correlation = config['model']['num_bins']['correlation']
65 | if nbins_correlation is None:
66 | config['model']['num_bins']['correlation'] = self.nbins_correlation
67 | else:
68 | assert nbins_correlation == self.nbins_correlation
69 |
70 | def train_dataloader(self):
71 | return DataLoader(
72 | dataset=self.train_dataset,
73 | batch_size=self.batch_size_train,
74 | shuffle=True,
75 | num_workers=self.num_workers,
76 | pin_memory=True,
77 | drop_last=True)
78 |
79 | def val_dataloader(self):
80 | assert self.val_dataset is not None, f'No validation data found for {self.dataset_type} dataset'
81 | return DataLoader(
82 | dataset=self.val_dataset,
83 | batch_size=self.batch_size_val,
84 | shuffle=False,
85 | num_workers=self.num_workers,
86 | pin_memory=True,
87 | drop_last=True)
88 |
89 | def test_dataloader(self):
90 | assert self.test_dataset is not None, f'No test data found for {self.dataset_type} dataset'
91 | return DataLoader(
92 | dataset=self.test_dataset,
93 | batch_size=1,
94 | shuffle=False,
95 | num_workers=self.num_workers,
96 | pin_memory=True,
97 | drop_last=False)
98 |
99 | def get_nbins_context(self):
100 | return self.nbins_context
101 |
102 | def get_nbins_correlation(self):
103 | return self.nbins_correlation
104 |
--------------------------------------------------------------------------------
/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Mapping
2 | from functools import wraps
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 |
8 | def detach_tensors(cpu: bool=False):
9 | ''' Detach tensors and optionally move to cpu. Only detaches torch.Tensor instances '''
10 | # Decorator factory to enable decorator arguments: https://stackoverflow.com/a/50538967
11 | def allow_detach(key: str, value, train: bool):
12 | if train and key == 'loss':
13 | return False
14 | return isinstance(value, torch.Tensor)
15 | def detach_tensor(input_tensor: torch.Tensor, cpu: bool):
16 | assert isinstance(input_tensor, torch.Tensor)
17 | if cpu:
18 | return input_tensor.detach().cpu()
19 | return input_tensor.detach()
20 | def decorator(func):
21 | train = 'train' in func.__name__
22 |
23 | @wraps(func)
24 | def inner(*args, **kwargs):
25 | output = func(*args, **kwargs)
26 | if isinstance(output, Mapping):
27 | return {k: (detach_tensor(v, cpu) if allow_detach(k, v, train) else v) for k, v in output.items()}
28 | assert isinstance(output, torch.Tensor)
29 | if train:
30 | # Do not detach because this will be the loss function of the training hook, which must not be detached.
31 | return output
32 | return detach_tensor(output, cpu)
33 | return inner
34 | return decorator
35 |
36 |
37 | def reduce_ev_repr(ev_repr: torch.Tensor) -> torch.Tensor:
38 | # This function is useful to reduce the overhead of moving an event representation
39 | # to CPU for visualization.
40 | # For now simply sum up the time dimension to reduce the memory.
41 | assert isinstance(ev_repr, torch.Tensor)
42 | assert ev_repr.ndim == 4
43 | assert ev_repr.is_cuda
44 |
45 | return torch.sum(ev_repr, dim=1)
46 |
47 |
48 | class InputPadder:
49 | """ Pads input tensor such that the last two dimensions are divisible by min_size """
50 | def __init__(self, min_size: int=8, no_top_padding: bool=False):
51 | assert min_size > 0
52 | self.min_size = min_size
53 | self.no_top_padding = no_top_padding
54 | self._pad = None
55 |
56 | def requires_padding(self, input_tensor: torch.Tensor):
57 | ht, wd = input_tensor.shape[-2:]
58 | answer = False
59 | answer &= ht % self.min_size == 0
60 | answer &= wd % self.min_size == 0
61 | return answer
62 |
63 | def pad(self, input_tensor: torch.Tensor):
64 | ht, wd = input_tensor.shape[-2:]
65 | pad_ht = (((ht // self.min_size) + 1) * self.min_size - ht) % self.min_size
66 | pad_wd = (((wd // self.min_size) + 1) * self.min_size - wd) % self.min_size
67 | if self.no_top_padding:
68 | # Pad only bottom instead of top
69 | # RAFT uses this for KITTI
70 | pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
71 | else:
72 | # RAFT uses this for SINTEL (as default too)
73 | pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
74 | if self._pad is None:
75 | self._pad = pad
76 | else:
77 | assert self._pad == pad
78 | return F.pad(input_tensor, self._pad, mode='replicate')
79 |
80 | def unpad(self, input_tensor: torch.Tensor):
81 | ht, wd = input_tensor.shape[-2:]
82 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
83 | return input_tensor[..., c[0]:c[1], c[2]:c[3]]
84 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
5 | os.environ["OMP_NUM_THREADS"] = "1"
6 | os.environ["OPENBLAS_NUM_THREADS"] = "1"
7 | os.environ["MKL_NUM_THREADS"] = "1"
8 | os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
9 | os.environ["NUMEXPR_NUM_THREADS"] = "1"
10 |
11 | import hydra
12 | from omegaconf import DictConfig, OmegaConf
13 | import pytorch_lightning as pl
14 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelSummary, ModelCheckpoint
15 | from pytorch_lightning.strategies import DDPStrategy
16 | import torch
17 | import wandb
18 |
19 | from callbacks.logger import WandBImageLoggingCallback
20 | from utils.general import get_ckpt_callback
21 | from loggers.wandb_logger import WandbLogger
22 | from modules.data_loading import DataModule
23 | from modules.raft_spline import RAFTSplineModule
24 |
25 |
26 | @hydra.main(config_path='config', config_name='train', version_base='1.3')
27 | def main(cfg: DictConfig):
28 | print('------ Configuration ------\n')
29 | print(OmegaConf.to_yaml(cfg))
30 | print('---------------------------\n')
31 | config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
32 | # ------------
33 | # Args
34 | # ------------
35 | gpu_devices = config['hardware']['gpus']
36 | gpus = gpu_devices if isinstance(gpu_devices, list) else [gpu_devices]
37 | num_gpus = len(gpus)
38 |
39 | batch_size: int = config['training']['batch_size']
40 | assert batch_size > 0
41 | per_gpu_batch_size = batch_size
42 | if num_gpus == 1:
43 | strategy = 'auto'
44 | ddp_active = False
45 | else:
46 | strategy = DDPStrategy(process_group_backend='nccl',
47 | find_unused_parameters=False,
48 | gradient_as_bucket_view=True)
49 | ddp_active = True
50 | per_gpu_batch_size = batch_size // num_gpus
51 | assert_info = 'Batch size ({}) must be divisible by number of gpus ({})'.format(batch_size, num_gpus)
52 | assert batch_size * num_gpus == per_gpu_batch_size, assert_info
53 |
54 | limit_train_batches = float(config['training']['limit_train_batches'])
55 | if limit_train_batches > 1.0:
56 | limit_train_batches = int(limit_train_batches)
57 |
58 | limit_val_batches = float(config['training']['limit_val_batches'])
59 | if limit_val_batches > 1.0:
60 | limit_val_batches = int(limit_val_batches)
61 |
62 | # ------------
63 | # Data
64 | # ------------
65 | data_module = DataModule(
66 | config,
67 | batch_size_train=per_gpu_batch_size,
68 | batch_size_val=per_gpu_batch_size)
69 |
70 | num_bins_context = data_module.get_nbins_context()
71 | num_bins_corr = data_module.get_nbins_correlation()
72 | print(f'num_bins:\n\tcontext: {num_bins_context}\n\tcorrelation: {num_bins_corr}')
73 |
74 | # ------------
75 | # Logging
76 | # ------------
77 | wandb_config = config['wandb']
78 | wandb_runpath = wandb_config['wandb_runpath']
79 | if wandb_runpath is None:
80 | wandb_id = wandb.util.generate_id()
81 | print(f'new run: generating id {wandb_id}')
82 | else:
83 | wandb_id = Path(wandb_runpath).name
84 | print(f'using provided id {wandb_id}')
85 | logger = WandbLogger(
86 | project=wandb_config['project_name'],
87 | group=wandb_config['group_name'],
88 | wandb_id=wandb_id,
89 | log_model=True,
90 | save_last_only_final=False,
91 | save_code=True,
92 | config_args=config)
93 | resume_path = None
94 | if wandb_config['artifact_name'] is not None:
95 | artifact_runpath = wandb_config['artifact_runpath']
96 | if artifact_runpath is None:
97 | artifact_runpath = wandb_runpath
98 | if artifact_runpath is None:
99 | print(
100 | 'must specify wandb_runpath or artifact_runpath to restore a checkpoint/artifact. Cannot load artifact.')
101 | else:
102 | artifact_name = wandb_config['artifact_name']
103 | print(f'resuming checkpoint from runpath {artifact_runpath} and artifact name {artifact_name}')
104 | resume_path = logger.get_checkpoint(artifact_runpath, artifact_name)
105 | assert resume_path.exists()
106 | assert resume_path.suffix == '.ckpt', resume_path.suffix
107 |
108 | # ------------
109 | # Checkpoints
110 | # ------------
111 | checkpoint_callback = get_ckpt_callback(config=config)
112 |
113 | # ------------
114 | # Other Callbacks
115 | # ------------
116 | image_callback = WandBImageLoggingCallback(config['logging'])
117 |
118 | callback_list = None
119 | if config['debugging']['profiler'] is None:
120 | callback_list = [checkpoint_callback, image_callback]
121 | if config['training']['lr_scheduler']['use']:
122 | callback_list.append(LearningRateMonitor(logging_interval='step'))
123 | # ------------
124 | # Model
125 | # ------------
126 |
127 | if resume_path is not None and wandb_config['resume_only_weights']:
128 | print('Resuming only the weights instead of the full training state')
129 | net = RAFTSplineModule.load_from_checkpoint(str(resume_path), **{"config": config})
130 | resume_path = None
131 | else:
132 | net = RAFTSplineModule(config)
133 |
134 | # ------------
135 | # Training
136 | # ------------
137 | logger.watch(net, log='all', log_freq=int(config['logging']['log_every_n_steps']), log_graph=True)
138 |
139 | gradient_clip_val = config['training']['gradient_clip_val']
140 | if gradient_clip_val is not None and gradient_clip_val > 0:
141 | for param in net.parameters():
142 | param.register_hook(lambda grad: torch.clamp(grad, -gradient_clip_val, gradient_clip_val))
143 |
144 | # ---------------------
145 | # Callbacks and Misc
146 | # ---------------------
147 | callbacks = list()
148 | if config['training']['lr_scheduler']['use']:
149 | callbacks.append(LearningRateMonitor(logging_interval='step'))
150 | callbacks.append(ModelSummary(max_depth=2))
151 |
152 | trainer = pl.Trainer(
153 | accelerator='gpu',
154 | strategy=strategy,
155 | enable_checkpointing=True,
156 | sync_batchnorm=True if ddp_active else False,
157 | devices=gpus,
158 | logger=logger,
159 | precision=32,
160 | max_epochs=int(config['training']['max_epochs']),
161 | max_steps=int(config['training']['max_steps']),
162 | profiler=config['debugging']['profiler'],
163 | limit_train_batches=limit_train_batches,
164 | limit_val_batches=limit_val_batches,
165 | log_every_n_steps=int(config['logging']['log_every_n_steps']),
166 | callbacks=callback_list)
167 | trainer.fit(model=net, ckpt_path=resume_path, datamodule=data_module)
168 |
169 |
170 | if __name__ == '__main__':
171 | main()
172 |
--------------------------------------------------------------------------------
/utils/general.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | from functools import wraps
3 | from typing import Any, Dict, Union, List
4 |
5 | import numpy as np
6 | import torch
7 | from pytorch_lightning.callbacks import ModelCheckpoint
8 |
9 |
10 | def is_cpu(input_: Union[torch.Tensor, List[torch.Tensor]]) -> bool:
11 | if isinstance(input_, torch.Tensor):
12 | return input_.device == torch.device('cpu')
13 | assert isinstance(input_, list)
14 | on_cpu = True
15 | for x in input_:
16 | assert isinstance(x, torch.Tensor)
17 | on_cpu &= x.device == torch.device('cpu')
18 | return on_cpu
19 |
20 |
21 | def _convert_to_tensor(input_: Any):
22 | if input_ is None or isinstance(input_, torch.Tensor):
23 | return input_
24 | if isinstance(input_, np.ndarray):
25 | return torch.from_numpy(input_)
26 | if isinstance(input_, numbers.Number):
27 | return torch.tensor(input_)
28 | if isinstance(input_, dict):
29 | return {k: _convert_to_tensor(v) for k, v in input_.items()}
30 | if isinstance(input_, list):
31 | return [_convert_to_tensor(x) for x in input_]
32 | if isinstance(input_, tuple):
33 | return (_convert_to_tensor(x) for x in input_)
34 | return input_
35 |
36 |
37 | def inputs_to_tensor(func):
38 | @wraps(func)
39 | def inner(*args, **kwargs):
40 | args = _convert_to_tensor(args)
41 | kwargs = _convert_to_tensor(kwargs)
42 | return func(*args, **kwargs)
43 |
44 | return inner
45 |
46 |
47 | def _obj_has_function(obj, func_name: str):
48 | return hasattr(obj, func_name) and callable(getattr(obj, func_name))
49 |
50 |
51 | def _data_to_cpu(input_: Any):
52 | if input_ is None:
53 | return input_
54 | if isinstance(input_, torch.Tensor):
55 | return input_.cpu()
56 | if isinstance(input_, dict):
57 | return {k: _data_to_cpu(v) for k, v in input_.items()}
58 | if isinstance(input_, list):
59 | return [_data_to_cpu(x) for x in input_]
60 | if isinstance(input_, tuple):
61 | return (_data_to_cpu(x) for x in input_)
62 | assert _obj_has_function(input_, 'cpu')
63 | return input_.cpu()
64 |
65 |
66 | def to_cpu(func):
67 | ''' Move stuff to cpu '''
68 |
69 | @wraps(func)
70 | def inner(*args, **kwargs):
71 | output = func(*args, **kwargs)
72 | output = _data_to_cpu(output)
73 | return output
74 |
75 | return inner
76 |
77 |
78 | def _unwrap_len1_list_or_tuple(input_: Any):
79 | if isinstance(input_, tuple):
80 | if len(input_) == 1:
81 | return input_[0]
82 | return (_unwrap_len1_list_or_tuple(x) for x in input_)
83 | if isinstance(input_, list):
84 | if len(input_) == 1:
85 | return input_[0]
86 | return [_unwrap_len1_list_or_tuple(x) for x in input_]
87 | if isinstance(input_, dict):
88 | return {k: _unwrap_len1_list_or_tuple(v) for k, v in input_.items()}
89 | return input_
90 |
91 |
92 | def wrap_unwrap_lists_for_class_method(func):
93 | # We add "self" such that it can (only) be used on class methods
94 | @wraps(func)
95 | def inner(self, *args, **kwargs):
96 | # The reason why we have to explicitly add self in the arguments is that
97 | # we wrap the inputs in a list and would need to detect whether the input arg is self or not.
98 | args = [arg if isinstance(arg, list) or arg is None else [arg] for arg in args]
99 | kwargs = {k: (v if isinstance(v, list) or v is None else [v]) for k, v in kwargs.items()}
100 | out = func(self, *args, **kwargs)
101 | out = _unwrap_len1_list_or_tuple(out)
102 | return out
103 |
104 | return inner
105 |
106 |
107 | def get_ckpt_callback(config: Dict) -> ModelCheckpoint:
108 | dataset_name = config['dataset']['name']
109 | assert dataset_name in {'dsec', 'multiflow_regen'}
110 |
111 | if dataset_name == 'dsec':
112 | ckpt_callback_monitor = 'global_step'
113 | ckpt_filename = 'epoch={epoch:03d}-step={' + ckpt_callback_monitor + ':.0f}'
114 | mode = 'max' # because of global_step
115 | else:
116 | prefix = 'val'
117 | metric = 'epe_multi'
118 | ckpt_callback_monitor = f'{prefix}/{metric}'
119 | filename_monitor_string = f'{prefix}_{metric}'
120 | ckpt_filename = 'epoch={epoch:03d}-step={step}-' + filename_monitor_string + '={' + ckpt_callback_monitor + ':.2f}'
121 | mode = 'min'
122 |
123 | checkpoint_callback = ModelCheckpoint(
124 | monitor=ckpt_callback_monitor,
125 | filename=ckpt_filename,
126 | auto_insert_metric_name=False,
127 | save_top_k=1,
128 | mode=mode,
129 | every_n_epochs=config['logging']['ckpt_every_n_epochs'],
130 | save_last=True,
131 | verbose=True,
132 | )
133 | checkpoint_callback.CHECKPOINT_NAME_LAST = 'last_epoch={epoch:03d}-step={step}'
134 | return checkpoint_callback
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch as th
4 |
5 |
6 | def l1_loss_channel_masked(source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None):
7 | # source: (N, C, *)
8 | # target: (N, C, *), source.shape == target.shape
9 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1
10 | assert source.ndim > 2
11 | assert source.shape == target.shape
12 |
13 | loss = th.abs(source - target).sum(1)
14 | if valid_mask is not None:
15 | assert valid_mask.shape[0] == target.shape[0]
16 | assert valid_mask.ndim == target.ndim - 1
17 | assert valid_mask.dtype == th.bool
18 | assert loss.shape == valid_mask.shape
19 | loss_masked = loss[valid_mask].sum() / valid_mask.sum()
20 | return loss_masked
21 | return th.mean(loss)
22 |
23 |
24 | def l1_seq_loss_channel_masked(source_list: List[th.Tensor], target: th.Tensor, valid_mask: Optional[th.Tensor]=None, gamma: float=0.8):
25 | # source: [(N, C, *), ...], I predictions from I iterations
26 | # target: (N, C, *), source.shape == target.shape
27 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1
28 |
29 | # Adopted from https://github.com/princeton-vl/RAFT/blob/224320502d66c356d88e6c712f38129e60661e80/train.py#L47
30 |
31 | n_predictions = len(source_list)
32 | loss = 0
33 |
34 | for i in range(n_predictions):
35 | i_weight = gamma**(n_predictions - i - 1)
36 | i_loss = l1_loss_channel_masked(source_list[i], target, valid_mask)
37 | loss += i_weight * i_loss
38 |
39 | return loss
40 |
41 | def l1_multi_seq_loss_channel_masked(src_list_list: List[List[th.Tensor]], target_list: List[th.Tensor], valid_mask_list: Optional[List[th.Tensor]]=None, gamma: float=0.8):
42 | # src_list_list: [[(N, C, *), ...], ...], I*M predictions -> I iterations (outer) and M supervision targets (inner)
43 | # target_list: [(N, C, *), ...], M supervision targets
44 | # valid_mask_list: [(N, *), ...], M supervision targets
45 |
46 | loss = 0
47 |
48 | num_iters = len(src_list_list)
49 | for iter_idx, sources_per_iter in enumerate(src_list_list):
50 | # iteration over RAFT iterations
51 | num_targets = len(sources_per_iter)
52 | assert num_targets > 0
53 | assert num_targets == len(target_list)
54 | i_loss = 0
55 | for tindex, source in enumerate(sources_per_iter):
56 | # iteration over prediction times
57 | i_loss += l1_loss_channel_masked(source, target_list[tindex], valid_mask_list[tindex] if valid_mask_list is not None else None)
58 |
59 | i_loss /= num_targets
60 | i_weight = gamma**(num_iters - iter_idx - 1)
61 | loss += i_weight * i_loss
62 |
63 | return loss
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Optional
3 |
4 | import torch as th
5 | from torchmetrics import Metric
6 |
7 | from utils.losses import l1_loss_channel_masked
8 |
9 |
10 | class L1ChannelMasked(Metric):
11 | def __init__(self, dist_sync_on_step=False):
12 | super().__init__(dist_sync_on_step=dist_sync_on_step)
13 |
14 | self.add_state("l1", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum")
15 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum")
16 |
17 | def update(self, source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None):
18 | # source (prediction): (N, C, *),
19 | # target (ground truth): (N, C, *), source.shape == target.shape
20 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1
21 |
22 | self.l1 += l1_loss_channel_masked(source, target, valid_mask).double()
23 | self.total += 1
24 |
25 | def compute(self):
26 | assert self.total > 0
27 | return (self.l1 / self.total).float()
28 |
29 |
30 | class EPE(Metric):
31 | def __init__(self, dist_sync_on_step=False):
32 | super().__init__(dist_sync_on_step=dist_sync_on_step)
33 |
34 | self.add_state("epe", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum")
35 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum")
36 |
37 | def update(self, source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None):
38 | # source (prediction): (N, C, *),
39 | # target (ground truth): (N, C, *), source.shape == target.shape
40 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1
41 |
42 | epe = epe_masked(source, target, valid_mask)
43 | if epe is not None:
44 | self.epe += epe.double()
45 | self.total += 1
46 |
47 | def compute(self):
48 | assert self.total > 0
49 | return (self.epe / self.total).float()
50 |
51 | class EPE_MULTI(Metric):
52 | def __init__(self, dist_sync_on_step=False, min_traj_len=None, max_traj_len=None):
53 | super().__init__(dist_sync_on_step=dist_sync_on_step)
54 |
55 | self.add_state("epe", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum")
56 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum")
57 | self.min_traj_len = min_traj_len
58 | self.max_traj_len = max_traj_len
59 |
60 | @staticmethod
61 | def compute_traj_len(target: List[th.Tensor]):
62 | target_stack = th.stack(target, dim=0)
63 | diff = target_stack[1:] - target_stack[:-1]
64 | return diff.square().sum(dim=2).sqrt().sum(dim=0)
65 |
66 | def get_true_mask(self, target: List[th.Tensor], device: th.device):
67 | valid_shape = (target[0].shape[0],) + target[0].shape[2:]
68 | return th.ones(valid_shape, dtype=th.bool, device=device)
69 |
70 | def update(self, source: List[th.Tensor], target: List[th.Tensor], valid_mask: Optional[List[th.Tensor]]=None):
71 | # source_lst: [(N, C, *), ...], M evaluation/predictions
72 | # target_lst: [(N, C, *), ...], source_lst[*].shape == target_lst[*].shape
73 | # valid_mask_lst: [(N, *), ...], where the channel dimension is missing. I.e. valid_mask_lst[*].ndim == target_lst[*].ndim - 1
74 |
75 | if self.min_traj_len is not None or self.max_traj_len is not None:
76 | traj_len = self.compute_traj_len(target=target)
77 | valid_len = self.get_true_mask(target=target, device=target[0].device)
78 | if self.min_traj_len is not None:
79 | valid_len &= (traj_len >= self.min_traj_len)
80 | if self.max_traj_len is not None:
81 | valid_len &= (traj_len <= self.max_traj_len)
82 | if valid_mask is None:
83 | valid_mask = [valid_len.clone() for _ in range(len(target))]
84 | else:
85 | valid_mask = [valid_mask[idx] & valid_len for idx in range(len(target))]
86 |
87 | epe = epe_masked_multi(source, target, valid_mask)
88 | if epe is not None:
89 | self.epe += epe.double()
90 | self.total += 1
91 |
92 | def compute(self):
93 | assert self.total > 0
94 | return (self.epe / self.total).float()
95 |
96 | class AE(Metric):
97 | def __init__(self, degrees: bool=True, dist_sync_on_step=False):
98 | super().__init__(dist_sync_on_step=dist_sync_on_step)
99 |
100 | self.degrees = degrees
101 |
102 | self.add_state("ae", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum")
103 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum")
104 |
105 | def update(self, source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None):
106 | # source (prediction): (N, C, *),
107 | # target (ground truth): (N, C, *), source.shape == target.shape
108 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1
109 |
110 | self.ae += ae_masked(source, target, valid_mask, degrees=self.degrees).double()
111 | self.total += 1
112 |
113 | def compute(self):
114 | assert self.total > 0
115 | return (self.ae / self.total).float()
116 |
117 |
118 | class AE_MULTI(Metric):
119 | def __init__(self, degrees: bool=True, dist_sync_on_step=False):
120 | super().__init__(dist_sync_on_step=dist_sync_on_step)
121 |
122 | self.degrees = degrees
123 |
124 | self.add_state("ae", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum")
125 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum")
126 |
127 | def update(self, source: List[th.Tensor], target: List[th.Tensor], valid_mask: Optional[List[th.Tensor]]=None):
128 | # source_lst: [(N, C, *), ...], M evaluation/predictions
129 | # target_lst: [(N, C, *), ...], source_lst[*].shape == target_lst[*].shape
130 | # valid_mask_lst: [(N, *), ...], where the channel dimension is missing. I.e. valid_mask_lst[*].ndim == target_lst[*].ndim - 1
131 |
132 | self.ae += ae_masked_multi(source, target, valid_mask, degrees=self.degrees).double()
133 | self.total += 1
134 |
135 | def compute(self):
136 | assert self.total > 0
137 | return (self.ae / self.total).float()
138 |
139 | class NPE(Metric):
140 | def __init__(self, n_pixels: float, dist_sync_on_step=False):
141 | super().__init__(dist_sync_on_step=dist_sync_on_step)
142 |
143 | assert n_pixels > 0
144 | self.n_pixels = n_pixels
145 |
146 | self.add_state("npe", default=th.tensor(0, dtype=th.float64), dist_reduce_fx="sum")
147 | self.add_state("total", default=th.tensor(0, dtype=th.int64), dist_reduce_fx="sum")
148 |
149 | def update(self, source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None):
150 | # source (prediction): (N, C, *),
151 | # target (ground truth): (N, C, *), source.shape == target.shape
152 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1
153 |
154 | self.npe += n_pixel_error_masked(source, target, valid_mask, self.n_pixels).double()
155 | self.total += 1
156 |
157 | def compute(self):
158 | assert self.total > 0
159 | return (self.npe / self.total).float()
160 |
161 | def n_pixel_error_masked(source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor], n_pixels: float):
162 | # source: (N, C, *),
163 | # target: (N, C, *), source.shape == target.shape
164 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1
165 | assert source.ndim > 2
166 | assert source.shape == target.shape
167 |
168 | if valid_mask is not None:
169 | assert valid_mask.shape[0] == target.shape[0]
170 | assert valid_mask.ndim == target.ndim - 1
171 | assert valid_mask.dtype == th.bool
172 |
173 | num_valid = th.sum(valid_mask)
174 | assert num_valid > 0
175 |
176 | gt_flow_magn = th.linalg.norm(target, dim=1)
177 | error_magn = th.linalg.norm(source - target, dim=1)
178 |
179 | if valid_mask is not None:
180 | rel_error = th.zeros_like(error_magn)
181 | rel_error[valid_mask] = error_magn[valid_mask] / th.clip(gt_flow_magn[valid_mask], min=1e-6)
182 | else:
183 | rel_error = error_magn / th.clip(gt_flow_magn, min=1e-6)
184 |
185 | error_map = (error_magn > n_pixels) & (rel_error >= 0.05)
186 |
187 | if valid_mask is not None:
188 | error = error_map[valid_mask].sum() / num_valid
189 | else:
190 | error = th.mean(error_map.float())
191 |
192 | error *= 100
193 | return error
194 |
195 |
196 | def epe_masked(source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor] = None) -> Optional[th.Tensor]:
197 | # source: (N, C, *),
198 | # target: (N, C, *), source.shape == target.shape
199 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1
200 | assert source.ndim > 2
201 | assert source.shape == target.shape
202 |
203 | epe = th.sqrt(th.square(source - target).sum(1))
204 | if valid_mask is not None:
205 | assert valid_mask.shape[0] == target.shape[0]
206 | assert valid_mask.ndim == target.ndim - 1
207 | assert valid_mask.dtype == th.bool
208 | assert epe.shape == valid_mask.shape
209 | denominator = valid_mask.sum()
210 | if denominator == 0:
211 | return None
212 | return epe[valid_mask].sum() / denominator
213 | return th.mean(epe)
214 |
215 |
216 | def epe_masked_multi(source_lst: List[th.Tensor], target_lst: List[th.Tensor], valid_mask_lst: Optional[List[th.Tensor]] = None) -> Optional[th.Tensor]:
217 | # source_lst: [(N, C, *), ...], M evaluation/predictions
218 | # target_lst: [(N, C, *), ...], source_lst[*].shape == target_lst[*].shape
219 | # valid_mask_lst: [(N, *), ...], where the channel dimension is missing. I.e. valid_mask_lst[*].ndim == target_lst[*].ndim - 1
220 |
221 | num_preds = len(source_lst)
222 | assert num_preds > 0
223 | assert len(target_lst) == num_preds, len(target_lst)
224 | if valid_mask_lst is not None:
225 | assert len(valid_mask_lst) == num_preds, len(valid_mask_lst)
226 | else:
227 | valid_mask_lst = [None]*num_preds
228 | epe_sum = 0
229 | denominator = 0
230 | for source, target, valid_mask in zip(source_lst, target_lst, valid_mask_lst):
231 | epe = epe_masked(source, target, valid_mask)
232 | if epe is not None:
233 | epe_sum += epe
234 | denominator += 1
235 | if denominator == 0:
236 | return None
237 | epe_sum: th.Tensor
238 | return epe_sum / denominator
239 |
240 | def ae_masked_multi(source_lst: List[th.Tensor], target_lst: List[th.Tensor], valid_mask_lst: Optional[List[th.Tensor]]=None, degrees: bool=True) -> th.Tensor:
241 | # source_lst: [(N, C, *), ...], M evaluation/predictions
242 | # target_lst: [(N, C, *), ...], source_lst[*].shape == target_lst[*].shape
243 | # valid_mask_lst: [(N, *), ...], where the channel dimension is missing. I.e. valid_mask_lst[*].ndim == target_lst[*].ndim - 1
244 |
245 | num_preds = len(source_lst)
246 | assert num_preds > 0
247 | assert len(target_lst) == num_preds, len(target_lst)
248 | if valid_mask_lst is not None:
249 | assert len(valid_mask_lst) == num_preds, len(valid_mask_lst)
250 | else:
251 | valid_mask_lst = [None]*num_preds
252 | ae_sum = 0
253 | for source, target, valid_mask in zip(source_lst, target_lst, valid_mask_lst):
254 | ae_sum += ae_masked(source, target, valid_mask, degrees)
255 | ae_sum: th.Tensor
256 | return ae_sum / num_preds
257 |
258 |
259 | def ae_masked(source: th.Tensor, target: th.Tensor, valid_mask: Optional[th.Tensor]=None, degrees: bool=True) -> th.Tensor:
260 | # source: (N, C, *),
261 | # target: (N, C, *), source.shape == target.shape
262 | # valid_mask: (N, *), where the channel dimension is missing. I.e. valid_mask.ndim == target.ndim - 1
263 | assert source.ndim > 2
264 | assert source.shape == target.shape
265 |
266 | shape = list(source.shape)
267 | extension_shape = shape
268 | extension_shape[1] = 1
269 | extension = th.ones(extension_shape, device=source.device)
270 |
271 | source_ext = th.cat((source, extension), dim=1)
272 | target_ext = th.cat((target, extension), dim=1)
273 |
274 | # according to https://vision.middlebury.edu/flow/floweval-ijcv2011.pdf
275 |
276 | nominator = th.sum(source_ext * target_ext, dim=1)
277 | denominator = th.linalg.norm(source_ext, dim=1) * th.linalg.norm(target_ext, dim=1)
278 |
279 | tmp = th.div(nominator, denominator)
280 |
281 | # Somehow this seems necessary
282 | tmp[tmp > 1.0] = 1.0
283 | tmp[tmp < -1.0] = -1.0
284 |
285 | ae = th.acos(tmp)
286 | if degrees:
287 | ae = ae/math.pi*180
288 |
289 | if valid_mask is not None:
290 | assert valid_mask.shape[0] == target.shape[0]
291 | assert valid_mask.ndim == target.ndim - 1
292 | assert valid_mask.dtype == th.bool
293 | assert ae.shape == valid_mask.shape
294 | ae_masked = ae[valid_mask].sum() / valid_mask.sum()
295 | return ae_masked
296 | return th.mean(ae)
297 |
298 | def predictions_from_lin_assumption(source: th.Tensor, target_timestamps: List[float]) -> List[th.Tensor]:
299 | assert max(target_timestamps) <= 1
300 | assert 0 <= min(target_timestamps)
301 |
302 | output = list()
303 | for target_ts in target_timestamps:
304 | output.append(target_ts * source)
305 | return output
306 |
--------------------------------------------------------------------------------
/utils/timers.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import time
3 |
4 | import numpy as np
5 | import torch
6 |
7 | cuda_timers = {}
8 | timers = {}
9 |
10 |
11 | class CudaTimer:
12 | # NOTE: This timer seems to work fine.
13 | # However, it's unclear whether I have to synchronize just after entering the __enter__ function.
14 | def __init__(self, device: torch.device, timer_name: str = ''):
15 | self.timer_name = timer_name
16 | if self.timer_name not in cuda_timers:
17 | cuda_timers[self.timer_name] = []
18 |
19 | self.device = device
20 | self.start = None
21 | self.end = None
22 |
23 | def __enter__(self):
24 | torch.cuda.synchronize(device=self.device)
25 | self.start = time.time()
26 | return self
27 |
28 | def __exit__(self, *args):
29 | assert self.start is not None
30 | torch.cuda.synchronize(device=self.device)
31 | end = time.time()
32 | cuda_timers[self.timer_name].append(end - self.start)
33 |
34 |
35 | class CudaTimerDummy:
36 | def __init__(self, *args, **kwargs):
37 | pass
38 |
39 | def __enter__(self):
40 | pass
41 |
42 | def __exit__(self, *args):
43 | pass
44 |
45 |
46 | class Timer:
47 | def __init__(self, timer_name=''):
48 | self.timer_name = timer_name
49 | if self.timer_name not in timers:
50 | timers[self.timer_name] = []
51 |
52 | def __enter__(self):
53 | self.start = time.time()
54 | return self
55 |
56 | def __exit__(self, *args):
57 | end = time.time()
58 | time_diff_s = end - self.start # measured in seconds
59 | timers[self.timer_name].append(time_diff_s)
60 |
61 |
62 | def print_timing_info():
63 | print('== Timing statistics ==')
64 | skip_warmup = 2
65 | for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]:
66 | if len(timing_values) <= skip_warmup:
67 | continue
68 | values = timing_values[skip_warmup:]
69 | timing_value_s = np.mean(np.array(values))
70 | timing_value_ms = timing_value_s * 1000
71 | if timing_value_ms > 1000:
72 | print('{}: {:.2f} s'.format(timer_name, timing_value_s))
73 | else:
74 | print('{}: {:.2f} ms'.format(timer_name, timing_value_ms))
75 |
76 |
77 | # this will print all the timer values upon termination of any program that imported this file
78 | atexit.register(print_timing_info)
79 |
--------------------------------------------------------------------------------
/val.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
5 | os.environ["OMP_NUM_THREADS"] = "1"
6 | os.environ["OPENBLAS_NUM_THREADS"] = "1"
7 | os.environ["MKL_NUM_THREADS"] = "1"
8 | os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
9 | os.environ["NUMEXPR_NUM_THREADS"] = "1"
10 |
11 | import hydra
12 | from omegaconf import DictConfig, OmegaConf
13 | import pytorch_lightning as pl
14 | from pytorch_lightning.callbacks import ModelSummary
15 | from pytorch_lightning.loggers import CSVLogger
16 | import torch
17 |
18 | from modules.data_loading import DataModule
19 | from modules.raft_spline import RAFTSplineModule
20 |
21 |
22 | @hydra.main(config_path='config', config_name='val', version_base='1.3')
23 | def main(config: DictConfig):
24 | print('------ Configuration ------\n')
25 | print(OmegaConf.to_yaml(config))
26 | print('---------------------------\n')
27 | OmegaConf.to_container(config, resolve=True, throw_on_missing=True)
28 |
29 | # ------------
30 | # GPU Options
31 | # ------------
32 | gpus = config.hardware.gpus
33 | assert isinstance(gpus, int), 'no more than 1 gpu supported'
34 | gpus = [gpus]
35 |
36 | batch_size: int = config.batch_size
37 | assert batch_size > 0
38 |
39 | # ------------
40 | # Data
41 | # ------------
42 | data_module = DataModule(config, batch_size_train=batch_size, batch_size_val=batch_size)
43 |
44 | num_bins_context = data_module.get_nbins_context()
45 | num_bins_corr = data_module.get_nbins_correlation()
46 | print(f'num_bins:\n\tcontext: {num_bins_context}\n\tcorrelation: {num_bins_corr}')
47 |
48 | # ---------------------
49 | # Logging and Checkpoints
50 | # ---------------------
51 | logger = CSVLogger(save_dir='./validation_logs')
52 | ckpt_path = Path(config.checkpoint)
53 |
54 | # ------------
55 | # Model
56 | # ------------
57 |
58 | module = RAFTSplineModule.load_from_checkpoint(str(ckpt_path), **{'config': config})
59 |
60 | # ---------------------
61 | # Callbacks and Misc
62 | # ---------------------
63 | callbacks = [ModelSummary(max_depth=2)]
64 |
65 | trainer = pl.Trainer(
66 | accelerator='gpu',
67 | callbacks=callbacks,
68 | default_root_dir=None,
69 | devices=gpus,
70 | logger=logger,
71 | log_every_n_steps=100,
72 | precision=32,
73 | )
74 |
75 | with torch.inference_mode():
76 | trainer.validate(model=module, datamodule=data_module, ckpt_path=str(ckpt_path))
77 |
78 |
79 | if __name__ == '__main__':
80 | main()
81 |
--------------------------------------------------------------------------------