├── .gitattributes ├── .gitignore ├── README.md ├── data └── .gitkeep ├── download_dsec_test.py ├── environment.yml └── idn ├── check_submission.py ├── checkpoint ├── id-4x.pt ├── id-8x.pt └── tid.pt ├── config ├── data_loader_base.yaml ├── dataset │ ├── dsec.yaml │ ├── dsec_augmentation.yaml │ ├── dsec_base.yaml │ └── dsec_rec.yaml ├── hydra │ └── custom_hydra.yaml ├── id_eval.yaml ├── id_train.yaml ├── logger_base.yaml ├── model │ ├── id-4x.yaml │ ├── id-8x.yaml │ └── idedeqid.yaml ├── mvsec_train.yaml ├── tid_eval.yaml ├── tid_train.yaml ├── torch_environ_base.yaml └── validation │ ├── co.yaml │ ├── data_loader_val_base.yaml │ ├── dsec_co.yaml │ ├── dsec_test.yaml │ ├── mvsec_day1.yaml │ └── nonrec.yaml ├── eval.py ├── loader ├── loader_dsec.py └── loader_mvsec.py ├── model ├── extractor.py ├── idedeq.py ├── loss.py └── update.py ├── scripts └── format_mvsec │ ├── eval_utils.py │ ├── format_mvsec.py │ └── h5_packager.py ├── tests ├── dsec.py ├── eval.py └── test.py ├── train.py └── utils ├── callbacks.py ├── cb ├── logger.py └── validator.py ├── dsec_utils.py ├── exp_tracker.py ├── helper_functions.py ├── logger.py ├── loss_utils.py ├── model_utils.py ├── mvsec_utils.py ├── retrieval_fn.py ├── torch_environ.py ├── trainer.py ├── transformers.py └── validation.py /.gitattributes: -------------------------------------------------------------------------------- 1 | idn/checkpoint/*.pt filter=lfs diff=lfs merge=lfs -text -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | data/* 4 | !data/.gitkeep -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lightweight Event-based Optical Flow Estimation via Iterative Deblurring 2 | 3 | Work accepted to 2024 IEEE International Conference on Robotics and Automation (ICRA'24) [[paper](https://arxiv.org/abs/2211.13726), [video](https://www.youtube.com/watch?v=1qA1hONS4Sw)]. 4 | 5 | 6 | idnet-graphical-abstract 7 | 8 | 9 | ![id-viz](https://github.com/tudelft/idnet/assets/10345786/f6314f3a-7e24-444a-bd28-695267ede7b4) 10 | 11 | If you use this code in an academic context, please cite our work: 12 | 13 | ```bibtex 14 | @INPROCEEDINGS{10610353, 15 | author={Wu, Yilun and Paredes-Vallés, Federico and de Croon, Guido C. H. E.}, 16 | booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)}, 17 | title={Lightweight Event-based Optical Flow Estimation via Iterative Deblurring}, 18 | year={2024}, 19 | volume={}, 20 | number={}, 21 | pages={14708-14715}, 22 | keywords={Image motion analysis;Correlation;Memory management;Estimation;Rendering (computer graphics);Real-time systems;Iterative algorithms}, 23 | doi={10.1109/ICRA57147.2024.10610353}} 24 | ``` 25 | 26 | ## Dependency 27 | Create a conda env and install dependencies by running 28 | ``` 29 | conda env create --file environment.yml 30 | ``` 31 | 32 | ## Download (For Evaluation) 33 | The DSEC dataset for optical flow can be downloaded [here](https://dsec.ifi.uzh.ch/dsec-datasets/download/). 34 | Use script [download_dsec_test.py](download_dsec_test.py) for your convenience. 35 | It downloads the dataset directly into the `DATA_DIRECTORY` with the expected directory structure. 36 | ```python 37 | download_dsec_test.py 38 | ``` 39 | Once downloaded, create a symbolic link called `data` pointing to the data directory: 40 | ``` 41 | ln -s data/test 42 | ``` 43 | 44 | ## Download (For Training) 45 | For training on DSEC, two more folders need to be downloaded: 46 | 47 | - Unzip [train_events.zip](https://download.ifi.uzh.ch/rpg/DSEC/train_coarse/train_events.zip) to data/train_events 48 | 49 | - Unzip [train_optical_flow.zip](https://download.ifi.uzh.ch/rpg/DSEC/train_coarse/train_optical_flow.zip) to data/train_optical_flow 50 | 51 | or establish symbolic links under data/ pointing to the folders. 52 | 53 | ## Download (MVSEC) 54 | To run experiments on MVSEC, additionally download outdoor day sequences .h5 files from https://drive.google.com/open?id=1rwyRk26wtWeRgrAx_fgPc-ubUzTFThkV 55 | and place the files under data/ or point symbolic links pointing to the data files under data/. 56 | 57 | ## Run Evaluation 58 | 59 | To run eval: 60 | ``` 61 | cd idnet 62 | conda activate IDNet 63 | python -m idn.eval 64 | ``` 65 | 66 | Change the save directory for eval results in `idn/config/validation/dsec_test.yaml` if you prefer. The default is at `/tmp/collect/XX`. 67 | 68 | To switch between models, change the model option in `idn/config/id_eval.yaml` to switch between id model with 1/4 and 1/8 resolution. 69 | 70 | To eval TID model, change the function decorator above the main function in `eval.py`. 71 | 72 | At the end of evaluation, a zip file containing the results will be created in the save directory, for which you can upload to the DSEC benchmark website to reproduce our results. 73 | 74 | ## Run Training 75 | To train IDNet, run: 76 | ``` 77 | cd idnet 78 | conda activate IDNet 79 | python -m idn.train 80 | ``` 81 | 82 | Similarly, switch between id-4x, id-8x and tid models and MVSEC training by changing the hydra.main() decorator in `train.py` and settings in the corresponding .yaml file. 83 | 84 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tudelft/idnet/de142df364cbbbc8b81ed72a4a94cca9fd8adab6/data/.gitkeep -------------------------------------------------------------------------------- /download_dsec_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import os 4 | import urllib 5 | import shutil 6 | from typing import Union 7 | 8 | from requests import get 9 | 10 | TEST_SEQUENCES = ['interlaken_00_b', 'interlaken_01_a', 'thun_01_a', 11 | 'thun_01_b', 'zurich_city_12_a', 'zurich_city_14_c', 'zurich_city_15_a'] 12 | BASE_TEST_URL = 'https://download.ifi.uzh.ch/rpg/DSEC/test/' 13 | TEST_FLOW_TIMESTAMPS_URL = 'https://download.ifi.uzh.ch/rpg/DSEC/test_forward_optical_flow_timestamps.zip' 14 | 15 | 16 | def download(url: str, filepath: Path, skip: bool = True) -> bool: 17 | if skip and filepath.exists(): 18 | print(f'{str(filepath)} already exists. Skipping download.') 19 | return True 20 | with open(str(filepath), 'wb') as fl: 21 | response = get(url) 22 | fl.write(response.content) 23 | return response.ok 24 | 25 | 26 | def unzip(file_: Path, delete_zip: bool = True, skip: bool = True) -> Path: 27 | assert file_.exists() 28 | assert file_.suffix == '.zip' 29 | output_dir = file_.parent / file_.stem 30 | if skip and output_dir.exists(): 31 | print(f'{str(output_dir)} already exists. Skipping unzipping operation.') 32 | else: 33 | shutil.unpack_archive(file_, output_dir) 34 | if delete_zip: 35 | os.remove(file_) 36 | return output_dir 37 | 38 | 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('output_directory') 42 | 43 | args = parser.parse_args() 44 | 45 | output_dir = Path(args.output_directory) 46 | output_dir = output_dir / 'test' 47 | os.makedirs(output_dir, exist_ok=True) 48 | 49 | test_timestamps_file = output_dir / 'test_forward_flow_timestamps.zip' 50 | 51 | assert download(TEST_FLOW_TIMESTAMPS_URL, 52 | test_timestamps_file), TEST_FLOW_TIMESTAMPS_URL 53 | test_timestamps_dir = unzip(test_timestamps_file) 54 | 55 | for seq_name in TEST_SEQUENCES: 56 | seq_path = output_dir / seq_name 57 | os.makedirs(seq_path, exist_ok=True) 58 | 59 | # image timestamps 60 | img_timestamps_url = BASE_TEST_URL + seq_name + \ 61 | '/' + seq_name + '_image_timestamps.txt' 62 | img_timestamps_file = seq_path / 'image_timestamps.txt' 63 | if not img_timestamps_file.exists(): 64 | assert download(img_timestamps_url, 65 | img_timestamps_file), img_timestamps_url 66 | 67 | # test timestamps 68 | test_timestamps_file_destination = seq_path / 'test_forward_flow_timestamps.csv' 69 | if not test_timestamps_file_destination.exists(): 70 | shutil.move(test_timestamps_dir / (seq_name + '.csv'), 71 | test_timestamps_file_destination) 72 | 73 | # event data 74 | events_left_url = BASE_TEST_URL + seq_name + '/' + seq_name + '_events_left.zip' 75 | events_left_file = seq_path / 'events_left.zip' 76 | if not (events_left_file.parent / events_left_file.stem).exists(): 77 | assert download(events_left_url, events_left_file), events_left_url 78 | unzip(events_left_file) 79 | 80 | shutil.rmtree(test_timestamps_dir) 81 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: IDNet 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.8 7 | - pytorch=1.13.1 8 | - pytorch-cuda=11.7 9 | - torchvision 10 | - scikit-image 11 | - numba 12 | - pandas 13 | - termcolor 14 | - h5py 15 | - tqdm 16 | - matplotlib 17 | - imageio 18 | - ipykernel 19 | - pip 20 | - pip: 21 | - torchinfo 22 | - hydra-core 23 | - wandb 24 | - gitignore-parser 25 | - opencv-python-headless 26 | - hdf5plugin 27 | - nbformat 28 | - nbstripout 29 | - rerun-sdk 30 | 31 | -------------------------------------------------------------------------------- /idn/check_submission.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | from typing import Dict 4 | from pathlib import Path 5 | import os 6 | from enum import Enum, auto 7 | import argparse 8 | import sys 9 | version = sys.version_info 10 | assert version[0] >= 3, 'Python 2 is not supported' 11 | assert version[1] >= 6, 'Requires Python 3.6 or higher' 12 | 13 | os.environ['IMAGEIO_USERDIR'] = '/var/tmp' 14 | 15 | imageio.plugins.freeimage.download() 16 | 17 | 18 | class WriteFormat(Enum): 19 | OPENCV = auto() 20 | IMAGEIO = auto() 21 | 22 | 23 | def is_string_swiss(input_str: str) -> bool: 24 | is_swiss = False 25 | is_swiss |= 'thun_' in input_str 26 | is_swiss |= 'interlaken_' in input_str 27 | is_swiss |= 'zurich_city_' in input_str 28 | return is_swiss 29 | 30 | 31 | def flow_16bit_to_float(flow_16bit: np.ndarray, valid_in_3rd_channel: bool): 32 | assert flow_16bit.dtype == np.uint16 33 | assert flow_16bit.ndim == 3 34 | h, w, c = flow_16bit.shape 35 | assert c == 3 36 | 37 | if valid_in_3rd_channel: 38 | valid2D = flow_16bit[..., 2] == 1 39 | assert valid2D.shape == (h, w) 40 | assert np.all(flow_16bit[~valid2D, -1] == 0) 41 | else: 42 | valid2D = np.ones_like(flow_16bit[..., 2], dtype=np.bool) 43 | valid_map = np.where(valid2D) 44 | 45 | # to actually compute something useful: 46 | flow_16bit = flow_16bit.astype('float') 47 | 48 | flow_map = np.zeros((h, w, 2)) 49 | flow_map[valid_map[0], valid_map[1], 0] = ( 50 | flow_16bit[valid_map[0], valid_map[1], 0] - 2**15) / 128 51 | flow_map[valid_map[0], valid_map[1], 1] = ( 52 | flow_16bit[valid_map[0], valid_map[1], 1] - 2**15) / 128 53 | return flow_map, valid2D 54 | 55 | 56 | def load_flow(flowfile: Path, valid_in_3rd_channel: bool, write_format=WriteFormat): 57 | assert flowfile.exists() 58 | assert flowfile.suffix == '.png' 59 | 60 | # imageio reading assumes write format was rgb 61 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI') 62 | if write_format == WriteFormat.OPENCV: 63 | # opencv writes as bgr -> flip last axis to get rgb 64 | flow_16bit = np.flip(flow_16bit, axis=-1) 65 | else: 66 | assert write_format == WriteFormat.IMAGEIO 67 | 68 | channel3 = flow_16bit[..., -1] 69 | assert channel3.max( 70 | ) <= 1, f'Maximum value in last channel should be 1: {flowfile}' 71 | flow, valid2D = flow_16bit_to_float(flow_16bit, valid_in_3rd_channel) 72 | return flow, valid2D 73 | 74 | 75 | def list_of_dirs(dirpath: Path): 76 | return next(os.walk(dirpath))[1] 77 | 78 | 79 | def files_per_sequence(flow_timestamps_dir: Path) -> Dict[str, int]: 80 | out_dict = dict() 81 | for entry in flow_timestamps_dir.iterdir(): 82 | assert entry.is_file() 83 | assert entry.suffix == '.csv', entry.suffix 84 | assert is_string_swiss(entry.stem), entry.stem 85 | data = np.loadtxt(entry, dtype=np.int64, delimiter=', ', comments='#') 86 | assert data.ndim == 2, data.ndim 87 | num_files = data.shape[0] 88 | out_dict[entry.stem] = num_files 89 | return out_dict 90 | 91 | 92 | def check_submission(submission_dir: Path, flow_timestamps_dir: Path): 93 | assert flow_timestamps_dir.is_dir() 94 | assert submission_dir.is_dir() 95 | 96 | name2num = files_per_sequence(flow_timestamps_dir) 97 | 98 | expected_flow_shape = (480, 640, 2) 99 | expected_valid_shape = (480, 640) 100 | expected_dir_names = set([*name2num]) 101 | actual_dir_names = set(list_of_dirs(submission_dir)) 102 | assert expected_dir_names == actual_dir_names, f'Expected directories in your submission: {expected_dir_names}.\nMissing directories: {expected_dir_names.difference(actual_dir_names)}' 103 | 104 | for seq in submission_dir.iterdir(): 105 | if not seq.is_dir(): 106 | continue 107 | assert seq.is_dir() 108 | assert is_string_swiss(seq.name), seq.name 109 | num_files = 0 110 | for prediction in seq.iterdir(): 111 | flow, valid_map = load_flow( 112 | prediction, valid_in_3rd_channel=False, write_format=WriteFormat.IMAGEIO) 113 | assert flow.shape == expected_flow_shape, f'Expected shape: {expected_flow_shape}, actual shape: {flow.shape}' 114 | assert valid_map.shape == expected_valid_shape, f'Expected shape: {expected_valid_shape}, actual shape: {valid_map.shape}' 115 | num_files += 1 116 | assert seq.name in [*name2num], f'{seq.name} not in {[*name2num]}' 117 | assert num_files == name2num[ 118 | seq.name], f'expected {name2num[seq.name]} files in {str(seq)} but only found {num_files} files' 119 | 120 | return True 121 | 122 | 123 | if __name__ == '__main__': 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument('submission_dir', help='Path to submission directory') 126 | parser.add_argument('flow_timestamps_dir', 127 | help='Path to directory containing the flow timestamps for evaluation.') 128 | 129 | args = parser.parse_args() 130 | 131 | print('start checking submission') 132 | check_submission(Path(args.submission_dir), Path(args.flow_timestamps_dir)) 133 | print('Your submission directory has the correct structure: Ready to submit!\n') 134 | print('Note, that we will sort the files according to their names in each directory and evaluate them sequentially. Follow the exact naming instructions if you are unsure.') 135 | -------------------------------------------------------------------------------- /idn/checkpoint/id-4x.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e28541f311cc4f52282a1d5321194f2baf1ce91ac5a4aaf4be9e28d525411ff5 3 | size 10219203 4 | -------------------------------------------------------------------------------- /idn/checkpoint/id-8x.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:044d107030c4d2f7fdaa3d41bf48d84960384bb6864af9e5639e6577651f2095 3 | size 5731319 4 | -------------------------------------------------------------------------------- /idn/checkpoint/tid.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f1134a4c61c8556a8f53981243895fffe819c64e3ca4c64fdc41943c0ffa617f 3 | size 7551263 4 | -------------------------------------------------------------------------------- /idn/config/data_loader_base.yaml: -------------------------------------------------------------------------------- 1 | data_loader: 2 | train: 3 | gpu: ??? 4 | args: 5 | batch_size: ??? 6 | shuffle: true 7 | num_workers: 2 8 | pin_memory: false 9 | prefetch_factor: 2 10 | persistent_workers: false 11 | val: 12 | gpu: ??? 13 | mp: false 14 | batch_freq: 1500 15 | args: 16 | batch_size: 1 17 | shuffle: false 18 | num_workers: 1 19 | pin_memory: true 20 | prefetch_factor: 2 -------------------------------------------------------------------------------- /idn/config/dataset/dsec.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dsec_base 3 | 4 | num_voxel_bins: 15 5 | 6 | train: 7 | use_all_seqs: true 8 | seq: [] 9 | load_gt: true 10 | downsample_ratio: 1 11 | val: 12 | seq: [zurich_city_01_a] 13 | downsample_ratio: 1 14 | concat_seq: false 15 | -------------------------------------------------------------------------------- /idn/config/dataset/dsec_augmentation.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dsec 3 | 4 | train: 5 | vertical_flip: 0.1 6 | horizontal_flip: 0.5 7 | random_crop: [288, 384] 8 | 9 | -------------------------------------------------------------------------------- /idn/config/dataset/dsec_base.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: dsec 2 | common: 3 | data_root: data/ 4 | test_root: data/test/ -------------------------------------------------------------------------------- /idn/config/dataset/dsec_rec.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dsec 3 | 4 | train: 5 | recurrent: true 6 | sequence_length: 4 7 | vertical_flip: 0.1 8 | horizontal_flip: 0.5 9 | random_crop: [384, 512] 10 | 11 | val: 12 | recurrent: true 13 | sequence_length: 4 14 | concat_seq: false -------------------------------------------------------------------------------- /idn/config/hydra/custom_hydra.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | dir: . 3 | output_subdir: null 4 | 5 | defaults: 6 | - override job_logging: none 7 | - override hydra_logging: none 8 | - _self_ -------------------------------------------------------------------------------- /idn/config/id_eval.yaml: -------------------------------------------------------------------------------- 1 | eval_only: True 2 | callbacks: 3 | 4 | defaults: 5 | - hydra: custom_hydra 6 | - data_loader_base 7 | - model: id-8x # id-4x 8 | - dataset: dsec 9 | - torch_environ_base 10 | - _self_ 11 | 12 | -------------------------------------------------------------------------------- /idn/config/id_train.yaml: -------------------------------------------------------------------------------- 1 | deterministic: false 2 | track: false 3 | num_epoch: 100 4 | data_loader: 5 | common: 6 | num_voxel_bins: 15 7 | train: 8 | gpu: 0 9 | args: 10 | batch_size: 3 11 | val: 12 | gpu: 0 13 | batch_freq: 1500 14 | loss: 15 | final_prediction_nonseq: 16 | loss_type: sparse_l1 17 | weight: 1.0 18 | seq_weight: avg 19 | seq_norm: false 20 | 21 | 22 | dataset: 23 | representation_type: voxel 24 | 25 | 26 | optim: 27 | optimizer: adam 28 | scheduler: onecycle 29 | lr: 1e-4 30 | 31 | callbacks: 32 | logger: 33 | enable: 34 | log_keys: 35 | batch_end: 36 | - loss 37 | - lr 38 | 39 | 40 | validator: 41 | enable: 42 | frequency_type: step 43 | frequency: 500 44 | sanity_run_step: 3 45 | 46 | 47 | validation: 48 | nonrec: 49 | dataset: 50 | representation_type: ${dataset.representation_type} 51 | 52 | model: 53 | pretrain_ckpt: null 54 | 55 | defaults: 56 | - validation@_group_.nonrec: nonrec 57 | - hydra: custom_hydra 58 | - data_loader_base 59 | - torch_environ_base 60 | - model: id-8x # id-4x 61 | - dataset: dsec_augmentation 62 | - _self_ 63 | -------------------------------------------------------------------------------- /idn/config/logger_base.yaml: -------------------------------------------------------------------------------- 1 | logger: 2 | saved_tensors: 3 | 4 | statistics: ['avg'] -------------------------------------------------------------------------------- /idn/config/model/id-4x.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - idedeqid 3 | 4 | hidden_dim: 128 5 | downsample: 4 6 | pretrain_ckpt: idn/checkpoint/id-4x.pt -------------------------------------------------------------------------------- /idn/config/model/id-8x.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - idedeqid 3 | 4 | hidden_dim: 96 5 | downsample: 8 6 | pretrain_ckpt: idn/checkpoint/id-8x.pt -------------------------------------------------------------------------------- /idn/config/model/idedeqid.yaml: -------------------------------------------------------------------------------- 1 | name: IDEDEQIDO 2 | deblur: true 3 | add_delta: true 4 | update_iters: 4 5 | zero_init: true 6 | deq_mode: false 7 | input_flowmap: true 8 | -------------------------------------------------------------------------------- /idn/config/mvsec_train.yaml: -------------------------------------------------------------------------------- 1 | deterministic: false 2 | track: false 3 | num_epoch: 150 4 | run_val: false 5 | 6 | data_loader: 7 | common: 8 | num_voxel_bins: 15 9 | train: 10 | gpu: 0 11 | args: 12 | batch_size: 3 13 | pin_memory: false 14 | val: 15 | gpu: 0 16 | batch_freq: 1500 17 | loss: 18 | final_prediction_nonseq: 19 | loss_type: sparse_l1 20 | weight: 1.0 21 | seq_weight: avg 22 | seq_norm: false 23 | 24 | dataset: 25 | dataset_name: mvsec 26 | num_voxel_bins: 15 27 | train: 28 | 29 | 30 | optim: 31 | optimizer: adam 32 | scheduler: onecycle 33 | lr: 1e-4 34 | 35 | callbacks: 36 | logger: 37 | enable: 38 | log_keys: 39 | batch_end: 40 | - loss 41 | 42 | 43 | validator: 44 | enable: 45 | frequency_type: step 46 | frequency: 500 47 | sanity_run_step: 3 48 | 49 | 50 | model: 51 | pretrain_ckpt: null 52 | 53 | defaults: 54 | - validation@_group_.mvsec_day1: mvsec_day1 55 | - hydra: custom_hydra 56 | - data_loader_base 57 | - torch_environ_base 58 | - model: id-8x 59 | - _self_ 60 | -------------------------------------------------------------------------------- /idn/config/tid_eval.yaml: -------------------------------------------------------------------------------- 1 | eval_only: True 2 | callbacks: 3 | 4 | defaults: 5 | - validation@_group_.co: co 6 | - hydra: custom_hydra 7 | - data_loader_base 8 | - torch_environ_base 9 | - model: idedeqid 10 | - dataset: dsec_rec 11 | - _self_ 12 | 13 | model: 14 | name: RecIDE 15 | update_iters: 1 16 | pred_next_flow: true 17 | pretrain_ckpt: idn/checkpoint/tid.pt -------------------------------------------------------------------------------- /idn/config/tid_train.yaml: -------------------------------------------------------------------------------- 1 | deterministic: false 2 | num_epoch: 100 3 | run_val: false 4 | data_loader: 5 | common: 6 | num_voxel_bins: 15 7 | train: 8 | gpu: 0 9 | args: 10 | batch_size: 3 11 | val: 12 | gpu: 0 13 | batch_freq: 1500 14 | loss: 15 | pred_flow_seq: 16 | loss_type: sparse_l1 17 | weight: 1.0 18 | seq_weight: [0.17, 0.21, 0.27, 0.35] 19 | seq_norm: false 20 | pred_flow_next_seq: 21 | loss_type: sparse_l1 22 | weight: 1.0 23 | seq_weight: [0.17, 0.21, 0.27, 0.35] 24 | seq_norm: false 25 | 26 | dataset: 27 | train: 28 | sequence_length: 4 29 | 30 | optim: 31 | optimizer: adam 32 | scheduler: onecycle 33 | lr: 1e-4 34 | 35 | callbacks: 36 | logger: 37 | enable: 38 | log_keys: 39 | batch_end: 40 | - loss 41 | - loss_pred_flow_seq 42 | - loss_pred_flow_next_seq 43 | 44 | 45 | validator: 46 | enable: 47 | frequency_type: step 48 | frequency: 1000 49 | sanity_run_step: 3 50 | 51 | 52 | 53 | 54 | defaults: 55 | - validation@_group_.co: co 56 | - hydra: custom_hydra 57 | - data_loader_base 58 | - torch_environ_base 59 | - model: idedeqid 60 | - dataset: dsec_rec 61 | - _self_ 62 | 63 | model: 64 | name: RecIDE 65 | update_iters: 1 66 | pred_next_flow: true 67 | -------------------------------------------------------------------------------- /idn/config/torch_environ_base.yaml: -------------------------------------------------------------------------------- 1 | torch: 2 | debug_grad: false 3 | deterministic: false -------------------------------------------------------------------------------- /idn/config/validation/co.yaml: -------------------------------------------------------------------------------- 1 | name: continuous-operation 2 | val_batch_freq: 1500 3 | data_loader: 4 | gpu: 0 5 | 6 | dataset: 7 | val: 8 | sequence_length: 1 9 | 10 | logger: 11 | saved_tensors: 12 | final_prediction: 13 | flow_gt_event_volume_new: 14 | 15 | metrics: 16 | final_prediction: ["L1", "L2"] 17 | next_flow: ["L1", "L2"] 18 | 19 | postprocess: 20 | 21 | defaults: 22 | - /validation/data_loader_val_base@_here_ 23 | - /dataset/dsec_rec@dataset 24 | - /logger_base@_here_ -------------------------------------------------------------------------------- /idn/config/validation/data_loader_val_base.yaml: -------------------------------------------------------------------------------- 1 | data_loader: 2 | gpu: ??? 3 | mp: false 4 | batch_freq: 1500 5 | args: 6 | batch_size: 1 7 | shuffle: false 8 | num_workers: 1 9 | pin_memory: true 10 | prefetch_factor: 2 -------------------------------------------------------------------------------- /idn/config/validation/dsec_co.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | val: 3 | sequence_length: 1 4 | recurrent: true 5 | 6 | defaults: 7 | - dsec_test 8 | - _self_ 9 | -------------------------------------------------------------------------------- /idn/config/validation/dsec_test.yaml: -------------------------------------------------------------------------------- 1 | name: dsec-test 2 | data_loader: 3 | gpu: 0 4 | mp: false 5 | batch_freq: 1500 6 | args: 7 | batch_size: 1 8 | shuffle: false 9 | num_workers: 1 10 | pin_memory: true 11 | prefetch_factor: 2 12 | 13 | 14 | logger: 15 | save_dir: /tmp/collect/XX 16 | saved_tensors: 17 | 18 | postprocess: 19 | 20 | metrics: 21 | 22 | hydra: 23 | output_subdir: null 24 | 25 | defaults: 26 | - /dataset/dsec@dataset 27 | - /logger_base@_here_ 28 | - _self_ 29 | -------------------------------------------------------------------------------- /idn/config/validation/mvsec_day1.yaml: -------------------------------------------------------------------------------- 1 | name: mvsec_day1 2 | val_batch_freq: 1500 3 | data_loader: 4 | gpu: 0 5 | args: 6 | batch_size: 1 7 | pin_memory: false 8 | shuffle: false 9 | num_workers: 0 10 | prefetch_factor: 2 11 | 12 | 13 | logger: 14 | saved_tensors: 15 | final_prediction: 16 | flow_gt_event_volume_new: 17 | 18 | dataset: 19 | val: 20 | recurrent: false 21 | 22 | 23 | postprocess: 24 | 25 | metrics: 26 | final_prediction_nonseq: ["L1", "L2", "1PE", "3PE"] 27 | 28 | hydra: 29 | output_subdir: null 30 | 31 | defaults: 32 | - /validation/data_loader_val_base@_here_ 33 | - /logger_base@_here_ 34 | -------------------------------------------------------------------------------- /idn/config/validation/nonrec.yaml: -------------------------------------------------------------------------------- 1 | name: non-rec 2 | val_batch_freq: 1500 3 | data_loader: 4 | gpu: 0 5 | 6 | 7 | logger: 8 | saved_tensors: 9 | final_prediction: 10 | flow_gt_event_volume_new: 11 | 12 | postprocess: 13 | 14 | metrics: 15 | final_prediction_nonseq: ["L1", "L2"] 16 | 17 | hydra: 18 | output_subdir: null 19 | 20 | defaults: 21 | - /validation/data_loader_val_base@_here_ 22 | - /dataset/dsec@dataset 23 | - /logger_base@_here_ 24 | -------------------------------------------------------------------------------- /idn/eval.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import hydra 3 | from hydra import initialize, compose 4 | from .utils.trainer import Trainer 5 | from .utils.validation import Validator 6 | 7 | 8 | 9 | def test(trainer): 10 | test_cfg = compose(config_name="validation/dsec_test", 11 | overrides=[]).validation 12 | Validator.get_test_type("dsec")(test_cfg).execute_test( 13 | trainer.model, save_all=False) 14 | 15 | 16 | def test_co(trainer): 17 | test_cfg = compose(config_name="validation/dsec_co", 18 | overrides=[]).validation 19 | Validator.get_test_type("dsec", "co")( 20 | test_cfg).execute_test(trainer.model, save_all=False) 21 | 22 | 23 | 24 | 25 | # @hydra.main(config_path="config", config_name="tid_eval") 26 | @hydra.main(config_path="config", config_name="id_eval") 27 | 28 | def main(config): 29 | print(OmegaConf.to_yaml(config)) 30 | 31 | trainer = Trainer(config) 32 | 33 | print("Number of parameters: ", sum(p.numel() 34 | for p in trainer.model.parameters() if p.requires_grad)) 35 | 36 | 37 | if config.model.name == "RecIDE": 38 | test_co(trainer) 39 | elif config.model.name == "IDEDEQIDO": 40 | test(trainer) 41 | 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | -------------------------------------------------------------------------------- /idn/loader/loader_dsec.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pathlib import Path, PurePath 3 | from random import sample 4 | import random 5 | from typing import Dict, Tuple 6 | import weakref 7 | from time import time 8 | import cv2 9 | # import h5pickle as h5py 10 | import h5py 11 | from numba import jit 12 | import numpy as np 13 | import os 14 | import imageio 15 | import hashlib 16 | import mkl 17 | import torch 18 | from torchvision.transforms import ToTensor, RandomCrop 19 | from torchvision import transforms as tf 20 | from torch.utils.data import Dataset, DataLoader 21 | from matplotlib import pyplot as plt, transforms 22 | from ..utils import transformers 23 | 24 | 25 | from ..utils.dsec_utils import RepresentationType, VoxelGrid, PolarityCount, flow_16bit_to_float 26 | from ..utils.transformers import ( 27 | downsample_spatial, 28 | downsample_spatial_mask, 29 | apply_transform_to_field, 30 | apply_randomcrop_to_sample) 31 | 32 | VISU_INDEX = 1 33 | 34 | 35 | class EventSlicer: 36 | def __init__(self, h5f: h5py.File): 37 | self.h5f = h5f 38 | 39 | self.events = dict() 40 | for dset_str in ['p', 'x', 'y', 't']: 41 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)] 42 | 43 | # This is the mapping from milliseconds to event index: 44 | # It is defined such that 45 | # (1) t[ms_to_idx[ms]] >= ms*1000 46 | # (2) t[ms_to_idx[ms] - 1] < ms*1000 47 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds. 48 | # 49 | # As an example, given 't' and 'ms': 50 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000 51 | # ms: 0 1 2 3 4 5 6 7 8 9 52 | # 53 | # we get 54 | # 55 | # ms_to_idx: 56 | # 0 2 2 3 3 3 5 5 8 9 57 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64') 58 | 59 | self.t_offset = int(h5f['t_offset'][()]) 60 | self.t_final = int(self.events['t'][-1]) + self.t_offset 61 | 62 | def get_final_time_us(self): 63 | return self.t_final 64 | 65 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]: 66 | """Get events (p, x, y, t) within the specified time window 67 | Parameters 68 | ---------- 69 | t_start_us: start time in microseconds 70 | t_end_us: end time in microseconds 71 | Returns 72 | ------- 73 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved 74 | """ 75 | assert t_start_us < t_end_us 76 | 77 | # We assume that the times are top-off-day, hence subtract offset: 78 | t_start_us -= self.t_offset 79 | t_end_us -= self.t_offset 80 | 81 | t_start_ms, t_end_ms = self.get_conservative_window_ms( 82 | t_start_us, t_end_us) 83 | t_start_ms_idx = self.ms2idx(t_start_ms) 84 | t_end_ms_idx = self.ms2idx(t_end_ms) 85 | 86 | if t_start_ms_idx is None or t_end_ms_idx is None: 87 | # Cannot guarantee window size anymore 88 | return None 89 | 90 | events = dict() 91 | time_array_conservative = np.asarray( 92 | self.events['t'][t_start_ms_idx:t_end_ms_idx]) 93 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets( 94 | time_array_conservative, t_start_us, t_end_us) 95 | t_start_us_idx = t_start_ms_idx + idx_start_offset 96 | t_end_us_idx = t_start_ms_idx + idx_end_offset 97 | # Again add t_offset to get gps time 98 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset 99 | for dset_str in ['p', 'x', 'y']: 100 | events[dset_str] = np.asarray( 101 | self.events[dset_str][t_start_us_idx:t_end_us_idx]) 102 | assert events[dset_str].size == events['t'].size 103 | return events 104 | 105 | @staticmethod 106 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]: 107 | """Compute a conservative time window of time with millisecond resolution. 108 | We have a time to index mapping for each millisecond. Hence, we need 109 | to compute the lower and upper millisecond to retrieve events. 110 | Parameters 111 | ---------- 112 | ts_start_us: start time in microseconds 113 | ts_end_us: end time in microseconds 114 | Returns 115 | ------- 116 | window_start_ms: conservative start time in milliseconds 117 | window_end_ms: conservative end time in milliseconds 118 | """ 119 | assert ts_end_us > ts_start_us 120 | window_start_ms = math.floor(ts_start_us/1000) 121 | window_end_ms = math.ceil(ts_end_us/1000) 122 | return window_start_ms, window_end_ms 123 | 124 | @staticmethod 125 | @jit(nopython=True) 126 | def get_time_indices_offsets( 127 | time_array: np.ndarray, 128 | time_start_us: int, 129 | time_end_us: int) -> Tuple[int, int]: 130 | """Compute index offset of start and end timestamps in microseconds 131 | Parameters 132 | ---------- 133 | time_array: timestamps (in us) of the events 134 | time_start_us: start timestamp (in us) 135 | time_end_us: end timestamp (in us) 136 | Returns 137 | ------- 138 | idx_start: Index within this array corresponding to time_start_us 139 | idx_end: Index within this array corresponding to time_end_us 140 | such that (in non-edge cases) 141 | time_array[idx_start] >= time_start_us 142 | time_array[idx_end] >= time_end_us 143 | time_array[idx_start - 1] < time_start_us 144 | time_array[idx_end - 1] < time_end_us 145 | this means that 146 | time_start_us <= time_array[idx_start:idx_end] < time_end_us 147 | """ 148 | 149 | assert time_array.ndim == 1 150 | 151 | idx_start = -1 152 | if time_array[-1] < time_start_us: 153 | # This can happen in extreme corner cases. E.g. 154 | # time_array[0] = 1016 155 | # time_array[-1] = 1984 156 | # time_start_us = 1990 157 | # time_end_us = 2000 158 | 159 | # Return same index twice: array[x:x] is empty. 160 | return time_array.size, time_array.size 161 | else: 162 | for idx_from_start in range(0, time_array.size, 1): 163 | if time_array[idx_from_start] >= time_start_us: 164 | idx_start = idx_from_start 165 | break 166 | assert idx_start >= 0 167 | 168 | idx_end = time_array.size 169 | for idx_from_end in range(time_array.size - 1, -1, -1): 170 | if time_array[idx_from_end] >= time_end_us: 171 | idx_end = idx_from_end 172 | else: 173 | break 174 | 175 | assert time_array[idx_start] >= time_start_us 176 | if idx_end < time_array.size: 177 | assert time_array[idx_end] >= time_end_us 178 | if idx_start > 0: 179 | assert time_array[idx_start - 1] < time_start_us 180 | if idx_end > 0: 181 | assert time_array[idx_end - 1] < time_end_us 182 | return idx_start, idx_end 183 | 184 | def ms2idx(self, time_ms: int) -> int: 185 | assert time_ms >= 0 186 | if time_ms >= self.ms_to_idx.size: 187 | return None 188 | return self.ms_to_idx[time_ms] 189 | 190 | 191 | class Sequence(Dataset): 192 | def __init__(self, seq_path: Path, representation_type: RepresentationType, mode: str = 'test', delta_t_ms: int = 100, 193 | num_bins: int = 15, transforms=[], name_idx=0, visualize=False, load_gt=False): 194 | assert num_bins >= 1 195 | assert delta_t_ms == 100 196 | assert seq_path.is_dir() 197 | assert mode in {'train', 'test'} 198 | ''' 199 | Directory Structure: 200 | 201 | Dataset 202 | └── test 203 | ├── interlaken_00_b 204 | │   ├── events_left 205 | │   │   ├── events.h5 206 | │   │   └── rectify_map.h5 207 | │   ├── image_timestamps.txt 208 | │   └── test_forward_flow_timestamps.csv 209 | 210 | ''' 211 | self.seq_name = PurePath(seq_path).name 212 | self.mode = mode 213 | self.name_idx = name_idx 214 | self.visualize_samples = visualize 215 | self.load_gt = load_gt 216 | self.transforms = transforms 217 | if self.mode is "test": 218 | # Get Test Timestamp File 219 | ev_dir_location = seq_path / 'events_left' 220 | timestamp_file = seq_path / 'test_forward_flow_timestamps.csv' 221 | flow_path = seq_path 222 | timestamps_images = np.loadtxt( 223 | flow_path / 'image_timestamps.txt', dtype='int64') 224 | self.indices = np.arange(len(timestamps_images))[::2][1:-1] 225 | self.timestamps_flow = timestamps_images[::2][1:-1] 226 | 227 | elif self.mode is "train": 228 | ev_dir_location = seq_path / 'events' / 'left' 229 | seq_name = seq_path.parts[-1] 230 | flow_path = seq_path.parents[1] / \ 231 | "train_optical_flow"/seq_name/'flow' 232 | timestamp_file = flow_path/'forward_timestamps.txt' 233 | self.flow_png = [Path(os.path.join(flow_path / 'forward', img)) for img in sorted( 234 | os.listdir(flow_path / 'forward'))] 235 | timestamps_images = np.loadtxt( 236 | flow_path / 'forward_timestamps.txt', delimiter=',', dtype='int64') 237 | self.indices = np.arange(len(timestamps_images) - 1) 238 | self.timestamps_flow = timestamps_images[1:, 0] 239 | else: 240 | pass 241 | assert timestamp_file.is_file() 242 | 243 | file = np.genfromtxt( 244 | timestamp_file, 245 | delimiter=',' 246 | ) 247 | 248 | self.idx_to_visualize = file[:, 2] if file.shape[1] == 3 else [] 249 | 250 | # Save output dimensions 251 | self.height = 480 252 | self.width = 640 253 | self.num_bins = num_bins 254 | 255 | # Just for now, we always train with num_bins=15 256 | assert self.num_bins == 15 257 | 258 | # Set event representation 259 | self.voxel_grid = None 260 | if representation_type == RepresentationType.VOXEL: 261 | self.voxel_grid = VoxelGrid( 262 | (self.num_bins, self.height, self.width), normalize=True) 263 | if representation_type == "count": 264 | self.voxel_grid = "count" 265 | if representation_type == "pcount": 266 | self.voxel_grid = PolarityCount((2, self.height, self.width)) 267 | 268 | # Save delta timestamp in ms 269 | self.delta_t_us = delta_t_ms * 1000 270 | 271 | # Left events only 272 | ev_data_file = ev_dir_location / 'events.h5' 273 | ev_rect_file = ev_dir_location / 'rectify_map.h5' 274 | 275 | h5f_location = h5py.File(str(ev_data_file), 'r') 276 | self.h5f = h5f_location 277 | self.event_slicer = EventSlicer(h5f_location) 278 | 279 | self.h5rect = h5py.File(str(ev_rect_file), 'r') 280 | self.rectify_ev_map = self.h5rect['rectify_map'][()] 281 | 282 | 283 | def events_to_voxel_grid(self, p, t, x, y, device: str = 'cpu'): 284 | t = (t - t[0]).astype('float32') 285 | t = (t/t[-1]) 286 | x = x.astype('float32') 287 | y = y.astype('float32') 288 | pol = p.astype('float32') 289 | event_data_torch = { 290 | 'p': torch.from_numpy(pol), 291 | 't': torch.from_numpy(t), 292 | 'x': torch.from_numpy(x), 293 | 'y': torch.from_numpy(y), 294 | } 295 | return self.voxel_grid.convert(event_data_torch) 296 | 297 | def getHeightAndWidth(self): 298 | return self.height, self.width 299 | 300 | @staticmethod 301 | def get_disparity_map(filepath: Path): 302 | assert filepath.is_file() 303 | disp_16bit = cv2.imread(str(filepath), cv2.IMREAD_ANYDEPTH) 304 | return disp_16bit.astype('float32')/256 305 | 306 | @staticmethod 307 | def load_flow(flowfile: Path): 308 | assert flowfile.exists() 309 | assert flowfile.suffix == '.png' 310 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI') 311 | flow, valid2D = flow_16bit_to_float(flow_16bit) 312 | return flow, valid2D 313 | 314 | @staticmethod 315 | def close_callback(h5f): 316 | h5f.close() 317 | 318 | def get_image_width_height(self): 319 | return self.height, self.width 320 | 321 | def __len__(self): 322 | return len(self.timestamps_flow) - 1 323 | 324 | def rectify_events(self, x: np.ndarray, y: np.ndarray): 325 | # assert location in self.locations 326 | # From distorted to undistorted 327 | rectify_map = self.rectify_ev_map 328 | assert rectify_map.shape == ( 329 | self.height, self.width, 2), rectify_map.shape 330 | assert x.max() < self.width 331 | assert y.max() < self.height 332 | return rectify_map[y, x] 333 | 334 | def get_data_sample(self, index, crop_window=None, flip=None): 335 | # First entry corresponds to all events BEFORE the flow map 336 | # Second entry corresponds to all events AFTER the flow map (corresponding to the actual fwd flow) 337 | names = ['event_volume_old', 'event_volume_new'] 338 | ts_start = [self.timestamps_flow[index] - 339 | self.delta_t_us, self.timestamps_flow[index]] 340 | ts_end = [self.timestamps_flow[index], 341 | self.timestamps_flow[index] + self.delta_t_us] 342 | 343 | file_index = self.indices[index] 344 | 345 | output = { 346 | 'file_index': file_index, 347 | 'timestamp': self.timestamps_flow[index], 348 | 'seq_name': self.seq_name 349 | } 350 | # Save sample for benchmark submission 351 | output['save_submission'] = file_index in self.idx_to_visualize 352 | output['visualize'] = self.visualize_samples 353 | 354 | for i in range(len(names)): 355 | event_data = self.event_slicer.get_events( 356 | ts_start[i], ts_end[i]) 357 | 358 | p = event_data['p'] 359 | t = event_data['t'] 360 | x = event_data['x'] 361 | y = event_data['y'] 362 | 363 | xy_rect = self.rectify_events(x, y) 364 | x_rect = xy_rect[:, 0] 365 | y_rect = xy_rect[:, 1] 366 | 367 | if crop_window is not None: 368 | # Cropping (+- 2 for safety reasons) 369 | x_mask = (x_rect >= crop_window['start_x']-2) & ( 370 | x_rect < crop_window['start_x']+crop_window['crop_width']+2) 371 | y_mask = (y_rect >= crop_window['start_y']-2) & ( 372 | y_rect < crop_window['start_y']+crop_window['crop_height']+2) 373 | mask_combined = x_mask & y_mask 374 | p = p[mask_combined] 375 | t = t[mask_combined] 376 | x_rect = x_rect[mask_combined] 377 | y_rect = y_rect[mask_combined] 378 | 379 | if self.voxel_grid is None: 380 | raise NotImplementedError 381 | else: 382 | event_representation = self.events_to_voxel_grid( 383 | p, t, x_rect, y_rect) 384 | output[names[i]] = event_representation 385 | output['name_map'] = self.name_idx 386 | 387 | if self.load_gt: 388 | output['flow_gt_' + names[i] 389 | ] = [torch.tensor(x) for x in self.load_flow(self.flow_png[index + i])] 390 | 391 | output['flow_gt_' + names[i] 392 | ][0] = torch.moveaxis(output['flow_gt_' + names[i]][0], -1, 0) 393 | output['flow_gt_' + names[i] 394 | ][1] = torch.unsqueeze(output['flow_gt_' + names[i]][1], 0) 395 | 396 | if self.load_gt: 397 | if index + 2 < len(self.flow_png): 398 | output['flow_gt_next'] = [torch.tensor( 399 | x) for x in self.load_flow(self.flow_png[index + 2])] 400 | output['flow_gt_next'][0] = torch.moveaxis( 401 | output['flow_gt_next'][0], -1, 0) 402 | output['flow_gt_next'][1] = torch.unsqueeze( 403 | output['flow_gt_next'][1], 0) 404 | return output 405 | 406 | def __getitem__(self, idx): 407 | sample = self.get_data_sample(idx) 408 | for key_t, transform in self.transforms.items(): 409 | if key_t == "hflip": 410 | if random.random() > 0.5: 411 | for key in sample: 412 | if isinstance(sample[key], torch.Tensor): 413 | sample[key] = tf.functional.hflip(sample[key]) 414 | if key.startswith("flow_gt"): 415 | sample[key] = [tf.functional.hflip( 416 | mask) for mask in sample[key]] 417 | sample[key][0][0, :] = -sample[key][0][0, :] 418 | elif key_t == "vflip": 419 | if random.random() < transform: 420 | for key in sample: 421 | if isinstance(sample[key], torch.Tensor): 422 | sample[key] = tf.functional.vflip(sample[key]) 423 | if key.startswith("flow_gt"): 424 | sample[key] = [tf.functional.vflip( 425 | mask) for mask in sample[key]] 426 | sample[key][0][1, :] = -sample[key][0][1, :] 427 | elif key_t == "randomcrop": 428 | apply_randomcrop_to_sample(sample, crop_size=transform) 429 | else: 430 | apply_transform_to_field(sample, transform, key_t) 431 | 432 | 433 | return sample 434 | 435 | def get_voxel_grid(self, idx): 436 | 437 | if idx == 0: 438 | event_data = self.event_slicer.get_events( 439 | self.timestamps_flow[0] - self.delta_t_us, self.timestamps_flow[0]) 440 | elif idx > 0 and idx <= self.__len__(): 441 | event_data = self.event_slicer.get_events( 442 | self.timestamps_flow[idx-1], self.timestamps_flow[idx-1] + self.delta_t_us) 443 | else: 444 | raise IndexError 445 | 446 | p = event_data['p'] 447 | t = event_data['t'] 448 | x = event_data['x'] 449 | y = event_data['y'] 450 | 451 | xy_rect = self.rectify_events(x, y) 452 | x_rect = xy_rect[:, 0] 453 | y_rect = xy_rect[:, 1] 454 | return self.events_to_voxel_grid(p, t, x_rect, y_rect) 455 | 456 | def get_event_count_image(self, ts_start, ts_end, num_bins=15, normalize=True): 457 | assert ts_end > ts_start 458 | delta_t_bin = (ts_end - ts_start) / num_bins 459 | ts_start_bin = np.linspace( 460 | ts_start, ts_end, num=num_bins, endpoint=False) 461 | ts_end_bin = ts_start_bin + delta_t_bin 462 | assert abs(ts_end_bin[-1] - ts_end) < 10. 463 | ts_end_bin[-1] = ts_end 464 | 465 | event_count = torch.zeros( 466 | (num_bins, self.height, self.width), dtype=torch.float, requires_grad=False) 467 | 468 | for i in range(num_bins): 469 | event_data = self.event_slicer.get_events( 470 | ts_start_bin[i], ts_end_bin[i]) 471 | p = event_data['p'] 472 | t = event_data['t'] 473 | x = event_data['x'] 474 | y = event_data['y'] 475 | 476 | t = (t - t[0]).astype('float32') 477 | t = (t/t[-1]) 478 | x = x.astype('float32') 479 | y = y.astype('float32') 480 | pol = p.astype('float32') 481 | event_data_torch = { 482 | 'p': torch.from_numpy(pol), 483 | 't': torch.from_numpy(t), 484 | 'x': torch.from_numpy(x), 485 | 'y': torch.from_numpy(y), 486 | } 487 | x = event_data_torch['x'] 488 | y = event_data_torch['y'] 489 | xy_rect = self.rectify_events(x.int(), y.int()) 490 | x_rect = torch.from_numpy(xy_rect[:, 0]).long() 491 | y_rect = torch.from_numpy(xy_rect[:, 1]).long() 492 | value = 2*event_data_torch['p']-1 493 | index = self.width*y_rect + x_rect 494 | mask = (x_rect < self.width) & (y_rect < self.height) 495 | event_count[i].put_(index[mask], value[mask], accumulate=True) 496 | 497 | return event_count 498 | 499 | @staticmethod 500 | def normalize_tensor(event_count): 501 | mask = torch.nonzero(event_count, as_tuple=True) 502 | if mask[0].size()[0] > 0: 503 | mean = event_count[mask].mean() 504 | std = event_count[mask].std() 505 | if std > 0: 506 | event_count[mask] = (event_count[mask] - mean) / std 507 | else: 508 | event_count[mask] = event_count[mask] - mean 509 | return event_count 510 | 511 | 512 | class SequenceRecurrent(Sequence): 513 | def __init__(self, seq_path: Path, representation_type: RepresentationType, mode: str = 'test', delta_t_ms: int = 100, 514 | num_bins: int = 15, transforms=None, sequence_length=1, name_idx=0, visualize=False, load_gt=False): 515 | super(SequenceRecurrent, self).__init__(seq_path, representation_type, mode, delta_t_ms, transforms=transforms, 516 | name_idx=name_idx, visualize=visualize, load_gt=load_gt) 517 | self.crop_size = self.transforms['randomcrop'] if 'randomcrop' in self.transforms else None 518 | self.sequence_length = sequence_length 519 | self.valid_indices = self.get_continuous_sequences() 520 | 521 | def get_continuous_sequences(self): 522 | continuous_seq_idcs = [] 523 | if self.sequence_length > 1: 524 | for i in range(len(self.timestamps_flow)-self.sequence_length+1): 525 | diff = self.timestamps_flow[i + 526 | self.sequence_length-1] - self.timestamps_flow[i] 527 | if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]): 528 | continuous_seq_idcs.append(i) 529 | else: 530 | for i in range(len(self.timestamps_flow)-1): 531 | diff = self.timestamps_flow[i+1] - self.timestamps_flow[i] 532 | if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]): 533 | continuous_seq_idcs.append(i) 534 | return continuous_seq_idcs 535 | 536 | def __len__(self): 537 | return len(self.valid_indices) 538 | 539 | def __getitem__(self, idx): 540 | assert idx >= 0 541 | assert idx < len(self) 542 | 543 | # Valid index is the actual index we want to load, which guarantees a continuous sequence length 544 | valid_idx = self.valid_indices[idx] 545 | 546 | sequence = [] 547 | j = valid_idx 548 | 549 | ts_cur = self.timestamps_flow[j] 550 | # Add first sample 551 | sample = self.get_data_sample(j) 552 | sequence.append(sample) 553 | 554 | # Data augmentation according to first sample 555 | crop_window = None 556 | flip = None 557 | if 'crop_window' in sample.keys(): 558 | crop_window = sample['crop_window'] 559 | if 'flipped' in sample.keys(): 560 | flip = sample['flipped'] 561 | 562 | for i in range(self.sequence_length-1): 563 | j += 1 564 | ts_old = ts_cur 565 | ts_cur = self.timestamps_flow[j] 566 | assert(ts_cur-ts_old < 100000 + 1000) 567 | sample = self.get_data_sample( 568 | j, crop_window=crop_window, flip=flip) 569 | sequence.append(sample) 570 | 571 | # Check if the current sample is the first sample of a continuous sequence 572 | if idx == 0 or self.valid_indices[idx]-self.valid_indices[idx-1] != 1: 573 | sequence[0]['new_sequence'] = 1 574 | print("Timestamp {} is the first one of the next seq!".format( 575 | self.timestamps_flow[self.valid_indices[idx]])) 576 | else: 577 | sequence[0]['new_sequence'] = 0 578 | 579 | # random crop 580 | if self.crop_size is not None: 581 | i, j, h, w = RandomCrop.get_params( 582 | sample["event_volume_old"], output_size=self.crop_size) 583 | keys_to_crop = ["event_volume_old", "event_volume_new", 584 | "flow_gt_event_volume_old", "flow_gt_event_volume_new", 585 | "flow_gt_next",] 586 | 587 | for sample in sequence: 588 | for key, value in sample.items(): 589 | if key in keys_to_crop: 590 | if isinstance(value, torch.Tensor): 591 | sample[key] = tf.functional.crop(value, i, j, h, w) 592 | elif isinstance(value, list) or isinstance(value, tuple): 593 | sample[key] = [tf.functional.crop( 594 | v, i, j, h, w) for v in value] 595 | return sequence 596 | 597 | 598 | class DatasetProvider: 599 | def __init__(self, dataset_path: Path, representation_type: RepresentationType, delta_t_ms: int = 100, num_bins=15, 600 | type='standard', config=None, visualize=False): 601 | test_path = dataset_path / 'test' 602 | assert dataset_path.is_dir(), str(dataset_path) 603 | assert test_path.is_dir(), str(test_path) 604 | assert delta_t_ms == 100 605 | self.config = config 606 | self.name_mapper_test = [] 607 | 608 | test_sequences = list() 609 | for child in test_path.iterdir(): 610 | self.name_mapper_test.append(str(child).split("/")[-1]) 611 | if type == 'standard': 612 | test_sequences.append(Sequence(child, representation_type, 'test', delta_t_ms, num_bins, 613 | transforms=[], 614 | name_idx=len( 615 | self.name_mapper_test)-1, 616 | visualize=visualize)) 617 | elif type == 'warm_start': 618 | test_sequences.append(SequenceRecurrent(child, representation_type, 'test', delta_t_ms, num_bins, 619 | transforms=[], sequence_length=1, 620 | name_idx=len( 621 | self.name_mapper_test)-1, 622 | visualize=visualize)) 623 | else: 624 | raise Exception( 625 | 'Please provide a valid subtype [standard/warm_start] in config file!') 626 | 627 | self.test_dataset = torch.utils.data.ConcatDataset(test_sequences) 628 | 629 | def get_test_dataset(self): 630 | return self.test_dataset 631 | 632 | def get_name_mapping_test(self): 633 | return self.name_mapper_test 634 | 635 | def summary(self, logger): 636 | logger.write_line( 637 | "================================== Dataloader Summary ====================================", True) 638 | logger.write_line("Loader Type:\t\t" + self.__class__.__name__, True) 639 | logger.write_line("Number of Voxel Bins: {}".format( 640 | self.test_dataset.datasets[0].num_bins), True) 641 | 642 | 643 | def assemble_dsec_sequences(dataset_root, include_seq=None, exclude_seq=None, require_gt=True, config=None, representation_type="voxel", num_bins=None): 644 | if representation_type is None: 645 | representation_type = "voxel" 646 | representation_type = RepresentationType.VOXEL if representation_type == "voxel" else representation_type 647 | event_root = os.path.join(dataset_root, "train_events") 648 | flow_gt_root = os.path.join(dataset_root, "train_optical_flow") 649 | available_seqs = os.listdir( 650 | flow_gt_root) if require_gt else os.listdir(event_root) 651 | 652 | seqs = available_seqs 653 | if include_seq: 654 | seqs = [seq for seq in seqs if seq in include_seq] 655 | if exclude_seq: 656 | seqs = [seq for seq in seqs if seq not in exclude_seq] 657 | 658 | # Prepare transform list 659 | transforms = dict() 660 | if config.downsample_ratio > 1: 661 | transforms['(?= 0 and p_hflip <= 1 669 | #from torchvision.transforms import RandomHorizontalFlip 670 | # ignore probability of hflip for now, perform when 'hflip' key exists in transforms dict 671 | transforms['hflip'] = None 672 | if config.get("vertical_flip", None): 673 | p_vflip = config.vertical_flip 674 | assert p_vflip >= 0 and p_vflip <= 1 675 | transforms['vflip'] = p_vflip 676 | if config.get("random_crop", None): 677 | crop_size = config.random_crop 678 | transforms['randomcrop'] = crop_size 679 | 680 | seq_dataset = [] 681 | for seq in seqs: 682 | dataset_cls = SequenceRecurrent if hasattr( 683 | config, "recurrent") and config.recurrent else Sequence 684 | extra_arg = dict( 685 | sequence_length=config.sequence_length) if dataset_cls == SequenceRecurrent else dict() 686 | 687 | seq_dataset.append(dataset_cls(Path(event_root) / seq, 688 | representation_type=representation_type, mode="train", 689 | load_gt=require_gt, transforms=transforms, **extra_arg)) 690 | if config.get("concat_seq", True): 691 | return torch.utils.data.ConcatDataset(seq_dataset) 692 | else: 693 | return seq_dataset 694 | 695 | 696 | def assemble_dsec_test_set(test_set_root, seq_len=None, concat_seq=False, representation_type=None): 697 | if representation_type is None: 698 | representation_type = RepresentationType.VOXEL 699 | print("dsec test uses representation: voxel") 700 | elif representation_type == "voxel": 701 | representation_type = RepresentationType.VOXEL 702 | print("dsec test uses representation: voxel") 703 | else: 704 | print("dsec test uses representation: {}".format(representation_type)) 705 | representation_type = representation_type 706 | test_seqs = os.listdir(test_set_root) 707 | seqs = [] 708 | 709 | transforms = dict() 710 | for seq in test_seqs: 711 | dataset_cls = SequenceRecurrent if seq_len else Sequence 712 | extra_arg = dict( 713 | sequence_length=seq_len) if dataset_cls == SequenceRecurrent else dict() 714 | seqs.append(dataset_cls(Path(test_set_root) / seq, 715 | representation_type, mode='test', 716 | load_gt=False, transforms=transforms, **extra_arg)) 717 | if concat_seq: 718 | return torch.utils.data.ConcatDataset(seqs) 719 | else: 720 | return seqs 721 | 722 | 723 | def assemble_dsec_train_set(train_set_root, flow_gt_root=None, exclude_seq=None, args=None): 724 | train_seqs = os.listdir( 725 | flow_gt_root) if flow_gt_root is not None else os.listdir(train_set_root) 726 | if exclude_seq is not None: 727 | train_seqs = [seq for seq in train_seqs if seq not in exclude_seq] 728 | seq_dataset = [] 729 | for seq in train_seqs: 730 | seq_dataset.append(Sequence(Path(train_set_root) / seq, 731 | RepresentationType.VOXEL, mode='train', 732 | transforms=ToTensor)) 733 | return torch.utils.data.ConcatDataset(seq_dataset) 734 | 735 | 736 | def train_collate(sample_list): 737 | batch = dict() 738 | for field_name in sample_list[0]: 739 | if field_name == 'seq_name': 740 | batch['seq_name'] = [sample[field_name] for sample in sample_list] 741 | if field_name == 'new_sequence': 742 | batch['new_sequence'] = [sample[field_name] 743 | for sample in sample_list] 744 | if field_name.startswith("event_volume"): 745 | batch[field_name] = torch.stack( 746 | [sample[field_name] for sample in sample_list]) 747 | if field_name.startswith("flow_gt"): 748 | if all(field_name in x for x in sample_list): 749 | batch[field_name] = torch.stack( 750 | [sample[field_name][0] for sample in sample_list]) 751 | batch[field_name + '_valid_mask'] = torch.stack( 752 | [sample[field_name][1] for sample in sample_list]) 753 | 754 | return batch 755 | 756 | 757 | def rec_train_collate(sample_list): 758 | seq_length = len(sample_list[0]) 759 | seq_of_batch = [] 760 | for i in range(seq_length): 761 | seq_of_batch.append(train_collate( 762 | [sample[i] for sample in sample_list])) 763 | return seq_of_batch 764 | -------------------------------------------------------------------------------- /idn/loader/loader_mvsec.py: -------------------------------------------------------------------------------- 1 | # import h5pickle as h5py 2 | import h5py 3 | import os 4 | import torch 5 | import random 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms as T 9 | from idn.utils.mvsec_utils import EventSequence 10 | from idn.utils.dsec_utils import RepresentationType, VoxelGrid 11 | from idn.utils.transformers import EventSequenceToVoxelGrid_Pytorch, apply_randomcrop_to_sample 12 | 13 | class MVSEC(Dataset): 14 | def __init__(self, seq_name, seq_path="data/", representation_type=None, \ 15 | rate=20, num_bins=5, transforms=[], filter=None, augment=True, dt=None): 16 | self.seq_name = seq_name 17 | self.dt = dt 18 | if self.dt is None: 19 | self.event_h5 = h5py.File(os.path.join(seq_path, f"{seq_name}_data.hdf5"), "r") 20 | self.event = self.event_h5['davis']['left']['events'] 21 | self.gt_h5 = h5py.File(os.path.join(seq_path, f"{seq_name}_gt.hdf5"), "r") 22 | self.gt_flow = self.gt_h5['davis']['left']['flow_dist'] 23 | self.timestamps = self.gt_h5['davis']['left']['flow_dist_ts'] 24 | else: 25 | assert self.dt == 1 or self.dt == 4 26 | self.h5 = h5py.File(os.path.join(seq_path, f"{seq_name}.h5"), "r") 27 | self.event = self.h5['events'] 28 | self.timestamps = self.h5['flow']['dt={}'.format(self.dt)]['timestamps'][:, 0] 29 | self.gt_flow = list(self.h5['flow']['dt={}'.format(self.dt)].keys()) 30 | self.gt_flow.remove('timestamps') 31 | assert sorted(self.gt_flow) == self.gt_flow 32 | 33 | if representation_type is None: 34 | self.representation_type = VoxelGrid 35 | else: 36 | self.representation_type = representation_type 37 | 38 | if filter is not None: 39 | assert isinstance(filter, tuple) and isinstance(filter[0], int)\ 40 | and isinstance(filter[1], int) 41 | self.timestamps = self.timestamps[slice(*filter)] 42 | self.gt_flow = self.gt_flow[slice(*filter)] 43 | 44 | self.raw_gt_len = self.timestamps.shape[0] 45 | self.event_ts_to_idx = self.build_event_idx() 46 | self.voxel = EventSequenceToVoxelGrid_Pytorch( 47 | num_bins=num_bins, 48 | normalize=True, 49 | gpu=False, 50 | ) 51 | self.image_width, self.image_height = 346, 260 52 | self.cropper = T.CenterCrop((256, 256)) 53 | self.augment = augment 54 | pass 55 | 56 | def __len__(self): 57 | return self.raw_gt_len - 2 58 | 59 | def __getitem__(self, idx): 60 | idx += 1 61 | sample = {} 62 | if self.dt is None: 63 | # get events 64 | events = self.event[self.event_ts_to_idx[idx-1]:self.event_ts_to_idx[idx]] 65 | events = events[:, [2, 0, 1, 3]] # make it (t, x, y, p) 66 | sample["event_volume_old"] = \ 67 | self.voxel(EventSequence(events, 68 | params={'width': self.image_width, 69 | 'height': self.image_height}, 70 | timestamp_multiplier=1e6, 71 | convert_to_relative=True, 72 | features=events)) 73 | 74 | # get events 75 | events = self.event[self.event_ts_to_idx[idx]:self.event_ts_to_idx[idx+1]] 76 | events = events[:, [2, 0, 1, 3]] # make it (t, x, y, p) 77 | 78 | sample["event_volume_new"] = \ 79 | self.voxel(EventSequence(events, 80 | params={'width': self.image_width, 81 | 'height': self.image_height}, 82 | timestamp_multiplier=1e6, 83 | convert_to_relative=True, 84 | features = events)) 85 | # get flow 86 | flow = self.gt_flow[idx] # -1 yields the same gt flow as E-RAFT, but likely incorrect 87 | flow_next = self.gt_flow[idx+1] 88 | else: 89 | old_p = self.event['ps'][self.event_ts_to_idx[idx-1]:self.event_ts_to_idx[idx]] 90 | old_t = self.event['ts'][self.event_ts_to_idx[idx-1]:self.event_ts_to_idx[idx]] 91 | old_x = self.event['xs'][self.event_ts_to_idx[idx-1]:self.event_ts_to_idx[idx]] 92 | old_y = self.event['ys'][self.event_ts_to_idx[idx-1]:self.event_ts_to_idx[idx]] 93 | 94 | old_events = np.column_stack((old_t, old_x, old_y, old_p)) 95 | sample["event_volume_old"] = \ 96 | self.voxel(EventSequence(old_events, 97 | params={'width': self.image_width, 98 | 'height': self.image_height}, 99 | timestamp_multiplier=1e6, 100 | convert_to_relative=True, 101 | features=old_events)) 102 | 103 | new_p = self.event['ps'][self.event_ts_to_idx[idx]:self.event_ts_to_idx[idx+1]] 104 | new_t = self.event['ts'][self.event_ts_to_idx[idx]:self.event_ts_to_idx[idx+1]] 105 | new_x = self.event['xs'][self.event_ts_to_idx[idx]:self.event_ts_to_idx[idx+1]] 106 | new_y = self.event['ys'][self.event_ts_to_idx[idx]:self.event_ts_to_idx[idx+1]] 107 | 108 | new_events = np.column_stack((new_t, new_x, new_y, new_p)) 109 | sample["event_volume_new"] = \ 110 | self.voxel(EventSequence(new_events, 111 | params={'width': self.image_width, 112 | 'height': self.image_height}, 113 | timestamp_multiplier=1e6, 114 | convert_to_relative=True, 115 | features=new_events)) 116 | 117 | # get flow 118 | flow = np.transpose(self.h5['flow']['dt={}'.format(self.dt)][self.gt_flow[idx]][:], (2, 0, 1)) 119 | flow_next = np.transpose(self.h5['flow']['dt={}'.format(self.dt)][self.gt_flow[idx+1]][:], (2, 0, 1)) 120 | 121 | 122 | sample["flow_gt_event_volume_new"] = self.process_flow_gt(flow) 123 | sample["flow_gt_next"] = self.process_flow_gt(flow_next) 124 | 125 | sample["event_volume_old"] = self.cropper(sample["event_volume_old"]) 126 | sample["event_volume_new"] = self.cropper(sample["event_volume_new"]) 127 | 128 | if self.augment: 129 | # augmentation 130 | if random.random() > 0.5: 131 | for key in sample: 132 | if isinstance(sample[key], torch.Tensor): 133 | sample[key] = T.functional.hflip(sample[key]) 134 | if key.startswith("flow_gt"): 135 | sample[key] = [T.functional.hflip( 136 | mask) for mask in sample[key]] 137 | sample[key][0][0, :] = -sample[key][0][0, :] 138 | 139 | 140 | return sample 141 | 142 | def process_flow_gt(self, flow): 143 | flow_valid = (flow[0] != 0) | (flow[1] != 0) 144 | flow_valid[193:, :] = False 145 | flow = torch.from_numpy(flow) 146 | valid_mask = torch.from_numpy( 147 | np.stack([flow_valid]*1, axis=0)) 148 | 149 | return (self.cropper(flow), self.cropper(valid_mask)) 150 | 151 | def build_event_idx(self): 152 | if self.dt is None: 153 | events_ts = self.event_h5['davis']['left']['events'][:, 2] 154 | else: 155 | events_ts = self.h5['events']['ts'] 156 | return np.searchsorted(events_ts, self.timestamps, side='left') 157 | 158 | 159 | class MVSECRecurrent(MVSEC): 160 | def __init__(self, seq_name, seq_path="/scratch", representation_type=None, 161 | rate=20, num_bins=15, transforms=[], filter=None, augment=True, sequence_length=1): 162 | super(MVSECRecurrent, self).__init__(seq_name, seq_path, representation_type, 163 | rate, num_bins, transforms, filter, augment) 164 | self.sequence_length = sequence_length 165 | self.valid_indices = self.get_continuous_sequences() 166 | 167 | def get_continuous_sequences(self): 168 | # MVSEC is continuous without breaks 169 | continuous_seq_idcs = list( 170 | (range(self.raw_gt_len - 2 - self.sequence_length))) 171 | return continuous_seq_idcs 172 | 173 | def __len__(self): 174 | return len(self.valid_indices) 175 | 176 | def __getitem__(self, idx): 177 | assert idx >= 0 178 | assert idx < len(self) 179 | 180 | valid_idx = self.valid_indices[idx] 181 | sequence = [] 182 | j = valid_idx 183 | 184 | for i in range(self.sequence_length): 185 | sample = super(MVSECRecurrent, self).__getitem__(j) 186 | sequence.append(sample) 187 | j += 1 188 | 189 | 190 | # Check if the current sample is the first sample of a continuous sequence 191 | if idx == 0 or self.valid_indices[idx]-self.valid_indices[idx-1] != 1: 192 | sequence[0]['new_sequence'] = 1 193 | else: 194 | sequence[0]['new_sequence'] = 0 195 | 196 | return sequence 197 | 198 | -------------------------------------------------------------------------------- /idn/model/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResidualBlock(nn.Module): 6 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 7 | super(ResidualBlock, self).__init__() 8 | 9 | self.conv1 = nn.Conv2d( 10 | in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm( 18 | num_groups=num_groups, num_channels=planes) 19 | self.norm2 = nn.GroupNorm( 20 | num_groups=num_groups, num_channels=planes) 21 | if not stride == 1: 22 | self.norm3 = nn.GroupNorm( 23 | num_groups=num_groups, num_channels=planes) 24 | 25 | elif norm_fn == 'batch': 26 | self.norm1 = nn.BatchNorm2d(planes) 27 | self.norm2 = nn.BatchNorm2d(planes) 28 | if not stride == 1: 29 | self.norm3 = nn.BatchNorm2d(planes) 30 | 31 | elif norm_fn == 'instance': 32 | self.norm1 = nn.InstanceNorm2d(planes) 33 | self.norm2 = nn.InstanceNorm2d(planes) 34 | if not stride == 1: 35 | self.norm3 = nn.InstanceNorm2d(planes) 36 | 37 | elif norm_fn == 'none': 38 | self.norm1 = nn.Sequential() 39 | self.norm2 = nn.Sequential() 40 | if not stride == 1: 41 | self.norm3 = nn.Sequential() 42 | 43 | if stride == 1: 44 | self.downsample = None 45 | 46 | else: 47 | self.downsample = nn.Sequential( 48 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 49 | 50 | def forward(self, x): 51 | y = x 52 | y = self.relu(self.norm1(self.conv1(y))) 53 | y = self.relu(self.norm2(self.conv2(y))) 54 | 55 | if self.downsample is not None: 56 | x = self.downsample(x) 57 | 58 | return self.relu(x+y) 59 | 60 | 61 | 62 | 63 | class LiteEncoder(nn.Module): 64 | def __init__(self, output_dim=32, stride=2, dropout=0.0, n_first_channels=1): 65 | super(LiteEncoder, self).__init__() 66 | 67 | self.conv1 = nn.Conv2d(n_first_channels, output_dim, 68 | kernel_size=7, stride=2, padding=3) 69 | self.relu1 = nn.ReLU(inplace=True) 70 | 71 | self.in_planes = output_dim 72 | if stride == 2: 73 | self.layer1 = self._make_layer(output_dim, stride=2) 74 | self.layer2 = self._make_layer(output_dim*2, stride=2) 75 | 76 | elif stride == 1: 77 | self.layer1 = self._make_layer(output_dim*2, stride=2) 78 | self.layer2 = self._make_layer(output_dim*2, stride=1) 79 | 80 | else: 81 | raise ValueError('stride must be 1 or 2') 82 | 83 | 84 | self.dropout = None 85 | if dropout > 0: 86 | self.dropout = nn.Dropout2d(p=dropout) 87 | 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d): 90 | nn.init.kaiming_normal_( 91 | m.weight, mode='fan_out', nonlinearity='relu') 92 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 93 | if m.weight is not None: 94 | nn.init.constant_(m.weight, 1) 95 | if m.bias is not None: 96 | nn.init.constant_(m.bias, 0) 97 | 98 | def _make_layer(self, dim, stride=1): 99 | layer1 = ResidualBlock(self.in_planes, dim, 'none', stride=stride) 100 | layer2 = ResidualBlock(dim, dim, 'none', stride=1) 101 | layers = (layer1, layer2) 102 | 103 | self.in_planes = dim 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | # if input is list, combine batch dimension 108 | is_list = isinstance(x, tuple) or isinstance(x, list) 109 | if is_list: 110 | batch_dim = x[0].shape[0] 111 | x = torch.cat(x, dim=0) 112 | 113 | x = self.conv1(x) 114 | x = self.relu1(x) 115 | x = self.layer1(x) 116 | x = self.layer2(x) 117 | 118 | # x = self.conv2(x) 119 | 120 | if self.training and self.dropout is not None: 121 | x = self.dropout(x) 122 | 123 | if is_list: 124 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 125 | return x 126 | 127 | -------------------------------------------------------------------------------- /idn/model/idedeq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import unfold, grid_sample, interpolate 4 | 5 | from .extractor import LiteEncoder 6 | from .update import LiteUpdateBlock 7 | 8 | from math import sqrt 9 | 10 | 11 | class IDEDEQIDO(nn.Module): 12 | def __init__(self, config): 13 | super(IDEDEQIDO, self).__init__() 14 | self.hidden_dim = getattr(config, 'hidden_dim', 96) 15 | self.input_dim = 64 16 | self.downsample = getattr(config, 'downsample', 8) 17 | self.input_flowmap = getattr(config, 'input_flowmap', False) 18 | self.pred_next_flow = getattr(config, 'pred_next_flow', False) 19 | self.fnet = LiteEncoder( 20 | output_dim=self.input_dim//2, dropout=0, n_first_channels=2, stride=2 if self.downsample == 8 else 1) 21 | self.update_net = LiteUpdateBlock( 22 | hidden_dim=self.hidden_dim, input_dim=self.input_dim, 23 | num_outputs=2 if self.pred_next_flow else 1, 24 | downsample=self.downsample) 25 | self.deblur_iters = config.update_iters 26 | self.zero_init = config.zero_init 27 | self._deq = getattr(config, "deq_mode", False) 28 | self.hook = None 29 | self.co_mode = getattr(config, "co_mode", False) 30 | self.conr_mode = getattr(config, "conr_mode", False) 31 | self.deblur = getattr(config, "deblur", True) 32 | self.add_delta = getattr(config, "add_delta", False) 33 | self.deblur_mode = getattr(config, "deblur_mode", "voxel") 34 | self.reset_continuous_flow() 35 | if self.input_flowmap: 36 | self.cnet = LiteEncoder( 37 | output_dim=self.hidden_dim // 2, dropout=0, n_first_channels=2, stride=2 if self.downsample == 8 else 1) 38 | else: 39 | self.cnet = None 40 | 41 | def upsample_flow(self, flow, mask): 42 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 43 | N, _, H, W = flow.shape 44 | _, D, H, W = mask.shape 45 | upsample_ratio = int(sqrt(D/9)) 46 | mask = mask.view(N, 1, 9, upsample_ratio, upsample_ratio, H, W) 47 | mask = torch.softmax(mask, dim=2) 48 | 49 | up_flow = unfold(8 * flow, [3, 3], padding=1) 50 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 51 | 52 | up_flow = torch.sum(mask * up_flow, dim=2) 53 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 54 | return up_flow.reshape(N, 2, upsample_ratio*H, upsample_ratio*W) 55 | 56 | @staticmethod 57 | def upflow8(flow, mode='bilinear'): 58 | new_size = (8 * flow.shape[-2], 8 * flow.shape[-1]) 59 | return 8 * interpolate(flow, size=new_size, mode=mode, align_corners=True) 60 | 61 | @staticmethod 62 | def create_identity_grid(H, W, device): 63 | i, j = map(lambda x: x.float(), torch.meshgrid( 64 | [torch.arange(0, H), torch.arange(0, W)], indexing='ij')) 65 | return torch.stack([j, i], dim=-1).to(device) 66 | 67 | def deblur_tensor(self, raw_input, flow, mask=None): 68 | # raw: [N, T, C, H, W] 69 | raw = raw_input.unsqueeze(2) if raw_input.ndim == 4 else raw_input 70 | N, T, C, H, W = raw.shape 71 | deblurred_tensor = torch.zeros_like(raw) 72 | identity_grid = self.create_identity_grid(H, W, raw.device) 73 | for t in range(T): 74 | if self.deblur_mode == "voxel": 75 | delta_p = flow*t/(T-1) 76 | else: 77 | delta_p = flow*((t+0.5)/T) 78 | sampling_grid = identity_grid + torch.movedim(delta_p, 1, -1) 79 | sampling_grid[..., 0] = sampling_grid[..., 0] / (W-1) * 2 - 1 80 | sampling_grid[..., 1] = sampling_grid[..., 1] / (H-1) * 2 - 1 81 | deblurred_tensor[:, t, 82 | ] = grid_sample(raw[:, t, ], sampling_grid, align_corners=False) 83 | if raw_input.ndim == 4: 84 | deblurred_tensor = deblurred_tensor.squeeze(2) 85 | return deblurred_tensor 86 | 87 | def reset_continuous_flow(self, reset=True): 88 | if reset: 89 | self.flow_init = None 90 | self.last_net_co = None 91 | 92 | def forward(self, event_bins, flow_init=None, deblur_iters=None, net_co=None): 93 | if self.co_mode: 94 | # continuous mode 95 | # take flow_init from state and forward propagate it 96 | assert flow_init is None, "flow_init should be None in continuous mode" 97 | if self.flow_init is None: 98 | print("No last flow, using zero flow") 99 | elif 'new_sequence' in event_bins and event_bins['new_sequence'][0] == 1: 100 | print("Got new sequence, resetting flow") 101 | flow_init = None 102 | else: 103 | flow_init = self.flow_init 104 | if self.conr_mode: 105 | if self.last_net_co is None: 106 | print("No last net_co, using zero flow") 107 | elif event_bins['new_sequence'][0] == 1: 108 | print("Got new sequence, resetting flow") 109 | net_co = None 110 | else: 111 | net_co = self.last_net_co 112 | 113 | deblur_iters = self.deblur_iters if deblur_iters is None else deblur_iters 114 | # x_old, x_new = event_bins["event_volume_old"], event_bins["event_volume_new"] 115 | x_raw = event_bins["event_volume_new"] 116 | 117 | B, V, H, W = x_raw.shape 118 | flow_total = torch.zeros(B, 2, H, W).to( 119 | x_raw.device) if flow_init is None else flow_init.clone() 120 | 121 | delta_flow = flow_total 122 | flow_history = torch.zeros(B, 0, 2, H, W).to(x_raw.device) 123 | x_deblur_history = x_raw.clone().unsqueeze(1) 124 | delta_flow_history = delta_flow.clone().unsqueeze(1) 125 | 126 | 127 | x_deblur = x_raw.clone() 128 | for iter in range(deblur_iters): 129 | if self.deblur: 130 | x_deblur = self.deblur_tensor(x_deblur, delta_flow) 131 | x = torch.stack([x_deblur, x_deblur], dim=1) 132 | x_deblur_history = torch.cat( 133 | [x_deblur_history, x_deblur.unsqueeze(1)], dim=1) 134 | else: 135 | x = torch.stack([x_raw, x_raw], dim=1) 136 | 137 | if net_co is not None: 138 | net = net_co 139 | else: 140 | if self.input_flowmap: 141 | assert self.cnet is not None, "cnet non initialized in flowmap mode" 142 | if flow_init is not None or iter >= 1: 143 | net = self.cnet(flow_total) 144 | else: 145 | net = torch.zeros( 146 | (B, self.hidden_dim, 147 | H//self.downsample, W//self.downsample)).to(x.device) 148 | else: 149 | if self.cnet is not None: 150 | net = self.cnet(x) 151 | else: 152 | net = torch.zeros( 153 | (B, self.hidden_dim, 154 | H//self.downsample, W//self.downsample)).to(x.device) 155 | for i, slice in enumerate(x.permute(2, 0, 1, 3, 4)): 156 | f = self.fnet(slice) 157 | net = self.update_net(net, f) 158 | 159 | dflow = self.update_net.compute_deltaflow(net) 160 | up_mask = self.update_net.compute_up_mask(net) 161 | delta_flow = self.upsample_flow(dflow, up_mask) 162 | delta_flow_history = torch.cat( 163 | [delta_flow_history, delta_flow.unsqueeze(1)], dim=1) 164 | if self.pred_next_flow: 165 | nflow = self.update_net.compute_nextflow(net) 166 | up_mask_next_flow = self.update_net.compute_up_mask2(net) 167 | next_flow = self.upsample_flow(nflow, up_mask_next_flow) 168 | else: 169 | next_flow = None 170 | 171 | if self.deblur or self.add_delta: 172 | flow_total = flow_total + delta_flow 173 | else: 174 | flow_total = delta_flow 175 | flow_history = torch.cat( 176 | [flow_history, flow_total.unsqueeze(1)], dim=1) 177 | 178 | 179 | if self.co_mode: 180 | if self.pred_next_flow: 181 | assert 'next_flow' in locals() 182 | self.flow_init = next_flow 183 | else: 184 | self.flow_init = self.forward_flow(flow_total) 185 | if self.conr_mode: 186 | self.last_net_co = net 187 | return {'final_prediction': flow_total, 188 | 'next_flow': next_flow, 189 | 'delta_flow': delta_flow_history, 190 | 'deblurred_event_volume_new': x_deblur_history, 191 | 'flow_history': flow_history, 192 | 'net': net} 193 | 194 | def forward_flowmap(self, event_bins, flow_init=None, deblur_iters=None): 195 | deblur_iters = self.deblur_iters if deblur_iters is None else deblur_iters 196 | 197 | x = event_bins["event_volume_new"] 198 | 199 | B, V, H, W = x.shape 200 | flow_total = torch.zeros(B, 2, H, W).to( 201 | x.device) if flow_init is None else flow_init.clone() 202 | 203 | delta_flow = flow_total 204 | flow_history = torch.zeros(B, 0, 2, H, W).to(x.device) 205 | 206 | for _ in range(deblur_iters): 207 | x_deblur = self.deblur_tensor(x, delta_flow) 208 | x = torch.stack([x_deblur, x_deblur], dim=1) 209 | 210 | if flow_init is not None and self.cnet is not None: 211 | net = self.cnet(flow_total) 212 | else: 213 | net = torch.zeros( 214 | (B, self.hidden_dim, H//8, W//8)).to(x.device) 215 | for i, slice in enumerate(x.permute(2, 0, 1, 3, 4)): 216 | f = self.fnet(slice) 217 | net = self.update_net(net, f) 218 | 219 | dflow = self.update_net.compute_deltaflow(net) 220 | up_mask = self.update_net.compute_up_mask(net) 221 | delta_flow = self.upsample_flow(dflow, up_mask) 222 | flow_total = flow_total + delta_flow 223 | 224 | flow_history = torch.cat( 225 | [flow_history, flow_total.unsqueeze(1)], dim=1) 226 | 227 | 228 | return {'final_prediction': flow_total, 229 | 'flow_history': flow_history} 230 | 231 | 232 | class RecIDE(IDEDEQIDO): 233 | def __init__(self, *args, **kwargs): 234 | super().__init__(*args, **kwargs) 235 | 236 | def forward(self, batch, flow_init=None, deblur_iters=None): 237 | deblur_iters = self.deblur_iters if deblur_iters is None else deblur_iters 238 | 239 | flow_trajectory = [] 240 | flow_next_trajectory = [] 241 | 242 | for t, x in enumerate(batch): 243 | out = super().forward(x, flow_init=flow_init) 244 | flow_pred = out['final_prediction'] 245 | if 'next_flow' in out: 246 | flow_next = out['next_flow'] 247 | flow_init = flow_next 248 | flow_next_trajectory.append(flow_next) 249 | else: 250 | flow_init = self.forward_flow(flow_pred) 251 | 252 | flow_trajectory.append(flow_pred) 253 | 254 | if (t+1) % 4 == 0: 255 | flow_init = flow_init.detach() 256 | yield {'final_prediction': flow_pred, 257 | 'flow_trajectory': flow_trajectory, 258 | 'flow_next_trajectory': flow_next_trajectory, } 259 | flow_trajectory = [] 260 | flow_next_trajectory = [] 261 | 262 | def forward_inference(self, batch, flow_init=None, deblur_iters=None): 263 | deblur_iters = self.deblur_iters if deblur_iters is None else deblur_iters 264 | 265 | 266 | flow_trajectory = [] 267 | 268 | for t, x in enumerate(batch): 269 | out = super().forward(x, flow_init=flow_init) 270 | flow_pred = out['final_prediction'] 271 | flow_init = self.forward_flow(flow_pred) 272 | flow_trajectory.append(flow_pred) 273 | 274 | return {'final_prediction': flow_pred, 275 | 'flow_trajectory': flow_trajectory} 276 | 277 | def backward_neg_flow(self, x): 278 | x["event_volume_new"] = -torch.flip(x["event_volume_new"], [1]) 279 | back_flow = -super().forward(x)['final_prediction'] 280 | return back_flow 281 | -------------------------------------------------------------------------------- /idn/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def sparse_l1_seq(estimated, ground_truth, valid_mask=None): 5 | assert isinstance(estimated, list) 6 | assert isinstance(ground_truth, list) 7 | if valid_mask is not None: 8 | assert isinstance(valid_mask, list) 9 | assert len(estimated) == len(ground_truth) == len(valid_mask) 10 | loss = 0. 11 | for i in range(len(estimated)): 12 | loss += sparse_l1(estimated[i], ground_truth[i], valid_mask[i]) 13 | return loss/len(estimated) 14 | 15 | 16 | def sparse_l1(estimated, ground_truth, valid_mask=None): 17 | """Return L1 loss. 18 | This loss ignores difference in pixels where mask is False. 19 | If all pixels are marked as False, the loss is equal to zero. 20 | Args: 21 | estimated: is tensor with predicted values of size 22 | batch_size x height x width. 23 | ground_truth: is tensor with ground truth values of 24 | size batch_size x height x width. 25 | mask: mask of size batch_size x height x width. Only 26 | pixels with True values are used. If "valid_mask" 27 | is None, than we use all pixels. 28 | """ 29 | if valid_mask is not None: 30 | valid_mask = valid_mask.bool() 31 | pixelwise_diff = (estimated - ground_truth).abs() 32 | if valid_mask is not None: 33 | if valid_mask.size() == pixelwise_diff.size(): 34 | pixelwise_diff = pixelwise_diff[valid_mask] 35 | else: 36 | try: 37 | pixelwise_diff = pixelwise_diff[valid_mask.expand( 38 | pixelwise_diff.size())] 39 | except: 40 | raise Exception("Mask auto expand failed.") 41 | if pixelwise_diff.numel() == 0: 42 | return torch.Tensor([0]).type(estimated.type()) 43 | return pixelwise_diff.mean() 44 | 45 | 46 | def sparse_lnorm(order, estimated, ground_truth, valid_mask=None, per_frame=False): 47 | """Return L1 loss. 48 | This loss ignores difference in pixels where mask is False. 49 | If all pixels are marked as False, the loss is equal to zero. 50 | Args: 51 | estimated: is tensor with predicted values of size 52 | batch_size x height x width. 53 | ground_truth: is tensor with ground truth values of 54 | size batch_size x height x width. 55 | mask: mask of size batch_size x height x width. Only 56 | pixels with True values are used. If "valid_mask" 57 | is None, than we use all pixels. 58 | """ 59 | if valid_mask is not None: 60 | valid_mask = valid_mask.bool() 61 | pixelwise_diff = torch.norm( 62 | estimated - ground_truth, p=order, keepdim=True, dim=(1)) 63 | # make sure valid_mask is the same shape as diff 64 | if valid_mask is not None: 65 | if valid_mask.size() != pixelwise_diff.size(): 66 | try: 67 | valid_mask = valid_mask.expand(pixelwise_diff.size()) 68 | except: 69 | raise Exception("Mask auto expand failed.") 70 | if per_frame: 71 | error = [] 72 | if valid_mask is not None: 73 | for diff, mask in zip(pixelwise_diff, valid_mask): 74 | error.append(diff[mask].mean().item()) 75 | else: 76 | error = [e.mean().item() for e in pixelwise_diff] 77 | emap = estimated - ground_truth 78 | emask = valid_mask.expand(emap.size()) 79 | emap[~emask] = 0 80 | return { 81 | "metric": error, 82 | "t_emap": emap 83 | } 84 | else: 85 | if valid_mask is not None: 86 | pixelwise_diff = pixelwise_diff[valid_mask] 87 | if pixelwise_diff.numel() == 0: 88 | return torch.Tensor([0]).type(estimated.type()) 89 | return pixelwise_diff.mean() 90 | 91 | 92 | def charbonnier_loss(delta, alpha=0.45, epsilon=1e-3): 93 | """ 94 | Robust Charbonnier loss, as defined in equation (4) of the paper. 95 | """ 96 | loss = torch.mean(torch.pow((delta ** 2 + epsilon ** 2), alpha)) 97 | return loss 98 | 99 | 100 | def compute_smoothness_loss(flow): 101 | """ 102 | Local smoothness loss, as defined in equation (5) of the paper. 103 | The neighborhood here is defined as the 8-connected region around each pixel. 104 | """ 105 | flow_ucrop = flow[..., 1:] 106 | flow_dcrop = flow[..., :-1] 107 | flow_lcrop = flow[..., 1:, :] 108 | flow_rcrop = flow[..., :-1, :] 109 | 110 | flow_ulcrop = flow[..., 1:, 1:] 111 | flow_drcrop = flow[..., :-1, :-1] 112 | flow_dlcrop = flow[..., :-1, 1:] 113 | flow_urcrop = flow[..., 1:, :-1] 114 | 115 | smoothness_loss = charbonnier_loss(flow_lcrop - flow_rcrop) + \ 116 | charbonnier_loss(flow_ucrop - flow_dcrop) + \ 117 | charbonnier_loss(flow_ulcrop - flow_drcrop) + \ 118 | charbonnier_loss(flow_dlcrop - flow_urcrop) 119 | smoothness_loss /= 4. 120 | 121 | return smoothness_loss 122 | 123 | 124 | def compute_npe(n, estimated, ground_truth, valid_mask=None): 125 | if valid_mask is not None: 126 | valid_mask = valid_mask.bool() 127 | pixelwise_diff = torch.norm( 128 | estimated - ground_truth, p=2, keepdim=True, dim=(1)) 129 | 130 | # make sure valid_mask is the same shape as diff 131 | if valid_mask is not None: 132 | if valid_mask.size() != pixelwise_diff.size(): 133 | try: 134 | valid_mask = valid_mask.expand(pixelwise_diff.size()) 135 | except: 136 | raise Exception("Mask auto expand failed.") 137 | 138 | if valid_mask is not None: 139 | pixelwise_diff = pixelwise_diff[valid_mask] 140 | if pixelwise_diff.numel() == 0: 141 | return torch.Tensor([0]).type(estimated.type()) 142 | return { 143 | "metric": torch.numel(pixelwise_diff[pixelwise_diff >= n]) / 144 | torch.numel(pixelwise_diff) 145 | } 146 | -------------------------------------------------------------------------------- /idn/model/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead2(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead2, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | self.tanh = nn.Tanh() 13 | 14 | def forward(self, x): 15 | return 10*self.tanh(self.conv2(self.relu(self.conv1(x)))) 16 | 17 | 18 | class FlowHead(nn.Module): 19 | def __init__(self, input_dim=128, hidden_dim=256): 20 | super(FlowHead, self).__init__() 21 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 22 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | def forward(self, x): 26 | return self.conv2(self.relu(self.conv1(x))) 27 | 28 | 29 | class ConvGRU(nn.Module): 30 | def __init__(self, hidden_dim=128, input_dim=192+128): 31 | super(ConvGRU, self).__init__() 32 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 33 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 34 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 35 | 36 | def forward(self, h, x): 37 | hx = torch.cat([h, x], dim=1) 38 | 39 | z = torch.sigmoid(self.convz(hx)) 40 | r = torch.sigmoid(self.convr(hx)) 41 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 42 | 43 | h = (1-z) * h + z * q 44 | return h 45 | 46 | 47 | class LiteUpdateBlock(nn.Module): 48 | def __init__(self, hidden_dim=32, input_dim=16, num_outputs=1, downsample=8): 49 | super(LiteUpdateBlock, self).__init__() 50 | self.upsample_mask_dim = downsample * downsample 51 | self.num_outputs = num_outputs 52 | assert self.num_outputs in [1, 2] 53 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=input_dim) 54 | self.flow_head = FlowHead(hidden_dim, hidden_dim=hidden_dim) 55 | self.mask = nn.Sequential( 56 | nn.Conv2d(hidden_dim, 256, 3, padding=1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(256, self.upsample_mask_dim*9, 1, padding=0)) 59 | if self.num_outputs == 2: 60 | self.flow_head2 = FlowHead(hidden_dim, hidden_dim=hidden_dim) 61 | self.mask2 = nn.Sequential( 62 | nn.Conv2d(hidden_dim, 256, 3, padding=1), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(256, self.upsample_mask_dim*9, 1, padding=0)) 65 | 66 | def forward(self, net, inp): 67 | return self.gru(net, inp) 68 | 69 | def compute_deltaflow(self, net): 70 | return self.flow_head(net) 71 | 72 | def compute_nextflow(self, net): 73 | if self.num_outputs == 2: 74 | return self.flow_head2(net) 75 | else: 76 | raise NotImplementedError 77 | 78 | def compute_up_mask(self, net): 79 | return self.mask(net) 80 | 81 | def compute_up_mask2(self, net): 82 | if self.num_outputs == 2: 83 | return self.mask2(net) 84 | else: 85 | raise NotImplementedError 86 | -------------------------------------------------------------------------------- /idn/scripts/format_mvsec/eval_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # import tensorflow as tf 3 | import numpy as np 4 | import math 5 | import cv2 6 | 7 | 8 | """ 9 | Propagates x_indices and y_indices by their flow, as defined in x_flow, y_flow. 10 | x_mask and y_mask are zeroed out at each pixel where the indices leave the image. 11 | The optional scale_factor will scale the final displacement. 12 | """ 13 | def prop_flow(x_flow, y_flow, x_indices, y_indices, x_mask, y_mask, scale_factor=1.0): 14 | flow_x_interp = cv2.remap(x_flow, 15 | x_indices, 16 | y_indices, 17 | cv2.INTER_NEAREST) 18 | 19 | flow_y_interp = cv2.remap(y_flow, 20 | x_indices, 21 | y_indices, 22 | cv2.INTER_NEAREST) 23 | 24 | x_mask[flow_x_interp == 0] = False 25 | y_mask[flow_y_interp == 0] = False 26 | 27 | x_indices += flow_x_interp * scale_factor 28 | y_indices += flow_y_interp * scale_factor 29 | 30 | return 31 | 32 | """ 33 | The ground truth flow maps are not time synchronized with the grayscale images. Therefore, we 34 | need to propagate the ground truth flow over the time between two images. 35 | This function assumes that the ground truth flow is in terms of pixel displacement, not velocity. 36 | Pseudo code for this process is as follows: 37 | x_orig = range(cols) 38 | y_orig = range(rows) 39 | x_prop = x_orig 40 | y_prop = y_orig 41 | Find all GT flows that fit in [image_timestamp, image_timestamp+image_dt]. 42 | for all of these flows: 43 | x_prop = x_prop + gt_flow_x(x_prop, y_prop) 44 | y_prop = y_prop + gt_flow_y(x_prop, y_prop) 45 | The final flow, then, is x_prop - x-orig, y_prop - y_orig. 46 | Note that this is flow in terms of pixel displacement, with units of pixels, not pixel velocity. 47 | Inputs: 48 | x_flow_in, y_flow_in - list of numpy arrays, each array corresponds to per pixel flow at 49 | each timestamp. 50 | gt_timestamps - timestamp for each flow array. 51 | start_time, end_time - gt flow will be estimated between start_time and end time. 52 | """ 53 | def estimate_corresponding_gt_flow(x_flow_in, 54 | y_flow_in, 55 | gt_timestamps, 56 | start_time, 57 | end_time): 58 | # Each gt flow at timestamp gt_timestamps[gt_iter] represents the displacement between 59 | # gt_iter and gt_iter+1. 60 | gt_iter = np.searchsorted(gt_timestamps, start_time, side='left') 61 | try: 62 | gt_timestamps[gt_iter+1] 63 | except IndexError: 64 | return None, None 65 | gt_dt = gt_timestamps[gt_iter+1] - gt_timestamps[gt_iter] 66 | x_flow = np.squeeze(x_flow_in[gt_iter, ...]) 67 | y_flow = np.squeeze(y_flow_in[gt_iter, ...]) 68 | 69 | dt = end_time - start_time 70 | 71 | # No need to propagate if the desired dt is shorter than the time between gt timestamps. 72 | if gt_dt > dt: 73 | return x_flow * dt / gt_dt, y_flow * dt / gt_dt 74 | 75 | x_indices, y_indices = np.meshgrid(np.arange(x_flow.shape[1]), 76 | np.arange(x_flow.shape[0])) 77 | x_indices = x_indices.astype(np.float32) 78 | y_indices = y_indices.astype(np.float32) 79 | 80 | orig_x_indices = np.copy(x_indices) 81 | orig_y_indices = np.copy(y_indices) 82 | 83 | # Mask keeps track of the points that leave the image, and zeros out the flow afterwards. 84 | x_mask = np.ones(x_indices.shape, dtype=bool) 85 | y_mask = np.ones(y_indices.shape, dtype=bool) 86 | 87 | scale_factor = (gt_timestamps[gt_iter+1] - start_time) / gt_dt 88 | total_dt = gt_timestamps[gt_iter+1] - start_time 89 | 90 | prop_flow(x_flow, y_flow, 91 | x_indices, y_indices, 92 | x_mask, y_mask, 93 | scale_factor=scale_factor) 94 | 95 | gt_iter += 1 96 | 97 | while gt_timestamps[gt_iter+1] < end_time: 98 | x_flow = np.squeeze(x_flow_in[gt_iter, ...]) 99 | y_flow = np.squeeze(y_flow_in[gt_iter, ...]) 100 | 101 | prop_flow(x_flow, y_flow, 102 | x_indices, y_indices, 103 | x_mask, y_mask) 104 | total_dt += gt_timestamps[gt_iter+1] - gt_timestamps[gt_iter] 105 | 106 | gt_iter += 1 107 | 108 | final_dt = end_time - gt_timestamps[gt_iter] 109 | total_dt += final_dt 110 | 111 | final_gt_dt = gt_timestamps[gt_iter+1] - gt_timestamps[gt_iter] 112 | 113 | x_flow = np.squeeze(x_flow_in[gt_iter, ...]) 114 | y_flow = np.squeeze(y_flow_in[gt_iter, ...]) 115 | 116 | scale_factor = final_dt / final_gt_dt 117 | 118 | prop_flow(x_flow, y_flow, 119 | x_indices, y_indices, 120 | x_mask, y_mask, 121 | scale_factor) 122 | 123 | x_shift = x_indices - orig_x_indices 124 | y_shift = y_indices - orig_y_indices 125 | x_shift[~x_mask] = 0 126 | y_shift[~y_mask] = 0 127 | 128 | return x_shift, y_shift -------------------------------------------------------------------------------- /idn/scripts/format_mvsec/format_mvsec.py: -------------------------------------------------------------------------------- 1 | import hdf5plugin 2 | import h5py 3 | import numpy as np 4 | 5 | from eval_utils import * 6 | from h5_packager import * 7 | 8 | 9 | def process_events(h5_file, event_file, delta=100000): 10 | print("Processing events...") 11 | 12 | cnt = 0 13 | t0 = None 14 | events = event_file["davis"]["left"]["events"] 15 | while True: 16 | x = events[cnt : cnt + delta, 0].astype(np.int16) 17 | y = events[cnt : cnt + delta, 1].astype(np.int16) 18 | t = events[cnt : cnt + delta, 2].astype(np.float64) 19 | p = events[cnt : cnt + delta, 3] 20 | p[p < 0] = 0 21 | p = p.astype(np.bool_) 22 | if x.shape[0] <= 0: 23 | break 24 | else: 25 | tlast = t[-1] 26 | 27 | if t0 is None: 28 | t0 = t[0] 29 | 30 | h5_file.package_events(x, y, t, p) 31 | cnt += delta 32 | 33 | return t0, tlast 34 | 35 | 36 | def process_flow(h5_file, gt_flow, event_file, t0, dt=1): 37 | print("Processing flow...") 38 | 39 | group = h5_file.file.create_group("flow/dt=" + str(dt)) 40 | ts_table = group.create_dataset("timestamps", (0, 2), dtype=np.float64, maxshape=(None, 2), chunks=True) 41 | flow_cnt = 0 42 | cur_cnt, prev_cnt = 0, 0 43 | cur_t, prev_t = None, None 44 | flow_x, flow_y, ts = gt_flow["x_flow_dist"], gt_flow["y_flow_dist"], gt_flow["timestamps"] 45 | 46 | image_raw_ts = event_file["davis"]["left"]["image_raw_ts"] 47 | for t in range(image_raw_ts.shape[0]): 48 | cur_t = image_raw_ts[t] 49 | 50 | # upsample flow only at the frame timestamps in between gt samples 51 | if cur_t < gt_flow["timestamps"].min(): 52 | continue 53 | elif cur_t > gt_flow["timestamps"].max(): 54 | break 55 | 56 | # skip dt frames between each gt sample 57 | if cur_cnt - prev_cnt >= dt: 58 | 59 | # interpolate flow 60 | disp_x, disp_y = estimate_corresponding_gt_flow( 61 | flow_x, 62 | flow_y, 63 | ts, 64 | prev_t, 65 | cur_t, 66 | ) 67 | if disp_x is None: 68 | return 69 | disp = np.stack([disp_x, disp_y], axis=2) 70 | print(cur_t - t0) 71 | h5_file.package_flow(disp, (prev_t, cur_t), flow_cnt, dt=dt) 72 | h5_file.append(ts_table, np.array([[prev_t, cur_t]])) 73 | 74 | # update counters 75 | flow_cnt += 1 76 | prev_t = cur_t 77 | prev_cnt = cur_cnt 78 | 79 | cur_cnt += 1 80 | if prev_t is None: 81 | prev_t = image_raw_ts[t] 82 | 83 | 84 | if __name__ == "__main__": 85 | 86 | # load data 87 | gt = np.load("/scratch/indoor_flying3_gt_flow_dist.npz") 88 | event_data = h5py.File("/scratch/indoor_flying3_data.hdf5", "r") 89 | 90 | 91 | # initialize h5 file 92 | ep = H5Packager("/scratch/indoor_flying3.h5") 93 | 94 | # process events 95 | t0, tlast = process_events(ep, event_data) 96 | # process flow 97 | # t0 = 1506119776.3833518 98 | # tlast = 1506120429.7728229 99 | process_flow(ep, gt, event_data, t0, dt=1) 100 | process_flow(ep, gt, event_data, t0, dt=4) 101 | 102 | ep.add_metadata(t0, tlast) 103 | -------------------------------------------------------------------------------- /idn/scripts/format_mvsec/h5_packager.py: -------------------------------------------------------------------------------- 1 | import hdf5plugin 2 | import h5py 3 | import numpy as np 4 | 5 | 6 | class H5Packager: 7 | def __init__(self, output_path): 8 | print("Creating file in {}".format(output_path)) 9 | self.output_path = output_path 10 | 11 | self.file = h5py.File(output_path, "w") 12 | self.event_xs = self.file.create_dataset( 13 | "events/xs", (0,), dtype=np.dtype(np.int16), maxshape=(None,), chunks=True, 14 | ) 15 | self.event_ys = self.file.create_dataset( 16 | "events/ys", (0,), dtype=np.dtype(np.int16), maxshape=(None,), chunks=True, 17 | ) 18 | self.event_ts = self.file.create_dataset( 19 | "events/ts", (0,), dtype=np.dtype(np.float64), maxshape=(None,), chunks=True, 20 | ) 21 | self.event_ps = self.file.create_dataset( 22 | "events/ps", (0,), dtype=np.dtype(np.bool_), maxshape=(None,), chunks=True, 23 | ) 24 | 25 | def append(self, dataset, data): 26 | dataset.resize(dataset.shape[0] + len(data), axis=0) 27 | if len(data) == 0: 28 | return 29 | dataset[-len(data) :] = data[:] 30 | 31 | def package_events(self, xs, ys, ts, ps): 32 | self.append(self.event_xs, xs) 33 | self.append(self.event_ys, ys) 34 | self.append(self.event_ts, ts) 35 | self.append(self.event_ps, ps) 36 | 37 | def package_flow(self, flowmap, timestamp, flow_idx, dt=1): 38 | flowmap_dset = self.file.create_dataset( 39 | "flow/dt=" + str(dt) + "/" + "{:09d}".format(flow_idx), data=flowmap, dtype=np.dtype(np.float64), 40 | ) 41 | flowmap_dset.attrs["size"] = flowmap.shape 42 | flowmap_dset.attrs["timestamp_from"] = timestamp[0] 43 | flowmap_dset.attrs["timestamp_to"] = timestamp[1] 44 | 45 | def add_metadata(self, t0, tlast): 46 | self.file.attrs["t0"] = t0 47 | self.file.attrs["tk"] = tlast 48 | self.file.attrs["duration"] = tlast - t0 49 | -------------------------------------------------------------------------------- /idn/tests/dsec.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import DataLoader 3 | from ..loader.loader_dsec import assemble_dsec_sequences, assemble_dsec_test_set, rec_train_collate, \ 4 | train_collate 5 | from ..loader.loader_mvsec import MVSEC 6 | from .test import Test 7 | import torch 8 | from tqdm import tqdm 9 | from ..utils.helper_functions import move_batch_to_cuda 10 | import time 11 | 12 | 13 | class TestCO(Test): 14 | def __init__(self, test_spec): 15 | super().__init__(test_spec) 16 | 17 | def configure_dataloader(self): 18 | self.spec.data_loader.args.batch_size = 1 19 | return super().configure_dataloader() 20 | 21 | @staticmethod 22 | def configure_model_forward_fn(model): 23 | return super(type(model), model).forward 24 | 25 | @staticmethod 26 | def configure_forward_pass_fn(model_forward_fn): 27 | return lambda list_batch: model_forward_fn(list_batch[0]) 28 | 29 | def configure_model(self, model): 30 | self.original_co_mode = model.co_mode 31 | model.co_mode = True 32 | model.reset_continuous_flow() 33 | 34 | def cleanup_model(self, model): 35 | model.co_mode = self.original_co_mode 36 | model.reset_continuous_flow() 37 | 38 | 39 | class TestRESET(Test): 40 | def __init__(self, test_spec): 41 | super().__init__(test_spec) 42 | 43 | @staticmethod 44 | def configure_model_forward_fn(model): 45 | return model.forward_inference 46 | 47 | def configure_dataloader(self): 48 | return super().configure_dataloader() 49 | 50 | def configure_model(self, model): 51 | self.original_co_mode = model.co_mode 52 | model.co_mode = False 53 | 54 | def cleanup_model(self, model): 55 | model.co_mode = self.original_co_mode 56 | 57 | 58 | class TestNONREC(Test): 59 | def __init__(self, test_spec): 60 | super().__init__(test_spec) 61 | 62 | def configure_dataloader(self): 63 | return super().configure_dataloader() 64 | 65 | 66 | def assemble_dsec_test_cls(test_type_name=None): 67 | test_cls = { 68 | "co": TestCO, 69 | "reset": TestRESET, 70 | } 71 | test_type = test_cls.get(test_type_name, Test) 72 | 73 | class TestDSEC(test_type): 74 | 75 | # DSEC Test set 76 | def configure_dataloader(self): 77 | test_set = assemble_dsec_test_set(self.spec.dataset.common.test_root, 78 | seq_len=self.spec.dataset.val.get( 79 | "sequence_length", None), 80 | representation_type=self.spec.dataset.get("representation_type", None)) 81 | if isinstance(test_set, list): 82 | val_dataloader = [DataLoader( 83 | seq, batch_size=1, shuffle=False, num_workers=0, 84 | ) for seq in test_set] 85 | else: 86 | val_dataloader = DataLoader( 87 | test_set, batch_size=1, shuffle=False, num_workers=0, 88 | ) 89 | return val_dataloader 90 | 91 | @torch.no_grad() 92 | def execute_test(self, model_eval, save_all=False): 93 | try: 94 | with self.evaluate_model(model_eval) as model: 95 | model_forward_fn = self.configure_model_forward_fn(model) 96 | forward_pass_fn = self.configure_forward_pass_fn( 97 | model_forward_fn) 98 | with self.logger.log_test(model) as log_path: 99 | if not isinstance(self.data_loader, list): 100 | self.data_loader = [self.data_loader] 101 | for seq_loader in self.data_loader: 102 | with self.logger.record_sequence(seq_loader.dataset.seq_name, 103 | log_path) as rec: 104 | for idx, batch in enumerate(tqdm(seq_loader, position=1)): 105 | if isinstance(batch, list): 106 | assert 'save_submission' in batch[-1] 107 | else: 108 | assert 'save_submission' in batch 109 | # for non-recurrent loading, we skip samples 110 | # not for eval 111 | if not batch['save_submission'].cpu().item() and not save_all: 112 | continue 113 | 114 | batch = move_batch_to_cuda( 115 | batch, self.device) 116 | start_time = time.perf_counter() 117 | 118 | out = forward_pass_fn(batch) 119 | end_time = time.perf_counter() 120 | #print(f"Forward pass time: {(end_time - start_time):.4f}") 121 | self.evaluator.evaluate(batch, out, idx) 122 | rec.log_tensors(batch, out, idx, save_all) 123 | rec.log_metrics(self.evaluator.results) 124 | 125 | self.pack_submission_to_zip(log_path) 126 | 127 | return 0, self.logger.summary() 128 | except Exception as e: 129 | return e, None 130 | 131 | @staticmethod 132 | def pack_submission_to_zip(log_path): 133 | import zipfile 134 | import os 135 | import tempfile 136 | from pathlib import Path 137 | from ..check_submission import check_submission 138 | with zipfile.ZipFile(os.path.join(log_path, 'submission.zip'), 'w') as zip: 139 | for seq in os.listdir(log_path): 140 | if os.path.isdir(os.path.join(log_path, seq, "submission")): 141 | for file in os.listdir(os.path.join(log_path, seq, "submission")): 142 | zip.write(os.path.join( 143 | log_path, seq, "submission", file), os.path.join(seq, file)) 144 | with tempfile.TemporaryDirectory() as tempdir: 145 | zipfile.ZipFile(os.path.join( 146 | log_path, 'submission.zip')).extractall(tempdir) 147 | assert check_submission( 148 | Path(tempdir), Path("data/test_forward_optical_flow_timestamps")), \ 149 | "submission did not pass check" 150 | return TestDSEC 151 | 152 | class TestMVSEC(Test): 153 | def __init__(self, test_spec): 154 | super().__init__(test_spec) 155 | 156 | def configure_dataloader(self): 157 | valid_set = MVSEC("outdoor_day1", filter=(4356, 4706), num_bins=15, augment=False) 158 | 159 | collate_fn = rec_train_collate if self.spec.dataset.val.get("recurrent", False) \ 160 | else train_collate 161 | assert self.spec.data_loader.args.shuffle is False, \ 162 | "shuffle must be false for val run." 163 | if isinstance(valid_set, list): 164 | val_dataloader = [DataLoader( 165 | seq, collate_fn=collate_fn, **self.spec.data_loader.args 166 | ) for seq in valid_set] 167 | else: 168 | val_dataloader = DataLoader( 169 | valid_set, collate_fn=collate_fn, **self.spec.data_loader.args 170 | ) 171 | return val_dataloader 172 | -------------------------------------------------------------------------------- /idn/tests/eval.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from ..utils.retrieval_fn import get_retreival_fn 3 | from ..model.loss import sparse_lnorm, compute_npe 4 | 5 | fm = namedtuple("frame_metric", ["n_frame", "value"]) 6 | 7 | 8 | class Evaluator: 9 | def __init__(self, spec): 10 | spec = dict() if spec is None else spec 11 | self.spec = spec 12 | self.assemble_eval_fn() 13 | 14 | def evaluate(self, batch, out, idx): 15 | for quantity, metrics in self.spec.items(): 16 | for metric in metrics: 17 | try: 18 | q = self.quantity_retrieval_fn[quantity](out, batch) 19 | except KeyError: 20 | continue 21 | eval_result = self.eval_fn[metric](*q) 22 | eval_metric = self.extract_metric(eval_result) 23 | self.results[quantity][metric].append(fm(idx, eval_metric)) 24 | 25 | def assemble_eval_fn(self): 26 | self.results = self.initialize_metrics_dict(self.spec) 27 | self.eval_fn = dict() 28 | self.quantity_retrieval_fn = dict() 29 | for quantity, metrics in self.spec.items(): 30 | if quantity not in self.quantity_retrieval_fn: 31 | self.quantity_retrieval_fn[quantity] = get_retreival_fn( 32 | quantity) 33 | for metric in metrics: 34 | if metric not in self.eval_fn: 35 | self.eval_fn[metric] = self.get_eval_fn(metric) 36 | 37 | 38 | @staticmethod 39 | def initialize_metrics_dict(spec): 40 | results = dict() 41 | for quantity, metrics in spec.items(): 42 | results[quantity] = {metric: [] for metric in metrics} 43 | return results 44 | 45 | @staticmethod 46 | def get_eval_fn(metric): 47 | if metric == "L1": 48 | return lambda estimate, ground_truth: \ 49 | sparse_lnorm(1, estimate, ground_truth.frame, ground_truth.mask, 50 | per_frame=True) 51 | if metric == "L2": 52 | return lambda estimate, ground_truth: \ 53 | sparse_lnorm(2, estimate, ground_truth.frame, ground_truth.mask, 54 | per_frame=True) 55 | if metric == "1PE": 56 | return lambda estimate, ground_truth: \ 57 | compute_npe(1, estimate, ground_truth.frame, ground_truth.mask) 58 | if metric == "3PE": 59 | return lambda estimate, ground_truth: \ 60 | compute_npe(3, estimate, ground_truth.frame, ground_truth.mask) 61 | 62 | @staticmethod 63 | def extract_metric(eval_result): 64 | assert "metric" in eval_result, "metric not found in eval result" 65 | if isinstance(eval_result["metric"], list): 66 | assert len(eval_result["metric"]) == 1, "multiple metrics found" 67 | return eval_result["metric"][0] 68 | else: 69 | return eval_result["metric"] 70 | -------------------------------------------------------------------------------- /idn/tests/test.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from ..utils.logger import Logger 3 | from ..loader.loader_dsec import assemble_dsec_sequences, rec_train_collate, \ 4 | train_collate 5 | from ..utils.helper_functions import move_batch_to_cuda 6 | from tqdm import tqdm 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from contextlib import contextmanager 10 | from .eval import Evaluator 11 | 12 | 13 | class Test: 14 | fm = namedtuple("frame_metric", ["n_frame", "value"]) 15 | 16 | def __init__(self, test_spec): 17 | self.spec = test_spec 18 | self.device = self.spec.data_loader.gpu 19 | self.data_loader = self.configure_dataloader() 20 | self.evaluator = Evaluator(test_spec.metrics) 21 | self.logger = Logger(self.spec.logger, self.spec.name) 22 | 23 | 24 | def configure_dataloader(self): 25 | valid_set = assemble_dsec_sequences( 26 | self.spec.dataset.common.data_root, 27 | include_seq=self.spec.dataset.val.seq, 28 | require_gt=True, 29 | config=self.spec.dataset.val, 30 | representation_type=self.spec.dataset.get( 31 | "representation_type", None), 32 | ) 33 | collate_fn = rec_train_collate if self.spec.dataset.val.get("recurrent", False) \ 34 | else train_collate 35 | assert self.spec.data_loader.args.shuffle is False, \ 36 | "shuffle must be false for val run." 37 | if isinstance(valid_set, list): 38 | val_dataloader = [DataLoader( 39 | seq, collate_fn=collate_fn, **self.spec.data_loader.args 40 | ) for seq in valid_set] 41 | else: 42 | val_dataloader = DataLoader( 43 | valid_set, collate_fn=collate_fn, **self.spec.data_loader.args 44 | ) 45 | return val_dataloader 46 | 47 | def assemble_postprocess_fn(self, postprocess): 48 | pass 49 | return None 50 | 51 | @staticmethod 52 | def configure_model_forward_fn(model): 53 | return model.forward 54 | 55 | @staticmethod 56 | def configure_forward_pass_fn(forward_fn): 57 | return forward_fn 58 | 59 | def configure_model(self, model): 60 | pass 61 | 62 | def cleanup_model(self, model): 63 | pass 64 | 65 | @contextmanager 66 | def evaluate_model(self, model): 67 | istrain = model.training 68 | original_device = next(model.parameters()).device 69 | try: 70 | model.cuda(self.device) 71 | self.configure_model(model) 72 | model.eval() 73 | yield model 74 | finally: 75 | self.cleanup_model(model) 76 | model.to(original_device) 77 | model.train(mode=istrain) 78 | 79 | @torch.no_grad() 80 | def execute_test(self, model_eval): 81 | try: 82 | with self.evaluate_model(model_eval) as model: 83 | model_forward_fn = self.configure_model_forward_fn(model) 84 | forward_pass_fn = self.configure_forward_pass_fn( 85 | model_forward_fn) 86 | with self.logger.log_test(model) as log_path: 87 | if not isinstance(self.data_loader, list): 88 | self.data_loader = [self.data_loader] 89 | for seq_loader in self.data_loader: 90 | with self.logger.record_sequence(seq_loader.dataset.seq_name, 91 | log_path) as rec: 92 | for idx, batch in enumerate(tqdm(seq_loader, position=1)): 93 | batch = move_batch_to_cuda(batch, self.device) 94 | out = forward_pass_fn(batch) 95 | self.evaluator.evaluate(batch, out, idx) 96 | rec.log_tensors(batch, out, idx) 97 | rec.log_metrics(self.evaluator.results) 98 | 99 | return 0, self.logger.summary() 100 | except Exception as e: 101 | return e, None 102 | -------------------------------------------------------------------------------- /idn/train.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import hydra 3 | from .utils.trainer import Trainer 4 | 5 | # @hydra.main(config_path="config", config_name="mvsec_train") 6 | # @hydra.main(config_path="config", config_name="tid_train") 7 | @hydra.main(config_path="config", config_name="id_train") 8 | 9 | def main(config): 10 | print(OmegaConf.to_yaml(config)) 11 | 12 | trainer = Trainer(config) 13 | 14 | print("Number of parameters: ", sum(p.numel() 15 | for p in trainer.model.parameters() if p.requires_grad)) 16 | 17 | trainer.fit() 18 | 19 | 20 | if __name__ == '__main__': 21 | main() 22 | -------------------------------------------------------------------------------- /idn/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | class CallbackBridge: 2 | def __init__(self): 3 | self.callbacks = list() 4 | 5 | def configure_callbacks(self, config): 6 | def get_cb_from_name(callback): 7 | if callback == "logger": 8 | from .cb.logger import CBLogger 9 | return CBLogger 10 | elif callback == "validator": 11 | from .cb.validator import CBValidator 12 | return CBValidator 13 | if config is not None: 14 | for callback, callback_config in config.items(): 15 | if "enable" in callback_config.keys(): 16 | self.callbacks.append( 17 | get_cb_from_name(callback)(callback_config)) 18 | 19 | def execute_callbacks(self, callback_type): 20 | for callback in sorted(self.callbacks, 21 | key=lambda x: x.call_order[callback_type]): 22 | getattr(callback, callback_type)(self) 23 | 24 | 25 | class Callback: 26 | callback_types = [ 27 | "on_init_end", 28 | "on_train_begin", 29 | "on_train_end", 30 | "on_epoch_begin", 31 | "on_epoch_end", 32 | "on_batch_begin", 33 | "on_batch_end", 34 | "on_step_begin", 35 | "on_step_end", 36 | ] 37 | 38 | def __init__(self): 39 | self.call_order = dict.fromkeys(self.callback_types, 0) 40 | 41 | def on_init_end(self, caller): 42 | pass 43 | 44 | def on_train_begin(self, caller): 45 | pass 46 | 47 | def on_train_end(self, caller): 48 | pass 49 | 50 | def on_epoch_begin(self, caller): 51 | pass 52 | 53 | def on_epoch_end(self, caller): 54 | pass 55 | 56 | def on_batch_begin(self, caller): 57 | pass 58 | 59 | def on_batch_end(self, caller): 60 | pass 61 | 62 | def on_step_begin(self, caller): 63 | pass 64 | 65 | def on_step_end(self, caller): 66 | pass 67 | -------------------------------------------------------------------------------- /idn/utils/cb/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from ..callbacks import Callback 4 | 5 | class CBLogger(Callback): 6 | def __init__(self, logger_config, logger_type="exp_tracker"): 7 | super().__init__() 8 | self.logger_config = logger_config 9 | self.logger = None 10 | 11 | def on_init_end(self, caller): 12 | self.logger = caller.logger 13 | try: 14 | run_id = caller.logged_tracker["run_id"] 15 | except: 16 | run_id = None 17 | self.logger.on_init_end(caller.config, run_id=run_id) 18 | 19 | def on_train_begin(self, caller): 20 | self.logger.on_exp_begin(caller.model) 21 | 22 | def on_batch_end(self, caller): 23 | log_dict = dict() 24 | for key in self.logger_config.log_keys["batch_end"]: 25 | log_dict[key] = getattr(caller, key, None) 26 | self.logger.log_dict_at_step(log_dict) 27 | 28 | def on_epoch_end(self, caller): 29 | if self.logger.log_dir is not None: 30 | torch.save({ 31 | 'epoch': caller.epoch, 32 | 'model_state_dict': caller.model.state_dict(), 33 | 'optimizer_state_dict': caller.optimizer.state_dict(), 34 | 'scheduler_state_dict': caller.scheduler.state_dict() \ 35 | if caller.scheduler is not None else None, 36 | 'loss': caller.loss, 37 | 'tracker': self.logger.summary(), 38 | }, os.path.join(self.logger.log_dir, "model.ckpt")) 39 | -------------------------------------------------------------------------------- /idn/utils/cb/validator.py: -------------------------------------------------------------------------------- 1 | from pickle import NONE 2 | from ..callbacks import Callback 3 | from ..validation import Validator 4 | from warnings import warn 5 | 6 | class CBValidator(Callback): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.validator = None 10 | self.config = config 11 | self.logger = None 12 | 13 | def run_validation(self, caller, sanity_check_run=False): 14 | validator = Validator(caller.config.validation) 15 | results = validator(caller.model) 16 | if self.logger is not None and not sanity_check_run: 17 | self.logger.log_dict_at_step(results) 18 | else: 19 | print(results) 20 | 21 | def on_train_begin(self, caller): 22 | self.logger = caller.logger 23 | 24 | def on_batch_end(self, caller): 25 | if caller.step == self.config.get("sanity_run_step", None): 26 | self.run_validation(caller, sanity_check_run=True) 27 | if self.time_to_validate(step=caller.step): 28 | self.run_validation(caller) 29 | 30 | def on_epoch_end(self, caller): 31 | if self.time_to_validate(epoch=caller.epoch): 32 | self.run_validation(caller) 33 | 34 | def time_to_validate(self, step=None, epoch=None): 35 | if self.config.frequency_type == "epoch": 36 | return epoch != 0 and epoch % self.config.frequency == 0 \ 37 | if epoch is not None else False 38 | elif self.config.frequency_type == "step": 39 | return step != 0 and step % self.config.frequency == 0 \ 40 | if step is not None else False 41 | else: 42 | warn(f"Frequency type {self.config.frequency_type} not recognized.") 43 | return False 44 | -------------------------------------------------------------------------------- /idn/utils/dsec_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from enum import Enum, auto 4 | from time import time 5 | 6 | 7 | class RepresentationType(Enum): 8 | VOXEL = auto() 9 | STEPAN = auto() 10 | 11 | 12 | class EventRepresentation: 13 | def __init__(self): 14 | pass 15 | 16 | def convert(self, events): 17 | raise NotImplementedError 18 | 19 | 20 | class VoxelGrid(EventRepresentation): 21 | def __init__(self, input_size: tuple, normalize: bool): 22 | assert len(input_size) == 3 23 | self.voxel_grid = torch.zeros( 24 | (input_size), dtype=torch.float, requires_grad=False) 25 | self.nb_channels = input_size[0] 26 | self.normalize = normalize 27 | 28 | def convert(self, events): 29 | C, H, W = self.voxel_grid.shape 30 | with torch.no_grad(): 31 | self.voxel_grid = self.voxel_grid.to(events['p'].device) 32 | voxel_grid = self.voxel_grid.clone() 33 | 34 | t_norm = events['t'] 35 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0]) 36 | 37 | x0 = events['x'].int() 38 | y0 = events['y'].int() 39 | t0 = t_norm.int() 40 | 41 | value = 2*events['p']-1 42 | #start_t = time() 43 | for xlim in [x0, x0+1]: 44 | for ylim in [y0, y0+1]: 45 | for tlim in [t0, t0+1]: 46 | 47 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & ( 48 | ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels) 49 | interp_weights = value * (1 - (xlim-events['x']).abs()) * ( 50 | 1 - (ylim-events['y']).abs()) * (1 - (tlim - t_norm).abs()) 51 | index = H * W * tlim.long() + \ 52 | W * ylim.long() + \ 53 | xlim.long() 54 | 55 | voxel_grid.put_( 56 | index[mask], interp_weights[mask], accumulate=True) 57 | 58 | if self.normalize: 59 | mask = torch.nonzero(voxel_grid, as_tuple=True) 60 | if mask[0].size()[0] > 0: 61 | mean = voxel_grid[mask].mean() 62 | std = voxel_grid[mask].std() 63 | if std > 0: 64 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 65 | else: 66 | voxel_grid[mask] = voxel_grid[mask] - mean 67 | 68 | return voxel_grid 69 | 70 | 71 | class PolarityCount(EventRepresentation): 72 | def __init__(self, input_size: tuple): 73 | assert len(input_size) == 3 74 | self.voxel_grid = torch.zeros( 75 | (input_size), dtype=torch.float, requires_grad=False) 76 | self.nb_channels = input_size[0] 77 | 78 | def convert(self, events): 79 | C, H, W = self.voxel_grid.shape 80 | with torch.no_grad(): 81 | self.voxel_grid = self.voxel_grid.to(events['p'].device) 82 | voxel_grid = self.voxel_grid.clone() 83 | 84 | x0 = events['x'].int() 85 | y0 = events['y'].int() 86 | 87 | #start_t = time() 88 | for xlim in [x0, x0+1]: 89 | for ylim in [y0, y0+1]: 90 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & ( 91 | ylim >= 0) 92 | interp_weights = (1 - (xlim-events['x']).abs()) * ( 93 | 1 - (ylim-events['y']).abs()) 94 | index = H * W * events['p'].long() + \ 95 | W * ylim.long() + \ 96 | xlim.long() 97 | 98 | voxel_grid.put_( 99 | index[mask], interp_weights[mask], accumulate=True) 100 | 101 | return voxel_grid 102 | 103 | 104 | def flow_16bit_to_float(flow_16bit: np.ndarray): 105 | assert flow_16bit.dtype == np.uint16 106 | assert flow_16bit.ndim == 3 107 | h, w, c = flow_16bit.shape 108 | assert c == 3 109 | 110 | valid2D = flow_16bit[..., 2] == 1 111 | assert valid2D.shape == (h, w) 112 | assert np.all(flow_16bit[~valid2D, -1] == 0) 113 | valid_map = np.where(valid2D) 114 | 115 | # to actually compute something useful: 116 | flow_16bit = flow_16bit.astype('float') 117 | 118 | flow_map = np.zeros((h, w, 2)) 119 | flow_map[valid_map[0], valid_map[1], 0] = ( 120 | flow_16bit[valid_map[0], valid_map[1], 0] - 2 ** 15) / 128 121 | flow_map[valid_map[0], valid_map[1], 1] = ( 122 | flow_16bit[valid_map[0], valid_map[1], 1] - 2 ** 15) / 128 123 | return flow_map, valid2D 124 | -------------------------------------------------------------------------------- /idn/utils/exp_tracker.py: -------------------------------------------------------------------------------- 1 | class ExpTracker: 2 | def __init__(self) -> None: 3 | self.log_dir = None 4 | 5 | def on_init_end(self, *args, **kwargs): 6 | pass 7 | 8 | def on_exp_begin(self, *args, **kwargs): 9 | pass 10 | 11 | def log_dict_at_step(self, dict, step=None): 12 | pass 13 | 14 | def summary(self): 15 | return { 16 | "id": "exp_tracker", 17 | } 18 | -------------------------------------------------------------------------------- /idn/utils/helper_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | 5 | def log_metrics(flag, metrics): 6 | log = dict() 7 | for seq in metrics.keys(): 8 | log[f"{seq}-l1-avg"] = metrics[seq]["l1_avg"] 9 | log[f"{seq}-l2-avg"] = metrics[seq]["l2_avg"] 10 | 11 | 12 | def move_tensor_to_cuda(dict_tensors, gpu): 13 | assert isinstance(dict_tensors, dict) 14 | for key, value in dict_tensors.items(): 15 | if isinstance(value, torch.Tensor): 16 | dict_tensors[key] = value.to(gpu, non_blocking=True) 17 | return dict_tensors 18 | 19 | 20 | def move_dict_to_cuda(dictionary_of_tensors, gpu): 21 | if isinstance(dictionary_of_tensors, dict): 22 | return { 23 | key: move_dict_to_cuda(value, gpu) 24 | for key, value in dictionary_of_tensors.items() if isinstance(value, torch.Tensor) 25 | } 26 | return dictionary_of_tensors.to(gpu, dtype=torch.float) 27 | 28 | 29 | def move_list_to_cuda(list_of_dicts, gpu): 30 | for i in range(len(list_of_dicts)): 31 | list_of_dicts[i] = move_tensor_to_cuda(list_of_dicts[i], gpu) 32 | return list_of_dicts 33 | 34 | 35 | def move_batch_to_cuda(batch, gpu): 36 | if isinstance(batch, dict): 37 | return move_tensor_to_cuda(batch, gpu) 38 | elif isinstance(batch, list): 39 | return move_list_to_cuda(batch, gpu) 40 | else: 41 | raise Exception("Batch is not a list or dict") 42 | 43 | 44 | def get_values_from_key(input_list, key): 45 | # Returns all the values with the same key from 46 | # a list filled with dicts of the same kind 47 | out = [] 48 | for i in input_list: 49 | out.append(i[key]) 50 | return out 51 | 52 | 53 | def create_save_path(subdir, name): 54 | # Check if sub-folder exists, and create if necessary 55 | if not os.path.exists(subdir): 56 | os.mkdir(subdir) 57 | # Create a new folder (named after the name defined in the config file) 58 | path = os.path.join(subdir, name) 59 | # Check if path already exists. if yes -> append a number 60 | if os.path.exists(path): 61 | i = 1 62 | while os.path.exists(path + "_" + str(i)): 63 | i += 1 64 | path = path + '_' + str(i) 65 | os.mkdir(path) 66 | return path 67 | 68 | 69 | def get_nth_element_of_all_dict_keys(dict, idx): 70 | out_dict = {} 71 | for k in dict.keys(): 72 | d = dict[k][idx] 73 | if isinstance(d, torch.Tensor): 74 | out_dict[k] = d.detach().cpu().item() 75 | else: 76 | out_dict[k] = d 77 | return out_dict 78 | 79 | 80 | def get_number_of_saved_elements(path, template, first=1): 81 | i = first 82 | while True: 83 | if os.path.exists(os.path.join(path, template.format(i))): 84 | i += 1 85 | else: 86 | break 87 | return range(first, i) 88 | 89 | 90 | def create_file_path(subdir, name): 91 | # Check if sub-folder exists, else raise exception 92 | if not os.path.exists(subdir): 93 | raise Exception("Path {} does not exist!".format(subdir)) 94 | # Check if file already exists, else create path 95 | if not os.path.exists(os.path.join(subdir, name)): 96 | return os.path.join(subdir, name) 97 | else: 98 | path = os.path.join(subdir, name) 99 | prefix, suffix = path.split('.') 100 | i = 1 101 | while os.path.exists("{}_{}.{}".format(prefix, i, suffix)): 102 | i += 1 103 | return "{}_{}.{}".format(prefix, i, suffix) 104 | 105 | 106 | def update_dict(dict_old, dict_new): 107 | # Update all the entries of dict_old with the new values(that have the identical keys) of dict_new 108 | for k in dict_new.keys(): 109 | if k in dict_old.keys(): 110 | # Replace the entry 111 | if isinstance(dict_new[k], dict): 112 | update_dict(dict_old[k], dict_new[k]) 113 | else: 114 | dict_old[k] = dict_new[k] 115 | return dict_old 116 | -------------------------------------------------------------------------------- /idn/utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | import itertools 5 | import imageio 6 | import tempfile 7 | import pathlib 8 | import torch 9 | from contextlib import contextmanager, nullcontext 10 | from ..tests.eval import fm 11 | 12 | 13 | class Logger: 14 | def __init__(self, config, test_name): 15 | self.config = config 16 | self.test_name = test_name 17 | self.metrics = dict() 18 | self.parse_log_fields() 19 | 20 | def parse_log_fields(self): 21 | def parse_index(index): 22 | for i, idx in enumerate(index): 23 | if isinstance(idx, str): 24 | assert '-' in idx 25 | a, b = idx.split('-') 26 | index[i] = list(range(int(a), int(b)+1)) 27 | parsed_list = [] 28 | for i in index: 29 | # this is because i can be type of omegaconf.listconfig 30 | if hasattr(i, '__len__'): 31 | parsed_list.extend(i) 32 | else: 33 | parsed_list.append(i) 34 | return parsed_list 35 | 36 | if not hasattr(self.config, "saved_tensors"): 37 | self.config.saved_tensors = None 38 | if self.config.saved_tensors is None: 39 | return 40 | for field, seqs in self.config.saved_tensors.items(): 41 | if seqs is None: 42 | continue 43 | else: 44 | for seq, idx in seqs.items(): 45 | if idx is None: 46 | continue 47 | else: 48 | self.config.saved_tensors[field][seq] = parse_index( 49 | idx) 50 | 51 | class SeqLogger: 52 | def __init__(self, config, log_path, seq_name): 53 | self.config = config 54 | self.path = log_path 55 | self.seq_name = seq_name 56 | 57 | def log_tensors(self, batch, out, idx, save_all=False): 58 | if isinstance(batch, list): 59 | batch = batch[-1] 60 | assert isinstance(out, dict) 61 | if 'save_submission' in batch: 62 | if isinstance(batch['save_submission'], torch.Tensor): 63 | batch['save_submission'] = batch['save_submission'].cpu().item() 64 | if batch['save_submission'] or save_all: 65 | self.save_submission(out['final_prediction'], 66 | batch['file_index'].cpu().item()) 67 | if getattr(self.config, "saved_tensors", None) is None: 68 | return 69 | for key, value in itertools.chain(batch.items(), out.items()): 70 | if key in self.config.saved_tensors: 71 | if self.tblogged(self.config.saved_tensors[key], idx): 72 | self.save_tensor(key, value, idx) 73 | 74 | def tblogged(self, logging_index, idx): 75 | if logging_index is None: 76 | return True 77 | elif self.seq_name not in logging_index: 78 | return False 79 | elif logging_index[self.seq_name] is None: 80 | return True 81 | else: 82 | return idx in logging_index[self.seq_name] 83 | 84 | def save_submission(self, flow, file_idx): 85 | os.makedirs(os.path.join(self.path, "submission"), exist_ok=True) 86 | if isinstance(flow, torch.Tensor): 87 | flow = flow.cpu().numpy() 88 | assert flow.shape == (1, 2, 480, 640) 89 | flow = flow.squeeze() 90 | _, h, w = flow.shape 91 | scaled_flow = np.rint( 92 | flow*128 + 2**15).astype(np.uint16).transpose(1, 2, 0) 93 | flow_image = np.concatenate((scaled_flow, np.zeros((h, w, 1), 94 | dtype=np.uint16)), axis=-1) 95 | imageio.imwrite(os.path.join(self.path, "submission", f"{file_idx:06d}.png"), 96 | flow_image, format='PNG-FI') 97 | 98 | def save_tensor(self, key, value, idx): 99 | if isinstance(value, torch.Tensor): 100 | value = value.cpu().numpy() 101 | np.save(os.path.join(self.path, f"{key}_{idx:05d}.npy"), value) 102 | return 103 | if isinstance(value, np.ndarray): 104 | np.save(os.path.join(self.path, f"{key}_{idx:05d}.npy"), value) 105 | return 106 | 107 | def log_metrics(self, results): 108 | self.results = results 109 | 110 | def compute_statistics(self): 111 | self.metric_stats = dict() 112 | for quantity, metrics in self.results.items(): 113 | self.metric_stats[quantity] = dict() 114 | for metric, list_fm in metrics.items(): 115 | assert isinstance(list_fm[0], fm) 116 | self.metric_stats[quantity][metric] = dict() 117 | metric_list = list(map(lambda x: x.value, list_fm)) 118 | self.metric_stats[quantity][metric]['avg'] = sum( 119 | metric_list) / len(metric_list) 120 | 121 | @contextmanager 122 | def record_sequence(self, seq_name, log_path): 123 | try: 124 | os.makedirs(os.path.join(log_path, seq_name), exist_ok=True) 125 | self.seq_logger = self.SeqLogger( 126 | self.config, os.path.join(log_path, seq_name), seq_name) 127 | yield self.seq_logger 128 | finally: 129 | self.seq_logger.compute_statistics() 130 | self.copy_metrics(self.seq_logger.results, self.seq_logger.metric_stats, 131 | seq_name) 132 | 133 | @contextmanager 134 | def log_test(self, model): 135 | log_dir = self.config.get('save_dir', None) 136 | with tempfile.TemporaryDirectory(prefix='id2log', dir='/tmp') \ 137 | if log_dir is None else nullcontext(pathlib.Path(log_dir)) as wd: 138 | try: 139 | yield wd 140 | finally: 141 | self.dump(model, wd) 142 | 143 | def copy_metrics(self, results, stats, seq_name): 144 | self.metrics[seq_name] = dict() 145 | self.metrics[seq_name]['results'] = results 146 | self.metrics[seq_name]['stats'] = stats 147 | 148 | def dump(self, model, wd): 149 | 150 | dict_to_dump = {"meta": { 151 | "test_name": self.test_name, 152 | }, **self.metrics} 153 | json.dump(dict_to_dump, open(os.path.join( 154 | wd, 'metrics.json'), 'w')) 155 | torch.save(model.state_dict(), os.path.join(wd, 'model.pt')) 156 | 157 | def summary(self): 158 | summary_flat = dict() 159 | summary = dict() 160 | for seq_name, results in self.metrics.items(): 161 | summary[seq_name] = dict() 162 | for quantity, metrics in results["stats"].items(): 163 | summary[seq_name][quantity] = metrics 164 | for metric, stats in metrics.items(): 165 | for stat, value in stats.items(): 166 | assert isinstance(value, (int, float)) 167 | summary_flat['-'.join([seq_name, 168 | quantity, metric, stat])] = value 169 | return summary 170 | 171 | -------------------------------------------------------------------------------- /idn/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | from ..model.loss import sparse_l1 2 | 3 | def get_loss_fn_by_name(loss_name): 4 | if loss_name == 'sparse_l1': 5 | return lambda estimate, ground_truth: sparse_l1(estimate, \ 6 | ground_truth.frame, valid_mask=ground_truth.mask) 7 | else: 8 | assert False, f"loss {loss_name} not implemented" 9 | 10 | 11 | def get_valid_loss_fn_by_name(config): 12 | valid_loss_fn = dict() 13 | # 14 | if config.name == 'sparse_l1': 15 | return sparse_l1 16 | 17 | def compute_seq_loss(weight, loss_fn, estimate, ground_truth): 18 | assert isinstance(estimate, list) 19 | if weight == "last": 20 | return loss_fn(estimate[-1], ground_truth[-1]) 21 | else: 22 | seq_loss = list(map(loss_fn, estimate, ground_truth)) 23 | if weight == "sum": 24 | return sum(seq_loss) 25 | 26 | elif weight == "avg": 27 | return sum(seq_loss)/len(seq_loss) 28 | 29 | elif hasattr(weight, '__getitem__') and isinstance(weight[0], float): 30 | assert len(weight) == len(estimate) 31 | return sum(map(lambda x, y: x*y, seq_loss, weight)) 32 | else: 33 | assert False, f"weight {weight} for seq loss not supported" -------------------------------------------------------------------------------- /idn/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | def get_model_by_name(name, model_config): 2 | if name == "RecIDE": 3 | from ..model.idedeq import RecIDE 4 | return RecIDE(model_config) 5 | elif name == "IDEDEQIDO": 6 | from ..model.idedeq import IDEDEQIDO 7 | return IDEDEQIDO(model_config) 8 | else: 9 | raise ValueError("Unknown model name: {}".format(name)) 10 | -------------------------------------------------------------------------------- /idn/utils/mvsec_utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import pandas 4 | from PIL import Image 5 | import random 6 | import h5py 7 | import json 8 | 9 | 10 | class EventSequence(object): 11 | def __init__(self, dataframe, params, features=None, timestamp_multiplier=None, convert_to_relative=False): 12 | if isinstance(dataframe, pandas.DataFrame): 13 | self.feature_names = dataframe.columns.values 14 | self.features = dataframe.to_numpy() 15 | else: 16 | self.feature_names = numpy.array(['ts', 'x', 'y', 'p'], dtype=object) 17 | if features is None: 18 | self.features = numpy.zeros([1, 4]) 19 | else: 20 | self.features = features 21 | self.image_height = params['height'] 22 | self.image_width = params['width'] 23 | if not self.is_sorted(): 24 | self.sort_by_timestamp() 25 | if timestamp_multiplier is not None: 26 | self.features[:,0] *= timestamp_multiplier 27 | if convert_to_relative: 28 | self.absolute_time_to_relative() 29 | 30 | def get_sequence_only(self): 31 | return self.features 32 | 33 | def __len__(self): 34 | return len(self.features) 35 | 36 | def __add__(self, sequence): 37 | event_sequence = EventSequence(dataframe=None, 38 | features=numpy.concatenate([self.features, sequence.features]), 39 | params={'height': self.image_height, 40 | 'width': self.image_width}) 41 | return event_sequence 42 | 43 | def is_sorted(self): 44 | return numpy.all(self.features[:-1, 0] <= self.features[1:, 0]) 45 | 46 | def sort_by_timestamp(self): 47 | if len(self.features[:, 0]) > 0: 48 | sort_indices = numpy.argsort(self.features[:, 0]) 49 | self.features = self.features[sort_indices] 50 | 51 | def absolute_time_to_relative(self): 52 | """Transforms absolute time to time relative to the first event.""" 53 | start_ts = self.features[:,0].min() 54 | assert(start_ts == self.features[0,0]) 55 | self.features[:,0] -= start_ts 56 | 57 | 58 | def get_image(image_path): 59 | try: 60 | im = Image.open(image_path) 61 | # print(image_path) 62 | return numpy.array(im) 63 | except OSError: 64 | raise 65 | 66 | 67 | def get_events(event_path): 68 | # It's possible that there is no event file! (camera standing still) 69 | try: 70 | f = pandas.read_hdf(event_path, "myDataset") 71 | return f[['ts', 'x', 'y', 'p']] 72 | except OSError: 73 | print("No file " + event_path) 74 | print("Creating an array of zeros!") 75 | return 0 76 | 77 | 78 | def get_ts(path, i, type='int'): 79 | try: 80 | f = open(path, "r") 81 | if type == 'int': 82 | return int(f.readlines()[i]) 83 | elif type == 'double' or type == 'float': 84 | return float(f.readlines()[i]) 85 | except OSError: 86 | raise 87 | 88 | 89 | def get_batchsize(path_dataset): 90 | filepath = os.path.join(path_dataset, "cam0", "timestamps.txt") 91 | try: 92 | f = open(filepath, "r") 93 | return len(f.readlines()) 94 | except OSError: 95 | raise 96 | 97 | 98 | def get_batch(path_dataset, i): 99 | return 0 100 | 101 | 102 | def dataset_paths(dataset_name, path_dataset, subset_number=None): 103 | cameras = {'cam0': {}, 'cam1': {}, 'cam2': {}, 'cam3': {}} 104 | if subset_number is not None: 105 | dataset_name = dataset_name + "_" + str(subset_number) 106 | paths = {'dataset_folder': os.path.join(path_dataset, dataset_name)} 107 | 108 | # For every camera, define its path 109 | for camera in cameras: 110 | cameras[camera]['image_folder'] = os.path.join(paths['dataset_folder'], camera, 'image_raw') 111 | cameras[camera]['event_folder'] = os.path.join(paths['dataset_folder'], camera, 'events') 112 | cameras[camera]['disparity_folder'] = os.path.join(paths['dataset_folder'], camera, 'disparity_image') 113 | cameras[camera]['depth_folder'] = os.path.join(paths['dataset_folder'], camera, 'depthmap') 114 | cameras["timestamp_file"] = os.path.join(paths['dataset_folder'], 'cam0', 'timestamps.txt') 115 | cameras["image_type"] = ".png" 116 | cameras["event_type"] = ".h5" 117 | cameras["disparity_type"] = ".png" 118 | cameras["depth_type"] = ".tiff" 119 | cameras["indexing_type"] = "%0.6i" 120 | paths.update(cameras) 121 | return paths 122 | 123 | 124 | def get_indices(path_dataset, dataset, filter, shuffle=False): 125 | samples = [] 126 | for dataset_name in dataset: 127 | for subset in dataset[dataset_name]: 128 | # Get all the dataframe paths 129 | paths = dataset_paths(dataset_name, path_dataset, subset) 130 | 131 | # import timestamps 132 | ts = numpy.loadtxt(paths["timestamp_file"]) 133 | 134 | # frames = [] 135 | # For every timestamp, import according data 136 | for idx in eval(filter[dataset_name][str(subset)]): 137 | frame = {} 138 | frame['dataset_name'] = dataset_name 139 | frame['subset_number'] = subset 140 | frame['index'] = idx 141 | frame['timestamp'] = ts[idx] 142 | samples.append(frame) 143 | # shuffle dataset 144 | if shuffle: 145 | random.shuffle(samples) 146 | return samples 147 | 148 | 149 | def get_flow_h5(flow_path): 150 | scaling_factor = 0.05 # seconds/frame 151 | f = h5py.File(flow_path, 'r') 152 | height, width = int(f['header']['height']), int(f['header']['width']) 153 | assert(len(f['x']) == height*width) 154 | assert(len(f['y']) == height*width) 155 | x = numpy.array(f['x']).reshape([height,width])*scaling_factor 156 | y = numpy.array(f['y']).reshape([height,width])*scaling_factor 157 | return numpy.stack([x,y]) 158 | 159 | 160 | def get_flow_npy(flow_path): 161 | # Array 2,height, width 162 | # No scaling needed. 163 | return numpy.load(flow_path, allow_pickle=True) 164 | 165 | 166 | def get_pose(pose_path, index): 167 | pose = pandas.read_csv(pose_path, delimiter=',').loc[index].to_numpy() 168 | # Convert Timestamp to int (as all the other timestamps) 169 | pose[0] = int(pose[0]) 170 | return pose 171 | 172 | 173 | def load_config(path, datasets): 174 | config = {} 175 | for dataset_name in datasets: 176 | config[dataset_name] = {} 177 | for subset in datasets[dataset_name]: 178 | name = "{}_{}".format(dataset_name, subset) 179 | try: 180 | config[dataset_name][subset] = json.load(open(os.path.join(path, name, "config.json"))) 181 | except: 182 | print("Could not find config file for dataset" + dataset_name + "_" + str(subset) + 183 | ". Please check if the file 'config.json' is existing in the dataset-scene directory") 184 | raise 185 | return config 186 | -------------------------------------------------------------------------------- /idn/utils/retrieval_fn.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch.nn.functional as F 3 | 4 | 5 | def upflow8(flow, mode='bilinear'): 6 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 7 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 8 | 9 | def get_retreival_fn(quantity): 10 | fmask = namedtuple("masked_frame", ["frame", "mask"]) 11 | if quantity == "final_prediction": 12 | return retreival_pred_seq_1 13 | elif quantity == "pred_flow_seq": 14 | return retreival_pred_seq 15 | elif quantity == "final_prediction_nonseq": 16 | return retreival_pred_nonseq 17 | elif quantity == "pred_flow_next_seq": 18 | return retreival_pred_nextflow_seq 19 | elif quantity == "next_flow": 20 | return retreival_next_flow 21 | elif quantity == "pred_lowres_flow_next_seq": 22 | return retreival_pred_lowres_nextflow_seq 23 | 24 | assert False, f"quantity {quantity} not implemented" 25 | 26 | 27 | def retreival_pred_nonseq(out, batch): 28 | fmask = namedtuple("masked_frame", ["frame", "mask"]) 29 | return (out["final_prediction"], fmask(batch["flow_gt_event_volume_new"], 30 | batch["flow_gt_event_volume_new_valid_mask"])) 31 | 32 | 33 | def retreival_pred_seq_1(out, batch): 34 | fmask = namedtuple("masked_frame", ["frame", "mask"]) 35 | return (out["final_prediction"], fmask(batch[-1]["flow_gt_event_volume_new"], 36 | batch[-1]["flow_gt_event_volume_new_valid_mask"])) 37 | 38 | 39 | def retreival_pred_seq(out, batch): 40 | fmask = namedtuple("masked_frame", ["frame", "mask"]) 41 | return (out["flow_trajectory"], [fmask(x["flow_gt_event_volume_new"], 42 | x["flow_gt_event_volume_new_valid_mask"]) for x in batch]) 43 | 44 | 45 | def retreival_pred_nextflow_seq(out, batch): 46 | fmask = namedtuple("masked_frame", ["frame", "mask"]) 47 | return (out["flow_next_trajectory"], [fmask(x["flow_gt_next"], 48 | x["flow_gt_next_valid_mask"]) for x in batch]) 49 | 50 | 51 | def retreival_next_flow(out, batch): 52 | fmask = namedtuple("masked_frame", ["frame", "mask"]) 53 | return (out["next_flow"], fmask(batch[-1]["flow_gt_next"], 54 | batch[-1]["flow_gt_next_valid_mask"])) 55 | 56 | 57 | def retreival_pred_lowres_nextflow_seq(out, batch): 58 | fmask = namedtuple("masked_frame", ["frame", "mask"]) 59 | return ([upflow8(x) for x in out["flow_next_trajectory"][:-1]], [fmask(x["flow_gt_event_volume_new"], 60 | x["flow_gt_event_volume_new_valid_mask"]) for x in batch[1:]]) 61 | -------------------------------------------------------------------------------- /idn/utils/torch_environ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | 6 | def config_torch(config): 7 | torch.backends.cudnn.benchmark = True 8 | if config.get("debug_grad", False): 9 | torch_set_gradient_debug() 10 | if config.get("deterministic", True): 11 | torch_set_deterministic() 12 | 13 | 14 | def torch_set_deterministic(seed=141021): 15 | torch.manual_seed(seed) 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | # because of .put_ in forward_interpolate_pytorch 21 | torch.use_deterministic_algorithms(False) 22 | 23 | 24 | def torch_set_gradient_debug(): 25 | torch.autograd.set_detect_anomaly(True) 26 | -------------------------------------------------------------------------------- /idn/utils/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader, ConcatDataset 4 | from tqdm import tqdm 5 | from types import GeneratorType 6 | from collections import namedtuple 7 | from torchinfo import summary 8 | from torch.optim.lr_scheduler import OneCycleLR 9 | 10 | from .torch_environ import config_torch 11 | from .helper_functions import move_batch_to_cuda 12 | from .model_utils import get_model_by_name 13 | from .loss_utils import compute_seq_loss, get_loss_fn_by_name 14 | from .validation import Validator 15 | from .callbacks import CallbackBridge 16 | from .exp_tracker import ExpTracker 17 | from .retrieval_fn import get_retreival_fn 18 | from ..loader.loader_dsec import ( 19 | Sequence, 20 | RepresentationType, 21 | DatasetProvider, 22 | assemble_dsec_sequences, 23 | assemble_dsec_test_set, 24 | train_collate, 25 | rec_train_collate 26 | ) 27 | from ..loader.loader_mvsec import ( 28 | MVSEC, 29 | MVSECRecurrent 30 | ) 31 | 32 | 33 | class Trainer(CallbackBridge): 34 | def __init__(self, config, model=None): 35 | super().__init__() 36 | self.config = config 37 | config_torch(config.torch) 38 | self.model = model if model is not None else \ 39 | get_model_by_name(config.model.name, config.model) 40 | if config.model.get("pretrain_ckpt", None): 41 | self.resume_model_from_ckpt(config.model.pretrain_ckpt) 42 | 43 | if not self.config.get("eval_only", False): 44 | self.train_dataloader = self.configure_train_dataloader() 45 | self.configure_loss() 46 | self.optimizer = self.configure_optimizer() 47 | if config.optim.get("scheduler", None): 48 | self.scheduler = OneCycleLR( 49 | self.optimizer, max_lr=self.config.optim.lr, 50 | steps_per_epoch=len(self.train_dataloader), 51 | epochs=self.config.num_epoch, 52 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear', 53 | ) 54 | else: 55 | self.scheduler = None 56 | 57 | self.epoch = 0 58 | self.step = 0 59 | self.batches_seen = 0 60 | self.samples_seen = 0 61 | if config.get("resume_ckpt", None): 62 | resume_only_model = config.get("finetune", False) 63 | self.resume_from_ckpt(config.resume_ckpt, resume_only_model) 64 | self.logger = self.configure_tracker() 65 | 66 | self.configure_callbacks(config.callbacks) 67 | 68 | self.execute_callbacks("on_init_end") 69 | 70 | def configure_train_dataloader(self): 71 | if self.config.dataset.dataset_name == "dsec": 72 | train_set = assemble_dsec_sequences( 73 | self.config.dataset.common.data_root, 74 | exclude_seq=set( 75 | [val_seq for x in self.config.get("validation", dict()).values() for val_seq in x.dataset.val.seq]), 76 | require_gt=True, 77 | config=self.config.dataset.train, 78 | representation_type=self.config.dataset.get("representation_type", None), 79 | num_bins=self.config.dataset.get("num_voxel_bins", None) 80 | ) 81 | 82 | elif self.config.dataset.dataset_name == "mvsec": 83 | train_set = MVSEC("outdoor_day2", num_bins=self.config.dataset.get("num_voxel_bins", None), dt=None) #20Hz 84 | elif self.config.dataset.dataset_name == "mvsec_recurrent": 85 | train_set = MVSECRecurrent("outdoor_day2", augment=False, 86 | sequence_length=self.config.dataset.train.sequence_length) 87 | else: 88 | raise NotImplementedError 89 | collate_fn = rec_train_collate \ 90 | if hasattr(self.config.dataset.train, "recurrent") \ 91 | and self.config.dataset.train.recurrent else train_collate 92 | return DataLoader( 93 | train_set, collate_fn=collate_fn, **self.config.data_loader.train.args) 94 | 95 | def configure_optimizer(self): 96 | if self.config.optim.optimizer == "adam": 97 | return torch.optim.Adam(self.model.parameters(), lr=self.config.optim.lr) 98 | elif self.config.optim.optimizer == "adamw": 99 | return torch.optim.AdamW(self.model.parameters(), lr=self.config.optim.lr) 100 | else: 101 | raise NotImplementedError 102 | 103 | def resume_model_from_ckpt(self, ckpt): 104 | ckpt = torch.load(ckpt) 105 | if "model" in ckpt: 106 | self.model.load_state_dict(ckpt["model"]) 107 | elif "model_state_dict" in ckpt: 108 | self.model.load_state_dict(ckpt["model_state_dict"]) 109 | else: 110 | try: 111 | self.model.load_state_dict(ckpt) 112 | except: 113 | raise ValueError("Invalid checkpoint") 114 | 115 | def resume_from_ckpt(self, ckpt, resume_only_model=False): 116 | ckpt = torch.load(ckpt, map_location='cpu') 117 | self.model.load_state_dict(ckpt['model_state_dict']) 118 | if not resume_only_model: 119 | self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) 120 | self.epoch = ckpt['epoch'] 121 | if self.scheduler and 'scheduler_state_dict' in ckpt: 122 | self.scheduler.load_state_dict(ckpt['scheduler_state_dict']) 123 | if "tracker" in ckpt: 124 | self.logged_tracker = ckpt['tracker'] 125 | 126 | def configure_tracker(self): 127 | return ExpTracker() 128 | 129 | def configure_loss(self): 130 | lc = namedtuple("loss_config", 131 | ["retrieval_fn", "loss_fn", "weight", "seq_weight", "seq_norm"]) 132 | self.loss_config = dict() 133 | for quantity, config in self.config.loss.items(): 134 | self.loss_config[quantity] = lc( 135 | get_retreival_fn(quantity), 136 | get_loss_fn_by_name(config.loss_type), 137 | config.get("weight", 1.0), 138 | config.get("seq_weight", None), 139 | config.get("seq_norm", False)) 140 | 141 | def train_epoch(self): 142 | self.execute_callbacks("on_epoch_begin") 143 | self.model.train() 144 | self.model.cuda(self.config.data_loader.train.gpu) 145 | # This is necessary to sync optimizer parameter device with model device 146 | self.optimizer.load_state_dict(self.optimizer.state_dict()) 147 | for batch in tqdm(self.train_dataloader): 148 | self.execute_callbacks("on_batch_begin") 149 | self.optimizer.zero_grad() 150 | batch = move_batch_to_cuda( 151 | batch, self.config.data_loader.train.gpu) 152 | out = self.model(batch) 153 | if isinstance(out, GeneratorType): 154 | loss_item = [] 155 | for i, ret in enumerate(out): 156 | seq_len = len(ret["flow_trajectory"]) 157 | loss, loss_breakdown = self.compute_loss( 158 | ret, batch[i*seq_len:(i+1)*seq_len]) 159 | loss.backward() 160 | loss_item.append(loss.detach().item()) 161 | self.loss = sum(loss_item)/len(loss_item) 162 | self.loss_1 = \ 163 | loss_item[0] 164 | for loss_type, l in loss_breakdown.items(): 165 | setattr(self, 'loss_'+loss_type, l.item()) 166 | 167 | else: 168 | self.loss, _ = self.compute_loss(out, batch) 169 | self.loss.backward() 170 | self.execute_callbacks("on_step_begin") 171 | self.optimizer.step() 172 | self.execute_callbacks("on_step_end") 173 | self.execute_callbacks("on_batch_end") 174 | self.step += 1 175 | if self.scheduler: 176 | self.scheduler.step() 177 | self.lr = self.scheduler.get_last_lr()[0] 178 | self.execute_callbacks("on_epoch_end") 179 | self.epoch += 1 180 | 181 | def compute_loss(self, ret, batch): 182 | loss = dict() 183 | total_loss = 0 184 | for quantity, config in self.loss_config.items(): 185 | estimate, ground_truth = config.retrieval_fn(ret, batch) 186 | if isinstance(estimate, list): 187 | loss_fn = lambda estimate, ground_truth: \ 188 | compute_seq_loss(config.seq_weight, config.loss_fn, 189 | estimate, ground_truth) 190 | else: 191 | loss_fn = config.loss_fn 192 | 193 | loss[quantity] = loss_fn(estimate, ground_truth) 194 | total_loss += config.weight * loss[quantity] 195 | return total_loss, loss 196 | 197 | def fit(self, epochs=None): 198 | num_epochs = epochs if epochs is not None else self.config.num_epoch 199 | self.execute_callbacks("on_train_begin") 200 | try: 201 | while self.epoch < num_epochs: 202 | self.train_epoch() 203 | except: 204 | raise Exception("Training failed") 205 | finally: 206 | self.execute_callbacks("on_train_end") 207 | 208 | -------------------------------------------------------------------------------- /idn/utils/transformers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import re 5 | from torchvision.transforms import RandomCrop 6 | import torchvision.transforms.functional as TF 7 | 8 | 9 | def dictionary_of_numpy_arrays_to_tensors(sample): 10 | """Transforms dictionary of numpy arrays to dictionary of tensors.""" 11 | if isinstance(sample, dict): 12 | return { 13 | key: dictionary_of_numpy_arrays_to_tensors(value) 14 | for key, value in sample.items() 15 | } 16 | if isinstance(sample, np.ndarray): 17 | if len(sample.shape) == 2: 18 | return torch.from_numpy(sample).float().unsqueeze(0) 19 | else: 20 | return torch.from_numpy(sample).float() 21 | return sample 22 | 23 | 24 | class EventSequenceToVoxelGrid_Pytorch(object): 25 | # Source: https://github.com/uzh-rpg/rpg_e2vid/blob/master/utils/inference_utils.py#L480 26 | def __init__(self, num_bins, gpu=False, gpu_nr=1, normalize=True, forkserver=True): 27 | if forkserver: 28 | try: 29 | torch.multiprocessing.set_start_method('forkserver') 30 | except RuntimeError: 31 | pass 32 | self.num_bins = num_bins 33 | self.normalize = normalize 34 | if gpu: 35 | if not torch.cuda.is_available(): 36 | print('Warning: There\'s no CUDA support on this machine!') 37 | else: 38 | self.device = torch.device('cuda:' + str(gpu_nr)) 39 | else: 40 | self.device = torch.device('cpu') 41 | 42 | def __call__(self, event_sequence): 43 | """ 44 | Build a voxel grid with bilinear interpolation in the time domain from a set of events. 45 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity] 46 | :param num_bins: number of bins in the temporal axis of the voxel grid 47 | :param width, height: dimensions of the voxel grid 48 | :param device: device to use to perform computations 49 | :return voxel_grid: PyTorch event tensor (on the device specified) 50 | """ 51 | 52 | events = event_sequence.features.astype('float') 53 | 54 | width = event_sequence.image_width 55 | height = event_sequence.image_height 56 | 57 | assert (events.shape[1] == 4) 58 | assert (self.num_bins > 0) 59 | assert (width > 0) 60 | assert (height > 0) 61 | 62 | with torch.no_grad(): 63 | 64 | events_torch = torch.from_numpy(events) 65 | # with DeviceTimer('Events -> Device (voxel grid)'): 66 | events_torch = events_torch.to(self.device) 67 | 68 | # with DeviceTimer('Voxel grid voting'): 69 | voxel_grid = torch.zeros( 70 | self.num_bins, height, width, dtype=torch.float32, device=self.device).flatten() 71 | 72 | # normalize the event timestamps so that they lie between 0 and num_bins 73 | last_stamp = events_torch[-1, 0] 74 | first_stamp = events_torch[0, 0] 75 | 76 | assert last_stamp.dtype == torch.float64, 'Timestamps must be float64!' 77 | # assert last_stamp.item()%1 == 0, 'Timestamps should not have decimals' 78 | 79 | deltaT = last_stamp - first_stamp 80 | 81 | if deltaT == 0: 82 | deltaT = 1.0 83 | 84 | events_torch[:, 0] = (self.num_bins - 1) * \ 85 | (events_torch[:, 0] - first_stamp) / deltaT 86 | ts = events_torch[:, 0] 87 | xs = events_torch[:, 1].long() 88 | ys = events_torch[:, 2].long() 89 | pols = events_torch[:, 3].float() 90 | pols[pols == 0] = -1 # polarity should be +1 / -1 91 | 92 | tis = torch.floor(ts) 93 | tis_long = tis.long() 94 | dts = ts - tis 95 | vals_left = pols * (1.0 - dts.float()) 96 | vals_right = pols * dts.float() 97 | 98 | valid_indices = tis < self.num_bins 99 | valid_indices &= tis >= 0 100 | 101 | if events_torch.is_cuda: 102 | datatype = torch.cuda.LongTensor 103 | else: 104 | datatype = torch.LongTensor 105 | 106 | voxel_grid.index_add_(dim=0, 107 | index=(xs[valid_indices] + ys[valid_indices] 108 | * width + tis_long[valid_indices] * width * height).type( 109 | datatype), 110 | source=vals_left[valid_indices]) 111 | 112 | valid_indices = (tis + 1) < self.num_bins 113 | valid_indices &= tis >= 0 114 | 115 | voxel_grid.index_add_(dim=0, 116 | index=(xs[valid_indices] + ys[valid_indices] * width 117 | + (tis_long[valid_indices] + 1) * width * height).type(datatype), 118 | source=vals_right[valid_indices]) 119 | 120 | voxel_grid = voxel_grid.view(self.num_bins, height, width) 121 | 122 | if self.normalize: 123 | mask = torch.nonzero(voxel_grid, as_tuple=True) 124 | if mask[0].size()[0] > 0: 125 | mean = voxel_grid[mask].mean() 126 | std = voxel_grid[mask].std() 127 | if std > 0: 128 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 129 | else: 130 | voxel_grid[mask] = voxel_grid[mask] - mean 131 | 132 | return voxel_grid 133 | 134 | 135 | def apply_transform_to_field(sample, func, field_name): 136 | """ 137 | Applies a function to a field of a sample. 138 | :param sample: a sample 139 | :param func: a function that takes a numpy array and returns a numpy array 140 | :param field_name: the name of the field to transform 141 | :return: the transformed sample 142 | """ 143 | if isinstance(sample, dict): 144 | for key, value in sample.items(): 145 | if bool(re.search(field_name, key)): 146 | sample[key] = func(value) 147 | 148 | if isinstance(sample, np.ndarray) or isinstance(sample, torch.Tensor): 149 | return func(sample) 150 | return sample 151 | 152 | 153 | def apply_randomcrop_to_sample(sample, crop_size): 154 | """ 155 | Applies a random crop to a sample. 156 | :param sample: a sample 157 | :param crop_size: the size of the crop 158 | :return: the cropped sample 159 | """ 160 | i, j, h, w = RandomCrop.get_params( 161 | sample["event_volume_old"], output_size=crop_size) 162 | keys_to_crop = ["event_volume_old", "event_volume_new", 163 | "flow_gt_event_volume_old", "flow_gt_event_volume_new", "reverse_flow_gt_event_volume_old", "reverse_flow_gt_event_volume_new"] 164 | 165 | for key, value in sample.items(): 166 | if key in keys_to_crop: 167 | if isinstance(value, torch.Tensor): 168 | sample[key] = TF.crop(value, i, j, h, w) 169 | elif isinstance(value, list) or isinstance(value, tuple): 170 | sample[key] = [TF.crop(v, i, j, h, w) for v in value] 171 | return sample 172 | 173 | 174 | def downsample_spatial(x, factor): 175 | """ 176 | Downsample a given tensor spatially by a factor. 177 | :param x: PyTorch tensor of shape [batch, num_bins, height, width] 178 | :param factor: downsampling factor 179 | :return: PyTorch tensor of shape [batch, num_bins, height/factor, width/factor] 180 | """ 181 | assert (factor > 0), 'Factor must be positive!' 182 | 183 | assert (x.shape[-1] % 184 | factor == 0), 'Width of x must be divisible by factor!' 185 | assert (x.shape[-2] % 186 | factor == 0), 'Height of x must be divisible by factor!' 187 | 188 | return nn.AvgPool2d(kernel_size=factor, stride=factor)(x) 189 | 190 | 191 | def downsample_spatial_mask(x, factor): 192 | """ 193 | Downsample a given mask (boolean) spatially by a factor. 194 | :param x: PyTorch tensor of shape [batch, num_bins, height, width] 195 | :param factor: downsampling factor 196 | :return: PyTorch tensor of shape [batch, num_bins, height/factor, width/factor] 197 | """ 198 | assert (factor > 0), 'Factor must be positive!' 199 | 200 | assert (x.shape[-1] % 201 | factor == 0), 'Width of x must be divisible by factor!' 202 | assert (x.shape[-2] % 203 | factor == 0), 'Height of x must be divisible by factor!' 204 | 205 | return nn.AvgPool2d(kernel_size=factor, stride=factor)(x.float()) >= 0.5 206 | -------------------------------------------------------------------------------- /idn/utils/validation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import pickle 5 | from tqdm import tqdm 6 | import importlib 7 | 8 | from .helper_functions import move_batch_to_cuda, move_list_to_cuda, move_tensor_to_cuda 9 | from ..model.loss import sparse_lnorm 10 | 11 | 12 | def log_tensors(idx, dict_of_tensors, tmp_folder): 13 | for key, value in dict_of_tensors.items(): 14 | if isinstance(value, torch.Tensor): 15 | tensor = value.cpu().numpy() 16 | np.save(os.path.join(tmp_folder, f"{key}_{idx}.npy"), tensor) 17 | continue 18 | if isinstance(value, np.ndarray): 19 | np.save(os.path.join(tmp_folder, f"{key}_{idx}.npy"), value) 20 | continue 21 | 22 | 23 | 24 | class Validator: 25 | def __init__(self, config): 26 | self.test = dict() 27 | # self.logger = Logger.from_config(config.logger, config.name) 28 | for test_name, test_spec in config.items(): 29 | # create a test object 30 | self.test[test_name] = self.get_test_type(test_name)(test_spec) 31 | 32 | @staticmethod 33 | def get_test_type(test_name, test_type=None): 34 | from ..tests import dsec as dsec_tests 35 | if test_name == "dsec": 36 | return dsec_tests.assemble_dsec_test_cls(test_type) 37 | if test_name == "mvsec_day1": 38 | return dsec_tests.TestMVSEC 39 | if test_name == "mvsec_day1_rec": 40 | return dsec_tests.TestMVSECCO 41 | try: 42 | return getattr(dsec_tests, f"Test{test_name.upper()}") 43 | except: 44 | raise ValueError(f"Test {test_name} not found.") 45 | 46 | def __call__(self, model, run_all=True): 47 | # run the corresponding validator 48 | results = dict() 49 | for name, test in self.test.items(): 50 | state, results[name] = test.execute_test(model) 51 | return results 52 | 53 | 54 | def validate_model_warm(data_loader, model, config, return_dict, gpu_id, test_mode=False, 55 | log_field_dict=None, log_dir=None): 56 | model.cuda(gpu_id) 57 | model.eval() 58 | 59 | metrics = dict() 60 | 61 | with torch.no_grad(): 62 | if not isinstance(data_loader, list): 63 | data_loader = [data_loader] 64 | data_loader = data_loader[0:1] 65 | for seq_loader in data_loader: 66 | for idx, batch in enumerate(tqdm(seq_loader, position=1)): 67 | # if idx == 730: 68 | # break 69 | if isinstance(batch, list): 70 | batch = move_list_to_cuda(batch, gpu_id) 71 | loss_item = batch[-1] 72 | else: 73 | move_tensor_to_cuda(batch, gpu_id) 74 | loss_item = batch 75 | if idx == 0: 76 | out = model(batch) 77 | else: 78 | pre_flow = model.forward_flow(pred) 79 | if isinstance(batch, list): 80 | batch[0]["pre_flow"] = pre_flow 81 | else: 82 | batch["pre_flow"] = pre_flow 83 | out = model(batch, pre_flow, 1) 84 | if isinstance(batch, list): 85 | batch[-1] = {**batch[-1], **out} 86 | pred = batch[-1]["final_prediction"] 87 | else: 88 | batch = {**batch, **out} 89 | pred = batch["final_prediction"] 90 | 91 | if not test_mode and idx != 0: 92 | loss_l1, emap = sparse_lnorm(1, pred, loss_item['flow_gt_event_volume_new'], 93 | loss_item['flow_gt_event_volume_new_valid_mask'], 94 | per_frame=True) 95 | loss_l2, _ = sparse_lnorm(2, pred, loss_item['flow_gt_event_volume_new'], 96 | loss_item['flow_gt_event_volume_new_valid_mask'], 97 | per_frame=True) 98 | loss_l1_pre, pre_emap = sparse_lnorm(1, pre_flow, loss_item['flow_gt_event_volume_new'], 99 | loss_item['flow_gt_event_volume_new_valid_mask'], 100 | per_frame=True) 101 | loss_l2_pre, _ = sparse_lnorm(2, pre_flow, loss_item['flow_gt_event_volume_new'], 102 | loss_item['flow_gt_event_volume_new_valid_mask'], 103 | per_frame=True) 104 | if isinstance(batch, list): 105 | batch = batch[-1] 106 | batch['emap'] = emap 107 | batch['pre_emap'] = pre_emap 108 | for i, sample_seq in enumerate(batch['seq_name']): 109 | if sample_seq not in metrics.keys(): 110 | metrics[sample_seq] = dict() 111 | metrics[sample_seq]["l2"] = [] 112 | metrics[sample_seq]["l1"] = [] 113 | metrics[sample_seq]["l2_pre"] = [] 114 | metrics[sample_seq]["l1_pre"] = [] 115 | metrics[sample_seq]["l2"].append(loss_l2[i]) 116 | metrics[sample_seq]["l1"].append(loss_l1[i]) 117 | metrics[sample_seq]["l2_pre"].append(loss_l2_pre[i]) 118 | metrics[sample_seq]["l1_pre"].append(loss_l1_pre[i]) 119 | 120 | if log_dir: 121 | log_tensors_dict = dict() 122 | for key, value in log_field_dict.items(): 123 | if isinstance(batch, list): 124 | batch = batch[-1] 125 | if value in batch.keys(): 126 | log_tensors_dict[key] = batch[value] 127 | log_tensors(idx, log_tensors_dict, log_dir) 128 | 129 | for seqs in metrics.keys(): 130 | metrics[seqs]["l2_avg"] = np.array(metrics[seqs]["l2"]).mean() 131 | metrics[seqs]["l1_avg"] = np.array(metrics[seqs]["l1"]).mean() 132 | return_dict[seqs] = metrics[seqs] 133 | 134 | if log_dir: 135 | with open(os.path.join(log_dir, "metrics.pkl"), "wb") as f: 136 | pickle.dump(metrics, f) 137 | 138 | 139 | def validate_model(data_loader, model, config, return_dict, gpu_id, test_mode=False, 140 | log_field_dict=None, log_dir=None): 141 | model.cuda(gpu_id) 142 | model.eval() 143 | 144 | metrics = dict() 145 | 146 | with torch.no_grad(): 147 | if not isinstance(data_loader, list): 148 | data_loader = [data_loader] 149 | data_loader = data_loader[0:1] # TODO: remove hard coded first seq 150 | for seq_loader in data_loader: 151 | for idx, batch in enumerate(tqdm(seq_loader, position=1)): 152 | if isinstance(batch, list): 153 | batch = move_list_to_cuda(batch, gpu_id) 154 | loss_item = batch[-1] 155 | else: 156 | move_tensor_to_cuda(batch, gpu_id) 157 | loss_item = batch 158 | out = model(batch) 159 | if not isinstance(batch, list): 160 | batch = {**batch, **out} 161 | pred = batch["final_prediction"] 162 | else: 163 | pred = out["final_prediction"] 164 | batch = batch[-1] 165 | if not test_mode: 166 | loss_l1, _ = sparse_lnorm(1, pred, loss_item['flow_gt_event_volume_new'], 167 | loss_item['flow_gt_event_volume_new_valid_mask'], 168 | per_frame=True) 169 | loss_l2, _ = sparse_lnorm(2, pred, loss_item['flow_gt_event_volume_new'], 170 | loss_item['flow_gt_event_volume_new_valid_mask'], 171 | per_frame=True) 172 | for i, sample_seq in enumerate(batch['seq_name']): 173 | if sample_seq not in metrics.keys(): 174 | metrics[sample_seq] = dict() 175 | metrics[sample_seq]["l2"] = [] 176 | metrics[sample_seq]["l1"] = [] 177 | metrics[sample_seq]["l2"].append(loss_l2[i]) 178 | metrics[sample_seq]["l1"].append(loss_l1[i]) 179 | 180 | if log_dir: 181 | log_tensors_dict = dict() 182 | for key, value in log_field_dict.items(): 183 | if value in batch.keys(): 184 | log_tensors_dict[key] = batch[value] 185 | log_tensors(idx, log_tensors_dict, log_dir) 186 | 187 | for seqs in metrics.keys(): 188 | metrics[seqs]["l2_avg"] = np.array(metrics[seqs]["l2"]).mean() 189 | metrics[seqs]["l1_avg"] = np.array(metrics[seqs]["l1"]).mean() 190 | return_dict[seqs] = metrics[seqs] 191 | 192 | if log_dir: 193 | with open(os.path.join(log_dir, "metrics.pkl"), "wb") as f: 194 | pickle.dump(metrics, f) 195 | --------------------------------------------------------------------------------