├── .gitignore
├── LICENSE
├── README.md
├── checkpoints
└── .keep
├── config
├── dsec_standard.json
├── dsec_warm_start.json
├── mvsec_20.json
└── mvsec_45.json
├── download_dsec_test.py
├── environment.yml
├── loader
├── loader_dsec.py
├── loader_mvsec_flow.py
└── utils.py
├── main.py
├── model
├── corr.py
├── eraft.py
├── extractor.py
├── update.py
└── utils.py
├── test.py
└── utils
├── dsec_utils.py
├── filename_templates.py
├── helper_functions.py
├── image_utils.py
├── logger.py
├── mvsec_utils.py
├── transformers.py
└── visualization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | .vscode/
3 | __pycache__/
4 | venv/
5 | datasets/
6 | saved/*
7 | statistics*
8 | utils/login_credentials.json
9 | build/
10 | checkpoints/*.tar
11 | !config/examples
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Robotics and Perception Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # E-RAFT: Dense Optical Flow from Event Cameras
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | This is the code for the paper **E-RAFT: Dense Optical Flow from Event Cameras** by [Mathias Gehrig](https://magehrig.github.io/), Mario Millhäusler, [Daniel Gehrig](https://danielgehrig18.github.io/) and [Davide Scaramuzza](http://rpg.ifi.uzh.ch/people_scaramuzza.html).
10 |
11 | We also introduce DSEC-Flow ([download here](https://dsec.ifi.uzh.ch/dsec-datasets/download/)), the optical flow extension of the [DSEC](https://dsec.ifi.uzh.ch/) dataset. We are also hosting an automatic evaluation server and a [public benchmark](https://dsec.ifi.uzh.ch/uzh/dsec-flow-optical-flow-benchmark/)!
12 |
13 | Visit our [project webpage](http://rpg.ifi.uzh.ch/ERAFT.html) or download the paper directly [here](https://dsec.ifi.uzh.ch/wp-content/uploads/2021/10/eraft_3dv.pdf) for more details.
14 | If you use any of this code, please cite the following publication:
15 |
16 | ```bibtex
17 | @InProceedings{Gehrig3dv2021,
18 | author = {Mathias Gehrig and Mario Millh\"ausler and Daniel Gehrig and Davide Scaramuzza},
19 | title = {E-RAFT: Dense Optical Flow from Event Cameras},
20 | booktitle = {International Conference on 3D Vision (3DV)},
21 | year = {2021}
22 | }
23 | ```
24 |
25 | ## Download
26 |
27 | Download the network checkpoints and place them in the folder ```checkpoints/```:
28 |
29 |
30 | [Checkpoint trained on DSEC](https://download.ifi.uzh.ch/rpg/ERAFT/checkpoints/dsec.tar)
31 |
32 | [Checkpoint trained on MVSEC 20 Hz](https://download.ifi.uzh.ch/rpg/ERAFT/checkpoints/mvsec_20.tar)
33 |
34 | [Checkpoint trained on MVSEC 45 Hz](https://download.ifi.uzh.ch/rpg/ERAFT/checkpoints/mvsec_45.tar)
35 |
36 |
37 | ## Installation
38 | Please install [conda](https://www.anaconda.com/download).
39 | Then, create new conda environment with python3.7 and all dependencies by running
40 | ```
41 | conda env create --file environment.yml
42 | ```
43 |
44 | ## Datasets
45 | ### DSEC
46 | The DSEC dataset for optical flow can be downloaded [here](https://dsec.ifi.uzh.ch/dsec-datasets/download/).
47 | We prepared a script [download_dsec_test.py](download_dsec_test.py) for your convenience.
48 | It downloads the dataset directly into the `OUTPUT_DIRECTORY` with the expected directory structure.
49 | ```python
50 | download_dsec_test.py OUTPUT_DIRECTORY
51 | ```
52 |
53 | ### MVSEC
54 | To use the MVSEC dataset for our approach, it needs to be pre-processed into the right format. For your convenience, we provide the pre-processed dataset here:
55 |
56 | [MVSEC Outdoor Day 1 for 20 Hz evaluation](https://download.ifi.uzh.ch/rpg/ERAFT/datasets/mvsec_outdoor_day_1_20Hz.tar)
57 |
58 | [MVSEC Outdoor Day 1 for 45 Hz evaluation](https://download.ifi.uzh.ch/rpg/ERAFT/datasets/mvsec_outdoor_day_1_45Hz.tar)
59 |
60 | ## Experiments
61 | ### DSEC Dataset
62 | For the evaluation of our method with warm-starting, execute the following command:
63 | ```
64 | python3 main.py --path
65 | ```
66 | For the evaluation of our method **without** warm-starting, execute the following command:
67 | ```
68 | python3 main.py --path --type standard
69 | ```
70 | ### MVSEC Dataset
71 | For the evaluation of our method with warm-starting, trained on 20Hz MVSEC data, execute the following command:
72 | ```
73 | python3 main.py --path --dataset mvsec --frequency 20
74 | ```
75 | For the evaluation of our method with warm-starting, trained on 45Hz MVSEC data, execute the following command:
76 | ```
77 | python3 main.py --path --dataset mvsec --frequency 45
78 | ```
79 |
80 | ### Arguments
81 | ```--path``` : Path where you stored the dataset
82 |
83 | ```--dataset``` : Which dataset to use: ([dsec]/mvsec)
84 |
85 | ```--type``` : Evaluation type ([warm_start]/standard)
86 |
87 | ```--frequency``` : Evaluation frequency of MVSEC dataset ([20]/45) Hz
88 |
89 | ```--visualize``` : Provide this argument s.t. DSEC results are visualized. MVSEC experiments are always visualized.
90 |
91 | ```--num_workers``` : How many sub-processes to use for data loading (default=0)
92 |
--------------------------------------------------------------------------------
/checkpoints/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/E-RAFT/c58ce0524ea0ebfa9849991caafb547f44fe9bfd/checkpoints/.keep
--------------------------------------------------------------------------------
/config/dsec_standard.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "dsec_standard",
3 | "cuda": true,
4 | "gpu": 0,
5 | "subtype": "standard",
6 | "save_dir": "saved",
7 | "data_loader": {
8 | "test": {
9 | "args": {
10 | "batch_size": 1,
11 | "shuffle": false,
12 | "num_voxel_bins": 15
13 | }
14 | }
15 | },
16 | "test": {
17 | "checkpoint": "checkpoints/dsec.tar"
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/config/dsec_warm_start.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "dsec_warm_start",
3 | "cuda": true,
4 | "gpu": 0,
5 | "subtype": "warm_start",
6 | "save_dir": "saved",
7 | "data_loader": {
8 | "test": {
9 | "args": {
10 | "batch_size": 1,
11 | "shuffle": false,
12 | "sequence_length": 1,
13 | "num_voxel_bins": 15
14 | }
15 | }
16 | },
17 | "test": {
18 | "checkpoint": "checkpoints/dsec.tar"
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/config/mvsec_20.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "mvsec_20Hz",
3 | "cuda": true,
4 | "gpu": 0,
5 | "subtype": "warm_start",
6 | "save_dir": "saved",
7 | "data_loader": {
8 | "test": {
9 | "args": {
10 | "batch_size": 1,
11 | "shuffle": false,
12 | "sequence_length": 1,
13 | "num_voxel_bins": 15,
14 | "align_to": "depth",
15 | "datasets": {
16 | "outdoor_day": [
17 | 1
18 | ]
19 | },
20 | "filter": {
21 | "outdoor_day": {
22 | "1": "range(4356, 4706)"
23 | }
24 | },
25 | "transforms": [
26 | "EventSequenceToVoxelGrid_Pytorch(num_bins=15, normalize=True, gpu=True)",
27 | "RandomCropping(crop_height=256, crop_width=256, fixed=True)"
28 | ]
29 | }
30 | }
31 | },
32 | "test": {
33 | "checkpoint": "checkpoints/mvsec_20.tar"
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/config/mvsec_45.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "mvsec_45Hz",
3 | "cuda": true,
4 | "gpu": 0,
5 | "subtype": "warm_start",
6 | "save_dir": "saved",
7 | "data_loader": {
8 | "test": {
9 | "args": {
10 | "batch_size": 1,
11 | "shuffle": false,
12 | "sequence_length": 1,
13 | "num_voxel_bins": 5,
14 | "align_to": "images",
15 | "datasets": {
16 | "outdoor_day": [
17 | 1
18 | ]
19 | },
20 | "filter": {
21 | "outdoor_day": {
22 | "1": "range(10167,10954)"
23 | }
24 | },
25 | "transforms": [
26 | "EventSequenceToVoxelGrid_Pytorch(num_bins=5, normalize=True, gpu=True)",
27 | "RandomCropping(crop_height=256, crop_width=256, fixed=True)"
28 | ]
29 | }
30 | }
31 | },
32 | "test": {
33 | "checkpoint": "checkpoints/mvsec_45.tar"
34 | }
35 | }
--------------------------------------------------------------------------------
/download_dsec_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import os
4 | import urllib
5 | import shutil
6 | from typing import Union
7 |
8 | from requests import get
9 |
10 | TEST_SEQUENCES = ['interlaken_00_b', 'interlaken_01_a', 'thun_01_a', 'thun_01_b', 'zurich_city_12_a', 'zurich_city_14_c', 'zurich_city_15_a']
11 | BASE_TEST_URL = 'https://download.ifi.uzh.ch/rpg/DSEC/test/'
12 | TEST_FLOW_TIMESTAMPS_URL = 'https://download.ifi.uzh.ch/rpg/DSEC/test_forward_optical_flow_timestamps.zip'
13 |
14 | def download(url: str, filepath: Path, skip: bool=True) -> bool:
15 | if skip and filepath.exists():
16 | print(f'{str(filepath)} already exists. Skipping download.')
17 | return True
18 | with open(str(filepath), 'wb') as fl:
19 | response = get(url)
20 | fl.write(response.content)
21 | return response.ok
22 |
23 | def unzip(file_: Path, delete_zip: bool=True, skip: bool=True) -> Path:
24 | assert file_.exists()
25 | assert file_.suffix == '.zip'
26 | output_dir = file_.parent / file_.stem
27 | if skip and output_dir.exists():
28 | print(f'{str(output_dir)} already exists. Skipping unzipping operation.')
29 | else:
30 | shutil.unpack_archive(file_, output_dir)
31 | if delete_zip:
32 | os.remove(file_)
33 | return output_dir
34 |
35 | if __name__ == '__main__':
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('output_directory')
38 |
39 | args = parser.parse_args()
40 |
41 | output_dir = Path(args.output_directory)
42 | output_dir = output_dir / 'test'
43 | os.makedirs(output_dir, exist_ok=True)
44 |
45 | test_timestamps_file = output_dir / 'test_forward_flow_timestamps.zip'
46 |
47 | assert download(TEST_FLOW_TIMESTAMPS_URL, test_timestamps_file), TEST_FLOW_TIMESTAMPS_URL
48 | test_timestamps_dir = unzip(test_timestamps_file)
49 |
50 | for seq_name in TEST_SEQUENCES:
51 | seq_path = output_dir / seq_name
52 | os.makedirs(seq_path, exist_ok=True)
53 |
54 | # image timestamps
55 | img_timestamps_url = BASE_TEST_URL + seq_name + '/' + seq_name + '_image_timestamps.txt'
56 | img_timestamps_file = seq_path / 'image_timestamps.txt'
57 | if not img_timestamps_file.exists():
58 | assert download(img_timestamps_url, img_timestamps_file), img_timestamps_url
59 |
60 | # test timestamps
61 | test_timestamps_file_destination = seq_path / 'test_forward_flow_timestamps.csv'
62 | if not test_timestamps_file_destination.exists():
63 | shutil.move(test_timestamps_dir / (seq_name + '.csv'), test_timestamps_file_destination)
64 |
65 | # event data
66 | events_left_url = BASE_TEST_URL + seq_name + '/' + seq_name + '_events_left.zip'
67 | events_left_file = seq_path / 'events_left.zip'
68 | if not (events_left_file.parent / events_left_file.stem).exists():
69 | assert download(events_left_url, events_left_file), events_left_url
70 | unzip(events_left_file)
71 |
72 | shutil.rmtree(test_timestamps_dir)
73 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: eRaft
2 | channels:
3 | - pytorch
4 | - numba
5 | - anaconda
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=conda_forge
10 | - _openmp_mutex=4.5=1_gnu
11 | - astroid=2.5=py37h06a4308_1
12 | - blas=1.0=mkl
13 | - blosc=1.21.0=h9c3ff4c_0
14 | - blosc-hdf5-plugin=1.0.0=h646ab9b_2
15 | - brotlipy=0.7.0=py37h7b6447c_1000
16 | - bzip2=1.0.8=h7b6447c_0
17 | - c-ares=1.17.1=h7f98852_1
18 | - ca-certificates=2020.10.14=0
19 | - cached-property=1.5.2=hd8ed1ab_1
20 | - cached_property=1.5.2=pyha770c72_1
21 | - cairo=1.16.0=h6cf1ce9_1008
22 | - certifi=2020.6.20=py37_0
23 | - cffi=1.14.3=py37he30daa8_0
24 | - chardet=3.0.4=py37_1003
25 | - cloudpickle=1.6.0=py_0
26 | - cryptography=3.1.1=py37h1ba5d50_0
27 | - cudatoolkit=10.2.89=hfd86e86_1
28 | - cycler=0.10.0=py37_0
29 | - cytoolz=0.11.0=py37h7b6447c_0
30 | - dask-core=2.30.0=py_0
31 | - dbus=1.13.6=h48d8840_2
32 | - decorator=4.4.2=py_0
33 | - expat=2.3.0=h9c3ff4c_0
34 | - ffmpeg=4.3.1=hca11adc_2
35 | - fontconfig=2.13.1=hba837de_1004
36 | - freetype=2.10.4=h5ab3b9f_0
37 | - gettext=0.19.8.1=h0b5b191_1005
38 | - gitdb=4.0.7=pyhd3eb1b0_0
39 | - gitpython=3.1.14=pyhd3eb1b0_1
40 | - glib=2.68.0=h9c3ff4c_2
41 | - glib-tools=2.68.0=h9c3ff4c_2
42 | - gmp=6.2.1=h58526e2_0
43 | - gnutls=3.6.13=h85f3911_1
44 | - graphite2=1.3.13=h58526e2_1001
45 | - gst-plugins-base=1.18.4=h29181c9_0
46 | - gstreamer=1.18.4=h76c114f_0
47 | - h5py=3.1.0=nompi_py37h1e651dc_100
48 | - harfbuzz=2.8.0=h83ec7ef_1
49 | - hdf5=1.10.6=nompi_h7c3c948_1111
50 | - icu=68.1=h58526e2_0
51 | - idna=2.10=py_0
52 | - imageio=2.9.0=py_0
53 | - intel-openmp=2020.2=254
54 | - isort=5.8.0=pyhd3eb1b0_0
55 | - jasper=1.900.1=h07fcdf6_1006
56 | - jpeg=9d=h36c2ea0_0
57 | - kiwisolver=1.2.0=py37hfd86e86_0
58 | - krb5=1.17.2=h926e7f8_0
59 | - lame=3.100=h7f98852_1001
60 | - lazy-object-proxy=1.6.0=py37h27cfd23_0
61 | - lcms2=2.11=h396b838_0
62 | - ld_impl_linux-64=2.33.1=h53a641e_7
63 | - libblas=3.9.0=1_h6e990d7_netlib
64 | - libcblas=3.9.0=3_h893e4fe_netlib
65 | - libclang=11.1.0=default_ha53f305_0
66 | - libcurl=7.76.0=hc4aaa36_0
67 | - libedit=3.1.20191231=he28a2e2_2
68 | - libev=4.33=h516909a_1
69 | - libevent=2.1.10=hcdb4288_3
70 | - libffi=3.3=he6710b0_2
71 | - libgcc-ng=9.3.0=h2828fa1_18
72 | - libgfortran-ng=7.5.0=h14aa051_18
73 | - libgfortran4=7.5.0=h14aa051_18
74 | - libglib=2.68.0=h3e27bee_2
75 | - libgomp=9.3.0=h2828fa1_18
76 | - libiconv=1.16=h516909a_0
77 | - liblapack=3.9.0=3_h893e4fe_netlib
78 | - liblapacke=3.9.0=3_h893e4fe_netlib
79 | - libllvm11=11.1.0=hf817b99_1
80 | - libnghttp2=1.43.0=h812cca2_0
81 | - libopencv=4.5.1=py37h5fff631_1
82 | - libpng=1.6.37=hbc83047_0
83 | - libpq=13.1=hfd2b0eb_2
84 | - libprotobuf=3.15.6=h780b84a_0
85 | - libssh2=1.9.0=ha56f1ee_6
86 | - libstdcxx-ng=9.3.0=h6de172a_18
87 | - libtiff=4.2.0=hdc55705_0
88 | - libuuid=2.32.1=h7f98852_1000
89 | - libuv=1.40.0=h7b6447c_0
90 | - libwebp-base=1.2.0=h7f98852_2
91 | - libxcb=1.13=h7f98852_1003
92 | - libxkbcommon=1.0.3=he3ba5ed_0
93 | - libxml2=2.9.10=h72842e0_3
94 | - llvmlite=0.36.0=py37hf484d3e_0
95 | - lz4-c=1.9.3=h9c3ff4c_0
96 | - lzo=2.10=h7b6447c_2
97 | - matplotlib=3.4.1=py37h89c1867_0
98 | - matplotlib-base=3.4.1=py37hdd32ed1_0
99 | - mccabe=0.6.1=py37_1
100 | - mkl=2019.4=243
101 | - mkl-service=2.3.0=py37he904b0f_0
102 | - mkl_fft=1.2.0=py37h23d657b_0
103 | - mkl_random=1.0.4=py37hd81dba3_0
104 | - mock=4.0.2=py_0
105 | - mysql-common=8.0.23=ha770c72_1
106 | - mysql-libs=8.0.23=h935591d_1
107 | - ncurses=6.2=he6710b0_1
108 | - nettle=3.6=he412f7d_0
109 | - networkx=2.5=py_0
110 | - ninja=1.10.2=py37hff7bd54_0
111 | - nspr=4.30=h9c3ff4c_0
112 | - nss=3.63=hb5efdd6_0
113 | - numba=0.53.1=np1.11py3.7h04863e7_g97fe221b3_0
114 | - numexpr=2.7.1=py37h423224d_0
115 | - numpy=1.19.1=py37hbc911f0_0
116 | - numpy-base=1.19.1=py37hfa32c7d_0
117 | - olefile=0.46=py37_0
118 | - opencv=4.5.1=py37h89c1867_1
119 | - openh264=2.1.1=h780b84a_0
120 | - openssl=1.1.1l=h7f8727e_0
121 | - pandas=1.2.3=py37ha9443f7_0
122 | - pcre=8.44=he1b5a44_0
123 | - pillow=8.0.0=py37h9a89aac_0
124 | - pip=21.0.1=py37h06a4308_0
125 | - pixman=0.40.0=h36c2ea0_0
126 | - pthread-stubs=0.4=h36c2ea0_1001
127 | - py-opencv=4.5.1=py37h085eea5_1
128 | - pycparser=2.20=py_2
129 | - pylint=2.7.4=py37h06a4308_1
130 | - pyopenssl=19.1.0=py_1
131 | - pyparsing=2.4.7=py_0
132 | - pyqt=5.12.3=py37h89c1867_7
133 | - pyqt-impl=5.12.3=py37he336c9b_7
134 | - pyqt5-sip=4.19.18=py37hcd2ae1e_7
135 | - pyqtchart=5.12=py37he336c9b_7
136 | - pyqtwebengine=5.12.1=py37he336c9b_7
137 | - pysocks=1.7.1=py37_1
138 | - pytables=3.6.1=py37h0c4f3e0_3
139 | - python=3.7.10=hdb3f193_0
140 | - python-dateutil=2.8.1=py_0
141 | - python_abi=3.7=1_cp37m
142 | - pytorch=1.7.0=py3.7_cuda10.2.89_cudnn7.6.5_0
143 | - pytz=2021.1=pyhd3eb1b0_0
144 | - pywavelets=1.1.1=py37h7b6447c_2
145 | - pyyaml=5.3.1=py37h7b6447c_1
146 | - qt=5.12.9=hda022c4_4
147 | - readline=8.1=h27cfd23_0
148 | - requests=2.24.0=py_0
149 | - ruamel.yaml=0.16.12=py37h5e8e339_2
150 | - ruamel.yaml.clib=0.2.2=py37h5e8e339_2
151 | - scikit-image=0.17.2=py37hdf5156a_0
152 | - scipy=1.5.2=py37h0b6359f_0
153 | - setuptools=52.0.0=py37h06a4308_0
154 | - six=1.15.0=py_0
155 | - smmap=3.0.5=pyhd3eb1b0_0
156 | - snappy=1.1.8=he6710b0_0
157 | - sqlite=3.35.3=hdfb4753_0
158 | - tifffile=2020.10.1=py37hdd07704_2
159 | - tk=8.6.10=hbc83047_0
160 | - toml=0.10.2=pyhd3eb1b0_0
161 | - toolz=0.11.1=py_0
162 | - torchvision=0.8.1=py37_cu102
163 | - tornado=6.0.4=py37h7b6447c_1
164 | - tqdm=4.59.0=pyhd8ed1ab_0
165 | - typed-ast=1.4.2=py37h27cfd23_1
166 | - typing_extensions=3.7.4.3=pyha847dfd_0
167 | - urllib3=1.25.11=py_0
168 | - wheel=0.36.2=pyhd3eb1b0_0
169 | - wrapt=1.12.1=py37h7b6447c_1
170 | - x264=1!161.3030=h7f98852_0
171 | - xorg-kbproto=1.0.7=h7f98852_1002
172 | - xorg-libice=1.0.10=h7f98852_0
173 | - xorg-libsm=1.2.3=hd9c2040_1000
174 | - xorg-libx11=1.7.0=h7f98852_0
175 | - xorg-libxau=1.0.9=h7f98852_0
176 | - xorg-libxdmcp=1.1.3=h7f98852_0
177 | - xorg-libxext=1.3.4=h7f98852_1
178 | - xorg-libxrender=0.9.10=h7f98852_1003
179 | - xorg-renderproto=0.11.1=h7f98852_1002
180 | - xorg-xextproto=7.3.0=h7f98852_1002
181 | - xorg-xproto=7.0.31=h7f98852_1007
182 | - xz=5.2.5=h7b6447c_0
183 | - yaml=0.2.5=h7b6447c_0
184 | - zlib=1.2.11=h7b6447c_3
185 | - zstd=1.4.9=ha95c52a_0
186 |
--------------------------------------------------------------------------------
/loader/loader_dsec.py:
--------------------------------------------------------------------------------
1 | import math
2 | from pathlib import Path
3 | from typing import Dict, Tuple
4 | import weakref
5 |
6 | import cv2
7 | import h5py
8 | from numba import jit
9 | import numpy as np
10 | import torch
11 | from torch.utils.data import Dataset, DataLoader
12 | from utils import visualization as visu
13 | from matplotlib import pyplot as plt
14 | from utils import transformers
15 | import os
16 | import imageio
17 |
18 | from utils.dsec_utils import RepresentationType, VoxelGrid, flow_16bit_to_float
19 |
20 | VISU_INDEX = 1
21 |
22 | class EventSlicer:
23 | def __init__(self, h5f: h5py.File):
24 | self.h5f = h5f
25 |
26 | self.events = dict()
27 | for dset_str in ['p', 'x', 'y', 't']:
28 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)]
29 |
30 | # This is the mapping from milliseconds to event index:
31 | # It is defined such that
32 | # (1) t[ms_to_idx[ms]] >= ms*1000
33 | # (2) t[ms_to_idx[ms] - 1] < ms*1000
34 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds.
35 | #
36 | # As an example, given 't' and 'ms':
37 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000
38 | # ms: 0 1 2 3 4 5 6 7 8 9
39 | #
40 | # we get
41 | #
42 | # ms_to_idx:
43 | # 0 2 2 3 3 3 5 5 8 9
44 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64')
45 |
46 | self.t_offset = int(h5f['t_offset'][()])
47 | self.t_final = int(self.events['t'][-1]) + self.t_offset
48 |
49 | def get_final_time_us(self):
50 | return self.t_final
51 |
52 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]:
53 | """Get events (p, x, y, t) within the specified time window
54 | Parameters
55 | ----------
56 | t_start_us: start time in microseconds
57 | t_end_us: end time in microseconds
58 | Returns
59 | -------
60 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved
61 | """
62 | assert t_start_us < t_end_us
63 |
64 | # We assume that the times are top-off-day, hence subtract offset:
65 | t_start_us -= self.t_offset
66 | t_end_us -= self.t_offset
67 |
68 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us)
69 | t_start_ms_idx = self.ms2idx(t_start_ms)
70 | t_end_ms_idx = self.ms2idx(t_end_ms)
71 |
72 | if t_start_ms_idx is None or t_end_ms_idx is None:
73 | # Cannot guarantee window size anymore
74 | return None
75 |
76 | events = dict()
77 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx])
78 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us)
79 | t_start_us_idx = t_start_ms_idx + idx_start_offset
80 | t_end_us_idx = t_start_ms_idx + idx_end_offset
81 | # Again add t_offset to get gps time
82 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset
83 | for dset_str in ['p', 'x', 'y']:
84 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx])
85 | assert events[dset_str].size == events['t'].size
86 | return events
87 |
88 |
89 | @staticmethod
90 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]:
91 | """Compute a conservative time window of time with millisecond resolution.
92 | We have a time to index mapping for each millisecond. Hence, we need
93 | to compute the lower and upper millisecond to retrieve events.
94 | Parameters
95 | ----------
96 | ts_start_us: start time in microseconds
97 | ts_end_us: end time in microseconds
98 | Returns
99 | -------
100 | window_start_ms: conservative start time in milliseconds
101 | window_end_ms: conservative end time in milliseconds
102 | """
103 | assert ts_end_us > ts_start_us
104 | window_start_ms = math.floor(ts_start_us/1000)
105 | window_end_ms = math.ceil(ts_end_us/1000)
106 | return window_start_ms, window_end_ms
107 |
108 | @staticmethod
109 | @jit(nopython=True)
110 | def get_time_indices_offsets(
111 | time_array: np.ndarray,
112 | time_start_us: int,
113 | time_end_us: int) -> Tuple[int, int]:
114 | """Compute index offset of start and end timestamps in microseconds
115 | Parameters
116 | ----------
117 | time_array: timestamps (in us) of the events
118 | time_start_us: start timestamp (in us)
119 | time_end_us: end timestamp (in us)
120 | Returns
121 | -------
122 | idx_start: Index within this array corresponding to time_start_us
123 | idx_end: Index within this array corresponding to time_end_us
124 | such that (in non-edge cases)
125 | time_array[idx_start] >= time_start_us
126 | time_array[idx_end] >= time_end_us
127 | time_array[idx_start - 1] < time_start_us
128 | time_array[idx_end - 1] < time_end_us
129 | this means that
130 | time_start_us <= time_array[idx_start:idx_end] < time_end_us
131 | """
132 |
133 | assert time_array.ndim == 1
134 |
135 | idx_start = -1
136 | if time_array[-1] < time_start_us:
137 | # This can happen in extreme corner cases. E.g.
138 | # time_array[0] = 1016
139 | # time_array[-1] = 1984
140 | # time_start_us = 1990
141 | # time_end_us = 2000
142 |
143 | # Return same index twice: array[x:x] is empty.
144 | return time_array.size, time_array.size
145 | else:
146 | for idx_from_start in range(0, time_array.size, 1):
147 | if time_array[idx_from_start] >= time_start_us:
148 | idx_start = idx_from_start
149 | break
150 | assert idx_start >= 0
151 |
152 | idx_end = time_array.size
153 | for idx_from_end in range(time_array.size - 1, -1, -1):
154 | if time_array[idx_from_end] >= time_end_us:
155 | idx_end = idx_from_end
156 | else:
157 | break
158 |
159 | assert time_array[idx_start] >= time_start_us
160 | if idx_end < time_array.size:
161 | assert time_array[idx_end] >= time_end_us
162 | if idx_start > 0:
163 | assert time_array[idx_start - 1] < time_start_us
164 | if idx_end > 0:
165 | assert time_array[idx_end - 1] < time_end_us
166 | return idx_start, idx_end
167 |
168 | def ms2idx(self, time_ms: int) -> int:
169 | assert time_ms >= 0
170 | if time_ms >= self.ms_to_idx.size:
171 | return None
172 | return self.ms_to_idx[time_ms]
173 |
174 |
175 | class Sequence(Dataset):
176 | def __init__(self, seq_path: Path, representation_type: RepresentationType, mode: str='test', delta_t_ms: int=100,
177 | num_bins: int=15, transforms=None, name_idx=0, visualize=False):
178 | assert num_bins >= 1
179 | assert delta_t_ms == 100
180 | assert seq_path.is_dir()
181 | assert mode in {'train', 'test'}
182 | '''
183 | Directory Structure:
184 |
185 | Dataset
186 | └── test
187 | ├── interlaken_00_b
188 | │ ├── events_left
189 | │ │ ├── events.h5
190 | │ │ └── rectify_map.h5
191 | │ ├── image_timestamps.txt
192 | │ └── test_forward_flow_timestamps.csv
193 |
194 | '''
195 |
196 | self.mode = mode
197 | self.name_idx = name_idx
198 | self.visualize_samples = visualize
199 | # Get Test Timestamp File
200 | test_timestamp_file = seq_path / 'test_forward_flow_timestamps.csv'
201 | assert test_timestamp_file.is_file()
202 | file = np.genfromtxt(
203 | test_timestamp_file,
204 | delimiter=','
205 | )
206 | self.idx_to_visualize = file[:,2]
207 |
208 | # Save output dimensions
209 | self.height = 480
210 | self.width = 640
211 | self.num_bins = num_bins
212 |
213 | # Just for now, we always train with num_bins=15
214 | assert self.num_bins==15
215 |
216 | # Set event representation
217 | self.voxel_grid = None
218 | if representation_type == RepresentationType.VOXEL:
219 | self.voxel_grid = VoxelGrid((self.num_bins, self.height, self.width), normalize=True)
220 |
221 |
222 | # Save delta timestamp in ms
223 | self.delta_t_us = delta_t_ms * 1000
224 |
225 | #Load and compute timestamps and indices
226 | timestamps_images = np.loadtxt(seq_path / 'image_timestamps.txt', dtype='int64')
227 | image_indices = np.arange(len(timestamps_images))
228 | # But only use every second one because we train at 10 Hz, and we leave away the 1st & last one
229 | self.timestamps_flow = timestamps_images[::2][1:-1]
230 | self.indices = image_indices[::2][1:-1]
231 |
232 | # Left events only
233 | ev_dir_location = seq_path / 'events_left'
234 | ev_data_file = ev_dir_location / 'events.h5'
235 | ev_rect_file = ev_dir_location / 'rectify_map.h5'
236 |
237 | h5f_location = h5py.File(str(ev_data_file), 'r')
238 | self.h5f = h5f_location
239 | self.event_slicer = EventSlicer(h5f_location)
240 | with h5py.File(str(ev_rect_file), 'r') as h5_rect:
241 | self.rectify_ev_map = h5_rect['rectify_map'][()]
242 |
243 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f)
244 |
245 | def events_to_voxel_grid(self, p, t, x, y, device: str='cpu'):
246 | t = (t - t[0]).astype('float32')
247 | t = (t/t[-1])
248 | x = x.astype('float32')
249 | y = y.astype('float32')
250 | pol = p.astype('float32')
251 | event_data_torch = {
252 | 'p': torch.from_numpy(pol),
253 | 't': torch.from_numpy(t),
254 | 'x': torch.from_numpy(x),
255 | 'y': torch.from_numpy(y),
256 | }
257 | return self.voxel_grid.convert(event_data_torch)
258 |
259 | def getHeightAndWidth(self):
260 | return self.height, self.width
261 |
262 | @staticmethod
263 | def get_disparity_map(filepath: Path):
264 | assert filepath.is_file()
265 | disp_16bit = cv2.imread(str(filepath), cv2.IMREAD_ANYDEPTH)
266 | return disp_16bit.astype('float32')/256
267 |
268 | @staticmethod
269 | def load_flow(flowfile: Path):
270 | assert flowfile.exists()
271 | assert flowfile.suffix == '.png'
272 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI')
273 | flow, valid2D = flow_16bit_to_float(flow_16bit)
274 | return flow, valid2D
275 |
276 | @staticmethod
277 | def close_callback(h5f):
278 | h5f.close()
279 |
280 | def get_image_width_height(self):
281 | return self.height, self.width
282 |
283 | def __len__(self):
284 | return len(self.timestamps_flow)
285 |
286 | def rectify_events(self, x: np.ndarray, y: np.ndarray):
287 | # assert location in self.locations
288 | # From distorted to undistorted
289 | rectify_map = self.rectify_ev_map
290 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape
291 | assert x.max() < self.width
292 | assert y.max() < self.height
293 | return rectify_map[y, x]
294 |
295 | def get_data_sample(self, index, crop_window=None, flip=None):
296 | # First entry corresponds to all events BEFORE the flow map
297 | # Second entry corresponds to all events AFTER the flow map (corresponding to the actual fwd flow)
298 | names = ['event_volume_old', 'event_volume_new']
299 | ts_start = [self.timestamps_flow[index] - self.delta_t_us, self.timestamps_flow[index]]
300 | ts_end = [self.timestamps_flow[index], self.timestamps_flow[index] + self.delta_t_us]
301 |
302 | file_index = self.indices[index]
303 |
304 | output = {
305 | 'file_index': file_index,
306 | 'timestamp': self.timestamps_flow[index]
307 | }
308 | # Save sample for benchmark submission
309 | output['save_submission'] = file_index in self.idx_to_visualize
310 | output['visualize'] = self.visualize_samples
311 |
312 | for i in range(len(names)):
313 | event_data = self.event_slicer.get_events(ts_start[i], ts_end[i])
314 |
315 | p = event_data['p']
316 | t = event_data['t']
317 | x = event_data['x']
318 | y = event_data['y']
319 |
320 | xy_rect = self.rectify_events(x, y)
321 | x_rect = xy_rect[:, 0]
322 | y_rect = xy_rect[:, 1]
323 |
324 | if crop_window is not None:
325 | # Cropping (+- 2 for safety reasons)
326 | x_mask = (x_rect >= crop_window['start_x']-2) & (x_rect < crop_window['start_x']+crop_window['crop_width']+2)
327 | y_mask = (y_rect >= crop_window['start_y']-2) & (y_rect < crop_window['start_y']+crop_window['crop_height']+2)
328 | mask_combined = x_mask & y_mask
329 | p = p[mask_combined]
330 | t = t[mask_combined]
331 | x_rect = x_rect[mask_combined]
332 | y_rect = y_rect[mask_combined]
333 |
334 | if self.voxel_grid is None:
335 | raise NotImplementedError
336 | else:
337 | event_representation = self.events_to_voxel_grid(p, t, x_rect, y_rect)
338 | output[names[i]] = event_representation
339 | output['name_map']=self.name_idx
340 | return output
341 |
342 | def __getitem__(self, idx):
343 | sample = self.get_data_sample(idx)
344 | return sample
345 |
346 |
347 | class SequenceRecurrent(Sequence):
348 | def __init__(self, seq_path: Path, representation_type: RepresentationType, mode: str='test', delta_t_ms: int=100,
349 | num_bins: int=15, transforms=None, sequence_length=1, name_idx=0, visualize=False):
350 | super(SequenceRecurrent, self).__init__(seq_path, representation_type, mode, delta_t_ms, transforms=transforms,
351 | name_idx=name_idx, visualize=visualize)
352 | self.sequence_length = sequence_length
353 | self.valid_indices = self.get_continuous_sequences()
354 |
355 | def get_continuous_sequences(self):
356 | continuous_seq_idcs = []
357 | if self.sequence_length > 1:
358 | for i in range(len(self.timestamps_flow)-self.sequence_length+1):
359 | diff = self.timestamps_flow[i+self.sequence_length-1] - self.timestamps_flow[i]
360 | if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]):
361 | continuous_seq_idcs.append(i)
362 | else:
363 | for i in range(len(self.timestamps_flow)-1):
364 | diff = self.timestamps_flow[i+1] - self.timestamps_flow[i]
365 | if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]):
366 | continuous_seq_idcs.append(i)
367 | return continuous_seq_idcs
368 |
369 | def __len__(self):
370 | return len(self.valid_indices)
371 |
372 | def __getitem__(self, idx):
373 | assert idx >= 0
374 | assert idx < len(self)
375 |
376 | # Valid index is the actual index we want to load, which guarantees a continuous sequence length
377 | valid_idx = self.valid_indices[idx]
378 |
379 | sequence = []
380 | j = valid_idx
381 |
382 | ts_cur = self.timestamps_flow[j]
383 | # Add first sample
384 | sample = self.get_data_sample(j)
385 | sequence.append(sample)
386 |
387 | # Data augmentation according to first sample
388 | crop_window = None
389 | flip = None
390 | if 'crop_window' in sample.keys():
391 | crop_window = sample['crop_window']
392 | if 'flipped' in sample.keys():
393 | flip = sample['flipped']
394 |
395 | for i in range(self.sequence_length-1):
396 | j += 1
397 | ts_old = ts_cur
398 | ts_cur = self.timestamps_flow[j]
399 | assert(ts_cur-ts_old < 100000 + 1000)
400 | sample = self.get_data_sample(j, crop_window=crop_window, flip=flip)
401 | sequence.append(sample)
402 |
403 | # Check if the current sample is the first sample of a continuous sequence
404 | if idx==0 or self.valid_indices[idx]-self.valid_indices[idx-1] != 1:
405 | sequence[0]['new_sequence'] = 1
406 | print("Timestamp {} is the first one of the next seq!".format(self.timestamps_flow[self.valid_indices[idx]]))
407 | else:
408 | sequence[0]['new_sequence'] = 0
409 | return sequence
410 |
411 | class DatasetProvider:
412 | def __init__(self, dataset_path: Path, representation_type: RepresentationType, delta_t_ms: int=100, num_bins=15,
413 | type='standard', config=None, visualize=False):
414 | test_path = dataset_path / 'test'
415 | assert dataset_path.is_dir(), str(dataset_path)
416 | assert test_path.is_dir(), str(test_path)
417 | assert delta_t_ms == 100
418 | self.config=config
419 | self.name_mapper_test = []
420 |
421 | test_sequences = list()
422 | for child in test_path.iterdir():
423 | self.name_mapper_test.append(str(child).split("/")[-1])
424 | if type == 'standard':
425 | test_sequences.append(Sequence(child, representation_type, 'test', delta_t_ms, num_bins,
426 | transforms=[],
427 | name_idx=len(self.name_mapper_test)-1,
428 | visualize=visualize))
429 | elif type == 'warm_start':
430 | test_sequences.append(SequenceRecurrent(child, representation_type, 'test', delta_t_ms, num_bins,
431 | transforms=[], sequence_length=1,
432 | name_idx=len(self.name_mapper_test)-1,
433 | visualize=visualize))
434 | else:
435 | raise Exception('Please provide a valid subtype [standard/warm_start] in config file!')
436 |
437 | self.test_dataset = torch.utils.data.ConcatDataset(test_sequences)
438 |
439 | def get_test_dataset(self):
440 | return self.test_dataset
441 |
442 |
443 | def get_name_mapping_test(self):
444 | return self.name_mapper_test
445 |
446 | def summary(self, logger):
447 | logger.write_line("================================== Dataloader Summary ====================================", True)
448 | logger.write_line("Loader Type:\t\t" + self.__class__.__name__, True)
449 | logger.write_line("Number of Voxel Bins: {}".format(self.test_dataset.datasets[0].num_bins), True)
450 |
--------------------------------------------------------------------------------
/loader/loader_mvsec_flow.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import torch
3 | from matplotlib import pyplot as plt
4 | from torch.utils.data import Dataset
5 | import os
6 | import numpy
7 | from utils import filename_templates as TEMPLATES
8 | from loader.utils import *
9 | from utils.transformers import *
10 | from utils import mvsec_utils
11 | from torchvision import transforms
12 |
13 | class MvsecFlow(Dataset):
14 | def __init__(self, args, type, path):
15 | super(MvsecFlow, self).__init__()
16 | #self.data_files = self.get_files(args['path'], args['datasets'])
17 | self.path_dataset = path
18 | self.timestamp_files = {}
19 | self.timestamp_files_flow = {}
20 | # If we load the image timestamps, we consider the framerate to be 45Hz.
21 | # Else if we load the depth/flow timestamps, the framerate is 20Hz.
22 | # The update rate gets set to 20 or 40 in the "get indices" method
23 | self.update_rate = None
24 | self.dataset = self.get_indices(path, args['datasets'], args['filter'], args['align_to'])
25 | self.input_type = 'events'
26 | self.type = type # Train/Val/Test
27 |
28 | # Evaluation Type. Dense -> Valid where GT exists
29 | # Sparse -> Valid where GT & Events exist
30 | self.evaluation_type = 'dense'
31 |
32 | self.image_width = 346
33 | self.image_height = 260
34 |
35 | self.voxel = EventSequenceToVoxelGrid_Pytorch(
36 | num_bins=args['num_voxel_bins'],
37 | normalize=True,
38 | gpu=True
39 | )
40 | self.cropper = transforms.CenterCrop((256,256))
41 |
42 | def summary(self, logger):
43 | logger.write_line("================================== Dataloader Summary ====================================", True)
44 | logger.write_line("Loader Type:\t\t" + self.__class__.__name__ + " for {}".format(self.type), True)
45 | logger.write_line("Framerate:\t\t{}".format(self.update_rate), True)
46 | logger.write_line("Evaluation Type:\t{}".format(self.evaluation_type), True)
47 |
48 | def get_indices(self, path, dataset, filter, align_to):
49 | # Returns a list of dicts. Each dict contains the following items:
50 | # ['dataset_name'] (e.g. outdoor_day)
51 | # ['subset_number'] (e.g. 1)
52 | # ['index'] (e.g. 1), Frame Index in the dataset
53 | # ['timestamp'] Timestamp of the frame with index i
54 | samples = []
55 | for dataset_name in dataset:
56 | self.timestamp_files[dataset_name] = {}
57 | self.timestamp_files_flow[dataset_name] = {}
58 | for subset in dataset[dataset_name]:
59 | dataset_path = TEMPLATES.MVSEC_DATASET_FOLDER.format(dataset_name, subset)
60 |
61 | # Timestamps of DEPTH image
62 | if align_to.lower() == 'images' or align_to.lower() == 'image':
63 | print("Aligning everything to the image timestamps!")
64 | ts_path = TEMPLATES.MVSEC_TIMESTAMPS_PATH_IMAGES
65 | if self.update_rate is not None and self.update_rate != 45:
66 | raise Exception('Something wrong with the update rate!')
67 | self.update_rate = 45
68 | self.timestamp_files_flow[dataset_name][subset] = numpy.loadtxt(os.path.join(path,
69 | dataset_path,
70 | TEMPLATES.MVSEC_TIMESTAMPS_PATH_FLOW))
71 | elif align_to.lower() == 'depth':
72 | print("Aligning everything to the depth timestamps!")
73 | ts_path = TEMPLATES.MVSEC_TIMESTAMPS_PATH_DEPTH
74 | if self.update_rate is not None and self.update_rate != 20:
75 | raise Exception('Something wrong with the update rate!')
76 | self.update_rate = 20
77 | elif align_to.lower() == 'flow':
78 | print("Aligning everything to the flow timestamps!")
79 | ts_path = TEMPLATES.MVSEC_TIMESTAMPS_PATH_FLOW
80 | if self.update_rate is not None and self.update_rate != 20:
81 | raise Exception('Something wrong with the update rate!')
82 | self.update_rate = 20
83 | else:
84 | raise ValueError("Please define the variable 'align_to' in the dataset [image/depth/flow]")
85 | ts = numpy.loadtxt(os.path.join(path, dataset_path, ts_path))
86 | self.timestamp_files[dataset_name][subset] = ts
87 | for idx in eval(filter[dataset_name][str(subset)]):
88 | sample = {}
89 | sample['dataset_name'] = dataset_name
90 | sample['subset_number'] = subset
91 | sample['index'] = idx
92 | sample['timestamp'] = ts[idx]
93 | samples.append(sample)
94 |
95 | return samples
96 |
97 | def get_data_sample(self, loader_idx):
98 | # ================================= Get Data Sample =============================== #
99 | # Returns dict with the following content: #
100 | # - Event Sequence New (if type == 'events') #
101 | # - Event Sequence Old (if type == 'events') #
102 | # - Optical Flow (forward) between two timesteps as defined below #
103 | # - Timestamp and some other params #
104 | # #
105 | # Nomenclature Definition #
106 | # #
107 | # NOTE THAT THIS IS A DIFFERENT NAMING SCHEME THAN IN THE OTHER DATASETS! #
108 | # #
109 | # Flow[i-1] Flow[i] Flow[i+1] Flow[i+2] #
110 | # Depth[i-1] Depth[i] Depth[i+1] Depth[i+2]
111 | # | . | . | . . | #
112 | # | . . | . | . . | #
113 | # | . | .. . | ... | #
114 | # | . | . | . . | #
115 | # Events[i] Events[i+1] Events[i+2] #
116 | #
117 | # Flow[i] tells us the flow between Depth[i] and Depth[i+1] #
118 | # This can be seen because the pixels of flow[i] are the same as depth[i] #
119 | # We are for now using the events aligned to the depth-timestamps. #
120 | # This means, to get the flow between Depth[i] and Depth[i+1], we need to load #
121 | # - Flow[i]
122 | # - Events[i+1]
123 | # - Events[i] (if using volumetric cost volumes)
124 | # - Timestamps (from depth) [i]
125 | # - Timestamps (from depth) [i+1]
126 | set = self.dataset[loader_idx]['dataset_name']
127 | subset = self.dataset[loader_idx]['subset_number']
128 | path_subset = TEMPLATES.MVSEC_DATASET_FOLDER.format(set, subset)
129 | path_dataset = os.path.join(self.path_dataset, path_subset)
130 | # params = self.config[self.dataset[loader_idx]['dataset_name']][self.dataset[loader_idx]['subset_number']]
131 | idx = self.dataset[loader_idx]['index']
132 | type = self.input_type
133 |
134 | # If the update rate is 20 Hz (i.e. we're aligned to the depth/flow maps), we can directly take the flow gt
135 | ts_old = self.timestamp_files[set][subset][idx]
136 | ts_new = self.timestamp_files[set][subset][idx + 1]
137 |
138 | if self.update_rate == 20:
139 | flow = get_flow_npy(os.path.join(path_dataset,TEMPLATES.MVSEC_FLOW_GT_FILE.format(idx)))
140 | # Else, we need to interpolate the flow
141 | elif self.update_rate == 45:
142 | flow = self.estimate_gt_flow(loader_idx, ts_old, ts_new)
143 | else:
144 | raise NotImplementedError
145 |
146 |
147 | # Either flow_x or flow_y has to be != 0 s.t. the flow is valid
148 | flow_valid = (flow[0]!=0) | (flow[1] != 0)
149 | # Additionally, the car hood (that goes from row 193..260 is not included in the GT. so this is invalid too.
150 | flow_valid[193:,:]=False
151 |
152 | return_dict = {'idx': idx,
153 | 'loader_idx': loader_idx,
154 | 'flow': torch.from_numpy(flow),
155 | 'gt_valid_mask': torch.from_numpy(numpy.stack([flow_valid]*2, axis=0)),
156 | "param_evc": {'height': self.image_height,
157 | 'width': self.image_width}
158 | }
159 |
160 |
161 | # Load Events
162 | if type == 'events':
163 | event_path_old = os.path.join(path_dataset, TEMPLATES.MVSEC_EVENTS_FILE.format('left', idx))
164 | event_path_new = os.path.join(path_dataset, TEMPLATES.MVSEC_EVENTS_FILE.format('left', idx+1))
165 | params = {'height': self.image_height, 'width': self.image_width}
166 |
167 | events_old = get_events(event_path_old)
168 | events_new = get_events(event_path_new)
169 |
170 | # Timestamp multiplier of 1e6 because the timestamps are saved as seconds and we're used to microseconds
171 | # This can be relevant for the voxel grid!
172 | ev_seq_old = EventSequence(events_old, params, timestamp_multiplier=1e6, convert_to_relative=True)
173 | ev_seq_new = EventSequence(events_new, params, timestamp_multiplier=1e6, convert_to_relative=True)
174 | return_dict['event_volume_new'] = self.voxel(ev_seq_new)
175 | return_dict['event_volume_old'] = self.voxel(ev_seq_old)
176 | if self.evaluation_type == 'sparse':
177 | seq = ev_seq_new.get_sequence_only()
178 | h = self.image_height
179 | w = self.image_width
180 | hist, _, _ = numpy.histogram2d(x=seq[:,1], y=seq[:,2],
181 | bins=(w,h),
182 | range=[[0,w], [0,h]])
183 | hist = hist.transpose()
184 | ev_mask = hist > 0
185 | return_dict['gt_valid_mask'] = torch.from_numpy(numpy.stack([flow_valid & ev_mask]*2, axis=0))
186 | elif type == 'frames':
187 | raise NotImplementedError
188 | else:
189 | raise Exception("Input Type not defined properly! Check config file.")
190 |
191 | # Check Timestamps
192 | ev = get_events(event_path_new).to_numpy()
193 | ts_ev_min = numpy.min(ev[:,0])
194 | ts_ev_max = numpy.max(ev[:,0])
195 | assert(ts_ev_min > ts_old and ts_ev_max <= ts_new)
196 |
197 | # plot images
198 | '''
199 | from utils import visualization as visu
200 | from matplotlib import pyplot as plt
201 |
202 | # Justifying my choice of alignment:
203 | # 1) Flow[i] corresponds to Depth[i]
204 | depth_i = torch.tensor(numpy.load(os.path.join(path_dataset, TEMPLATES.MVSEC_DEPTH_GT_FILE.format(idx))))
205 | plt.figure("depth i")
206 | plt.imshow(depth_i.numpy())
207 |
208 | flow_visu = visu.visualize_optical_flow(flow, return_image=True)[0]
209 | plt.figure('Flow i')
210 | plt.imshow(flow_visu)
211 |
212 | # 2) The events are aligned to the depth
213 | # -> events[i] correspond to all events BEFORE depth i
214 | # -> events[i+1] correspond to all events AFTER depth i
215 | # This can be proven by the timestamps timestamp[i] corresponding to depth[i]
216 | ts_old = self.timestamp_files[set][subset][idx]
217 | ts_new = self.timestamp_files[set][subset][idx + 1]
218 |
219 |
220 | ev = get_events(event_path_new).to_numpy()
221 | ts_ev_min = numpy.min(ev[:,0])
222 | ts_ev_max = numpy.max(ev[:,0])
223 | assert(ts_ev_min > ts_old and ts_ev_max <= ts_new)
224 |
225 | # -> Additionally, we can show this, if we plot the events of the first 5ms of events before the depth map
226 | # Remember: events[i] are all the events BEFORE the depth[i]
227 |
228 | event_path_i = os.path.join(path_dataset, TEMPLATES.MVSEC_EVENTS_FILE.format('left', idx))
229 | ev_i = get_events(event_path_i).to_numpy()
230 | ts_i = self.timestamp_files[set][subset][idx]
231 | ev_inst_idx = ev_i[:,0] > ts_i - 0.005
232 | ev_inst = ev_i[ev_inst_idx]
233 | evv = visu.events_to_event_image(ev_inst, self.image_height, self.image_width)
234 | plt.figure("events_instantaneous")
235 | plt.imshow(evv.numpy().transpose(1,2,0))
236 | # This should now match the depth_i
237 |
238 | # Hence, all misalignments are coming from the ground-truth itself.
239 | '''
240 |
241 | return return_dict
242 |
243 | def estimate_gt_flow(self, loader_idx, ts_old, ts_new):
244 | # We need to estimate the flow between two timestamps.
245 |
246 | # First, get the dataset & subset
247 | set = self.dataset[loader_idx]['dataset_name']
248 | subset = self.dataset[loader_idx]['subset_number']
249 | path_flow = os.path.join(self.path_dataset,
250 | TEMPLATES.MVSEC_DATASET_FOLDER.format(set, subset))
251 |
252 | assert ts_old >= self.timestamp_files_flow[set][subset].min(), \
253 | 'Timestamp is smaller than the first flow timestamp'
254 |
255 | # Now, estimate the corresponding GT
256 | flow = mvsec_utils.estimate_corresponding_gt_flow(path_flow=path_flow,
257 | gt_timestamps=self.timestamp_files_flow[set][subset],
258 | start_time=ts_old,
259 | end_time=ts_new)
260 | # flow is a tuple of [H,W]. Stack it
261 | return numpy.stack(flow)
262 |
263 | @staticmethod
264 | def mvsec_time_conversion(timestamps):
265 | raise NotImplementedError
266 |
267 | def get_ts(self, path, i):
268 | try:
269 | f = open(path, "r")
270 | return float(f.readlines()[i])
271 | except OSError:
272 | raise
273 |
274 | def get_image_width_height(self, type='event_camera'):
275 | if hasattr(self, 'cropper'):
276 | h = self.cropper.size[0]
277 | w = self.cropper.size[1]
278 | return h, w
279 | return self.image_height, self.image_width
280 |
281 | def get_events(self, loader_idx):
282 | # Get Events For Visualization Only!!!
283 | path_dataset = os.path.join(self.path_dataset,self.dataset[loader_idx]['dataset_name'] + "_" + str(self.dataset[loader_idx]['subset_number']))
284 | params = {'height': self.image_height, 'width': self.image_width}
285 | i = self.dataset[loader_idx]['index']
286 | path = os.path.join(path_dataset, TEMPLATES.MVSEC_EVENTS_FILE.format('left', i+1))
287 | events = EventSequence(get_events(path), params).get_sequence_only()
288 | return events
289 |
290 | def __len__(self):
291 | return len(self.dataset)
292 |
293 | def __getitem__(self, idx, force_crop_window=None, force_flipping=None):
294 | if idx >= len(self):
295 | raise IndexError
296 | sample = self.get_data_sample(idx)
297 | # Center Crop Everything
298 | sample['flow'] = self.cropper(sample['flow'])
299 | sample['gt_valid_mask'] = self.cropper(sample['gt_valid_mask'])
300 | sample['event_volume_new'] = self.cropper(sample['event_volume_new'])
301 | sample['event_volume_old'] = self.cropper(sample['event_volume_old'])
302 |
303 | return sample
304 |
305 | class MvsecFlowRecurrent(Dataset):
306 | def __init__(self, args, type, path):
307 | super(MvsecFlowRecurrent, self).__init__()
308 | if type.lower() != 'test':
309 | self.sequence_length = args['sequence_length']
310 | else:
311 | self.sequence_length = 1
312 | self.step_size = 1
313 | self.dataset = MvsecFlow(args, type, path=path)
314 |
315 | def __len__(self):
316 | return (len(self.dataset) - self.sequence_length) // self.step_size + 1
317 |
318 | def __getitem__(self, idx):
319 | # ----------------------------------------------------------------------------- #
320 | # Returns a list, containing of event/frame/flow #
321 | # [ e_(i-sequence_length), ..., e_(i) ] #
322 | # ----------------------------------------------------------------------------- #
323 | assert(idx >= 0)
324 | assert(idx < len(self))
325 | sequence = []
326 | j = idx * self.step_size
327 |
328 | flip = None
329 | crop_window = None
330 |
331 | for i in range(self.sequence_length):
332 | sequence.append(self.dataset.__getitem__(j + i, force_crop_window=crop_window, force_flipping=flip))
333 | # Just Making Sure
334 | assert sequence[-1]['idx']-sequence[0]['idx'] == self.sequence_length-1
335 | return sequence
336 |
337 | def summary(self, logger):
338 | logger.write_line("================================== Dataloader Summary ====================================", True)
339 | logger.write_line("Loader Type:\t\t" + self.__class__.__name__ + " for {}".format(self.dataset.type), True)
340 | logger.write_line("Sequence Length:\t{}".format(self.sequence_length), True)
341 | logger.write_line("Step Size:\t\t{}".format(self.step_size), True)
342 | logger.write_line("Framerate:\t\t{}".format(self.dataset.update_rate), True)
343 |
344 | def get_image_width_height(self, type='event_camera'):
345 | return self.dataset.get_image_width_height(type)
346 |
347 | def get_events(self, loader_idx):
348 | return self.dataset.get_events(loader_idx)
349 |
--------------------------------------------------------------------------------
/loader/utils.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import os
3 | import pandas
4 | from PIL import Image
5 | import random
6 | import torch
7 | from itertools import chain
8 | import h5py
9 | import json
10 |
11 |
12 | class EventSequence(object):
13 | def __init__(self, dataframe, params, features=None, timestamp_multiplier=None, convert_to_relative=False):
14 | if isinstance(dataframe, pandas.DataFrame):
15 | self.feature_names = dataframe.columns.values
16 | self.features = dataframe.to_numpy()
17 | else:
18 | self.feature_names = numpy.array(['ts', 'x', 'y', 'p'], dtype=object)
19 | if features is None:
20 | self.features = numpy.zeros([1, 4])
21 | else:
22 | self.features = features
23 | self.image_height = params['height']
24 | self.image_width = params['width']
25 | if not self.is_sorted():
26 | self.sort_by_timestamp()
27 | if timestamp_multiplier is not None:
28 | self.features[:,0] *= timestamp_multiplier
29 | if convert_to_relative:
30 | self.absolute_time_to_relative()
31 |
32 | def get_sequence_only(self):
33 | return self.features
34 |
35 | def __len__(self):
36 | return len(self.features)
37 |
38 | def __add__(self, sequence):
39 | event_sequence = EventSequence(dataframe=None,
40 | features=numpy.concatenate([self.features, sequence.features]),
41 | params={'height': self.image_height,
42 | 'width': self.image_width})
43 | return event_sequence
44 |
45 | def is_sorted(self):
46 | return numpy.all(self.features[:-1, 0] <= self.features[1:, 0])
47 |
48 | def sort_by_timestamp(self):
49 | if len(self.features[:, 0]) > 0:
50 | sort_indices = numpy.argsort(self.features[:, 0])
51 | self.features = self.features[sort_indices]
52 |
53 | def absolute_time_to_relative(self):
54 | """Transforms absolute time to time relative to the first event."""
55 | start_ts = self.features[:,0].min()
56 | assert(start_ts == self.features[0,0])
57 | self.features[:,0] -= start_ts
58 |
59 |
60 | def get_image(image_path):
61 | try:
62 | im = Image.open(image_path)
63 | # print(image_path)
64 | return numpy.array(im)
65 | except OSError:
66 | raise
67 |
68 |
69 | def get_events(event_path):
70 | # It's possible that there is no event file! (camera standing still)
71 | try:
72 | f = pandas.read_hdf(event_path, "myDataset")
73 | return f[['ts', 'x', 'y', 'p']]
74 | except OSError:
75 | print("No file " + event_path)
76 | print("Creating an array of zeros!")
77 | return 0
78 |
79 |
80 | def get_ts(path, i, type='int'):
81 | try:
82 | f = open(path, "r")
83 | if type == 'int':
84 | return int(f.readlines()[i])
85 | elif type == 'double' or type == 'float':
86 | return float(f.readlines()[i])
87 | except OSError:
88 | raise
89 |
90 |
91 | def get_batchsize(path_dataset):
92 | filepath = os.path.join(path_dataset, "cam0", "timestamps.txt")
93 | try:
94 | f = open(filepath, "r")
95 | return len(f.readlines())
96 | except OSError:
97 | raise
98 |
99 |
100 | def get_batch(path_dataset, i):
101 | return 0
102 |
103 |
104 | def dataset_paths(dataset_name, path_dataset, subset_number=None):
105 | cameras = {'cam0': {}, 'cam1': {}, 'cam2': {}, 'cam3': {}}
106 | if subset_number is not None:
107 | dataset_name = dataset_name + "_" + str(subset_number)
108 | paths = {'dataset_folder': os.path.join(path_dataset, dataset_name)}
109 |
110 | # For every camera, define its path
111 | for camera in cameras:
112 | cameras[camera]['image_folder'] = os.path.join(paths['dataset_folder'], camera, 'image_raw')
113 | cameras[camera]['event_folder'] = os.path.join(paths['dataset_folder'], camera, 'events')
114 | cameras[camera]['disparity_folder'] = os.path.join(paths['dataset_folder'], camera, 'disparity_image')
115 | cameras[camera]['depth_folder'] = os.path.join(paths['dataset_folder'], camera, 'depthmap')
116 | cameras["timestamp_file"] = os.path.join(paths['dataset_folder'], 'cam0', 'timestamps.txt')
117 | cameras["image_type"] = ".png"
118 | cameras["event_type"] = ".h5"
119 | cameras["disparity_type"] = ".png"
120 | cameras["depth_type"] = ".tiff"
121 | cameras["indexing_type"] = "%0.6i"
122 | paths.update(cameras)
123 | return paths
124 |
125 |
126 | def get_indices(path_dataset, dataset, filter, shuffle=False):
127 | samples = []
128 | for dataset_name in dataset:
129 | for subset in dataset[dataset_name]:
130 | # Get all the dataframe paths
131 | paths = dataset_paths(dataset_name, path_dataset, subset)
132 |
133 | # import timestamps
134 | ts = numpy.loadtxt(paths["timestamp_file"])
135 |
136 | # frames = []
137 | # For every timestamp, import according data
138 | for idx in eval(filter[dataset_name][str(subset)]):
139 | frame = {}
140 | frame['dataset_name'] = dataset_name
141 | frame['subset_number'] = subset
142 | frame['index'] = idx
143 | frame['timestamp'] = ts[idx]
144 | samples.append(frame)
145 | # shuffle dataset
146 | if shuffle:
147 | random.shuffle(samples)
148 | return samples
149 |
150 |
151 | def get_flow_h5(flow_path):
152 | scaling_factor = 0.05 # seconds/frame
153 | f = h5py.File(flow_path, 'r')
154 | height, width = int(f['header']['height']), int(f['header']['width'])
155 | assert(len(f['x']) == height*width)
156 | assert(len(f['y']) == height*width)
157 | x = numpy.array(f['x']).reshape([height,width])*scaling_factor
158 | y = numpy.array(f['y']).reshape([height,width])*scaling_factor
159 | return numpy.stack([x,y])
160 |
161 |
162 | def get_flow_npy(flow_path):
163 | # Array 2,height, width
164 | # No scaling needed.
165 | return numpy.load(flow_path, allow_pickle=True)
166 |
167 |
168 | def get_pose(pose_path, index):
169 | pose = pandas.read_csv(pose_path, delimiter=',').loc[index].to_numpy()
170 | # Convert Timestamp to int (as all the other timestamps)
171 | pose[0] = int(pose[0])
172 | return pose
173 |
174 |
175 | def load_config(path, datasets):
176 | config = {}
177 | for dataset_name in datasets:
178 | config[dataset_name] = {}
179 | for subset in datasets[dataset_name]:
180 | name = "{}_{}".format(dataset_name, subset)
181 | try:
182 | config[dataset_name][subset] = json.load(open(os.path.join(path, name, "config.json")))
183 | except:
184 | print("Could not find config file for dataset" + dataset_name + "_" + str(subset) +
185 | ". Please check if the file 'config.json' is existing in the dataset-scene directory")
186 | raise
187 | return config
188 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["MKL_NUM_THREADS"] = "1"
3 | os.environ["NUMEXPR_NUM_THREADS"] = "1"
4 | os.environ["OMP_NUM_THREADS"] = "1"
5 | os.environ["NUMEXPR_MAX_THREADS"]="1"
6 | from loader.loader_mvsec_flow import *
7 | from loader.loader_dsec import *
8 | from utils.logger import *
9 | import utils.helper_functions as helper
10 | import json
11 | from torch.utils.data.dataloader import DataLoader
12 | from utils import visualization as visu
13 | import argparse
14 | from test import *
15 | import git
16 | import torch.nn
17 | from model import eraft
18 |
19 | def initialize_tester(config):
20 | # Warm Start
21 | if config['subtype'].lower() == 'warm_start':
22 | return TestRaftEventsWarm
23 | # Classic
24 | else:
25 | return TestRaftEvents
26 |
27 | def get_visualizer(args):
28 | # DSEC dataset
29 | if args.dataset.lower() == 'dsec':
30 | return visualization.DsecFlowVisualizer
31 | # MVSEC dataset
32 | else:
33 | return visualization.FlowVisualizerEvents
34 |
35 | def test(args):
36 | # Choose correct config file
37 | if args.dataset.lower()=='dsec':
38 | if args.type.lower()=='warm_start':
39 | config_path = 'config/dsec_warm_start.json'
40 | elif args.type.lower()=='standard':
41 | config_path = 'config/dsec_standard.json'
42 | else:
43 | raise Exception('Please provide a valid argument for --type. [warm_start/standard]')
44 | elif args.dataset.lower()=='mvsec':
45 | if args.frequency==20:
46 | config_path = 'config/mvsec_20.json'
47 | elif args.frequency==45:
48 | config_path = 'config/mvsec_45.json'
49 | else:
50 | raise Exception('Please provide a valid argument for --frequency. [20/45]')
51 | if args.type=='standard':
52 | raise NotImplementedError('Sorry, this is not implemented yet, please choose --type warm_start')
53 | else:
54 | raise Exception('Please provide a valid argument for --dataset. [dsec/mvsec]')
55 |
56 |
57 | # Load config file
58 | config = json.load(open(config_path))
59 | # Create Save Folder
60 | save_path = helper.create_save_path(config['save_dir'].lower(), config['name'].lower())
61 | print('Storing output in folder {}'.format(save_path))
62 | # Copy config file to save dir
63 | json.dump(config, open(os.path.join(save_path, 'config.json'), 'w'),
64 | indent=4, sort_keys=False)
65 | # Logger
66 | logger = Logger(save_path)
67 | logger.initialize_file("test")
68 |
69 | # Instantiate Dataset
70 | # Case: DSEC Dataset
71 | additional_loader_returns = None
72 | if args.dataset.lower() == 'dsec':
73 | # Dsec Dataloading
74 | loader = DatasetProvider(
75 | dataset_path=Path(args.path),
76 | representation_type=RepresentationType.VOXEL,
77 | delta_t_ms=100,
78 | config=config,
79 | type=config['subtype'].lower(),
80 | visualize=args.visualize)
81 | loader.summary(logger)
82 | test_set = loader.get_test_dataset()
83 | additional_loader_returns = {'name_mapping_test': loader.get_name_mapping_test()}
84 |
85 | # Case: MVSEC Dataset
86 | else:
87 | if config['subtype'].lower() == 'standard':
88 | test_set = MvsecFlow(
89 | args = config["data_loader"]["test"]["args"],
90 | type='test',
91 | path=args.path
92 | )
93 | elif config['subtype'].lower() == 'warm_start':
94 | test_set = MvsecFlowRecurrent(
95 | args = config["data_loader"]["test"]["args"],
96 | type='test',
97 | path=args.path
98 | )
99 | else:
100 | raise NotImplementedError
101 | test_set.summary(logger)
102 |
103 | # Instantiate Dataloader
104 | test_set_loader = DataLoader(test_set,
105 | batch_size=config['data_loader']['test']['args']['batch_size'],
106 | shuffle=config['data_loader']['test']['args']['shuffle'],
107 | num_workers=args.num_workers,
108 | drop_last=True)
109 |
110 | # Load Model
111 | model = eraft.ERAFT(
112 | config=config,
113 | n_first_channels=config['data_loader']['test']['args']['num_voxel_bins']
114 | )
115 | # Load Checkpoint
116 | checkpoint = torch.load(config['test']['checkpoint'])
117 | model.load_state_dict(checkpoint['model'])
118 |
119 | # Get Visualizer
120 | visualizer = get_visualizer(args)
121 |
122 | # Initialize Tester
123 | test = initialize_tester(config)
124 |
125 | test = test(
126 | model=model,
127 | config=config,
128 | data_loader=test_set_loader,
129 | test_logger=logger,
130 | save_path=save_path,
131 | visualizer=visualizer,
132 | additional_args=additional_loader_returns
133 | )
134 |
135 | test.summary()
136 | test._test()
137 |
138 | if __name__ == '__main__':
139 | config_path = "config/config_test.json"
140 | # Argument Parser
141 | parser = argparse.ArgumentParser()
142 | parser.add_argument('-p', '--path', type=str, help="Dataset path", required=True)
143 | parser.add_argument('-d', '--dataset', default="dsec", type=str, help="Which dataset to use: ([dsec]/mvsec)")
144 | parser.add_argument('-f', '--frequency', default=20, type=int, help="Evaluation frequency of MVSEC dataset ([20]/45) Hz")
145 | parser.add_argument('-t', '--type', default='warm_start', type=str, help="Evaluation type ([warm_start]/standard)")
146 | parser.add_argument('-v', '--visualize', action='store_true', help='Provide this argument s.t. DSEC results are visualized. MVSEC experiments are always visualized.')
147 | parser.add_argument('-n', '--num_workers', default=0, type=int, help='How many sub-processes to use for data loading')
148 | args = parser.parse_args()
149 |
150 | # Run Test Script
151 | test(args)
152 |
--------------------------------------------------------------------------------
/model/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from model.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1)
38 | dy = torch.linspace(-r, r, 2*r+1)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
--------------------------------------------------------------------------------
/model/eraft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .update import BasicUpdateBlock
7 | from .extractor import BasicEncoder
8 | from .corr import CorrBlock
9 | from model.utils import coords_grid, upflow8
10 | from argparse import Namespace
11 | from utils.image_utils import ImagePadder
12 |
13 | try:
14 | autocast = torch.cuda.amp.autocast
15 | except:
16 | # dummy autocast for PyTorch < 1.6
17 | class autocast:
18 | def __init__(self, enabled):
19 | pass
20 | def __enter__(self):
21 | pass
22 | def __exit__(self, *args):
23 | pass
24 |
25 |
26 | def get_args():
27 | # This is an adapter function that converts the arguments given in out config file to the format, which the ERAFT
28 | # expects.
29 | args = Namespace(small=False,
30 | dropout=False,
31 | mixed_precision=False,
32 | clip=1.0)
33 | return args
34 |
35 |
36 |
37 | class ERAFT(nn.Module):
38 | def __init__(self, config, n_first_channels):
39 | # args:
40 | super(ERAFT, self).__init__()
41 | args = get_args()
42 | self.args = args
43 | self.image_padder = ImagePadder(min_size=32)
44 | self.subtype = config['subtype'].lower()
45 |
46 | assert (self.subtype == 'standard' or self.subtype == 'warm_start')
47 |
48 | self.hidden_dim = hdim = 128
49 | self.context_dim = cdim = 128
50 | args.corr_levels = 4
51 | args.corr_radius = 4
52 |
53 | # feature network, context network, and update block
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=0,
55 | n_first_channels=n_first_channels)
56 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=0,
57 | n_first_channels=n_first_channels)
58 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
59 |
60 | def freeze_bn(self):
61 | for m in self.modules():
62 | if isinstance(m, nn.BatchNorm2d):
63 | m.eval()
64 |
65 | def initialize_flow(self, img):
66 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
67 | N, C, H, W = img.shape
68 | coords0 = coords_grid(N, H//8, W//8).to(img.device)
69 | coords1 = coords_grid(N, H//8, W//8).to(img.device)
70 |
71 | # optical flow computed as difference: flow = coords1 - coords0
72 | return coords0, coords1
73 |
74 | def upsample_flow(self, flow, mask):
75 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
76 | N, _, H, W = flow.shape
77 | mask = mask.view(N, 1, 9, 8, 8, H, W)
78 | mask = torch.softmax(mask, dim=2)
79 |
80 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
81 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
82 |
83 | up_flow = torch.sum(mask * up_flow, dim=2)
84 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
85 | return up_flow.reshape(N, 2, 8*H, 8*W)
86 |
87 |
88 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True):
89 | """ Estimate optical flow between pair of frames """
90 | # Pad Image (for flawless up&downsampling)
91 | image1 = self.image_padder.pad(image1)
92 | image2 = self.image_padder.pad(image2)
93 |
94 | image1 = image1.contiguous()
95 | image2 = image2.contiguous()
96 |
97 | hdim = self.hidden_dim
98 | cdim = self.context_dim
99 |
100 | # run the feature network
101 | with autocast(enabled=self.args.mixed_precision):
102 | fmap1, fmap2 = self.fnet([image1, image2])
103 |
104 | fmap1 = fmap1.float()
105 | fmap2 = fmap2.float()
106 |
107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108 |
109 | # run the context network
110 | with autocast(enabled=self.args.mixed_precision):
111 | if (self.subtype == 'standard' or self.subtype == 'warm_start'):
112 | cnet = self.cnet(image2)
113 | else:
114 | raise Exception
115 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
116 | net = torch.tanh(net)
117 | inp = torch.relu(inp)
118 |
119 | # Initialize Grids. First channel: x, 2nd channel: y. Image is just used to get the shape
120 | coords0, coords1 = self.initialize_flow(image1)
121 |
122 | if flow_init is not None:
123 | coords1 = coords1 + flow_init
124 |
125 | flow_predictions = []
126 | for itr in range(iters):
127 | coords1 = coords1.detach()
128 | corr = corr_fn(coords1) # index correlation volume
129 |
130 | flow = coords1 - coords0
131 | with autocast(enabled=self.args.mixed_precision):
132 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
133 |
134 | # F(t+1) = F(t) + \Delta(t)
135 | coords1 = coords1 + delta_flow
136 |
137 | # upsample predictions
138 | if up_mask is None:
139 | flow_up = upflow8(coords1 - coords0)
140 | else:
141 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
142 |
143 | flow_predictions.append(self.image_padder.unpad(flow_up))
144 |
145 | return coords1 - coords0, flow_predictions
146 |
--------------------------------------------------------------------------------
/model/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | # from model.network_blocks import RecurrentResidualBlock
5 | from torch.nn import init
6 |
7 | class ResidualBlock(nn.Module):
8 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
9 | super(ResidualBlock, self).__init__()
10 |
11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
13 | self.relu = nn.ReLU(inplace=True)
14 |
15 | num_groups = planes // 8
16 |
17 | if norm_fn == 'group':
18 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
20 | if not stride == 1:
21 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
22 |
23 | elif norm_fn == 'batch':
24 | self.norm1 = nn.BatchNorm2d(planes)
25 | self.norm2 = nn.BatchNorm2d(planes)
26 | if not stride == 1:
27 | self.norm3 = nn.BatchNorm2d(planes)
28 |
29 | elif norm_fn == 'instance':
30 | self.norm1 = nn.InstanceNorm2d(planes)
31 | self.norm2 = nn.InstanceNorm2d(planes)
32 | if not stride == 1:
33 | self.norm3 = nn.InstanceNorm2d(planes)
34 |
35 | elif norm_fn == 'none':
36 | self.norm1 = nn.Sequential()
37 | self.norm2 = nn.Sequential()
38 | if not stride == 1:
39 | self.norm3 = nn.Sequential()
40 |
41 | if stride == 1:
42 | self.downsample = None
43 |
44 | else:
45 | self.downsample = nn.Sequential(
46 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
47 |
48 |
49 | def forward(self, x):
50 | y = x
51 | y = self.relu(self.norm1(self.conv1(y)))
52 | y = self.relu(self.norm2(self.conv2(y)))
53 |
54 | if self.downsample is not None:
55 | x = self.downsample(x)
56 |
57 | return self.relu(x+y)
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | num_groups = planes // 8
70 |
71 | if norm_fn == 'group':
72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 | if not stride == 1:
76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 |
78 | elif norm_fn == 'batch':
79 | self.norm1 = nn.BatchNorm2d(planes//4)
80 | self.norm2 = nn.BatchNorm2d(planes//4)
81 | self.norm3 = nn.BatchNorm2d(planes)
82 | if not stride == 1:
83 | self.norm4 = nn.BatchNorm2d(planes)
84 |
85 | elif norm_fn == 'instance':
86 | self.norm1 = nn.InstanceNorm2d(planes//4)
87 | self.norm2 = nn.InstanceNorm2d(planes//4)
88 | self.norm3 = nn.InstanceNorm2d(planes)
89 | if not stride == 1:
90 | self.norm4 = nn.InstanceNorm2d(planes)
91 |
92 | elif norm_fn == 'none':
93 | self.norm1 = nn.Sequential()
94 | self.norm2 = nn.Sequential()
95 | self.norm3 = nn.Sequential()
96 | if not stride == 1:
97 | self.norm4 = nn.Sequential()
98 |
99 | if stride == 1:
100 | self.downsample = None
101 |
102 | else:
103 | self.downsample = nn.Sequential(
104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105 |
106 |
107 | def forward(self, x):
108 | y = x
109 | y = self.relu(self.norm1(self.conv1(y)))
110 | y = self.relu(self.norm2(self.conv2(y)))
111 | y = self.relu(self.norm3(self.conv3(y)))
112 |
113 | if self.downsample is not None:
114 | x = self.downsample(x)
115 |
116 | return self.relu(x+y)
117 |
118 |
119 | class BasicEncoder(nn.Module):
120 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, n_first_channels=1):
121 | super(BasicEncoder, self).__init__()
122 | self.norm_fn = norm_fn
123 |
124 | if self.norm_fn == 'group':
125 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
126 |
127 | elif self.norm_fn == 'batch':
128 | self.norm1 = nn.BatchNorm2d(64)
129 |
130 | elif self.norm_fn == 'instance':
131 | self.norm1 = nn.InstanceNorm2d(64)
132 |
133 | elif self.norm_fn == 'none':
134 | self.norm1 = nn.Sequential()
135 |
136 | self.conv1 = nn.Conv2d(n_first_channels, 64, kernel_size=7, stride=2, padding=3)
137 | self.relu1 = nn.ReLU(inplace=True)
138 |
139 | self.in_planes = 64
140 | self.layer1 = self._make_layer(64, stride=1)
141 | self.layer2 = self._make_layer(96, stride=2)
142 | self.layer3 = self._make_layer(128, stride=2)
143 |
144 | # output convolution
145 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
146 |
147 | self.dropout = None
148 | if dropout > 0:
149 | self.dropout = nn.Dropout2d(p=dropout)
150 |
151 | for m in self.modules():
152 | if isinstance(m, nn.Conv2d):
153 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
154 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
155 | if m.weight is not None:
156 | nn.init.constant_(m.weight, 1)
157 | if m.bias is not None:
158 | nn.init.constant_(m.bias, 0)
159 |
160 | def _make_layer(self, dim, stride=1):
161 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
162 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
163 | layers = (layer1, layer2)
164 |
165 | self.in_planes = dim
166 | return nn.Sequential(*layers)
167 |
168 | def forward(self, x):
169 | # if input is list, combine batch dimension
170 | is_list = isinstance(x, tuple) or isinstance(x, list)
171 | if is_list:
172 | batch_dim = x[0].shape[0]
173 | x = torch.cat(x, dim=0)
174 |
175 | x = self.conv1(x)
176 | x = self.norm1(x)
177 | x = self.relu1(x)
178 | x = self.layer1(x)
179 | x = self.layer2(x)
180 | x = self.layer3(x)
181 |
182 | x = self.conv2(x)
183 |
184 | if self.training and self.dropout is not None:
185 | x = self.dropout(x)
186 |
187 | if is_list:
188 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
189 | return x
190 |
--------------------------------------------------------------------------------
/model/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 |
63 | class BasicMotionEncoder(nn.Module):
64 | def __init__(self, args):
65 | super(BasicMotionEncoder, self).__init__()
66 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
67 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
68 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
69 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
70 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
71 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
72 |
73 | def forward(self, flow, corr):
74 | cor = F.relu(self.convc1(corr))
75 | cor = F.relu(self.convc2(cor))
76 | flo = F.relu(self.convf1(flow))
77 | flo = F.relu(self.convf2(flo))
78 |
79 | cor_flo = torch.cat([cor, flo], dim=1)
80 | out = F.relu(self.conv(cor_flo))
81 | return torch.cat([out, flow], dim=1)
82 |
83 |
84 | class BasicUpdateBlock(nn.Module):
85 | def __init__(self, args, hidden_dim=128, input_dim=128):
86 | super(BasicUpdateBlock, self).__init__()
87 | self.args = args
88 | self.encoder = BasicMotionEncoder(args)
89 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
90 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
91 |
92 | self.mask = nn.Sequential(
93 | nn.Conv2d(128, 256, 3, padding=1),
94 | nn.ReLU(inplace=True),
95 | nn.Conv2d(256, 64*9, 1, padding=0))
96 |
97 | def forward(self, net, inp, corr, flow, upsample=True):
98 | motion_features = self.encoder(flow, corr)
99 | inp = torch.cat([inp, motion_features], dim=1)
100 |
101 | net = self.gru(net, inp)
102 | delta_flow = self.flow_head(net)
103 |
104 | # scale mask to balance gradients
105 | mask = .25 * self.mask(net)
106 | return net, mask, delta_flow
107 |
108 |
109 |
110 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
8 | """ Wrapper for grid_sample, uses pixel coordinates """
9 | H, W = img.shape[-2:]
10 | xgrid, ygrid = coords.split([1,1], dim=-1)
11 | xgrid = 2*xgrid/(W-1) - 1
12 | ygrid = 2*ygrid/(H-1) - 1
13 |
14 | grid = torch.cat([xgrid, ygrid], dim=-1)
15 | img = F.grid_sample(img, grid, align_corners=True)
16 |
17 | if mask:
18 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
19 | return img, mask.float()
20 |
21 | return img
22 |
23 |
24 | def coords_grid(batch, ht, wd):
25 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
26 | coords = torch.stack(coords[::-1], dim=0).float()
27 | return coords[None].repeat(batch, 1, 1, 1)
28 |
29 |
30 | def upflow8(flow, mode='bilinear'):
31 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
32 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
33 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 | from torchvision import utils
4 | from utils.helper_functions import *
5 | import utils.visualization as visualization
6 | import utils.filename_templates as TEMPLATES
7 | import utils.helper_functions as helper
8 | import utils.logger as logger
9 | from utils import image_utils
10 |
11 | class Test(object):
12 | """
13 | Test class
14 |
15 | """
16 |
17 | def __init__(self, model, config,
18 | data_loader, visualizer, test_logger=None, save_path=None, additional_args=None):
19 | self.downsample = False # Downsampling for Rebuttal
20 | self.model = model
21 | self.config = config
22 | self.data_loader = data_loader
23 | self.additional_args = additional_args
24 | if config['cuda'] and not torch.cuda.is_available():
25 | print('Warning: There\'s no CUDA support on this machine, '
26 | 'training is performed on CPU.')
27 | else:
28 | self.gpu = torch.device('cuda:' + str(config['gpu']))
29 | self.model = self.model.to(self.gpu)
30 | if save_path is None:
31 | self.save_path = helper.create_save_path(config['save_dir'].lower(),
32 | config['name'].lower())
33 | else:
34 | self.save_path=save_path
35 | if logger is None:
36 | self.logger = logger.Logger(self.save_path)
37 | else:
38 | self.logger = test_logger
39 | if isinstance(self.additional_args, dict) and 'name_mapping_test' in self.additional_args.keys():
40 | visu_add_args = {'name_mapping' : self.additional_args['name_mapping_test']}
41 | else:
42 | visu_add_args = None
43 | self.visualizer = visualizer(data_loader, self.save_path, additional_args=visu_add_args)
44 |
45 | def summary(self):
46 | self.logger.write_line("====================================== TEST SUMMARY ======================================", True)
47 | self.logger.write_line("Model:\t\t\t" + self.model.__class__.__name__, True)
48 | self.logger.write_line("Tester:\t\t" + self.__class__.__name__, True)
49 | self.logger.write_line("Test Set:\t" + self.data_loader.dataset.__class__.__name__, True)
50 | self.logger.write_line("\t-Dataset length:\t"+str(len(self.data_loader)), True)
51 | self.logger.write_line("\t-Batch size:\t\t" + str(self.data_loader.batch_size), True)
52 | self.logger.write_line("==========================================================================================", True)
53 |
54 | def run_network(self, epoch):
55 | raise NotImplementedError
56 |
57 | def move_batch_to_cuda(self, batch):
58 | raise NotImplementedError
59 |
60 | def visualize_sample(self, batch):
61 | self.visualizer(batch)
62 |
63 | def visualize_sample_dsec(self, batch, batch_idx):
64 | self.visualizer(batch, batch_idx, None)
65 |
66 | def get_estimation_and_target(self, batch):
67 | # Returns the estimation and target of the current batch
68 | raise NotImplementedError
69 |
70 | def _test(self):
71 | """
72 | Validate after training an epoch
73 |
74 | :return: A log that contains information about validation
75 |
76 | Note:
77 | The validation metrics in log must have the key 'val_metrics'.
78 | """
79 | self.model.eval()
80 | with torch.no_grad():
81 | for batch_idx, batch in enumerate(self.data_loader):
82 | # Move Data to GPU
83 | if next(self.model.parameters()).is_cuda:
84 | batch = self.move_batch_to_cuda(batch)
85 | # Network Forward Pass
86 | self.run_network(batch)
87 | print("Sample {}/{}".format(batch_idx + 1, len(self.data_loader)))
88 |
89 | # Visualize
90 | if hasattr(batch, 'keys') and 'loader_idx' in batch.keys() \
91 | or (isinstance(batch,list) and hasattr(batch[0], 'keys') and 'loader_idx' in batch[0].keys()):
92 | self.visualize_sample(batch)
93 | else:
94 | # DSEC Special Snowflake
95 | self.visualize_sample_dsec(batch, batch_idx)
96 | #print('Not Visualizing')
97 |
98 | # Log Generation
99 | log = {}
100 |
101 | return log
102 |
103 | class TestRaftEvents(Test):
104 | def move_batch_to_cuda(self, batch):
105 | return move_dict_to_cuda(batch, self.gpu)
106 |
107 | def get_estimation_and_target(self, batch):
108 | if not self.downsample:
109 | if 'gt_valid_mask' in batch.keys():
110 | return batch['flow_est'].cpu().data, (batch['flow'].cpu().data, batch['gt_valid_mask'].cpu().data)
111 | return batch['flow_est'].cpu().data, batch['flow'].cpu().data
112 | else:
113 | f_est = batch['flow_est'].cpu().data
114 | f_gt = torch.nn.functional.interpolate(batch['flow'].cpu().data, scale_factor=0.5)
115 | if 'gt_valid_mask' in batch.keys():
116 | f_mask = torch.nn.functional.interpolate(batch['gt_valid_mask'].cpu().data, scale_factor=0.5)
117 | return f_est, (f_gt, f_mask)
118 | return f_est, f_gt
119 |
120 | def run_network(self, batch):
121 | # RAFT just expects two images as input. cleanest. code. ever.
122 | if not self.downsample:
123 | im1 = batch['event_volume_old']
124 | im2 = batch['event_volume_new']
125 | else:
126 | im1 = torch.nn.functional.interpolate(batch['event_volume_old'], scale_factor=0.5)
127 | im2 = torch.nn.functional.interpolate(batch['event_volume_new'], scale_factor=0.5)
128 | _, batch['flow_list'] = self.model(image1=im1,
129 | image2=im2)
130 | batch['flow_est'] = batch['flow_list'][-1]
131 |
132 | class TestRaftEventsWarm(Test):
133 | def __init__(self, model, config,
134 | data_loader, visualizer, test_logger=None, save_path=None, additional_args=None):
135 | super(TestRaftEventsWarm, self).__init__(model, config,
136 | data_loader, visualizer, test_logger, save_path,
137 | additional_args=additional_args)
138 | self.subtype = config['subtype'].lower()
139 | print('Tester Subtype: {}'.format(self.subtype))
140 | self.net_init = None # Hidden state of the refinement GRU
141 | self.flow_init = None
142 | self.idx_prev = None
143 | self.init_print=False
144 | assert self.data_loader.batch_size == 1, 'Batch size for recurrent testing must be 1'
145 |
146 | def move_batch_to_cuda(self, batch):
147 | return move_list_to_cuda(batch, self.gpu)
148 |
149 | def get_estimation_and_target(self, batch):
150 | if not self.downsample:
151 | if 'gt_valid_mask' in batch[-1].keys():
152 | return batch[-1]['flow_est'].cpu().data, (batch[-1]['flow'].cpu().data, batch[-1]['gt_valid_mask'].cpu().data)
153 | return batch[-1]['flow_est'].cpu().data, batch[-1]['flow'].cpu().data
154 | else:
155 | f_est = batch[-1]['flow_est'].cpu().data
156 | f_gt = torch.nn.functional.interpolate(batch[-1]['flow'].cpu().data, scale_factor=0.5)
157 | if 'gt_valid_mask' in batch[-1].keys():
158 | f_mask = torch.nn.functional.interpolate(batch[-1]['gt_valid_mask'].cpu().data, scale_factor=0.5)
159 | return f_est, (f_gt, f_mask)
160 | return f_est, f_gt
161 |
162 | def visualize_sample(self, batch):
163 | self.visualizer(batch[-1])
164 |
165 | def visualize_sample_dsec(self, batch, batch_idx):
166 | self.visualizer(batch[-1], batch_idx, None)
167 |
168 | def check_states(self, batch):
169 | # 0th case: there is a flag in the batch that tells us to reset the state (DSEC)
170 | if 'new_sequence' in batch[0].keys():
171 | if batch[0]['new_sequence'].item() == 1:
172 | self.flow_init = None
173 | self.net_init = None
174 | self.logger.write_line("Resetting States!", True)
175 | else:
176 | # During Validation, reset state if a new scene starts (index jump)
177 | if self.idx_prev is not None and batch[0]['idx'].item() - self.idx_prev != 1:
178 | self.flow_init = None
179 | self.net_init = None
180 | self.logger.write_line("Resetting States!", True)
181 | self.idx_prev = batch[0]['idx'].item()
182 |
183 | def run_network(self, batch):
184 | self.check_states(batch)
185 | for l in range(len(batch)):
186 | # Run Recurrent Network for this sample
187 |
188 | if not self.downsample:
189 | im1 = batch[l]['event_volume_old']
190 | im2 = batch[l]['event_volume_new']
191 | else:
192 | im1 = torch.nn.functional.interpolate(batch[l]['event_volume_old'], scale_factor=0.5)
193 | im2 = torch.nn.functional.interpolate(batch[l]['event_volume_new'], scale_factor=0.5)
194 | flow_low_res, batch[l]['flow_list'] = self.model(image1=im1,
195 | image2=im2,
196 | flow_init=self.flow_init)
197 |
198 | batch[l]['flow_est'] = batch[l]['flow_list'][-1]
199 | self.flow_init = image_utils.forward_interpolate_pytorch(flow_low_res)
200 | batch[l]['flow_init'] = self.flow_init
201 |
--------------------------------------------------------------------------------
/utils/dsec_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from enum import Enum, auto
4 |
5 |
6 | class RepresentationType(Enum):
7 | VOXEL = auto()
8 | STEPAN = auto()
9 |
10 |
11 | class EventRepresentation:
12 | def __init__(self):
13 | pass
14 |
15 | def convert(self, events):
16 | raise NotImplementedError
17 |
18 |
19 | class VoxelGrid(EventRepresentation):
20 | def __init__(self, input_size: tuple, normalize: bool):
21 | assert len(input_size) == 3
22 | self.voxel_grid = torch.zeros((input_size), dtype=torch.float, requires_grad=False)
23 | self.nb_channels = input_size[0]
24 | self.normalize = normalize
25 |
26 | def convert(self, events):
27 | C, H, W = self.voxel_grid.shape
28 | with torch.no_grad():
29 | self.voxel_grid = self.voxel_grid.to(events['p'].device)
30 | voxel_grid = self.voxel_grid.clone()
31 |
32 | t_norm = events['t']
33 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0])
34 |
35 | x0 = events['x'].int()
36 | y0 = events['y'].int()
37 | t0 = t_norm.int()
38 |
39 | value = 2*events['p']-1
40 |
41 | for xlim in [x0,x0+1]:
42 | for ylim in [y0,y0+1]:
43 | for tlim in [t0,t0+1]:
44 |
45 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels)
46 | interp_weights = value * (1 - (xlim-events['x']).abs()) * (1 - (ylim-events['y']).abs()) * (1 - (tlim - t_norm).abs())
47 |
48 | index = H * W * tlim.long() + \
49 | W * ylim.long() + \
50 | xlim.long()
51 |
52 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
53 |
54 | if self.normalize:
55 | mask = torch.nonzero(voxel_grid, as_tuple=True)
56 | if mask[0].size()[0] > 0:
57 | mean = voxel_grid[mask].mean()
58 | std = voxel_grid[mask].std()
59 | if std > 0:
60 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std
61 | else:
62 | voxel_grid[mask] = voxel_grid[mask] - mean
63 |
64 | return voxel_grid
65 |
66 | def flow_16bit_to_float(flow_16bit: np.ndarray):
67 | assert flow_16bit.dtype == np.uint16
68 | assert flow_16bit.ndim == 3
69 | h, w, c = flow_16bit.shape
70 | assert c == 3
71 |
72 | valid2D = flow_16bit[..., 2] == 1
73 | assert valid2D.shape == (h, w)
74 | assert np.all(flow_16bit[~valid2D, -1] == 0)
75 | valid_map = np.where(valid2D)
76 |
77 | # to actually compute something useful:
78 | flow_16bit = flow_16bit.astype('float')
79 |
80 | flow_map = np.zeros((h, w, 2))
81 | flow_map[valid_map[0], valid_map[1], 0] = (flow_16bit[valid_map[0], valid_map[1], 0] - 2 ** 15) / 128
82 | flow_map[valid_map[0], valid_map[1], 1] = (flow_16bit[valid_map[0], valid_map[1], 1] - 2 ** 15) / 128
83 | return flow_map, valid2D
84 |
--------------------------------------------------------------------------------
/utils/filename_templates.py:
--------------------------------------------------------------------------------
1 | # =========================== Templates for saving Images ========================= #
2 | GT_FLOW = '{}_{}_flow_gt.png'
3 | FLOW_TEST = '{}_{}_flow.png'
4 | IMG = '{}_{}_image.png'
5 | EVENTS = '{}_{}_events.png'
6 |
7 | # ========================= Templates for saving Checkpoints ====================== #
8 | CHECKPOINT = '{:03d}_checkpoint.tar'
9 |
10 | # ========================= MVSEC DATALOADING ====================== #
11 | MVSEC_DATASET_FOLDER = '{}_{}'
12 | MVSEC_TIMESTAMPS_PATH_DEPTH = 'timestamps_depth.txt'
13 | MVSEC_TIMESTAMPS_PATH_FLOW = 'timestamps_flow.txt'
14 | MVSEC_TIMESTAMPS_PATH_IMAGES = 'timestamps_images.txt'
15 | MVSEC_EVENTS_FILE = "davis/{}/events/{:06d}.h5"
16 | MVSEC_FLOW_GT_FILE = "optical_flow/{:06d}.npy"
17 |
--------------------------------------------------------------------------------
/utils/helper_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import smtplib
4 | import json
5 |
6 | def move_dict_to_cuda(dictionary_of_tensors, gpu):
7 | if isinstance(dictionary_of_tensors, dict):
8 | return {
9 | key: move_dict_to_cuda(value, gpu)
10 | for key, value in dictionary_of_tensors.items()
11 | }
12 | return dictionary_of_tensors.to(gpu, dtype=torch.float)
13 |
14 | def move_list_to_cuda(list_of_dicts, gpu):
15 | for i in range(len(list_of_dicts)):
16 | list_of_dicts[i] = move_dict_to_cuda(list_of_dicts[i], gpu)
17 | return list_of_dicts
18 |
19 | def get_values_from_key(input_list, key):
20 | # Returns all the values with the same key from
21 | # a list filled with dicts of the same kind
22 | out = []
23 | for i in input_list:
24 | out.append(i[key])
25 | return out
26 |
27 | def create_save_path(subdir, name):
28 | # Check if sub-folder exists, and create if necessary
29 | if not os.path.exists(subdir):
30 | os.mkdir(subdir)
31 | # Create a new folder (named after the name defined in the config file)
32 | path = os.path.join(subdir, name)
33 | # Check if path already exists. if yes -> append a number
34 | if os.path.exists(path):
35 | i = 1
36 | while os.path.exists(path + "_" + str(i)):
37 | i += 1
38 | path = path + '_' + str(i)
39 | os.mkdir(path)
40 | return path
41 |
42 | def get_nth_element_of_all_dict_keys(dict, idx):
43 | out_dict = {}
44 | for k in dict.keys():
45 | d = dict[k][idx]
46 | if isinstance(d,torch.Tensor):
47 | out_dict[k]=d.detach().cpu().item()
48 | else:
49 | out_dict[k]=d
50 | return out_dict
51 |
52 | def get_number_of_saved_elements(path, template, first=1):
53 | i = first
54 | while True:
55 | if os.path.exists(os.path.join(path,template.format(i))):
56 | i+=1
57 | else:
58 | break
59 | return range(first, i)
60 |
61 | def create_file_path(subdir, name):
62 | # Check if sub-folder exists, else raise exception
63 | if not os.path.exists(subdir):
64 | raise Exception("Path {} does not exist!".format(subdir))
65 | # Check if file already exists, else create path
66 | if not os.path.exists(os.path.join(subdir,name)):
67 | return os.path.join(subdir,name)
68 | else:
69 | path = os.path.join(subdir,name)
70 | prefix,suffix = path.split('.')
71 | i = 1
72 | while os.path.exists("{}_{}.{}".format(prefix,i,suffix)):
73 | i += 1
74 | return "{}_{}.{}".format(prefix,i,suffix)
75 |
76 | def update_dict(dict_old, dict_new):
77 | # Update all the entries of dict_old with the new values(that have the identical keys) of dict_new
78 | for k in dict_new.keys():
79 | if k in dict_old.keys():
80 | # Replace the entry
81 | if isinstance(dict_new[k], dict):
82 | update_dict(dict_old[k], dict_new[k])
83 | else:
84 | dict_old[k] = dict_new[k]
85 | return dict_old
86 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import torch
3 | from torch import nn
4 | from torch.nn.functional import grid_sample
5 | from scipy.spatial import transform
6 | from scipy import interpolate
7 | from matplotlib import pyplot as plt
8 |
9 |
10 | def grid_sample_values(input, height, width):
11 | # ================================ Grid Sample Values ============================= #
12 | # Input: Torch Tensor [3,H*W]m where the 3 Dimensions mean [x,y,z] #
13 | # Height: Image Height #
14 | # Width: Image Width #
15 | # --------------------------------------------------------------------------------- #
16 | # Output: tuple(value_ipl, valid_mask) #
17 | # value_ipl -> [H,W]: Interpolated values #
18 | # valid_mask -> [H,W]: 1: Point is valid, 0: Point is invalid #
19 | # ================================================================================= #
20 | device = input.device
21 | ceil = torch.stack([torch.ceil(input[0,:]), torch.ceil(input[1,:]), input[2,:]])
22 | floor = torch.stack([torch.floor(input[0,:]), torch.floor(input[1,:]), input[2,:]])
23 | z = input[2,:].clone()
24 |
25 | values_ipl = torch.zeros(height*width, device=device)
26 | weights_acc = torch.zeros(height*width, device=device)
27 | # Iterate over all ceil/floor points
28 | for x_vals in [floor[0], ceil[0]]:
29 | for y_vals in [floor[1], ceil[1]]:
30 | # Mask Points that are in the image
31 | in_bounds_mask = (x_vals < width) & (x_vals >=0) & (y_vals < height) & (y_vals >= 0)
32 |
33 | # Calculate weights, according to their real distance to the floored/ceiled value
34 | weights = (1 - (input[0]-x_vals).abs()) * (1 - (input[1]-y_vals).abs())
35 |
36 | # Put them into the right grid
37 | indices = (x_vals + width * y_vals).long()
38 | values_ipl.put_(indices[in_bounds_mask], (z * weights)[in_bounds_mask], accumulate=True)
39 | weights_acc.put_(indices[in_bounds_mask], weights[in_bounds_mask], accumulate=True)
40 |
41 | # Mask of valid pixels -> Everywhere where we have an interpolated value
42 | valid_mask = weights_acc.clone()
43 | valid_mask[valid_mask > 0] = 1
44 | valid_mask= valid_mask.bool().reshape([height,width])
45 |
46 | # Divide by weights to get interpolated values
47 | values_ipl = values_ipl / (weights_acc + 1e-15)
48 | values_rs = values_ipl.reshape([height,width])
49 |
50 | return values_rs.unsqueeze(0).clone(), valid_mask.unsqueeze(0).clone()
51 |
52 | def forward_interpolate_pytorch(flow_in):
53 | # Same as the numpy implementation, but differentiable :)
54 | # Flow: [B,2,H,W]
55 | flow = flow_in.clone()
56 | if len(flow.shape) < 4:
57 | flow = flow.unsqueeze(0)
58 |
59 | b, _, h, w = flow.shape
60 | device = flow.device
61 |
62 | dx ,dy = flow[:,0], flow[:,1]
63 | y0, x0 = torch.meshgrid(torch.arange(0, h, 1), torch.arange(0, w, 1))
64 | x0 = torch.stack([x0]*b).to(device)
65 | y0 = torch.stack([y0]*b).to(device)
66 |
67 | x1 = x0 + dx
68 | y1 = y0 + dy
69 |
70 | x1 = x1.flatten(start_dim=1)
71 | y1 = y1.flatten(start_dim=1)
72 | dx = dx.flatten(start_dim=1)
73 | dy = dy.flatten(start_dim=1)
74 |
75 | # Interpolate Griddata...
76 | # Note that a Nearest Neighbor Interpolation would be better. But there does not exist a pytorch fcn yet.
77 | # See issue: https://github.com/pytorch/pytorch/issues/50339
78 | flow_new = torch.zeros(flow.shape, device=device)
79 | for i in range(b):
80 | flow_new[i,0] = grid_sample_values(torch.stack([x1[i],y1[i],dx[i]]), h, w)[0]
81 | flow_new[i,1] = grid_sample_values(torch.stack([x1[i],y1[i],dy[i]]), h, w)[0]
82 |
83 | return flow_new
84 |
85 | class ImagePadder(object):
86 | # =================================================================== #
87 | # In some networks, the image gets downsized. This is a problem, if #
88 | # the to-be-downsized image has odd dimensions ([15x20]->[7.5x10]). #
89 | # To prevent this, the input image of the network needs to be a #
90 | # multiple of a minimum size (min_size) #
91 | # The ImagePadder makes sure, that the input image is of such a size, #
92 | # and if not, it pads the image accordingly. #
93 | # =================================================================== #
94 |
95 | def __init__(self, min_size=64):
96 | # --------------------------------------------------------------- #
97 | # The min_size additionally ensures, that the smallest image #
98 | # does not get too small #
99 | # --------------------------------------------------------------- #
100 | self.min_size = min_size
101 | self.pad_height = None
102 | self.pad_width = None
103 |
104 | def pad(self, image):
105 | # --------------------------------------------------------------- #
106 | # If necessary, this function pads the image on the left & top #
107 | # --------------------------------------------------------------- #
108 | height, width = image.shape[-2:]
109 | if self.pad_width is None:
110 | self.pad_height = (self.min_size - height % self.min_size)%self.min_size
111 | self.pad_width = (self.min_size - width % self.min_size)%self.min_size
112 | else:
113 | pad_height = (self.min_size - height % self.min_size)%self.min_size
114 | pad_width = (self.min_size - width % self.min_size)%self.min_size
115 | if pad_height != self.pad_height or pad_width != self.pad_width:
116 | raise
117 | return nn.ZeroPad2d((self.pad_width, 0, self.pad_height, 0))(image)
118 |
119 | def unpad(self, image):
120 | # --------------------------------------------------------------- #
121 | # Removes the padded rows & columns #
122 | # --------------------------------------------------------------- #
123 | return image[..., self.pad_height:, self.pad_width:]
124 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import numpy
4 | import shutil
5 |
6 | class Logger:
7 | # Logger of the Training/Testing Process
8 | def __init__(self, save_path, custom_name='log.txt'):
9 | self.toWrite = {}
10 | self.signalization = "========================================"
11 | self.path = os.path.join(save_path,custom_name)
12 |
13 | def initialize_file(self, mode):
14 | # Mode : "Training" or "Testing"
15 | with open(self.path, 'a') as file:
16 | file.write(self.signalization + " " + mode + " " + self.signalization + "\n")
17 |
18 | def write_as_list(self, dict_to_write, overwrite=False):
19 | if overwrite:
20 | if os.path.exists(self.path):
21 | os.remove(self.path)
22 | with open(self.path, 'a') as file:
23 | for entry in dict_to_write.keys():
24 | file.write(entry+"="+json.dumps(dict_to_write[entry])+"\n")
25 |
26 | def write_dict(self, dict_to_write, array_names=None, overwrite=False, as_list=False):
27 | if overwrite:
28 | open_type = 'w'
29 | else:
30 | open_type = 'a'
31 | dict_to_write = self.check_for_arrays(dict_to_write, array_names)
32 | if as_list:
33 | self.write_as_list(dict_to_write, overwrite)
34 | else:
35 | with open(self.path, open_type) as file:
36 | #if "epoch" in dict_to_write:
37 | # file.write("Epoch")
38 | file.write(json.dumps(dict_to_write) + "\n")
39 |
40 | def write_line(self,line, verbose=False):
41 | with open(self.path, 'a') as file:
42 | file.write(line + "\n")
43 | if verbose:
44 | print(line)
45 |
46 | def arrays_to_dicts(self, list_of_arrays, array_name, entry_name):
47 | list_of_arrays = numpy.array(list_of_arrays).T
48 | out = {}
49 | for i in range(list_of_arrays.shape[0]):
50 | out[array_name+'_'+entry_name[i]] = list(list_of_arrays[i])
51 | return out
52 |
53 |
54 | def check_for_arrays(self, dict_to_write, array_names):
55 | if array_names is not None:
56 | names = []
57 | for n in range(len(array_names)):
58 | if hasattr(array_names[n], 'name'):
59 | names.append(array_names[n].name)
60 | elif hasattr(array_names[n],'__name__'):
61 | names.append(array_names[n].__name__)
62 | elif hasattr(array_names[n],'__class__'):
63 | names.append(array_names[n].__class__.__name__)
64 | else:
65 | names.append(array_names[n])
66 |
67 | keys = dict_to_write.keys()
68 | out = {}
69 | for entry in keys:
70 | if hasattr(dict_to_write[entry], '__len__') and len(dict_to_write[entry])>0:
71 | if isinstance(dict_to_write[entry][0], numpy.ndarray) or isinstance(dict_to_write[entry][0], list):
72 | out.update(self.arrays_to_dicts(dict_to_write[entry], entry, names))
73 | else:
74 | out.update({entry:dict_to_write[entry]})
75 | else:
76 | out.update({entry: dict_to_write[entry]})
77 | return out
78 |
--------------------------------------------------------------------------------
/utils/mvsec_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # import cv2
3 | from matplotlib import pyplot as plt
4 | import os
5 | from utils import filename_templates as TEMPLATES
6 |
7 | def prop_flow(x_flow, y_flow, x_indices, y_indices, x_mask, y_mask, scale_factor=1.0):
8 | flow_x_interp = cv2.remap(x_flow,
9 | x_indices,
10 | y_indices,
11 | cv2.INTER_NEAREST)
12 |
13 | flow_y_interp = cv2.remap(y_flow,
14 | x_indices,
15 | y_indices,
16 | cv2.INTER_NEAREST)
17 |
18 | x_mask[flow_x_interp == 0] = False
19 | y_mask[flow_y_interp == 0] = False
20 |
21 | x_indices += flow_x_interp * scale_factor
22 | y_indices += flow_y_interp * scale_factor
23 |
24 | return
25 |
26 | def estimate_corresponding_gt_flow(path_flow,
27 | gt_timestamps,
28 | start_time,
29 | end_time):
30 | # Each gt flow at timestamp gt_timestamps[gt_iter] represents the displacement between
31 | # gt_iter and gt_iter+1.
32 | # gt_timestamps[gt_iter] -> Timestamp just before start_time
33 |
34 | gt_iter = np.searchsorted(gt_timestamps, start_time, side='right') - 1
35 | gt_dt = gt_timestamps[gt_iter + 1] - gt_timestamps[gt_iter]
36 |
37 | # Load Flow just before start_time
38 | flow_file = os.path.join(path_flow, TEMPLATES.MVSEC_FLOW_GT_FILE.format(gt_iter))
39 | flow = np.load(flow_file)
40 |
41 | x_flow = flow[0]
42 | y_flow = flow[1]
43 | #x_flow = np.squeeze(x_flow_in[gt_iter, ...])
44 | #y_flow = np.squeeze(y_flow_in[gt_iter, ...])
45 |
46 | dt = end_time - start_time
47 |
48 | # No need to propagate if the desired dt is shorter than the time between gt timestamps.
49 | if gt_dt > dt:
50 | return x_flow * dt / gt_dt, y_flow * dt / gt_dt
51 | else:
52 | raise Exception
53 |
--------------------------------------------------------------------------------
/utils/transformers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | def dictionary_of_numpy_arrays_to_tensors(sample):
5 | """Transforms dictionary of numpy arrays to dictionary of tensors."""
6 | if isinstance(sample, dict):
7 | return {
8 | key: dictionary_of_numpy_arrays_to_tensors(value)
9 | for key, value in sample.items()
10 | }
11 | if isinstance(sample, np.ndarray):
12 | if len(sample.shape) == 2:
13 | return th.from_numpy(sample).float().unsqueeze(0)
14 | else:
15 | return th.from_numpy(sample).float()
16 | return sample
17 |
18 | class EventSequenceToVoxelGrid_Pytorch(object):
19 | # Source: https://github.com/uzh-rpg/rpg_e2vid/blob/master/utils/inference_utils.py#L480
20 | def __init__(self, num_bins, gpu=False, gpu_nr=0, normalize=True, forkserver=True):
21 | if forkserver:
22 | try:
23 | th.multiprocessing.set_start_method('forkserver')
24 | except RuntimeError:
25 | pass
26 | self.num_bins = num_bins
27 | self.normalize = normalize
28 | if gpu:
29 | if not th.cuda.is_available():
30 | print('Warning: There\'s no CUDA support on this machine!')
31 | else:
32 | self.device = th.device('cuda:' + str(gpu_nr))
33 | else:
34 | self.device = th.device('cpu')
35 |
36 | def __call__(self, event_sequence):
37 | """
38 | Build a voxel grid with bilinear interpolation in the time domain from a set of events.
39 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity]
40 | :param num_bins: number of bins in the temporal axis of the voxel grid
41 | :param width, height: dimensions of the voxel grid
42 | :param device: device to use to perform computations
43 | :return voxel_grid: PyTorch event tensor (on the device specified)
44 | """
45 |
46 | events = event_sequence.features.astype('float')
47 |
48 | width = event_sequence.image_width
49 | height = event_sequence.image_height
50 |
51 | assert (events.shape[1] == 4)
52 | assert (self.num_bins > 0)
53 | assert (width > 0)
54 | assert (height > 0)
55 |
56 | with th.no_grad():
57 |
58 | events_torch = th.from_numpy(events)
59 | # with DeviceTimer('Events -> Device (voxel grid)'):
60 | events_torch = events_torch.to(self.device)
61 |
62 | # with DeviceTimer('Voxel grid voting'):
63 | voxel_grid = th.zeros(self.num_bins, height, width, dtype=th.float32, device=self.device).flatten()
64 |
65 | # normalize the event timestamps so that they lie between 0 and num_bins
66 | last_stamp = events_torch[-1, 0]
67 | first_stamp = events_torch[0, 0]
68 |
69 | assert last_stamp.dtype == th.float64, 'Timestamps must be float64!'
70 | # assert last_stamp.item()%1 == 0, 'Timestamps should not have decimals'
71 |
72 | deltaT = last_stamp - first_stamp
73 |
74 | if deltaT == 0:
75 | deltaT = 1.0
76 |
77 | events_torch[:, 0] = (self.num_bins - 1) * (events_torch[:, 0] - first_stamp) / deltaT
78 | ts = events_torch[:, 0]
79 | xs = events_torch[:, 1].long()
80 | ys = events_torch[:, 2].long()
81 | pols = events_torch[:, 3].float()
82 | pols[pols == 0] = -1 # polarity should be +1 / -1
83 |
84 |
85 | tis = th.floor(ts)
86 | tis_long = tis.long()
87 | dts = ts - tis
88 | vals_left = pols * (1.0 - dts.float())
89 | vals_right = pols * dts.float()
90 |
91 | valid_indices = tis < self.num_bins
92 | valid_indices &= tis >= 0
93 |
94 | if events_torch.is_cuda:
95 | datatype = th.cuda.LongTensor
96 | else:
97 | datatype = th.LongTensor
98 |
99 | voxel_grid.index_add_(dim=0,
100 | index=(xs[valid_indices] + ys[valid_indices]
101 | * width + tis_long[valid_indices] * width * height).type(
102 | datatype),
103 | source=vals_left[valid_indices])
104 |
105 |
106 | valid_indices = (tis + 1) < self.num_bins
107 | valid_indices &= tis >= 0
108 |
109 | voxel_grid.index_add_(dim=0,
110 | index=(xs[valid_indices] + ys[valid_indices] * width
111 | + (tis_long[valid_indices] + 1) * width * height).type(datatype),
112 | source=vals_right[valid_indices])
113 |
114 | voxel_grid = voxel_grid.view(self.num_bins, height, width)
115 |
116 | if self.normalize:
117 | mask = th.nonzero(voxel_grid, as_tuple=True)
118 | if mask[0].size()[0] > 0:
119 | mean = voxel_grid[mask].mean()
120 | std = voxel_grid[mask].std()
121 | if std > 0:
122 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std
123 | else:
124 | voxel_grid[mask] = voxel_grid[mask] - mean
125 |
126 | return voxel_grid
127 |
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from matplotlib import pyplot as plt
3 | from matplotlib import colors
4 | import numpy
5 | import os
6 | import loader.utils as loader
7 | import utils.transformers as transformers
8 | import utils.filename_templates as TEMPLATES
9 | import utils.helper_functions as helper
10 | from matplotlib.lines import Line2D
11 | from skimage.transform import rotate, warp
12 | from skimage import io
13 | import cv2
14 | import imageio
15 | from torchvision.transforms import CenterCrop
16 |
17 | class BaseVisualizer(object):
18 | def __init__(self, dataloader, save_path, additional_args=None):
19 | super(BaseVisualizer, self).__init__()
20 | self.dataloader = dataloader
21 | self.visu_path = helper.create_save_path(save_path, 'visualizations')
22 | self.submission_path = os.path.join(save_path, 'submission')
23 | os.mkdir(self.submission_path)
24 | self.mode = 'test'
25 | self.additional_args = additional_args
26 |
27 | def __call__(self, batch, epoch=None):
28 | for j in batch['loader_idx'].cpu().numpy().astype(int):
29 | # Get Batch Index
30 | batch_idx = torch.nonzero(batch['loader_idx'] == j).item()
31 |
32 | # Visualize Ground Truths, but only in first epoch or if we're cropping
33 | if epoch == 1 or epoch is None or 'crop_window' in batch.keys():
34 | self.visualize_ground_truths(batch, batch_idx=batch_idx, epoch=epoch,
35 | data_aug='crop_window' in batch.keys())
36 |
37 | # Visualize Estimations
38 | self.visualize_estimations(batch, batch_idx=batch_idx, epoch=epoch)
39 |
40 | def visualize_ground_truths(self, batch, batch_idx, epoch=None, data_aug=False):
41 | raise NotImplementedError
42 |
43 | def visualize_estimations(self, batch, batch_idx, epoch=None):
44 | raise NotImplementedError
45 |
46 | def visualize_image(self, image, true_idx, epoch=None, data_aug=False):
47 | true_idx = int(true_idx)
48 | name = TEMPLATES.IMG.format('inference', true_idx)
49 | save_image(os.path.join(self.visu_path, name), image.detach().cpu())
50 |
51 | def visualize_events(self, image, batch, batch_idx, epoch=None, flip_before_crop=True, crop_window=None):
52 | raise NotImplementedError
53 |
54 | def visualize_flow_colours(self, flow, true_idx, epoch=None, data_aug=False, is_gt=False, fix_scaling=10,
55 | custom_name=None, prefix=None, suffix=None, sub_folder=None):
56 | true_idx = int(true_idx)
57 |
58 | if custom_name is None:
59 | name = TEMPLATES.FLOW_TEST.format('inference', true_idx)
60 | else:
61 | name = custom_name
62 | if prefix is not None:
63 | name = prefix + name
64 | if suffix is not None:
65 | split = name.split('.')
66 | name = split[0] + suffix + "." +split[1]
67 | if sub_folder is not None:
68 | name = os.path.join(sub_folder, name)
69 | # Visualize
70 | _, scaling = visualize_optical_flow(flow.detach().cpu().numpy(),
71 | os.path.join(self.visu_path, name),
72 | scaling=fix_scaling)
73 | return scaling
74 |
75 | def visualize_flow_submission(self, seq_name: str, flow: numpy.ndarray, file_index: int):
76 | # flow_u(u,v) = ((float)I(u,v,1)-2^15)/128.0;
77 | # flow_v(u,v) = ((float)I(u,v,2)-2^15)/128.0;
78 | # valid(u,v) = (bool)I(u,v,3);
79 | # [-2**15/128, 2**15/128] = [-256, 256]
80 | #flow_map_16bit = np.rint(flow_map*128 + 2**15).astype(np.uint16)
81 | _, h,w = flow.shape
82 | flow_map = numpy.rint(flow*128 + 2**15)
83 | flow_map = flow_map.astype(numpy.uint16).transpose(1,2,0)
84 | flow_map = numpy.concatenate((flow_map, numpy.zeros((h,w,1), dtype=numpy.uint16)), axis=-1)
85 | parent_path = os.path.join(
86 | self.submission_path,
87 | seq_name
88 | )
89 | if not os.path.exists(parent_path):
90 | os.mkdir(parent_path)
91 | file_name = '{:06d}.png'.format(file_index)
92 |
93 | imageio.imwrite(os.path.join(parent_path, file_name), flow_map, format='PNG-FI')
94 |
95 | class FlowVisualizerEvents(BaseVisualizer):
96 | def __init__(self, dataloader, save_path, clamp_flow=True, additional_args=None):
97 | super(FlowVisualizerEvents, self).__init__(dataloader, save_path, additional_args=additional_args)
98 | self.flow_scaling = 0
99 | self.clamp_flow = clamp_flow
100 |
101 | def visualize_events(self, image, batch, batch_idx, epoch=None, flip_before_crop=True, crop_window=None):
102 | # Plots Events on top of an Image.
103 | if image is not None:
104 | im = image.detach().cpu()
105 | else:
106 | im = None
107 |
108 | # Load Raw events
109 | events = self.dataloader.dataset.get_events(loader_idx=int(batch['loader_idx'][batch_idx].item()))
110 | name_events = TEMPLATES.EVENTS.format('inference', int(batch['idx'][batch_idx].item()))
111 |
112 | # Event Sequence to Event Image
113 | events = events_to_event_image(events,
114 | int(batch['param_evc']['height'][batch_idx].item()),
115 | int(batch['param_evc']['width'][batch_idx].item()),
116 | im,
117 | crop_window=crop_window,
118 | rotation_angle=False,
119 | horizontal_flip=False,
120 | flip_before_crop=False)
121 | # center-crop 256x256
122 | crop = CenterCrop(256)
123 | events = crop(events)
124 | # Save
125 | save_image(os.path.join(self.visu_path, name_events), events)
126 |
127 | def visualize_ground_truths(self, batch, batch_idx, epoch=None, data_aug=False):
128 | # Visualize Events
129 | if 'image_old' in batch.keys():
130 | image_old = batch['image_old'][batch_idx]
131 | else:
132 | image_old = None
133 | self.visualize_events(image_old, batch, batch_idx, epoch)
134 |
135 | # Visualize Image
136 | '''
137 | if 'image_old' in batch.keys():
138 | self.visualize_image(batch['image_old'][batch_idx], batch['idx'][batch_idx],epoch, data_aug)
139 | '''
140 | # Visualize Flow GT
141 | flow_gt = batch['flow'][batch_idx].clone()
142 | flow_gt[~batch['gt_valid_mask'][batch_idx].bool()] = 0.0
143 | self.flow_scaling = self.visualize_flow_colours(flow_gt, batch['idx'][batch_idx], epoch=epoch,
144 | data_aug=data_aug, is_gt=True, fix_scaling=None, suffix='_gt')
145 |
146 | def visualize_estimations(self, batch, batch_idx, epoch=None):
147 | # Visualize Flow Estimation
148 | if self.clamp_flow:
149 | scaling = self.flow_scaling[1]
150 | else:
151 | scaling = None
152 | self.visualize_flow_colours(batch['flow_est'][batch_idx], batch['idx'][batch_idx], epoch=epoch,
153 | is_gt=False, fix_scaling=scaling)
154 |
155 | # Visualize Masked Flow
156 | flow_est = batch['flow_est'][batch_idx].clone()
157 | flow_est[~batch['gt_valid_mask'][batch_idx].bool()] = 0.0
158 | self.visualize_flow_colours(flow_est, batch['idx'][batch_idx], epoch=epoch,
159 | is_gt=False, fix_scaling=scaling, suffix='_masked')
160 |
161 | class DsecFlowVisualizer(BaseVisualizer):
162 | def __init__(self, dataloader, save_path, additional_args=None):
163 | super(DsecFlowVisualizer, self).__init__(dataloader, save_path, additional_args=additional_args)
164 | # Create Visu folders for every sequence
165 | for name in self.additional_args['name_mapping']:
166 | os.mkdir(os.path.join(self.visu_path, name))
167 | os.mkdir(os.path.join(self.submission_path, name))
168 |
169 | def visualize_events(self, image, batch, batch_idx, sequence_name):
170 | sequence_idx = [i for i, e in enumerate(self.additional_args['name_mapping']) if e == sequence_name][0]
171 | delta_t_us = self.dataloader.dataset.datasets[sequence_idx].delta_t_us
172 | loader_instance = self.dataloader.dataset.datasets[sequence_idx]
173 | h, w = loader_instance.get_image_width_height()
174 | events = loader_instance.event_slicer.get_events(
175 | t_start_us=batch['timestamp'][batch_idx].item(),
176 | t_end_us=batch['timestamp'][batch_idx].item()+delta_t_us
177 | )
178 | p = events['p'].astype(numpy.int8)
179 | t = events['t'].astype(numpy.float64)
180 | x = events['x']
181 | y = events['y']
182 | p = 2*p - 1
183 | xy_rect = loader_instance.rectify_events(x, y)
184 | x_rect = numpy.rint(xy_rect[:, 0])
185 | y_rect = numpy.rint(xy_rect[:, 1])
186 |
187 | events_rectified = numpy.stack([t, x_rect, y_rect, p], axis=-1)
188 | event_image = events_to_event_image(
189 | event_sequence=events_rectified,
190 | height=h,
191 | width=w
192 | ).numpy()
193 | name_events = TEMPLATES.EVENTS.format('inference', int(batch['file_index'][batch_idx].item()))
194 | out_path = os.path.join(self.visu_path, sequence_name, name_events)
195 | imageio.imsave(out_path, event_image.transpose(1,2,0))
196 |
197 | def __call__(self, batch, batch_idx, epoch=None):
198 | for batch_idx in range(len(batch['file_index'])):
199 | if batch['save_submission'][batch_idx]:
200 | sequence_name = self.additional_args['name_mapping'][int(batch['name_map'][batch_idx].item())]
201 | # Save for Benchmark Submission
202 | self.visualize_flow_submission(
203 | seq_name=sequence_name,
204 | flow=batch['flow_est'][batch_idx].clone().cpu().numpy(),
205 | file_index=int(batch['file_index'][batch_idx].item()),
206 | )
207 | if batch['visualize'][batch_idx]:
208 | sequence_name = self.additional_args['name_mapping'][int(batch['name_map'][batch_idx].item())]
209 | # Visualize Flow
210 | self.visualize_flow_colours(
211 | batch['flow_est'][batch_idx],
212 | batch['file_index'][batch_idx],
213 | epoch=epoch,
214 | is_gt=False,
215 | fix_scaling=None,
216 | sub_folder=sequence_name
217 | )
218 | # Visualize Events
219 | self.visualize_events(
220 | image=None,
221 | batch=batch,
222 | batch_idx=batch_idx,
223 | sequence_name=sequence_name
224 | )
225 |
226 | def save_tensor(filepath, tensor):
227 | map = plt.get_cmap('plasma')
228 | t = tensor[0].numpy() / tensor[0].numpy().max()
229 | image = map(t) * 255
230 | io.imsave(filepath, image.astype(numpy.uint8))
231 |
232 |
233 | def grayscale_to_rgb(tensor, permute=False):
234 | # Tensor [height, width, 3], or
235 | # Tensor [height, width, 1], or
236 | # Tensor [1, height, width], or
237 | # Tensor [3, height, width]
238 |
239 | # if permute -> Convert to [height, width, 3]
240 | if permute:
241 | if tensor.size()[0] < 4:
242 | tensor = tensor.permute(1, 2, 0)
243 | if tensor.size()[2] == 1:
244 | return torch.stack([tensor[:, :, 0]] * 3, dim=2)
245 | else:
246 | return tensor
247 | else:
248 | if tensor.size()[0] == 1:
249 | return torch.stack([tensor[0, :, :]] * 3, dim=0)
250 | else:
251 | return tensor
252 |
253 |
254 | def save_image(filepath, tensor):
255 | # Tensor [height, width, 3], or
256 | # Tensor [height, width, 1], or
257 | # Tensor [1, height, width], or
258 | # Tensor [3, height, width]
259 |
260 | # Convert to [height, width, 3]
261 | tensor = grayscale_to_rgb(tensor, True).numpy()
262 | use_pyplot=False
263 | if use_pyplot:
264 | fig = plt.figure()
265 | # Change Dimensions of Tensor
266 | plot = plt.imshow(tensor.astype(numpy.uint8))
267 | plot.axes.get_xaxis().set_visible(False)
268 | plot.axes.get_yaxis().set_visible(False)
269 | fig.savefig(filepath, bbox_inches='tight', dpi=200)
270 | plt.close()
271 | else:
272 | io.imsave(filepath, tensor.astype(numpy.uint8))
273 |
274 |
275 | def events_to_event_image(event_sequence, height, width, background=None, rotation_angle=None, crop_window=None,
276 | horizontal_flip=False, flip_before_crop=True):
277 | polarity = event_sequence[:, 3] == -1.0
278 | x_negative = event_sequence[~polarity, 1].astype(numpy.int)
279 | y_negative = event_sequence[~polarity, 2].astype(numpy.int)
280 | x_positive = event_sequence[polarity, 1].astype(numpy.int)
281 | y_positive = event_sequence[polarity, 2].astype(numpy.int)
282 |
283 | positive_histogram, _, _ = numpy.histogram2d(
284 | x_positive,
285 | y_positive,
286 | bins=(width, height),
287 | range=[[0, width], [0, height]])
288 | negative_histogram, _, _ = numpy.histogram2d(
289 | x_negative,
290 | y_negative,
291 | bins=(width, height),
292 | range=[[0, width], [0, height]])
293 |
294 | # Red -> Negative Events
295 | red = numpy.transpose((negative_histogram >= positive_histogram) & (negative_histogram != 0))
296 | # Blue -> Positive Events
297 | blue = numpy.transpose(positive_histogram > negative_histogram)
298 | # Normally, we flip first, before we apply the other data augmentations
299 | if flip_before_crop:
300 | if horizontal_flip:
301 | red = numpy.flip(red, axis=1)
302 | blue = numpy.flip(blue, axis=1)
303 | # Rotate, if necessary
304 | if rotation_angle is not None:
305 | red = rotate(red, angle=rotation_angle, preserve_range=True).astype(bool)
306 | blue = rotate(blue, angle=rotation_angle, preserve_range=True).astype(bool)
307 | # Crop, if necessary
308 | if crop_window is not None:
309 | tf = transformers.RandomCropping(crop_height=crop_window['crop_height'],
310 | crop_width=crop_window['crop_width'],
311 | left_right=crop_window['left_right'],
312 | shift=crop_window['shift'])
313 | red = tf.crop_image(red, None, window=crop_window)
314 | blue = tf.crop_image(blue, None, window=crop_window)
315 | else:
316 | # Rotate, if necessary
317 | if rotation_angle is not None:
318 | red = rotate(red, angle=rotation_angle, preserve_range=True).astype(bool)
319 | blue = rotate(blue, angle=rotation_angle, preserve_range=True).astype(bool)
320 | # Crop, if necessary
321 | if crop_window is not None:
322 | tf = transformers.RandomCropping(crop_height=crop_window['crop_height'],
323 | crop_width=crop_window['crop_width'],
324 | left_right=crop_window['left_right'],
325 | shift=crop_window['shift'])
326 | red = tf.crop_image(red, None, window=crop_window)
327 | blue = tf.crop_image(blue, None, window=crop_window)
328 | if horizontal_flip:
329 | red = numpy.flip(red, axis=1)
330 | blue = numpy.flip(blue, axis=1)
331 |
332 | if background is None:
333 | height, width = red.shape
334 | background = torch.full((3, height, width), 255).byte()
335 | if len(background.shape) == 2:
336 | background = background.unsqueeze(0)
337 | else:
338 | if min(background.size()) == 1:
339 | background = grayscale_to_rgb(background)
340 | else:
341 | if not isinstance(background, torch.Tensor):
342 | background = torch.from_numpy(background)
343 | points_on_background = plot_points_on_background(
344 | torch.nonzero(torch.from_numpy(red.astype(numpy.uint8))), background,
345 | [255, 0, 0])
346 | points_on_background = plot_points_on_background(
347 | torch.nonzero(torch.from_numpy(blue.astype(numpy.uint8))),
348 | points_on_background, [0, 0, 255])
349 | return points_on_background
350 |
351 |
352 | def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
353 | new_cmap = colors.LinearSegmentedColormap.from_list(
354 | 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
355 | cmap(numpy.linspace(minval, maxval, n)))
356 | return new_cmap
357 |
358 |
359 | def plot_points_on_background(points_coordinates,
360 | background,
361 | points_color=[0, 0, 255]):
362 | """
363 | Args:
364 | points_coordinates: array of (y, x) points coordinates
365 | of size (number_of_points x 2).
366 | background: (3 x height x width)
367 | gray or color image uint8.
368 | color: color of points [red, green, blue] uint8.
369 | """
370 | if not (len(background.size()) == 3 and background.size(0) == 3):
371 | raise ValueError('background should be (color x height x width).')
372 | _, height, width = background.size()
373 | background_with_points = background.clone()
374 | y, x = points_coordinates.transpose(0, 1)
375 | if len(x) > 0 and len(y) > 0: # There can be empty arrays!
376 | x_min, x_max = x.min(), x.max()
377 | y_min, y_max = y.min(), y.max()
378 | if not (x_min >= 0 and y_min >= 0 and x_max < width and y_max < height):
379 | raise ValueError('points coordinates are outsize of "background" '
380 | 'boundaries.')
381 | background_with_points[:, y, x] = torch.Tensor(points_color).type_as(
382 | background).unsqueeze(-1)
383 | return background_with_points
384 |
385 |
386 | def visualize_optical_flow(flow, savepath=None, return_image=False, text=None, scaling=None):
387 | # flow -> numpy array 2 x height x width
388 | # 2,h,w -> h,w,2
389 | flow = flow.transpose(1,2,0)
390 | flow[numpy.isinf(flow)]=0
391 | # Use Hue, Saturation, Value colour model
392 | hsv = numpy.zeros((flow.shape[0], flow.shape[1], 3), dtype=float)
393 |
394 | # The additional **0.5 is a scaling factor
395 | mag = numpy.sqrt(flow[...,0]**2+flow[...,1]**2)**0.5
396 |
397 | ang = numpy.arctan2(flow[...,1], flow[...,0])
398 | ang[ang<0]+=numpy.pi*2
399 | hsv[..., 0] = ang/numpy.pi/2.0 # Scale from 0..1
400 | hsv[..., 1] = 1
401 | if scaling is None:
402 | hsv[..., 2] = (mag-mag.min())/(mag-mag.min()).max() # Scale from 0..1
403 | else:
404 | mag[mag>scaling]=scaling
405 | hsv[...,2] = mag/scaling
406 | rgb = colors.hsv_to_rgb(hsv)
407 | # This all seems like an overkill, but it's just to exactly match the cv2 implementation
408 | bgr = numpy.stack([rgb[...,2],rgb[...,1],rgb[...,0]], axis=2)
409 | plot_with_pyplot = False
410 | if plot_with_pyplot:
411 | fig = plt.figure(frameon=False)
412 | plot = plt.imshow(bgr)
413 | plot.axes.get_xaxis().set_visible(False)
414 | plot.axes.get_yaxis().set_visible(False)
415 | if text is not None:
416 | plt.text(0, -5, text)
417 |
418 | if savepath is not None:
419 | if plot_with_pyplot:
420 | fig.savefig(savepath, bbox_inches='tight', dpi=200)
421 | plt.close()
422 | else: #Plot with skimage
423 | out = bgr*255
424 | io.imsave(savepath, out.astype('uint8'))
425 | return bgr, (mag.min(), mag.max())
426 |
--------------------------------------------------------------------------------