├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── configs
├── mvsec_indoor_burgers.yaml
└── mvsec_indoor_no_timeaware.yaml
├── conftest.py
├── datasets
├── .gitkeep
└── README.md
├── docs
└── img
│ ├── 2024_TPAMI_SecretsOfEVFlow_poster.pdf
│ └── secretsevflow_eccv22.jpg
├── main.py
├── outputs
└── .gitkeep
├── pyproject.toml
├── src
├── __init__.py
├── costs
│ ├── __init__.py
│ ├── base.py
│ ├── gradient_magnitude.py
│ ├── hybrid.py
│ ├── image_variance.py
│ ├── multi_focal_normalized_gradient_magnitude.py
│ ├── multi_focal_normalized_image_variance.py
│ ├── normalized_gradient_magnitude.py
│ ├── normalized_image_variance.py
│ └── total_variation.py
├── data_loader
│ ├── __init__.py
│ ├── base.py
│ └── mvsec.py
├── event_image_converter.py
├── feature_calculator.py
├── solver
│ ├── __init__.py
│ ├── base.py
│ ├── nnmodels
│ │ ├── __init__.py
│ │ ├── basic_layers.py
│ │ └── ev_flownet.py
│ ├── patch_contrast_base.py
│ ├── patch_contrast_mixed.py
│ ├── patch_contrast_pyramid.py
│ ├── scipy_autograd
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── base_wrapper.py
│ │ ├── scipy_minimize.py
│ │ └── torch_wrapper.py
│ └── time_aware_patch_contrast.py
├── types
│ ├── __init__.py
│ └── flow_patch.py
├── utils
│ ├── __init__.py
│ ├── event_utils.py
│ ├── flow_utils.py
│ ├── misc.py
│ └── stat_utils.py
├── visualizer.py
└── warp.py
└── tests
├── costs
├── test_gradient_magnitude.py
├── test_hybrid.py
└── test_image_variance.py
├── test_event_image_converter.py
├── test_warp.py
└── utils
├── test_event_utils.py
└── test_flow_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
3 | # VS Code files for those working on multiple tools
4 | .vscode
5 | !.vscode/settings.json
6 | !.vscode/tasks.json
7 | !.vscode/launch.json
8 | !.vscode/extensions.json
9 | *.code-workspace
10 |
11 | # Dataset and output large file
12 | datasets/*
13 | !datasets/.gitkeep
14 | !datasets/README.md
15 | outputs/*
16 | !outputs/.gitkeep
17 |
18 | *.log
19 | *.prof
20 | **/__pycache__/*
21 | **/*.pyc
22 | *.mp4
23 | *.avi
24 |
25 | wandb/*
26 | jupyter/*
27 | .ipynb_checkpoints
28 |
29 | **build
30 | **dist
31 | **.egg-info
32 |
33 | # Because we cannot assume pytorch version (CUDA, CPU etc), not commit .lock file.
34 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
35 | poetry.lock
36 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | lint:
2 | poetry run python3 -m mypy --install-types
3 | poetry run python3 -m mypy src/ main.py
4 |
5 | fmt:
6 | poetry run python3 -m black src/ tests/ main.py
7 | poetry run python3 -m isort ./src ./tests main.py
8 |
9 | run:
10 | poetry run python3 main.py --config_file ./configs/mvsec_indoor_no_timeaware.yaml
11 |
12 | test:
13 | poetry run pytest -vvv -c tests/ -o "testpaths=tests" -W ignore::DeprecationWarning
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 👀 **The extension paper has been accepted to IEEE T-PAMI! ([Paper](https://doi.org/10.1109/TPAMI.2024.3396116))**
2 |
3 | 👀 **We are now working to make this method more generic, easy-to-use functions (`flow = useful_function(events)`). Stay tuned!**
4 |
5 | # Secrets of Event-Based Optical Flow (T-PAMI 2024, ECCV 2022)
6 |
7 | This is the official repository for [**Secrets of Event-Based Optical Flow**](https://arxiv.org/abs/2207.10022), **ECCV 2022 Oral** by
8 | [Shintaro Shiba](http://shibashintaro.com/), [Yoshimitsu Aoki](https://aoki-medialab.jp/aokiyoshimitsu-en/) and [Guillermo Callego](http://www.guillermogallego.es).
9 |
10 | We have extended this paper to a journal version: [**Secrets of Event-based Optical Flow, Depth and Ego-motion Estimation by Contrast Maximization**](https://doi.org/10.1109/TPAMI.2024.3396116), **IEEE T-PAMI 2024**.
11 |
12 |
16 |
17 |
18 |
19 | [Paper (IEEE T-PAMI 2024)](https://hal.science/hal-04655247v1/document) | [Paper (ECCV 2022)](https://arxiv.org/pdf/2207.10022) | [Video](https://youtu.be/nUb2ZRPdbWk) | [Poster](docs/img/2024_TPAMI_SecretsOfEVFlow_poster.pdf)
20 |
21 |
22 | [](https://youtu.be/nUb2ZRPdbWk)
23 |
24 |
25 | If you use this work in your research, please cite it (see also [here](#citation)):
26 |
27 | ```bibtex
28 | @Article{Shiba24pami,
29 | author = {Shintaro Shiba and Yannick Klose and Yoshimitsu Aoki and Guillermo Gallego},
30 | title = {Secrets of Event-based Optical Flow, Depth, and Ego-Motion by Contrast Maximization},
31 | journal = {IEEE Trans. Pattern Anal. Mach. Intell. (T-PAMI)},
32 | year = 2024,
33 | pages = {1--18},
34 | doi = {10.1109/TPAMI.2024.3396116}
35 | }
36 |
37 | @InProceedings{Shiba22eccv,
38 | author = {Shintaro Shiba and Yoshimitsu Aoki and Guillermo Gallego},
39 | title = {Secrets of Event-based Optical Flow},
40 | booktitle = {European Conference on Computer Vision (ECCV)},
41 | pages = {628--645},
42 | doi = {10.1007/978-3-031-19797-0_36},
43 | year = 2022
44 | }
45 | ```
46 |
47 | ## **List of datasets that the flow estimation is tested on**
48 |
49 | Although this codebase releases just MVSEC examples,
50 | I have tested the flow estimation is roughly good in the below datasets.
51 | The list is being updated, and if you test new datasets please let us know.
52 |
53 | - [MVSEC](https://daniilidis-group.github.io/mvsec/)
54 | - [DSEC](https://dsec.ifi.uzh.ch/dsec-datasets/download/)
55 | - [ECD, both simulation and real data](http://rpg.ifi.uzh.ch/davis_data.html)
56 | - [TUM VIE](https://cvg.cit.tum.de/data/datasets/visual-inertial-event-dataset)
57 | - [UZH-FPV Drone Racing Dataset](https://fpv.ifi.uzh.ch/)
58 | - [EDS](https://rpg.ifi.uzh.ch/eds.html#dataset)
59 | - [M3ED](https://m3ed.io/)
60 |
61 | The above is all public datasets, and in our paper (T-PAMI 2024) we also used some non-public dataset from previous works.
62 |
63 | -------
64 | # Setup
65 |
66 | ## Requirements
67 |
68 | Although not all versions are strictly tested, the followings should work.
69 |
70 | - python: 3.8.x, 3.9.x, 3.10.x
71 |
72 | GPU is entirely optional.
73 | If `torch.cuda.is_available()` then it automatically switches to use GPU.
74 | I'd recomment to use GPU for time-aware solutions, but CPU is ok for no-timeaware method as long as I tested.
75 |
76 | ### Tested environments
77 |
78 | - Mac OS Monterey (both M1 and non-M1)
79 | - Ubuntu (CUDA 11.1, 11.3, 11.8)
80 | - PyTorch 1.9-1.12.1, or PyTorch 2.0 (1.13 raises an error during Burgers).
81 |
82 | ## Installation
83 |
84 | I strongly recommend to use venv: `python3 -m venv `
85 | Also, you can use [poetry]().
86 |
87 | - Install pytorch **< 1.13** or **>= 2.0** and torchvision for your environment. Make sure you install the correct CUDA version if you want to use it.
88 |
89 | - If you use poetry, `poetry install`. If you use only venv, check dependecy libraries and install it from [here](./pyproject.toml).
90 |
91 | - If you are having trouble to install pytorch with cuda using poetry refer to this [link](https://github.com/python-poetry/poetry/issues/6409).
92 |
93 | ## Download dataset
94 |
95 | Download each dataset under `./datasets` directory.
96 | Optionally you can specify other root directory:
97 | please check the [dataset readme](./datasets/README.md) for the details.
98 |
99 | # Execution
100 |
101 | ```shell
102 | python3 main.py --config_file ./configs/mvsec_indoor_no_timeaware.yaml
103 | ```
104 |
105 | If you use poetry, simply add `poetry run` at the beginning.
106 | Please run with `-h` option to know more about the other options.
107 |
108 | ## Config file
109 |
110 | The config (.yaml) file specifies various experimental settings.
111 | Please check and change parameters as you like.
112 |
113 | ### Optional tasks (for me)
114 |
115 | **The code here is already runnable, and explains the ideas of the paper enough.** (Please report bugs if any.)
116 |
117 | Rather than releasing all of my (sometimes too experimental) codes,
118 | I published just a minimal set of the codebase to reproduce.
119 | So the following tasks are more optional for me.
120 | But if it helps you, I can publish other parts as well. For example:
121 |
122 | - Other data loader
123 |
124 | - Some other cost functions
125 |
126 | - Pretrained model checkpoint file ✔️ [released for MVSEC](https://drive.google.com/file/d/13m-waAt5X0C7f0JLBwb6KAApYxgXoA2J/view?usp=sharing)
127 |
128 | - Other solver (especially DNN)
129 |
130 | - The implementation of [the Sensors paper]((https://www.mdpi.com/1424-8220/22/14/5190))
131 |
132 | Your feedback is helpful to prioritize the tasks, so please contact me or raise issues.
133 | The code is modularized well, so if you want to contribute, it should be easy too.
134 |
135 | # Citation
136 |
137 | If you use this work in your research, please cite it **as stated above**, below the video.
138 |
139 | This code also includes some implementation of the [following paper about event collapse in details](https://www.mdpi.com/1424-8220/22/14/5190).
140 | Please check it :)
141 |
142 | ```bibtex
143 | @Article{Shiba22sensors,
144 | author = {Shintaro Shiba and Yoshimitsu Aoki and Guillermo Gallego},
145 | title = {Event Collapse in Contrast Maximization Frameworks},
146 | journal = {Sensors},
147 | year = 2022,
148 | volume = 22,
149 | number = 14,
150 | pages = {1--20},
151 | article-number= 5190,
152 | doi = {10.3390/s22145190}
153 | }
154 | ```
155 |
156 | # Author
157 |
158 | Shintaro Shiba [@shiba24](https://github.com/shiba24)
159 |
160 | ## LICENSE
161 |
162 | Please check [License](./LICENSE).
163 |
164 | ## Acknowledgement
165 |
166 | I appreciate the following repositories for the inspiration:
167 |
168 | - [autograd-minimize](https://github.com/brunorigal/autograd-minimize)
169 | - [EVFlowNet-pytorch](https://github.com/CyrilSterling/EVFlowNet-pytorch)
170 |
171 | -------
172 | # Additional Resources
173 |
174 | * [Motion-prior Contrast Maximization (ECCV 2024)](https://github.com/tub-rip/MotionPriorCMax)
175 | * [EVILIP: Event-based Image Reconstruction as a Linear Inverse Problem (TPAMI 2022)](https://github.com/tub-rip/event_based_image_rec_inverse_problem)
176 | * [Event Collapse in Contrast Maximization Frameworks](https://github.com/tub-rip/event_collapse)
177 | * [CMax-SLAM (TRO 2024)](https://github.com/tub-rip/cmax_slam)
178 | * [EBOS: Event-based Background-Oriented Schlieren (TPAMI 2023)](https://github.com/tub-rip/event_based_bos)
179 | * [EPBA: Event-based Photometric Bundle Adjustment](https://github.com/tub-rip/epba)
180 | * [ES-PTAM: Event-based Stereo Parallel Tracking and Mapping](https://github.com/tub-rip/ES-PTAM)
181 | * [Research page (TU Berlin, RIP lab)](https://sites.google.com/view/guillermogallego/research/event-based-vision)
182 | * [Research page (Keio University, Aoki Media Lab)](https://aoki-medialab.jp/home-en/)
183 | * [Course at TU Berlin](https://sites.google.com/view/guillermogallego/teaching/event-based-robot-vision)
184 | * [Survey paper](http://rpg.ifi.uzh.ch/docs/EventVisionSurvey.pdf)
185 | * [List of Resources](https://github.com/uzh-rpg/event-based_vision_resources)
186 |
--------------------------------------------------------------------------------
/configs/mvsec_indoor_burgers.yaml:
--------------------------------------------------------------------------------
1 | is_dnn: false
2 |
3 | data:
4 | eval_dt: 4 # for MVSEC evaluation
5 | root: "~/local/event_based_optical_flow/datasets/MVSEC/hdf5"
6 | dataset: "MVSEC"
7 | sequence: "indoor_flying1"
8 | height: 260
9 | width: 346
10 | load_gt_flow: True
11 | gt: "~/local/event_based_optical_flow/datasets/MVSEC/gt_flow"
12 | n_events_per_batch: 30000
13 | ind1: 1130000
14 | ind2: 1160000
15 |
16 | output:
17 | output_dir: "./outputs/paper/burgers/indoor_flying1_dt4"
18 | show_interactive_result: false
19 |
20 | solver:
21 | method: "pyramidal_patch_contrast_maximization"
22 | time_aware: True
23 | time_bin: 10
24 | flow_interpolation: "burgers"
25 | t0_flow_location: "middle"
26 | patch:
27 | initialize: "random"
28 | scale: 5
29 | crop_height: 256
30 | crop_width: 336
31 | filter_type: "bilinear"
32 | motion_model: "2d-translation"
33 | warp_direction: "first"
34 | parameters:
35 | - "trans_x"
36 | - "trans_y"
37 | cost: "hybrid"
38 | outer_padding: 0
39 | cost_with_weight:
40 | multi_focal_normalized_gradient_magnitude: 1.
41 | total_variation: 0.01
42 | iwe:
43 | method: "bilinear_vote"
44 | blur_sigma: 1
45 |
46 | optimizer:
47 | n_iter: 40
48 | method: "Newton-CG"
49 | max_iter: 25
50 | parameters:
51 | trans_x:
52 | min: -150
53 | max: 150
54 | trans_y:
55 | min: -150
56 | max: 150
57 |
--------------------------------------------------------------------------------
/configs/mvsec_indoor_no_timeaware.yaml:
--------------------------------------------------------------------------------
1 | is_dnn: false
2 |
3 | data:
4 | eval_dt: 4 # for MVSEC evaluation
5 | root: "~/local/event_based_optical_flow/datasets/MVSEC/hdf5"
6 | dataset: "MVSEC"
7 | sequence: "indoor_flying1"
8 | height: 260
9 | width: 346
10 | load_gt_flow: True
11 | gt: "~/local/event_based_optical_flow/datasets/MVSEC/gt_flow"
12 | n_events_per_batch: 30000
13 | ind1: 1130000
14 | ind2: 1160000
15 |
16 | output:
17 | output_dir: "./outputs/paper/no_timeaware/indoor_flying1_dt4"
18 | show_interactive_result: false
19 |
20 | solver:
21 | method: "pyramidal_patch_contrast_maximization"
22 | time_aware: False
23 | patch:
24 | initialize: "random"
25 | scale: 5
26 | crop_height: 256
27 | crop_width: 336
28 | filter_type: "bilinear"
29 | motion_model: "2d-translation"
30 | warp_direction: "first"
31 | parameters:
32 | - "trans_x"
33 | - "trans_y"
34 | cost: "hybrid"
35 | outer_padding: 0
36 | cost_with_weight:
37 | multi_focal_normalized_gradient_magnitude: 1.
38 | total_variation: 0.01
39 | iwe:
40 | method: "bilinear_vote"
41 | blur_sigma: 1
42 |
43 | optimizer:
44 | n_iter: 40
45 | method: "Newton-CG"
46 | max_iter: 25
47 | parameters:
48 | trans_x:
49 | min: -150
50 | max: 150
51 | trans_y:
52 | min: -150
53 | max: 150
54 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
--------------------------------------------------------------------------------
/datasets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/event_based_optical_flow/8c12d21a91c22922bb150acb9b1bada2f4b23def/datasets/.gitkeep
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Folder structure
2 |
3 | The folder structure should follow below:
4 |
5 | ```shell
6 | tree -L 3
7 | .
8 | ├── MVSEC
9 | │ ├── hdf5
10 | │ ├── indoor_flying1_data.hdf5
11 | │ ...
12 | │ └── gt_flow
13 | │ ├── indoor_flying1_gt_flow_dist.npz
14 | │ ...
15 | └── README.md # this readme
16 | ```
17 |
18 | Please download datasets accordingly.
19 |
20 | - MVSEC data from (https://drive.google.com/drive/folders/1gDy2PwVOu_FPOsEZjojdWEB2ZHmpio8D)
21 |
22 | # Your own dataset location
23 |
24 | Optionally, you don't have to locate the files here,
25 | rather choose your dataset root directory location.
26 |
27 | You need to specify the root folder with the config yaml file.
28 |
29 |
--------------------------------------------------------------------------------
/docs/img/2024_TPAMI_SecretsOfEVFlow_poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/event_based_optical_flow/8c12d21a91c22922bb150acb9b1bada2f4b23def/docs/img/2024_TPAMI_SecretsOfEVFlow_poster.pdf
--------------------------------------------------------------------------------
/docs/img/secretsevflow_eccv22.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/event_based_optical_flow/8c12d21a91c22922bb150acb9b1bada2f4b23def/docs/img/secretsevflow_eccv22.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # type: ignore
2 | import argparse
3 | import logging
4 | import os
5 | import shutil
6 | import sys
7 |
8 | import numpy as np
9 | import yaml
10 | from tqdm import tqdm
11 |
12 | from src import data_loader, solver, utils, visualizer
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument(
18 | "--config_file",
19 | default="./configs/mvsec_indoor_no_timeaware.yaml",
20 | help="Config file yaml path",
21 | type=str,
22 | )
23 | parser.add_argument(
24 | "--eval",
25 | help="Add for evaluation run",
26 | action="store_true",
27 | )
28 | parser.add_argument(
29 | "--log", help="Log level: [debug, info, warning, error, critical]", type=str, default="info"
30 | )
31 | args = parser.parse_args()
32 | with open(args.config_file, "r") as f:
33 | config = yaml.safe_load(f)
34 | return config, args
35 |
36 |
37 | def save_config(save_dir: str, file_name: str, log_level=logging.INFO):
38 | """Save configuration"""
39 | if not os.path.exists(save_dir):
40 | os.makedirs(save_dir)
41 | shutil.copy(file_name, save_dir)
42 | logging.basicConfig(
43 | handlers=[
44 | logging.FileHandler(f"{save_dir}/main.log", mode="w"),
45 | logging.StreamHandler(sys.stdout),
46 | ],
47 | level=log_level,
48 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
49 | )
50 |
51 |
52 | def evaluate_mvsec_dataset_with_gt(eval_frame_time_stamp_list, data_config, loader, solv):
53 | logger.info("Evaluation pipeline")
54 | eval_dt = data_config["eval_dt"]
55 | assert eval_dt == 1 or eval_dt == 4
56 | logger.info(f"dt (for MVSEC) is {eval_dt}")
57 | n_events = data_config["n_events_per_batch"]
58 |
59 | for i1 in tqdm(range(len(eval_frame_time_stamp_list) - eval_dt)):
60 | logger.info(f"Frame {i1} of {len(eval_frame_time_stamp_list)}")
61 | try:
62 | if i1 < data_config["ind1"] or i1 > data_config["ind2"]:
63 | continue # cutofff
64 | except KeyError:
65 | pass
66 | t1 = eval_frame_time_stamp_list[i1]
67 | t2 = eval_frame_time_stamp_list[i1 + eval_dt]
68 | ind1 = loader.time_to_index(t1) # event index
69 | ind2 = loader.time_to_index(t2)
70 |
71 | # Flow error metrics calculation is based on GT flow + events between the consective GT flow frames
72 | batch_for_gt_slice = loader.load_event(ind1, ind2)
73 | gt_flow = loader.load_optical_flow(t1, t2)
74 | flow_time = t2 - t1
75 | batch_for_gt_slice[..., 2] -= np.min(batch_for_gt_slice[..., 2])
76 |
77 | # Optimization is based on fixed number of events
78 | if ind2 - ind1 < n_events:
79 | logger.info(
80 | f"Less events in one GT flow sequence. Events: {ind2-ind1} / Expected: {n_events}"
81 | )
82 | insufficient = n_events - (ind2 - ind1)
83 | ind1 -= insufficient // 2
84 | ind2 += insufficient // 2
85 | elif ind2 - ind1 > n_events:
86 | logger.info(
87 | f"Too many events in one GT flow sequence. Events: {ind2-ind1} / Expected: {n_events}"
88 | )
89 | ind1 = ind2 - n_events
90 |
91 | batch_for_optimization = loader.load_event(max(ind1, 0), min(ind2, len(loader)))
92 | batch_for_optimization[..., 2] -= np.min(batch_for_optimization[..., 2])
93 |
94 | if utils.check_key_and_bool(data_config, "remove_car"):
95 | logger.info("Remove car-boody pixels")
96 | batch_for_optimization = utils.crop_event(batch_for_optimization, 0, 193, 0, 346)
97 |
98 | best_motion = solv.optimize(batch_for_optimization)
99 | solv.set_previous_frame_best_estimation(best_motion)
100 | # mask with event
101 | flow_error_with_mask = solv.calculate_flow_error(best_motion, gt_flow, timescale=flow_time, events=batch_for_gt_slice) # type: ignore
102 | solv.save_flow_error_as_text(i1, flow_error_with_mask, "flow_error_per_frame_with_mask.txt") # type: ignore
103 |
104 | # Visualization
105 | solv.visualize_original_sequential(batch_for_gt_slice)
106 | solv.visualize_pred_sequential(batch_for_gt_slice, best_motion)
107 | solv.visualize_gt_sequential(batch_for_gt_slice, gt_flow)
108 |
109 |
110 | if __name__ == "__main__":
111 | config, args = parse_args()
112 | data_config: dict = config["data"]
113 | out_config: dict = config["output"]
114 | log_level = getattr(logging, args.log.upper(), None)
115 | if not isinstance(log_level, int):
116 | raise ValueError("Invalid log level: %s" % log_level)
117 | save_config(out_config["output_dir"], args.config_file, log_level)
118 | logger = logging.getLogger(__name__)
119 |
120 | if utils.check_key_and_bool(config, "fix_random_seed"):
121 | utils.fix_random_seed()
122 |
123 | # Visualizer
124 | image_shape = (data_config["height"], data_config["width"])
125 | if config["is_dnn"] and "crop" in data_config["preprocess"].keys():
126 | image_shape = (data_config["preprocess"]["crop"]["height"], data_config["preprocess"]["crop"]["width"]) # type: ignore
127 |
128 | viz = visualizer.Visualizer(
129 | image_shape,
130 | show=out_config["show_interactive_result"],
131 | save=True,
132 | save_dir=out_config["output_dir"],
133 | )
134 |
135 | # Loader
136 | loader = data_loader.collections[data_config["dataset"]](config=data_config)
137 | loader.set_sequence(data_config["sequence"])
138 |
139 | # Solver
140 | method_name = config["solver"]["method"]
141 | solv: solver.SolverBase = solver.collections[method_name](
142 | image_shape,
143 | calibration_parameter=loader.load_calib(),
144 | solver_config=config["solver"],
145 | optimizer_config=config["optimizer"],
146 | output_config=config["output"],
147 | visualize_module=viz,
148 | )
149 |
150 | if args.eval: # Run evaluation piipeline.
151 | if config["is_dnn"]:
152 | e = "DNN code is not published."
153 | logger.error(e)
154 | raise NotImplementedError(e)
155 | else:
156 | logger.info("Sequential optimization")
157 | assert loader.gt_flow_available # evaluate with GT flow
158 | logger.info("evaluation with GT")
159 | eval_frame_time_stamp_list = loader.eval_frame_time_list()
160 | evaluate_mvsec_dataset_with_gt(eval_frame_time_stamp_list, data_config, loader, solv)
161 | logger.info(f"Evaluation done! {data_config['sequence']}")
162 | exit()
163 |
164 | # Not evaluation - single frame optimization
165 | if config["is_dnn"]:
166 | e = "DNN code is not published."
167 | logger.error(e)
168 | raise NotImplementedError(e)
169 | else: # For non-DNN method
170 | logger.info("Single-frame optimization")
171 | ind1, ind2 = data_config["ind1"], data_config["ind2"]
172 | batch: np.ndarray = loader.load_event(ind1, ind2)
173 | batch[..., 2] -= np.min(batch[..., 2])
174 |
175 | if utils.check_key_and_bool(data_config, "remove_car"):
176 | batch = utils.crop_event(batch, 0, 193, 0, 346) # remvoe MVSEC car
177 |
178 | solv.visualize_one_batch_warp(batch)
179 | best_motion: np.ndarray = solv.optimize(batch)
180 | solv.visualize_one_batch_warp(batch, best_motion)
181 |
182 | # Calculate Flow error when GT is available
183 | if loader.gt_flow_available:
184 | t1 = loader.index_to_time(ind1)
185 | t2 = loader.index_to_time(ind2)
186 | gt_flow = loader.load_optical_flow(t1, t2)
187 |
188 | solv.visualize_one_batch_warp_gt(batch, gt_flow)
189 | solv.calculate_flow_error(best_motion, gt_flow, t2 - t1, batch)
190 |
--------------------------------------------------------------------------------
/outputs/.gitkeep:
--------------------------------------------------------------------------------
1 | !gallego_cvpr2018
2 | !gallego_cvpr2019
3 | !mitrokhin_iros2018
4 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "event-based-optical-flow"
3 | version = "0.1.0"
4 | description = ""
5 | authors = ["Shintaro Shiba "]
6 |
7 | [tool.poetry.dependencies]
8 | python = ">=3.8,<3.11"
9 | numpy = "^1.21.2"
10 | torch = ">=1.9.1"
11 | torchvision = ">=0.10.1"
12 | opencv-python = "^4.5.3"
13 | matplotlib = "^3.4.3"
14 | argparse = "^1.4.0"
15 | PyYAML = ">=5.4.1"
16 | Pillow = "^8.3.2"
17 | optuna = "^2.10.0"
18 | sklearn = "^0.0"
19 | wandb = "^0.12.4"
20 | hdf5plugin = "^3.2.0"
21 | h5py = "^3.5.0"
22 | plotly = "^5.4.0"
23 | scikit-image = "^0.19.1"
24 | ffmpeg-python = "^0.2.0"
25 |
26 | [tool.poetry.dev-dependencies]
27 | pytest = "^6.2.5"
28 | black = "^21.9b0"
29 | mypy = "^0.910"
30 |
31 | [build-system]
32 | requires = ["poetry-core>=1.0.0"]
33 | build-backend = "poetry.core.masonry.api"
34 |
35 | [tool.black]
36 | line-length = 100
37 | target-version = ['py39']
38 | include = '\.pyi?$'
39 |
40 | [tool.mypy]
41 | ignore_missing_imports = true
42 |
43 | [tool.isort]
44 | profile = "black"
45 | line_length = 119
46 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tub-rip/event_based_optical_flow/8c12d21a91c22922bb150acb9b1bada2f4b23def/src/__init__.py
--------------------------------------------------------------------------------
/src/costs/__init__.py:
--------------------------------------------------------------------------------
1 | """isort:skip_file
2 | """
3 | # Basics
4 | from .base import CostBase
5 | from .gradient_magnitude import GradientMagnitude
6 | from .image_variance import ImageVariance
7 |
8 | # from .zhu_average_timestamp import ZhuAverageTimestamp
9 | # from .paredes_average_timestamp import ParedesAverageTimestamp
10 |
11 | # Flow related
12 | from .total_variation import TotalVariation
13 |
14 | # Normalized ~
15 | from .normalized_image_variance import NormalizedImageVariance
16 | from .normalized_gradient_magnitude import NormalizedGradientMagnitude
17 |
18 | # Multi-reference ~
19 | from .multi_focal_normalized_image_variance import MultiFocalNormalizedImageVariance
20 | from .multi_focal_normalized_gradient_magnitude import MultiFocalNormalizedGradientMagnitude
21 |
22 |
23 | def inheritors(klass):
24 | subclasses = set()
25 | work = [klass]
26 | while work:
27 | parent = work.pop()
28 | for child in parent.__subclasses__():
29 | if child not in subclasses:
30 | subclasses.add(child)
31 | work.append(child)
32 | return subclasses
33 |
34 |
35 | functions = {k.name: k for k in inheritors(CostBase)}
36 |
37 | # For hybrid loss
38 | from .hybrid import HybridCost
39 |
--------------------------------------------------------------------------------
/src/costs/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Dict, List
3 |
4 | import torch
5 |
6 | from ..types import FLOAT_TORCH
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class CostBase(object):
12 | """Base of the Cost class.
13 | Args:
14 | direction (str) ... 'minimize' or 'maximize' or `natural`.
15 | Defines the objective function. If natural, it returns more interpretable value.
16 | """
17 |
18 | required_keys: List[str] = []
19 |
20 | def __init__(self, direction="minimize", store_history: bool = False, *args, **kwargs):
21 | if direction not in ["minimize", "maximize", "natural"]:
22 | e = f"direction should be minimize, maximize, and natural. Got {direction}."
23 | logger.error(e)
24 | raise ValueError(e)
25 | self.direction = direction
26 | self.store_history = store_history
27 | self.clear_history()
28 |
29 | def catch_key_error(func):
30 | """Wrapper utility function to catch the key error."""
31 |
32 | def wrapper(self, arg: dict):
33 | try:
34 | return func(self, arg) # type: ignore
35 | except KeyError as e:
36 | logger.error("Input for the cost needs keys of:")
37 | logger.error(self.required_keys)
38 | raise e
39 |
40 | return wrapper
41 |
42 | def register_history(func):
43 | """Registr history of the loss."""
44 |
45 | def wrapper(self, arg: dict):
46 | loss = func(self, arg) # type: ignore
47 | if self.store_history:
48 | self.history["loss"].append(self.get_item(loss))
49 | return loss
50 |
51 | return wrapper
52 |
53 | def get_item(self, loss: FLOAT_TORCH) -> float:
54 | if isinstance(loss, torch.Tensor):
55 | return loss.item()
56 | return loss
57 |
58 | def clear_history(self) -> None:
59 | self.history: Dict[str, list] = {"loss": []}
60 |
61 | def get_history(self) -> dict:
62 | return self.history.copy()
63 |
64 | def enable_history_register(self) -> None:
65 | self.store_history = True
66 |
67 | def disable_history_register(self) -> None:
68 | self.store_history = False
69 |
70 | # Every subclass needs to implement calculate()
71 | @register_history # type: ignore
72 | @catch_key_error # type: ignore
73 | def calculate(self, arg: dict) -> FLOAT_TORCH:
74 | raise NotImplementedError
75 |
76 | catch_key_error = staticmethod(catch_key_error) # type: ignore
77 | register_history = staticmethod(register_history) # type: ignore
78 |
--------------------------------------------------------------------------------
/src/costs/gradient_magnitude.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 |
7 | from ..types import FLOAT_TORCH
8 | from ..utils import SobelTorch
9 | from . import CostBase
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class GradientMagnitude(CostBase):
15 | """Gradient Magnitude loss from Gallego et al. CVPR 2019.
16 |
17 | Args:
18 | direction (str) ... 'minimize' or 'maximize' or `natural`.
19 | Defines the objective function. If natural, it returns more interpretable value.
20 | """
21 |
22 | name = "gradient_magnitude"
23 | required_keys = ["iwe", "omit_boundary"]
24 |
25 | def __init__(
26 | self,
27 | direction="minimize",
28 | store_history: bool = False,
29 | cuda_available=False,
30 | precision="32",
31 | *args,
32 | **kwargs,
33 | ):
34 | super().__init__(direction=direction, store_history=store_history)
35 | self.precision = precision
36 | self.torch_sobel = SobelTorch(
37 | ksize=3, in_channels=1, cuda_available=cuda_available, precision=precision
38 | )
39 |
40 | @CostBase.register_history # type: ignore
41 | @CostBase.catch_key_error # type: ignore
42 | def calculate(self, arg: dict) -> FLOAT_TORCH:
43 | """Calculate gradient of IWE.
44 | Inputs:
45 | iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of warped events
46 | omit_bondary (bool) ... Omit boundary if True.
47 |
48 | Returns:
49 | (Union[float, torch.Tensor]) ... Magnitude of gradient
50 | """
51 | iwe = arg["iwe"]
52 | if isinstance(iwe, torch.Tensor):
53 | return self.calculate_torch(iwe, arg["omit_boundary"])
54 | elif isinstance(iwe, np.ndarray):
55 | return self.calculate_numpy(iwe, arg["omit_boundary"])
56 | e = f"Unsupported input type. {type(iwe)}."
57 | logger.error(e)
58 | raise NotImplementedError(e)
59 |
60 | def calculate_torch(self, iwe: torch.Tensor, omit_boundary: bool) -> torch.Tensor:
61 | if len(iwe.shape) == 2:
62 | iwe = iwe[None, None, ...]
63 | elif len(iwe.shape) == 3:
64 | iwe = iwe[:, None, ...]
65 | if self.precision == "64":
66 | iwe = iwe.double()
67 | iwe_sobel = self.torch_sobel.forward(iwe) / 8.0
68 | gx = iwe_sobel[:, 0]
69 | gy = iwe_sobel[:, 1]
70 | if omit_boundary:
71 | gx = gx[..., 1:-1, 1:-1]
72 | gy = gy[..., 1:-1, 1:-1]
73 | magnitude = torch.mean(torch.square(gx) + torch.square(gy))
74 | if self.direction == "minimize":
75 | return -magnitude
76 | return magnitude
77 |
78 | def calculate_numpy(self, iwe: np.ndarray, omit_boundary: bool) -> float:
79 | """Calculate contrast of the count image.
80 | Inputs:
81 | iwe (np.ndarray) ... [W, H]. Image of warped events
82 | omit_bondary (bool) ... Omit boundary if True.
83 |
84 | Returns:
85 | (float) ... magnitude of gradient.
86 | """
87 | gx = cv2.Sobel(iwe, cv2.CV_64F, 1, 0, ksize=3) / 8.0
88 | gy = cv2.Sobel(iwe, cv2.CV_64F, 0, 1, ksize=3) / 8.0
89 | if omit_boundary:
90 | gx = gx[..., 1:-1, 1:-1]
91 | gy = gy[..., 1:-1, 1:-1]
92 | magnitude = np.mean(np.square(gx) + np.square(gy))
93 | if self.direction == "minimize":
94 | return -magnitude
95 | return magnitude
96 |
--------------------------------------------------------------------------------
/src/costs/hybrid.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Union
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from . import CostBase, functions
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class HybridCost(CostBase):
13 | """Hybrid cost function with arbitrary weight.
14 |
15 | Args:
16 | direction (str) ... 'minimize' or 'maximize'.
17 | cost_with_weight (dict) ... key is the name of the cost, value is its weight.
18 | """
19 |
20 | name = "hybrid"
21 |
22 | def __init__(
23 | self, direction: str, cost_with_weight: dict, store_history: bool = False, *args, **kwargs
24 | ):
25 | logger.info(f"Log functions are mix of {cost_with_weight}")
26 | self.cost_func = {
27 | key: {
28 | "func": functions[key](
29 | direction=direction, store_history=store_history, *args, **kwargs
30 | ),
31 | "weight": value,
32 | }
33 | for key, value in cost_with_weight.items()
34 | }
35 | super().__init__(direction=direction, store_history=store_history)
36 |
37 | self.required_keys = []
38 | for name in self.cost_func.keys():
39 | self.required_keys.extend(self.cost_func[name]["func"].required_keys)
40 |
41 | def update_weight(self, cost_with_weight):
42 | assert set(self.cost_func.keys()) == set(cost_with_weight.keys())
43 | for key in cost_with_weight.keys():
44 | self.cost_func[key]["weight"] = cost_with_weight[key]
45 |
46 | @CostBase.register_history # type: ignore
47 | @CostBase.catch_key_error # type: ignore
48 | def calculate(self, arg: dict) -> Union[float, torch.Tensor]:
49 | loss = 0.0
50 | for name in self.cost_func.keys():
51 | if self.cost_func[name]["weight"] == "inv":
52 | _l = 1.0 / self.cost_func[name]["func"].calculate(arg)
53 | loss += _l
54 | else:
55 | _l = self.cost_func[name]["weight"] * self.cost_func[name]["func"].calculate(arg)
56 | loss += _l
57 | return loss
58 |
59 | # For hybrid cost function, need to store with its name
60 | def clear_history(self) -> None:
61 | self.history = {"loss": []}
62 | for name in self.cost_func.keys():
63 | self.cost_func[name]["func"].clear_history()
64 |
65 | def get_history(self) -> dict:
66 | dic = self.history.copy()
67 | for name in self.cost_func.keys():
68 | dic.update({name: self.cost_func[name]["func"].get_history()["loss"]})
69 | return dic
70 |
71 | def enable_history_register(self) -> None:
72 | self.store_history = True
73 | for name in self.cost_func.keys():
74 | self.cost_func[name]["func"].store_history = True
75 |
76 | def disable_history_register(self) -> None:
77 | self.store_history = False
78 | for name in self.cost_func.keys():
79 | self.cost_func[name]["func"].store_history = False
80 |
--------------------------------------------------------------------------------
/src/costs/image_variance.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Union
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from . import CostBase
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class ImageVariance(CostBase):
13 | """Image Variance from Gallego et al. CVPR 2018.
14 | Args:
15 | direction (str) ... 'minimize' or 'maximize' or `natural`.
16 | Defines the objective function. If natural, it returns more interpretable value.
17 | """
18 |
19 | name = "image_variance"
20 | required_keys = ["iwe", "omit_boundary"]
21 |
22 | def __init__(self, direction="minimize", store_history: bool = False, *args, **kwargs):
23 | super().__init__(direction=direction, store_history=store_history)
24 |
25 | @CostBase.register_history # type: ignore
26 | @CostBase.catch_key_error # type: ignore
27 | def calculate(self, arg: dict) -> Union[float, torch.Tensor]:
28 | """Calculate contrast of the IWE.
29 | Inputs:
30 | iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of warped events
31 | omit_bondary (bool) ... Omit boundary if True.
32 |
33 | Returns:
34 | contrast (Union[float, torch.Tensor]) ... contrast of the image.
35 | """
36 | iwe = arg["iwe"]
37 | if arg["omit_boundary"]:
38 | iwe = iwe[..., 1:-1, 1:-1] # omit boundary
39 | if isinstance(iwe, torch.Tensor):
40 | return self.calculate_torch(iwe)
41 | elif isinstance(iwe, np.ndarray):
42 | return self.calculate_numpy(iwe)
43 | e = f"Unsupported input type. {type(iwe)}."
44 | logger.error(e)
45 | raise NotImplementedError(e)
46 |
47 | def calculate_torch(self, iwe: torch.Tensor) -> torch.Tensor:
48 | """Calculate contrast of the IWE.
49 | Inputs:
50 | iwe (torch.Tensor) ... [W, H]. Image of warped events
51 |
52 | Returns:
53 | loss (torch.Tensor) ... contrast of the image.
54 | """
55 | loss = torch.var(iwe)
56 | if self.direction == "minimize":
57 | return -loss
58 | return loss
59 |
60 | def calculate_numpy(self, iwe: np.ndarray) -> float:
61 | """Calculate contrast of the IWE.
62 | Inputs:
63 | iwe (np.ndarray) ... [W, H]. Image of warped events
64 |
65 | Returns:
66 | contrast (float) ... contrast of the image.
67 | """
68 | loss = np.var(iwe)
69 | if self.direction == "minimize":
70 | return -loss
71 | return loss
72 |
--------------------------------------------------------------------------------
/src/costs/multi_focal_normalized_gradient_magnitude.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from ..types import FLOAT_TORCH
8 | from . import CostBase, NormalizedGradientMagnitude
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class MultiFocalNormalizedGradientMagnitude(CostBase):
14 | """Multi-focus normalized gradient magnitude, Shiba et al. ECCV 2022.
15 |
16 | Args:
17 | direction (str) ... 'minimize' or 'maximize' or `natural`.
18 | Defines the objective function. If natural, it returns more interpretable value.
19 | """
20 |
21 | name = "multi_focal_normalized_gradient_magnitude"
22 | required_keys = ["forward_iwe", "backward_iwe", "middle_iwe", "omit_boundary", "orig_iwe"]
23 |
24 | def __init__(
25 | self,
26 | direction="minimize",
27 | store_history: bool = False,
28 | cuda_available=False,
29 | precision="32",
30 | *args,
31 | **kwargs,
32 | ):
33 | super().__init__(direction=direction, store_history=store_history)
34 | self.gradient_loss = NormalizedGradientMagnitude(
35 | direction=direction, cuda_available=cuda_available, precision=precision
36 | )
37 |
38 | @CostBase.register_history # type: ignore
39 | @CostBase.catch_key_error # type: ignore
40 | def calculate(self, arg: dict) -> FLOAT_TORCH:
41 | """Calculate cost.
42 | Inputs:
43 | orig_iwe (np.ndarray or torch.Tensor) ... Original IWE (before any warp).
44 | forward_iwe (np.ndarray or torch.Tensor) ... IWE to forward warp.
45 | backward_iwe (np.ndarray or torch.Tensor) ... IWE to backward warp.
46 | middle_iwe (Optional[np.ndarray or torch.Tensor]) ... IWE to middle warp.
47 | omit_bondary (bool) ... Omit boundary if True.
48 |
49 | Returns:
50 | average_time (Union[float, torch.Tensor]) ... Average timestamp.
51 | """
52 | orig_iwe = arg["orig_iwe"]
53 | forward_iwe = arg["forward_iwe"]
54 | if "middle_iwe" in arg.keys():
55 | middle_iwe = arg["middle_iwe"]
56 | else:
57 | middle_iwe = None
58 | backward_iwe = arg["backward_iwe"]
59 | omit_boundary = arg["omit_boundary"]
60 |
61 | if isinstance(forward_iwe, torch.Tensor):
62 | return self.calculate_torch(
63 | orig_iwe, forward_iwe, backward_iwe, middle_iwe, omit_boundary
64 | )
65 | elif isinstance(forward_iwe, np.ndarray):
66 | return self.calculate_numpy(
67 | orig_iwe, forward_iwe, backward_iwe, middle_iwe, omit_boundary
68 | )
69 | e = f"Unsupported input type. {type(forward_iwe)}."
70 | logger.error(e)
71 | raise NotImplementedError(e)
72 |
73 | def calculate_torch(
74 | self,
75 | orig_iwe: torch.Tensor,
76 | forward_iwe: torch.Tensor,
77 | backward_iwe: torch.Tensor,
78 | middle_iwe: Optional[torch.Tensor],
79 | omit_boundary: bool,
80 | ) -> torch.Tensor:
81 | """Calculate cost for torch tensor.
82 | Inputs:
83 | orig_iwe (torch.Tensor) ... Original IWE (before warp).
84 | forward_iwe (torch.Tensor) ... IWE to forward warp.
85 | backward_iwe (torch.Tensor) ... IWE to backward warp.
86 | middle_iwe (Optional[torch.Tensor]) ... IWE to middle warp.
87 |
88 | Returns:
89 | loss (torch.Tensor) ... average time loss.
90 | """
91 | forward_loss = self.gradient_loss.calculate_torch(forward_iwe, orig_iwe, omit_boundary)
92 | backward_loss = self.gradient_loss.calculate_torch(backward_iwe, orig_iwe, omit_boundary)
93 | loss = forward_loss + backward_loss
94 |
95 | if middle_iwe is not None:
96 | loss += self.gradient_loss.calculate_torch(middle_iwe, orig_iwe, omit_boundary) * 2
97 |
98 | if self.direction in ["minimize", "natural"]:
99 | return loss
100 | logger.warning("The loss is specified as maximize direction")
101 | return -loss
102 |
103 | def calculate_numpy(
104 | self,
105 | orig_iwe: np.ndarray,
106 | forward_iwe: np.ndarray,
107 | backward_iwe: np.ndarray,
108 | middle_iwe: Optional[np.ndarray],
109 | omit_boundary: bool,
110 | ) -> float:
111 | """Calculate cost for numpy array.
112 | Inputs:
113 | orig_iwe (np.ndarray) ... Original IWE (before warp).
114 | forward_iwe (np.ndarray) ... IWE to forward warp.
115 | backward_iwe (np.ndarray) ... IWE to backward warp.
116 | middle_iwe (Optional[np.ndarray]) ... IWE to middle warp.
117 |
118 | Returns:
119 | loss (float) ... average time loss
120 | """
121 | forward_loss = self.gradient_loss.calculate_numpy(forward_iwe, orig_iwe, omit_boundary)
122 | backward_loss = self.gradient_loss.calculate_numpy(backward_iwe, orig_iwe, omit_boundary)
123 | loss = forward_loss + backward_loss
124 |
125 | if middle_iwe is not None:
126 | loss += self.gradient_loss.calculate_numpy(middle_iwe, orig_iwe, omit_boundary) * 2
127 |
128 | if self.direction in ["minimize", "natural"]:
129 | return loss
130 | logger.warning("The loss is specified as maximize direction")
131 | return -loss
132 |
--------------------------------------------------------------------------------
/src/costs/multi_focal_normalized_image_variance.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from ..types import FLOAT_TORCH
8 | from . import CostBase, NormalizedImageVariance
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class MultiFocalNormalizedImageVariance(CostBase):
14 | """Multi-focus normalized image variance, Shiba et al. ECCV 2022.
15 |
16 | Args:
17 | direction (str) ... 'minimize' or 'maximize' or `natural`.
18 | Defines the objective function. If natural, it returns more interpretable value.
19 | """
20 |
21 | name = "multi_focal_normalized_image_variance"
22 | required_keys = ["forward_iwe", "backward_iwe", "middle_iwe", "omit_boundary", "orig_iwe"]
23 |
24 | def __init__(self, direction="minimize", store_history: bool = False, *args, **kwargs):
25 | super().__init__(direction=direction, store_history=store_history)
26 | self.variance_loss = NormalizedImageVariance(direction=direction)
27 |
28 | @CostBase.register_history # type: ignore
29 | @CostBase.catch_key_error # type: ignore
30 | def calculate(self, arg: dict) -> FLOAT_TORCH:
31 | """Calculate multi-focus normalized image variance.
32 | Inputs:
33 | orig_iwe (np.ndarray or torch.Tensor) ... Original IWE (before any warp).
34 | forward_iwe (np.ndarray or torch.Tensor) ... IWE to forward warp.
35 | backward_iwe (np.ndarray or torch.Tensor) ... IWE to backward warp.
36 | middle_iwe (Optional[np.ndarray or torch.Tensor]) ... IWE to middle warp.
37 | omit_bondary (bool) ... Omit boundary if True.
38 |
39 | Returns:
40 | average_time (Union[float, torch.Tensor]) ... Average timestamp.
41 | """
42 | orig_iwe = arg["orig_iwe"]
43 | forward_iwe = arg["forward_iwe"]
44 | if "middle_iwe" in arg.keys():
45 | middle_iwe = arg["middle_iwe"]
46 | else:
47 | middle_iwe = None
48 | backward_iwe = arg["backward_iwe"]
49 | omit_boundary = arg["omit_boundary"]
50 | if omit_boundary:
51 | forward_iwe = forward_iwe[..., 1:-1, 1:-1] # omit boundary
52 | backward_iwe = backward_iwe[..., 1:-1, 1:-1] # omit boundary
53 | if middle_iwe is not None:
54 | middle_iwe = middle_iwe[..., 1:-1, 1:-1] # omit boundary
55 |
56 | if isinstance(forward_iwe, torch.Tensor):
57 | return self.calculate_torch(orig_iwe, forward_iwe, backward_iwe, middle_iwe)
58 | elif isinstance(forward_iwe, np.ndarray):
59 | return self.calculate_numpy(orig_iwe, forward_iwe, backward_iwe, middle_iwe)
60 | e = f"Unsupported input type. {type(forward_iwe)}."
61 | logger.error(e)
62 | raise NotImplementedError(e)
63 |
64 | def calculate_torch(
65 | self,
66 | orig_iwe: torch.Tensor,
67 | forward_iwe: torch.Tensor,
68 | backward_iwe: torch.Tensor,
69 | middle_iwe: Optional[torch.Tensor],
70 | ) -> torch.Tensor:
71 | """Calculate bytorch
72 | Inputs:
73 | orig_iwe (torch.Tensor) ... Original IWE (before warp).
74 | forward_iwe (torch.Tensor) ... IWE to forward warp.
75 | backward_iwe (torch.Tensor) ... IWE to backward warp.
76 | middle_iwe (Optional[torch.Tensor]) ... IWE to middle warp.
77 |
78 | Returns:
79 | loss (torch.Tensor) ... average time loss.
80 | """
81 | forward_loss = self.variance_loss.calculate_torch(forward_iwe, orig_iwe)
82 | backward_loss = self.variance_loss.calculate_torch(backward_iwe, orig_iwe)
83 | loss = forward_loss + backward_loss
84 |
85 | if middle_iwe is not None:
86 | loss += self.variance_loss.calculate_torch(middle_iwe, orig_iwe) * 2
87 |
88 | if self.direction in ["minimize", "natural"]:
89 | return loss
90 | logger.warning("The loss is specified as maximize direction")
91 | return -loss
92 |
93 | def calculate_numpy(
94 | self,
95 | orig_iwe: np.ndarray,
96 | forward_iwe: np.ndarray,
97 | backward_iwe: np.ndarray,
98 | middle_iwe: Optional[np.ndarray],
99 | ) -> float:
100 | """Calculate contrast of the count image.
101 | Inputs:
102 | orig_iwe (np.ndarray) ... Original IWE (before warp).
103 | forward_iwe (np.ndarray) ... IWE to forward warp.
104 | backward_iwe (np.ndarray) ... IWE to backward warp.
105 | middle_iwe (Optional[np.ndarray]) ... IWE to middle warp.
106 |
107 | Returns:
108 | loss (float) ... average time loss
109 | """
110 | forward_loss = self.variance_loss.calculate_numpy(forward_iwe, orig_iwe)
111 | backward_loss = self.variance_loss.calculate_numpy(backward_iwe, orig_iwe)
112 | loss = forward_loss + backward_loss
113 |
114 | if middle_iwe is not None:
115 | loss += self.variance_loss.calculate_numpy(middle_iwe, orig_iwe) * 2
116 |
117 | if self.direction in ["minimize", "natural"]:
118 | return loss
119 | logger.warning("The loss is specified as maximize direction")
120 | return -loss
121 |
--------------------------------------------------------------------------------
/src/costs/normalized_gradient_magnitude.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from ..types import FLOAT_TORCH
7 | from . import CostBase, GradientMagnitude
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class NormalizedGradientMagnitude(CostBase):
13 | """Normalized gradient magnitude.
14 |
15 | Args:
16 | direction (str) ... 'minimize' or 'maximize' or `natural`.
17 | Defines the objective function. If natural, it returns more interpretable value.
18 | """
19 |
20 | name = "normalized_gradient_magnitude"
21 | required_keys = ["orig_iwe", "iwe", "omit_boundary"]
22 |
23 | def __init__(
24 | self,
25 | direction="minimize",
26 | store_history: bool = False,
27 | cuda_available=False,
28 | precision="32",
29 | *args,
30 | **kwargs,
31 | ):
32 | super().__init__(direction=direction, store_history=store_history)
33 | self.gradient_magnitude = GradientMagnitude(
34 | direction=direction,
35 | store_history=store_history,
36 | cuda_available=cuda_available,
37 | precision=precision,
38 | )
39 |
40 | @CostBase.register_history # type: ignore
41 | @CostBase.catch_key_error # type: ignore
42 | def calculate(self, arg: dict) -> FLOAT_TORCH:
43 | """Calculate normalized gradiend magnitude of IWE.
44 | Inputs:
45 | iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of warped events
46 | orig_iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of original events
47 | omit_bondary (bool) ... Omit boundary if True.
48 |
49 | Returns:
50 | contrast (FLOAT_TORCH) ... contrast of the image.
51 | """
52 | iwe = arg["iwe"]
53 | orig_iwe = arg["orig_iwe"]
54 | omit_boundary = arg["omit_boundary"]
55 | if isinstance(iwe, torch.Tensor):
56 | return self.calculate_torch(iwe, orig_iwe, omit_boundary)
57 | elif isinstance(iwe, np.ndarray):
58 | return self.calculate_numpy(iwe, orig_iwe, omit_boundary)
59 | e = f"Unsupported input type. {type(iwe)}."
60 | logger.error(e)
61 | raise NotImplementedError(e)
62 |
63 | def calculate_torch(
64 | self, iwe: torch.Tensor, orig_iwe: torch.Tensor, omit_boundary: bool
65 | ) -> torch.Tensor:
66 | """
67 | Inputs:
68 | iwe (torch.Tensor) ... [W, H]. Image of warped events
69 | orig_iwe (torch.Tensor) ... [W, H]. Image of original events
70 |
71 | Returns:
72 | loss (torch.Tensor) ... contrast of the image.
73 | """
74 | loss1 = self.gradient_magnitude.calculate_torch(iwe, omit_boundary)
75 | loss2 = self.gradient_magnitude.calculate_torch(orig_iwe, omit_boundary)
76 | if self.direction == "minimize":
77 | return loss2 / loss1
78 | logger.warning("The loss is specified as maximize direction")
79 | return loss1 / loss2
80 |
81 | def calculate_numpy(self, iwe: np.ndarray, orig_iwe: np.ndarray, omit_boundary: bool) -> float:
82 | """
83 | Inputs:
84 | iwe (np.ndarray) ... [W, H]. Image of warped events
85 | orig_iwe (np.ndarray) ... [W, H]. Image of original events
86 |
87 | Returns:
88 | contrast (float) ... contrast of the image.
89 | """
90 | loss1 = self.gradient_magnitude.calculate_numpy(iwe, omit_boundary)
91 | loss2 = self.gradient_magnitude.calculate_numpy(orig_iwe, omit_boundary)
92 | if self.direction == "minimize":
93 | return loss2 / loss1
94 | return loss1 / loss2
95 |
--------------------------------------------------------------------------------
/src/costs/normalized_image_variance.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Union
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from . import CostBase
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class NormalizedImageVariance(CostBase):
13 | """Normalized image variance,
14 | a.k.a FWP (flow Warp Loss) by Stoffregen et al. ECCV 2020.
15 | Args:
16 | direction (str) ... 'minimize' or 'maximize' or `natural`.
17 | Defines the objective function. If natural, it returns more interpretable value.
18 | """
19 |
20 | name = "normalized_image_variance"
21 | required_keys = ["orig_iwe", "iwe", "omit_boundary"]
22 |
23 | def __init__(self, direction="minimize", store_history: bool = False, *args, **kwargs):
24 | super().__init__(direction=direction, store_history=store_history)
25 |
26 | @CostBase.register_history # type: ignore
27 | @CostBase.catch_key_error # type: ignore
28 | def calculate(self, arg: dict) -> Union[float, torch.Tensor]:
29 | """Calculate the normalized contrast of the IWE.
30 | Inputs:
31 | iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of warped events
32 | orig_iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of original events
33 | omit_bondary (bool) ... Omit boundary if True.
34 |
35 | Returns:
36 | contrast (Union[float, torch.Tensor]) ... normalized contrast of the image.
37 | """
38 | iwe = arg["iwe"]
39 | orig_iwe = arg["orig_iwe"]
40 | if arg["omit_boundary"]:
41 | iwe = iwe[..., 1:-1, 1:-1] # omit boundary
42 | if isinstance(iwe, torch.Tensor):
43 | return self.calculate_torch(iwe, orig_iwe)
44 | elif isinstance(iwe, np.ndarray):
45 | return self.calculate_numpy(iwe, orig_iwe)
46 | e = f"Unsupported input type. {type(iwe)}."
47 | logger.error(e)
48 | raise NotImplementedError(e)
49 |
50 | def calculate_torch(self, iwe: torch.Tensor, orig_iwe: torch.Tensor) -> torch.Tensor:
51 | """Calculate the normalized contrast of the IWE.
52 | Inputs:
53 | iwe (torch.Tensor) ... [W, H]. Image of warped events
54 | orig_iwe (torch.Tensor) ... [W, H]. Image of original events
55 |
56 | Returns:
57 | loss (torch.Tensor) ... contrast of the image.
58 | """
59 | loss1 = torch.var(iwe)
60 | loss2 = torch.var(orig_iwe)
61 | if self.direction == "minimize":
62 | return loss2 / loss1
63 | logger.warning("The loss is specified as maximize direction")
64 | return loss1 / loss2
65 |
66 | def calculate_numpy(self, iwe: np.ndarray, orig_iwe: np.ndarray) -> float:
67 | """Calculate the normalized contrast of the IWE.
68 | Inputs:
69 | iwe (np.ndarray) ... [W, H]. Image of warped events
70 | orig_iwe (np.ndarray) ... [W, H]. Image of original events
71 |
72 | Returns:
73 | contrast (float) ... contrast of the image.
74 | """
75 | loss1 = np.var(iwe)
76 | loss2 = np.var(orig_iwe)
77 | if self.direction == "minimize":
78 | return loss2 / loss1
79 | return loss1 / loss2
80 |
--------------------------------------------------------------------------------
/src/costs/total_variation.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Union
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 |
8 | from ..utils import SobelTorch
9 | from . import CostBase
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class TotalVariation(CostBase):
15 | """Total Variation for regularizer
16 | Args:
17 | direction (str) ... 'minimize' or 'maximize' or `natural`.
18 | Defines the objective function. If natural, it returns more interpretable value.
19 | """
20 |
21 | name = "total_variation"
22 | required_keys = ["flow", "omit_boundary"]
23 |
24 | def __init__(
25 | self,
26 | direction="minimize",
27 | store_history: bool = False,
28 | cuda_available=False,
29 | precision="32",
30 | *args,
31 | **kwargs,
32 | ):
33 | super().__init__(direction=direction, store_history=store_history)
34 | self.torch_sobel = SobelTorch(
35 | ksize=3, in_channels=2, cuda_available=cuda_available, precision=precision
36 | )
37 |
38 | @CostBase.register_history # type: ignore
39 | @CostBase.catch_key_error # type: ignore
40 | def calculate(self, arg: dict) -> Union[float, torch.Tensor]:
41 | """Calculate total Variation of flow
42 | Inputs:
43 | flow (np.ndarray or torch.Tensor) ... [(b,) 2, W, H]. Flow of the image.
44 | omit_bondary (bool) ... Omit boundary if True.
45 |
46 | Returns:
47 | (Union[float, torch.Tensor]) ... Total variation of the flow.
48 | """
49 | flow = arg["flow"]
50 | omit_boundary = arg["omit_boundary"]
51 |
52 | if isinstance(flow, torch.Tensor):
53 | return self.calculate_torch(flow, omit_boundary)
54 | elif isinstance(flow, np.ndarray):
55 | return self.calculate_numpy(flow, omit_boundary)
56 | e = f"Unsupported input type. {type(flow)}."
57 | logger.error(e)
58 | raise NotImplementedError(e)
59 |
60 | def calculate_torch(self, flow: torch.Tensor, omit_boundary: bool) -> torch.Tensor:
61 | """Calculate cost
62 | Inputs:
63 | flow (torch.Tensor) ... [(b,) 2, W, H]. Optical flow
64 |
65 | Returns:
66 | loss (torch.Tensor) ... Total Variation.
67 | """
68 | sobel = self.get_sobel_image_torch(flow, omit_boundary)
69 |
70 | loss = torch.mean(torch.abs(sobel))
71 |
72 | if self.direction == "minimize":
73 | return loss
74 | logger.warning("The loss is specified as maximize direction")
75 | return -loss
76 |
77 | def calculate_numpy(self, flow: np.ndarray, omit_boundary: bool) -> float:
78 | """Calculate cost
79 | Inputs:
80 | flow (np.ndarray) ... [(b,) 2, W, W]. Optical flow
81 |
82 | Returns:
83 | loss (float) ... Total variation.
84 | """
85 | sobelxx, sobelxy, sobelyx, sobelyy = self.get_sobel_image_numpy(flow, omit_boundary)
86 |
87 | loss = np.mean(
88 | np.abs(sobelxx) + np.abs(sobelxy) + np.abs(sobelyx) + np.abs(sobelyy)
89 | ) # L1 version
90 |
91 | if self.direction == "minimize":
92 | return loss
93 | return -loss
94 |
95 | def visualize_sobel_image(self, sobel_image):
96 | sobel_image = np.abs(sobel_image)
97 | if len(sobel_image.shape) == 4:
98 | sobel_image = sobel_image[0]
99 | if len(sobel_image.shape) == 3:
100 | sobel_image = np.concatenate(
101 | [sobel_image[0], sobel_image[1], sobel_image[2], sobel_image[3]]
102 | )
103 | sobel_image = (
104 | (sobel_image - sobel_image.min()) / (sobel_image.max() - sobel_image.min()) * 255
105 | )
106 | self.inter_visualizer.visualize_image(
107 | sobel_image.astype(np.uint8), file_prefix="total_variation"
108 | )
109 |
110 | def get_sobel_image_torch(self, flow: torch.Tensor, omit_boundary: bool) -> torch.Tensor:
111 | """Calculate sobel of the flow.
112 | Inputs:
113 | flow (torch.Tensor) ... [(b,) 2, W, H]. Optical flow
114 |
115 | Returns:
116 | loss (torch.Tensor) ... [(b,), 4, W, H]. 4ch is
117 | [x-component dx, x-component dy, y-component dx, y-component dy]
118 | """
119 | if len(flow.shape) == 3:
120 | flow = flow[None] # 1, 2, W, H
121 | sobel = self.torch_sobel(flow) / 8.0
122 |
123 | if omit_boundary:
124 | if sobel.shape[2] > 2 and sobel.shape[3] > 2:
125 | sobel = sobel[..., 1:-1, 1:-1]
126 | return sobel
127 |
128 | def get_sobel_image_numpy(self, flow: np.ndarray, omit_boundary: bool) -> tuple:
129 | """Calculate sobel images of the flow.
130 | Inputs:
131 | flow (np.ndarray) ... [(b,) 2, W, W]. Optical flow
132 |
133 | Returns:
134 | (tuple) ... [x-component dx, x-component dy, y-component dx, y-component dy]
135 | """
136 | if len(flow.shape) == 4:
137 | raise NotImplementedError
138 | sobelxx = cv2.Sobel(flow[0], cv2.CV_64F, 1, 0, ksize=3)
139 | sobelxy = cv2.Sobel(flow[0], cv2.CV_64F, 0, 1, ksize=3)
140 | sobelyx = cv2.Sobel(flow[1], cv2.CV_64F, 1, 0, ksize=3)
141 | sobelyy = cv2.Sobel(flow[1], cv2.CV_64F, 0, 1, ksize=3)
142 |
143 | if omit_boundary:
144 | # Only for 3-dim array
145 | if sobelxx.shape[0] > 1 and sobelxx.shape[1] > 1:
146 | sobelxx = sobelxx[1:-1, 1:-1]
147 | sobelxy = sobelxy[1:-1, 1:-1]
148 | sobelyx = sobelyx[1:-1, 1:-1]
149 | sobelyy = sobelyy[1:-1, 1:-1]
150 |
151 | return sobelxx, sobelxy, sobelyx, sobelyy
152 |
--------------------------------------------------------------------------------
/src/data_loader/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | DATASET_ROOT_DIR = os.path.join(
4 | os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "datasets"
5 | )
6 |
7 | from .base import DataLoaderBase
8 |
9 | # from .davis_data import DavisDataLoader
10 | # from .dsec import DsecDataLoader # TODO comes later..
11 | from .mvsec import MvsecDataLoader
12 |
13 |
14 | # List of supported dataset
15 | def inheritors(klass):
16 | subclasses = set()
17 | work = [klass]
18 | while work:
19 | parent = work.pop()
20 | for child in parent.__subclasses__():
21 | if child not in subclasses:
22 | subclasses.add(child)
23 | work.append(child)
24 | return subclasses
25 |
26 |
27 | collections = {k.NAME: k for k in inheritors(DataLoaderBase)}
28 |
29 | # DNN
30 | # TODO comes lates
31 |
--------------------------------------------------------------------------------
/src/data_loader/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import numpy as np
5 |
6 | from .. import utils
7 | from . import DATASET_ROOT_DIR
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class DataLoaderBase(object):
13 | """Base of the DataLoader class.
14 | Please make sure to implement
15 | - load_event()
16 | - get_sequence()
17 | in chile classes.
18 | """
19 |
20 | NAME = "example"
21 |
22 | def __init__(self, config: dict = {}):
23 | self._HEIGHT = config["height"]
24 | self._WIDTH = config["width"]
25 |
26 | root_dir: str = config["root"] if config["root"] else DATASET_ROOT_DIR
27 | self.root_dir: str = os.path.expanduser(root_dir)
28 | data_dir: str = config["dataset"] if config["dataset"] else self.NAME
29 |
30 | self.dataset_dir: str = os.path.join(self.root_dir, data_dir)
31 | self.__dataset_files: dict = {}
32 | logger.info(f"Loading directory in {self.dataset_dir}")
33 |
34 | self.gt_flow_available: bool
35 | if utils.check_key_and_bool(config, "load_gt_flow"):
36 | self.gt_flow_dir: str = os.path.expanduser(config["gt"])
37 | self.gt_flow_available = utils.check_file_utils(self.gt_flow_dir)
38 | else:
39 | self.gt_flow_available = False
40 |
41 | if utils.check_key_and_bool(config, "undistort"):
42 | logger.info("Undistort events when load_event.")
43 | self.auto_undistort = True
44 | else:
45 | logger.info("No undistortion.")
46 | self.auto_undistort = False
47 |
48 | @property
49 | def dataset_files(self) -> dict:
50 | return self.__dataset_files
51 |
52 | @dataset_files.setter
53 | def dataset_files(self, sequence: dict):
54 | self.__dataset_files = sequence
55 |
56 | def set_sequence(self, sequence_name: str) -> None:
57 | logger.info(f"Use sequence {sequence_name}")
58 | self.sequence_name = sequence_name
59 | self.dataset_files = self.get_sequence(sequence_name)
60 |
61 | def get_sequence(self, sequence_name: str) -> dict:
62 | raise NotImplementedError
63 |
64 | def load_event(
65 | self, start_index: int, end_index: int, cam: str = "left", *args, **kwargs
66 | ) -> np.ndarray:
67 | raise NotImplementedError
68 |
69 | def load_calib(self) -> dict:
70 | raise NotImplementedError
71 |
72 | def load_optical_flow(self, t1: float, t2: float, *args, **kwargs) -> np.ndarray:
73 | raise NotImplementedError
74 |
75 | def index_to_time(self, index: int) -> float:
76 | raise NotImplementedError
77 |
78 | def time_to_index(self, time: float) -> int:
79 | raise NotImplementedError
80 |
--------------------------------------------------------------------------------
/src/data_loader/mvsec.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from typing import Tuple
4 |
5 | import h5py
6 | import numpy as np
7 |
8 | from ..utils import estimate_corresponding_gt_flow, undistort_events
9 | from . import DataLoaderBase
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | # hdf5 data loader
15 | def h5py_loader(path: str):
16 | """Basic loader for .hdf5 files.
17 | Args:
18 | path (str) ... Path to the .hdf5 file.
19 |
20 | Returns:
21 | timestamp (dict) ... Doctionary of numpy arrays. Keys are "left" / "right".
22 | davis_left (dict) ... "event": np.ndarray.
23 | davis_right (dict) ... "event": np.ndarray.
24 | """
25 | data = h5py.File(path, "r")
26 | event_timestamp = get_timestamp_index(data)
27 | r = {
28 | "event": np.array(data["davis"]["right"]["events"], dtype=np.int16),
29 | }
30 | # 'gray_ts': np.array(data['davis']['right']['image_raw_ts'], dtype=np.float64)
31 | l = {
32 | "event": np.array(data["davis"]["left"]["events"], dtype=np.int16),
33 | "gray_ts": np.array(data["davis"]["left"]["image_raw_ts"], dtype=np.float64),
34 | }
35 | data.close()
36 | return event_timestamp, l, r
37 |
38 |
39 | def get_timestamp_index(h5py_data):
40 | """Timestampm loader for pre-fetching before actual sensor data loading.
41 | This is necessary for sync between sensors and decide which timestamp to
42 | be used as ground clock.
43 | Args:
44 | h5py_data... h5py object.
45 | Returns:
46 | timestamp (dict) ... Doctionary of numpy arrays. Keys are "left" / "right".
47 | """
48 | timestamp = {}
49 | timestamp["right"] = np.array(h5py_data["davis"]["right"]["events"][:, 2])
50 | timestamp["left"] = np.array(h5py_data["davis"]["left"]["events"][:, 2])
51 | return timestamp
52 |
53 |
54 | class MvsecDataLoader(DataLoaderBase):
55 | """Dataloader class for MVSEC dataset."""
56 |
57 | NAME = "MVSEC"
58 |
59 | def __init__(self, config: dict = {}):
60 | super().__init__(config)
61 |
62 | # Override
63 | def set_sequence(self, sequence_name: str, undistort: bool = False) -> None:
64 | logger.info(f"Use sequence {sequence_name}")
65 | self.sequence_name = sequence_name
66 | logger.info(f"Undistort events = {undistort}")
67 |
68 | self.dataset_files = self.get_sequence(sequence_name)
69 | ts, l_event, r_event = h5py_loader(self.dataset_files["event"])
70 | self.left_event = l_event["event"] # int16 .. for smaller memory consumption.
71 | self.left_ts = ts["left"] # float64
72 | self.left_gray_ts = l_event["gray_ts"] # float64
73 | # self.right_event = r_event["event"]
74 | # self.right_ts = ts["right"]
75 | # self.right_gray_ts = r_event["gray_ts"] # float64
76 |
77 | # Setup GT
78 | if self.gt_flow_available:
79 | self.setup_gt_flow(os.path.join(self.gt_flow_dir, sequence_name))
80 | self.omit_invalid_data(sequence_name)
81 |
82 | # Undistort - most likely necessary to run evaluation with GT.
83 | self.undistort = undistort
84 | if self.undistort:
85 | self.calib_map_x, self.calib_map_y = self.get_calib_map(
86 | self.dataset_files["calib_map_x"], self.dataset_files["calib_map_y"]
87 | )
88 |
89 | # Setting up time suration statistics
90 | self.min_ts = self.left_ts.min()
91 | self.max_ts = self.left_ts.max()
92 | # self.min_ts = np.max([self.left_ts.min(), self.right_ts.min()])
93 | # self.max_ts = np.min([self.left_ts.max(), self.right_ts.max()]) - 10.0 # not use last 1 sec
94 | self.data_duration = self.max_ts - self.min_ts
95 |
96 | def get_sequence(self, sequence_name: str) -> dict:
97 | """Get data inside a sequence.
98 |
99 | Inputs:
100 | sequence_name (str) ... name of the sequence. ex) `outdoot_day2`.
101 |
102 | Returns:
103 | sequence_file (dict) ... dictionary of the filenames for the sequence.
104 | """
105 | data_path: str = os.path.join(self.root_dir, sequence_name)
106 | event_file = data_path + "_data.hdf5"
107 | calib_file_x = data_path[:-1] + "_left_x_map.txt"
108 | calib_file_y = data_path[:-1] + "_left_y_map.txt"
109 | sequence_file = {
110 | "event": event_file,
111 | "calib_map_x": calib_file_x,
112 | "calib_map_y": calib_file_y,
113 | }
114 | return sequence_file
115 |
116 | def setup_gt_flow(self, path):
117 | path = path + "_gt_flow_dist.npz"
118 | logger.info(f"Loading ground truth flow {path}")
119 | gt = np.load(path)
120 | self.gt_timestamps = gt["timestamps"]
121 | self.U_gt_all = gt["x_flow_dist"]
122 | self.V_gt_all = gt["y_flow_dist"]
123 |
124 | def free_up_flow(self):
125 | del self.gt_timestamps, self.U_gt_all, self.V_gt_all
126 |
127 | def omit_invalid_data(self, sequence_name: str):
128 | logger.info(f"Use only valid frames.")
129 | first_valid_gt_frame = 0
130 | last_valid_gt_frame = -1
131 | if "indoor_flying1" in sequence_name:
132 | first_valid_gt_frame = 60
133 | last_valid_gt_frame = 1340
134 | elif "indoor_flying2" in sequence_name:
135 | first_valid_gt_frame = 140
136 | last_valid_gt_frame = 1500
137 | elif "indoor_flying3" in sequence_name:
138 | first_valid_gt_frame = 100
139 | last_valid_gt_frame = 1711
140 | elif "indoor_flying4" in sequence_name:
141 | first_valid_gt_frame = 104
142 | last_valid_gt_frame = 380
143 | elif "outdoor_day1" in sequence_name:
144 | last_valid_gt_frame = 5020
145 | elif "outdoor_day2" in sequence_name:
146 | first_valid_gt_frame = 30
147 | # last_valid_gt_frame = 5020
148 |
149 | self.gt_timestamps = self.gt_timestamps[first_valid_gt_frame:last_valid_gt_frame]
150 | self.U_gt_all = self.U_gt_all[first_valid_gt_frame:last_valid_gt_frame]
151 | self.V_gt_all = self.V_gt_all[first_valid_gt_frame:last_valid_gt_frame]
152 |
153 | # Update event list
154 | first_event_index = self.time_to_index(self.gt_timestamps[0])
155 | last_event_index = self.time_to_index(self.gt_timestamps[-1])
156 | self.left_event = self.left_event[first_event_index:last_event_index]
157 | self.left_ts = self.left_ts[first_event_index:last_event_index]
158 |
159 | self.min_ts = self.left_ts.min()
160 | self.max_ts = self.left_ts.max()
161 |
162 | # Update gray frame ts
163 | self.left_gray_ts = self.left_gray_ts[
164 | (self.gt_timestamps[0] < self.left_gray_ts)
165 | & (self.gt_timestamps[-1] > self.left_gray_ts)
166 | ]
167 |
168 | # self.right_event = self.right_event[first_event_index:last_event_index]
169 | # self.right_ts = self.right_ts[first_event_index:last_event_index]
170 | # self.right_gray_ts = self.right_gray_ts[
171 | # (self.gt_timestamps[0] < self.right_gray_ts)
172 | # & (self.gt_timestamps[-1] > self.right_gray_ts)
173 | # ]
174 |
175 | def __len__(self):
176 | return len(self.left_event)
177 |
178 | def load_event(self, start_index: int, end_index: int, cam: str = "left") -> np.ndarray:
179 | """Load events.
180 | The original hdf5 file contains (x, y, t, p),
181 | where x means in width direction, and y means in height direction. p is -1 or 1.
182 |
183 | Returns:
184 | events (np.ndarray) ... Events. [x, y, t, p] where x is height.
185 | t is absolute value, in sec. p is [-1, 1].
186 | """
187 | n_events = end_index - start_index
188 | events = np.zeros((n_events, 4), dtype=np.float64)
189 |
190 | if cam == "left":
191 | if len(self.left_event) <= start_index:
192 | logger.error(
193 | f"Specified {start_index} to {end_index} index for {len(self.left_event)}."
194 | )
195 | raise IndexError
196 | events[:, 0] = self.left_event[start_index:end_index, 1]
197 | events[:, 1] = self.left_event[start_index:end_index, 0]
198 | events[:, 2] = self.left_ts[start_index:end_index]
199 | events[:, 3] = self.left_event[start_index:end_index, 3]
200 | elif cam == "right":
201 | logger.error("Please select `left`as `cam` parameter.")
202 | raise NotImplementedError
203 | if self.undistort:
204 | events = undistort_events(
205 | events, self.calib_map_x, self.calib_map_y, self._HEIGHT, self._WIDTH
206 | )
207 | return events
208 |
209 | # Optical flow (GT)
210 | def gt_time_list(self):
211 | return self.gt_timestamps
212 |
213 | def eval_frame_time_list(self):
214 | # In MVSEC, evaluation is based on gray frame timestamp.
215 | return self.left_gray_ts
216 |
217 | def index_to_time(self, index: int) -> float:
218 | return self.left_ts[index]
219 |
220 | def time_to_index(self, time: float) -> int:
221 | # inds = np.where(self.left_ts > time)[0]
222 | # if len(inds) == 0:
223 | # return len(self.left_ts) - 1
224 | # return inds[0] - 1
225 | ind = np.searchsorted(self.left_ts, time)
226 | return ind - 1
227 |
228 | def get_gt_time(self, index: int) -> tuple:
229 | """Get GT flow timestamp [floor, ceil] for a given index.
230 |
231 | Args:
232 | index (int): Index of the event
233 |
234 | Returns:
235 | tuple: [floor_gt, ceil_gt]. Both are synced with GT optical flow.
236 | """
237 | inds = np.where(self.gt_timestamps > self.index_to_time(index))[0]
238 | if len(inds) == 0:
239 | return (self.gt_timestamps[-1], None)
240 | elif len(inds) == len(self.gt_timestamps):
241 | return (None, self.gt_timestamps[0])
242 | else:
243 | return (self.gt_timestamps[inds[0] - 1], self.gt_timestamps[inds[0]])
244 |
245 | def load_optical_flow(self, t1: float, t2: float) -> np.ndarray:
246 | """Load GT Optical flow based on timestamp.
247 | Note: this is pixel displacement.
248 | Note: the args are not indices, but timestamps.
249 |
250 | Args:
251 | t1 (float): [description]
252 | t2 (float): [description]
253 |
254 | Returns:
255 | [np.ndarray]: H x W x 2. Be careful that the 2 ch is [height, width] direction component.
256 | """
257 | U_gt, V_gt = estimate_corresponding_gt_flow(
258 | self.U_gt_all,
259 | self.V_gt_all,
260 | self.gt_timestamps,
261 | t1,
262 | t2,
263 | )
264 | gt_flow = np.stack((V_gt, U_gt), axis=2)
265 | return gt_flow
266 |
267 | def load_calib(self) -> dict:
268 | """Load calibration file.
269 |
270 | Outputs:
271 | (dict) ... {"K": camera_matrix, "D": distortion_coeff}
272 | camera_matrix (np.ndarray) ... [3 x 3] matrix.
273 | distortion_coeff (np.array) ... [5] array.
274 | """
275 | logger.warning("directly load calib_param is not implemented!! please use rectify instead.")
276 | outdoor_K = np.array(
277 | [
278 | [223.9940010790056, 0, 170.7684322973841, 0],
279 | [0, 223.61783486959376, 128.18711828338436, 0],
280 | [0, 0, 1, 0],
281 | [0, 0, 0, 1],
282 | ],
283 | dtype=np.float32,
284 | )
285 | return {"K": outdoor_K}
286 |
287 | def get_calib_map(self, map_txt_x, map_txt_y):
288 | """Intrinsic calibration parameter file loader.
289 | Args:
290 | map_txt... file path.
291 | Returns
292 | map_array (np.array)... map array.
293 | """
294 | map_x = self.load_map_txt(map_txt_x)
295 | map_y = self.load_map_txt(map_txt_y)
296 | return map_x, map_y
297 |
298 | def load_map_txt(self, map_txt):
299 | f = open(map_txt, "r")
300 | line = f.readlines()
301 | map_array = np.zeros((self._HEIGHT, self._WIDTH))
302 | for i, l in enumerate(line):
303 | map_array[i] = np.array([float(k) for k in l.split()])
304 | return map_array
305 |
--------------------------------------------------------------------------------
/src/event_image_converter.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional, Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | from scipy.ndimage.filters import gaussian_filter
7 | from torchvision.transforms.functional import gaussian_blur
8 |
9 | from .types import FLOAT_TORCH, NUMPY_TORCH, is_numpy, is_torch
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class EventImageConverter(object):
15 | """Converter class of image into many different representations as an image.
16 |
17 | Args:
18 | image_size (tuple)... (H, W)
19 | outer_padding (int, or tuple) ... Padding to outer to the conversion. This tries to
20 | avoid events go out of the image.
21 | """
22 |
23 | def __init__(self, image_size: tuple, outer_padding: Union[int, Tuple[int, int]] = 0):
24 | if isinstance(outer_padding, (int, float)):
25 | self.outer_padding = (int(outer_padding), int(outer_padding))
26 | else:
27 | self.outer_padding = outer_padding
28 | self.image_size = tuple(int(i + p * 2) for i, p in zip(image_size, self.outer_padding))
29 |
30 | def update_property(
31 | self,
32 | image_size: Optional[tuple] = None,
33 | outer_padding: Optional[Union[int, Tuple[int, int]]] = None,
34 | ):
35 | if image_size is not None:
36 | self.image_size = image_size
37 | if outer_padding is not None:
38 | if isinstance(outer_padding, int):
39 | self.outer_padding = (outer_padding, outer_padding)
40 | else:
41 | self.outer_padding = outer_padding
42 | self.image_size = tuple(i + p for i, p in zip(self.image_size, self.outer_padding))
43 |
44 | # Higher layer functions
45 | def create_iwe(
46 | self,
47 | events: NUMPY_TORCH,
48 | method: str = "bilinear_vote",
49 | sigma: int = 1,
50 | ) -> NUMPY_TORCH:
51 | """Create Image of Warped Events (IWE).
52 |
53 | Args:
54 | events (NUMPY_TORCH): [(b,) n_events, 4]
55 | method (str): [description]
56 | sigma (float): [description]
57 |
58 | Returns:
59 | NUMPY_TORCH: [(b,) H, W]
60 | """
61 | if is_numpy(events):
62 | return self.create_image_from_events_numpy(events, method, sigma=sigma)
63 | elif is_torch(events):
64 | return self.create_image_from_events_tensor(events, method, sigma=sigma)
65 | e = f"Non-supported type of events. {type(events)}"
66 | logger.error(e)
67 | raise RuntimeError(e)
68 |
69 | def create_eventmask(self, events: NUMPY_TORCH) -> NUMPY_TORCH:
70 | """Create mask image where at least one event exists.
71 |
72 | Args:
73 | events (NUMPY_TORCH): [(b,) n_events, 4]
74 |
75 | Returns:
76 | NUMPY_TORCH: [(b,) 1, H, W]
77 | """
78 | if is_numpy(events):
79 | return (0 != self.create_image_from_events_numpy(events, sigma=0))[..., None, :, :]
80 | elif is_torch(events):
81 | return (0 != self.create_image_from_events_tensor(events, sigma=0))[..., None, :, :]
82 | raise RuntimeError
83 |
84 | # Lower layer functions
85 | # Image creation functions
86 | def create_image_from_events_numpy(
87 | self,
88 | events: np.ndarray,
89 | method: str = "bilinear_vote",
90 | weight: Union[float, np.ndarray] = 1.0,
91 | sigma: int = 1,
92 | ) -> np.ndarray:
93 | """Create image of events for numpy array.
94 |
95 | Inputs:
96 | events (np.ndarray) ... [(b,) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float.
97 | Also, x is height dimension and y is the width dimension.
98 | method (str) ... method to accumulate events. "count", "bilinear_vote", "polarity", etc.
99 | weight (float or np.ndarray) ... Only applicable when method = "bilinear_vote".
100 | sigma (int) ... Sigma for the gaussian blur.
101 |
102 | Returns:
103 | image ... [(b,) H, W]. Each index indicates the sum of the event, based on the specified method.
104 | """
105 | if method == "count":
106 | image = self.count_event_numpy(events)
107 | elif method == "bilinear_vote":
108 | image = self.bilinear_vote_numpy(events, weight=weight)
109 | elif method == "polarity": # TODO implement
110 | pos_flag = events[..., 3] > 0
111 | if is_numpy(weight):
112 | pos_image = self.bilinear_vote_numpy(events[pos_flag], weight=weight[pos_flag])
113 | neg_image = self.bilinear_vote_numpy(events[~pos_flag], weight=weight[~pos_flag])
114 | else:
115 | pos_image = self.bilinear_vote_numpy(events[pos_flag], weight=weight)
116 | neg_image = self.bilinear_vote_numpy(events[~pos_flag], weight=weight)
117 | image = np.stack([pos_image, neg_image], axis=-3)
118 | else:
119 | e = f"{method = } is not supported."
120 | logger.error(e)
121 | raise NotImplementedError(e)
122 | if sigma > 0:
123 | image = gaussian_filter(image, sigma)
124 | return image
125 |
126 | def create_image_from_events_tensor(
127 | self,
128 | events: torch.Tensor,
129 | method: str = "bilinear_vote",
130 | weight: FLOAT_TORCH = 1.0,
131 | sigma: int = 0,
132 | ) -> torch.Tensor:
133 | """Create image of events for tensor array.
134 |
135 | Inputs:
136 | events (torch.Tensor) ... [(b, ) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float.
137 | Also, x is the width dimension and y is the height dimension.
138 | method (str) ... method to accumulate events. "count", "bilinear_vote", "polarity", etc.
139 | weight (float or torch.Tensor) ... Only applicable when method = "bilinear_vote".
140 | sigma (int) ... Sigma for the gaussian blur.
141 |
142 | Returns:
143 | image ... [(b, ) W, H]. Each index indicates the sum of the event, based on the specified method.
144 | """
145 | if method == "count":
146 | image = self.count_event_tensor(events)
147 | elif method == "bilinear_vote":
148 | image = self.bilinear_vote_tensor(events, weight=weight)
149 | else:
150 | e = f"{method = } is not implemented"
151 | logger.error(e)
152 | raise NotImplementedError(e)
153 | if sigma > 0:
154 | if len(image.shape) == 2:
155 | image = image[None, None, ...]
156 | elif len(image.shape) == 3:
157 | image = image[:, None, ...]
158 | image = gaussian_blur(image, kernel_size=3, sigma=sigma)
159 | return torch.squeeze(image)
160 |
161 | def count_event_numpy(self, events: np.ndarray):
162 | """Count event and make image.
163 |
164 | Args:
165 | events ... [(b,) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float.
166 |
167 | Returns:
168 | image ... [(b,) W, H]. Each index indicates the sum of the event, just counting.
169 | """
170 | if len(events.shape) == 2:
171 | events = events[None, ...] # 1 x n x 4
172 |
173 | # x-y is height-width
174 | ph, pw = self.outer_padding
175 | h, w = self.image_size
176 | nb = len(events)
177 | image = np.zeros((nb, h * w), dtype=np.float64)
178 |
179 | floor_xy = np.floor(events[..., :2] + 1e-8)
180 | floor_to_xy = events[..., :2] - floor_xy
181 |
182 | x1 = floor_xy[..., 1] + pw
183 | y1 = floor_xy[..., 0] + ph
184 | inds = np.concatenate(
185 | [
186 | x1 + y1 * w,
187 | x1 + (y1 + 1) * w,
188 | (x1 + 1) + y1 * w,
189 | (x1 + 1) + (y1 + 1) * w,
190 | ],
191 | axis=-1,
192 | )
193 | inds_mask = np.concatenate(
194 | [
195 | (0 <= x1) * (x1 < w) * (0 <= y1) * (y1 < h),
196 | (0 <= x1) * (x1 < w) * (0 <= y1 + 1) * (y1 + 1 < h),
197 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1) * (y1 < h),
198 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1 + 1) * (y1 + 1 < h),
199 | ],
200 | axis=-1,
201 | )
202 | vals = np.ones_like(inds)
203 | inds = (inds * inds_mask).astype(np.int64)
204 | vals = vals * inds_mask
205 | for i in range(nb):
206 | np.add.at(image[i], inds[i], vals[i])
207 | return image.reshape((nb,) + self.image_size).squeeze()
208 |
209 | def count_event_tensor(self, events: torch.Tensor):
210 | """Tensor version of `count_event_numpy().`
211 |
212 | Args:
213 | events (torch.Tensor) ... [(b,) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float.
214 |
215 | Returns:
216 | image ... [(b,) H, W]. Each index indicates the bilinear vote result. If the outer_padding is set,
217 | the return size will be [H + outer_padding, W + outer_padding].
218 | """
219 | if len(events.shape) == 2:
220 | events = events[None, ...] # 1 x n x 4
221 |
222 | ph, pw = self.outer_padding
223 | h, w = self.image_size
224 | nb = len(events)
225 | image = events.new_zeros((nb, h * w))
226 |
227 | floor_xy = torch.floor(events[..., :2] + 1e-6)
228 | floor_to_xy = events[..., :2] - floor_xy
229 | floor_xy = floor_xy.long()
230 |
231 | x1 = floor_xy[..., 1] + pw
232 | y1 = floor_xy[..., 0] + ph
233 | inds = torch.cat(
234 | [
235 | x1 + y1 * w,
236 | x1 + (y1 + 1) * w,
237 | (x1 + 1) + y1 * w,
238 | (x1 + 1) + (y1 + 1) * w,
239 | ],
240 | dim=-1,
241 | ) # [(b, ) n_events x 4]
242 | inds_mask = torch.cat(
243 | [
244 | (0 <= x1) * (x1 < w) * (0 <= y1) * (y1 < h),
245 | (0 <= x1) * (x1 < w) * (0 <= y1 + 1) * (y1 + 1 < h),
246 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1) * (y1 < h),
247 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1 + 1) * (y1 + 1 < h),
248 | ],
249 | axis=-1,
250 | )
251 | vals = torch.ones_like(inds)
252 | inds = (inds * inds_mask).long()
253 | vals = vals * inds_mask
254 | image.scatter_add_(1, inds, vals)
255 | return image.reshape((nb,) + self.image_size).squeeze()
256 |
257 | def bilinear_vote_numpy(self, events: np.ndarray, weight: Union[float, np.ndarray] = 1.0):
258 | """Use bilinear voting to and make image.
259 |
260 | Args:
261 | events (np.ndarray) ... [(b, ) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float.
262 | weight (float or np.ndarray) ... Weight to multiply to the voting value.
263 | If scalar, the weight is all the same among events.
264 | If it's array-like, it should be the shape of [n_events].
265 | Defaults to 1.0.
266 |
267 | Returns:
268 | image ... [(b, ) H, W]. Each index indicates the bilinear vote result. If the outer_padding is set,
269 | the return size will be [H + outer_padding, W + outer_padding].
270 | """
271 | if type(weight) == np.ndarray:
272 | assert weight.shape == events.shape[:-1]
273 | if len(events.shape) == 2:
274 | events = events[None, ...] # 1 x n x 4
275 |
276 | # x-y is height-width
277 | ph, pw = self.outer_padding
278 | h, w = self.image_size
279 | nb = len(events)
280 | image = np.zeros((nb, h * w), dtype=np.float64)
281 |
282 | floor_xy = np.floor(events[..., :2] + 1e-8)
283 | floor_to_xy = events[..., :2] - floor_xy
284 |
285 | x1 = floor_xy[..., 1] + pw
286 | y1 = floor_xy[..., 0] + ph
287 | inds = np.concatenate(
288 | [
289 | x1 + y1 * w,
290 | x1 + (y1 + 1) * w,
291 | (x1 + 1) + y1 * w,
292 | (x1 + 1) + (y1 + 1) * w,
293 | ],
294 | axis=-1,
295 | )
296 | inds_mask = np.concatenate(
297 | [
298 | (0 <= x1) * (x1 < w) * (0 <= y1) * (y1 < h),
299 | (0 <= x1) * (x1 < w) * (0 <= y1 + 1) * (y1 + 1 < h),
300 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1) * (y1 < h),
301 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1 + 1) * (y1 + 1 < h),
302 | ],
303 | axis=-1,
304 | )
305 | w_pos0 = (1 - floor_to_xy[..., 0]) * (1 - floor_to_xy[..., 1]) * weight
306 | w_pos1 = floor_to_xy[..., 0] * (1 - floor_to_xy[..., 1]) * weight
307 | w_pos2 = (1 - floor_to_xy[..., 0]) * floor_to_xy[..., 1] * weight
308 | w_pos3 = floor_to_xy[..., 0] * floor_to_xy[..., 1] * weight
309 | vals = np.concatenate([w_pos0, w_pos1, w_pos2, w_pos3], axis=-1)
310 | inds = (inds * inds_mask).astype(np.int64)
311 | vals = vals * inds_mask
312 | for i in range(nb):
313 | np.add.at(image[i], inds[i], vals[i])
314 | return image.reshape((nb,) + self.image_size).squeeze()
315 |
316 | def bilinear_vote_tensor(self, events: torch.Tensor, weight: FLOAT_TORCH = 1.0):
317 | """Tensor version of `bilinear_vote_numpy().`
318 |
319 | Args:
320 | events (torch.Tensor) ... [(b,) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float.
321 | weight (float or torch.Tensor) ... Weight to multiply to the voting value.
322 | If scalar, the weight is all the same among events.
323 | If it's array-like, it should be the shape of [(b,) n_events].
324 | Defaults to 1.0.
325 |
326 | Returns:
327 | image ... [(b,) H, W]. Each index indicates the bilinear vote result. If the outer_padding is set,
328 | the return size will be [H + outer_padding, W + outer_padding].
329 | """
330 | if type(weight) == torch.Tensor:
331 | assert weight.shape == events.shape[:-1]
332 | if len(events.shape) == 2:
333 | events = events[None, ...] # 1 x n x 4
334 |
335 | ph, pw = self.outer_padding
336 | h, w = self.image_size
337 | nb = len(events)
338 | image = events.new_zeros((nb, h * w))
339 |
340 | floor_xy = torch.floor(events[..., :2] + 1e-6)
341 | floor_to_xy = events[..., :2] - floor_xy
342 | floor_xy = floor_xy.long()
343 |
344 | x1 = floor_xy[..., 1] + pw
345 | y1 = floor_xy[..., 0] + ph
346 | inds = torch.cat(
347 | [
348 | x1 + y1 * w,
349 | x1 + (y1 + 1) * w,
350 | (x1 + 1) + y1 * w,
351 | (x1 + 1) + (y1 + 1) * w,
352 | ],
353 | dim=-1,
354 | ) # [(b, ) n_events x 4]
355 | inds_mask = torch.cat(
356 | [
357 | (0 <= x1) * (x1 < w) * (0 <= y1) * (y1 < h),
358 | (0 <= x1) * (x1 < w) * (0 <= y1 + 1) * (y1 + 1 < h),
359 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1) * (y1 < h),
360 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1 + 1) * (y1 + 1 < h),
361 | ],
362 | axis=-1,
363 | )
364 |
365 | w_pos0 = (1 - floor_to_xy[..., 0]) * (1 - floor_to_xy[..., 1]) * weight
366 | w_pos1 = floor_to_xy[..., 0] * (1 - floor_to_xy[..., 1]) * weight
367 | w_pos2 = (1 - floor_to_xy[..., 0]) * floor_to_xy[..., 1] * weight
368 | w_pos3 = floor_to_xy[..., 0] * floor_to_xy[..., 1] * weight
369 | vals = torch.cat([w_pos0, w_pos1, w_pos2, w_pos3], dim=-1) # [(b,) n_events x 4]
370 |
371 | inds = (inds * inds_mask).long()
372 | vals = vals * inds_mask
373 | image.scatter_add_(1, inds, vals)
374 | return image.reshape((nb,) + self.image_size).squeeze()
375 |
--------------------------------------------------------------------------------
/src/feature_calculator.py:
--------------------------------------------------------------------------------
1 | # Mock code
2 | # Feature calculation is not necessary
3 | import logging
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 |
8 | class FeatureCalculatorMock:
9 | def __init__(self, *args, **kwargs):
10 | """Mock class -- please ignore."""
11 | logger.warning("Feature calculation is disabled in this source code.")
12 | pass
13 |
14 | def skip(self):
15 | feature = {
16 | "none": {"per_event": True, "value": None},
17 | }
18 | return feature
19 |
20 | def calculate_feature(self, *args, skip: bool = False, **kwargs) -> dict:
21 | """Mock function."""
22 | return self.skip()
23 |
--------------------------------------------------------------------------------
/src/solver/__init__.py:
--------------------------------------------------------------------------------
1 | """isort:skip_file
2 | """
3 |
4 | # Non DNN
5 | from .base import SolverBase
6 |
7 | # from .contrast_maximization import ContrastMaximization
8 | from .patch_contrast_mixed import MixedPatchContrastMaximization
9 | from .time_aware_patch_contrast import TimeAwarePatchContrastMaximization
10 | from .patch_contrast_pyramid import PyramidalPatchContrastMaximization
11 |
12 |
13 | # List of supported solver - non DNN
14 | collections = {
15 | # CMax and variants
16 | # "contrast_maximization": ContrastMaximization,
17 | "pyramidal_patch_contrast_maximization": PyramidalPatchContrastMaximization,
18 | "time_aware_mixed_patch_contrast_maximization": TimeAwarePatchContrastMaximization,
19 | }
20 |
--------------------------------------------------------------------------------
/src/solver/nnmodels/__init__.py:
--------------------------------------------------------------------------------
1 | from .ev_flownet import EVFlowNet
2 |
--------------------------------------------------------------------------------
/src/solver/nnmodels/basic_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class build_resnet_block(nn.Module):
6 | """
7 | a resnet block which includes two general_conv2d
8 | """
9 |
10 | def __init__(self, channels, layers=2, do_batch_norm=False):
11 | super().__init__()
12 | self._channels = channels
13 | self._layers = layers
14 |
15 | self.res_block = nn.Sequential(
16 | *[
17 | general_conv2d(
18 | in_channels=self._channels,
19 | out_channels=self._channels,
20 | strides=1,
21 | do_batch_norm=do_batch_norm,
22 | )
23 | for i in range(self._layers)
24 | ]
25 | )
26 |
27 | def forward(self, input_res):
28 | inputs = input_res.clone()
29 | input_res = self.res_block(input_res)
30 | return input_res + inputs
31 |
32 |
33 | class upsample_conv2d_and_predict_flow_or_depth(nn.Module):
34 | """
35 | an upsample convolution layer which includes a nearest interpolate and a general_conv2d
36 | """
37 |
38 | def __init__(
39 | self, in_channels, out_channels, ksize=3, do_batch_norm=False, type="flow", scale=256.0
40 | ):
41 | super().__init__()
42 | self._in_channels = in_channels
43 | self._out_channels = out_channels
44 | self._ksize = ksize
45 | self._do_batch_norm = do_batch_norm
46 | self._flow_or_depth = type
47 | self._scale = scale
48 |
49 | self.general_conv2d = general_conv2d(
50 | in_channels=self._in_channels,
51 | out_channels=self._out_channels,
52 | ksize=self._ksize,
53 | strides=1,
54 | do_batch_norm=self._do_batch_norm,
55 | padding=0,
56 | )
57 |
58 | self.pad = nn.ReflectionPad2d(
59 | padding=(
60 | int((self._ksize - 1) / 2),
61 | int((self._ksize - 1) / 2),
62 | int((self._ksize - 1) / 2),
63 | int((self._ksize - 1) / 2),
64 | )
65 | )
66 |
67 | if self._flow_or_depth == "flow":
68 | self.predict = general_conv2d(
69 | in_channels=self._out_channels,
70 | out_channels=2,
71 | ksize=1,
72 | strides=1,
73 | padding=0,
74 | activation="tanh",
75 | )
76 | elif self._flow_or_depth == "depth":
77 | self.predict = general_conv2d(
78 | in_channels=self._out_channels,
79 | out_channels=1,
80 | ksize=1,
81 | strides=1,
82 | padding=0,
83 | activation="sigmoid",
84 | )
85 | else:
86 | raise NotImplementedError("flow or depth?")
87 |
88 | def forward(self, conv):
89 | """
90 | Returns:
91 | feature
92 | pred (tensor) ... [N, ch, H, W]; 2 ch (flow) or 1 ch (depth).
93 | """
94 | shape = conv.shape
95 | conv = nn.functional.interpolate(
96 | # conv, size=[shape[2] * 2, shape[3] * 2], mode="nearest"
97 | conv,
98 | size=[shape[2] * 2, shape[3] * 2],
99 | mode="bilinear",
100 | )
101 | conv = self.pad(conv)
102 | conv = self.general_conv2d(conv)
103 | pred = self.predict(conv) * self._scale
104 | return torch.cat([conv, pred.clone()], dim=1), pred
105 |
106 |
107 | def general_conv2d(
108 | in_channels, out_channels, ksize=3, strides=2, padding=1, do_batch_norm=False, activation="relu"
109 | ):
110 | """
111 | a general convolution layer which includes a conv2d, a relu and a batch_normalize
112 | """
113 | if activation == "relu":
114 | if do_batch_norm:
115 | conv2d = nn.Sequential(
116 | nn.Conv2d(
117 | in_channels=in_channels,
118 | out_channels=out_channels,
119 | kernel_size=ksize,
120 | stride=strides,
121 | padding=padding,
122 | ),
123 | nn.ReLU(inplace=True),
124 | nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.99),
125 | )
126 | else:
127 | conv2d = nn.Sequential(
128 | nn.Conv2d(
129 | in_channels=in_channels,
130 | out_channels=out_channels,
131 | kernel_size=ksize,
132 | stride=strides,
133 | padding=padding,
134 | ),
135 | nn.ReLU(inplace=True),
136 | )
137 | elif activation == "tanh":
138 | if do_batch_norm:
139 | conv2d = nn.Sequential(
140 | nn.Conv2d(
141 | in_channels=in_channels,
142 | out_channels=out_channels,
143 | kernel_size=ksize,
144 | stride=strides,
145 | padding=padding,
146 | ),
147 | nn.Tanh(),
148 | nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.99),
149 | )
150 | else:
151 | conv2d = nn.Sequential(
152 | nn.Conv2d(
153 | in_channels=in_channels,
154 | out_channels=out_channels,
155 | kernel_size=ksize,
156 | stride=strides,
157 | padding=padding,
158 | ),
159 | nn.Tanh(),
160 | )
161 | elif activation == "sigmoid":
162 | if do_batch_norm:
163 | conv2d = nn.Sequential(
164 | nn.Conv2d(
165 | in_channels=in_channels,
166 | out_channels=out_channels,
167 | kernel_size=ksize,
168 | stride=strides,
169 | padding=padding,
170 | ),
171 | nn.Sigmoid(),
172 | nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.99),
173 | )
174 | else:
175 | conv2d = nn.Sequential(
176 | nn.Conv2d(
177 | in_channels=in_channels,
178 | out_channels=out_channels,
179 | kernel_size=ksize,
180 | stride=strides,
181 | padding=padding,
182 | ),
183 | nn.Sigmoid(),
184 | )
185 | return conv2d
186 |
187 |
188 | import torch
189 | import torch.nn as nn
190 | import torch.nn.functional as f
191 |
192 |
193 | class ConvLayer(nn.Module):
194 | """
195 | Convolutional layer.
196 | Default: bias, ReLU, no downsampling, no batch norm.
197 | """
198 |
199 | def __init__(
200 | self,
201 | in_channels,
202 | out_channels,
203 | kernel_size,
204 | stride=1,
205 | padding=0,
206 | activation="relu",
207 | norm=None,
208 | BN_momentum=0.1,
209 | ):
210 | super(ConvLayer, self).__init__()
211 |
212 | bias = False if norm == "BN" else True
213 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
214 | if activation is not None:
215 | self.activation = getattr(torch, activation)
216 | else:
217 | self.activation = None
218 |
219 | self.norm = norm
220 | if norm == "BN":
221 | self.norm_layer = nn.BatchNorm2d(out_channels, momentum=BN_momentum)
222 | elif norm == "IN":
223 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
224 |
225 | def forward(self, x):
226 | out = self.conv2d(x)
227 |
228 | if self.norm in ["BN", "IN"]:
229 | out = self.norm_layer(out)
230 |
231 | if self.activation is not None:
232 | out = self.activation(out)
233 |
234 | return out
235 |
236 |
237 | class TransposedConvLayer(nn.Module):
238 | """
239 | Transposed convolutional layer to increase spatial resolution (x2) in a decoder.
240 | Default: bias, ReLU, no downsampling, no batch norm.
241 | """
242 |
243 | def __init__(
244 | self,
245 | in_channels,
246 | out_channels,
247 | kernel_size,
248 | padding=0,
249 | activation="relu",
250 | norm=None,
251 | ):
252 | super(TransposedConvLayer, self).__init__()
253 |
254 | bias = False if norm == "BN" else True
255 | self.transposed_conv2d = nn.ConvTranspose2d(
256 | in_channels,
257 | out_channels,
258 | kernel_size,
259 | stride=2,
260 | padding=padding,
261 | output_padding=1,
262 | bias=bias,
263 | )
264 |
265 | if activation is not None:
266 | self.activation = getattr(torch, activation)
267 | else:
268 | self.activation = None
269 |
270 | self.norm = norm
271 | if norm == "BN":
272 | self.norm_layer = nn.BatchNorm2d(out_channels)
273 | elif norm == "IN":
274 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
275 |
276 | def forward(self, x):
277 | out = self.transposed_conv2d(x)
278 |
279 | if self.norm in ["BN", "IN"]:
280 | out = self.norm_layer(out)
281 |
282 | if self.activation is not None:
283 | out = self.activation(out)
284 |
285 | return out
286 |
287 |
288 | class UpsampleConvLayer(nn.Module):
289 | """
290 | Upsampling layer (bilinear interpolation + Conv2d) to increase spatial resolution (x2) in a decoder.
291 | Default: bias, ReLU, no downsampling, no batch norm.
292 | """
293 |
294 | def __init__(
295 | self,
296 | in_channels,
297 | out_channels,
298 | kernel_size,
299 | stride=1,
300 | padding=0,
301 | activation="relu",
302 | norm=None,
303 | ):
304 | super(UpsampleConvLayer, self).__init__()
305 |
306 | bias = False if norm == "BN" else True
307 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
308 |
309 | if activation is not None:
310 | self.activation = getattr(torch, activation)
311 | else:
312 | self.activation = None
313 |
314 | self.norm = norm
315 | if norm == "BN":
316 | self.norm_layer = nn.BatchNorm2d(out_channels)
317 | elif norm == "IN":
318 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
319 |
320 | def forward(self, x):
321 | x_upsampled = f.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
322 | out = self.conv2d(x_upsampled)
323 |
324 | if self.norm in ["BN", "IN"]:
325 | out = self.norm_layer(out)
326 |
327 | if self.activation is not None:
328 | out = self.activation(out)
329 |
330 | return out
331 |
332 |
333 | class ResidualBlock(nn.Module):
334 | """
335 | Residual block as in "Deep residual learning for image recognition", He et al. 2016.
336 | Default: bias, ReLU, no downsampling, no batch norm, ConvLSTM.
337 | """
338 |
339 | def __init__(
340 | self,
341 | in_channels,
342 | out_channels,
343 | stride=1,
344 | downsample=None,
345 | norm=None,
346 | BN_momentum=0.1,
347 | ):
348 | super(ResidualBlock, self).__init__()
349 | bias = False if norm == "BN" else True
350 | self.conv1 = nn.Conv2d(
351 | in_channels,
352 | out_channels,
353 | kernel_size=3,
354 | stride=stride,
355 | padding=1,
356 | bias=bias,
357 | )
358 | self.norm = norm
359 | if norm == "BN":
360 | self.bn1 = nn.BatchNorm2d(out_channels, momentum=BN_momentum)
361 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=BN_momentum)
362 | elif norm == "IN":
363 | self.bn1 = nn.InstanceNorm2d(out_channels)
364 | self.bn2 = nn.InstanceNorm2d(out_channels)
365 |
366 | self.relu = nn.ReLU(inplace=True)
367 | self.conv2 = nn.Conv2d(
368 | out_channels,
369 | out_channels,
370 | kernel_size=3,
371 | stride=1,
372 | padding=1,
373 | bias=bias,
374 | )
375 | self.downsample = downsample
376 |
377 | def forward(self, x):
378 | residual = x
379 | out = self.conv1(x)
380 | if self.norm in ["BN", "IN"]:
381 | out = self.bn1(out)
382 | out = self.relu(out)
383 | out = self.conv2(out)
384 | if self.norm in ["BN", "IN"]:
385 | out = self.bn2(out)
386 |
387 | if self.downsample:
388 | residual = self.downsample(x)
389 |
390 | out += residual
391 | out = self.relu(out)
392 | return out
393 |
394 |
395 | def skip_concat(x1, x2):
396 | diffY = x2.size()[2] - x1.size()[2]
397 | diffX = x2.size()[3] - x1.size()[3]
398 | padding = nn.ZeroPad2d((diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
399 | x1 = padding(x1)
400 | return torch.cat([x1, x2], dim=1)
401 |
402 |
403 | def skip_sum(x1, x2):
404 | diffY = x2.size()[2] - x1.size()[2]
405 | diffX = x2.size()[3] - x1.size()[3]
406 | padding = nn.ZeroPad2d((diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
407 | x1 = padding(x1)
408 | return x1 + x2
409 |
--------------------------------------------------------------------------------
/src/solver/nnmodels/ev_flownet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from ... import utils
5 | from . import basic_layers
6 |
7 | _BASE_CHANNELS = 64
8 |
9 |
10 | class EVFlowNet(nn.Module):
11 | """EV-FlowNet definition
12 | Code is obtained from https://github.com/CyrilSterling/EVFlowNet-pytorch
13 | Thanks to the author @CyrilSterling (and @alexzhu for the original paper!)
14 | """
15 | def __init__(self, nn_param: dict = {}):
16 | super().__init__()
17 | self.no_batch_norm = nn_param["no_batch_norm"]
18 |
19 | # Parameters for event voxel input.
20 | self.n_channel = nn_param["n_bin"] if "n_bin" in nn_param.keys() else 4
21 | self.scale_bin_time = nn_param["scale_time"] if "scale_time" in nn_param.keys() else 128.0
22 |
23 | self.encoder1 = basic_layers.general_conv2d(
24 | in_channels=self.n_channel,
25 | out_channels=_BASE_CHANNELS,
26 | do_batch_norm=not self.no_batch_norm,
27 | )
28 | self.encoder2 = basic_layers.general_conv2d(
29 | in_channels=_BASE_CHANNELS,
30 | out_channels=2 * _BASE_CHANNELS,
31 | do_batch_norm=not self.no_batch_norm,
32 | )
33 | self.encoder3 = basic_layers.general_conv2d(
34 | in_channels=2 * _BASE_CHANNELS,
35 | out_channels=4 * _BASE_CHANNELS,
36 | do_batch_norm=not self.no_batch_norm,
37 | )
38 | self.encoder4 = basic_layers.general_conv2d(
39 | in_channels=4 * _BASE_CHANNELS,
40 | out_channels=8 * _BASE_CHANNELS,
41 | do_batch_norm=not self.no_batch_norm,
42 | )
43 |
44 | self.resnet_block = nn.Sequential(
45 | *[
46 | basic_layers.build_resnet_block(
47 | 8 * _BASE_CHANNELS, do_batch_norm=not self.no_batch_norm
48 | )
49 | for i in range(2)
50 | ]
51 | )
52 |
53 | self.decoder1 = basic_layers.upsample_conv2d_and_predict_flow_or_depth(
54 | in_channels=16 * _BASE_CHANNELS,
55 | out_channels=4 * _BASE_CHANNELS,
56 | do_batch_norm=not self.no_batch_norm,
57 | type="flow",
58 | scale=self.scale_bin_time,
59 | )
60 |
61 | self.decoder2 = basic_layers.upsample_conv2d_and_predict_flow_or_depth(
62 | in_channels=8 * _BASE_CHANNELS + 2,
63 | out_channels=2 * _BASE_CHANNELS,
64 | do_batch_norm=not self.no_batch_norm,
65 | type="flow",
66 | scale=self.scale_bin_time,
67 | )
68 |
69 | self.decoder3 = basic_layers.upsample_conv2d_and_predict_flow_or_depth(
70 | in_channels=4 * _BASE_CHANNELS + 2,
71 | out_channels=_BASE_CHANNELS,
72 | do_batch_norm=not self.no_batch_norm,
73 | type="flow",
74 | scale=self.scale_bin_time,
75 | )
76 |
77 | self.decoder4 = basic_layers.upsample_conv2d_and_predict_flow_or_depth(
78 | in_channels=2 * _BASE_CHANNELS + 2,
79 | out_channels=int(_BASE_CHANNELS / 2),
80 | do_batch_norm=not self.no_batch_norm,
81 | type="flow",
82 | scale=self.scale_bin_time,
83 | )
84 |
85 | @utils.profile(
86 | output_file="dnn_forward.prof", sort_by="cumulative", lines_to_print=300, strip_dirs=True
87 | )
88 | def forward(self, inputs: torch.Tensor) -> dict:
89 | """
90 | Args:
91 | inputs (torch.Tensor) ... [n_batch, n_bin, height, width]
92 |
93 | Returns
94 | flow_dict (dict) ... "flow3": [n_batch, 1, height, width]
95 | "flow0": [n_batch, 1, height // 2**3, width // 2**3]
96 | """
97 | # encoder
98 | skip_connections = {}
99 | inputs = self.encoder1(inputs)
100 | skip_connections["skip0"] = inputs.clone()
101 | inputs = self.encoder2(inputs)
102 | skip_connections["skip1"] = inputs.clone()
103 | inputs = self.encoder3(inputs)
104 | skip_connections["skip2"] = inputs.clone()
105 | inputs = self.encoder4(inputs)
106 | skip_connections["skip3"] = inputs.clone()
107 |
108 | # transition
109 | inputs = self.resnet_block(inputs)
110 |
111 | # decoder
112 | flow_dict = {}
113 | inputs = torch.cat([inputs, skip_connections["skip3"]], dim=1)
114 | inputs, flow = self.decoder1(inputs)
115 | flow_dict["flow0"] = flow.clone()
116 |
117 | inputs = torch.cat([inputs, skip_connections["skip2"]], dim=1)
118 | inputs, flow = self.decoder2(inputs)
119 | flow_dict["flow1"] = flow.clone()
120 |
121 | inputs = torch.cat([inputs, skip_connections["skip1"]], dim=1)
122 | inputs, flow = self.decoder3(inputs)
123 | flow_dict["flow2"] = flow.clone()
124 |
125 | inputs = torch.cat([inputs, skip_connections["skip0"]], dim=1)
126 | inputs, flow = self.decoder4(inputs)
127 | flow_dict["flow3"] = flow.clone()
128 |
129 | return flow_dict
130 |
--------------------------------------------------------------------------------
/src/solver/patch_contrast_mixed.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import shutil
3 | from typing import Any, Dict, List, Optional, Tuple
4 |
5 | import numpy as np
6 | import scipy
7 | import torch
8 |
9 | from .. import utils, visualizer
10 | from . import scipy_autograd
11 | from .base import SCIPY_OPTIMIZERS
12 | from .patch_contrast_base import PatchContrastMaximization
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class MixedPatchContrastMaximization(PatchContrastMaximization):
18 | """Mixed patch-based CMax.
19 |
20 | Params:
21 | image_shape (tuple) ... (H, W)
22 | calibration_parameter (dict) ... dictionary of the calibration parameter
23 | solver_config (dict) ... solver configuration
24 | optimizer_config (dict) ... optimizer configuration
25 | visualize_module ... visualizer.Visualizer
26 | """
27 |
28 | def __init__(
29 | self,
30 | image_shape: tuple,
31 | calibration_parameter: dict,
32 | solver_config: dict = {},
33 | optimizer_config: dict = {},
34 | output_config: dict = {},
35 | visualize_module: Optional[visualizer.Visualizer] = None,
36 | ):
37 | super().__init__(
38 | image_shape,
39 | calibration_parameter,
40 | solver_config,
41 | optimizer_config,
42 | output_config,
43 | visualize_module,
44 | )
45 | self.set_patch_size_and_sliding_window()
46 | self.patches, self.patch_image_size = self.prepare_patch(
47 | image_shape, self.patch_size, self.sliding_window
48 | )
49 | self.n_patch = len(self.patches.keys())
50 |
51 | # internal variable
52 | self._patch_motion_model_keys = [
53 | f"patch{i}_{k}" for i in range(self.n_patch) for k in self.motion_model_keys
54 | ]
55 |
56 | def optimize(self, events: np.ndarray) -> np.ndarray:
57 | """Run optimization.
58 |
59 | Inputs:
60 | events (np.ndarray) ... [n_events x 4] event array. Should be (x, y, t, p).
61 | n_iteration (int) ... How many iterations to run.
62 |
63 | """
64 | # Preprocessings
65 | logger.info("Start optimization.")
66 | logger.info(f"DoF is {self.motion_vector_size * self.n_patch}")
67 |
68 | if self.opt_method == "optuna":
69 | opt_result = self.run_optuna(events)
70 | logger.info(f"End optimization.")
71 | best_motion = self.get_motion_array_optuna(opt_result.best_params)
72 | elif self.opt_method in SCIPY_OPTIMIZERS:
73 | opt_result = self.run_scipy(events)
74 | logger.info(f"End optimization.\n Best parameters: {opt_result}")
75 | best_motion = opt_result.x.reshape(
76 | ((self.motion_vector_size,) + self.patch_image_size)
77 | ) # / 1000
78 |
79 | logger.info("Profile file saved.")
80 | if self.visualizer:
81 | shutil.copy("optimize.prof", self.visualizer.save_dir)
82 | if self.opt_method in SCIPY_OPTIMIZERS:
83 | self.visualizer.visualize_scipy_history(
84 | self.cost_func.get_history(), self.cost_weight
85 | )
86 |
87 | logger.info(f"{best_motion}")
88 | return best_motion
89 |
90 | # Optuna functions
91 | def objective(self, trial, events: np.ndarray):
92 | # Parameters setting
93 | params = {k: self.sampling(trial, k) for k in self._patch_motion_model_keys}
94 | motion_array = self.get_motion_array_optuna(params) # 2 x H x W
95 | if self.normalize_t_in_batch:
96 | t_scale = np.max(events[:, 2]) - np.min(events[:, 2])
97 | motion_array *= t_scale
98 | dense_flow = self.motion_to_dense_flow(motion_array)
99 |
100 | loss = self.calculate_cost(events, dense_flow, self.motion_model_for_dense_warp)
101 | logger.info(f"{trial.number = } / {loss = }")
102 | return loss
103 |
104 | def sampling(self, trial, key: str):
105 | """Sampling function for mixed type patch solution.
106 |
107 | Args:
108 | trial ([type]): [description]
109 | key (str): [description]
110 |
111 | Returns:
112 | [type]: [description]
113 | """
114 | key_suffix = key[key.find("_") + 1 :]
115 | return trial.suggest_uniform(
116 | key,
117 | self.opt_config["parameters"][key_suffix]["min"],
118 | self.opt_config["parameters"][key_suffix]["max"],
119 | )
120 |
121 | def get_motion_array_optuna(self, params: dict) -> np.ndarray:
122 | # Returns [n_patch x n_motion_paremter]
123 | motion_array = np.zeros((self.motion_vector_size, self.n_patch))
124 | for i in range(self.n_patch):
125 | param = {k: params[f"patch{i}_{k}"] for k in self.motion_model_keys}
126 | motion_array[:, i] = self.motion_model_to_motion(param)
127 | return motion_array.reshape((self.motion_vector_size,) + self.patch_image_size)
128 |
129 | # Scipy
130 | @utils.profile(
131 | output_file="optimize.prof", sort_by="cumulative", lines_to_print=300, strip_dirs=True
132 | )
133 | def run_scipy(self, events: np.ndarray) -> scipy.optimize.OptimizeResult:
134 | if self.previous_frame_best_estimation is not None:
135 | motion0 = np.copy(self.previous_frame_best_estimation)
136 | else:
137 | # Initialize with various methods
138 | if self.slv_config["patch"]["initialize"] == "random":
139 | motion0 = self.initialize_random()
140 | elif self.slv_config["patch"]["initialize"] == "zero":
141 | motion0 = self.initialize_zeros()
142 | elif self.slv_config["patch"]["initialize"] == "global-best":
143 | logger.info("sampling initialization")
144 | best_guess = self.initialize_guess_from_whole_image(events)
145 | if isinstance(best_guess, torch.Tensor):
146 | motion0 = torch.tile(best_guess[None], (self.n_patch, 1)).T.reshape(-1)
147 | elif isinstance(best_guess, np.ndarray):
148 | motion0 = np.tile(best_guess[None], (self.n_patch, 1)).T.reshape(-1)
149 | elif self.slv_config["patch"]["initialize"] == "grid-best":
150 | logger.info("sampling initialization")
151 | best_guess = self.initialize_guess_from_patch(
152 | events, patch_index=self.n_patch // 2 - 1
153 | )
154 | if isinstance(best_guess, torch.Tensor):
155 | motion0 = torch.tile(best_guess[None], (self.n_patch, 1)).T.reshape(-1)
156 | elif isinstance(best_guess, np.ndarray):
157 | motion0 = np.tile(best_guess[None], (self.n_patch, 1)).T.reshape(-1)
158 | # motion0 += (
159 | # np.random.rand(self.motion_vector_size * self.n_patch).astype(np.double64) * 10 - 5
160 | # )
161 | elif self.slv_config["patch"]["initialize"] == "optuna-sampling":
162 | logger.info("Optuna intelligent sampling initialization")
163 | motion0 = self.initialize_guess_from_optuna_sampling(events)
164 | self.cost_func.clear_history()
165 |
166 | self.events = torch.from_numpy(events).double().requires_grad_().to(self._device)
167 | result = scipy_autograd.minimize(
168 | self.objective_scipy,
169 | motion0,
170 | method=self.opt_method,
171 | options={
172 | "gtol": 1e-7,
173 | "disp": True,
174 | "maxiter": self.opt_config["max_iter"],
175 | "eps": 0.01,
176 | },
177 | precision="float64",
178 | torch_device=self._device,
179 | # TODO support bounds
180 | # bounds=[(-300, 300), (-300, 300)]
181 | )
182 | return result
183 |
184 | def objective_scipy(self, motion_array: np.ndarray, suppress_log: bool = False):
185 | """
186 | Args:
187 | motion_array (np.ndarray): [2 * n_patches] array
188 |
189 | Returns:
190 | [type]: [description]
191 | """
192 | if self.normalize_t_in_batch:
193 | t_scale = self.events[:, 2].max() - self.events[:, 2].min()
194 | else:
195 | t_scale = 1.0
196 |
197 | events = self.events.clone()
198 | dense_flow = self.motion_to_dense_flow(motion_array * t_scale)
199 |
200 | loss = self.calculate_cost(
201 | events,
202 | dense_flow,
203 | self.motion_model_for_dense_warp,
204 | motion_array.reshape((self.motion_vector_size,) + self.patch_image_size),
205 | )
206 | if not suppress_log:
207 | logger.info(f"{loss = }")
208 | return loss
209 |
--------------------------------------------------------------------------------
/src/solver/scipy_autograd/README.md:
--------------------------------------------------------------------------------
1 |
2 | Code is cominig from
3 | https://github.com/brunorigal/autograd-minimize
4 |
5 |
--------------------------------------------------------------------------------
/src/solver/scipy_autograd/__init__.py:
--------------------------------------------------------------------------------
1 | from .scipy_minimize import minimize
2 |
--------------------------------------------------------------------------------
/src/solver/scipy_autograd/base_wrapper.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import scipy.optimize as sopt
5 | import torch
6 |
7 |
8 | class BaseWrapper(ABC):
9 | def get_input(self, input_var):
10 | self.input_type = type(input_var)
11 | assert self.input_type in [
12 | dict,
13 | list,
14 | np.ndarray,
15 | torch.Tensor,
16 | ], "The initial input to your optimized function should be one of dict, list or np.ndarray"
17 | input_, self.shapes = self._concat(input_var)
18 | self.var_num = input_.shape[0]
19 | return input_
20 |
21 | def get_output(self, output_var):
22 | assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
23 | output_var_ = self._unconcat(output_var, self.shapes)
24 | return output_var_
25 |
26 | def get_bounds(self, bounds):
27 |
28 | if bounds is not None:
29 | if isinstance(bounds, tuple) and not (
30 | isinstance(bounds[0], tuple) or isinstance(bounds[0], sopt.Bounds)
31 | ):
32 | assert len(bounds) == 2
33 | new_bounds = [bounds] * self.var_num
34 |
35 | elif isinstance(bounds, sopt.Bounds):
36 | new_bounds = [bounds] * self.var_num
37 |
38 | elif type(bounds) in [list, tuple, np.ndarray]:
39 | if self.input_type in [list, tuple]:
40 | assert len(self.shapes) == len(bounds)
41 | new_bounds = []
42 | for sh, bounds_ in zip(self.shapes, bounds):
43 | new_bounds += format_bounds(bounds_, sh)
44 | elif self.input_type in [np.ndarray]:
45 | new_bounds = bounds
46 | elif self.input_type in [torch.Tensor]:
47 | new_bounds = bounds.detach().cpu().numpy()
48 |
49 | elif isinstance(bounds, dict):
50 | assert self.input_type == dict
51 | assert set(bounds.keys()).issubset(self.shapes.keys())
52 |
53 | new_bounds = []
54 | for k in self.shapes.keys():
55 | if k in bounds.keys():
56 | new_bounds += format_bounds(bounds[k], self.shapes[k])
57 | else:
58 | new_bounds += [(None, None)] ** np.prod(self.shapes[k], dtype=np.int32)
59 | else:
60 | new_bounds = bounds
61 | return new_bounds
62 |
63 | def get_constraints(self, constraints, method):
64 | if constraints is not None and not isinstance(constraints, sopt.LinearConstraint):
65 | assert isinstance(constraints, dict)
66 | assert "fun" in constraints.keys()
67 | self.ctr_func = constraints["fun"]
68 | use_autograd = constraints.get("use_autograd", True)
69 | if method in ["trust-constr"]:
70 |
71 | constraints = sopt.NonlinearConstraint(
72 | self._eval_ctr_func,
73 | lb=constraints.get("lb", -np.inf),
74 | ub=constraints.get("ub", np.inf),
75 | jac=self.get_ctr_jac if use_autograd else "2-point",
76 | keep_feasible=constraints.get("keep_feasible", False),
77 | )
78 | elif method in ["COBYLA", "SLSQP"]:
79 | constraints = {
80 | "type": constraints.get("type", "eq"),
81 | "fun": self._eval_ctr_func,
82 | }
83 | if use_autograd:
84 | constraints["jac"] = self.get_ctr_jac
85 | else:
86 | raise NotImplementedError
87 | elif constraints is None:
88 | constraints = ()
89 | return constraints
90 |
91 | @abstractmethod
92 | def get_value_and_grad(self, input_var):
93 | return
94 |
95 | @abstractmethod
96 | def get_hvp(self, input_var, vector):
97 | return
98 |
99 | @abstractmethod
100 | def get_hess(self, input_var):
101 | return
102 |
103 | def _eval_func(self, input_var):
104 | if isinstance(input_var, dict):
105 | loss = self.func(**input_var)
106 | elif isinstance(input_var, list) or isinstance(input_var, tuple):
107 | loss = self.func(*input_var)
108 | else:
109 | loss = self.func(input_var)
110 | return loss
111 |
112 | def _eval_ctr_func(self, input_var):
113 | input_var = self._unconcat(input_var, self.shapes)
114 | if isinstance(input_var, dict):
115 | ctr_val = self.ctr_func(**input_var)
116 | elif isinstance(input_var, list) or isinstance(input_var, tuple):
117 | ctr_val = self.ctr_func(*input_var)
118 | else:
119 | ctr_val = self.ctr_func(input_var)
120 | return ctr_val
121 |
122 | @abstractmethod
123 | def get_ctr_jac(self, input_var):
124 | return
125 |
126 | def _concat(self, ten_vals):
127 | ten = []
128 | if isinstance(ten_vals, dict):
129 | shapes = {}
130 | for k, t in ten_vals.items():
131 | if t is not None:
132 | if isinstance(t, (np.floating, float, int)):
133 | t = np.array(t)
134 | shapes[k] = t.shape
135 | ten.append(self._reshape(t, [-1]))
136 | ten = self._tconcat(ten, 0)
137 |
138 | elif isinstance(ten_vals, list) or isinstance(ten_vals, tuple):
139 | shapes = []
140 | for t in ten_vals:
141 | if t is not None:
142 | if isinstance(t, (np.floating, float, int)):
143 | t = np.array(t)
144 | shapes.append(t.shape)
145 | ten.append(self._reshape(t, [-1]))
146 | ten = self._tconcat(ten, 0)
147 |
148 | elif isinstance(ten_vals, (np.floating, float, int)):
149 | ten_vals = np.array(ten_vals)
150 | shapes = np.array(ten_vals).shape
151 | ten = self._reshape(np.array(ten_vals), [-1])
152 | elif isinstance(ten_vals, torch.Tensor):
153 | ten_vals = ten_vals.detach().cpu()
154 | shapes = ten_vals.shape
155 | ten = self._reshape(ten_vals, [-1])
156 | else:
157 | ten_vals = ten_vals
158 | shapes = ten_vals.shape
159 | ten = self._reshape(ten_vals, [-1])
160 | return ten, shapes
161 |
162 | def _unconcat(self, ten, shapes):
163 | current_ind = 0
164 | if isinstance(shapes, dict):
165 | ten_vals = {}
166 | for k, sh in shapes.items():
167 | next_ind = current_ind + np.prod(sh, dtype=np.int32)
168 | ten_vals[k] = self._reshape(self._gather(ten, current_ind, next_ind), sh)
169 |
170 | current_ind = next_ind
171 |
172 | elif isinstance(shapes, list) or isinstance(shapes, tuple):
173 | if isinstance(shapes[0], int):
174 | ten_vals = self._reshape(ten, shapes)
175 | else:
176 | ten_vals = []
177 | for sh in shapes:
178 | next_ind = current_ind + np.prod(sh, dtype=np.int32)
179 | ten_vals.append(self._reshape(self._gather(ten, current_ind, next_ind), sh))
180 |
181 | current_ind = next_ind
182 |
183 | elif shapes is None:
184 | ten_vals = ten
185 |
186 | return ten_vals
187 |
188 | @abstractmethod
189 | def _reshape(self, t, sh):
190 | return
191 |
192 | @abstractmethod
193 | def _tconcat(self, t_list, dim=0):
194 | return
195 |
196 | @abstractmethod
197 | def _gather(self, t, i, j):
198 | return
199 |
200 |
201 | def format_bounds(bounds_, sh):
202 | if isinstance(bounds_, tuple):
203 | assert len(bounds_) == 2
204 | return [bounds_] * np.prod(sh, dtype=np.int32)
205 | elif isinstance(bounds_, sopt.Bounds):
206 | return [bounds_] * np.prod(sh, dtype=np.int32)
207 | elif isinstance(bounds_, list):
208 | assert np.prod(sh) == len(bounds_)
209 | return bounds_
210 | elif isinstance(bounds_, np.ndarray):
211 | assert np.prod(sh) == np.prod(np.array(bounds_).shape)
212 | return np.concatenate(np.reshape(bounds_, -1)).tolist()
213 | else:
214 | raise TypeError
215 |
--------------------------------------------------------------------------------
/src/solver/scipy_autograd/scipy_minimize.py:
--------------------------------------------------------------------------------
1 | import scipy.optimize as sopt
2 |
3 | from .torch_wrapper import TorchWrapper
4 |
5 |
6 | def minimize(
7 | fun,
8 | x0,
9 | args=(),
10 | precision="float32",
11 | method=None,
12 | hvp_type=None,
13 | torch_device="cpu",
14 | bounds=None,
15 | constraints=None,
16 | tol=None,
17 | callback=None,
18 | options=None,
19 | ):
20 | """
21 | wrapper around the [minimize](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html)
22 | function of scipy which includes an automatic computation of gradients,
23 | hessian vector product or hessian with tensorflow or torch backends.
24 | :param fun: function to be minimized, its signature can be a tensor, a list of tensors or a dict of tensors.
25 | :type fun: tensorflow of torch function
26 | :param x0: input to the function, it must match the signature of the function.
27 | :type x0: np.ndarray, list of arrays or dict of arrays.
28 | :param precision: one of 'float32' or 'float64', defaults to 'float32'
29 | :type precision: str, optional
30 | :param method: method used by the optimizer, it should be one of:
31 | 'Nelder-Mead',
32 | 'Powell',
33 | 'CG',
34 | 'BFGS',
35 | 'Newton-CG',
36 | 'L-BFGS-B',
37 | 'TNC',
38 | 'COBYLA',
39 | 'SLSQP',
40 | 'trust-constr',
41 | 'dogleg', # requires positive semi definite hessian
42 | 'trust-ncg',
43 | 'trust-exact', # requires hessian
44 | 'trust-krylov'
45 | , defaults to None
46 | :type method: str, optional
47 | :param hvp_type: type of computation scheme for the hessian vector product
48 | for the torch backend it is one of hvp and vhp (vhp is faster according to the [doc](https://pytorch.org/docs/stable/autograd.html))
49 | for the tf backend it is one of 'forward_over_back', 'back_over_forward', 'tf_gradients_forward_over_back' and 'back_over_back'
50 | Some infos about the most interesting scheme are given [here](https://www.tensorflow.org/api_docs/python/tf/autodiff/ForwardAccumulator)
51 | , defaults to None
52 | :type hvp_type: str, optional
53 | :param torch_device: device used by torch for the gradients computation,
54 | if the backend is not torch, this parameter is ignored, defaults to 'cpu'
55 | :type torch_device: str, optional
56 | :param bounds: Bounds on the input variables, only available for L-BFGS-B, TNC, SLSQP, Powell, and trust-constr methods.
57 | It can be:
58 | * a tuple (min, max), None indicates no bounds, in this case the same bound is applied to all variables.
59 | * An instance of the [Bounds](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.Bounds.html#scipy.optimize.Bounds) class, in this case the same bound is applied to all variables.
60 | * A numpy array of bounds (if the optimized function has a single numpy array as input)
61 | * A list or dict of bounds with the same format as the optimized function signature.
62 | , defaults to None
63 | :type bounds: tuple, list, dict or np.ndarray, optional
64 | :param constraints: It has to be a dict with the following keys:
65 | * fun: a callable computing the constraint function
66 | * lb and ub: the lower and upper bounds, if equal, the constraint is an inequality, use np.inf if there is no upper bound. Only used if method is trust-constr.
67 | * type: 'eq' or 'ineq' only used if method is one of COBYLA, SLSQP.
68 | * keep_feasible: see [here](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.NonlinearConstraint.html#scipy.optimize.NonlinearConstraint)
69 | , defaults to None
70 | :type constraints: dict, optional
71 | :param tol: Tolerance for termination, defaults to None
72 | :type tol: float, optional
73 | :param callback: Called after each iteration, defaults to None
74 | :type callback: callable, optional
75 | :param options: solver options, defaults to None
76 | :type options: dict, optional
77 | :return: dict of optimization results
78 | :rtype: dict
79 | """
80 |
81 | wrapper = TorchWrapper(fun, precision=precision, hvp_type=hvp_type, device=torch_device)
82 |
83 | if bounds is not None:
84 | assert method in [
85 | None,
86 | "L-BFGS-B",
87 | "TNC",
88 | "SLSQP",
89 | "Powell",
90 | "trust-constr",
91 | ], "bounds are only available for L-BFGS-B, TNC, SLSQP, Powell, trust-constr"
92 |
93 | if constraints is not None:
94 | assert method in [
95 | "COBYLA",
96 | "SLSQP",
97 | "trust-constr",
98 | ], "Constraints are only available for COBYLA, SLSQP and trust-constr"
99 |
100 | optim_res = sopt.minimize(
101 | wrapper.get_value_and_grad,
102 | wrapper.get_input(x0),
103 | args=args,
104 | method=method,
105 | jac=True,
106 | hessp=wrapper.get_hvp
107 | if method in ["Newton-CG", "trust-ncg", "trust-krylov", "trust-constr"]
108 | else None,
109 | hess=wrapper.get_hess if method in ["dogleg", "trust-exact"] else None,
110 | bounds=wrapper.get_bounds(bounds),
111 | constraints=wrapper.get_constraints(constraints, method),
112 | tol=tol,
113 | callback=callback,
114 | options=options,
115 | )
116 |
117 | optim_res.x = wrapper.get_output(optim_res.x)
118 |
119 | if "jac" in optim_res.keys() and len(optim_res.jac) > 0:
120 | try:
121 | optim_res.jac = wrapper.get_output(optim_res.jac[0])
122 | except:
123 | pass
124 |
125 | return optim_res
126 |
--------------------------------------------------------------------------------
/src/solver/scipy_autograd/torch_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, List, Tuple, Union
2 |
3 | import numpy as np
4 | import torch
5 | from torch import Tensor, nn
6 | from torch.autograd.functional import hessian, hvp, vhp
7 |
8 | from .base_wrapper import BaseWrapper
9 |
10 |
11 | class TorchWrapper(BaseWrapper):
12 | def __init__(self, func, precision="float32", hvp_type="vhp", device="cpu"):
13 | self.func = func
14 |
15 | # Not very clean...
16 | if "device" in dir(func):
17 | self.device = func.device
18 | else:
19 | self.device = torch.device(device)
20 |
21 | if precision == "float32":
22 | self.precision = torch.float32
23 | elif precision == "float64":
24 | self.precision = torch.float64
25 | else:
26 | raise ValueError
27 |
28 | self.hvp_func = hvp if hvp_type == "hvp" else vhp
29 |
30 | def get_value_and_grad(self, input_var):
31 | assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
32 |
33 | input_var_ = self._unconcat(
34 | torch.tensor(input_var, dtype=self.precision, requires_grad=True, device=self.device),
35 | self.shapes,
36 | )
37 |
38 | loss = self._eval_func(input_var_)
39 | input_var_grad = input_var_.values() if isinstance(input_var_, dict) else input_var_
40 | grads = torch.autograd.grad(loss, input_var_grad)
41 | # print('=====', torch.linalg.norm(grads[0], ord=1))
42 |
43 | if isinstance(input_var_, dict):
44 | grads = {k: v for k, v in zip(input_var_.keys(), grads)}
45 |
46 | return [
47 | loss.cpu().detach().numpy().astype(np.float64),
48 | self._concat(grads)[0].cpu().detach().numpy().astype(np.float64),
49 | ]
50 |
51 | def get_hvp(self, input_var, vector):
52 | assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
53 |
54 | input_var_ = self._unconcat(
55 | torch.tensor(input_var, dtype=self.precision, device=self.device), self.shapes
56 | )
57 | vector_ = self._unconcat(
58 | torch.tensor(vector, dtype=self.precision, device=self.device), self.shapes
59 | )
60 |
61 | if isinstance(input_var_, dict):
62 | input_var_ = tuple(input_var_.values())
63 | if isinstance(vector_, dict):
64 | vector_ = tuple(vector_.values())
65 |
66 | if isinstance(input_var_, list):
67 | input_var_ = tuple(input_var_)
68 | if isinstance(vector_, list):
69 | vector_ = tuple(vector_)
70 |
71 | loss, vhp_res = self.hvp_func(self.func, input_var_, v=vector_)
72 |
73 | return self._concat(vhp_res)[0].cpu().detach().numpy().astype(np.float64)
74 |
75 | def get_hess(self, input_var):
76 | assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
77 | input_var_ = torch.tensor(input_var, dtype=self.precision, device=self.device)
78 |
79 | def func(inp):
80 | return self._eval_func(self._unconcat(inp, self.shapes))
81 |
82 | hess = hessian(func, input_var_, vectorize=False)
83 |
84 | return hess.cpu().detach().numpy().astype(np.float64)
85 |
86 | def get_ctr_jac(self, input_var):
87 | assert "shapes" in dir(self), "You must first call get input to define the tensors shapes."
88 |
89 | input_var_ = self._unconcat(
90 | torch.tensor(input_var, dtype=self.precision, requires_grad=True, device=self.device),
91 | self.shapes,
92 | )
93 |
94 | ctr_val = self._eval_ctr_func(input_var_)
95 | input_var_grad = input_var_.values() if isinstance(input_var_, dict) else input_var_
96 | grads = torch.autograd.grad(ctr_val, input_var_grad)
97 |
98 | return grads.cpu().detach().numpy().astype(np.float64)
99 |
100 | def _reshape(self, t, sh):
101 | if torch.is_tensor(t):
102 | return t.reshape(sh)
103 | elif isinstance(t, np.ndarray):
104 | return np.reshape(t, sh)
105 | else:
106 | raise NotImplementedError
107 |
108 | def _tconcat(self, t_list, dim=0):
109 | if torch.is_tensor(t_list[0]):
110 | return torch.cat(t_list, dim)
111 | elif isinstance(t_list[0], np.ndarray):
112 | return np.concatenate(t_list, dim)
113 | else:
114 | raise NotImplementedError
115 |
116 | def _gather(self, t, i, j):
117 | if isinstance(t, np.ndarray) or torch.is_tensor(t):
118 | return t[i:j]
119 | else:
120 | raise NotImplementedError
121 |
122 |
123 | def torch_function_factory(model, loss, train_x, train_y, precision="float32", optimized_vars=None):
124 | """
125 | A factory to create a function of the torch parameter model.
126 | :param model: torch model
127 | :type model: torch.nn.Modle]
128 | :param loss: a function with signature loss_value = loss(pred_y, true_y).
129 | :type loss: function
130 | :param train_x: dataset used as input of the model
131 | :type train_x: np.ndarray
132 | :param train_y: dataset used as ground truth input of the loss
133 | :type train_y: np.ndarray
134 | :return: (function of the parameters, list of parameters, names of parameters)
135 | :rtype: tuple
136 | """
137 | # named_params = {k: var.cpu().detach().numpy() for k, var in model.named_parameters()}
138 | params, names = extract_weights(model)
139 | device = params[0].device
140 |
141 | prec_ = torch.float32 if precision == "float32" else torch.float64
142 | if isinstance(train_x, np.ndarray):
143 | train_x = torch.tensor(train_x, dtype=prec_, device=device)
144 | if isinstance(train_y, np.ndarray):
145 | train_y = torch.tensor(train_y, dtype=prec_, device=device)
146 |
147 | def func(*new_params):
148 | load_weights(model, {k: v for k, v in zip(names, new_params)})
149 | out = apply_func(model, train_x)
150 |
151 | return loss(out, train_y)
152 |
153 | func.device = device
154 |
155 | return func, [p.cpu().detach().numpy() for p in params], names
156 |
157 |
158 | def apply_func(func, input_):
159 | if isinstance(input_, dict):
160 | return func(**input_)
161 | elif isinstance(input_, list) or isinstance(input_, tuple):
162 | return func(*input_)
163 | else:
164 | return func(input_)
165 |
166 |
167 | # Adapted from https://github.com/pytorch/pytorch/blob/21c04b4438a766cd998fddb42247d4eb2e010f9a/benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py
168 |
169 | # Utilities to make nn.Module "functional"
170 | # In particular the goal is to be able to provide a function that takes as input
171 | # the parameters and evaluate the nn.Module using fixed inputs.
172 |
173 |
174 | def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
175 | """
176 | Deletes the attribute specified by the given list of names.
177 | For example, to delete the attribute obj.conv.weight,
178 | use _del_nested_attr(obj, ['conv', 'weight'])
179 | """
180 | if len(names) == 1:
181 | delattr(obj, names[0])
182 | else:
183 | _del_nested_attr(getattr(obj, names[0]), names[1:])
184 |
185 |
186 | def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
187 | """
188 | Set the attribute specified by the given list of names to value.
189 | For example, to set the attribute obj.conv.weight,
190 | use _del_nested_attr(obj, ['conv', 'weight'], value)
191 | """
192 | if len(names) == 1:
193 | setattr(obj, names[0], value)
194 | else:
195 | _set_nested_attr(getattr(obj, names[0]), names[1:], value)
196 |
197 |
198 | def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
199 | """
200 | This function removes all the Parameters from the model and
201 | return them as a tuple as well as their original attribute names.
202 | The weights must be re-loaded with `load_weights` before the model
203 | can be used again.
204 | Note that this function modifies the model in place and after this
205 | call, mod.parameters() will be empty.
206 | """
207 |
208 | orig_params = [p for p in mod.parameters() if p.requires_grad]
209 | # Remove all the parameters in the model
210 | names = []
211 | for name, p in list(mod.named_parameters()):
212 | if p.requires_grad:
213 | _del_nested_attr(mod, name.split("."))
214 | names.append(name)
215 |
216 | # Make params regular Tensors instead of nn.Parameter
217 | params = tuple(p.detach().requires_grad_() for p in orig_params)
218 | return params, names
219 |
220 |
221 | def load_weights(mod: nn.Module, params: Dict[str, Tensor]) -> None:
222 | """
223 | Reload a set of weights so that `mod` can be used again to perform a forward pass.
224 | Note that the `params` are regular Tensors (that can have history) and so are left
225 | as Tensors. This means that mod.parameters() will still be empty after this call.
226 | """
227 | for name, p in params.items():
228 | _set_nested_attr(mod, name.split("."), p)
229 |
--------------------------------------------------------------------------------
/src/solver/time_aware_patch_contrast.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any, Dict, List, Optional, Tuple
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from .. import solver, types, utils, visualizer
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class TimeAwarePatchContrastMaximization(solver.MixedPatchContrastMaximization):
13 | """Time-aware patch-based CMax.
14 |
15 | Params:
16 | image_shape (tuple) ... (H, W)
17 | calibration_parameter (dict) ... dictionary of the calibration parameter
18 | solver_config (dict) ... solver configuration
19 | optimizer_config (dict) ... optimizer configuration
20 | visualize_module ... visualizer.Visualizer
21 | """
22 |
23 | def __init__(
24 | self,
25 | image_shape: tuple,
26 | calibration_parameter: dict,
27 | solver_config: dict = {},
28 | optimizer_config: dict = {},
29 | output_config: dict = {},
30 | visualize_module: Optional[visualizer.Visualizer] = None,
31 | ):
32 | super().__init__(
33 | image_shape,
34 | calibration_parameter,
35 | solver_config,
36 | optimizer_config,
37 | output_config,
38 | visualize_module,
39 | )
40 | assert self.is_time_aware
41 |
42 | def motion_to_dense_flow(self, motion_array: types.NUMPY_TORCH) -> types.NUMPY_TORCH:
43 | """Returns dense flow at quantized time voxel.
44 | TODO eventually I should be able to remove this entire class!
45 |
46 | Args:
47 | motion_array (types.NUMPY_TORCH): [2 x h_patch x w_patch] Flow array.
48 |
49 | Returns:
50 | types.NUMPY_TORCH: [time_bin x 2 x H x W]
51 | """
52 | if self.scale_later:
53 | scale = motion_array.max()
54 | else:
55 | scale = 1.0
56 | if isinstance(motion_array, np.ndarray):
57 | dense_flow_t0 = self.interpolate_dense_flow_from_patch_numpy(motion_array)
58 | return (
59 | utils.construct_dense_flow_voxel_numpy(
60 | dense_flow_t0 / scale,
61 | self.time_bin,
62 | self.flow_interpolation,
63 | t0_location=self.t0_flow_location,
64 | )
65 | * scale
66 | )
67 | elif isinstance(motion_array, torch.Tensor):
68 | dense_flow_t0 = self.interpolate_dense_flow_from_patch_tensor(motion_array)
69 | return (
70 | utils.construct_dense_flow_voxel_torch(
71 | dense_flow_t0 / scale,
72 | self.time_bin,
73 | self.flow_interpolation,
74 | t0_location=self.t0_flow_location,
75 | )
76 | * scale
77 | )
78 | e = f"Unsupported type: {type(motion_array)}"
79 | raise TypeError(e)
80 |
--------------------------------------------------------------------------------
/src/types/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple, Union
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from .flow_patch import FlowPatch
7 |
8 | NUMPY_TORCH = Union[np.ndarray, torch.Tensor]
9 | FLOAT_TORCH = Union[float, torch.Tensor]
10 |
11 |
12 | def is_torch(arr: Any) -> bool:
13 | return isinstance(arr, torch.Tensor)
14 |
15 |
16 | def is_numpy(arr: Any) -> bool:
17 | return isinstance(arr, np.ndarray)
18 |
19 |
20 | def nt_max(array: NUMPY_TORCH, dim: int) -> NUMPY_TORCH:
21 | """max function compatible for numpy ndarray and torch tensor.
22 |
23 | Args:
24 | array (NUMPY_TORCH):
25 |
26 | Returns:
27 | NUMPY_TORCH: _description_
28 | """
29 | if is_numpy(array):
30 | return array.max(axis=dim) # type: ignore
31 | return torch.max(array, dim).values
32 |
33 |
34 | def nt_min(array: NUMPY_TORCH, dim: int) -> NUMPY_TORCH:
35 | """Min function compatible for numpy ndarray and torch tensor.
36 |
37 | Args:
38 | array (NUMPY_TORCH):
39 |
40 | Returns:
41 | NUMPY_TORCH: _description_
42 | """
43 | if is_numpy(array):
44 | return array.min(axis=dim) # type: ignore
45 | return torch.min(array, dim).values
46 |
--------------------------------------------------------------------------------
/src/types/flow_patch.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from dataclasses import dataclass
3 | from typing import Any, List, Optional
4 |
5 | import numpy as np
6 |
7 |
8 | @dataclass
9 | class FlowPatch:
10 | """Dataclass for flow patch"""
11 |
12 | # center of coordinates
13 | x: np.int16 # height
14 | y: np.int16 # width
15 | shape: tuple # (height, width)
16 | # flow (pixel displacement) value of the flow at the location
17 | u: float = 0.0 # height
18 | v: float = 0.0 # width
19 |
20 | @property
21 | def h(self) -> int:
22 | return self.shape[0]
23 |
24 | @property
25 | def w(self) -> int:
26 | return self.shape[1]
27 |
28 | @property
29 | def x_min(self) -> int:
30 | return int(self.x - np.ceil(self.h / 2))
31 |
32 | @property
33 | def x_max(self) -> int:
34 | return int(self.x + np.floor(self.h / 2))
35 |
36 | @property
37 | def y_min(self) -> int:
38 | return int(self.y - np.ceil(self.w / 2))
39 |
40 | @property
41 | def y_max(self) -> int:
42 | return int(self.y + np.floor(self.w / 2))
43 |
44 | @property
45 | def position(self) -> np.ndarray:
46 | return np.array([self.x, self.y])
47 |
48 | @property
49 | def flow(self) -> np.ndarray:
50 | return np.array([self.u, self.v])
51 |
52 | def update_flow(self, u: float, v: float):
53 | self.u = u
54 | self.v = v
55 |
56 | def new_ones(self):
57 | return np.ones(self.shape)
58 |
59 | def copy(self) -> Any:
60 | return copy.deepcopy(self)
61 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .event_utils import crop_event, generate_events, set_event_origin_to_zero, undistort_events
2 | from .flow_utils import (
3 | calculate_flow_error_numpy,
4 | calculate_flow_error_tensor,
5 | construct_dense_flow_voxel_numpy,
6 | construct_dense_flow_voxel_torch,
7 | estimate_corresponding_gt_flow,
8 | generate_dense_optical_flow,
9 | inviscid_burger_flow_to_voxel_numpy,
10 | inviscid_burger_flow_to_voxel_torch,
11 | upwind_flow_to_voxel_numpy,
12 | upwind_flow_to_voxel_torch,
13 | )
14 | from .misc import (
15 | SingleThreadInMemoryStorage,
16 | check_file_utils,
17 | check_key_and_bool,
18 | fetch_runtime_information,
19 | fix_random_seed,
20 | profile,
21 | )
22 | from .stat_utils import SobelTorch
23 |
--------------------------------------------------------------------------------
/src/utils/event_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional, Tuple
3 |
4 | import numpy as np
5 |
6 | from ..types import FLOAT_TORCH, NUMPY_TORCH, is_numpy, is_torch
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | try:
11 | import torch
12 | except ImportError:
13 | e = "Torch is disabled."
14 | logger.warning(e)
15 |
16 |
17 | # Simulator module
18 | def generate_events(
19 | n_events: int,
20 | height: int,
21 | width: int,
22 | tmin: float = 0.0,
23 | tmax: float = 0.5,
24 | dist: str = "uniform",
25 | ) -> np.ndarray:
26 | """Generate random events.
27 |
28 | Args:
29 | n_events (int) ... num of events
30 | height (int) ... height of the camera
31 | width (int) ... width of the camera
32 | tmin (float) ... timestamp min
33 | tmax (float) ... timestamp max
34 | dist (str) ... currently only "uniform" is supported.
35 |
36 | Returns:
37 | events (np.ndarray) ... [n_events x 4] numpy array. (x, y, t, p)
38 | x indicates height direction.
39 | """
40 | x = np.random.randint(0, height, n_events)
41 | y = np.random.randint(0, width, n_events)
42 | t = np.random.uniform(tmin, tmax, n_events)
43 | t = np.sort(t)
44 | p = np.random.randint(0, 2, n_events)
45 |
46 | events = np.concatenate([x[..., None], y[..., None], t[..., None], p[..., None]], axis=1)
47 | return events
48 |
49 |
50 | def crop_event(events: NUMPY_TORCH, x0: int, x1: int, y0: int, y1: int) -> NUMPY_TORCH:
51 | """Crop events.
52 |
53 | Args:
54 | events (NUMPY_TORCH): [n x 4]. [x, y, t, p].
55 | x0 (int): Start of the crop, at row[0]
56 | x1 (int): End of the crop, at row[0]
57 | y0 (int): Start of the crop, at row[1]
58 | y1 (int): End of the crop, at row[1]
59 |
60 | Returns:
61 | NUMPY_TORCH: Cropped events.
62 | """
63 | mask = (
64 | (x0 <= events[..., 0])
65 | * (events[..., 0] < x1)
66 | * (y0 <= events[..., 1])
67 | * (events[..., 1] < y1)
68 | )
69 | cropped = events[mask]
70 | return cropped
71 |
72 |
73 | def set_event_origin_to_zero(events: np.ndarray, x0: int, y0: int, t0: float = 0.0) -> np.ndarray:
74 | """Set each origin of each row to 0.
75 |
76 | Args:
77 | events (np.ndarray): [n x 4]. [x, y, t, p].
78 | x0 (int): x origin
79 | y0 (int): y origin
80 | t0 (float): t origin
81 |
82 | Returns:
83 | np.ndarray: [n x 4]. x is in [0, xmax - x0], and so on.
84 | """
85 | basis = np.array([x0, y0, t0, 0.0])
86 | if is_torch(events):
87 | basis = torch.from_numpy(basis)
88 | return events - basis
89 |
90 |
91 | def undistort_events(events, map_x, map_y, h, w):
92 | """Undistort (rectify) events.
93 | Args:
94 | events ... [x, y, t, p]. X is height direction.
95 | map_x, map_y... meshgrid
96 |
97 | Returns:
98 | events... events that is in the camera plane after undistortion.
99 | TODO check overflow
100 | """
101 | # k = np.int32(map_y[np.int16(events[:, 1]), np.int16(events[:, 0])])
102 | # l = np.int32(map_x[np.int16(events[:, 1]), np.int16(events[:, 0])])
103 | # k = np.int32(map_y[events[:, 1].astype(np.int32), events[:, 0].astype(np.int32)])
104 | # l = np.int32(map_x[events[:, 1].astype(np.int32), events[:, 0].astype(np.int32)])
105 | # undistort_events = np.copy(events)
106 | # undistort_events[:, 0] = l
107 | # undistort_events[:, 1] = k
108 | # return undistort_events[((0 <= k) & (k < h)) & ((0 <= l) & (l < w))]
109 |
110 | k = np.int32(map_y[events[:, 0].astype(np.int32), events[:, 1].astype(np.int32)])
111 | l = np.int32(map_x[events[:, 0].astype(np.int32), events[:, 1].astype(np.int32)])
112 | undistort_events = np.copy(events)
113 | undistort_events[:, 0] = k
114 | undistort_events[:, 1] = l
115 | return undistort_events[((0 <= k) & (k < h)) & ((0 <= l) & (l < w))]
116 |
--------------------------------------------------------------------------------
/src/utils/misc.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import cProfile
3 | import logging
4 | import os
5 | import pstats
6 | import random
7 | import subprocess
8 | from functools import wraps
9 | from typing import Dict
10 |
11 | import numpy as np
12 | import optuna
13 | import torch
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def fix_random_seed(seed=46) -> None:
19 | """Fix random seed"""
20 | logger.info("Fix random Seed: ", seed)
21 | np.random.seed(seed)
22 | random.seed(seed)
23 | torch.manual_seed(seed)
24 | torch.cuda.manual_seed_all(seed)
25 |
26 |
27 | def check_file_utils(filename: str) -> bool:
28 | """Return True if the file exists.
29 |
30 | Args:
31 | filename (str): _description_
32 |
33 | Returns:
34 | bool: _description_
35 | """
36 | logger.debug(f"Check {filename}")
37 | res = os.path.exists(filename)
38 | if not res:
39 | logger.warning(f"{filename} does not exist!")
40 | return res
41 |
42 |
43 | def check_key_and_bool(config: dict, key: str) -> bool:
44 | """Check the existance of the key and if it's True
45 |
46 | Args:
47 | config (dict): dict.
48 | key (str): Key name to be checked.
49 |
50 | Returns:
51 | bool: Return True only if the key exists in the dict and its value is True.
52 | Otherwise returns False.
53 | """
54 | return key in config.keys() and config[key]
55 |
56 |
57 | def fetch_runtime_information() -> dict:
58 | """Fetch information of the experiment at runtime.
59 |
60 | Returns:
61 | dict: _description_
62 | """
63 | config = {}
64 | config["commit"] = fetch_commit_id()
65 | config["server"] = get_server_name()
66 | return config
67 |
68 |
69 | def fetch_commit_id() -> str:
70 | """Get the latest commit ID of the repository.
71 |
72 | Returns:
73 | str: commit hash
74 | """
75 | label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()
76 | return label.decode("utf-8")
77 |
78 |
79 | def get_server_name() -> str:
80 | """Always returns `unknown` for the public code :)
81 |
82 | Returns:
83 | str: _description_
84 | """
85 | return "unknown"
86 |
87 |
88 | def profile(output_file=None, sort_by="cumulative", lines_to_print=None, strip_dirs=False):
89 | """A time profiler decorator.
90 | Inspired by: http://code.activestate.com/recipes/577817-profile-decorator/
91 |
92 | Usage:
93 | ```
94 | @profile(output_file= ...)
95 | def your_function():
96 | ...
97 | ```
98 | Then you will get the profile automatically after the function call is finished.
99 |
100 | Args:
101 | output_file: str or None. Default is None
102 | Path of the output file. If only name of the file is given, it's
103 | saved in the current directory.
104 | If it's None, the name of the decorated function is used.
105 | sort_by: str or SortKey enum or tuple/list of str/SortKey enum
106 | Sorting criteria for the Stats object.
107 | For a list of valid string and SortKey refer to:
108 | https://docs.python.org/3/library/profile.html#pstats.Stats.sort_stats
109 | lines_to_print: int or None
110 | Number of lines to print. Default (None) is for all the lines.
111 | This is useful in reducing the size of the printout, especially
112 | that sorting by 'cumulative', the time consuming operations
113 | are printed toward the top of the file.
114 | strip_dirs: bool
115 | Whether to remove the leading path info from file names.
116 | This is also useful in reducing the size of the printout
117 | Returns:
118 | Profile of the decorated function
119 | """
120 |
121 | def inner(func):
122 | @wraps(func)
123 | def wrapper(*args, **kwargs):
124 | _output_file = output_file or func.__name__ + ".prof"
125 | pr = cProfile.Profile()
126 | pr.enable()
127 | retval = func(*args, **kwargs)
128 | pr.disable()
129 | pr.dump_stats(_output_file)
130 |
131 | with open(_output_file, "w") as f:
132 | ps = pstats.Stats(pr, stream=f)
133 | if strip_dirs:
134 | ps.strip_dirs()
135 | if isinstance(sort_by, (tuple, list)):
136 | ps.sort_stats(*sort_by)
137 | else:
138 | ps.sort_stats(sort_by)
139 | ps.print_stats(lines_to_print)
140 | return retval
141 |
142 | return wrapper
143 |
144 | return inner
145 |
146 |
147 | class SingleThreadInMemoryStorage(optuna.storages.InMemoryStorage):
148 | """This is faster version of in-memory storage only when the study n_jobs = 1 (single thread).
149 | Adopted from https://github.com/optuna/optuna/issues/3151
150 |
151 | Args:
152 | optuna ([type]): [description]
153 | """
154 |
155 | def set_trial_param(
156 | self,
157 | trial_id: int,
158 | param_name: str,
159 | param_value_internal: float,
160 | distribution: optuna.distributions.BaseDistribution,
161 | ) -> None:
162 | with self._lock:
163 | trial = self._get_trial(trial_id)
164 | self.check_trial_is_updatable(trial_id, trial.state)
165 |
166 | study_id = self._trial_id_to_study_id_and_number[trial_id][0]
167 | # Check param distribution compatibility with previous trial(s).
168 | if param_name in self._studies[study_id].param_distribution:
169 | optuna.distributions.check_distribution_compatibility(
170 | self._studies[study_id].param_distribution[param_name], distribution
171 | )
172 | # Set param distribution.
173 | self._studies[study_id].param_distribution[param_name] = distribution
174 |
175 | # Set param.
176 | trial.params[param_name] = distribution.to_external_repr(param_value_internal)
177 | trial.distributions[param_name] = distribution
178 |
--------------------------------------------------------------------------------
/src/utils/stat_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 |
4 | import numpy as np
5 | import scipy
6 | import scipy.fftpack
7 | import torch
8 | from torch import nn
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class SobelTorch(nn.Module):
14 | """Sobel operator for pytorch, for divergence calculation.
15 | This is equivalent implementation of
16 | ```
17 | sobelx = cv2.Sobel(flow[0], cv2.CV_64F, 1, 0, ksize=3)
18 | sobely = cv2.Sobel(flow[1], cv2.CV_64F, 0, 1, ksize=3)
19 | dxy = (sobelx + sobely) / 8.0
20 | ```
21 | Args:
22 | ksize (int) ... Kernel size of the convolution operation.
23 | in_channels (int) ... In channles.
24 | cuda_available (bool) ... True if cuda is available.
25 | """
26 |
27 | def __init__(
28 | self, ksize: int = 3, in_channels: int = 2, cuda_available: bool = False, precision="32"
29 | ):
30 | super().__init__()
31 | self.cuda_available = cuda_available
32 | self.in_channels = in_channels
33 | self.filter_dx = nn.Conv2d(
34 | in_channels=in_channels,
35 | out_channels=1,
36 | kernel_size=ksize,
37 | stride=1,
38 | padding=1,
39 | bias=False,
40 | )
41 | self.filter_dy = nn.Conv2d(
42 | in_channels=in_channels,
43 | out_channels=1,
44 | kernel_size=ksize,
45 | stride=1,
46 | padding=1,
47 | bias=False,
48 | )
49 | # x in height direction
50 | if precision == "64":
51 | Gx = torch.tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]).double()
52 | Gy = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]).double()
53 | else:
54 | Gx = torch.tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]])
55 | Gy = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])
56 |
57 | if self.cuda_available:
58 | Gx = Gx.cuda()
59 | Gy = Gy.cuda()
60 |
61 | self.filter_dx.weight = nn.Parameter(Gx.unsqueeze(0).unsqueeze(0), requires_grad=False)
62 | self.filter_dy.weight = nn.Parameter(Gy.unsqueeze(0).unsqueeze(0), requires_grad=False)
63 |
64 | def forward(self, img):
65 | """
66 | Args:
67 | img (torch.Tensor) ... [b x (2 or 1) x H x W]. The 2 ch is [h, w] direction.
68 |
69 | Returns:
70 | sobel (torch.Tensor) ... [b x (4 or 2) x (H - 2) x (W - 2)].
71 | 4ch means Sobel_x on xdim, Sobel_y on ydim, Sobel_x on ydim, and Sobel_y on xdim.
72 | To make it divergence, run `(sobel[:, 0] + sobel[:, 1]) / 8.0`.
73 | """
74 | if self.in_channels == 2:
75 | dxx = self.filter_dx(img[..., [0], :, :])
76 | dyy = self.filter_dy(img[..., [1], :, :])
77 | dyx = self.filter_dx(img[..., [1], :, :])
78 | dxy = self.filter_dy(img[..., [0], :, :])
79 | return torch.cat([dxx, dyy, dyx, dxy], dim=1)
80 | elif self.in_channels == 1:
81 | dx = self.filter_dx(img[..., [0], :, :])
82 | dy = self.filter_dy(img[..., [0], :, :])
83 | return torch.cat([dx, dy], dim=1)
84 |
--------------------------------------------------------------------------------
/src/visualizer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from typing import Any, Dict, List, Optional
4 |
5 | import cv2
6 | import numpy as np
7 | import plotly.graph_objects as go
8 | from matplotlib import pyplot as plt
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | from PIL import Image, ImageDraw
14 |
15 | from . import event_image_converter, types, warp
16 |
17 | TRANSPARENCY = 0.25 # Degree of transparency, 0-100%
18 | OPACITY = int(255 * TRANSPARENCY)
19 |
20 |
21 | class Visualizer:
22 | """Visualization class for multi utility. It includes visualization of
23 | - Events (polarity-based or event-based, 2D or 3D, etc...)
24 | - Images
25 | - Optical flow
26 | - Optimization history, loss function
27 | - Matplotlib figure
28 | etc.
29 | Also it works generic for if it saves and/or shows the figures.
30 |
31 | Args:
32 | image_shape (tuple) ... [H, W]. Image shape is necessary to visualize events.
33 | show (bool) ... It True, it shows the visualization results when any fuction is called.
34 | save (bool) ... It True, it saves the results under `save_dir` without any duplication.
35 | save_dir (str) ... Applicable when `save` is True. The root directory for the save.
36 |
37 | """
38 |
39 | def __init__(self, image_shape: tuple, show=False, save=False, save_dir=None) -> None:
40 | super().__init__()
41 | self.update_image_shape(image_shape)
42 | self._show = show
43 | self._save = save
44 | if save_dir is None:
45 | save_dir = "./"
46 | self.update_save_dir(save_dir)
47 | self.default_prefix = "" # default file prefix
48 | self.default_save_count = 0 # default save count
49 | self.prefixed_save_count: Dict[str, int] = {}
50 |
51 | def update_image_shape(self, image_shape):
52 | self._image_size = image_shape # H, W
53 | self._image_height = image_shape[0]
54 | self._image_width = image_shape[1]
55 | self.imager = event_image_converter.EventImageConverter(image_shape)
56 |
57 | def update_save_dir(self, new_dir: str) -> None:
58 | """Update save directiry. Creates it if not exist.
59 |
60 | Args:
61 | new_dir (str): New directory
62 | """
63 | self.save_dir = new_dir
64 | if not os.path.exists(self.save_dir):
65 | os.makedirs(self.save_dir)
66 |
67 | def get_filename_from_prefix(
68 | self, prefix: Optional[str] = None, file_format: str = "png"
69 | ) -> str:
70 | """Helper function: returns expected filename from the prefix.
71 | It makes sure to save the output filename without any duplication.
72 |
73 | Args:
74 | prefix (Optional[str], optional): Prefix. Defaults to None.
75 | format (str) ... file format. Defaults to png.
76 |
77 | Returns:
78 | str: ${save_dir}/{prefix}{count}.png. Count automatically goes up.
79 | """
80 | if prefix is None or prefix == "":
81 | file_name = os.path.join(
82 | self.save_dir, f"{self.default_prefix}{self.default_save_count}.{file_format}"
83 | )
84 | self.default_save_count += 1
85 | else:
86 | try:
87 | self.prefixed_save_count[prefix] += 1
88 | except KeyError:
89 | self.prefixed_save_count[prefix] = 0
90 | file_name = os.path.join(
91 | self.save_dir, f"{prefix}{self.prefixed_save_count[prefix]}.{file_format}"
92 | )
93 | return file_name
94 |
95 | def rollback_save_count(self, prefix: Optional[str] = None):
96 | """Helper function:
97 | # hack - neeeds to be consistent number between .png and .npy
98 |
99 | Args:
100 | prefix (Optional[str], optional): Prefix. Defaults to None.
101 | """
102 | if prefix is None or prefix == "":
103 | self.default_save_count -= 1
104 | else:
105 | try:
106 | self.prefixed_save_count[prefix] -= 1
107 | except KeyError:
108 | raise ValueError("The visualization save count error")
109 |
110 | def reset_save_count(self, file_prefix: Optional[str] = None):
111 | if file_prefix is None or file_prefix == "":
112 | self.default_save_count = 0
113 | elif file_prefix == "all":
114 | self.default_save_count = 0
115 | self.prefixed_save_count = {}
116 | else:
117 | del self.prefixed_save_count[file_prefix]
118 |
119 | def _show_or_save_image(
120 | self, image: Any, file_prefix: Optional[str] = None, fixed_file_name: Optional[str] = None
121 | ):
122 | """Helper function - save and/or show the image.
123 |
124 | Args:
125 | image (Any): PIL.Image
126 | file_prefix (Optional[str], optional): [description]. Defaults to None.
127 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`.
128 | """
129 | if self._show:
130 | if image.mode == "RGBA":
131 | image = image.convert("RGB") # Back to RGB
132 | image.show()
133 | if self._save:
134 | if image.mode == "RGBA":
135 | image = image.convert("RGB") # Back to RGB
136 | if fixed_file_name is not None:
137 | image.save(os.path.join(self.save_dir, f"{fixed_file_name}.png"))
138 | else:
139 | image.save(self.get_filename_from_prefix(file_prefix))
140 |
141 | # Image related
142 | def load_image(self, image: Any) -> Image.Image:
143 | """A wrapper function to get image and returns PIL Image object.
144 |
145 | Args:
146 | image (str or np.ndarray): If it is str, open and load the image.
147 | If it is numpy array, it converts to PIL.Image.
148 |
149 | Returns:
150 | Image.Image: PIl Image object.
151 | """
152 | if type(image) == str:
153 | image = Image.open(image)
154 | elif type(image) == np.ndarray:
155 | image = Image.fromarray(image)
156 | return image
157 |
158 | def visualize_image(self, image: Any, file_prefix: Optional[str] = None) -> Image.Image:
159 | """Visualize image.
160 |
161 | Args:
162 | image (Any): str, np.ndarray, or PIL Image.
163 | file_prefix (Optional[str], optional): [description]. Defaults to None.
164 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`.
165 |
166 | Returns:
167 | Image.Image: PIL Image object
168 | """
169 | image = self.load_image(image)
170 | self._show_or_save_image(image, file_prefix)
171 | return image
172 |
173 | def create_clipped_iwe_for_visualization(self, events, max_scale=50):
174 | """Utility function for clipped IWE. Same one in solver.
175 |
176 | Args:
177 | events (_type_): _description_
178 | max_scale (int, optional): _description_. Defaults to 50.
179 |
180 | Returns:
181 | _type_: _description_
182 | """
183 | im = self.imager.create_image_from_events_numpy(events, method="bilinear_vote", sigma=0)
184 | clipped_iwe = 255 - np.clip(max_scale * im, 0, 255).astype(np.uint8)
185 | return clipped_iwe
186 |
187 | # Optical flow
188 | def visualize_optical_flow(
189 | self,
190 | flow_x: np.ndarray,
191 | flow_y: np.ndarray,
192 | visualize_color_wheel: bool = True,
193 | file_prefix: Optional[str] = None,
194 | save_flow: bool = False,
195 | ord: float = 0.5,
196 | ):
197 | """Visualize optical flow.
198 | Args:
199 | flow_x (numpy.ndarray) ... [H x W], height direction.
200 | flow_y (numpy.ndarray) ... [H x W], width direction.
201 | visualize_color_wheel (bool) ... If True, it also visualizes the color wheel (legend for OF).
202 | file_prefix (Optional[str], optional): [description]. Defaults to None.
203 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`.
204 |
205 | Returns:
206 | image (PIL.Image) ... PIL image.
207 | """
208 | if save_flow:
209 | save_name = self.get_filename_from_prefix(file_prefix).replace("png", "npy")
210 | np.save(save_name, np.stack([flow_x, flow_y], axis=0))
211 | self.rollback_save_count(file_prefix)
212 | flow_rgb, color_wheel, _ = self.color_optical_flow(flow_x, flow_y, ord=ord)
213 | image = Image.fromarray(flow_rgb)
214 | self._show_or_save_image(image, file_prefix)
215 |
216 | if visualize_color_wheel:
217 | wheel = Image.fromarray(color_wheel)
218 | self._show_or_save_image(wheel, fixed_file_name="color_wheel")
219 | return image
220 |
221 | # Combined with events
222 | def visualize_overlay_optical_flow_on_event(
223 | self,
224 | flow: np.ndarray,
225 | events: np.ndarray,
226 | file_prefix: Optional[str] = None,
227 | ord: float = 0.5,
228 | ):
229 | """Visualize optical flow on event data.
230 | Args:
231 | flow (numpy.ndarray) ... [2 x H x W]
232 | events (np.ndarray) ... event_image (H x W) or raw events (n_events x 4).
233 | file_prefix (Optional[str], optional): [description]. Defaults to None.
234 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`.
235 |
236 | Returns:
237 | image (PIL.Image) ... PIL image.
238 | """
239 | _show, _save = self._show, self._save
240 | self._show, self._save = False, False
241 | flow_image = self.visualize_optical_flow(flow[0], flow[1], ord=ord)
242 | flow_ratio = 0.8
243 | flow_image.putalpha(int(255 * flow_ratio))
244 | if events.shape[1] == 4: # raw events
245 | event_image = self.visualize_event(events, grayscale=False).convert("RGB")
246 | else:
247 | event_image = self.visualize_image(events).convert("RGB")
248 | event_image.putalpha(255 - int(255 * flow_ratio))
249 | flow_image.paste(event_image, None, event_image)
250 | self._show, self._save = _show, _save
251 | self._show_or_save_image(flow_image, file_prefix)
252 | return flow_image
253 |
254 | def visualize_optical_flow_on_event_mask(
255 | self,
256 | flow: np.ndarray,
257 | events: np.ndarray,
258 | file_prefix: Optional[str] = None,
259 | ord: float = 0.5,
260 | max_color_on_mask: bool = True,
261 | ):
262 | """Visualize optical flow only where event exists.
263 | Args:
264 | flow (numpy.ndarray) ... [2 x H x W]
265 | events (np.ndarray) ... [n_events x 4]
266 | file_prefix (Optional[str], optional): [description]. Defaults to None.
267 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`.
268 |
269 | max_color_on_mask (bool) ... If True, the max magnitude is based on the masked flow. If False, it is based on the raw (dense) flow.
270 |
271 | Returns:
272 | image (PIL.Image) ... PIL image.
273 | """
274 | _show, _save = self._show, self._save
275 | self._show, self._save = False, False
276 | mask = self.imager.create_eventmask(events)
277 | if max_color_on_mask:
278 | masked_flow = flow * mask
279 | image = self.visualize_optical_flow(
280 | masked_flow[0],
281 | masked_flow[1],
282 | visualize_color_wheel=False,
283 | file_prefix=file_prefix,
284 | ord=ord,
285 | )
286 | else:
287 | image = self.visualize_optical_flow(
288 | flow[0], flow[1], visualize_color_wheel=False, file_prefix=file_prefix, ord=ord
289 | )
290 | mask = Image.fromarray((~mask)[0]).convert("1")
291 | white = Image.new("RGB", image.size, (255, 255, 255))
292 | masked_flow = Image.composite(white, image, mask)
293 | self._show, self._save = _show, _save
294 | self._show_or_save_image(masked_flow, file_prefix)
295 | return masked_flow
296 |
297 | def visualize_optical_flow_pred_and_gt(
298 | self,
299 | flow_pred: np.ndarray,
300 | flow_gt: np.ndarray,
301 | visualize_color_wheel: bool = True,
302 | pred_file_prefix: Optional[str] = None,
303 | gt_file_prefix: Optional[str] = None,
304 | ord: float = 0.5,
305 | ):
306 | """Visualize optical flow both pred and GT.
307 | Args:
308 | flow_pred (numpy.ndarray) ... [2, H x W]
309 | flow_gt (numpy.ndarray) ... [2, H x W]
310 | visualize_color_wheel (bool) ... If True, it also visualizes the color wheel (legend for OF).
311 | file_prefix (Optional[str], optional): [description]. Defaults to None.
312 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`.
313 |
314 | Returns:
315 | image (PIL.Image) ... PIL image.
316 | """
317 | # get largest magnitude in both pred and gt
318 | _, _, max_pred = self.color_optical_flow(flow_pred[0], flow_pred[1], ord=ord)
319 | _, _, max_gt = self.color_optical_flow(flow_gt[0], flow_gt[1], ord=ord)
320 | max_magnitude = np.max([max_pred, max_gt])
321 | color_pred, _, _ = self.color_optical_flow(
322 | flow_pred[0], flow_pred[1], max_magnitude, ord=ord
323 | )
324 | color_gt, color_wheel, _ = self.color_optical_flow(
325 | flow_gt[0], flow_gt[1], max_magnitude, ord=ord
326 | )
327 |
328 | image = Image.fromarray(color_pred)
329 | self._show_or_save_image(image, pred_file_prefix)
330 | image = Image.fromarray(color_gt)
331 | self._show_or_save_image(image, gt_file_prefix)
332 | if visualize_color_wheel:
333 | wheel = Image.fromarray(color_wheel)
334 | self._show_or_save_image(wheel, fixed_file_name="color_wheel")
335 |
336 | def color_optical_flow(
337 | self, flow_x: np.ndarray, flow_y: np.ndarray, max_magnitude=None, ord=1.0
338 | ):
339 | """Color optical flow.
340 | Args:
341 | flow_x (numpy.ndarray) ... [H x W], height direction.
342 | flow_y (numpy.ndarray) ... [H x W], width direction.
343 | max_magnitude (float, optional) ... Max magnitude used for the colorization. Defaults to None.
344 | ord (float) ... 1: our usual, 0.5: DSEC colorinzing.
345 |
346 | Returns:
347 | flow_rgb (np.ndarray) ... [W, H]
348 | color_wheel (np.ndarray) ... [H, H] color wheel
349 | max_magnitude (float) ... max magnitude of the flow.
350 | """
351 | flows = np.stack((flow_x, flow_y), axis=2)
352 | flows[np.isinf(flows)] = 0
353 | flows[np.isnan(flows)] = 0
354 | mag = np.linalg.norm(flows, axis=2) ** ord
355 | ang = (np.arctan2(flow_y, flow_x) + np.pi) * 180.0 / np.pi / 2.0
356 | ang = ang.astype(np.uint8)
357 | hsv = np.zeros([flow_x.shape[0], flow_x.shape[1], 3], dtype=np.uint8)
358 | hsv[:, :, 0] = ang
359 | hsv[:, :, 1] = 255
360 | if max_magnitude is None:
361 | max_magnitude = mag.max()
362 | hsv[:, :, 2] = (255 * mag / max_magnitude).astype(np.uint8)
363 | # hsv[:, :, 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
364 | flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
365 |
366 | # Color wheel
367 | hsv = np.zeros([flow_x.shape[0], flow_x.shape[0], 3], dtype=np.uint8)
368 | xx, yy = np.meshgrid(
369 | np.linspace(-1, 1, flow_x.shape[0]), np.linspace(-1, 1, flow_x.shape[0])
370 | )
371 | mag = np.linalg.norm(np.stack((xx, yy), axis=2), axis=2)
372 | # ang = (np.arctan2(yy, xx) + np.pi) * 180 / np.pi / 2.0
373 | ang = (np.arctan2(xx, yy) + np.pi) * 180 / np.pi / 2.0
374 | hsv[:, :, 0] = ang.astype(np.uint8)
375 | hsv[:, :, 1] = 255
376 | hsv[:, :, 2] = (255 * mag / mag.max()).astype(np.uint8)
377 | # hsv[:, :, 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
378 | color_wheel = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
379 |
380 | return flow_rgb, color_wheel, max_magnitude
381 |
382 | # Event related
383 | def visualize_event(
384 | self,
385 | events: Any,
386 | grayscale: bool = True,
387 | background_color: int = 127,
388 | ignore_polarity: bool = False,
389 | file_prefix: Optional[str] = None,
390 | ) -> Image.Image:
391 | """Visualize event as image.
392 | # TODO the function is messy - cleanup.
393 |
394 | Args:
395 | events (Any): [description]
396 | grayscale (bool, optional): [description]. Defaults to True.
397 | background_color: int = 127: Background color when events are none
398 | backround (int, optional): Only effective when grayscale is True. Defaults to 127. If non-grayscale, it is 255.
399 | ignore_polarity (bool, optional): If true, crerate polarity-ignore image. Defaults to False.
400 |
401 | Returns:
402 | Optional[Image.Image]: [description]
403 | """
404 | if grayscale:
405 | image = np.ones((self._image_size[0], self._image_size[1]))
406 | else:
407 | background_color = 255
408 | image = (
409 | np.ones((self._image_size[0], self._image_size[1], 3), dtype=np.uint8)
410 | * background_color
411 | ) # RGBA channel
412 |
413 | # events = events[0 <= events[:, 0] < self._image_size[0]]
414 | # events = events[events[:, 1] < self._image_size[1]]
415 | events[:, 0] = np.clip(events[:, 0], 0, self._image_size[0] - 1)
416 | events[:, 1] = np.clip(events[:, 1], 0, self._image_size[1] - 1)
417 | if grayscale:
418 | indice = (events[:, 0].astype(np.int32), events[:, 1].astype(np.int32))
419 | if ignore_polarity:
420 | np.add.at(image, indice, np.ones_like(events[:, 3], dtype=np.int16))
421 | else:
422 | if np.min(events[:, 3]) == 0:
423 | pol = events[:, 3] * 2 - 1
424 | else:
425 | pol = events[:, 3]
426 | np.add.at(image, indice, pol)
427 | return self.visualize_event_image(image, background_color)
428 | else:
429 | colors = np.array([(255, 0, 0) if e[3] == 1 else (0, 0, 255) for e in events])
430 | image[events[:, 0].astype(np.int32), events[:, 1].astype(np.int32), :] = colors
431 |
432 | image = Image.fromarray(image)
433 | self._show_or_save_image(image, file_prefix)
434 | return image
435 |
436 | def save_array(
437 | self,
438 | array: np.ndarray,
439 | file_prefix: Optional[str] = None,
440 | new_prefix: bool = False,
441 | ) -> None:
442 | """Helper function to save numpy array. It belongs to this visualizer class
443 | because it associates with the naming rule of visualized files.
444 |
445 | Args:
446 | array (np.ndarray): Numpy array to save.
447 | file_prefix (Optional[str]): Prefix of the file. Defaults to None.
448 | new_prefix (bool): If True, rollback_save_count is skipped. Set to True if
449 | there is no correspondng .png file with the prefix. bDefaults to False.
450 |
451 | Returns:
452 | Optional[Image.Image]: [description]
453 | """
454 | save_name = self.get_filename_from_prefix(file_prefix).replace("png", "npy")
455 | np.save(save_name, array)
456 | if not new_prefix:
457 | self.rollback_save_count(file_prefix)
458 |
459 | def visualize_event_image(
460 | self, eventimage: np.ndarray, background_color: int = 255, file_prefix: Optional[str] = None
461 | ) -> Image.Image:
462 | """Visualize event on white image"""
463 | # max_value_abs = np.max(np.abs(eventimage))
464 | # eventimage = (255 * eventimage / (2 * max_value_abs) + 127).astype(np.uint8)
465 | background = eventimage == 0
466 | eventimage = (
467 | 255 * (eventimage - eventimage.min()) / (eventimage.max() - eventimage.min())
468 | ).astype(np.uint8)
469 | if background_color == 255:
470 | eventimage = 255 - eventimage
471 | else:
472 | eventimage[background] = background_color
473 | eventimage = Image.fromarray(eventimage)
474 | self._show_or_save_image(eventimage, file_prefix)
475 | return eventimage
476 |
477 | # Scipy history visualizer
478 | def visualize_scipy_history(self, cost_history: dict, cost_weight: Optional[dict] = None):
479 | """Visualizing scipy optimization history.
480 |
481 | Args:
482 | cost_history (dict): [description]
483 | """
484 | plt.figure()
485 | for k in cost_history.keys():
486 | if k == "loss" or cost_weight is None:
487 | plt.plot(np.array(cost_history[k]), label=k)
488 | else:
489 | plt.plot(np.array(cost_history[k]) * cost_weight[k], label=k)
490 | plt.legend()
491 | if self._save:
492 | plt.savefig(self.get_filename_from_prefix("optimization_steps"))
493 | if self._show:
494 | plt.show(block=False)
495 | plt.close()
496 |
--------------------------------------------------------------------------------
/tests/costs/test_gradient_magnitude.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 |
5 | from src import utils
6 | from src.costs import GradientMagnitude
7 | from src.event_image_converter import EventImageConverter
8 | from src.warp import Warp
9 |
10 |
11 | # minimum is different
12 | def test_calculate_store_history():
13 | size = (260, 346)
14 | events = utils.generate_events(1000, size[0], size[1], tmin=0.1, tmax=0.9).astype(np.float32)
15 | events_torch = torch.from_numpy(events)
16 | imager = EventImageConverter(size)
17 |
18 | cost = GradientMagnitude(direction="minimize", store_history=True)
19 | # Calculate numpy
20 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", weight=1.0, sigma=0)
21 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
22 |
23 | # Calculate torch
24 | count = imager.create_image_from_events_tensor(events_torch, "bilinear_vote", weight=1.0)
25 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
26 |
27 | history = cost.get_history()
28 | assert len(history["loss"]) == 2
29 |
30 |
31 | def test_calculate_not_store_history():
32 | size = (260, 346)
33 | events = utils.generate_events(1000, size[0], size[1], tmin=0.1, tmax=0.9).astype(np.float32)
34 | events_torch = torch.from_numpy(events)
35 | imager = EventImageConverter(size)
36 |
37 | cost = GradientMagnitude(direction="minimize", store_history=False)
38 | # Calculate numpy
39 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", weight=1.0, sigma=0)
40 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
41 |
42 | # Calculate torch
43 | count = imager.create_image_from_events_tensor(events_torch, "bilinear_vote", weight=1.0)
44 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
45 |
46 | history = cost.get_history()
47 | assert len(history["loss"]) == 0
48 |
49 |
50 | @pytest.mark.parametrize(
51 | "direction,is_small",
52 | [["natural", True], ["minimize", False], ["maximize", True]],
53 | )
54 | def test_calculate_blur_is_small(direction, is_small):
55 | size = (10, 40)
56 | imager = EventImageConverter(size)
57 | cost = GradientMagnitude(direction=direction, store_history=False)
58 |
59 | events = np.array(
60 | [
61 | [5.0, 10.0],
62 | [8.0, 3.0],
63 | [2.0, 2.0],
64 | ]
65 | )
66 | grad_blur = cost.calculate({"iwe": imager.create_iwe(events), "omit_boundary": False})
67 | events = np.array(
68 | [
69 | [5.0, 10.0],
70 | [5.0, 10.0],
71 | [2.0, 2.0],
72 | ]
73 | )
74 | grad_sharp = cost.calculate({"iwe": imager.create_iwe(events), "omit_boundary": False})
75 |
76 | assert (grad_blur < grad_sharp) == is_small
77 |
78 |
79 | def test_calculate_np_torch():
80 | size = (10, 20)
81 | imager = EventImageConverter(size)
82 | cost = GradientMagnitude(direction="natural", store_history=False)
83 |
84 | events = np.array(
85 | [
86 | [12.0, 10.0],
87 | [8.0, 3.0],
88 | [2.0, 2.0],
89 | ]
90 | ).astype(np.float32)
91 | events_torch = torch.from_numpy(events)
92 | var_blur_numpy = cost.calculate(
93 | {"iwe": imager.create_iwe(events, sigma=0), "omit_boundary": True}
94 | )
95 | var_blur_torch = cost.calculate(
96 | {"iwe": imager.create_iwe(events_torch, sigma=0), "omit_boundary": True}
97 | )
98 |
99 | np.testing.assert_allclose(var_blur_torch.item(), var_blur_numpy, rtol=1e-2)
100 |
--------------------------------------------------------------------------------
/tests/costs/test_hybrid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from PIL.Image import Image
5 |
6 | from src import utils
7 | from src.costs import HybridCost, ImageVariance
8 | from src.event_image_converter import EventImageConverter
9 | from src.warp import Warp
10 |
11 |
12 | # minimum is different
13 | def test_hybrid_cost_store_history():
14 | size = (20, 34)
15 | imager = EventImageConverter(size)
16 | cost_with_weight = {"image_variance": 1.0, "gradient_magnitude": 2.4}
17 |
18 | cost = HybridCost(direction="minimize", cost_with_weight=cost_with_weight, store_history=True)
19 | variance = ImageVariance(store_history=True)
20 |
21 | # Calculate numpy
22 | events = utils.generate_events(1000, size[0], size[1])
23 | count = imager.create_image_from_events_numpy(events, sigma=0)
24 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
25 | _ = variance.calculate({"iwe": count, "omit_boundary": True})
26 | events = utils.generate_events(1000, size[0], size[1])
27 | count = imager.create_image_from_events_numpy(events, sigma=0)
28 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
29 | _ = variance.calculate({"iwe": count, "omit_boundary": True})
30 | events = utils.generate_events(1000, size[0], size[1])
31 | count = imager.create_image_from_events_numpy(events, sigma=0)
32 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
33 | _ = variance.calculate({"iwe": count, "omit_boundary": True})
34 |
35 | history = cost.get_history()
36 | keys = ["loss", "image_variance", "gradient_magnitude"]
37 | assert history.keys() == set(keys)
38 | for k in keys:
39 | assert len(history[k]) == 3
40 | np.testing.assert_allclose(
41 | history["image_variance"], variance.get_history()["loss"], rtol=1e-5, atol=1e-5
42 | )
43 |
44 |
45 | def test_hybrid_cost_without_store_history():
46 | size = (20, 34)
47 | imager = EventImageConverter(size)
48 | cost_with_weight = {"image_variance": 1.0, "gradient_magnitude": 2.4}
49 |
50 | cost = HybridCost(direction="minimize", cost_with_weight=cost_with_weight, store_history=False)
51 |
52 | # Calculate numpy
53 | events = utils.generate_events(1000, size[0], size[1])
54 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", sigma=0)
55 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
56 | events = utils.generate_events(1000, size[0], size[1])
57 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", sigma=0)
58 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
59 | events = utils.generate_events(1000, size[0], size[1])
60 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", sigma=0)
61 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
62 |
63 | history = cost.get_history()
64 | keys = ["loss", "image_variance", "gradient_magnitude"]
65 | assert history.keys() == set(keys)
66 | for k in keys:
67 | assert len(history[k]) == 0
68 |
--------------------------------------------------------------------------------
/tests/costs/test_image_variance.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 |
5 | from src import utils
6 | from src.costs import ImageVariance
7 | from src.event_image_converter import EventImageConverter
8 | from src.warp import Warp
9 |
10 |
11 | # minimum is different
12 | def test_calculate_store_history():
13 | size = (260, 346)
14 | events = utils.generate_events(1000, size[0], size[1], tmin=0.1, tmax=0.9)
15 | events_torch = torch.from_numpy(events)
16 | imager = EventImageConverter(size)
17 |
18 | cost = ImageVariance(direction="minimize", store_history=True)
19 | # Calculate numpy
20 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", weight=1.0, sigma=0)
21 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
22 |
23 | # Calculate torch
24 | count = imager.create_image_from_events_tensor(events_torch, "bilinear_vote", weight=1.0)
25 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
26 |
27 | history = cost.get_history()
28 | assert len(history["loss"]) == 2
29 | np.testing.assert_allclose(history["loss"][0], history["loss"][1], rtol=1e-5, atol=1e-5)
30 |
31 |
32 | def test_calculate_not_store_history():
33 | size = (260, 346)
34 | events = utils.generate_events(1000, size[0], size[1], tmin=0.1, tmax=0.9)
35 | events_torch = torch.from_numpy(events)
36 | imager = EventImageConverter(size)
37 |
38 | cost = ImageVariance(direction="minimize", store_history=False)
39 | # Calculate numpy
40 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", weight=1.0, sigma=0)
41 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
42 |
43 | # Calculate torch
44 | count = imager.create_image_from_events_tensor(events_torch, "bilinear_vote", weight=1.0)
45 | _ = cost.calculate({"iwe": count, "omit_boundary": True})
46 |
47 | history = cost.get_history()
48 | assert len(history["loss"]) == 0
49 |
50 |
51 | @pytest.mark.parametrize(
52 | "direction,is_small",
53 | [["natural", True], ["minimize", False], ["maximize", True]],
54 | )
55 | def test_calculate_blur_is_small(direction, is_small):
56 | size = (10, 40)
57 | imager = EventImageConverter(size)
58 | cost = ImageVariance(direction=direction, store_history=False)
59 |
60 | events = np.array(
61 | [
62 | [5.0, 10.0],
63 | [8.0, 3.0],
64 | [2.0, 2.0],
65 | ]
66 | )
67 | var_blur = cost.calculate({"iwe": imager.create_iwe(events), "omit_boundary": False})
68 | events = np.array(
69 | [
70 | [5.0, 10.0],
71 | [5.0, 10.0],
72 | [2.0, 2.0],
73 | ]
74 | )
75 | var_sharp = cost.calculate({"iwe": imager.create_iwe(events), "omit_boundary": False})
76 |
77 | assert (var_blur < var_sharp) == is_small
78 |
79 |
80 | def test_calculate_np_torch():
81 | size = (10, 20)
82 | imager = EventImageConverter(size)
83 | cost = ImageVariance(direction="natural", store_history=False)
84 |
85 | events = np.array(
86 | [
87 | [12.0, 10.0],
88 | [8.0, 3.0],
89 | [2.0, 2.0],
90 | ]
91 | ).astype(np.float64)
92 | events_torch = torch.from_numpy(events)
93 | var_blur_numpy = cost.calculate(
94 | {"iwe": imager.create_iwe(events, sigma=0), "omit_boundary": True}
95 | )
96 | var_blur_torch = cost.calculate(
97 | {"iwe": imager.create_iwe(events_torch, sigma=0), "omit_boundary": True}
98 | )
99 |
100 | np.testing.assert_allclose(var_blur_torch.item(), var_blur_numpy, rtol=1e-2)
101 |
--------------------------------------------------------------------------------
/tests/test_event_image_converter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from src import event_image_converter, utils
5 |
6 |
7 | def test_create_iwe():
8 | image_shape = (100, 200)
9 | imager = event_image_converter.EventImageConverter(image_shape)
10 | events = np.array(
11 | [utils.generate_events(100, image_shape[0] - 1, image_shape[1] - 1) for _ in range(4)]
12 | )
13 | iwe = imager.create_iwe(events)
14 | assert iwe.shape == (4, 100, 200)
15 |
16 |
17 | def test_bilinear_vote_integer():
18 | image_shape = (3, 4)
19 | imager = event_image_converter.EventImageConverter(image_shape)
20 |
21 | events = np.array(
22 | [
23 | [1.0, 2],
24 | [0, 1],
25 | [1, 0],
26 | ]
27 | )
28 | weights = np.array([1, 2, 0.8])
29 | img = imager.bilinear_vote_numpy(events, weight=weights)
30 | expected = np.array(
31 | [
32 | [0, 2, 0, 0],
33 | [0.8, 0, 1, 0],
34 | [0, 0, 0, 0],
35 | ]
36 | )
37 | # NUmpy
38 | np.testing.assert_array_equal(img, expected)
39 |
40 | # Torch
41 | img = imager.bilinear_vote_tensor(torch.from_numpy(events), weight=torch.from_numpy(weights))
42 | assert torch.allclose(img, torch.from_numpy(expected))
43 |
44 |
45 | def test_bilinear_vote_float():
46 | image_shape = (3, 4)
47 | imager = event_image_converter.EventImageConverter(image_shape)
48 | events = np.array(
49 | [
50 | [1.2, 2],
51 | [0, 1.9],
52 | [0.5, 0.6],
53 | ]
54 | )
55 | weights = np.array([-1.0, 1.0, 1.5])
56 | img = imager.bilinear_vote_numpy(events, weight=weights)
57 | expected = np.array(
58 | [
59 | [0.3, 0.55, 0.9, 0],
60 | [0.3, 0.45, -0.8, 0],
61 | [0, 0, -0.2, 0],
62 | ]
63 | )
64 | # numpy
65 | np.testing.assert_allclose(img, expected)
66 |
67 | # torch
68 | img = imager.bilinear_vote_tensor(torch.from_numpy(events), weight=torch.from_numpy(weights))
69 | assert torch.allclose(img, torch.from_numpy(expected))
70 |
71 |
72 | def test_bilinear_vote_batch():
73 | image_shape = (3, 4)
74 | imager = event_image_converter.EventImageConverter(image_shape)
75 | events = np.array(
76 | [
77 | [
78 | [1, 2],
79 | [0, 1],
80 | [1, 0],
81 | ],
82 | [
83 | [1.2, 2],
84 | [0, 1.9],
85 | [0.5, 0.6],
86 | ],
87 | ]
88 | )
89 | weights = np.array([[1.0, 2.0, 0.8], [-1.0, 1.0, 1.5]])
90 | img = imager.bilinear_vote_numpy(events, weight=weights)
91 | expected = np.array(
92 | [
93 | [
94 | [0, 2, 0, 0],
95 | [0.8, 0, 1, 0],
96 | [0, 0, 0, 0],
97 | ],
98 | [
99 | [0.3, 0.55, 0.9, 0],
100 | [0.3, 0.45, -0.8, 0],
101 | [0, 0, -0.2, 0],
102 | ],
103 | ]
104 | )
105 | # numpy
106 | np.testing.assert_allclose(img, expected)
107 |
108 | # torch
109 | img = imager.bilinear_vote_tensor(torch.from_numpy(events), weight=torch.from_numpy(weights))
110 | assert torch.allclose(img, torch.from_numpy(expected))
111 |
112 |
113 | def test_bilinear_vote_accuracy():
114 | image_shape = (10, 20)
115 | imager = event_image_converter.EventImageConverter(image_shape)
116 | events = utils.generate_events(100, image_shape[0] - 1, image_shape[1] - 1)
117 | events += np.random.rand(100)[..., None]
118 | events = events.astype(np.float32)
119 | img = imager.bilinear_vote_numpy(events, weight=1.0)
120 |
121 | img_t = imager.bilinear_vote_tensor(torch.from_numpy(events), weight=1.0)
122 | np.testing.assert_allclose(img, img_t.numpy(), rtol=1e-5)
123 |
--------------------------------------------------------------------------------
/tests/test_warp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 |
5 | from src import utils, warp
6 |
7 |
8 | @pytest.mark.parametrize(
9 | "model,size",
10 | [["2d-translation", 2]],
11 | )
12 | def test_get_motion_vector_size(model, size):
13 | warper = warp.Warp((100, 200), normalize_t=True)
14 | assert size == warper.get_motion_vector_size(model)
15 |
16 |
17 | def test_calculate_dt_normalize():
18 | image_size = (100, 200)
19 | warper = warp.Warp(image_size, normalize_t=True)
20 |
21 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=1, tmax=2)
22 | dt = warper.calculate_dt(events, 1.0)
23 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1)
24 | np.testing.assert_allclose(dt.max(), 1.0, rtol=1e-2, atol=0.1)
25 |
26 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=0, tmax=0.5)
27 | dt = warper.calculate_dt(events, 0)
28 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1)
29 | np.testing.assert_allclose(dt.max(), 1.0, rtol=1e-2, atol=0.1)
30 |
31 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=-1, tmax=1)
32 | dt = warper.calculate_dt(events, 0)
33 | np.testing.assert_allclose(dt.min(), -0.5, rtol=1e-2, atol=0.1)
34 | np.testing.assert_allclose(dt.max(), 0.5, rtol=1e-2, atol=0.1)
35 |
36 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=-1, tmax=1)
37 | dt = warper.calculate_dt(events, -1)
38 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1)
39 | np.testing.assert_allclose(dt.max(), 1.0, rtol=1e-2, atol=0.1)
40 |
41 |
42 | def test_calculate_dt_non_normalize():
43 | image_size = (10, 20)
44 | warper = warp.Warp(image_size, normalize_t=False)
45 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=1, tmax=2)
46 | dt = warper.calculate_dt(events, 1.0)
47 | np.testing.assert_allclose(dt.max(), 1.0, rtol=1e-2, atol=0.1)
48 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1)
49 |
50 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=0, tmax=0.5)
51 | dt = warper.calculate_dt(events, 0)
52 | np.testing.assert_allclose(dt.max(), 0.5, rtol=1e-2, atol=0.1)
53 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1)
54 |
55 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=-1, tmax=1)
56 | dt = warper.calculate_dt(events, 0)
57 | np.testing.assert_allclose(dt.max(), 1.0, rtol=1e-2, atol=0.1)
58 | np.testing.assert_allclose(dt.min(), -1.0, rtol=1e-2, atol=0.1)
59 |
60 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=-1, tmax=1)
61 | dt = warper.calculate_dt(events, -1)
62 | np.testing.assert_allclose(dt.max(), 2.0, rtol=1e-2, atol=0.1)
63 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1)
64 |
65 |
66 | def test_calculate_dt_numpy_batch():
67 | image_size = (10, 20)
68 | warper = warp.Warp(image_size, normalize_t=True)
69 | events = np.array(
70 | [
71 | utils.generate_events(300, image_size[0], image_size[1], tmin=1, tmax=i + 2)
72 | for i in range(4)
73 | ]
74 | )
75 | dt = warper.calculate_dt(events, 1.0)
76 | np.testing.assert_allclose(dt.max(axis=-1), 1.0, rtol=1e-2, atol=0.1)
77 | np.testing.assert_allclose(dt.min(axis=-1), 0.0, rtol=1e-2, atol=0.1)
78 | assert dt.shape == (4, 300)
79 |
80 |
81 | def test_calculate_dt_torch_batch():
82 | image_size = (10, 20)
83 | warper = warp.Warp(image_size, normalize_t=True)
84 | events = np.array(
85 | [
86 | utils.generate_events(300, image_size[0], image_size[1], tmin=1, tmax=i + 2)
87 | for i in range(4)
88 | ]
89 | )
90 | dt = warper.calculate_dt(torch.from_numpy(events), 1.0).numpy()
91 | np.testing.assert_allclose(dt.max(axis=-1), 1.0, rtol=1e-2, atol=0.1)
92 | np.testing.assert_allclose(dt.min(axis=-1), 0.0, rtol=1e-2, atol=0.1)
93 | assert dt.shape == (4, 300)
94 |
95 |
96 | def test_warp_event_dense_flow():
97 | image_size = (3, 4)
98 | warper = warp.Warp(image_size, normalize_t=True)
99 |
100 | events = np.array(
101 | [
102 | [1, 2, 0],
103 | [2, 3, 0.2],
104 | [0, 1, 0.6],
105 | [1, 0, 1.0],
106 | ]
107 | )
108 | flow = np.array(
109 | [
110 | [
111 | [1.0, -0.5, 2, 8],
112 | [-2, 0, 2.0, 0],
113 | [2, 1, -2, 0],
114 | ],
115 | [
116 | [-10, 1.0, 3, 2],
117 | [0, 2, -0.9, 0],
118 | [0, 10, -3, 0],
119 | ],
120 | ]
121 | )
122 |
123 | expected = np.array(
124 | [
125 | [1.0, 2.0, 0],
126 | [2.0, 3.0, 0.2],
127 | [0.3, 0.4, 0.6],
128 | [3, 0, 1.0],
129 | ]
130 | )
131 | # NUmpy
132 | warped, _ = warper.warp_event(events, flow, "dense-flow")
133 | np.testing.assert_allclose(warped, expected)
134 |
135 | # Torch
136 | warped_torch, _ = warper.warp_event(
137 | torch.from_numpy(events), torch.from_numpy(flow), "dense-flow"
138 | )
139 | assert torch.allclose(warped_torch, torch.from_numpy(expected))
140 |
141 |
142 | def test_warp_event_dense_flow_batch():
143 | image_size = (3, 4)
144 | warper = warp.Warp(image_size, normalize_t=True)
145 |
146 | events = np.array(
147 | [
148 | [[1, 2, 0], [2, 3, 0.2]],
149 | [[0, 1, 0.6], [1, 0, 1.2]],
150 | ]
151 | )
152 | flow = np.array(
153 | [
154 | [
155 | [
156 | [1.0, -0.5, 2, 8],
157 | [-2, 0, 2.0, 0],
158 | [2, 1, -2, 0],
159 | ],
160 | [
161 | [-10, 1.0, 3, 2],
162 | [0, 2, -0.9, 0],
163 | [0, 10, -3, 0],
164 | ],
165 | ],
166 | [
167 | [
168 | [1.0, -0.5, 2, 8],
169 | [-2, 0, 2.0, 0],
170 | [2, 1, -2, 0],
171 | ],
172 | [
173 | [-10, 1.0, 3, 2],
174 | [0, 2, -0.9, 0],
175 | [0, 10, -3, 0],
176 | ],
177 | ],
178 | ]
179 | )
180 |
181 | expected = np.array(
182 | [
183 | [[1.0, 2.0, 0], [2, 3, 1.0]],
184 | [[0, 1, 0], [3, 0, 1.0]],
185 | ]
186 | )
187 | # NUmpy
188 | warped, _ = warper.warp_event(events, flow, "dense-flow")
189 | np.testing.assert_allclose(warped, expected)
190 |
191 | # Torch
192 | warped_torch, _ = warper.warp_event(
193 | torch.from_numpy(events), torch.from_numpy(flow), "dense-flow"
194 | )
195 | assert torch.allclose(warped_torch, torch.from_numpy(expected))
196 |
197 |
198 | def test_warp_event_dense_flow_accuracy():
199 | image_size = (10, 20)
200 | warper = warp.Warp(image_size, normalize_t=True)
201 | events = np.array(
202 | [
203 | utils.generate_events(300, image_size[0], image_size[1], tmin=1, tmax=i + 2)
204 | for i in range(4)
205 | ]
206 | ).astype(np.float32)
207 | flow = np.array(
208 | [utils.generate_dense_optical_flow(image_size, max_val=10) for i in range(4)]
209 | ).astype(np.float32)
210 | warped, _ = warper.warp_event(events, flow, "dense-flow")
211 | warped_torch, _ = warper.warp_event(
212 | torch.from_numpy(events), torch.from_numpy(flow), "dense-flow"
213 | )
214 | np.testing.assert_allclose(warped, warped_torch.numpy(), rtol=1e-5, atol=1e-5)
215 |
--------------------------------------------------------------------------------
/tests/utils/test_event_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from src import utils
5 |
6 |
7 | def test_crop_events():
8 | events = utils.generate_events(100, 30, 20)
9 | cropped = utils.crop_event(events, -10, 40, -20, 30)
10 | assert len(events) == len(cropped)
11 |
12 | cropped_torch = utils.crop_event(torch.from_numpy(events), -10, 40, -20, 30)
13 | assert len(events) == len(cropped) == len(cropped_torch)
14 |
15 | cropped = utils.crop_event(events, 5, 29, 2, 11)
16 | cropped_torch = utils.crop_event(torch.from_numpy(events), 5, 29, 2, 11)
17 | assert len(cropped) == len(cropped_torch)
18 |
--------------------------------------------------------------------------------
/tests/utils/test_flow_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 |
5 | from src import utils
6 |
7 |
8 | def test_flow_error_numpy_torch():
9 | imsize = (100, 200)
10 | flow_pred_np = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float32)
11 | flow_pred_np = flow_pred_np[None, ...]
12 | flow_gt_np = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float32)
13 | flow_gt_np = flow_gt_np[None, ...]
14 |
15 | flow_pred_th = torch.from_numpy(flow_pred_np)
16 | flow_gt_th = torch.from_numpy(flow_gt_np)
17 |
18 | error_np = utils.calculate_flow_error_numpy(flow_gt_np, flow_pred_np)
19 | error_th = utils.calculate_flow_error_tensor(flow_gt_th, flow_pred_th)
20 |
21 | for k in error_np.keys():
22 | np.testing.assert_almost_equal(error_np[k], error_th[k].numpy(), decimal=5)
23 |
24 |
25 | def test_flow_error_different_batch_size():
26 | # Test with bsize=1 vs. bsize=8
27 | imsize = (100, 200)
28 | flow_pred = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float32)
29 | flow_gt = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float32)
30 | flow_pred = np.tile(flow_pred, (8, 1, 1, 1))
31 | flow_gt = np.tile(flow_gt, (8, 1, 1, 1))
32 | mask = np.random.rand(8, 1, imsize[0], imsize[1]) > 0.1
33 |
34 | # Numpy
35 | error_batch = utils.calculate_flow_error_numpy(flow_gt, flow_pred, mask)
36 | error_one = utils.calculate_flow_error_numpy(flow_gt[[0]], flow_pred[[0]])
37 |
38 | for k in error_one.keys():
39 | np.testing.assert_almost_equal(error_one[k], error_batch[k], decimal=1)
40 |
41 | # Torch
42 | flow_gt = torch.from_numpy(flow_gt)
43 | flow_pred = torch.from_numpy(flow_pred)
44 | mask = torch.from_numpy(mask)
45 | error_batch = utils.calculate_flow_error_tensor(flow_gt, flow_pred, mask)
46 | error_one = utils.calculate_flow_error_tensor(flow_gt[[0]], flow_pred[[0]])
47 |
48 | for k in error_one.keys():
49 | np.testing.assert_almost_equal(error_one[k].numpy(), error_batch[k].numpy(), decimal=1)
50 |
51 |
52 | @pytest.mark.parametrize("scheme", ["upwind", "burgers"])
53 | def test_construct_dense_flow_voxel_numpy(scheme):
54 | imsize = (100, 200)
55 | n_bin = 60
56 | flow = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float64)
57 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, 1)
58 | np.testing.assert_almost_equal(voxel_flow[0], flow, decimal=8)
59 |
60 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, n_bin, scheme, t0_location="middle")
61 | np.testing.assert_almost_equal(voxel_flow[n_bin // 2], flow, decimal=8)
62 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, n_bin, scheme, t0_location="first")
63 | np.testing.assert_almost_equal(voxel_flow[0], flow, decimal=8)
64 |
65 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, n_bin, scheme, t0_location="middle")
66 | np.testing.assert_almost_equal(voxel_flow[n_bin // 2], flow, decimal=8)
67 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, n_bin, scheme, t0_location="first")
68 | np.testing.assert_almost_equal(voxel_flow[0], flow, decimal=8)
69 |
70 |
71 | @pytest.mark.parametrize("scheme", ["upwind", "burgers"])
72 | def test_construct_dense_flow_voxel_upwind_torch(scheme):
73 | imsize = (100, 200)
74 | n_bin = 60
75 | flow = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float64)
76 |
77 | flow_torch = torch.from_numpy(flow).double()
78 | voxel_flow_torch = utils.construct_dense_flow_voxel_torch(
79 | flow_torch, n_bin, scheme, t0_location="middle"
80 | )
81 | np.testing.assert_almost_equal(
82 | voxel_flow_torch.numpy()[n_bin // 2], flow_torch.numpy(), decimal=8
83 | )
84 | voxel_flow_torch = utils.construct_dense_flow_voxel_torch(
85 | flow_torch, n_bin, scheme, t0_location="first"
86 | )
87 | np.testing.assert_almost_equal(voxel_flow_torch.numpy()[0], flow_torch.numpy(), decimal=8)
88 |
89 |
90 | @pytest.mark.parametrize(
91 | "scheme,t0_location",
92 | [["upwind", "middle"], ["upwind", "first"], ["burgers", "middle"], ["burgers", "first"]],
93 | )
94 | def test_construct_dense_flow_voxel_numerical(scheme, t0_location):
95 | imsize = (100, 200)
96 | n_bin = 100
97 | flow = utils.generate_dense_optical_flow(imsize, max_val=10).astype(np.float64)
98 |
99 | flow_torch = torch.from_numpy(flow).double()
100 |
101 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, n_bin, scheme, t0_location)
102 | voxel_flow_torch = utils.construct_dense_flow_voxel_torch(
103 | flow_torch, n_bin, scheme, t0_location
104 | )
105 | np.testing.assert_almost_equal(voxel_flow_torch.numpy(), voxel_flow, decimal=6)
106 |
107 |
108 | def test_inviscid_burger_flow_to_voxel_numerical():
109 | imsize = (100, 200)
110 | dt = 0.01
111 | flow = utils.generate_dense_optical_flow(imsize, max_val=1).astype(np.float64)
112 | flow_torch = torch.from_numpy(flow).double()
113 |
114 | flow1 = utils.inviscid_burger_flow_to_voxel_numpy(flow, dt, 1, 1)
115 | flow1_torch = utils.inviscid_burger_flow_to_voxel_torch(flow_torch, dt, 1, 1)
116 | np.testing.assert_almost_equal(flow1_torch.numpy(), flow1, decimal=6)
117 |
118 | flow1 = utils.inviscid_burger_flow_to_voxel_numpy(flow, -dt, 1, 1)
119 | flow1_torch = utils.inviscid_burger_flow_to_voxel_torch(flow_torch, -dt, 1, 1)
120 | np.testing.assert_almost_equal(flow1_torch.numpy(), flow1, decimal=6)
121 |
--------------------------------------------------------------------------------