├── .readme └── title.png ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── eval_dsec.yml ├── eval_flow.yml ├── eval_mvsec.yml ├── parser.py └── train_flow.yml ├── dataloader ├── __init__.py ├── base.py ├── cache.py ├── encodings.py ├── h5.py └── utils.py ├── dsec_submissions ├── interlaken_00_b.txt ├── interlaken_00_b_flag.npy ├── interlaken_01_a.txt ├── interlaken_01_a_flag.npy ├── thun_01_a.txt ├── thun_01_a_flag.npy ├── thun_01_b.txt ├── thun_01_b_flag.npy ├── zurich_city_12_a.txt ├── zurich_city_12_a_flag.npy ├── zurich_city_14_c.txt ├── zurich_city_14_c_flag.npy ├── zurich_city_15_a.txt └── zurich_city_15_a_flag.npy ├── eval_flow.py ├── loss ├── __init__.py ├── flow.py └── flow_val.py ├── models ├── __init__.py ├── arch.py ├── base.py ├── model.py ├── model_util.py └── submodules.py ├── prepare_dsec_submission.py ├── requirements.txt ├── train_flow.py └── utils ├── iwe.py ├── mlflow.py ├── utils.py └── visualization.py /.readme/title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/.readme/title.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 TU Delft 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Taming Contrast Maximization for Learning Sequential Event-based Optical Flow Estimation 2 | 3 | Work accepted at ICCV'23 [[paper](https://arxiv.org/abs/2303.05214), [video](https://youtu.be/vkYimENc494)]. 4 | 5 | If you use this code in an academic context, please cite our work: 6 | 7 | ```bibtex 8 | @InProceedings{Paredes-Valles_2023_ICCV, 9 | author = {Paredes-Vall\'es, Federico and Scheper, Kirk Y. W. and De Wagter, Christophe and de Croon, Guido C. H. E.}, 10 | title = {Taming Contrast Maximization for Learning Sequential, Low-latency, Event-based Optical Flow}, 11 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 12 | month = {October}, 13 | year = {2023}, 14 | pages = {9695-9705} 15 | } 16 | 17 | ``` 18 | 19 | This code allows for the reproduction of the experiments leading to the results in Section 4.2. 20 | 21 | ![Alt text](.readme/title.png) 22 | 23 | ---------- 24 | 25 | ## Usage 26 | 27 | ---------- 28 | 29 | This project uses Python >= 3.7.3 and we strongly recommend the use of virtual environments. If you don't have an environment manager yet, we recommend `pyenv`. It can be installed via: 30 | 31 | ``` 32 | curl https://pyenv.run | bash 33 | ``` 34 | 35 | Make sure your `~/.bashrc` or `~/.zshrc` file contains the following: 36 | 37 | ``` 38 | export PATH="$HOME/.pyenv/bin:$PATH" 39 | eval "$(pyenv init -)" 40 | eval "$(pyenv virtualenv-init -)" 41 | ``` 42 | 43 | After that, restart your terminal and run: 44 | 45 | ``` 46 | pyenv update 47 | ``` 48 | 49 | To set up your environment with `pyenv` first install the required python distribution and make sure the installation is successful (i.e., no errors nor warnings): 50 | 51 | ``` 52 | pyenv install -v 3.7.3 53 | ``` 54 | 55 | Once this is done, set up the environment and install the required libraries: 56 | 57 | ``` 58 | pyenv virtualenv 3.7.3 taming_flow 59 | pyenv activate taming_flow 60 | 61 | pip install --upgrade pip==20.0.2 62 | 63 | cd taming_flow/ 64 | ``` 65 | 66 | To install our dependencies: 67 | 68 | ``` 69 | pip install -r requirements.txt 70 | ``` 71 | 72 | ### Download datasets 73 | 74 | The following is a list of the datasets that are required to train/evaluate our models on the DSEC-Flow and MVSEC datasets: 75 | 76 | - **dsec_train** (1.98 GB): 128x128 random crops of the training partition of DSEC-Flow. Each sequence is two seconds of duration. Used for training. 77 | - **dsec_benchmark_aug** (15.94 GB): 480x640 test partition of DSEC-Flow. Used for evaluation. 78 | - **mvsec_eval** (653.5 MB): 260x346 outdoot_day_1 sequence, from seconds 222.4 to 240.4. Used for evaluation. 79 | 80 | The datasets can be downloaded from [here](https://1drv.ms/u/s!Ah0kx0CRKrAZjxMxBx4z5HN1CjWv?e=UiayaL) and are expected at: `../datasets/`. 81 | 82 | ### Download models 83 | 84 | The pretrained models can be downloaded from [here](https://1drv.ms/u/s!Ah0kx0CRKrAZjxSwx8-UTUAncgg3?e=yM2g0i), and are expected at `mlruns/`. 85 | 86 | In this project we use MLflow to keep track of the experiments. To visualize the models tht are available, alongside other useful details and evaluation metrics, run the following from the home directory of the project: 87 | 88 | ``` 89 | mlflow ui 90 | ``` 91 | 92 | We provide our best performing models in the DSEC and MVSEC datasets. 93 | 94 | ---------- 95 | 96 | ## Inference 97 | 98 | ---------- 99 | 100 | To estimate optical flow from event sequences from a dataset of your choice, adjust `config/eval_flow.yml` according to your needs and run: 101 | 102 | ``` 103 | python eval_flow.py 104 | 105 | # for example: 106 | python eval_flow.py dsec_model 107 | ``` 108 | 109 | where `` is the name of MLflow run to be evaluated. Note that, if a run does not have a name (this would be the case for your own trained models), you can evaluated it through its run ID (also visible through MLflow). 110 | 111 | ### MVSEC 112 | 113 | Simply run: 114 | 115 | ``` 116 | python eval_flow.py --config configs/eval_mvsec.yml 117 | ``` 118 | 119 | ### DSEC-Flow Public Benchmark 120 | If what you want to is to generate a submission to the [DSEC-Flow Optical Flow Public Benchmark](https://dsec.ifi.uzh.ch/uzh/dsec-flow-optical-flow-benchmark/), run: 121 | 122 | ``` 123 | python eval_flow.py --config configs/eval_dsec.yml 124 | mkdir results_inference/ 125 | cp -r results_inference//results/eval_* results_inference// 126 | python prepare_dsec_submission.py --eval_id 127 | ``` 128 | 129 | This will generate a `submission/` folder in the directory with your results. Zip it and submit! 130 | 131 | The DSEC submission associated to our best performing model can be downloaded for inspection from [here](https://1drv.ms/u/s!Ah0kx0CRKrAZjyfkk6kgwMKgxar_?e=njw0KT). 132 | 133 | ---------- 134 | 135 | ## Training 136 | 137 | ---------- 138 | 139 | Run: 140 | 141 | ``` 142 | python train_flow.py 143 | ``` 144 | 145 | to train an traditional artificial neural network. In `configs/`, you can find the configuration files and vary the training settings (e.g., input settings, model, event warping, activate/deactivate visualization). For other models available, see `models/model.py`. 146 | 147 | **Note that we used a batch size of 8 in our experiments. Depending on your computational resources, you may need to lower this number.** 148 | 149 | During and after the training, information about your run can be visualized through [MLflow](https://www.mlflow.org/docs/latest/index.html#) and [TensorBoard](https://www.tensorflow.org/tensorboard). 150 | 151 | ---------- 152 | 153 | ## Uninstalling pyenv 154 | 155 | ---------- 156 | 157 | Once you finish using our code, you can uninstall `pyenv` from your system by: 158 | 159 | 1. Removing the `pyenv` configuration lines from your `~/.bashrc`. 160 | 2. Removing its root directory. This will delete all Python versions that were installed under the `$HOME/.pyenv/versions/` directory: 161 | 162 | ``` 163 | rm -rf $HOME/.pyenv/ 164 | ``` 165 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/configs/__init__.py -------------------------------------------------------------------------------- /configs/eval_dsec.yml: -------------------------------------------------------------------------------- 1 | data: 2 | path: ../datasets/dsec_benchmark_aug 3 | mode: gtflow # events/time/gtflow 4 | window: 0.1 5 | passes_loss: 10 6 | cache: False 7 | 8 | loader: 9 | resolution: [480, 640] # H x W 10 | augment: [] 11 | max_num_grad_events: Null 12 | gpu: 0 13 | seed: Null # random if Null 14 | 15 | metrics: 16 | warping: Iterative # Linear/Iterative 17 | name: ["FWL", "RSAT", "AEE"] 18 | 19 | vis: 20 | enabled: False 21 | px: 400 22 | bars: True 23 | store: True 24 | mask_output: False 25 | dynamic: True 26 | show: ["flow_bw"] # Null for everything -------------------------------------------------------------------------------- /configs/eval_flow.yml: -------------------------------------------------------------------------------- 1 | data: 2 | path: ../datasets/dsec_benchmark_aug 3 | mode: gtflow # events/time/gtflow 4 | window: 0.1 5 | passes_loss: 10 6 | cache: False 7 | 8 | loader: 9 | resolution: [480, 640] # H x W 10 | augment: [] 11 | max_num_grad_events: Null 12 | gpu: 0 13 | seed: Null # random if Null 14 | 15 | metrics: 16 | warping: Iterative # Linear/Iterative 17 | name: ["FWL", "RSAT", "AEE"] 18 | 19 | vis: 20 | enabled: True 21 | px: 400 22 | bars: True 23 | store: False 24 | mask_output: False 25 | dynamic: True 26 | show: Null # Null for everything 27 | -------------------------------------------------------------------------------- /configs/eval_mvsec.yml: -------------------------------------------------------------------------------- 1 | data: 2 | path: ../datasets/mvsec_eval 3 | mode: gtflow # events/time/gtflow 4 | window: 1 5 | passes_loss: 1 6 | cache: False 7 | 8 | loader: 9 | resolution: [260, 346] # H x W 10 | augment: [] 11 | max_num_grad_events: Null 12 | gpu: 0 13 | seed: Null # random if Null 14 | 15 | metrics: 16 | warping: Iterative # Linear/Iterative 17 | name: ["FWL", "RSAT", "AEE"] 18 | eval_time: [222.4, 240.4] 19 | mask_aee: True 20 | res_aee: [256, 256] 21 | vertical_crop_aee: 190 22 | 23 | vis: 24 | enabled: True 25 | px: 400 26 | bars: True 27 | store: False 28 | mask_output: True 29 | dynamic: False 30 | show: Null # Null for everything 31 | -------------------------------------------------------------------------------- /configs/parser.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import yaml 5 | 6 | 7 | class YAMLParser: 8 | """ 9 | YAML parser for the config files. 10 | """ 11 | 12 | def __init__(self, config): 13 | self.reset_config() 14 | self.parse_config(config) 15 | self.get_device() 16 | if self._config["loader"]["seed"] is not None: 17 | self.init_seeds() 18 | 19 | def parse_config(self, file): 20 | """ 21 | Load and parse the config file. 22 | """ 23 | with open(file) as fid: 24 | yaml_config = yaml.load(fid, Loader=yaml.FullLoader) 25 | self.parse_dict(yaml_config) 26 | 27 | @property 28 | def config(self): 29 | return self._config 30 | 31 | @property 32 | def device(self): 33 | return self._device 34 | 35 | @property 36 | def loader_kwargs(self): 37 | return self._loader_kwargs 38 | 39 | def reset_config(self): 40 | self._config = {} 41 | 42 | # MLFlow experiment name 43 | self._config["experiment"] = "Default" 44 | 45 | # input data mode 46 | self._config["data"] = {} 47 | self._config["data"]["mode"] = "events" 48 | self._config["data"]["window"] = 5000 49 | 50 | # data loader 51 | self._config["loader"] = {} 52 | self._config["loader"]["resolution"] = [180, 240] 53 | self._config["loader"]["batch_size"] = 1 54 | self._config["loader"]["augment"] = [] 55 | self._config["loader"]["gpu"] = 0 56 | self._config["loader"]["seed"] = 42 57 | 58 | # model 59 | self._config["model"] = {} 60 | 61 | # visualization 62 | self._config["vis"] = {} 63 | self._config["vis"]["bars"] = False 64 | 65 | def update(self, config): 66 | """ 67 | Updates the config with the given config. 68 | :param config: dictionary containing a config to update with 69 | """ 70 | self.reset_config() 71 | self.parse_config(config) 72 | 73 | def parse_dict(self, input_dict, parent=None): 74 | """ 75 | Augments self._config with the given dictionary. 76 | :param input_dict: dictionary to parse and use to update self._config 77 | :param parent: parent dictionary to be updated 78 | """ 79 | if parent is None: 80 | parent = self._config 81 | for key, val in input_dict.items(): 82 | if isinstance(val, dict): 83 | if key not in parent.keys(): 84 | parent[key] = {} 85 | self.parse_dict(val, parent[key]) 86 | else: 87 | parent[key] = val 88 | 89 | def get_device(self): 90 | """ 91 | Get the device to use in the pipeline. 92 | """ 93 | cuda = torch.cuda.is_available() 94 | self._device = torch.device("cuda:" + str(self._config["loader"]["gpu"]) if cuda else "cpu") 95 | self._loader_kwargs = {"num_workers": 0, "pin_memory": True} if cuda else {} 96 | 97 | @staticmethod 98 | # TODO: not using multiple workers anymore, enable it 99 | def worker_init_fn(worker_id): 100 | np.random.seed(np.random.get_state()[1][0] + worker_id) 101 | 102 | def init_seeds(self): 103 | """ 104 | Initialize random seeds. 105 | """ 106 | torch.manual_seed(self._config["loader"]["seed"]) 107 | if torch.cuda.is_available(): 108 | torch.cuda.manual_seed(self._config["loader"]["seed"]) 109 | torch.cuda.manual_seed_all(self._config["loader"]["seed"]) 110 | np.random.seed(self._config["loader"]["seed"]) 111 | random.seed(self._config["loader"]["seed"]) 112 | 113 | def merge_configs(self, run): 114 | """ 115 | Overwrites mlflow metadata with configs. 116 | :param run: mlflow run object 117 | """ 118 | # parse mlflow settings 119 | config = {} 120 | for key in run.keys(): 121 | if len(run[key]) > 0 and run[key][0] == "{": # assume dictionary 122 | config[key] = eval(run[key]) 123 | else: # string 124 | config[key] = run[key] 125 | 126 | # overwrite with config settings 127 | self.parse_dict(self._config, config) 128 | 129 | return config 130 | 131 | @staticmethod 132 | def combine_entries(config): 133 | """ 134 | Combines entries that had to be split because of MLFlow's max character limit. 135 | :param config: dictionary to combine entries in 136 | """ 137 | return config 138 | -------------------------------------------------------------------------------- /configs/train_flow.yml: -------------------------------------------------------------------------------- 1 | experiment: Default 2 | 3 | data: 4 | path: ../datasets/dsec_train 5 | mode: time # events/time/gtflow 6 | window: 0.01 7 | passes_loss: 10 # length of the loss accumulation window 8 | scales_loss: 1 # temporal scales for loss computation 9 | voxel: Null # number of bins 10 | cache: False 11 | 12 | model: 13 | name: RecEVFlowNet 14 | final_w_scale: 0.01 15 | 16 | loss: 17 | warping: Iterative # Linear/Iterative 18 | iterative_mode: two # one/two/four (see notes) 19 | round_ts: False 20 | flow_scaling: 32 21 | flow_spat_smooth_weight: Null # Null to disable 22 | flow_temp_smooth_weight: Null # Null to disable 23 | clip_grad: 100.0 # set to Null to disable 24 | 25 | optimizer: 26 | name: Adam 27 | lr: 0.00001 28 | 29 | loader: 30 | n_epochs: 500 31 | batch_size: 1 32 | resolution: [128, 128] # H x W 33 | augment: ["Horizontal", "Vertical", "Polarity"] 34 | augment_prob: [0.5, 0.5, 0.5] 35 | max_num_grad_events: 10000 36 | gpu: 0 37 | seed: Null # random if Null 38 | 39 | vis: 40 | verbose: True 41 | enabled: False 42 | store: False 43 | px: 400 44 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import cv2 4 | import numpy as np 5 | import random 6 | import torch 7 | 8 | from .encodings import events_to_voxel, events_to_channels 9 | 10 | 11 | class BaseDataLoader(torch.utils.data.Dataset): 12 | """ 13 | Base class for dataloader. 14 | """ 15 | 16 | def __init__(self, config): 17 | self.config = config 18 | self.epoch = 0 19 | self.seq_num = 0 20 | self.samples = 0 21 | self.new_seq = False 22 | self.rectify = False 23 | self.device = self.config["loader"]["device"] 24 | self.res = self.config["loader"]["resolution"] 25 | self.batch_size = self.config["loader"]["batch_size"] 26 | 27 | # batch-specific data augmentation mechanisms 28 | self.batch_augmentation = {} 29 | for mechanism in self.config["loader"]["augment"]: 30 | self.batch_augmentation[mechanism] = [False for i in range(self.config["loader"]["batch_size"])] 31 | 32 | for i, mechanism in enumerate(self.config["loader"]["augment"]): 33 | for batch in range(self.config["loader"]["batch_size"]): 34 | if np.random.random() < self.config["loader"]["augment_prob"][i]: 35 | self.batch_augmentation[mechanism][batch] = True 36 | 37 | @abstractmethod 38 | def __getitem__(self, index): 39 | raise NotImplementedError 40 | 41 | @abstractmethod 42 | def get_events(self, history): 43 | raise NotImplementedError 44 | 45 | def reset_sequence(self, batch): 46 | """ 47 | Reset sequence-specific variables. 48 | :param batch: batch index 49 | """ 50 | 51 | self.seq_num += 1 52 | 53 | # data augmentation 54 | for i, mechanism in enumerate(self.config["loader"]["augment"]): 55 | if np.random.random() < self.config["loader"]["augment_prob"][i]: 56 | self.batch_augmentation[mechanism][batch] = True 57 | else: 58 | self.batch_augmentation[mechanism][batch] = False 59 | 60 | def rectification_mapping(self, batch): 61 | """ 62 | Compute the backward rectification map for the input representations. 63 | See https://github.com/uzh-rpg/DSEC/issues/14 for details. 64 | :param batch: batch index 65 | :return K_rect: intrinsic matrix of rectified image 66 | :return mapping: rectification map 67 | :return Q_rect: scaling matrix to convert disparity to depth 68 | """ 69 | 70 | # distorted image 71 | K_dist = eval(self.open_files[batch]["calibration/intrinsics"][()])["cam0"]["camera_matrix"] 72 | 73 | # rectified image 74 | K_rect = eval(self.open_files[batch]["calibration/intrinsics"][()])["camRect0"]["camera_matrix"] 75 | R_rect = eval(self.open_files[batch]["calibration/extrinsics"][()])["R_rect0"] 76 | dist_coeffs = eval(self.open_files[batch]["calibration/intrinsics"][()])["cam0"]["distortion_coeffs"] 77 | 78 | # formatting 79 | K_dist = np.array([[K_dist[0], 0, K_dist[2]], [0, K_dist[1], K_dist[3]], [0, 0, 1]]) 80 | K_rect = np.array([[K_rect[0], 0, K_rect[2]], [0, K_rect[1], K_rect[3]], [0, 0, 1]]) 81 | R_rect = np.array( 82 | [ 83 | [R_rect[0][0], R_rect[0][1], R_rect[0][2]], 84 | [R_rect[1][0], R_rect[1][1], R_rect[1][2]], 85 | [R_rect[2][0], R_rect[2][1], R_rect[2][2]], 86 | ] 87 | ) 88 | dist_coeffs = np.array([dist_coeffs[0], dist_coeffs[1], dist_coeffs[2], dist_coeffs[3]]) 89 | 90 | # backward mapping 91 | mapping = cv2.initUndistortRectifyMap( 92 | K_dist, 93 | dist_coeffs, 94 | R_rect, 95 | K_rect, 96 | (self.res[1], self.res[0]), 97 | cv2.CV_32FC2, 98 | )[0] 99 | 100 | # disparity to depth (onyl used for evaluation) 101 | Q_rect = eval(self.open_files[batch]["calibration/disparity_to_depth"][()])["cams_03"] 102 | Q_rect = np.array( 103 | [ 104 | [Q_rect[0][0], Q_rect[0][1], Q_rect[0][2], Q_rect[0][3]], 105 | [Q_rect[1][0], Q_rect[1][1], Q_rect[1][2], Q_rect[1][3]], 106 | [Q_rect[2][0], Q_rect[2][1], Q_rect[2][2], Q_rect[2][3]], 107 | [Q_rect[3][0], Q_rect[3][1], Q_rect[3][2], Q_rect[3][3]], 108 | ] 109 | ).astype(np.float32) 110 | 111 | for _, mechanism in enumerate(self.config["loader"]["augment"]): 112 | 113 | if mechanism == "Horizontal": 114 | if self.batch_augmentation["Horizontal"][batch]: 115 | K_rect[0, 2] = self.res[1] - 1 - K_rect[0, 2] 116 | mapping[:, :, 0] = self.res[1] - 1 - mapping[:, :, 0] 117 | mapping = np.flip(mapping, axis=1) 118 | Q_rect[0, 3] = -K_rect[0, 2] 119 | 120 | elif mechanism == "Vertical": 121 | if self.batch_augmentation["Vertical"][batch]: 122 | K_rect[1, 2] = self.res[0] - 1 - K_rect[1, 2] 123 | mapping[:, :, 1] = self.res[0] - 1 - mapping[:, :, 1] 124 | mapping = np.flip(mapping, axis=0) 125 | Q_rect[1, 3] = -K_rect[1, 2] 126 | 127 | return K_rect, mapping, Q_rect 128 | 129 | @staticmethod 130 | def format_intrinsics(K_rect): 131 | """ 132 | Format camera matrices. 133 | :param K_rect: [3 x 3] intrinsic matrix (numpy) of rectified image 134 | :return K_rect: [4 x 4] intrinsic matrix (tensor) of rectified image 135 | :return inv_K_rect: [4 x 4] inverse of the intrinsic matrix (tensor) of rectified image 136 | """ 137 | 138 | K_rect = np.c_[K_rect, np.zeros(3)] 139 | K_rect = np.concatenate((K_rect, np.array([[0, 0, 0, 1]])), axis=0) 140 | inv_K_rect = np.linalg.pinv(K_rect) 141 | 142 | K_rect = torch.from_numpy(K_rect.astype(np.float32)) 143 | inv_K_rect = torch.from_numpy(inv_K_rect.astype(np.float32)) 144 | 145 | return K_rect, inv_K_rect 146 | 147 | def event_formatting(self, xs, ys, ts, ps): 148 | """ 149 | Format input events as torch tensors. 150 | :param xs: [N] numpy array with event x location 151 | :param ys: [N] numpy array with event y location 152 | :param ts: [N] numpy array with event timestamp 153 | :param ps: [N] numpy array with event polarity ([-1, 1]) 154 | :return rectified_xs: [N] numpy array with rectified event x location 155 | :return rectified_ys: [N] numpy array with rectified event y location 156 | :return xs: [N] tensor with event x location 157 | :return ys: [N] tensor with event y location 158 | :return ts: [N] tensor with normalized event timestamp 159 | :return ps: [N] tensor with event polarity ([-1, 1]) 160 | """ 161 | 162 | assert len(xs) == len(ys) == len(ts) == len(ps) 163 | 164 | xs = torch.from_numpy(xs.astype(np.float32)).to(self.device) 165 | ys = torch.from_numpy(ys.astype(np.float32)).to(self.device) 166 | ts = torch.from_numpy(ts.astype(np.float32)).to(self.device) 167 | ps = torch.from_numpy(ps.astype(np.float32)).to(self.device) * 2 - 1 168 | if ts.shape[0] > 0: 169 | ts = (ts - ts[0]) / (ts[-1] - ts[0]) 170 | 171 | return xs, ys, ts, ps 172 | 173 | @staticmethod 174 | def rectify_events(rectify_map, xs, ys): 175 | """ 176 | Rectify (and undistort) input events. 177 | :param rectify_map: map used to rectify events 178 | :param xs: [N] numpy array with event x location 179 | :param ys: [N] numpy array with event y location 180 | :return rectified_xs: [N] numpy array with rectified event x location 181 | :return rectified_ys: [N] numpy array with rectified event y location 182 | """ 183 | 184 | rectified_events = rectify_map[ys.long(), xs.long()] 185 | rectified_xs = rectified_events[:, 0] 186 | rectified_ys = rectified_events[:, 1] 187 | 188 | return rectified_xs, rectified_ys 189 | 190 | def augment_events(self, xs, ys, ps, rec_xs, rec_ys, batch): 191 | """ 192 | Augment event sequence with horizontal, vertical, and polarity flips. 193 | :param xs: [N] tensor with event x location 194 | :param ys: [N] tensor with event y location 195 | :param ps: [N] tensor with event polarity ([-1, 1]) 196 | :param rec_xs: [N] tensor with rectified event x location 197 | :param rec_ys: [N] tensor with rectified event y location 198 | :param batch: batch index 199 | :return xs: [N] tensor with augmented event x location 200 | :return ys: [N] tensor with augmented event y location 201 | :return ps: [N] tensor with augmented event polarity ([-1, 1]) 202 | :return rec_xs: [N] tensor with augmented rectified event x location 203 | :return rec_ys: [N] tensor with augmented rectified event y location 204 | """ 205 | 206 | for _, mechanism in enumerate(self.config["loader"]["augment"]): 207 | 208 | if mechanism == "Horizontal": 209 | if self.batch_augmentation["Horizontal"][batch]: 210 | xs = self.res[1] - 1 - xs 211 | if rec_xs is not None: 212 | rec_xs = self.res[1] - 1 - rec_xs 213 | 214 | elif mechanism == "Vertical": 215 | if self.batch_augmentation["Vertical"][batch]: 216 | ys = self.res[0] - 1 - ys 217 | if rec_ys is not None: 218 | rec_ys = self.res[0] - 1 - rec_ys 219 | 220 | elif mechanism == "Polarity": 221 | if self.batch_augmentation["Polarity"][batch]: 222 | ps *= -1 223 | 224 | return xs, ys, ps, rec_xs, rec_ys 225 | 226 | def augment_gt(self, gt, batch): 227 | """ 228 | Augment ground truth data with horizontal and vertical. 229 | :param gt: dictionary containing ground truth data 230 | :param batch: batch index 231 | """ 232 | 233 | for _, mechanism in enumerate(self.config["loader"]["augment"]): 234 | 235 | if mechanism == "Horizontal": 236 | if self.batch_augmentation["Horizontal"][batch]: 237 | for key in gt.keys(): 238 | gt[key] = torch.flip(gt[key], dims=[2]) 239 | if key == "gtflow": 240 | gt[key][0, ...] *= -1 241 | 242 | elif mechanism == "Vertical": 243 | if self.batch_augmentation["Vertical"][batch]: 244 | for key in gt.keys(): 245 | gt[key] = torch.flip(gt[key], dims=[1]) 246 | if key == "gtflow": 247 | gt[key][1, ...] *= -1 248 | 249 | return gt 250 | 251 | @staticmethod 252 | def create_list_encoding(xs, ys, ts, ps): 253 | """ 254 | Creates a four channel tensor with all the events in the input partition. 255 | :param xs: [N] tensor with event x location 256 | :param ys: [N] tensor with event y location 257 | :param ts: [N] tensor with normalized event timestamp 258 | :param ps: [N] tensor with event polarity ([-1, 1]) 259 | :return [N x 4] list event representation 260 | """ 261 | 262 | return torch.stack([ts, ys, xs, ps]) 263 | 264 | @staticmethod 265 | def create_polarity_mask(ps): 266 | """ 267 | Creates a two channel tensor that acts as a mask for the input event list. 268 | :param ps: [N] tensor with event polarity ([-1, 1]) 269 | :return [N x 2] polarity list event representation 270 | """ 271 | 272 | event_list_pol_mask = torch.stack([ps, ps]) 273 | event_list_pol_mask[0, :][event_list_pol_mask[0, :] < 0] = 0 274 | event_list_pol_mask[0, :][event_list_pol_mask[0, :] > 0] = 1 275 | event_list_pol_mask[1, :][event_list_pol_mask[1, :] < 0] = -1 276 | event_list_pol_mask[1, :][event_list_pol_mask[1, :] > 0] = 0 277 | event_list_pol_mask[1, :] *= -1 278 | return event_list_pol_mask 279 | 280 | def create_cnt_encoding(self, xs, ys, ps, rect_mapping): 281 | """ 282 | Creates a per-pixel and per-polarity event count representation. 283 | :param xs: [N] tensor with event x location 284 | :param ys: [N] tensor with event y location 285 | :param ps: [N] tensor with event polarity ([-1, 1]) 286 | :param rect_mapping: map used to rectify events 287 | :return [2 x H x W] rectified event count representation 288 | """ 289 | 290 | # create event count representation and rectify it using backward mapping 291 | event_cnt = events_to_channels(xs, ys, ps, sensor_size=self.res) 292 | if rect_mapping is not None: 293 | event_cnt = event_cnt.permute(1, 2, 0) 294 | event_cnt = cv2.remap(event_cnt.cpu().numpy(), rect_mapping, None, cv2.INTER_NEAREST) 295 | event_cnt = torch.from_numpy(event_cnt.astype(np.float32)).to(self.device) 296 | event_cnt = event_cnt.permute(2, 0, 1) 297 | 298 | return event_cnt 299 | 300 | @staticmethod 301 | def create_mask_encoding(event_cnt): 302 | """ 303 | Creates per-pixel event mask based on event count. 304 | :param event_cnt: [2 x H x W] event count 305 | :return [H x W] rectified event mask representation 306 | """ 307 | 308 | event_mask = event_cnt.clone() 309 | event_mask = torch.sum(event_mask, dim=0, keepdim=True) 310 | event_mask[event_mask > 0.0] = 1.0 311 | 312 | return event_mask 313 | 314 | def create_voxel_encoding(self, xs, ys, ts, ps, rect_mapping, num_bins=5): 315 | """ 316 | Creates a spatiotemporal voxel grid tensor representation with a certain number of bins, 317 | as described in Section 3.1 of the paper 'Unsupervised Event-based Learning of Optical Flow, 318 | Depth, and Egomotion', Zhu et al., CVPR'19.. 319 | Events are distributed to the spatiotemporal closest bins through bilinear interpolation. 320 | Positive events are added as +1, while negative as -1. 321 | :param xs: [N] tensor with event x location 322 | :param ys: [N] tensor with event y location 323 | :param ts: [N] tensor with normalized event timestamp 324 | :param ps: [N] tensor with event polarity ([-1, 1]) 325 | :param rect_mapping: map used to rectify events 326 | :param num_bins: number of bins in the voxel grid 327 | :return [B x H x W] rectified voxel grid event representation 328 | """ 329 | 330 | event_voxel = events_to_voxel( 331 | xs, 332 | ys, 333 | ts, 334 | ps, 335 | num_bins, 336 | sensor_size=self.res, 337 | ) 338 | 339 | if rect_mapping is not None: 340 | event_voxel = event_voxel.permute(1, 2, 0) 341 | event_voxel = cv2.remap(event_voxel.cpu().numpy(), rect_mapping, None, cv2.INTER_NEAREST) 342 | event_voxel = torch.from_numpy(event_voxel.astype(np.float32)).to(self.device) 343 | event_voxel = event_voxel.permute(2, 0, 1) 344 | 345 | return event_voxel 346 | 347 | @staticmethod 348 | def split_event_list(event_list, event_list_pol_mask, max_num_grad_events): 349 | """ 350 | Splits the event list into two lists, one of them (with max. length) to be used for backprop. 351 | This helps reducing (VRAM) memory consumption. 352 | :param event_list: [4 x N] list event representation 353 | :param event_list_pol_mask: [2 x N] polarity list event representation 354 | :param max_num_grad_events: maximum number of events to be used for backprop 355 | :return event_list: [4 x N] list event representation to be used for backprop 356 | :return event_list_pol_mask: [2 x N] polarity list event representation to be used for backprop 357 | :return d_event_list: [4 x N] list event representation 358 | :return d_event_list_pol_mask: [2 x N] polarity list event representation 359 | """ 360 | 361 | d_event_list = torch.zeros((4, 0)) 362 | d_event_list_pol_mask = torch.zeros((2, 0)) 363 | if max_num_grad_events is not None and event_list.shape[1] > max_num_grad_events: 364 | probs = torch.ones(event_list.shape[1], dtype=torch.float32) / event_list.shape[1] 365 | sampled_indices = probs.multinomial( 366 | max_num_grad_events, replacement=False 367 | ) # sample indices with equal prob. 368 | 369 | unsampled_indices = torch.ones(event_list.shape[1], dtype=torch.bool) 370 | unsampled_indices[sampled_indices] = False 371 | d_event_list = event_list[:, unsampled_indices] 372 | d_event_list_pol_mask = event_list_pol_mask[:, unsampled_indices] 373 | 374 | event_list = event_list[:, sampled_indices] 375 | event_list_pol_mask = event_list_pol_mask[:, sampled_indices] 376 | 377 | return event_list, event_list_pol_mask, d_event_list, d_event_list_pol_mask 378 | 379 | def __len__(self): 380 | return 1000 # not used 381 | 382 | def shuffle(self, flag=True): 383 | """ 384 | Shuffles the training data. 385 | :param flag: if true, shuffles the data 386 | """ 387 | 388 | if flag: 389 | random.shuffle(self.files) 390 | 391 | @staticmethod 392 | def custom_collate(batch): 393 | """ 394 | Collects the different event representations and stores them together in a dictionary. 395 | :param batch: batch index 396 | :return batch_dict: dictionary with the output of a dataloader iteration 397 | """ 398 | 399 | # create dictionary 400 | batch_dict = {} 401 | for key in batch[0].keys(): 402 | batch_dict[key] = [] 403 | 404 | # collect data 405 | for entry in batch: 406 | for key in entry.keys(): 407 | batch_dict[key].append(entry[key]) 408 | 409 | # create batches 410 | for key in batch_dict.keys(): 411 | 412 | if batch_dict[key][0] is not None: 413 | 414 | # pad entries of different size 415 | N = 0 416 | if key in ["event_list", "event_list_pol_mask", "d_event_list", "d_event_list_pol_mask"]: 417 | for i in range(len(batch_dict[key])): 418 | if N < batch_dict[key][i].shape[1]: 419 | N = batch_dict[key][i].shape[1] 420 | 421 | for i in range(len(batch_dict[key])): 422 | zeros = torch.zeros((batch_dict[key][i].shape[0], N - batch_dict[key][i].shape[1])) 423 | batch_dict[key][i] = torch.cat((batch_dict[key][i], zeros), dim=1) 424 | 425 | # create tensor 426 | item = torch.stack(batch_dict[key]) 427 | if len(item.shape) == 3: 428 | item = item.transpose(2, 1) 429 | batch_dict[key] = item 430 | 431 | else: 432 | batch_dict[key] = None 433 | 434 | return batch_dict 435 | -------------------------------------------------------------------------------- /dataloader/cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hdf5plugin 3 | import h5py 4 | import numpy as np 5 | import torch 6 | import yaml 7 | 8 | 9 | class CacheDataset: 10 | """ 11 | Utility class to "cache" the output of the dataloader in hdf5 files 12 | for a more efficient dataloader. 13 | """ 14 | 15 | def __init__(self, config, dir, mode="train"): 16 | self.keys = {} 17 | 18 | data_keys = ["path", "mode", "window", "voxel"] 19 | for key in data_keys: 20 | self.keys[key] = config["data"][key] 21 | 22 | loader_keys = ["resolution", "augment", "augment_prob"] 23 | for key in loader_keys: 24 | self.keys[key] = config["loader"][key] 25 | 26 | if not os.path.exists(dir): 27 | os.system("mkdir " + dir) 28 | 29 | self.dir = dir + "cache_" + mode 30 | dict_file = self.dir + "/dataset_keys.yml" 31 | if os.path.exists(self.dir): 32 | if os.path.isfile(dict_file): 33 | tmp_keys = self.read_yaml(dict_file) 34 | if self.keys != tmp_keys: # there are keys, but diff from current settings 35 | print("Deleting cache dir:", self.dir) 36 | os.system("rm -rf " + self.dir + "/*") 37 | self.write_yaml(dict_file, self.keys) 38 | else: 39 | self.write_yaml(dict_file, self.keys) # no keys, write them 40 | else: 41 | os.system("mkdir " + self.dir) 42 | self.write_yaml(dict_file, self.keys) # no cache, create it 43 | 44 | @staticmethod 45 | def read_yaml(file): 46 | with open(file, "r") as f: 47 | tmp_keys = yaml.load(f, Loader=yaml.FullLoader) 48 | return tmp_keys 49 | 50 | @staticmethod 51 | def write_yaml(file, keys): 52 | with open(file, "w") as outfile: 53 | yaml.dump(keys, outfile, default_flow_style=False) 54 | 55 | def update(self, filename, dict): 56 | """ 57 | Update the cached dataset with the given data. 58 | :param filename: name of the file to update 59 | :param dict: dictionary with the data to update 60 | """ 61 | 62 | filename = self.dir + "/" + filename.split("/")[-1] 63 | if not os.path.isfile(filename): 64 | file = h5py.File(filename, "w") 65 | file.attrs["idx"] = 0 66 | else: 67 | file = h5py.File(filename, "a") 68 | file.attrs["idx"] += 1 69 | 70 | for key in dict: 71 | file.create_dataset( 72 | key + "/{:09d}".format(file.attrs["idx"]), 73 | data=dict[key].numpy(), 74 | dtype=np.dtype(np.float32), 75 | **hdf5plugin.Zstd() 76 | ) 77 | 78 | def load(self, filename, idx): 79 | """ 80 | Load the cached dataset with the given data. 81 | :param filename: name of the file to load 82 | :param idx: index of the data to load 83 | :return data: dictionary with the loaded data 84 | :return success: flag indicating if the data was loaded correctly 85 | """ 86 | 87 | filename = self.dir + "/" + filename.split("/")[-1] 88 | if not os.path.isfile(filename): 89 | return {}, False 90 | 91 | file = h5py.File(filename, "r") 92 | 93 | data = {} 94 | success = True 95 | entry = "{:09d}".format(idx) 96 | for key in file.keys(): 97 | if entry in file[key].keys(): 98 | data[key] = torch.from_numpy(file[key + "/" + entry][:]) 99 | else: 100 | success = False 101 | break 102 | 103 | file.close() 104 | 105 | return data, success 106 | -------------------------------------------------------------------------------- /dataloader/encodings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Monash University https://github.com/TimoStoff/events_contrast_maximization 3 | """ 4 | 5 | import torch 6 | 7 | 8 | def events_to_image(xs, ys, ps, sensor_size=(180, 240), accumulate=True): 9 | """ 10 | Accumulate events into an image. 11 | :param xs: event x coordinates 12 | :param ys: event y coordinates 13 | :param ps: event polarity 14 | :param sensor_size: sensor size 15 | :param accumulate: flag indicating whether to accumulate events into the image 16 | :return img: image containing per-pixel event counts 17 | """ 18 | 19 | device = xs.device 20 | img_size = list(sensor_size) 21 | img = torch.zeros(img_size, device=device) 22 | 23 | if xs.dtype is not torch.long: 24 | xs = xs.long().to(device) 25 | if ys.dtype is not torch.long: 26 | ys = ys.long().to(device) 27 | img.index_put_((ys, xs), ps, accumulate=accumulate) 28 | 29 | return img 30 | 31 | 32 | def events_to_voxel(xs, ys, ts, ps, num_bins, sensor_size=(180, 240)): 33 | """ 34 | Generate a voxel grid from input events using temporal bilinear interpolation. 35 | :param xs: event x coordinates 36 | :param ys: event y coordinates 37 | :param ts: event timestamps 38 | :param ps: event polarity 39 | :param num_bins: number of bins in the voxel grid 40 | :param sensor_size: sensor size 41 | :return: voxel grid representation 42 | """ 43 | 44 | assert len(xs) == len(ys) and len(ys) == len(ts) and len(ts) == len(ps) 45 | 46 | voxel = [] 47 | ts = ts * (num_bins - 1) 48 | device = xs.device 49 | 50 | zeros = torch.zeros(ts.size(), device=device) 51 | for b_idx in range(num_bins): 52 | weights = torch.max(zeros, 1.0 - torch.abs(ts - b_idx)) 53 | voxel_bin = events_to_image(xs, ys, ps * weights, sensor_size=sensor_size) 54 | voxel.append(voxel_bin) 55 | 56 | return torch.stack(voxel) 57 | 58 | 59 | def events_to_channels(xs, ys, ps, sensor_size=(180, 240)): 60 | """ 61 | Generate a two-channel event image containing per-pixel event counters. 62 | :param xs: event x coordinates 63 | :param ys: event y coordinates 64 | :param ps: event polarity 65 | :param sensor_size: sensor size 66 | :return: event image containing per-pixel and per-polarity event counts 67 | """ 68 | 69 | assert len(xs) == len(ys) and len(ys) == len(ps) 70 | 71 | mask_pos = ps.clone() 72 | mask_neg = ps.clone() 73 | mask_pos[ps < 0] = 0 74 | mask_neg[ps > 0] = 0 75 | mask_pos[ps > 0] = 1 76 | mask_neg[ps < 0] = -1 77 | 78 | pos_cnt = events_to_image(xs, ys, ps * mask_pos, sensor_size=sensor_size) 79 | neg_cnt = events_to_image(xs, ys, ps * mask_neg, sensor_size=sensor_size) 80 | 81 | return torch.stack([pos_cnt, neg_cnt]) 82 | -------------------------------------------------------------------------------- /dataloader/h5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import hdf5plugin 5 | import h5py 6 | import numpy as np 7 | 8 | import torch 9 | 10 | from .base import BaseDataLoader 11 | from .cache import CacheDataset 12 | from .utils import ProgressBar 13 | 14 | parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15 | sys.path.append(parent_dir_name) 16 | 17 | from utils.utils import binary_search_array 18 | 19 | 20 | class FlowMaps: 21 | """ 22 | Utility class for reading the ground truth optical flow maps encoded in the HDF5 files. 23 | """ 24 | 25 | def __init__(self): 26 | self.ts_from = [] 27 | self.ts_to = [] 28 | self.names = [] 29 | 30 | def __call__(self, name, h5obj): 31 | if hasattr(h5obj, "dtype") and name not in self.names: 32 | self.names += [name] 33 | self.ts_from += [h5obj.attrs["timestamp_from"]] 34 | self.ts_to += [h5obj.attrs["timestamp_to"]] 35 | 36 | 37 | class H5Loader(BaseDataLoader): 38 | def __init__(self, config, shuffle=False, path_cache=""): 39 | super().__init__(config) 40 | self.input_window = self.config["data"]["window"] 41 | if self.config["data"]["mode"] in ["gtflow"] and self.input_window > 1: 42 | print("DataLoader error: Ground truth data mode cannot be used with window > 1.") 43 | raise AttributeError 44 | 45 | self.ts_jump = False 46 | self.ts_jump_reset = False # used in the inference loops to reset model states 47 | 48 | self.gt_avg_dt = None 49 | self.gt_avg_idx = 0 50 | self.last_proc_timestamp = 0 51 | 52 | # "memory" that goes from forward pass to the next 53 | self.batch_idx = [i for i in range(self.batch_size)] # event sequence 54 | self.batch_row = [0 for i in range(self.batch_size)] # event_idx / time_idx 55 | self.batch_pass = [0 for i in range(self.batch_size)] # forward passes 56 | 57 | # input event sequences 58 | self.files = [] 59 | for root, dirs, files in os.walk(config["data"]["path"]): 60 | for file in files: 61 | if file.endswith(".h5"): 62 | self.files.append(os.path.join(root, file)) 63 | 64 | # shuffle files 65 | if shuffle: 66 | self.shuffle() 67 | 68 | # initialize cache 69 | if self.config["data"]["cache"]: 70 | self.cache = CacheDataset(config, path_cache) 71 | 72 | # open first files 73 | self.open_files = [] 74 | self.batch_rectify_map = [] 75 | self.batch_K_rect = [] 76 | self.batch_Q_rect = [] 77 | self.batch_rect_mapping = [] 78 | for batch in range(self.config["loader"]["batch_size"]): 79 | self.open_files.append(h5py.File(self.files[self.batch_idx[batch] % len(self.files)], "r")) 80 | if "rectification" in self.open_files[-1].keys(): 81 | self.batch_rectify_map.append(self.open_files[-1]["rectification/rectify_map"][:]) 82 | self.batch_rectify_map[-1] = torch.from_numpy(self.batch_rectify_map[-1]).float().to(self.device) 83 | 84 | K_rect, mapping, Q_rect = self.rectification_mapping(-1) 85 | self.batch_K_rect.append(K_rect) 86 | self.batch_Q_rect.append(Q_rect) 87 | self.batch_rect_mapping.append(mapping) 88 | self.rectify = True 89 | else: 90 | self.batch_rect_mapping.append(None) 91 | 92 | # load GT optical flow maps from open files 93 | self.open_files_flowmaps = [] 94 | if config["data"]["mode"] == "gtflow": 95 | for batch in range(self.batch_size): 96 | flowmaps = FlowMaps() 97 | if "flow" in self.open_files[batch].keys(): 98 | self.open_files[batch]["flow"].visititems(flowmaps) 99 | self.open_files_flowmaps.append(flowmaps) 100 | 101 | # progress bars 102 | if self.config["vis"]["bars"]: 103 | self.open_files_bar = [] 104 | for batch in range(self.config["loader"]["batch_size"]): 105 | max_iters = self.get_iters(batch) 106 | self.open_files_bar.append(ProgressBar(self.files[batch].split("/")[-1], max=max_iters)) 107 | 108 | def get_iters(self, batch): 109 | """ 110 | Compute the number of forward passes given a sequence and an input mode and window. 111 | :param batch: batch index 112 | :return: number of forward passes 113 | """ 114 | 115 | if self.config["data"]["mode"] == "events": 116 | max_iters = len(self.open_files[batch]["events/xs"]) 117 | elif self.config["data"]["mode"] == "time": 118 | max_iters = self.open_files[batch].attrs["duration"] 119 | elif self.config["data"]["mode"] == "gtflow": 120 | max_iters = len(self.open_files_flowmaps[batch].ts_to) - 1 121 | else: 122 | print("DataLoader error: Unknown mode.") 123 | raise AttributeError 124 | 125 | return max_iters // self.input_window 126 | 127 | def get_events(self, file, idx0, idx1): 128 | """ 129 | Get all the events in between two indices. 130 | :param file: file to read from 131 | :param idx0: start index 132 | :param idx1: end index 133 | :return xs: [N] numpy array with event x location 134 | :return ys: [N] numpy array with event y location 135 | :return ts: [N] numpy array with event timestamp 136 | :return ps: [N] numpy array with event polarity ([-1, 1]) 137 | """ 138 | 139 | xs = file["events/xs"][idx0:idx1] 140 | ys = file["events/ys"][idx0:idx1] 141 | ts = file["events/ts"][idx0:idx1] 142 | ps = file["events/ps"][idx0:idx1] 143 | ts -= file.attrs["t0"] # sequence starting at t0 = 0 144 | 145 | # check if temporal discontinuity in gt data modes 146 | self.ts_jump = False 147 | if self.config["data"]["mode"] in ["gtflow"]: 148 | dt = ts[-1] - self.last_proc_timestamp 149 | if self.gt_avg_dt is None: 150 | self.gt_avg_dt = dt 151 | self.gt_avg_idx += 1 152 | 153 | if dt >= 2 * self.gt_avg_dt / self.gt_avg_idx: 154 | self.ts_jump = True 155 | self.ts_jump_reset = True 156 | else: 157 | self.gt_avg_dt += dt 158 | self.gt_avg_idx += 1 159 | 160 | if ts.shape[0] > 0: 161 | self.last_proc_timestamp = ts[-1] 162 | return xs, ys, ts, ps 163 | 164 | def get_event_index(self, batch, window=0): 165 | """ 166 | Get all the event indices to be used for reading. 167 | :param batch: batch index 168 | :param window: input window 169 | :return event_idx0: event index (from) 170 | :return event_idx1: event index (to) 171 | """ 172 | 173 | restart = False 174 | event_idx0 = None 175 | event_idx1 = None 176 | if self.config["data"]["mode"] == "events": 177 | event_idx0 = self.batch_row[batch] 178 | event_idx1 = self.batch_row[batch] + window 179 | 180 | elif self.config["data"]["mode"] == "time": 181 | event_idx0 = self.find_ts_index( 182 | self.open_files[batch], self.batch_row[batch] + self.open_files[batch].attrs["t0"] 183 | ) 184 | event_idx1 = self.find_ts_index( 185 | self.open_files[batch], self.batch_row[batch] + self.open_files[batch].attrs["t0"] + window 186 | ) 187 | 188 | elif self.config["data"]["mode"] == "gtflow": 189 | idx1 = int(np.ceil(self.batch_row[batch] + window)) 190 | if np.isclose(self.batch_row[batch] + window, idx1 - 1): 191 | idx1 -= 1 192 | event_idx0 = self.find_ts_index(self.open_files[batch], self.open_files_flowmaps[batch].ts_from[idx1]) 193 | event_idx1 = self.find_ts_index(self.open_files[batch], self.open_files_flowmaps[batch].ts_to[idx1]) 194 | if self.open_files_flowmaps[batch].ts_to[idx1] > self.open_files[batch].attrs["tk"]: 195 | restart = True 196 | 197 | else: 198 | print("DataLoader error: Unknown mode.") 199 | raise AttributeError 200 | 201 | return event_idx0, event_idx1, restart 202 | 203 | def find_ts_index(self, file, timestamp, dataset="events/ts"): 204 | """ 205 | Find closest event index for a given timestamp through binary search. 206 | :param file: file to read from 207 | :param timestamp: timestamp to find 208 | :param dataset: dataset to search in 209 | :return: event index 210 | """ 211 | 212 | return binary_search_array(file[dataset], timestamp) 213 | 214 | def open_new_h5(self, batch): 215 | """ 216 | Open new H5 event sequence. 217 | :param batch: batch index 218 | """ 219 | 220 | self.ts_jump = False 221 | self.ts_jump_reset = False 222 | 223 | self.gt_avg_dt = None 224 | self.gt_avg_idx = 0 225 | self.last_proc_timestamp = 0 226 | 227 | self.open_files[batch] = h5py.File(self.files[self.batch_idx[batch] % len(self.files)], "r+") 228 | 229 | if self.rectify: 230 | self.batch_rectify_map[batch] = self.open_files[batch]["rectification/rectify_map"][:] 231 | self.batch_rectify_map[batch] = torch.from_numpy(self.batch_rectify_map[batch]).float().to(self.device) 232 | 233 | K_rect, mapping, Q_rect = self.rectification_mapping(batch) 234 | self.batch_K_rect[batch] = K_rect 235 | self.batch_Q_rect[batch] = Q_rect 236 | self.batch_rect_mapping[batch] = mapping 237 | 238 | if self.config["data"]["mode"] == "gtflow": 239 | flowmaps = FlowMaps() 240 | if "flow" in self.open_files[batch].keys(): 241 | self.open_files[batch]["flow"].visititems(flowmaps) 242 | self.open_files_flowmaps[batch] = flowmaps 243 | 244 | if self.config["vis"]["bars"]: 245 | self.open_files_bar[batch].finish() 246 | max_iters = self.get_iters(batch) 247 | self.open_files_bar[batch] = ProgressBar( 248 | self.files[self.batch_idx[batch] % len(self.files)].split("/")[-1], max=max_iters 249 | ) 250 | 251 | if "Playback" in self.batch_augmentation.keys() and self.batch_augmentation["Playback"][batch]: 252 | file = self.open_files[batch] 253 | xs = np.flip(file["events/xs"][:]) 254 | ys = np.flip(file["events/ys"][:]) 255 | ps = np.flip(file["events/ps"][:]) 256 | 257 | ts = np.flip(file["events/ts"][:]) 258 | min_ts = ts[-1] 259 | max_ts = ts[0] 260 | ts = np.absolute((ts - min_ts) / (max_ts - min_ts) - 1) 261 | ts = ts * (max_ts - min_ts) + min_ts 262 | 263 | file["events/xs"][:] = xs 264 | file["events/ys"][:] = ys 265 | file["events/ts"][:] = ts 266 | file["events/ps"][:] = ps 267 | 268 | def __getitem__(self, index): 269 | while True: 270 | batch = index % self.config["loader"]["batch_size"] 271 | 272 | # try loading cached data 273 | if self.config["data"]["cache"]: 274 | output, success = self.cache.load( 275 | self.files[self.batch_idx[batch] % len(self.files)], self.batch_pass[batch] 276 | ) 277 | if success: 278 | self.batch_row[batch] += self.input_window 279 | self.batch_pass[batch] += 1 280 | return output 281 | 282 | # trigger sequence change 283 | len_frames = 0 284 | restart = False 285 | if self.config["data"]["mode"] == "gtflow": 286 | len_frames = len(self.open_files_flowmaps[batch].ts_to) 287 | if int(np.ceil(self.batch_row[batch] + self.input_window)) >= len_frames: 288 | restart = True 289 | 290 | # load events 291 | xs = np.zeros((0)) 292 | ys = np.zeros((0)) 293 | ts = np.zeros((0)) 294 | ps = np.zeros((0)) 295 | if not restart: 296 | idx0, idx1, restart = self.get_event_index(batch, window=self.input_window) 297 | 298 | if self.config["data"]["mode"] in ["gtflow"] and self.input_window < 1.0: 299 | floor_row = int(np.floor(self.batch_row[batch])) 300 | ceil_row = int(np.ceil(self.batch_row[batch] + self.input_window)) 301 | 302 | if np.isclose(self.batch_row[batch], floor_row + 1): 303 | floor_row += 1 304 | if np.isclose(self.batch_row[batch] + self.input_window, ceil_row - 1): 305 | ceil_row -= 1 306 | 307 | idx0_change = self.batch_row[batch] - floor_row 308 | idx1_change = self.batch_row[batch] + self.input_window - floor_row 309 | 310 | delta_idx = idx1 - idx0 311 | idx1 = int(idx0 + idx1_change * delta_idx) 312 | idx0 = int(idx0 + idx0_change * delta_idx) 313 | 314 | if not restart: 315 | xs, ys, ts, ps = self.get_events(self.open_files[batch], idx0, idx1) 316 | 317 | # skip gt sample if temporal discontinuity in gt 318 | if self.config["data"]["mode"] in ["gtflow"] and self.ts_jump: 319 | self.batch_row[batch] += self.input_window 320 | self.batch_pass[batch] += 1 321 | continue 322 | 323 | # trigger sequence change 324 | if (self.config["data"]["mode"] == "events" and xs.shape[0] < self.input_window) or ( 325 | self.config["data"]["mode"] == "time" 326 | and self.batch_row[batch] + self.input_window >= self.open_files[batch].attrs["duration"] 327 | ): 328 | restart = True 329 | 330 | # reset sequence if not enough input events 331 | if restart: 332 | self.new_seq = True 333 | self.reset_sequence(batch) 334 | self.batch_row[batch] = 0 335 | self.batch_idx[batch] = max(self.batch_idx) + 1 336 | self.batch_pass[batch] = 0 337 | self.open_files[batch].close() 338 | self.open_new_h5(batch) 339 | continue 340 | 341 | # handle case with very few events 342 | if xs.shape[0] <= 10: 343 | xs = np.empty([0]) 344 | ys = np.empty([0]) 345 | ts = np.empty([0]) 346 | ps = np.empty([0]) 347 | 348 | # event formatting and timestamp normalization 349 | xs, ys, ts, ps = self.event_formatting(xs, ys, ts, ps) 350 | 351 | # rectify input events 352 | rec_xs, rec_ys = None, None 353 | if self.rectify: 354 | rec_xs, rec_ys = self.rectify_events(self.batch_rectify_map[batch], xs, ys) 355 | 356 | # data augmentation 357 | xs, ys, ps, rec_xs, rec_ys = self.augment_events(xs, ys, ps, rec_xs, rec_ys, batch) 358 | 359 | # events to lists 360 | if self.rectify: 361 | event_list = self.create_list_encoding(rec_xs, rec_ys, ts, ps) 362 | else: 363 | event_list = self.create_list_encoding(xs, ys, ts, ps) 364 | event_list_pol_mask = self.create_polarity_mask(ps) 365 | 366 | # create event representations 367 | event_cnt = self.create_cnt_encoding(xs, ys, ps, self.batch_rect_mapping[batch]) 368 | event_mask = self.create_mask_encoding(event_cnt) 369 | if self.config["data"]["voxel"] is not None: 370 | event_voxel = self.create_voxel_encoding( 371 | xs, ys, ts, ps, self.batch_rect_mapping[batch], num_bins=self.config["data"]["voxel"] 372 | ) 373 | 374 | # voxel is the preferred representation for the network's input 375 | if self.config["data"]["voxel"] is None: 376 | net_input = event_cnt.clone() 377 | else: 378 | net_input = event_voxel.clone() 379 | 380 | # load (and augment) GT maps when required 381 | gt = {} 382 | if self.config["data"]["mode"] == "gtflow": 383 | idx = int(np.ceil(self.batch_row[batch] + self.input_window)) 384 | if np.isclose(self.batch_row[batch] + self.input_window, idx - 1): 385 | idx -= 1 386 | flowmap = self.open_files[batch]["flow"][self.open_files_flowmaps[batch].names[idx]][:] 387 | flowmap = flowmap.astype(np.float32) 388 | flowmap = torch.from_numpy(flowmap).permute(2, 0, 1) 389 | gt["gtflow"] = flowmap 390 | gt["gtflow_dt"] = ( 391 | self.open_files_flowmaps[batch].ts_to[idx] - self.open_files_flowmaps[batch].ts_from[idx] 392 | ) 393 | gt["gtflow_dt"] = torch.from_numpy(np.asarray(gt["gtflow_dt"])).float() 394 | 395 | gt = self.augment_gt(gt, batch) 396 | 397 | # update window 398 | self.batch_row[batch] += self.input_window 399 | self.batch_pass[batch] += 1 400 | 401 | # break while loop if everything went well 402 | break 403 | 404 | # camera matrix for rectified and cropped events 405 | if self.rectify: 406 | K_rect, inv_K_rect = self.format_intrinsics(self.batch_K_rect[batch].copy()) 407 | 408 | # split event list (events with and without gradients) 409 | event_list, event_list_pol_mask, d_event_list, d_event_list_pol_mask = self.split_event_list( 410 | event_list, event_list_pol_mask, self.config["loader"]["max_num_grad_events"] 411 | ) 412 | 413 | # prepare output 414 | output = {} 415 | output["net_input"] = net_input.cpu() 416 | output["event_cnt"] = event_cnt.cpu() 417 | output["event_mask"] = event_mask.cpu() 418 | output["event_list"] = event_list.cpu() 419 | output["event_list_pol_mask"] = event_list_pol_mask.cpu() 420 | output["d_event_list"] = d_event_list.cpu() 421 | output["d_event_list_pol_mask"] = d_event_list_pol_mask.cpu() 422 | if self.rectify: 423 | output["K_rect"] = K_rect 424 | output["inv_K_rect"] = inv_K_rect 425 | for key in gt.keys(): 426 | output[key] = gt[key] 427 | 428 | if self.config["data"]["cache"]: 429 | self.cache.update(self.files[self.batch_idx[batch] % len(self.files)], output) 430 | 431 | return output 432 | -------------------------------------------------------------------------------- /dataloader/utils.py: -------------------------------------------------------------------------------- 1 | from progress.bar import Bar 2 | 3 | 4 | class ProgressBar(Bar): 5 | suffix = "%(percent).1f%%, ETA: %(eta)ds, %(frequency)fHz" 6 | 7 | @property 8 | def frequency(self): 9 | """ 10 | :return: frequency of the processing 11 | """ 12 | return 1 / self.avg 13 | -------------------------------------------------------------------------------- /dsec_submissions/interlaken_00_b.txt: -------------------------------------------------------------------------------- 1 | # from_timestamp_us, to_timestamp_us, file_index 2 | 51648500652, 51648600574, 820 3 | 51649000383, 51649100410, 830 4 | 51649500439, 51649600452, 840 5 | 51650000446, 51650100510, 850 6 | 51650500682, 51650600786, 860 7 | 51651002403, 51651103123, 870 8 | 51651507548, 51651607608, 880 9 | 51652007591, 51652107617, 890 10 | 51652507627, 51652607642, 900 11 | 51653007592, 51653107618, 910 12 | 51653507640, 51653607635, 920 13 | 51654007593, 51654107627, 930 14 | 51654507641, 51654607644, 940 15 | 51655007590, 51655107624, 950 16 | 51655506576, 51655605216, 960 17 | 51656001980, 51656101583, 970 18 | 51656500712, 51656600624, 980 19 | 51657000414, 51657100436, 990 20 | 51657500453, 51657600462, 1000 21 | 51658000434, 51658100469, 1010 22 | 51658500531, 51658600528, 1020 23 | 51659000538, 51659100590, 1030 24 | 51659500691, 51659600718, 1040 25 | 51660000821, 51660100908, 1050 26 | 51660501244, 51660601294, 1060 27 | 51661001436, 51661101512, 1070 28 | 51661501621, 51661601633, 1080 29 | 51662001670, 51662101723, 1090 30 | 51662501812, 51662601820, 1100 31 | 51663001755, 51663101761, 1110 32 | 51663501734, 51663601698, 1120 33 | 51664001599, 51664101614, 1130 34 | 51664501591, 51664601559, 1140 35 | 51665001449, 51665101458, 1150 36 | 51665501384, 51665601350, 1160 37 | 51666001166, 51666101162, 1170 38 | 51666501061, 51666601014, 1180 39 | 51667000858, 51667100853, 1190 40 | 51667500764, 51667600743, 1200 41 | 51668000625, 51668100640, 1210 42 | 51668500635, 51668600621, 1220 43 | 51669000539, 51669100565, 1230 44 | 51669500595, 51669600612, 1240 45 | 51670000676, 51670100746, 1250 46 | 51670500846, 51670600864, 1260 47 | 51671000856, 51671100904, 1270 48 | 51671500961, 51671600966, 1280 49 | 51672000911, 51672100933, 1290 50 | 51672500987, 51672601036, 1300 51 | 51673001126, 51673101184, 1310 52 | 51673501146, 51673601151, 1320 53 | 51674001077, 51674101089, 1330 54 | 51674501045, 51674601036, 1340 55 | 51675000975, 51675101031, 1350 56 | 51675501218, 51675601212, 1360 57 | 51676000975, 51676100938, 1370 58 | 51676500808, 51676600844, 1380 59 | -------------------------------------------------------------------------------- /dsec_submissions/interlaken_00_b_flag.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/dsec_submissions/interlaken_00_b_flag.npy -------------------------------------------------------------------------------- /dsec_submissions/interlaken_01_a.txt: -------------------------------------------------------------------------------- 1 | # from_timestamp_us, to_timestamp_us, file_index 2 | 53193302164, 53193402021, 20 3 | 53193801533, 53193901444, 30 4 | 53194301052, 53194400996, 40 5 | 53194800836, 53194900802, 50 6 | 53195300686, 53195400691, 60 7 | 53195800652, 53195900640, 70 8 | 53196300560, 53196400544, 80 9 | 53196800490, 53196900470, 90 10 | 53197300401, 53197400387, 100 11 | 53197800326, 53197900314, 110 12 | 53198300254, 53198400248, 120 13 | 53198800240, 53198900238, 130 14 | 53199300244, 53199400246, 140 15 | 53199800252, 53199900254, 150 16 | 53200300240, 53200400238, 160 17 | 53200800248, 53200900246, 170 18 | 53201300244, 53201400246, 180 19 | 53201800255, 53201900261, 190 20 | 53202300230, 53202400233, 200 21 | 53202800270, 53202900274, 210 22 | 53203300256, 53203400258, 220 23 | 53203800280, 53203900262, 230 24 | 53204300264, 53204400262, 240 25 | 53204800271, 53204900280, 250 26 | 53205300286, 53205400293, 260 27 | 53205800320, 53205900330, 270 28 | 53206300358, 53206400362, 280 29 | 53206800402, 53206900410, 290 30 | 53207300409, 53207400391, 300 31 | 53207800405, 53207900403, 310 32 | 53208300373, 53208400387, 320 33 | 53208800346, 53208900340, 330 34 | 53209300320, 53209400317, 340 35 | 53209800292, 53209900309, 350 36 | 53210300302, 53210400306, 360 37 | 53210800326, 53210900333, 370 38 | 53211300310, 53211400308, 380 39 | 53211800295, 53211900286, 390 40 | 53212300259, 53212400236, 400 41 | 53212800232, 53212900232, 410 42 | 53213300216, 53213400214, 420 43 | 53213800226, 53213900221, 430 44 | 53214300208, 53214400212, 440 45 | 53217300184, 53217400206, 500 46 | 53217800200, 53217900201, 510 47 | 53218300180, 53218400186, 520 48 | 53218800210, 53218900210, 530 49 | 53219300187, 53219400185, 540 50 | 53219800195, 53219900197, 550 51 | 53220300187, 53220400189, 560 52 | 53220800203, 53220900204, 570 53 | 53221300183, 53221400208, 580 54 | 53221800204, 53221900210, 590 55 | 53222300192, 53222400198, 600 56 | 53222800204, 53222900201, 610 57 | 53223300199, 53223400181, 620 58 | 53223800202, 53223900208, 630 59 | 53224300208, 53224400208, 640 60 | 53224800202, 53224900200, 650 61 | 53225300186, 53225400192, 660 62 | 53225800203, 53225900214, 670 63 | 53226300200, 53226400210, 680 64 | 53226800244, 53226900246, 690 65 | 53227300236, 53227400237, 700 66 | 53227800246, 53227900224, 710 67 | 53228300204, 53228400198, 720 68 | 53228800196, 53228900198, 730 69 | 53229300168, 53229400149, 740 70 | 53229800171, 53229900177, 750 71 | 53230300160, 53230400158, 760 72 | 53230800168, 53230900170, 770 73 | 53263300127, 53263400129, 1420 74 | 53263800123, 53263900121, 1430 75 | 53264300107, 53264400113, 1440 76 | 53264800127, 53264900130, 1450 77 | 53265300131, 53265400136, 1460 78 | 53265800136, 53265900141, 1470 79 | 53266300135, 53266400146, 1480 80 | 53266800168, 53266900181, 1490 81 | 53267300186, 53267400188, 1500 82 | 53267800199, 53267900202, 1510 83 | 53268300187, 53268400193, 1520 84 | 53268800211, 53268900216, 1530 85 | 53269300190, 53269400206, 1540 86 | 53269800162, 53269900158, 1550 87 | 53270300127, 53270400128, 1560 88 | 53270800144, 53270900150, 1570 89 | 53271300120, 53271400142, 1580 90 | 53271800136, 53271900138, 1590 91 | 53272300122, 53272400130, 1600 92 | 53272800162, 53272900161, 1610 93 | 53273300127, 53273400149, 1620 94 | 53273800139, 53273900140, 1630 95 | 53274300116, 53274400122, 1640 96 | 53274800124, 53274900126, 1650 97 | 53279800123, 53279900133, 1750 98 | 53280300107, 53280400113, 1760 99 | 53280800111, 53280900113, 1770 100 | 53281300107, 53281400117, 1780 101 | 53281800123, 53281900129, 1790 102 | 53282300124, 53282400128, 1800 103 | 53282800140, 53282900126, 1810 104 | 53283300120, 53283400126, 1820 105 | 53283800128, 53283900138, 1830 106 | 53284300116, 53284400118, 1840 107 | 53284800128, 53284900138, 1850 108 | 53285300112, 53285400124, 1860 109 | 53285800144, 53285900150, 1870 110 | -------------------------------------------------------------------------------- /dsec_submissions/interlaken_01_a_flag.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/dsec_submissions/interlaken_01_a_flag.npy -------------------------------------------------------------------------------- /dsec_submissions/thun_01_a.txt: -------------------------------------------------------------------------------- 1 | # from_timestamp_us, to_timestamp_us, file_index 2 | 49740900618, 49741000566, 20 3 | 49741400626, 49741500641, 30 4 | 49741900678, 49742000617, 40 5 | 49742400635, 49742500616, 50 6 | 49742900582, 49743000532, 60 7 | 49743400576, 49743500585, 70 8 | 49743900598, 49744000541, 80 9 | 49744400576, 49744500588, 90 10 | 49744900595, 49745000553, 100 11 | 49745400604, 49745500596, 110 12 | 49745900633, 49746000569, 120 13 | 49746400609, 49746500584, 130 14 | 49746900586, 49747000533, 140 15 | 49747400588, 49747500564, 150 16 | 49747900570, 49748000521, 160 17 | 49748400574, 49748500573, 170 18 | 49748900585, 49749000526, 180 19 | 49749400562, 49749500565, 190 20 | -------------------------------------------------------------------------------- /dsec_submissions/thun_01_a_flag.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/dsec_submissions/thun_01_a_flag.npy -------------------------------------------------------------------------------- /dsec_submissions/thun_01_b.txt: -------------------------------------------------------------------------------- 1 | # from_timestamp_us, to_timestamp_us, file_index 2 | 49768600704, 49768700706, 140 3 | 49769100705, 49769200699, 150 4 | 49769600710, 49769700710, 160 5 | 49770100686, 49770200688, 170 6 | 49770600702, 49770700702, 180 7 | 49771100682, 49771200682, 190 8 | 49771600663, 49771700662, 200 9 | 49772100647, 49772200649, 210 10 | 49772600667, 49772700669, 220 11 | 49773100632, 49773200630, 230 12 | 49773600660, 49773700665, 240 13 | 49774100661, 49774200662, 250 14 | 49774600679, 49774700685, 260 15 | 49775100667, 49775200685, 270 16 | 49775600660, 49775700660, 280 17 | 49776100602, 49776200598, 290 18 | 49776600612, 49776700618, 300 19 | 49777100579, 49777200577, 310 20 | 49777600592, 49777700598, 320 21 | 49778100588, 49778200570, 330 22 | 49778600592, 49778700594, 340 23 | 49779100572, 49779200568, 350 24 | 49779600579, 49779700590, 360 25 | 49780100559, 49780200565, 370 26 | 49780600578, 49780700576, 380 27 | 49781100547, 49781200545, 390 28 | 49781600566, 49781700572, 400 29 | 49782100547, 49782200553, 410 30 | 49782600579, 49782700574, 420 31 | 49783100592, 49783200589, 430 32 | 49783600590, 49783700593, 440 33 | 49784100584, 49784200584, 450 34 | 49784600601, 49784700619, 460 35 | 49785100592, 49785200596, 470 36 | 49785600608, 49785700616, 480 37 | 49786100613, 49786200615, 490 38 | 49786600652, 49786700647, 500 39 | 49787100629, 49787200635, 510 40 | 49787600640, 49787700639, 520 41 | 49788100610, 49788200606, 530 42 | 49788600602, 49788700602, 540 43 | 49789100572, 49789200551, 550 44 | 49789600565, 49789700575, 560 45 | 49790100575, 49790200598, 570 46 | 49790600616, 49790700636, 580 47 | 49791100658, 49791200662, 590 48 | 49791600707, 49791700709, 600 49 | 49792100682, 49792200672, 610 50 | 49792600651, 49792700646, 620 51 | 49793100632, 49793200638, 630 52 | 49793600632, 49793700636, 640 53 | 49794100664, 49794200672, 650 54 | 49794600754, 49794700753, 660 55 | 49795100656, 49795200645, 670 56 | 49795600636, 49795700634, 680 57 | 49796100612, 49796200617, 690 58 | 49796600619, 49796700629, 700 59 | 49797100645, 49797200631, 710 60 | 49797600686, 49797700699, 720 61 | 49798100748, 49798200764, 730 62 | 49798600766, 49798700772, 740 63 | 49799100706, 49799200686, 750 64 | 49799600656, 49799700640, 760 65 | 49800100585, 49800200586, 770 66 | 49800600596, 49800700586, 780 67 | 49801100578, 49801200582, 790 68 | 49801600616, 49801700615, 800 69 | 49802100604, 49802200616, 810 70 | 49802600642, 49802700630, 820 71 | 49812600554, 49812700562, 1020 72 | 49813100570, 49813200582, 1030 73 | 49813600617, 49813700630, 1040 74 | 49814100646, 49814200654, 1050 75 | 49814600698, 49814700687, 1060 76 | 49815100677, 49815200683, 1070 77 | -------------------------------------------------------------------------------- /dsec_submissions/thun_01_b_flag.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/dsec_submissions/thun_01_b_flag.npy -------------------------------------------------------------------------------- /dsec_submissions/zurich_city_12_a.txt: -------------------------------------------------------------------------------- 1 | # from_timestamp_us, to_timestamp_us, file_index 2 | 57923607612, 57923707606, 20 3 | 57924107583, 57924207609, 30 4 | 57924607602, 57924707609, 40 5 | 57925107606, 57925207588, 50 6 | 57925607613, 57925707611, 60 7 | 57926107604, 57926207590, 70 8 | 57926607607, 57926707614, 80 9 | 57927107595, 57927207589, 90 10 | 57927607598, 57927707592, 100 11 | 57928107577, 57928207591, 110 12 | 57928607596, 57928707599, 120 13 | 57929107572, 57929207586, 130 14 | 57929607599, 57929707598, 140 15 | 57930107575, 57930207593, 150 16 | 57930607598, 57930707604, 160 17 | 57931107573, 57931207583, 170 18 | 57931607584, 57931707591, 180 19 | 57932107564, 57932207578, 190 20 | 57932607595, 57932707577, 200 21 | 57933107566, 57933207572, 210 22 | 57933607585, 57933707588, 220 23 | 57934107573, 57934207567, 230 24 | 57934607600, 57934707595, 240 25 | 57935107560, 57935207570, 250 26 | 57935607579, 57935707585, 260 27 | 57936107578, 57936207572, 270 28 | 57936607585, 57936707580, 280 29 | 57937107577, 57937207567, 290 30 | 57937607576, 57937707578, 300 31 | 57938107571, 57938207561, 310 32 | 57938607578, 57938707573, 320 33 | 57939107546, 57939207560, 330 34 | 57939607577, 57939707579, 340 35 | 57940107568, 57940207554, 350 36 | 57940607563, 57940707558, 360 37 | 57941107547, 57941207565, 370 38 | 57941607554, 57941707561, 380 39 | 57942107542, 57942207552, 390 40 | 57942607569, 57942707567, 400 41 | 57943107540, 57943207554, 410 42 | 57943607563, 57943707566, 420 43 | 57944107535, 57944207537, 430 44 | 57944607566, 57944707568, 440 45 | 57945107533, 57945207539, 450 46 | 57945607540, 57945707543, 460 47 | 57946107552, 57946207550, 470 48 | 57946607555, 57946707558, 480 49 | 57947107547, 57947207549, 490 50 | 57947607558, 57947707552, 500 51 | 57948107545, 57948207539, 510 52 | 57948607536, 57948707543, 520 53 | 57949107524, 57949207542, 530 54 | 57949607539, 57949707557, 540 55 | 57950107518, 57950207524, 550 56 | 57950607529, 57950707536, 560 57 | 57951107537, 57951207523, 570 58 | 57951607548, 57951707547, 580 59 | 57952107516, 57952207534, 590 60 | 57952607527, 57952707533, 600 61 | 57953107534, 57953207512, 610 62 | 57953607529, 57953707528, 620 63 | 57954107509, 57954207519, 630 64 | 57954607532, 57954707534, 640 65 | 57955107511, 57955207513, 650 66 | 57955607518, 57955707537, 660 67 | 57956107506, 57956207512, 670 68 | 57956607517, 57956707523, 680 69 | 57957107504, 57957207506, 690 70 | 57957607523, 57957707526, 700 71 | 57958107503, 57958207505, 710 72 | 57958607522, 57958707529, 720 73 | 57959107498, 57959207500, 730 74 | 57959607517, 57959707523, 740 75 | 57960107492, 57960207494, 750 76 | 57960607503, 57960707510, 760 77 | -------------------------------------------------------------------------------- /dsec_submissions/zurich_city_12_a_flag.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/dsec_submissions/zurich_city_12_a_flag.npy -------------------------------------------------------------------------------- /dsec_submissions/zurich_city_14_c.txt: -------------------------------------------------------------------------------- 1 | # from_timestamp_us, to_timestamp_us, file_index 2 | 55397507504, 55397607514, 1080 3 | 55398007460, 55398107494, 1090 4 | 55398507508, 55398607514, 1100 5 | 55399007460, 55399107490, 1110 6 | 55399507504, 55399607514, 1120 7 | 55400007460, 55400107490, 1130 8 | 55400507504, 55400607507, 1140 9 | 55401007457, 55401107491, 1150 10 | 55401507501, 55401607515, 1160 11 | 55402007457, 55402107483, 1170 12 | 55402507501, 55402607515, 1180 13 | -------------------------------------------------------------------------------- /dsec_submissions/zurich_city_14_c_flag.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/dsec_submissions/zurich_city_14_c_flag.npy -------------------------------------------------------------------------------- /dsec_submissions/zurich_city_15_a.txt: -------------------------------------------------------------------------------- 1 | # from_timestamp_us, to_timestamp_us, file_index 2 | 39550400396, 39550500400, 20 3 | 39550900417, 39551000368, 30 4 | 39551400439, 39551500442, 40 5 | 39551900447, 39552000394, 50 6 | 39552400451, 39552500450, 60 7 | 39552900474, 39553000409, 70 8 | 39553400460, 39553500464, 80 9 | 39553900463, 39554000410, 90 10 | 39554400455, 39554500453, 100 11 | 39554900458, 39555000410, 110 12 | 39555400446, 39555500449, 120 13 | 39555900444, 39556000390, 130 14 | 39556400436, 39556500434, 140 15 | 39556900438, 39557000386, 150 16 | 39557400440, 39557500447, 160 17 | 39557900460, 39558000407, 170 18 | 39558400456, 39558500454, 180 19 | 39558900459, 39559000392, 190 20 | 39559400426, 39559500428, 200 21 | 39559900463, 39560000410, 210 22 | 39560400503, 39560500556, 220 23 | 39560900668, 39561000636, 230 24 | 39561400782, 39561500822, 240 25 | 39561900931, 39562000912, 250 26 | 39562401091, 39562501127, 260 27 | 39562901152, 39563001078, 270 28 | 39576400440, 39576500437, 540 29 | 39576900419, 39577000374, 550 30 | 39577400424, 39577500404, 560 31 | 39577900427, 39578000362, 570 32 | 39578400403, 39578500405, 580 33 | 39578900427, 39579000366, 590 34 | 39579400430, 39579500442, 600 35 | 39579900485, 39580000438, 610 36 | 39580400473, 39580500470, 620 37 | 39580900500, 39581000446, 630 38 | 39581400479, 39581500496, 640 39 | 39581900538, 39582000487, 650 40 | 39582400499, 39582500521, 660 41 | 39582900494, 39583000418, 670 42 | 39583400427, 39583500420, 680 43 | 39583900403, 39584000340, 690 44 | 39584400411, 39584500417, 700 45 | 39584900466, 39585000425, 710 46 | 39585400526, 39585500549, 720 47 | 39585900558, 39586000501, 730 48 | 39586400541, 39586500531, 740 49 | 39586900498, 39587000439, 750 50 | 39587400514, 39587500534, 760 51 | 39587900570, 39588000514, 770 52 | 39588400538, 39588500534, 780 53 | 39588900502, 39589000464, 790 54 | 39597400886, 39597500906, 960 55 | 39597900884, 39598000823, 970 56 | 39598400904, 39598500924, 980 57 | 39598901012, 39599000955, 990 58 | 39599400910, 39599500918, 1000 59 | 39599900843, 39600000774, 1010 60 | 39600400754, 39600500718, 1020 61 | 39600900666, 39601000607, 1030 62 | 39601400580, 39601500566, 1040 63 | 39601900548, 39602000492, 1050 64 | 39602400522, 39602500528, 1060 65 | 39602900545, 39603000484, 1070 66 | 39603400535, 39603500543, 1080 67 | 39603900588, 39604000542, 1090 68 | 39604400619, 39604500620, 1100 69 | 39608900441, 39609000392, 1190 70 | 39609400429, 39609500440, 1200 71 | 39609900484, 39610000424, 1210 72 | 39610400524, 39610500559, 1220 73 | 39610900589, 39611000536, 1230 74 | -------------------------------------------------------------------------------- /dsec_submissions/zurich_city_15_a_flag.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/dsec_submissions/zurich_city_15_a_flag.npy -------------------------------------------------------------------------------- /eval_flow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mlflow 4 | import torch 5 | 6 | from configs.parser import YAMLParser 7 | from dataloader.h5 import H5Loader 8 | from loss.flow_val import * 9 | from models.model import * 10 | from utils.iwe import compute_pol_iwe 11 | from utils.utils import load_model, create_model_dir, initialize_quant_results 12 | from utils.mlflow import log_config, log_results 13 | from utils.visualization import Visualization 14 | 15 | 16 | def test(args, config_parser): 17 | """ 18 | Main function of the evaluation pipeline for event-based optical flow estimation. 19 | :param args: arguments of the script 20 | :param config_parser: YAMLParser object with config data 21 | """ 22 | 23 | mlflow.set_tracking_uri(args.path_mlflow) 24 | 25 | run = mlflow.get_run(args.runid) 26 | config = config_parser.merge_configs(run.data.params) 27 | config = config_parser.combine_entries(config) 28 | 29 | # configs 30 | config["loader"]["batch_size"] = 1 31 | 32 | # create directory for inference results 33 | path_results = create_model_dir(args.path_results, args.runid) 34 | 35 | # store validation settings 36 | eval_id = log_config(path_results, args.runid, config) 37 | 38 | # initialize settings 39 | device = config_parser.device 40 | kwargs = config_parser.loader_kwargs 41 | config["loader"]["device"] = device 42 | 43 | # visualization tool 44 | vis = Visualization(config, eval_id=eval_id, path_results=path_results) 45 | 46 | # data loader 47 | data = H5Loader(config, shuffle=True) 48 | dataloader = torch.utils.data.DataLoader( 49 | data, 50 | drop_last=True, 51 | batch_size=config["loader"]["batch_size"], 52 | collate_fn=data.custom_collate, 53 | worker_init_fn=config_parser.worker_init_fn, 54 | **kwargs, 55 | ) 56 | 57 | # model initialization and settings 58 | num_bins = 2 if config["data"]["voxel"] is None else config["data"]["voxel"] 59 | model = eval(config["model"]["name"])(config["model"].copy(), num_bins) 60 | model = model.to(device) 61 | model, _ = load_model(args.runid, model, device) 62 | model.eval() 63 | 64 | # validation metric 65 | criteria = eval(config["metrics"]["warping"])(config, device) 66 | val_results = {} 67 | 68 | # inference loop 69 | end_test = False 70 | with torch.no_grad(): 71 | while not end_test: 72 | for inputs in dataloader: 73 | sequence = data.files[data.batch_idx[0] % len(data.files)].split("/")[-1].split(".")[0] 74 | 75 | if data.new_seq: 76 | data.new_seq = False 77 | model.reset_states() 78 | criteria.reset() 79 | 80 | if config["data"]["mode"] in ["gtflow"] and data.ts_jump_reset: 81 | data.ts_jump_reset = False 82 | model.reset_states() 83 | 84 | # finish inference loop 85 | if data.seq_num >= len(data.files): 86 | end_test = True 87 | break 88 | 89 | # forward pass 90 | x = model(inputs["net_input"].to(device)) 91 | for i in range(len(x["flow"])): 92 | x["flow"][i] = x["flow"][i] * config["loss"]["flow_scaling"] 93 | 94 | # mask flow for visualization 95 | flow_vis = x["flow"][-1].clone() 96 | if config["vis"]["mask_output"]: 97 | flow_vis *= inputs["event_mask"].to(device) 98 | 99 | # image of warped events 100 | iwe = None 101 | if (config["vis"]["enabled"] or config["vis"]["store"]) and ( 102 | config["vis"]["show"] is None or "iwe" in config["vis"]["show"] 103 | ): 104 | iwe = compute_pol_iwe( 105 | flow_vis, 106 | inputs["event_list"].to(device), 107 | config["loader"]["resolution"], 108 | inputs["event_list_pol_mask"].to(device), 109 | round_idx=False, 110 | round_flow=False, 111 | ) 112 | 113 | # update validation criteria 114 | criteria.update( 115 | x["flow"], 116 | inputs["event_list"].to(device), 117 | inputs["event_list_pol_mask"].to(device), 118 | inputs["event_mask"].to(device), 119 | ) 120 | 121 | # prepare for visualization 122 | if config["vis"]["enabled"] or config["vis"]["store"]: 123 | 124 | # dynamic windows 125 | if config["data"]["passes_loss"] > 1 and config["vis"]["dynamic"]: 126 | vis.data["events_dynamic"] = criteria.window_events() 127 | vis.data["iwe_fw_dynamic"] = criteria.window_iwe(mode="forward") 128 | vis.data["iwe_bw_dynamic"] = criteria.window_iwe(mode="backward") 129 | vis.data["flow_dynamic"] = criteria.window_flow(mode="forward") 130 | 131 | # accumulated windows 132 | if criteria.num_passes > 1 and criteria.num_passes == config["data"]["passes_loss"]: 133 | vis.data["events_window"] = criteria.window_events() 134 | vis.data["iwe_fw_window"] = criteria.window_iwe(mode="forward") 135 | vis.data["iwe_bw_window"] = criteria.window_iwe(mode="backward") 136 | vis.data["flow_window"] = criteria.window_flow(mode="forward") 137 | 138 | # compute error metrics 139 | vis.data["flow_bw"] = None 140 | val_results = initialize_quant_results(val_results, sequence, config["metrics"]["name"]) 141 | if criteria.num_passes == config["data"]["passes_loss"]: 142 | 143 | compute_metrics = True 144 | if "eval_time" in config["metrics"].keys(): 145 | if ( 146 | data.last_proc_timestamp < config["metrics"]["eval_time"][0] 147 | or data.last_proc_timestamp > config["metrics"]["eval_time"][1] 148 | ): 149 | compute_metrics = False 150 | 151 | if compute_metrics: 152 | 153 | # AEE 154 | if config["data"]["mode"] == "gtflow" and "AEE" in config["metrics"]["name"]: 155 | mask_aee = None 156 | if "mask_aee" in config["metrics"].keys() and config["metrics"]["mask_aee"]: 157 | mask_aee = criteria.window_events().clone().to(device) 158 | 159 | vis.data["flow_bw"] = ( 160 | criteria.window_flow(mode="backward", mask=False) * config["data"]["passes_loss"] 161 | ) 162 | aee = criteria.compute_aee(vis.data["flow_bw"], inputs["gtflow"].to(device), mask=mask_aee) 163 | val_results[sequence]["AEE"]["it"] += 1 164 | val_results[sequence]["AEE"]["metric"] += aee.cpu().numpy() 165 | 166 | # deblurring metrics 167 | for metric in config["metrics"]["name"]: 168 | if metric == "RSAT": 169 | rsat = criteria.rsat() 170 | val_results[sequence][metric]["metric"] += rsat[0].cpu().numpy() 171 | val_results[sequence][metric]["it"] += 1 172 | 173 | elif metric == "FWL": 174 | fwl = criteria.fwl() 175 | val_results[sequence][metric]["metric"] += fwl.cpu().numpy() 176 | val_results[sequence][metric]["it"] += 1 177 | 178 | # reset criteria 179 | criteria.reset() 180 | 181 | # visualization 182 | if config["vis"]["bars"]: 183 | for bar in data.open_files_bar: 184 | bar.next() 185 | if config["vis"]["enabled"] or config["vis"]["store"]: 186 | vis.data["iwe"] = iwe 187 | vis.data["flow"] = flow_vis 188 | vis.step( 189 | inputs, 190 | sequence=sequence, 191 | ts=data.last_proc_timestamp, 192 | show=config["vis"]["show"], 193 | ) 194 | 195 | if config["vis"]["bars"]: 196 | for bar in data.open_files_bar: 197 | bar.finish() 198 | 199 | # store validation config and results 200 | results = {} 201 | for metric in config["metrics"]["name"]: 202 | results[metric] = {} 203 | for key in val_results.keys(): 204 | if val_results[key][metric]["it"] > 0: 205 | results[metric][key] = str(val_results[key][metric]["metric"] / val_results[key][metric]["it"]) 206 | log_results(args.runid, results, path_results, eval_id) 207 | print(results) 208 | 209 | 210 | if __name__ == "__main__": 211 | parser = argparse.ArgumentParser() 212 | parser.add_argument("runid", help="mlflow run") 213 | parser.add_argument( 214 | "--config", 215 | default="configs/eval_flow.yml", 216 | help="config file, overwrites mlflow settings", 217 | ) 218 | parser.add_argument( 219 | "--path_mlflow", 220 | default="", 221 | help="location of the mlflow ui", 222 | ) 223 | parser.add_argument("--path_results", default="results_inference/") 224 | args = parser.parse_args() 225 | 226 | # launch testing 227 | test(args, YAMLParser(args.config)) 228 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/loss/__init__.py -------------------------------------------------------------------------------- /loss/flow.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import os 3 | import sys 4 | 5 | import math 6 | import torch 7 | 8 | parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 9 | sys.path.append(parent_dir_name) 10 | 11 | from utils.iwe import event_propagation, get_event_flow, purge_unfeasible, get_interpolation, interpolate 12 | 13 | 14 | class BaseEventWarping(torch.nn.Module): 15 | """ 16 | Base class for the contrast maximization loss. 17 | """ 18 | 19 | def __init__(self, config, device, loss_scaling=True, border_compensation=True): 20 | super(BaseEventWarping, self).__init__() 21 | self.device = device 22 | self.config = config 23 | self.loss_scaling = loss_scaling 24 | self.border_compensation = border_compensation 25 | self.res = config["loader"]["resolution"] 26 | self.batch_size = config["loader"]["batch_size"] 27 | self.flow_spat_smooth_weight = config["loss"]["flow_spat_smooth_weight"] 28 | self.flow_temp_smooth_weight = config["loss"]["flow_temp_smooth_weight"] 29 | 30 | self._passes = 0 31 | self._num_flows = None 32 | self._flow_maps_x = None 33 | self._flow_maps_y = None 34 | 35 | # warping indices (for temporal consistency) 36 | my, mx = torch.meshgrid(torch.arange(self.res[0]), torch.arange(self.res[1])) 37 | indices = torch.stack([my, mx], dim=0).unsqueeze(0) 38 | self.indices = indices.float().to(device) 39 | self.indices_mask = torch.ones(self.indices.shape).to(device) 40 | 41 | # timescales for loss computation 42 | self.passes_loss = [] 43 | for s in range(config["data"]["scales_loss"]): 44 | self.passes_loss.append(config["data"]["passes_loss"] // (2**s)) 45 | 46 | def update_base(self, flow_list): 47 | """ 48 | Initialize/Update container lists of events and flow maps for forward warping. 49 | :param flow_list: [[batch_size x 2 x H x W]] list of optical flow (x, y) maps 50 | """ 51 | 52 | if self._num_flows is None: 53 | self._num_flows = len(flow_list) 54 | 55 | # update optical flow maps 56 | if self._flow_maps_x is None: 57 | self._flow_maps_x = [] 58 | self._flow_maps_y = [] 59 | 60 | for i, flow in enumerate(flow_list): 61 | if i == len(self._flow_maps_x): 62 | self._flow_maps_x.append(flow[:, 0:1, :, :]) 63 | self._flow_maps_y.append(flow[:, 1:2, :, :]) 64 | else: 65 | self._flow_maps_x[i] = torch.cat([self._flow_maps_x[i], flow[:, 0:1, :, :]], dim=1) 66 | self._flow_maps_y[i] = torch.cat([self._flow_maps_y[i], flow[:, 1:2, :, :]], dim=1) 67 | 68 | def reset_base(self): 69 | """ 70 | Reset lists. 71 | """ 72 | 73 | self._passes = 0 74 | self._flow_maps_x = None 75 | self._flow_maps_y = None 76 | 77 | @property 78 | def num_passes(self): 79 | return self._passes 80 | 81 | def iwe_formatting(self, warped_events, pol_mask, ts_list, tref, ts_scaling, interp_zeros=None, iwe_zeros=None): 82 | """ 83 | Compues the images of warped events with event count and (accumulated) event timestamp information. 84 | :param warped_events: [batch_size x N x 2] warped events 85 | :param pol_mask: [batch_size x 4*N x 2] polarity mask of the (forward) warped events 86 | :param ts_list: [batch_size x N x 1] event timestamp [0, max_ts] 87 | :param tref: reference timestamp 88 | :param ts_scaling: with of the timestamp normalization window 89 | :return iwe: image of warped events 90 | :return iwe_ts: image of averaged timestamps 91 | """ 92 | 93 | # normalize timestamps (the ref timestamp is 1, the rest decreases linearly towards 0 on the extremes) 94 | norm_ts = ts_list.clone() 95 | norm_ts = 1 - torch.abs(tref - norm_ts) / ts_scaling 96 | 97 | # interpolate forward 98 | idx, weights = get_interpolation(warped_events, self.res, zeros=interp_zeros) 99 | 100 | # per-polarity image of (forward) warped events 101 | iwe_pos = interpolate(idx, weights, self.res, polarity_mask=pol_mask[:, :, 0:1], zeros=iwe_zeros) 102 | iwe_neg = interpolate(idx, weights, self.res, polarity_mask=pol_mask[:, :, 1:2], zeros=iwe_zeros) 103 | iwe = torch.cat([iwe_pos, iwe_neg], dim=1) 104 | 105 | # image of (forward) warped averaged timestamps 106 | iwe_pos_ts = interpolate(idx, weights * norm_ts, self.res, polarity_mask=pol_mask[:, :, 0:1], zeros=iwe_zeros) 107 | iwe_neg_ts = interpolate(idx, weights * norm_ts, self.res, polarity_mask=pol_mask[:, :, 1:2], zeros=iwe_zeros) 108 | iwe_ts = torch.cat([iwe_pos_ts, iwe_neg_ts], dim=1) 109 | 110 | return iwe, iwe_ts 111 | 112 | def focus_loss(self, iwe, iwe_ts): 113 | """ 114 | Scaling of the loss function based on the number of events in the image space. 115 | See "Self-Supervised Learning of Event-Based Optical Flow with Spiking Neural Networks", 116 | Hagenaars and Paredes-Valles et al., NeurIPS 2021 117 | :param iwe: [batch_size x N x 2] image of warped events 118 | :param iwe_ts: [batch_size x N x 2] image of averaged timestamps 119 | :return loss: loss value 120 | """ 121 | 122 | iwe_ts = iwe_ts.view(iwe_ts.shape[0], 2, -1) 123 | loss = torch.sum(iwe_ts[:, 0, :] ** 2, dim=1) + torch.sum(iwe_ts[:, 1, :] ** 2, dim=1) 124 | if self.loss_scaling: 125 | nonzero_px = torch.sum(iwe, dim=1, keepdim=True).bool() 126 | nonzero_px = nonzero_px.view(nonzero_px.shape[0], -1) 127 | loss /= torch.sum(nonzero_px, dim=1) + 1e-9 128 | 129 | return torch.sum(loss) 130 | 131 | def flow_temporal_smoothing(self): 132 | """ 133 | (Temporal) Scaled Charbonnier smoothness prior on the estimated optical flow vectors. 134 | :return loss: smoothing loss value 135 | """ 136 | 137 | loss = 0 138 | for i in range(self._num_flows): 139 | for j in range(self._flow_maps_x[i].shape[1] - 1): 140 | 141 | # compute (backward) warping indices 142 | flow = torch.stack([self._flow_maps_y[i][:, j, ...], self._flow_maps_x[i][:, j, ...]], dim=1) 143 | warped_indices = self.indices + flow 144 | warped_indices = warped_indices.view(self.batch_size, 2, -1).permute(0, 2, 1) 145 | 146 | # ignore pixels that go out of the image space 147 | warping_mask = ( 148 | (warped_indices[..., 0] >= 0) 149 | * (warped_indices[..., 0] <= self.res[0] - 1.0) 150 | * (warped_indices[..., 1] >= 0) 151 | * (warped_indices[..., 1] <= self.res[1] - 1.0) 152 | ) 153 | 154 | # (backward) warp the next flow 155 | warped_flow = get_event_flow( 156 | self._flow_maps_x[i][:, j + 1, ...], self._flow_maps_y[i][:, j + 1, ...], warped_indices 157 | ) 158 | warped_flow = warped_flow.permute(0, 2, 1).view(self.batch_size, 2, self.res[0], self.res[1]) 159 | 160 | # compute flow temporal consistency (charbonnier) 161 | flow_dt = torch.sqrt((flow - warped_flow) ** 2 + 1e-9) 162 | flow_dt = torch.sum(flow_dt, dim=1, keepdim=True).view(self.batch_size, -1) 163 | loss += torch.sum(flow_dt * warping_mask, dim=1) / (torch.sum(warping_mask, dim=1) + 1e-9) 164 | 165 | loss /= self._num_flows 166 | loss /= self._passes - 1 167 | 168 | return self.flow_temp_smooth_weight * loss.sum() 169 | 170 | def flow_spatial_smoothing(self): 171 | """ 172 | (Spatial) Scaled Charbonnier smoothness prior on the estimated optical flow vectors. 173 | :return loss: smoothing loss value 174 | """ 175 | 176 | loss = 0 177 | for i in range(self._num_flows): 178 | 179 | # forward differences (horizontal, vertical, and diagonals) 180 | flow_x_dx = self._flow_maps_x[i][:, :, :, :-1] - self._flow_maps_x[i][:, :, :, 1:] 181 | flow_y_dx = self._flow_maps_y[i][:, :, :, :-1] - self._flow_maps_y[i][:, :, :, 1:] 182 | flow_x_dy = self._flow_maps_x[i][:, :, :-1, :] - self._flow_maps_x[i][:, :, 1:, :] 183 | flow_y_dy = self._flow_maps_y[i][:, :, :-1, :] - self._flow_maps_y[i][:, :, 1:, :] 184 | flow_x_dxdy_dr = self._flow_maps_x[i][:, :, :-1, :-1] - self._flow_maps_x[i][:, :, 1:, 1:] 185 | flow_y_dxdy_dr = self._flow_maps_y[i][:, :, :-1, :-1] - self._flow_maps_y[i][:, :, 1:, 1:] 186 | flow_x_dxdy_ur = self._flow_maps_x[i][:, :, 1:, :-1] - self._flow_maps_x[i][:, :, :-1, 1:] 187 | flow_y_dxdy_ur = self._flow_maps_y[i][:, :, 1:, :-1] - self._flow_maps_y[i][:, :, :-1, 1:] 188 | 189 | # compute flow spatial consistency (charbonnier) 190 | flow_dx = torch.sqrt((flow_x_dx) ** 2 + 1e-6) + torch.sqrt((flow_y_dx) ** 2 + 1e-6) 191 | flow_dy = torch.sqrt((flow_x_dy) ** 2 + 1e-6) + torch.sqrt((flow_y_dy) ** 2 + 1e-6) 192 | flow_dxdy_dr = torch.sqrt((flow_x_dxdy_dr) ** 2 + 1e-6) + torch.sqrt((flow_y_dxdy_dr) ** 2 + 1e-6) 193 | flow_dxdy_ur = torch.sqrt((flow_x_dxdy_ur) ** 2 + 1e-6) + torch.sqrt((flow_y_dxdy_ur) ** 2 + 1e-6) 194 | 195 | flow_dx = flow_dx.view(self.batch_size, self._passes, -1) 196 | flow_dy = flow_dy.view(self.batch_size, self._passes, -1) 197 | flow_dxdy_dr = flow_dxdy_dr.view(self.batch_size, self._passes, -1) 198 | flow_dxdy_ur = flow_dxdy_ur.view(self.batch_size, self._passes, -1) 199 | 200 | flow_dx = flow_dx.mean(2).mean(1) 201 | flow_dy = flow_dy.mean(2).mean(1) 202 | flow_dxdy_ur = flow_dxdy_ur.mean(2).mean(1) 203 | flow_dxdy_dr = flow_dxdy_dr.mean(2).mean(1) 204 | 205 | loss += (flow_dx + flow_dy + flow_dxdy_dr + flow_dxdy_ur) / 4 206 | 207 | loss /= self._num_flows 208 | 209 | return self.flow_spat_smooth_weight * loss.sum() 210 | 211 | @abstractmethod 212 | def forward(self): 213 | raise NotImplementedError 214 | 215 | 216 | class Linear(BaseEventWarping): 217 | """ 218 | Contrast maximization loss from Hagenaars and Paredes-Valles et al. (NeurIPS 2021). 219 | """ 220 | 221 | def __init__(self, config, device, loss_scaling=True): 222 | super().__init__(config, device, loss_scaling=loss_scaling) 223 | self._event_ts = [] 224 | self._event_loc = [] 225 | self._event_flow = [] 226 | self._event_pol_mask = [] 227 | 228 | self._d_event_ts = [] 229 | self._d_event_loc = [] 230 | self._d_event_flow = [] 231 | self._d_event_pol_mask = [] 232 | 233 | def update(self, flow_list, event_list, pol_mask, d_event_list, d_pol_mask): 234 | """ 235 | Initialize/Update container lists of events and flow maps for forward warping. 236 | :param flow_list: [[batch_size x 2 x H x W]] list of optical flow (x, y) maps 237 | :param event_list: [batch_size x N x 4] input events (ts, y, x, p) 238 | :param pol_mask: [batch_size x N x 2] polarity mask (pos, neg) 239 | :param d_event_list: [batch_size x N x 4] detached input events (ts, y, x, p) 240 | :param d_pol_mask: [batch_size x N x 2] detached polarity mask (pos, neg) 241 | """ 242 | 243 | # update base lists (flow maps) 244 | self.update_base(flow_list) 245 | 246 | # update event timestamps 247 | event_list[:, :, 0:1] += self._passes 248 | d_event_list[:, :, 0:1] += self._passes 249 | event_ts = event_list[:, :, 0:1].clone() 250 | d_event_ts = d_event_list[:, :, 0:1].clone() 251 | if self.config["loss"]["round_ts"]: 252 | event_ts[...] = event_ts.min() + 0.5 253 | d_event_ts[...] = d_event_ts.min() + 0.5 254 | self._event_ts.append(event_ts) 255 | self._d_event_ts.append(d_event_ts) 256 | 257 | # update event location 258 | self._event_loc.append(event_list[:, :, 1:3].clone()) 259 | self._d_event_loc.append(d_event_list[:, :, 1:3].clone()) 260 | 261 | # update polarity mask 262 | self._event_pol_mask.append(pol_mask.clone()) 263 | self._d_event_pol_mask.append(d_pol_mask.clone()) 264 | 265 | # update per-event flow vector 266 | event_flow = [] 267 | d_event_flow = [] 268 | for i in range(self._num_flows): 269 | event_flow.append( 270 | get_event_flow( 271 | self._flow_maps_x[i][:, -1, ...], 272 | self._flow_maps_y[i][:, -1, ...], 273 | event_list[:, :, 1:3], 274 | ) 275 | ) 276 | with torch.no_grad(): 277 | d_event_flow.append( 278 | get_event_flow( 279 | self._flow_maps_x[i][:, -1, ...], 280 | self._flow_maps_y[i][:, -1, ...], 281 | d_event_list[:, :, 1:3], 282 | ) 283 | ) 284 | self._event_flow.append(event_flow) 285 | self._d_event_flow.append(d_event_flow) 286 | 287 | # update timestamp index 288 | self._passes += 1 289 | 290 | def reset(self): 291 | """ 292 | Reset lists. 293 | """ 294 | 295 | self.reset_base() 296 | self._event_ts = [] 297 | self._event_loc = [] 298 | self._event_flow = [] 299 | self._event_pol_mask = [] 300 | 301 | self._d_event_ts = [] 302 | self._d_event_loc = [] 303 | self._d_event_flow = [] 304 | self._d_event_pol_mask = [] 305 | 306 | def forward(self): 307 | 308 | loss = 0 309 | for s, scale in enumerate(self.passes_loss): 310 | 311 | loss_update = 0 312 | for w in range(2**s): 313 | low_pass = w * scale 314 | high_pass = (w + 1) * scale 315 | 316 | event_ts = torch.cat(self._event_ts[low_pass:high_pass], dim=1) 317 | event_loc = torch.cat(self._event_loc[low_pass:high_pass], dim=1) 318 | ts_list = torch.cat([event_ts for _ in range(4)], dim=1) 319 | 320 | d_event_ts = torch.cat(self._d_event_ts[low_pass:high_pass], dim=1) 321 | d_event_loc = torch.cat(self._d_event_loc[low_pass:high_pass], dim=1) 322 | d_ts_list = torch.cat([d_event_ts for _ in range(4)], dim=1) 323 | 324 | if not self.border_compensation: 325 | event_pol_mask = torch.cat(self._event_pol_mask[low_pass:high_pass], dim=1) 326 | event_pol_mask = torch.cat([event_pol_mask for _ in range(4)], dim=1) 327 | d_event_pol_mask = torch.cat(self._d_event_pol_mask[low_pass:high_pass], dim=1) 328 | d_event_pol_mask = torch.cat([d_event_pol_mask for _ in range(4)], dim=1) 329 | 330 | for i in range(self._num_flows): 331 | if self.border_compensation: 332 | event_pol_mask = torch.cat(self._event_pol_mask[low_pass:high_pass], dim=1) 333 | d_event_pol_mask = torch.cat(self._d_event_pol_mask[low_pass:high_pass], dim=1) 334 | 335 | # event propagation (with grads) 336 | event_flow = torch.cat([flow[i] for flow in self._event_flow[low_pass:high_pass]], dim=1) 337 | fw_events = event_propagation(event_ts, event_loc, event_flow, high_pass) 338 | bw_events = event_propagation(event_ts, event_loc, event_flow, low_pass) 339 | 340 | if self.border_compensation: 341 | fw_events, event_pol_mask = purge_unfeasible(fw_events, event_pol_mask, self.res) 342 | bw_events, event_pol_mask = purge_unfeasible(bw_events, event_pol_mask, self.res) 343 | event_pol_mask = torch.cat([event_pol_mask for _ in range(4)], dim=1) 344 | 345 | fw_iwe, fw_iwe_ts = self.iwe_formatting( 346 | fw_events, 347 | event_pol_mask, 348 | ts_list, 349 | high_pass, 350 | scale, 351 | ) 352 | bw_iwe, bw_iwe_ts = self.iwe_formatting( 353 | bw_events, 354 | event_pol_mask, 355 | ts_list, 356 | low_pass, 357 | scale, 358 | ) 359 | 360 | # event propagation (without grads) 361 | d_event_flow = torch.cat([flow[i] for flow in self._d_event_flow[low_pass:high_pass]], dim=1) 362 | d_fw_events = event_propagation(d_event_ts, d_event_loc, d_event_flow, high_pass) 363 | d_bw_events = event_propagation(d_event_ts, d_event_loc, d_event_flow, low_pass) 364 | if self.border_compensation: 365 | d_fw_events, d_event_pol_mask = purge_unfeasible(d_fw_events, d_event_pol_mask, self.res) 366 | d_bw_events, d_event_pol_mask = purge_unfeasible(d_bw_events, d_event_pol_mask, self.res) 367 | d_event_pol_mask = torch.cat([d_event_pol_mask for _ in range(4)], dim=1) 368 | 369 | d_fw_iwe, d_fw_iwe_ts = self.iwe_formatting( 370 | d_fw_events, 371 | d_event_pol_mask, 372 | d_ts_list, 373 | high_pass, 374 | scale, 375 | ) 376 | d_bw_iwe, d_bw_iwe_ts = self.iwe_formatting( 377 | d_bw_events, 378 | d_event_pol_mask, 379 | d_ts_list, 380 | low_pass, 381 | scale, 382 | ) 383 | 384 | # compute loss (forward) 385 | fw_iwe = fw_iwe + d_fw_iwe 386 | fw_iwe_ts = fw_iwe_ts + d_fw_iwe_ts 387 | fw_iwe_ts = fw_iwe_ts / (fw_iwe + 1e-9) # per-pixel and per-polarity timestamps 388 | loss_update += self.focus_loss(fw_iwe, fw_iwe_ts) 389 | 390 | # compute loss (backward) 391 | bw_iwe = bw_iwe + d_bw_iwe 392 | bw_iwe_ts = bw_iwe_ts + d_bw_iwe_ts 393 | bw_iwe_ts = bw_iwe_ts / (bw_iwe + 1e-9) # per-pixel and per-polarity timestamps 394 | loss_update += self.focus_loss(bw_iwe, bw_iwe_ts) 395 | 396 | loss_update /= 2**s # number of deblurring windows for a given scale 397 | loss_update /= 2 # number of deblurring points per deblurring window 398 | loss += loss_update 399 | 400 | # average loss over all flow predictions 401 | loss /= self.config["data"]["scales_loss"] 402 | loss /= self._num_flows 403 | 404 | # spatial smoothing of predicted flow vectors 405 | if self.flow_spat_smooth_weight is not None: 406 | loss += self.flow_spatial_smoothing() 407 | 408 | # temporal consistency of predicted flow vectors 409 | if self.flow_temp_smooth_weight is not None and self._passes > 1: 410 | loss += self.flow_temporal_smoothing() 411 | 412 | return loss 413 | 414 | 415 | class Iterative(BaseEventWarping): 416 | """ 417 | Contrast maximization loss from Hagenaars and Paredes-Valles et al. (NeurIPS 2021) but augmented 418 | with iterative event warping, loss computation at all intermediate (time) points, and multiple temporal scales. 419 | """ 420 | 421 | def __init__(self, config, device, loss_scaling=True): 422 | if config["loss"]["iterative_mode"] == "four": 423 | config["data"]["passes_loss"] *= 2 424 | 425 | super().__init__(config, device, loss_scaling=loss_scaling) 426 | self._event_ts = [] 427 | self._event_loc = [] 428 | self._event_pol_mask = [] 429 | 430 | self._d_event_ts = [] 431 | self._d_event_loc = [] 432 | self._d_event_pol_mask = [] 433 | 434 | self.delta_passes = [] 435 | for passes in self.passes_loss: 436 | if config["loss"]["iterative_mode"] == "one": 437 | self.delta_passes.append(passes // 1) 438 | elif config["loss"]["iterative_mode"] == "two": 439 | self.delta_passes.append(passes // 2) 440 | elif config["loss"]["iterative_mode"] == "four": 441 | self.delta_passes.append(passes // 4) 442 | 443 | def update(self, flow_list, event_list, pol_mask, d_event_list, d_pol_mask): 444 | """ 445 | Initialize/Update container lists of events and flow maps for forward warping. 446 | :param flow_list: [[batch_size x 2 x H x W]] list of optical flow (x, y) maps 447 | :param event_list: [batch_size x N x 4] input events (ts, y, x, p) 448 | :param pol_mask: [batch_size x N x 2] polarity mask (pos, neg) 449 | :param d_event_list: [batch_size x N x 4] detached input events (ts, y, x, p) 450 | :param d_pol_mask: [batch_size x N x 2] detached polarity mask (pos, neg) 451 | """ 452 | 453 | # update base lists (event data, flow maps, event masks) 454 | self.update_base(flow_list) 455 | 456 | # update event timestamps 457 | event_list[:, :, 0:1] += self._passes 458 | d_event_list[:, :, 0:1] += self._passes 459 | event_ts = event_list[:, :, 0:1].clone() 460 | d_event_ts = d_event_list[:, :, 0:1].clone() 461 | if self.config["loss"]["round_ts"]: 462 | event_ts[...] = event_ts.min() + 0.5 463 | d_event_ts[...] = d_event_ts.min() + 0.5 464 | self._event_ts.append(event_ts) 465 | self._d_event_ts.append(d_event_ts) 466 | 467 | # update event locations 468 | self._event_loc.append(event_list[:, :, 1:3].clone()) 469 | self._d_event_loc.append(d_event_list[:, :, 1:3].clone()) 470 | 471 | # update event polarity masks 472 | self._event_pol_mask.append(pol_mask.clone()) 473 | self._d_event_pol_mask.append(d_pol_mask.clone()) 474 | 475 | # update timestamp index 476 | self._passes += 1 477 | 478 | def reset(self): 479 | """ 480 | Reset lists. 481 | """ 482 | 483 | self.reset_base() 484 | self._event_ts = [] 485 | self._event_loc = [] 486 | self._event_pol_mask = [] 487 | 488 | self._d_event_ts = [] 489 | self._d_event_loc = [] 490 | self._d_event_pol_mask = [] 491 | 492 | def update_warping_indices(self, tref, t, cnt, mode="forward"): 493 | """ 494 | Updates indices for forward/backward iterative event warping. 495 | :param tref: reference timestamp for deblurring 496 | :param t: current (starting) timestamp 497 | :param cnt: counter that counts the number of warping passes a set of events have undergone 498 | :param mode: "forward" or "backward" warping mode 499 | :return cnt: updated counter 500 | :return break_flag: flag indicating whether the warping should be performed 501 | :return sampling_idx: indices for flow sampling 502 | :return warping_ts: reference time for current warping 503 | """ 504 | 505 | warping_ts = t + cnt 506 | sampling_idx = t + cnt 507 | if mode == "forward": 508 | warping_ts += 1 509 | break_flag = t + cnt < tref 510 | cnt += 1 511 | 512 | elif mode == "backward": 513 | break_flag = t + cnt >= tref 514 | cnt -= 1 515 | 516 | else: 517 | raise ValueError("Unknown warping mode: {}".format(mode)) 518 | 519 | return cnt, break_flag, sampling_idx, warping_ts 520 | 521 | def event_warping( 522 | self, 523 | tref, 524 | t, 525 | i, 526 | mode, 527 | buffer_event_loc, 528 | buffer_event_ts, 529 | buffer_event_pol_mask, 530 | warped_events, 531 | warped_events_ts, 532 | warped_events_mask, 533 | ): 534 | """ 535 | Perform forward/backward iterative event warping. 536 | :param tref: reference timestamp for deblurring 537 | :param t: current (starting) timestamp 538 | :param i: index indicating which optical flow map to use 539 | :param mode: "forward" or "backward" warping mode 540 | :param buffer_event_loc: [batch_size x N x 2] buffer for event locations 541 | :param buffer_event_ts: [batch_size x N x 1] buffer for event timestamps 542 | :param buffer_event_pol_mask: [batch_size x N x 2] buffer for event polarity masks 543 | :param warped_events: [[batch_size x N x 2]] list containing warped events at all trefs 544 | :param warped_events_ts: [[batch_size x N x 1]] list containing warped events timestamp at all trefs 545 | :param warped_events_mask: [[batch_size x N x 2]] list containing polarity masks for warped events at all trefs 546 | :return warped_events: [[batch_size x N x 2]] updated list containing warped events at all trefs 547 | :return warped_events_ts: [[batch_size x N x 1]] updated list containing warped events timestamp at all trefs 548 | :return warped_events_mask: [[batch_size x N x 2]] updated list containing polarity masks for warped events at all trefs 549 | """ 550 | 551 | cnt = 0 552 | event_loc = buffer_event_loc[t].clone() 553 | event_warp_ts = buffer_event_ts[t].clone() 554 | event_pol_mask = buffer_event_pol_mask[t].clone() 555 | while True: 556 | cnt, break_flag, sampling_idx, warping_ts = self.update_warping_indices(tref, t, cnt, mode=mode) 557 | if not break_flag: 558 | break 559 | 560 | # sample optical flow 561 | event_flow = get_event_flow( 562 | self._flow_maps_x[i][:, sampling_idx, ...], 563 | self._flow_maps_y[i][:, sampling_idx, ...], 564 | event_loc, 565 | ) 566 | 567 | # event warping 568 | event_loc = event_propagation( 569 | event_warp_ts, 570 | event_loc, 571 | event_flow, 572 | warping_ts, 573 | ) 574 | event_warp_ts[...] = warping_ts 575 | event_loc, event_pol_mask = purge_unfeasible( 576 | event_loc, 577 | event_pol_mask, 578 | self.res, 579 | ) 580 | 581 | # update warping information (when in range) 582 | warped_events[warping_ts][t] = event_loc.clone() 583 | warped_events_ts[warping_ts][t] = buffer_event_ts[t].clone() 584 | warped_events_mask[warping_ts][t] = event_pol_mask.clone() 585 | 586 | return warped_events, warped_events_ts, warped_events_mask 587 | 588 | def forward(self): 589 | 590 | loss = 0 591 | max_passes = max(self.passes_loss) 592 | for i in range(self._num_flows): 593 | none_list = [None for _ in range(max_passes)] 594 | 595 | # iterative event warping 596 | event_ts = [none_list.copy() for _ in range(max_passes + 1)] 597 | event_loc = [none_list.copy() for _ in range(max_passes + 1)] 598 | event_pol_mask = [none_list.copy() for _ in range(max_passes + 1)] 599 | for t in range(max_passes): 600 | event_loc, event_ts, event_pol_mask = self.event_warping( 601 | max_passes, 602 | t, 603 | i, 604 | "forward", 605 | self._event_loc, 606 | self._event_ts, 607 | self._event_pol_mask, 608 | event_loc, 609 | event_ts, 610 | event_pol_mask, 611 | ) 612 | event_loc, event_ts, event_pol_mask = self.event_warping( 613 | 0, 614 | t, 615 | i, 616 | "backward", 617 | self._event_loc, 618 | self._event_ts, 619 | self._event_pol_mask, 620 | event_loc, 621 | event_ts, 622 | event_pol_mask, 623 | ) 624 | 625 | # detached iterative event warping 626 | d_event_ts = [none_list.copy() for _ in range(max_passes + 1)] 627 | d_event_loc = [none_list.copy() for _ in range(max_passes + 1)] 628 | d_event_pol_mask = [none_list.copy() for _ in range(max_passes + 1)] 629 | with torch.no_grad(): 630 | for t in range(max_passes): 631 | d_event_loc, d_event_ts, d_event_pol_mask = self.event_warping( 632 | max_passes, 633 | t, 634 | i, 635 | "forward", 636 | self._d_event_loc, 637 | self._d_event_ts, 638 | self._d_event_pol_mask, 639 | d_event_loc, 640 | d_event_ts, 641 | d_event_pol_mask, 642 | ) 643 | d_event_loc, d_event_ts, d_event_pol_mask = self.event_warping( 644 | 0, 645 | t, 646 | i, 647 | "backward", 648 | self._d_event_loc, 649 | self._d_event_ts, 650 | self._d_event_pol_mask, 651 | d_event_loc, 652 | d_event_ts, 653 | d_event_pol_mask, 654 | ) 655 | 656 | # learning from multiple temporal scales (i.e., different amounts of blur) 657 | for s, scale in enumerate(self.passes_loss): 658 | 659 | loss_update = 0 660 | for w in range(2**s): 661 | low_pass = w * scale 662 | high_pass = (w + 1) * scale 663 | 664 | low_tref = low_pass 665 | high_tref = high_pass + 1 666 | if self.config["loss"]["iterative_mode"] == "four": 667 | low_tref = low_pass + self.delta_passes[s] 668 | high_tref = low_pass + 3 * self.delta_passes[s] + 1 669 | 670 | # merge event masks (to deal with partially-observable edges) 671 | if self.border_compensation: 672 | shared_event_pol_mask = none_list.copy() 673 | shared_d_event_pol_mask = none_list.copy() 674 | for t in range(low_tref, high_tref - 1): 675 | tmp_event_pol_mask = event_pol_mask[low_tref][t].clone() 676 | tmp_d_event_pol_mask = d_event_pol_mask[low_tref][t].clone() 677 | for tref in range(low_tref + 1, high_tref): 678 | tmp_event_pol_mask *= event_pol_mask[tref][t].clone() 679 | tmp_d_event_pol_mask *= d_event_pol_mask[tref][t].clone() 680 | shared_event_pol_mask[t] = tmp_event_pol_mask 681 | shared_d_event_pol_mask[t] = tmp_d_event_pol_mask 682 | 683 | # compute loss at all intermediate points 684 | for tref in range(low_tref, high_tref): 685 | low_extreme = max(low_pass, tref - self.delta_passes[s]) 686 | high_extreme = min(high_pass, tref + self.delta_passes[s]) 687 | 688 | # image of warped events (with grads) 689 | ts_list = torch.cat(event_ts[tref][low_extreme:high_extreme], dim=1) 690 | warped_events = torch.cat(event_loc[tref][low_extreme:high_extreme], dim=1) 691 | if self.border_compensation: 692 | warped_pol_mask = torch.cat(shared_event_pol_mask[low_extreme:high_extreme], dim=1) 693 | else: 694 | warped_pol_mask = torch.cat(event_pol_mask[tref][low_extreme:high_extreme], dim=1) 695 | 696 | ts_list = torch.cat([ts_list for _ in range(4)], dim=1) 697 | warped_pol_mask = torch.cat([warped_pol_mask for _ in range(4)], dim=1) 698 | iwe, iwe_ts = self.iwe_formatting( 699 | warped_events, 700 | warped_pol_mask, 701 | ts_list, 702 | tref, 703 | self.delta_passes[s], 704 | ) 705 | 706 | # image of warped events (without grads) 707 | d_ts_list = torch.cat(d_event_ts[tref][low_extreme:high_extreme], dim=1) 708 | d_warped_events = torch.cat(d_event_loc[tref][low_extreme:high_extreme], dim=1) 709 | if self.border_compensation: 710 | d_warped_pol_mask = torch.cat(shared_d_event_pol_mask[low_extreme:high_extreme], dim=1) 711 | else: 712 | d_warped_pol_mask = torch.cat(d_event_pol_mask[tref][low_extreme:high_extreme], dim=1) 713 | 714 | d_ts_list = torch.cat([d_ts_list for _ in range(4)], dim=1) 715 | d_warped_pol_mask = torch.cat([d_warped_pol_mask for _ in range(4)], dim=1) 716 | d_iwe, d_iwe_ts = self.iwe_formatting( 717 | d_warped_events, 718 | d_warped_pol_mask, 719 | d_ts_list, 720 | tref, 721 | self.delta_passes[s], 722 | ) 723 | 724 | # combination of images of warped events (with and without grads) 725 | iwe = iwe + d_iwe 726 | iwe_ts = iwe_ts + d_iwe_ts 727 | iwe_ts = iwe_ts / (iwe + 1e-9) # per-pixel and per-polarity timestamps 728 | loss_update += self.focus_loss(iwe, iwe_ts) 729 | 730 | loss_update /= 2**s # number of deblurring windows for a given scale 731 | loss_update /= 2 * self.delta_passes[s] + 1 # number of deblurring points per deblurring window 732 | loss += loss_update 733 | 734 | # average loss over all flow predictions and deblurring points 735 | loss /= self.config["data"]["scales_loss"] 736 | loss /= self._num_flows 737 | 738 | # spatial smoothing of predicted flow vectors 739 | if self.flow_spat_smooth_weight is not None: 740 | loss += self.flow_spatial_smoothing() 741 | 742 | # temporal consistency of predicted flow vectors 743 | if self.flow_temp_smooth_weight is not None and self._passes > 1: 744 | loss += self.flow_temporal_smoothing() 745 | 746 | return loss 747 | -------------------------------------------------------------------------------- /loss/flow_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | 6 | parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 7 | sys.path.append(parent_dir_name) 8 | 9 | from utils.iwe import event_propagation, get_event_flow, purge_unfeasible, get_interpolation, interpolate 10 | 11 | 12 | class BaseValidation(torch.nn.Module): 13 | """ 14 | Base class for validation metrics. 15 | """ 16 | 17 | def __init__(self, config, device): 18 | super(BaseValidation, self).__init__() 19 | self.res = config["loader"]["resolution"] 20 | self.device = device 21 | self.config = config 22 | 23 | self._passes = 0 24 | self._event_ts = None 25 | self._event_loc = None 26 | self._event_pol_mask = None 27 | 28 | self._flow_maps_x = None 29 | self._flow_maps_y = None 30 | self._event_mask = None 31 | 32 | # warping indices (for forward-propagated flow) 33 | my, mx = torch.meshgrid(torch.arange(self.res[0]), torch.arange(self.res[1])) 34 | indices = torch.stack([my, mx], dim=0).unsqueeze(0) 35 | self.indices = indices.float().view(1, 2, -1).permute(0, 2, 1).to(device) 36 | self.indices_map = self.indices.clone().permute(0, 2, 1).view(1, 2, self.res[0], self.res[1]) 37 | self.indices_mask = torch.ones((1, self.res[0] * self.res[1], 1)).to(device) 38 | 39 | @property 40 | def num_passes(self): 41 | return self._passes 42 | 43 | def forward_prop_flow(self, i, tref, flow_maps_x, flow_maps_y): 44 | """ 45 | Forward propagation of the estimated optical flow using bilinear interpolation. 46 | :param i: time at which the flow map to be warped is defined 47 | :param tref: reference time for the forward propagation 48 | :return warped_flow_x: [[batch_size x 1 x H x W]] warped, horizontal optical flow map 49 | :return warped_flow_y: [batch_size x 1 x H x W] warped, vertical optical flow map 50 | """ 51 | 52 | # sample per-pixel optical flow 53 | indices_mask = self.indices_mask.clone() 54 | indices_flow = get_event_flow(flow_maps_x[:, i, ...], flow_maps_y[:, i, ...], self.indices) 55 | 56 | # optical flow (forward) propagation 57 | warped_indices = event_propagation(i, self.indices, indices_flow, tref) 58 | warped_indices, indices_mask = purge_unfeasible(warped_indices, indices_mask, self.res) 59 | 60 | # (bilinearly) interpolate forward 61 | indices_mask = torch.cat([indices_mask for _ in range(4)], dim=1) 62 | indices_flow = torch.cat([indices_flow for _ in range(4)], dim=1) 63 | interp_warped_indices, interp_weights = get_interpolation(warped_indices, self.res) 64 | warped_weights = interpolate(interp_warped_indices, interp_weights, self.res, polarity_mask=indices_mask) 65 | warped_flow_y = interpolate( 66 | interp_warped_indices, interp_weights * indices_flow[..., 0:1], self.res, polarity_mask=indices_mask 67 | ) 68 | warped_flow_x = interpolate( 69 | interp_warped_indices, interp_weights * indices_flow[..., 1:2], self.res, polarity_mask=indices_mask 70 | ) 71 | warped_flow_y /= warped_weights + 1e-9 72 | warped_flow_x /= warped_weights + 1e-9 73 | 74 | return warped_flow_x, warped_flow_y 75 | 76 | def update_base(self, flow_list, event_list, pol_mask, event_mask): 77 | """ 78 | Initialize/Update container lists of events and flow maps for forward warping. 79 | :param flow_list: [[batch_size x 2 x H x W]] list of optical flow (x, y) maps 80 | :param event_list: [batch_size x N x 4] input events (ts, y, x, p) 81 | :param pol_mask: [batch_size x N x 2] polarity mask (pos, neg) 82 | :param event_mask: [batch_size x 1 x H x W] event mask 83 | """ 84 | 85 | # update event timestamps 86 | event_list[:, :, 0:1] += self._passes # only nonzero second time 87 | event_ts = event_list[:, :, 0:1].clone() 88 | if self.config["loss"]["round_ts"]: 89 | event_ts[...] = event_ts.min() + 0.5 90 | 91 | if self._event_ts is None: 92 | self._event_ts = event_ts 93 | self._event_loc = event_list[:, :, 1:3].clone() 94 | self._event_pol_mask = pol_mask.clone() 95 | else: 96 | self._event_ts = torch.cat([self._event_ts, event_ts], dim=1) 97 | self._event_loc = torch.cat([self._event_loc, event_list[:, :, 1:3].clone()], dim=1) 98 | self._event_pol_mask = torch.cat([self._event_pol_mask, pol_mask.clone()], dim=1) 99 | 100 | # update optical flow maps 101 | flow = flow_list[-1] # only highest resolution flow 102 | if self._flow_maps_x is None: 103 | self._flow_maps_x = flow[:, 0:1, :, :] 104 | self._flow_maps_y = flow[:, 1:2, :, :] 105 | else: 106 | self._flow_maps_x = torch.cat([self._flow_maps_x, flow[:, 0:1, :, :]], dim=1) 107 | self._flow_maps_y = torch.cat([self._flow_maps_y, flow[:, 1:2, :, :]], dim=1) 108 | 109 | # update internal smoothing mask 110 | if self._event_mask is None: 111 | self._event_mask = event_mask 112 | else: 113 | self._event_mask = torch.cat([self._event_mask, event_mask], dim=1) 114 | 115 | def reset_base(self): 116 | """ 117 | Reset lists. 118 | """ 119 | 120 | self._passes = 0 121 | self._event_ts = None 122 | self._event_loc = None 123 | self._event_pol_mask = None 124 | 125 | self._flow_maps_x = None 126 | self._flow_maps_y = None 127 | self._event_mask = None 128 | 129 | def window_events_base(self, round_idx=False): 130 | """ 131 | :param round_idx: if True, round the event coordinates to the nearest integer. 132 | :return: image-like representation of all the events in the validation time/event window. 133 | """ 134 | 135 | pol_mask_list = self._event_pol_mask 136 | if not round_idx: 137 | pol_mask_list = torch.cat([pol_mask_list for i in range(4)], dim=1) 138 | 139 | fw_idx, fw_weights = get_interpolation(self._event_loc, self.res, round_idx=round_idx) 140 | events_pos = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask_list[:, :, 0:1]) 141 | events_neg = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask_list[:, :, 1:2]) 142 | 143 | return torch.cat([events_pos, events_neg], dim=1) 144 | 145 | def window_flow_base(self, flow_maps_x, flow_maps_y, mask=False): 146 | """ 147 | :param flow_maps_x: [batch_size x num_passes x H x W] horizontal flow maps to be averaged 148 | :param flow_maps_y: [batch_size x num_passes x H x W] vertical flow maps to be averaged 149 | :return avg_flow: image-like representation of the per-pixel average flow in the validation time/event window. 150 | """ 151 | 152 | flow_x = flow_maps_x[:, 0:1, :, :] 153 | flow_y = flow_maps_y[:, 0:1, :, :] 154 | avg_flow = torch.cat([flow_x, flow_y], dim=1) 155 | flow_mask = (flow_x != 0.0) + (flow_y != 0.0) 156 | cnt = flow_mask.float() 157 | 158 | for i in range(1, flow_maps_x.shape[1]): 159 | flow_x = flow_maps_x[:, i : i + 1, :, :] 160 | flow_y = flow_maps_y[:, i : i + 1, :, :] 161 | avg_flow += torch.cat([flow_x, flow_y], dim=1) 162 | flow_mask = (flow_x != 0.0) + (flow_y != 0.0) 163 | cnt += flow_mask.float() 164 | 165 | if mask: 166 | mask = torch.sum(self._event_mask, dim=1, keepdim=True) > 0.0 167 | avg_flow *= mask.float() 168 | 169 | return avg_flow / (cnt + 1e-9) 170 | 171 | def window_iwe_base(self, round_idx=False): 172 | """ 173 | Assumption: events have NOT been previously warped in a forward fashion in the update() method. 174 | :param round_idx: if True, round the event coordinates to the nearest integer. 175 | :return: image-like representation of the IWE of all the events in the validation time/event window. 176 | """ 177 | 178 | pol_mask_list = self._event_pol_mask 179 | if not round_idx: 180 | pol_mask_list = torch.cat([pol_mask_list for i in range(4)], dim=1) 181 | 182 | fw_events = event_propagation(self._event_ts, self._event_loc, self._event_flow, self._passes) 183 | fw_idx, fw_weights = get_interpolation(fw_events, self.res, round_idx=round_idx) 184 | fw_iwe_pos = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask_list[:, :, 0:1]) 185 | fw_iwe_neg = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask_list[:, :, 1:2]) 186 | 187 | return torch.cat([fw_iwe_pos, fw_iwe_neg], dim=1) 188 | 189 | def compute_fwl(self, fw_events, zero_events, fw_pol_mask, zero_pol_mask): 190 | """ 191 | The Flow Warp Loss (FWL) metric is the ratio of the variance of the image of warped events 192 | and that of the image of (non-warped) events; hence, the higher the value of this metric, 193 | the better the optical flow estimate. 194 | See 'Reducing the Sim-to-Real Gap for Event Cameras', 195 | Stoffregen et al., ECCV 2020. 196 | """ 197 | 198 | # interpolate forward 199 | fw_idx, fw_weights = get_interpolation(fw_events, self.res, round_idx=True) 200 | 201 | # image of (forward) warped averaged timestamps 202 | fw_iwe_pos = interpolate(fw_idx, fw_weights, self.res, polarity_mask=fw_pol_mask[:, :, 0:1]) 203 | fw_iwe_neg = interpolate(fw_idx, fw_weights, self.res, polarity_mask=fw_pol_mask[:, :, 1:2]) 204 | fw_iwe = fw_iwe_pos + fw_iwe_neg 205 | 206 | # image of non-warped averaged timestamps 207 | zero_idx, zero_weights = get_interpolation(zero_events, self.res, round_idx=True) 208 | zero_iwe_pos = interpolate(zero_idx, zero_weights, self.res, polarity_mask=zero_pol_mask[:, :, 0:1]) 209 | zero_iwe_neg = interpolate(zero_idx, zero_weights, self.res, polarity_mask=zero_pol_mask[:, :, 1:2]) 210 | zero_iwe = zero_iwe_pos + zero_iwe_neg 211 | 212 | return fw_iwe.var() / zero_iwe.var() 213 | 214 | def compute_rsat(self, fw_events, zero_events, fw_pol_mask, zero_pol_mask, ts_list): 215 | """ 216 | The Ratio of the Squared Averaged Timestamps (RSAT) metric is the ratio of the squared sum of the per-pixel and 217 | per-polarity average timestamp of the image of warped events and that of the image of (non-warped) events; hence, 218 | the lower the value of this metric, the better the optical flow estimate. 219 | See 'Self-Supervised Learning of Event-Based Optical Flow with Spiking Neural Networks', 220 | Hagenaars and Paredes-Valles et al., NeurIPS 2021. 221 | """ 222 | 223 | # interpolate forward 224 | fw_idx, fw_weights = get_interpolation(fw_events, self.res, round_idx=True) 225 | 226 | # image of (forward) warped averaged timestamps 227 | fw_iwe_pos = interpolate(fw_idx, fw_weights, self.res, polarity_mask=fw_pol_mask[:, :, 0:1]) 228 | fw_iwe_neg = interpolate(fw_idx, fw_weights, self.res, polarity_mask=fw_pol_mask[:, :, 1:2]) 229 | fw_iwe_pos_ts = interpolate(fw_idx, fw_weights * ts_list, self.res, polarity_mask=fw_pol_mask[:, :, 0:1]) 230 | fw_iwe_neg_ts = interpolate(fw_idx, fw_weights * ts_list, self.res, polarity_mask=fw_pol_mask[:, :, 1:2]) 231 | fw_iwe_pos_ts /= fw_iwe_pos + 1e-9 232 | fw_iwe_neg_ts /= fw_iwe_neg + 1e-9 233 | fw_iwe_pos_ts = fw_iwe_pos_ts / self._passes 234 | fw_iwe_neg_ts = fw_iwe_neg_ts / self._passes 235 | 236 | # image of non-warped averaged timestamps 237 | zero_idx, zero_weights = get_interpolation(zero_events, self.res, round_idx=True) 238 | zero_iwe_pos = interpolate(zero_idx, zero_weights, self.res, polarity_mask=zero_pol_mask[:, :, 0:1]) 239 | zero_iwe_neg = interpolate(zero_idx, zero_weights, self.res, polarity_mask=zero_pol_mask[:, :, 1:2]) 240 | zero_iwe_pos_ts = interpolate( 241 | zero_idx, zero_weights * ts_list, self.res, polarity_mask=zero_pol_mask[:, :, 0:1] 242 | ) 243 | zero_iwe_neg_ts = interpolate( 244 | zero_idx, zero_weights * ts_list, self.res, polarity_mask=zero_pol_mask[:, :, 1:2] 245 | ) 246 | zero_iwe_pos_ts /= zero_iwe_pos + 1e-9 247 | zero_iwe_neg_ts /= zero_iwe_neg + 1e-9 248 | zero_iwe_pos_ts = zero_iwe_pos_ts / self._passes 249 | zero_iwe_neg_ts = zero_iwe_neg_ts / self._passes 250 | 251 | # (scaled) sum of the squares of the per-pixel and per-polarity average timestamps 252 | fw_iwe_pos_ts = fw_iwe_pos_ts.view(fw_iwe_pos_ts.shape[0], -1) 253 | fw_iwe_neg_ts = fw_iwe_neg_ts.view(fw_iwe_neg_ts.shape[0], -1) 254 | fw_iwe_pos_ts = torch.sum(fw_iwe_pos_ts**2, dim=1) 255 | fw_iwe_neg_ts = torch.sum(fw_iwe_neg_ts**2, dim=1) 256 | fw_ts_sum = fw_iwe_pos_ts + fw_iwe_neg_ts 257 | 258 | fw_nonzero_px = fw_iwe_pos + fw_iwe_neg 259 | fw_nonzero_px[fw_nonzero_px > 0] = 1 260 | fw_nonzero_px = fw_nonzero_px.view(fw_nonzero_px.shape[0], -1) 261 | fw_ts_sum /= torch.sum(fw_nonzero_px, dim=1) 262 | 263 | zero_iwe_pos_ts = zero_iwe_pos_ts.view(zero_iwe_pos_ts.shape[0], -1) 264 | zero_iwe_neg_ts = zero_iwe_neg_ts.view(zero_iwe_neg_ts.shape[0], -1) 265 | zero_iwe_pos_ts = torch.sum(zero_iwe_pos_ts**2, dim=1) 266 | zero_iwe_neg_ts = torch.sum(zero_iwe_neg_ts**2, dim=1) 267 | zero_ts_sum = zero_iwe_pos_ts + zero_iwe_neg_ts 268 | 269 | zero_nonzero_px = zero_iwe_pos + zero_iwe_neg 270 | zero_nonzero_px[zero_nonzero_px > 0] = 1 271 | zero_nonzero_px = zero_nonzero_px.view(zero_nonzero_px.shape[0], -1) 272 | zero_ts_sum /= torch.sum(zero_nonzero_px, dim=1) 273 | 274 | return fw_ts_sum / zero_ts_sum 275 | 276 | def compute_aee(self, pred, gt, mask=None): 277 | """ 278 | Average endpoint error (i.e., Euclidean distance). 279 | """ 280 | 281 | # compute AEE 282 | batch_size = pred.shape[0] 283 | error = (pred - gt).pow(2).sum(1).sqrt() 284 | 285 | # AEE not computed in pixels without valid ground truth 286 | gtflow_mask = (gt[:, 0, :, :] == 0.0) * (gt[:, 1, :, :] == 0.0) 287 | gtflow_mask = ~gtflow_mask 288 | 289 | # AEE not computed in pixels without input events (MVSEC) 290 | if mask is not None: 291 | mask = torch.sum(mask, axis=1) 292 | mask = mask > 0 293 | 294 | if "res_aee" in self.config["metrics"].keys(): 295 | yoff = (self.res[0] - self.config["metrics"]["res_aee"][0]) // 2 296 | xoff = (self.res[1] - self.config["metrics"]["res_aee"][1]) // 2 297 | mask = mask[:, yoff:-yoff, xoff:-xoff].contiguous() 298 | error = error[:, yoff:-yoff, xoff:-xoff].contiguous() 299 | gtflow_mask = gtflow_mask[:, yoff:-yoff, xoff:-xoff].contiguous() 300 | 301 | if "vertical_crop_aee" in self.config["metrics"].keys(): 302 | mask = mask[:, : self.config["metrics"]["vertical_crop_aee"], :] 303 | error = error[:, : self.config["metrics"]["vertical_crop_aee"], :] 304 | gtflow_mask = gtflow_mask[:, : self.config["metrics"]["vertical_crop_aee"], :] 305 | 306 | gtflow_mask = gtflow_mask * mask 307 | 308 | # compute AEE 309 | error = error.view(batch_size, -1) 310 | gtflow_mask = gtflow_mask.view(batch_size, -1) 311 | error = error[gtflow_mask] 312 | aee = torch.mean(error, dim=0) 313 | 314 | return aee 315 | 316 | 317 | class Linear(BaseValidation): 318 | """ 319 | Linear event warping validation class. 320 | """ 321 | 322 | def __init__(self, config, device): 323 | super().__init__(config, device) 324 | self._event_flow = None 325 | 326 | def update(self, flow_list, event_list, pol_mask, event_mask): 327 | """ 328 | Initialize/Update container lists of events and flow maps for forward warping. 329 | :param flow_list: [[batch_size x 2 x H x W]] list of optical flow (x, y) maps 330 | :param event_list: [batch_size x N x 4] input events (ts, y, x, p) 331 | :param pol_mask: [batch_size x N x 2] polarity mask (pos, neg) 332 | :param event_mask: [batch_size x 1 x H x W] event mask 333 | """ 334 | 335 | # update base lists (event data, flow maps, event masks) 336 | self.update_base(flow_list, event_list, pol_mask, event_mask) 337 | 338 | # get flow for every event in the list 339 | event_flow = get_event_flow( 340 | self._flow_maps_x[:, -1, ...], 341 | self._flow_maps_y[:, -1, ...], 342 | event_list[:, :, 1:3], 343 | ) 344 | 345 | if self._event_flow is None: 346 | self._event_flow = event_flow 347 | else: 348 | self._event_flow = torch.cat([self._event_flow, event_flow], dim=1) 349 | 350 | # update timestamp index 351 | self._passes += 1 352 | 353 | def reset(self): 354 | """ 355 | Reset lists. 356 | """ 357 | 358 | self.reset_base() 359 | self._event_flow = None 360 | 361 | def window_events(self, round_idx=False): 362 | """ 363 | :param round_idx: if True, round the event coordinates to the nearest integer. 364 | :return: image-like representation of all the events in the validation time/event window. 365 | """ 366 | 367 | return self.window_events_base(round_idx) 368 | 369 | def window_flow(self, mode=None, mask=None): 370 | """ 371 | :return avg_flow: image-like representation of the per-pixel average flow in the validation time/event window. 372 | """ 373 | 374 | if mask is None: 375 | mask = self.config["vis"]["mask_output"] 376 | 377 | # copy flow tensors to prevent overwriting 378 | flow_maps_x = self._flow_maps_x.clone() 379 | flow_maps_y = self._flow_maps_y.clone() 380 | 381 | # forward propagation of the estimated optical flow 382 | for i in range(self._passes - 1): 383 | warped_flow_x, warped_flow_y = self.forward_prop_flow( 384 | i, self._passes - 1, self._flow_maps_x, self._flow_maps_y 385 | ) 386 | 387 | # update lists 388 | flow_maps_x[:, i : i + 1, ...] = warped_flow_x 389 | flow_maps_y[:, i : i + 1, ...] = warped_flow_y 390 | 391 | return self.window_flow_base(flow_maps_x, flow_maps_y, mask=mask) 392 | 393 | def window_iwe(self, mode=None, round_idx=False): 394 | """ 395 | Assumption: events have NOT been previously warped in a forward fashion in the update() method. 396 | :param round_idx: if True, round the event coordinates to the nearest integer. 397 | :return: image-like representation of the IWE of all the events in the validation time/event window. 398 | """ 399 | 400 | return self.window_iwe_base(round_idx) 401 | 402 | def rsat(self): 403 | """ 404 | :return rsat: deblur metric for validation of the estimated optical flow. 405 | """ 406 | 407 | fw_events = event_propagation(self._event_ts, self._event_loc, self._event_flow, self._passes) 408 | return self.compute_rsat(fw_events, self._event_loc, self._event_pol_mask, self._event_pol_mask, self._event_ts) 409 | 410 | def fwl(self): 411 | """ 412 | :return fwl: deblur metric for validation of the estimated optical flow (Stoffregen et al, ECCV 2020). 413 | """ 414 | 415 | fw_events = event_propagation(self._event_ts, self._event_loc, self._event_flow, self._passes) 416 | return self.compute_fwl(fw_events, self._event_loc, self._event_pol_mask, self._event_pol_mask) 417 | 418 | 419 | class Iterative(BaseValidation): 420 | """ 421 | Iterative event warping validation class. 422 | """ 423 | 424 | def __init__(self, config, device): 425 | super().__init__(config, device) 426 | self._fw_event_loc = None 427 | self._fw_event_warp_ts = None 428 | self._fw_event_pol_mask = None 429 | 430 | self._bw_event_loc = None 431 | self._bw_event_pol_mask = None 432 | 433 | self._fw_prop_flow_maps_x = None 434 | self._fw_prop_flow_maps_y = None 435 | 436 | self._accum_flow_map_x = None 437 | self._accum_flow_map_y = None 438 | self._flow_warping_indices = None 439 | self._flow_out_mask = torch.zeros(1, 1, self.res[0], self.res[1]).to(device) 440 | 441 | def update_fw_event_lists(self, event_list, event_pol_mask): 442 | """ 443 | Initialize/Update container lists of events to be udpated during foward warping 444 | :param event_list: [batch_size x N x 4] input events (ts, y, x, p) 445 | :param event_pol_mask: [batch_size x N x 2] event polarity mask 446 | """ 447 | 448 | event_ts = event_list[:, :, 0:1].clone() 449 | if self.config["loss"]["round_ts"]: 450 | event_ts[...] = event_ts.min() + 0.5 451 | 452 | if self._fw_event_loc is None: 453 | self._fw_event_warp_ts = event_ts 454 | self._fw_event_loc = event_list[:, :, 1:3].clone() 455 | self._fw_event_pol_mask = event_pol_mask.clone() 456 | 457 | else: 458 | self._fw_event_warp_ts = torch.cat([self._fw_event_warp_ts, event_ts], dim=1) 459 | self._fw_event_loc = torch.cat([self._fw_event_loc, event_list[:, :, 1:3].clone()], dim=1) 460 | self._fw_event_pol_mask = torch.cat([self._fw_event_pol_mask, event_pol_mask.clone()], dim=1) 461 | 462 | def update_bw_event_lists(self, event_loc, event_pol_mask): 463 | """ 464 | Initialize/Update container lists of events to be udpated during foward warping 465 | :param event_list: [batch_size x N x 2] input events (ts, y, x, p) 466 | :param event_pol_mask: [batch_size x N x 2] event polarity mask 467 | """ 468 | 469 | if self._bw_event_loc is None: 470 | self._bw_event_loc = event_loc.clone() 471 | self._bw_event_pol_mask = event_pol_mask.clone() 472 | 473 | else: 474 | self._bw_event_loc = torch.cat([self._bw_event_loc, event_loc.clone()], dim=1) 475 | self._bw_event_pol_mask = torch.cat([self._bw_event_pol_mask, event_pol_mask.clone()], dim=1) 476 | 477 | def update(self, flow_list, event_list, pol_mask, event_mask): 478 | """ 479 | Initialize/Update container lists of events and flow maps for forward warping. 480 | :param flow_list: [[batch_size x 2 x H x W]] list of optical flow (x, y) maps 481 | :param event_list: [batch_size x N x 4] input events (ts, y, x, p) 482 | :param pol_mask: [batch_size x N x 2] polarity mask (pos, neg) 483 | :param event_mask: [batch_size x 1 x H x W] event mask 484 | """ 485 | 486 | # update base lists (event data, flow maps, event masks) 487 | self.update_base(flow_list, event_list, pol_mask, event_mask) 488 | 489 | ############ 490 | # FORWARD WARPING 491 | ############ 492 | 493 | # initialize and update event lists for fw warping 494 | self.update_fw_event_lists(event_list, pol_mask) 495 | 496 | # sample optical flow 497 | fw_event_flow = get_event_flow( 498 | self._flow_maps_x[:, -1, ...], 499 | self._flow_maps_y[:, -1, ...], 500 | self._fw_event_loc, 501 | ) 502 | 503 | # event warping process 504 | self._fw_event_loc = event_propagation( 505 | self._fw_event_warp_ts, 506 | self._fw_event_loc, 507 | fw_event_flow, 508 | self._passes + 1, 509 | ) 510 | self._fw_event_loc, self._fw_event_pol_mask = purge_unfeasible( 511 | self._fw_event_loc, 512 | self._fw_event_pol_mask, 513 | self.res, 514 | ) 515 | 516 | # update warping times 517 | self._fw_event_warp_ts[...] = self._passes + 1 518 | 519 | ############ 520 | # BACKWARD WARPING 521 | ############ 522 | 523 | bw_event_loc = event_list[:, :, 1:3].clone() 524 | bw_event_pol_mask = pol_mask.clone() 525 | bw_event_warp_ts = event_list[:, :, 0:1].clone() 526 | if self.config["loss"]["round_ts"]: 527 | bw_event_warp_ts[...] = bw_event_warp_ts.min() + 0.5 528 | 529 | cnt = 0 530 | while self._passes + cnt >= 0: 531 | 532 | # sample optical flow 533 | bw_event_flow = get_event_flow( 534 | self._flow_maps_x[:, self._passes + cnt, ...], 535 | self._flow_maps_y[:, self._passes + cnt, ...], 536 | bw_event_loc, 537 | ) 538 | 539 | # event warping process 540 | bw_event_loc = event_propagation( 541 | bw_event_warp_ts, 542 | bw_event_loc, 543 | bw_event_flow, 544 | self._passes + cnt, 545 | ) 546 | bw_event_loc, bw_event_pol_mask = purge_unfeasible( 547 | bw_event_loc, 548 | bw_event_pol_mask, 549 | self.res, 550 | ) 551 | 552 | # update warping times 553 | bw_event_warp_ts[...] = self._passes + cnt 554 | cnt -= 1 555 | 556 | self.update_bw_event_lists(bw_event_loc, bw_event_pol_mask) 557 | 558 | ######################## 559 | # FORWARD-PROPAGATED FLOW 560 | ######################## 561 | 562 | # forward propagation of the estimated optical flow 563 | flow = flow_list[-1] # only highest resolution flow 564 | if self._fw_prop_flow_maps_x is None: 565 | self._fw_prop_flow_maps_x = flow[:, 0:1, :, :] 566 | self._fw_prop_flow_maps_y = flow[:, 1:2, :, :] 567 | else: 568 | self._fw_prop_flow_maps_x = torch.cat([self._fw_prop_flow_maps_x, flow[:, 0:1, :, :]], dim=1) 569 | self._fw_prop_flow_maps_y = torch.cat([self._fw_prop_flow_maps_y, flow[:, 1:2, :, :]], dim=1) 570 | 571 | for i in range(self._passes): 572 | warped_flow_x, warped_flow_y = self.forward_prop_flow( 573 | i, i + 1, self._fw_prop_flow_maps_x, self._fw_prop_flow_maps_y 574 | ) 575 | self._fw_prop_flow_maps_x[:, i : i + 1, ...] = warped_flow_x 576 | self._fw_prop_flow_maps_y[:, i : i + 1, ...] = warped_flow_y 577 | 578 | ######################## 579 | # ACCUMULATED FLOW (BACKWARD WARPING) 580 | ######################## 581 | 582 | indices = self.indices_map.clone() 583 | if self._flow_warping_indices is not None: 584 | indices = self._flow_warping_indices.clone() 585 | 586 | mask_valid = ( 587 | (indices[:, 0:1, ...] >= 0) 588 | * (indices[:, 0:1, ...] <= self.res[0] - 1.0) 589 | * (indices[:, 1:2, ...] >= 0) 590 | * (indices[:, 1:2, ...] <= self.res[1] - 1.0) 591 | ) 592 | self._flow_out_mask += mask_valid.float() 593 | 594 | curr_flow = get_event_flow( 595 | self._flow_maps_x[:, -1, ...], 596 | self._flow_maps_y[:, -1, ...], 597 | indices.view(1, 2, -1).permute(0, 2, 1), 598 | ) 599 | curr_flow = curr_flow.permute(0, 2, 1).view(1, 2, self.res[0], self.res[1]) 600 | 601 | warped_indices = indices + curr_flow * mask_valid.float() 602 | self._accum_flow_map_x = warped_indices[:, 1:2, :, :] - self.indices_map[:, 1:2, :, :] 603 | self._accum_flow_map_y = warped_indices[:, 0:1, :, :] - self.indices_map[:, 0:1, :, :] 604 | self._flow_warping_indices = warped_indices 605 | 606 | # update timestamp index 607 | self._passes += 1 608 | 609 | def reset(self): 610 | """ 611 | Reset lists. 612 | """ 613 | 614 | self.reset_base() 615 | self._fw_event_loc = None 616 | self._fw_event_warp_ts = None 617 | self._fw_event_pol_mask = None 618 | 619 | self._bw_event_loc = None 620 | self._bw_event_pol_mask = None 621 | 622 | self._fw_prop_flow_maps_x = None 623 | self._fw_prop_flow_maps_y = None 624 | 625 | self._accum_flow_map_x = None 626 | self._accum_flow_map_y = None 627 | self._flow_warping_indices = None 628 | self._flow_out_mask = torch.zeros(1, 1, self.res[0], self.res[1]).to(self.device) 629 | 630 | def window_events(self, round_idx=False): 631 | """ 632 | :param round_idx: if True, round the event coordinates to the nearest integer. 633 | :return: image-like representation of all the events in the validation time/event window. 634 | """ 635 | 636 | return self.window_events_base(round_idx) 637 | 638 | def window_flow(self, mode=None, mask=None): 639 | """ 640 | :return avg_flow: image-like representation of the per-pixel average flow in the validation time/event window. 641 | """ 642 | 643 | if mask is None: 644 | mask = self.config["vis"]["mask_output"] 645 | 646 | if mode == "forward": 647 | return self.window_flow_base(self._fw_prop_flow_maps_x, self._fw_prop_flow_maps_y, mask=mask) 648 | elif mode == "backward": 649 | return self.window_flow_base( 650 | self._accum_flow_map_x / self._flow_out_mask, self._accum_flow_map_y / self._flow_out_mask, mask=mask 651 | ) 652 | else: 653 | return self.window_flow_base(self._flow_maps_x, self._flow_maps_y, mask=mask) 654 | 655 | def window_iwe(self, mode="forward", round_idx=False): 656 | """ 657 | Assumption: events have been warped in a forward fashion in the update() method. 658 | :param round_idx: if True, round the event coordinates to the nearest integer. 659 | :return: image-like representation of the IWE of all the events in the validation time/event window. 660 | """ 661 | 662 | if mode == "forward": 663 | event_loc = self._fw_event_loc 664 | pol_mask_list = self._fw_event_pol_mask 665 | elif mode == "backward": 666 | event_loc = self._bw_event_loc 667 | pol_mask_list = self._bw_event_pol_mask 668 | else: 669 | raise ValueError("Invalid IWE mode: {}".format(mode)) 670 | 671 | if not round_idx: 672 | pol_mask_list = torch.cat([pol_mask_list for _ in range(4)], dim=1) 673 | 674 | idx, weights = get_interpolation(event_loc, self.res, round_idx=round_idx) 675 | iwe_pos = interpolate(idx.long(), weights, self.res, polarity_mask=pol_mask_list[:, :, 0:1]) 676 | iwe_neg = interpolate(idx.long(), weights, self.res, polarity_mask=pol_mask_list[:, :, 1:2]) 677 | 678 | return torch.cat([iwe_pos, iwe_neg], dim=1) 679 | 680 | def rsat(self): 681 | """ 682 | :return: deblur metric for validation of the estimated optical flow. 683 | """ 684 | 685 | return self.compute_rsat( 686 | self._fw_event_loc, self._event_loc, self._fw_event_pol_mask, self._event_pol_mask, self._event_ts 687 | ) 688 | 689 | def fwl(self): 690 | """ 691 | :return fwl: deblur metric for validation of the estimated optical flow (Stoffregen et al, ECCV 2020). 692 | """ 693 | 694 | return self.compute_fwl(self._fw_event_loc, self._event_loc, self._fw_event_pol_mask, self._event_pol_mask) 695 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/taming_event_flow/57cbc7e38d6d9f5bd489139ba465609ff83c5a89/models/__init__.py -------------------------------------------------------------------------------- /models/arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .submodules import * 4 | 5 | 6 | class BaseUNet(nn.Module): 7 | """ 8 | Base class for conventional UNet architecture. 9 | Symmetric, skip connections on every encoding layer. 10 | """ 11 | 12 | ff_type = ConvLayer 13 | res_type = ResidualBlock 14 | upsample_type = UpsampleConvLayer 15 | transpose_type = TransposedConvLayer 16 | 17 | def __init__( 18 | self, 19 | num_bins, 20 | base_channels, 21 | num_encoders, 22 | num_residual_blocks, 23 | num_output_channels, 24 | skip_type, 25 | norm, 26 | use_upsample_conv=True, 27 | kernel_size=3, 28 | encoder_stride=2, 29 | channel_multiplier=2, 30 | activations=["relu", None], 31 | final_activation=None, 32 | final_bias=True, 33 | final_w_scale=None, 34 | recurrent_block_type=None, 35 | ): 36 | super(BaseUNet, self).__init__() 37 | self.base_channels = base_channels 38 | self.num_encoders = num_encoders 39 | self.num_residual_blocks = num_residual_blocks 40 | self.num_output_channels = num_output_channels 41 | self.norm = norm 42 | self.num_bins = num_bins 43 | self.recurrent_block_type = recurrent_block_type 44 | self.kernel_size = kernel_size 45 | self.encoder_stride = encoder_stride 46 | self.channel_multiplier = channel_multiplier 47 | self.ff_act, self.rec_act = activations 48 | self.final_activation = final_activation 49 | self.final_bias = final_bias 50 | self.final_w_scale = final_w_scale 51 | 52 | self.skip_type = skip_type 53 | assert self.skip_type is None or self.skip_type in ["sum", "concat"] 54 | 55 | if use_upsample_conv: 56 | self.up_type = self.upsample_type 57 | else: 58 | self.up_type = self.transpose_type 59 | 60 | self.encoder_input_sizes = [ 61 | int(self.base_channels * pow(self.channel_multiplier, i - 1)) for i in range(self.num_encoders) 62 | ] 63 | self.encoder_output_sizes = [ 64 | int(self.base_channels * pow(self.channel_multiplier, i)) for i in range(self.num_encoders) 65 | ] 66 | 67 | self.max_num_channels = self.encoder_output_sizes[-1] 68 | 69 | def skip_fn(self, x, y, mode="sum"): 70 | assert y.shape[2:] <= x.shape[2:] 71 | if x.shape[2:] > y.shape[2:]: 72 | print("Warning: skipping row/col in skip_fn() due to odd dimensions throughout the architecture.") 73 | x = x[:, :, : y.shape[2], : y.shape[3]] # skip last row/col if necessary 74 | 75 | if mode == "sum": 76 | assert x.shape[1] == y.shape[1] 77 | x = x + y 78 | elif mode == "concat": 79 | x = torch.cat([x, y], dim=1) 80 | return x 81 | 82 | def get_axonal_delays(self): 83 | self.delays = 0 84 | 85 | def build_encoders(self): 86 | encoders = nn.ModuleList() 87 | for i, (input_size, output_size) in enumerate(zip(self.encoder_input_sizes, self.encoder_output_sizes)): 88 | if i == 0: 89 | input_size = self.num_bins 90 | encoders.append( 91 | self.ff_type( 92 | input_size, 93 | output_size, 94 | kernel_size=self.kernel_size, 95 | stride=self.encoder_stride, 96 | activation=self.ff_act, 97 | norm=self.norm, 98 | ) 99 | ) 100 | return encoders 101 | 102 | def build_recurrent_encoders(self): 103 | encoders = nn.ModuleList() 104 | for i, (input_size, output_size) in enumerate(zip(self.encoder_input_sizes, self.encoder_output_sizes)): 105 | if i == 0: 106 | input_size = self.num_bins 107 | encoders.append( 108 | self.rec_type( 109 | input_size, 110 | output_size, 111 | kernel_size=self.kernel_size, 112 | stride=self.encoder_stride, 113 | recurrent_block_type=self.recurrent_block_type, 114 | activation_ff=self.ff_act, 115 | activation_rec=self.rec_act, 116 | norm=self.norm, 117 | ) 118 | ) 119 | return encoders 120 | 121 | def build_resblocks(self): 122 | resblocks = nn.ModuleList() 123 | for i in range(self.num_residual_blocks): 124 | resblocks.append( 125 | self.res_type( 126 | self.max_num_channels, 127 | self.max_num_channels, 128 | activation=self.ff_act, 129 | norm=self.norm, 130 | ) 131 | ) 132 | return resblocks 133 | 134 | def build_decoders(self): 135 | decoder_input_sizes = reversed(self.encoder_output_sizes) 136 | decoder_output_sizes = reversed(self.encoder_input_sizes) 137 | decoders = nn.ModuleList() 138 | for input_size, output_size in zip(decoder_input_sizes, decoder_output_sizes): 139 | decoders.append( 140 | self.up_type( 141 | input_size if self.skip_type == "sum" else 2 * input_size, 142 | output_size, 143 | kernel_size=self.kernel_size, 144 | activation=self.ff_act, 145 | norm=self.norm, 146 | ) 147 | ) 148 | return decoders 149 | 150 | def build_multires_prediction_decoders(self): 151 | decoder_input_sizes = reversed(self.encoder_output_sizes) 152 | decoder_output_sizes = reversed(self.encoder_input_sizes) 153 | decoders = nn.ModuleList() 154 | for i, (input_size, output_size) in enumerate(zip(decoder_input_sizes, decoder_output_sizes)): 155 | input_size = 2 * input_size if self.skip_type == "concat" else input_size 156 | prediction_channels = 0 if i == 0 else self.num_output_channels 157 | decoders.append( 158 | self.up_type( 159 | input_size + prediction_channels, 160 | output_size, 161 | kernel_size=self.kernel_size, 162 | activation=self.ff_act, 163 | norm=self.norm, 164 | ) 165 | ) 166 | return decoders 167 | 168 | def build_prediction_layer(self): 169 | return self.ff_type( 170 | 2 * self.base_channels if self.skip_type == "concat" else self.base_channels, 171 | self.num_output_channels, 172 | kernel_size=1, 173 | activation=self.final_activation, 174 | norm=self.norm, 175 | w_scale=self.final_w_scale, 176 | bias=self.final_bias, 177 | ) 178 | 179 | def build_multires_prediction_layer(self): 180 | preds = nn.ModuleList() 181 | decoder_output_sizes = reversed(self.encoder_input_sizes) 182 | for output_size in decoder_output_sizes: 183 | preds.append( 184 | self.ff_type( 185 | output_size, 186 | self.num_output_channels, 187 | 1, 188 | activation=self.final_activation, 189 | norm=self.norm, 190 | w_scale=self.final_w_scale, 191 | bias=self.final_bias, 192 | ) 193 | ) 194 | return preds 195 | 196 | 197 | class MultiResUNetRecurrent(BaseUNet): 198 | """ 199 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block. 200 | Symmetric, skip connections on every encoding layer (concat/sum). 201 | Predictions at each decoding layer. 202 | Predictions are added as skip connection (concat) to the input of the subsequent layer. 203 | """ 204 | 205 | rec_type = RecurrentConvLayer 206 | 207 | def __init__(self, kwargs): 208 | super().__init__(**kwargs) 209 | 210 | self.encoders = self.build_recurrent_encoders() 211 | self.resblocks = self.build_resblocks() 212 | self.decoders = self.build_multires_prediction_decoders() 213 | self.preds = self.build_multires_prediction_layer() 214 | self.num_states = self.num_encoders 215 | self.states = [None] * self.num_states 216 | 217 | def forward(self, x): 218 | """ 219 | :param x: N x num_input_channels x H x W 220 | :return: [N x num_output_channels x H x W for i in range(self.num_encoders)] 221 | """ 222 | 223 | # encoder 224 | blocks = [] 225 | for i, encoder in enumerate(self.encoders): 226 | x, self.states[i] = encoder(x, self.states[i]) 227 | blocks.append(x) 228 | 229 | # residual blocks 230 | for resblock in self.resblocks: 231 | x, _ = resblock(x) 232 | 233 | # decoder and multires predictions 234 | predictions = [] 235 | for i, (decoder, pred) in enumerate(zip(self.decoders, self.preds)): 236 | x = self.skip_fn(x, blocks[self.num_encoders - i - 1], mode=self.skip_type) 237 | if i > 0: 238 | x = self.skip_fn(predictions[-1], x, mode="concat") 239 | x = decoder(x) 240 | predictions.append(pred(x)) 241 | 242 | return predictions 243 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from UZH-RPG https://github.com/uzh-rpg/rpg_e2vid 3 | """ 4 | 5 | from abc import abstractmethod 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class BaseModel(torch.nn.Module): 12 | """ 13 | Base class for all models 14 | """ 15 | 16 | @abstractmethod 17 | def forward(self, *inputs): 18 | """ 19 | Forward pass logic 20 | 21 | :return: Model output 22 | """ 23 | raise NotImplementedError 24 | 25 | def __str__(self): 26 | """ 27 | Model prints with number of trainable parameters 28 | """ 29 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 30 | params = sum([np.prod(p.size()) for p in model_parameters]) 31 | return super().__str__() + "\nTrainable parameters: {}".format(params) 32 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from .arch import * 2 | from .base import BaseModel 3 | from .model_util import copy_states, ImagePadder 4 | 5 | 6 | class RecEVFlowNet(BaseModel): 7 | """ 8 | Recurrent version of the EV-FlowNet model, as described in the paper "Self-Supervised Learning of 9 | Event-based Optical Flow with Spiking Neural Networks", Hagenaars and Paredes-Vallés et al., NeurIPS 2021. 10 | """ 11 | 12 | net_type = MultiResUNetRecurrent 13 | recurrent_block_type = "convgru" 14 | activations = ["relu", None] 15 | 16 | def __init__(self, kwargs, num_bins=2, key="flow", min_size=16): 17 | super().__init__() 18 | self.image_padder = ImagePadder(min_size=min_size) 19 | 20 | self.key = key 21 | arch_kwargs = { 22 | "num_bins": num_bins, 23 | "base_channels": 64, 24 | "num_encoders": 4, 25 | "num_residual_blocks": 2, 26 | "num_output_channels": 2, 27 | "skip_type": "sum", 28 | "norm": None, 29 | "use_upsample_conv": True, 30 | "kernel_size": 3, 31 | "encoder_stride": 2, 32 | "channel_multiplier": 2, 33 | "final_activation": "tanh", 34 | "activations": self.activations, 35 | "recurrent_block_type": self.recurrent_block_type, 36 | } 37 | arch_kwargs.update(kwargs) # udpate params with config 38 | arch_kwargs.pop("name", None) 39 | self.arch = self.net_type(arch_kwargs) 40 | self.num_encoders = arch_kwargs["num_encoders"] 41 | 42 | @property 43 | def states(self): 44 | return copy_states(self.arch.states) 45 | 46 | @states.setter 47 | def states(self, states): 48 | self.arch.states = states 49 | 50 | def detach_states(self): 51 | detached_states = [] 52 | for state in self.arch.states: 53 | if type(state) is tuple: 54 | tmp = [] 55 | for hidden in state: 56 | tmp.append(hidden.detach()) 57 | detached_states.append(tuple(tmp)) 58 | else: 59 | detached_states.append(state.detach()) 60 | self.arch.states = detached_states 61 | 62 | def reset_states(self): 63 | self.arch.states = [None] * self.arch.num_states 64 | 65 | def forward(self, x): 66 | 67 | # image padding 68 | x = self.image_padder.pad(x).contiguous() 69 | 70 | # forward pass 71 | multires_flow = self.arch.forward(x) 72 | 73 | # upsample flow estimates to the original input resolution 74 | flow_list = [] 75 | for i, flow in enumerate(multires_flow): 76 | scaling_h = x.shape[2] / flow.shape[2] 77 | scaling_w = x.shape[3] / flow.shape[3] 78 | scaling_flow = 2 ** (self.num_encoders - i - 1) 79 | upflow = scaling_flow * torch.nn.functional.interpolate( 80 | flow, scale_factor=(scaling_h, scaling_w), mode="bilinear", align_corners=False 81 | ) 82 | upflow = self.image_padder.unpad(upflow) 83 | flow_list.append(upflow) 84 | 85 | return {self.key: flow_list} 86 | -------------------------------------------------------------------------------- /models/model_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | 6 | def recursive_clone(tensor): 7 | """ 8 | Assumes tensor is a torch.tensor with 'clone()' method, possibly 9 | inside nested iterable. 10 | E.g., tensor = [(pytorch_tensor, pytorch_tensor), ...] 11 | """ 12 | if hasattr(tensor, "clone"): 13 | return tensor.clone() 14 | try: 15 | return type(tensor)(recursive_clone(t) for t in tensor) 16 | except TypeError: 17 | print("{} is not iterable and has no clone() method.".format(tensor)) 18 | 19 | 20 | def copy_states(states): 21 | """ 22 | Simple deepcopy if list of Nones, else clone. 23 | """ 24 | if states[0] is None: 25 | return copy.deepcopy(states) 26 | return recursive_clone(states) 27 | 28 | 29 | class ImagePadder(object): 30 | """ 31 | From E-RAFT: https://github.com/uzh-rpg/E-RAFT 32 | """ 33 | 34 | # =================================================================== # 35 | # In some networks, the image gets downsized. This is a problem, if # 36 | # the to-be-downsized image has odd dimensions ([15x20]->[7.5x10]). # 37 | # To prevent this, the input image of the network needs to be a # 38 | # multiple of a minimum size (min_size) # 39 | # The ImagePadder makes sure, that the input image is of such a size, # 40 | # and if not, it pads the image accordingly. # 41 | # =================================================================== # 42 | 43 | def __init__(self, min_size=64): 44 | # --------------------------------------------------------------- # 45 | # The min_size additionally ensures, that the smallest image # 46 | # does not get too small # 47 | # --------------------------------------------------------------- # 48 | self.min_size = min_size 49 | self.pad_height = None 50 | self.pad_width = None 51 | 52 | def pad(self, image): 53 | # --------------------------------------------------------------- # 54 | # If necessary, this function pads the image on the left & top # 55 | # --------------------------------------------------------------- # 56 | height, width = image.shape[-2:] 57 | if self.pad_width is None: 58 | self.pad_height = (self.min_size - height % self.min_size) % self.min_size 59 | self.pad_width = (self.min_size - width % self.min_size) % self.min_size 60 | else: 61 | pad_height = (self.min_size - height % self.min_size) % self.min_size 62 | pad_width = (self.min_size - width % self.min_size) % self.min_size 63 | if pad_height != self.pad_height or pad_width != self.pad_width: 64 | raise 65 | return torch.nn.ZeroPad2d((self.pad_width, 0, self.pad_height, 0))(image) 66 | 67 | def unpad(self, image): 68 | # --------------------------------------------------------------- # 69 | # Removes the padded rows & columns # 70 | # --------------------------------------------------------------- # 71 | return image[..., self.pad_height :, self.pad_width :] 72 | -------------------------------------------------------------------------------- /models/submodules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as f 6 | 7 | 8 | class ConvLayer(nn.Module): 9 | """ 10 | Convolutional layer. 11 | Default: bias, ReLU, no downsampling, no batch norm. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | in_channels, 17 | out_channels, 18 | kernel_size, 19 | stride=1, 20 | activation="relu", 21 | norm=None, 22 | BN_momentum=0.1, 23 | w_scale=None, 24 | padding=None, 25 | bias=None, 26 | ): 27 | super(ConvLayer, self).__init__() 28 | 29 | if padding is None: 30 | padding = kernel_size // 2 31 | if bias is None: 32 | bias = False if norm == "BN" else True 33 | if w_scale is None: 34 | w_scale = math.sqrt(1 / in_channels) 35 | 36 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 37 | nn.init.uniform_(self.conv2d.weight, -w_scale, w_scale) 38 | if bias: 39 | nn.init.zeros_(self.conv2d.bias) 40 | 41 | if activation is not None: 42 | if hasattr(torch, activation): 43 | self.activation = getattr(torch, activation) 44 | else: 45 | self.activation = None 46 | 47 | self.norm = norm 48 | if norm == "BN": 49 | self.norm_layer = nn.BatchNorm2d(out_channels, momentum=BN_momentum) 50 | elif norm == "IN": 51 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 52 | 53 | def forward(self, x): 54 | out = self.conv2d(x) 55 | 56 | if self.norm in ["BN", "IN"]: 57 | out = self.norm_layer(out) 58 | 59 | if self.activation is not None: 60 | out = self.activation(out) 61 | 62 | return out 63 | 64 | 65 | class RecurrentConvLayer(nn.Module): 66 | """ 67 | Layer comprised of a convolution followed by a recurrent convolutional block. 68 | Default: bias, ReLU, no downsampling, no batch norm, ConvGRU. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | in_channels, 74 | out_channels, 75 | kernel_size=3, 76 | stride=1, 77 | recurrent_block_type="convgru", 78 | activation_ff="relu", 79 | activation_rec=None, 80 | norm=None, 81 | BN_momentum=0.1, 82 | ): 83 | super(RecurrentConvLayer, self).__init__() 84 | 85 | assert recurrent_block_type in ["convgru"] 86 | self.recurrent_block_type = recurrent_block_type 87 | if recurrent_block_type == "convgru": 88 | RecurrentBlock = ConvGRU 89 | else: 90 | raise NotImplementedError 91 | 92 | self.conv = ConvLayer( 93 | in_channels, 94 | out_channels, 95 | kernel_size, 96 | stride, 97 | activation_ff, 98 | norm, 99 | BN_momentum=BN_momentum, 100 | ) 101 | self.recurrent_block = RecurrentBlock( 102 | input_size=out_channels, hidden_size=out_channels, kernel_size=3, activation=activation_rec 103 | ) 104 | 105 | def forward(self, x, prev_state): 106 | x = self.conv(x) 107 | x, state = self.recurrent_block(x, prev_state) 108 | return x, state 109 | 110 | 111 | class ConvGRU(nn.Module): 112 | """ 113 | Convolutional GRU cell. 114 | Adapted from https://github.com/jacobkimmel/pytorch_convgru/blob/master/convgru.py 115 | """ 116 | 117 | def __init__(self, input_size, hidden_size, kernel_size, activation=None): 118 | super().__init__() 119 | padding = kernel_size // 2 120 | self.input_size = input_size 121 | self.hidden_size = hidden_size 122 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 123 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 124 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 125 | assert activation is None, "ConvGRU activation cannot be set (just for compatibility)" 126 | 127 | nn.init.orthogonal_(self.reset_gate.weight) 128 | nn.init.orthogonal_(self.update_gate.weight) 129 | nn.init.orthogonal_(self.out_gate.weight) 130 | nn.init.constant_(self.reset_gate.bias, 0.0) 131 | nn.init.constant_(self.update_gate.bias, 0.0) 132 | nn.init.constant_(self.out_gate.bias, 0.0) 133 | 134 | def forward(self, input_, prev_state): 135 | 136 | # get batch and spatial sizes 137 | batch_size = input_.data.size()[0] 138 | spatial_size = input_.data.size()[2:] 139 | 140 | # generate empty prev_state, if None is provided 141 | if prev_state is None: 142 | state_size = [batch_size, self.hidden_size] + list(spatial_size) 143 | prev_state = torch.zeros(state_size, dtype=input_.dtype, device=input_.device) 144 | 145 | # data size is [batch, channel, height, width] 146 | stacked_inputs = torch.cat([input_, prev_state], dim=1) 147 | update = torch.sigmoid(self.update_gate(stacked_inputs)) 148 | reset = torch.sigmoid(self.reset_gate(stacked_inputs)) 149 | out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) 150 | new_state = prev_state * (1 - update) + out_inputs * update 151 | 152 | return new_state, new_state 153 | 154 | 155 | class ResidualBlock(nn.Module): 156 | """ 157 | Residual block as in "Deep residual learning for image recognition", He et al. 2016. 158 | Default: bias, ReLU, no downsampling, no batch norm. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | in_channels, 164 | out_channels, 165 | kernel_size=3, 166 | stride=1, 167 | activation="relu", 168 | downsample=None, 169 | norm=None, 170 | BN_momentum=0.1, 171 | ): 172 | super(ResidualBlock, self).__init__() 173 | bias = False if norm == "BN" else True 174 | self.conv1 = nn.Conv2d( 175 | in_channels, 176 | out_channels, 177 | kernel_size=kernel_size, 178 | stride=stride, 179 | padding=kernel_size // 2, 180 | bias=bias, 181 | ) 182 | 183 | if activation is not None: 184 | if hasattr(torch, activation): 185 | self.activation = getattr(torch, activation) 186 | else: 187 | self.activation = None 188 | 189 | self.norm = norm 190 | if norm == "BN": 191 | self.bn1 = nn.BatchNorm2d(out_channels, momentum=BN_momentum) 192 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=BN_momentum) 193 | elif norm == "IN": 194 | self.bn1 = nn.InstanceNorm2d(out_channels, track_running_stats=True) 195 | self.bn2 = nn.InstanceNorm2d(out_channels, track_running_stats=True) 196 | 197 | self.conv2 = nn.Conv2d( 198 | out_channels, 199 | out_channels, 200 | kernel_size=kernel_size, 201 | stride=stride, 202 | padding=kernel_size // 2, 203 | bias=bias, 204 | ) 205 | self.downsample = downsample 206 | 207 | def forward(self, x): 208 | residual = x 209 | out1 = self.conv1(x) 210 | if self.norm in ["BN", "IN"]: 211 | out1 = self.bn1(out1) 212 | 213 | if self.activation is not None: 214 | out1 = self.activation(out1) 215 | 216 | out2 = self.conv2(out1) 217 | if self.norm in ["BN", "IN"]: 218 | out2 = self.bn2(out2) 219 | 220 | if self.downsample: 221 | residual = self.downsample(x) 222 | 223 | out2 += residual 224 | if self.activation is not None: 225 | out2 = self.activation(out2) 226 | 227 | return out2, out1 228 | 229 | 230 | class UpsampleConvLayer(nn.Module): 231 | """ 232 | Upsampling layer (bilinear interpolation + Conv2d) to increase spatial resolution (x2) in a decoder. 233 | Default: bias, ReLU, no downsampling, no batch norm. 234 | """ 235 | 236 | def __init__( 237 | self, 238 | in_channels, 239 | out_channels, 240 | kernel_size, 241 | stride=1, 242 | activation="relu", 243 | norm=None, 244 | ): 245 | super(UpsampleConvLayer, self).__init__() 246 | 247 | bias = False if norm == "BN" else True 248 | padding = kernel_size // 2 249 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 250 | 251 | if activation is not None: 252 | if hasattr(torch, activation): 253 | self.activation = getattr(torch, activation) 254 | else: 255 | self.activation = None 256 | 257 | self.norm = norm 258 | if norm == "BN": 259 | self.norm_layer = nn.BatchNorm2d(out_channels) 260 | elif norm == "IN": 261 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 262 | 263 | def forward(self, x): 264 | x_upsampled = f.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) 265 | out = self.conv2d(x_upsampled) 266 | 267 | if self.norm in ["BN", "IN"]: 268 | out = self.norm_layer(out) 269 | 270 | if self.activation is not None: 271 | out = self.activation(out) 272 | 273 | return out 274 | 275 | 276 | class TransposedConvLayer(nn.Module): 277 | """ 278 | Transposed convolutional layer to increase spatial resolution (x2) in a decoder. 279 | Default: bias, ReLU, no downsampling, no batch norm. 280 | """ 281 | 282 | def __init__( 283 | self, 284 | in_channels, 285 | out_channels, 286 | kernel_size, 287 | activation="relu", 288 | norm=None, 289 | ): 290 | super(TransposedConvLayer, self).__init__() 291 | 292 | bias = False if norm == "BN" else True 293 | padding = kernel_size // 2 294 | self.transposed_conv2d = nn.ConvTranspose2d( 295 | in_channels, 296 | out_channels, 297 | kernel_size, 298 | stride=2, 299 | padding=padding, 300 | output_padding=1, 301 | bias=bias, 302 | ) 303 | 304 | if activation is not None: 305 | if hasattr(torch, activation): 306 | self.activation = getattr(torch, activation) 307 | else: 308 | self.activation = None 309 | 310 | self.norm = norm 311 | if norm == "BN": 312 | self.norm_layer = nn.BatchNorm2d(out_channels) 313 | elif norm == "IN": 314 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 315 | 316 | def forward(self, x): 317 | out = self.transposed_conv2d(x) 318 | 319 | if self.norm in ["BN", "IN"]: 320 | out = self.norm_layer(out) 321 | 322 | if self.activation is not None: 323 | out = self.activation(out) 324 | 325 | return out 326 | -------------------------------------------------------------------------------- /prepare_dsec_submission.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | 5 | 6 | def retrieve_eval(args): 7 | eval_id = args.eval_id 8 | if args.eval_id < 0: 9 | eval_id = 0 10 | for file in os.listdir(args.path + args.runid + "/"): 11 | if file == ".DS_Store": 12 | continue 13 | tmp = int(file.split(".")[0].split("_")[-1]) 14 | eval_id = tmp + 1 if tmp + 1 > eval_id else eval_id 15 | eval_id -= 1 16 | path_from = args.path + args.runid + "/" + "eval_" + str(eval_id) + "/" 17 | print("Preparing submission for eval_{0}".format(eval_id)) 18 | 19 | return path_from 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("runid") 25 | parser.add_argument("--path", default="dsec_submissions/") 26 | parser.add_argument("--eval_id", default=-1, type=int) 27 | args = parser.parse_args() 28 | 29 | # retrieve last eval run unless specified 30 | path_from = retrieve_eval(args) 31 | 32 | # retrieve folders in directory 33 | entry = "/flow_bw/" 34 | folders = os.listdir(path_from) 35 | for folder in folders: 36 | if folder in [".DS_Store", "submission"]: 37 | continue 38 | 39 | # retrieve files in folder with png extension 40 | files = os.listdir(path_from + folder + entry) 41 | indices = [] 42 | for file in files: 43 | indices.append(int(file.split(".")[0])) 44 | indices.sort() 45 | 46 | # fixing pred-gt alignment 47 | flags = np.load(args.path + folder + "_flag.npy") 48 | flags = np.roll(flags, -1) 49 | 50 | # select gt maps to be submitted 51 | flow_timestamp = np.genfromtxt(args.path + folder + ".txt", skip_header=1, delimiter=",") 52 | flow_filenames = flow_timestamp[:, -1] 53 | 54 | selected_indices = [] 55 | for i in range(len(indices)): 56 | if flags[i] == 1: 57 | selected_indices.append(indices[i]) 58 | 59 | # create new folder 60 | if not os.path.exists(path_from + "submission/"): 61 | os.makedirs(path_from + "submission/") 62 | if not os.path.exists(path_from + "submission/" + folder + "/"): 63 | os.makedirs(path_from + "submission/" + folder + "/") 64 | 65 | # copy files to new folder with the right name 66 | for i in range(len(selected_indices)): 67 | filename = path_from + "submission/" + folder + "/" + str(int(flow_filenames[i])).zfill(6) + ".png" 68 | os.system("cp " + path_from + folder + entry + str(selected_indices[i]).zfill(9) + ".png " + filename) 69 | 70 | print(folder) 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/torch_stable.html 2 | torch>=1.12.0 3 | torchvision>=0.13.0 4 | tensorboard==2.9.1 5 | torch-tb-profiler==0.4.0 6 | PyYAML==6.0 7 | numpy==1.21.6 8 | pandas 9 | hdf5plugin==3.2.0 10 | h5py==3.6.0 11 | opencv-python==4.5.5.64 12 | matplotlib==3.5.1 13 | progress==1.6 14 | mlflow==1.24.0 15 | -------------------------------------------------------------------------------- /train_flow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mlflow 4 | import torch 5 | from torch.optim import * 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | from configs.parser import YAMLParser 9 | from dataloader.h5 import H5Loader 10 | from loss.flow import * 11 | from models.model import * 12 | from utils.utils import load_model, save_diff, save_model 13 | from utils.visualization import Visualization 14 | 15 | 16 | def train(args, config_parser): 17 | """ 18 | Main function of the training pipeline for event-based optical flow estimation. 19 | :param args: arguments of the script 20 | :param config_parser: YAMLParser object with config data 21 | """ 22 | 23 | mlflow.set_tracking_uri(args.path_mlflow) 24 | 25 | # configs 26 | config = config_parser.config 27 | mlflow.set_experiment(config["experiment"]) 28 | run = mlflow.start_run() 29 | runid = run.to_dictionary()["info"]["run_id"] 30 | mlflow.log_params(config) 31 | mlflow.log_param("prev_runid", args.prev_runid) 32 | config = config_parser.combine_entries(config) 33 | print("MLflow dir:", mlflow.active_run().info.artifact_uri[:-9]) 34 | 35 | # log git diff 36 | save_diff("train_diff.txt") 37 | tb_writer = SummaryWriter(log_dir=args.path_mlflow + "mlruns/0/" + runid + "/") 38 | 39 | # initialize settings 40 | device = config_parser.device 41 | kwargs = config_parser.loader_kwargs 42 | config["loader"]["device"] = device 43 | 44 | # visualization tool 45 | if config["vis"]["enabled"]: 46 | vis = Visualization(config) 47 | 48 | # data loader 49 | data = H5Loader(config, shuffle=True, path_cache=args.path_cache) 50 | dataloader = torch.utils.data.DataLoader( 51 | data, 52 | drop_last=True, 53 | batch_size=config["loader"]["batch_size"], 54 | collate_fn=data.custom_collate, 55 | worker_init_fn=config_parser.worker_init_fn, 56 | **kwargs, 57 | ) 58 | 59 | # model initialization and settings 60 | num_bins = 2 if config["data"]["voxel"] is None else config["data"]["voxel"] 61 | model = eval(config["model"]["name"])(config["model"].copy(), num_bins, key="flow") 62 | model = model.to(device) 63 | model, epoch = load_model(args.prev_runid, model, device, curr_run=run, tb_writer=tb_writer) 64 | model.train() 65 | 66 | # loss functions 67 | loss_function = eval(config["loss"]["warping"])(config, device) 68 | 69 | # optimizers 70 | optimizer = eval(config["optimizer"]["name"])(model.parameters(), lr=config["optimizer"]["lr"]) 71 | optimizer.zero_grad() 72 | 73 | # simulation variables 74 | train_loss = 0 75 | best_loss = 1.0e6 76 | end_train = False 77 | data.epoch = epoch 78 | 79 | # dataloader loop 80 | while True: 81 | for inputs in dataloader: 82 | 83 | if data.new_seq: 84 | data.new_seq = False 85 | loss_function.reset() 86 | model.reset_states() 87 | optimizer.zero_grad() 88 | 89 | if data.seq_num >= len(data.files): 90 | tb_writer.add_scalar("loss", train_loss / data.samples, data.epoch) 91 | mlflow.log_metric("loss", train_loss / data.samples, step=data.epoch) 92 | with torch.no_grad(): 93 | if train_loss / data.samples < best_loss: 94 | save_model(model) 95 | best_loss = train_loss / data.samples 96 | 97 | data.epoch += 1 98 | data.samples = 0 99 | train_loss = 0 100 | data.seq_num = data.seq_num % len(data.files) 101 | if data.epoch == config["loader"]["n_epochs"]: 102 | end_train = True 103 | break 104 | 105 | # forward pass (flow in px/input_time) 106 | x = model(inputs["net_input"].to(device)) 107 | for i in range(len(x["flow"])): 108 | x["flow"][i] = x["flow"][i] * config["loss"]["flow_scaling"] 109 | 110 | # event-flow association 111 | loss_function.update( 112 | x["flow"], 113 | inputs["event_list"].to(device), 114 | inputs["event_list_pol_mask"].to(device), 115 | inputs["d_event_list"].to(device), 116 | inputs["d_event_list_pol_mask"].to(device), 117 | ) 118 | 119 | # loss computation 120 | if loss_function.num_passes >= config["data"]["passes_loss"]: 121 | data.samples += config["loader"]["batch_size"] 122 | 123 | loss = loss_function() 124 | train_loss += loss.item() 125 | loss.backward() 126 | 127 | if config["loss"]["clip_grad"] is not None: 128 | torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), config["loss"]["clip_grad"]) 129 | 130 | optimizer.step() 131 | optimizer.zero_grad() 132 | 133 | if config["vis"]["enabled"] and config["loader"]["batch_size"] == 1: 134 | vis.data["flow"] = x["flow"][-1].clone() 135 | 136 | model.detach_states() 137 | loss_function.reset() 138 | 139 | with torch.no_grad(): 140 | if config["vis"]["enabled"] and config["loader"]["batch_size"] == 1: 141 | vis.step(inputs) 142 | 143 | if config["vis"]["verbose"]: 144 | print( 145 | "Train Epoch: {:04d} [{:03d}/{:03d} ({:03d}%)] Loss: {:.6f}".format( 146 | data.epoch, 147 | data.seq_num, 148 | len(data.files), 149 | int(100 * data.seq_num / len(data.files)), 150 | train_loss / data.samples, 151 | ), 152 | end="\r", 153 | ) 154 | 155 | if end_train: 156 | break 157 | 158 | mlflow.end_run() 159 | 160 | 161 | if __name__ == "__main__": 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument( 164 | "--config", 165 | default="configs/train_flow.yml", 166 | help="training configuration", 167 | ) 168 | parser.add_argument( 169 | "--path_mlflow", 170 | default="", 171 | help="location of the mlflow ui", 172 | ) 173 | parser.add_argument( 174 | "--path_cache", 175 | default="", 176 | help="location of the cache version of the formatted dataset", 177 | ) 178 | parser.add_argument( 179 | "--prev_runid", 180 | default="", 181 | help="pre-trained model to use as starting point", 182 | ) 183 | args = parser.parse_args() 184 | 185 | # launch training 186 | train(args, YAMLParser(args.config)) 187 | -------------------------------------------------------------------------------- /utils/iwe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as f 3 | 4 | 5 | def event_propagation(events_ts, events_idx, flow, tref): 6 | """ 7 | Warp the input events according to the provided optical flow map. 8 | :param events: [batch_size x N x 4] input events (ts, y, x, p) 9 | :param flow: [batch_size x N x 2] optical flows (y, x) 10 | :param tref: reference time toward which events are warped 11 | :return warped event indices 12 | """ 13 | 14 | return events_idx + (tref - events_ts) * flow 15 | 16 | 17 | def get_event_flow(flow_map_x, flow_map_y, event_loc): 18 | """ 19 | Sample optical flow maps using event indices 20 | :param flow_map_x: [batch_size x H x W] horizontal optical flow map 21 | :param flow_map_y: [batch_size x H x W] vertical optical flow map 22 | :param event_loc: [batch_size x N x 2] event locations 23 | :return event_flow: [batch_size x N x 2] per-event optical flow (y, x) 24 | """ 25 | 26 | _, h, w = flow_map_x.shape 27 | 28 | # flow vector per input event 29 | event_idx = event_loc.clone() 30 | event_idx[..., 0] = 2 * event_idx[..., 0] / (h - 1) - 1 31 | event_idx[..., 1] = 2 * event_idx[..., 1] / (w - 1) - 1 32 | event_idx = torch.roll(event_idx, 1, dims=-1).unsqueeze(2) # needs to be (x, y) and not (y, x) 33 | 34 | event_flow_x = f.grid_sample(flow_map_x.unsqueeze(1), event_idx, mode="bilinear", align_corners=True) 35 | event_flow_y = f.grid_sample(flow_map_y.unsqueeze(1), event_idx, mode="bilinear", align_corners=True) 36 | event_flow_x = event_flow_x.squeeze(1) 37 | event_flow_y = event_flow_y.squeeze(1) 38 | event_flow = torch.cat([event_flow_y, event_flow_x], dim=2) 39 | 40 | return event_flow 41 | 42 | 43 | def purge_unfeasible(event_loc, event_pol_mask, res): 44 | """ 45 | Purge events that are warped outside the image space. 46 | :param event_loc: [batch_size x N x 2] warped event location 47 | :param event_pol_mask: [batch_size x N x 2] polarity mask of warped events 48 | :return event_loc: [batch_size x N x 2] masked warped event location 49 | :return event_pol_mask: [batch_size x N x 2] masked polarity mask of warped events 50 | """ 51 | 52 | mask = ( 53 | (event_loc[:, :, 0:1] >= 0) 54 | * (event_loc[:, :, 0:1] <= res[0] - 1.0) 55 | * (event_loc[:, :, 1:2] >= 0) 56 | * (event_loc[:, :, 1:2] <= res[1] - 1.0) 57 | ) 58 | event_loc = event_loc * mask 59 | event_pol_mask = event_pol_mask * mask 60 | return event_loc, event_pol_mask 61 | 62 | 63 | def get_interpolation(warped_events, res, round_idx=False, zeros=None): 64 | """ 65 | Warp the input events according to the provided optical flow map and compute the bilinar interpolation 66 | (or rounding) weights to distribute the events to the closes (integer) locations in the image space. 67 | :param events: [batch_size x N x 4] input events (ts, y, x, p) 68 | :param flow: [batch_size x N x 2] optical flows (y, x) 69 | :param tref: reference time toward which events are warped 70 | :param res: resolution of the image space 71 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = False) 72 | :return interpolated event indices 73 | :return interpolation weights 74 | """ 75 | 76 | if round_idx: 77 | 78 | # no bilinear interpolation 79 | idx = torch.round(warped_events) 80 | weights = torch.ones(idx.shape, device=warped_events.device) 81 | 82 | else: 83 | 84 | # get scattering indices 85 | top_y = torch.floor(warped_events[:, :, 0:1]) 86 | bot_y = torch.floor(warped_events[:, :, 0:1] + 1) 87 | left_x = torch.floor(warped_events[:, :, 1:2]) 88 | right_x = torch.floor(warped_events[:, :, 1:2] + 1) 89 | 90 | top_left = torch.cat([top_y, left_x], dim=2) 91 | top_right = torch.cat([top_y, right_x], dim=2) 92 | bottom_left = torch.cat([bot_y, left_x], dim=2) 93 | bottom_right = torch.cat([bot_y, right_x], dim=2) 94 | idx = torch.cat([top_left, top_right, bottom_left, bottom_right], dim=1) 95 | 96 | # get scattering interpolation weights 97 | warped_events = torch.cat([warped_events for i in range(4)], dim=1) 98 | if zeros is None: 99 | zeros = torch.zeros(warped_events.shape, device=warped_events.device) 100 | weights = torch.max(zeros, 1 - torch.abs(warped_events - idx)) 101 | 102 | # purge unfeasible indices 103 | mask = (idx[:, :, 0:1] >= 0) * (idx[:, :, 0:1] < res[0]) * (idx[:, :, 1:2] >= 0) * (idx[:, :, 1:2] < res[1]) 104 | idx *= mask 105 | 106 | # make unfeasible weights zero 107 | weights = torch.prod(weights, dim=-1, keepdim=True) * mask # bilinear interpolation 108 | 109 | # prepare indices 110 | idx[:, :, 0] *= res[1] # torch.view is row-major 111 | idx = torch.sum(idx, dim=2, keepdim=True) 112 | 113 | return idx, weights 114 | 115 | 116 | def interpolate(idx, weights, res, polarity_mask=None, zeros=None): 117 | """ 118 | Create an image-like representation of the warped events. 119 | :param idx: [batch_size x N x 1] warped event locations 120 | :param weights: [batch_size x N x 1] interpolation weights for the warped events 121 | :param res: resolution of the image space 122 | :param polarity_mask: [batch_size x N x 1] polarity mask for the warped events (default = None) 123 | :return image of warped events 124 | """ 125 | 126 | if polarity_mask is not None: 127 | weights = weights * polarity_mask 128 | 129 | if zeros is None: 130 | iwe = torch.zeros((idx.shape[0], res[0] * res[1], 1), device=idx.device) 131 | else: 132 | iwe = zeros.clone() 133 | 134 | iwe = iwe.scatter_add_(1, idx.long(), weights) 135 | iwe = iwe.view((idx.shape[0], 1, res[0], res[1])) 136 | return iwe 137 | 138 | 139 | def deblur_events(flow, event_list, res, round_idx=True, polarity_mask=None, round_flow=True): 140 | """ 141 | Deblur the input events given an optical flow map. 142 | Event timestamp needs to be normalized between 0 and 1. 143 | :param flow: [batch_size x 2 x H x W] optical flow map 144 | :param events: [batch_size x N x 4] input events (ts, y, x, p) 145 | :param res: resolution of the image space 146 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. 147 | :param polarity_mask: [batch_size x N x 1] polarity mask for the warped events 148 | :param round_flow: whether or not to associate events with the closest flow vector 149 | :return iwe: [batch_size x 1 x H x W] image of warped events 150 | """ 151 | 152 | # flow vector per input event 153 | flow_idx = event_list[:, :, 1:3].clone() 154 | mask_unfeasible = ( 155 | (flow_idx[:, :, 0:1] >= 0) 156 | * (flow_idx[:, :, 0:1] < res[0]) 157 | * (flow_idx[:, :, 1:2] >= 0) 158 | * (flow_idx[:, :, 1:2] < res[1]) 159 | ) 160 | flow_idx *= mask_unfeasible 161 | 162 | if not round_flow: 163 | 164 | top_y = torch.floor(flow_idx[:, :, 0:1]) 165 | bot_y = torch.floor(flow_idx[:, :, 0:1] + 1) 166 | left_x = torch.floor(flow_idx[:, :, 1:2]) 167 | right_x = torch.floor(flow_idx[:, :, 1:2] + 1) 168 | 169 | top_left = torch.cat([top_y, left_x], dim=2) 170 | top_right = torch.cat([top_y, right_x], dim=2) 171 | bottom_left = torch.cat([bot_y, left_x], dim=2) 172 | bottom_right = torch.cat([bot_y, right_x], dim=2) 173 | idx = torch.cat([top_left, top_right, bottom_left, bottom_right], dim=1) 174 | 175 | og_idx = torch.cat([flow_idx for i in range(4)], dim=1) 176 | zeros = torch.zeros(idx.shape, device=idx.device) 177 | interp_weights = torch.max(zeros, 1 - torch.abs(og_idx - idx)) 178 | 179 | mask_y = (idx[:, :, 0:1] >= 0) * (idx[:, :, 0:1] < res[0]) 180 | mask_x = (idx[:, :, 1:2] >= 0) * (idx[:, :, 1:2] < res[1]) 181 | mask = mask_y * mask_x 182 | flow_idx = idx * mask 183 | interp_weights = torch.prod(interp_weights, dim=-1, keepdim=True) * mask # bilinear interpolation 184 | 185 | flow_idx[:, :, 0] *= res[1] # torch.view is row-major 186 | flow_idx = torch.sum(flow_idx, dim=2) 187 | 188 | # get flow for every event in the list 189 | flow = flow.view(flow.shape[0], 2, -1) 190 | event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component 191 | event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component 192 | event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1) 193 | event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1) 194 | 195 | # bilinear interpolation of the optical flow vectors 196 | if not round_flow: 197 | N = event_list.shape[1] # number of events 198 | event_flowy = ( 199 | interp_weights[:, 0 * N : 1 * N, :] * event_flowy[:, 0 * N : 1 * N, :] 200 | + interp_weights[:, 1 * N : 2 * N, :] * event_flowy[:, 1 * N : 2 * N, :] 201 | + interp_weights[:, 2 * N : 3 * N, :] * event_flowy[:, 2 * N : 3 * N, :] 202 | + interp_weights[:, 3 * N : 4 * N, :] * event_flowy[:, 3 * N : 4 * N, :] 203 | ) 204 | event_flowx = ( 205 | interp_weights[:, 0 * N : 1 * N, :] * event_flowx[:, 0 * N : 1 * N, :] 206 | + interp_weights[:, 1 * N : 2 * N, :] * event_flowx[:, 1 * N : 2 * N, :] 207 | + interp_weights[:, 2 * N : 3 * N, :] * event_flowx[:, 2 * N : 3 * N, :] 208 | + interp_weights[:, 3 * N : 4 * N, :] * event_flowx[:, 3 * N : 4 * N, :] 209 | ) 210 | 211 | event_flow = torch.cat([event_flowy, event_flowx], dim=2) 212 | 213 | # interpolate forward 214 | fw_events = event_propagation(event_list[:, :, 0:1], event_list[:, :, 1:3], event_flow, 1) 215 | fw_idx, fw_weights = get_interpolation(fw_events, res, round_idx=round_idx) 216 | if not round_idx and polarity_mask is not None: 217 | polarity_mask = torch.cat([polarity_mask for i in range(4)], dim=1) 218 | mask_unfeasible = torch.cat([mask_unfeasible for i in range(4)], dim=1) 219 | fw_weights *= mask_unfeasible 220 | 221 | # image of (forward) warped events 222 | iwe = interpolate(fw_idx, fw_weights, res, polarity_mask=polarity_mask) 223 | 224 | return iwe 225 | 226 | 227 | def compute_pol_iwe(flow, event_list, res, pol_mask, round_idx=True, round_flow=True): 228 | """ 229 | Create a per-polarity image of warped events given an optical flow map. 230 | :param flow: [batch_size x 2 x H x W] optical flow map 231 | :param event_list: [batch_size x N x 4] input events (ts, y, x, p) 232 | :param res: resolution of the image space 233 | :param pol_mask: [batch_size x N x 2] polarity mask 234 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. 235 | :param round_flow: whether or not to associate events with the closest flow vector 236 | :return iwe: [batch_size x 2 x H x W] image of warped events 237 | """ 238 | 239 | iwe_pos = deblur_events( 240 | flow, 241 | event_list, 242 | res, 243 | round_idx=round_idx, 244 | polarity_mask=pol_mask[:, :, 0:1], 245 | round_flow=round_flow, 246 | ) 247 | iwe_neg = deblur_events( 248 | flow, 249 | event_list, 250 | res, 251 | round_idx=round_idx, 252 | polarity_mask=pol_mask[:, :, 1:2], 253 | round_flow=round_flow, 254 | ) 255 | iwe = torch.cat([iwe_pos, iwe_neg], dim=1) 256 | 257 | return iwe 258 | -------------------------------------------------------------------------------- /utils/mlflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | import torch 5 | import mlflow 6 | 7 | 8 | def log_config(path_results, runid, config): 9 | """ 10 | Log configuration file to MlFlow run. 11 | """ 12 | 13 | eval_id = 0 14 | for file in os.listdir(path_results): 15 | if file.endswith(".yml"): 16 | tmp = int(file.split(".")[0].split("_")[-1]) 17 | eval_id = tmp + 1 if tmp + 1 > eval_id else eval_id 18 | yaml_filename = path_results + "eval_" + str(eval_id) + ".yml" 19 | with open(yaml_filename, "w") as outfile: 20 | yaml.dump(config, outfile, default_flow_style=False) 21 | 22 | mlflow.start_run(runid) 23 | mlflow.log_artifact(yaml_filename) 24 | mlflow.end_run() 25 | 26 | return eval_id 27 | 28 | 29 | def log_results(runid, results, path, eval_id): 30 | """ 31 | Log validation results as artifacts to MlFlow run. 32 | """ 33 | 34 | yaml_filename = path + "metrics_" + str(eval_id) + ".yml" 35 | with open(yaml_filename, "w") as outfile: 36 | yaml.dump(results, outfile, default_flow_style=False) 37 | 38 | mlflow.start_run(runid) 39 | mlflow.log_artifact(yaml_filename) 40 | mlflow.end_run() 41 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import mlflow 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | 8 | 9 | def load_model(prev_runid, model, device, curr_run=None, tb_writer=None): 10 | try: 11 | run = mlflow.get_run(prev_runid) 12 | except: 13 | return model, 0 14 | 15 | model_dir = run.info.artifact_uri + "/model/data/model.pth" 16 | if model_dir[:7] == "file://": 17 | model_dir = model_dir[7:] 18 | 19 | starting_epoch = 0 20 | if os.path.isfile(model_dir): 21 | model_loaded = torch.load(model_dir, map_location=device).state_dict() 22 | 23 | # check for input-dependent layers 24 | for key in model_loaded.keys(): 25 | if key.split(".")[1] == "pooling" and key.split(".")[-1] in ["weight", "weight_f"]: 26 | model.encoder_unet.pooling = model.encoder_unet.build_pooling(model_loaded[key].shape).to(device) 27 | model.encoder_unet.get_axonal_delays() 28 | 29 | new_params = model.state_dict() 30 | new_params.update(model_loaded) 31 | model.load_state_dict(new_params) 32 | 33 | loss_file = run.info.artifact_uri[:-9] + "metrics/loss" 34 | if os.path.isfile(run.info.artifact_uri[:-9] + "metrics/loss"): 35 | loss = np.genfromtxt(loss_file) 36 | if curr_run is not None: 37 | if not os.path.exists(curr_run.info.artifact_uri[:-9] + "metrics/"): 38 | os.makedirs(curr_run.info.artifact_uri[:-9] + "metrics/") 39 | for i in range(loss.shape[0]): 40 | mlflow.log_metric("loss", loss[i, 1], step=int(loss[i, 2])) 41 | if tb_writer is not None: 42 | tb_writer.add_scalar("loss", loss[i, 1], int(loss[i, 2])) 43 | starting_epoch = int(loss[-1][-1]) 44 | 45 | print("Model restored from " + prev_runid + "\n") 46 | else: 47 | print("No model found at " + prev_runid + "\n") 48 | 49 | return model, starting_epoch 50 | 51 | 52 | def create_model_dir(path_results, runid): 53 | path_results += runid + "/" 54 | if not os.path.exists(path_results): 55 | os.makedirs(path_results) 56 | print("Results stored at " + path_results + "\n") 57 | return path_results 58 | 59 | 60 | def save_model(model): 61 | mlflow.pytorch.log_model(model, "model", conda_env={"dependencies": []}) 62 | 63 | 64 | def save_state_dict(runid, state_dict): 65 | mlflow.start_run(runid) 66 | mlflow.pytorch.log_state_dict(state_dict, "state_dict") 67 | mlflow.end_run() 68 | 69 | 70 | def load_state_dict(runid, dir="state_dict/", filename="state_dict.pth"): 71 | run = mlflow.get_run(runid) 72 | model_dir = run.info.artifact_uri + "/" + dir 73 | if model_dir[:7] == "file://": 74 | model_dir = model_dir[7:] 75 | 76 | model_dict = None 77 | if os.path.isfile(model_dir + filename): 78 | model_dict = torch.load(model_dir + filename, map_location=torch.device("cpu")) 79 | print("Model restored from " + runid) 80 | else: 81 | print("No model found at " + runid) 82 | 83 | return model_dict 84 | 85 | 86 | def save_csv(data, fname): 87 | # create file if not there 88 | path = mlflow.get_artifact_uri(artifact_path=fname) 89 | if path[:7] == "file://": # to_csv() doesn't work with 'file://' 90 | path = path[7:] 91 | if not os.path.isfile(path): 92 | mlflow.log_text("", fname) 93 | pd.DataFrame(data).to_csv(path) 94 | # else append 95 | else: 96 | pd.DataFrame(data).to_csv(path, mode="a", header=False) 97 | 98 | 99 | def save_diff(fname="git_diff.txt"): 100 | # .txt to allow showing in mlflow 101 | path = mlflow.get_artifact_uri(artifact_path=fname) 102 | if path[:7] == "file://": 103 | path = path[7:] 104 | mlflow.log_text("", fname) 105 | os.system(f"git diff > {path}") 106 | 107 | 108 | def binary_search_array(array, x, left=None, right=None, side="left"): 109 | left = 0 if left is None else left 110 | right = len(array) - 1 if right is None else right 111 | mid = left + (right - left) // 2 112 | 113 | if left > right: 114 | return left if side == "left" else right 115 | 116 | if array[mid] == x: 117 | return mid 118 | 119 | if x < array[mid]: 120 | return binary_search_array(array, x, left=left, right=mid - 1, side=side) 121 | 122 | return binary_search_array(array, x, left=mid + 1, right=right, side=side) 123 | 124 | 125 | def initialize_quant_results(results, filename, metrics): 126 | if filename not in results.keys(): 127 | results[filename] = {} 128 | for metric in metrics: 129 | if metric not in results[filename].keys(): 130 | results[filename][metric] = {} 131 | results[filename][metric]["metric"] = 0 132 | results[filename][metric]["it"] = 0 133 | 134 | return results 135 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import matplotlib 5 | import numpy as np 6 | 7 | 8 | class Visualization: 9 | """ 10 | Utility class for the visualization and storage of rendered image-like representation of multiple elements of the pipeline. 11 | """ 12 | 13 | def __init__(self, kwargs, eval_id=-1, path_results=None): 14 | self.img_idx = 0 15 | self.px = kwargs["vis"]["px"] 16 | self.res = kwargs["loader"]["resolution"] 17 | self.color_scheme = "green_red" # gray / blue_red / green_red 18 | self.show_rendered = kwargs["vis"]["enabled"] 19 | self.store_rendered = kwargs["vis"]["store"] 20 | 21 | if eval_id >= 0 and path_results is not None: 22 | self.store_dir = path_results + "results/" 23 | self.store_dir = self.store_dir + "eval_" + str(eval_id) + "/" 24 | if not os.path.exists(self.store_dir): 25 | os.makedirs(self.store_dir) 26 | self.store_file = None 27 | 28 | self.data = {} 29 | self.keys = [ 30 | "events", 31 | "events_window", 32 | "events_dynamic", 33 | "flow", 34 | "flow_window", 35 | "flow_dynamic", 36 | "flow_bw", 37 | "iwe", 38 | "iwe_fw_window", 39 | "iwe_bw_window", 40 | "iwe_fw_dynamic", 41 | "iwe_bw_dynamic", 42 | "flow_gt", 43 | "depth_gt", 44 | ] 45 | self.title = [ 46 | "Input events", 47 | "Input events - Eval window", 48 | "Input events - Dynamic window", 49 | "Estimated flow", 50 | "Estimated flow - Eval window", 51 | "Estimated flow - Dynamic window", 52 | "Estimated flow - Backward", 53 | "IWE", 54 | "Forward IWE - Eval window", 55 | "Backward IWE - Eval window", 56 | "Forward IWE - Dynamic window", 57 | "Backward IWE - Dynamic window", 58 | "Ground truth flow", 59 | "Ground truth depth", 60 | ] 61 | 62 | self.reset_image_ph() 63 | 64 | def step(self, inputs, sequence=None, ts=None, show=None): 65 | """ 66 | Main function of the visualization workflow. 67 | :param inputs: input data (output of the dataloader) 68 | """ 69 | 70 | # render images 71 | self.render(inputs, show) 72 | 73 | # live display 74 | if self.show_rendered: 75 | self.update(show) 76 | 77 | # store rendered images 78 | if self.store_rendered and sequence is not None: 79 | self.store(sequence, ts, show) 80 | 81 | # reset image placeholders 82 | self.reset_image_ph() 83 | 84 | def reset_image_ph(self): 85 | """ 86 | Initialize/Reset image placeholders. 87 | """ 88 | for key in self.keys: 89 | self.data[key] = None 90 | 91 | def render(self, inputs, show=None): 92 | """ 93 | Rendering tool. 94 | :param inputs: input data (output of the dataloader) 95 | """ 96 | 97 | self.data["events"] = inputs["event_cnt"] if "event_cnt" in inputs.keys() else None 98 | if self.data["events"] is None: 99 | self.data["events"] = inputs["net_input"] if "net_input" in inputs.keys() else None 100 | 101 | self.data["flow_gt"] = inputs["gtflow"] if "gtflow" in inputs.keys() else None 102 | self.data["depth_gt"] = inputs["gtdepth"] if "gtdepth" in inputs.keys() else None 103 | 104 | # optical flow error 105 | if self.data["flow_bw"] is not None and self.data["flow_gt"] is not None: 106 | self.data["error_flow"] = ( 107 | (self.data["flow_bw"].cpu() - self.data["flow_gt"]).pow(2).sum(1).sqrt().unsqueeze(1) 108 | ) 109 | gtflow_mask = (self.data["flow_gt"][:, 0:1, :, :] == 0.0) * (self.data["flow_gt"][:, 1:2, :, :] == 0.0) 110 | self.data["error_flow"] *= ~gtflow_mask 111 | if "error_flow" not in self.keys: 112 | self.keys.append("error_flow") 113 | self.title.append("AEE (capped at 30px)") 114 | 115 | for key in self.keys: 116 | if show is not None: 117 | if key not in show: 118 | continue 119 | 120 | if self.data[key] is not None: 121 | self.data[key] = self.data[key].detach() 122 | 123 | # input events 124 | if key.split("_")[0] == "events" and self.data[key] is not None: 125 | self.data[key] = ( 126 | self.data[key] 127 | .cpu() 128 | .numpy() 129 | .transpose(0, 2, 3, 1) 130 | .reshape((self.data[key].shape[2], self.data[key].shape[3], 2)) 131 | ) 132 | self.data[key] = self.events_to_image(self.data[key]) 133 | 134 | # optical flow 135 | elif key.split("_")[0] == "flow" and self.data[key] is not None: 136 | self.data[key] = ( 137 | self.data[key] 138 | .cpu() 139 | .numpy() 140 | .transpose(0, 2, 3, 1) 141 | .reshape((self.data[key].shape[2], self.data[key].shape[3], 2)) 142 | ) 143 | if key != "flow_bw": 144 | self.data[key] = self.flow_to_image(self.data[key]) 145 | else: 146 | self.data[key] = self.data[key] * 128 + 2**15 147 | self.data[key] = self.data[key].astype(np.uint16) 148 | self.data[key] = np.pad(self.data[key], ((0, 0), (0, 0), (0, 1)), constant_values=0) 149 | self.data[key] = np.flip(self.data[key], axis=-1) 150 | 151 | # optical flow error 152 | elif key == "error_flow" and self.data[key] is not None: 153 | self.data[key] = ( 154 | self.data[key] 155 | .cpu() 156 | .numpy() 157 | .transpose(0, 2, 3, 1) 158 | .reshape((self.data[key].shape[2], self.data[key].shape[3], 1)) 159 | ) 160 | self.data[key] = self.minmax_norm(self.data[key], max=30, min=0) 161 | self.data[key] *= 255 162 | self.data[key] = self.data[key].astype(np.uint8) 163 | self.data[key] = cv2.applyColorMap(self.data[key], cv2.COLORMAP_VIRIDIS) 164 | 165 | # image of warped events 166 | elif key.split("_")[0] == "iwe" and self.data[key] is not None: 167 | self.data[key] = ( 168 | self.data[key] 169 | .cpu() 170 | .numpy() 171 | .transpose(0, 2, 3, 1) 172 | .reshape((self.data[key].shape[2], self.data[key].shape[3], 2)) 173 | ) 174 | self.data[key] = self.events_to_image(self.data[key]) 175 | 176 | def update(self, show=None): 177 | """ 178 | Live visualization of the previously-rendered images. 179 | """ 180 | 181 | for i, key in enumerate(self.keys): 182 | if show is not None: 183 | if key not in show: 184 | continue 185 | 186 | if key not in ["flow_bw"] and self.data[key] is not None: 187 | cv2.namedWindow(self.title[i], cv2.WINDOW_NORMAL) 188 | cv2.resizeWindow(self.title[i], int(self.px), int(self.px)) 189 | cv2.imshow(self.title[i], self.data[key]) 190 | 191 | cv2.waitKey(1) 192 | 193 | def store(self, sequence, ts=None, show=None): 194 | """ 195 | Store previously-rendered images. 196 | :param sequence: name of the sequence 197 | :param ts: timestamp of the images to be stored 198 | """ 199 | 200 | # check if new sequence 201 | path_to = self.store_dir + sequence + "/" 202 | if not os.path.exists(path_to): 203 | os.makedirs(path_to) 204 | for key in self.keys: 205 | os.makedirs(path_to + key + "/") 206 | if self.store_file is not None: 207 | self.store_file.close() 208 | self.store_file = open(path_to + "timestamps.txt", "w") 209 | self.img_idx = 0 210 | 211 | # store images 212 | for key in self.keys: 213 | if show is not None: 214 | if key not in show: 215 | continue 216 | 217 | if not os.path.exists(path_to + key + "/"): 218 | os.makedirs(path_to + key + "/") 219 | if self.data[key] is not None: 220 | filename = path_to + key + "/%09d.png" % self.img_idx 221 | cv2.imwrite(filename, self.data[key]) 222 | 223 | # store timestamps 224 | if ts is not None: 225 | self.store_file.write(str(ts) + "\n") 226 | self.store_file.flush() 227 | 228 | self.img_idx += 1 229 | cv2.waitKey(1) 230 | 231 | @staticmethod 232 | def flow_to_image(flow): 233 | """ 234 | Use the optical flow color scheme from the supplementary materials of the paper 'Back to Event 235 | Basics: Self-Supervised Image Reconstruction for Event Cameras via Photometric Constancy', 236 | Paredes-Valles et al., CVPR'21. 237 | :param flow: [H x W x 2] optical flow map (horizontal, vertical in dim=2) 238 | :return: [H x W x 3] color-encoded optical flow in BGR format 239 | """ 240 | mag = np.linalg.norm(flow, axis=2) 241 | min_mag = np.min(mag) 242 | mag_range = np.max(mag) - min_mag 243 | 244 | ang = np.arctan2(flow[:, :, 1], flow[:, :, 0]) + np.pi 245 | ang *= 1.0 / np.pi / 2.0 246 | 247 | hsv = np.zeros([flow.shape[0], flow.shape[1], 3]) 248 | hsv[:, :, 0] = ang 249 | hsv[:, :, 1] = 1.0 250 | hsv[:, :, 2] = mag - min_mag 251 | if mag_range != 0.0: 252 | hsv[:, :, 2] /= mag_range 253 | 254 | flow_rgb = matplotlib.colors.hsv_to_rgb(hsv) 255 | flow_rgb = (255 * flow_rgb).astype(np.uint8) 256 | return cv2.cvtColor(flow_rgb, cv2.COLOR_RGB2BGR) 257 | 258 | @staticmethod 259 | def events_to_image(event_cnt, color_scheme="green_red"): 260 | """ 261 | Format events into an image. 262 | :param event_cnt: [H x W x 2] event count map 263 | :param color_scheme: gray / blue_red / green_red 264 | """ 265 | pos = event_cnt[:, :, 0] 266 | neg = event_cnt[:, :, 1] 267 | pos_max = np.percentile(pos, 99) 268 | pos_min = np.percentile(pos, 1) 269 | neg_max = np.percentile(neg, 99) 270 | neg_min = np.percentile(neg, 1) 271 | max = pos_max if pos_max > neg_max else neg_max 272 | 273 | if pos_min != max: 274 | pos = (pos - pos_min) / (max - pos_min) 275 | if neg_min != max: 276 | neg = (neg - neg_min) / (max - neg_min) 277 | 278 | pos = np.clip(pos, 0, 1) 279 | neg = np.clip(neg, 0, 1) 280 | 281 | event_image = np.ones((event_cnt.shape[0], event_cnt.shape[1])) 282 | if color_scheme == "gray": 283 | event_image *= 0.5 284 | pos *= 0.5 285 | neg *= -0.5 286 | event_image += pos + neg 287 | 288 | elif color_scheme == "green_red": 289 | event_image = np.repeat(event_image[:, :, np.newaxis], 3, axis=2) 290 | event_image *= 0 291 | mask_pos = pos > 0 292 | mask_neg = neg > 0 293 | mask_not_pos = pos == 0 294 | mask_not_neg = neg == 0 295 | 296 | event_image[:, :, 0][mask_pos] = 0 297 | event_image[:, :, 1][mask_pos] = pos[mask_pos] 298 | event_image[:, :, 2][mask_pos * mask_not_neg] = 0 299 | event_image[:, :, 2][mask_neg] = neg[mask_neg] 300 | event_image[:, :, 0][mask_neg] = 0 301 | event_image[:, :, 1][mask_neg * mask_not_pos] = 0 302 | 303 | elif color_scheme == "rpg": 304 | event_image = np.repeat(event_image[:, :, np.newaxis], 3, axis=2) 305 | mask_pos = pos > 0 306 | mask_neg = neg > 0 307 | 308 | event_image[:, :, 0][mask_neg] = 1 309 | event_image[:, :, 1][mask_neg] = 0 310 | event_image[:, :, 2][mask_neg] = 0 311 | event_image[:, :, 0][mask_pos] = 0 312 | event_image[:, :, 1][mask_pos] = 0 313 | event_image[:, :, 2][mask_pos] = 1 314 | 315 | elif color_scheme == "prophesee": 316 | event_image = np.repeat(event_image[:, :, np.newaxis], 3, axis=2) 317 | mask_pos = pos > 0 318 | mask_neg = neg > 0 319 | 320 | event_image[:, :, 0][mask_neg] = 0.24313725490196078 321 | event_image[:, :, 1][mask_neg] = 0.11764705882352941 322 | event_image[:, :, 2][mask_neg] = 0.047058823529411764 323 | event_image[:, :, 0][mask_pos] = 0.6352941176470588 324 | event_image[:, :, 1][mask_pos] = 0.4235294117647059 325 | event_image[:, :, 2][mask_pos] = 0.23529411764705882 326 | 327 | else: 328 | print("Visualization error: Unknown color scheme for event images.") 329 | raise AttributeError 330 | 331 | event_image = (255 * event_image).astype(np.uint8) 332 | return event_image 333 | 334 | @staticmethod 335 | def minmax_norm(x, max=None, min=None): 336 | """ 337 | Robust min-max normalization. 338 | :param x: [H x W x 1] 339 | :return x: [H x W x 1] normalized x 340 | """ 341 | 342 | if max is None: 343 | max = np.percentile(x, 99) 344 | if min is None: 345 | min = np.percentile(x, 1) 346 | 347 | den = max - min 348 | if den != 0: 349 | x = (x - min) / den 350 | return np.clip(x, 0, 1) 351 | --------------------------------------------------------------------------------