├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── .readme ├── flow.gif └── reconstruction.gif ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── eval_flow.yml ├── eval_reconstruction.yml ├── parser.py ├── train_flow.yml └── train_reconstruction.yml ├── dataloader ├── __init__.py ├── base.py ├── encodings.py ├── h5.py └── utils.py ├── datasets ├── .gitignore └── tools │ ├── h5_packager.py │ ├── messageTypes │ ├── __init__.py │ ├── common.py │ ├── dvs_msgs_EventArray.py │ ├── esim_msgs_OpticFlow.py │ ├── geometry_msgs_PoseStamped.py │ ├── geometry_msgs_Transform.py │ ├── geometry_msgs_TransformStamped.py │ ├── geometry_msgs_TwistStamped.py │ ├── sensor_msgs_CameraInfo.py │ ├── sensor_msgs_Image.py │ ├── sensor_msgs_Imu.py │ ├── sensor_msgs_PointCloud2.py │ └── tf_tfMessage.py │ ├── random_crop.py │ └── rosbag_to_h5.py ├── eval_flow.py ├── eval_reconstruction.py ├── loss ├── __init__.py ├── flow.py └── reconstruction.py ├── models ├── __init__.py ├── base.py ├── model.py ├── model_util.py ├── submodules.py └── unet.py ├── pyproject.toml ├── requirements.txt ├── train_flow.py ├── train_reconstruction.py └── utils ├── __init__.py ├── gradients.py ├── iwe.py ├── utils.py └── visualization.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, F403, F401, E741, W291, C901, E722, E402 3 | max-line-length = 120 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | trained_models/ 2 | __pycache__/ 3 | mlruns/ 4 | .vscode/ 5 | *.pyc 6 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: 3.7.9 9 | hooks: 10 | - id: flake8 -------------------------------------------------------------------------------- /.readme/flow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/ssl_e2vid/ef9d61b476576f60f92029cadde4286c6cf0f734/.readme/flow.gif -------------------------------------------------------------------------------- /.readme/reconstruction.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/ssl_e2vid/ef9d61b476576f60f92029cadde4286c6cf0f734/.readme/reconstruction.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 TU Delft 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Back to Event Basics: SSL of Image Reconstruction for Event Cameras 2 | 3 | Minimal code for Back to Event Basics: Self-Supervised Learning of Image Reconstruction for Event Cameras via Photometric Constancy, CVPR'21. 4 | 5 | ## Usage 6 | 7 | This project uses Python >= 3.7.3. After setting up your virtual environment, please install the required python libraries through: 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | Code is formatted with Black (PEP8) using a pre-commit hook. To configure it, run: 14 | 15 | ``` 16 | pre-commit install 17 | ``` 18 | 19 | ### Data format 20 | 21 | Similarly to researchers from [Monash University](https://github.com/TimoStoff/events_contrast_maximization/tree/d6241dc90ec4dc2b4cffbb331a2389ff179bf7ab), this project processes events through the HDF5 data format. Details about the structure of these files can be found in `datasets/tools/`. 22 | 23 | ## Inference 24 | 25 | Download our pre-trained models from [here](https://1drv.ms/u/s!Ah0kx0CRKrAZjyw3nbvTo-lmXPvO?e=1r6SKD). 26 | 27 | Our HDF5 version of sequences from the Event Camera Dataset can also be downloaded from [here](https://1drv.ms/u/s!Ah0kx0CRKrAZjysmUU3tB7VkN2z3?e=S9CGut) for evaluation purposes. 28 | 29 | To estimate optical flow from the input events: 30 | 31 | ``` 32 | python eval_flow.py 33 | ``` 34 | 35 | 36 | 37 |   38 | 39 | To perform image reconstruction from the input events: 40 | 41 | ``` 42 | python eval_reconstruction.py 43 | ``` 44 | 45 | 46 | 47 |   48 | 49 | In `configs/`, you can find the configuration files associated to these scripts and vary the inference settings (e.g., number of input events, dataset). 50 | 51 | ## Training 52 | 53 | Our framework can be trained using any event camera dataset. However, if you are interested in using our training data, you can download it from [here](https://1drv.ms/u/s!Ah0kx0CRKrAZjysmUU3tB7VkN2z3?e=S9CGut). The datasets are expected at `datasets/data/`, but this location can be modified in the configuration files. 54 | 55 | To train an image reconstruction and optical flow model, you need to adapt the training settings in `configs/train_reconstruction.yml`. Here, you can choose the training dataset, the number of input events, the neural networks to be used (EV-FlowNet or FireFlowNet for optical flow; E2VID or FireNet for image reconstruction), the number of epochs, the optimizer and learning rate, etc. To start the training from scratch, run: 56 | 57 | ``` 58 | python train_reconstruction.py 59 | ``` 60 | 61 | Alternatively, if you have a model that you would like to keep training from, you can use 62 | 63 | ``` 64 | python train_reconstruction.py --prev_model 65 | ``` 66 | 67 | This is handy if, for instance, you just want to train the image reconstruction model and use a pre-trained optical flow network. For this, you can set `train_flow: False` in `configs/train_reconstruction.yml`, and run: 68 | 69 | ``` 70 | python train_reconstruction.py --prev_model 71 | ``` 72 | 73 | If you just want to train an optical flow network, adapt `configs/train_flow.yml`, and run: 74 | 75 | ``` 76 | python train_flow.py 77 | ``` 78 | 79 | Note that we use [MLflow](https://mlflow.org/) to keep track of all the experiments. 80 | 81 | ## Citations 82 | 83 | If you use this library in an academic context, please cite the following: 84 | 85 | ``` 86 | @article{paredes2020back, 87 | title={Back to Event Basics: Self-Supervised Learning of Image Reconstruction for Event Cameras via Photometric Constancy}, 88 | author={Paredes-Vall{\'e}s, Federico and de Croon, Guido C. H. E.}, 89 | journal={arXiv preprint arXiv:2009.08283}, 90 | year={2020} 91 | } 92 | ``` 93 | 94 | ## Acknowledgements 95 | 96 | This code borrows from the following open source projects, whom we would like to thank: 97 | 98 | - [E2VID](https://github.com/uzh-rpg/rpg_e2vid) 99 | - [Event Contrast Maximization Library](https://github.com/TimoStoff/events_contrast_maximization) 100 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/ssl_e2vid/ef9d61b476576f60f92029cadde4286c6cf0f734/configs/__init__.py -------------------------------------------------------------------------------- /configs/eval_flow.yml: -------------------------------------------------------------------------------- 1 | data: 2 | path: datasets/data/ECD/ 3 | mode: events # events/time/frames 4 | window: 50000 # events/time/frames 5 | 6 | model_flow: 7 | flow_scaling: 128 8 | mask_output: True 9 | 10 | loader: 11 | resolution: [180, 240] # H x W 12 | augment: [] 13 | gpu: 0 14 | 15 | vis: 16 | enabled: True 17 | px: 400 18 | store: False 19 | 20 | hot_filter: 21 | enabled: True 22 | max_px: 100 23 | min_obvs: 5 24 | max_rate: 0.8 25 | -------------------------------------------------------------------------------- /configs/eval_reconstruction.yml: -------------------------------------------------------------------------------- 1 | data: 2 | path: datasets/data/ECD/ 3 | mode: events # events/time/frames 4 | window: 50000 # events/time/frames 5 | 6 | model_flow: 7 | mask_output: True 8 | eval: True 9 | 10 | loader: 11 | resolution: [180, 240] # H x W 12 | augment: [] 13 | gpu: 0 14 | 15 | vis: 16 | enabled: True 17 | px: 400 18 | store: False 19 | 20 | hot_filter: 21 | enabled: True 22 | max_px: 100 23 | min_obvs: 5 24 | max_rate: 0.8 25 | -------------------------------------------------------------------------------- /configs/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import yaml 6 | 7 | 8 | class YAMLParser: 9 | """ YAML parser for optical flow and image reconstruction config files """ 10 | 11 | def __init__(self, config): 12 | self.reset_config() 13 | self.parse_config(config) 14 | self.get_device() 15 | self.init_seeds() 16 | 17 | def parse_config(self, file): 18 | with open(file) as fid: 19 | yaml_config = yaml.load(fid, Loader=yaml.FullLoader) 20 | self.parse_dict(yaml_config) 21 | 22 | def log_config(self, path_models): 23 | with open(path_models + "train_config.yml", "w") as fid: 24 | yaml.dump(self._config, fid) 25 | 26 | @property 27 | def config(self): 28 | return self._config 29 | 30 | @config.setter 31 | def config(self, config): 32 | self._config = config 33 | 34 | @property 35 | def device(self): 36 | return self._device 37 | 38 | @property 39 | def loader_kwargs(self): 40 | return self._loader_kwargs 41 | 42 | def reset_config(self): 43 | self._config = {} 44 | 45 | # MLFlow experiment name 46 | self._config["experiment"] = "Default" 47 | 48 | # input data mode 49 | self._config["data"] = {} 50 | self._config["data"]["mode"] = "events" 51 | self._config["data"]["window"] = 5000 52 | 53 | # data loader 54 | self._config["loader"] = {} 55 | self._config["loader"]["resolution"] = [180, 240] 56 | self._config["loader"]["batch_size"] = 1 57 | self._config["loader"]["augment"] = [] 58 | self._config["loader"]["gpu"] = 0 59 | self._config["loader"]["seed"] = 0 60 | 61 | # hot pixel 62 | self._config["hot_filter"] = {} 63 | self._config["hot_filter"]["enabled"] = True 64 | self._config["hot_filter"]["max_px"] = 100 65 | self._config["hot_filter"]["min_obvs"] = 5 66 | self._config["hot_filter"]["max_rate"] = 0.8 67 | 68 | # model 69 | self._config["model"] = {} 70 | 71 | def update(self, config): 72 | self.reset_config() 73 | self.parse_config(config) 74 | 75 | def parse_dict(self, input_dict, parent=None): 76 | if parent is None: 77 | parent = self._config 78 | for key, val in input_dict.items(): 79 | if isinstance(val, dict): 80 | if key not in parent.keys(): 81 | parent[key] = {} 82 | self.parse_dict(val, parent[key]) 83 | else: 84 | parent[key] = val 85 | 86 | def log_eval_config(self, config): 87 | eval_id = 0 88 | for file in os.listdir(config["trained_model"]): 89 | if file.endswith(".yml"): 90 | if file.split(".")[0].split("_")[0] == "eval": 91 | tmp = int(file.split(".")[0].split("_")[-1]) 92 | eval_id = tmp + 1 if tmp + 1 > eval_id else eval_id 93 | yaml_filename = config["trained_model"] + "eval_" + str(eval_id) + ".yml" 94 | with open(yaml_filename, "w") as outfile: 95 | yaml.dump(config, outfile, default_flow_style=False) 96 | 97 | return eval_id 98 | 99 | def get_device(self): 100 | cuda = torch.cuda.is_available() 101 | self._device = torch.device("cuda:" + str(self._config["loader"]["gpu"]) if cuda else "cpu") 102 | self._loader_kwargs = {"num_workers": 0, "pin_memory": True} if cuda else {} 103 | 104 | @staticmethod 105 | def worker_init_fn(worker_id): 106 | np.random.seed(np.random.get_state()[1][0] + worker_id) 107 | 108 | def init_seeds(self): 109 | torch.manual_seed(self._config["loader"]["seed"]) 110 | if torch.cuda.is_available(): 111 | torch.cuda.manual_seed(self._config["loader"]["seed"]) 112 | torch.cuda.manual_seed_all(self._config["loader"]["seed"]) 113 | 114 | def merge_configs(self, path_models): 115 | 116 | # parse training config 117 | with open(path_models + "train_config.yml") as fid: 118 | train_config = yaml.load(fid, Loader=yaml.FullLoader) 119 | 120 | # overwrite with config settings 121 | self.parse_dict(self._config, train_config) 122 | 123 | return train_config 124 | -------------------------------------------------------------------------------- /configs/train_flow.yml: -------------------------------------------------------------------------------- 1 | experiment: Default 2 | 3 | data: 4 | path: datasets/data/training/ 5 | mode: events # events/time/frames 6 | window: 5000 # events/time/frames 7 | num_bins: 5 8 | 9 | model_flow: 10 | name: EVFlowNet # FireFlowNet/EVFlowNet 11 | base_num_channels: 32 12 | kernel_size: 3 13 | mask_output: False 14 | 15 | loss: 16 | flow_regul_weight: 1 17 | 18 | optimizer: 19 | name: Adam 20 | lr: 0.0001 21 | 22 | loader: 23 | n_epochs: 120 24 | batch_size: 1 25 | resolution: [128, 128] # H x W 26 | augment: ["Horizontal", "Vertical", "Polarity"] 27 | augment_prob: [0.5, 0.5, 0.5] 28 | gpu: 0 29 | 30 | vis: 31 | verbose: True 32 | enabled: False 33 | px: 400 34 | 35 | hot_filter: 36 | enabled: True 37 | max_px: 100 38 | min_obvs: 5 39 | max_rate: 0.8 40 | -------------------------------------------------------------------------------- /configs/train_reconstruction.yml: -------------------------------------------------------------------------------- 1 | experiment: Default 2 | 3 | data: 4 | path: datasets/data/training/ 5 | mode: events # events/time/frames 6 | window: 5000 # events/time/frames 7 | num_bins: 5 8 | 9 | model_reconstruction: 10 | name: E2VID # E2VID/FireNet 11 | base_num_channels: 32 12 | kernel_size: 5 13 | 14 | model_flow: 15 | name: EVFlowNet # EVFlowNet/FireFlowNet 16 | base_num_channels: 32 17 | kernel_size: 3 18 | mask_output: False 19 | 20 | loss: 21 | train_flow: True 22 | flow_regul_weight: 1.0 23 | reconstruction_regul_weight: [0.1, 0.05] # TotalVariation/TemporalConsistency 24 | reconstruction_tc_idx_threshold: 10 25 | reconstruction_unroll: 20 26 | 27 | optimizer: 28 | name: Adam 29 | lr: 0.0001 30 | 31 | loader: 32 | n_epochs: 120 33 | batch_size: 1 34 | resolution: [128, 128] # H x W 35 | augment: ["Horizontal", "Vertical", "Polarity", "Pause"] 36 | augment_prob: [0.5, 0.5, 0.5, [0.05, 0.1]] 37 | gpu: 0 38 | 39 | vis: 40 | verbose: True 41 | enabled: False 42 | px: 400 43 | 44 | hot_filter: 45 | enabled: True 46 | max_px: 100 47 | min_obvs: 5 48 | max_rate: 0.8 49 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/ssl_e2vid/ef9d61b476576f60f92029cadde4286c6cf0f734/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import numpy as np 4 | import random 5 | import torch 6 | 7 | from .encodings import events_to_voxel, events_to_channels, events_to_mask, get_hot_event_mask 8 | 9 | 10 | class BaseDataLoader(torch.utils.data.Dataset): 11 | """ 12 | Base class for dataloader. 13 | """ 14 | 15 | def __init__(self, config, num_bins): 16 | self.config = config 17 | self.epoch = 0 18 | self.seq_num = 0 19 | self.samples = 0 20 | self.new_seq = False 21 | self.tc_idx = 0 22 | self.num_bins = num_bins 23 | 24 | # batch-specific data augmentation mechanisms 25 | self.batch_augmentation = {} 26 | for mechanism in self.config["loader"]["augment"]: 27 | if mechanism != "Pause": 28 | self.batch_augmentation[mechanism] = [False for i in range(self.config["loader"]["batch_size"])] 29 | else: 30 | self.batch_augmentation[mechanism] = False # shared among batch elements 31 | 32 | for i, mechanism in enumerate(self.config["loader"]["augment"]): 33 | if mechanism != "Pause": 34 | for batch in range(self.config["loader"]["batch_size"]): 35 | if np.random.random() < self.config["loader"]["augment_prob"][i]: 36 | self.batch_augmentation[mechanism][batch] = True 37 | 38 | # hot pixels 39 | if self.config["hot_filter"]["enabled"]: 40 | self.hot_idx = [0 for i in range(self.config["loader"]["batch_size"])] 41 | self.hot_events = [ 42 | torch.zeros(self.config["loader"]["resolution"]) for i in range(self.config["loader"]["batch_size"]) 43 | ] 44 | 45 | @abstractmethod 46 | def __getitem__(self, index): 47 | raise NotImplementedError 48 | 49 | @abstractmethod 50 | def get_events(self, history): 51 | raise NotImplementedError 52 | 53 | def reset_sequence(self, batch): 54 | """ 55 | Reset sequence-specific variables. 56 | :param batch: batch index 57 | """ 58 | 59 | self.tc_idx = 0 60 | self.seq_num += 1 61 | if self.config["hot_filter"]["enabled"]: 62 | self.hot_idx[batch] = 0 63 | self.hot_events[batch] = torch.zeros(self.config["loader"]["resolution"]) 64 | 65 | for i, mechanism in enumerate(self.config["loader"]["augment"]): 66 | if mechanism != "Pause": 67 | if np.random.random() < self.config["loader"]["augment_prob"][i]: 68 | self.batch_augmentation[mechanism][batch] = True 69 | else: 70 | self.batch_augmentation[mechanism][batch] = False 71 | else: 72 | self.batch_augmentation[mechanism] = False 73 | 74 | @staticmethod 75 | def event_formatting(xs, ys, ts, ps): 76 | """ 77 | Reset sequence-specific variables. 78 | :param xs: [N] numpy array with event x location 79 | :param ys: [N] numpy array with event y location 80 | :param ts: [N] numpy array with event timestamp 81 | :param ps: [N] numpy array with event polarity ([-1, 1]) 82 | :return xs: [N] tensor with event x location 83 | :return ys: [N] tensor with event y location 84 | :return ts: [N] tensor with normalized event timestamp 85 | :return ps: [N] tensor with event polarity ([-1, 1]) 86 | """ 87 | 88 | xs = torch.from_numpy(xs.astype(np.float32)) 89 | ys = torch.from_numpy(ys.astype(np.float32)) 90 | ts = torch.from_numpy(ts.astype(np.float32)) 91 | ps = torch.from_numpy(ps.astype(np.float32)) * 2 - 1 92 | ts = (ts - ts[0]) / (ts[-1] - ts[0]) 93 | return xs, ys, ts, ps 94 | 95 | def augment_events(self, xs, ys, ps, batch): 96 | """ 97 | Augment event sequence with horizontal, vertical, and polarity flips, and 98 | artificial event pauses. 99 | :return xs: [N] tensor with event x location 100 | :return ys: [N] tensor with event y location 101 | :return ps: [N] tensor with event polarity ([-1, 1]) 102 | :param batch: batch index 103 | :return xs: [N] tensor with augmented event x location 104 | :return ys: [N] tensor with augmented event y location 105 | :return ps: [N] tensor with augmented event polarity ([-1, 1]) 106 | """ 107 | 108 | for i, mechanism in enumerate(self.config["loader"]["augment"]): 109 | 110 | if mechanism == "Horizontal": 111 | if self.batch_augmentation["Horizontal"][batch]: 112 | xs = self.config["loader"]["resolution"][1] - 1 - xs 113 | 114 | elif mechanism == "Vertical": 115 | if self.batch_augmentation["Vertical"][batch]: 116 | ys = self.config["loader"]["resolution"][0] - 1 - ys 117 | 118 | elif mechanism == "Polarity": 119 | if self.batch_augmentation["Polarity"][batch]: 120 | ps *= -1 121 | 122 | # shared among batch elements 123 | elif ( 124 | batch == 0 125 | and mechanism == "Pause" 126 | and self.tc_idx > self.config["loss"]["reconstruction_tc_idx_threshold"] 127 | ): 128 | if not self.batch_augmentation["Pause"]: 129 | if np.random.random() < self.config["loader"]["augment_prob"][i][0]: 130 | self.batch_augmentation["Pause"] = True 131 | else: 132 | if np.random.random() < self.config["loader"]["augment_prob"][i][1]: 133 | self.batch_augmentation["Pause"] = False 134 | 135 | return xs, ys, ps 136 | 137 | def augment_frames(self, img, batch): 138 | """ 139 | Augment APS frame with horizontal and vertical flips. 140 | :param img: [H x W] numpy array with APS intensity 141 | :param batch: batch index 142 | :return img: [H x W] augmented numpy array with APS intensity 143 | """ 144 | if "Horizontal" in self.batch_augmentation: 145 | if self.batch_augmentation["Horizontal"][batch]: 146 | img = np.flip(img, 1) 147 | if "Vertical" in self.batch_augmentation: 148 | if self.batch_augmentation["Vertical"][batch]: 149 | img = np.flip(img, 0) 150 | return img 151 | 152 | def create_cnt_encoding(self, xs, ys, ts, ps): 153 | """ 154 | Creates a per-pixel and per-polarity event count representation. 155 | :param xs: [N] tensor with event x location 156 | :param ys: [N] tensor with event y location 157 | :param ts: [N] tensor with normalized event timestamp 158 | :param ps: [N] tensor with event polarity ([-1, 1]) 159 | :return [2 x H x W] event representation 160 | """ 161 | 162 | return events_to_channels(xs, ys, ps, sensor_size=self.config["loader"]["resolution"]) 163 | 164 | def create_voxel_encoding(self, xs, ys, ts, ps): 165 | """ 166 | Creates a spatiotemporal voxel grid tensor representation with a certain number of bins, 167 | as described in Section 3.1 of the paper 'Unsupervised Event-based Learning of Optical Flow, 168 | Depth, and Egomotion', Zhu et al., CVPR'19.. 169 | Events are distributed to the spatiotemporal closest bins through bilinear interpolation. 170 | Positive events are added as +1, while negative as -1. 171 | :param xs: [N] tensor with event x location 172 | :param ys: [N] tensor with event y location 173 | :param ts: [N] tensor with normalized event timestamp 174 | :param ps: [N] tensor with event polarity ([-1, 1]) 175 | :return [B x H x W] event representation 176 | """ 177 | 178 | return events_to_voxel( 179 | xs, 180 | ys, 181 | ts, 182 | ps, 183 | self.num_bins, 184 | sensor_size=self.config["loader"]["resolution"], 185 | ) 186 | 187 | @staticmethod 188 | def create_list_encoding(xs, ys, ts, ps): 189 | """ 190 | Creates a four channel tensor with all the events in the input partition. 191 | :param xs: [N] tensor with event x location 192 | :param ys: [N] tensor with event y location 193 | :param ts: [N] tensor with normalized event timestamp 194 | :param ps: [N] tensor with event polarity ([-1, 1]) 195 | :return [N x 4] event representation 196 | """ 197 | 198 | return torch.stack([ts, ys, xs, ps]) 199 | 200 | @staticmethod 201 | def create_polarity_mask(ps): 202 | """ 203 | Creates a two channel tensor that acts as a mask for the input event list. 204 | :param ps: [N] tensor with event polarity ([-1, 1]) 205 | :return [N x 2] event representation 206 | """ 207 | 208 | inp_pol_mask = torch.stack([ps, ps]) 209 | inp_pol_mask[0, :][inp_pol_mask[0, :] < 0] = 0 210 | inp_pol_mask[1, :][inp_pol_mask[1, :] > 0] = 0 211 | inp_pol_mask[1, :] *= -1 212 | return inp_pol_mask 213 | 214 | def create_hot_mask(self, xs, ys, ps, batch): 215 | """ 216 | Creates a one channel tensor that can act as mask to remove pixel with high event rate. 217 | :param xs: [N] tensor with event x location 218 | :param ys: [N] tensor with event y location 219 | :param ps: [N] tensor with event polarity ([-1, 1]) 220 | :return [H x W] binary mask 221 | """ 222 | 223 | hot_update = events_to_mask(xs, ys, ps, sensor_size=self.hot_events[batch].shape) 224 | self.hot_events[batch] += hot_update 225 | self.hot_idx[batch] += 1 226 | event_rate = self.hot_events[batch] / self.hot_idx[batch] 227 | return get_hot_event_mask( 228 | event_rate, 229 | self.hot_idx[batch], 230 | max_px=self.config["hot_filter"]["max_px"], 231 | min_obvs=self.config["hot_filter"]["min_obvs"], 232 | max_rate=self.config["hot_filter"]["max_rate"], 233 | ) 234 | 235 | def __len__(self): 236 | return 1000 # not used 237 | 238 | @staticmethod 239 | def custom_collate(batch): 240 | """ 241 | Collects the different event representations and stores them together in a dictionary. 242 | """ 243 | 244 | batch_dict = {} 245 | for key in batch[0].keys(): 246 | batch_dict[key] = [] 247 | for entry in batch: 248 | for key in entry.keys(): 249 | batch_dict[key].append(entry[key]) 250 | for key in batch_dict.keys(): 251 | item = torch.stack(batch_dict[key]) 252 | if len(item.shape) == 3: 253 | item = item.transpose(2, 1) 254 | batch_dict[key] = item 255 | return batch_dict 256 | 257 | def shuffle(self, flag=True): 258 | """ 259 | Shuffles the training data. 260 | """ 261 | 262 | if flag: 263 | random.shuffle(self.files) 264 | -------------------------------------------------------------------------------- /dataloader/encodings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Monash University https://github.com/TimoStoff/events_contrast_maximization 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def binary_search_array(array, x, left=None, right=None, side="left"): 10 | """ 11 | Binary search through a sorted array. 12 | """ 13 | 14 | left = 0 if left is None else left 15 | right = len(array) - 1 if right is None else right 16 | mid = left + (right - left) // 2 17 | 18 | if left > right: 19 | return left if side == "left" else right 20 | 21 | if array[mid] == x: 22 | return mid 23 | 24 | if x < array[mid]: 25 | return binary_search_array(array, x, left=left, right=mid - 1) 26 | 27 | return binary_search_array(array, x, left=mid + 1, right=right) 28 | 29 | 30 | def events_to_mask(xs, ys, ps, sensor_size=(180, 240)): 31 | """ 32 | Accumulate events into a binary mask. 33 | """ 34 | 35 | device = xs.device 36 | img_size = list(sensor_size) 37 | mask = torch.zeros(img_size).to(device) 38 | 39 | if xs.dtype is not torch.long: 40 | xs = xs.long().to(device) 41 | if ys.dtype is not torch.long: 42 | ys = ys.long().to(device) 43 | mask.index_put_((ys, xs), ps.abs(), accumulate=False) 44 | 45 | return mask 46 | 47 | 48 | def events_to_image(xs, ys, ps, sensor_size=(180, 240)): 49 | """ 50 | Accumulate events into an image. 51 | """ 52 | 53 | device = xs.device 54 | img_size = list(sensor_size) 55 | img = torch.zeros(img_size).to(device) 56 | 57 | if xs.dtype is not torch.long: 58 | xs = xs.long().to(device) 59 | if ys.dtype is not torch.long: 60 | ys = ys.long().to(device) 61 | img.index_put_((ys, xs), ps, accumulate=True) 62 | 63 | return img 64 | 65 | 66 | def events_to_voxel(xs, ys, ts, ps, num_bins, sensor_size=(180, 240)): 67 | """ 68 | Generate a voxel grid from input events using temporal bilinear interpolation. 69 | """ 70 | 71 | assert len(xs) == len(ys) and len(ys) == len(ts) and len(ts) == len(ps) 72 | 73 | voxel = [] 74 | ts = ts * (num_bins - 1) 75 | zeros = torch.zeros(ts.size()) 76 | for b_idx in range(num_bins): 77 | weights = torch.max(zeros, 1.0 - torch.abs(ts - b_idx)) 78 | voxel_bin = events_to_image(xs, ys, ps * weights, sensor_size=sensor_size) 79 | voxel.append(voxel_bin) 80 | 81 | return torch.stack(voxel) 82 | 83 | 84 | def events_to_channels(xs, ys, ps, sensor_size=(180, 240)): 85 | """ 86 | Generate a two-channel event image containing event counters. 87 | """ 88 | 89 | assert len(xs) == len(ys) and len(ys) == len(ps) 90 | 91 | mask_pos = ps.clone() 92 | mask_neg = ps.clone() 93 | mask_pos[ps < 0] = 0 94 | mask_neg[ps > 0] = 0 95 | 96 | pos_cnt = events_to_image(xs, ys, ps * mask_pos, sensor_size=sensor_size) 97 | neg_cnt = events_to_image(xs, ys, ps * mask_neg, sensor_size=sensor_size) 98 | 99 | return torch.stack([pos_cnt, neg_cnt]) 100 | 101 | 102 | def get_hot_event_mask(event_rate, idx, max_px=100, min_obvs=5, max_rate=0.8): 103 | """ 104 | Returns binary mask to remove events from hot pixels. 105 | """ 106 | 107 | mask = torch.ones(event_rate.shape).to(event_rate.device) 108 | if idx > min_obvs: 109 | for i in range(max_px): 110 | argmax = torch.argmax(event_rate) 111 | index = (argmax // event_rate.shape[1], argmax % event_rate.shape[1]) 112 | if event_rate[index] > max_rate: 113 | event_rate[index] = 0 114 | mask[index] = 0 115 | else: 116 | break 117 | return mask 118 | -------------------------------------------------------------------------------- /dataloader/h5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import numpy as np 4 | 5 | import torch 6 | import torch.utils.data as data 7 | 8 | from .base import BaseDataLoader 9 | from .utils import ProgressBar 10 | 11 | from .encodings import binary_search_array 12 | 13 | 14 | class Frames: 15 | """ 16 | Utility class for reading the APS frames encoded in the HDF5 files. 17 | """ 18 | 19 | def __init__(self): 20 | self.ts = [] 21 | self.names = [] 22 | 23 | def __call__(self, name, h5obj): 24 | if hasattr(h5obj, "dtype") and name not in self.names: 25 | self.names += [name] 26 | self.ts += [h5obj.attrs["timestamp"]] 27 | 28 | def get_frames(self, file, t0, t1, crop, res): 29 | """ 30 | Get all the APS frames in between two timestamps. 31 | :param file: file to read from 32 | :param t0: start time 33 | :param t1: end time 34 | :param crop: top-left corner of the patch to be cropped 35 | :param res: resolution of the patch to be cropped 36 | :return imgs: list of [H x W] APS images 37 | :return idx0: index of the first frame 38 | :return idx1: index of the last frame 39 | """ 40 | 41 | idx0 = binary_search_array(self.ts, t0) 42 | idx1 = binary_search_array(self.ts, t1) 43 | 44 | imgs = [] 45 | for i in range(idx0, idx1): 46 | imgs.append(file["images"]["image{:09d}".format(i)][crop[0] : crop[0] + res[0], crop[1] : crop[1] + res[1]]) 47 | 48 | return imgs, idx0, idx1 49 | 50 | 51 | class H5Loader(BaseDataLoader): 52 | def __init__(self, config, num_bins): 53 | super().__init__(config, num_bins) 54 | self.last_proc_timestamp = 0 55 | 56 | # "memory" that goes from forward pass to the next 57 | self.batch_idx = [i for i in range(self.config["loader"]["batch_size"])] # event sequence 58 | self.batch_row = [0 for i in range(self.config["loader"]["batch_size"])] # event_idx / time_idx / frame_idx 59 | self.batch_t0 = [None for i in range(self.config["loader"]["batch_size"])] 60 | 61 | # input event sequences 62 | self.files = [] 63 | for root, dirs, files in os.walk(config["data"]["path"]): 64 | for file in files: 65 | if file.endswith(".h5"): 66 | self.files.append(os.path.join(root, file)) 67 | 68 | # open first files 69 | self.open_files = [] 70 | for batch in range(self.config["loader"]["batch_size"]): 71 | self.open_files.append(h5py.File(self.files[batch], "r")) 72 | 73 | # load frames from open files 74 | self.open_files_frames = [] 75 | if self.config["data"]["mode"] == "frames": 76 | for batch in range(self.config["loader"]["batch_size"]): 77 | frames = Frames() 78 | self.open_files[batch]["images"].visititems(frames) 79 | self.open_files_frames.append(frames) 80 | 81 | # progress bars 82 | if self.config["vis"]["bars"]: 83 | self.open_files_bar = [] 84 | for batch in range(self.config["loader"]["batch_size"]): 85 | max_iters = self.get_iters(batch) 86 | self.open_files_bar.append(ProgressBar(self.files[batch].split("/")[-1], max=max_iters)) 87 | 88 | def get_iters(self, batch): 89 | """ 90 | Compute the number of forward passes given a sequence and an input mode and window. 91 | """ 92 | 93 | if self.config["data"]["mode"] == "events": 94 | max_iters = len(self.open_files[batch]["events/xs"]) 95 | elif self.config["data"]["mode"] == "time": 96 | max_iters = self.open_files[batch].attrs["duration"] 97 | elif self.config["data"]["mode"] == "frames": 98 | max_iters = len(self.open_files_frames[batch].ts) - 1 99 | else: 100 | print("DataLoader error: Unknown mode.") 101 | raise AttributeError 102 | 103 | return max_iters // self.config["data"]["window"] 104 | 105 | def get_events(self, file, idx0, idx1): 106 | """ 107 | Get all the events in between two indices. 108 | :param file: file to read from 109 | :param idx0: start index 110 | :param idx1: end index 111 | :return xs: [N] numpy array with event x location 112 | :return ys: [N] numpy array with event y location 113 | :return ts: [N] numpy array with event timestamp 114 | :return ps: [N] numpy array with event polarity ([-1, 1]) 115 | """ 116 | 117 | xs = file["events/xs"][idx0:idx1] 118 | ys = file["events/ys"][idx0:idx1] 119 | ts = file["events/ts"][idx0:idx1] 120 | ps = file["events/ps"][idx0:idx1] 121 | ts -= file.attrs["t0"] # sequence starting at t0 = 0 122 | if ts.shape[0] > 0: 123 | self.last_proc_timestamp = ts[-1] 124 | ts *= 1.0e6 # us 125 | return xs, ys, ts, ps 126 | 127 | def get_event_index(self, batch, window=0): 128 | """ 129 | Get all the event indices to be used for reading. 130 | :param batch: batch index 131 | :param window: input window 132 | :return event_idx: event index 133 | """ 134 | 135 | event_idx = None 136 | if self.config["data"]["mode"] == "events": 137 | event_idx = self.batch_row[batch] + window 138 | elif self.config["data"]["mode"] == "time": 139 | event_idx = self.find_ts_index( 140 | self.open_files[batch], self.batch_row[batch] + self.open_files[batch].attrs["t0"] + window 141 | ) 142 | elif self.config["data"]["mode"] == "frames": 143 | event_idx = self.find_ts_index( 144 | self.open_files[batch], self.open_files_frames[batch].ts[self.batch_row[batch] + window] 145 | ) 146 | else: 147 | print("DataLoader error: Unknown mode.") 148 | raise AttributeError 149 | return event_idx 150 | 151 | def find_ts_index(self, file, timestamp): 152 | """ 153 | Find closest event index for a given timestamp through binary search. 154 | """ 155 | 156 | return binary_search_array(file["events/ts"], timestamp) 157 | 158 | def __getitem__(self, index): 159 | while True: 160 | batch = index % self.config["loader"]["batch_size"] 161 | 162 | # trigger sequence change 163 | restart = False 164 | if self.config["data"]["mode"] == "frames" and self.batch_row[batch] + self.config["data"]["window"] >= len( 165 | self.open_files_frames[batch].ts 166 | ): 167 | restart = True 168 | 169 | # load events 170 | xs = np.zeros((0)) 171 | ys = np.zeros((0)) 172 | ts = np.zeros((0)) 173 | ps = np.zeros((0)) 174 | if not restart: 175 | idx0 = self.get_event_index(batch) 176 | idx1 = self.get_event_index(batch, window=self.config["data"]["window"]) 177 | xs, ys, ts, ps = self.get_events(self.open_files[batch], idx0, idx1) 178 | 179 | # trigger sequence change 180 | if (self.config["data"]["mode"] == "events" and xs.shape[0] < self.config["data"]["window"]) or xs.shape[ 181 | 0 182 | ] <= 10: 183 | restart = True 184 | 185 | # reset sequence if not enough input events 186 | if restart: 187 | self.new_seq = True 188 | self.reset_sequence(batch) 189 | self.batch_row[batch] = 0 190 | self.batch_idx[batch] = max(self.batch_idx) + 1 191 | self.batch_t0[batch] = None 192 | 193 | self.open_files[batch].close() 194 | self.open_files[batch] = h5py.File(self.files[self.batch_idx[batch] % len(self.files)], "r") 195 | 196 | if self.config["data"]["mode"] == "frames": 197 | frames = Frames() 198 | self.open_files[batch]["images"].visititems(frames) 199 | self.open_files_frames[batch] = frames 200 | if self.config["vis"]["bars"]: 201 | self.open_files_bar[batch].finish() 202 | max_iters = self.get_iters(batch) 203 | self.open_files_bar[batch] = ProgressBar( 204 | self.files[self.batch_idx[batch] % len(self.files)].split("/")[-1], max=max_iters 205 | ) 206 | 207 | continue 208 | 209 | # timestamp normalization 210 | if self.batch_t0[batch] is None: 211 | self.batch_t0[batch] = ts[0] 212 | ts -= self.batch_t0[batch] 213 | xs, ys, ts, ps = self.event_formatting(xs, ys, ts, ps) 214 | 215 | # data augmentation 216 | xs, ys, ps = self.augment_events(xs, ys, ps, batch) 217 | 218 | # artificial pauses to the event stream 219 | if "Pause" in self.config["loader"]["augment"]: 220 | if self.batch_augmentation["Pause"]: 221 | xs = torch.from_numpy(np.empty([0]).astype(np.float32)) 222 | ys = torch.from_numpy(np.empty([0]).astype(np.float32)) 223 | ts = torch.from_numpy(np.empty([0]).astype(np.float32)) 224 | ps = torch.from_numpy(np.empty([0]).astype(np.float32)) 225 | 226 | # events to tensors 227 | inp_cnt = self.create_cnt_encoding(xs, ys, ts, ps) 228 | inp_voxel = self.create_voxel_encoding(xs, ys, ts, ps) 229 | inp_list = self.create_list_encoding(xs, ys, ts, ps) 230 | inp_pol_mask = self.create_polarity_mask(ps) 231 | 232 | # hot pixel removal 233 | if self.config["hot_filter"]["enabled"]: 234 | hot_mask = self.create_hot_mask(xs, ys, ps, batch) 235 | hot_mask_voxel = torch.stack([hot_mask] * self.num_bins, axis=2).permute(2, 0, 1) 236 | hot_mask_cnt = torch.stack([hot_mask] * 2, axis=2).permute(2, 0, 1) 237 | inp_voxel = inp_voxel * hot_mask_voxel 238 | inp_cnt = inp_cnt * hot_mask_cnt 239 | 240 | # load frames when required 241 | if self.config["data"]["mode"] == "frames": 242 | inp_frames = np.zeros( 243 | ( 244 | 2, 245 | self.config["loader"]["resolution"][0], 246 | self.config["loader"]["resolution"][1], 247 | ) 248 | ) 249 | img0 = self.open_files[batch]["images"][self.open_files_frames[batch].names[self.batch_row[batch]]][:] 250 | img1 = self.open_files[batch]["images"][ 251 | self.open_files_frames[batch].names[self.batch_row[batch] + self.config["data"]["window"]] 252 | ][:] 253 | inp_frames[0, :, :] = self.augment_frames(img0, batch) 254 | inp_frames[1, :, :] = self.augment_frames(img1, batch) 255 | inp_frames = torch.from_numpy(inp_frames.astype(np.uint8)) 256 | 257 | # update window if not in pause mode 258 | if "Pause" in self.config["loader"]["augment"]: 259 | if not self.batch_augmentation["Pause"]: 260 | self.batch_row[batch] += self.config["data"]["window"] 261 | else: 262 | self.batch_row[batch] += self.config["data"]["window"] 263 | 264 | # break while loop if everything went well 265 | break 266 | 267 | # prepare output 268 | output = {} 269 | output["inp_cnt"] = inp_cnt 270 | output["inp_voxel"] = inp_voxel 271 | output["inp_list"] = inp_list 272 | output["inp_pol_mask"] = inp_pol_mask 273 | if self.config["data"]["mode"] == "frames": 274 | output["inp_frames"] = inp_frames 275 | 276 | return output 277 | -------------------------------------------------------------------------------- /dataloader/utils.py: -------------------------------------------------------------------------------- 1 | from progress.bar import Bar 2 | 3 | 4 | class ProgressBar(Bar): 5 | suffix = "%(percent).1f%%, ETA: %(eta)ds, %(frequency)fHz" 6 | 7 | @property 8 | def frequency(self): 9 | return 1 / self.avg 10 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | -------------------------------------------------------------------------------- /datasets/tools/h5_packager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Monash University https://github.com/TimoStoff/events_contrast_maximization 3 | """ 4 | 5 | import h5py 6 | import numpy as np 7 | 8 | 9 | class H5Packager: 10 | def __init__(self, output_path): 11 | print("Creating file in {}".format(output_path)) 12 | self.output_path = output_path 13 | 14 | self.file = h5py.File(output_path, "w") 15 | self.event_xs = self.file.create_dataset( 16 | "events/xs", 17 | (0,), 18 | dtype=np.dtype(np.int16), 19 | maxshape=(None,), 20 | chunks=True, 21 | ) 22 | self.event_ys = self.file.create_dataset( 23 | "events/ys", 24 | (0,), 25 | dtype=np.dtype(np.int16), 26 | maxshape=(None,), 27 | chunks=True, 28 | ) 29 | self.event_ts = self.file.create_dataset( 30 | "events/ts", 31 | (0,), 32 | dtype=np.dtype(np.float64), 33 | maxshape=(None,), 34 | chunks=True, 35 | ) 36 | self.event_ps = self.file.create_dataset( 37 | "events/ps", 38 | (0,), 39 | dtype=np.dtype(np.bool_), 40 | maxshape=(None,), 41 | chunks=True, 42 | ) 43 | 44 | def append(self, dataset, data): 45 | dataset.resize(dataset.shape[0] + len(data), axis=0) 46 | if len(data) == 0: 47 | return 48 | dataset[-len(data) :] = data[:] 49 | 50 | def package_events(self, xs, ys, ts, ps): 51 | self.append(self.event_xs, xs) 52 | self.append(self.event_ys, ys) 53 | self.append(self.event_ts, ts) 54 | self.append(self.event_ps, ps) 55 | 56 | def package_image(self, image, timestamp, img_idx): 57 | image_dset = self.file.create_dataset( 58 | "images/image{:09d}".format(img_idx), 59 | data=image, 60 | dtype=np.dtype(np.uint8), 61 | ) 62 | image_dset.attrs["size"] = image.shape 63 | image_dset.attrs["timestamp"] = timestamp 64 | image_dset.attrs["type"] = "greyscale" if image.shape[-1] == 1 or len(image.shape) == 2 else "color_bgr" 65 | 66 | def add_metadata( 67 | self, 68 | num_pos, 69 | num_neg, 70 | duration, 71 | t0, 72 | tk, 73 | num_imgs, 74 | sensor_size, 75 | ): 76 | self.file.attrs["num_events"] = num_pos + num_neg 77 | self.file.attrs["num_pos"] = num_pos 78 | self.file.attrs["num_neg"] = num_neg 79 | self.file.attrs["duration"] = tk - t0 80 | self.file.attrs["t0"] = t0 81 | self.file.attrs["tk"] = tk 82 | self.file.attrs["num_imgs"] = num_imgs 83 | self.file.attrs["sensor_resolution"] = sensor_size 84 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/ssl_e2vid/ef9d61b476576f60f92029cadde4286c6cf0f734/datasets/tools/messageTypes/__init__.py -------------------------------------------------------------------------------- /datasets/tools/messageTypes/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Event-driven Perception for Robotics https://github.com/event-driven-robotics/importRosbag under GNU General Public License. 3 | """ 4 | 5 | from struct import unpack 6 | import numpy as np 7 | 8 | 9 | def unpack_header(headerLen, headerBytes): 10 | fields = {} 11 | ptr = 0 12 | while ptr < headerLen: 13 | fieldLen = unpack("=l", headerBytes[ptr : ptr + 4])[0] 14 | ptr += 4 15 | # print(fieldLen) 16 | field = headerBytes[ptr : ptr + fieldLen] 17 | ptr += fieldLen 18 | # print(field) 19 | fieldSplit = field.find(b"\x3d") 20 | fieldName = field[:fieldSplit].decode("utf-8") 21 | fieldValue = field[fieldSplit + 1 :] 22 | fields[fieldName] = fieldValue 23 | return fields 24 | 25 | 26 | def unpackRosUint32(data, ptr): 27 | return unpack("=L", data[ptr : ptr + 4])[0], ptr + 4 28 | 29 | 30 | def unpackRosUint8(data, ptr): 31 | return unpack("=B", data[ptr : ptr + 1])[0], ptr + 1 32 | 33 | 34 | def unpackRosString(data, ptr): 35 | stringLen = unpack("=L", data[ptr : ptr + 4])[0] 36 | ptr += 4 37 | try: 38 | outStr = data[ptr : ptr + stringLen].decode("utf-8") 39 | except UnicodeDecodeError: 40 | outStr = "UnicodeDecodeError" 41 | ptr += stringLen 42 | return outStr, ptr 43 | 44 | 45 | def unpackRosFloat64Array(data, num, ptr): 46 | return ( 47 | np.frombuffer(data[ptr : ptr + num * 8], dtype=np.float64), 48 | ptr + num * 8, 49 | ) 50 | 51 | 52 | def unpackRosFloat32Array(data, num, ptr): 53 | return ( 54 | np.frombuffer(data[ptr : ptr + num * 4], dtype=np.float32), 55 | ptr + num * 4, 56 | ) 57 | 58 | 59 | def unpackRosFloat32(data, ptr): 60 | return unpack("=f", data[ptr : ptr + 4])[0], ptr + 4 61 | 62 | 63 | def unpackRosTimestamp(data, ptr): 64 | timeS, timeNs = unpack("=LL", data[ptr : ptr + 8]) 65 | timeFloat = np.float64(timeS) + np.float64(timeNs) * 0.000000001 66 | return timeFloat, ptr + 8 67 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/dvs_msgs_EventArray.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Event-driven Perception for Robotics https://github.com/event-driven-robotics/importRosbag under GNU General Public License. 3 | """ 4 | 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | # local imports 9 | 10 | from .common import unpackRosString, unpackRosUint32 11 | 12 | 13 | def importTopic(msgs, **kwargs): 14 | 15 | tsByMessage = [] 16 | xByMessage = [] 17 | yByMessage = [] 18 | polByMessage = [] 19 | for msg in tqdm(msgs, position=0, leave=True): 20 | # TODO: maybe implement kwargs['useRosMsgTimestamps'] 21 | data = msg["data"] 22 | # seq = unpack('=L', data[0:4])[0] 23 | # timeS, timeNs = unpack('=LL', data[4:12]) 24 | frame_id, ptr = unpackRosString(data, 12) 25 | height, ptr = unpackRosUint32(data, ptr) 26 | width, ptr = unpackRosUint32(data, ptr) 27 | numEventsInMsg, ptr = unpackRosUint32(data, ptr) 28 | # The format of the event is x=Uint16, y=Uint16, ts = Uint32, tns (nano seconds) = Uint32, pol=Bool 29 | # Unpack in batch into uint8 and then compose 30 | dataAsArray = np.frombuffer(data[ptr : ptr + numEventsInMsg * 13], dtype=np.uint8) 31 | dataAsArray = dataAsArray.reshape((-1, 13), order="C") 32 | # Assuming big-endian 33 | xByMessage.append((dataAsArray[:, 0] + dataAsArray[:, 1] * 2 ** 8).astype(np.uint16)) 34 | yByMessage.append((dataAsArray[:, 2] + dataAsArray[:, 3] * 2 ** 8).astype(np.uint16)) 35 | ts = ( 36 | dataAsArray[:, 4] + dataAsArray[:, 5] * 2 ** 8 + dataAsArray[:, 6] * 2 ** 16 + dataAsArray[:, 7] * 2 ** 24 37 | ).astype(np.float64) 38 | tns = ( 39 | dataAsArray[:, 8] + dataAsArray[:, 9] * 2 ** 8 + dataAsArray[:, 10] * 2 ** 16 + dataAsArray[:, 11] * 2 ** 24 40 | ).astype(np.float64) 41 | tsByMessage.append(ts + tns / 1000000000) # Combine timestamp parts, result is in seconds 42 | polByMessage.append(dataAsArray[:, 12].astype(np.bool)) 43 | outDict = { 44 | "x": np.concatenate(xByMessage), 45 | "y": np.concatenate(yByMessage), 46 | "ts": np.concatenate(tsByMessage), 47 | "pol": np.concatenate(polByMessage), 48 | "dimX": width, 49 | "dimY": height, 50 | } 51 | return outDict 52 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/esim_msgs_OpticFlow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (C) 2019 Event-driven Perception for Robotics 5 | Authors: Sim Bamford 6 | This program is free software: you can redistribute it and/or modify it under 7 | the terms of the GNU General Public License as published by the Free Software 8 | Foundation, either version 3 of the License, or (at your option) any later version. 9 | This program is distributed in the hope that it will be useful, but WITHOUT ANY 10 | WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 11 | PARTICULAR PURPOSE. See the GNU General Public License for more details. 12 | You should have received a copy of the GNU General Public License along with 13 | this program. If not, see . 14 | 15 | Intended as part of importRosbag. 16 | 17 | The importRosbag function receives a list of messages and returns 18 | a dict with one field for each data field in the message, where the field 19 | will contain an appropriate iterable to contain the interpretted contents of each message. 20 | 21 | This function imports the ros message type defined at: 22 | 23 | https://github.com/uzh-rpg/rpg_esim/blob/master/event_camera_simulator/esim_msgs/msg/OpticFlow.msg 24 | 25 | """ 26 | 27 | from tqdm import tqdm 28 | import numpy as np 29 | 30 | from .common import ( 31 | unpackRosFloat32Array, 32 | unpackRosUint32, 33 | unpackRosTimestamp, 34 | unpackRosString, 35 | ) 36 | 37 | 38 | def importTopic(msgs, **kwargs): 39 | 40 | tsAll = [] 41 | flowMaps = [] 42 | for msg in tqdm(msgs): 43 | 44 | data = msg["data"] 45 | ptr = 0 46 | seq, ptr = unpackRosUint32(data, ptr) # Not used 47 | ts, ptr = unpackRosTimestamp(data, ptr) 48 | frame_id, ptr = unpackRosString(data, ptr) # Not used 49 | height, ptr = unpackRosUint32(data, ptr) 50 | width, ptr = unpackRosUint32(data, ptr) 51 | 52 | if width > 0 and height > 0: 53 | arraySize, ptr = unpackRosUint32(data, ptr) 54 | # assert arraySize == width*height 55 | flowMapX, ptr = unpackRosFloat32Array(data, width * height, ptr) 56 | arraySize, ptr = unpackRosUint32(data, ptr) 57 | # assert arraySize == width*height 58 | flowMapY, ptr = unpackRosFloat32Array(data, width * height, ptr) 59 | flowMap = np.concatenate( 60 | ( 61 | flowMapX.reshape(height, width, 1), 62 | flowMapY.reshape(height, width, 1), 63 | ), 64 | axis=2, 65 | ) 66 | flowMaps.append(flowMap) 67 | tsAll.append(ts) 68 | if not tsAll: 69 | return None 70 | outDict = { 71 | "ts": np.array(tsAll, dtype=np.float64), 72 | "flowMaps": flowMaps, 73 | } 74 | return outDict 75 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/geometry_msgs_PoseStamped.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Event-driven Perception for Robotics https://github.com/event-driven-robotics/importRosbag under GNU General Public License. 3 | """ 4 | 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | from .common import unpackRosString, unpackRosTimestamp, unpackRosFloat64Array 9 | 10 | 11 | def importTopic(msgs, **kwargs): 12 | # if 'Stamped' not in kwargs.get('messageType', 'Stamped'): 13 | # return interpretMsgsAsPose6qAlt(msgs, **kwargs) 14 | sizeOfArray = 1024 15 | tsAll = np.zeros((sizeOfArray), dtype=np.float64) 16 | poseAll = np.zeros((sizeOfArray, 7), dtype=np.float64) 17 | for idx, msg in enumerate(tqdm(msgs, position=0, leave=True)): 18 | if sizeOfArray <= idx: 19 | tsAll = np.append(tsAll, np.zeros((sizeOfArray), dtype=np.float64)) 20 | poseAll = np.concatenate((poseAll, np.zeros((sizeOfArray, 7), dtype=np.float64))) 21 | sizeOfArray *= 2 22 | # TODO: maybe implement kwargs['useRosMsgTimestamps'] 23 | data = msg["data"] 24 | # seq = unpack('=L', data[0:4])[0] 25 | tsAll[idx], ptr = unpackRosTimestamp(data, 4) 26 | frame_id, ptr = unpackRosString(data, ptr) 27 | poseAll[idx, :], _ = unpackRosFloat64Array(data, 7, ptr) 28 | # Crop arrays to number of events 29 | numEvents = idx + 1 30 | tsAll = tsAll[:numEvents] 31 | poseAll = poseAll[:numEvents] 32 | point = poseAll[:, 0:3] 33 | rotation = poseAll[:, [6, 3, 4, 5]] # Switch quaternion form from xyzw to wxyz 34 | outDict = {"ts": tsAll, "point": point, "rotation": rotation} 35 | return outDict 36 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/geometry_msgs_Transform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (C) 2019 Event-driven Perception for Robotics 5 | Authors: Sim Bamford 6 | This program is free software: you can redistribute it and/or modify it under 7 | the terms of the GNU General Public License as published by the Free Software 8 | Foundation, either version 3 of the License, or (at your option) any later version. 9 | This program is distributed in the hope that it will be useful, but WITHOUT ANY 10 | WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 11 | PARTICULAR PURPOSE. See the GNU General Public License for more details. 12 | You should have received a copy of the GNU General Public License along with 13 | this program. If not, see . 14 | 15 | Intended as part of importRosbag. 16 | 17 | The importTopic function receives a list of messages and returns 18 | a dict with one field for each data field in the message, where the field 19 | will contain an appropriate iterable to contain the interpretted contents of each message. 20 | In some cases, static info is repeated in each message; in which case a field may not contain an iterable. 21 | 22 | This function imports the ros message type defined at: 23 | http://docs.ros.org/melodic/api/geometry_msgs/html/msg/Transform.html 24 | """ 25 | 26 | from tqdm import tqdm 27 | import numpy as np 28 | 29 | # Local imports 30 | 31 | from .common import unpackRosTimestamp, unpackRosFloat64Array 32 | 33 | 34 | def importTopic(msgs, **kwargs): 35 | numEvents = 0 36 | sizeOfArray = 1024 37 | tsAll = np.zeros((sizeOfArray), dtype=np.float64) 38 | poseAll = np.zeros((sizeOfArray, 7), dtype=np.float64) 39 | for idx, msg in enumerate(tqdm(msgs, position=0, leave=True)): 40 | if sizeOfArray <= idx: 41 | tsAll = np.append(tsAll, np.zeros((sizeOfArray), dtype=np.float64)) 42 | poseAll = np.concatenate((poseAll, np.zeros((sizeOfArray, 7), dtype=np.float64))) 43 | sizeOfArray *= 2 44 | # Note - ignoring kwargs['useRosMsgTimestamps'] as there is no choice 45 | tsAll[idx], _ = unpackRosTimestamp(msg["time"], 0) 46 | poseAll[idx, :], _ = unpackRosFloat64Array(msg["data"], 7, 0) 47 | numEvents = idx + 1 48 | # Crop arrays to number of events 49 | tsAll = tsAll[:numEvents] 50 | 51 | """ Needed? 52 | from timestamps import zeroTimestampsForAChannel, rezeroTimestampsForImportedDicts, unwrapTimestamps 53 | tsAll = unwrapTimestamps(tsAll) # here we could pass in wrapTime=2**62, but actually it handles this internally 54 | """ 55 | poseAll = poseAll[:numEvents] 56 | point = poseAll[:, 0:3] 57 | rotation = poseAll[:, [6, 3, 4, 5]] # Switch quaternion form from xyzw to wxyz 58 | outDict = {"ts": tsAll, "point": point, "rotation": rotation} 59 | return outDict 60 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/geometry_msgs_TransformStamped.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (C) 2019 Event-driven Perception for Robotics 5 | Authors: Sim Bamford 6 | This program is free software: you can redistribute it and/or modify it under 7 | the terms of the GNU General Public License as published by the Free Software 8 | Foundation, either version 3 of the License, or (at your option) any later version. 9 | This program is distributed in the hope that it will be useful, but WITHOUT ANY 10 | WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 11 | PARTICULAR PURPOSE. See the GNU General Public License for more details. 12 | You should have received a copy of the GNU General Public License along with 13 | this program. If not, see . 14 | 15 | Intended as part of importRosbag. 16 | 17 | The importTopic function receives a list of messages and returns 18 | a dict with one field for each data field in the message, where the field 19 | will contain an appropriate iterable to contain the interpretted contents of each message. 20 | In some cases, static info is repeated in each message; in which case a field may not contain an iterable. 21 | 22 | This function imports the ros message type defined at: 23 | http://docs.ros.org/api/geometry_msgs/html/msg/TransformStamped.html 24 | 25 | The result is a ts plus a 7-column np array of np.float64, 26 | where the cols are x, y, z, q-w, q-x, q-y, q-z, (i.e. quaternion orientation) 27 | 28 | NOTE: QUATERNION ORDER GETS MODIFIED from xyzw to wxyz 29 | 30 | NOTE - this code is identical to geometry_msgs_PoseStamped 31 | """ 32 | 33 | from tqdm import tqdm 34 | import numpy as np 35 | 36 | from .common import unpackRosString, unpackRosTimestamp, unpackRosFloat64Array 37 | 38 | 39 | def importTopic(msgs, **kwargs): 40 | # if 'Stamped' not in kwargs.get('messageType', 'Stamped'): 41 | # return interpretMsgsAsPose6qAlt(msgs, **kwargs) 42 | sizeOfArray = 1024 43 | tsAll = np.zeros((sizeOfArray), dtype=np.float64) 44 | poseAll = np.zeros((sizeOfArray, 7), dtype=np.float64) 45 | for idx, msg in enumerate(tqdm(msgs, position=0, leave=True)): 46 | if sizeOfArray <= idx: 47 | tsAll = np.append(tsAll, np.zeros((sizeOfArray), dtype=np.float64)) 48 | poseAll = np.concatenate((poseAll, np.zeros((sizeOfArray, 7), dtype=np.float64))) 49 | sizeOfArray *= 2 50 | # TODO: maybe implement kwargs['useRosMsgTimestamps'] 51 | data = msg["data"] 52 | # seq = unpack('=L', data[0:4])[0] 53 | tsAll[idx], ptr = unpackRosTimestamp(data, 4) 54 | frame_id, ptr = unpackRosString(data, ptr) 55 | poseAll[idx, :], _ = unpackRosFloat64Array(data, 7, ptr) 56 | # Crop arrays to number of events 57 | numEvents = idx + 1 58 | tsAll = tsAll[:numEvents] 59 | poseAll = poseAll[:numEvents] 60 | point = poseAll[:, 0:3] 61 | rotation = poseAll[:, [6, 3, 4, 5]] # Switch quaternion form from xyzw to wxyz 62 | outDict = {"ts": tsAll, "point": point, "rotation": rotation} 63 | return outDict 64 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/geometry_msgs_TwistStamped.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (C) 2019 Event-driven Perception for Robotics 5 | Authors: Sim Bamford 6 | This program is free software: you can redistribute it and/or modify it under 7 | the terms of the GNU General Public License as published by the Free Software 8 | Foundation, either version 3 of the License, or (at your option) any later version. 9 | This program is distributed in the hope that it will be useful, but WITHOUT ANY 10 | WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 11 | PARTICULAR PURPOSE. See the GNU General Public License for more details. 12 | You should have received a copy of the GNU General Public License along with 13 | this program. If not, see . 14 | 15 | Intended as part of importRosbag. 16 | 17 | The interpretMessages function receives a list of messages and returns 18 | a dict with one field for each data field in the message, where the field 19 | will contain an appropriate iterable to contain the interpretted contents of each message. 20 | 21 | This function imports the ros message type defined at: 22 | http://docs.ros.org/melodic/api/geometry_msgs/html/msg/TwistStamped.html 23 | """ 24 | 25 | from tqdm import tqdm 26 | import numpy as np 27 | 28 | from .common import unpackRosString, unpackRosTimestamp, unpackRosFloat64Array 29 | 30 | 31 | def importTopic(msgs, **kwargs): 32 | sizeOfArray = 1024 33 | ts = np.zeros((sizeOfArray), dtype=np.float64) 34 | linV = np.zeros((sizeOfArray, 3), dtype=np.float64) 35 | angV = np.zeros((sizeOfArray, 3), dtype=np.float64) 36 | for idx, msg in enumerate(tqdm(msgs, position=0, leave=True)): 37 | if sizeOfArray <= idx: 38 | ts = np.append(ts, np.zeros((sizeOfArray), dtype=np.float64)) 39 | linV = np.concatenate((linV, np.zeros((sizeOfArray, 3), dtype=np.float64))) 40 | angV = np.concatenate((angV, np.zeros((sizeOfArray, 3), dtype=np.float64))) 41 | sizeOfArray *= 2 42 | data = msg["data"] 43 | # seq = unpack('=L', data[0:4])[0] 44 | ts[idx], ptr = unpackRosTimestamp(data, 4) 45 | frame_id, ptr = unpackRosString(data, ptr) 46 | linV[idx, :], ptr = unpackRosFloat64Array(data, 3, ptr) 47 | angV[idx, :], _ = unpackRosFloat64Array(data, 3, ptr) 48 | # Crop arrays to number of events 49 | numEvents = idx + 1 50 | ts = ts[:numEvents] 51 | linV = linV[:numEvents] 52 | angV = angV[:numEvents] 53 | outDict = {"ts": ts, "linV": linV, "angV": angV} 54 | return outDict 55 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/sensor_msgs_CameraInfo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (C) 2019 Event-driven Perception for Robotics 5 | Authors: Sim Bamford 6 | This program is free software: you can redistribute it and/or modify it under 7 | the terms of the GNU General Public License as published by the Free Software 8 | Foundation, either version 3 of the License, or (at your option) any later version. 9 | This program is distributed in the hope that it will be useful, but WITHOUT ANY 10 | WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 11 | PARTICULAR PURPOSE. See the GNU General Public License for more details. 12 | You should have received a copy of the GNU General Public License along with 13 | this program. If not, see . 14 | 15 | Intended as part of importRosbag. 16 | 17 | The importTopic function receives a list of messages and returns 18 | a dict with one field for each data field in the message, where the field 19 | will contain an appropriate iterable to contain the interpretted contents of each message. 20 | In some cases, static info is repeated in each message; in which case a field may not contain an iterable. 21 | 22 | This function imports the ros message type defined at: 23 | 24 | http://docs.ros.org/api/sensor_msgs/html/msg/CameraInfo.html 25 | We assume that there will only be one camera_info msg per channel, 26 | so the resulting dict is populated by the following fields: 27 | std_msgs/Header header 28 | uint32 height 29 | uint32 width 30 | string distortion_model 31 | 32 | float64[] D (distortion params) 33 | float64[9] K (Intrinsic camera matrix) 34 | float64[9] R (rectification matrix - only for stereo setup) 35 | float64[12] P (projection matrix) 36 | 37 | uint32 binning_x 38 | uint32 binning_y 39 | sensor_msgs/RegionOfInterest roi 40 | 41 | """ 42 | 43 | from .common import unpackRosString, unpackRosUint32, unpackRosFloat64Array 44 | 45 | 46 | def importTopic(msgs, **kwargs): 47 | 48 | outDict = {} 49 | data = msgs[0]["data"] # There is one calibration msg per frame. 50 | # Just use the first one 51 | # seq = unpack('=L', data[0:4])[0] 52 | # timeS, timeNs = unpack('=LL', data[4:12]) 53 | frame_id, ptr = unpackRosString(data, 12) 54 | outDict["height"], ptr = unpackRosUint32(data, ptr) 55 | outDict["width"], ptr = unpackRosUint32(data, ptr) 56 | outDict["distortionModel"], ptr = unpackRosString(data, ptr) 57 | numElementsInD, ptr = unpackRosUint32(data, ptr) 58 | outDict["D"], ptr = unpackRosFloat64Array(data, numElementsInD, ptr) 59 | outDict["K"], ptr = unpackRosFloat64Array(data, 9, ptr) 60 | outDict["K"] = outDict["K"].reshape(3, 3) 61 | outDict["R"], ptr = unpackRosFloat64Array(data, 9, ptr) 62 | outDict["R"] = outDict["R"].reshape(3, 3) 63 | outDict["P"], ptr = unpackRosFloat64Array(data, 12, ptr) 64 | outDict["P"] = outDict["P"].reshape(3, 4) 65 | # Ignore binning and ROI 66 | return outDict 67 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/sensor_msgs_Image.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (C) 2019 Event-driven Perception for Robotics 5 | Authors: Sim Bamford 6 | This program is free software: you can redistribute it and/or modify it under 7 | the terms of the GNU General Public License as published by the Free Software 8 | Foundation, either version 3 of the License, or (at your option) any later version. 9 | This program is distributed in the hope that it will be useful, but WITHOUT ANY 10 | WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 11 | PARTICULAR PURPOSE. See the GNU General Public License for more details. 12 | You should have received a copy of the GNU General Public License along with 13 | this program. If not, see . 14 | 15 | Intended as part of importRosbag. 16 | 17 | The importTopic function receives a list of messages and returns 18 | a dict with one field for each data field in the message, where the field 19 | will contain an appropriate iterable to contain the interpretted contents of each message. 20 | In some cases, static info is repeated in each message; in which case a field may not contain an iterable. 21 | 22 | This function imports the ros message type defined at: 23 | http://docs.ros.org/api/sensor_msgs/html/msg/Image.html 24 | """ 25 | 26 | 27 | from tqdm import tqdm 28 | import numpy as np 29 | 30 | from .common import ( 31 | unpackRosString, 32 | unpackRosUint32, 33 | unpackRosUint8, 34 | unpackRosTimestamp, 35 | ) 36 | 37 | 38 | def importTopic(msgs, **kwargs): 39 | """ 40 | ros message is defined here: 41 | http://docs.ros.org/api/sensor_msgs/html/msg/Image.html 42 | the result is a ts plus a 2d array of samples () 43 | """ 44 | sizeOfArray = 1024 45 | tsAll = np.zeros((sizeOfArray), dtype=np.float64) 46 | framesAll = [] 47 | for idx, msg in enumerate(tqdm(msgs, position=0, leave=True)): 48 | if sizeOfArray <= idx: 49 | tsAll = np.append(tsAll, np.zeros((sizeOfArray), dtype=np.float64)) 50 | sizeOfArray *= 2 51 | data = msg["data"] 52 | # seq = unpack('=L', data[0:4])[0] 53 | if kwargs.get("useRosMsgTimestamps", False): 54 | tsAll[idx], _ = unpackRosTimestamp(msg["time"], 0) 55 | else: 56 | tsAll[idx], _ = unpackRosTimestamp(data, 4) 57 | frame_id, ptr = unpackRosString(data, 12) 58 | height, ptr = unpackRosUint32(data, ptr) 59 | width, ptr = unpackRosUint32(data, ptr) 60 | fmtString, ptr = unpackRosString(data, ptr) 61 | isBigendian, ptr = unpackRosUint8(data, ptr) 62 | if isBigendian: 63 | print("data is bigendian, but it doesn" "t matter") 64 | step, ptr = unpackRosUint32(data, ptr) # not used 65 | arraySize, ptr = unpackRosUint32(data, ptr) 66 | # assert arraySize == height*width 67 | 68 | # The pain of writing this scetion will continue to increase until it 69 | # matches this reference implementation: 70 | # http://docs.ros.org/jade/api/sensor_msgs/html/image__encodings_8h_source.html 71 | if fmtString in ["mono8", "8UC1"]: 72 | frameData = np.frombuffer(data[ptr : ptr + height * width], np.uint8) 73 | depth = 1 74 | elif fmtString in ["mono16", "16UC1"]: 75 | frameData = np.frombuffer(data[ptr : ptr + height * width * 2], np.uint16) 76 | depth = 1 77 | elif fmtString in ["bgr8", "rgb8"]: 78 | frameData = np.frombuffer(data[ptr : ptr + height * width * 3], np.uint8) 79 | depth = 3 80 | elif fmtString in ["bgra8", "rgba8"]: 81 | frameData = np.frombuffer(data[ptr : ptr + height * width * 4], np.uint8) 82 | depth = 4 83 | elif fmtString == "16SC1": 84 | frameData = np.frombuffer(data[ptr : ptr + height * width * 2], np.int16) 85 | depth = 1 86 | elif fmtString == "32FC1": 87 | frameData = np.frombuffer(data[ptr : ptr + height * width * 4], np.float32) 88 | depth = 1 89 | else: 90 | print("image format not supported:" + fmtString) 91 | return None 92 | if depth > 1: 93 | frameData = frameData.reshape(height, width, depth) 94 | else: 95 | frameData = frameData.reshape(height, width) 96 | 97 | framesAll.append(frameData) 98 | numEvents = idx + 1 99 | # Crop arrays to number of events 100 | tsAll = tsAll[:numEvents] 101 | outDict = {"ts": tsAll, "frames": framesAll} 102 | return outDict 103 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/sensor_msgs_Imu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (C) 2019 Event-driven Perception for Robotics 5 | Authors: Sim Bamford 6 | This program is free software: you can redistribute it and/or modify it under 7 | the terms of the GNU General Public License as published by the Free Software 8 | Foundation, either version 3 of the License, or (at your option) any later version. 9 | This program is distributed in the hope that it will be useful, but WITHOUT ANY 10 | WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 11 | PARTICULAR PURPOSE. See the GNU General Public License for more details. 12 | You should have received a copy of the GNU General Public License along with 13 | this program. If not, see . 14 | 15 | Intended as part of importRosbag. 16 | 17 | The importTopic function receives a list of messages and returns 18 | a dict with one field for each data field in the message, where the field 19 | will contain an appropriate iterable to contain the interpretted contents of each message. 20 | In some cases, static info is repeated in each message; in which case a field may not contain an iterable. 21 | 22 | This function imports the ros message type defined at: 23 | http://docs.ros.org/api/sensor_msgs/html/msg/Imu.html 24 | """ 25 | 26 | from tqdm import tqdm 27 | import numpy as np 28 | 29 | from .common import unpackRosString, unpackRosTimestamp, unpackRosFloat64Array 30 | 31 | 32 | def importTopic(msgs, **kwargs): 33 | """ 34 | ros message is defined here: 35 | http://docs.ros.org/api/geometry_msgs/html/msg/PoseStamped.html 36 | the result is are np arrays of float64 for: 37 | rotQ (4 cols, quaternion) 38 | angV (3 cols) 39 | acc (3 cols) 40 | mag (3 cols) 41 | temp (1 cols) - but I'll probably ignore this to start with 42 | """ 43 | sizeOfArray = 1024 44 | tsAll = np.zeros((sizeOfArray), dtype=np.float64) 45 | rotQAll = np.zeros((sizeOfArray, 4), dtype=np.float64) 46 | angVAll = np.zeros((sizeOfArray, 3), dtype=np.float64) 47 | accAll = np.zeros((sizeOfArray, 3), dtype=np.float64) 48 | magAll = np.zeros((sizeOfArray, 3), dtype=np.float64) 49 | # tempAll = np.zeros((sizeOfArray, 1), dtype=np.float64) 50 | for idx, msg in enumerate(tqdm(msgs, position=0, leave=True)): 51 | if sizeOfArray <= idx: 52 | tsAll = np.append(tsAll, np.zeros((sizeOfArray), dtype=np.float64)) 53 | rotQAll = np.concatenate((rotQAll, np.zeros((sizeOfArray, 4), dtype=np.float64))) 54 | angVAll = np.concatenate((angVAll, np.zeros((sizeOfArray, 3), dtype=np.float64))) 55 | accAll = np.concatenate((accAll, np.zeros((sizeOfArray, 3), dtype=np.float64))) 56 | magAll = np.concatenate((magAll, np.zeros((sizeOfArray, 3), dtype=np.float64))) 57 | sizeOfArray *= 2 58 | # TODO: maybe implement kwargs['useRosMsgTimestamps'] 59 | data = msg["data"] 60 | # seq = unpack('=L', data[0:4])[0] 61 | tsAll[idx], ptr = unpackRosTimestamp(data, 4) 62 | frame_id, ptr = unpackRosString(data, ptr) 63 | rotQAll[idx, :], ptr = unpackRosFloat64Array(data, 4, ptr) 64 | ptr += 72 # Skip the covariance matrix 65 | angVAll[idx, :], ptr = unpackRosFloat64Array(data, 3, ptr) 66 | ptr += 72 # Skip the covariance matrix 67 | accAll[idx, :], ptr = unpackRosFloat64Array(data, 3, ptr) 68 | # ptr += 24 69 | # ptr += 72 # Skip the covariance matrix 70 | numEvents = idx + 1 71 | # Crop arrays to number of events 72 | tsAll = tsAll[:numEvents] 73 | rotQAll = rotQAll[:numEvents] 74 | angVAll = angVAll[:numEvents] 75 | accAll = accAll[:numEvents] 76 | magAll = magAll[:numEvents] 77 | outDict = { 78 | "ts": tsAll, 79 | "rotQ": rotQAll, 80 | "angV": angVAll, 81 | "acc": accAll, 82 | "mag": magAll, 83 | } 84 | return outDict 85 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/sensor_msgs_PointCloud2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (C) 2019 Event-driven Perception for Robotics 5 | Authors: Sim Bamford 6 | This program is free software: you can redistribute it and/or modify it under 7 | the terms of the GNU General Public License as published by the Free Software 8 | Foundation, either version 3 of the License, or (at your option) any later version. 9 | This program is distributed in the hope that it will be useful, but WITHOUT ANY 10 | WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 11 | PARTICULAR PURPOSE. See the GNU General Public License for more details. 12 | You should have received a copy of the GNU General Public License along with 13 | this program. If not, see . 14 | 15 | Intended as part of importRosbag. 16 | 17 | The importTopic function receives a list of messages and returns 18 | a dict with one field for each data field in the message, where the field 19 | will contain an appropriate iterable to contain the interpretted contents of each message. 20 | In some cases, static info is repeated in each message; in which case a field may not contain an iterable. 21 | 22 | This function imports the ros message type defined at: 23 | http://docs.ros.org/api/sensor_msgs/html/msg/PointCloud2.html 24 | 25 | For simplicity, we're currently directly unpacking the format that we are 26 | encountering in the data, which is x,y,z,_,rgb,_,_,_ 27 | each as 32-bit little-endian floats 28 | """ 29 | 30 | from tqdm import tqdm 31 | import numpy as np 32 | 33 | from .common import ( 34 | unpackRosString, 35 | unpackRosUint8, 36 | unpackRosUint32, 37 | unpackRosTimestamp, 38 | ) 39 | 40 | 41 | def importTopic(msgs, **kwargs): 42 | """ 43 | ros message is defined here: 44 | http://docs.ros.org/api/geometry_msgs/html/msg/PoseStamped.html 45 | the result is are np arrays of float64 for: 46 | rotQ (4 cols, quaternion) 47 | angV (3 cols) 48 | acc (3 cols) 49 | mag (3 cols) 50 | temp (1 cols) - but I'll probably ignore this to start with 51 | """ 52 | # tempAll = np.zeros((sizeOfArray, 1), dtype=np.float64) 53 | # for msg in tqdm(msgs, position=0, leave=True): 54 | tsByMessage = [] 55 | pointsByMessage = [] 56 | for msg in tqdm(msgs): 57 | 58 | data = msg["data"] 59 | ptr = 0 60 | seq, ptr = unpackRosUint32(data, ptr) 61 | ts, ptr = unpackRosTimestamp(data, ptr) 62 | frame_id, ptr = unpackRosString(data, ptr) 63 | height, ptr = unpackRosUint32(data, ptr) 64 | width, ptr = unpackRosUint32(data, ptr) 65 | 66 | if width > 0 and height > 0: 67 | 68 | arraySize, ptr = unpackRosUint32(data, ptr) 69 | for element in range(arraySize): 70 | # Move through the field definitions - we'll ignore these 71 | # until we encounter a file that uses a different set 72 | name, ptr = unpackRosString(data, ptr) 73 | offset, ptr = unpackRosUint32(data, ptr) 74 | datatype, ptr = unpackRosUint8(data, ptr) 75 | count, ptr = unpackRosUint32(data, ptr) 76 | 77 | isBigendian, ptr = unpackRosUint8(data, ptr) 78 | pointStep, ptr = unpackRosUint32(data, ptr) 79 | rowStep, ptr = unpackRosUint32(data, ptr) 80 | 81 | numPoints = width * height 82 | points = np.empty((numPoints, 3), dtype=np.float32) 83 | arraySize, ptr = unpackRosUint32(data, ptr) 84 | # assert arraySize = width*height 85 | for x in range(width): 86 | for y in range(height): 87 | points[x * height + y, :] = np.frombuffer(data[ptr : ptr + 12], dtype=np.float32) 88 | ptr += pointStep 89 | pointsByMessage.append(points) 90 | tsByMessage.append(np.ones((numPoints), dtype=np.float64) * ts) 91 | if not pointsByMessage: # None of the messages contained any points 92 | return None 93 | points = np.concatenate(pointsByMessage) 94 | ts = np.concatenate(tsByMessage) 95 | 96 | # Crop arrays to number of events 97 | outDict = { 98 | "ts": ts, 99 | "point": points, 100 | } 101 | return outDict 102 | -------------------------------------------------------------------------------- /datasets/tools/messageTypes/tf_tfMessage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Copyright (C) 2019 Event-driven Perception for Robotics 5 | Authors: Sim Bamford 6 | This program is free software: you can redistribute it and/or modify it under 7 | the terms of the GNU General Public License as published by the Free Software 8 | Foundation, either version 3 of the License, or (at your option) any later version. 9 | This program is distributed in the hope that it will be useful, but WITHOUT ANY 10 | WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 11 | PARTICULAR PURPOSE. See the GNU General Public License for more details. 12 | You should have received a copy of the GNU General Public License along with 13 | this program. If not, see . 14 | 15 | Intended as part of importRosbag. 16 | 17 | The importTopic function receives a list of messages and returns 18 | a dict with one field for each data field in the message, where the field 19 | will contain an appropriate iterable to contain the interpretted contents of each message. 20 | In some cases, static info is repeated in each message; in which case a field may not contain an iterable. 21 | 22 | This function imports the ros message type defined at: 23 | http://docs.ros.org/api/tf/html/msg/tfMessage.html 24 | 25 | Each message contains an array of transform_stamped messages 26 | 27 | The result is a ts plus a 7-column np array of np.float64, 28 | where the cols are x, y, z, q-w, q-x, q-y, q-z, (i.e. quaternion orientation) 29 | 30 | NOTE: QUATERNION ORDER GETS MODIFIED from xyzw to wxyz 31 | 32 | NOTE - this code is similar to geometry_msgs_TransformStamped 33 | """ 34 | 35 | from tqdm import tqdm 36 | import numpy as np 37 | 38 | from .common import ( 39 | unpackRosString, 40 | unpackRosTimestamp, 41 | unpackRosFloat64Array, 42 | unpackRosUint32, 43 | ) 44 | 45 | 46 | def importTopic(msgs, **kwargs): 47 | # if 'Stamped' not in kwargs.get('messageType', 'Stamped'): 48 | # return interpretMsgsAsPose6qAlt(msgs, **kwargs) 49 | sizeOfArray = 1024 50 | tsAll = np.zeros((sizeOfArray), dtype=np.float64) 51 | poseAll = np.zeros((sizeOfArray, 7), dtype=np.float64) 52 | frameIdAll = [] 53 | childFrameIdAll = [] 54 | idx = 0 55 | for msg in tqdm(msgs, position=0, leave=True): 56 | data = msg["data"] 57 | numTfInMsg, ptr = unpackRosUint32(data, 0) 58 | for tfIdx in range(numTfInMsg): 59 | while sizeOfArray <= idx + numTfInMsg: 60 | tsAll = np.append(tsAll, np.zeros((sizeOfArray), dtype=np.float64)) 61 | poseAll = np.concatenate((poseAll, np.zeros((sizeOfArray, 7), dtype=np.float64))) 62 | sizeOfArray *= 2 63 | seq, ptr = unpackRosUint32(data, ptr) 64 | tsAll[idx], ptr = unpackRosTimestamp(data, ptr) 65 | frame_id, ptr = unpackRosString(data, ptr) 66 | frameIdAll.append(frame_id) 67 | child_frame_id, ptr = unpackRosString(data, ptr) 68 | childFrameIdAll.append(child_frame_id) 69 | poseAll[idx, :], ptr = unpackRosFloat64Array(data, 7, ptr) 70 | idx += 1 71 | # Crop arrays to number of events 72 | numEvents = idx 73 | tsAll = tsAll[:numEvents] 74 | poseAll = poseAll[:numEvents] 75 | point = poseAll[:, 0:3] 76 | rotation = poseAll[:, [6, 3, 4, 5]] # Switch quaternion form from xyzw to wxyz 77 | outDict = { 78 | "ts": tsAll, 79 | "point": point, 80 | "rotation": rotation, 81 | "frameId": np.array(frameIdAll, dtype="object"), 82 | "childFrameId": np.array(childFrameIdAll, dtype="object"), 83 | } 84 | return outDict 85 | -------------------------------------------------------------------------------- /datasets/tools/random_crop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import h5py 4 | import argparse 5 | import numpy as np 6 | from ast import literal_eval 7 | 8 | from h5_packager import H5Packager 9 | 10 | 11 | class Frames: 12 | def __init__(self): 13 | self.ts = [] 14 | self.names = [] 15 | 16 | def __call__(self, name, h5obj): 17 | if hasattr(h5obj, "dtype") and name not in self.names: 18 | self.names += [name] 19 | self.ts += [h5obj.attrs["timestamp"]] 20 | 21 | def get_frames(self, file, t0, t1, crop, res): 22 | idx0 = binary_search_array(self.ts, t0) 23 | idx1 = binary_search_array(self.ts, t1) 24 | 25 | imgs = [] 26 | for i in range(idx0, idx1): 27 | imgs.append(file["images"]["image{:09d}".format(i)][crop[0] : crop[0] + res[0], crop[1] : crop[1] + res[1]]) 28 | 29 | return imgs, idx0, idx1 30 | 31 | 32 | def binary_search_array(array, x, l=None, r=None, side="left"): 33 | """ 34 | Binary search through a sorted array. 35 | """ 36 | 37 | l = 0 if l is None else l 38 | r = len(array) - 1 if r is None else r 39 | mid = l + (r - l) // 2 40 | 41 | if l > r: 42 | return l if side == "left" else r 43 | 44 | if array[mid] == x: 45 | return mid 46 | elif x < array[mid]: 47 | return binary_search_array(array, x, l=l, r=mid - 1) 48 | else: 49 | return binary_search_array(array, x, l=mid + 1, r=r) 50 | 51 | 52 | def find_ts_index(file, timestamp): 53 | idx = binary_search_array(file["events/ts"], timestamp) 54 | return idx 55 | 56 | 57 | def get_events(file, idx0, idx1): 58 | xs = file["events/xs"][idx0:idx1] 59 | ys = file["events/ys"][idx0:idx1] 60 | ts = file["events/ts"][idx0:idx1] 61 | ps = file["events/ps"][idx0:idx1] * 2 - 1 62 | return xs, ys, ts, ps 63 | 64 | 65 | def random_crop(original_res, output_res): 66 | h = np.random.randint(0, high=original_res[0] - 1 - output_res[0]) 67 | w = np.random.randint(0, high=original_res[1] - 1 - output_res[1]) 68 | return h, w 69 | 70 | 71 | if __name__ == "__main__": 72 | """ 73 | Tool for generating a training dataset out of a set of specified group of H5 datasets. 74 | The original sequences are cropped both in time and space. 75 | The resulting training sequences contain the raw images if available. 76 | """ 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument("path", help="directory of datasets to be used") 79 | parser.add_argument( 80 | "--output_dir", 81 | default="/tmp/training", 82 | help="output directory containing the resulting training sequences", 83 | ) 84 | parser.add_argument( 85 | "--time_length", 86 | default=2, 87 | help="maximum duration, in seconds, of the training sequences", 88 | type=float, 89 | ) 90 | parser.add_argument( 91 | "--original_res", 92 | default="(180, 240)", 93 | help="resolution of the original sequences, HxW", 94 | ) 95 | parser.add_argument( 96 | "--output_res", 97 | default="(128,128)", 98 | help="resolution of the resulting sequences, HxW", 99 | ) 100 | parser.add_argument( 101 | "--with_images", 102 | default=False, 103 | help="whether or not the resulting dataset should contain images", 104 | ) 105 | args = parser.parse_args() 106 | original_res = literal_eval(args.original_res) 107 | output_res = literal_eval(args.output_res) 108 | 109 | print("Data will be extracted in folder: {}".format(args.output_dir)) 110 | if not os.path.exists(args.output_dir): 111 | os.makedirs(args.output_dir) 112 | 113 | path_from = [] 114 | for root, dirs, files in os.walk(args.path): 115 | for file in files: 116 | if file.endswith(".h5"): 117 | path_from.append(os.path.join(root, file)) 118 | 119 | # process dataset 120 | for path in path_from: 121 | hf = h5py.File(path, "r") 122 | print("Processing:", path) 123 | filename = path.split("/")[-1].split(".")[0] 124 | 125 | # load image data 126 | if args.with_images: 127 | frames = Frames() 128 | hf["images"].visititems(frames) 129 | 130 | # get subsequence random crop params 131 | crop = random_crop(original_res, output_res) 132 | 133 | # start reading sequence 134 | t = hf["events/ts"][0] # s 135 | idx0 = find_ts_index(hf, t) 136 | 137 | sequence_id = 0 138 | while True: 139 | idx1 = find_ts_index(hf, t + args.time_length) 140 | 141 | # events in temporal window 142 | xs, ys, ts, ps = get_events(hf, idx0, idx1) 143 | if len(xs) == 0: 144 | break 145 | 146 | # events in spatial window 147 | x_out = np.argwhere(xs < crop[1]) 148 | x_out = np.concatenate((x_out, np.argwhere(xs >= crop[1] + output_res[1])), axis=0) 149 | xs = np.delete(xs, x_out) 150 | ys = np.delete(ys, x_out) 151 | ts = np.delete(ts, x_out) 152 | ps = np.delete(ps, x_out) 153 | 154 | y_out = np.argwhere(ys < crop[0]) 155 | y_out = np.concatenate((y_out, np.argwhere(ys >= crop[0] + output_res[0])), axis=0) 156 | xs = np.delete(xs, y_out) 157 | ys = np.delete(ys, y_out) 158 | ts = np.delete(ts, y_out) 159 | ps = np.delete(ps, y_out) 160 | ps[ps < 0] = 0 161 | 162 | xs -= crop[1] 163 | ys -= crop[0] 164 | 165 | # images in temporal window 166 | cropped_frames = [] 167 | if args.with_images: 168 | cropped_frames, frame_idx0, frame_idx1 = frames.get_frames( 169 | hf, t, t + args.time_length, crop, output_res 170 | ) 171 | 172 | # store subsequence 173 | ep = H5Packager(args.output_dir + filename + "_" + str(sequence_id) + ".h5") 174 | ep.package_events(xs.tolist(), ys.tolist(), ts.tolist(), ps.tolist()) 175 | 176 | img_cnt = 0 177 | if args.with_images: 178 | for i in range(frame_idx0, frame_idx1): 179 | ep.package_image(cropped_frames[img_cnt], frames.ts[i], img_cnt) 180 | img_cnt += 1 181 | 182 | ep.add_metadata( 183 | len(ps[ps > 0]), 184 | len(ps[ps < 0]), 185 | ts[-1] - ts[0], 186 | ts[0], 187 | ts[-1], 188 | img_cnt, 189 | output_res, 190 | ) 191 | 192 | t += args.time_length 193 | idx0 = idx1 194 | sequence_id += 1 195 | 196 | hf.close() 197 | print("") 198 | -------------------------------------------------------------------------------- /datasets/tools/rosbag_to_h5.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Event-driven Perception for Robotics https://github.com/event-driven-robotics/importRosbag 3 | """ 4 | 5 | from struct import unpack 6 | from struct import error as structError 7 | from tqdm import tqdm 8 | 9 | import glob 10 | import argparse 11 | import os 12 | import h5py 13 | import numpy as np 14 | from h5_packager import H5Packager 15 | 16 | from messageTypes.common import unpack_header 17 | from messageTypes.dvs_msgs_EventArray import ( 18 | importTopic as import_dvs_msgs_EventArray, 19 | ) 20 | from messageTypes.esim_msgs_OpticFlow import ( 21 | importTopic as import_esim_msgs_OpticFlow, 22 | ) 23 | from messageTypes.geometry_msgs_PoseStamped import ( 24 | importTopic as import_geometry_msgs_PoseStamped, 25 | ) 26 | from messageTypes.geometry_msgs_Transform import ( 27 | importTopic as import_geometry_msgs_Transform, 28 | ) 29 | from messageTypes.geometry_msgs_TransformStamped import ( 30 | importTopic as import_geometry_msgs_TransformStamped, 31 | ) 32 | from messageTypes.geometry_msgs_TwistStamped import ( 33 | importTopic as import_geometry_msgs_TwistStamped, 34 | ) 35 | from messageTypes.sensor_msgs_CameraInfo import ( 36 | importTopic as import_sensor_msgs_CameraInfo, 37 | ) 38 | from messageTypes.sensor_msgs_Image import ( 39 | importTopic as import_sensor_msgs_Image, 40 | ) 41 | from messageTypes.sensor_msgs_Imu import importTopic as import_sensor_msgs_Imu 42 | from messageTypes.sensor_msgs_PointCloud2 import ( 43 | importTopic as import_sensor_msgs_PointCloud2, 44 | ) 45 | from messageTypes.tf_tfMessage import importTopic as import_tf_tfMessage 46 | 47 | 48 | def import_topic(topic, **kwargs): 49 | msgs = topic["msgs"] 50 | topic_type = topic["type"].replace("/", "_") 51 | if topic_type == "dvs_msgs_EventArray": 52 | topic_dict = import_dvs_msgs_EventArray(msgs, **kwargs) 53 | elif topic_type == "esim_msgs_OpticFlow": 54 | topic_dict = import_esim_msgs_OpticFlow(msgs, **kwargs) 55 | elif topic_type == "geometry_msgs_PoseStamped": 56 | topic_dict = import_geometry_msgs_PoseStamped(msgs, **kwargs) 57 | elif topic_type == "geometry_msgs_Transform": 58 | topic_dict = import_geometry_msgs_Transform(msgs, **kwargs) 59 | elif topic_type == "geometry_msgs_TransformStamped": 60 | topic_dict = import_geometry_msgs_TransformStamped(msgs, **kwargs) 61 | elif topic_type == "geometry_msgs_TwistStamped": 62 | topic_dict = import_geometry_msgs_TwistStamped(msgs, **kwargs) 63 | elif topic_type == "sensor_msgs_CameraInfo": 64 | topic_dict = import_sensor_msgs_CameraInfo(msgs, **kwargs) 65 | elif topic_type == "sensor_msgs_Image": 66 | topic_dict = import_sensor_msgs_Image(msgs, **kwargs) 67 | elif topic_type == "sensor_msgs_Imu": 68 | topic_dict = import_sensor_msgs_Imu(msgs, **kwargs) 69 | elif topic_type == "sensor_msgs_PointCloud2": 70 | topic_dict = import_sensor_msgs_PointCloud2(msgs, **kwargs) 71 | elif topic_type == "tf_tfMessage": 72 | topic_dict = import_tf_tfMessage(msgs, **kwargs) 73 | else: 74 | return None 75 | if topic_dict: 76 | topic_dict["rosbagType"] = topic["type"] 77 | return topic_dict 78 | 79 | 80 | def read_file(filename): 81 | print("Attempting to import " + filename + " as a rosbag 2.0 file.") 82 | with open(filename, "rb") as file: 83 | # File format string 84 | file_format = file.readline().decode("utf-8") 85 | print("ROSBAG file format: " + file_format) 86 | if file_format != "#ROSBAG V2.0\n": 87 | print("This file format might not be supported") 88 | eof = False 89 | conns = [] 90 | chunks = [] 91 | while not eof: 92 | # Read a record header 93 | try: 94 | header_len = unpack("=l", file.read(4))[0] 95 | except structError: 96 | if len(file.read(1)) == 0: # Distinguish EOF from other struct errors 97 | # a struct error could also occur if the data is downloaded by one os and read by another. 98 | eof = True 99 | continue 100 | # unpack the header into fields 101 | header_bytes = file.read(header_len) 102 | fields = unpack_header(header_len, header_bytes) 103 | # Read the record data 104 | data_len = unpack("=l", file.read(4))[0] 105 | data = file.read(data_len) 106 | # The op code tells us what to do with the record 107 | op = unpack("=b", fields["op"])[0] 108 | fields["op"] = op 109 | if op == 2: 110 | # It's a message 111 | # AFAIK these are not found unpacked in the file 112 | # fields['data'] = data 113 | # msgs.append(fields) 114 | pass 115 | elif op == 3: 116 | # It's a bag header - use this to do progress bar for the read 117 | chunk_count = unpack("=l", fields["chunk_count"])[0] 118 | pbar = tqdm(total=chunk_count, position=0, leave=True) 119 | elif op == 4: 120 | # It's an index - this is used to index the previous chunk 121 | conn = unpack("=l", fields["conn"])[0] 122 | count = unpack("=l", fields["count"])[0] 123 | for idx in range(count): 124 | time, offset = unpack("=ql", data[idx * 12 : idx * 12 + 12]) 125 | chunks[-1]["ids"].append((conn, time, offset)) 126 | elif op == 5: 127 | # It's a chunk 128 | fields["data"] = data 129 | fields["ids"] = [] 130 | chunks.append(fields) 131 | pbar.update(len(chunks)) 132 | elif op == 6: 133 | # It's a chunk-info - seems to be redundant 134 | pass 135 | elif op == 7: 136 | # It's a conn 137 | # interpret data as a string containing the connection header 138 | conn_fields = unpack_header(data_len, data) 139 | conn_fields.update(fields) 140 | conn_fields["conn"] = unpack("=l", conn_fields["conn"])[0] 141 | conn_fields["topic"] = conn_fields["topic"].decode("utf-8") 142 | conn_fields["type"] = conn_fields["type"].decode("utf-8").replace("/", "_") 143 | conns.append(conn_fields) 144 | return conns, chunks 145 | 146 | 147 | def break_chunks_into_msgs(chunks): 148 | msgs = [] 149 | for chunk in tqdm(chunks, position=0, leave=True): 150 | for idx in chunk["ids"]: 151 | ptr = idx[2] 152 | header_len = unpack("=l", chunk["data"][ptr : ptr + 4])[0] 153 | ptr += 4 154 | # unpack the header into fields 155 | header_bytes = chunk["data"][ptr : ptr + header_len] 156 | ptr += header_len 157 | fields = unpack_header(header_len, header_bytes) 158 | # Read the record data 159 | data_len = unpack("=l", chunk["data"][ptr : ptr + 4])[0] 160 | ptr += 4 161 | fields["data"] = chunk["data"][ptr : ptr + data_len] 162 | fields["conn"] = unpack("=l", fields["conn"])[0] 163 | msgs.append(fields) 164 | return msgs 165 | 166 | 167 | def rekey_conns_by_topic(conn_dict): 168 | topics = {} 169 | for conn in conn_dict: 170 | topics[conn_dict[conn]["topic"]] = conn_dict[conn] 171 | return topics 172 | 173 | 174 | def import_rosbag(filename, **kwargs): 175 | print("Importing file: ", filename) 176 | conns, chunks = read_file(filename) 177 | # Restructure conns as a dictionary keyed by conn number 178 | conn_dict = {} 179 | for conn in conns: 180 | conn_dict[conn["conn"]] = conn 181 | conn["msgs"] = [] 182 | if kwargs.get("listTopics", False): 183 | topics = rekey_conns_by_topic(conn_dict) 184 | print("Topics in the file are (with types):") 185 | for topicKey, topic in topics.items(): 186 | del topic["conn"] 187 | del topic["md5sum"] 188 | del topic["msgs"] 189 | del topic["op"] 190 | del topic["topic"] 191 | topic["message_definition"] = topic["message_definition"].decode("utf-8") 192 | print(" " + topicKey + " --- " + topic["type"]) 193 | return topics 194 | msgs = break_chunks_into_msgs(chunks) 195 | for msg in msgs: 196 | conn_dict[msg["conn"]]["msgs"].append(msg) 197 | topics = rekey_conns_by_topic(conn_dict) 198 | 199 | imported_topics = {} 200 | import_topics = kwargs.get("import_topics") 201 | import_types = kwargs.get("import_types") 202 | if import_topics is not None: 203 | for topic_to_import in import_topics: 204 | for topic_in_file in topics.keys(): 205 | if topic_in_file == topic_to_import: 206 | imported_topic = import_topic(topics[topic_in_file], **kwargs) 207 | if imported_topic is not None: 208 | imported_topics[topic_to_import] = imported_topic 209 | del topics[topic_in_file] 210 | elif import_types is not None: 211 | for type_to_import in import_types: 212 | type_to_import = type_to_import.replace("/", "_") 213 | for topic_in_file in list(topics.keys()): 214 | if topics[topic_in_file]["type"].replace("/", "_") == type_to_import: 215 | imported_topic = import_topic(topics[topic_in_file], **kwargs) 216 | if imported_topic is not None: 217 | imported_topics[topic_in_file] = imported_topic 218 | del topics[topic_in_file] 219 | else: # import everything 220 | for topic_in_file in list(topics.keys()): 221 | imported_topic = import_topic(topics[topic_in_file], **kwargs) 222 | if imported_topic is not None: 223 | imported_topics[topic_in_file] = imported_topic 224 | del topics[topic_in_file] 225 | 226 | print() 227 | if imported_topics: 228 | print("Topics imported are:") 229 | for topic in imported_topics.keys(): 230 | print(topic + " --- " + imported_topics[topic]["rosbagType"]) 231 | # del imported_topics[topic]['rosbagType'] 232 | print() 233 | 234 | if topics: 235 | print("Topics not imported are:") 236 | for topic in topics.keys(): 237 | print(topic + " --- " + topics[topic]["type"]) 238 | print() 239 | 240 | return imported_topics 241 | 242 | 243 | def extract_rosbag( 244 | rosbag_path, 245 | output_path, 246 | event_topic, 247 | image_topic=None, 248 | start_time=None, 249 | end_time=None, 250 | packager=H5Packager, 251 | ): 252 | ep = packager(output_path) 253 | t0 = -1 254 | sensor_size = None 255 | if not os.path.exists(rosbag_path): 256 | print("{} does not exist!".format(rosbag_path)) 257 | return 258 | 259 | # import rosbag 260 | bag = import_rosbag(rosbag_path) 261 | 262 | max_events = 10000000 263 | xs, ys, ts, ps = [], [], [], [] 264 | num_pos, num_neg, last_ts, img_cnt = 0, 0, 0, 0 265 | 266 | # event topic 267 | print("Processing events...") 268 | for i in range(0, len(bag[event_topic]["ts"])): 269 | timestamp = bag[event_topic]["ts"][i] 270 | if i == 0: 271 | t0 = timestamp 272 | last_ts = timestamp 273 | 274 | xs.append(bag[event_topic]["x"][i]) 275 | ys.append(bag[event_topic]["y"][i]) 276 | ts.append(timestamp) 277 | ps.append(1 if bag[event_topic]["pol"][i] else 0) 278 | 279 | if len(xs) == max_events: 280 | ep.package_events(xs, ys, ts, ps) 281 | del xs[:] 282 | del ys[:] 283 | del ts[:] 284 | del ps[:] 285 | print(timestamp - t0) 286 | 287 | if bag[event_topic]["pol"][i]: 288 | num_pos += 1 289 | else: 290 | num_neg += 1 291 | last_ts = timestamp 292 | 293 | if sensor_size is None: 294 | sensor_size = [max(xs) + 1, max(ys) + 1] 295 | print("Sensor size inferred from events as {}".format(sensor_size)) 296 | 297 | ep.package_events(xs, ys, ts, ps) 298 | del xs[:] 299 | del ys[:] 300 | del ts[:] 301 | del ps[:] 302 | 303 | # image topic 304 | if image_topic is not None: 305 | print("Processing images...") 306 | for i in range(0, len(bag[image_topic]["ts"])): 307 | timestamp = bag[image_topic]["ts"][i] 308 | t0 = timestamp if timestamp < t0 else t0 309 | last_ts = timestamp if timestamp > last_ts else last_ts 310 | image = bag[image_topic]["frames"][i] 311 | ep.package_image(image, timestamp, img_cnt) 312 | sensor_size = image.shape 313 | img_cnt += 1 314 | 315 | ep.add_metadata( 316 | num_pos, 317 | num_neg, 318 | last_ts - t0, 319 | t0, 320 | last_ts, 321 | img_cnt, 322 | sensor_size, 323 | ) 324 | 325 | 326 | def extract_rosbags(rosbag_paths, output_dir, event_topic, image_topic): 327 | for path in rosbag_paths: 328 | bagname = os.path.splitext(os.path.basename(path))[0] 329 | out_path = os.path.join(output_dir, "{}.h5".format(bagname)) 330 | print("Extracting {} to {}".format(path, out_path)) 331 | extract_rosbag(path, out_path, event_topic, image_topic=image_topic) 332 | 333 | 334 | if __name__ == "__main__": 335 | """ 336 | Tool for converting rosbag events to an efficient HDF5 format that can be speedily 337 | accessed by python code. 338 | """ 339 | parser = argparse.ArgumentParser() 340 | parser.add_argument("path", help="ROS bag file to extract or directory containing bags") 341 | parser.add_argument( 342 | "--output_dir", 343 | default="/tmp/extracted_data", 344 | help="Folder where to extract the data", 345 | ) 346 | parser.add_argument("--event_topic", default="/dvs/events", help="Event topic") 347 | parser.add_argument( 348 | "--image_topic", 349 | default=None, 350 | help="Image topic (if left empty, no images will be collected)", 351 | ) 352 | args = parser.parse_args() 353 | 354 | print("Data will be extracted in folder: {}".format(args.output_dir)) 355 | if not os.path.exists(args.output_dir): 356 | os.makedirs(args.output_dir) 357 | if os.path.isdir(args.path): 358 | rosbag_paths = sorted(glob.glob(os.path.join(args.path, "*.bag"))) 359 | else: 360 | rosbag_paths = [args.path] 361 | extract_rosbags(rosbag_paths, args.output_dir, args.event_topic, args.image_topic) 362 | -------------------------------------------------------------------------------- /eval_flow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from torch.optim import * 6 | 7 | from configs.parser import YAMLParser 8 | from dataloader.h5 import H5Loader 9 | from models.model import FireFlowNet, EVFlowNet 10 | from utils.iwe import deblur_events, compute_pol_iwe 11 | from utils.utils import load_model 12 | from utils.visualization import Visualization 13 | 14 | 15 | def test(args, config_parser): 16 | config = config_parser.merge_configs(args.trained_model) 17 | config["loader"]["batch_size"] = 1 18 | config["vis"]["bars"] = True 19 | 20 | # store validation settings 21 | eval_id = config_parser.log_eval_config(config) 22 | 23 | # initialize settings 24 | device = config_parser.device 25 | kwargs = config_parser.loader_kwargs 26 | 27 | # visualization tool 28 | if config["vis"]["enabled"] or config["vis"]["store"]: 29 | vis = Visualization(config, eval_id=eval_id) 30 | 31 | # optical flow settings 32 | num_bins = config["data"]["num_bins"] 33 | flow_scaling = config["model_flow"]["flow_scaling"] 34 | model = eval(config["model_flow"]["name"])(config["model_flow"], num_bins).to(device) 35 | model = load_model(config["trained_model"], model, device) 36 | model.eval() 37 | 38 | # data loader 39 | data = H5Loader(config, num_bins) 40 | dataloader = torch.utils.data.DataLoader( 41 | data, 42 | drop_last=True, 43 | batch_size=config["loader"]["batch_size"], 44 | collate_fn=data.custom_collate, 45 | worker_init_fn=config_parser.worker_init_fn, 46 | **kwargs 47 | ) 48 | 49 | # inference loop 50 | end_test = False 51 | with torch.no_grad(): 52 | while True: 53 | for inputs in dataloader: 54 | 55 | # finish inference loop 56 | if data.seq_num >= len(data.files): 57 | end_test = True 58 | break 59 | 60 | # forward pass 61 | x = model(inputs["inp_voxel"].to(device), inputs["inp_cnt"].to(device)) 62 | 63 | # image of warped events 64 | iwe = compute_pol_iwe( 65 | x["flow"][-1], 66 | inputs["inp_list"].to(device), 67 | config["loader"]["resolution"], 68 | inputs["inp_pol_mask"][:, :, 0:1].to(device), 69 | inputs["inp_pol_mask"][:, :, 1:2].to(device), 70 | flow_scaling=flow_scaling, 71 | round_idx=False, 72 | ) 73 | 74 | # visualize 75 | for bar in data.open_files_bar: 76 | bar.next() 77 | if config["vis"]["enabled"]: 78 | vis.update(inputs, x["flow"][-1], iwe, None) 79 | if config["vis"]["store"]: 80 | sequence = data.files[data.batch_idx[0] % len(data.files)].split("/")[-1].split(".")[0] 81 | vis.store(inputs, x["flow"][-1], iwe, None, sequence, ts=data.last_proc_timestamp) 82 | 83 | if end_test: 84 | break 85 | 86 | for bar in data.open_files_bar: 87 | bar.finish() 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("trained_model", help="model to be evaluated") 93 | parser.add_argument( 94 | "--config", 95 | default="configs/eval_flow.yml", 96 | help="config file, overwrites model settings", 97 | ) 98 | args = parser.parse_args() 99 | 100 | # launch testing 101 | test(args, YAMLParser(args.config)) 102 | -------------------------------------------------------------------------------- /eval_reconstruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from configs.parser import YAMLParser 10 | from dataloader.h5 import H5Loader 11 | from models.model import FireFlowNet, EVFlowNet 12 | from models.model import FireNet, E2VID 13 | from utils.utils import load_model 14 | from utils.visualization import Visualization 15 | 16 | 17 | def test(args, config_parser): 18 | config = config_parser.merge_configs(args.trained_model) 19 | config["loader"]["batch_size"] = 1 20 | config["vis"]["bars"] = True 21 | 22 | # store validation settings 23 | eval_id = config_parser.log_eval_config(config) 24 | 25 | # initialize settings 26 | device = config_parser.device 27 | kwargs = config_parser.loader_kwargs 28 | num_bins = config["data"]["num_bins"] 29 | 30 | # visualization tool 31 | if config["vis"]["enabled"] or config["vis"]["store"]: 32 | vis = Visualization(config, eval_id=eval_id) 33 | 34 | # data loader 35 | data = H5Loader(config, num_bins) 36 | dataloader = torch.utils.data.DataLoader( 37 | data, 38 | drop_last=True, 39 | batch_size=config["loader"]["batch_size"], 40 | collate_fn=data.custom_collate, 41 | worker_init_fn=config_parser.worker_init_fn, 42 | **kwargs 43 | ) 44 | 45 | # reconstruction settings 46 | model_reconstruction = eval(config["model_reconstruction"]["name"])(config["model_reconstruction"], num_bins).to( 47 | device 48 | ) 49 | model_reconstruction = load_model(config["trained_model"], model_reconstruction, device) 50 | model_reconstruction.eval() 51 | 52 | # optical flow settings 53 | flow_eval = config["model_flow"]["eval"] 54 | if flow_eval: 55 | model_flow = eval(config["model_flow"]["name"])(config["model_flow"], num_bins).to(device) 56 | model_flow = load_model(config["trained_model"], model_flow, device) 57 | model_flow.eval() 58 | 59 | # inference loop 60 | x_flow = {} 61 | x_flow["flow"] = [None] 62 | end_test = False 63 | with torch.no_grad(): 64 | while True: 65 | for inputs in dataloader: 66 | 67 | # reset states 68 | if data.new_seq: 69 | data.new_seq = False 70 | model_reconstruction.reset_states() 71 | 72 | # finish inference loop 73 | if data.seq_num >= len(data.files): 74 | end_test = True 75 | break 76 | 77 | # flow - forward pass 78 | if flow_eval: 79 | x_flow = model_flow(inputs["inp_voxel"].to(device), inputs["inp_cnt"].to(device)) 80 | 81 | # reconstruction - forward pass 82 | x_reconstruction = model_reconstruction(inputs["inp_voxel"].to(device)) 83 | 84 | # visualize 85 | if config["vis"]["bars"]: 86 | for bar in data.open_files_bar: 87 | bar.next() 88 | if config["vis"]["enabled"]: 89 | vis.update(inputs, x_flow["flow"][-1], None, x_reconstruction["image"]) 90 | if config["vis"]["store"]: 91 | sequence = data.files[data.batch_idx[0] % len(data.files)].split("/")[-1].split(".")[0] 92 | vis.store( 93 | inputs, 94 | x_flow["flow"][-1], 95 | None, 96 | x_reconstruction["image"], 97 | sequence, 98 | ts=data.last_proc_timestamp, 99 | ) 100 | 101 | if end_test: 102 | break 103 | 104 | if config["vis"]["bars"]: 105 | for bar in data.open_files_bar: 106 | bar.finish() 107 | 108 | 109 | if __name__ == "__main__": 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument("trained_model", help="model to be evaluated") 112 | parser.add_argument( 113 | "--config", 114 | default="configs/eval_reconstruction.yml", 115 | help="config file, overwrites model settings", 116 | ) 117 | args = parser.parse_args() 118 | 119 | # launch testing 120 | test(args, YAMLParser(args.config)) 121 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/ssl_e2vid/ef9d61b476576f60f92029cadde4286c6cf0f734/loss/__init__.py -------------------------------------------------------------------------------- /loss/flow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 10 | sys.path.append(parent_dir_name) 11 | 12 | from utils.iwe import purge_unfeasible, get_interpolation, interpolate 13 | 14 | 15 | class EventWarping(nn.Module): 16 | """ 17 | Contrast maximization loss, as described in Section 3.2 of the paper 'Unsupervised Event-based Learning 18 | of Optical Flow, Depth, and Egomotion', Zhu et al., CVPR'19. 19 | The contrast maximization loss is the minimization of the per-pixel and per-polarity image of averaged 20 | timestamps of the input events after they have been compensated for their motion using the estimated 21 | optical flow. This minimization is performed in a forward and in a backward fashion to prevent scaling 22 | issues during backpropagation. 23 | """ 24 | 25 | def __init__(self, config, device): 26 | super(EventWarping, self).__init__() 27 | self.res = config["loader"]["resolution"] 28 | self.flow_scaling = max(config["loader"]["resolution"]) 29 | self.weight = config["loss"]["flow_regul_weight"] 30 | self.device = device 31 | 32 | def forward(self, flow_list, event_list, pol_mask): 33 | """ 34 | :param flow_list: [[batch_size x 2 x H x W]] list of optical flow maps 35 | :param event_list: [batch_size x N x 4] input events (y, x, ts, p) 36 | :param pol_mask: [batch_size x N x 2] per-polarity binary mask of the input events 37 | """ 38 | 39 | # split input 40 | pol_mask = torch.cat([pol_mask for i in range(4)], dim=1) 41 | ts_list = torch.cat([event_list[:, :, 0:1] for i in range(4)], dim=1) 42 | 43 | # flow vector per input event 44 | flow_idx = event_list[:, :, 1:3].clone() 45 | flow_idx[:, :, 0] *= self.res[1] # torch.view is row-major 46 | flow_idx = torch.sum(flow_idx, dim=2) 47 | 48 | loss = 0 49 | for flow in flow_list: 50 | 51 | # get flow for every event in the list 52 | flow = flow.view(flow.shape[0], 2, -1) 53 | event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component 54 | event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component 55 | event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1) 56 | event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1) 57 | event_flow = torch.cat([event_flowy, event_flowx], dim=2) 58 | 59 | # interpolate forward 60 | tref = 1 61 | fw_idx, fw_weights = get_interpolation(event_list, event_flow, tref, self.res, self.flow_scaling) 62 | 63 | # per-polarity image of (forward) warped events 64 | fw_iwe_pos = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 0:1]) 65 | fw_iwe_neg = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 1:2]) 66 | 67 | # image of (forward) warped averaged timestamps 68 | fw_iwe_pos_ts = interpolate( 69 | fw_idx.long(), fw_weights * ts_list, self.res, polarity_mask=pol_mask[:, :, 0:1] 70 | ) 71 | fw_iwe_neg_ts = interpolate( 72 | fw_idx.long(), fw_weights * ts_list, self.res, polarity_mask=pol_mask[:, :, 1:2] 73 | ) 74 | fw_iwe_pos_ts /= fw_iwe_pos + 1e-9 75 | fw_iwe_neg_ts /= fw_iwe_neg + 1e-9 76 | 77 | # interpolate backward 78 | tref = 0 79 | bw_idx, bw_weights = get_interpolation(event_list, event_flow, tref, self.res, self.flow_scaling) 80 | 81 | # per-polarity image of (backward) warped events 82 | bw_iwe_pos = interpolate(bw_idx.long(), bw_weights, self.res, polarity_mask=pol_mask[:, :, 0:1]) 83 | bw_iwe_neg = interpolate(bw_idx.long(), bw_weights, self.res, polarity_mask=pol_mask[:, :, 1:2]) 84 | 85 | # image of (backward) warped averaged timestamps 86 | bw_iwe_pos_ts = interpolate( 87 | bw_idx.long(), bw_weights * (1 - ts_list), self.res, polarity_mask=pol_mask[:, :, 0:1] 88 | ) 89 | bw_iwe_neg_ts = interpolate( 90 | bw_idx.long(), bw_weights * (1 - ts_list), self.res, polarity_mask=pol_mask[:, :, 1:2] 91 | ) 92 | bw_iwe_pos_ts /= bw_iwe_pos + 1e-9 93 | bw_iwe_neg_ts /= bw_iwe_neg + 1e-9 94 | 95 | # flow smoothing 96 | flow = flow.view(flow.shape[0], 2, self.res[0], self.res[1]) 97 | flow_dx = flow[:, :, :-1, :] - flow[:, :, 1:, :] 98 | flow_dy = flow[:, :, :, :-1] - flow[:, :, :, 1:] 99 | flow_dx = torch.sqrt(flow_dx ** 2 + 1e-6) # charbonnier 100 | flow_dy = torch.sqrt(flow_dy ** 2 + 1e-6) # charbonnier 101 | 102 | loss += ( 103 | torch.sum(fw_iwe_pos_ts ** 2) 104 | + torch.sum(fw_iwe_neg_ts ** 2) 105 | + torch.sum(bw_iwe_pos_ts ** 2) 106 | + torch.sum(bw_iwe_neg_ts ** 2) 107 | + self.weight * (flow_dx.sum() + flow_dy.sum()) 108 | ) 109 | 110 | return loss 111 | 112 | 113 | class AveragedIWE(nn.Module): 114 | """ 115 | Returns an image of the per-pixel and per-polarity average number of warped events given 116 | an optical flow map. 117 | """ 118 | 119 | def __init__(self, config, device): 120 | super(AveragedIWE, self).__init__() 121 | self.res = config["loader"]["resolution"] 122 | self.flow_scaling = max(config["loader"]["resolution"]) 123 | self.batch_size = config["loader"]["batch_size"] 124 | self.device = device 125 | 126 | def forward(self, flow, event_list, pol_mask): 127 | """ 128 | :param flow: [batch_size x 2 x H x W] optical flow maps 129 | :param event_list: [batch_size x N x 4] input events (y, x, ts, p) 130 | :param pol_mask: [batch_size x N x 2] per-polarity binary mask of the input events 131 | """ 132 | 133 | # original location of events 134 | idx = event_list[:, :, 1:3].clone() 135 | idx[:, :, 0] *= self.res[1] # torch.view is row-major 136 | idx = torch.sum(idx, dim=2, keepdim=True) 137 | 138 | # flow vector per input event 139 | flow_idx = event_list[:, :, 1:3].clone() 140 | flow_idx[:, :, 0] *= self.res[1] # torch.view is row-major 141 | flow_idx = torch.sum(flow_idx, dim=2) 142 | 143 | # get flow for every event in the list 144 | flow = flow.view(flow.shape[0], 2, -1) 145 | event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component 146 | event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component 147 | event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1) 148 | event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1) 149 | event_flow = torch.cat([event_flowy, event_flowx], dim=2) 150 | 151 | # interpolate forward 152 | fw_idx, fw_weights = get_interpolation(event_list, event_flow, 1, self.res, self.flow_scaling, round_idx=True) 153 | 154 | # per-polarity image of (forward) warped events 155 | fw_iwe_pos = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 0:1]) 156 | fw_iwe_neg = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 1:2]) 157 | if fw_idx.shape[1] == 0: 158 | return torch.cat([fw_iwe_pos, fw_iwe_neg], dim=1) 159 | 160 | # make sure unfeasible mappings are not considered 161 | pol_list = event_list[:, :, 3:4].clone() 162 | pol_list[pol_list < 1] = 0 # negative polarity set to 0 163 | pol_list[fw_weights == 0] = 2 # fake polarity to detect unfeasible mappings 164 | 165 | # encode unique ID for pixel location mapping (idx <-> fw_idx = m_idx) 166 | m_idx = torch.cat([idx.long(), fw_idx.long()], dim=2) 167 | m_idx[:, :, 0] *= self.res[0] * self.res[1] 168 | m_idx = torch.sum(m_idx, dim=2, keepdim=True) 169 | 170 | # encode unique ID for per-polarity pixel location mapping (pol_list <-> m_idx = pm_idx) 171 | pm_idx = torch.cat([pol_list.long(), m_idx.long()], dim=2) 172 | pm_idx[:, :, 0] *= (self.res[0] * self.res[1]) ** 2 173 | pm_idx = torch.sum(pm_idx, dim=2, keepdim=True) 174 | 175 | # number of different pixels locations from where pixels originate during warping 176 | # this needs to be done per batch as the number of unique indices differs 177 | fw_iwe_pos_contrib = torch.zeros((flow.shape[0], self.res[0] * self.res[1], 1)).to(self.device) 178 | fw_iwe_neg_contrib = torch.zeros((flow.shape[0], self.res[0] * self.res[1], 1)).to(self.device) 179 | for b in range(0, self.batch_size): 180 | 181 | # per-polarity unique mapping combinations 182 | unique_pm_idx = torch.unique(pm_idx[b, :, :], dim=0) 183 | unique_pm_idx = torch.cat( 184 | [ 185 | unique_pm_idx // ((self.res[0] * self.res[1]) ** 2), 186 | unique_pm_idx % ((self.res[0] * self.res[1]) ** 2), 187 | ], 188 | dim=1, 189 | ) # (pol_idx, mapping_idx) 190 | unique_pm_idx = torch.cat( 191 | [unique_pm_idx[:, 0:1], unique_pm_idx[:, 1:2] % (self.res[0] * self.res[1])], dim=1 192 | ) # (pol_idx, fw_idx) 193 | unique_pm_idx[:, 0] *= self.res[0] * self.res[1] 194 | unique_pm_idx = torch.sum(unique_pm_idx, dim=1, keepdim=True) 195 | 196 | # per-polarity unique receiving pixels 197 | unique_pfw_idx, contrib_pfw = torch.unique(unique_pm_idx[:, 0], dim=0, return_counts=True) 198 | unique_pfw_idx = unique_pfw_idx.view((unique_pfw_idx.shape[0], 1)) 199 | contrib_pfw = contrib_pfw.view((contrib_pfw.shape[0], 1)) 200 | unique_pfw_idx = torch.cat( 201 | [unique_pfw_idx // (self.res[0] * self.res[1]), unique_pfw_idx % (self.res[0] * self.res[1])], 202 | dim=1, 203 | ) # (polarity mask, fw_idx) 204 | 205 | # positive scatter pixel contribution 206 | mask_pos = unique_pfw_idx[:, 0:1].clone() 207 | mask_pos[mask_pos == 2] = 0 # remove unfeasible mappings 208 | b_fw_iwe_pos_contrib = torch.zeros((self.res[0] * self.res[1], 1)).to(self.device) 209 | b_fw_iwe_pos_contrib = b_fw_iwe_pos_contrib.scatter_add_( 210 | 0, unique_pfw_idx[:, 1:2], mask_pos.float() * contrib_pfw.float() 211 | ) 212 | 213 | # negative scatter pixel contribution 214 | mask_neg = unique_pfw_idx[:, 0:1].clone() 215 | mask_neg[mask_neg == 2] = 1 # remove unfeasible mappings 216 | mask_neg = 1 - mask_neg # invert polarities 217 | b_fw_iwe_neg_contrib = torch.zeros((self.res[0] * self.res[1], 1)).to(self.device) 218 | b_fw_iwe_neg_contrib = b_fw_iwe_neg_contrib.scatter_add_( 219 | 0, unique_pfw_idx[:, 1:2], mask_neg.float() * contrib_pfw.float() 220 | ) 221 | 222 | # store info 223 | fw_iwe_pos_contrib[b, :, :] = b_fw_iwe_pos_contrib 224 | fw_iwe_neg_contrib[b, :, :] = b_fw_iwe_neg_contrib 225 | 226 | # average number of warped events per pixel 227 | fw_iwe_pos_contrib = fw_iwe_pos_contrib.view((flow.shape[0], 1, self.res[0], self.res[1])) 228 | fw_iwe_neg_contrib = fw_iwe_neg_contrib.view((flow.shape[0], 1, self.res[0], self.res[1])) 229 | fw_iwe_pos[fw_iwe_pos_contrib > 0] /= fw_iwe_pos_contrib[fw_iwe_pos_contrib > 0] 230 | fw_iwe_neg[fw_iwe_neg_contrib > 0] /= fw_iwe_neg_contrib[fw_iwe_neg_contrib > 0] 231 | 232 | return torch.cat([fw_iwe_pos, fw_iwe_neg], dim=1) 233 | -------------------------------------------------------------------------------- /loss/reconstruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | from .flow import AveragedIWE 9 | 10 | parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 11 | sys.path.append(parent_dir_name) 12 | 13 | from utils.iwe import deblur_events 14 | from utils.gradients import Sobel 15 | 16 | 17 | class BrightnessConstancy(torch.nn.Module): 18 | """ 19 | Self-supervised image reconstruction loss, as described in Section 3.4 of the paper 'Back to Event Basics: 20 | Self-Supervised Image Reconstruction for Event Cameras via Photometric Constancy', Paredes-Valles et al., CVPR'21. 21 | The reconstruction loss is the combination of three components. 22 | 1) Image reconstruction through the generative model of event cameras. The reconstruction error propagates back 23 | through the spatial gradients of the reconstructed images. The loss consists in an L2-norm of the difference of the 24 | brightness increment images that can be obtained through the generative model and by means of event integration. 25 | 2) Temporal consistency. Simple L1-norm of the warping error between two consecutive reconstructed frames. 26 | 3) Image regularization. Conventional total variation formulation. 27 | """ 28 | 29 | def __init__(self, config, device): 30 | super(BrightnessConstancy, self).__init__() 31 | self.sobel = Sobel(device) 32 | self.res = config["loader"]["resolution"] 33 | self.flow_scaling = max(config["loader"]["resolution"]) 34 | self.weights = config["loss"]["reconstruction_regul_weight"] 35 | 36 | col_idx = np.linspace(0, self.res[1] - 1, num=self.res[1]) 37 | row_idx = np.linspace(0, self.res[0] - 1, num=self.res[0]) 38 | mx, my = np.meshgrid(col_idx, row_idx) 39 | indices = np.zeros((1, 2, self.res[0], self.res[1])) 40 | indices[:, 0, :, :] = my 41 | indices[:, 1, :, :] = mx 42 | self.indices = torch.from_numpy(indices).float().to(device) 43 | 44 | self.averaged_iwe = AveragedIWE(config, device) 45 | 46 | def generative_model(self, flow, img, inputs): 47 | """ 48 | :param flow: [batch_size x 2 x H x W] optical flow map 49 | :param img: [batch_size x 1 x H x W] last reconstructed image 50 | :param inputs: dataloader dictionary 51 | :return generative model loss 52 | """ 53 | 54 | event_cnt = inputs["inp_cnt"].to(flow.device) 55 | event_list = inputs["inp_list"].to(flow.device) 56 | pol_mask = inputs["inp_pol_mask"].to(flow.device) 57 | 58 | # mask optical flow with input events 59 | flow_mask = torch.sum(event_cnt, dim=1, keepdim=True) 60 | flow_mask[flow_mask > 0] = 1 61 | flow = flow * flow_mask 62 | 63 | # foward warping metrics 64 | warped_y = self.indices[:, 0:1, :, :] - flow[:, 1:2, :, :] * self.flow_scaling 65 | warped_x = self.indices[:, 1:2, :, :] - flow[:, 0:1, :, :] * self.flow_scaling 66 | warped_y = 2 * warped_y / (self.res[0] - 1) - 1 67 | warped_x = 2 * warped_x / (self.res[1] - 1) - 1 68 | grid_pos = torch.cat([warped_x, warped_y], dim=1).permute(0, 2, 3, 1) 69 | 70 | # warped predicted brightness increment (previous image) 71 | img_gradx, img_grady = self.sobel(img) 72 | warped_img_grady = F.grid_sample(img_grady, grid_pos, mode="bilinear", padding_mode="zeros") 73 | warped_img_gradx = F.grid_sample(img_gradx, grid_pos, mode="bilinear", padding_mode="zeros") 74 | pred_deltaL = warped_img_gradx * flow[:, 0:1, :, :] + warped_img_grady * flow[:, 1:2, :, :] 75 | pred_deltaL = pred_deltaL * self.flow_scaling 76 | 77 | # warped brightness increment from the averaged image of warped events 78 | avg_iwe = self.averaged_iwe(flow, event_list, pol_mask) 79 | event_deltaL = avg_iwe[:, 0:1, :, :] - avg_iwe[:, 1:2, :, :] # C == 1 80 | 81 | # squared L2 norm - brightness constancy error 82 | bc_error = event_deltaL + pred_deltaL 83 | bc_error = ( 84 | torch.norm( 85 | bc_error.view( 86 | bc_error.shape[0], 87 | bc_error.shape[1], 88 | 1, 89 | -1, 90 | ), 91 | p=2, 92 | dim=3, 93 | ) 94 | ** 2 95 | ) # norm in the spatial dimension 96 | 97 | return bc_error.sum() 98 | 99 | def temporal_consistency(self, flow, prev_img, img): 100 | """ 101 | :param flow: [batch_size x 2 x H x W] optical flow map 102 | :param prev_img: [batch_size x 1 x H x W] previous reconstructed image 103 | :param img: [batch_size x 1 x H x W] last reconstructed image 104 | :return weighted temporal consistency loss 105 | """ 106 | 107 | # foward warping metrics 108 | warped_y = self.indices[:, 0:1, :, :] - flow[:, 1:2, :, :] * self.flow_scaling 109 | warped_x = self.indices[:, 1:2, :, :] - flow[:, 0:1, :, :] * self.flow_scaling 110 | warped_y = 2 * warped_y / (self.res[0] - 1) - 1 111 | warped_x = 2 * warped_x / (self.res[1] - 1) - 1 112 | grid_pos = torch.cat([warped_x, warped_y], dim=1).permute(0, 2, 3, 1) 113 | 114 | # temporal consistency 115 | warped_prev_img = F.grid_sample(prev_img, grid_pos, mode="bilinear", padding_mode="zeros") 116 | tc_error = img - warped_prev_img 117 | tc_error = ( 118 | torch.norm( 119 | tc_error.view( 120 | tc_error.shape[0], 121 | tc_error.shape[1], 122 | 1, 123 | -1, 124 | ), 125 | p=1, 126 | dim=3, 127 | ) 128 | ** 1 129 | ) # norm in the spatial dimension 130 | tc_error = tc_error.sum() 131 | 132 | return self.weights[1] * tc_error 133 | 134 | def regularization(self, img): 135 | """ 136 | :param img: [batch_size x 1 x H x W] last reconstructed image 137 | :return weighted image regularization loss 138 | """ 139 | 140 | # conventional total variation with forward differences 141 | img_dx = torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]) 142 | img_dy = torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]) 143 | tv_error = img_dx.sum() + img_dy.sum() 144 | 145 | return self.weights[0] * tv_error 146 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/ssl_e2vid/ef9d61b476576f60f92029cadde4286c6cf0f734/models/__init__.py -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from UZH-RPG https://github.com/uzh-rpg/rpg_e2vid 3 | """ 4 | 5 | from abc import abstractmethod 6 | 7 | import numpy as np 8 | import torch.nn as nn 9 | 10 | 11 | class BaseModel(nn.Module): 12 | """ 13 | Base class for all models 14 | """ 15 | 16 | @abstractmethod 17 | def forward(self, *inputs): 18 | """ 19 | Forward pass logic 20 | 21 | :return: Model output 22 | """ 23 | raise NotImplementedError 24 | 25 | def __str__(self): 26 | """ 27 | Model prints with number of trainable parameters 28 | """ 29 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 30 | params = sum([np.prod(p.size()) for p in model_parameters]) 31 | return super().__str__() + "\nTrainable parameters: {}".format(params) 32 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from UZH-RPG https://github.com/uzh-rpg/rpg_e2vid 3 | """ 4 | 5 | import copy 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .base import BaseModel 11 | from .model_util import copy_states, CropParameters 12 | from .submodules import ResidualBlock, ConvGRU, ConvLayer 13 | from .unet import UNetRecurrent, MultiResUNet 14 | 15 | 16 | class E2VID(BaseModel): 17 | """ 18 | E2VID architecture for image reconstruction from event-data. 19 | "High speed and high dynamic range video with an event camera", Rebecq et al. 2019. 20 | """ 21 | 22 | def __init__(self, unet_kwargs, num_bins): 23 | super().__init__() 24 | self.crop = None 25 | 26 | norm = None 27 | use_upsample_conv = True 28 | final_activation = "none" 29 | if "norm" in unet_kwargs.keys(): 30 | norm = unet_kwargs["norm"] 31 | if "use_upsample_conv" in unet_kwargs.keys(): 32 | use_upsample_conv = unet_kwargs["use_upsample_conv"] 33 | if "final_activation" in unet_kwargs.keys(): 34 | final_activation = unet_kwargs["final_activation"] 35 | 36 | E2VID_kwargs = { 37 | "base_num_channels": unet_kwargs["base_num_channels"], 38 | "num_encoders": 3, 39 | "num_residual_blocks": 2, 40 | "num_output_channels": 1, 41 | "skip_type": "sum", 42 | "norm": norm, 43 | "num_bins": num_bins, 44 | "use_upsample_conv": use_upsample_conv, 45 | "kernel_size": unet_kwargs["kernel_size"], 46 | "channel_multiplier": 2, 47 | "recurrent_block_type": "convlstm", 48 | "final_activation": final_activation, 49 | } 50 | 51 | self.num_encoders = E2VID_kwargs["num_encoders"] 52 | unet_kwargs.update(E2VID_kwargs) 53 | unet_kwargs.pop("name", None) 54 | unet_kwargs.pop("encoding", None) # TODO: remove 55 | self.unetrecurrent = UNetRecurrent(unet_kwargs) 56 | 57 | @property 58 | def states(self): 59 | return copy_states(self.unetrecurrent.states) 60 | 61 | @states.setter 62 | def states(self, states): 63 | self.unetrecurrent.states = states 64 | 65 | def detach_states(self): 66 | detached_states = [] 67 | for state in self.unetrecurrent.states: 68 | if type(state) is tuple: 69 | tmp = [] 70 | for hidden in state: 71 | tmp.append(hidden.detach()) 72 | detached_states.append(tuple(tmp)) 73 | else: 74 | detached_states.append(state.detach()) 75 | self.unetrecurrent.states = detached_states 76 | 77 | def reset_states(self): 78 | self.unetrecurrent.states = [None] * self.unetrecurrent.num_encoders 79 | 80 | def init_cropping(self, width, height, safety_margin=0): 81 | self.crop = CropParameters(width, height, self.num_encoders, safety_margin) 82 | 83 | def forward(self, inp_voxel): 84 | """ 85 | :param inp_voxel: N x num_bins x H x W 86 | :return: [N x 1 X H X W] reconstructed brightness signal. 87 | """ 88 | 89 | # pad input 90 | x = inp_voxel 91 | if self.crop is not None: 92 | x = self.crop.pad(x) 93 | 94 | # forward pass 95 | img = self.unetrecurrent.forward(x) 96 | 97 | # crop output 98 | if self.crop is not None: 99 | img = img[:, :, self.crop.iy0 : self.crop.iy1, self.crop.ix0 : self.crop.ix1] 100 | img = img.contiguous() 101 | 102 | return {"image": img} 103 | 104 | 105 | class FireNet(BaseModel): 106 | """ 107 | FireNet architecture for image reconstruction from event-data. 108 | "Fast image reconstruction with an event camera", Scheerlinck et al., 2019 109 | """ 110 | 111 | def __init__(self, unet_kwargs, num_bins): 112 | super().__init__() 113 | base_num_channels = unet_kwargs["base_num_channels"] 114 | kernel_size = unet_kwargs["kernel_size"] 115 | 116 | padding = kernel_size // 2 117 | self.head = ConvLayer(num_bins, base_num_channels, kernel_size, padding=padding) 118 | self.G1 = ConvGRU(base_num_channels, base_num_channels, kernel_size) 119 | self.R1 = ResidualBlock(base_num_channels, base_num_channels) 120 | self.G2 = ConvGRU(base_num_channels, base_num_channels, kernel_size) 121 | self.R2 = ResidualBlock(base_num_channels, base_num_channels) 122 | self.pred = ConvLayer(base_num_channels, out_channels=1, kernel_size=1, activation=None) 123 | self.num_encoders = 0 # needed by image_reconstructor.py 124 | self.num_recurrent_units = 2 125 | self.reset_states() 126 | 127 | @property 128 | def states(self): 129 | return copy_states(self._states) 130 | 131 | @states.setter 132 | def states(self, states): 133 | self._states = states 134 | 135 | def detach_states(self): 136 | detached_states = [] 137 | for state in self.states: 138 | if type(state) is tuple: 139 | tmp = [] 140 | for hidden in state: 141 | tmp.append(hidden.detach()) 142 | detached_states.append(tuple(tmp)) 143 | else: 144 | detached_states.append(state.detach()) 145 | self.states = detached_states 146 | 147 | def reset_states(self): 148 | self._states = [None] * self.num_recurrent_units 149 | 150 | def init_cropping(self, width, height): 151 | pass 152 | 153 | def forward(self, inp_voxel): 154 | """ 155 | :param inp_voxel: N x num_bins x H x W 156 | :return: [N x 1 X H X W] reconstructed brightness signal. 157 | """ 158 | 159 | # forward pass 160 | x = inp_voxel 161 | x = self.head(x) 162 | x = self.G1(x, self._states[0]) 163 | self._states[0] = x 164 | x = self.R1(x) 165 | x = self.G2(x, self._states[1]) 166 | self._states[1] = x 167 | x = self.R2(x) 168 | return {"image": self.pred(x)} 169 | 170 | 171 | class EVFlowNet(BaseModel): 172 | """ 173 | EV-FlowNet architecture for (dense/sparse) optical flow estimation from event-data. 174 | "EV-FlowNet: Self-Supervised Optical Flow for Event-based Cameras", Zhu et al. 2018. 175 | """ 176 | 177 | def __init__(self, unet_kwargs, num_bins): 178 | super().__init__() 179 | self.crop = None 180 | self.mask = unet_kwargs["mask_output"] 181 | EVFlowNet_kwargs = { 182 | "base_num_channels": unet_kwargs["base_num_channels"], 183 | "num_encoders": 4, 184 | "num_residual_blocks": 2, 185 | "num_output_channels": 2, 186 | "skip_type": "concat", 187 | "norm": None, 188 | "num_bins": num_bins, 189 | "use_upsample_conv": True, 190 | "kernel_size": unet_kwargs["kernel_size"], 191 | "channel_multiplier": 2, 192 | "final_activation": "tanh", 193 | } 194 | self.num_encoders = EVFlowNet_kwargs["num_encoders"] 195 | unet_kwargs.update(EVFlowNet_kwargs) 196 | unet_kwargs.pop("name", None) 197 | unet_kwargs.pop("eval", None) 198 | unet_kwargs.pop("encoding", None) # TODO: remove 199 | unet_kwargs.pop("mask_output", None) 200 | unet_kwargs.pop("mask_smoothing", None) # TODO: remove 201 | if "flow_scaling" in unet_kwargs.keys(): 202 | unet_kwargs.pop("flow_scaling", None) 203 | 204 | self.multires_unet = MultiResUNet(unet_kwargs) 205 | 206 | def reset_states(self): 207 | pass 208 | 209 | def init_cropping(self, width, height, safety_margin=0): 210 | self.crop = CropParameters(width, height, self.num_encoders, safety_margin) 211 | 212 | def forward(self, inp_voxel, inp_cnt): 213 | """ 214 | :param inp_voxel: N x num_bins x H x W 215 | :return: output dict with list of [N x 2 X H X W] (x, y) displacement within event_tensor. 216 | """ 217 | 218 | # pad input 219 | x = inp_voxel 220 | if self.crop is not None: 221 | x = self.crop.pad(x) 222 | 223 | # forward pass 224 | multires_flow = self.multires_unet.forward(x) 225 | 226 | # upsample flow estimates to the original input resolution 227 | flow_list = [] 228 | for flow in multires_flow: 229 | flow_list.append( 230 | torch.nn.functional.interpolate( 231 | flow, 232 | scale_factor=( 233 | multires_flow[-1].shape[2] / flow.shape[2], 234 | multires_flow[-1].shape[3] / flow.shape[3], 235 | ), 236 | ) 237 | ) 238 | 239 | # crop output 240 | if self.crop is not None: 241 | for i, flow in enumerate(flow_list): 242 | flow_list[i] = flow[:, :, self.crop.iy0 : self.crop.iy1, self.crop.ix0 : self.crop.ix1] 243 | flow_list[i] = flow_list[i].contiguous() 244 | 245 | # mask flow 246 | if self.mask: 247 | mask = torch.sum(inp_cnt, dim=1, keepdim=True) 248 | mask[mask > 0] = 1 249 | for i, flow in enumerate(flow_list): 250 | flow_list[i] = flow * mask 251 | 252 | return {"flow": flow_list} 253 | 254 | 255 | class FireFlowNet(BaseModel): 256 | """ 257 | FireFlowNet architecture for (dense/sparse) optical flow estimation from event-data. 258 | "Back to Event Basics: Self Supervised Learning of Image Reconstruction from Event Data via Photometric Constancy", Paredes-Valles et al., 2020 259 | """ 260 | 261 | def __init__(self, unet_kwargs, num_bins): 262 | super().__init__() 263 | base_num_channels = unet_kwargs["base_num_channels"] 264 | kernel_size = unet_kwargs["kernel_size"] 265 | self.mask = unet_kwargs["mask_output"] 266 | 267 | padding = kernel_size // 2 268 | self.E1 = ConvLayer(num_bins, base_num_channels, kernel_size, padding=padding) 269 | self.E2 = ConvLayer(base_num_channels, base_num_channels, kernel_size, padding=padding) 270 | self.R1 = ResidualBlock(base_num_channels, base_num_channels) 271 | self.E3 = ConvLayer(base_num_channels, base_num_channels, kernel_size, padding=padding) 272 | self.R2 = ResidualBlock(base_num_channels, base_num_channels) 273 | self.pred = ConvLayer(base_num_channels, out_channels=2, kernel_size=1, activation="tanh") 274 | 275 | def reset_states(self): 276 | pass 277 | 278 | def init_cropping(self, width, height): 279 | pass 280 | 281 | def forward(self, inp_voxel, inp_cnt): 282 | """ 283 | :param inp_voxel: N x num_bins x H x W 284 | :return: output dict with list of [N x 2 X H X W] (x, y) displacement within event_tensor. 285 | """ 286 | 287 | # forward pass 288 | x = inp_voxel 289 | x = self.E1(x) 290 | x = self.E2(x) 291 | x = self.R1(x) 292 | x = self.E3(x) 293 | x = self.R2(x) 294 | flow = self.pred(x) 295 | 296 | # mask flow 297 | if self.mask: 298 | mask = torch.sum(inp_cnt, dim=1, keepdim=True) 299 | mask[mask > 0] = 1 300 | flow = flow * mask 301 | 302 | return {"flow": [flow]} 303 | -------------------------------------------------------------------------------- /models/model_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from UZH-RPG https://github.com/uzh-rpg/rpg_e2vid 3 | """ 4 | 5 | import os 6 | import copy 7 | from math import fabs, ceil, floor 8 | 9 | import numpy as np 10 | import torch 11 | from torch.nn import ZeroPad2d 12 | 13 | 14 | def skip_concat(x1, x2): 15 | diffY = x2.size()[2] - x1.size()[2] 16 | diffX = x2.size()[3] - x1.size()[3] 17 | padding = ZeroPad2d((diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) 18 | x1 = padding(x1) 19 | return torch.cat([x1, x2], dim=1) 20 | 21 | 22 | def skip_sum(x1, x2): 23 | diffY = x2.size()[2] - x1.size()[2] 24 | diffX = x2.size()[3] - x1.size()[3] 25 | padding = ZeroPad2d((diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) 26 | x1 = padding(x1) 27 | return x1 + x2 28 | 29 | 30 | def optimal_crop_size(max_size, max_subsample_factor, safety_margin=0): 31 | """ 32 | Find the optimal crop size for a given max_size and subsample_factor. 33 | The optimal crop size is the smallest integer which is greater or equal than max_size, 34 | while being divisible by 2^max_subsample_factor. 35 | """ 36 | crop_size = int(pow(2, max_subsample_factor) * ceil(max_size / pow(2, max_subsample_factor))) 37 | crop_size += safety_margin * pow(2, max_subsample_factor) 38 | return crop_size 39 | 40 | 41 | class CropParameters: 42 | """ 43 | Helper class to compute and store useful parameters for pre-processing and post-processing 44 | of images in and out of E2VID. 45 | Pre-processing: finding the best image size for the network, and padding the input image with zeros 46 | Post-processing: Crop the output image back to the original image size 47 | """ 48 | 49 | def __init__(self, width, height, num_encoders, safety_margin=0): 50 | 51 | self.height = height 52 | self.width = width 53 | self.num_encoders = num_encoders 54 | self.width_crop_size = optimal_crop_size(self.width, num_encoders, safety_margin) 55 | self.height_crop_size = optimal_crop_size(self.height, num_encoders, safety_margin) 56 | 57 | self.padding_top = ceil(0.5 * (self.height_crop_size - self.height)) 58 | self.padding_bottom = floor(0.5 * (self.height_crop_size - self.height)) 59 | self.padding_left = ceil(0.5 * (self.width_crop_size - self.width)) 60 | self.padding_right = floor(0.5 * (self.width_crop_size - self.width)) 61 | self.pad = ZeroPad2d( 62 | ( 63 | self.padding_left, 64 | self.padding_right, 65 | self.padding_top, 66 | self.padding_bottom, 67 | ) 68 | ) 69 | 70 | self.cx = floor(self.width_crop_size / 2) 71 | self.cy = floor(self.height_crop_size / 2) 72 | 73 | self.ix0 = self.cx - floor(self.width / 2) 74 | self.ix1 = self.cx + ceil(self.width / 2) 75 | self.iy0 = self.cy - floor(self.height / 2) 76 | self.iy1 = self.cy + ceil(self.height / 2) 77 | 78 | def crop(self, img): 79 | return img[..., self.iy0 : self.iy1, self.ix0 : self.ix1] 80 | 81 | 82 | def recursive_clone(tensor): 83 | """ 84 | Assumes tensor is a torch.tensor with 'clone()' method, possibly 85 | inside nested iterable. 86 | E.g., tensor = [(pytorch_tensor, pytorch_tensor), ...] 87 | """ 88 | if hasattr(tensor, "clone"): 89 | return tensor.clone() 90 | try: 91 | return type(tensor)(recursive_clone(t) for t in tensor) 92 | except TypeError: 93 | print("{} is not iterable and has no clone() method.".format(tensor)) 94 | 95 | 96 | def copy_states(states): 97 | """ 98 | LSTM states: [(torch.tensor, torch.tensor), ...] 99 | GRU states: [torch.tensor, ...] 100 | """ 101 | if states[0] is None: 102 | return copy.deepcopy(states) 103 | return recursive_clone(states) 104 | -------------------------------------------------------------------------------- /models/submodules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from UZH-RPG https://github.com/uzh-rpg/rpg_e2vid 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as f 8 | 9 | 10 | class ConvLayer(nn.Module): 11 | """ 12 | Convolutional layer. 13 | Default: bias, ReLU, no downsampling, no batch norm. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | in_channels, 19 | out_channels, 20 | kernel_size, 21 | stride=1, 22 | padding=0, 23 | activation="relu", 24 | norm=None, 25 | BN_momentum=0.1, 26 | ): 27 | super(ConvLayer, self).__init__() 28 | 29 | bias = False if norm == "BN" else True 30 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 31 | if activation is not None: 32 | self.activation = getattr(torch, activation) 33 | else: 34 | self.activation = None 35 | 36 | self.norm = norm 37 | if norm == "BN": 38 | self.norm_layer = nn.BatchNorm2d(out_channels, momentum=BN_momentum) 39 | elif norm == "IN": 40 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 41 | 42 | def forward(self, x): 43 | out = self.conv2d(x) 44 | 45 | if self.norm in ["BN", "IN"]: 46 | out = self.norm_layer(out) 47 | 48 | if self.activation is not None: 49 | out = self.activation(out) 50 | 51 | return out 52 | 53 | 54 | class TransposedConvLayer(nn.Module): 55 | """ 56 | Transposed convolutional layer to increase spatial resolution (x2) in a decoder. 57 | Default: bias, ReLU, no downsampling, no batch norm. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | in_channels, 63 | out_channels, 64 | kernel_size, 65 | padding=0, 66 | activation="relu", 67 | norm=None, 68 | ): 69 | super(TransposedConvLayer, self).__init__() 70 | 71 | bias = False if norm == "BN" else True 72 | self.transposed_conv2d = nn.ConvTranspose2d( 73 | in_channels, 74 | out_channels, 75 | kernel_size, 76 | stride=2, 77 | padding=padding, 78 | output_padding=1, 79 | bias=bias, 80 | ) 81 | 82 | if activation is not None: 83 | self.activation = getattr(torch, activation) 84 | else: 85 | self.activation = None 86 | 87 | self.norm = norm 88 | if norm == "BN": 89 | self.norm_layer = nn.BatchNorm2d(out_channels) 90 | elif norm == "IN": 91 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 92 | 93 | def forward(self, x): 94 | out = self.transposed_conv2d(x) 95 | 96 | if self.norm in ["BN", "IN"]: 97 | out = self.norm_layer(out) 98 | 99 | if self.activation is not None: 100 | out = self.activation(out) 101 | 102 | return out 103 | 104 | 105 | class UpsampleConvLayer(nn.Module): 106 | """ 107 | Upsampling layer (bilinear interpolation + Conv2d) to increase spatial resolution (x2) in a decoder. 108 | Default: bias, ReLU, no downsampling, no batch norm. 109 | """ 110 | 111 | def __init__( 112 | self, 113 | in_channels, 114 | out_channels, 115 | kernel_size, 116 | stride=1, 117 | padding=0, 118 | activation="relu", 119 | norm=None, 120 | ): 121 | super(UpsampleConvLayer, self).__init__() 122 | 123 | bias = False if norm == "BN" else True 124 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 125 | 126 | if activation is not None: 127 | self.activation = getattr(torch, activation) 128 | else: 129 | self.activation = None 130 | 131 | self.norm = norm 132 | if norm == "BN": 133 | self.norm_layer = nn.BatchNorm2d(out_channels) 134 | elif norm == "IN": 135 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 136 | 137 | def forward(self, x): 138 | x_upsampled = f.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) 139 | out = self.conv2d(x_upsampled) 140 | 141 | if self.norm in ["BN", "IN"]: 142 | out = self.norm_layer(out) 143 | 144 | if self.activation is not None: 145 | out = self.activation(out) 146 | 147 | return out 148 | 149 | 150 | class RecurrentConvLayer(nn.Module): 151 | """ 152 | Layer comprised of a convolution followed by a recurrent convolutional block. 153 | Default: bias, ReLU, no downsampling, no batch norm, ConvLSTM. 154 | """ 155 | 156 | def __init__( 157 | self, 158 | in_channels, 159 | out_channels, 160 | kernel_size=3, 161 | stride=1, 162 | padding=0, 163 | recurrent_block_type="convlstm", 164 | activation="relu", 165 | norm=None, 166 | BN_momentum=0.1, 167 | ): 168 | super(RecurrentConvLayer, self).__init__() 169 | 170 | assert recurrent_block_type in ["convlstm", "convgru"] 171 | self.recurrent_block_type = recurrent_block_type 172 | if self.recurrent_block_type == "convlstm": 173 | RecurrentBlock = ConvLSTM 174 | else: 175 | RecurrentBlock = ConvGRU 176 | self.conv = ConvLayer( 177 | in_channels, 178 | out_channels, 179 | kernel_size, 180 | stride, 181 | padding, 182 | activation, 183 | norm, 184 | BN_momentum=BN_momentum, 185 | ) 186 | self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3) 187 | 188 | def forward(self, x, prev_state): 189 | x = self.conv(x) 190 | state = self.recurrent_block(x, prev_state) 191 | x = state[0] if self.recurrent_block_type == "convlstm" else state 192 | return x, state 193 | 194 | 195 | class ResidualBlock(nn.Module): 196 | """ 197 | Residual block as in "Deep residual learning for image recognition", He et al. 2016. 198 | Default: bias, ReLU, no downsampling, no batch norm, ConvLSTM. 199 | """ 200 | 201 | def __init__( 202 | self, 203 | in_channels, 204 | out_channels, 205 | stride=1, 206 | downsample=None, 207 | norm=None, 208 | BN_momentum=0.1, 209 | ): 210 | super(ResidualBlock, self).__init__() 211 | bias = False if norm == "BN" else True 212 | self.conv1 = nn.Conv2d( 213 | in_channels, 214 | out_channels, 215 | kernel_size=3, 216 | stride=stride, 217 | padding=1, 218 | bias=bias, 219 | ) 220 | self.norm = norm 221 | if norm == "BN": 222 | self.bn1 = nn.BatchNorm2d(out_channels, momentum=BN_momentum) 223 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=BN_momentum) 224 | elif norm == "IN": 225 | self.bn1 = nn.InstanceNorm2d(out_channels) 226 | self.bn2 = nn.InstanceNorm2d(out_channels) 227 | 228 | self.relu = nn.ReLU(inplace=True) 229 | self.conv2 = nn.Conv2d( 230 | out_channels, 231 | out_channels, 232 | kernel_size=3, 233 | stride=1, 234 | padding=1, 235 | bias=bias, 236 | ) 237 | self.downsample = downsample 238 | 239 | def forward(self, x): 240 | residual = x 241 | out = self.conv1(x) 242 | if self.norm in ["BN", "IN"]: 243 | out = self.bn1(out) 244 | out = self.relu(out) 245 | out = self.conv2(out) 246 | if self.norm in ["BN", "IN"]: 247 | out = self.bn2(out) 248 | 249 | if self.downsample: 250 | residual = self.downsample(x) 251 | 252 | out += residual 253 | out = self.relu(out) 254 | return out 255 | 256 | 257 | class ConvLSTM(nn.Module): 258 | """ 259 | Convolutional LSTM module. 260 | Adapted from https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py 261 | """ 262 | 263 | def __init__(self, input_size, hidden_size, kernel_size): 264 | super(ConvLSTM, self).__init__() 265 | 266 | self.input_size = input_size 267 | self.hidden_size = hidden_size 268 | pad = kernel_size // 2 269 | 270 | # cache a tensor filled with zeros to avoid reallocating memory at each inference step if --no-recurrent is enabled 271 | self.zero_tensors = {} 272 | 273 | self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad) 274 | 275 | def forward(self, input_, prev_state=None): 276 | 277 | # get batch and spatial sizes 278 | batch_size = input_.data.size()[0] 279 | spatial_size = input_.data.size()[2:] 280 | 281 | # generate empty prev_state, if None is provided 282 | if prev_state is None: 283 | 284 | # create the zero tensor if it has not been created already 285 | state_size = tuple([batch_size, self.hidden_size] + list(spatial_size)) 286 | if state_size not in self.zero_tensors: 287 | # allocate a tensor with size `spatial_size`, filled with zero (if it has not been allocated already) 288 | self.zero_tensors[state_size] = ( 289 | torch.zeros(state_size, dtype=input_.dtype).to(input_.device), 290 | torch.zeros(state_size, dtype=input_.dtype).to(input_.device), 291 | ) 292 | 293 | prev_state = self.zero_tensors[tuple(state_size)] 294 | 295 | prev_hidden, prev_cell = prev_state 296 | 297 | # data size is [batch, channel, height, width] 298 | stacked_inputs = torch.cat((input_, prev_hidden), 1) 299 | gates = self.Gates(stacked_inputs) 300 | 301 | # chunk across channel dimension 302 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1) 303 | 304 | # apply sigmoid non linearity 305 | in_gate = torch.sigmoid(in_gate) 306 | remember_gate = torch.sigmoid(remember_gate) 307 | out_gate = torch.sigmoid(out_gate) 308 | 309 | # apply tanh non linearity 310 | cell_gate = torch.tanh(cell_gate) 311 | 312 | # compute current cell and hidden state 313 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate) 314 | hidden = out_gate * torch.tanh(cell) 315 | 316 | return hidden, cell 317 | 318 | 319 | class ConvGRU(nn.Module): 320 | """ 321 | Convolutional GRU cell. 322 | Adapted from https://github.com/jacobkimmel/pytorch_convgru/blob/master/convgru.py 323 | """ 324 | 325 | def __init__(self, input_size, hidden_size, kernel_size): 326 | super().__init__() 327 | padding = kernel_size // 2 328 | self.input_size = input_size 329 | self.hidden_size = hidden_size 330 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 331 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 332 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 333 | 334 | nn.init.orthogonal_(self.reset_gate.weight) 335 | nn.init.orthogonal_(self.update_gate.weight) 336 | nn.init.orthogonal_(self.out_gate.weight) 337 | nn.init.constant_(self.reset_gate.bias, 0.0) 338 | nn.init.constant_(self.update_gate.bias, 0.0) 339 | nn.init.constant_(self.out_gate.bias, 0.0) 340 | 341 | def forward(self, input_, prev_state): 342 | 343 | # get batch and spatial sizes 344 | batch_size = input_.data.size()[0] 345 | spatial_size = input_.data.size()[2:] 346 | 347 | # generate empty prev_state, if None is provided 348 | if prev_state is None: 349 | state_size = [batch_size, self.hidden_size] + list(spatial_size) 350 | prev_state = torch.zeros(state_size, dtype=input_.dtype).to(input_.device) 351 | 352 | # data size is [batch, channel, height, width] 353 | stacked_inputs = torch.cat([input_, prev_state], dim=1) 354 | update = torch.sigmoid(self.update_gate(stacked_inputs)) 355 | reset = torch.sigmoid(self.reset_gate(stacked_inputs)) 356 | out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) 357 | new_state = prev_state * (1 - update) + out_inputs * update 358 | 359 | return new_state 360 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from UZH-RPG https://github.com/uzh-rpg/rpg_e2vid 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .model_util import * 9 | from .submodules import ( 10 | ConvGRU, 11 | ConvLayer, 12 | RecurrentConvLayer, 13 | ResidualBlock, 14 | TransposedConvLayer, 15 | UpsampleConvLayer, 16 | ) 17 | 18 | 19 | class BaseUNet(nn.Module): 20 | """ 21 | Base class for conventional UNet architecture. 22 | Symmetric, skip connections on every encoding layer. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | base_num_channels, 28 | num_encoders, 29 | num_residual_blocks, 30 | num_output_channels, 31 | skip_type, 32 | norm, 33 | use_upsample_conv, 34 | num_bins, 35 | recurrent_block_type=None, 36 | kernel_size=5, 37 | channel_multiplier=2, 38 | ): 39 | super(BaseUNet, self).__init__() 40 | self.base_num_channels = base_num_channels 41 | self.num_encoders = num_encoders 42 | self.num_residual_blocks = num_residual_blocks 43 | self.num_output_channels = num_output_channels 44 | self.kernel_size = kernel_size 45 | self.skip_type = skip_type 46 | self.norm = norm 47 | self.num_bins = num_bins 48 | self.recurrent_block_type = recurrent_block_type 49 | self.channel_multiplier = channel_multiplier 50 | 51 | self.skip_ftn = eval("skip_" + skip_type) 52 | if use_upsample_conv: 53 | self.UpsampleLayer = UpsampleConvLayer 54 | else: 55 | self.UpsampleLayer = TransposedConvLayer 56 | assert self.num_output_channels > 0 57 | 58 | self.encoder_input_sizes = [ 59 | int(self.base_num_channels * pow(self.channel_multiplier, i)) for i in range(self.num_encoders) 60 | ] 61 | self.encoder_output_sizes = [ 62 | int(self.base_num_channels * pow(self.channel_multiplier, i + 1)) for i in range(self.num_encoders) 63 | ] 64 | self.max_num_channels = self.encoder_output_sizes[-1] 65 | 66 | def build_encoders(self): 67 | encoders = nn.ModuleList() 68 | for (input_size, output_size) in zip(self.encoder_input_sizes, self.encoder_output_sizes): 69 | encoders.append( 70 | ConvLayer( 71 | input_size, 72 | output_size, 73 | kernel_size=self.kernel_size, 74 | stride=2, 75 | padding=self.kernel_size // 2, 76 | activation=self.activation, 77 | norm=self.norm, 78 | ) 79 | ) 80 | return encoders 81 | 82 | def build_resblocks(self): 83 | resblocks = nn.ModuleList() 84 | for i in range(self.num_residual_blocks): 85 | resblocks.append(ResidualBlock(self.max_num_channels, self.max_num_channels, norm=self.norm)) 86 | return resblocks 87 | 88 | def build_decoders(self): 89 | decoder_input_sizes = reversed(self.encoder_output_sizes) 90 | decoder_output_sizes = reversed(self.encoder_input_sizes) 91 | decoders = nn.ModuleList() 92 | for input_size, output_size in zip(decoder_input_sizes, decoder_output_sizes): 93 | decoders.append( 94 | self.UpsampleLayer( 95 | input_size if self.skip_type == "sum" else 2 * input_size, 96 | output_size, 97 | kernel_size=self.kernel_size, 98 | padding=self.kernel_size // 2, 99 | norm=self.norm, 100 | ) 101 | ) 102 | return decoders 103 | 104 | def build_prediction_layer(self, num_output_channels, norm=None): 105 | return ConvLayer( 106 | self.base_num_channels if self.skip_type == "sum" else 2 * self.base_num_channels, 107 | num_output_channels, 108 | 1, 109 | activation=None, 110 | norm=norm, 111 | ) 112 | 113 | 114 | class UNetRecurrent(BaseUNet): 115 | """ 116 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block. 117 | Symmetric, skip connections on every encoding layer. 118 | """ 119 | 120 | def __init__(self, unet_kwargs): 121 | final_activation = unet_kwargs.pop("final_activation", "none") 122 | self.final_activation = getattr(torch, final_activation, None) 123 | unet_kwargs["num_output_channels"] = 1 124 | super().__init__(**unet_kwargs) 125 | 126 | self.head = ConvLayer( 127 | self.num_bins, 128 | self.base_num_channels, 129 | kernel_size=self.kernel_size, 130 | stride=1, 131 | padding=self.kernel_size // 2, 132 | ) 133 | 134 | self.encoders = self.build_recurrent_encoders() 135 | self.resblocks = self.build_resblocks() 136 | self.decoders = self.build_decoders() 137 | self.pred = self.build_prediction_layer(self.num_output_channels, self.norm) 138 | self.states = [None] * self.num_encoders 139 | 140 | def build_recurrent_encoders(self): 141 | encoders = nn.ModuleList() 142 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes): 143 | encoders.append( 144 | RecurrentConvLayer( 145 | input_size, 146 | output_size, 147 | kernel_size=self.kernel_size, 148 | stride=2, 149 | padding=self.kernel_size // 2, 150 | recurrent_block_type=self.recurrent_block_type, 151 | norm=self.norm, 152 | ) 153 | ) 154 | return encoders 155 | 156 | def forward(self, x): 157 | """ 158 | :param x: N x num_input_channels x H x W 159 | :return: N x num_output_channels x H x W 160 | """ 161 | 162 | # head 163 | x = self.head(x) 164 | head = x 165 | 166 | # encoder 167 | blocks = [] 168 | for i, encoder in enumerate(self.encoders): 169 | x, state = encoder(x, self.states[i]) 170 | blocks.append(x) 171 | self.states[i] = state 172 | 173 | # residual blocks 174 | for resblock in self.resblocks: 175 | x = resblock(x) 176 | 177 | # decoder 178 | for i, decoder in enumerate(self.decoders): 179 | x = decoder(self.skip_ftn(x, blocks[self.num_encoders - i - 1])) 180 | 181 | # tail 182 | img = self.pred(self.skip_ftn(x, head)) 183 | if self.final_activation is not None: 184 | img = self.final_activation(img) 185 | return img 186 | 187 | 188 | class MultiResUNet(BaseUNet): 189 | """ 190 | Conventional UNet architecture. 191 | Symmetric, skip connections on every encoding layer. 192 | Predictions at each decoding layer. 193 | Predictions are added as skip connection (concat) to the input of the subsequent layer. 194 | """ 195 | 196 | def __init__(self, unet_kwargs): 197 | self.final_activation = unet_kwargs.pop("final_activation", "none") 198 | self.skip_type = "concat" 199 | super().__init__(**unet_kwargs) 200 | 201 | self.encoders = self.build_encoders() 202 | self.resblocks = self.build_resblocks() 203 | self.decoders = self.build_multires_prediction_decoders() 204 | self.preds = self.build_multires_prediction_layer() 205 | 206 | def build_encoders(self): 207 | encoders = nn.ModuleList() 208 | for i, (input_size, output_size) in enumerate(zip(self.encoder_input_sizes, self.encoder_output_sizes)): 209 | if i == 0: 210 | input_size = self.num_bins 211 | encoders.append( 212 | ConvLayer( 213 | input_size, 214 | output_size, 215 | kernel_size=self.kernel_size, 216 | stride=2, 217 | padding=self.kernel_size // 2, 218 | norm=self.norm, 219 | ) 220 | ) 221 | return encoders 222 | 223 | def build_multires_prediction_layer(self): 224 | preds = nn.ModuleList() 225 | decoder_output_sizes = reversed(self.encoder_input_sizes) 226 | for output_size in decoder_output_sizes: 227 | preds.append( 228 | ConvLayer(output_size, self.num_output_channels, 1, activation=self.final_activation, norm=self.norm) 229 | ) 230 | return preds 231 | 232 | def build_multires_prediction_decoders(self): 233 | decoder_input_sizes = reversed(self.encoder_output_sizes) 234 | decoder_output_sizes = reversed(self.encoder_input_sizes) 235 | decoders = nn.ModuleList() 236 | for i, (input_size, output_size) in enumerate(zip(decoder_input_sizes, decoder_output_sizes)): 237 | prediction_channels = 0 if i == 0 else self.num_output_channels 238 | decoders.append( 239 | self.UpsampleLayer( 240 | 2 * input_size + prediction_channels, 241 | output_size, 242 | kernel_size=self.kernel_size, 243 | padding=self.kernel_size // 2, 244 | norm=self.norm, 245 | ) 246 | ) 247 | return decoders 248 | 249 | def forward(self, x): 250 | """ 251 | :param x: N x num_input_channels x H x W 252 | :return: N x num_output_channels x H x W 253 | """ 254 | 255 | # encoder 256 | blocks = [] 257 | for i, encoder in enumerate(self.encoders): 258 | x = encoder(x) 259 | blocks.append(x) 260 | 261 | # residual blocks 262 | for resblock in self.resblocks: 263 | x = resblock(x) 264 | 265 | # decoder and multires predictions 266 | predictions = [] 267 | for i, (decoder, pred) in enumerate(zip(self.decoders, self.preds)): 268 | x = self.skip_ftn(x, blocks[self.num_encoders - i - 1]) 269 | if i > 0: 270 | x = self.skip_ftn(predictions[-1], x) 271 | x = decoder(x) 272 | predictions.append(pred(x)) 273 | 274 | return predictions 275 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | /( 6 | \.git 7 | | \.hg 8 | | \.mypy_cache 9 | | \.tox 10 | | \.venv 11 | | _build 12 | | buck-out 13 | | build 14 | | dist 15 | )/ 16 | ''' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | PyYAML==5.3.1 3 | numpy==1.19.2 4 | h5py==2.10.0 5 | opencv-python==4.4.0.44 6 | pre-commit 7 | matplotlib==3.3.2 8 | progress 9 | mlflow -------------------------------------------------------------------------------- /train_flow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import mlflow 5 | import numpy as np 6 | import torch 7 | from torch.optim import * 8 | 9 | from configs.parser import YAMLParser 10 | from dataloader.h5 import H5Loader 11 | from loss.flow import EventWarping 12 | from models.model import FireFlowNet, EVFlowNet 13 | from utils.utils import load_model, create_model_dir, save_model 14 | from utils.visualization import Visualization 15 | 16 | 17 | def train(args, config_parser): 18 | if not os.path.exists(args.path_models): 19 | os.makedirs(args.path_models) 20 | 21 | # configs 22 | config = config_parser.config 23 | config["vis"]["bars"] = False 24 | 25 | # log config 26 | mlflow.set_experiment(config["experiment"]) 27 | mlflow.start_run() 28 | mlflow.log_params(config) 29 | mlflow.log_param("prev_model", args.prev_model) 30 | config["prev_model"] = args.prev_model 31 | 32 | # initialize settings 33 | device = config_parser.device 34 | kwargs = config_parser.loader_kwargs 35 | 36 | # visualization tool 37 | if config["vis"]["enabled"]: 38 | vis = Visualization(config) 39 | 40 | # loss functions 41 | loss_function = EventWarping(config, device) 42 | 43 | # optical flow settings 44 | num_bins = config["data"]["num_bins"] 45 | model = eval(config["model_flow"]["name"])(config["model_flow"].copy(), num_bins).to(device) 46 | model = load_model(args.prev_model, model, device) 47 | model.train() 48 | 49 | # model directory 50 | path_models = create_model_dir(args.path_models, mlflow.active_run().info.run_id) 51 | mlflow.log_param("trained_model", path_models) 52 | config["trained_model"] = path_models 53 | config_parser.config = config 54 | config_parser.log_config(path_models) 55 | 56 | # data loader 57 | data = H5Loader(config, num_bins) 58 | dataloader = torch.utils.data.DataLoader( 59 | data, 60 | drop_last=True, 61 | batch_size=config["loader"]["batch_size"], 62 | collate_fn=data.custom_collate, 63 | worker_init_fn=config_parser.worker_init_fn, 64 | **kwargs 65 | ) 66 | 67 | # optimizers 68 | optimizer = eval(config["optimizer"]["name"])(model.parameters(), lr=config["optimizer"]["lr"]) 69 | optimizer.zero_grad() 70 | 71 | # simulation variables 72 | loss = 0 73 | train_loss = 0 74 | best_loss = 1.0e6 75 | end_train = False 76 | 77 | # training loop 78 | data.shuffle() 79 | while True: 80 | for inputs in dataloader: 81 | 82 | # check new epoch 83 | if data.seq_num >= len(data.files): 84 | mlflow.log_metric("loss_flow", train_loss / (data.samples + 1), step=data.epoch) 85 | 86 | with torch.no_grad(): 87 | if train_loss / (data.samples + 1) < best_loss: 88 | save_model(path_models, model) 89 | best_loss = train_loss / (data.samples + 1) 90 | 91 | data.epoch += 1 92 | data.samples = 0 93 | train_loss = 0 94 | data.seq_num = data.seq_num % len(data.files) 95 | 96 | # finish training loop 97 | if data.epoch == config["loader"]["n_epochs"]: 98 | end_train = True 99 | 100 | # forward pass 101 | x = model(inputs["inp_voxel"].to(device), inputs["inp_cnt"].to(device)) 102 | 103 | # loss and backward pass 104 | loss = loss_function(x["flow"], inputs["inp_list"].to(device), inputs["inp_pol_mask"].to(device)) 105 | train_loss += loss.item() 106 | loss.backward() 107 | optimizer.step() 108 | optimizer.zero_grad() 109 | 110 | # print training info 111 | if config["vis"]["verbose"]: 112 | print( 113 | "Train Epoch: {:04d} [{:03d}/{:03d} ({:03d}%)] loss: {:.6f}".format( 114 | data.epoch, 115 | data.seq_num, 116 | len(data.files), 117 | int(100 * data.seq_num / len(data.files)), 118 | train_loss / (data.samples + 1), 119 | ), 120 | end="\r", 121 | ) 122 | 123 | # visualize 124 | with torch.no_grad(): 125 | if config["vis"]["enabled"] and config["loader"]["batch_size"] == 1: 126 | vis.update(inputs, x["flow"][-1], None, None) 127 | 128 | # update number of samples seen by the network 129 | data.samples += config["loader"]["batch_size"] 130 | 131 | if end_train: 132 | break 133 | 134 | mlflow.end_run() 135 | 136 | 137 | if __name__ == "__main__": 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument( 140 | "--config", 141 | default="configs/train_flow.yml", 142 | help="training configuration", 143 | ) 144 | parser.add_argument( 145 | "--path_models", 146 | default="trained_models/", 147 | help="location of trained models", 148 | ) 149 | parser.add_argument( 150 | "--prev_model", 151 | default="", 152 | help="pre-trained model to use as starting point", 153 | ) 154 | args = parser.parse_args() 155 | 156 | # launch training 157 | train(args, YAMLParser(args.config)) 158 | -------------------------------------------------------------------------------- /train_reconstruction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import mlflow 5 | import numpy as np 6 | import torch 7 | from torch.optim import * 8 | 9 | from configs.parser import YAMLParser 10 | from dataloader.h5 import H5Loader 11 | from loss.flow import EventWarping 12 | from loss.reconstruction import BrightnessConstancy 13 | from models.model import FireFlowNet, EVFlowNet 14 | from models.model import FireNet, E2VID 15 | from utils.utils import load_model, create_model_dir, save_model 16 | from utils.visualization import Visualization 17 | 18 | 19 | def train(args, config_parser): 20 | if not os.path.exists(args.path_models): 21 | os.makedirs(args.path_models) 22 | 23 | # configs 24 | config = config_parser.config 25 | config["vis"]["bars"] = False 26 | 27 | # log config 28 | mlflow.set_experiment(config["experiment"]) 29 | mlflow.start_run() 30 | mlflow.log_params(config) 31 | mlflow.log_param("prev_model", args.prev_model) 32 | config["prev_model"] = args.prev_model 33 | 34 | # initialize settings 35 | device = config_parser.device 36 | kwargs = config_parser.loader_kwargs 37 | num_bins = config["data"]["num_bins"] 38 | 39 | # visualization tool 40 | if config["vis"]["enabled"]: 41 | vis = Visualization(config) 42 | 43 | # data loader 44 | data = H5Loader(config, num_bins) 45 | dataloader = torch.utils.data.DataLoader( 46 | data, 47 | drop_last=True, 48 | batch_size=config["loader"]["batch_size"], 49 | collate_fn=data.custom_collate, 50 | worker_init_fn=config_parser.worker_init_fn, 51 | **kwargs 52 | ) 53 | 54 | # loss functions 55 | loss_function_flow = EventWarping(config, device) 56 | loss_function_reconstruction = BrightnessConstancy(config, device) 57 | 58 | # reconstruction settings 59 | model_reconstruction = eval(config["model_reconstruction"]["name"])( 60 | config["model_reconstruction"].copy(), num_bins 61 | ).to(device) 62 | model_reconstruction = load_model(args.prev_model, model_reconstruction, device) 63 | model_reconstruction.train() 64 | 65 | # optical flow settings 66 | model_flow = eval(config["model_flow"]["name"])(config["model_flow"].copy(), num_bins).to(device) 67 | model_flow = load_model(args.prev_model, model_flow, device) 68 | if config["loss"]["train_flow"]: 69 | model_flow.train() 70 | else: 71 | model_flow.eval() 72 | 73 | # model directory 74 | path_models = create_model_dir(args.path_models, mlflow.active_run().info.run_id) 75 | mlflow.log_param("trained_model", path_models) 76 | config_parser.log_config(path_models) 77 | config["trained_model"] = path_models 78 | config_parser.config = config 79 | config_parser.log_config(path_models) 80 | 81 | # optimizers 82 | optimizer_reconstruction = eval(config["optimizer"]["name"])( 83 | model_reconstruction.parameters(), lr=config["optimizer"]["lr"] 84 | ) 85 | optimizer_flow = eval(config["optimizer"]["name"])(model_flow.parameters(), lr=config["optimizer"]["lr"]) 86 | optimizer_reconstruction.zero_grad() 87 | optimizer_flow.zero_grad() 88 | 89 | # simulation variables 90 | seq_length = 0 91 | loss_reconstruction = 0 92 | loss_flow = 0 93 | train_loss_reconstruction = 0 94 | train_loss_flow = 0 95 | best_loss_reconstruction = 1.0e6 96 | best_loss_flow = 1.0e6 97 | end_train = False 98 | 99 | prev_img = None 100 | x_reconstruction = None 101 | 102 | # training loop 103 | data.shuffle() 104 | while True: 105 | for inputs in dataloader: 106 | 107 | if data.new_seq: 108 | seq_length = 0 109 | data.new_seq = False 110 | loss_reconstruction = 0 111 | model_reconstruction.reset_states() 112 | optimizer_reconstruction.zero_grad() 113 | 114 | prev_img = None 115 | x_reconstruction = None 116 | 117 | if data.seq_num >= len(data.files): 118 | mlflow.log_metric( 119 | "loss_reconstruction", train_loss_reconstruction / (data.samples + 1), step=data.epoch 120 | ) 121 | mlflow.log_metric("loss_flow", train_loss_flow / (data.samples + 1), step=data.epoch) 122 | 123 | with torch.no_grad(): 124 | if train_loss_reconstruction / (data.samples + 1) < best_loss_reconstruction: 125 | save_model(path_models, model_reconstruction) 126 | best_loss_reconstruction = train_loss_reconstruction / (data.samples + 1) 127 | if train_loss_flow / (data.samples + 1) < best_loss_flow: 128 | save_model(path_models, model_flow) 129 | best_loss_flow = train_loss_flow / (data.samples + 1) 130 | 131 | data.epoch += 1 132 | data.samples = 0 133 | train_loss_flow = 0 134 | train_loss_reconstruction = 0 135 | data.seq_num = data.seq_num % len(data.files) 136 | 137 | # finish training loop 138 | if data.epoch == config["loader"]["n_epochs"]: 139 | end_train = True 140 | 141 | # forward pass - flow network 142 | x_flow = model_flow(inputs["inp_voxel"].to(device), inputs["inp_cnt"].to(device)) 143 | 144 | # loss and backward pass 145 | if config["loss"]["train_flow"]: 146 | loss_flow = loss_function_flow( 147 | x_flow["flow"], inputs["inp_list"].to(device), inputs["inp_pol_mask"].to(device) 148 | ) 149 | 150 | train_loss_flow += loss_flow.item() 151 | loss_flow.backward() 152 | optimizer_flow.step() 153 | optimizer_flow.zero_grad() 154 | 155 | if x_reconstruction is not None: 156 | 157 | # reconstruction loss - generative model 158 | delta_loss_model = loss_function_reconstruction.generative_model( 159 | x_flow["flow"][0].detach(), x_reconstruction["image"], inputs 160 | ) 161 | loss_reconstruction += delta_loss_model 162 | train_loss_reconstruction += delta_loss_model.item() 163 | 164 | if prev_img is None or "Pause" not in data.batch_augmentation or not data.batch_augmentation["Pause"]: 165 | 166 | # reconstruction loss - regularization 167 | delta_loss_reg = loss_function_reconstruction.regularization(x_reconstruction["image"]) 168 | loss_reconstruction += delta_loss_reg 169 | train_loss_reconstruction += delta_loss_reg.item() 170 | 171 | # update previous image 172 | prev_img = x_reconstruction["image"].detach().clone() 173 | 174 | # forward pass - reconstruction network 175 | x_reconstruction = model_reconstruction(inputs["inp_voxel"].to(device)) 176 | data.tc_idx += 1 177 | 178 | # reconstruction loss - temporal constancy 179 | if data.tc_idx >= config["loss"]["reconstruction_tc_idx_threshold"]: 180 | delta_loss_tc = loss_function_reconstruction.temporal_consistency( 181 | x_flow["flow"][0].detach(), prev_img, x_reconstruction["image"] 182 | ) 183 | loss_reconstruction += delta_loss_tc 184 | train_loss_reconstruction += delta_loss_tc.item() 185 | 186 | # update sequence length 187 | seq_length += 1 188 | 189 | # visualize 190 | with torch.no_grad(): 191 | if config["vis"]["enabled"] and config["loader"]["batch_size"] == 1: 192 | vis.update(inputs, x_flow["flow"][-1], None, x_reconstruction["image"]) 193 | 194 | # reconstruction backward pass 195 | if seq_length == config["loss"]["reconstruction_unroll"]: 196 | 197 | if loss_reconstruction != 0: 198 | loss_reconstruction.backward() 199 | optimizer_reconstruction.step() 200 | optimizer_reconstruction.zero_grad() 201 | 202 | seq_length = 0 203 | x_reconstruction = None 204 | loss_reconstruction = 0 205 | 206 | # detach states 207 | model_reconstruction.detach_states() 208 | 209 | # print training info 210 | if config["vis"]["verbose"]: 211 | print( 212 | "Train Epoch: {:04d} [{:03d}/{:03d} ({:03d}%)] Flow loss: {:.6f}, Brightness loss: {:.6f}".format( 213 | data.epoch, 214 | data.seq_num, 215 | len(data.files), 216 | int(100 * data.seq_num / len(data.files)), 217 | train_loss_flow / (data.samples + 1), 218 | train_loss_reconstruction / (data.samples + 1), 219 | ), 220 | end="\r", 221 | ) 222 | 223 | # update number of samples seen by the network 224 | data.samples += config["loader"]["batch_size"] 225 | 226 | if end_train: 227 | break 228 | 229 | mlflow.end_run() 230 | 231 | 232 | if __name__ == "__main__": 233 | parser = argparse.ArgumentParser() 234 | parser.add_argument( 235 | "--config", 236 | default="configs/train_reconstruction.yml", 237 | help="training configuration", 238 | ) 239 | parser.add_argument( 240 | "--path_models", 241 | default="trained_models/", 242 | help="location of trained models", 243 | ) 244 | parser.add_argument( 245 | "--prev_model", 246 | default="", 247 | help="pre-trained model to use as starting point", 248 | ) 249 | args = parser.parse_args() 250 | 251 | # launch training 252 | train(args, YAMLParser(args.config)) 253 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/ssl_e2vid/ef9d61b476576f60f92029cadde4286c6cf0f734/utils/__init__.py -------------------------------------------------------------------------------- /utils/gradients.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Sobel(nn.Module): 8 | """ 9 | Computes the spatial gradients of 3D data using Sobel filters. 10 | """ 11 | 12 | def __init__(self, device): 13 | super().__init__() 14 | self.pad = nn.ReplicationPad2d(1) 15 | a = np.zeros((1, 1, 3, 3)) 16 | b = np.zeros((1, 1, 3, 3)) 17 | a[0, :, :, :] = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) 18 | b[0, :, :, :] = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) 19 | self.a = torch.from_numpy(a).float().to(device) 20 | self.b = torch.from_numpy(b).float().to(device) 21 | 22 | def forward(self, x): 23 | """ 24 | :param x: [batch_size x 1 x H x W] input tensor 25 | :return gradx: [batch_size x 2 x H x W-1] spatial gradient in the x direction 26 | :return grady: [batch_size x 2 x H-1 x W] spatial gradient in the y direction 27 | """ 28 | 29 | x = x.view(-1, 1, x.shape[2], x.shape[3]) # (batch * channels, 1, height, width) 30 | x = self.pad(x) 31 | gradx = F.conv2d(x, self.a, groups=1) / 8 # normalized gradients 32 | grady = F.conv2d(x, self.b, groups=1) / 8 # normalized gradients 33 | return gradx, grady 34 | -------------------------------------------------------------------------------- /utils/iwe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def purge_unfeasible(x, res): 5 | """ 6 | Purge unfeasible event locations by setting their interpolation weights to zero. 7 | :param x: location of motion compensated events 8 | :param res: resolution of the image space 9 | :return masked indices 10 | :return mask for interpolation weights 11 | """ 12 | 13 | mask = torch.ones((x.shape[0], x.shape[1], 1)).to(x.device) 14 | mask_y = (x[:, :, 0:1] < 0) + (x[:, :, 0:1] >= res[0]) 15 | mask_x = (x[:, :, 1:2] < 0) + (x[:, :, 1:2] >= res[1]) 16 | mask[mask_y + mask_x] = 0 17 | return x * mask, mask 18 | 19 | 20 | def get_interpolation(events, flow, tref, res, flow_scaling, round_idx=False): 21 | """ 22 | Warp the input events according to the provided optical flow map and compute the bilinar interpolation 23 | (or rounding) weights to distribute the events to the closes (integer) locations in the image space. 24 | :param events: [batch_size x N x 4] input events (y, x, ts, p) 25 | :param flow: [batch_size x 2 x H x W] optical flow map 26 | :param tref: reference time toward which events are warped 27 | :param res: resolution of the image space 28 | :param flow_scaling: scalar that multiplies the optical flow map 29 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = False) 30 | :return interpolated event indices 31 | :return interpolation weights 32 | """ 33 | 34 | # event propagation 35 | warped_events = events[:, :, 1:3] + (tref - events[:, :, 0:1]) * flow * flow_scaling 36 | 37 | if round_idx: 38 | 39 | # no bilinear interpolation 40 | idx = torch.round(warped_events) 41 | weights = torch.ones(idx.shape).to(events.device) 42 | 43 | else: 44 | 45 | # get scattering indices 46 | top_y = torch.floor(warped_events[:, :, 0:1]) 47 | bot_y = torch.floor(warped_events[:, :, 0:1] + 1) 48 | left_x = torch.floor(warped_events[:, :, 1:2]) 49 | right_x = torch.floor(warped_events[:, :, 1:2] + 1) 50 | 51 | top_left = torch.cat([top_y, left_x], dim=2) 52 | top_right = torch.cat([top_y, right_x], dim=2) 53 | bottom_left = torch.cat([bot_y, left_x], dim=2) 54 | bottom_right = torch.cat([bot_y, right_x], dim=2) 55 | idx = torch.cat([top_left, top_right, bottom_left, bottom_right], dim=1) 56 | 57 | # get scattering interpolation weights 58 | warped_events = torch.cat([warped_events for i in range(4)], dim=1) 59 | zeros = torch.zeros(warped_events.shape).to(events.device) 60 | weights = torch.max(zeros, 1 - torch.abs(warped_events - idx)) 61 | 62 | # purge unfeasible indices 63 | idx, mask = purge_unfeasible(idx, res) 64 | 65 | # make unfeasible weights zero 66 | weights = torch.prod(weights, dim=-1, keepdim=True) * mask # bilinear interpolation 67 | 68 | # prepare indices 69 | idx[:, :, 0] *= res[1] # torch.view is row-major 70 | idx = torch.sum(idx, dim=2, keepdim=True) 71 | 72 | return idx, weights 73 | 74 | 75 | def interpolate(idx, weights, res, polarity_mask=None): 76 | """ 77 | Create an image-like representation of the warped events. 78 | :param idx: [batch_size x N x 1] warped event locations 79 | :param weights: [batch_size x N x 1] interpolation weights for the warped events 80 | :param res: resolution of the image space 81 | :param polarity_mask: [batch_size x N x 2] polarity mask for the warped events (default = None) 82 | :return image of warped events 83 | """ 84 | 85 | if polarity_mask is not None: 86 | weights = weights * polarity_mask 87 | iwe = torch.zeros((idx.shape[0], res[0] * res[1], 1)).to(idx.device) 88 | iwe = iwe.scatter_add_(1, idx.long(), weights) 89 | iwe = iwe.view((idx.shape[0], 1, res[0], res[1])) 90 | return iwe 91 | 92 | 93 | def deblur_events(flow, event_list, res, flow_scaling=128, round_idx=True, polarity_mask=None): 94 | """ 95 | Deblur the input events given an optical flow map. 96 | Event timestamp needs to be normalized between 0 and 1. 97 | :param flow: [batch_size x 2 x H x W] optical flow map 98 | :param events: [batch_size x N x 4] input events (y, x, ts, p) 99 | :param res: resolution of the image space 100 | :param flow_scaling: scalar that multiplies the optical flow map 101 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = False) 102 | :param polarity_mask: [batch_size x N x 2] polarity mask for the warped events (default = None) 103 | :return iwe: [batch_size x 1 x H x W] image of warped events 104 | """ 105 | 106 | # flow vector per input event 107 | flow_idx = event_list[:, :, 1:3].clone() 108 | flow_idx[:, :, 0] *= res[1] # torch.view is row-major 109 | flow_idx = torch.sum(flow_idx, dim=2) 110 | 111 | # get flow for every event in the list 112 | flow = flow.view(flow.shape[0], 2, -1) 113 | event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component 114 | event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component 115 | event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1) 116 | event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1) 117 | event_flow = torch.cat([event_flowy, event_flowx], dim=2) 118 | 119 | # interpolate forward 120 | fw_idx, fw_weights = get_interpolation(event_list, event_flow, 1, res, flow_scaling, round_idx=round_idx) 121 | if not round_idx: 122 | polarity_mask = torch.cat([polarity_mask for i in range(4)], dim=1) 123 | 124 | # image of (forward) warped events 125 | iwe = interpolate(fw_idx.long(), fw_weights, res, polarity_mask=polarity_mask) 126 | 127 | return iwe 128 | 129 | 130 | def compute_pol_iwe(flow, event_list, res, pos_mask, neg_mask, flow_scaling=128, round_idx=True): 131 | """ 132 | Create a per-polarity image of warped events given an optical flow map. 133 | :param flow: [batch_size x 2 x H x W] optical flow map 134 | :param event_list: [batch_size x N x 4] input events (y, x, ts, p) 135 | :param res: resolution of the image space 136 | :param pos_mask: [batch_size x N x 1] polarity mask for positive events 137 | :param neg_mask: [batch_size x N x 1] polarity mask for negative events 138 | :param flow_scaling: scalar that multiplies the optical flow map 139 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = True) 140 | :return iwe: [batch_size x 2 x H x W] image of warped events 141 | """ 142 | 143 | iwe_pos = deblur_events( 144 | flow, event_list, res, flow_scaling=flow_scaling, round_idx=round_idx, polarity_mask=pos_mask 145 | ) 146 | iwe_neg = deblur_events( 147 | flow, event_list, res, flow_scaling=flow_scaling, round_idx=round_idx, polarity_mask=neg_mask 148 | ) 149 | iwe = torch.cat([iwe_pos, iwe_neg], dim=1) 150 | 151 | return iwe 152 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | 4 | import torch 5 | 6 | 7 | def load_model(model_dir, model, device): 8 | """ 9 | Load model from file. 10 | :param model_dir: model directory 11 | :param model: instance of the model class to be loaded 12 | :param device: model device 13 | :return loaded model 14 | """ 15 | 16 | if os.path.isfile(model_dir): 17 | model_loaded = torch.load(model_dir, map_location=device) 18 | if "state_dict" in model_loaded.keys(): 19 | model_loaded = model_loaded["state_dict"] 20 | model.load_state_dict(model_loaded) 21 | print("Model restored from " + model_dir + "\n") 22 | 23 | elif os.path.isdir(model_dir): 24 | model_name = model_dir + model.__class__.__name__ 25 | 26 | extensions = [".pt", ".pth.tar", ".pwf", "_weights_min.pwf"] # backwards compatibility 27 | for ext in extensions: 28 | if os.path.isfile(model_name + ext): 29 | model_name += ext 30 | break 31 | 32 | if os.path.isfile(model_name): 33 | model_loaded = torch.load(model_name, map_location=device) 34 | if "state_dict" in model_loaded.keys(): 35 | model_loaded = model_loaded["state_dict"] 36 | model.load_state_dict(model_loaded) 37 | print("Model restored from " + model_name + "\n") 38 | else: 39 | print("No model found at" + model_name + "\n") 40 | 41 | return model 42 | 43 | 44 | def create_model_dir(path_models, runid): 45 | """ 46 | Create directory for storing model parameters. 47 | :param path_models: path in which the model should be stored 48 | :param runid: MLFlow's unique ID of the model 49 | :return path to generated model directory 50 | """ 51 | 52 | now = datetime.datetime.now() 53 | 54 | path_models += "model_" 55 | path_models += "%02d%02d%04d" % (now.day, now.month, now.year) 56 | path_models += "_%02d%02d%02d_" % (now.hour, now.minute, now.second) 57 | path_models += runid # mlflow run ID 58 | path_models += "/" 59 | if not os.path.exists(path_models): 60 | os.makedirs(path_models) 61 | print("Weights stored at " + path_models + "\n") 62 | return path_models 63 | 64 | 65 | def save_model(path_models, model): 66 | """ 67 | Overwrite previously saved model with new parameters. 68 | :param path_models: model directory 69 | :param model: instance of the model class to be saved 70 | """ 71 | 72 | os.system("rm -rf " + path_models + model.__class__.__name__ + ".pt") 73 | model_name = path_models + model.__class__.__name__ + ".pt" 74 | torch.save(model.state_dict(), model_name) 75 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import matplotlib 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Visualization: 12 | """ 13 | Utility class for the visualization and storage of rendered image-like representation 14 | of multiple elements of the optical flow estimation and image reconstruction pipeline. 15 | """ 16 | 17 | def __init__(self, kwargs, eval_id=-1): 18 | self.img_idx = 0 19 | self.px = kwargs["vis"]["px"] 20 | self.color_scheme = "green_red" # gray / blue_red / green_red 21 | 22 | if eval_id >= 0: 23 | self.store_dir = kwargs["trained_model"] + "results/" 24 | self.store_dir = self.store_dir + "eval_" + str(eval_id) + "/" 25 | if not os.path.exists(self.store_dir): 26 | os.makedirs(self.store_dir) 27 | self.store_file = None 28 | 29 | def update(self, inputs, flow, iwe, brightness): 30 | """ 31 | Live visualization. 32 | :param inputs: dataloader dictionary 33 | :param flow: [batch_size x 2 x H x W] optical flow map 34 | :param iwe: [batch_size x 1 x H x W] image of warped events 35 | :param brightness: [batch_size x 1 x H x W] reconstructed image 36 | """ 37 | 38 | inp_events = inputs["inp_cnt"] if "inp_cnt" in inputs.keys() else None 39 | inp_frames = inputs["inp_frames"] if "inp_frames" in inputs.keys() else None 40 | height = inp_events.shape[2] 41 | width = inp_events.shape[3] 42 | 43 | # input events 44 | inp_events = inp_events.detach() 45 | inp_events_npy = inp_events.cpu().numpy().transpose(0, 2, 3, 1).reshape((height, width, -1)) 46 | cv2.namedWindow("Input Events", cv2.WINDOW_NORMAL) 47 | cv2.resizeWindow("Input Events", int(self.px), int(self.px)) 48 | cv2.imshow("Input Events", self.events_to_image(inp_events_npy)) 49 | 50 | # input frames 51 | if inp_frames is not None: 52 | frame_image = np.zeros((height, 2 * width)) 53 | inp_frames = inp_frames.detach() 54 | inp_frames_npy = inp_frames.cpu().numpy().transpose(0, 2, 3, 1).reshape((height, width, 2)) 55 | frame_image[:height, 0:width] = inp_frames_npy[:, :, 0] / 255.0 56 | frame_image[:height, width : 2 * width] = inp_frames_npy[:, :, 1] / 255.0 57 | cv2.namedWindow("Input Frames (Prev/Curr)", cv2.WINDOW_NORMAL) 58 | cv2.resizeWindow("Input Frames (Prev/Curr)", int(2 * self.px), int(self.px)) 59 | cv2.imshow("Input Frames (Prev/Curr)", frame_image) 60 | 61 | # optical flow 62 | if flow is not None: 63 | flow = flow.detach() 64 | flow_npy = flow.cpu().numpy().transpose(0, 2, 3, 1).reshape((height, width, 2)) 65 | flow_npy = self.flow_to_image(flow_npy[:, :, 0], flow_npy[:, :, 1]) 66 | flow_npy = cv2.cvtColor(flow_npy, cv2.COLOR_RGB2BGR) 67 | cv2.namedWindow("Estimated Flow", cv2.WINDOW_NORMAL) 68 | cv2.resizeWindow("Estimated Flow", int(self.px), int(self.px)) 69 | cv2.imshow("Estimated Flow", flow_npy) 70 | 71 | # image of warped events 72 | if iwe is not None: 73 | iwe = iwe.detach() 74 | iwe_npy = iwe.cpu().numpy().transpose(0, 2, 3, 1).reshape((height, width, 2)) 75 | iwe_npy = self.events_to_image(iwe_npy) 76 | cv2.namedWindow("Image of Warped Events", cv2.WINDOW_NORMAL) 77 | cv2.resizeWindow("Image of Warped Events", int(self.px), int(self.px)) 78 | cv2.imshow("Image of Warped Events", iwe_npy) 79 | 80 | # reconstructed brightness 81 | if brightness is not None: 82 | brightness = brightness.detach() 83 | brightness_npy = brightness.cpu().numpy().transpose(0, 2, 3, 1).reshape((height, width, 1)) 84 | intensity_npy = brightness_npy.reshape((height, width, 1)) 85 | intensity_image = self.minmax_norm(intensity_npy) 86 | cv2.namedWindow("Reconstructed Intensity", cv2.WINDOW_NORMAL) 87 | cv2.resizeWindow("Reconstructed Intensity", int(self.px), int(self.px)) 88 | cv2.imshow("Reconstructed Intensity", intensity_image) 89 | 90 | cv2.waitKey(1) 91 | 92 | def store(self, inputs, flow, iwe, brightness, sequence, ts=None): 93 | """ 94 | Store rendered images. 95 | :param inputs: dataloader dictionary 96 | :param flow: [batch_size x 2 x H x W] optical flow map 97 | :param iwe: [batch_size x 1 x H x W] image of warped events 98 | :param brightness: [batch_size x 1 x H x W] reconstructed image 99 | :param sequence: filename of the event sequence under analysis 100 | :param ts: timestamp associated with rendered files (default = None) 101 | """ 102 | 103 | inp_events = inputs["inp_cnt"] if "inp_cnt" in inputs.keys() else None 104 | inp_frames = inputs["inp_frames"] if "inp_frames" in inputs.keys() else None 105 | height = inp_events.shape[2] 106 | width = inp_events.shape[3] 107 | 108 | # check if new sequence 109 | path_to = self.store_dir + sequence + "/" 110 | if not os.path.exists(path_to): 111 | os.makedirs(path_to) 112 | os.makedirs(path_to + "events/") 113 | os.makedirs(path_to + "flow/") 114 | os.makedirs(path_to + "frames/") 115 | os.makedirs(path_to + "iwe/") 116 | os.makedirs(path_to + "brightness/") 117 | if self.store_file is not None: 118 | self.store_file.close() 119 | self.store_file = open(path_to + "timestamps.txt", "w") 120 | self.img_idx = 0 121 | 122 | # input events 123 | event_image = np.zeros((height, width)) 124 | inp_events = inp_events.detach() 125 | inp_events_npy = inp_events.cpu().numpy().transpose(0, 2, 3, 1).reshape((height, width, -1)) 126 | event_image = self.events_to_image(inp_events_npy) 127 | filename = path_to + "events/%09d.png" % self.img_idx 128 | cv2.imwrite(filename, event_image * 255) 129 | 130 | # input frames 131 | if inp_frames is not None: 132 | inp_frames = inp_frames.detach() 133 | inp_frames_npy = inp_frames.cpu().numpy().transpose(0, 2, 3, 1).reshape((height, width, 2)) 134 | filename = path_to + "frames/%09d.png" % self.img_idx 135 | cv2.imwrite(filename, inp_frames_npy[:, :, 1]) 136 | 137 | # optical flow 138 | if flow is not None: 139 | flow = flow.detach() 140 | flow_npy = flow.cpu().numpy().transpose(0, 2, 3, 1).reshape((height, width, 2)) 141 | flow_npy = self.flow_to_image(flow_npy[:, :, 0], flow_npy[:, :, 1]) 142 | flow_npy = cv2.cvtColor(flow_npy, cv2.COLOR_RGB2BGR) 143 | filename = path_to + "flow/%09d.png" % self.img_idx 144 | cv2.imwrite(filename, flow_npy) 145 | 146 | # image of warped events 147 | if iwe is not None: 148 | iwe = iwe.detach() 149 | iwe_npy = iwe.cpu().numpy().transpose(0, 2, 3, 1).reshape((height, width, 2)) 150 | iwe_npy = self.events_to_image(iwe_npy) 151 | filename = path_to + "iwe/%09d.png" % self.img_idx 152 | cv2.imwrite(filename, iwe_npy * 255) 153 | 154 | # reconstructed brightness 155 | if brightness is not None: 156 | brightness = brightness.detach() 157 | brightness_npy = brightness.cpu().numpy().transpose(0, 2, 3, 1).reshape((height, width, 1)) 158 | intensity_npy = brightness_npy.reshape((height, width, 1)) 159 | intensity_image = self.minmax_norm(intensity_npy) 160 | filename = path_to + "brightness/%09d.png" % self.img_idx 161 | cv2.imwrite(filename, intensity_image * 255) 162 | 163 | # store timestamps 164 | if ts is not None: 165 | self.store_file.write(str(ts) + "\n") 166 | self.store_file.flush() 167 | 168 | self.img_idx += 1 169 | cv2.waitKey(1) 170 | 171 | @staticmethod 172 | def flow_to_image(flow_x, flow_y): 173 | """ 174 | Use the optical flow color scheme from the supplementary materials of the paper 'Back to Event 175 | Basics: Self-Supervised Image Reconstruction for Event Cameras via Photometric Constancy', 176 | Paredes-Valles et al., CVPR'21. 177 | :param flow_x: [H x W x 1] horizontal optical flow component 178 | :param flow_y: [H x W x 1] vertical optical flow component 179 | :return flow_rgb: [H x W x 3] color-encoded optical flow 180 | """ 181 | flows = np.stack((flow_x, flow_y), axis=2) 182 | mag = np.linalg.norm(flows, axis=2) 183 | min_mag = np.min(mag) 184 | mag_range = np.max(mag) - min_mag 185 | 186 | ang = np.arctan2(flow_y, flow_x) + np.pi 187 | ang *= 1.0 / np.pi / 2.0 188 | 189 | hsv = np.zeros([flow_x.shape[0], flow_x.shape[1], 3]) 190 | hsv[:, :, 0] = ang 191 | hsv[:, :, 1] = 1.0 192 | hsv[:, :, 2] = mag - min_mag 193 | if mag_range != 0.0: 194 | hsv[:, :, 2] /= mag_range 195 | 196 | flow_rgb = matplotlib.colors.hsv_to_rgb(hsv) 197 | return (255 * flow_rgb).astype(np.uint8) 198 | 199 | @staticmethod 200 | def minmax_norm(x): 201 | """ 202 | Robust min-max normalization. 203 | :param x: [H x W x 1] 204 | :return x: [H x W x 1] normalized x 205 | """ 206 | den = np.percentile(x, 99) - np.percentile(x, 1) 207 | if den != 0: 208 | x = (x - np.percentile(x, 1)) / den 209 | return np.clip(x, 0, 1) 210 | 211 | @staticmethod 212 | def events_to_image(inp_events, color_scheme="green_red"): 213 | """ 214 | Visualize the input events. 215 | :param inp_events: [batch_size x 2 x H x W] per-pixel and per-polarity event count 216 | :param color_scheme: green_red/gray 217 | :return event_image: [H x W x 3] color-coded event image 218 | """ 219 | pos = inp_events[:, :, 0] 220 | neg = inp_events[:, :, 1] 221 | pos_max = np.percentile(pos, 99) 222 | pos_min = np.percentile(pos, 1) 223 | neg_max = np.percentile(neg, 99) 224 | neg_min = np.percentile(neg, 1) 225 | max = pos_max if pos_max > neg_max else neg_max 226 | 227 | if pos_min != max: 228 | pos = (pos - pos_min) / (max - pos_min) 229 | if neg_min != max: 230 | neg = (neg - neg_min) / (max - neg_min) 231 | 232 | pos = np.clip(pos, 0, 1) 233 | neg = np.clip(neg, 0, 1) 234 | 235 | event_image = np.ones((inp_events.shape[0], inp_events.shape[1])) 236 | if color_scheme == "gray": 237 | event_image *= 0.5 238 | pos *= 0.5 239 | neg *= -0.5 240 | event_image += pos + neg 241 | 242 | elif color_scheme == "green_red": 243 | event_image = np.repeat(event_image[:, :, np.newaxis], 3, axis=2) 244 | event_image *= 0 245 | mask_pos = pos > 0 246 | mask_neg = neg > 0 247 | mask_not_pos = pos == 0 248 | mask_not_neg = neg == 0 249 | 250 | event_image[:, :, 0][mask_pos] = 0 251 | event_image[:, :, 1][mask_pos] = pos[mask_pos] 252 | event_image[:, :, 2][mask_pos * mask_not_neg] = 0 253 | event_image[:, :, 2][mask_neg] = neg[mask_neg] 254 | event_image[:, :, 0][mask_neg] = 0 255 | event_image[:, :, 1][mask_neg * mask_not_pos] = 0 256 | 257 | return event_image 258 | --------------------------------------------------------------------------------