├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── configs ├── mvsec_indoor_burgers.yaml └── mvsec_indoor_no_timeaware.yaml ├── conftest.py ├── datasets ├── .gitkeep └── README.md ├── docs └── img │ ├── 2024_TPAMI_SecretsOfEVFlow_poster.pdf │ └── secretsevflow_eccv22.jpg ├── main.py ├── outputs └── .gitkeep ├── pyproject.toml ├── src ├── __init__.py ├── costs │ ├── __init__.py │ ├── base.py │ ├── gradient_magnitude.py │ ├── hybrid.py │ ├── image_variance.py │ ├── multi_focal_normalized_gradient_magnitude.py │ ├── multi_focal_normalized_image_variance.py │ ├── normalized_gradient_magnitude.py │ ├── normalized_image_variance.py │ └── total_variation.py ├── data_loader │ ├── __init__.py │ ├── base.py │ └── mvsec.py ├── event_image_converter.py ├── feature_calculator.py ├── solver │ ├── __init__.py │ ├── base.py │ ├── nnmodels │ │ ├── __init__.py │ │ ├── basic_layers.py │ │ └── ev_flownet.py │ ├── patch_contrast_base.py │ ├── patch_contrast_mixed.py │ ├── patch_contrast_pyramid.py │ ├── scipy_autograd │ │ ├── README.md │ │ ├── __init__.py │ │ ├── base_wrapper.py │ │ ├── scipy_minimize.py │ │ └── torch_wrapper.py │ └── time_aware_patch_contrast.py ├── types │ ├── __init__.py │ └── flow_patch.py ├── utils │ ├── __init__.py │ ├── event_utils.py │ ├── flow_utils.py │ ├── misc.py │ └── stat_utils.py ├── visualizer.py └── warp.py └── tests ├── costs ├── test_gradient_magnitude.py ├── test_hybrid.py └── test_image_variance.py ├── test_event_image_converter.py ├── test_warp.py └── utils ├── test_event_utils.py └── test_flow_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # VS Code files for those working on multiple tools 4 | .vscode 5 | !.vscode/settings.json 6 | !.vscode/tasks.json 7 | !.vscode/launch.json 8 | !.vscode/extensions.json 9 | *.code-workspace 10 | 11 | # Dataset and output large file 12 | datasets/* 13 | !datasets/.gitkeep 14 | !datasets/README.md 15 | outputs/* 16 | !outputs/.gitkeep 17 | 18 | *.log 19 | *.prof 20 | **/__pycache__/* 21 | **/*.pyc 22 | *.mp4 23 | *.avi 24 | 25 | wandb/* 26 | jupyter/* 27 | .ipynb_checkpoints 28 | 29 | **build 30 | **dist 31 | **.egg-info 32 | 33 | # Because we cannot assume pytorch version (CUDA, CPU etc), not commit .lock file. 34 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 35 | poetry.lock 36 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | lint: 2 | poetry run python3 -m mypy --install-types 3 | poetry run python3 -m mypy src/ main.py 4 | 5 | fmt: 6 | poetry run python3 -m black src/ tests/ main.py 7 | poetry run python3 -m isort ./src ./tests main.py 8 | 9 | run: 10 | poetry run python3 main.py --config_file ./configs/mvsec_indoor_no_timeaware.yaml 11 | 12 | test: 13 | poetry run pytest -vvv -c tests/ -o "testpaths=tests" -W ignore::DeprecationWarning 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 👀 **The extension paper has been accepted to IEEE T-PAMI! ([Paper](https://doi.org/10.1109/TPAMI.2024.3396116))** 2 | 3 | 👀 **We are now working to make this method more generic, easy-to-use functions (`flow = useful_function(events)`). Stay tuned!** 4 | 5 | # Secrets of Event-Based Optical Flow (T-PAMI 2024, ECCV 2022) 6 | 7 | This is the official repository for [**Secrets of Event-Based Optical Flow**](https://arxiv.org/abs/2207.10022), **ECCV 2022 Oral** by 8 | [Shintaro Shiba](http://shibashintaro.com/), [Yoshimitsu Aoki](https://aoki-medialab.jp/aokiyoshimitsu-en/) and [Guillermo Callego](http://www.guillermogallego.es). 9 | 10 | We have extended this paper to a journal version: [**Secrets of Event-based Optical Flow, Depth and Ego-motion Estimation by Contrast Maximization**](https://doi.org/10.1109/TPAMI.2024.3396116), **IEEE T-PAMI 2024**. 11 | 12 | 16 | 17 |

