├── .gitattributes
├── .gitignore
├── README.md
├── data
└── .gitkeep
├── download_dsec_test.py
├── environment.yml
└── idn
├── check_submission.py
├── checkpoint
├── id-4x.pt
├── id-8x.pt
└── tid.pt
├── config
├── data_loader_base.yaml
├── dataset
│ ├── dsec.yaml
│ ├── dsec_augmentation.yaml
│ ├── dsec_base.yaml
│ └── dsec_rec.yaml
├── hydra
│ └── custom_hydra.yaml
├── id_eval.yaml
├── id_train.yaml
├── logger_base.yaml
├── model
│ ├── id-4x.yaml
│ ├── id-8x.yaml
│ └── idedeqid.yaml
├── mvsec_train.yaml
├── tid_eval.yaml
├── tid_train.yaml
├── torch_environ_base.yaml
└── validation
│ ├── co.yaml
│ ├── data_loader_val_base.yaml
│ ├── dsec_co.yaml
│ ├── dsec_test.yaml
│ ├── mvsec_day1.yaml
│ └── nonrec.yaml
├── eval.py
├── loader
├── loader_dsec.py
└── loader_mvsec.py
├── model
├── extractor.py
├── idedeq.py
├── loss.py
└── update.py
├── scripts
└── format_mvsec
│ ├── eval_utils.py
│ ├── format_mvsec.py
│ └── h5_packager.py
├── tests
├── dsec.py
├── eval.py
└── test.py
├── train.py
└── utils
├── callbacks.py
├── cb
├── logger.py
└── validator.py
├── dsec_utils.py
├── exp_tracker.py
├── helper_functions.py
├── logger.py
├── loss_utils.py
├── model_utils.py
├── mvsec_utils.py
├── retrieval_fn.py
├── torch_environ.py
├── trainer.py
├── transformers.py
└── validation.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | idn/checkpoint/*.pt filter=lfs diff=lfs merge=lfs -text
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .vscode/
3 | data/*
4 | !data/.gitkeep
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Lightweight Event-based Optical Flow Estimation via Iterative Deblurring
2 |
3 | Work accepted to 2024 IEEE International Conference on Robotics and Automation (ICRA'24) [[paper](https://arxiv.org/abs/2211.13726), [video](https://www.youtube.com/watch?v=1qA1hONS4Sw)].
4 |
5 |
6 |
7 |
8 |
9 | 
10 |
11 | If you use this code in an academic context, please cite our work:
12 |
13 | ```bibtex
14 | @INPROCEEDINGS{10610353,
15 | author={Wu, Yilun and Paredes-Vallés, Federico and de Croon, Guido C. H. E.},
16 | booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)},
17 | title={Lightweight Event-based Optical Flow Estimation via Iterative Deblurring},
18 | year={2024},
19 | volume={},
20 | number={},
21 | pages={14708-14715},
22 | keywords={Image motion analysis;Correlation;Memory management;Estimation;Rendering (computer graphics);Real-time systems;Iterative algorithms},
23 | doi={10.1109/ICRA57147.2024.10610353}}
24 | ```
25 |
26 | ## Dependency
27 | Create a conda env and install dependencies by running
28 | ```
29 | conda env create --file environment.yml
30 | ```
31 |
32 | ## Download (For Evaluation)
33 | The DSEC dataset for optical flow can be downloaded [here](https://dsec.ifi.uzh.ch/dsec-datasets/download/).
34 | Use script [download_dsec_test.py](download_dsec_test.py) for your convenience.
35 | It downloads the dataset directly into the `DATA_DIRECTORY` with the expected directory structure.
36 | ```python
37 | download_dsec_test.py
38 | ```
39 | Once downloaded, create a symbolic link called `data` pointing to the data directory:
40 | ```
41 | ln -s data/test
42 | ```
43 |
44 | ## Download (For Training)
45 | For training on DSEC, two more folders need to be downloaded:
46 |
47 | - Unzip [train_events.zip](https://download.ifi.uzh.ch/rpg/DSEC/train_coarse/train_events.zip) to data/train_events
48 |
49 | - Unzip [train_optical_flow.zip](https://download.ifi.uzh.ch/rpg/DSEC/train_coarse/train_optical_flow.zip) to data/train_optical_flow
50 |
51 | or establish symbolic links under data/ pointing to the folders.
52 |
53 | ## Download (MVSEC)
54 | To run experiments on MVSEC, additionally download outdoor day sequences .h5 files from https://drive.google.com/open?id=1rwyRk26wtWeRgrAx_fgPc-ubUzTFThkV
55 | and place the files under data/ or point symbolic links pointing to the data files under data/.
56 |
57 | ## Run Evaluation
58 |
59 | To run eval:
60 | ```
61 | cd idnet
62 | conda activate IDNet
63 | python -m idn.eval
64 | ```
65 |
66 | Change the save directory for eval results in `idn/config/validation/dsec_test.yaml` if you prefer. The default is at `/tmp/collect/XX`.
67 |
68 | To switch between models, change the model option in `idn/config/id_eval.yaml` to switch between id model with 1/4 and 1/8 resolution.
69 |
70 | To eval TID model, change the function decorator above the main function in `eval.py`.
71 |
72 | At the end of evaluation, a zip file containing the results will be created in the save directory, for which you can upload to the DSEC benchmark website to reproduce our results.
73 |
74 | ## Run Training
75 | To train IDNet, run:
76 | ```
77 | cd idnet
78 | conda activate IDNet
79 | python -m idn.train
80 | ```
81 |
82 | Similarly, switch between id-4x, id-8x and tid models and MVSEC training by changing the hydra.main() decorator in `train.py` and settings in the corresponding .yaml file.
83 |
84 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tudelft/idnet/de142df364cbbbc8b81ed72a4a94cca9fd8adab6/data/.gitkeep
--------------------------------------------------------------------------------
/download_dsec_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import os
4 | import urllib
5 | import shutil
6 | from typing import Union
7 |
8 | from requests import get
9 |
10 | TEST_SEQUENCES = ['interlaken_00_b', 'interlaken_01_a', 'thun_01_a',
11 | 'thun_01_b', 'zurich_city_12_a', 'zurich_city_14_c', 'zurich_city_15_a']
12 | BASE_TEST_URL = 'https://download.ifi.uzh.ch/rpg/DSEC/test/'
13 | TEST_FLOW_TIMESTAMPS_URL = 'https://download.ifi.uzh.ch/rpg/DSEC/test_forward_optical_flow_timestamps.zip'
14 |
15 |
16 | def download(url: str, filepath: Path, skip: bool = True) -> bool:
17 | if skip and filepath.exists():
18 | print(f'{str(filepath)} already exists. Skipping download.')
19 | return True
20 | with open(str(filepath), 'wb') as fl:
21 | response = get(url)
22 | fl.write(response.content)
23 | return response.ok
24 |
25 |
26 | def unzip(file_: Path, delete_zip: bool = True, skip: bool = True) -> Path:
27 | assert file_.exists()
28 | assert file_.suffix == '.zip'
29 | output_dir = file_.parent / file_.stem
30 | if skip and output_dir.exists():
31 | print(f'{str(output_dir)} already exists. Skipping unzipping operation.')
32 | else:
33 | shutil.unpack_archive(file_, output_dir)
34 | if delete_zip:
35 | os.remove(file_)
36 | return output_dir
37 |
38 |
39 | if __name__ == '__main__':
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument('output_directory')
42 |
43 | args = parser.parse_args()
44 |
45 | output_dir = Path(args.output_directory)
46 | output_dir = output_dir / 'test'
47 | os.makedirs(output_dir, exist_ok=True)
48 |
49 | test_timestamps_file = output_dir / 'test_forward_flow_timestamps.zip'
50 |
51 | assert download(TEST_FLOW_TIMESTAMPS_URL,
52 | test_timestamps_file), TEST_FLOW_TIMESTAMPS_URL
53 | test_timestamps_dir = unzip(test_timestamps_file)
54 |
55 | for seq_name in TEST_SEQUENCES:
56 | seq_path = output_dir / seq_name
57 | os.makedirs(seq_path, exist_ok=True)
58 |
59 | # image timestamps
60 | img_timestamps_url = BASE_TEST_URL + seq_name + \
61 | '/' + seq_name + '_image_timestamps.txt'
62 | img_timestamps_file = seq_path / 'image_timestamps.txt'
63 | if not img_timestamps_file.exists():
64 | assert download(img_timestamps_url,
65 | img_timestamps_file), img_timestamps_url
66 |
67 | # test timestamps
68 | test_timestamps_file_destination = seq_path / 'test_forward_flow_timestamps.csv'
69 | if not test_timestamps_file_destination.exists():
70 | shutil.move(test_timestamps_dir / (seq_name + '.csv'),
71 | test_timestamps_file_destination)
72 |
73 | # event data
74 | events_left_url = BASE_TEST_URL + seq_name + '/' + seq_name + '_events_left.zip'
75 | events_left_file = seq_path / 'events_left.zip'
76 | if not (events_left_file.parent / events_left_file.stem).exists():
77 | assert download(events_left_url, events_left_file), events_left_url
78 | unzip(events_left_file)
79 |
80 | shutil.rmtree(test_timestamps_dir)
81 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: IDNet
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python=3.8
7 | - pytorch=1.13.1
8 | - pytorch-cuda=11.7
9 | - torchvision
10 | - scikit-image
11 | - numba
12 | - pandas
13 | - termcolor
14 | - h5py
15 | - tqdm
16 | - matplotlib
17 | - imageio
18 | - ipykernel
19 | - pip
20 | - pip:
21 | - torchinfo
22 | - hydra-core
23 | - wandb
24 | - gitignore-parser
25 | - opencv-python-headless
26 | - hdf5plugin
27 | - nbformat
28 | - nbstripout
29 | - rerun-sdk
30 |
31 |
--------------------------------------------------------------------------------
/idn/check_submission.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio
3 | from typing import Dict
4 | from pathlib import Path
5 | import os
6 | from enum import Enum, auto
7 | import argparse
8 | import sys
9 | version = sys.version_info
10 | assert version[0] >= 3, 'Python 2 is not supported'
11 | assert version[1] >= 6, 'Requires Python 3.6 or higher'
12 |
13 | os.environ['IMAGEIO_USERDIR'] = '/var/tmp'
14 |
15 | imageio.plugins.freeimage.download()
16 |
17 |
18 | class WriteFormat(Enum):
19 | OPENCV = auto()
20 | IMAGEIO = auto()
21 |
22 |
23 | def is_string_swiss(input_str: str) -> bool:
24 | is_swiss = False
25 | is_swiss |= 'thun_' in input_str
26 | is_swiss |= 'interlaken_' in input_str
27 | is_swiss |= 'zurich_city_' in input_str
28 | return is_swiss
29 |
30 |
31 | def flow_16bit_to_float(flow_16bit: np.ndarray, valid_in_3rd_channel: bool):
32 | assert flow_16bit.dtype == np.uint16
33 | assert flow_16bit.ndim == 3
34 | h, w, c = flow_16bit.shape
35 | assert c == 3
36 |
37 | if valid_in_3rd_channel:
38 | valid2D = flow_16bit[..., 2] == 1
39 | assert valid2D.shape == (h, w)
40 | assert np.all(flow_16bit[~valid2D, -1] == 0)
41 | else:
42 | valid2D = np.ones_like(flow_16bit[..., 2], dtype=np.bool)
43 | valid_map = np.where(valid2D)
44 |
45 | # to actually compute something useful:
46 | flow_16bit = flow_16bit.astype('float')
47 |
48 | flow_map = np.zeros((h, w, 2))
49 | flow_map[valid_map[0], valid_map[1], 0] = (
50 | flow_16bit[valid_map[0], valid_map[1], 0] - 2**15) / 128
51 | flow_map[valid_map[0], valid_map[1], 1] = (
52 | flow_16bit[valid_map[0], valid_map[1], 1] - 2**15) / 128
53 | return flow_map, valid2D
54 |
55 |
56 | def load_flow(flowfile: Path, valid_in_3rd_channel: bool, write_format=WriteFormat):
57 | assert flowfile.exists()
58 | assert flowfile.suffix == '.png'
59 |
60 | # imageio reading assumes write format was rgb
61 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI')
62 | if write_format == WriteFormat.OPENCV:
63 | # opencv writes as bgr -> flip last axis to get rgb
64 | flow_16bit = np.flip(flow_16bit, axis=-1)
65 | else:
66 | assert write_format == WriteFormat.IMAGEIO
67 |
68 | channel3 = flow_16bit[..., -1]
69 | assert channel3.max(
70 | ) <= 1, f'Maximum value in last channel should be 1: {flowfile}'
71 | flow, valid2D = flow_16bit_to_float(flow_16bit, valid_in_3rd_channel)
72 | return flow, valid2D
73 |
74 |
75 | def list_of_dirs(dirpath: Path):
76 | return next(os.walk(dirpath))[1]
77 |
78 |
79 | def files_per_sequence(flow_timestamps_dir: Path) -> Dict[str, int]:
80 | out_dict = dict()
81 | for entry in flow_timestamps_dir.iterdir():
82 | assert entry.is_file()
83 | assert entry.suffix == '.csv', entry.suffix
84 | assert is_string_swiss(entry.stem), entry.stem
85 | data = np.loadtxt(entry, dtype=np.int64, delimiter=', ', comments='#')
86 | assert data.ndim == 2, data.ndim
87 | num_files = data.shape[0]
88 | out_dict[entry.stem] = num_files
89 | return out_dict
90 |
91 |
92 | def check_submission(submission_dir: Path, flow_timestamps_dir: Path):
93 | assert flow_timestamps_dir.is_dir()
94 | assert submission_dir.is_dir()
95 |
96 | name2num = files_per_sequence(flow_timestamps_dir)
97 |
98 | expected_flow_shape = (480, 640, 2)
99 | expected_valid_shape = (480, 640)
100 | expected_dir_names = set([*name2num])
101 | actual_dir_names = set(list_of_dirs(submission_dir))
102 | assert expected_dir_names == actual_dir_names, f'Expected directories in your submission: {expected_dir_names}.\nMissing directories: {expected_dir_names.difference(actual_dir_names)}'
103 |
104 | for seq in submission_dir.iterdir():
105 | if not seq.is_dir():
106 | continue
107 | assert seq.is_dir()
108 | assert is_string_swiss(seq.name), seq.name
109 | num_files = 0
110 | for prediction in seq.iterdir():
111 | flow, valid_map = load_flow(
112 | prediction, valid_in_3rd_channel=False, write_format=WriteFormat.IMAGEIO)
113 | assert flow.shape == expected_flow_shape, f'Expected shape: {expected_flow_shape}, actual shape: {flow.shape}'
114 | assert valid_map.shape == expected_valid_shape, f'Expected shape: {expected_valid_shape}, actual shape: {valid_map.shape}'
115 | num_files += 1
116 | assert seq.name in [*name2num], f'{seq.name} not in {[*name2num]}'
117 | assert num_files == name2num[
118 | seq.name], f'expected {name2num[seq.name]} files in {str(seq)} but only found {num_files} files'
119 |
120 | return True
121 |
122 |
123 | if __name__ == '__main__':
124 | parser = argparse.ArgumentParser()
125 | parser.add_argument('submission_dir', help='Path to submission directory')
126 | parser.add_argument('flow_timestamps_dir',
127 | help='Path to directory containing the flow timestamps for evaluation.')
128 |
129 | args = parser.parse_args()
130 |
131 | print('start checking submission')
132 | check_submission(Path(args.submission_dir), Path(args.flow_timestamps_dir))
133 | print('Your submission directory has the correct structure: Ready to submit!\n')
134 | print('Note, that we will sort the files according to their names in each directory and evaluate them sequentially. Follow the exact naming instructions if you are unsure.')
135 |
--------------------------------------------------------------------------------
/idn/checkpoint/id-4x.pt:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:e28541f311cc4f52282a1d5321194f2baf1ce91ac5a4aaf4be9e28d525411ff5
3 | size 10219203
4 |
--------------------------------------------------------------------------------
/idn/checkpoint/id-8x.pt:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:044d107030c4d2f7fdaa3d41bf48d84960384bb6864af9e5639e6577651f2095
3 | size 5731319
4 |
--------------------------------------------------------------------------------
/idn/checkpoint/tid.pt:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:f1134a4c61c8556a8f53981243895fffe819c64e3ca4c64fdc41943c0ffa617f
3 | size 7551263
4 |
--------------------------------------------------------------------------------
/idn/config/data_loader_base.yaml:
--------------------------------------------------------------------------------
1 | data_loader:
2 | train:
3 | gpu: ???
4 | args:
5 | batch_size: ???
6 | shuffle: true
7 | num_workers: 2
8 | pin_memory: false
9 | prefetch_factor: 2
10 | persistent_workers: false
11 | val:
12 | gpu: ???
13 | mp: false
14 | batch_freq: 1500
15 | args:
16 | batch_size: 1
17 | shuffle: false
18 | num_workers: 1
19 | pin_memory: true
20 | prefetch_factor: 2
--------------------------------------------------------------------------------
/idn/config/dataset/dsec.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dsec_base
3 |
4 | num_voxel_bins: 15
5 |
6 | train:
7 | use_all_seqs: true
8 | seq: []
9 | load_gt: true
10 | downsample_ratio: 1
11 | val:
12 | seq: [zurich_city_01_a]
13 | downsample_ratio: 1
14 | concat_seq: false
15 |
--------------------------------------------------------------------------------
/idn/config/dataset/dsec_augmentation.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dsec
3 |
4 | train:
5 | vertical_flip: 0.1
6 | horizontal_flip: 0.5
7 | random_crop: [288, 384]
8 |
9 |
--------------------------------------------------------------------------------
/idn/config/dataset/dsec_base.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: dsec
2 | common:
3 | data_root: data/
4 | test_root: data/test/
--------------------------------------------------------------------------------
/idn/config/dataset/dsec_rec.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dsec
3 |
4 | train:
5 | recurrent: true
6 | sequence_length: 4
7 | vertical_flip: 0.1
8 | horizontal_flip: 0.5
9 | random_crop: [384, 512]
10 |
11 | val:
12 | recurrent: true
13 | sequence_length: 4
14 | concat_seq: false
--------------------------------------------------------------------------------
/idn/config/hydra/custom_hydra.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | dir: .
3 | output_subdir: null
4 |
5 | defaults:
6 | - override job_logging: none
7 | - override hydra_logging: none
8 | - _self_
--------------------------------------------------------------------------------
/idn/config/id_eval.yaml:
--------------------------------------------------------------------------------
1 | eval_only: True
2 | callbacks:
3 |
4 | defaults:
5 | - hydra: custom_hydra
6 | - data_loader_base
7 | - model: id-8x # id-4x
8 | - dataset: dsec
9 | - torch_environ_base
10 | - _self_
11 |
12 |
--------------------------------------------------------------------------------
/idn/config/id_train.yaml:
--------------------------------------------------------------------------------
1 | deterministic: false
2 | track: false
3 | num_epoch: 100
4 | data_loader:
5 | common:
6 | num_voxel_bins: 15
7 | train:
8 | gpu: 0
9 | args:
10 | batch_size: 3
11 | val:
12 | gpu: 0
13 | batch_freq: 1500
14 | loss:
15 | final_prediction_nonseq:
16 | loss_type: sparse_l1
17 | weight: 1.0
18 | seq_weight: avg
19 | seq_norm: false
20 |
21 |
22 | dataset:
23 | representation_type: voxel
24 |
25 |
26 | optim:
27 | optimizer: adam
28 | scheduler: onecycle
29 | lr: 1e-4
30 |
31 | callbacks:
32 | logger:
33 | enable:
34 | log_keys:
35 | batch_end:
36 | - loss
37 | - lr
38 |
39 |
40 | validator:
41 | enable:
42 | frequency_type: step
43 | frequency: 500
44 | sanity_run_step: 3
45 |
46 |
47 | validation:
48 | nonrec:
49 | dataset:
50 | representation_type: ${dataset.representation_type}
51 |
52 | model:
53 | pretrain_ckpt: null
54 |
55 | defaults:
56 | - validation@_group_.nonrec: nonrec
57 | - hydra: custom_hydra
58 | - data_loader_base
59 | - torch_environ_base
60 | - model: id-8x # id-4x
61 | - dataset: dsec_augmentation
62 | - _self_
63 |
--------------------------------------------------------------------------------
/idn/config/logger_base.yaml:
--------------------------------------------------------------------------------
1 | logger:
2 | saved_tensors:
3 |
4 | statistics: ['avg']
--------------------------------------------------------------------------------
/idn/config/model/id-4x.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - idedeqid
3 |
4 | hidden_dim: 128
5 | downsample: 4
6 | pretrain_ckpt: idn/checkpoint/id-4x.pt
--------------------------------------------------------------------------------
/idn/config/model/id-8x.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - idedeqid
3 |
4 | hidden_dim: 96
5 | downsample: 8
6 | pretrain_ckpt: idn/checkpoint/id-8x.pt
--------------------------------------------------------------------------------
/idn/config/model/idedeqid.yaml:
--------------------------------------------------------------------------------
1 | name: IDEDEQIDO
2 | deblur: true
3 | add_delta: true
4 | update_iters: 4
5 | zero_init: true
6 | deq_mode: false
7 | input_flowmap: true
8 |
--------------------------------------------------------------------------------
/idn/config/mvsec_train.yaml:
--------------------------------------------------------------------------------
1 | deterministic: false
2 | track: false
3 | num_epoch: 150
4 | run_val: false
5 |
6 | data_loader:
7 | common:
8 | num_voxel_bins: 15
9 | train:
10 | gpu: 0
11 | args:
12 | batch_size: 3
13 | pin_memory: false
14 | val:
15 | gpu: 0
16 | batch_freq: 1500
17 | loss:
18 | final_prediction_nonseq:
19 | loss_type: sparse_l1
20 | weight: 1.0
21 | seq_weight: avg
22 | seq_norm: false
23 |
24 | dataset:
25 | dataset_name: mvsec
26 | num_voxel_bins: 15
27 | train:
28 |
29 |
30 | optim:
31 | optimizer: adam
32 | scheduler: onecycle
33 | lr: 1e-4
34 |
35 | callbacks:
36 | logger:
37 | enable:
38 | log_keys:
39 | batch_end:
40 | - loss
41 |
42 |
43 | validator:
44 | enable:
45 | frequency_type: step
46 | frequency: 500
47 | sanity_run_step: 3
48 |
49 |
50 | model:
51 | pretrain_ckpt: null
52 |
53 | defaults:
54 | - validation@_group_.mvsec_day1: mvsec_day1
55 | - hydra: custom_hydra
56 | - data_loader_base
57 | - torch_environ_base
58 | - model: id-8x
59 | - _self_
60 |
--------------------------------------------------------------------------------
/idn/config/tid_eval.yaml:
--------------------------------------------------------------------------------
1 | eval_only: True
2 | callbacks:
3 |
4 | defaults:
5 | - validation@_group_.co: co
6 | - hydra: custom_hydra
7 | - data_loader_base
8 | - torch_environ_base
9 | - model: idedeqid
10 | - dataset: dsec_rec
11 | - _self_
12 |
13 | model:
14 | name: RecIDE
15 | update_iters: 1
16 | pred_next_flow: true
17 | pretrain_ckpt: idn/checkpoint/tid.pt
--------------------------------------------------------------------------------
/idn/config/tid_train.yaml:
--------------------------------------------------------------------------------
1 | deterministic: false
2 | num_epoch: 100
3 | run_val: false
4 | data_loader:
5 | common:
6 | num_voxel_bins: 15
7 | train:
8 | gpu: 0
9 | args:
10 | batch_size: 3
11 | val:
12 | gpu: 0
13 | batch_freq: 1500
14 | loss:
15 | pred_flow_seq:
16 | loss_type: sparse_l1
17 | weight: 1.0
18 | seq_weight: [0.17, 0.21, 0.27, 0.35]
19 | seq_norm: false
20 | pred_flow_next_seq:
21 | loss_type: sparse_l1
22 | weight: 1.0
23 | seq_weight: [0.17, 0.21, 0.27, 0.35]
24 | seq_norm: false
25 |
26 | dataset:
27 | train:
28 | sequence_length: 4
29 |
30 | optim:
31 | optimizer: adam
32 | scheduler: onecycle
33 | lr: 1e-4
34 |
35 | callbacks:
36 | logger:
37 | enable:
38 | log_keys:
39 | batch_end:
40 | - loss
41 | - loss_pred_flow_seq
42 | - loss_pred_flow_next_seq
43 |
44 |
45 | validator:
46 | enable:
47 | frequency_type: step
48 | frequency: 1000
49 | sanity_run_step: 3
50 |
51 |
52 |
53 |
54 | defaults:
55 | - validation@_group_.co: co
56 | - hydra: custom_hydra
57 | - data_loader_base
58 | - torch_environ_base
59 | - model: idedeqid
60 | - dataset: dsec_rec
61 | - _self_
62 |
63 | model:
64 | name: RecIDE
65 | update_iters: 1
66 | pred_next_flow: true
67 |
--------------------------------------------------------------------------------
/idn/config/torch_environ_base.yaml:
--------------------------------------------------------------------------------
1 | torch:
2 | debug_grad: false
3 | deterministic: false
--------------------------------------------------------------------------------
/idn/config/validation/co.yaml:
--------------------------------------------------------------------------------
1 | name: continuous-operation
2 | val_batch_freq: 1500
3 | data_loader:
4 | gpu: 0
5 |
6 | dataset:
7 | val:
8 | sequence_length: 1
9 |
10 | logger:
11 | saved_tensors:
12 | final_prediction:
13 | flow_gt_event_volume_new:
14 |
15 | metrics:
16 | final_prediction: ["L1", "L2"]
17 | next_flow: ["L1", "L2"]
18 |
19 | postprocess:
20 |
21 | defaults:
22 | - /validation/data_loader_val_base@_here_
23 | - /dataset/dsec_rec@dataset
24 | - /logger_base@_here_
--------------------------------------------------------------------------------
/idn/config/validation/data_loader_val_base.yaml:
--------------------------------------------------------------------------------
1 | data_loader:
2 | gpu: ???
3 | mp: false
4 | batch_freq: 1500
5 | args:
6 | batch_size: 1
7 | shuffle: false
8 | num_workers: 1
9 | pin_memory: true
10 | prefetch_factor: 2
--------------------------------------------------------------------------------
/idn/config/validation/dsec_co.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | val:
3 | sequence_length: 1
4 | recurrent: true
5 |
6 | defaults:
7 | - dsec_test
8 | - _self_
9 |
--------------------------------------------------------------------------------
/idn/config/validation/dsec_test.yaml:
--------------------------------------------------------------------------------
1 | name: dsec-test
2 | data_loader:
3 | gpu: 0
4 | mp: false
5 | batch_freq: 1500
6 | args:
7 | batch_size: 1
8 | shuffle: false
9 | num_workers: 1
10 | pin_memory: true
11 | prefetch_factor: 2
12 |
13 |
14 | logger:
15 | save_dir: /tmp/collect/XX
16 | saved_tensors:
17 |
18 | postprocess:
19 |
20 | metrics:
21 |
22 | hydra:
23 | output_subdir: null
24 |
25 | defaults:
26 | - /dataset/dsec@dataset
27 | - /logger_base@_here_
28 | - _self_
29 |
--------------------------------------------------------------------------------
/idn/config/validation/mvsec_day1.yaml:
--------------------------------------------------------------------------------
1 | name: mvsec_day1
2 | val_batch_freq: 1500
3 | data_loader:
4 | gpu: 0
5 | args:
6 | batch_size: 1
7 | pin_memory: false
8 | shuffle: false
9 | num_workers: 0
10 | prefetch_factor: 2
11 |
12 |
13 | logger:
14 | saved_tensors:
15 | final_prediction:
16 | flow_gt_event_volume_new:
17 |
18 | dataset:
19 | val:
20 | recurrent: false
21 |
22 |
23 | postprocess:
24 |
25 | metrics:
26 | final_prediction_nonseq: ["L1", "L2", "1PE", "3PE"]
27 |
28 | hydra:
29 | output_subdir: null
30 |
31 | defaults:
32 | - /validation/data_loader_val_base@_here_
33 | - /logger_base@_here_
34 |
--------------------------------------------------------------------------------
/idn/config/validation/nonrec.yaml:
--------------------------------------------------------------------------------
1 | name: non-rec
2 | val_batch_freq: 1500
3 | data_loader:
4 | gpu: 0
5 |
6 |
7 | logger:
8 | saved_tensors:
9 | final_prediction:
10 | flow_gt_event_volume_new:
11 |
12 | postprocess:
13 |
14 | metrics:
15 | final_prediction_nonseq: ["L1", "L2"]
16 |
17 | hydra:
18 | output_subdir: null
19 |
20 | defaults:
21 | - /validation/data_loader_val_base@_here_
22 | - /dataset/dsec@dataset
23 | - /logger_base@_here_
24 |
--------------------------------------------------------------------------------
/idn/eval.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | import hydra
3 | from hydra import initialize, compose
4 | from .utils.trainer import Trainer
5 | from .utils.validation import Validator
6 |
7 |
8 |
9 | def test(trainer):
10 | test_cfg = compose(config_name="validation/dsec_test",
11 | overrides=[]).validation
12 | Validator.get_test_type("dsec")(test_cfg).execute_test(
13 | trainer.model, save_all=False)
14 |
15 |
16 | def test_co(trainer):
17 | test_cfg = compose(config_name="validation/dsec_co",
18 | overrides=[]).validation
19 | Validator.get_test_type("dsec", "co")(
20 | test_cfg).execute_test(trainer.model, save_all=False)
21 |
22 |
23 |
24 |
25 | # @hydra.main(config_path="config", config_name="tid_eval")
26 | @hydra.main(config_path="config", config_name="id_eval")
27 |
28 | def main(config):
29 | print(OmegaConf.to_yaml(config))
30 |
31 | trainer = Trainer(config)
32 |
33 | print("Number of parameters: ", sum(p.numel()
34 | for p in trainer.model.parameters() if p.requires_grad))
35 |
36 |
37 | if config.model.name == "RecIDE":
38 | test_co(trainer)
39 | elif config.model.name == "IDEDEQIDO":
40 | test(trainer)
41 |
42 |
43 |
44 | if __name__ == '__main__':
45 | main()
46 |
--------------------------------------------------------------------------------
/idn/loader/loader_dsec.py:
--------------------------------------------------------------------------------
1 | import math
2 | from pathlib import Path, PurePath
3 | from random import sample
4 | import random
5 | from typing import Dict, Tuple
6 | import weakref
7 | from time import time
8 | import cv2
9 | # import h5pickle as h5py
10 | import h5py
11 | from numba import jit
12 | import numpy as np
13 | import os
14 | import imageio
15 | import hashlib
16 | import mkl
17 | import torch
18 | from torchvision.transforms import ToTensor, RandomCrop
19 | from torchvision import transforms as tf
20 | from torch.utils.data import Dataset, DataLoader
21 | from matplotlib import pyplot as plt, transforms
22 | from ..utils import transformers
23 |
24 |
25 | from ..utils.dsec_utils import RepresentationType, VoxelGrid, PolarityCount, flow_16bit_to_float
26 | from ..utils.transformers import (
27 | downsample_spatial,
28 | downsample_spatial_mask,
29 | apply_transform_to_field,
30 | apply_randomcrop_to_sample)
31 |
32 | VISU_INDEX = 1
33 |
34 |
35 | class EventSlicer:
36 | def __init__(self, h5f: h5py.File):
37 | self.h5f = h5f
38 |
39 | self.events = dict()
40 | for dset_str in ['p', 'x', 'y', 't']:
41 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)]
42 |
43 | # This is the mapping from milliseconds to event index:
44 | # It is defined such that
45 | # (1) t[ms_to_idx[ms]] >= ms*1000
46 | # (2) t[ms_to_idx[ms] - 1] < ms*1000
47 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds.
48 | #
49 | # As an example, given 't' and 'ms':
50 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000
51 | # ms: 0 1 2 3 4 5 6 7 8 9
52 | #
53 | # we get
54 | #
55 | # ms_to_idx:
56 | # 0 2 2 3 3 3 5 5 8 9
57 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64')
58 |
59 | self.t_offset = int(h5f['t_offset'][()])
60 | self.t_final = int(self.events['t'][-1]) + self.t_offset
61 |
62 | def get_final_time_us(self):
63 | return self.t_final
64 |
65 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]:
66 | """Get events (p, x, y, t) within the specified time window
67 | Parameters
68 | ----------
69 | t_start_us: start time in microseconds
70 | t_end_us: end time in microseconds
71 | Returns
72 | -------
73 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved
74 | """
75 | assert t_start_us < t_end_us
76 |
77 | # We assume that the times are top-off-day, hence subtract offset:
78 | t_start_us -= self.t_offset
79 | t_end_us -= self.t_offset
80 |
81 | t_start_ms, t_end_ms = self.get_conservative_window_ms(
82 | t_start_us, t_end_us)
83 | t_start_ms_idx = self.ms2idx(t_start_ms)
84 | t_end_ms_idx = self.ms2idx(t_end_ms)
85 |
86 | if t_start_ms_idx is None or t_end_ms_idx is None:
87 | # Cannot guarantee window size anymore
88 | return None
89 |
90 | events = dict()
91 | time_array_conservative = np.asarray(
92 | self.events['t'][t_start_ms_idx:t_end_ms_idx])
93 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(
94 | time_array_conservative, t_start_us, t_end_us)
95 | t_start_us_idx = t_start_ms_idx + idx_start_offset
96 | t_end_us_idx = t_start_ms_idx + idx_end_offset
97 | # Again add t_offset to get gps time
98 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset
99 | for dset_str in ['p', 'x', 'y']:
100 | events[dset_str] = np.asarray(
101 | self.events[dset_str][t_start_us_idx:t_end_us_idx])
102 | assert events[dset_str].size == events['t'].size
103 | return events
104 |
105 | @staticmethod
106 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]:
107 | """Compute a conservative time window of time with millisecond resolution.
108 | We have a time to index mapping for each millisecond. Hence, we need
109 | to compute the lower and upper millisecond to retrieve events.
110 | Parameters
111 | ----------
112 | ts_start_us: start time in microseconds
113 | ts_end_us: end time in microseconds
114 | Returns
115 | -------
116 | window_start_ms: conservative start time in milliseconds
117 | window_end_ms: conservative end time in milliseconds
118 | """
119 | assert ts_end_us > ts_start_us
120 | window_start_ms = math.floor(ts_start_us/1000)
121 | window_end_ms = math.ceil(ts_end_us/1000)
122 | return window_start_ms, window_end_ms
123 |
124 | @staticmethod
125 | @jit(nopython=True)
126 | def get_time_indices_offsets(
127 | time_array: np.ndarray,
128 | time_start_us: int,
129 | time_end_us: int) -> Tuple[int, int]:
130 | """Compute index offset of start and end timestamps in microseconds
131 | Parameters
132 | ----------
133 | time_array: timestamps (in us) of the events
134 | time_start_us: start timestamp (in us)
135 | time_end_us: end timestamp (in us)
136 | Returns
137 | -------
138 | idx_start: Index within this array corresponding to time_start_us
139 | idx_end: Index within this array corresponding to time_end_us
140 | such that (in non-edge cases)
141 | time_array[idx_start] >= time_start_us
142 | time_array[idx_end] >= time_end_us
143 | time_array[idx_start - 1] < time_start_us
144 | time_array[idx_end - 1] < time_end_us
145 | this means that
146 | time_start_us <= time_array[idx_start:idx_end] < time_end_us
147 | """
148 |
149 | assert time_array.ndim == 1
150 |
151 | idx_start = -1
152 | if time_array[-1] < time_start_us:
153 | # This can happen in extreme corner cases. E.g.
154 | # time_array[0] = 1016
155 | # time_array[-1] = 1984
156 | # time_start_us = 1990
157 | # time_end_us = 2000
158 |
159 | # Return same index twice: array[x:x] is empty.
160 | return time_array.size, time_array.size
161 | else:
162 | for idx_from_start in range(0, time_array.size, 1):
163 | if time_array[idx_from_start] >= time_start_us:
164 | idx_start = idx_from_start
165 | break
166 | assert idx_start >= 0
167 |
168 | idx_end = time_array.size
169 | for idx_from_end in range(time_array.size - 1, -1, -1):
170 | if time_array[idx_from_end] >= time_end_us:
171 | idx_end = idx_from_end
172 | else:
173 | break
174 |
175 | assert time_array[idx_start] >= time_start_us
176 | if idx_end < time_array.size:
177 | assert time_array[idx_end] >= time_end_us
178 | if idx_start > 0:
179 | assert time_array[idx_start - 1] < time_start_us
180 | if idx_end > 0:
181 | assert time_array[idx_end - 1] < time_end_us
182 | return idx_start, idx_end
183 |
184 | def ms2idx(self, time_ms: int) -> int:
185 | assert time_ms >= 0
186 | if time_ms >= self.ms_to_idx.size:
187 | return None
188 | return self.ms_to_idx[time_ms]
189 |
190 |
191 | class Sequence(Dataset):
192 | def __init__(self, seq_path: Path, representation_type: RepresentationType, mode: str = 'test', delta_t_ms: int = 100,
193 | num_bins: int = 15, transforms=[], name_idx=0, visualize=False, load_gt=False):
194 | assert num_bins >= 1
195 | assert delta_t_ms == 100
196 | assert seq_path.is_dir()
197 | assert mode in {'train', 'test'}
198 | '''
199 | Directory Structure:
200 |
201 | Dataset
202 | └── test
203 | ├── interlaken_00_b
204 | │ ├── events_left
205 | │ │ ├── events.h5
206 | │ │ └── rectify_map.h5
207 | │ ├── image_timestamps.txt
208 | │ └── test_forward_flow_timestamps.csv
209 |
210 | '''
211 | self.seq_name = PurePath(seq_path).name
212 | self.mode = mode
213 | self.name_idx = name_idx
214 | self.visualize_samples = visualize
215 | self.load_gt = load_gt
216 | self.transforms = transforms
217 | if self.mode is "test":
218 | # Get Test Timestamp File
219 | ev_dir_location = seq_path / 'events_left'
220 | timestamp_file = seq_path / 'test_forward_flow_timestamps.csv'
221 | flow_path = seq_path
222 | timestamps_images = np.loadtxt(
223 | flow_path / 'image_timestamps.txt', dtype='int64')
224 | self.indices = np.arange(len(timestamps_images))[::2][1:-1]
225 | self.timestamps_flow = timestamps_images[::2][1:-1]
226 |
227 | elif self.mode is "train":
228 | ev_dir_location = seq_path / 'events' / 'left'
229 | seq_name = seq_path.parts[-1]
230 | flow_path = seq_path.parents[1] / \
231 | "train_optical_flow"/seq_name/'flow'
232 | timestamp_file = flow_path/'forward_timestamps.txt'
233 | self.flow_png = [Path(os.path.join(flow_path / 'forward', img)) for img in sorted(
234 | os.listdir(flow_path / 'forward'))]
235 | timestamps_images = np.loadtxt(
236 | flow_path / 'forward_timestamps.txt', delimiter=',', dtype='int64')
237 | self.indices = np.arange(len(timestamps_images) - 1)
238 | self.timestamps_flow = timestamps_images[1:, 0]
239 | else:
240 | pass
241 | assert timestamp_file.is_file()
242 |
243 | file = np.genfromtxt(
244 | timestamp_file,
245 | delimiter=','
246 | )
247 |
248 | self.idx_to_visualize = file[:, 2] if file.shape[1] == 3 else []
249 |
250 | # Save output dimensions
251 | self.height = 480
252 | self.width = 640
253 | self.num_bins = num_bins
254 |
255 | # Just for now, we always train with num_bins=15
256 | assert self.num_bins == 15
257 |
258 | # Set event representation
259 | self.voxel_grid = None
260 | if representation_type == RepresentationType.VOXEL:
261 | self.voxel_grid = VoxelGrid(
262 | (self.num_bins, self.height, self.width), normalize=True)
263 | if representation_type == "count":
264 | self.voxel_grid = "count"
265 | if representation_type == "pcount":
266 | self.voxel_grid = PolarityCount((2, self.height, self.width))
267 |
268 | # Save delta timestamp in ms
269 | self.delta_t_us = delta_t_ms * 1000
270 |
271 | # Left events only
272 | ev_data_file = ev_dir_location / 'events.h5'
273 | ev_rect_file = ev_dir_location / 'rectify_map.h5'
274 |
275 | h5f_location = h5py.File(str(ev_data_file), 'r')
276 | self.h5f = h5f_location
277 | self.event_slicer = EventSlicer(h5f_location)
278 |
279 | self.h5rect = h5py.File(str(ev_rect_file), 'r')
280 | self.rectify_ev_map = self.h5rect['rectify_map'][()]
281 |
282 |
283 | def events_to_voxel_grid(self, p, t, x, y, device: str = 'cpu'):
284 | t = (t - t[0]).astype('float32')
285 | t = (t/t[-1])
286 | x = x.astype('float32')
287 | y = y.astype('float32')
288 | pol = p.astype('float32')
289 | event_data_torch = {
290 | 'p': torch.from_numpy(pol),
291 | 't': torch.from_numpy(t),
292 | 'x': torch.from_numpy(x),
293 | 'y': torch.from_numpy(y),
294 | }
295 | return self.voxel_grid.convert(event_data_torch)
296 |
297 | def getHeightAndWidth(self):
298 | return self.height, self.width
299 |
300 | @staticmethod
301 | def get_disparity_map(filepath: Path):
302 | assert filepath.is_file()
303 | disp_16bit = cv2.imread(str(filepath), cv2.IMREAD_ANYDEPTH)
304 | return disp_16bit.astype('float32')/256
305 |
306 | @staticmethod
307 | def load_flow(flowfile: Path):
308 | assert flowfile.exists()
309 | assert flowfile.suffix == '.png'
310 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI')
311 | flow, valid2D = flow_16bit_to_float(flow_16bit)
312 | return flow, valid2D
313 |
314 | @staticmethod
315 | def close_callback(h5f):
316 | h5f.close()
317 |
318 | def get_image_width_height(self):
319 | return self.height, self.width
320 |
321 | def __len__(self):
322 | return len(self.timestamps_flow) - 1
323 |
324 | def rectify_events(self, x: np.ndarray, y: np.ndarray):
325 | # assert location in self.locations
326 | # From distorted to undistorted
327 | rectify_map = self.rectify_ev_map
328 | assert rectify_map.shape == (
329 | self.height, self.width, 2), rectify_map.shape
330 | assert x.max() < self.width
331 | assert y.max() < self.height
332 | return rectify_map[y, x]
333 |
334 | def get_data_sample(self, index, crop_window=None, flip=None):
335 | # First entry corresponds to all events BEFORE the flow map
336 | # Second entry corresponds to all events AFTER the flow map (corresponding to the actual fwd flow)
337 | names = ['event_volume_old', 'event_volume_new']
338 | ts_start = [self.timestamps_flow[index] -
339 | self.delta_t_us, self.timestamps_flow[index]]
340 | ts_end = [self.timestamps_flow[index],
341 | self.timestamps_flow[index] + self.delta_t_us]
342 |
343 | file_index = self.indices[index]
344 |
345 | output = {
346 | 'file_index': file_index,
347 | 'timestamp': self.timestamps_flow[index],
348 | 'seq_name': self.seq_name
349 | }
350 | # Save sample for benchmark submission
351 | output['save_submission'] = file_index in self.idx_to_visualize
352 | output['visualize'] = self.visualize_samples
353 |
354 | for i in range(len(names)):
355 | event_data = self.event_slicer.get_events(
356 | ts_start[i], ts_end[i])
357 |
358 | p = event_data['p']
359 | t = event_data['t']
360 | x = event_data['x']
361 | y = event_data['y']
362 |
363 | xy_rect = self.rectify_events(x, y)
364 | x_rect = xy_rect[:, 0]
365 | y_rect = xy_rect[:, 1]
366 |
367 | if crop_window is not None:
368 | # Cropping (+- 2 for safety reasons)
369 | x_mask = (x_rect >= crop_window['start_x']-2) & (
370 | x_rect < crop_window['start_x']+crop_window['crop_width']+2)
371 | y_mask = (y_rect >= crop_window['start_y']-2) & (
372 | y_rect < crop_window['start_y']+crop_window['crop_height']+2)
373 | mask_combined = x_mask & y_mask
374 | p = p[mask_combined]
375 | t = t[mask_combined]
376 | x_rect = x_rect[mask_combined]
377 | y_rect = y_rect[mask_combined]
378 |
379 | if self.voxel_grid is None:
380 | raise NotImplementedError
381 | else:
382 | event_representation = self.events_to_voxel_grid(
383 | p, t, x_rect, y_rect)
384 | output[names[i]] = event_representation
385 | output['name_map'] = self.name_idx
386 |
387 | if self.load_gt:
388 | output['flow_gt_' + names[i]
389 | ] = [torch.tensor(x) for x in self.load_flow(self.flow_png[index + i])]
390 |
391 | output['flow_gt_' + names[i]
392 | ][0] = torch.moveaxis(output['flow_gt_' + names[i]][0], -1, 0)
393 | output['flow_gt_' + names[i]
394 | ][1] = torch.unsqueeze(output['flow_gt_' + names[i]][1], 0)
395 |
396 | if self.load_gt:
397 | if index + 2 < len(self.flow_png):
398 | output['flow_gt_next'] = [torch.tensor(
399 | x) for x in self.load_flow(self.flow_png[index + 2])]
400 | output['flow_gt_next'][0] = torch.moveaxis(
401 | output['flow_gt_next'][0], -1, 0)
402 | output['flow_gt_next'][1] = torch.unsqueeze(
403 | output['flow_gt_next'][1], 0)
404 | return output
405 |
406 | def __getitem__(self, idx):
407 | sample = self.get_data_sample(idx)
408 | for key_t, transform in self.transforms.items():
409 | if key_t == "hflip":
410 | if random.random() > 0.5:
411 | for key in sample:
412 | if isinstance(sample[key], torch.Tensor):
413 | sample[key] = tf.functional.hflip(sample[key])
414 | if key.startswith("flow_gt"):
415 | sample[key] = [tf.functional.hflip(
416 | mask) for mask in sample[key]]
417 | sample[key][0][0, :] = -sample[key][0][0, :]
418 | elif key_t == "vflip":
419 | if random.random() < transform:
420 | for key in sample:
421 | if isinstance(sample[key], torch.Tensor):
422 | sample[key] = tf.functional.vflip(sample[key])
423 | if key.startswith("flow_gt"):
424 | sample[key] = [tf.functional.vflip(
425 | mask) for mask in sample[key]]
426 | sample[key][0][1, :] = -sample[key][0][1, :]
427 | elif key_t == "randomcrop":
428 | apply_randomcrop_to_sample(sample, crop_size=transform)
429 | else:
430 | apply_transform_to_field(sample, transform, key_t)
431 |
432 |
433 | return sample
434 |
435 | def get_voxel_grid(self, idx):
436 |
437 | if idx == 0:
438 | event_data = self.event_slicer.get_events(
439 | self.timestamps_flow[0] - self.delta_t_us, self.timestamps_flow[0])
440 | elif idx > 0 and idx <= self.__len__():
441 | event_data = self.event_slicer.get_events(
442 | self.timestamps_flow[idx-1], self.timestamps_flow[idx-1] + self.delta_t_us)
443 | else:
444 | raise IndexError
445 |
446 | p = event_data['p']
447 | t = event_data['t']
448 | x = event_data['x']
449 | y = event_data['y']
450 |
451 | xy_rect = self.rectify_events(x, y)
452 | x_rect = xy_rect[:, 0]
453 | y_rect = xy_rect[:, 1]
454 | return self.events_to_voxel_grid(p, t, x_rect, y_rect)
455 |
456 | def get_event_count_image(self, ts_start, ts_end, num_bins=15, normalize=True):
457 | assert ts_end > ts_start
458 | delta_t_bin = (ts_end - ts_start) / num_bins
459 | ts_start_bin = np.linspace(
460 | ts_start, ts_end, num=num_bins, endpoint=False)
461 | ts_end_bin = ts_start_bin + delta_t_bin
462 | assert abs(ts_end_bin[-1] - ts_end) < 10.
463 | ts_end_bin[-1] = ts_end
464 |
465 | event_count = torch.zeros(
466 | (num_bins, self.height, self.width), dtype=torch.float, requires_grad=False)
467 |
468 | for i in range(num_bins):
469 | event_data = self.event_slicer.get_events(
470 | ts_start_bin[i], ts_end_bin[i])
471 | p = event_data['p']
472 | t = event_data['t']
473 | x = event_data['x']
474 | y = event_data['y']
475 |
476 | t = (t - t[0]).astype('float32')
477 | t = (t/t[-1])
478 | x = x.astype('float32')
479 | y = y.astype('float32')
480 | pol = p.astype('float32')
481 | event_data_torch = {
482 | 'p': torch.from_numpy(pol),
483 | 't': torch.from_numpy(t),
484 | 'x': torch.from_numpy(x),
485 | 'y': torch.from_numpy(y),
486 | }
487 | x = event_data_torch['x']
488 | y = event_data_torch['y']
489 | xy_rect = self.rectify_events(x.int(), y.int())
490 | x_rect = torch.from_numpy(xy_rect[:, 0]).long()
491 | y_rect = torch.from_numpy(xy_rect[:, 1]).long()
492 | value = 2*event_data_torch['p']-1
493 | index = self.width*y_rect + x_rect
494 | mask = (x_rect < self.width) & (y_rect < self.height)
495 | event_count[i].put_(index[mask], value[mask], accumulate=True)
496 |
497 | return event_count
498 |
499 | @staticmethod
500 | def normalize_tensor(event_count):
501 | mask = torch.nonzero(event_count, as_tuple=True)
502 | if mask[0].size()[0] > 0:
503 | mean = event_count[mask].mean()
504 | std = event_count[mask].std()
505 | if std > 0:
506 | event_count[mask] = (event_count[mask] - mean) / std
507 | else:
508 | event_count[mask] = event_count[mask] - mean
509 | return event_count
510 |
511 |
512 | class SequenceRecurrent(Sequence):
513 | def __init__(self, seq_path: Path, representation_type: RepresentationType, mode: str = 'test', delta_t_ms: int = 100,
514 | num_bins: int = 15, transforms=None, sequence_length=1, name_idx=0, visualize=False, load_gt=False):
515 | super(SequenceRecurrent, self).__init__(seq_path, representation_type, mode, delta_t_ms, transforms=transforms,
516 | name_idx=name_idx, visualize=visualize, load_gt=load_gt)
517 | self.crop_size = self.transforms['randomcrop'] if 'randomcrop' in self.transforms else None
518 | self.sequence_length = sequence_length
519 | self.valid_indices = self.get_continuous_sequences()
520 |
521 | def get_continuous_sequences(self):
522 | continuous_seq_idcs = []
523 | if self.sequence_length > 1:
524 | for i in range(len(self.timestamps_flow)-self.sequence_length+1):
525 | diff = self.timestamps_flow[i +
526 | self.sequence_length-1] - self.timestamps_flow[i]
527 | if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]):
528 | continuous_seq_idcs.append(i)
529 | else:
530 | for i in range(len(self.timestamps_flow)-1):
531 | diff = self.timestamps_flow[i+1] - self.timestamps_flow[i]
532 | if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]):
533 | continuous_seq_idcs.append(i)
534 | return continuous_seq_idcs
535 |
536 | def __len__(self):
537 | return len(self.valid_indices)
538 |
539 | def __getitem__(self, idx):
540 | assert idx >= 0
541 | assert idx < len(self)
542 |
543 | # Valid index is the actual index we want to load, which guarantees a continuous sequence length
544 | valid_idx = self.valid_indices[idx]
545 |
546 | sequence = []
547 | j = valid_idx
548 |
549 | ts_cur = self.timestamps_flow[j]
550 | # Add first sample
551 | sample = self.get_data_sample(j)
552 | sequence.append(sample)
553 |
554 | # Data augmentation according to first sample
555 | crop_window = None
556 | flip = None
557 | if 'crop_window' in sample.keys():
558 | crop_window = sample['crop_window']
559 | if 'flipped' in sample.keys():
560 | flip = sample['flipped']
561 |
562 | for i in range(self.sequence_length-1):
563 | j += 1
564 | ts_old = ts_cur
565 | ts_cur = self.timestamps_flow[j]
566 | assert(ts_cur-ts_old < 100000 + 1000)
567 | sample = self.get_data_sample(
568 | j, crop_window=crop_window, flip=flip)
569 | sequence.append(sample)
570 |
571 | # Check if the current sample is the first sample of a continuous sequence
572 | if idx == 0 or self.valid_indices[idx]-self.valid_indices[idx-1] != 1:
573 | sequence[0]['new_sequence'] = 1
574 | print("Timestamp {} is the first one of the next seq!".format(
575 | self.timestamps_flow[self.valid_indices[idx]]))
576 | else:
577 | sequence[0]['new_sequence'] = 0
578 |
579 | # random crop
580 | if self.crop_size is not None:
581 | i, j, h, w = RandomCrop.get_params(
582 | sample["event_volume_old"], output_size=self.crop_size)
583 | keys_to_crop = ["event_volume_old", "event_volume_new",
584 | "flow_gt_event_volume_old", "flow_gt_event_volume_new",
585 | "flow_gt_next",]
586 |
587 | for sample in sequence:
588 | for key, value in sample.items():
589 | if key in keys_to_crop:
590 | if isinstance(value, torch.Tensor):
591 | sample[key] = tf.functional.crop(value, i, j, h, w)
592 | elif isinstance(value, list) or isinstance(value, tuple):
593 | sample[key] = [tf.functional.crop(
594 | v, i, j, h, w) for v in value]
595 | return sequence
596 |
597 |
598 | class DatasetProvider:
599 | def __init__(self, dataset_path: Path, representation_type: RepresentationType, delta_t_ms: int = 100, num_bins=15,
600 | type='standard', config=None, visualize=False):
601 | test_path = dataset_path / 'test'
602 | assert dataset_path.is_dir(), str(dataset_path)
603 | assert test_path.is_dir(), str(test_path)
604 | assert delta_t_ms == 100
605 | self.config = config
606 | self.name_mapper_test = []
607 |
608 | test_sequences = list()
609 | for child in test_path.iterdir():
610 | self.name_mapper_test.append(str(child).split("/")[-1])
611 | if type == 'standard':
612 | test_sequences.append(Sequence(child, representation_type, 'test', delta_t_ms, num_bins,
613 | transforms=[],
614 | name_idx=len(
615 | self.name_mapper_test)-1,
616 | visualize=visualize))
617 | elif type == 'warm_start':
618 | test_sequences.append(SequenceRecurrent(child, representation_type, 'test', delta_t_ms, num_bins,
619 | transforms=[], sequence_length=1,
620 | name_idx=len(
621 | self.name_mapper_test)-1,
622 | visualize=visualize))
623 | else:
624 | raise Exception(
625 | 'Please provide a valid subtype [standard/warm_start] in config file!')
626 |
627 | self.test_dataset = torch.utils.data.ConcatDataset(test_sequences)
628 |
629 | def get_test_dataset(self):
630 | return self.test_dataset
631 |
632 | def get_name_mapping_test(self):
633 | return self.name_mapper_test
634 |
635 | def summary(self, logger):
636 | logger.write_line(
637 | "================================== Dataloader Summary ====================================", True)
638 | logger.write_line("Loader Type:\t\t" + self.__class__.__name__, True)
639 | logger.write_line("Number of Voxel Bins: {}".format(
640 | self.test_dataset.datasets[0].num_bins), True)
641 |
642 |
643 | def assemble_dsec_sequences(dataset_root, include_seq=None, exclude_seq=None, require_gt=True, config=None, representation_type="voxel", num_bins=None):
644 | if representation_type is None:
645 | representation_type = "voxel"
646 | representation_type = RepresentationType.VOXEL if representation_type == "voxel" else representation_type
647 | event_root = os.path.join(dataset_root, "train_events")
648 | flow_gt_root = os.path.join(dataset_root, "train_optical_flow")
649 | available_seqs = os.listdir(
650 | flow_gt_root) if require_gt else os.listdir(event_root)
651 |
652 | seqs = available_seqs
653 | if include_seq:
654 | seqs = [seq for seq in seqs if seq in include_seq]
655 | if exclude_seq:
656 | seqs = [seq for seq in seqs if seq not in exclude_seq]
657 |
658 | # Prepare transform list
659 | transforms = dict()
660 | if config.downsample_ratio > 1:
661 | transforms['(?= 0 and p_hflip <= 1
669 | #from torchvision.transforms import RandomHorizontalFlip
670 | # ignore probability of hflip for now, perform when 'hflip' key exists in transforms dict
671 | transforms['hflip'] = None
672 | if config.get("vertical_flip", None):
673 | p_vflip = config.vertical_flip
674 | assert p_vflip >= 0 and p_vflip <= 1
675 | transforms['vflip'] = p_vflip
676 | if config.get("random_crop", None):
677 | crop_size = config.random_crop
678 | transforms['randomcrop'] = crop_size
679 |
680 | seq_dataset = []
681 | for seq in seqs:
682 | dataset_cls = SequenceRecurrent if hasattr(
683 | config, "recurrent") and config.recurrent else Sequence
684 | extra_arg = dict(
685 | sequence_length=config.sequence_length) if dataset_cls == SequenceRecurrent else dict()
686 |
687 | seq_dataset.append(dataset_cls(Path(event_root) / seq,
688 | representation_type=representation_type, mode="train",
689 | load_gt=require_gt, transforms=transforms, **extra_arg))
690 | if config.get("concat_seq", True):
691 | return torch.utils.data.ConcatDataset(seq_dataset)
692 | else:
693 | return seq_dataset
694 |
695 |
696 | def assemble_dsec_test_set(test_set_root, seq_len=None, concat_seq=False, representation_type=None):
697 | if representation_type is None:
698 | representation_type = RepresentationType.VOXEL
699 | print("dsec test uses representation: voxel")
700 | elif representation_type == "voxel":
701 | representation_type = RepresentationType.VOXEL
702 | print("dsec test uses representation: voxel")
703 | else:
704 | print("dsec test uses representation: {}".format(representation_type))
705 | representation_type = representation_type
706 | test_seqs = os.listdir(test_set_root)
707 | seqs = []
708 |
709 | transforms = dict()
710 | for seq in test_seqs:
711 | dataset_cls = SequenceRecurrent if seq_len else Sequence
712 | extra_arg = dict(
713 | sequence_length=seq_len) if dataset_cls == SequenceRecurrent else dict()
714 | seqs.append(dataset_cls(Path(test_set_root) / seq,
715 | representation_type, mode='test',
716 | load_gt=False, transforms=transforms, **extra_arg))
717 | if concat_seq:
718 | return torch.utils.data.ConcatDataset(seqs)
719 | else:
720 | return seqs
721 |
722 |
723 | def assemble_dsec_train_set(train_set_root, flow_gt_root=None, exclude_seq=None, args=None):
724 | train_seqs = os.listdir(
725 | flow_gt_root) if flow_gt_root is not None else os.listdir(train_set_root)
726 | if exclude_seq is not None:
727 | train_seqs = [seq for seq in train_seqs if seq not in exclude_seq]
728 | seq_dataset = []
729 | for seq in train_seqs:
730 | seq_dataset.append(Sequence(Path(train_set_root) / seq,
731 | RepresentationType.VOXEL, mode='train',
732 | transforms=ToTensor))
733 | return torch.utils.data.ConcatDataset(seq_dataset)
734 |
735 |
736 | def train_collate(sample_list):
737 | batch = dict()
738 | for field_name in sample_list[0]:
739 | if field_name == 'seq_name':
740 | batch['seq_name'] = [sample[field_name] for sample in sample_list]
741 | if field_name == 'new_sequence':
742 | batch['new_sequence'] = [sample[field_name]
743 | for sample in sample_list]
744 | if field_name.startswith("event_volume"):
745 | batch[field_name] = torch.stack(
746 | [sample[field_name] for sample in sample_list])
747 | if field_name.startswith("flow_gt"):
748 | if all(field_name in x for x in sample_list):
749 | batch[field_name] = torch.stack(
750 | [sample[field_name][0] for sample in sample_list])
751 | batch[field_name + '_valid_mask'] = torch.stack(
752 | [sample[field_name][1] for sample in sample_list])
753 |
754 | return batch
755 |
756 |
757 | def rec_train_collate(sample_list):
758 | seq_length = len(sample_list[0])
759 | seq_of_batch = []
760 | for i in range(seq_length):
761 | seq_of_batch.append(train_collate(
762 | [sample[i] for sample in sample_list]))
763 | return seq_of_batch
764 |
--------------------------------------------------------------------------------
/idn/loader/loader_mvsec.py:
--------------------------------------------------------------------------------
1 | # import h5pickle as h5py
2 | import h5py
3 | import os
4 | import torch
5 | import random
6 | import numpy as np
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms as T
9 | from idn.utils.mvsec_utils import EventSequence
10 | from idn.utils.dsec_utils import RepresentationType, VoxelGrid
11 | from idn.utils.transformers import EventSequenceToVoxelGrid_Pytorch, apply_randomcrop_to_sample
12 |
13 | class MVSEC(Dataset):
14 | def __init__(self, seq_name, seq_path="data/", representation_type=None, \
15 | rate=20, num_bins=5, transforms=[], filter=None, augment=True, dt=None):
16 | self.seq_name = seq_name
17 | self.dt = dt
18 | if self.dt is None:
19 | self.event_h5 = h5py.File(os.path.join(seq_path, f"{seq_name}_data.hdf5"), "r")
20 | self.event = self.event_h5['davis']['left']['events']
21 | self.gt_h5 = h5py.File(os.path.join(seq_path, f"{seq_name}_gt.hdf5"), "r")
22 | self.gt_flow = self.gt_h5['davis']['left']['flow_dist']
23 | self.timestamps = self.gt_h5['davis']['left']['flow_dist_ts']
24 | else:
25 | assert self.dt == 1 or self.dt == 4
26 | self.h5 = h5py.File(os.path.join(seq_path, f"{seq_name}.h5"), "r")
27 | self.event = self.h5['events']
28 | self.timestamps = self.h5['flow']['dt={}'.format(self.dt)]['timestamps'][:, 0]
29 | self.gt_flow = list(self.h5['flow']['dt={}'.format(self.dt)].keys())
30 | self.gt_flow.remove('timestamps')
31 | assert sorted(self.gt_flow) == self.gt_flow
32 |
33 | if representation_type is None:
34 | self.representation_type = VoxelGrid
35 | else:
36 | self.representation_type = representation_type
37 |
38 | if filter is not None:
39 | assert isinstance(filter, tuple) and isinstance(filter[0], int)\
40 | and isinstance(filter[1], int)
41 | self.timestamps = self.timestamps[slice(*filter)]
42 | self.gt_flow = self.gt_flow[slice(*filter)]
43 |
44 | self.raw_gt_len = self.timestamps.shape[0]
45 | self.event_ts_to_idx = self.build_event_idx()
46 | self.voxel = EventSequenceToVoxelGrid_Pytorch(
47 | num_bins=num_bins,
48 | normalize=True,
49 | gpu=False,
50 | )
51 | self.image_width, self.image_height = 346, 260
52 | self.cropper = T.CenterCrop((256, 256))
53 | self.augment = augment
54 | pass
55 |
56 | def __len__(self):
57 | return self.raw_gt_len - 2
58 |
59 | def __getitem__(self, idx):
60 | idx += 1
61 | sample = {}
62 | if self.dt is None:
63 | # get events
64 | events = self.event[self.event_ts_to_idx[idx-1]:self.event_ts_to_idx[idx]]
65 | events = events[:, [2, 0, 1, 3]] # make it (t, x, y, p)
66 | sample["event_volume_old"] = \
67 | self.voxel(EventSequence(events,
68 | params={'width': self.image_width,
69 | 'height': self.image_height},
70 | timestamp_multiplier=1e6,
71 | convert_to_relative=True,
72 | features=events))
73 |
74 | # get events
75 | events = self.event[self.event_ts_to_idx[idx]:self.event_ts_to_idx[idx+1]]
76 | events = events[:, [2, 0, 1, 3]] # make it (t, x, y, p)
77 |
78 | sample["event_volume_new"] = \
79 | self.voxel(EventSequence(events,
80 | params={'width': self.image_width,
81 | 'height': self.image_height},
82 | timestamp_multiplier=1e6,
83 | convert_to_relative=True,
84 | features = events))
85 | # get flow
86 | flow = self.gt_flow[idx] # -1 yields the same gt flow as E-RAFT, but likely incorrect
87 | flow_next = self.gt_flow[idx+1]
88 | else:
89 | old_p = self.event['ps'][self.event_ts_to_idx[idx-1]:self.event_ts_to_idx[idx]]
90 | old_t = self.event['ts'][self.event_ts_to_idx[idx-1]:self.event_ts_to_idx[idx]]
91 | old_x = self.event['xs'][self.event_ts_to_idx[idx-1]:self.event_ts_to_idx[idx]]
92 | old_y = self.event['ys'][self.event_ts_to_idx[idx-1]:self.event_ts_to_idx[idx]]
93 |
94 | old_events = np.column_stack((old_t, old_x, old_y, old_p))
95 | sample["event_volume_old"] = \
96 | self.voxel(EventSequence(old_events,
97 | params={'width': self.image_width,
98 | 'height': self.image_height},
99 | timestamp_multiplier=1e6,
100 | convert_to_relative=True,
101 | features=old_events))
102 |
103 | new_p = self.event['ps'][self.event_ts_to_idx[idx]:self.event_ts_to_idx[idx+1]]
104 | new_t = self.event['ts'][self.event_ts_to_idx[idx]:self.event_ts_to_idx[idx+1]]
105 | new_x = self.event['xs'][self.event_ts_to_idx[idx]:self.event_ts_to_idx[idx+1]]
106 | new_y = self.event['ys'][self.event_ts_to_idx[idx]:self.event_ts_to_idx[idx+1]]
107 |
108 | new_events = np.column_stack((new_t, new_x, new_y, new_p))
109 | sample["event_volume_new"] = \
110 | self.voxel(EventSequence(new_events,
111 | params={'width': self.image_width,
112 | 'height': self.image_height},
113 | timestamp_multiplier=1e6,
114 | convert_to_relative=True,
115 | features=new_events))
116 |
117 | # get flow
118 | flow = np.transpose(self.h5['flow']['dt={}'.format(self.dt)][self.gt_flow[idx]][:], (2, 0, 1))
119 | flow_next = np.transpose(self.h5['flow']['dt={}'.format(self.dt)][self.gt_flow[idx+1]][:], (2, 0, 1))
120 |
121 |
122 | sample["flow_gt_event_volume_new"] = self.process_flow_gt(flow)
123 | sample["flow_gt_next"] = self.process_flow_gt(flow_next)
124 |
125 | sample["event_volume_old"] = self.cropper(sample["event_volume_old"])
126 | sample["event_volume_new"] = self.cropper(sample["event_volume_new"])
127 |
128 | if self.augment:
129 | # augmentation
130 | if random.random() > 0.5:
131 | for key in sample:
132 | if isinstance(sample[key], torch.Tensor):
133 | sample[key] = T.functional.hflip(sample[key])
134 | if key.startswith("flow_gt"):
135 | sample[key] = [T.functional.hflip(
136 | mask) for mask in sample[key]]
137 | sample[key][0][0, :] = -sample[key][0][0, :]
138 |
139 |
140 | return sample
141 |
142 | def process_flow_gt(self, flow):
143 | flow_valid = (flow[0] != 0) | (flow[1] != 0)
144 | flow_valid[193:, :] = False
145 | flow = torch.from_numpy(flow)
146 | valid_mask = torch.from_numpy(
147 | np.stack([flow_valid]*1, axis=0))
148 |
149 | return (self.cropper(flow), self.cropper(valid_mask))
150 |
151 | def build_event_idx(self):
152 | if self.dt is None:
153 | events_ts = self.event_h5['davis']['left']['events'][:, 2]
154 | else:
155 | events_ts = self.h5['events']['ts']
156 | return np.searchsorted(events_ts, self.timestamps, side='left')
157 |
158 |
159 | class MVSECRecurrent(MVSEC):
160 | def __init__(self, seq_name, seq_path="/scratch", representation_type=None,
161 | rate=20, num_bins=15, transforms=[], filter=None, augment=True, sequence_length=1):
162 | super(MVSECRecurrent, self).__init__(seq_name, seq_path, representation_type,
163 | rate, num_bins, transforms, filter, augment)
164 | self.sequence_length = sequence_length
165 | self.valid_indices = self.get_continuous_sequences()
166 |
167 | def get_continuous_sequences(self):
168 | # MVSEC is continuous without breaks
169 | continuous_seq_idcs = list(
170 | (range(self.raw_gt_len - 2 - self.sequence_length)))
171 | return continuous_seq_idcs
172 |
173 | def __len__(self):
174 | return len(self.valid_indices)
175 |
176 | def __getitem__(self, idx):
177 | assert idx >= 0
178 | assert idx < len(self)
179 |
180 | valid_idx = self.valid_indices[idx]
181 | sequence = []
182 | j = valid_idx
183 |
184 | for i in range(self.sequence_length):
185 | sample = super(MVSECRecurrent, self).__getitem__(j)
186 | sequence.append(sample)
187 | j += 1
188 |
189 |
190 | # Check if the current sample is the first sample of a continuous sequence
191 | if idx == 0 or self.valid_indices[idx]-self.valid_indices[idx-1] != 1:
192 | sequence[0]['new_sequence'] = 1
193 | else:
194 | sequence[0]['new_sequence'] = 0
195 |
196 | return sequence
197 |
198 |
--------------------------------------------------------------------------------
/idn/model/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ResidualBlock(nn.Module):
6 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
7 | super(ResidualBlock, self).__init__()
8 |
9 | self.conv1 = nn.Conv2d(
10 | in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(
18 | num_groups=num_groups, num_channels=planes)
19 | self.norm2 = nn.GroupNorm(
20 | num_groups=num_groups, num_channels=planes)
21 | if not stride == 1:
22 | self.norm3 = nn.GroupNorm(
23 | num_groups=num_groups, num_channels=planes)
24 |
25 | elif norm_fn == 'batch':
26 | self.norm1 = nn.BatchNorm2d(planes)
27 | self.norm2 = nn.BatchNorm2d(planes)
28 | if not stride == 1:
29 | self.norm3 = nn.BatchNorm2d(planes)
30 |
31 | elif norm_fn == 'instance':
32 | self.norm1 = nn.InstanceNorm2d(planes)
33 | self.norm2 = nn.InstanceNorm2d(planes)
34 | if not stride == 1:
35 | self.norm3 = nn.InstanceNorm2d(planes)
36 |
37 | elif norm_fn == 'none':
38 | self.norm1 = nn.Sequential()
39 | self.norm2 = nn.Sequential()
40 | if not stride == 1:
41 | self.norm3 = nn.Sequential()
42 |
43 | if stride == 1:
44 | self.downsample = None
45 |
46 | else:
47 | self.downsample = nn.Sequential(
48 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
49 |
50 | def forward(self, x):
51 | y = x
52 | y = self.relu(self.norm1(self.conv1(y)))
53 | y = self.relu(self.norm2(self.conv2(y)))
54 |
55 | if self.downsample is not None:
56 | x = self.downsample(x)
57 |
58 | return self.relu(x+y)
59 |
60 |
61 |
62 |
63 | class LiteEncoder(nn.Module):
64 | def __init__(self, output_dim=32, stride=2, dropout=0.0, n_first_channels=1):
65 | super(LiteEncoder, self).__init__()
66 |
67 | self.conv1 = nn.Conv2d(n_first_channels, output_dim,
68 | kernel_size=7, stride=2, padding=3)
69 | self.relu1 = nn.ReLU(inplace=True)
70 |
71 | self.in_planes = output_dim
72 | if stride == 2:
73 | self.layer1 = self._make_layer(output_dim, stride=2)
74 | self.layer2 = self._make_layer(output_dim*2, stride=2)
75 |
76 | elif stride == 1:
77 | self.layer1 = self._make_layer(output_dim*2, stride=2)
78 | self.layer2 = self._make_layer(output_dim*2, stride=1)
79 |
80 | else:
81 | raise ValueError('stride must be 1 or 2')
82 |
83 |
84 | self.dropout = None
85 | if dropout > 0:
86 | self.dropout = nn.Dropout2d(p=dropout)
87 |
88 | for m in self.modules():
89 | if isinstance(m, nn.Conv2d):
90 | nn.init.kaiming_normal_(
91 | m.weight, mode='fan_out', nonlinearity='relu')
92 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
93 | if m.weight is not None:
94 | nn.init.constant_(m.weight, 1)
95 | if m.bias is not None:
96 | nn.init.constant_(m.bias, 0)
97 |
98 | def _make_layer(self, dim, stride=1):
99 | layer1 = ResidualBlock(self.in_planes, dim, 'none', stride=stride)
100 | layer2 = ResidualBlock(dim, dim, 'none', stride=1)
101 | layers = (layer1, layer2)
102 |
103 | self.in_planes = dim
104 | return nn.Sequential(*layers)
105 |
106 | def forward(self, x):
107 | # if input is list, combine batch dimension
108 | is_list = isinstance(x, tuple) or isinstance(x, list)
109 | if is_list:
110 | batch_dim = x[0].shape[0]
111 | x = torch.cat(x, dim=0)
112 |
113 | x = self.conv1(x)
114 | x = self.relu1(x)
115 | x = self.layer1(x)
116 | x = self.layer2(x)
117 |
118 | # x = self.conv2(x)
119 |
120 | if self.training and self.dropout is not None:
121 | x = self.dropout(x)
122 |
123 | if is_list:
124 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
125 | return x
126 |
127 |
--------------------------------------------------------------------------------
/idn/model/idedeq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.functional import unfold, grid_sample, interpolate
4 |
5 | from .extractor import LiteEncoder
6 | from .update import LiteUpdateBlock
7 |
8 | from math import sqrt
9 |
10 |
11 | class IDEDEQIDO(nn.Module):
12 | def __init__(self, config):
13 | super(IDEDEQIDO, self).__init__()
14 | self.hidden_dim = getattr(config, 'hidden_dim', 96)
15 | self.input_dim = 64
16 | self.downsample = getattr(config, 'downsample', 8)
17 | self.input_flowmap = getattr(config, 'input_flowmap', False)
18 | self.pred_next_flow = getattr(config, 'pred_next_flow', False)
19 | self.fnet = LiteEncoder(
20 | output_dim=self.input_dim//2, dropout=0, n_first_channels=2, stride=2 if self.downsample == 8 else 1)
21 | self.update_net = LiteUpdateBlock(
22 | hidden_dim=self.hidden_dim, input_dim=self.input_dim,
23 | num_outputs=2 if self.pred_next_flow else 1,
24 | downsample=self.downsample)
25 | self.deblur_iters = config.update_iters
26 | self.zero_init = config.zero_init
27 | self._deq = getattr(config, "deq_mode", False)
28 | self.hook = None
29 | self.co_mode = getattr(config, "co_mode", False)
30 | self.conr_mode = getattr(config, "conr_mode", False)
31 | self.deblur = getattr(config, "deblur", True)
32 | self.add_delta = getattr(config, "add_delta", False)
33 | self.deblur_mode = getattr(config, "deblur_mode", "voxel")
34 | self.reset_continuous_flow()
35 | if self.input_flowmap:
36 | self.cnet = LiteEncoder(
37 | output_dim=self.hidden_dim // 2, dropout=0, n_first_channels=2, stride=2 if self.downsample == 8 else 1)
38 | else:
39 | self.cnet = None
40 |
41 | def upsample_flow(self, flow, mask):
42 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
43 | N, _, H, W = flow.shape
44 | _, D, H, W = mask.shape
45 | upsample_ratio = int(sqrt(D/9))
46 | mask = mask.view(N, 1, 9, upsample_ratio, upsample_ratio, H, W)
47 | mask = torch.softmax(mask, dim=2)
48 |
49 | up_flow = unfold(8 * flow, [3, 3], padding=1)
50 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
51 |
52 | up_flow = torch.sum(mask * up_flow, dim=2)
53 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
54 | return up_flow.reshape(N, 2, upsample_ratio*H, upsample_ratio*W)
55 |
56 | @staticmethod
57 | def upflow8(flow, mode='bilinear'):
58 | new_size = (8 * flow.shape[-2], 8 * flow.shape[-1])
59 | return 8 * interpolate(flow, size=new_size, mode=mode, align_corners=True)
60 |
61 | @staticmethod
62 | def create_identity_grid(H, W, device):
63 | i, j = map(lambda x: x.float(), torch.meshgrid(
64 | [torch.arange(0, H), torch.arange(0, W)], indexing='ij'))
65 | return torch.stack([j, i], dim=-1).to(device)
66 |
67 | def deblur_tensor(self, raw_input, flow, mask=None):
68 | # raw: [N, T, C, H, W]
69 | raw = raw_input.unsqueeze(2) if raw_input.ndim == 4 else raw_input
70 | N, T, C, H, W = raw.shape
71 | deblurred_tensor = torch.zeros_like(raw)
72 | identity_grid = self.create_identity_grid(H, W, raw.device)
73 | for t in range(T):
74 | if self.deblur_mode == "voxel":
75 | delta_p = flow*t/(T-1)
76 | else:
77 | delta_p = flow*((t+0.5)/T)
78 | sampling_grid = identity_grid + torch.movedim(delta_p, 1, -1)
79 | sampling_grid[..., 0] = sampling_grid[..., 0] / (W-1) * 2 - 1
80 | sampling_grid[..., 1] = sampling_grid[..., 1] / (H-1) * 2 - 1
81 | deblurred_tensor[:, t,
82 | ] = grid_sample(raw[:, t, ], sampling_grid, align_corners=False)
83 | if raw_input.ndim == 4:
84 | deblurred_tensor = deblurred_tensor.squeeze(2)
85 | return deblurred_tensor
86 |
87 | def reset_continuous_flow(self, reset=True):
88 | if reset:
89 | self.flow_init = None
90 | self.last_net_co = None
91 |
92 | def forward(self, event_bins, flow_init=None, deblur_iters=None, net_co=None):
93 | if self.co_mode:
94 | # continuous mode
95 | # take flow_init from state and forward propagate it
96 | assert flow_init is None, "flow_init should be None in continuous mode"
97 | if self.flow_init is None:
98 | print("No last flow, using zero flow")
99 | elif 'new_sequence' in event_bins and event_bins['new_sequence'][0] == 1:
100 | print("Got new sequence, resetting flow")
101 | flow_init = None
102 | else:
103 | flow_init = self.flow_init
104 | if self.conr_mode:
105 | if self.last_net_co is None:
106 | print("No last net_co, using zero flow")
107 | elif event_bins['new_sequence'][0] == 1:
108 | print("Got new sequence, resetting flow")
109 | net_co = None
110 | else:
111 | net_co = self.last_net_co
112 |
113 | deblur_iters = self.deblur_iters if deblur_iters is None else deblur_iters
114 | # x_old, x_new = event_bins["event_volume_old"], event_bins["event_volume_new"]
115 | x_raw = event_bins["event_volume_new"]
116 |
117 | B, V, H, W = x_raw.shape
118 | flow_total = torch.zeros(B, 2, H, W).to(
119 | x_raw.device) if flow_init is None else flow_init.clone()
120 |
121 | delta_flow = flow_total
122 | flow_history = torch.zeros(B, 0, 2, H, W).to(x_raw.device)
123 | x_deblur_history = x_raw.clone().unsqueeze(1)
124 | delta_flow_history = delta_flow.clone().unsqueeze(1)
125 |
126 |
127 | x_deblur = x_raw.clone()
128 | for iter in range(deblur_iters):
129 | if self.deblur:
130 | x_deblur = self.deblur_tensor(x_deblur, delta_flow)
131 | x = torch.stack([x_deblur, x_deblur], dim=1)
132 | x_deblur_history = torch.cat(
133 | [x_deblur_history, x_deblur.unsqueeze(1)], dim=1)
134 | else:
135 | x = torch.stack([x_raw, x_raw], dim=1)
136 |
137 | if net_co is not None:
138 | net = net_co
139 | else:
140 | if self.input_flowmap:
141 | assert self.cnet is not None, "cnet non initialized in flowmap mode"
142 | if flow_init is not None or iter >= 1:
143 | net = self.cnet(flow_total)
144 | else:
145 | net = torch.zeros(
146 | (B, self.hidden_dim,
147 | H//self.downsample, W//self.downsample)).to(x.device)
148 | else:
149 | if self.cnet is not None:
150 | net = self.cnet(x)
151 | else:
152 | net = torch.zeros(
153 | (B, self.hidden_dim,
154 | H//self.downsample, W//self.downsample)).to(x.device)
155 | for i, slice in enumerate(x.permute(2, 0, 1, 3, 4)):
156 | f = self.fnet(slice)
157 | net = self.update_net(net, f)
158 |
159 | dflow = self.update_net.compute_deltaflow(net)
160 | up_mask = self.update_net.compute_up_mask(net)
161 | delta_flow = self.upsample_flow(dflow, up_mask)
162 | delta_flow_history = torch.cat(
163 | [delta_flow_history, delta_flow.unsqueeze(1)], dim=1)
164 | if self.pred_next_flow:
165 | nflow = self.update_net.compute_nextflow(net)
166 | up_mask_next_flow = self.update_net.compute_up_mask2(net)
167 | next_flow = self.upsample_flow(nflow, up_mask_next_flow)
168 | else:
169 | next_flow = None
170 |
171 | if self.deblur or self.add_delta:
172 | flow_total = flow_total + delta_flow
173 | else:
174 | flow_total = delta_flow
175 | flow_history = torch.cat(
176 | [flow_history, flow_total.unsqueeze(1)], dim=1)
177 |
178 |
179 | if self.co_mode:
180 | if self.pred_next_flow:
181 | assert 'next_flow' in locals()
182 | self.flow_init = next_flow
183 | else:
184 | self.flow_init = self.forward_flow(flow_total)
185 | if self.conr_mode:
186 | self.last_net_co = net
187 | return {'final_prediction': flow_total,
188 | 'next_flow': next_flow,
189 | 'delta_flow': delta_flow_history,
190 | 'deblurred_event_volume_new': x_deblur_history,
191 | 'flow_history': flow_history,
192 | 'net': net}
193 |
194 | def forward_flowmap(self, event_bins, flow_init=None, deblur_iters=None):
195 | deblur_iters = self.deblur_iters if deblur_iters is None else deblur_iters
196 |
197 | x = event_bins["event_volume_new"]
198 |
199 | B, V, H, W = x.shape
200 | flow_total = torch.zeros(B, 2, H, W).to(
201 | x.device) if flow_init is None else flow_init.clone()
202 |
203 | delta_flow = flow_total
204 | flow_history = torch.zeros(B, 0, 2, H, W).to(x.device)
205 |
206 | for _ in range(deblur_iters):
207 | x_deblur = self.deblur_tensor(x, delta_flow)
208 | x = torch.stack([x_deblur, x_deblur], dim=1)
209 |
210 | if flow_init is not None and self.cnet is not None:
211 | net = self.cnet(flow_total)
212 | else:
213 | net = torch.zeros(
214 | (B, self.hidden_dim, H//8, W//8)).to(x.device)
215 | for i, slice in enumerate(x.permute(2, 0, 1, 3, 4)):
216 | f = self.fnet(slice)
217 | net = self.update_net(net, f)
218 |
219 | dflow = self.update_net.compute_deltaflow(net)
220 | up_mask = self.update_net.compute_up_mask(net)
221 | delta_flow = self.upsample_flow(dflow, up_mask)
222 | flow_total = flow_total + delta_flow
223 |
224 | flow_history = torch.cat(
225 | [flow_history, flow_total.unsqueeze(1)], dim=1)
226 |
227 |
228 | return {'final_prediction': flow_total,
229 | 'flow_history': flow_history}
230 |
231 |
232 | class RecIDE(IDEDEQIDO):
233 | def __init__(self, *args, **kwargs):
234 | super().__init__(*args, **kwargs)
235 |
236 | def forward(self, batch, flow_init=None, deblur_iters=None):
237 | deblur_iters = self.deblur_iters if deblur_iters is None else deblur_iters
238 |
239 | flow_trajectory = []
240 | flow_next_trajectory = []
241 |
242 | for t, x in enumerate(batch):
243 | out = super().forward(x, flow_init=flow_init)
244 | flow_pred = out['final_prediction']
245 | if 'next_flow' in out:
246 | flow_next = out['next_flow']
247 | flow_init = flow_next
248 | flow_next_trajectory.append(flow_next)
249 | else:
250 | flow_init = self.forward_flow(flow_pred)
251 |
252 | flow_trajectory.append(flow_pred)
253 |
254 | if (t+1) % 4 == 0:
255 | flow_init = flow_init.detach()
256 | yield {'final_prediction': flow_pred,
257 | 'flow_trajectory': flow_trajectory,
258 | 'flow_next_trajectory': flow_next_trajectory, }
259 | flow_trajectory = []
260 | flow_next_trajectory = []
261 |
262 | def forward_inference(self, batch, flow_init=None, deblur_iters=None):
263 | deblur_iters = self.deblur_iters if deblur_iters is None else deblur_iters
264 |
265 |
266 | flow_trajectory = []
267 |
268 | for t, x in enumerate(batch):
269 | out = super().forward(x, flow_init=flow_init)
270 | flow_pred = out['final_prediction']
271 | flow_init = self.forward_flow(flow_pred)
272 | flow_trajectory.append(flow_pred)
273 |
274 | return {'final_prediction': flow_pred,
275 | 'flow_trajectory': flow_trajectory}
276 |
277 | def backward_neg_flow(self, x):
278 | x["event_volume_new"] = -torch.flip(x["event_volume_new"], [1])
279 | back_flow = -super().forward(x)['final_prediction']
280 | return back_flow
281 |
--------------------------------------------------------------------------------
/idn/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def sparse_l1_seq(estimated, ground_truth, valid_mask=None):
5 | assert isinstance(estimated, list)
6 | assert isinstance(ground_truth, list)
7 | if valid_mask is not None:
8 | assert isinstance(valid_mask, list)
9 | assert len(estimated) == len(ground_truth) == len(valid_mask)
10 | loss = 0.
11 | for i in range(len(estimated)):
12 | loss += sparse_l1(estimated[i], ground_truth[i], valid_mask[i])
13 | return loss/len(estimated)
14 |
15 |
16 | def sparse_l1(estimated, ground_truth, valid_mask=None):
17 | """Return L1 loss.
18 | This loss ignores difference in pixels where mask is False.
19 | If all pixels are marked as False, the loss is equal to zero.
20 | Args:
21 | estimated: is tensor with predicted values of size
22 | batch_size x height x width.
23 | ground_truth: is tensor with ground truth values of
24 | size batch_size x height x width.
25 | mask: mask of size batch_size x height x width. Only
26 | pixels with True values are used. If "valid_mask"
27 | is None, than we use all pixels.
28 | """
29 | if valid_mask is not None:
30 | valid_mask = valid_mask.bool()
31 | pixelwise_diff = (estimated - ground_truth).abs()
32 | if valid_mask is not None:
33 | if valid_mask.size() == pixelwise_diff.size():
34 | pixelwise_diff = pixelwise_diff[valid_mask]
35 | else:
36 | try:
37 | pixelwise_diff = pixelwise_diff[valid_mask.expand(
38 | pixelwise_diff.size())]
39 | except:
40 | raise Exception("Mask auto expand failed.")
41 | if pixelwise_diff.numel() == 0:
42 | return torch.Tensor([0]).type(estimated.type())
43 | return pixelwise_diff.mean()
44 |
45 |
46 | def sparse_lnorm(order, estimated, ground_truth, valid_mask=None, per_frame=False):
47 | """Return L1 loss.
48 | This loss ignores difference in pixels where mask is False.
49 | If all pixels are marked as False, the loss is equal to zero.
50 | Args:
51 | estimated: is tensor with predicted values of size
52 | batch_size x height x width.
53 | ground_truth: is tensor with ground truth values of
54 | size batch_size x height x width.
55 | mask: mask of size batch_size x height x width. Only
56 | pixels with True values are used. If "valid_mask"
57 | is None, than we use all pixels.
58 | """
59 | if valid_mask is not None:
60 | valid_mask = valid_mask.bool()
61 | pixelwise_diff = torch.norm(
62 | estimated - ground_truth, p=order, keepdim=True, dim=(1))
63 | # make sure valid_mask is the same shape as diff
64 | if valid_mask is not None:
65 | if valid_mask.size() != pixelwise_diff.size():
66 | try:
67 | valid_mask = valid_mask.expand(pixelwise_diff.size())
68 | except:
69 | raise Exception("Mask auto expand failed.")
70 | if per_frame:
71 | error = []
72 | if valid_mask is not None:
73 | for diff, mask in zip(pixelwise_diff, valid_mask):
74 | error.append(diff[mask].mean().item())
75 | else:
76 | error = [e.mean().item() for e in pixelwise_diff]
77 | emap = estimated - ground_truth
78 | emask = valid_mask.expand(emap.size())
79 | emap[~emask] = 0
80 | return {
81 | "metric": error,
82 | "t_emap": emap
83 | }
84 | else:
85 | if valid_mask is not None:
86 | pixelwise_diff = pixelwise_diff[valid_mask]
87 | if pixelwise_diff.numel() == 0:
88 | return torch.Tensor([0]).type(estimated.type())
89 | return pixelwise_diff.mean()
90 |
91 |
92 | def charbonnier_loss(delta, alpha=0.45, epsilon=1e-3):
93 | """
94 | Robust Charbonnier loss, as defined in equation (4) of the paper.
95 | """
96 | loss = torch.mean(torch.pow((delta ** 2 + epsilon ** 2), alpha))
97 | return loss
98 |
99 |
100 | def compute_smoothness_loss(flow):
101 | """
102 | Local smoothness loss, as defined in equation (5) of the paper.
103 | The neighborhood here is defined as the 8-connected region around each pixel.
104 | """
105 | flow_ucrop = flow[..., 1:]
106 | flow_dcrop = flow[..., :-1]
107 | flow_lcrop = flow[..., 1:, :]
108 | flow_rcrop = flow[..., :-1, :]
109 |
110 | flow_ulcrop = flow[..., 1:, 1:]
111 | flow_drcrop = flow[..., :-1, :-1]
112 | flow_dlcrop = flow[..., :-1, 1:]
113 | flow_urcrop = flow[..., 1:, :-1]
114 |
115 | smoothness_loss = charbonnier_loss(flow_lcrop - flow_rcrop) + \
116 | charbonnier_loss(flow_ucrop - flow_dcrop) + \
117 | charbonnier_loss(flow_ulcrop - flow_drcrop) + \
118 | charbonnier_loss(flow_dlcrop - flow_urcrop)
119 | smoothness_loss /= 4.
120 |
121 | return smoothness_loss
122 |
123 |
124 | def compute_npe(n, estimated, ground_truth, valid_mask=None):
125 | if valid_mask is not None:
126 | valid_mask = valid_mask.bool()
127 | pixelwise_diff = torch.norm(
128 | estimated - ground_truth, p=2, keepdim=True, dim=(1))
129 |
130 | # make sure valid_mask is the same shape as diff
131 | if valid_mask is not None:
132 | if valid_mask.size() != pixelwise_diff.size():
133 | try:
134 | valid_mask = valid_mask.expand(pixelwise_diff.size())
135 | except:
136 | raise Exception("Mask auto expand failed.")
137 |
138 | if valid_mask is not None:
139 | pixelwise_diff = pixelwise_diff[valid_mask]
140 | if pixelwise_diff.numel() == 0:
141 | return torch.Tensor([0]).type(estimated.type())
142 | return {
143 | "metric": torch.numel(pixelwise_diff[pixelwise_diff >= n]) /
144 | torch.numel(pixelwise_diff)
145 | }
146 |
--------------------------------------------------------------------------------
/idn/model/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead2(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead2, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 | self.tanh = nn.Tanh()
13 |
14 | def forward(self, x):
15 | return 10*self.tanh(self.conv2(self.relu(self.conv1(x))))
16 |
17 |
18 | class FlowHead(nn.Module):
19 | def __init__(self, input_dim=128, hidden_dim=256):
20 | super(FlowHead, self).__init__()
21 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
22 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
23 | self.relu = nn.ReLU(inplace=True)
24 |
25 | def forward(self, x):
26 | return self.conv2(self.relu(self.conv1(x)))
27 |
28 |
29 | class ConvGRU(nn.Module):
30 | def __init__(self, hidden_dim=128, input_dim=192+128):
31 | super(ConvGRU, self).__init__()
32 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
33 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
34 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
35 |
36 | def forward(self, h, x):
37 | hx = torch.cat([h, x], dim=1)
38 |
39 | z = torch.sigmoid(self.convz(hx))
40 | r = torch.sigmoid(self.convr(hx))
41 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
42 |
43 | h = (1-z) * h + z * q
44 | return h
45 |
46 |
47 | class LiteUpdateBlock(nn.Module):
48 | def __init__(self, hidden_dim=32, input_dim=16, num_outputs=1, downsample=8):
49 | super(LiteUpdateBlock, self).__init__()
50 | self.upsample_mask_dim = downsample * downsample
51 | self.num_outputs = num_outputs
52 | assert self.num_outputs in [1, 2]
53 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=input_dim)
54 | self.flow_head = FlowHead(hidden_dim, hidden_dim=hidden_dim)
55 | self.mask = nn.Sequential(
56 | nn.Conv2d(hidden_dim, 256, 3, padding=1),
57 | nn.ReLU(inplace=True),
58 | nn.Conv2d(256, self.upsample_mask_dim*9, 1, padding=0))
59 | if self.num_outputs == 2:
60 | self.flow_head2 = FlowHead(hidden_dim, hidden_dim=hidden_dim)
61 | self.mask2 = nn.Sequential(
62 | nn.Conv2d(hidden_dim, 256, 3, padding=1),
63 | nn.ReLU(inplace=True),
64 | nn.Conv2d(256, self.upsample_mask_dim*9, 1, padding=0))
65 |
66 | def forward(self, net, inp):
67 | return self.gru(net, inp)
68 |
69 | def compute_deltaflow(self, net):
70 | return self.flow_head(net)
71 |
72 | def compute_nextflow(self, net):
73 | if self.num_outputs == 2:
74 | return self.flow_head2(net)
75 | else:
76 | raise NotImplementedError
77 |
78 | def compute_up_mask(self, net):
79 | return self.mask(net)
80 |
81 | def compute_up_mask2(self, net):
82 | if self.num_outputs == 2:
83 | return self.mask2(net)
84 | else:
85 | raise NotImplementedError
86 |
--------------------------------------------------------------------------------
/idn/scripts/format_mvsec/eval_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # import tensorflow as tf
3 | import numpy as np
4 | import math
5 | import cv2
6 |
7 |
8 | """
9 | Propagates x_indices and y_indices by their flow, as defined in x_flow, y_flow.
10 | x_mask and y_mask are zeroed out at each pixel where the indices leave the image.
11 | The optional scale_factor will scale the final displacement.
12 | """
13 | def prop_flow(x_flow, y_flow, x_indices, y_indices, x_mask, y_mask, scale_factor=1.0):
14 | flow_x_interp = cv2.remap(x_flow,
15 | x_indices,
16 | y_indices,
17 | cv2.INTER_NEAREST)
18 |
19 | flow_y_interp = cv2.remap(y_flow,
20 | x_indices,
21 | y_indices,
22 | cv2.INTER_NEAREST)
23 |
24 | x_mask[flow_x_interp == 0] = False
25 | y_mask[flow_y_interp == 0] = False
26 |
27 | x_indices += flow_x_interp * scale_factor
28 | y_indices += flow_y_interp * scale_factor
29 |
30 | return
31 |
32 | """
33 | The ground truth flow maps are not time synchronized with the grayscale images. Therefore, we
34 | need to propagate the ground truth flow over the time between two images.
35 | This function assumes that the ground truth flow is in terms of pixel displacement, not velocity.
36 | Pseudo code for this process is as follows:
37 | x_orig = range(cols)
38 | y_orig = range(rows)
39 | x_prop = x_orig
40 | y_prop = y_orig
41 | Find all GT flows that fit in [image_timestamp, image_timestamp+image_dt].
42 | for all of these flows:
43 | x_prop = x_prop + gt_flow_x(x_prop, y_prop)
44 | y_prop = y_prop + gt_flow_y(x_prop, y_prop)
45 | The final flow, then, is x_prop - x-orig, y_prop - y_orig.
46 | Note that this is flow in terms of pixel displacement, with units of pixels, not pixel velocity.
47 | Inputs:
48 | x_flow_in, y_flow_in - list of numpy arrays, each array corresponds to per pixel flow at
49 | each timestamp.
50 | gt_timestamps - timestamp for each flow array.
51 | start_time, end_time - gt flow will be estimated between start_time and end time.
52 | """
53 | def estimate_corresponding_gt_flow(x_flow_in,
54 | y_flow_in,
55 | gt_timestamps,
56 | start_time,
57 | end_time):
58 | # Each gt flow at timestamp gt_timestamps[gt_iter] represents the displacement between
59 | # gt_iter and gt_iter+1.
60 | gt_iter = np.searchsorted(gt_timestamps, start_time, side='left')
61 | try:
62 | gt_timestamps[gt_iter+1]
63 | except IndexError:
64 | return None, None
65 | gt_dt = gt_timestamps[gt_iter+1] - gt_timestamps[gt_iter]
66 | x_flow = np.squeeze(x_flow_in[gt_iter, ...])
67 | y_flow = np.squeeze(y_flow_in[gt_iter, ...])
68 |
69 | dt = end_time - start_time
70 |
71 | # No need to propagate if the desired dt is shorter than the time between gt timestamps.
72 | if gt_dt > dt:
73 | return x_flow * dt / gt_dt, y_flow * dt / gt_dt
74 |
75 | x_indices, y_indices = np.meshgrid(np.arange(x_flow.shape[1]),
76 | np.arange(x_flow.shape[0]))
77 | x_indices = x_indices.astype(np.float32)
78 | y_indices = y_indices.astype(np.float32)
79 |
80 | orig_x_indices = np.copy(x_indices)
81 | orig_y_indices = np.copy(y_indices)
82 |
83 | # Mask keeps track of the points that leave the image, and zeros out the flow afterwards.
84 | x_mask = np.ones(x_indices.shape, dtype=bool)
85 | y_mask = np.ones(y_indices.shape, dtype=bool)
86 |
87 | scale_factor = (gt_timestamps[gt_iter+1] - start_time) / gt_dt
88 | total_dt = gt_timestamps[gt_iter+1] - start_time
89 |
90 | prop_flow(x_flow, y_flow,
91 | x_indices, y_indices,
92 | x_mask, y_mask,
93 | scale_factor=scale_factor)
94 |
95 | gt_iter += 1
96 |
97 | while gt_timestamps[gt_iter+1] < end_time:
98 | x_flow = np.squeeze(x_flow_in[gt_iter, ...])
99 | y_flow = np.squeeze(y_flow_in[gt_iter, ...])
100 |
101 | prop_flow(x_flow, y_flow,
102 | x_indices, y_indices,
103 | x_mask, y_mask)
104 | total_dt += gt_timestamps[gt_iter+1] - gt_timestamps[gt_iter]
105 |
106 | gt_iter += 1
107 |
108 | final_dt = end_time - gt_timestamps[gt_iter]
109 | total_dt += final_dt
110 |
111 | final_gt_dt = gt_timestamps[gt_iter+1] - gt_timestamps[gt_iter]
112 |
113 | x_flow = np.squeeze(x_flow_in[gt_iter, ...])
114 | y_flow = np.squeeze(y_flow_in[gt_iter, ...])
115 |
116 | scale_factor = final_dt / final_gt_dt
117 |
118 | prop_flow(x_flow, y_flow,
119 | x_indices, y_indices,
120 | x_mask, y_mask,
121 | scale_factor)
122 |
123 | x_shift = x_indices - orig_x_indices
124 | y_shift = y_indices - orig_y_indices
125 | x_shift[~x_mask] = 0
126 | y_shift[~y_mask] = 0
127 |
128 | return x_shift, y_shift
--------------------------------------------------------------------------------
/idn/scripts/format_mvsec/format_mvsec.py:
--------------------------------------------------------------------------------
1 | import hdf5plugin
2 | import h5py
3 | import numpy as np
4 |
5 | from eval_utils import *
6 | from h5_packager import *
7 |
8 |
9 | def process_events(h5_file, event_file, delta=100000):
10 | print("Processing events...")
11 |
12 | cnt = 0
13 | t0 = None
14 | events = event_file["davis"]["left"]["events"]
15 | while True:
16 | x = events[cnt : cnt + delta, 0].astype(np.int16)
17 | y = events[cnt : cnt + delta, 1].astype(np.int16)
18 | t = events[cnt : cnt + delta, 2].astype(np.float64)
19 | p = events[cnt : cnt + delta, 3]
20 | p[p < 0] = 0
21 | p = p.astype(np.bool_)
22 | if x.shape[0] <= 0:
23 | break
24 | else:
25 | tlast = t[-1]
26 |
27 | if t0 is None:
28 | t0 = t[0]
29 |
30 | h5_file.package_events(x, y, t, p)
31 | cnt += delta
32 |
33 | return t0, tlast
34 |
35 |
36 | def process_flow(h5_file, gt_flow, event_file, t0, dt=1):
37 | print("Processing flow...")
38 |
39 | group = h5_file.file.create_group("flow/dt=" + str(dt))
40 | ts_table = group.create_dataset("timestamps", (0, 2), dtype=np.float64, maxshape=(None, 2), chunks=True)
41 | flow_cnt = 0
42 | cur_cnt, prev_cnt = 0, 0
43 | cur_t, prev_t = None, None
44 | flow_x, flow_y, ts = gt_flow["x_flow_dist"], gt_flow["y_flow_dist"], gt_flow["timestamps"]
45 |
46 | image_raw_ts = event_file["davis"]["left"]["image_raw_ts"]
47 | for t in range(image_raw_ts.shape[0]):
48 | cur_t = image_raw_ts[t]
49 |
50 | # upsample flow only at the frame timestamps in between gt samples
51 | if cur_t < gt_flow["timestamps"].min():
52 | continue
53 | elif cur_t > gt_flow["timestamps"].max():
54 | break
55 |
56 | # skip dt frames between each gt sample
57 | if cur_cnt - prev_cnt >= dt:
58 |
59 | # interpolate flow
60 | disp_x, disp_y = estimate_corresponding_gt_flow(
61 | flow_x,
62 | flow_y,
63 | ts,
64 | prev_t,
65 | cur_t,
66 | )
67 | if disp_x is None:
68 | return
69 | disp = np.stack([disp_x, disp_y], axis=2)
70 | print(cur_t - t0)
71 | h5_file.package_flow(disp, (prev_t, cur_t), flow_cnt, dt=dt)
72 | h5_file.append(ts_table, np.array([[prev_t, cur_t]]))
73 |
74 | # update counters
75 | flow_cnt += 1
76 | prev_t = cur_t
77 | prev_cnt = cur_cnt
78 |
79 | cur_cnt += 1
80 | if prev_t is None:
81 | prev_t = image_raw_ts[t]
82 |
83 |
84 | if __name__ == "__main__":
85 |
86 | # load data
87 | gt = np.load("/scratch/indoor_flying3_gt_flow_dist.npz")
88 | event_data = h5py.File("/scratch/indoor_flying3_data.hdf5", "r")
89 |
90 |
91 | # initialize h5 file
92 | ep = H5Packager("/scratch/indoor_flying3.h5")
93 |
94 | # process events
95 | t0, tlast = process_events(ep, event_data)
96 | # process flow
97 | # t0 = 1506119776.3833518
98 | # tlast = 1506120429.7728229
99 | process_flow(ep, gt, event_data, t0, dt=1)
100 | process_flow(ep, gt, event_data, t0, dt=4)
101 |
102 | ep.add_metadata(t0, tlast)
103 |
--------------------------------------------------------------------------------
/idn/scripts/format_mvsec/h5_packager.py:
--------------------------------------------------------------------------------
1 | import hdf5plugin
2 | import h5py
3 | import numpy as np
4 |
5 |
6 | class H5Packager:
7 | def __init__(self, output_path):
8 | print("Creating file in {}".format(output_path))
9 | self.output_path = output_path
10 |
11 | self.file = h5py.File(output_path, "w")
12 | self.event_xs = self.file.create_dataset(
13 | "events/xs", (0,), dtype=np.dtype(np.int16), maxshape=(None,), chunks=True,
14 | )
15 | self.event_ys = self.file.create_dataset(
16 | "events/ys", (0,), dtype=np.dtype(np.int16), maxshape=(None,), chunks=True,
17 | )
18 | self.event_ts = self.file.create_dataset(
19 | "events/ts", (0,), dtype=np.dtype(np.float64), maxshape=(None,), chunks=True,
20 | )
21 | self.event_ps = self.file.create_dataset(
22 | "events/ps", (0,), dtype=np.dtype(np.bool_), maxshape=(None,), chunks=True,
23 | )
24 |
25 | def append(self, dataset, data):
26 | dataset.resize(dataset.shape[0] + len(data), axis=0)
27 | if len(data) == 0:
28 | return
29 | dataset[-len(data) :] = data[:]
30 |
31 | def package_events(self, xs, ys, ts, ps):
32 | self.append(self.event_xs, xs)
33 | self.append(self.event_ys, ys)
34 | self.append(self.event_ts, ts)
35 | self.append(self.event_ps, ps)
36 |
37 | def package_flow(self, flowmap, timestamp, flow_idx, dt=1):
38 | flowmap_dset = self.file.create_dataset(
39 | "flow/dt=" + str(dt) + "/" + "{:09d}".format(flow_idx), data=flowmap, dtype=np.dtype(np.float64),
40 | )
41 | flowmap_dset.attrs["size"] = flowmap.shape
42 | flowmap_dset.attrs["timestamp_from"] = timestamp[0]
43 | flowmap_dset.attrs["timestamp_to"] = timestamp[1]
44 |
45 | def add_metadata(self, t0, tlast):
46 | self.file.attrs["t0"] = t0
47 | self.file.attrs["tk"] = tlast
48 | self.file.attrs["duration"] = tlast - t0
49 |
--------------------------------------------------------------------------------
/idn/tests/dsec.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.data import DataLoader
3 | from ..loader.loader_dsec import assemble_dsec_sequences, assemble_dsec_test_set, rec_train_collate, \
4 | train_collate
5 | from ..loader.loader_mvsec import MVSEC
6 | from .test import Test
7 | import torch
8 | from tqdm import tqdm
9 | from ..utils.helper_functions import move_batch_to_cuda
10 | import time
11 |
12 |
13 | class TestCO(Test):
14 | def __init__(self, test_spec):
15 | super().__init__(test_spec)
16 |
17 | def configure_dataloader(self):
18 | self.spec.data_loader.args.batch_size = 1
19 | return super().configure_dataloader()
20 |
21 | @staticmethod
22 | def configure_model_forward_fn(model):
23 | return super(type(model), model).forward
24 |
25 | @staticmethod
26 | def configure_forward_pass_fn(model_forward_fn):
27 | return lambda list_batch: model_forward_fn(list_batch[0])
28 |
29 | def configure_model(self, model):
30 | self.original_co_mode = model.co_mode
31 | model.co_mode = True
32 | model.reset_continuous_flow()
33 |
34 | def cleanup_model(self, model):
35 | model.co_mode = self.original_co_mode
36 | model.reset_continuous_flow()
37 |
38 |
39 | class TestRESET(Test):
40 | def __init__(self, test_spec):
41 | super().__init__(test_spec)
42 |
43 | @staticmethod
44 | def configure_model_forward_fn(model):
45 | return model.forward_inference
46 |
47 | def configure_dataloader(self):
48 | return super().configure_dataloader()
49 |
50 | def configure_model(self, model):
51 | self.original_co_mode = model.co_mode
52 | model.co_mode = False
53 |
54 | def cleanup_model(self, model):
55 | model.co_mode = self.original_co_mode
56 |
57 |
58 | class TestNONREC(Test):
59 | def __init__(self, test_spec):
60 | super().__init__(test_spec)
61 |
62 | def configure_dataloader(self):
63 | return super().configure_dataloader()
64 |
65 |
66 | def assemble_dsec_test_cls(test_type_name=None):
67 | test_cls = {
68 | "co": TestCO,
69 | "reset": TestRESET,
70 | }
71 | test_type = test_cls.get(test_type_name, Test)
72 |
73 | class TestDSEC(test_type):
74 |
75 | # DSEC Test set
76 | def configure_dataloader(self):
77 | test_set = assemble_dsec_test_set(self.spec.dataset.common.test_root,
78 | seq_len=self.spec.dataset.val.get(
79 | "sequence_length", None),
80 | representation_type=self.spec.dataset.get("representation_type", None))
81 | if isinstance(test_set, list):
82 | val_dataloader = [DataLoader(
83 | seq, batch_size=1, shuffle=False, num_workers=0,
84 | ) for seq in test_set]
85 | else:
86 | val_dataloader = DataLoader(
87 | test_set, batch_size=1, shuffle=False, num_workers=0,
88 | )
89 | return val_dataloader
90 |
91 | @torch.no_grad()
92 | def execute_test(self, model_eval, save_all=False):
93 | try:
94 | with self.evaluate_model(model_eval) as model:
95 | model_forward_fn = self.configure_model_forward_fn(model)
96 | forward_pass_fn = self.configure_forward_pass_fn(
97 | model_forward_fn)
98 | with self.logger.log_test(model) as log_path:
99 | if not isinstance(self.data_loader, list):
100 | self.data_loader = [self.data_loader]
101 | for seq_loader in self.data_loader:
102 | with self.logger.record_sequence(seq_loader.dataset.seq_name,
103 | log_path) as rec:
104 | for idx, batch in enumerate(tqdm(seq_loader, position=1)):
105 | if isinstance(batch, list):
106 | assert 'save_submission' in batch[-1]
107 | else:
108 | assert 'save_submission' in batch
109 | # for non-recurrent loading, we skip samples
110 | # not for eval
111 | if not batch['save_submission'].cpu().item() and not save_all:
112 | continue
113 |
114 | batch = move_batch_to_cuda(
115 | batch, self.device)
116 | start_time = time.perf_counter()
117 |
118 | out = forward_pass_fn(batch)
119 | end_time = time.perf_counter()
120 | #print(f"Forward pass time: {(end_time - start_time):.4f}")
121 | self.evaluator.evaluate(batch, out, idx)
122 | rec.log_tensors(batch, out, idx, save_all)
123 | rec.log_metrics(self.evaluator.results)
124 |
125 | self.pack_submission_to_zip(log_path)
126 |
127 | return 0, self.logger.summary()
128 | except Exception as e:
129 | return e, None
130 |
131 | @staticmethod
132 | def pack_submission_to_zip(log_path):
133 | import zipfile
134 | import os
135 | import tempfile
136 | from pathlib import Path
137 | from ..check_submission import check_submission
138 | with zipfile.ZipFile(os.path.join(log_path, 'submission.zip'), 'w') as zip:
139 | for seq in os.listdir(log_path):
140 | if os.path.isdir(os.path.join(log_path, seq, "submission")):
141 | for file in os.listdir(os.path.join(log_path, seq, "submission")):
142 | zip.write(os.path.join(
143 | log_path, seq, "submission", file), os.path.join(seq, file))
144 | with tempfile.TemporaryDirectory() as tempdir:
145 | zipfile.ZipFile(os.path.join(
146 | log_path, 'submission.zip')).extractall(tempdir)
147 | assert check_submission(
148 | Path(tempdir), Path("data/test_forward_optical_flow_timestamps")), \
149 | "submission did not pass check"
150 | return TestDSEC
151 |
152 | class TestMVSEC(Test):
153 | def __init__(self, test_spec):
154 | super().__init__(test_spec)
155 |
156 | def configure_dataloader(self):
157 | valid_set = MVSEC("outdoor_day1", filter=(4356, 4706), num_bins=15, augment=False)
158 |
159 | collate_fn = rec_train_collate if self.spec.dataset.val.get("recurrent", False) \
160 | else train_collate
161 | assert self.spec.data_loader.args.shuffle is False, \
162 | "shuffle must be false for val run."
163 | if isinstance(valid_set, list):
164 | val_dataloader = [DataLoader(
165 | seq, collate_fn=collate_fn, **self.spec.data_loader.args
166 | ) for seq in valid_set]
167 | else:
168 | val_dataloader = DataLoader(
169 | valid_set, collate_fn=collate_fn, **self.spec.data_loader.args
170 | )
171 | return val_dataloader
172 |
--------------------------------------------------------------------------------
/idn/tests/eval.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from ..utils.retrieval_fn import get_retreival_fn
3 | from ..model.loss import sparse_lnorm, compute_npe
4 |
5 | fm = namedtuple("frame_metric", ["n_frame", "value"])
6 |
7 |
8 | class Evaluator:
9 | def __init__(self, spec):
10 | spec = dict() if spec is None else spec
11 | self.spec = spec
12 | self.assemble_eval_fn()
13 |
14 | def evaluate(self, batch, out, idx):
15 | for quantity, metrics in self.spec.items():
16 | for metric in metrics:
17 | try:
18 | q = self.quantity_retrieval_fn[quantity](out, batch)
19 | except KeyError:
20 | continue
21 | eval_result = self.eval_fn[metric](*q)
22 | eval_metric = self.extract_metric(eval_result)
23 | self.results[quantity][metric].append(fm(idx, eval_metric))
24 |
25 | def assemble_eval_fn(self):
26 | self.results = self.initialize_metrics_dict(self.spec)
27 | self.eval_fn = dict()
28 | self.quantity_retrieval_fn = dict()
29 | for quantity, metrics in self.spec.items():
30 | if quantity not in self.quantity_retrieval_fn:
31 | self.quantity_retrieval_fn[quantity] = get_retreival_fn(
32 | quantity)
33 | for metric in metrics:
34 | if metric not in self.eval_fn:
35 | self.eval_fn[metric] = self.get_eval_fn(metric)
36 |
37 |
38 | @staticmethod
39 | def initialize_metrics_dict(spec):
40 | results = dict()
41 | for quantity, metrics in spec.items():
42 | results[quantity] = {metric: [] for metric in metrics}
43 | return results
44 |
45 | @staticmethod
46 | def get_eval_fn(metric):
47 | if metric == "L1":
48 | return lambda estimate, ground_truth: \
49 | sparse_lnorm(1, estimate, ground_truth.frame, ground_truth.mask,
50 | per_frame=True)
51 | if metric == "L2":
52 | return lambda estimate, ground_truth: \
53 | sparse_lnorm(2, estimate, ground_truth.frame, ground_truth.mask,
54 | per_frame=True)
55 | if metric == "1PE":
56 | return lambda estimate, ground_truth: \
57 | compute_npe(1, estimate, ground_truth.frame, ground_truth.mask)
58 | if metric == "3PE":
59 | return lambda estimate, ground_truth: \
60 | compute_npe(3, estimate, ground_truth.frame, ground_truth.mask)
61 |
62 | @staticmethod
63 | def extract_metric(eval_result):
64 | assert "metric" in eval_result, "metric not found in eval result"
65 | if isinstance(eval_result["metric"], list):
66 | assert len(eval_result["metric"]) == 1, "multiple metrics found"
67 | return eval_result["metric"][0]
68 | else:
69 | return eval_result["metric"]
70 |
--------------------------------------------------------------------------------
/idn/tests/test.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from ..utils.logger import Logger
3 | from ..loader.loader_dsec import assemble_dsec_sequences, rec_train_collate, \
4 | train_collate
5 | from ..utils.helper_functions import move_batch_to_cuda
6 | from tqdm import tqdm
7 | import torch
8 | from torch.utils.data import DataLoader
9 | from contextlib import contextmanager
10 | from .eval import Evaluator
11 |
12 |
13 | class Test:
14 | fm = namedtuple("frame_metric", ["n_frame", "value"])
15 |
16 | def __init__(self, test_spec):
17 | self.spec = test_spec
18 | self.device = self.spec.data_loader.gpu
19 | self.data_loader = self.configure_dataloader()
20 | self.evaluator = Evaluator(test_spec.metrics)
21 | self.logger = Logger(self.spec.logger, self.spec.name)
22 |
23 |
24 | def configure_dataloader(self):
25 | valid_set = assemble_dsec_sequences(
26 | self.spec.dataset.common.data_root,
27 | include_seq=self.spec.dataset.val.seq,
28 | require_gt=True,
29 | config=self.spec.dataset.val,
30 | representation_type=self.spec.dataset.get(
31 | "representation_type", None),
32 | )
33 | collate_fn = rec_train_collate if self.spec.dataset.val.get("recurrent", False) \
34 | else train_collate
35 | assert self.spec.data_loader.args.shuffle is False, \
36 | "shuffle must be false for val run."
37 | if isinstance(valid_set, list):
38 | val_dataloader = [DataLoader(
39 | seq, collate_fn=collate_fn, **self.spec.data_loader.args
40 | ) for seq in valid_set]
41 | else:
42 | val_dataloader = DataLoader(
43 | valid_set, collate_fn=collate_fn, **self.spec.data_loader.args
44 | )
45 | return val_dataloader
46 |
47 | def assemble_postprocess_fn(self, postprocess):
48 | pass
49 | return None
50 |
51 | @staticmethod
52 | def configure_model_forward_fn(model):
53 | return model.forward
54 |
55 | @staticmethod
56 | def configure_forward_pass_fn(forward_fn):
57 | return forward_fn
58 |
59 | def configure_model(self, model):
60 | pass
61 |
62 | def cleanup_model(self, model):
63 | pass
64 |
65 | @contextmanager
66 | def evaluate_model(self, model):
67 | istrain = model.training
68 | original_device = next(model.parameters()).device
69 | try:
70 | model.cuda(self.device)
71 | self.configure_model(model)
72 | model.eval()
73 | yield model
74 | finally:
75 | self.cleanup_model(model)
76 | model.to(original_device)
77 | model.train(mode=istrain)
78 |
79 | @torch.no_grad()
80 | def execute_test(self, model_eval):
81 | try:
82 | with self.evaluate_model(model_eval) as model:
83 | model_forward_fn = self.configure_model_forward_fn(model)
84 | forward_pass_fn = self.configure_forward_pass_fn(
85 | model_forward_fn)
86 | with self.logger.log_test(model) as log_path:
87 | if not isinstance(self.data_loader, list):
88 | self.data_loader = [self.data_loader]
89 | for seq_loader in self.data_loader:
90 | with self.logger.record_sequence(seq_loader.dataset.seq_name,
91 | log_path) as rec:
92 | for idx, batch in enumerate(tqdm(seq_loader, position=1)):
93 | batch = move_batch_to_cuda(batch, self.device)
94 | out = forward_pass_fn(batch)
95 | self.evaluator.evaluate(batch, out, idx)
96 | rec.log_tensors(batch, out, idx)
97 | rec.log_metrics(self.evaluator.results)
98 |
99 | return 0, self.logger.summary()
100 | except Exception as e:
101 | return e, None
102 |
--------------------------------------------------------------------------------
/idn/train.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | import hydra
3 | from .utils.trainer import Trainer
4 |
5 | # @hydra.main(config_path="config", config_name="mvsec_train")
6 | # @hydra.main(config_path="config", config_name="tid_train")
7 | @hydra.main(config_path="config", config_name="id_train")
8 |
9 | def main(config):
10 | print(OmegaConf.to_yaml(config))
11 |
12 | trainer = Trainer(config)
13 |
14 | print("Number of parameters: ", sum(p.numel()
15 | for p in trainer.model.parameters() if p.requires_grad))
16 |
17 | trainer.fit()
18 |
19 |
20 | if __name__ == '__main__':
21 | main()
22 |
--------------------------------------------------------------------------------
/idn/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | class CallbackBridge:
2 | def __init__(self):
3 | self.callbacks = list()
4 |
5 | def configure_callbacks(self, config):
6 | def get_cb_from_name(callback):
7 | if callback == "logger":
8 | from .cb.logger import CBLogger
9 | return CBLogger
10 | elif callback == "validator":
11 | from .cb.validator import CBValidator
12 | return CBValidator
13 | if config is not None:
14 | for callback, callback_config in config.items():
15 | if "enable" in callback_config.keys():
16 | self.callbacks.append(
17 | get_cb_from_name(callback)(callback_config))
18 |
19 | def execute_callbacks(self, callback_type):
20 | for callback in sorted(self.callbacks,
21 | key=lambda x: x.call_order[callback_type]):
22 | getattr(callback, callback_type)(self)
23 |
24 |
25 | class Callback:
26 | callback_types = [
27 | "on_init_end",
28 | "on_train_begin",
29 | "on_train_end",
30 | "on_epoch_begin",
31 | "on_epoch_end",
32 | "on_batch_begin",
33 | "on_batch_end",
34 | "on_step_begin",
35 | "on_step_end",
36 | ]
37 |
38 | def __init__(self):
39 | self.call_order = dict.fromkeys(self.callback_types, 0)
40 |
41 | def on_init_end(self, caller):
42 | pass
43 |
44 | def on_train_begin(self, caller):
45 | pass
46 |
47 | def on_train_end(self, caller):
48 | pass
49 |
50 | def on_epoch_begin(self, caller):
51 | pass
52 |
53 | def on_epoch_end(self, caller):
54 | pass
55 |
56 | def on_batch_begin(self, caller):
57 | pass
58 |
59 | def on_batch_end(self, caller):
60 | pass
61 |
62 | def on_step_begin(self, caller):
63 | pass
64 |
65 | def on_step_end(self, caller):
66 | pass
67 |
--------------------------------------------------------------------------------
/idn/utils/cb/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from ..callbacks import Callback
4 |
5 | class CBLogger(Callback):
6 | def __init__(self, logger_config, logger_type="exp_tracker"):
7 | super().__init__()
8 | self.logger_config = logger_config
9 | self.logger = None
10 |
11 | def on_init_end(self, caller):
12 | self.logger = caller.logger
13 | try:
14 | run_id = caller.logged_tracker["run_id"]
15 | except:
16 | run_id = None
17 | self.logger.on_init_end(caller.config, run_id=run_id)
18 |
19 | def on_train_begin(self, caller):
20 | self.logger.on_exp_begin(caller.model)
21 |
22 | def on_batch_end(self, caller):
23 | log_dict = dict()
24 | for key in self.logger_config.log_keys["batch_end"]:
25 | log_dict[key] = getattr(caller, key, None)
26 | self.logger.log_dict_at_step(log_dict)
27 |
28 | def on_epoch_end(self, caller):
29 | if self.logger.log_dir is not None:
30 | torch.save({
31 | 'epoch': caller.epoch,
32 | 'model_state_dict': caller.model.state_dict(),
33 | 'optimizer_state_dict': caller.optimizer.state_dict(),
34 | 'scheduler_state_dict': caller.scheduler.state_dict() \
35 | if caller.scheduler is not None else None,
36 | 'loss': caller.loss,
37 | 'tracker': self.logger.summary(),
38 | }, os.path.join(self.logger.log_dir, "model.ckpt"))
39 |
--------------------------------------------------------------------------------
/idn/utils/cb/validator.py:
--------------------------------------------------------------------------------
1 | from pickle import NONE
2 | from ..callbacks import Callback
3 | from ..validation import Validator
4 | from warnings import warn
5 |
6 | class CBValidator(Callback):
7 | def __init__(self, config):
8 | super().__init__()
9 | self.validator = None
10 | self.config = config
11 | self.logger = None
12 |
13 | def run_validation(self, caller, sanity_check_run=False):
14 | validator = Validator(caller.config.validation)
15 | results = validator(caller.model)
16 | if self.logger is not None and not sanity_check_run:
17 | self.logger.log_dict_at_step(results)
18 | else:
19 | print(results)
20 |
21 | def on_train_begin(self, caller):
22 | self.logger = caller.logger
23 |
24 | def on_batch_end(self, caller):
25 | if caller.step == self.config.get("sanity_run_step", None):
26 | self.run_validation(caller, sanity_check_run=True)
27 | if self.time_to_validate(step=caller.step):
28 | self.run_validation(caller)
29 |
30 | def on_epoch_end(self, caller):
31 | if self.time_to_validate(epoch=caller.epoch):
32 | self.run_validation(caller)
33 |
34 | def time_to_validate(self, step=None, epoch=None):
35 | if self.config.frequency_type == "epoch":
36 | return epoch != 0 and epoch % self.config.frequency == 0 \
37 | if epoch is not None else False
38 | elif self.config.frequency_type == "step":
39 | return step != 0 and step % self.config.frequency == 0 \
40 | if step is not None else False
41 | else:
42 | warn(f"Frequency type {self.config.frequency_type} not recognized.")
43 | return False
44 |
--------------------------------------------------------------------------------
/idn/utils/dsec_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from enum import Enum, auto
4 | from time import time
5 |
6 |
7 | class RepresentationType(Enum):
8 | VOXEL = auto()
9 | STEPAN = auto()
10 |
11 |
12 | class EventRepresentation:
13 | def __init__(self):
14 | pass
15 |
16 | def convert(self, events):
17 | raise NotImplementedError
18 |
19 |
20 | class VoxelGrid(EventRepresentation):
21 | def __init__(self, input_size: tuple, normalize: bool):
22 | assert len(input_size) == 3
23 | self.voxel_grid = torch.zeros(
24 | (input_size), dtype=torch.float, requires_grad=False)
25 | self.nb_channels = input_size[0]
26 | self.normalize = normalize
27 |
28 | def convert(self, events):
29 | C, H, W = self.voxel_grid.shape
30 | with torch.no_grad():
31 | self.voxel_grid = self.voxel_grid.to(events['p'].device)
32 | voxel_grid = self.voxel_grid.clone()
33 |
34 | t_norm = events['t']
35 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0])
36 |
37 | x0 = events['x'].int()
38 | y0 = events['y'].int()
39 | t0 = t_norm.int()
40 |
41 | value = 2*events['p']-1
42 | #start_t = time()
43 | for xlim in [x0, x0+1]:
44 | for ylim in [y0, y0+1]:
45 | for tlim in [t0, t0+1]:
46 |
47 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (
48 | ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels)
49 | interp_weights = value * (1 - (xlim-events['x']).abs()) * (
50 | 1 - (ylim-events['y']).abs()) * (1 - (tlim - t_norm).abs())
51 | index = H * W * tlim.long() + \
52 | W * ylim.long() + \
53 | xlim.long()
54 |
55 | voxel_grid.put_(
56 | index[mask], interp_weights[mask], accumulate=True)
57 |
58 | if self.normalize:
59 | mask = torch.nonzero(voxel_grid, as_tuple=True)
60 | if mask[0].size()[0] > 0:
61 | mean = voxel_grid[mask].mean()
62 | std = voxel_grid[mask].std()
63 | if std > 0:
64 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std
65 | else:
66 | voxel_grid[mask] = voxel_grid[mask] - mean
67 |
68 | return voxel_grid
69 |
70 |
71 | class PolarityCount(EventRepresentation):
72 | def __init__(self, input_size: tuple):
73 | assert len(input_size) == 3
74 | self.voxel_grid = torch.zeros(
75 | (input_size), dtype=torch.float, requires_grad=False)
76 | self.nb_channels = input_size[0]
77 |
78 | def convert(self, events):
79 | C, H, W = self.voxel_grid.shape
80 | with torch.no_grad():
81 | self.voxel_grid = self.voxel_grid.to(events['p'].device)
82 | voxel_grid = self.voxel_grid.clone()
83 |
84 | x0 = events['x'].int()
85 | y0 = events['y'].int()
86 |
87 | #start_t = time()
88 | for xlim in [x0, x0+1]:
89 | for ylim in [y0, y0+1]:
90 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (
91 | ylim >= 0)
92 | interp_weights = (1 - (xlim-events['x']).abs()) * (
93 | 1 - (ylim-events['y']).abs())
94 | index = H * W * events['p'].long() + \
95 | W * ylim.long() + \
96 | xlim.long()
97 |
98 | voxel_grid.put_(
99 | index[mask], interp_weights[mask], accumulate=True)
100 |
101 | return voxel_grid
102 |
103 |
104 | def flow_16bit_to_float(flow_16bit: np.ndarray):
105 | assert flow_16bit.dtype == np.uint16
106 | assert flow_16bit.ndim == 3
107 | h, w, c = flow_16bit.shape
108 | assert c == 3
109 |
110 | valid2D = flow_16bit[..., 2] == 1
111 | assert valid2D.shape == (h, w)
112 | assert np.all(flow_16bit[~valid2D, -1] == 0)
113 | valid_map = np.where(valid2D)
114 |
115 | # to actually compute something useful:
116 | flow_16bit = flow_16bit.astype('float')
117 |
118 | flow_map = np.zeros((h, w, 2))
119 | flow_map[valid_map[0], valid_map[1], 0] = (
120 | flow_16bit[valid_map[0], valid_map[1], 0] - 2 ** 15) / 128
121 | flow_map[valid_map[0], valid_map[1], 1] = (
122 | flow_16bit[valid_map[0], valid_map[1], 1] - 2 ** 15) / 128
123 | return flow_map, valid2D
124 |
--------------------------------------------------------------------------------
/idn/utils/exp_tracker.py:
--------------------------------------------------------------------------------
1 | class ExpTracker:
2 | def __init__(self) -> None:
3 | self.log_dir = None
4 |
5 | def on_init_end(self, *args, **kwargs):
6 | pass
7 |
8 | def on_exp_begin(self, *args, **kwargs):
9 | pass
10 |
11 | def log_dict_at_step(self, dict, step=None):
12 | pass
13 |
14 | def summary(self):
15 | return {
16 | "id": "exp_tracker",
17 | }
18 |
--------------------------------------------------------------------------------
/idn/utils/helper_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 |
5 | def log_metrics(flag, metrics):
6 | log = dict()
7 | for seq in metrics.keys():
8 | log[f"{seq}-l1-avg"] = metrics[seq]["l1_avg"]
9 | log[f"{seq}-l2-avg"] = metrics[seq]["l2_avg"]
10 |
11 |
12 | def move_tensor_to_cuda(dict_tensors, gpu):
13 | assert isinstance(dict_tensors, dict)
14 | for key, value in dict_tensors.items():
15 | if isinstance(value, torch.Tensor):
16 | dict_tensors[key] = value.to(gpu, non_blocking=True)
17 | return dict_tensors
18 |
19 |
20 | def move_dict_to_cuda(dictionary_of_tensors, gpu):
21 | if isinstance(dictionary_of_tensors, dict):
22 | return {
23 | key: move_dict_to_cuda(value, gpu)
24 | for key, value in dictionary_of_tensors.items() if isinstance(value, torch.Tensor)
25 | }
26 | return dictionary_of_tensors.to(gpu, dtype=torch.float)
27 |
28 |
29 | def move_list_to_cuda(list_of_dicts, gpu):
30 | for i in range(len(list_of_dicts)):
31 | list_of_dicts[i] = move_tensor_to_cuda(list_of_dicts[i], gpu)
32 | return list_of_dicts
33 |
34 |
35 | def move_batch_to_cuda(batch, gpu):
36 | if isinstance(batch, dict):
37 | return move_tensor_to_cuda(batch, gpu)
38 | elif isinstance(batch, list):
39 | return move_list_to_cuda(batch, gpu)
40 | else:
41 | raise Exception("Batch is not a list or dict")
42 |
43 |
44 | def get_values_from_key(input_list, key):
45 | # Returns all the values with the same key from
46 | # a list filled with dicts of the same kind
47 | out = []
48 | for i in input_list:
49 | out.append(i[key])
50 | return out
51 |
52 |
53 | def create_save_path(subdir, name):
54 | # Check if sub-folder exists, and create if necessary
55 | if not os.path.exists(subdir):
56 | os.mkdir(subdir)
57 | # Create a new folder (named after the name defined in the config file)
58 | path = os.path.join(subdir, name)
59 | # Check if path already exists. if yes -> append a number
60 | if os.path.exists(path):
61 | i = 1
62 | while os.path.exists(path + "_" + str(i)):
63 | i += 1
64 | path = path + '_' + str(i)
65 | os.mkdir(path)
66 | return path
67 |
68 |
69 | def get_nth_element_of_all_dict_keys(dict, idx):
70 | out_dict = {}
71 | for k in dict.keys():
72 | d = dict[k][idx]
73 | if isinstance(d, torch.Tensor):
74 | out_dict[k] = d.detach().cpu().item()
75 | else:
76 | out_dict[k] = d
77 | return out_dict
78 |
79 |
80 | def get_number_of_saved_elements(path, template, first=1):
81 | i = first
82 | while True:
83 | if os.path.exists(os.path.join(path, template.format(i))):
84 | i += 1
85 | else:
86 | break
87 | return range(first, i)
88 |
89 |
90 | def create_file_path(subdir, name):
91 | # Check if sub-folder exists, else raise exception
92 | if not os.path.exists(subdir):
93 | raise Exception("Path {} does not exist!".format(subdir))
94 | # Check if file already exists, else create path
95 | if not os.path.exists(os.path.join(subdir, name)):
96 | return os.path.join(subdir, name)
97 | else:
98 | path = os.path.join(subdir, name)
99 | prefix, suffix = path.split('.')
100 | i = 1
101 | while os.path.exists("{}_{}.{}".format(prefix, i, suffix)):
102 | i += 1
103 | return "{}_{}.{}".format(prefix, i, suffix)
104 |
105 |
106 | def update_dict(dict_old, dict_new):
107 | # Update all the entries of dict_old with the new values(that have the identical keys) of dict_new
108 | for k in dict_new.keys():
109 | if k in dict_old.keys():
110 | # Replace the entry
111 | if isinstance(dict_new[k], dict):
112 | update_dict(dict_old[k], dict_new[k])
113 | else:
114 | dict_old[k] = dict_new[k]
115 | return dict_old
116 |
--------------------------------------------------------------------------------
/idn/utils/logger.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import numpy as np
4 | import itertools
5 | import imageio
6 | import tempfile
7 | import pathlib
8 | import torch
9 | from contextlib import contextmanager, nullcontext
10 | from ..tests.eval import fm
11 |
12 |
13 | class Logger:
14 | def __init__(self, config, test_name):
15 | self.config = config
16 | self.test_name = test_name
17 | self.metrics = dict()
18 | self.parse_log_fields()
19 |
20 | def parse_log_fields(self):
21 | def parse_index(index):
22 | for i, idx in enumerate(index):
23 | if isinstance(idx, str):
24 | assert '-' in idx
25 | a, b = idx.split('-')
26 | index[i] = list(range(int(a), int(b)+1))
27 | parsed_list = []
28 | for i in index:
29 | # this is because i can be type of omegaconf.listconfig
30 | if hasattr(i, '__len__'):
31 | parsed_list.extend(i)
32 | else:
33 | parsed_list.append(i)
34 | return parsed_list
35 |
36 | if not hasattr(self.config, "saved_tensors"):
37 | self.config.saved_tensors = None
38 | if self.config.saved_tensors is None:
39 | return
40 | for field, seqs in self.config.saved_tensors.items():
41 | if seqs is None:
42 | continue
43 | else:
44 | for seq, idx in seqs.items():
45 | if idx is None:
46 | continue
47 | else:
48 | self.config.saved_tensors[field][seq] = parse_index(
49 | idx)
50 |
51 | class SeqLogger:
52 | def __init__(self, config, log_path, seq_name):
53 | self.config = config
54 | self.path = log_path
55 | self.seq_name = seq_name
56 |
57 | def log_tensors(self, batch, out, idx, save_all=False):
58 | if isinstance(batch, list):
59 | batch = batch[-1]
60 | assert isinstance(out, dict)
61 | if 'save_submission' in batch:
62 | if isinstance(batch['save_submission'], torch.Tensor):
63 | batch['save_submission'] = batch['save_submission'].cpu().item()
64 | if batch['save_submission'] or save_all:
65 | self.save_submission(out['final_prediction'],
66 | batch['file_index'].cpu().item())
67 | if getattr(self.config, "saved_tensors", None) is None:
68 | return
69 | for key, value in itertools.chain(batch.items(), out.items()):
70 | if key in self.config.saved_tensors:
71 | if self.tblogged(self.config.saved_tensors[key], idx):
72 | self.save_tensor(key, value, idx)
73 |
74 | def tblogged(self, logging_index, idx):
75 | if logging_index is None:
76 | return True
77 | elif self.seq_name not in logging_index:
78 | return False
79 | elif logging_index[self.seq_name] is None:
80 | return True
81 | else:
82 | return idx in logging_index[self.seq_name]
83 |
84 | def save_submission(self, flow, file_idx):
85 | os.makedirs(os.path.join(self.path, "submission"), exist_ok=True)
86 | if isinstance(flow, torch.Tensor):
87 | flow = flow.cpu().numpy()
88 | assert flow.shape == (1, 2, 480, 640)
89 | flow = flow.squeeze()
90 | _, h, w = flow.shape
91 | scaled_flow = np.rint(
92 | flow*128 + 2**15).astype(np.uint16).transpose(1, 2, 0)
93 | flow_image = np.concatenate((scaled_flow, np.zeros((h, w, 1),
94 | dtype=np.uint16)), axis=-1)
95 | imageio.imwrite(os.path.join(self.path, "submission", f"{file_idx:06d}.png"),
96 | flow_image, format='PNG-FI')
97 |
98 | def save_tensor(self, key, value, idx):
99 | if isinstance(value, torch.Tensor):
100 | value = value.cpu().numpy()
101 | np.save(os.path.join(self.path, f"{key}_{idx:05d}.npy"), value)
102 | return
103 | if isinstance(value, np.ndarray):
104 | np.save(os.path.join(self.path, f"{key}_{idx:05d}.npy"), value)
105 | return
106 |
107 | def log_metrics(self, results):
108 | self.results = results
109 |
110 | def compute_statistics(self):
111 | self.metric_stats = dict()
112 | for quantity, metrics in self.results.items():
113 | self.metric_stats[quantity] = dict()
114 | for metric, list_fm in metrics.items():
115 | assert isinstance(list_fm[0], fm)
116 | self.metric_stats[quantity][metric] = dict()
117 | metric_list = list(map(lambda x: x.value, list_fm))
118 | self.metric_stats[quantity][metric]['avg'] = sum(
119 | metric_list) / len(metric_list)
120 |
121 | @contextmanager
122 | def record_sequence(self, seq_name, log_path):
123 | try:
124 | os.makedirs(os.path.join(log_path, seq_name), exist_ok=True)
125 | self.seq_logger = self.SeqLogger(
126 | self.config, os.path.join(log_path, seq_name), seq_name)
127 | yield self.seq_logger
128 | finally:
129 | self.seq_logger.compute_statistics()
130 | self.copy_metrics(self.seq_logger.results, self.seq_logger.metric_stats,
131 | seq_name)
132 |
133 | @contextmanager
134 | def log_test(self, model):
135 | log_dir = self.config.get('save_dir', None)
136 | with tempfile.TemporaryDirectory(prefix='id2log', dir='/tmp') \
137 | if log_dir is None else nullcontext(pathlib.Path(log_dir)) as wd:
138 | try:
139 | yield wd
140 | finally:
141 | self.dump(model, wd)
142 |
143 | def copy_metrics(self, results, stats, seq_name):
144 | self.metrics[seq_name] = dict()
145 | self.metrics[seq_name]['results'] = results
146 | self.metrics[seq_name]['stats'] = stats
147 |
148 | def dump(self, model, wd):
149 |
150 | dict_to_dump = {"meta": {
151 | "test_name": self.test_name,
152 | }, **self.metrics}
153 | json.dump(dict_to_dump, open(os.path.join(
154 | wd, 'metrics.json'), 'w'))
155 | torch.save(model.state_dict(), os.path.join(wd, 'model.pt'))
156 |
157 | def summary(self):
158 | summary_flat = dict()
159 | summary = dict()
160 | for seq_name, results in self.metrics.items():
161 | summary[seq_name] = dict()
162 | for quantity, metrics in results["stats"].items():
163 | summary[seq_name][quantity] = metrics
164 | for metric, stats in metrics.items():
165 | for stat, value in stats.items():
166 | assert isinstance(value, (int, float))
167 | summary_flat['-'.join([seq_name,
168 | quantity, metric, stat])] = value
169 | return summary
170 |
171 |
--------------------------------------------------------------------------------
/idn/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | from ..model.loss import sparse_l1
2 |
3 | def get_loss_fn_by_name(loss_name):
4 | if loss_name == 'sparse_l1':
5 | return lambda estimate, ground_truth: sparse_l1(estimate, \
6 | ground_truth.frame, valid_mask=ground_truth.mask)
7 | else:
8 | assert False, f"loss {loss_name} not implemented"
9 |
10 |
11 | def get_valid_loss_fn_by_name(config):
12 | valid_loss_fn = dict()
13 | #
14 | if config.name == 'sparse_l1':
15 | return sparse_l1
16 |
17 | def compute_seq_loss(weight, loss_fn, estimate, ground_truth):
18 | assert isinstance(estimate, list)
19 | if weight == "last":
20 | return loss_fn(estimate[-1], ground_truth[-1])
21 | else:
22 | seq_loss = list(map(loss_fn, estimate, ground_truth))
23 | if weight == "sum":
24 | return sum(seq_loss)
25 |
26 | elif weight == "avg":
27 | return sum(seq_loss)/len(seq_loss)
28 |
29 | elif hasattr(weight, '__getitem__') and isinstance(weight[0], float):
30 | assert len(weight) == len(estimate)
31 | return sum(map(lambda x, y: x*y, seq_loss, weight))
32 | else:
33 | assert False, f"weight {weight} for seq loss not supported"
--------------------------------------------------------------------------------
/idn/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | def get_model_by_name(name, model_config):
2 | if name == "RecIDE":
3 | from ..model.idedeq import RecIDE
4 | return RecIDE(model_config)
5 | elif name == "IDEDEQIDO":
6 | from ..model.idedeq import IDEDEQIDO
7 | return IDEDEQIDO(model_config)
8 | else:
9 | raise ValueError("Unknown model name: {}".format(name))
10 |
--------------------------------------------------------------------------------
/idn/utils/mvsec_utils.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import os
3 | import pandas
4 | from PIL import Image
5 | import random
6 | import h5py
7 | import json
8 |
9 |
10 | class EventSequence(object):
11 | def __init__(self, dataframe, params, features=None, timestamp_multiplier=None, convert_to_relative=False):
12 | if isinstance(dataframe, pandas.DataFrame):
13 | self.feature_names = dataframe.columns.values
14 | self.features = dataframe.to_numpy()
15 | else:
16 | self.feature_names = numpy.array(['ts', 'x', 'y', 'p'], dtype=object)
17 | if features is None:
18 | self.features = numpy.zeros([1, 4])
19 | else:
20 | self.features = features
21 | self.image_height = params['height']
22 | self.image_width = params['width']
23 | if not self.is_sorted():
24 | self.sort_by_timestamp()
25 | if timestamp_multiplier is not None:
26 | self.features[:,0] *= timestamp_multiplier
27 | if convert_to_relative:
28 | self.absolute_time_to_relative()
29 |
30 | def get_sequence_only(self):
31 | return self.features
32 |
33 | def __len__(self):
34 | return len(self.features)
35 |
36 | def __add__(self, sequence):
37 | event_sequence = EventSequence(dataframe=None,
38 | features=numpy.concatenate([self.features, sequence.features]),
39 | params={'height': self.image_height,
40 | 'width': self.image_width})
41 | return event_sequence
42 |
43 | def is_sorted(self):
44 | return numpy.all(self.features[:-1, 0] <= self.features[1:, 0])
45 |
46 | def sort_by_timestamp(self):
47 | if len(self.features[:, 0]) > 0:
48 | sort_indices = numpy.argsort(self.features[:, 0])
49 | self.features = self.features[sort_indices]
50 |
51 | def absolute_time_to_relative(self):
52 | """Transforms absolute time to time relative to the first event."""
53 | start_ts = self.features[:,0].min()
54 | assert(start_ts == self.features[0,0])
55 | self.features[:,0] -= start_ts
56 |
57 |
58 | def get_image(image_path):
59 | try:
60 | im = Image.open(image_path)
61 | # print(image_path)
62 | return numpy.array(im)
63 | except OSError:
64 | raise
65 |
66 |
67 | def get_events(event_path):
68 | # It's possible that there is no event file! (camera standing still)
69 | try:
70 | f = pandas.read_hdf(event_path, "myDataset")
71 | return f[['ts', 'x', 'y', 'p']]
72 | except OSError:
73 | print("No file " + event_path)
74 | print("Creating an array of zeros!")
75 | return 0
76 |
77 |
78 | def get_ts(path, i, type='int'):
79 | try:
80 | f = open(path, "r")
81 | if type == 'int':
82 | return int(f.readlines()[i])
83 | elif type == 'double' or type == 'float':
84 | return float(f.readlines()[i])
85 | except OSError:
86 | raise
87 |
88 |
89 | def get_batchsize(path_dataset):
90 | filepath = os.path.join(path_dataset, "cam0", "timestamps.txt")
91 | try:
92 | f = open(filepath, "r")
93 | return len(f.readlines())
94 | except OSError:
95 | raise
96 |
97 |
98 | def get_batch(path_dataset, i):
99 | return 0
100 |
101 |
102 | def dataset_paths(dataset_name, path_dataset, subset_number=None):
103 | cameras = {'cam0': {}, 'cam1': {}, 'cam2': {}, 'cam3': {}}
104 | if subset_number is not None:
105 | dataset_name = dataset_name + "_" + str(subset_number)
106 | paths = {'dataset_folder': os.path.join(path_dataset, dataset_name)}
107 |
108 | # For every camera, define its path
109 | for camera in cameras:
110 | cameras[camera]['image_folder'] = os.path.join(paths['dataset_folder'], camera, 'image_raw')
111 | cameras[camera]['event_folder'] = os.path.join(paths['dataset_folder'], camera, 'events')
112 | cameras[camera]['disparity_folder'] = os.path.join(paths['dataset_folder'], camera, 'disparity_image')
113 | cameras[camera]['depth_folder'] = os.path.join(paths['dataset_folder'], camera, 'depthmap')
114 | cameras["timestamp_file"] = os.path.join(paths['dataset_folder'], 'cam0', 'timestamps.txt')
115 | cameras["image_type"] = ".png"
116 | cameras["event_type"] = ".h5"
117 | cameras["disparity_type"] = ".png"
118 | cameras["depth_type"] = ".tiff"
119 | cameras["indexing_type"] = "%0.6i"
120 | paths.update(cameras)
121 | return paths
122 |
123 |
124 | def get_indices(path_dataset, dataset, filter, shuffle=False):
125 | samples = []
126 | for dataset_name in dataset:
127 | for subset in dataset[dataset_name]:
128 | # Get all the dataframe paths
129 | paths = dataset_paths(dataset_name, path_dataset, subset)
130 |
131 | # import timestamps
132 | ts = numpy.loadtxt(paths["timestamp_file"])
133 |
134 | # frames = []
135 | # For every timestamp, import according data
136 | for idx in eval(filter[dataset_name][str(subset)]):
137 | frame = {}
138 | frame['dataset_name'] = dataset_name
139 | frame['subset_number'] = subset
140 | frame['index'] = idx
141 | frame['timestamp'] = ts[idx]
142 | samples.append(frame)
143 | # shuffle dataset
144 | if shuffle:
145 | random.shuffle(samples)
146 | return samples
147 |
148 |
149 | def get_flow_h5(flow_path):
150 | scaling_factor = 0.05 # seconds/frame
151 | f = h5py.File(flow_path, 'r')
152 | height, width = int(f['header']['height']), int(f['header']['width'])
153 | assert(len(f['x']) == height*width)
154 | assert(len(f['y']) == height*width)
155 | x = numpy.array(f['x']).reshape([height,width])*scaling_factor
156 | y = numpy.array(f['y']).reshape([height,width])*scaling_factor
157 | return numpy.stack([x,y])
158 |
159 |
160 | def get_flow_npy(flow_path):
161 | # Array 2,height, width
162 | # No scaling needed.
163 | return numpy.load(flow_path, allow_pickle=True)
164 |
165 |
166 | def get_pose(pose_path, index):
167 | pose = pandas.read_csv(pose_path, delimiter=',').loc[index].to_numpy()
168 | # Convert Timestamp to int (as all the other timestamps)
169 | pose[0] = int(pose[0])
170 | return pose
171 |
172 |
173 | def load_config(path, datasets):
174 | config = {}
175 | for dataset_name in datasets:
176 | config[dataset_name] = {}
177 | for subset in datasets[dataset_name]:
178 | name = "{}_{}".format(dataset_name, subset)
179 | try:
180 | config[dataset_name][subset] = json.load(open(os.path.join(path, name, "config.json")))
181 | except:
182 | print("Could not find config file for dataset" + dataset_name + "_" + str(subset) +
183 | ". Please check if the file 'config.json' is existing in the dataset-scene directory")
184 | raise
185 | return config
186 |
--------------------------------------------------------------------------------
/idn/utils/retrieval_fn.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch.nn.functional as F
3 |
4 |
5 | def upflow8(flow, mode='bilinear'):
6 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
7 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
8 |
9 | def get_retreival_fn(quantity):
10 | fmask = namedtuple("masked_frame", ["frame", "mask"])
11 | if quantity == "final_prediction":
12 | return retreival_pred_seq_1
13 | elif quantity == "pred_flow_seq":
14 | return retreival_pred_seq
15 | elif quantity == "final_prediction_nonseq":
16 | return retreival_pred_nonseq
17 | elif quantity == "pred_flow_next_seq":
18 | return retreival_pred_nextflow_seq
19 | elif quantity == "next_flow":
20 | return retreival_next_flow
21 | elif quantity == "pred_lowres_flow_next_seq":
22 | return retreival_pred_lowres_nextflow_seq
23 |
24 | assert False, f"quantity {quantity} not implemented"
25 |
26 |
27 | def retreival_pred_nonseq(out, batch):
28 | fmask = namedtuple("masked_frame", ["frame", "mask"])
29 | return (out["final_prediction"], fmask(batch["flow_gt_event_volume_new"],
30 | batch["flow_gt_event_volume_new_valid_mask"]))
31 |
32 |
33 | def retreival_pred_seq_1(out, batch):
34 | fmask = namedtuple("masked_frame", ["frame", "mask"])
35 | return (out["final_prediction"], fmask(batch[-1]["flow_gt_event_volume_new"],
36 | batch[-1]["flow_gt_event_volume_new_valid_mask"]))
37 |
38 |
39 | def retreival_pred_seq(out, batch):
40 | fmask = namedtuple("masked_frame", ["frame", "mask"])
41 | return (out["flow_trajectory"], [fmask(x["flow_gt_event_volume_new"],
42 | x["flow_gt_event_volume_new_valid_mask"]) for x in batch])
43 |
44 |
45 | def retreival_pred_nextflow_seq(out, batch):
46 | fmask = namedtuple("masked_frame", ["frame", "mask"])
47 | return (out["flow_next_trajectory"], [fmask(x["flow_gt_next"],
48 | x["flow_gt_next_valid_mask"]) for x in batch])
49 |
50 |
51 | def retreival_next_flow(out, batch):
52 | fmask = namedtuple("masked_frame", ["frame", "mask"])
53 | return (out["next_flow"], fmask(batch[-1]["flow_gt_next"],
54 | batch[-1]["flow_gt_next_valid_mask"]))
55 |
56 |
57 | def retreival_pred_lowres_nextflow_seq(out, batch):
58 | fmask = namedtuple("masked_frame", ["frame", "mask"])
59 | return ([upflow8(x) for x in out["flow_next_trajectory"][:-1]], [fmask(x["flow_gt_event_volume_new"],
60 | x["flow_gt_event_volume_new_valid_mask"]) for x in batch[1:]])
61 |
--------------------------------------------------------------------------------
/idn/utils/torch_environ.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 |
5 |
6 | def config_torch(config):
7 | torch.backends.cudnn.benchmark = True
8 | if config.get("debug_grad", False):
9 | torch_set_gradient_debug()
10 | if config.get("deterministic", True):
11 | torch_set_deterministic()
12 |
13 |
14 | def torch_set_deterministic(seed=141021):
15 | torch.manual_seed(seed)
16 | np.random.seed(seed)
17 | random.seed(seed)
18 | torch.backends.cudnn.deterministic = True
19 | torch.backends.cudnn.benchmark = False
20 | # because of .put_ in forward_interpolate_pytorch
21 | torch.use_deterministic_algorithms(False)
22 |
23 |
24 | def torch_set_gradient_debug():
25 | torch.autograd.set_detect_anomaly(True)
26 |
--------------------------------------------------------------------------------
/idn/utils/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import DataLoader, ConcatDataset
4 | from tqdm import tqdm
5 | from types import GeneratorType
6 | from collections import namedtuple
7 | from torchinfo import summary
8 | from torch.optim.lr_scheduler import OneCycleLR
9 |
10 | from .torch_environ import config_torch
11 | from .helper_functions import move_batch_to_cuda
12 | from .model_utils import get_model_by_name
13 | from .loss_utils import compute_seq_loss, get_loss_fn_by_name
14 | from .validation import Validator
15 | from .callbacks import CallbackBridge
16 | from .exp_tracker import ExpTracker
17 | from .retrieval_fn import get_retreival_fn
18 | from ..loader.loader_dsec import (
19 | Sequence,
20 | RepresentationType,
21 | DatasetProvider,
22 | assemble_dsec_sequences,
23 | assemble_dsec_test_set,
24 | train_collate,
25 | rec_train_collate
26 | )
27 | from ..loader.loader_mvsec import (
28 | MVSEC,
29 | MVSECRecurrent
30 | )
31 |
32 |
33 | class Trainer(CallbackBridge):
34 | def __init__(self, config, model=None):
35 | super().__init__()
36 | self.config = config
37 | config_torch(config.torch)
38 | self.model = model if model is not None else \
39 | get_model_by_name(config.model.name, config.model)
40 | if config.model.get("pretrain_ckpt", None):
41 | self.resume_model_from_ckpt(config.model.pretrain_ckpt)
42 |
43 | if not self.config.get("eval_only", False):
44 | self.train_dataloader = self.configure_train_dataloader()
45 | self.configure_loss()
46 | self.optimizer = self.configure_optimizer()
47 | if config.optim.get("scheduler", None):
48 | self.scheduler = OneCycleLR(
49 | self.optimizer, max_lr=self.config.optim.lr,
50 | steps_per_epoch=len(self.train_dataloader),
51 | epochs=self.config.num_epoch,
52 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear',
53 | )
54 | else:
55 | self.scheduler = None
56 |
57 | self.epoch = 0
58 | self.step = 0
59 | self.batches_seen = 0
60 | self.samples_seen = 0
61 | if config.get("resume_ckpt", None):
62 | resume_only_model = config.get("finetune", False)
63 | self.resume_from_ckpt(config.resume_ckpt, resume_only_model)
64 | self.logger = self.configure_tracker()
65 |
66 | self.configure_callbacks(config.callbacks)
67 |
68 | self.execute_callbacks("on_init_end")
69 |
70 | def configure_train_dataloader(self):
71 | if self.config.dataset.dataset_name == "dsec":
72 | train_set = assemble_dsec_sequences(
73 | self.config.dataset.common.data_root,
74 | exclude_seq=set(
75 | [val_seq for x in self.config.get("validation", dict()).values() for val_seq in x.dataset.val.seq]),
76 | require_gt=True,
77 | config=self.config.dataset.train,
78 | representation_type=self.config.dataset.get("representation_type", None),
79 | num_bins=self.config.dataset.get("num_voxel_bins", None)
80 | )
81 |
82 | elif self.config.dataset.dataset_name == "mvsec":
83 | train_set = MVSEC("outdoor_day2", num_bins=self.config.dataset.get("num_voxel_bins", None), dt=None) #20Hz
84 | elif self.config.dataset.dataset_name == "mvsec_recurrent":
85 | train_set = MVSECRecurrent("outdoor_day2", augment=False,
86 | sequence_length=self.config.dataset.train.sequence_length)
87 | else:
88 | raise NotImplementedError
89 | collate_fn = rec_train_collate \
90 | if hasattr(self.config.dataset.train, "recurrent") \
91 | and self.config.dataset.train.recurrent else train_collate
92 | return DataLoader(
93 | train_set, collate_fn=collate_fn, **self.config.data_loader.train.args)
94 |
95 | def configure_optimizer(self):
96 | if self.config.optim.optimizer == "adam":
97 | return torch.optim.Adam(self.model.parameters(), lr=self.config.optim.lr)
98 | elif self.config.optim.optimizer == "adamw":
99 | return torch.optim.AdamW(self.model.parameters(), lr=self.config.optim.lr)
100 | else:
101 | raise NotImplementedError
102 |
103 | def resume_model_from_ckpt(self, ckpt):
104 | ckpt = torch.load(ckpt)
105 | if "model" in ckpt:
106 | self.model.load_state_dict(ckpt["model"])
107 | elif "model_state_dict" in ckpt:
108 | self.model.load_state_dict(ckpt["model_state_dict"])
109 | else:
110 | try:
111 | self.model.load_state_dict(ckpt)
112 | except:
113 | raise ValueError("Invalid checkpoint")
114 |
115 | def resume_from_ckpt(self, ckpt, resume_only_model=False):
116 | ckpt = torch.load(ckpt, map_location='cpu')
117 | self.model.load_state_dict(ckpt['model_state_dict'])
118 | if not resume_only_model:
119 | self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
120 | self.epoch = ckpt['epoch']
121 | if self.scheduler and 'scheduler_state_dict' in ckpt:
122 | self.scheduler.load_state_dict(ckpt['scheduler_state_dict'])
123 | if "tracker" in ckpt:
124 | self.logged_tracker = ckpt['tracker']
125 |
126 | def configure_tracker(self):
127 | return ExpTracker()
128 |
129 | def configure_loss(self):
130 | lc = namedtuple("loss_config",
131 | ["retrieval_fn", "loss_fn", "weight", "seq_weight", "seq_norm"])
132 | self.loss_config = dict()
133 | for quantity, config in self.config.loss.items():
134 | self.loss_config[quantity] = lc(
135 | get_retreival_fn(quantity),
136 | get_loss_fn_by_name(config.loss_type),
137 | config.get("weight", 1.0),
138 | config.get("seq_weight", None),
139 | config.get("seq_norm", False))
140 |
141 | def train_epoch(self):
142 | self.execute_callbacks("on_epoch_begin")
143 | self.model.train()
144 | self.model.cuda(self.config.data_loader.train.gpu)
145 | # This is necessary to sync optimizer parameter device with model device
146 | self.optimizer.load_state_dict(self.optimizer.state_dict())
147 | for batch in tqdm(self.train_dataloader):
148 | self.execute_callbacks("on_batch_begin")
149 | self.optimizer.zero_grad()
150 | batch = move_batch_to_cuda(
151 | batch, self.config.data_loader.train.gpu)
152 | out = self.model(batch)
153 | if isinstance(out, GeneratorType):
154 | loss_item = []
155 | for i, ret in enumerate(out):
156 | seq_len = len(ret["flow_trajectory"])
157 | loss, loss_breakdown = self.compute_loss(
158 | ret, batch[i*seq_len:(i+1)*seq_len])
159 | loss.backward()
160 | loss_item.append(loss.detach().item())
161 | self.loss = sum(loss_item)/len(loss_item)
162 | self.loss_1 = \
163 | loss_item[0]
164 | for loss_type, l in loss_breakdown.items():
165 | setattr(self, 'loss_'+loss_type, l.item())
166 |
167 | else:
168 | self.loss, _ = self.compute_loss(out, batch)
169 | self.loss.backward()
170 | self.execute_callbacks("on_step_begin")
171 | self.optimizer.step()
172 | self.execute_callbacks("on_step_end")
173 | self.execute_callbacks("on_batch_end")
174 | self.step += 1
175 | if self.scheduler:
176 | self.scheduler.step()
177 | self.lr = self.scheduler.get_last_lr()[0]
178 | self.execute_callbacks("on_epoch_end")
179 | self.epoch += 1
180 |
181 | def compute_loss(self, ret, batch):
182 | loss = dict()
183 | total_loss = 0
184 | for quantity, config in self.loss_config.items():
185 | estimate, ground_truth = config.retrieval_fn(ret, batch)
186 | if isinstance(estimate, list):
187 | loss_fn = lambda estimate, ground_truth: \
188 | compute_seq_loss(config.seq_weight, config.loss_fn,
189 | estimate, ground_truth)
190 | else:
191 | loss_fn = config.loss_fn
192 |
193 | loss[quantity] = loss_fn(estimate, ground_truth)
194 | total_loss += config.weight * loss[quantity]
195 | return total_loss, loss
196 |
197 | def fit(self, epochs=None):
198 | num_epochs = epochs if epochs is not None else self.config.num_epoch
199 | self.execute_callbacks("on_train_begin")
200 | try:
201 | while self.epoch < num_epochs:
202 | self.train_epoch()
203 | except:
204 | raise Exception("Training failed")
205 | finally:
206 | self.execute_callbacks("on_train_end")
207 |
208 |
--------------------------------------------------------------------------------
/idn/utils/transformers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import re
5 | from torchvision.transforms import RandomCrop
6 | import torchvision.transforms.functional as TF
7 |
8 |
9 | def dictionary_of_numpy_arrays_to_tensors(sample):
10 | """Transforms dictionary of numpy arrays to dictionary of tensors."""
11 | if isinstance(sample, dict):
12 | return {
13 | key: dictionary_of_numpy_arrays_to_tensors(value)
14 | for key, value in sample.items()
15 | }
16 | if isinstance(sample, np.ndarray):
17 | if len(sample.shape) == 2:
18 | return torch.from_numpy(sample).float().unsqueeze(0)
19 | else:
20 | return torch.from_numpy(sample).float()
21 | return sample
22 |
23 |
24 | class EventSequenceToVoxelGrid_Pytorch(object):
25 | # Source: https://github.com/uzh-rpg/rpg_e2vid/blob/master/utils/inference_utils.py#L480
26 | def __init__(self, num_bins, gpu=False, gpu_nr=1, normalize=True, forkserver=True):
27 | if forkserver:
28 | try:
29 | torch.multiprocessing.set_start_method('forkserver')
30 | except RuntimeError:
31 | pass
32 | self.num_bins = num_bins
33 | self.normalize = normalize
34 | if gpu:
35 | if not torch.cuda.is_available():
36 | print('Warning: There\'s no CUDA support on this machine!')
37 | else:
38 | self.device = torch.device('cuda:' + str(gpu_nr))
39 | else:
40 | self.device = torch.device('cpu')
41 |
42 | def __call__(self, event_sequence):
43 | """
44 | Build a voxel grid with bilinear interpolation in the time domain from a set of events.
45 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity]
46 | :param num_bins: number of bins in the temporal axis of the voxel grid
47 | :param width, height: dimensions of the voxel grid
48 | :param device: device to use to perform computations
49 | :return voxel_grid: PyTorch event tensor (on the device specified)
50 | """
51 |
52 | events = event_sequence.features.astype('float')
53 |
54 | width = event_sequence.image_width
55 | height = event_sequence.image_height
56 |
57 | assert (events.shape[1] == 4)
58 | assert (self.num_bins > 0)
59 | assert (width > 0)
60 | assert (height > 0)
61 |
62 | with torch.no_grad():
63 |
64 | events_torch = torch.from_numpy(events)
65 | # with DeviceTimer('Events -> Device (voxel grid)'):
66 | events_torch = events_torch.to(self.device)
67 |
68 | # with DeviceTimer('Voxel grid voting'):
69 | voxel_grid = torch.zeros(
70 | self.num_bins, height, width, dtype=torch.float32, device=self.device).flatten()
71 |
72 | # normalize the event timestamps so that they lie between 0 and num_bins
73 | last_stamp = events_torch[-1, 0]
74 | first_stamp = events_torch[0, 0]
75 |
76 | assert last_stamp.dtype == torch.float64, 'Timestamps must be float64!'
77 | # assert last_stamp.item()%1 == 0, 'Timestamps should not have decimals'
78 |
79 | deltaT = last_stamp - first_stamp
80 |
81 | if deltaT == 0:
82 | deltaT = 1.0
83 |
84 | events_torch[:, 0] = (self.num_bins - 1) * \
85 | (events_torch[:, 0] - first_stamp) / deltaT
86 | ts = events_torch[:, 0]
87 | xs = events_torch[:, 1].long()
88 | ys = events_torch[:, 2].long()
89 | pols = events_torch[:, 3].float()
90 | pols[pols == 0] = -1 # polarity should be +1 / -1
91 |
92 | tis = torch.floor(ts)
93 | tis_long = tis.long()
94 | dts = ts - tis
95 | vals_left = pols * (1.0 - dts.float())
96 | vals_right = pols * dts.float()
97 |
98 | valid_indices = tis < self.num_bins
99 | valid_indices &= tis >= 0
100 |
101 | if events_torch.is_cuda:
102 | datatype = torch.cuda.LongTensor
103 | else:
104 | datatype = torch.LongTensor
105 |
106 | voxel_grid.index_add_(dim=0,
107 | index=(xs[valid_indices] + ys[valid_indices]
108 | * width + tis_long[valid_indices] * width * height).type(
109 | datatype),
110 | source=vals_left[valid_indices])
111 |
112 | valid_indices = (tis + 1) < self.num_bins
113 | valid_indices &= tis >= 0
114 |
115 | voxel_grid.index_add_(dim=0,
116 | index=(xs[valid_indices] + ys[valid_indices] * width
117 | + (tis_long[valid_indices] + 1) * width * height).type(datatype),
118 | source=vals_right[valid_indices])
119 |
120 | voxel_grid = voxel_grid.view(self.num_bins, height, width)
121 |
122 | if self.normalize:
123 | mask = torch.nonzero(voxel_grid, as_tuple=True)
124 | if mask[0].size()[0] > 0:
125 | mean = voxel_grid[mask].mean()
126 | std = voxel_grid[mask].std()
127 | if std > 0:
128 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std
129 | else:
130 | voxel_grid[mask] = voxel_grid[mask] - mean
131 |
132 | return voxel_grid
133 |
134 |
135 | def apply_transform_to_field(sample, func, field_name):
136 | """
137 | Applies a function to a field of a sample.
138 | :param sample: a sample
139 | :param func: a function that takes a numpy array and returns a numpy array
140 | :param field_name: the name of the field to transform
141 | :return: the transformed sample
142 | """
143 | if isinstance(sample, dict):
144 | for key, value in sample.items():
145 | if bool(re.search(field_name, key)):
146 | sample[key] = func(value)
147 |
148 | if isinstance(sample, np.ndarray) or isinstance(sample, torch.Tensor):
149 | return func(sample)
150 | return sample
151 |
152 |
153 | def apply_randomcrop_to_sample(sample, crop_size):
154 | """
155 | Applies a random crop to a sample.
156 | :param sample: a sample
157 | :param crop_size: the size of the crop
158 | :return: the cropped sample
159 | """
160 | i, j, h, w = RandomCrop.get_params(
161 | sample["event_volume_old"], output_size=crop_size)
162 | keys_to_crop = ["event_volume_old", "event_volume_new",
163 | "flow_gt_event_volume_old", "flow_gt_event_volume_new", "reverse_flow_gt_event_volume_old", "reverse_flow_gt_event_volume_new"]
164 |
165 | for key, value in sample.items():
166 | if key in keys_to_crop:
167 | if isinstance(value, torch.Tensor):
168 | sample[key] = TF.crop(value, i, j, h, w)
169 | elif isinstance(value, list) or isinstance(value, tuple):
170 | sample[key] = [TF.crop(v, i, j, h, w) for v in value]
171 | return sample
172 |
173 |
174 | def downsample_spatial(x, factor):
175 | """
176 | Downsample a given tensor spatially by a factor.
177 | :param x: PyTorch tensor of shape [batch, num_bins, height, width]
178 | :param factor: downsampling factor
179 | :return: PyTorch tensor of shape [batch, num_bins, height/factor, width/factor]
180 | """
181 | assert (factor > 0), 'Factor must be positive!'
182 |
183 | assert (x.shape[-1] %
184 | factor == 0), 'Width of x must be divisible by factor!'
185 | assert (x.shape[-2] %
186 | factor == 0), 'Height of x must be divisible by factor!'
187 |
188 | return nn.AvgPool2d(kernel_size=factor, stride=factor)(x)
189 |
190 |
191 | def downsample_spatial_mask(x, factor):
192 | """
193 | Downsample a given mask (boolean) spatially by a factor.
194 | :param x: PyTorch tensor of shape [batch, num_bins, height, width]
195 | :param factor: downsampling factor
196 | :return: PyTorch tensor of shape [batch, num_bins, height/factor, width/factor]
197 | """
198 | assert (factor > 0), 'Factor must be positive!'
199 |
200 | assert (x.shape[-1] %
201 | factor == 0), 'Width of x must be divisible by factor!'
202 | assert (x.shape[-2] %
203 | factor == 0), 'Height of x must be divisible by factor!'
204 |
205 | return nn.AvgPool2d(kernel_size=factor, stride=factor)(x.float()) >= 0.5
206 |
--------------------------------------------------------------------------------
/idn/utils/validation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import pickle
5 | from tqdm import tqdm
6 | import importlib
7 |
8 | from .helper_functions import move_batch_to_cuda, move_list_to_cuda, move_tensor_to_cuda
9 | from ..model.loss import sparse_lnorm
10 |
11 |
12 | def log_tensors(idx, dict_of_tensors, tmp_folder):
13 | for key, value in dict_of_tensors.items():
14 | if isinstance(value, torch.Tensor):
15 | tensor = value.cpu().numpy()
16 | np.save(os.path.join(tmp_folder, f"{key}_{idx}.npy"), tensor)
17 | continue
18 | if isinstance(value, np.ndarray):
19 | np.save(os.path.join(tmp_folder, f"{key}_{idx}.npy"), value)
20 | continue
21 |
22 |
23 |
24 | class Validator:
25 | def __init__(self, config):
26 | self.test = dict()
27 | # self.logger = Logger.from_config(config.logger, config.name)
28 | for test_name, test_spec in config.items():
29 | # create a test object
30 | self.test[test_name] = self.get_test_type(test_name)(test_spec)
31 |
32 | @staticmethod
33 | def get_test_type(test_name, test_type=None):
34 | from ..tests import dsec as dsec_tests
35 | if test_name == "dsec":
36 | return dsec_tests.assemble_dsec_test_cls(test_type)
37 | if test_name == "mvsec_day1":
38 | return dsec_tests.TestMVSEC
39 | if test_name == "mvsec_day1_rec":
40 | return dsec_tests.TestMVSECCO
41 | try:
42 | return getattr(dsec_tests, f"Test{test_name.upper()}")
43 | except:
44 | raise ValueError(f"Test {test_name} not found.")
45 |
46 | def __call__(self, model, run_all=True):
47 | # run the corresponding validator
48 | results = dict()
49 | for name, test in self.test.items():
50 | state, results[name] = test.execute_test(model)
51 | return results
52 |
53 |
54 | def validate_model_warm(data_loader, model, config, return_dict, gpu_id, test_mode=False,
55 | log_field_dict=None, log_dir=None):
56 | model.cuda(gpu_id)
57 | model.eval()
58 |
59 | metrics = dict()
60 |
61 | with torch.no_grad():
62 | if not isinstance(data_loader, list):
63 | data_loader = [data_loader]
64 | data_loader = data_loader[0:1]
65 | for seq_loader in data_loader:
66 | for idx, batch in enumerate(tqdm(seq_loader, position=1)):
67 | # if idx == 730:
68 | # break
69 | if isinstance(batch, list):
70 | batch = move_list_to_cuda(batch, gpu_id)
71 | loss_item = batch[-1]
72 | else:
73 | move_tensor_to_cuda(batch, gpu_id)
74 | loss_item = batch
75 | if idx == 0:
76 | out = model(batch)
77 | else:
78 | pre_flow = model.forward_flow(pred)
79 | if isinstance(batch, list):
80 | batch[0]["pre_flow"] = pre_flow
81 | else:
82 | batch["pre_flow"] = pre_flow
83 | out = model(batch, pre_flow, 1)
84 | if isinstance(batch, list):
85 | batch[-1] = {**batch[-1], **out}
86 | pred = batch[-1]["final_prediction"]
87 | else:
88 | batch = {**batch, **out}
89 | pred = batch["final_prediction"]
90 |
91 | if not test_mode and idx != 0:
92 | loss_l1, emap = sparse_lnorm(1, pred, loss_item['flow_gt_event_volume_new'],
93 | loss_item['flow_gt_event_volume_new_valid_mask'],
94 | per_frame=True)
95 | loss_l2, _ = sparse_lnorm(2, pred, loss_item['flow_gt_event_volume_new'],
96 | loss_item['flow_gt_event_volume_new_valid_mask'],
97 | per_frame=True)
98 | loss_l1_pre, pre_emap = sparse_lnorm(1, pre_flow, loss_item['flow_gt_event_volume_new'],
99 | loss_item['flow_gt_event_volume_new_valid_mask'],
100 | per_frame=True)
101 | loss_l2_pre, _ = sparse_lnorm(2, pre_flow, loss_item['flow_gt_event_volume_new'],
102 | loss_item['flow_gt_event_volume_new_valid_mask'],
103 | per_frame=True)
104 | if isinstance(batch, list):
105 | batch = batch[-1]
106 | batch['emap'] = emap
107 | batch['pre_emap'] = pre_emap
108 | for i, sample_seq in enumerate(batch['seq_name']):
109 | if sample_seq not in metrics.keys():
110 | metrics[sample_seq] = dict()
111 | metrics[sample_seq]["l2"] = []
112 | metrics[sample_seq]["l1"] = []
113 | metrics[sample_seq]["l2_pre"] = []
114 | metrics[sample_seq]["l1_pre"] = []
115 | metrics[sample_seq]["l2"].append(loss_l2[i])
116 | metrics[sample_seq]["l1"].append(loss_l1[i])
117 | metrics[sample_seq]["l2_pre"].append(loss_l2_pre[i])
118 | metrics[sample_seq]["l1_pre"].append(loss_l1_pre[i])
119 |
120 | if log_dir:
121 | log_tensors_dict = dict()
122 | for key, value in log_field_dict.items():
123 | if isinstance(batch, list):
124 | batch = batch[-1]
125 | if value in batch.keys():
126 | log_tensors_dict[key] = batch[value]
127 | log_tensors(idx, log_tensors_dict, log_dir)
128 |
129 | for seqs in metrics.keys():
130 | metrics[seqs]["l2_avg"] = np.array(metrics[seqs]["l2"]).mean()
131 | metrics[seqs]["l1_avg"] = np.array(metrics[seqs]["l1"]).mean()
132 | return_dict[seqs] = metrics[seqs]
133 |
134 | if log_dir:
135 | with open(os.path.join(log_dir, "metrics.pkl"), "wb") as f:
136 | pickle.dump(metrics, f)
137 |
138 |
139 | def validate_model(data_loader, model, config, return_dict, gpu_id, test_mode=False,
140 | log_field_dict=None, log_dir=None):
141 | model.cuda(gpu_id)
142 | model.eval()
143 |
144 | metrics = dict()
145 |
146 | with torch.no_grad():
147 | if not isinstance(data_loader, list):
148 | data_loader = [data_loader]
149 | data_loader = data_loader[0:1] # TODO: remove hard coded first seq
150 | for seq_loader in data_loader:
151 | for idx, batch in enumerate(tqdm(seq_loader, position=1)):
152 | if isinstance(batch, list):
153 | batch = move_list_to_cuda(batch, gpu_id)
154 | loss_item = batch[-1]
155 | else:
156 | move_tensor_to_cuda(batch, gpu_id)
157 | loss_item = batch
158 | out = model(batch)
159 | if not isinstance(batch, list):
160 | batch = {**batch, **out}
161 | pred = batch["final_prediction"]
162 | else:
163 | pred = out["final_prediction"]
164 | batch = batch[-1]
165 | if not test_mode:
166 | loss_l1, _ = sparse_lnorm(1, pred, loss_item['flow_gt_event_volume_new'],
167 | loss_item['flow_gt_event_volume_new_valid_mask'],
168 | per_frame=True)
169 | loss_l2, _ = sparse_lnorm(2, pred, loss_item['flow_gt_event_volume_new'],
170 | loss_item['flow_gt_event_volume_new_valid_mask'],
171 | per_frame=True)
172 | for i, sample_seq in enumerate(batch['seq_name']):
173 | if sample_seq not in metrics.keys():
174 | metrics[sample_seq] = dict()
175 | metrics[sample_seq]["l2"] = []
176 | metrics[sample_seq]["l1"] = []
177 | metrics[sample_seq]["l2"].append(loss_l2[i])
178 | metrics[sample_seq]["l1"].append(loss_l1[i])
179 |
180 | if log_dir:
181 | log_tensors_dict = dict()
182 | for key, value in log_field_dict.items():
183 | if value in batch.keys():
184 | log_tensors_dict[key] = batch[value]
185 | log_tensors(idx, log_tensors_dict, log_dir)
186 |
187 | for seqs in metrics.keys():
188 | metrics[seqs]["l2_avg"] = np.array(metrics[seqs]["l2"]).mean()
189 | metrics[seqs]["l1_avg"] = np.array(metrics[seqs]["l1"]).mean()
190 | return_dict[seqs] = metrics[seqs]
191 |
192 | if log_dir:
193 | with open(os.path.join(log_dir, "metrics.pkl"), "wb") as f:
194 | pickle.dump(metrics, f)
195 |
--------------------------------------------------------------------------------