18 | 19 | [Paper (IEEE T-PAMI 2024)](https://hal.science/hal-04655247v1/document) | [Paper (ECCV 2022)](https://arxiv.org/pdf/2207.10022) | [Video](https://youtu.be/nUb2ZRPdbWk) | [Poster](docs/img/2024_TPAMI_SecretsOfEVFlow_poster.pdf) 20 |

21 | 22 | [![Secrets of Event-Based Optical Flow](docs/img/secretsevflow_eccv22.jpg)](https://youtu.be/nUb2ZRPdbWk) 23 | 24 | 25 | If you use this work in your research, please cite it (see also [here](#citation)): 26 | 27 | ```bibtex 28 | @Article{Shiba24pami, 29 | author = {Shintaro Shiba and Yannick Klose and Yoshimitsu Aoki and Guillermo Gallego}, 30 | title = {Secrets of Event-based Optical Flow, Depth, and Ego-Motion by Contrast Maximization}, 31 | journal = {IEEE Trans. Pattern Anal. Mach. Intell. (T-PAMI)}, 32 | year = 2024, 33 | pages = {1--18}, 34 | doi = {10.1109/TPAMI.2024.3396116} 35 | } 36 | 37 | @InProceedings{Shiba22eccv, 38 | author = {Shintaro Shiba and Yoshimitsu Aoki and Guillermo Gallego}, 39 | title = {Secrets of Event-based Optical Flow}, 40 | booktitle = {European Conference on Computer Vision (ECCV)}, 41 | pages = {628--645}, 42 | doi = {10.1007/978-3-031-19797-0_36}, 43 | year = 2022 44 | } 45 | ``` 46 | 47 | ## **List of datasets that the flow estimation is tested on** 48 | 49 | Although this codebase releases just MVSEC examples, 50 | I have tested the flow estimation is roughly good in the below datasets. 51 | The list is being updated, and if you test new datasets please let us know. 52 | 53 | - [MVSEC](https://daniilidis-group.github.io/mvsec/) 54 | - [DSEC](https://dsec.ifi.uzh.ch/dsec-datasets/download/) 55 | - [ECD, both simulation and real data](http://rpg.ifi.uzh.ch/davis_data.html) 56 | - [TUM VIE](https://cvg.cit.tum.de/data/datasets/visual-inertial-event-dataset) 57 | - [UZH-FPV Drone Racing Dataset](https://fpv.ifi.uzh.ch/) 58 | - [EDS](https://rpg.ifi.uzh.ch/eds.html#dataset) 59 | - [M3ED](https://m3ed.io/) 60 | 61 | The above is all public datasets, and in our paper (T-PAMI 2024) we also used some non-public dataset from previous works. 62 | 63 | ------- 64 | # Setup 65 | 66 | ## Requirements 67 | 68 | Although not all versions are strictly tested, the followings should work. 69 | 70 | - python: 3.8.x, 3.9.x, 3.10.x 71 | 72 | GPU is entirely optional. 73 | If `torch.cuda.is_available()` then it automatically switches to use GPU. 74 | I'd recomment to use GPU for time-aware solutions, but CPU is ok for no-timeaware method as long as I tested. 75 | 76 | ### Tested environments 77 | 78 | - Mac OS Monterey (both M1 and non-M1) 79 | - Ubuntu (CUDA 11.1, 11.3, 11.8) 80 | - PyTorch 1.9-1.12.1, or PyTorch 2.0 (1.13 raises an error during Burgers). 81 | 82 | ## Installation 83 | 84 | I strongly recommend to use venv: `python3 -m venv ` 85 | Also, you can use [poetry](). 86 | 87 | - Install pytorch **< 1.13** or **>= 2.0** and torchvision for your environment. Make sure you install the correct CUDA version if you want to use it. 88 | 89 | - If you use poetry, `poetry install`. If you use only venv, check dependecy libraries and install it from [here](./pyproject.toml). 90 | 91 | - If you are having trouble to install pytorch with cuda using poetry refer to this [link](https://github.com/python-poetry/poetry/issues/6409). 92 | 93 | ## Download dataset 94 | 95 | Download each dataset under `./datasets` directory. 96 | Optionally you can specify other root directory: 97 | please check the [dataset readme](./datasets/README.md) for the details. 98 | 99 | # Execution 100 | 101 | ```shell 102 | python3 main.py --config_file ./configs/mvsec_indoor_no_timeaware.yaml 103 | ``` 104 | 105 | If you use poetry, simply add `poetry run` at the beginning. 106 | Please run with `-h` option to know more about the other options. 107 | 108 | ## Config file 109 | 110 | The config (.yaml) file specifies various experimental settings. 111 | Please check and change parameters as you like. 112 | 113 | ### Optional tasks (for me) 114 | 115 | **The code here is already runnable, and explains the ideas of the paper enough.** (Please report bugs if any.) 116 | 117 | Rather than releasing all of my (sometimes too experimental) codes, 118 | I published just a minimal set of the codebase to reproduce. 119 | So the following tasks are more optional for me. 120 | But if it helps you, I can publish other parts as well. For example: 121 | 122 | - Other data loader 123 | 124 | - Some other cost functions 125 | 126 | - Pretrained model checkpoint file ✔️ [released for MVSEC](https://drive.google.com/file/d/13m-waAt5X0C7f0JLBwb6KAApYxgXoA2J/view?usp=sharing) 127 | 128 | - Other solver (especially DNN) 129 | 130 | - The implementation of [the Sensors paper]((https://www.mdpi.com/1424-8220/22/14/5190)) 131 | 132 | Your feedback is helpful to prioritize the tasks, so please contact me or raise issues. 133 | The code is modularized well, so if you want to contribute, it should be easy too. 134 | 135 | # Citation 136 | 137 | If you use this work in your research, please cite it **as stated above**, below the video. 138 | 139 | This code also includes some implementation of the [following paper about event collapse in details](https://www.mdpi.com/1424-8220/22/14/5190). 140 | Please check it :) 141 | 142 | ```bibtex 143 | @Article{Shiba22sensors, 144 | author = {Shintaro Shiba and Yoshimitsu Aoki and Guillermo Gallego}, 145 | title = {Event Collapse in Contrast Maximization Frameworks}, 146 | journal = {Sensors}, 147 | year = 2022, 148 | volume = 22, 149 | number = 14, 150 | pages = {1--20}, 151 | article-number= 5190, 152 | doi = {10.3390/s22145190} 153 | } 154 | ``` 155 | 156 | # Author 157 | 158 | Shintaro Shiba [@shiba24](https://github.com/shiba24) 159 | 160 | ## LICENSE 161 | 162 | Please check [License](./LICENSE). 163 | 164 | ## Acknowledgement 165 | 166 | I appreciate the following repositories for the inspiration: 167 | 168 | - [autograd-minimize](https://github.com/brunorigal/autograd-minimize) 169 | - [EVFlowNet-pytorch](https://github.com/CyrilSterling/EVFlowNet-pytorch) 170 | 171 | ------- 172 | # Additional Resources 173 | 174 | * [Motion-prior Contrast Maximization (ECCV 2024)](https://github.com/tub-rip/MotionPriorCMax) 175 | * [EVILIP: Event-based Image Reconstruction as a Linear Inverse Problem (TPAMI 2022)](https://github.com/tub-rip/event_based_image_rec_inverse_problem) 176 | * [Event Collapse in Contrast Maximization Frameworks](https://github.com/tub-rip/event_collapse) 177 | * [CMax-SLAM (TRO 2024)](https://github.com/tub-rip/cmax_slam) 178 | * [EBOS: Event-based Background-Oriented Schlieren (TPAMI 2023)](https://github.com/tub-rip/event_based_bos) 179 | * [EPBA: Event-based Photometric Bundle Adjustment](https://github.com/tub-rip/epba) 180 | * [ES-PTAM: Event-based Stereo Parallel Tracking and Mapping](https://github.com/tub-rip/ES-PTAM) 181 | * [Research page (TU Berlin, RIP lab)](https://sites.google.com/view/guillermogallego/research/event-based-vision) 182 | * [Research page (Keio University, Aoki Media Lab)](https://aoki-medialab.jp/home-en/) 183 | * [Course at TU Berlin](https://sites.google.com/view/guillermogallego/teaching/event-based-robot-vision) 184 | * [Survey paper](http://rpg.ifi.uzh.ch/docs/EventVisionSurvey.pdf) 185 | * [List of Resources](https://github.com/uzh-rpg/event-based_vision_resources) 186 | -------------------------------------------------------------------------------- /configs/mvsec_indoor_burgers.yaml: -------------------------------------------------------------------------------- 1 | is_dnn: false 2 | 3 | data: 4 | eval_dt: 4 # for MVSEC evaluation 5 | root: "~/local/event_based_optical_flow/datasets/MVSEC/hdf5" 6 | dataset: "MVSEC" 7 | sequence: "indoor_flying1" 8 | height: 260 9 | width: 346 10 | load_gt_flow: True 11 | gt: "~/local/event_based_optical_flow/datasets/MVSEC/gt_flow" 12 | n_events_per_batch: 30000 13 | ind1: 1130000 14 | ind2: 1160000 15 | 16 | output: 17 | output_dir: "./outputs/paper/burgers/indoor_flying1_dt4" 18 | show_interactive_result: false 19 | 20 | solver: 21 | method: "pyramidal_patch_contrast_maximization" 22 | time_aware: True 23 | time_bin: 10 24 | flow_interpolation: "burgers" 25 | t0_flow_location: "middle" 26 | patch: 27 | initialize: "random" 28 | scale: 5 29 | crop_height: 256 30 | crop_width: 336 31 | filter_type: "bilinear" 32 | motion_model: "2d-translation" 33 | warp_direction: "first" 34 | parameters: 35 | - "trans_x" 36 | - "trans_y" 37 | cost: "hybrid" 38 | outer_padding: 0 39 | cost_with_weight: 40 | multi_focal_normalized_gradient_magnitude: 1. 41 | total_variation: 0.01 42 | iwe: 43 | method: "bilinear_vote" 44 | blur_sigma: 1 45 | 46 | optimizer: 47 | n_iter: 40 48 | method: "Newton-CG" 49 | max_iter: 25 50 | parameters: 51 | trans_x: 52 | min: -150 53 | max: 150 54 | trans_y: 55 | min: -150 56 | max: 150 57 | -------------------------------------------------------------------------------- /configs/mvsec_indoor_no_timeaware.yaml: -------------------------------------------------------------------------------- 1 | is_dnn: false 2 | 3 | data: 4 | eval_dt: 4 # for MVSEC evaluation 5 | root: "~/local/event_based_optical_flow/datasets/MVSEC/hdf5" 6 | dataset: "MVSEC" 7 | sequence: "indoor_flying1" 8 | height: 260 9 | width: 346 10 | load_gt_flow: True 11 | gt: "~/local/event_based_optical_flow/datasets/MVSEC/gt_flow" 12 | n_events_per_batch: 30000 13 | ind1: 1130000 14 | ind2: 1160000 15 | 16 | output: 17 | output_dir: "./outputs/paper/no_timeaware/indoor_flying1_dt4" 18 | show_interactive_result: false 19 | 20 | solver: 21 | method: "pyramidal_patch_contrast_maximization" 22 | time_aware: False 23 | patch: 24 | initialize: "random" 25 | scale: 5 26 | crop_height: 256 27 | crop_width: 336 28 | filter_type: "bilinear" 29 | motion_model: "2d-translation" 30 | warp_direction: "first" 31 | parameters: 32 | - "trans_x" 33 | - "trans_y" 34 | cost: "hybrid" 35 | outer_padding: 0 36 | cost_with_weight: 37 | multi_focal_normalized_gradient_magnitude: 1. 38 | total_variation: 0.01 39 | iwe: 40 | method: "bilinear_vote" 41 | blur_sigma: 1 42 | 43 | optimizer: 44 | n_iter: 40 45 | method: "Newton-CG" 46 | max_iter: 25 47 | parameters: 48 | trans_x: 49 | min: -150 50 | max: 150 51 | trans_y: 52 | min: -150 53 | max: 150 54 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | -------------------------------------------------------------------------------- /datasets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/event_based_optical_flow/8c12d21a91c22922bb150acb9b1bada2f4b23def/datasets/.gitkeep -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Folder structure 2 | 3 | The folder structure should follow below: 4 | 5 | ```shell 6 | tree -L 3 7 | . 8 | ├── MVSEC 9 | │ ├── hdf5 10 | │ ├── indoor_flying1_data.hdf5 11 | │ ... 12 | │ └── gt_flow 13 | │ ├── indoor_flying1_gt_flow_dist.npz 14 | │ ... 15 | └── README.md # this readme 16 | ``` 17 | 18 | Please download datasets accordingly. 19 | 20 | - MVSEC data from (https://drive.google.com/drive/folders/1gDy2PwVOu_FPOsEZjojdWEB2ZHmpio8D) 21 | 22 | # Your own dataset location 23 | 24 | Optionally, you don't have to locate the files here, 25 | rather choose your dataset root directory location. 26 | 27 | You need to specify the root folder with the config yaml file. 28 | 29 | -------------------------------------------------------------------------------- /docs/img/2024_TPAMI_SecretsOfEVFlow_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/event_based_optical_flow/8c12d21a91c22922bb150acb9b1bada2f4b23def/docs/img/2024_TPAMI_SecretsOfEVFlow_poster.pdf -------------------------------------------------------------------------------- /docs/img/secretsevflow_eccv22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/event_based_optical_flow/8c12d21a91c22922bb150acb9b1bada2f4b23def/docs/img/secretsevflow_eccv22.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | import argparse 3 | import logging 4 | import os 5 | import shutil 6 | import sys 7 | 8 | import numpy as np 9 | import yaml 10 | from tqdm import tqdm 11 | 12 | from src import data_loader, solver, utils, visualizer 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "--config_file", 19 | default="./configs/mvsec_indoor_no_timeaware.yaml", 20 | help="Config file yaml path", 21 | type=str, 22 | ) 23 | parser.add_argument( 24 | "--eval", 25 | help="Add for evaluation run", 26 | action="store_true", 27 | ) 28 | parser.add_argument( 29 | "--log", help="Log level: [debug, info, warning, error, critical]", type=str, default="info" 30 | ) 31 | args = parser.parse_args() 32 | with open(args.config_file, "r") as f: 33 | config = yaml.safe_load(f) 34 | return config, args 35 | 36 | 37 | def save_config(save_dir: str, file_name: str, log_level=logging.INFO): 38 | """Save configuration""" 39 | if not os.path.exists(save_dir): 40 | os.makedirs(save_dir) 41 | shutil.copy(file_name, save_dir) 42 | logging.basicConfig( 43 | handlers=[ 44 | logging.FileHandler(f"{save_dir}/main.log", mode="w"), 45 | logging.StreamHandler(sys.stdout), 46 | ], 47 | level=log_level, 48 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 49 | ) 50 | 51 | 52 | def evaluate_mvsec_dataset_with_gt(eval_frame_time_stamp_list, data_config, loader, solv): 53 | logger.info("Evaluation pipeline") 54 | eval_dt = data_config["eval_dt"] 55 | assert eval_dt == 1 or eval_dt == 4 56 | logger.info(f"dt (for MVSEC) is {eval_dt}") 57 | n_events = data_config["n_events_per_batch"] 58 | 59 | for i1 in tqdm(range(len(eval_frame_time_stamp_list) - eval_dt)): 60 | logger.info(f"Frame {i1} of {len(eval_frame_time_stamp_list)}") 61 | try: 62 | if i1 < data_config["ind1"] or i1 > data_config["ind2"]: 63 | continue # cutofff 64 | except KeyError: 65 | pass 66 | t1 = eval_frame_time_stamp_list[i1] 67 | t2 = eval_frame_time_stamp_list[i1 + eval_dt] 68 | ind1 = loader.time_to_index(t1) # event index 69 | ind2 = loader.time_to_index(t2) 70 | 71 | # Flow error metrics calculation is based on GT flow + events between the consective GT flow frames 72 | batch_for_gt_slice = loader.load_event(ind1, ind2) 73 | gt_flow = loader.load_optical_flow(t1, t2) 74 | flow_time = t2 - t1 75 | batch_for_gt_slice[..., 2] -= np.min(batch_for_gt_slice[..., 2]) 76 | 77 | # Optimization is based on fixed number of events 78 | if ind2 - ind1 < n_events: 79 | logger.info( 80 | f"Less events in one GT flow sequence. Events: {ind2-ind1} / Expected: {n_events}" 81 | ) 82 | insufficient = n_events - (ind2 - ind1) 83 | ind1 -= insufficient // 2 84 | ind2 += insufficient // 2 85 | elif ind2 - ind1 > n_events: 86 | logger.info( 87 | f"Too many events in one GT flow sequence. Events: {ind2-ind1} / Expected: {n_events}" 88 | ) 89 | ind1 = ind2 - n_events 90 | 91 | batch_for_optimization = loader.load_event(max(ind1, 0), min(ind2, len(loader))) 92 | batch_for_optimization[..., 2] -= np.min(batch_for_optimization[..., 2]) 93 | 94 | if utils.check_key_and_bool(data_config, "remove_car"): 95 | logger.info("Remove car-boody pixels") 96 | batch_for_optimization = utils.crop_event(batch_for_optimization, 0, 193, 0, 346) 97 | 98 | best_motion = solv.optimize(batch_for_optimization) 99 | solv.set_previous_frame_best_estimation(best_motion) 100 | # mask with event 101 | flow_error_with_mask = solv.calculate_flow_error(best_motion, gt_flow, timescale=flow_time, events=batch_for_gt_slice) # type: ignore 102 | solv.save_flow_error_as_text(i1, flow_error_with_mask, "flow_error_per_frame_with_mask.txt") # type: ignore 103 | 104 | # Visualization 105 | solv.visualize_original_sequential(batch_for_gt_slice) 106 | solv.visualize_pred_sequential(batch_for_gt_slice, best_motion) 107 | solv.visualize_gt_sequential(batch_for_gt_slice, gt_flow) 108 | 109 | 110 | if __name__ == "__main__": 111 | config, args = parse_args() 112 | data_config: dict = config["data"] 113 | out_config: dict = config["output"] 114 | log_level = getattr(logging, args.log.upper(), None) 115 | if not isinstance(log_level, int): 116 | raise ValueError("Invalid log level: %s" % log_level) 117 | save_config(out_config["output_dir"], args.config_file, log_level) 118 | logger = logging.getLogger(__name__) 119 | 120 | if utils.check_key_and_bool(config, "fix_random_seed"): 121 | utils.fix_random_seed() 122 | 123 | # Visualizer 124 | image_shape = (data_config["height"], data_config["width"]) 125 | if config["is_dnn"] and "crop" in data_config["preprocess"].keys(): 126 | image_shape = (data_config["preprocess"]["crop"]["height"], data_config["preprocess"]["crop"]["width"]) # type: ignore 127 | 128 | viz = visualizer.Visualizer( 129 | image_shape, 130 | show=out_config["show_interactive_result"], 131 | save=True, 132 | save_dir=out_config["output_dir"], 133 | ) 134 | 135 | # Loader 136 | loader = data_loader.collections[data_config["dataset"]](config=data_config) 137 | loader.set_sequence(data_config["sequence"]) 138 | 139 | # Solver 140 | method_name = config["solver"]["method"] 141 | solv: solver.SolverBase = solver.collections[method_name]( 142 | image_shape, 143 | calibration_parameter=loader.load_calib(), 144 | solver_config=config["solver"], 145 | optimizer_config=config["optimizer"], 146 | output_config=config["output"], 147 | visualize_module=viz, 148 | ) 149 | 150 | if args.eval: # Run evaluation piipeline. 151 | if config["is_dnn"]: 152 | e = "DNN code is not published." 153 | logger.error(e) 154 | raise NotImplementedError(e) 155 | else: 156 | logger.info("Sequential optimization") 157 | assert loader.gt_flow_available # evaluate with GT flow 158 | logger.info("evaluation with GT") 159 | eval_frame_time_stamp_list = loader.eval_frame_time_list() 160 | evaluate_mvsec_dataset_with_gt(eval_frame_time_stamp_list, data_config, loader, solv) 161 | logger.info(f"Evaluation done! {data_config['sequence']}") 162 | exit() 163 | 164 | # Not evaluation - single frame optimization 165 | if config["is_dnn"]: 166 | e = "DNN code is not published." 167 | logger.error(e) 168 | raise NotImplementedError(e) 169 | else: # For non-DNN method 170 | logger.info("Single-frame optimization") 171 | ind1, ind2 = data_config["ind1"], data_config["ind2"] 172 | batch: np.ndarray = loader.load_event(ind1, ind2) 173 | batch[..., 2] -= np.min(batch[..., 2]) 174 | 175 | if utils.check_key_and_bool(data_config, "remove_car"): 176 | batch = utils.crop_event(batch, 0, 193, 0, 346) # remvoe MVSEC car 177 | 178 | solv.visualize_one_batch_warp(batch) 179 | best_motion: np.ndarray = solv.optimize(batch) 180 | solv.visualize_one_batch_warp(batch, best_motion) 181 | 182 | # Calculate Flow error when GT is available 183 | if loader.gt_flow_available: 184 | t1 = loader.index_to_time(ind1) 185 | t2 = loader.index_to_time(ind2) 186 | gt_flow = loader.load_optical_flow(t1, t2) 187 | 188 | solv.visualize_one_batch_warp_gt(batch, gt_flow) 189 | solv.calculate_flow_error(best_motion, gt_flow, t2 - t1, batch) 190 | -------------------------------------------------------------------------------- /outputs/.gitkeep: -------------------------------------------------------------------------------- 1 | !gallego_cvpr2018 2 | !gallego_cvpr2019 3 | !mitrokhin_iros2018 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "event-based-optical-flow" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Shintaro Shiba "] 6 | 7 | [tool.poetry.dependencies] 8 | python = ">=3.8,<3.11" 9 | numpy = "^1.21.2" 10 | torch = ">=1.9.1" 11 | torchvision = ">=0.10.1" 12 | opencv-python = "^4.5.3" 13 | matplotlib = "^3.4.3" 14 | argparse = "^1.4.0" 15 | PyYAML = ">=5.4.1" 16 | Pillow = "^8.3.2" 17 | optuna = "^2.10.0" 18 | sklearn = "^0.0" 19 | wandb = "^0.12.4" 20 | hdf5plugin = "^3.2.0" 21 | h5py = "^3.5.0" 22 | plotly = "^5.4.0" 23 | scikit-image = "^0.19.1" 24 | ffmpeg-python = "^0.2.0" 25 | 26 | [tool.poetry.dev-dependencies] 27 | pytest = "^6.2.5" 28 | black = "^21.9b0" 29 | mypy = "^0.910" 30 | 31 | [build-system] 32 | requires = ["poetry-core>=1.0.0"] 33 | build-backend = "poetry.core.masonry.api" 34 | 35 | [tool.black] 36 | line-length = 100 37 | target-version = ['py39'] 38 | include = '\.pyi?$' 39 | 40 | [tool.mypy] 41 | ignore_missing_imports = true 42 | 43 | [tool.isort] 44 | profile = "black" 45 | line_length = 119 46 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tub-rip/event_based_optical_flow/8c12d21a91c22922bb150acb9b1bada2f4b23def/src/__init__.py -------------------------------------------------------------------------------- /src/costs/__init__.py: -------------------------------------------------------------------------------- 1 | """isort:skip_file 2 | """ 3 | # Basics 4 | from .base import CostBase 5 | from .gradient_magnitude import GradientMagnitude 6 | from .image_variance import ImageVariance 7 | 8 | # from .zhu_average_timestamp import ZhuAverageTimestamp 9 | # from .paredes_average_timestamp import ParedesAverageTimestamp 10 | 11 | # Flow related 12 | from .total_variation import TotalVariation 13 | 14 | # Normalized ~ 15 | from .normalized_image_variance import NormalizedImageVariance 16 | from .normalized_gradient_magnitude import NormalizedGradientMagnitude 17 | 18 | # Multi-reference ~ 19 | from .multi_focal_normalized_image_variance import MultiFocalNormalizedImageVariance 20 | from .multi_focal_normalized_gradient_magnitude import MultiFocalNormalizedGradientMagnitude 21 | 22 | 23 | def inheritors(klass): 24 | subclasses = set() 25 | work = [klass] 26 | while work: 27 | parent = work.pop() 28 | for child in parent.__subclasses__(): 29 | if child not in subclasses: 30 | subclasses.add(child) 31 | work.append(child) 32 | return subclasses 33 | 34 | 35 | functions = {k.name: k for k in inheritors(CostBase)} 36 | 37 | # For hybrid loss 38 | from .hybrid import HybridCost 39 | -------------------------------------------------------------------------------- /src/costs/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List 3 | 4 | import torch 5 | 6 | from ..types import FLOAT_TORCH 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class CostBase(object): 12 | """Base of the Cost class. 13 | Args: 14 | direction (str) ... 'minimize' or 'maximize' or `natural`. 15 | Defines the objective function. If natural, it returns more interpretable value. 16 | """ 17 | 18 | required_keys: List[str] = [] 19 | 20 | def __init__(self, direction="minimize", store_history: bool = False, *args, **kwargs): 21 | if direction not in ["minimize", "maximize", "natural"]: 22 | e = f"direction should be minimize, maximize, and natural. Got {direction}." 23 | logger.error(e) 24 | raise ValueError(e) 25 | self.direction = direction 26 | self.store_history = store_history 27 | self.clear_history() 28 | 29 | def catch_key_error(func): 30 | """Wrapper utility function to catch the key error.""" 31 | 32 | def wrapper(self, arg: dict): 33 | try: 34 | return func(self, arg) # type: ignore 35 | except KeyError as e: 36 | logger.error("Input for the cost needs keys of:") 37 | logger.error(self.required_keys) 38 | raise e 39 | 40 | return wrapper 41 | 42 | def register_history(func): 43 | """Registr history of the loss.""" 44 | 45 | def wrapper(self, arg: dict): 46 | loss = func(self, arg) # type: ignore 47 | if self.store_history: 48 | self.history["loss"].append(self.get_item(loss)) 49 | return loss 50 | 51 | return wrapper 52 | 53 | def get_item(self, loss: FLOAT_TORCH) -> float: 54 | if isinstance(loss, torch.Tensor): 55 | return loss.item() 56 | return loss 57 | 58 | def clear_history(self) -> None: 59 | self.history: Dict[str, list] = {"loss": []} 60 | 61 | def get_history(self) -> dict: 62 | return self.history.copy() 63 | 64 | def enable_history_register(self) -> None: 65 | self.store_history = True 66 | 67 | def disable_history_register(self) -> None: 68 | self.store_history = False 69 | 70 | # Every subclass needs to implement calculate() 71 | @register_history # type: ignore 72 | @catch_key_error # type: ignore 73 | def calculate(self, arg: dict) -> FLOAT_TORCH: 74 | raise NotImplementedError 75 | 76 | catch_key_error = staticmethod(catch_key_error) # type: ignore 77 | register_history = staticmethod(register_history) # type: ignore 78 | -------------------------------------------------------------------------------- /src/costs/gradient_magnitude.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from ..types import FLOAT_TORCH 8 | from ..utils import SobelTorch 9 | from . import CostBase 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class GradientMagnitude(CostBase): 15 | """Gradient Magnitude loss from Gallego et al. CVPR 2019. 16 | 17 | Args: 18 | direction (str) ... 'minimize' or 'maximize' or `natural`. 19 | Defines the objective function. If natural, it returns more interpretable value. 20 | """ 21 | 22 | name = "gradient_magnitude" 23 | required_keys = ["iwe", "omit_boundary"] 24 | 25 | def __init__( 26 | self, 27 | direction="minimize", 28 | store_history: bool = False, 29 | cuda_available=False, 30 | precision="32", 31 | *args, 32 | **kwargs, 33 | ): 34 | super().__init__(direction=direction, store_history=store_history) 35 | self.precision = precision 36 | self.torch_sobel = SobelTorch( 37 | ksize=3, in_channels=1, cuda_available=cuda_available, precision=precision 38 | ) 39 | 40 | @CostBase.register_history # type: ignore 41 | @CostBase.catch_key_error # type: ignore 42 | def calculate(self, arg: dict) -> FLOAT_TORCH: 43 | """Calculate gradient of IWE. 44 | Inputs: 45 | iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of warped events 46 | omit_bondary (bool) ... Omit boundary if True. 47 | 48 | Returns: 49 | (Union[float, torch.Tensor]) ... Magnitude of gradient 50 | """ 51 | iwe = arg["iwe"] 52 | if isinstance(iwe, torch.Tensor): 53 | return self.calculate_torch(iwe, arg["omit_boundary"]) 54 | elif isinstance(iwe, np.ndarray): 55 | return self.calculate_numpy(iwe, arg["omit_boundary"]) 56 | e = f"Unsupported input type. {type(iwe)}." 57 | logger.error(e) 58 | raise NotImplementedError(e) 59 | 60 | def calculate_torch(self, iwe: torch.Tensor, omit_boundary: bool) -> torch.Tensor: 61 | if len(iwe.shape) == 2: 62 | iwe = iwe[None, None, ...] 63 | elif len(iwe.shape) == 3: 64 | iwe = iwe[:, None, ...] 65 | if self.precision == "64": 66 | iwe = iwe.double() 67 | iwe_sobel = self.torch_sobel.forward(iwe) / 8.0 68 | gx = iwe_sobel[:, 0] 69 | gy = iwe_sobel[:, 1] 70 | if omit_boundary: 71 | gx = gx[..., 1:-1, 1:-1] 72 | gy = gy[..., 1:-1, 1:-1] 73 | magnitude = torch.mean(torch.square(gx) + torch.square(gy)) 74 | if self.direction == "minimize": 75 | return -magnitude 76 | return magnitude 77 | 78 | def calculate_numpy(self, iwe: np.ndarray, omit_boundary: bool) -> float: 79 | """Calculate contrast of the count image. 80 | Inputs: 81 | iwe (np.ndarray) ... [W, H]. Image of warped events 82 | omit_bondary (bool) ... Omit boundary if True. 83 | 84 | Returns: 85 | (float) ... magnitude of gradient. 86 | """ 87 | gx = cv2.Sobel(iwe, cv2.CV_64F, 1, 0, ksize=3) / 8.0 88 | gy = cv2.Sobel(iwe, cv2.CV_64F, 0, 1, ksize=3) / 8.0 89 | if omit_boundary: 90 | gx = gx[..., 1:-1, 1:-1] 91 | gy = gy[..., 1:-1, 1:-1] 92 | magnitude = np.mean(np.square(gx) + np.square(gy)) 93 | if self.direction == "minimize": 94 | return -magnitude 95 | return magnitude 96 | -------------------------------------------------------------------------------- /src/costs/hybrid.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from . import CostBase, functions 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class HybridCost(CostBase): 13 | """Hybrid cost function with arbitrary weight. 14 | 15 | Args: 16 | direction (str) ... 'minimize' or 'maximize'. 17 | cost_with_weight (dict) ... key is the name of the cost, value is its weight. 18 | """ 19 | 20 | name = "hybrid" 21 | 22 | def __init__( 23 | self, direction: str, cost_with_weight: dict, store_history: bool = False, *args, **kwargs 24 | ): 25 | logger.info(f"Log functions are mix of {cost_with_weight}") 26 | self.cost_func = { 27 | key: { 28 | "func": functions[key]( 29 | direction=direction, store_history=store_history, *args, **kwargs 30 | ), 31 | "weight": value, 32 | } 33 | for key, value in cost_with_weight.items() 34 | } 35 | super().__init__(direction=direction, store_history=store_history) 36 | 37 | self.required_keys = [] 38 | for name in self.cost_func.keys(): 39 | self.required_keys.extend(self.cost_func[name]["func"].required_keys) 40 | 41 | def update_weight(self, cost_with_weight): 42 | assert set(self.cost_func.keys()) == set(cost_with_weight.keys()) 43 | for key in cost_with_weight.keys(): 44 | self.cost_func[key]["weight"] = cost_with_weight[key] 45 | 46 | @CostBase.register_history # type: ignore 47 | @CostBase.catch_key_error # type: ignore 48 | def calculate(self, arg: dict) -> Union[float, torch.Tensor]: 49 | loss = 0.0 50 | for name in self.cost_func.keys(): 51 | if self.cost_func[name]["weight"] == "inv": 52 | _l = 1.0 / self.cost_func[name]["func"].calculate(arg) 53 | loss += _l 54 | else: 55 | _l = self.cost_func[name]["weight"] * self.cost_func[name]["func"].calculate(arg) 56 | loss += _l 57 | return loss 58 | 59 | # For hybrid cost function, need to store with its name 60 | def clear_history(self) -> None: 61 | self.history = {"loss": []} 62 | for name in self.cost_func.keys(): 63 | self.cost_func[name]["func"].clear_history() 64 | 65 | def get_history(self) -> dict: 66 | dic = self.history.copy() 67 | for name in self.cost_func.keys(): 68 | dic.update({name: self.cost_func[name]["func"].get_history()["loss"]}) 69 | return dic 70 | 71 | def enable_history_register(self) -> None: 72 | self.store_history = True 73 | for name in self.cost_func.keys(): 74 | self.cost_func[name]["func"].store_history = True 75 | 76 | def disable_history_register(self) -> None: 77 | self.store_history = False 78 | for name in self.cost_func.keys(): 79 | self.cost_func[name]["func"].store_history = False 80 | -------------------------------------------------------------------------------- /src/costs/image_variance.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from . import CostBase 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class ImageVariance(CostBase): 13 | """Image Variance from Gallego et al. CVPR 2018. 14 | Args: 15 | direction (str) ... 'minimize' or 'maximize' or `natural`. 16 | Defines the objective function. If natural, it returns more interpretable value. 17 | """ 18 | 19 | name = "image_variance" 20 | required_keys = ["iwe", "omit_boundary"] 21 | 22 | def __init__(self, direction="minimize", store_history: bool = False, *args, **kwargs): 23 | super().__init__(direction=direction, store_history=store_history) 24 | 25 | @CostBase.register_history # type: ignore 26 | @CostBase.catch_key_error # type: ignore 27 | def calculate(self, arg: dict) -> Union[float, torch.Tensor]: 28 | """Calculate contrast of the IWE. 29 | Inputs: 30 | iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of warped events 31 | omit_bondary (bool) ... Omit boundary if True. 32 | 33 | Returns: 34 | contrast (Union[float, torch.Tensor]) ... contrast of the image. 35 | """ 36 | iwe = arg["iwe"] 37 | if arg["omit_boundary"]: 38 | iwe = iwe[..., 1:-1, 1:-1] # omit boundary 39 | if isinstance(iwe, torch.Tensor): 40 | return self.calculate_torch(iwe) 41 | elif isinstance(iwe, np.ndarray): 42 | return self.calculate_numpy(iwe) 43 | e = f"Unsupported input type. {type(iwe)}." 44 | logger.error(e) 45 | raise NotImplementedError(e) 46 | 47 | def calculate_torch(self, iwe: torch.Tensor) -> torch.Tensor: 48 | """Calculate contrast of the IWE. 49 | Inputs: 50 | iwe (torch.Tensor) ... [W, H]. Image of warped events 51 | 52 | Returns: 53 | loss (torch.Tensor) ... contrast of the image. 54 | """ 55 | loss = torch.var(iwe) 56 | if self.direction == "minimize": 57 | return -loss 58 | return loss 59 | 60 | def calculate_numpy(self, iwe: np.ndarray) -> float: 61 | """Calculate contrast of the IWE. 62 | Inputs: 63 | iwe (np.ndarray) ... [W, H]. Image of warped events 64 | 65 | Returns: 66 | contrast (float) ... contrast of the image. 67 | """ 68 | loss = np.var(iwe) 69 | if self.direction == "minimize": 70 | return -loss 71 | return loss 72 | -------------------------------------------------------------------------------- /src/costs/multi_focal_normalized_gradient_magnitude.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ..types import FLOAT_TORCH 8 | from . import CostBase, NormalizedGradientMagnitude 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class MultiFocalNormalizedGradientMagnitude(CostBase): 14 | """Multi-focus normalized gradient magnitude, Shiba et al. ECCV 2022. 15 | 16 | Args: 17 | direction (str) ... 'minimize' or 'maximize' or `natural`. 18 | Defines the objective function. If natural, it returns more interpretable value. 19 | """ 20 | 21 | name = "multi_focal_normalized_gradient_magnitude" 22 | required_keys = ["forward_iwe", "backward_iwe", "middle_iwe", "omit_boundary", "orig_iwe"] 23 | 24 | def __init__( 25 | self, 26 | direction="minimize", 27 | store_history: bool = False, 28 | cuda_available=False, 29 | precision="32", 30 | *args, 31 | **kwargs, 32 | ): 33 | super().__init__(direction=direction, store_history=store_history) 34 | self.gradient_loss = NormalizedGradientMagnitude( 35 | direction=direction, cuda_available=cuda_available, precision=precision 36 | ) 37 | 38 | @CostBase.register_history # type: ignore 39 | @CostBase.catch_key_error # type: ignore 40 | def calculate(self, arg: dict) -> FLOAT_TORCH: 41 | """Calculate cost. 42 | Inputs: 43 | orig_iwe (np.ndarray or torch.Tensor) ... Original IWE (before any warp). 44 | forward_iwe (np.ndarray or torch.Tensor) ... IWE to forward warp. 45 | backward_iwe (np.ndarray or torch.Tensor) ... IWE to backward warp. 46 | middle_iwe (Optional[np.ndarray or torch.Tensor]) ... IWE to middle warp. 47 | omit_bondary (bool) ... Omit boundary if True. 48 | 49 | Returns: 50 | average_time (Union[float, torch.Tensor]) ... Average timestamp. 51 | """ 52 | orig_iwe = arg["orig_iwe"] 53 | forward_iwe = arg["forward_iwe"] 54 | if "middle_iwe" in arg.keys(): 55 | middle_iwe = arg["middle_iwe"] 56 | else: 57 | middle_iwe = None 58 | backward_iwe = arg["backward_iwe"] 59 | omit_boundary = arg["omit_boundary"] 60 | 61 | if isinstance(forward_iwe, torch.Tensor): 62 | return self.calculate_torch( 63 | orig_iwe, forward_iwe, backward_iwe, middle_iwe, omit_boundary 64 | ) 65 | elif isinstance(forward_iwe, np.ndarray): 66 | return self.calculate_numpy( 67 | orig_iwe, forward_iwe, backward_iwe, middle_iwe, omit_boundary 68 | ) 69 | e = f"Unsupported input type. {type(forward_iwe)}." 70 | logger.error(e) 71 | raise NotImplementedError(e) 72 | 73 | def calculate_torch( 74 | self, 75 | orig_iwe: torch.Tensor, 76 | forward_iwe: torch.Tensor, 77 | backward_iwe: torch.Tensor, 78 | middle_iwe: Optional[torch.Tensor], 79 | omit_boundary: bool, 80 | ) -> torch.Tensor: 81 | """Calculate cost for torch tensor. 82 | Inputs: 83 | orig_iwe (torch.Tensor) ... Original IWE (before warp). 84 | forward_iwe (torch.Tensor) ... IWE to forward warp. 85 | backward_iwe (torch.Tensor) ... IWE to backward warp. 86 | middle_iwe (Optional[torch.Tensor]) ... IWE to middle warp. 87 | 88 | Returns: 89 | loss (torch.Tensor) ... average time loss. 90 | """ 91 | forward_loss = self.gradient_loss.calculate_torch(forward_iwe, orig_iwe, omit_boundary) 92 | backward_loss = self.gradient_loss.calculate_torch(backward_iwe, orig_iwe, omit_boundary) 93 | loss = forward_loss + backward_loss 94 | 95 | if middle_iwe is not None: 96 | loss += self.gradient_loss.calculate_torch(middle_iwe, orig_iwe, omit_boundary) * 2 97 | 98 | if self.direction in ["minimize", "natural"]: 99 | return loss 100 | logger.warning("The loss is specified as maximize direction") 101 | return -loss 102 | 103 | def calculate_numpy( 104 | self, 105 | orig_iwe: np.ndarray, 106 | forward_iwe: np.ndarray, 107 | backward_iwe: np.ndarray, 108 | middle_iwe: Optional[np.ndarray], 109 | omit_boundary: bool, 110 | ) -> float: 111 | """Calculate cost for numpy array. 112 | Inputs: 113 | orig_iwe (np.ndarray) ... Original IWE (before warp). 114 | forward_iwe (np.ndarray) ... IWE to forward warp. 115 | backward_iwe (np.ndarray) ... IWE to backward warp. 116 | middle_iwe (Optional[np.ndarray]) ... IWE to middle warp. 117 | 118 | Returns: 119 | loss (float) ... average time loss 120 | """ 121 | forward_loss = self.gradient_loss.calculate_numpy(forward_iwe, orig_iwe, omit_boundary) 122 | backward_loss = self.gradient_loss.calculate_numpy(backward_iwe, orig_iwe, omit_boundary) 123 | loss = forward_loss + backward_loss 124 | 125 | if middle_iwe is not None: 126 | loss += self.gradient_loss.calculate_numpy(middle_iwe, orig_iwe, omit_boundary) * 2 127 | 128 | if self.direction in ["minimize", "natural"]: 129 | return loss 130 | logger.warning("The loss is specified as maximize direction") 131 | return -loss 132 | -------------------------------------------------------------------------------- /src/costs/multi_focal_normalized_image_variance.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ..types import FLOAT_TORCH 8 | from . import CostBase, NormalizedImageVariance 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class MultiFocalNormalizedImageVariance(CostBase): 14 | """Multi-focus normalized image variance, Shiba et al. ECCV 2022. 15 | 16 | Args: 17 | direction (str) ... 'minimize' or 'maximize' or `natural`. 18 | Defines the objective function. If natural, it returns more interpretable value. 19 | """ 20 | 21 | name = "multi_focal_normalized_image_variance" 22 | required_keys = ["forward_iwe", "backward_iwe", "middle_iwe", "omit_boundary", "orig_iwe"] 23 | 24 | def __init__(self, direction="minimize", store_history: bool = False, *args, **kwargs): 25 | super().__init__(direction=direction, store_history=store_history) 26 | self.variance_loss = NormalizedImageVariance(direction=direction) 27 | 28 | @CostBase.register_history # type: ignore 29 | @CostBase.catch_key_error # type: ignore 30 | def calculate(self, arg: dict) -> FLOAT_TORCH: 31 | """Calculate multi-focus normalized image variance. 32 | Inputs: 33 | orig_iwe (np.ndarray or torch.Tensor) ... Original IWE (before any warp). 34 | forward_iwe (np.ndarray or torch.Tensor) ... IWE to forward warp. 35 | backward_iwe (np.ndarray or torch.Tensor) ... IWE to backward warp. 36 | middle_iwe (Optional[np.ndarray or torch.Tensor]) ... IWE to middle warp. 37 | omit_bondary (bool) ... Omit boundary if True. 38 | 39 | Returns: 40 | average_time (Union[float, torch.Tensor]) ... Average timestamp. 41 | """ 42 | orig_iwe = arg["orig_iwe"] 43 | forward_iwe = arg["forward_iwe"] 44 | if "middle_iwe" in arg.keys(): 45 | middle_iwe = arg["middle_iwe"] 46 | else: 47 | middle_iwe = None 48 | backward_iwe = arg["backward_iwe"] 49 | omit_boundary = arg["omit_boundary"] 50 | if omit_boundary: 51 | forward_iwe = forward_iwe[..., 1:-1, 1:-1] # omit boundary 52 | backward_iwe = backward_iwe[..., 1:-1, 1:-1] # omit boundary 53 | if middle_iwe is not None: 54 | middle_iwe = middle_iwe[..., 1:-1, 1:-1] # omit boundary 55 | 56 | if isinstance(forward_iwe, torch.Tensor): 57 | return self.calculate_torch(orig_iwe, forward_iwe, backward_iwe, middle_iwe) 58 | elif isinstance(forward_iwe, np.ndarray): 59 | return self.calculate_numpy(orig_iwe, forward_iwe, backward_iwe, middle_iwe) 60 | e = f"Unsupported input type. {type(forward_iwe)}." 61 | logger.error(e) 62 | raise NotImplementedError(e) 63 | 64 | def calculate_torch( 65 | self, 66 | orig_iwe: torch.Tensor, 67 | forward_iwe: torch.Tensor, 68 | backward_iwe: torch.Tensor, 69 | middle_iwe: Optional[torch.Tensor], 70 | ) -> torch.Tensor: 71 | """Calculate bytorch 72 | Inputs: 73 | orig_iwe (torch.Tensor) ... Original IWE (before warp). 74 | forward_iwe (torch.Tensor) ... IWE to forward warp. 75 | backward_iwe (torch.Tensor) ... IWE to backward warp. 76 | middle_iwe (Optional[torch.Tensor]) ... IWE to middle warp. 77 | 78 | Returns: 79 | loss (torch.Tensor) ... average time loss. 80 | """ 81 | forward_loss = self.variance_loss.calculate_torch(forward_iwe, orig_iwe) 82 | backward_loss = self.variance_loss.calculate_torch(backward_iwe, orig_iwe) 83 | loss = forward_loss + backward_loss 84 | 85 | if middle_iwe is not None: 86 | loss += self.variance_loss.calculate_torch(middle_iwe, orig_iwe) * 2 87 | 88 | if self.direction in ["minimize", "natural"]: 89 | return loss 90 | logger.warning("The loss is specified as maximize direction") 91 | return -loss 92 | 93 | def calculate_numpy( 94 | self, 95 | orig_iwe: np.ndarray, 96 | forward_iwe: np.ndarray, 97 | backward_iwe: np.ndarray, 98 | middle_iwe: Optional[np.ndarray], 99 | ) -> float: 100 | """Calculate contrast of the count image. 101 | Inputs: 102 | orig_iwe (np.ndarray) ... Original IWE (before warp). 103 | forward_iwe (np.ndarray) ... IWE to forward warp. 104 | backward_iwe (np.ndarray) ... IWE to backward warp. 105 | middle_iwe (Optional[np.ndarray]) ... IWE to middle warp. 106 | 107 | Returns: 108 | loss (float) ... average time loss 109 | """ 110 | forward_loss = self.variance_loss.calculate_numpy(forward_iwe, orig_iwe) 111 | backward_loss = self.variance_loss.calculate_numpy(backward_iwe, orig_iwe) 112 | loss = forward_loss + backward_loss 113 | 114 | if middle_iwe is not None: 115 | loss += self.variance_loss.calculate_numpy(middle_iwe, orig_iwe) * 2 116 | 117 | if self.direction in ["minimize", "natural"]: 118 | return loss 119 | logger.warning("The loss is specified as maximize direction") 120 | return -loss 121 | -------------------------------------------------------------------------------- /src/costs/normalized_gradient_magnitude.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from ..types import FLOAT_TORCH 7 | from . import CostBase, GradientMagnitude 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class NormalizedGradientMagnitude(CostBase): 13 | """Normalized gradient magnitude. 14 | 15 | Args: 16 | direction (str) ... 'minimize' or 'maximize' or `natural`. 17 | Defines the objective function. If natural, it returns more interpretable value. 18 | """ 19 | 20 | name = "normalized_gradient_magnitude" 21 | required_keys = ["orig_iwe", "iwe", "omit_boundary"] 22 | 23 | def __init__( 24 | self, 25 | direction="minimize", 26 | store_history: bool = False, 27 | cuda_available=False, 28 | precision="32", 29 | *args, 30 | **kwargs, 31 | ): 32 | super().__init__(direction=direction, store_history=store_history) 33 | self.gradient_magnitude = GradientMagnitude( 34 | direction=direction, 35 | store_history=store_history, 36 | cuda_available=cuda_available, 37 | precision=precision, 38 | ) 39 | 40 | @CostBase.register_history # type: ignore 41 | @CostBase.catch_key_error # type: ignore 42 | def calculate(self, arg: dict) -> FLOAT_TORCH: 43 | """Calculate normalized gradiend magnitude of IWE. 44 | Inputs: 45 | iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of warped events 46 | orig_iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of original events 47 | omit_bondary (bool) ... Omit boundary if True. 48 | 49 | Returns: 50 | contrast (FLOAT_TORCH) ... contrast of the image. 51 | """ 52 | iwe = arg["iwe"] 53 | orig_iwe = arg["orig_iwe"] 54 | omit_boundary = arg["omit_boundary"] 55 | if isinstance(iwe, torch.Tensor): 56 | return self.calculate_torch(iwe, orig_iwe, omit_boundary) 57 | elif isinstance(iwe, np.ndarray): 58 | return self.calculate_numpy(iwe, orig_iwe, omit_boundary) 59 | e = f"Unsupported input type. {type(iwe)}." 60 | logger.error(e) 61 | raise NotImplementedError(e) 62 | 63 | def calculate_torch( 64 | self, iwe: torch.Tensor, orig_iwe: torch.Tensor, omit_boundary: bool 65 | ) -> torch.Tensor: 66 | """ 67 | Inputs: 68 | iwe (torch.Tensor) ... [W, H]. Image of warped events 69 | orig_iwe (torch.Tensor) ... [W, H]. Image of original events 70 | 71 | Returns: 72 | loss (torch.Tensor) ... contrast of the image. 73 | """ 74 | loss1 = self.gradient_magnitude.calculate_torch(iwe, omit_boundary) 75 | loss2 = self.gradient_magnitude.calculate_torch(orig_iwe, omit_boundary) 76 | if self.direction == "minimize": 77 | return loss2 / loss1 78 | logger.warning("The loss is specified as maximize direction") 79 | return loss1 / loss2 80 | 81 | def calculate_numpy(self, iwe: np.ndarray, orig_iwe: np.ndarray, omit_boundary: bool) -> float: 82 | """ 83 | Inputs: 84 | iwe (np.ndarray) ... [W, H]. Image of warped events 85 | orig_iwe (np.ndarray) ... [W, H]. Image of original events 86 | 87 | Returns: 88 | contrast (float) ... contrast of the image. 89 | """ 90 | loss1 = self.gradient_magnitude.calculate_numpy(iwe, omit_boundary) 91 | loss2 = self.gradient_magnitude.calculate_numpy(orig_iwe, omit_boundary) 92 | if self.direction == "minimize": 93 | return loss2 / loss1 94 | return loss1 / loss2 95 | -------------------------------------------------------------------------------- /src/costs/normalized_image_variance.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from . import CostBase 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class NormalizedImageVariance(CostBase): 13 | """Normalized image variance, 14 | a.k.a FWP (flow Warp Loss) by Stoffregen et al. ECCV 2020. 15 | Args: 16 | direction (str) ... 'minimize' or 'maximize' or `natural`. 17 | Defines the objective function. If natural, it returns more interpretable value. 18 | """ 19 | 20 | name = "normalized_image_variance" 21 | required_keys = ["orig_iwe", "iwe", "omit_boundary"] 22 | 23 | def __init__(self, direction="minimize", store_history: bool = False, *args, **kwargs): 24 | super().__init__(direction=direction, store_history=store_history) 25 | 26 | @CostBase.register_history # type: ignore 27 | @CostBase.catch_key_error # type: ignore 28 | def calculate(self, arg: dict) -> Union[float, torch.Tensor]: 29 | """Calculate the normalized contrast of the IWE. 30 | Inputs: 31 | iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of warped events 32 | orig_iwe (np.ndarray or torch.Tensor) ... [W, H]. Image of original events 33 | omit_bondary (bool) ... Omit boundary if True. 34 | 35 | Returns: 36 | contrast (Union[float, torch.Tensor]) ... normalized contrast of the image. 37 | """ 38 | iwe = arg["iwe"] 39 | orig_iwe = arg["orig_iwe"] 40 | if arg["omit_boundary"]: 41 | iwe = iwe[..., 1:-1, 1:-1] # omit boundary 42 | if isinstance(iwe, torch.Tensor): 43 | return self.calculate_torch(iwe, orig_iwe) 44 | elif isinstance(iwe, np.ndarray): 45 | return self.calculate_numpy(iwe, orig_iwe) 46 | e = f"Unsupported input type. {type(iwe)}." 47 | logger.error(e) 48 | raise NotImplementedError(e) 49 | 50 | def calculate_torch(self, iwe: torch.Tensor, orig_iwe: torch.Tensor) -> torch.Tensor: 51 | """Calculate the normalized contrast of the IWE. 52 | Inputs: 53 | iwe (torch.Tensor) ... [W, H]. Image of warped events 54 | orig_iwe (torch.Tensor) ... [W, H]. Image of original events 55 | 56 | Returns: 57 | loss (torch.Tensor) ... contrast of the image. 58 | """ 59 | loss1 = torch.var(iwe) 60 | loss2 = torch.var(orig_iwe) 61 | if self.direction == "minimize": 62 | return loss2 / loss1 63 | logger.warning("The loss is specified as maximize direction") 64 | return loss1 / loss2 65 | 66 | def calculate_numpy(self, iwe: np.ndarray, orig_iwe: np.ndarray) -> float: 67 | """Calculate the normalized contrast of the IWE. 68 | Inputs: 69 | iwe (np.ndarray) ... [W, H]. Image of warped events 70 | orig_iwe (np.ndarray) ... [W, H]. Image of original events 71 | 72 | Returns: 73 | contrast (float) ... contrast of the image. 74 | """ 75 | loss1 = np.var(iwe) 76 | loss2 = np.var(orig_iwe) 77 | if self.direction == "minimize": 78 | return loss2 / loss1 79 | return loss1 / loss2 80 | -------------------------------------------------------------------------------- /src/costs/total_variation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from ..utils import SobelTorch 9 | from . import CostBase 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class TotalVariation(CostBase): 15 | """Total Variation for regularizer 16 | Args: 17 | direction (str) ... 'minimize' or 'maximize' or `natural`. 18 | Defines the objective function. If natural, it returns more interpretable value. 19 | """ 20 | 21 | name = "total_variation" 22 | required_keys = ["flow", "omit_boundary"] 23 | 24 | def __init__( 25 | self, 26 | direction="minimize", 27 | store_history: bool = False, 28 | cuda_available=False, 29 | precision="32", 30 | *args, 31 | **kwargs, 32 | ): 33 | super().__init__(direction=direction, store_history=store_history) 34 | self.torch_sobel = SobelTorch( 35 | ksize=3, in_channels=2, cuda_available=cuda_available, precision=precision 36 | ) 37 | 38 | @CostBase.register_history # type: ignore 39 | @CostBase.catch_key_error # type: ignore 40 | def calculate(self, arg: dict) -> Union[float, torch.Tensor]: 41 | """Calculate total Variation of flow 42 | Inputs: 43 | flow (np.ndarray or torch.Tensor) ... [(b,) 2, W, H]. Flow of the image. 44 | omit_bondary (bool) ... Omit boundary if True. 45 | 46 | Returns: 47 | (Union[float, torch.Tensor]) ... Total variation of the flow. 48 | """ 49 | flow = arg["flow"] 50 | omit_boundary = arg["omit_boundary"] 51 | 52 | if isinstance(flow, torch.Tensor): 53 | return self.calculate_torch(flow, omit_boundary) 54 | elif isinstance(flow, np.ndarray): 55 | return self.calculate_numpy(flow, omit_boundary) 56 | e = f"Unsupported input type. {type(flow)}." 57 | logger.error(e) 58 | raise NotImplementedError(e) 59 | 60 | def calculate_torch(self, flow: torch.Tensor, omit_boundary: bool) -> torch.Tensor: 61 | """Calculate cost 62 | Inputs: 63 | flow (torch.Tensor) ... [(b,) 2, W, H]. Optical flow 64 | 65 | Returns: 66 | loss (torch.Tensor) ... Total Variation. 67 | """ 68 | sobel = self.get_sobel_image_torch(flow, omit_boundary) 69 | 70 | loss = torch.mean(torch.abs(sobel)) 71 | 72 | if self.direction == "minimize": 73 | return loss 74 | logger.warning("The loss is specified as maximize direction") 75 | return -loss 76 | 77 | def calculate_numpy(self, flow: np.ndarray, omit_boundary: bool) -> float: 78 | """Calculate cost 79 | Inputs: 80 | flow (np.ndarray) ... [(b,) 2, W, W]. Optical flow 81 | 82 | Returns: 83 | loss (float) ... Total variation. 84 | """ 85 | sobelxx, sobelxy, sobelyx, sobelyy = self.get_sobel_image_numpy(flow, omit_boundary) 86 | 87 | loss = np.mean( 88 | np.abs(sobelxx) + np.abs(sobelxy) + np.abs(sobelyx) + np.abs(sobelyy) 89 | ) # L1 version 90 | 91 | if self.direction == "minimize": 92 | return loss 93 | return -loss 94 | 95 | def visualize_sobel_image(self, sobel_image): 96 | sobel_image = np.abs(sobel_image) 97 | if len(sobel_image.shape) == 4: 98 | sobel_image = sobel_image[0] 99 | if len(sobel_image.shape) == 3: 100 | sobel_image = np.concatenate( 101 | [sobel_image[0], sobel_image[1], sobel_image[2], sobel_image[3]] 102 | ) 103 | sobel_image = ( 104 | (sobel_image - sobel_image.min()) / (sobel_image.max() - sobel_image.min()) * 255 105 | ) 106 | self.inter_visualizer.visualize_image( 107 | sobel_image.astype(np.uint8), file_prefix="total_variation" 108 | ) 109 | 110 | def get_sobel_image_torch(self, flow: torch.Tensor, omit_boundary: bool) -> torch.Tensor: 111 | """Calculate sobel of the flow. 112 | Inputs: 113 | flow (torch.Tensor) ... [(b,) 2, W, H]. Optical flow 114 | 115 | Returns: 116 | loss (torch.Tensor) ... [(b,), 4, W, H]. 4ch is 117 | [x-component dx, x-component dy, y-component dx, y-component dy] 118 | """ 119 | if len(flow.shape) == 3: 120 | flow = flow[None] # 1, 2, W, H 121 | sobel = self.torch_sobel(flow) / 8.0 122 | 123 | if omit_boundary: 124 | if sobel.shape[2] > 2 and sobel.shape[3] > 2: 125 | sobel = sobel[..., 1:-1, 1:-1] 126 | return sobel 127 | 128 | def get_sobel_image_numpy(self, flow: np.ndarray, omit_boundary: bool) -> tuple: 129 | """Calculate sobel images of the flow. 130 | Inputs: 131 | flow (np.ndarray) ... [(b,) 2, W, W]. Optical flow 132 | 133 | Returns: 134 | (tuple) ... [x-component dx, x-component dy, y-component dx, y-component dy] 135 | """ 136 | if len(flow.shape) == 4: 137 | raise NotImplementedError 138 | sobelxx = cv2.Sobel(flow[0], cv2.CV_64F, 1, 0, ksize=3) 139 | sobelxy = cv2.Sobel(flow[0], cv2.CV_64F, 0, 1, ksize=3) 140 | sobelyx = cv2.Sobel(flow[1], cv2.CV_64F, 1, 0, ksize=3) 141 | sobelyy = cv2.Sobel(flow[1], cv2.CV_64F, 0, 1, ksize=3) 142 | 143 | if omit_boundary: 144 | # Only for 3-dim array 145 | if sobelxx.shape[0] > 1 and sobelxx.shape[1] > 1: 146 | sobelxx = sobelxx[1:-1, 1:-1] 147 | sobelxy = sobelxy[1:-1, 1:-1] 148 | sobelyx = sobelyx[1:-1, 1:-1] 149 | sobelyy = sobelyy[1:-1, 1:-1] 150 | 151 | return sobelxx, sobelxy, sobelyx, sobelyy 152 | -------------------------------------------------------------------------------- /src/data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DATASET_ROOT_DIR = os.path.join( 4 | os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "datasets" 5 | ) 6 | 7 | from .base import DataLoaderBase 8 | 9 | # from .davis_data import DavisDataLoader 10 | # from .dsec import DsecDataLoader # TODO comes later.. 11 | from .mvsec import MvsecDataLoader 12 | 13 | 14 | # List of supported dataset 15 | def inheritors(klass): 16 | subclasses = set() 17 | work = [klass] 18 | while work: 19 | parent = work.pop() 20 | for child in parent.__subclasses__(): 21 | if child not in subclasses: 22 | subclasses.add(child) 23 | work.append(child) 24 | return subclasses 25 | 26 | 27 | collections = {k.NAME: k for k in inheritors(DataLoaderBase)} 28 | 29 | # DNN 30 | # TODO comes lates 31 | -------------------------------------------------------------------------------- /src/data_loader/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | 6 | from .. import utils 7 | from . import DATASET_ROOT_DIR 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class DataLoaderBase(object): 13 | """Base of the DataLoader class. 14 | Please make sure to implement 15 | - load_event() 16 | - get_sequence() 17 | in chile classes. 18 | """ 19 | 20 | NAME = "example" 21 | 22 | def __init__(self, config: dict = {}): 23 | self._HEIGHT = config["height"] 24 | self._WIDTH = config["width"] 25 | 26 | root_dir: str = config["root"] if config["root"] else DATASET_ROOT_DIR 27 | self.root_dir: str = os.path.expanduser(root_dir) 28 | data_dir: str = config["dataset"] if config["dataset"] else self.NAME 29 | 30 | self.dataset_dir: str = os.path.join(self.root_dir, data_dir) 31 | self.__dataset_files: dict = {} 32 | logger.info(f"Loading directory in {self.dataset_dir}") 33 | 34 | self.gt_flow_available: bool 35 | if utils.check_key_and_bool(config, "load_gt_flow"): 36 | self.gt_flow_dir: str = os.path.expanduser(config["gt"]) 37 | self.gt_flow_available = utils.check_file_utils(self.gt_flow_dir) 38 | else: 39 | self.gt_flow_available = False 40 | 41 | if utils.check_key_and_bool(config, "undistort"): 42 | logger.info("Undistort events when load_event.") 43 | self.auto_undistort = True 44 | else: 45 | logger.info("No undistortion.") 46 | self.auto_undistort = False 47 | 48 | @property 49 | def dataset_files(self) -> dict: 50 | return self.__dataset_files 51 | 52 | @dataset_files.setter 53 | def dataset_files(self, sequence: dict): 54 | self.__dataset_files = sequence 55 | 56 | def set_sequence(self, sequence_name: str) -> None: 57 | logger.info(f"Use sequence {sequence_name}") 58 | self.sequence_name = sequence_name 59 | self.dataset_files = self.get_sequence(sequence_name) 60 | 61 | def get_sequence(self, sequence_name: str) -> dict: 62 | raise NotImplementedError 63 | 64 | def load_event( 65 | self, start_index: int, end_index: int, cam: str = "left", *args, **kwargs 66 | ) -> np.ndarray: 67 | raise NotImplementedError 68 | 69 | def load_calib(self) -> dict: 70 | raise NotImplementedError 71 | 72 | def load_optical_flow(self, t1: float, t2: float, *args, **kwargs) -> np.ndarray: 73 | raise NotImplementedError 74 | 75 | def index_to_time(self, index: int) -> float: 76 | raise NotImplementedError 77 | 78 | def time_to_index(self, time: float) -> int: 79 | raise NotImplementedError 80 | -------------------------------------------------------------------------------- /src/data_loader/mvsec.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Tuple 4 | 5 | import h5py 6 | import numpy as np 7 | 8 | from ..utils import estimate_corresponding_gt_flow, undistort_events 9 | from . import DataLoaderBase 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | # hdf5 data loader 15 | def h5py_loader(path: str): 16 | """Basic loader for .hdf5 files. 17 | Args: 18 | path (str) ... Path to the .hdf5 file. 19 | 20 | Returns: 21 | timestamp (dict) ... Doctionary of numpy arrays. Keys are "left" / "right". 22 | davis_left (dict) ... "event": np.ndarray. 23 | davis_right (dict) ... "event": np.ndarray. 24 | """ 25 | data = h5py.File(path, "r") 26 | event_timestamp = get_timestamp_index(data) 27 | r = { 28 | "event": np.array(data["davis"]["right"]["events"], dtype=np.int16), 29 | } 30 | # 'gray_ts': np.array(data['davis']['right']['image_raw_ts'], dtype=np.float64) 31 | l = { 32 | "event": np.array(data["davis"]["left"]["events"], dtype=np.int16), 33 | "gray_ts": np.array(data["davis"]["left"]["image_raw_ts"], dtype=np.float64), 34 | } 35 | data.close() 36 | return event_timestamp, l, r 37 | 38 | 39 | def get_timestamp_index(h5py_data): 40 | """Timestampm loader for pre-fetching before actual sensor data loading. 41 | This is necessary for sync between sensors and decide which timestamp to 42 | be used as ground clock. 43 | Args: 44 | h5py_data... h5py object. 45 | Returns: 46 | timestamp (dict) ... Doctionary of numpy arrays. Keys are "left" / "right". 47 | """ 48 | timestamp = {} 49 | timestamp["right"] = np.array(h5py_data["davis"]["right"]["events"][:, 2]) 50 | timestamp["left"] = np.array(h5py_data["davis"]["left"]["events"][:, 2]) 51 | return timestamp 52 | 53 | 54 | class MvsecDataLoader(DataLoaderBase): 55 | """Dataloader class for MVSEC dataset.""" 56 | 57 | NAME = "MVSEC" 58 | 59 | def __init__(self, config: dict = {}): 60 | super().__init__(config) 61 | 62 | # Override 63 | def set_sequence(self, sequence_name: str, undistort: bool = False) -> None: 64 | logger.info(f"Use sequence {sequence_name}") 65 | self.sequence_name = sequence_name 66 | logger.info(f"Undistort events = {undistort}") 67 | 68 | self.dataset_files = self.get_sequence(sequence_name) 69 | ts, l_event, r_event = h5py_loader(self.dataset_files["event"]) 70 | self.left_event = l_event["event"] # int16 .. for smaller memory consumption. 71 | self.left_ts = ts["left"] # float64 72 | self.left_gray_ts = l_event["gray_ts"] # float64 73 | # self.right_event = r_event["event"] 74 | # self.right_ts = ts["right"] 75 | # self.right_gray_ts = r_event["gray_ts"] # float64 76 | 77 | # Setup GT 78 | if self.gt_flow_available: 79 | self.setup_gt_flow(os.path.join(self.gt_flow_dir, sequence_name)) 80 | self.omit_invalid_data(sequence_name) 81 | 82 | # Undistort - most likely necessary to run evaluation with GT. 83 | self.undistort = undistort 84 | if self.undistort: 85 | self.calib_map_x, self.calib_map_y = self.get_calib_map( 86 | self.dataset_files["calib_map_x"], self.dataset_files["calib_map_y"] 87 | ) 88 | 89 | # Setting up time suration statistics 90 | self.min_ts = self.left_ts.min() 91 | self.max_ts = self.left_ts.max() 92 | # self.min_ts = np.max([self.left_ts.min(), self.right_ts.min()]) 93 | # self.max_ts = np.min([self.left_ts.max(), self.right_ts.max()]) - 10.0 # not use last 1 sec 94 | self.data_duration = self.max_ts - self.min_ts 95 | 96 | def get_sequence(self, sequence_name: str) -> dict: 97 | """Get data inside a sequence. 98 | 99 | Inputs: 100 | sequence_name (str) ... name of the sequence. ex) `outdoot_day2`. 101 | 102 | Returns: 103 | sequence_file (dict) ... dictionary of the filenames for the sequence. 104 | """ 105 | data_path: str = os.path.join(self.root_dir, sequence_name) 106 | event_file = data_path + "_data.hdf5" 107 | calib_file_x = data_path[:-1] + "_left_x_map.txt" 108 | calib_file_y = data_path[:-1] + "_left_y_map.txt" 109 | sequence_file = { 110 | "event": event_file, 111 | "calib_map_x": calib_file_x, 112 | "calib_map_y": calib_file_y, 113 | } 114 | return sequence_file 115 | 116 | def setup_gt_flow(self, path): 117 | path = path + "_gt_flow_dist.npz" 118 | logger.info(f"Loading ground truth flow {path}") 119 | gt = np.load(path) 120 | self.gt_timestamps = gt["timestamps"] 121 | self.U_gt_all = gt["x_flow_dist"] 122 | self.V_gt_all = gt["y_flow_dist"] 123 | 124 | def free_up_flow(self): 125 | del self.gt_timestamps, self.U_gt_all, self.V_gt_all 126 | 127 | def omit_invalid_data(self, sequence_name: str): 128 | logger.info(f"Use only valid frames.") 129 | first_valid_gt_frame = 0 130 | last_valid_gt_frame = -1 131 | if "indoor_flying1" in sequence_name: 132 | first_valid_gt_frame = 60 133 | last_valid_gt_frame = 1340 134 | elif "indoor_flying2" in sequence_name: 135 | first_valid_gt_frame = 140 136 | last_valid_gt_frame = 1500 137 | elif "indoor_flying3" in sequence_name: 138 | first_valid_gt_frame = 100 139 | last_valid_gt_frame = 1711 140 | elif "indoor_flying4" in sequence_name: 141 | first_valid_gt_frame = 104 142 | last_valid_gt_frame = 380 143 | elif "outdoor_day1" in sequence_name: 144 | last_valid_gt_frame = 5020 145 | elif "outdoor_day2" in sequence_name: 146 | first_valid_gt_frame = 30 147 | # last_valid_gt_frame = 5020 148 | 149 | self.gt_timestamps = self.gt_timestamps[first_valid_gt_frame:last_valid_gt_frame] 150 | self.U_gt_all = self.U_gt_all[first_valid_gt_frame:last_valid_gt_frame] 151 | self.V_gt_all = self.V_gt_all[first_valid_gt_frame:last_valid_gt_frame] 152 | 153 | # Update event list 154 | first_event_index = self.time_to_index(self.gt_timestamps[0]) 155 | last_event_index = self.time_to_index(self.gt_timestamps[-1]) 156 | self.left_event = self.left_event[first_event_index:last_event_index] 157 | self.left_ts = self.left_ts[first_event_index:last_event_index] 158 | 159 | self.min_ts = self.left_ts.min() 160 | self.max_ts = self.left_ts.max() 161 | 162 | # Update gray frame ts 163 | self.left_gray_ts = self.left_gray_ts[ 164 | (self.gt_timestamps[0] < self.left_gray_ts) 165 | & (self.gt_timestamps[-1] > self.left_gray_ts) 166 | ] 167 | 168 | # self.right_event = self.right_event[first_event_index:last_event_index] 169 | # self.right_ts = self.right_ts[first_event_index:last_event_index] 170 | # self.right_gray_ts = self.right_gray_ts[ 171 | # (self.gt_timestamps[0] < self.right_gray_ts) 172 | # & (self.gt_timestamps[-1] > self.right_gray_ts) 173 | # ] 174 | 175 | def __len__(self): 176 | return len(self.left_event) 177 | 178 | def load_event(self, start_index: int, end_index: int, cam: str = "left") -> np.ndarray: 179 | """Load events. 180 | The original hdf5 file contains (x, y, t, p), 181 | where x means in width direction, and y means in height direction. p is -1 or 1. 182 | 183 | Returns: 184 | events (np.ndarray) ... Events. [x, y, t, p] where x is height. 185 | t is absolute value, in sec. p is [-1, 1]. 186 | """ 187 | n_events = end_index - start_index 188 | events = np.zeros((n_events, 4), dtype=np.float64) 189 | 190 | if cam == "left": 191 | if len(self.left_event) <= start_index: 192 | logger.error( 193 | f"Specified {start_index} to {end_index} index for {len(self.left_event)}." 194 | ) 195 | raise IndexError 196 | events[:, 0] = self.left_event[start_index:end_index, 1] 197 | events[:, 1] = self.left_event[start_index:end_index, 0] 198 | events[:, 2] = self.left_ts[start_index:end_index] 199 | events[:, 3] = self.left_event[start_index:end_index, 3] 200 | elif cam == "right": 201 | logger.error("Please select `left`as `cam` parameter.") 202 | raise NotImplementedError 203 | if self.undistort: 204 | events = undistort_events( 205 | events, self.calib_map_x, self.calib_map_y, self._HEIGHT, self._WIDTH 206 | ) 207 | return events 208 | 209 | # Optical flow (GT) 210 | def gt_time_list(self): 211 | return self.gt_timestamps 212 | 213 | def eval_frame_time_list(self): 214 | # In MVSEC, evaluation is based on gray frame timestamp. 215 | return self.left_gray_ts 216 | 217 | def index_to_time(self, index: int) -> float: 218 | return self.left_ts[index] 219 | 220 | def time_to_index(self, time: float) -> int: 221 | # inds = np.where(self.left_ts > time)[0] 222 | # if len(inds) == 0: 223 | # return len(self.left_ts) - 1 224 | # return inds[0] - 1 225 | ind = np.searchsorted(self.left_ts, time) 226 | return ind - 1 227 | 228 | def get_gt_time(self, index: int) -> tuple: 229 | """Get GT flow timestamp [floor, ceil] for a given index. 230 | 231 | Args: 232 | index (int): Index of the event 233 | 234 | Returns: 235 | tuple: [floor_gt, ceil_gt]. Both are synced with GT optical flow. 236 | """ 237 | inds = np.where(self.gt_timestamps > self.index_to_time(index))[0] 238 | if len(inds) == 0: 239 | return (self.gt_timestamps[-1], None) 240 | elif len(inds) == len(self.gt_timestamps): 241 | return (None, self.gt_timestamps[0]) 242 | else: 243 | return (self.gt_timestamps[inds[0] - 1], self.gt_timestamps[inds[0]]) 244 | 245 | def load_optical_flow(self, t1: float, t2: float) -> np.ndarray: 246 | """Load GT Optical flow based on timestamp. 247 | Note: this is pixel displacement. 248 | Note: the args are not indices, but timestamps. 249 | 250 | Args: 251 | t1 (float): [description] 252 | t2 (float): [description] 253 | 254 | Returns: 255 | [np.ndarray]: H x W x 2. Be careful that the 2 ch is [height, width] direction component. 256 | """ 257 | U_gt, V_gt = estimate_corresponding_gt_flow( 258 | self.U_gt_all, 259 | self.V_gt_all, 260 | self.gt_timestamps, 261 | t1, 262 | t2, 263 | ) 264 | gt_flow = np.stack((V_gt, U_gt), axis=2) 265 | return gt_flow 266 | 267 | def load_calib(self) -> dict: 268 | """Load calibration file. 269 | 270 | Outputs: 271 | (dict) ... {"K": camera_matrix, "D": distortion_coeff} 272 | camera_matrix (np.ndarray) ... [3 x 3] matrix. 273 | distortion_coeff (np.array) ... [5] array. 274 | """ 275 | logger.warning("directly load calib_param is not implemented!! please use rectify instead.") 276 | outdoor_K = np.array( 277 | [ 278 | [223.9940010790056, 0, 170.7684322973841, 0], 279 | [0, 223.61783486959376, 128.18711828338436, 0], 280 | [0, 0, 1, 0], 281 | [0, 0, 0, 1], 282 | ], 283 | dtype=np.float32, 284 | ) 285 | return {"K": outdoor_K} 286 | 287 | def get_calib_map(self, map_txt_x, map_txt_y): 288 | """Intrinsic calibration parameter file loader. 289 | Args: 290 | map_txt... file path. 291 | Returns 292 | map_array (np.array)... map array. 293 | """ 294 | map_x = self.load_map_txt(map_txt_x) 295 | map_y = self.load_map_txt(map_txt_y) 296 | return map_x, map_y 297 | 298 | def load_map_txt(self, map_txt): 299 | f = open(map_txt, "r") 300 | line = f.readlines() 301 | map_array = np.zeros((self._HEIGHT, self._WIDTH)) 302 | for i, l in enumerate(line): 303 | map_array[i] = np.array([float(k) for k in l.split()]) 304 | return map_array 305 | -------------------------------------------------------------------------------- /src/event_image_converter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from scipy.ndimage.filters import gaussian_filter 7 | from torchvision.transforms.functional import gaussian_blur 8 | 9 | from .types import FLOAT_TORCH, NUMPY_TORCH, is_numpy, is_torch 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class EventImageConverter(object): 15 | """Converter class of image into many different representations as an image. 16 | 17 | Args: 18 | image_size (tuple)... (H, W) 19 | outer_padding (int, or tuple) ... Padding to outer to the conversion. This tries to 20 | avoid events go out of the image. 21 | """ 22 | 23 | def __init__(self, image_size: tuple, outer_padding: Union[int, Tuple[int, int]] = 0): 24 | if isinstance(outer_padding, (int, float)): 25 | self.outer_padding = (int(outer_padding), int(outer_padding)) 26 | else: 27 | self.outer_padding = outer_padding 28 | self.image_size = tuple(int(i + p * 2) for i, p in zip(image_size, self.outer_padding)) 29 | 30 | def update_property( 31 | self, 32 | image_size: Optional[tuple] = None, 33 | outer_padding: Optional[Union[int, Tuple[int, int]]] = None, 34 | ): 35 | if image_size is not None: 36 | self.image_size = image_size 37 | if outer_padding is not None: 38 | if isinstance(outer_padding, int): 39 | self.outer_padding = (outer_padding, outer_padding) 40 | else: 41 | self.outer_padding = outer_padding 42 | self.image_size = tuple(i + p for i, p in zip(self.image_size, self.outer_padding)) 43 | 44 | # Higher layer functions 45 | def create_iwe( 46 | self, 47 | events: NUMPY_TORCH, 48 | method: str = "bilinear_vote", 49 | sigma: int = 1, 50 | ) -> NUMPY_TORCH: 51 | """Create Image of Warped Events (IWE). 52 | 53 | Args: 54 | events (NUMPY_TORCH): [(b,) n_events, 4] 55 | method (str): [description] 56 | sigma (float): [description] 57 | 58 | Returns: 59 | NUMPY_TORCH: [(b,) H, W] 60 | """ 61 | if is_numpy(events): 62 | return self.create_image_from_events_numpy(events, method, sigma=sigma) 63 | elif is_torch(events): 64 | return self.create_image_from_events_tensor(events, method, sigma=sigma) 65 | e = f"Non-supported type of events. {type(events)}" 66 | logger.error(e) 67 | raise RuntimeError(e) 68 | 69 | def create_eventmask(self, events: NUMPY_TORCH) -> NUMPY_TORCH: 70 | """Create mask image where at least one event exists. 71 | 72 | Args: 73 | events (NUMPY_TORCH): [(b,) n_events, 4] 74 | 75 | Returns: 76 | NUMPY_TORCH: [(b,) 1, H, W] 77 | """ 78 | if is_numpy(events): 79 | return (0 != self.create_image_from_events_numpy(events, sigma=0))[..., None, :, :] 80 | elif is_torch(events): 81 | return (0 != self.create_image_from_events_tensor(events, sigma=0))[..., None, :, :] 82 | raise RuntimeError 83 | 84 | # Lower layer functions 85 | # Image creation functions 86 | def create_image_from_events_numpy( 87 | self, 88 | events: np.ndarray, 89 | method: str = "bilinear_vote", 90 | weight: Union[float, np.ndarray] = 1.0, 91 | sigma: int = 1, 92 | ) -> np.ndarray: 93 | """Create image of events for numpy array. 94 | 95 | Inputs: 96 | events (np.ndarray) ... [(b,) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float. 97 | Also, x is height dimension and y is the width dimension. 98 | method (str) ... method to accumulate events. "count", "bilinear_vote", "polarity", etc. 99 | weight (float or np.ndarray) ... Only applicable when method = "bilinear_vote". 100 | sigma (int) ... Sigma for the gaussian blur. 101 | 102 | Returns: 103 | image ... [(b,) H, W]. Each index indicates the sum of the event, based on the specified method. 104 | """ 105 | if method == "count": 106 | image = self.count_event_numpy(events) 107 | elif method == "bilinear_vote": 108 | image = self.bilinear_vote_numpy(events, weight=weight) 109 | elif method == "polarity": # TODO implement 110 | pos_flag = events[..., 3] > 0 111 | if is_numpy(weight): 112 | pos_image = self.bilinear_vote_numpy(events[pos_flag], weight=weight[pos_flag]) 113 | neg_image = self.bilinear_vote_numpy(events[~pos_flag], weight=weight[~pos_flag]) 114 | else: 115 | pos_image = self.bilinear_vote_numpy(events[pos_flag], weight=weight) 116 | neg_image = self.bilinear_vote_numpy(events[~pos_flag], weight=weight) 117 | image = np.stack([pos_image, neg_image], axis=-3) 118 | else: 119 | e = f"{method = } is not supported." 120 | logger.error(e) 121 | raise NotImplementedError(e) 122 | if sigma > 0: 123 | image = gaussian_filter(image, sigma) 124 | return image 125 | 126 | def create_image_from_events_tensor( 127 | self, 128 | events: torch.Tensor, 129 | method: str = "bilinear_vote", 130 | weight: FLOAT_TORCH = 1.0, 131 | sigma: int = 0, 132 | ) -> torch.Tensor: 133 | """Create image of events for tensor array. 134 | 135 | Inputs: 136 | events (torch.Tensor) ... [(b, ) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float. 137 | Also, x is the width dimension and y is the height dimension. 138 | method (str) ... method to accumulate events. "count", "bilinear_vote", "polarity", etc. 139 | weight (float or torch.Tensor) ... Only applicable when method = "bilinear_vote". 140 | sigma (int) ... Sigma for the gaussian blur. 141 | 142 | Returns: 143 | image ... [(b, ) W, H]. Each index indicates the sum of the event, based on the specified method. 144 | """ 145 | if method == "count": 146 | image = self.count_event_tensor(events) 147 | elif method == "bilinear_vote": 148 | image = self.bilinear_vote_tensor(events, weight=weight) 149 | else: 150 | e = f"{method = } is not implemented" 151 | logger.error(e) 152 | raise NotImplementedError(e) 153 | if sigma > 0: 154 | if len(image.shape) == 2: 155 | image = image[None, None, ...] 156 | elif len(image.shape) == 3: 157 | image = image[:, None, ...] 158 | image = gaussian_blur(image, kernel_size=3, sigma=sigma) 159 | return torch.squeeze(image) 160 | 161 | def count_event_numpy(self, events: np.ndarray): 162 | """Count event and make image. 163 | 164 | Args: 165 | events ... [(b,) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float. 166 | 167 | Returns: 168 | image ... [(b,) W, H]. Each index indicates the sum of the event, just counting. 169 | """ 170 | if len(events.shape) == 2: 171 | events = events[None, ...] # 1 x n x 4 172 | 173 | # x-y is height-width 174 | ph, pw = self.outer_padding 175 | h, w = self.image_size 176 | nb = len(events) 177 | image = np.zeros((nb, h * w), dtype=np.float64) 178 | 179 | floor_xy = np.floor(events[..., :2] + 1e-8) 180 | floor_to_xy = events[..., :2] - floor_xy 181 | 182 | x1 = floor_xy[..., 1] + pw 183 | y1 = floor_xy[..., 0] + ph 184 | inds = np.concatenate( 185 | [ 186 | x1 + y1 * w, 187 | x1 + (y1 + 1) * w, 188 | (x1 + 1) + y1 * w, 189 | (x1 + 1) + (y1 + 1) * w, 190 | ], 191 | axis=-1, 192 | ) 193 | inds_mask = np.concatenate( 194 | [ 195 | (0 <= x1) * (x1 < w) * (0 <= y1) * (y1 < h), 196 | (0 <= x1) * (x1 < w) * (0 <= y1 + 1) * (y1 + 1 < h), 197 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1) * (y1 < h), 198 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1 + 1) * (y1 + 1 < h), 199 | ], 200 | axis=-1, 201 | ) 202 | vals = np.ones_like(inds) 203 | inds = (inds * inds_mask).astype(np.int64) 204 | vals = vals * inds_mask 205 | for i in range(nb): 206 | np.add.at(image[i], inds[i], vals[i]) 207 | return image.reshape((nb,) + self.image_size).squeeze() 208 | 209 | def count_event_tensor(self, events: torch.Tensor): 210 | """Tensor version of `count_event_numpy().` 211 | 212 | Args: 213 | events (torch.Tensor) ... [(b,) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float. 214 | 215 | Returns: 216 | image ... [(b,) H, W]. Each index indicates the bilinear vote result. If the outer_padding is set, 217 | the return size will be [H + outer_padding, W + outer_padding]. 218 | """ 219 | if len(events.shape) == 2: 220 | events = events[None, ...] # 1 x n x 4 221 | 222 | ph, pw = self.outer_padding 223 | h, w = self.image_size 224 | nb = len(events) 225 | image = events.new_zeros((nb, h * w)) 226 | 227 | floor_xy = torch.floor(events[..., :2] + 1e-6) 228 | floor_to_xy = events[..., :2] - floor_xy 229 | floor_xy = floor_xy.long() 230 | 231 | x1 = floor_xy[..., 1] + pw 232 | y1 = floor_xy[..., 0] + ph 233 | inds = torch.cat( 234 | [ 235 | x1 + y1 * w, 236 | x1 + (y1 + 1) * w, 237 | (x1 + 1) + y1 * w, 238 | (x1 + 1) + (y1 + 1) * w, 239 | ], 240 | dim=-1, 241 | ) # [(b, ) n_events x 4] 242 | inds_mask = torch.cat( 243 | [ 244 | (0 <= x1) * (x1 < w) * (0 <= y1) * (y1 < h), 245 | (0 <= x1) * (x1 < w) * (0 <= y1 + 1) * (y1 + 1 < h), 246 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1) * (y1 < h), 247 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1 + 1) * (y1 + 1 < h), 248 | ], 249 | axis=-1, 250 | ) 251 | vals = torch.ones_like(inds) 252 | inds = (inds * inds_mask).long() 253 | vals = vals * inds_mask 254 | image.scatter_add_(1, inds, vals) 255 | return image.reshape((nb,) + self.image_size).squeeze() 256 | 257 | def bilinear_vote_numpy(self, events: np.ndarray, weight: Union[float, np.ndarray] = 1.0): 258 | """Use bilinear voting to and make image. 259 | 260 | Args: 261 | events (np.ndarray) ... [(b, ) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float. 262 | weight (float or np.ndarray) ... Weight to multiply to the voting value. 263 | If scalar, the weight is all the same among events. 264 | If it's array-like, it should be the shape of [n_events]. 265 | Defaults to 1.0. 266 | 267 | Returns: 268 | image ... [(b, ) H, W]. Each index indicates the bilinear vote result. If the outer_padding is set, 269 | the return size will be [H + outer_padding, W + outer_padding]. 270 | """ 271 | if type(weight) == np.ndarray: 272 | assert weight.shape == events.shape[:-1] 273 | if len(events.shape) == 2: 274 | events = events[None, ...] # 1 x n x 4 275 | 276 | # x-y is height-width 277 | ph, pw = self.outer_padding 278 | h, w = self.image_size 279 | nb = len(events) 280 | image = np.zeros((nb, h * w), dtype=np.float64) 281 | 282 | floor_xy = np.floor(events[..., :2] + 1e-8) 283 | floor_to_xy = events[..., :2] - floor_xy 284 | 285 | x1 = floor_xy[..., 1] + pw 286 | y1 = floor_xy[..., 0] + ph 287 | inds = np.concatenate( 288 | [ 289 | x1 + y1 * w, 290 | x1 + (y1 + 1) * w, 291 | (x1 + 1) + y1 * w, 292 | (x1 + 1) + (y1 + 1) * w, 293 | ], 294 | axis=-1, 295 | ) 296 | inds_mask = np.concatenate( 297 | [ 298 | (0 <= x1) * (x1 < w) * (0 <= y1) * (y1 < h), 299 | (0 <= x1) * (x1 < w) * (0 <= y1 + 1) * (y1 + 1 < h), 300 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1) * (y1 < h), 301 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1 + 1) * (y1 + 1 < h), 302 | ], 303 | axis=-1, 304 | ) 305 | w_pos0 = (1 - floor_to_xy[..., 0]) * (1 - floor_to_xy[..., 1]) * weight 306 | w_pos1 = floor_to_xy[..., 0] * (1 - floor_to_xy[..., 1]) * weight 307 | w_pos2 = (1 - floor_to_xy[..., 0]) * floor_to_xy[..., 1] * weight 308 | w_pos3 = floor_to_xy[..., 0] * floor_to_xy[..., 1] * weight 309 | vals = np.concatenate([w_pos0, w_pos1, w_pos2, w_pos3], axis=-1) 310 | inds = (inds * inds_mask).astype(np.int64) 311 | vals = vals * inds_mask 312 | for i in range(nb): 313 | np.add.at(image[i], inds[i], vals[i]) 314 | return image.reshape((nb,) + self.image_size).squeeze() 315 | 316 | def bilinear_vote_tensor(self, events: torch.Tensor, weight: FLOAT_TORCH = 1.0): 317 | """Tensor version of `bilinear_vote_numpy().` 318 | 319 | Args: 320 | events (torch.Tensor) ... [(b,) n_events, 4] Batch of events. 4 is (x, y, t, p). Attention that (x, y) could float. 321 | weight (float or torch.Tensor) ... Weight to multiply to the voting value. 322 | If scalar, the weight is all the same among events. 323 | If it's array-like, it should be the shape of [(b,) n_events]. 324 | Defaults to 1.0. 325 | 326 | Returns: 327 | image ... [(b,) H, W]. Each index indicates the bilinear vote result. If the outer_padding is set, 328 | the return size will be [H + outer_padding, W + outer_padding]. 329 | """ 330 | if type(weight) == torch.Tensor: 331 | assert weight.shape == events.shape[:-1] 332 | if len(events.shape) == 2: 333 | events = events[None, ...] # 1 x n x 4 334 | 335 | ph, pw = self.outer_padding 336 | h, w = self.image_size 337 | nb = len(events) 338 | image = events.new_zeros((nb, h * w)) 339 | 340 | floor_xy = torch.floor(events[..., :2] + 1e-6) 341 | floor_to_xy = events[..., :2] - floor_xy 342 | floor_xy = floor_xy.long() 343 | 344 | x1 = floor_xy[..., 1] + pw 345 | y1 = floor_xy[..., 0] + ph 346 | inds = torch.cat( 347 | [ 348 | x1 + y1 * w, 349 | x1 + (y1 + 1) * w, 350 | (x1 + 1) + y1 * w, 351 | (x1 + 1) + (y1 + 1) * w, 352 | ], 353 | dim=-1, 354 | ) # [(b, ) n_events x 4] 355 | inds_mask = torch.cat( 356 | [ 357 | (0 <= x1) * (x1 < w) * (0 <= y1) * (y1 < h), 358 | (0 <= x1) * (x1 < w) * (0 <= y1 + 1) * (y1 + 1 < h), 359 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1) * (y1 < h), 360 | (0 <= x1 + 1) * (x1 + 1 < w) * (0 <= y1 + 1) * (y1 + 1 < h), 361 | ], 362 | axis=-1, 363 | ) 364 | 365 | w_pos0 = (1 - floor_to_xy[..., 0]) * (1 - floor_to_xy[..., 1]) * weight 366 | w_pos1 = floor_to_xy[..., 0] * (1 - floor_to_xy[..., 1]) * weight 367 | w_pos2 = (1 - floor_to_xy[..., 0]) * floor_to_xy[..., 1] * weight 368 | w_pos3 = floor_to_xy[..., 0] * floor_to_xy[..., 1] * weight 369 | vals = torch.cat([w_pos0, w_pos1, w_pos2, w_pos3], dim=-1) # [(b,) n_events x 4] 370 | 371 | inds = (inds * inds_mask).long() 372 | vals = vals * inds_mask 373 | image.scatter_add_(1, inds, vals) 374 | return image.reshape((nb,) + self.image_size).squeeze() 375 | -------------------------------------------------------------------------------- /src/feature_calculator.py: -------------------------------------------------------------------------------- 1 | # Mock code 2 | # Feature calculation is not necessary 3 | import logging 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class FeatureCalculatorMock: 9 | def __init__(self, *args, **kwargs): 10 | """Mock class -- please ignore.""" 11 | logger.warning("Feature calculation is disabled in this source code.") 12 | pass 13 | 14 | def skip(self): 15 | feature = { 16 | "none": {"per_event": True, "value": None}, 17 | } 18 | return feature 19 | 20 | def calculate_feature(self, *args, skip: bool = False, **kwargs) -> dict: 21 | """Mock function.""" 22 | return self.skip() 23 | -------------------------------------------------------------------------------- /src/solver/__init__.py: -------------------------------------------------------------------------------- 1 | """isort:skip_file 2 | """ 3 | 4 | # Non DNN 5 | from .base import SolverBase 6 | 7 | # from .contrast_maximization import ContrastMaximization 8 | from .patch_contrast_mixed import MixedPatchContrastMaximization 9 | from .time_aware_patch_contrast import TimeAwarePatchContrastMaximization 10 | from .patch_contrast_pyramid import PyramidalPatchContrastMaximization 11 | 12 | 13 | # List of supported solver - non DNN 14 | collections = { 15 | # CMax and variants 16 | # "contrast_maximization": ContrastMaximization, 17 | "pyramidal_patch_contrast_maximization": PyramidalPatchContrastMaximization, 18 | "time_aware_mixed_patch_contrast_maximization": TimeAwarePatchContrastMaximization, 19 | } 20 | -------------------------------------------------------------------------------- /src/solver/nnmodels/__init__.py: -------------------------------------------------------------------------------- 1 | from .ev_flownet import EVFlowNet 2 | -------------------------------------------------------------------------------- /src/solver/nnmodels/basic_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class build_resnet_block(nn.Module): 6 | """ 7 | a resnet block which includes two general_conv2d 8 | """ 9 | 10 | def __init__(self, channels, layers=2, do_batch_norm=False): 11 | super().__init__() 12 | self._channels = channels 13 | self._layers = layers 14 | 15 | self.res_block = nn.Sequential( 16 | *[ 17 | general_conv2d( 18 | in_channels=self._channels, 19 | out_channels=self._channels, 20 | strides=1, 21 | do_batch_norm=do_batch_norm, 22 | ) 23 | for i in range(self._layers) 24 | ] 25 | ) 26 | 27 | def forward(self, input_res): 28 | inputs = input_res.clone() 29 | input_res = self.res_block(input_res) 30 | return input_res + inputs 31 | 32 | 33 | class upsample_conv2d_and_predict_flow_or_depth(nn.Module): 34 | """ 35 | an upsample convolution layer which includes a nearest interpolate and a general_conv2d 36 | """ 37 | 38 | def __init__( 39 | self, in_channels, out_channels, ksize=3, do_batch_norm=False, type="flow", scale=256.0 40 | ): 41 | super().__init__() 42 | self._in_channels = in_channels 43 | self._out_channels = out_channels 44 | self._ksize = ksize 45 | self._do_batch_norm = do_batch_norm 46 | self._flow_or_depth = type 47 | self._scale = scale 48 | 49 | self.general_conv2d = general_conv2d( 50 | in_channels=self._in_channels, 51 | out_channels=self._out_channels, 52 | ksize=self._ksize, 53 | strides=1, 54 | do_batch_norm=self._do_batch_norm, 55 | padding=0, 56 | ) 57 | 58 | self.pad = nn.ReflectionPad2d( 59 | padding=( 60 | int((self._ksize - 1) / 2), 61 | int((self._ksize - 1) / 2), 62 | int((self._ksize - 1) / 2), 63 | int((self._ksize - 1) / 2), 64 | ) 65 | ) 66 | 67 | if self._flow_or_depth == "flow": 68 | self.predict = general_conv2d( 69 | in_channels=self._out_channels, 70 | out_channels=2, 71 | ksize=1, 72 | strides=1, 73 | padding=0, 74 | activation="tanh", 75 | ) 76 | elif self._flow_or_depth == "depth": 77 | self.predict = general_conv2d( 78 | in_channels=self._out_channels, 79 | out_channels=1, 80 | ksize=1, 81 | strides=1, 82 | padding=0, 83 | activation="sigmoid", 84 | ) 85 | else: 86 | raise NotImplementedError("flow or depth?") 87 | 88 | def forward(self, conv): 89 | """ 90 | Returns: 91 | feature 92 | pred (tensor) ... [N, ch, H, W]; 2 ch (flow) or 1 ch (depth). 93 | """ 94 | shape = conv.shape 95 | conv = nn.functional.interpolate( 96 | # conv, size=[shape[2] * 2, shape[3] * 2], mode="nearest" 97 | conv, 98 | size=[shape[2] * 2, shape[3] * 2], 99 | mode="bilinear", 100 | ) 101 | conv = self.pad(conv) 102 | conv = self.general_conv2d(conv) 103 | pred = self.predict(conv) * self._scale 104 | return torch.cat([conv, pred.clone()], dim=1), pred 105 | 106 | 107 | def general_conv2d( 108 | in_channels, out_channels, ksize=3, strides=2, padding=1, do_batch_norm=False, activation="relu" 109 | ): 110 | """ 111 | a general convolution layer which includes a conv2d, a relu and a batch_normalize 112 | """ 113 | if activation == "relu": 114 | if do_batch_norm: 115 | conv2d = nn.Sequential( 116 | nn.Conv2d( 117 | in_channels=in_channels, 118 | out_channels=out_channels, 119 | kernel_size=ksize, 120 | stride=strides, 121 | padding=padding, 122 | ), 123 | nn.ReLU(inplace=True), 124 | nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.99), 125 | ) 126 | else: 127 | conv2d = nn.Sequential( 128 | nn.Conv2d( 129 | in_channels=in_channels, 130 | out_channels=out_channels, 131 | kernel_size=ksize, 132 | stride=strides, 133 | padding=padding, 134 | ), 135 | nn.ReLU(inplace=True), 136 | ) 137 | elif activation == "tanh": 138 | if do_batch_norm: 139 | conv2d = nn.Sequential( 140 | nn.Conv2d( 141 | in_channels=in_channels, 142 | out_channels=out_channels, 143 | kernel_size=ksize, 144 | stride=strides, 145 | padding=padding, 146 | ), 147 | nn.Tanh(), 148 | nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.99), 149 | ) 150 | else: 151 | conv2d = nn.Sequential( 152 | nn.Conv2d( 153 | in_channels=in_channels, 154 | out_channels=out_channels, 155 | kernel_size=ksize, 156 | stride=strides, 157 | padding=padding, 158 | ), 159 | nn.Tanh(), 160 | ) 161 | elif activation == "sigmoid": 162 | if do_batch_norm: 163 | conv2d = nn.Sequential( 164 | nn.Conv2d( 165 | in_channels=in_channels, 166 | out_channels=out_channels, 167 | kernel_size=ksize, 168 | stride=strides, 169 | padding=padding, 170 | ), 171 | nn.Sigmoid(), 172 | nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.99), 173 | ) 174 | else: 175 | conv2d = nn.Sequential( 176 | nn.Conv2d( 177 | in_channels=in_channels, 178 | out_channels=out_channels, 179 | kernel_size=ksize, 180 | stride=strides, 181 | padding=padding, 182 | ), 183 | nn.Sigmoid(), 184 | ) 185 | return conv2d 186 | 187 | 188 | import torch 189 | import torch.nn as nn 190 | import torch.nn.functional as f 191 | 192 | 193 | class ConvLayer(nn.Module): 194 | """ 195 | Convolutional layer. 196 | Default: bias, ReLU, no downsampling, no batch norm. 197 | """ 198 | 199 | def __init__( 200 | self, 201 | in_channels, 202 | out_channels, 203 | kernel_size, 204 | stride=1, 205 | padding=0, 206 | activation="relu", 207 | norm=None, 208 | BN_momentum=0.1, 209 | ): 210 | super(ConvLayer, self).__init__() 211 | 212 | bias = False if norm == "BN" else True 213 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 214 | if activation is not None: 215 | self.activation = getattr(torch, activation) 216 | else: 217 | self.activation = None 218 | 219 | self.norm = norm 220 | if norm == "BN": 221 | self.norm_layer = nn.BatchNorm2d(out_channels, momentum=BN_momentum) 222 | elif norm == "IN": 223 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 224 | 225 | def forward(self, x): 226 | out = self.conv2d(x) 227 | 228 | if self.norm in ["BN", "IN"]: 229 | out = self.norm_layer(out) 230 | 231 | if self.activation is not None: 232 | out = self.activation(out) 233 | 234 | return out 235 | 236 | 237 | class TransposedConvLayer(nn.Module): 238 | """ 239 | Transposed convolutional layer to increase spatial resolution (x2) in a decoder. 240 | Default: bias, ReLU, no downsampling, no batch norm. 241 | """ 242 | 243 | def __init__( 244 | self, 245 | in_channels, 246 | out_channels, 247 | kernel_size, 248 | padding=0, 249 | activation="relu", 250 | norm=None, 251 | ): 252 | super(TransposedConvLayer, self).__init__() 253 | 254 | bias = False if norm == "BN" else True 255 | self.transposed_conv2d = nn.ConvTranspose2d( 256 | in_channels, 257 | out_channels, 258 | kernel_size, 259 | stride=2, 260 | padding=padding, 261 | output_padding=1, 262 | bias=bias, 263 | ) 264 | 265 | if activation is not None: 266 | self.activation = getattr(torch, activation) 267 | else: 268 | self.activation = None 269 | 270 | self.norm = norm 271 | if norm == "BN": 272 | self.norm_layer = nn.BatchNorm2d(out_channels) 273 | elif norm == "IN": 274 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 275 | 276 | def forward(self, x): 277 | out = self.transposed_conv2d(x) 278 | 279 | if self.norm in ["BN", "IN"]: 280 | out = self.norm_layer(out) 281 | 282 | if self.activation is not None: 283 | out = self.activation(out) 284 | 285 | return out 286 | 287 | 288 | class UpsampleConvLayer(nn.Module): 289 | """ 290 | Upsampling layer (bilinear interpolation + Conv2d) to increase spatial resolution (x2) in a decoder. 291 | Default: bias, ReLU, no downsampling, no batch norm. 292 | """ 293 | 294 | def __init__( 295 | self, 296 | in_channels, 297 | out_channels, 298 | kernel_size, 299 | stride=1, 300 | padding=0, 301 | activation="relu", 302 | norm=None, 303 | ): 304 | super(UpsampleConvLayer, self).__init__() 305 | 306 | bias = False if norm == "BN" else True 307 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 308 | 309 | if activation is not None: 310 | self.activation = getattr(torch, activation) 311 | else: 312 | self.activation = None 313 | 314 | self.norm = norm 315 | if norm == "BN": 316 | self.norm_layer = nn.BatchNorm2d(out_channels) 317 | elif norm == "IN": 318 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 319 | 320 | def forward(self, x): 321 | x_upsampled = f.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) 322 | out = self.conv2d(x_upsampled) 323 | 324 | if self.norm in ["BN", "IN"]: 325 | out = self.norm_layer(out) 326 | 327 | if self.activation is not None: 328 | out = self.activation(out) 329 | 330 | return out 331 | 332 | 333 | class ResidualBlock(nn.Module): 334 | """ 335 | Residual block as in "Deep residual learning for image recognition", He et al. 2016. 336 | Default: bias, ReLU, no downsampling, no batch norm, ConvLSTM. 337 | """ 338 | 339 | def __init__( 340 | self, 341 | in_channels, 342 | out_channels, 343 | stride=1, 344 | downsample=None, 345 | norm=None, 346 | BN_momentum=0.1, 347 | ): 348 | super(ResidualBlock, self).__init__() 349 | bias = False if norm == "BN" else True 350 | self.conv1 = nn.Conv2d( 351 | in_channels, 352 | out_channels, 353 | kernel_size=3, 354 | stride=stride, 355 | padding=1, 356 | bias=bias, 357 | ) 358 | self.norm = norm 359 | if norm == "BN": 360 | self.bn1 = nn.BatchNorm2d(out_channels, momentum=BN_momentum) 361 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=BN_momentum) 362 | elif norm == "IN": 363 | self.bn1 = nn.InstanceNorm2d(out_channels) 364 | self.bn2 = nn.InstanceNorm2d(out_channels) 365 | 366 | self.relu = nn.ReLU(inplace=True) 367 | self.conv2 = nn.Conv2d( 368 | out_channels, 369 | out_channels, 370 | kernel_size=3, 371 | stride=1, 372 | padding=1, 373 | bias=bias, 374 | ) 375 | self.downsample = downsample 376 | 377 | def forward(self, x): 378 | residual = x 379 | out = self.conv1(x) 380 | if self.norm in ["BN", "IN"]: 381 | out = self.bn1(out) 382 | out = self.relu(out) 383 | out = self.conv2(out) 384 | if self.norm in ["BN", "IN"]: 385 | out = self.bn2(out) 386 | 387 | if self.downsample: 388 | residual = self.downsample(x) 389 | 390 | out += residual 391 | out = self.relu(out) 392 | return out 393 | 394 | 395 | def skip_concat(x1, x2): 396 | diffY = x2.size()[2] - x1.size()[2] 397 | diffX = x2.size()[3] - x1.size()[3] 398 | padding = nn.ZeroPad2d((diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) 399 | x1 = padding(x1) 400 | return torch.cat([x1, x2], dim=1) 401 | 402 | 403 | def skip_sum(x1, x2): 404 | diffY = x2.size()[2] - x1.size()[2] 405 | diffX = x2.size()[3] - x1.size()[3] 406 | padding = nn.ZeroPad2d((diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) 407 | x1 = padding(x1) 408 | return x1 + x2 409 | -------------------------------------------------------------------------------- /src/solver/nnmodels/ev_flownet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from ... import utils 5 | from . import basic_layers 6 | 7 | _BASE_CHANNELS = 64 8 | 9 | 10 | class EVFlowNet(nn.Module): 11 | """EV-FlowNet definition 12 | Code is obtained from https://github.com/CyrilSterling/EVFlowNet-pytorch 13 | Thanks to the author @CyrilSterling (and @alexzhu for the original paper!) 14 | """ 15 | def __init__(self, nn_param: dict = {}): 16 | super().__init__() 17 | self.no_batch_norm = nn_param["no_batch_norm"] 18 | 19 | # Parameters for event voxel input. 20 | self.n_channel = nn_param["n_bin"] if "n_bin" in nn_param.keys() else 4 21 | self.scale_bin_time = nn_param["scale_time"] if "scale_time" in nn_param.keys() else 128.0 22 | 23 | self.encoder1 = basic_layers.general_conv2d( 24 | in_channels=self.n_channel, 25 | out_channels=_BASE_CHANNELS, 26 | do_batch_norm=not self.no_batch_norm, 27 | ) 28 | self.encoder2 = basic_layers.general_conv2d( 29 | in_channels=_BASE_CHANNELS, 30 | out_channels=2 * _BASE_CHANNELS, 31 | do_batch_norm=not self.no_batch_norm, 32 | ) 33 | self.encoder3 = basic_layers.general_conv2d( 34 | in_channels=2 * _BASE_CHANNELS, 35 | out_channels=4 * _BASE_CHANNELS, 36 | do_batch_norm=not self.no_batch_norm, 37 | ) 38 | self.encoder4 = basic_layers.general_conv2d( 39 | in_channels=4 * _BASE_CHANNELS, 40 | out_channels=8 * _BASE_CHANNELS, 41 | do_batch_norm=not self.no_batch_norm, 42 | ) 43 | 44 | self.resnet_block = nn.Sequential( 45 | *[ 46 | basic_layers.build_resnet_block( 47 | 8 * _BASE_CHANNELS, do_batch_norm=not self.no_batch_norm 48 | ) 49 | for i in range(2) 50 | ] 51 | ) 52 | 53 | self.decoder1 = basic_layers.upsample_conv2d_and_predict_flow_or_depth( 54 | in_channels=16 * _BASE_CHANNELS, 55 | out_channels=4 * _BASE_CHANNELS, 56 | do_batch_norm=not self.no_batch_norm, 57 | type="flow", 58 | scale=self.scale_bin_time, 59 | ) 60 | 61 | self.decoder2 = basic_layers.upsample_conv2d_and_predict_flow_or_depth( 62 | in_channels=8 * _BASE_CHANNELS + 2, 63 | out_channels=2 * _BASE_CHANNELS, 64 | do_batch_norm=not self.no_batch_norm, 65 | type="flow", 66 | scale=self.scale_bin_time, 67 | ) 68 | 69 | self.decoder3 = basic_layers.upsample_conv2d_and_predict_flow_or_depth( 70 | in_channels=4 * _BASE_CHANNELS + 2, 71 | out_channels=_BASE_CHANNELS, 72 | do_batch_norm=not self.no_batch_norm, 73 | type="flow", 74 | scale=self.scale_bin_time, 75 | ) 76 | 77 | self.decoder4 = basic_layers.upsample_conv2d_and_predict_flow_or_depth( 78 | in_channels=2 * _BASE_CHANNELS + 2, 79 | out_channels=int(_BASE_CHANNELS / 2), 80 | do_batch_norm=not self.no_batch_norm, 81 | type="flow", 82 | scale=self.scale_bin_time, 83 | ) 84 | 85 | @utils.profile( 86 | output_file="dnn_forward.prof", sort_by="cumulative", lines_to_print=300, strip_dirs=True 87 | ) 88 | def forward(self, inputs: torch.Tensor) -> dict: 89 | """ 90 | Args: 91 | inputs (torch.Tensor) ... [n_batch, n_bin, height, width] 92 | 93 | Returns 94 | flow_dict (dict) ... "flow3": [n_batch, 1, height, width] 95 | "flow0": [n_batch, 1, height // 2**3, width // 2**3] 96 | """ 97 | # encoder 98 | skip_connections = {} 99 | inputs = self.encoder1(inputs) 100 | skip_connections["skip0"] = inputs.clone() 101 | inputs = self.encoder2(inputs) 102 | skip_connections["skip1"] = inputs.clone() 103 | inputs = self.encoder3(inputs) 104 | skip_connections["skip2"] = inputs.clone() 105 | inputs = self.encoder4(inputs) 106 | skip_connections["skip3"] = inputs.clone() 107 | 108 | # transition 109 | inputs = self.resnet_block(inputs) 110 | 111 | # decoder 112 | flow_dict = {} 113 | inputs = torch.cat([inputs, skip_connections["skip3"]], dim=1) 114 | inputs, flow = self.decoder1(inputs) 115 | flow_dict["flow0"] = flow.clone() 116 | 117 | inputs = torch.cat([inputs, skip_connections["skip2"]], dim=1) 118 | inputs, flow = self.decoder2(inputs) 119 | flow_dict["flow1"] = flow.clone() 120 | 121 | inputs = torch.cat([inputs, skip_connections["skip1"]], dim=1) 122 | inputs, flow = self.decoder3(inputs) 123 | flow_dict["flow2"] = flow.clone() 124 | 125 | inputs = torch.cat([inputs, skip_connections["skip0"]], dim=1) 126 | inputs, flow = self.decoder4(inputs) 127 | flow_dict["flow3"] = flow.clone() 128 | 129 | return flow_dict 130 | -------------------------------------------------------------------------------- /src/solver/patch_contrast_mixed.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shutil 3 | from typing import Any, Dict, List, Optional, Tuple 4 | 5 | import numpy as np 6 | import scipy 7 | import torch 8 | 9 | from .. import utils, visualizer 10 | from . import scipy_autograd 11 | from .base import SCIPY_OPTIMIZERS 12 | from .patch_contrast_base import PatchContrastMaximization 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class MixedPatchContrastMaximization(PatchContrastMaximization): 18 | """Mixed patch-based CMax. 19 | 20 | Params: 21 | image_shape (tuple) ... (H, W) 22 | calibration_parameter (dict) ... dictionary of the calibration parameter 23 | solver_config (dict) ... solver configuration 24 | optimizer_config (dict) ... optimizer configuration 25 | visualize_module ... visualizer.Visualizer 26 | """ 27 | 28 | def __init__( 29 | self, 30 | image_shape: tuple, 31 | calibration_parameter: dict, 32 | solver_config: dict = {}, 33 | optimizer_config: dict = {}, 34 | output_config: dict = {}, 35 | visualize_module: Optional[visualizer.Visualizer] = None, 36 | ): 37 | super().__init__( 38 | image_shape, 39 | calibration_parameter, 40 | solver_config, 41 | optimizer_config, 42 | output_config, 43 | visualize_module, 44 | ) 45 | self.set_patch_size_and_sliding_window() 46 | self.patches, self.patch_image_size = self.prepare_patch( 47 | image_shape, self.patch_size, self.sliding_window 48 | ) 49 | self.n_patch = len(self.patches.keys()) 50 | 51 | # internal variable 52 | self._patch_motion_model_keys = [ 53 | f"patch{i}_{k}" for i in range(self.n_patch) for k in self.motion_model_keys 54 | ] 55 | 56 | def optimize(self, events: np.ndarray) -> np.ndarray: 57 | """Run optimization. 58 | 59 | Inputs: 60 | events (np.ndarray) ... [n_events x 4] event array. Should be (x, y, t, p). 61 | n_iteration (int) ... How many iterations to run. 62 | 63 | """ 64 | # Preprocessings 65 | logger.info("Start optimization.") 66 | logger.info(f"DoF is {self.motion_vector_size * self.n_patch}") 67 | 68 | if self.opt_method == "optuna": 69 | opt_result = self.run_optuna(events) 70 | logger.info(f"End optimization.") 71 | best_motion = self.get_motion_array_optuna(opt_result.best_params) 72 | elif self.opt_method in SCIPY_OPTIMIZERS: 73 | opt_result = self.run_scipy(events) 74 | logger.info(f"End optimization.\n Best parameters: {opt_result}") 75 | best_motion = opt_result.x.reshape( 76 | ((self.motion_vector_size,) + self.patch_image_size) 77 | ) # / 1000 78 | 79 | logger.info("Profile file saved.") 80 | if self.visualizer: 81 | shutil.copy("optimize.prof", self.visualizer.save_dir) 82 | if self.opt_method in SCIPY_OPTIMIZERS: 83 | self.visualizer.visualize_scipy_history( 84 | self.cost_func.get_history(), self.cost_weight 85 | ) 86 | 87 | logger.info(f"{best_motion}") 88 | return best_motion 89 | 90 | # Optuna functions 91 | def objective(self, trial, events: np.ndarray): 92 | # Parameters setting 93 | params = {k: self.sampling(trial, k) for k in self._patch_motion_model_keys} 94 | motion_array = self.get_motion_array_optuna(params) # 2 x H x W 95 | if self.normalize_t_in_batch: 96 | t_scale = np.max(events[:, 2]) - np.min(events[:, 2]) 97 | motion_array *= t_scale 98 | dense_flow = self.motion_to_dense_flow(motion_array) 99 | 100 | loss = self.calculate_cost(events, dense_flow, self.motion_model_for_dense_warp) 101 | logger.info(f"{trial.number = } / {loss = }") 102 | return loss 103 | 104 | def sampling(self, trial, key: str): 105 | """Sampling function for mixed type patch solution. 106 | 107 | Args: 108 | trial ([type]): [description] 109 | key (str): [description] 110 | 111 | Returns: 112 | [type]: [description] 113 | """ 114 | key_suffix = key[key.find("_") + 1 :] 115 | return trial.suggest_uniform( 116 | key, 117 | self.opt_config["parameters"][key_suffix]["min"], 118 | self.opt_config["parameters"][key_suffix]["max"], 119 | ) 120 | 121 | def get_motion_array_optuna(self, params: dict) -> np.ndarray: 122 | # Returns [n_patch x n_motion_paremter] 123 | motion_array = np.zeros((self.motion_vector_size, self.n_patch)) 124 | for i in range(self.n_patch): 125 | param = {k: params[f"patch{i}_{k}"] for k in self.motion_model_keys} 126 | motion_array[:, i] = self.motion_model_to_motion(param) 127 | return motion_array.reshape((self.motion_vector_size,) + self.patch_image_size) 128 | 129 | # Scipy 130 | @utils.profile( 131 | output_file="optimize.prof", sort_by="cumulative", lines_to_print=300, strip_dirs=True 132 | ) 133 | def run_scipy(self, events: np.ndarray) -> scipy.optimize.OptimizeResult: 134 | if self.previous_frame_best_estimation is not None: 135 | motion0 = np.copy(self.previous_frame_best_estimation) 136 | else: 137 | # Initialize with various methods 138 | if self.slv_config["patch"]["initialize"] == "random": 139 | motion0 = self.initialize_random() 140 | elif self.slv_config["patch"]["initialize"] == "zero": 141 | motion0 = self.initialize_zeros() 142 | elif self.slv_config["patch"]["initialize"] == "global-best": 143 | logger.info("sampling initialization") 144 | best_guess = self.initialize_guess_from_whole_image(events) 145 | if isinstance(best_guess, torch.Tensor): 146 | motion0 = torch.tile(best_guess[None], (self.n_patch, 1)).T.reshape(-1) 147 | elif isinstance(best_guess, np.ndarray): 148 | motion0 = np.tile(best_guess[None], (self.n_patch, 1)).T.reshape(-1) 149 | elif self.slv_config["patch"]["initialize"] == "grid-best": 150 | logger.info("sampling initialization") 151 | best_guess = self.initialize_guess_from_patch( 152 | events, patch_index=self.n_patch // 2 - 1 153 | ) 154 | if isinstance(best_guess, torch.Tensor): 155 | motion0 = torch.tile(best_guess[None], (self.n_patch, 1)).T.reshape(-1) 156 | elif isinstance(best_guess, np.ndarray): 157 | motion0 = np.tile(best_guess[None], (self.n_patch, 1)).T.reshape(-1) 158 | # motion0 += ( 159 | # np.random.rand(self.motion_vector_size * self.n_patch).astype(np.double64) * 10 - 5 160 | # ) 161 | elif self.slv_config["patch"]["initialize"] == "optuna-sampling": 162 | logger.info("Optuna intelligent sampling initialization") 163 | motion0 = self.initialize_guess_from_optuna_sampling(events) 164 | self.cost_func.clear_history() 165 | 166 | self.events = torch.from_numpy(events).double().requires_grad_().to(self._device) 167 | result = scipy_autograd.minimize( 168 | self.objective_scipy, 169 | motion0, 170 | method=self.opt_method, 171 | options={ 172 | "gtol": 1e-7, 173 | "disp": True, 174 | "maxiter": self.opt_config["max_iter"], 175 | "eps": 0.01, 176 | }, 177 | precision="float64", 178 | torch_device=self._device, 179 | # TODO support bounds 180 | # bounds=[(-300, 300), (-300, 300)] 181 | ) 182 | return result 183 | 184 | def objective_scipy(self, motion_array: np.ndarray, suppress_log: bool = False): 185 | """ 186 | Args: 187 | motion_array (np.ndarray): [2 * n_patches] array 188 | 189 | Returns: 190 | [type]: [description] 191 | """ 192 | if self.normalize_t_in_batch: 193 | t_scale = self.events[:, 2].max() - self.events[:, 2].min() 194 | else: 195 | t_scale = 1.0 196 | 197 | events = self.events.clone() 198 | dense_flow = self.motion_to_dense_flow(motion_array * t_scale) 199 | 200 | loss = self.calculate_cost( 201 | events, 202 | dense_flow, 203 | self.motion_model_for_dense_warp, 204 | motion_array.reshape((self.motion_vector_size,) + self.patch_image_size), 205 | ) 206 | if not suppress_log: 207 | logger.info(f"{loss = }") 208 | return loss 209 | -------------------------------------------------------------------------------- /src/solver/scipy_autograd/README.md: -------------------------------------------------------------------------------- 1 | 2 | Code is cominig from 3 | https://github.com/brunorigal/autograd-minimize 4 | 5 | -------------------------------------------------------------------------------- /src/solver/scipy_autograd/__init__.py: -------------------------------------------------------------------------------- 1 | from .scipy_minimize import minimize 2 | -------------------------------------------------------------------------------- /src/solver/scipy_autograd/base_wrapper.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import scipy.optimize as sopt 5 | import torch 6 | 7 | 8 | class BaseWrapper(ABC): 9 | def get_input(self, input_var): 10 | self.input_type = type(input_var) 11 | assert self.input_type in [ 12 | dict, 13 | list, 14 | np.ndarray, 15 | torch.Tensor, 16 | ], "The initial input to your optimized function should be one of dict, list or np.ndarray" 17 | input_, self.shapes = self._concat(input_var) 18 | self.var_num = input_.shape[0] 19 | return input_ 20 | 21 | def get_output(self, output_var): 22 | assert "shapes" in dir(self), "You must first call get input to define the tensors shapes." 23 | output_var_ = self._unconcat(output_var, self.shapes) 24 | return output_var_ 25 | 26 | def get_bounds(self, bounds): 27 | 28 | if bounds is not None: 29 | if isinstance(bounds, tuple) and not ( 30 | isinstance(bounds[0], tuple) or isinstance(bounds[0], sopt.Bounds) 31 | ): 32 | assert len(bounds) == 2 33 | new_bounds = [bounds] * self.var_num 34 | 35 | elif isinstance(bounds, sopt.Bounds): 36 | new_bounds = [bounds] * self.var_num 37 | 38 | elif type(bounds) in [list, tuple, np.ndarray]: 39 | if self.input_type in [list, tuple]: 40 | assert len(self.shapes) == len(bounds) 41 | new_bounds = [] 42 | for sh, bounds_ in zip(self.shapes, bounds): 43 | new_bounds += format_bounds(bounds_, sh) 44 | elif self.input_type in [np.ndarray]: 45 | new_bounds = bounds 46 | elif self.input_type in [torch.Tensor]: 47 | new_bounds = bounds.detach().cpu().numpy() 48 | 49 | elif isinstance(bounds, dict): 50 | assert self.input_type == dict 51 | assert set(bounds.keys()).issubset(self.shapes.keys()) 52 | 53 | new_bounds = [] 54 | for k in self.shapes.keys(): 55 | if k in bounds.keys(): 56 | new_bounds += format_bounds(bounds[k], self.shapes[k]) 57 | else: 58 | new_bounds += [(None, None)] ** np.prod(self.shapes[k], dtype=np.int32) 59 | else: 60 | new_bounds = bounds 61 | return new_bounds 62 | 63 | def get_constraints(self, constraints, method): 64 | if constraints is not None and not isinstance(constraints, sopt.LinearConstraint): 65 | assert isinstance(constraints, dict) 66 | assert "fun" in constraints.keys() 67 | self.ctr_func = constraints["fun"] 68 | use_autograd = constraints.get("use_autograd", True) 69 | if method in ["trust-constr"]: 70 | 71 | constraints = sopt.NonlinearConstraint( 72 | self._eval_ctr_func, 73 | lb=constraints.get("lb", -np.inf), 74 | ub=constraints.get("ub", np.inf), 75 | jac=self.get_ctr_jac if use_autograd else "2-point", 76 | keep_feasible=constraints.get("keep_feasible", False), 77 | ) 78 | elif method in ["COBYLA", "SLSQP"]: 79 | constraints = { 80 | "type": constraints.get("type", "eq"), 81 | "fun": self._eval_ctr_func, 82 | } 83 | if use_autograd: 84 | constraints["jac"] = self.get_ctr_jac 85 | else: 86 | raise NotImplementedError 87 | elif constraints is None: 88 | constraints = () 89 | return constraints 90 | 91 | @abstractmethod 92 | def get_value_and_grad(self, input_var): 93 | return 94 | 95 | @abstractmethod 96 | def get_hvp(self, input_var, vector): 97 | return 98 | 99 | @abstractmethod 100 | def get_hess(self, input_var): 101 | return 102 | 103 | def _eval_func(self, input_var): 104 | if isinstance(input_var, dict): 105 | loss = self.func(**input_var) 106 | elif isinstance(input_var, list) or isinstance(input_var, tuple): 107 | loss = self.func(*input_var) 108 | else: 109 | loss = self.func(input_var) 110 | return loss 111 | 112 | def _eval_ctr_func(self, input_var): 113 | input_var = self._unconcat(input_var, self.shapes) 114 | if isinstance(input_var, dict): 115 | ctr_val = self.ctr_func(**input_var) 116 | elif isinstance(input_var, list) or isinstance(input_var, tuple): 117 | ctr_val = self.ctr_func(*input_var) 118 | else: 119 | ctr_val = self.ctr_func(input_var) 120 | return ctr_val 121 | 122 | @abstractmethod 123 | def get_ctr_jac(self, input_var): 124 | return 125 | 126 | def _concat(self, ten_vals): 127 | ten = [] 128 | if isinstance(ten_vals, dict): 129 | shapes = {} 130 | for k, t in ten_vals.items(): 131 | if t is not None: 132 | if isinstance(t, (np.floating, float, int)): 133 | t = np.array(t) 134 | shapes[k] = t.shape 135 | ten.append(self._reshape(t, [-1])) 136 | ten = self._tconcat(ten, 0) 137 | 138 | elif isinstance(ten_vals, list) or isinstance(ten_vals, tuple): 139 | shapes = [] 140 | for t in ten_vals: 141 | if t is not None: 142 | if isinstance(t, (np.floating, float, int)): 143 | t = np.array(t) 144 | shapes.append(t.shape) 145 | ten.append(self._reshape(t, [-1])) 146 | ten = self._tconcat(ten, 0) 147 | 148 | elif isinstance(ten_vals, (np.floating, float, int)): 149 | ten_vals = np.array(ten_vals) 150 | shapes = np.array(ten_vals).shape 151 | ten = self._reshape(np.array(ten_vals), [-1]) 152 | elif isinstance(ten_vals, torch.Tensor): 153 | ten_vals = ten_vals.detach().cpu() 154 | shapes = ten_vals.shape 155 | ten = self._reshape(ten_vals, [-1]) 156 | else: 157 | ten_vals = ten_vals 158 | shapes = ten_vals.shape 159 | ten = self._reshape(ten_vals, [-1]) 160 | return ten, shapes 161 | 162 | def _unconcat(self, ten, shapes): 163 | current_ind = 0 164 | if isinstance(shapes, dict): 165 | ten_vals = {} 166 | for k, sh in shapes.items(): 167 | next_ind = current_ind + np.prod(sh, dtype=np.int32) 168 | ten_vals[k] = self._reshape(self._gather(ten, current_ind, next_ind), sh) 169 | 170 | current_ind = next_ind 171 | 172 | elif isinstance(shapes, list) or isinstance(shapes, tuple): 173 | if isinstance(shapes[0], int): 174 | ten_vals = self._reshape(ten, shapes) 175 | else: 176 | ten_vals = [] 177 | for sh in shapes: 178 | next_ind = current_ind + np.prod(sh, dtype=np.int32) 179 | ten_vals.append(self._reshape(self._gather(ten, current_ind, next_ind), sh)) 180 | 181 | current_ind = next_ind 182 | 183 | elif shapes is None: 184 | ten_vals = ten 185 | 186 | return ten_vals 187 | 188 | @abstractmethod 189 | def _reshape(self, t, sh): 190 | return 191 | 192 | @abstractmethod 193 | def _tconcat(self, t_list, dim=0): 194 | return 195 | 196 | @abstractmethod 197 | def _gather(self, t, i, j): 198 | return 199 | 200 | 201 | def format_bounds(bounds_, sh): 202 | if isinstance(bounds_, tuple): 203 | assert len(bounds_) == 2 204 | return [bounds_] * np.prod(sh, dtype=np.int32) 205 | elif isinstance(bounds_, sopt.Bounds): 206 | return [bounds_] * np.prod(sh, dtype=np.int32) 207 | elif isinstance(bounds_, list): 208 | assert np.prod(sh) == len(bounds_) 209 | return bounds_ 210 | elif isinstance(bounds_, np.ndarray): 211 | assert np.prod(sh) == np.prod(np.array(bounds_).shape) 212 | return np.concatenate(np.reshape(bounds_, -1)).tolist() 213 | else: 214 | raise TypeError 215 | -------------------------------------------------------------------------------- /src/solver/scipy_autograd/scipy_minimize.py: -------------------------------------------------------------------------------- 1 | import scipy.optimize as sopt 2 | 3 | from .torch_wrapper import TorchWrapper 4 | 5 | 6 | def minimize( 7 | fun, 8 | x0, 9 | args=(), 10 | precision="float32", 11 | method=None, 12 | hvp_type=None, 13 | torch_device="cpu", 14 | bounds=None, 15 | constraints=None, 16 | tol=None, 17 | callback=None, 18 | options=None, 19 | ): 20 | """ 21 | wrapper around the [minimize](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html) 22 | function of scipy which includes an automatic computation of gradients, 23 | hessian vector product or hessian with tensorflow or torch backends. 24 | :param fun: function to be minimized, its signature can be a tensor, a list of tensors or a dict of tensors. 25 | :type fun: tensorflow of torch function 26 | :param x0: input to the function, it must match the signature of the function. 27 | :type x0: np.ndarray, list of arrays or dict of arrays. 28 | :param precision: one of 'float32' or 'float64', defaults to 'float32' 29 | :type precision: str, optional 30 | :param method: method used by the optimizer, it should be one of: 31 | 'Nelder-Mead', 32 | 'Powell', 33 | 'CG', 34 | 'BFGS', 35 | 'Newton-CG', 36 | 'L-BFGS-B', 37 | 'TNC', 38 | 'COBYLA', 39 | 'SLSQP', 40 | 'trust-constr', 41 | 'dogleg', # requires positive semi definite hessian 42 | 'trust-ncg', 43 | 'trust-exact', # requires hessian 44 | 'trust-krylov' 45 | , defaults to None 46 | :type method: str, optional 47 | :param hvp_type: type of computation scheme for the hessian vector product 48 | for the torch backend it is one of hvp and vhp (vhp is faster according to the [doc](https://pytorch.org/docs/stable/autograd.html)) 49 | for the tf backend it is one of 'forward_over_back', 'back_over_forward', 'tf_gradients_forward_over_back' and 'back_over_back' 50 | Some infos about the most interesting scheme are given [here](https://www.tensorflow.org/api_docs/python/tf/autodiff/ForwardAccumulator) 51 | , defaults to None 52 | :type hvp_type: str, optional 53 | :param torch_device: device used by torch for the gradients computation, 54 | if the backend is not torch, this parameter is ignored, defaults to 'cpu' 55 | :type torch_device: str, optional 56 | :param bounds: Bounds on the input variables, only available for L-BFGS-B, TNC, SLSQP, Powell, and trust-constr methods. 57 | It can be: 58 | * a tuple (min, max), None indicates no bounds, in this case the same bound is applied to all variables. 59 | * An instance of the [Bounds](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.Bounds.html#scipy.optimize.Bounds) class, in this case the same bound is applied to all variables. 60 | * A numpy array of bounds (if the optimized function has a single numpy array as input) 61 | * A list or dict of bounds with the same format as the optimized function signature. 62 | , defaults to None 63 | :type bounds: tuple, list, dict or np.ndarray, optional 64 | :param constraints: It has to be a dict with the following keys: 65 | * fun: a callable computing the constraint function 66 | * lb and ub: the lower and upper bounds, if equal, the constraint is an inequality, use np.inf if there is no upper bound. Only used if method is trust-constr. 67 | * type: 'eq' or 'ineq' only used if method is one of COBYLA, SLSQP. 68 | * keep_feasible: see [here](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.NonlinearConstraint.html#scipy.optimize.NonlinearConstraint) 69 | , defaults to None 70 | :type constraints: dict, optional 71 | :param tol: Tolerance for termination, defaults to None 72 | :type tol: float, optional 73 | :param callback: Called after each iteration, defaults to None 74 | :type callback: callable, optional 75 | :param options: solver options, defaults to None 76 | :type options: dict, optional 77 | :return: dict of optimization results 78 | :rtype: dict 79 | """ 80 | 81 | wrapper = TorchWrapper(fun, precision=precision, hvp_type=hvp_type, device=torch_device) 82 | 83 | if bounds is not None: 84 | assert method in [ 85 | None, 86 | "L-BFGS-B", 87 | "TNC", 88 | "SLSQP", 89 | "Powell", 90 | "trust-constr", 91 | ], "bounds are only available for L-BFGS-B, TNC, SLSQP, Powell, trust-constr" 92 | 93 | if constraints is not None: 94 | assert method in [ 95 | "COBYLA", 96 | "SLSQP", 97 | "trust-constr", 98 | ], "Constraints are only available for COBYLA, SLSQP and trust-constr" 99 | 100 | optim_res = sopt.minimize( 101 | wrapper.get_value_and_grad, 102 | wrapper.get_input(x0), 103 | args=args, 104 | method=method, 105 | jac=True, 106 | hessp=wrapper.get_hvp 107 | if method in ["Newton-CG", "trust-ncg", "trust-krylov", "trust-constr"] 108 | else None, 109 | hess=wrapper.get_hess if method in ["dogleg", "trust-exact"] else None, 110 | bounds=wrapper.get_bounds(bounds), 111 | constraints=wrapper.get_constraints(constraints, method), 112 | tol=tol, 113 | callback=callback, 114 | options=options, 115 | ) 116 | 117 | optim_res.x = wrapper.get_output(optim_res.x) 118 | 119 | if "jac" in optim_res.keys() and len(optim_res.jac) > 0: 120 | try: 121 | optim_res.jac = wrapper.get_output(optim_res.jac[0]) 122 | except: 123 | pass 124 | 125 | return optim_res 126 | -------------------------------------------------------------------------------- /src/solver/scipy_autograd/torch_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor, nn 6 | from torch.autograd.functional import hessian, hvp, vhp 7 | 8 | from .base_wrapper import BaseWrapper 9 | 10 | 11 | class TorchWrapper(BaseWrapper): 12 | def __init__(self, func, precision="float32", hvp_type="vhp", device="cpu"): 13 | self.func = func 14 | 15 | # Not very clean... 16 | if "device" in dir(func): 17 | self.device = func.device 18 | else: 19 | self.device = torch.device(device) 20 | 21 | if precision == "float32": 22 | self.precision = torch.float32 23 | elif precision == "float64": 24 | self.precision = torch.float64 25 | else: 26 | raise ValueError 27 | 28 | self.hvp_func = hvp if hvp_type == "hvp" else vhp 29 | 30 | def get_value_and_grad(self, input_var): 31 | assert "shapes" in dir(self), "You must first call get input to define the tensors shapes." 32 | 33 | input_var_ = self._unconcat( 34 | torch.tensor(input_var, dtype=self.precision, requires_grad=True, device=self.device), 35 | self.shapes, 36 | ) 37 | 38 | loss = self._eval_func(input_var_) 39 | input_var_grad = input_var_.values() if isinstance(input_var_, dict) else input_var_ 40 | grads = torch.autograd.grad(loss, input_var_grad) 41 | # print('=====', torch.linalg.norm(grads[0], ord=1)) 42 | 43 | if isinstance(input_var_, dict): 44 | grads = {k: v for k, v in zip(input_var_.keys(), grads)} 45 | 46 | return [ 47 | loss.cpu().detach().numpy().astype(np.float64), 48 | self._concat(grads)[0].cpu().detach().numpy().astype(np.float64), 49 | ] 50 | 51 | def get_hvp(self, input_var, vector): 52 | assert "shapes" in dir(self), "You must first call get input to define the tensors shapes." 53 | 54 | input_var_ = self._unconcat( 55 | torch.tensor(input_var, dtype=self.precision, device=self.device), self.shapes 56 | ) 57 | vector_ = self._unconcat( 58 | torch.tensor(vector, dtype=self.precision, device=self.device), self.shapes 59 | ) 60 | 61 | if isinstance(input_var_, dict): 62 | input_var_ = tuple(input_var_.values()) 63 | if isinstance(vector_, dict): 64 | vector_ = tuple(vector_.values()) 65 | 66 | if isinstance(input_var_, list): 67 | input_var_ = tuple(input_var_) 68 | if isinstance(vector_, list): 69 | vector_ = tuple(vector_) 70 | 71 | loss, vhp_res = self.hvp_func(self.func, input_var_, v=vector_) 72 | 73 | return self._concat(vhp_res)[0].cpu().detach().numpy().astype(np.float64) 74 | 75 | def get_hess(self, input_var): 76 | assert "shapes" in dir(self), "You must first call get input to define the tensors shapes." 77 | input_var_ = torch.tensor(input_var, dtype=self.precision, device=self.device) 78 | 79 | def func(inp): 80 | return self._eval_func(self._unconcat(inp, self.shapes)) 81 | 82 | hess = hessian(func, input_var_, vectorize=False) 83 | 84 | return hess.cpu().detach().numpy().astype(np.float64) 85 | 86 | def get_ctr_jac(self, input_var): 87 | assert "shapes" in dir(self), "You must first call get input to define the tensors shapes." 88 | 89 | input_var_ = self._unconcat( 90 | torch.tensor(input_var, dtype=self.precision, requires_grad=True, device=self.device), 91 | self.shapes, 92 | ) 93 | 94 | ctr_val = self._eval_ctr_func(input_var_) 95 | input_var_grad = input_var_.values() if isinstance(input_var_, dict) else input_var_ 96 | grads = torch.autograd.grad(ctr_val, input_var_grad) 97 | 98 | return grads.cpu().detach().numpy().astype(np.float64) 99 | 100 | def _reshape(self, t, sh): 101 | if torch.is_tensor(t): 102 | return t.reshape(sh) 103 | elif isinstance(t, np.ndarray): 104 | return np.reshape(t, sh) 105 | else: 106 | raise NotImplementedError 107 | 108 | def _tconcat(self, t_list, dim=0): 109 | if torch.is_tensor(t_list[0]): 110 | return torch.cat(t_list, dim) 111 | elif isinstance(t_list[0], np.ndarray): 112 | return np.concatenate(t_list, dim) 113 | else: 114 | raise NotImplementedError 115 | 116 | def _gather(self, t, i, j): 117 | if isinstance(t, np.ndarray) or torch.is_tensor(t): 118 | return t[i:j] 119 | else: 120 | raise NotImplementedError 121 | 122 | 123 | def torch_function_factory(model, loss, train_x, train_y, precision="float32", optimized_vars=None): 124 | """ 125 | A factory to create a function of the torch parameter model. 126 | :param model: torch model 127 | :type model: torch.nn.Modle] 128 | :param loss: a function with signature loss_value = loss(pred_y, true_y). 129 | :type loss: function 130 | :param train_x: dataset used as input of the model 131 | :type train_x: np.ndarray 132 | :param train_y: dataset used as ground truth input of the loss 133 | :type train_y: np.ndarray 134 | :return: (function of the parameters, list of parameters, names of parameters) 135 | :rtype: tuple 136 | """ 137 | # named_params = {k: var.cpu().detach().numpy() for k, var in model.named_parameters()} 138 | params, names = extract_weights(model) 139 | device = params[0].device 140 | 141 | prec_ = torch.float32 if precision == "float32" else torch.float64 142 | if isinstance(train_x, np.ndarray): 143 | train_x = torch.tensor(train_x, dtype=prec_, device=device) 144 | if isinstance(train_y, np.ndarray): 145 | train_y = torch.tensor(train_y, dtype=prec_, device=device) 146 | 147 | def func(*new_params): 148 | load_weights(model, {k: v for k, v in zip(names, new_params)}) 149 | out = apply_func(model, train_x) 150 | 151 | return loss(out, train_y) 152 | 153 | func.device = device 154 | 155 | return func, [p.cpu().detach().numpy() for p in params], names 156 | 157 | 158 | def apply_func(func, input_): 159 | if isinstance(input_, dict): 160 | return func(**input_) 161 | elif isinstance(input_, list) or isinstance(input_, tuple): 162 | return func(*input_) 163 | else: 164 | return func(input_) 165 | 166 | 167 | # Adapted from https://github.com/pytorch/pytorch/blob/21c04b4438a766cd998fddb42247d4eb2e010f9a/benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py 168 | 169 | # Utilities to make nn.Module "functional" 170 | # In particular the goal is to be able to provide a function that takes as input 171 | # the parameters and evaluate the nn.Module using fixed inputs. 172 | 173 | 174 | def _del_nested_attr(obj: nn.Module, names: List[str]) -> None: 175 | """ 176 | Deletes the attribute specified by the given list of names. 177 | For example, to delete the attribute obj.conv.weight, 178 | use _del_nested_attr(obj, ['conv', 'weight']) 179 | """ 180 | if len(names) == 1: 181 | delattr(obj, names[0]) 182 | else: 183 | _del_nested_attr(getattr(obj, names[0]), names[1:]) 184 | 185 | 186 | def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None: 187 | """ 188 | Set the attribute specified by the given list of names to value. 189 | For example, to set the attribute obj.conv.weight, 190 | use _del_nested_attr(obj, ['conv', 'weight'], value) 191 | """ 192 | if len(names) == 1: 193 | setattr(obj, names[0], value) 194 | else: 195 | _set_nested_attr(getattr(obj, names[0]), names[1:], value) 196 | 197 | 198 | def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]: 199 | """ 200 | This function removes all the Parameters from the model and 201 | return them as a tuple as well as their original attribute names. 202 | The weights must be re-loaded with `load_weights` before the model 203 | can be used again. 204 | Note that this function modifies the model in place and after this 205 | call, mod.parameters() will be empty. 206 | """ 207 | 208 | orig_params = [p for p in mod.parameters() if p.requires_grad] 209 | # Remove all the parameters in the model 210 | names = [] 211 | for name, p in list(mod.named_parameters()): 212 | if p.requires_grad: 213 | _del_nested_attr(mod, name.split(".")) 214 | names.append(name) 215 | 216 | # Make params regular Tensors instead of nn.Parameter 217 | params = tuple(p.detach().requires_grad_() for p in orig_params) 218 | return params, names 219 | 220 | 221 | def load_weights(mod: nn.Module, params: Dict[str, Tensor]) -> None: 222 | """ 223 | Reload a set of weights so that `mod` can be used again to perform a forward pass. 224 | Note that the `params` are regular Tensors (that can have history) and so are left 225 | as Tensors. This means that mod.parameters() will still be empty after this call. 226 | """ 227 | for name, p in params.items(): 228 | _set_nested_attr(mod, name.split("."), p) 229 | -------------------------------------------------------------------------------- /src/solver/time_aware_patch_contrast.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict, List, Optional, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from .. import solver, types, utils, visualizer 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class TimeAwarePatchContrastMaximization(solver.MixedPatchContrastMaximization): 13 | """Time-aware patch-based CMax. 14 | 15 | Params: 16 | image_shape (tuple) ... (H, W) 17 | calibration_parameter (dict) ... dictionary of the calibration parameter 18 | solver_config (dict) ... solver configuration 19 | optimizer_config (dict) ... optimizer configuration 20 | visualize_module ... visualizer.Visualizer 21 | """ 22 | 23 | def __init__( 24 | self, 25 | image_shape: tuple, 26 | calibration_parameter: dict, 27 | solver_config: dict = {}, 28 | optimizer_config: dict = {}, 29 | output_config: dict = {}, 30 | visualize_module: Optional[visualizer.Visualizer] = None, 31 | ): 32 | super().__init__( 33 | image_shape, 34 | calibration_parameter, 35 | solver_config, 36 | optimizer_config, 37 | output_config, 38 | visualize_module, 39 | ) 40 | assert self.is_time_aware 41 | 42 | def motion_to_dense_flow(self, motion_array: types.NUMPY_TORCH) -> types.NUMPY_TORCH: 43 | """Returns dense flow at quantized time voxel. 44 | TODO eventually I should be able to remove this entire class! 45 | 46 | Args: 47 | motion_array (types.NUMPY_TORCH): [2 x h_patch x w_patch] Flow array. 48 | 49 | Returns: 50 | types.NUMPY_TORCH: [time_bin x 2 x H x W] 51 | """ 52 | if self.scale_later: 53 | scale = motion_array.max() 54 | else: 55 | scale = 1.0 56 | if isinstance(motion_array, np.ndarray): 57 | dense_flow_t0 = self.interpolate_dense_flow_from_patch_numpy(motion_array) 58 | return ( 59 | utils.construct_dense_flow_voxel_numpy( 60 | dense_flow_t0 / scale, 61 | self.time_bin, 62 | self.flow_interpolation, 63 | t0_location=self.t0_flow_location, 64 | ) 65 | * scale 66 | ) 67 | elif isinstance(motion_array, torch.Tensor): 68 | dense_flow_t0 = self.interpolate_dense_flow_from_patch_tensor(motion_array) 69 | return ( 70 | utils.construct_dense_flow_voxel_torch( 71 | dense_flow_t0 / scale, 72 | self.time_bin, 73 | self.flow_interpolation, 74 | t0_location=self.t0_flow_location, 75 | ) 76 | * scale 77 | ) 78 | e = f"Unsupported type: {type(motion_array)}" 79 | raise TypeError(e) 80 | -------------------------------------------------------------------------------- /src/types/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from .flow_patch import FlowPatch 7 | 8 | NUMPY_TORCH = Union[np.ndarray, torch.Tensor] 9 | FLOAT_TORCH = Union[float, torch.Tensor] 10 | 11 | 12 | def is_torch(arr: Any) -> bool: 13 | return isinstance(arr, torch.Tensor) 14 | 15 | 16 | def is_numpy(arr: Any) -> bool: 17 | return isinstance(arr, np.ndarray) 18 | 19 | 20 | def nt_max(array: NUMPY_TORCH, dim: int) -> NUMPY_TORCH: 21 | """max function compatible for numpy ndarray and torch tensor. 22 | 23 | Args: 24 | array (NUMPY_TORCH): 25 | 26 | Returns: 27 | NUMPY_TORCH: _description_ 28 | """ 29 | if is_numpy(array): 30 | return array.max(axis=dim) # type: ignore 31 | return torch.max(array, dim).values 32 | 33 | 34 | def nt_min(array: NUMPY_TORCH, dim: int) -> NUMPY_TORCH: 35 | """Min function compatible for numpy ndarray and torch tensor. 36 | 37 | Args: 38 | array (NUMPY_TORCH): 39 | 40 | Returns: 41 | NUMPY_TORCH: _description_ 42 | """ 43 | if is_numpy(array): 44 | return array.min(axis=dim) # type: ignore 45 | return torch.min(array, dim).values 46 | -------------------------------------------------------------------------------- /src/types/flow_patch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from dataclasses import dataclass 3 | from typing import Any, List, Optional 4 | 5 | import numpy as np 6 | 7 | 8 | @dataclass 9 | class FlowPatch: 10 | """Dataclass for flow patch""" 11 | 12 | # center of coordinates 13 | x: np.int16 # height 14 | y: np.int16 # width 15 | shape: tuple # (height, width) 16 | # flow (pixel displacement) value of the flow at the location 17 | u: float = 0.0 # height 18 | v: float = 0.0 # width 19 | 20 | @property 21 | def h(self) -> int: 22 | return self.shape[0] 23 | 24 | @property 25 | def w(self) -> int: 26 | return self.shape[1] 27 | 28 | @property 29 | def x_min(self) -> int: 30 | return int(self.x - np.ceil(self.h / 2)) 31 | 32 | @property 33 | def x_max(self) -> int: 34 | return int(self.x + np.floor(self.h / 2)) 35 | 36 | @property 37 | def y_min(self) -> int: 38 | return int(self.y - np.ceil(self.w / 2)) 39 | 40 | @property 41 | def y_max(self) -> int: 42 | return int(self.y + np.floor(self.w / 2)) 43 | 44 | @property 45 | def position(self) -> np.ndarray: 46 | return np.array([self.x, self.y]) 47 | 48 | @property 49 | def flow(self) -> np.ndarray: 50 | return np.array([self.u, self.v]) 51 | 52 | def update_flow(self, u: float, v: float): 53 | self.u = u 54 | self.v = v 55 | 56 | def new_ones(self): 57 | return np.ones(self.shape) 58 | 59 | def copy(self) -> Any: 60 | return copy.deepcopy(self) 61 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .event_utils import crop_event, generate_events, set_event_origin_to_zero, undistort_events 2 | from .flow_utils import ( 3 | calculate_flow_error_numpy, 4 | calculate_flow_error_tensor, 5 | construct_dense_flow_voxel_numpy, 6 | construct_dense_flow_voxel_torch, 7 | estimate_corresponding_gt_flow, 8 | generate_dense_optical_flow, 9 | inviscid_burger_flow_to_voxel_numpy, 10 | inviscid_burger_flow_to_voxel_torch, 11 | upwind_flow_to_voxel_numpy, 12 | upwind_flow_to_voxel_torch, 13 | ) 14 | from .misc import ( 15 | SingleThreadInMemoryStorage, 16 | check_file_utils, 17 | check_key_and_bool, 18 | fetch_runtime_information, 19 | fix_random_seed, 20 | profile, 21 | ) 22 | from .stat_utils import SobelTorch 23 | -------------------------------------------------------------------------------- /src/utils/event_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | 6 | from ..types import FLOAT_TORCH, NUMPY_TORCH, is_numpy, is_torch 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | try: 11 | import torch 12 | except ImportError: 13 | e = "Torch is disabled." 14 | logger.warning(e) 15 | 16 | 17 | # Simulator module 18 | def generate_events( 19 | n_events: int, 20 | height: int, 21 | width: int, 22 | tmin: float = 0.0, 23 | tmax: float = 0.5, 24 | dist: str = "uniform", 25 | ) -> np.ndarray: 26 | """Generate random events. 27 | 28 | Args: 29 | n_events (int) ... num of events 30 | height (int) ... height of the camera 31 | width (int) ... width of the camera 32 | tmin (float) ... timestamp min 33 | tmax (float) ... timestamp max 34 | dist (str) ... currently only "uniform" is supported. 35 | 36 | Returns: 37 | events (np.ndarray) ... [n_events x 4] numpy array. (x, y, t, p) 38 | x indicates height direction. 39 | """ 40 | x = np.random.randint(0, height, n_events) 41 | y = np.random.randint(0, width, n_events) 42 | t = np.random.uniform(tmin, tmax, n_events) 43 | t = np.sort(t) 44 | p = np.random.randint(0, 2, n_events) 45 | 46 | events = np.concatenate([x[..., None], y[..., None], t[..., None], p[..., None]], axis=1) 47 | return events 48 | 49 | 50 | def crop_event(events: NUMPY_TORCH, x0: int, x1: int, y0: int, y1: int) -> NUMPY_TORCH: 51 | """Crop events. 52 | 53 | Args: 54 | events (NUMPY_TORCH): [n x 4]. [x, y, t, p]. 55 | x0 (int): Start of the crop, at row[0] 56 | x1 (int): End of the crop, at row[0] 57 | y0 (int): Start of the crop, at row[1] 58 | y1 (int): End of the crop, at row[1] 59 | 60 | Returns: 61 | NUMPY_TORCH: Cropped events. 62 | """ 63 | mask = ( 64 | (x0 <= events[..., 0]) 65 | * (events[..., 0] < x1) 66 | * (y0 <= events[..., 1]) 67 | * (events[..., 1] < y1) 68 | ) 69 | cropped = events[mask] 70 | return cropped 71 | 72 | 73 | def set_event_origin_to_zero(events: np.ndarray, x0: int, y0: int, t0: float = 0.0) -> np.ndarray: 74 | """Set each origin of each row to 0. 75 | 76 | Args: 77 | events (np.ndarray): [n x 4]. [x, y, t, p]. 78 | x0 (int): x origin 79 | y0 (int): y origin 80 | t0 (float): t origin 81 | 82 | Returns: 83 | np.ndarray: [n x 4]. x is in [0, xmax - x0], and so on. 84 | """ 85 | basis = np.array([x0, y0, t0, 0.0]) 86 | if is_torch(events): 87 | basis = torch.from_numpy(basis) 88 | return events - basis 89 | 90 | 91 | def undistort_events(events, map_x, map_y, h, w): 92 | """Undistort (rectify) events. 93 | Args: 94 | events ... [x, y, t, p]. X is height direction. 95 | map_x, map_y... meshgrid 96 | 97 | Returns: 98 | events... events that is in the camera plane after undistortion. 99 | TODO check overflow 100 | """ 101 | # k = np.int32(map_y[np.int16(events[:, 1]), np.int16(events[:, 0])]) 102 | # l = np.int32(map_x[np.int16(events[:, 1]), np.int16(events[:, 0])]) 103 | # k = np.int32(map_y[events[:, 1].astype(np.int32), events[:, 0].astype(np.int32)]) 104 | # l = np.int32(map_x[events[:, 1].astype(np.int32), events[:, 0].astype(np.int32)]) 105 | # undistort_events = np.copy(events) 106 | # undistort_events[:, 0] = l 107 | # undistort_events[:, 1] = k 108 | # return undistort_events[((0 <= k) & (k < h)) & ((0 <= l) & (l < w))] 109 | 110 | k = np.int32(map_y[events[:, 0].astype(np.int32), events[:, 1].astype(np.int32)]) 111 | l = np.int32(map_x[events[:, 0].astype(np.int32), events[:, 1].astype(np.int32)]) 112 | undistort_events = np.copy(events) 113 | undistort_events[:, 0] = k 114 | undistort_events[:, 1] = l 115 | return undistort_events[((0 <= k) & (k < h)) & ((0 <= l) & (l < w))] 116 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import cProfile 3 | import logging 4 | import os 5 | import pstats 6 | import random 7 | import subprocess 8 | from functools import wraps 9 | from typing import Dict 10 | 11 | import numpy as np 12 | import optuna 13 | import torch 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def fix_random_seed(seed=46) -> None: 19 | """Fix random seed""" 20 | logger.info("Fix random Seed: ", seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | 27 | def check_file_utils(filename: str) -> bool: 28 | """Return True if the file exists. 29 | 30 | Args: 31 | filename (str): _description_ 32 | 33 | Returns: 34 | bool: _description_ 35 | """ 36 | logger.debug(f"Check {filename}") 37 | res = os.path.exists(filename) 38 | if not res: 39 | logger.warning(f"{filename} does not exist!") 40 | return res 41 | 42 | 43 | def check_key_and_bool(config: dict, key: str) -> bool: 44 | """Check the existance of the key and if it's True 45 | 46 | Args: 47 | config (dict): dict. 48 | key (str): Key name to be checked. 49 | 50 | Returns: 51 | bool: Return True only if the key exists in the dict and its value is True. 52 | Otherwise returns False. 53 | """ 54 | return key in config.keys() and config[key] 55 | 56 | 57 | def fetch_runtime_information() -> dict: 58 | """Fetch information of the experiment at runtime. 59 | 60 | Returns: 61 | dict: _description_ 62 | """ 63 | config = {} 64 | config["commit"] = fetch_commit_id() 65 | config["server"] = get_server_name() 66 | return config 67 | 68 | 69 | def fetch_commit_id() -> str: 70 | """Get the latest commit ID of the repository. 71 | 72 | Returns: 73 | str: commit hash 74 | """ 75 | label = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip() 76 | return label.decode("utf-8") 77 | 78 | 79 | def get_server_name() -> str: 80 | """Always returns `unknown` for the public code :) 81 | 82 | Returns: 83 | str: _description_ 84 | """ 85 | return "unknown" 86 | 87 | 88 | def profile(output_file=None, sort_by="cumulative", lines_to_print=None, strip_dirs=False): 89 | """A time profiler decorator. 90 | Inspired by: http://code.activestate.com/recipes/577817-profile-decorator/ 91 | 92 | Usage: 93 | ``` 94 | @profile(output_file= ...) 95 | def your_function(): 96 | ... 97 | ``` 98 | Then you will get the profile automatically after the function call is finished. 99 | 100 | Args: 101 | output_file: str or None. Default is None 102 | Path of the output file. If only name of the file is given, it's 103 | saved in the current directory. 104 | If it's None, the name of the decorated function is used. 105 | sort_by: str or SortKey enum or tuple/list of str/SortKey enum 106 | Sorting criteria for the Stats object. 107 | For a list of valid string and SortKey refer to: 108 | https://docs.python.org/3/library/profile.html#pstats.Stats.sort_stats 109 | lines_to_print: int or None 110 | Number of lines to print. Default (None) is for all the lines. 111 | This is useful in reducing the size of the printout, especially 112 | that sorting by 'cumulative', the time consuming operations 113 | are printed toward the top of the file. 114 | strip_dirs: bool 115 | Whether to remove the leading path info from file names. 116 | This is also useful in reducing the size of the printout 117 | Returns: 118 | Profile of the decorated function 119 | """ 120 | 121 | def inner(func): 122 | @wraps(func) 123 | def wrapper(*args, **kwargs): 124 | _output_file = output_file or func.__name__ + ".prof" 125 | pr = cProfile.Profile() 126 | pr.enable() 127 | retval = func(*args, **kwargs) 128 | pr.disable() 129 | pr.dump_stats(_output_file) 130 | 131 | with open(_output_file, "w") as f: 132 | ps = pstats.Stats(pr, stream=f) 133 | if strip_dirs: 134 | ps.strip_dirs() 135 | if isinstance(sort_by, (tuple, list)): 136 | ps.sort_stats(*sort_by) 137 | else: 138 | ps.sort_stats(sort_by) 139 | ps.print_stats(lines_to_print) 140 | return retval 141 | 142 | return wrapper 143 | 144 | return inner 145 | 146 | 147 | class SingleThreadInMemoryStorage(optuna.storages.InMemoryStorage): 148 | """This is faster version of in-memory storage only when the study n_jobs = 1 (single thread). 149 | Adopted from https://github.com/optuna/optuna/issues/3151 150 | 151 | Args: 152 | optuna ([type]): [description] 153 | """ 154 | 155 | def set_trial_param( 156 | self, 157 | trial_id: int, 158 | param_name: str, 159 | param_value_internal: float, 160 | distribution: optuna.distributions.BaseDistribution, 161 | ) -> None: 162 | with self._lock: 163 | trial = self._get_trial(trial_id) 164 | self.check_trial_is_updatable(trial_id, trial.state) 165 | 166 | study_id = self._trial_id_to_study_id_and_number[trial_id][0] 167 | # Check param distribution compatibility with previous trial(s). 168 | if param_name in self._studies[study_id].param_distribution: 169 | optuna.distributions.check_distribution_compatibility( 170 | self._studies[study_id].param_distribution[param_name], distribution 171 | ) 172 | # Set param distribution. 173 | self._studies[study_id].param_distribution[param_name] = distribution 174 | 175 | # Set param. 176 | trial.params[param_name] = distribution.to_external_repr(param_value_internal) 177 | trial.distributions[param_name] = distribution 178 | -------------------------------------------------------------------------------- /src/utils/stat_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import numpy as np 5 | import scipy 6 | import scipy.fftpack 7 | import torch 8 | from torch import nn 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class SobelTorch(nn.Module): 14 | """Sobel operator for pytorch, for divergence calculation. 15 | This is equivalent implementation of 16 | ``` 17 | sobelx = cv2.Sobel(flow[0], cv2.CV_64F, 1, 0, ksize=3) 18 | sobely = cv2.Sobel(flow[1], cv2.CV_64F, 0, 1, ksize=3) 19 | dxy = (sobelx + sobely) / 8.0 20 | ``` 21 | Args: 22 | ksize (int) ... Kernel size of the convolution operation. 23 | in_channels (int) ... In channles. 24 | cuda_available (bool) ... True if cuda is available. 25 | """ 26 | 27 | def __init__( 28 | self, ksize: int = 3, in_channels: int = 2, cuda_available: bool = False, precision="32" 29 | ): 30 | super().__init__() 31 | self.cuda_available = cuda_available 32 | self.in_channels = in_channels 33 | self.filter_dx = nn.Conv2d( 34 | in_channels=in_channels, 35 | out_channels=1, 36 | kernel_size=ksize, 37 | stride=1, 38 | padding=1, 39 | bias=False, 40 | ) 41 | self.filter_dy = nn.Conv2d( 42 | in_channels=in_channels, 43 | out_channels=1, 44 | kernel_size=ksize, 45 | stride=1, 46 | padding=1, 47 | bias=False, 48 | ) 49 | # x in height direction 50 | if precision == "64": 51 | Gx = torch.tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]).double() 52 | Gy = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]).double() 53 | else: 54 | Gx = torch.tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]) 55 | Gy = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]) 56 | 57 | if self.cuda_available: 58 | Gx = Gx.cuda() 59 | Gy = Gy.cuda() 60 | 61 | self.filter_dx.weight = nn.Parameter(Gx.unsqueeze(0).unsqueeze(0), requires_grad=False) 62 | self.filter_dy.weight = nn.Parameter(Gy.unsqueeze(0).unsqueeze(0), requires_grad=False) 63 | 64 | def forward(self, img): 65 | """ 66 | Args: 67 | img (torch.Tensor) ... [b x (2 or 1) x H x W]. The 2 ch is [h, w] direction. 68 | 69 | Returns: 70 | sobel (torch.Tensor) ... [b x (4 or 2) x (H - 2) x (W - 2)]. 71 | 4ch means Sobel_x on xdim, Sobel_y on ydim, Sobel_x on ydim, and Sobel_y on xdim. 72 | To make it divergence, run `(sobel[:, 0] + sobel[:, 1]) / 8.0`. 73 | """ 74 | if self.in_channels == 2: 75 | dxx = self.filter_dx(img[..., [0], :, :]) 76 | dyy = self.filter_dy(img[..., [1], :, :]) 77 | dyx = self.filter_dx(img[..., [1], :, :]) 78 | dxy = self.filter_dy(img[..., [0], :, :]) 79 | return torch.cat([dxx, dyy, dyx, dxy], dim=1) 80 | elif self.in_channels == 1: 81 | dx = self.filter_dx(img[..., [0], :, :]) 82 | dy = self.filter_dy(img[..., [0], :, :]) 83 | return torch.cat([dx, dy], dim=1) 84 | -------------------------------------------------------------------------------- /src/visualizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Any, Dict, List, Optional 4 | 5 | import cv2 6 | import numpy as np 7 | import plotly.graph_objects as go 8 | from matplotlib import pyplot as plt 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | from PIL import Image, ImageDraw 14 | 15 | from . import event_image_converter, types, warp 16 | 17 | TRANSPARENCY = 0.25 # Degree of transparency, 0-100% 18 | OPACITY = int(255 * TRANSPARENCY) 19 | 20 | 21 | class Visualizer: 22 | """Visualization class for multi utility. It includes visualization of 23 | - Events (polarity-based or event-based, 2D or 3D, etc...) 24 | - Images 25 | - Optical flow 26 | - Optimization history, loss function 27 | - Matplotlib figure 28 | etc. 29 | Also it works generic for if it saves and/or shows the figures. 30 | 31 | Args: 32 | image_shape (tuple) ... [H, W]. Image shape is necessary to visualize events. 33 | show (bool) ... It True, it shows the visualization results when any fuction is called. 34 | save (bool) ... It True, it saves the results under `save_dir` without any duplication. 35 | save_dir (str) ... Applicable when `save` is True. The root directory for the save. 36 | 37 | """ 38 | 39 | def __init__(self, image_shape: tuple, show=False, save=False, save_dir=None) -> None: 40 | super().__init__() 41 | self.update_image_shape(image_shape) 42 | self._show = show 43 | self._save = save 44 | if save_dir is None: 45 | save_dir = "./" 46 | self.update_save_dir(save_dir) 47 | self.default_prefix = "" # default file prefix 48 | self.default_save_count = 0 # default save count 49 | self.prefixed_save_count: Dict[str, int] = {} 50 | 51 | def update_image_shape(self, image_shape): 52 | self._image_size = image_shape # H, W 53 | self._image_height = image_shape[0] 54 | self._image_width = image_shape[1] 55 | self.imager = event_image_converter.EventImageConverter(image_shape) 56 | 57 | def update_save_dir(self, new_dir: str) -> None: 58 | """Update save directiry. Creates it if not exist. 59 | 60 | Args: 61 | new_dir (str): New directory 62 | """ 63 | self.save_dir = new_dir 64 | if not os.path.exists(self.save_dir): 65 | os.makedirs(self.save_dir) 66 | 67 | def get_filename_from_prefix( 68 | self, prefix: Optional[str] = None, file_format: str = "png" 69 | ) -> str: 70 | """Helper function: returns expected filename from the prefix. 71 | It makes sure to save the output filename without any duplication. 72 | 73 | Args: 74 | prefix (Optional[str], optional): Prefix. Defaults to None. 75 | format (str) ... file format. Defaults to png. 76 | 77 | Returns: 78 | str: ${save_dir}/{prefix}{count}.png. Count automatically goes up. 79 | """ 80 | if prefix is None or prefix == "": 81 | file_name = os.path.join( 82 | self.save_dir, f"{self.default_prefix}{self.default_save_count}.{file_format}" 83 | ) 84 | self.default_save_count += 1 85 | else: 86 | try: 87 | self.prefixed_save_count[prefix] += 1 88 | except KeyError: 89 | self.prefixed_save_count[prefix] = 0 90 | file_name = os.path.join( 91 | self.save_dir, f"{prefix}{self.prefixed_save_count[prefix]}.{file_format}" 92 | ) 93 | return file_name 94 | 95 | def rollback_save_count(self, prefix: Optional[str] = None): 96 | """Helper function: 97 | # hack - neeeds to be consistent number between .png and .npy 98 | 99 | Args: 100 | prefix (Optional[str], optional): Prefix. Defaults to None. 101 | """ 102 | if prefix is None or prefix == "": 103 | self.default_save_count -= 1 104 | else: 105 | try: 106 | self.prefixed_save_count[prefix] -= 1 107 | except KeyError: 108 | raise ValueError("The visualization save count error") 109 | 110 | def reset_save_count(self, file_prefix: Optional[str] = None): 111 | if file_prefix is None or file_prefix == "": 112 | self.default_save_count = 0 113 | elif file_prefix == "all": 114 | self.default_save_count = 0 115 | self.prefixed_save_count = {} 116 | else: 117 | del self.prefixed_save_count[file_prefix] 118 | 119 | def _show_or_save_image( 120 | self, image: Any, file_prefix: Optional[str] = None, fixed_file_name: Optional[str] = None 121 | ): 122 | """Helper function - save and/or show the image. 123 | 124 | Args: 125 | image (Any): PIL.Image 126 | file_prefix (Optional[str], optional): [description]. Defaults to None. 127 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`. 128 | """ 129 | if self._show: 130 | if image.mode == "RGBA": 131 | image = image.convert("RGB") # Back to RGB 132 | image.show() 133 | if self._save: 134 | if image.mode == "RGBA": 135 | image = image.convert("RGB") # Back to RGB 136 | if fixed_file_name is not None: 137 | image.save(os.path.join(self.save_dir, f"{fixed_file_name}.png")) 138 | else: 139 | image.save(self.get_filename_from_prefix(file_prefix)) 140 | 141 | # Image related 142 | def load_image(self, image: Any) -> Image.Image: 143 | """A wrapper function to get image and returns PIL Image object. 144 | 145 | Args: 146 | image (str or np.ndarray): If it is str, open and load the image. 147 | If it is numpy array, it converts to PIL.Image. 148 | 149 | Returns: 150 | Image.Image: PIl Image object. 151 | """ 152 | if type(image) == str: 153 | image = Image.open(image) 154 | elif type(image) == np.ndarray: 155 | image = Image.fromarray(image) 156 | return image 157 | 158 | def visualize_image(self, image: Any, file_prefix: Optional[str] = None) -> Image.Image: 159 | """Visualize image. 160 | 161 | Args: 162 | image (Any): str, np.ndarray, or PIL Image. 163 | file_prefix (Optional[str], optional): [description]. Defaults to None. 164 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`. 165 | 166 | Returns: 167 | Image.Image: PIL Image object 168 | """ 169 | image = self.load_image(image) 170 | self._show_or_save_image(image, file_prefix) 171 | return image 172 | 173 | def create_clipped_iwe_for_visualization(self, events, max_scale=50): 174 | """Utility function for clipped IWE. Same one in solver. 175 | 176 | Args: 177 | events (_type_): _description_ 178 | max_scale (int, optional): _description_. Defaults to 50. 179 | 180 | Returns: 181 | _type_: _description_ 182 | """ 183 | im = self.imager.create_image_from_events_numpy(events, method="bilinear_vote", sigma=0) 184 | clipped_iwe = 255 - np.clip(max_scale * im, 0, 255).astype(np.uint8) 185 | return clipped_iwe 186 | 187 | # Optical flow 188 | def visualize_optical_flow( 189 | self, 190 | flow_x: np.ndarray, 191 | flow_y: np.ndarray, 192 | visualize_color_wheel: bool = True, 193 | file_prefix: Optional[str] = None, 194 | save_flow: bool = False, 195 | ord: float = 0.5, 196 | ): 197 | """Visualize optical flow. 198 | Args: 199 | flow_x (numpy.ndarray) ... [H x W], height direction. 200 | flow_y (numpy.ndarray) ... [H x W], width direction. 201 | visualize_color_wheel (bool) ... If True, it also visualizes the color wheel (legend for OF). 202 | file_prefix (Optional[str], optional): [description]. Defaults to None. 203 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`. 204 | 205 | Returns: 206 | image (PIL.Image) ... PIL image. 207 | """ 208 | if save_flow: 209 | save_name = self.get_filename_from_prefix(file_prefix).replace("png", "npy") 210 | np.save(save_name, np.stack([flow_x, flow_y], axis=0)) 211 | self.rollback_save_count(file_prefix) 212 | flow_rgb, color_wheel, _ = self.color_optical_flow(flow_x, flow_y, ord=ord) 213 | image = Image.fromarray(flow_rgb) 214 | self._show_or_save_image(image, file_prefix) 215 | 216 | if visualize_color_wheel: 217 | wheel = Image.fromarray(color_wheel) 218 | self._show_or_save_image(wheel, fixed_file_name="color_wheel") 219 | return image 220 | 221 | # Combined with events 222 | def visualize_overlay_optical_flow_on_event( 223 | self, 224 | flow: np.ndarray, 225 | events: np.ndarray, 226 | file_prefix: Optional[str] = None, 227 | ord: float = 0.5, 228 | ): 229 | """Visualize optical flow on event data. 230 | Args: 231 | flow (numpy.ndarray) ... [2 x H x W] 232 | events (np.ndarray) ... event_image (H x W) or raw events (n_events x 4). 233 | file_prefix (Optional[str], optional): [description]. Defaults to None. 234 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`. 235 | 236 | Returns: 237 | image (PIL.Image) ... PIL image. 238 | """ 239 | _show, _save = self._show, self._save 240 | self._show, self._save = False, False 241 | flow_image = self.visualize_optical_flow(flow[0], flow[1], ord=ord) 242 | flow_ratio = 0.8 243 | flow_image.putalpha(int(255 * flow_ratio)) 244 | if events.shape[1] == 4: # raw events 245 | event_image = self.visualize_event(events, grayscale=False).convert("RGB") 246 | else: 247 | event_image = self.visualize_image(events).convert("RGB") 248 | event_image.putalpha(255 - int(255 * flow_ratio)) 249 | flow_image.paste(event_image, None, event_image) 250 | self._show, self._save = _show, _save 251 | self._show_or_save_image(flow_image, file_prefix) 252 | return flow_image 253 | 254 | def visualize_optical_flow_on_event_mask( 255 | self, 256 | flow: np.ndarray, 257 | events: np.ndarray, 258 | file_prefix: Optional[str] = None, 259 | ord: float = 0.5, 260 | max_color_on_mask: bool = True, 261 | ): 262 | """Visualize optical flow only where event exists. 263 | Args: 264 | flow (numpy.ndarray) ... [2 x H x W] 265 | events (np.ndarray) ... [n_events x 4] 266 | file_prefix (Optional[str], optional): [description]. Defaults to None. 267 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`. 268 | 269 | max_color_on_mask (bool) ... If True, the max magnitude is based on the masked flow. If False, it is based on the raw (dense) flow. 270 | 271 | Returns: 272 | image (PIL.Image) ... PIL image. 273 | """ 274 | _show, _save = self._show, self._save 275 | self._show, self._save = False, False 276 | mask = self.imager.create_eventmask(events) 277 | if max_color_on_mask: 278 | masked_flow = flow * mask 279 | image = self.visualize_optical_flow( 280 | masked_flow[0], 281 | masked_flow[1], 282 | visualize_color_wheel=False, 283 | file_prefix=file_prefix, 284 | ord=ord, 285 | ) 286 | else: 287 | image = self.visualize_optical_flow( 288 | flow[0], flow[1], visualize_color_wheel=False, file_prefix=file_prefix, ord=ord 289 | ) 290 | mask = Image.fromarray((~mask)[0]).convert("1") 291 | white = Image.new("RGB", image.size, (255, 255, 255)) 292 | masked_flow = Image.composite(white, image, mask) 293 | self._show, self._save = _show, _save 294 | self._show_or_save_image(masked_flow, file_prefix) 295 | return masked_flow 296 | 297 | def visualize_optical_flow_pred_and_gt( 298 | self, 299 | flow_pred: np.ndarray, 300 | flow_gt: np.ndarray, 301 | visualize_color_wheel: bool = True, 302 | pred_file_prefix: Optional[str] = None, 303 | gt_file_prefix: Optional[str] = None, 304 | ord: float = 0.5, 305 | ): 306 | """Visualize optical flow both pred and GT. 307 | Args: 308 | flow_pred (numpy.ndarray) ... [2, H x W] 309 | flow_gt (numpy.ndarray) ... [2, H x W] 310 | visualize_color_wheel (bool) ... If True, it also visualizes the color wheel (legend for OF). 311 | file_prefix (Optional[str], optional): [description]. Defaults to None. 312 | If specified, the save location will be `save_dir/{prefix}_{unique}.png`. 313 | 314 | Returns: 315 | image (PIL.Image) ... PIL image. 316 | """ 317 | # get largest magnitude in both pred and gt 318 | _, _, max_pred = self.color_optical_flow(flow_pred[0], flow_pred[1], ord=ord) 319 | _, _, max_gt = self.color_optical_flow(flow_gt[0], flow_gt[1], ord=ord) 320 | max_magnitude = np.max([max_pred, max_gt]) 321 | color_pred, _, _ = self.color_optical_flow( 322 | flow_pred[0], flow_pred[1], max_magnitude, ord=ord 323 | ) 324 | color_gt, color_wheel, _ = self.color_optical_flow( 325 | flow_gt[0], flow_gt[1], max_magnitude, ord=ord 326 | ) 327 | 328 | image = Image.fromarray(color_pred) 329 | self._show_or_save_image(image, pred_file_prefix) 330 | image = Image.fromarray(color_gt) 331 | self._show_or_save_image(image, gt_file_prefix) 332 | if visualize_color_wheel: 333 | wheel = Image.fromarray(color_wheel) 334 | self._show_or_save_image(wheel, fixed_file_name="color_wheel") 335 | 336 | def color_optical_flow( 337 | self, flow_x: np.ndarray, flow_y: np.ndarray, max_magnitude=None, ord=1.0 338 | ): 339 | """Color optical flow. 340 | Args: 341 | flow_x (numpy.ndarray) ... [H x W], height direction. 342 | flow_y (numpy.ndarray) ... [H x W], width direction. 343 | max_magnitude (float, optional) ... Max magnitude used for the colorization. Defaults to None. 344 | ord (float) ... 1: our usual, 0.5: DSEC colorinzing. 345 | 346 | Returns: 347 | flow_rgb (np.ndarray) ... [W, H] 348 | color_wheel (np.ndarray) ... [H, H] color wheel 349 | max_magnitude (float) ... max magnitude of the flow. 350 | """ 351 | flows = np.stack((flow_x, flow_y), axis=2) 352 | flows[np.isinf(flows)] = 0 353 | flows[np.isnan(flows)] = 0 354 | mag = np.linalg.norm(flows, axis=2) ** ord 355 | ang = (np.arctan2(flow_y, flow_x) + np.pi) * 180.0 / np.pi / 2.0 356 | ang = ang.astype(np.uint8) 357 | hsv = np.zeros([flow_x.shape[0], flow_x.shape[1], 3], dtype=np.uint8) 358 | hsv[:, :, 0] = ang 359 | hsv[:, :, 1] = 255 360 | if max_magnitude is None: 361 | max_magnitude = mag.max() 362 | hsv[:, :, 2] = (255 * mag / max_magnitude).astype(np.uint8) 363 | # hsv[:, :, 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 364 | flow_rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) 365 | 366 | # Color wheel 367 | hsv = np.zeros([flow_x.shape[0], flow_x.shape[0], 3], dtype=np.uint8) 368 | xx, yy = np.meshgrid( 369 | np.linspace(-1, 1, flow_x.shape[0]), np.linspace(-1, 1, flow_x.shape[0]) 370 | ) 371 | mag = np.linalg.norm(np.stack((xx, yy), axis=2), axis=2) 372 | # ang = (np.arctan2(yy, xx) + np.pi) * 180 / np.pi / 2.0 373 | ang = (np.arctan2(xx, yy) + np.pi) * 180 / np.pi / 2.0 374 | hsv[:, :, 0] = ang.astype(np.uint8) 375 | hsv[:, :, 1] = 255 376 | hsv[:, :, 2] = (255 * mag / mag.max()).astype(np.uint8) 377 | # hsv[:, :, 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 378 | color_wheel = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) 379 | 380 | return flow_rgb, color_wheel, max_magnitude 381 | 382 | # Event related 383 | def visualize_event( 384 | self, 385 | events: Any, 386 | grayscale: bool = True, 387 | background_color: int = 127, 388 | ignore_polarity: bool = False, 389 | file_prefix: Optional[str] = None, 390 | ) -> Image.Image: 391 | """Visualize event as image. 392 | # TODO the function is messy - cleanup. 393 | 394 | Args: 395 | events (Any): [description] 396 | grayscale (bool, optional): [description]. Defaults to True. 397 | background_color: int = 127: Background color when events are none 398 | backround (int, optional): Only effective when grayscale is True. Defaults to 127. If non-grayscale, it is 255. 399 | ignore_polarity (bool, optional): If true, crerate polarity-ignore image. Defaults to False. 400 | 401 | Returns: 402 | Optional[Image.Image]: [description] 403 | """ 404 | if grayscale: 405 | image = np.ones((self._image_size[0], self._image_size[1])) 406 | else: 407 | background_color = 255 408 | image = ( 409 | np.ones((self._image_size[0], self._image_size[1], 3), dtype=np.uint8) 410 | * background_color 411 | ) # RGBA channel 412 | 413 | # events = events[0 <= events[:, 0] < self._image_size[0]] 414 | # events = events[events[:, 1] < self._image_size[1]] 415 | events[:, 0] = np.clip(events[:, 0], 0, self._image_size[0] - 1) 416 | events[:, 1] = np.clip(events[:, 1], 0, self._image_size[1] - 1) 417 | if grayscale: 418 | indice = (events[:, 0].astype(np.int32), events[:, 1].astype(np.int32)) 419 | if ignore_polarity: 420 | np.add.at(image, indice, np.ones_like(events[:, 3], dtype=np.int16)) 421 | else: 422 | if np.min(events[:, 3]) == 0: 423 | pol = events[:, 3] * 2 - 1 424 | else: 425 | pol = events[:, 3] 426 | np.add.at(image, indice, pol) 427 | return self.visualize_event_image(image, background_color) 428 | else: 429 | colors = np.array([(255, 0, 0) if e[3] == 1 else (0, 0, 255) for e in events]) 430 | image[events[:, 0].astype(np.int32), events[:, 1].astype(np.int32), :] = colors 431 | 432 | image = Image.fromarray(image) 433 | self._show_or_save_image(image, file_prefix) 434 | return image 435 | 436 | def save_array( 437 | self, 438 | array: np.ndarray, 439 | file_prefix: Optional[str] = None, 440 | new_prefix: bool = False, 441 | ) -> None: 442 | """Helper function to save numpy array. It belongs to this visualizer class 443 | because it associates with the naming rule of visualized files. 444 | 445 | Args: 446 | array (np.ndarray): Numpy array to save. 447 | file_prefix (Optional[str]): Prefix of the file. Defaults to None. 448 | new_prefix (bool): If True, rollback_save_count is skipped. Set to True if 449 | there is no correspondng .png file with the prefix. bDefaults to False. 450 | 451 | Returns: 452 | Optional[Image.Image]: [description] 453 | """ 454 | save_name = self.get_filename_from_prefix(file_prefix).replace("png", "npy") 455 | np.save(save_name, array) 456 | if not new_prefix: 457 | self.rollback_save_count(file_prefix) 458 | 459 | def visualize_event_image( 460 | self, eventimage: np.ndarray, background_color: int = 255, file_prefix: Optional[str] = None 461 | ) -> Image.Image: 462 | """Visualize event on white image""" 463 | # max_value_abs = np.max(np.abs(eventimage)) 464 | # eventimage = (255 * eventimage / (2 * max_value_abs) + 127).astype(np.uint8) 465 | background = eventimage == 0 466 | eventimage = ( 467 | 255 * (eventimage - eventimage.min()) / (eventimage.max() - eventimage.min()) 468 | ).astype(np.uint8) 469 | if background_color == 255: 470 | eventimage = 255 - eventimage 471 | else: 472 | eventimage[background] = background_color 473 | eventimage = Image.fromarray(eventimage) 474 | self._show_or_save_image(eventimage, file_prefix) 475 | return eventimage 476 | 477 | # Scipy history visualizer 478 | def visualize_scipy_history(self, cost_history: dict, cost_weight: Optional[dict] = None): 479 | """Visualizing scipy optimization history. 480 | 481 | Args: 482 | cost_history (dict): [description] 483 | """ 484 | plt.figure() 485 | for k in cost_history.keys(): 486 | if k == "loss" or cost_weight is None: 487 | plt.plot(np.array(cost_history[k]), label=k) 488 | else: 489 | plt.plot(np.array(cost_history[k]) * cost_weight[k], label=k) 490 | plt.legend() 491 | if self._save: 492 | plt.savefig(self.get_filename_from_prefix("optimization_steps")) 493 | if self._show: 494 | plt.show(block=False) 495 | plt.close() 496 | -------------------------------------------------------------------------------- /tests/costs/test_gradient_magnitude.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from src import utils 6 | from src.costs import GradientMagnitude 7 | from src.event_image_converter import EventImageConverter 8 | from src.warp import Warp 9 | 10 | 11 | # minimum is different 12 | def test_calculate_store_history(): 13 | size = (260, 346) 14 | events = utils.generate_events(1000, size[0], size[1], tmin=0.1, tmax=0.9).astype(np.float32) 15 | events_torch = torch.from_numpy(events) 16 | imager = EventImageConverter(size) 17 | 18 | cost = GradientMagnitude(direction="minimize", store_history=True) 19 | # Calculate numpy 20 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", weight=1.0, sigma=0) 21 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 22 | 23 | # Calculate torch 24 | count = imager.create_image_from_events_tensor(events_torch, "bilinear_vote", weight=1.0) 25 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 26 | 27 | history = cost.get_history() 28 | assert len(history["loss"]) == 2 29 | 30 | 31 | def test_calculate_not_store_history(): 32 | size = (260, 346) 33 | events = utils.generate_events(1000, size[0], size[1], tmin=0.1, tmax=0.9).astype(np.float32) 34 | events_torch = torch.from_numpy(events) 35 | imager = EventImageConverter(size) 36 | 37 | cost = GradientMagnitude(direction="minimize", store_history=False) 38 | # Calculate numpy 39 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", weight=1.0, sigma=0) 40 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 41 | 42 | # Calculate torch 43 | count = imager.create_image_from_events_tensor(events_torch, "bilinear_vote", weight=1.0) 44 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 45 | 46 | history = cost.get_history() 47 | assert len(history["loss"]) == 0 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "direction,is_small", 52 | [["natural", True], ["minimize", False], ["maximize", True]], 53 | ) 54 | def test_calculate_blur_is_small(direction, is_small): 55 | size = (10, 40) 56 | imager = EventImageConverter(size) 57 | cost = GradientMagnitude(direction=direction, store_history=False) 58 | 59 | events = np.array( 60 | [ 61 | [5.0, 10.0], 62 | [8.0, 3.0], 63 | [2.0, 2.0], 64 | ] 65 | ) 66 | grad_blur = cost.calculate({"iwe": imager.create_iwe(events), "omit_boundary": False}) 67 | events = np.array( 68 | [ 69 | [5.0, 10.0], 70 | [5.0, 10.0], 71 | [2.0, 2.0], 72 | ] 73 | ) 74 | grad_sharp = cost.calculate({"iwe": imager.create_iwe(events), "omit_boundary": False}) 75 | 76 | assert (grad_blur < grad_sharp) == is_small 77 | 78 | 79 | def test_calculate_np_torch(): 80 | size = (10, 20) 81 | imager = EventImageConverter(size) 82 | cost = GradientMagnitude(direction="natural", store_history=False) 83 | 84 | events = np.array( 85 | [ 86 | [12.0, 10.0], 87 | [8.0, 3.0], 88 | [2.0, 2.0], 89 | ] 90 | ).astype(np.float32) 91 | events_torch = torch.from_numpy(events) 92 | var_blur_numpy = cost.calculate( 93 | {"iwe": imager.create_iwe(events, sigma=0), "omit_boundary": True} 94 | ) 95 | var_blur_torch = cost.calculate( 96 | {"iwe": imager.create_iwe(events_torch, sigma=0), "omit_boundary": True} 97 | ) 98 | 99 | np.testing.assert_allclose(var_blur_torch.item(), var_blur_numpy, rtol=1e-2) 100 | -------------------------------------------------------------------------------- /tests/costs/test_hybrid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL.Image import Image 5 | 6 | from src import utils 7 | from src.costs import HybridCost, ImageVariance 8 | from src.event_image_converter import EventImageConverter 9 | from src.warp import Warp 10 | 11 | 12 | # minimum is different 13 | def test_hybrid_cost_store_history(): 14 | size = (20, 34) 15 | imager = EventImageConverter(size) 16 | cost_with_weight = {"image_variance": 1.0, "gradient_magnitude": 2.4} 17 | 18 | cost = HybridCost(direction="minimize", cost_with_weight=cost_with_weight, store_history=True) 19 | variance = ImageVariance(store_history=True) 20 | 21 | # Calculate numpy 22 | events = utils.generate_events(1000, size[0], size[1]) 23 | count = imager.create_image_from_events_numpy(events, sigma=0) 24 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 25 | _ = variance.calculate({"iwe": count, "omit_boundary": True}) 26 | events = utils.generate_events(1000, size[0], size[1]) 27 | count = imager.create_image_from_events_numpy(events, sigma=0) 28 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 29 | _ = variance.calculate({"iwe": count, "omit_boundary": True}) 30 | events = utils.generate_events(1000, size[0], size[1]) 31 | count = imager.create_image_from_events_numpy(events, sigma=0) 32 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 33 | _ = variance.calculate({"iwe": count, "omit_boundary": True}) 34 | 35 | history = cost.get_history() 36 | keys = ["loss", "image_variance", "gradient_magnitude"] 37 | assert history.keys() == set(keys) 38 | for k in keys: 39 | assert len(history[k]) == 3 40 | np.testing.assert_allclose( 41 | history["image_variance"], variance.get_history()["loss"], rtol=1e-5, atol=1e-5 42 | ) 43 | 44 | 45 | def test_hybrid_cost_without_store_history(): 46 | size = (20, 34) 47 | imager = EventImageConverter(size) 48 | cost_with_weight = {"image_variance": 1.0, "gradient_magnitude": 2.4} 49 | 50 | cost = HybridCost(direction="minimize", cost_with_weight=cost_with_weight, store_history=False) 51 | 52 | # Calculate numpy 53 | events = utils.generate_events(1000, size[0], size[1]) 54 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", sigma=0) 55 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 56 | events = utils.generate_events(1000, size[0], size[1]) 57 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", sigma=0) 58 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 59 | events = utils.generate_events(1000, size[0], size[1]) 60 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", sigma=0) 61 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 62 | 63 | history = cost.get_history() 64 | keys = ["loss", "image_variance", "gradient_magnitude"] 65 | assert history.keys() == set(keys) 66 | for k in keys: 67 | assert len(history[k]) == 0 68 | -------------------------------------------------------------------------------- /tests/costs/test_image_variance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from src import utils 6 | from src.costs import ImageVariance 7 | from src.event_image_converter import EventImageConverter 8 | from src.warp import Warp 9 | 10 | 11 | # minimum is different 12 | def test_calculate_store_history(): 13 | size = (260, 346) 14 | events = utils.generate_events(1000, size[0], size[1], tmin=0.1, tmax=0.9) 15 | events_torch = torch.from_numpy(events) 16 | imager = EventImageConverter(size) 17 | 18 | cost = ImageVariance(direction="minimize", store_history=True) 19 | # Calculate numpy 20 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", weight=1.0, sigma=0) 21 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 22 | 23 | # Calculate torch 24 | count = imager.create_image_from_events_tensor(events_torch, "bilinear_vote", weight=1.0) 25 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 26 | 27 | history = cost.get_history() 28 | assert len(history["loss"]) == 2 29 | np.testing.assert_allclose(history["loss"][0], history["loss"][1], rtol=1e-5, atol=1e-5) 30 | 31 | 32 | def test_calculate_not_store_history(): 33 | size = (260, 346) 34 | events = utils.generate_events(1000, size[0], size[1], tmin=0.1, tmax=0.9) 35 | events_torch = torch.from_numpy(events) 36 | imager = EventImageConverter(size) 37 | 38 | cost = ImageVariance(direction="minimize", store_history=False) 39 | # Calculate numpy 40 | count = imager.create_image_from_events_numpy(events, "bilinear_vote", weight=1.0, sigma=0) 41 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 42 | 43 | # Calculate torch 44 | count = imager.create_image_from_events_tensor(events_torch, "bilinear_vote", weight=1.0) 45 | _ = cost.calculate({"iwe": count, "omit_boundary": True}) 46 | 47 | history = cost.get_history() 48 | assert len(history["loss"]) == 0 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "direction,is_small", 53 | [["natural", True], ["minimize", False], ["maximize", True]], 54 | ) 55 | def test_calculate_blur_is_small(direction, is_small): 56 | size = (10, 40) 57 | imager = EventImageConverter(size) 58 | cost = ImageVariance(direction=direction, store_history=False) 59 | 60 | events = np.array( 61 | [ 62 | [5.0, 10.0], 63 | [8.0, 3.0], 64 | [2.0, 2.0], 65 | ] 66 | ) 67 | var_blur = cost.calculate({"iwe": imager.create_iwe(events), "omit_boundary": False}) 68 | events = np.array( 69 | [ 70 | [5.0, 10.0], 71 | [5.0, 10.0], 72 | [2.0, 2.0], 73 | ] 74 | ) 75 | var_sharp = cost.calculate({"iwe": imager.create_iwe(events), "omit_boundary": False}) 76 | 77 | assert (var_blur < var_sharp) == is_small 78 | 79 | 80 | def test_calculate_np_torch(): 81 | size = (10, 20) 82 | imager = EventImageConverter(size) 83 | cost = ImageVariance(direction="natural", store_history=False) 84 | 85 | events = np.array( 86 | [ 87 | [12.0, 10.0], 88 | [8.0, 3.0], 89 | [2.0, 2.0], 90 | ] 91 | ).astype(np.float64) 92 | events_torch = torch.from_numpy(events) 93 | var_blur_numpy = cost.calculate( 94 | {"iwe": imager.create_iwe(events, sigma=0), "omit_boundary": True} 95 | ) 96 | var_blur_torch = cost.calculate( 97 | {"iwe": imager.create_iwe(events_torch, sigma=0), "omit_boundary": True} 98 | ) 99 | 100 | np.testing.assert_allclose(var_blur_torch.item(), var_blur_numpy, rtol=1e-2) 101 | -------------------------------------------------------------------------------- /tests/test_event_image_converter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from src import event_image_converter, utils 5 | 6 | 7 | def test_create_iwe(): 8 | image_shape = (100, 200) 9 | imager = event_image_converter.EventImageConverter(image_shape) 10 | events = np.array( 11 | [utils.generate_events(100, image_shape[0] - 1, image_shape[1] - 1) for _ in range(4)] 12 | ) 13 | iwe = imager.create_iwe(events) 14 | assert iwe.shape == (4, 100, 200) 15 | 16 | 17 | def test_bilinear_vote_integer(): 18 | image_shape = (3, 4) 19 | imager = event_image_converter.EventImageConverter(image_shape) 20 | 21 | events = np.array( 22 | [ 23 | [1.0, 2], 24 | [0, 1], 25 | [1, 0], 26 | ] 27 | ) 28 | weights = np.array([1, 2, 0.8]) 29 | img = imager.bilinear_vote_numpy(events, weight=weights) 30 | expected = np.array( 31 | [ 32 | [0, 2, 0, 0], 33 | [0.8, 0, 1, 0], 34 | [0, 0, 0, 0], 35 | ] 36 | ) 37 | # NUmpy 38 | np.testing.assert_array_equal(img, expected) 39 | 40 | # Torch 41 | img = imager.bilinear_vote_tensor(torch.from_numpy(events), weight=torch.from_numpy(weights)) 42 | assert torch.allclose(img, torch.from_numpy(expected)) 43 | 44 | 45 | def test_bilinear_vote_float(): 46 | image_shape = (3, 4) 47 | imager = event_image_converter.EventImageConverter(image_shape) 48 | events = np.array( 49 | [ 50 | [1.2, 2], 51 | [0, 1.9], 52 | [0.5, 0.6], 53 | ] 54 | ) 55 | weights = np.array([-1.0, 1.0, 1.5]) 56 | img = imager.bilinear_vote_numpy(events, weight=weights) 57 | expected = np.array( 58 | [ 59 | [0.3, 0.55, 0.9, 0], 60 | [0.3, 0.45, -0.8, 0], 61 | [0, 0, -0.2, 0], 62 | ] 63 | ) 64 | # numpy 65 | np.testing.assert_allclose(img, expected) 66 | 67 | # torch 68 | img = imager.bilinear_vote_tensor(torch.from_numpy(events), weight=torch.from_numpy(weights)) 69 | assert torch.allclose(img, torch.from_numpy(expected)) 70 | 71 | 72 | def test_bilinear_vote_batch(): 73 | image_shape = (3, 4) 74 | imager = event_image_converter.EventImageConverter(image_shape) 75 | events = np.array( 76 | [ 77 | [ 78 | [1, 2], 79 | [0, 1], 80 | [1, 0], 81 | ], 82 | [ 83 | [1.2, 2], 84 | [0, 1.9], 85 | [0.5, 0.6], 86 | ], 87 | ] 88 | ) 89 | weights = np.array([[1.0, 2.0, 0.8], [-1.0, 1.0, 1.5]]) 90 | img = imager.bilinear_vote_numpy(events, weight=weights) 91 | expected = np.array( 92 | [ 93 | [ 94 | [0, 2, 0, 0], 95 | [0.8, 0, 1, 0], 96 | [0, 0, 0, 0], 97 | ], 98 | [ 99 | [0.3, 0.55, 0.9, 0], 100 | [0.3, 0.45, -0.8, 0], 101 | [0, 0, -0.2, 0], 102 | ], 103 | ] 104 | ) 105 | # numpy 106 | np.testing.assert_allclose(img, expected) 107 | 108 | # torch 109 | img = imager.bilinear_vote_tensor(torch.from_numpy(events), weight=torch.from_numpy(weights)) 110 | assert torch.allclose(img, torch.from_numpy(expected)) 111 | 112 | 113 | def test_bilinear_vote_accuracy(): 114 | image_shape = (10, 20) 115 | imager = event_image_converter.EventImageConverter(image_shape) 116 | events = utils.generate_events(100, image_shape[0] - 1, image_shape[1] - 1) 117 | events += np.random.rand(100)[..., None] 118 | events = events.astype(np.float32) 119 | img = imager.bilinear_vote_numpy(events, weight=1.0) 120 | 121 | img_t = imager.bilinear_vote_tensor(torch.from_numpy(events), weight=1.0) 122 | np.testing.assert_allclose(img, img_t.numpy(), rtol=1e-5) 123 | -------------------------------------------------------------------------------- /tests/test_warp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from src import utils, warp 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "model,size", 10 | [["2d-translation", 2]], 11 | ) 12 | def test_get_motion_vector_size(model, size): 13 | warper = warp.Warp((100, 200), normalize_t=True) 14 | assert size == warper.get_motion_vector_size(model) 15 | 16 | 17 | def test_calculate_dt_normalize(): 18 | image_size = (100, 200) 19 | warper = warp.Warp(image_size, normalize_t=True) 20 | 21 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=1, tmax=2) 22 | dt = warper.calculate_dt(events, 1.0) 23 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1) 24 | np.testing.assert_allclose(dt.max(), 1.0, rtol=1e-2, atol=0.1) 25 | 26 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=0, tmax=0.5) 27 | dt = warper.calculate_dt(events, 0) 28 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1) 29 | np.testing.assert_allclose(dt.max(), 1.0, rtol=1e-2, atol=0.1) 30 | 31 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=-1, tmax=1) 32 | dt = warper.calculate_dt(events, 0) 33 | np.testing.assert_allclose(dt.min(), -0.5, rtol=1e-2, atol=0.1) 34 | np.testing.assert_allclose(dt.max(), 0.5, rtol=1e-2, atol=0.1) 35 | 36 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=-1, tmax=1) 37 | dt = warper.calculate_dt(events, -1) 38 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1) 39 | np.testing.assert_allclose(dt.max(), 1.0, rtol=1e-2, atol=0.1) 40 | 41 | 42 | def test_calculate_dt_non_normalize(): 43 | image_size = (10, 20) 44 | warper = warp.Warp(image_size, normalize_t=False) 45 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=1, tmax=2) 46 | dt = warper.calculate_dt(events, 1.0) 47 | np.testing.assert_allclose(dt.max(), 1.0, rtol=1e-2, atol=0.1) 48 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1) 49 | 50 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=0, tmax=0.5) 51 | dt = warper.calculate_dt(events, 0) 52 | np.testing.assert_allclose(dt.max(), 0.5, rtol=1e-2, atol=0.1) 53 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1) 54 | 55 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=-1, tmax=1) 56 | dt = warper.calculate_dt(events, 0) 57 | np.testing.assert_allclose(dt.max(), 1.0, rtol=1e-2, atol=0.1) 58 | np.testing.assert_allclose(dt.min(), -1.0, rtol=1e-2, atol=0.1) 59 | 60 | events = utils.generate_events(300, image_size[0], image_size[1], tmin=-1, tmax=1) 61 | dt = warper.calculate_dt(events, -1) 62 | np.testing.assert_allclose(dt.max(), 2.0, rtol=1e-2, atol=0.1) 63 | np.testing.assert_allclose(dt.min(), 0.0, rtol=1e-2, atol=0.1) 64 | 65 | 66 | def test_calculate_dt_numpy_batch(): 67 | image_size = (10, 20) 68 | warper = warp.Warp(image_size, normalize_t=True) 69 | events = np.array( 70 | [ 71 | utils.generate_events(300, image_size[0], image_size[1], tmin=1, tmax=i + 2) 72 | for i in range(4) 73 | ] 74 | ) 75 | dt = warper.calculate_dt(events, 1.0) 76 | np.testing.assert_allclose(dt.max(axis=-1), 1.0, rtol=1e-2, atol=0.1) 77 | np.testing.assert_allclose(dt.min(axis=-1), 0.0, rtol=1e-2, atol=0.1) 78 | assert dt.shape == (4, 300) 79 | 80 | 81 | def test_calculate_dt_torch_batch(): 82 | image_size = (10, 20) 83 | warper = warp.Warp(image_size, normalize_t=True) 84 | events = np.array( 85 | [ 86 | utils.generate_events(300, image_size[0], image_size[1], tmin=1, tmax=i + 2) 87 | for i in range(4) 88 | ] 89 | ) 90 | dt = warper.calculate_dt(torch.from_numpy(events), 1.0).numpy() 91 | np.testing.assert_allclose(dt.max(axis=-1), 1.0, rtol=1e-2, atol=0.1) 92 | np.testing.assert_allclose(dt.min(axis=-1), 0.0, rtol=1e-2, atol=0.1) 93 | assert dt.shape == (4, 300) 94 | 95 | 96 | def test_warp_event_dense_flow(): 97 | image_size = (3, 4) 98 | warper = warp.Warp(image_size, normalize_t=True) 99 | 100 | events = np.array( 101 | [ 102 | [1, 2, 0], 103 | [2, 3, 0.2], 104 | [0, 1, 0.6], 105 | [1, 0, 1.0], 106 | ] 107 | ) 108 | flow = np.array( 109 | [ 110 | [ 111 | [1.0, -0.5, 2, 8], 112 | [-2, 0, 2.0, 0], 113 | [2, 1, -2, 0], 114 | ], 115 | [ 116 | [-10, 1.0, 3, 2], 117 | [0, 2, -0.9, 0], 118 | [0, 10, -3, 0], 119 | ], 120 | ] 121 | ) 122 | 123 | expected = np.array( 124 | [ 125 | [1.0, 2.0, 0], 126 | [2.0, 3.0, 0.2], 127 | [0.3, 0.4, 0.6], 128 | [3, 0, 1.0], 129 | ] 130 | ) 131 | # NUmpy 132 | warped, _ = warper.warp_event(events, flow, "dense-flow") 133 | np.testing.assert_allclose(warped, expected) 134 | 135 | # Torch 136 | warped_torch, _ = warper.warp_event( 137 | torch.from_numpy(events), torch.from_numpy(flow), "dense-flow" 138 | ) 139 | assert torch.allclose(warped_torch, torch.from_numpy(expected)) 140 | 141 | 142 | def test_warp_event_dense_flow_batch(): 143 | image_size = (3, 4) 144 | warper = warp.Warp(image_size, normalize_t=True) 145 | 146 | events = np.array( 147 | [ 148 | [[1, 2, 0], [2, 3, 0.2]], 149 | [[0, 1, 0.6], [1, 0, 1.2]], 150 | ] 151 | ) 152 | flow = np.array( 153 | [ 154 | [ 155 | [ 156 | [1.0, -0.5, 2, 8], 157 | [-2, 0, 2.0, 0], 158 | [2, 1, -2, 0], 159 | ], 160 | [ 161 | [-10, 1.0, 3, 2], 162 | [0, 2, -0.9, 0], 163 | [0, 10, -3, 0], 164 | ], 165 | ], 166 | [ 167 | [ 168 | [1.0, -0.5, 2, 8], 169 | [-2, 0, 2.0, 0], 170 | [2, 1, -2, 0], 171 | ], 172 | [ 173 | [-10, 1.0, 3, 2], 174 | [0, 2, -0.9, 0], 175 | [0, 10, -3, 0], 176 | ], 177 | ], 178 | ] 179 | ) 180 | 181 | expected = np.array( 182 | [ 183 | [[1.0, 2.0, 0], [2, 3, 1.0]], 184 | [[0, 1, 0], [3, 0, 1.0]], 185 | ] 186 | ) 187 | # NUmpy 188 | warped, _ = warper.warp_event(events, flow, "dense-flow") 189 | np.testing.assert_allclose(warped, expected) 190 | 191 | # Torch 192 | warped_torch, _ = warper.warp_event( 193 | torch.from_numpy(events), torch.from_numpy(flow), "dense-flow" 194 | ) 195 | assert torch.allclose(warped_torch, torch.from_numpy(expected)) 196 | 197 | 198 | def test_warp_event_dense_flow_accuracy(): 199 | image_size = (10, 20) 200 | warper = warp.Warp(image_size, normalize_t=True) 201 | events = np.array( 202 | [ 203 | utils.generate_events(300, image_size[0], image_size[1], tmin=1, tmax=i + 2) 204 | for i in range(4) 205 | ] 206 | ).astype(np.float32) 207 | flow = np.array( 208 | [utils.generate_dense_optical_flow(image_size, max_val=10) for i in range(4)] 209 | ).astype(np.float32) 210 | warped, _ = warper.warp_event(events, flow, "dense-flow") 211 | warped_torch, _ = warper.warp_event( 212 | torch.from_numpy(events), torch.from_numpy(flow), "dense-flow" 213 | ) 214 | np.testing.assert_allclose(warped, warped_torch.numpy(), rtol=1e-5, atol=1e-5) 215 | -------------------------------------------------------------------------------- /tests/utils/test_event_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from src import utils 5 | 6 | 7 | def test_crop_events(): 8 | events = utils.generate_events(100, 30, 20) 9 | cropped = utils.crop_event(events, -10, 40, -20, 30) 10 | assert len(events) == len(cropped) 11 | 12 | cropped_torch = utils.crop_event(torch.from_numpy(events), -10, 40, -20, 30) 13 | assert len(events) == len(cropped) == len(cropped_torch) 14 | 15 | cropped = utils.crop_event(events, 5, 29, 2, 11) 16 | cropped_torch = utils.crop_event(torch.from_numpy(events), 5, 29, 2, 11) 17 | assert len(cropped) == len(cropped_torch) 18 | -------------------------------------------------------------------------------- /tests/utils/test_flow_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from src import utils 6 | 7 | 8 | def test_flow_error_numpy_torch(): 9 | imsize = (100, 200) 10 | flow_pred_np = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float32) 11 | flow_pred_np = flow_pred_np[None, ...] 12 | flow_gt_np = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float32) 13 | flow_gt_np = flow_gt_np[None, ...] 14 | 15 | flow_pred_th = torch.from_numpy(flow_pred_np) 16 | flow_gt_th = torch.from_numpy(flow_gt_np) 17 | 18 | error_np = utils.calculate_flow_error_numpy(flow_gt_np, flow_pred_np) 19 | error_th = utils.calculate_flow_error_tensor(flow_gt_th, flow_pred_th) 20 | 21 | for k in error_np.keys(): 22 | np.testing.assert_almost_equal(error_np[k], error_th[k].numpy(), decimal=5) 23 | 24 | 25 | def test_flow_error_different_batch_size(): 26 | # Test with bsize=1 vs. bsize=8 27 | imsize = (100, 200) 28 | flow_pred = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float32) 29 | flow_gt = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float32) 30 | flow_pred = np.tile(flow_pred, (8, 1, 1, 1)) 31 | flow_gt = np.tile(flow_gt, (8, 1, 1, 1)) 32 | mask = np.random.rand(8, 1, imsize[0], imsize[1]) > 0.1 33 | 34 | # Numpy 35 | error_batch = utils.calculate_flow_error_numpy(flow_gt, flow_pred, mask) 36 | error_one = utils.calculate_flow_error_numpy(flow_gt[[0]], flow_pred[[0]]) 37 | 38 | for k in error_one.keys(): 39 | np.testing.assert_almost_equal(error_one[k], error_batch[k], decimal=1) 40 | 41 | # Torch 42 | flow_gt = torch.from_numpy(flow_gt) 43 | flow_pred = torch.from_numpy(flow_pred) 44 | mask = torch.from_numpy(mask) 45 | error_batch = utils.calculate_flow_error_tensor(flow_gt, flow_pred, mask) 46 | error_one = utils.calculate_flow_error_tensor(flow_gt[[0]], flow_pred[[0]]) 47 | 48 | for k in error_one.keys(): 49 | np.testing.assert_almost_equal(error_one[k].numpy(), error_batch[k].numpy(), decimal=1) 50 | 51 | 52 | @pytest.mark.parametrize("scheme", ["upwind", "burgers"]) 53 | def test_construct_dense_flow_voxel_numpy(scheme): 54 | imsize = (100, 200) 55 | n_bin = 60 56 | flow = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float64) 57 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, 1) 58 | np.testing.assert_almost_equal(voxel_flow[0], flow, decimal=8) 59 | 60 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, n_bin, scheme, t0_location="middle") 61 | np.testing.assert_almost_equal(voxel_flow[n_bin // 2], flow, decimal=8) 62 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, n_bin, scheme, t0_location="first") 63 | np.testing.assert_almost_equal(voxel_flow[0], flow, decimal=8) 64 | 65 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, n_bin, scheme, t0_location="middle") 66 | np.testing.assert_almost_equal(voxel_flow[n_bin // 2], flow, decimal=8) 67 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, n_bin, scheme, t0_location="first") 68 | np.testing.assert_almost_equal(voxel_flow[0], flow, decimal=8) 69 | 70 | 71 | @pytest.mark.parametrize("scheme", ["upwind", "burgers"]) 72 | def test_construct_dense_flow_voxel_upwind_torch(scheme): 73 | imsize = (100, 200) 74 | n_bin = 60 75 | flow = utils.generate_dense_optical_flow(imsize, max_val=20).astype(np.float64) 76 | 77 | flow_torch = torch.from_numpy(flow).double() 78 | voxel_flow_torch = utils.construct_dense_flow_voxel_torch( 79 | flow_torch, n_bin, scheme, t0_location="middle" 80 | ) 81 | np.testing.assert_almost_equal( 82 | voxel_flow_torch.numpy()[n_bin // 2], flow_torch.numpy(), decimal=8 83 | ) 84 | voxel_flow_torch = utils.construct_dense_flow_voxel_torch( 85 | flow_torch, n_bin, scheme, t0_location="first" 86 | ) 87 | np.testing.assert_almost_equal(voxel_flow_torch.numpy()[0], flow_torch.numpy(), decimal=8) 88 | 89 | 90 | @pytest.mark.parametrize( 91 | "scheme,t0_location", 92 | [["upwind", "middle"], ["upwind", "first"], ["burgers", "middle"], ["burgers", "first"]], 93 | ) 94 | def test_construct_dense_flow_voxel_numerical(scheme, t0_location): 95 | imsize = (100, 200) 96 | n_bin = 100 97 | flow = utils.generate_dense_optical_flow(imsize, max_val=10).astype(np.float64) 98 | 99 | flow_torch = torch.from_numpy(flow).double() 100 | 101 | voxel_flow = utils.construct_dense_flow_voxel_numpy(flow, n_bin, scheme, t0_location) 102 | voxel_flow_torch = utils.construct_dense_flow_voxel_torch( 103 | flow_torch, n_bin, scheme, t0_location 104 | ) 105 | np.testing.assert_almost_equal(voxel_flow_torch.numpy(), voxel_flow, decimal=6) 106 | 107 | 108 | def test_inviscid_burger_flow_to_voxel_numerical(): 109 | imsize = (100, 200) 110 | dt = 0.01 111 | flow = utils.generate_dense_optical_flow(imsize, max_val=1).astype(np.float64) 112 | flow_torch = torch.from_numpy(flow).double() 113 | 114 | flow1 = utils.inviscid_burger_flow_to_voxel_numpy(flow, dt, 1, 1) 115 | flow1_torch = utils.inviscid_burger_flow_to_voxel_torch(flow_torch, dt, 1, 1) 116 | np.testing.assert_almost_equal(flow1_torch.numpy(), flow1, decimal=6) 117 | 118 | flow1 = utils.inviscid_burger_flow_to_voxel_numpy(flow, -dt, 1, 1) 119 | flow1_torch = utils.inviscid_burger_flow_to_voxel_torch(flow_torch, -dt, 1, 1) 120 | np.testing.assert_almost_equal(flow1_torch.numpy(), flow1, decimal=6) 121 | --------------------------------------------------------------------------------