├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── .keep ├── config ├── dsec_standard.json ├── dsec_warm_start.json ├── mvsec_20.json └── mvsec_45.json ├── download_dsec_test.py ├── environment.yml ├── loader ├── loader_dsec.py ├── loader_mvsec_flow.py └── utils.py ├── main.py ├── model ├── corr.py ├── eraft.py ├── extractor.py ├── update.py └── utils.py ├── test.py └── utils ├── dsec_utils.py ├── filename_templates.py ├── helper_functions.py ├── image_utils.py ├── logger.py ├── mvsec_utils.py ├── transformers.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .vscode/ 3 | __pycache__/ 4 | venv/ 5 | datasets/ 6 | saved/* 7 | statistics* 8 | utils/login_credentials.json 9 | build/ 10 | checkpoints/*.tar 11 | !config/examples 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Robotics and Perception Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # E-RAFT: Dense Optical Flow from Event Cameras 2 | 3 |

4 | 5 | E-RAFT 6 | 7 |

8 | 9 | This is the code for the paper **E-RAFT: Dense Optical Flow from Event Cameras** by [Mathias Gehrig](https://magehrig.github.io/), Mario Millhäusler, [Daniel Gehrig](https://danielgehrig18.github.io/) and [Davide Scaramuzza](http://rpg.ifi.uzh.ch/people_scaramuzza.html). 10 | 11 | We also introduce DSEC-Flow ([download here](https://dsec.ifi.uzh.ch/dsec-datasets/download/)), the optical flow extension of the [DSEC](https://dsec.ifi.uzh.ch/) dataset. We are also hosting an automatic evaluation server and a [public benchmark](https://dsec.ifi.uzh.ch/uzh/dsec-flow-optical-flow-benchmark/)! 12 | 13 | Visit our [project webpage](http://rpg.ifi.uzh.ch/ERAFT.html) or download the paper directly [here](https://dsec.ifi.uzh.ch/wp-content/uploads/2021/10/eraft_3dv.pdf) for more details. 14 | If you use any of this code, please cite the following publication: 15 | 16 | ```bibtex 17 | @InProceedings{Gehrig3dv2021, 18 | author = {Mathias Gehrig and Mario Millh\"ausler and Daniel Gehrig and Davide Scaramuzza}, 19 | title = {E-RAFT: Dense Optical Flow from Event Cameras}, 20 | booktitle = {International Conference on 3D Vision (3DV)}, 21 | year = {2021} 22 | } 23 | ``` 24 | 25 | ## Download 26 | 27 | Download the network checkpoints and place them in the folder ```checkpoints/```: 28 | 29 | 30 | [Checkpoint trained on DSEC](https://download.ifi.uzh.ch/rpg/ERAFT/checkpoints/dsec.tar) 31 | 32 | [Checkpoint trained on MVSEC 20 Hz](https://download.ifi.uzh.ch/rpg/ERAFT/checkpoints/mvsec_20.tar) 33 | 34 | [Checkpoint trained on MVSEC 45 Hz](https://download.ifi.uzh.ch/rpg/ERAFT/checkpoints/mvsec_45.tar) 35 | 36 | 37 | ## Installation 38 | Please install [conda](https://www.anaconda.com/download). 39 | Then, create new conda environment with python3.7 and all dependencies by running 40 | ``` 41 | conda env create --file environment.yml 42 | ``` 43 | 44 | ## Datasets 45 | ### DSEC 46 | The DSEC dataset for optical flow can be downloaded [here](https://dsec.ifi.uzh.ch/dsec-datasets/download/). 47 | We prepared a script [download_dsec_test.py](download_dsec_test.py) for your convenience. 48 | It downloads the dataset directly into the `OUTPUT_DIRECTORY` with the expected directory structure. 49 | ```python 50 | download_dsec_test.py OUTPUT_DIRECTORY 51 | ``` 52 | 53 | ### MVSEC 54 | To use the MVSEC dataset for our approach, it needs to be pre-processed into the right format. For your convenience, we provide the pre-processed dataset here: 55 | 56 | [MVSEC Outdoor Day 1 for 20 Hz evaluation](https://download.ifi.uzh.ch/rpg/ERAFT/datasets/mvsec_outdoor_day_1_20Hz.tar) 57 | 58 | [MVSEC Outdoor Day 1 for 45 Hz evaluation](https://download.ifi.uzh.ch/rpg/ERAFT/datasets/mvsec_outdoor_day_1_45Hz.tar) 59 | 60 | ## Experiments 61 | ### DSEC Dataset 62 | For the evaluation of our method with warm-starting, execute the following command: 63 | ``` 64 | python3 main.py --path 65 | ``` 66 | For the evaluation of our method **without** warm-starting, execute the following command: 67 | ``` 68 | python3 main.py --path --type standard 69 | ``` 70 | ### MVSEC Dataset 71 | For the evaluation of our method with warm-starting, trained on 20Hz MVSEC data, execute the following command: 72 | ``` 73 | python3 main.py --path --dataset mvsec --frequency 20 74 | ``` 75 | For the evaluation of our method with warm-starting, trained on 45Hz MVSEC data, execute the following command: 76 | ``` 77 | python3 main.py --path --dataset mvsec --frequency 45 78 | ``` 79 | 80 | ### Arguments 81 | ```--path``` : Path where you stored the dataset 82 | 83 | ```--dataset``` : Which dataset to use: ([dsec]/mvsec) 84 | 85 | ```--type``` : Evaluation type ([warm_start]/standard) 86 | 87 | ```--frequency``` : Evaluation frequency of MVSEC dataset ([20]/45) Hz 88 | 89 | ```--visualize``` : Provide this argument s.t. DSEC results are visualized. MVSEC experiments are always visualized. 90 | 91 | ```--num_workers``` : How many sub-processes to use for data loading (default=0) 92 | -------------------------------------------------------------------------------- /checkpoints/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/E-RAFT/c58ce0524ea0ebfa9849991caafb547f44fe9bfd/checkpoints/.keep -------------------------------------------------------------------------------- /config/dsec_standard.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "dsec_standard", 3 | "cuda": true, 4 | "gpu": 0, 5 | "subtype": "standard", 6 | "save_dir": "saved", 7 | "data_loader": { 8 | "test": { 9 | "args": { 10 | "batch_size": 1, 11 | "shuffle": false, 12 | "num_voxel_bins": 15 13 | } 14 | } 15 | }, 16 | "test": { 17 | "checkpoint": "checkpoints/dsec.tar" 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /config/dsec_warm_start.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "dsec_warm_start", 3 | "cuda": true, 4 | "gpu": 0, 5 | "subtype": "warm_start", 6 | "save_dir": "saved", 7 | "data_loader": { 8 | "test": { 9 | "args": { 10 | "batch_size": 1, 11 | "shuffle": false, 12 | "sequence_length": 1, 13 | "num_voxel_bins": 15 14 | } 15 | } 16 | }, 17 | "test": { 18 | "checkpoint": "checkpoints/dsec.tar" 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /config/mvsec_20.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mvsec_20Hz", 3 | "cuda": true, 4 | "gpu": 0, 5 | "subtype": "warm_start", 6 | "save_dir": "saved", 7 | "data_loader": { 8 | "test": { 9 | "args": { 10 | "batch_size": 1, 11 | "shuffle": false, 12 | "sequence_length": 1, 13 | "num_voxel_bins": 15, 14 | "align_to": "depth", 15 | "datasets": { 16 | "outdoor_day": [ 17 | 1 18 | ] 19 | }, 20 | "filter": { 21 | "outdoor_day": { 22 | "1": "range(4356, 4706)" 23 | } 24 | }, 25 | "transforms": [ 26 | "EventSequenceToVoxelGrid_Pytorch(num_bins=15, normalize=True, gpu=True)", 27 | "RandomCropping(crop_height=256, crop_width=256, fixed=True)" 28 | ] 29 | } 30 | } 31 | }, 32 | "test": { 33 | "checkpoint": "checkpoints/mvsec_20.tar" 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /config/mvsec_45.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mvsec_45Hz", 3 | "cuda": true, 4 | "gpu": 0, 5 | "subtype": "warm_start", 6 | "save_dir": "saved", 7 | "data_loader": { 8 | "test": { 9 | "args": { 10 | "batch_size": 1, 11 | "shuffle": false, 12 | "sequence_length": 1, 13 | "num_voxel_bins": 5, 14 | "align_to": "images", 15 | "datasets": { 16 | "outdoor_day": [ 17 | 1 18 | ] 19 | }, 20 | "filter": { 21 | "outdoor_day": { 22 | "1": "range(10167,10954)" 23 | } 24 | }, 25 | "transforms": [ 26 | "EventSequenceToVoxelGrid_Pytorch(num_bins=5, normalize=True, gpu=True)", 27 | "RandomCropping(crop_height=256, crop_width=256, fixed=True)" 28 | ] 29 | } 30 | } 31 | }, 32 | "test": { 33 | "checkpoint": "checkpoints/mvsec_45.tar" 34 | } 35 | } -------------------------------------------------------------------------------- /download_dsec_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import os 4 | import urllib 5 | import shutil 6 | from typing import Union 7 | 8 | from requests import get 9 | 10 | TEST_SEQUENCES = ['interlaken_00_b', 'interlaken_01_a', 'thun_01_a', 'thun_01_b', 'zurich_city_12_a', 'zurich_city_14_c', 'zurich_city_15_a'] 11 | BASE_TEST_URL = 'https://download.ifi.uzh.ch/rpg/DSEC/test/' 12 | TEST_FLOW_TIMESTAMPS_URL = 'https://download.ifi.uzh.ch/rpg/DSEC/test_forward_optical_flow_timestamps.zip' 13 | 14 | def download(url: str, filepath: Path, skip: bool=True) -> bool: 15 | if skip and filepath.exists(): 16 | print(f'{str(filepath)} already exists. Skipping download.') 17 | return True 18 | with open(str(filepath), 'wb') as fl: 19 | response = get(url) 20 | fl.write(response.content) 21 | return response.ok 22 | 23 | def unzip(file_: Path, delete_zip: bool=True, skip: bool=True) -> Path: 24 | assert file_.exists() 25 | assert file_.suffix == '.zip' 26 | output_dir = file_.parent / file_.stem 27 | if skip and output_dir.exists(): 28 | print(f'{str(output_dir)} already exists. Skipping unzipping operation.') 29 | else: 30 | shutil.unpack_archive(file_, output_dir) 31 | if delete_zip: 32 | os.remove(file_) 33 | return output_dir 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('output_directory') 38 | 39 | args = parser.parse_args() 40 | 41 | output_dir = Path(args.output_directory) 42 | output_dir = output_dir / 'test' 43 | os.makedirs(output_dir, exist_ok=True) 44 | 45 | test_timestamps_file = output_dir / 'test_forward_flow_timestamps.zip' 46 | 47 | assert download(TEST_FLOW_TIMESTAMPS_URL, test_timestamps_file), TEST_FLOW_TIMESTAMPS_URL 48 | test_timestamps_dir = unzip(test_timestamps_file) 49 | 50 | for seq_name in TEST_SEQUENCES: 51 | seq_path = output_dir / seq_name 52 | os.makedirs(seq_path, exist_ok=True) 53 | 54 | # image timestamps 55 | img_timestamps_url = BASE_TEST_URL + seq_name + '/' + seq_name + '_image_timestamps.txt' 56 | img_timestamps_file = seq_path / 'image_timestamps.txt' 57 | if not img_timestamps_file.exists(): 58 | assert download(img_timestamps_url, img_timestamps_file), img_timestamps_url 59 | 60 | # test timestamps 61 | test_timestamps_file_destination = seq_path / 'test_forward_flow_timestamps.csv' 62 | if not test_timestamps_file_destination.exists(): 63 | shutil.move(test_timestamps_dir / (seq_name + '.csv'), test_timestamps_file_destination) 64 | 65 | # event data 66 | events_left_url = BASE_TEST_URL + seq_name + '/' + seq_name + '_events_left.zip' 67 | events_left_file = seq_path / 'events_left.zip' 68 | if not (events_left_file.parent / events_left_file.stem).exists(): 69 | assert download(events_left_url, events_left_file), events_left_url 70 | unzip(events_left_file) 71 | 72 | shutil.rmtree(test_timestamps_dir) 73 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: eRaft 2 | channels: 3 | - pytorch 4 | - numba 5 | - anaconda 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=1_gnu 11 | - astroid=2.5=py37h06a4308_1 12 | - blas=1.0=mkl 13 | - blosc=1.21.0=h9c3ff4c_0 14 | - blosc-hdf5-plugin=1.0.0=h646ab9b_2 15 | - brotlipy=0.7.0=py37h7b6447c_1000 16 | - bzip2=1.0.8=h7b6447c_0 17 | - c-ares=1.17.1=h7f98852_1 18 | - ca-certificates=2020.10.14=0 19 | - cached-property=1.5.2=hd8ed1ab_1 20 | - cached_property=1.5.2=pyha770c72_1 21 | - cairo=1.16.0=h6cf1ce9_1008 22 | - certifi=2020.6.20=py37_0 23 | - cffi=1.14.3=py37he30daa8_0 24 | - chardet=3.0.4=py37_1003 25 | - cloudpickle=1.6.0=py_0 26 | - cryptography=3.1.1=py37h1ba5d50_0 27 | - cudatoolkit=10.2.89=hfd86e86_1 28 | - cycler=0.10.0=py37_0 29 | - cytoolz=0.11.0=py37h7b6447c_0 30 | - dask-core=2.30.0=py_0 31 | - dbus=1.13.6=h48d8840_2 32 | - decorator=4.4.2=py_0 33 | - expat=2.3.0=h9c3ff4c_0 34 | - ffmpeg=4.3.1=hca11adc_2 35 | - fontconfig=2.13.1=hba837de_1004 36 | - freetype=2.10.4=h5ab3b9f_0 37 | - gettext=0.19.8.1=h0b5b191_1005 38 | - gitdb=4.0.7=pyhd3eb1b0_0 39 | - gitpython=3.1.14=pyhd3eb1b0_1 40 | - glib=2.68.0=h9c3ff4c_2 41 | - glib-tools=2.68.0=h9c3ff4c_2 42 | - gmp=6.2.1=h58526e2_0 43 | - gnutls=3.6.13=h85f3911_1 44 | - graphite2=1.3.13=h58526e2_1001 45 | - gst-plugins-base=1.18.4=h29181c9_0 46 | - gstreamer=1.18.4=h76c114f_0 47 | - h5py=3.1.0=nompi_py37h1e651dc_100 48 | - harfbuzz=2.8.0=h83ec7ef_1 49 | - hdf5=1.10.6=nompi_h7c3c948_1111 50 | - icu=68.1=h58526e2_0 51 | - idna=2.10=py_0 52 | - imageio=2.9.0=py_0 53 | - intel-openmp=2020.2=254 54 | - isort=5.8.0=pyhd3eb1b0_0 55 | - jasper=1.900.1=h07fcdf6_1006 56 | - jpeg=9d=h36c2ea0_0 57 | - kiwisolver=1.2.0=py37hfd86e86_0 58 | - krb5=1.17.2=h926e7f8_0 59 | - lame=3.100=h7f98852_1001 60 | - lazy-object-proxy=1.6.0=py37h27cfd23_0 61 | - lcms2=2.11=h396b838_0 62 | - ld_impl_linux-64=2.33.1=h53a641e_7 63 | - libblas=3.9.0=1_h6e990d7_netlib 64 | - libcblas=3.9.0=3_h893e4fe_netlib 65 | - libclang=11.1.0=default_ha53f305_0 66 | - libcurl=7.76.0=hc4aaa36_0 67 | - libedit=3.1.20191231=he28a2e2_2 68 | - libev=4.33=h516909a_1 69 | - libevent=2.1.10=hcdb4288_3 70 | - libffi=3.3=he6710b0_2 71 | - libgcc-ng=9.3.0=h2828fa1_18 72 | - libgfortran-ng=7.5.0=h14aa051_18 73 | - libgfortran4=7.5.0=h14aa051_18 74 | - libglib=2.68.0=h3e27bee_2 75 | - libgomp=9.3.0=h2828fa1_18 76 | - libiconv=1.16=h516909a_0 77 | - liblapack=3.9.0=3_h893e4fe_netlib 78 | - liblapacke=3.9.0=3_h893e4fe_netlib 79 | - libllvm11=11.1.0=hf817b99_1 80 | - libnghttp2=1.43.0=h812cca2_0 81 | - libopencv=4.5.1=py37h5fff631_1 82 | - libpng=1.6.37=hbc83047_0 83 | - libpq=13.1=hfd2b0eb_2 84 | - libprotobuf=3.15.6=h780b84a_0 85 | - libssh2=1.9.0=ha56f1ee_6 86 | - libstdcxx-ng=9.3.0=h6de172a_18 87 | - libtiff=4.2.0=hdc55705_0 88 | - libuuid=2.32.1=h7f98852_1000 89 | - libuv=1.40.0=h7b6447c_0 90 | - libwebp-base=1.2.0=h7f98852_2 91 | - libxcb=1.13=h7f98852_1003 92 | - libxkbcommon=1.0.3=he3ba5ed_0 93 | - libxml2=2.9.10=h72842e0_3 94 | - llvmlite=0.36.0=py37hf484d3e_0 95 | - lz4-c=1.9.3=h9c3ff4c_0 96 | - lzo=2.10=h7b6447c_2 97 | - matplotlib=3.4.1=py37h89c1867_0 98 | - matplotlib-base=3.4.1=py37hdd32ed1_0 99 | - mccabe=0.6.1=py37_1 100 | - mkl=2019.4=243 101 | - mkl-service=2.3.0=py37he904b0f_0 102 | - mkl_fft=1.2.0=py37h23d657b_0 103 | - mkl_random=1.0.4=py37hd81dba3_0 104 | - mock=4.0.2=py_0 105 | - mysql-common=8.0.23=ha770c72_1 106 | - mysql-libs=8.0.23=h935591d_1 107 | - ncurses=6.2=he6710b0_1 108 | - nettle=3.6=he412f7d_0 109 | - networkx=2.5=py_0 110 | - ninja=1.10.2=py37hff7bd54_0 111 | - nspr=4.30=h9c3ff4c_0 112 | - nss=3.63=hb5efdd6_0 113 | - numba=0.53.1=np1.11py3.7h04863e7_g97fe221b3_0 114 | - numexpr=2.7.1=py37h423224d_0 115 | - numpy=1.19.1=py37hbc911f0_0 116 | - numpy-base=1.19.1=py37hfa32c7d_0 117 | - olefile=0.46=py37_0 118 | - opencv=4.5.1=py37h89c1867_1 119 | - openh264=2.1.1=h780b84a_0 120 | - openssl=1.1.1l=h7f8727e_0 121 | - pandas=1.2.3=py37ha9443f7_0 122 | - pcre=8.44=he1b5a44_0 123 | - pillow=8.0.0=py37h9a89aac_0 124 | - pip=21.0.1=py37h06a4308_0 125 | - pixman=0.40.0=h36c2ea0_0 126 | - pthread-stubs=0.4=h36c2ea0_1001 127 | - py-opencv=4.5.1=py37h085eea5_1 128 | - pycparser=2.20=py_2 129 | - pylint=2.7.4=py37h06a4308_1 130 | - pyopenssl=19.1.0=py_1 131 | - pyparsing=2.4.7=py_0 132 | - pyqt=5.12.3=py37h89c1867_7 133 | - pyqt-impl=5.12.3=py37he336c9b_7 134 | - pyqt5-sip=4.19.18=py37hcd2ae1e_7 135 | - pyqtchart=5.12=py37he336c9b_7 136 | - pyqtwebengine=5.12.1=py37he336c9b_7 137 | - pysocks=1.7.1=py37_1 138 | - pytables=3.6.1=py37h0c4f3e0_3 139 | - python=3.7.10=hdb3f193_0 140 | - python-dateutil=2.8.1=py_0 141 | - python_abi=3.7=1_cp37m 142 | - pytorch=1.7.0=py3.7_cuda10.2.89_cudnn7.6.5_0 143 | - pytz=2021.1=pyhd3eb1b0_0 144 | - pywavelets=1.1.1=py37h7b6447c_2 145 | - pyyaml=5.3.1=py37h7b6447c_1 146 | - qt=5.12.9=hda022c4_4 147 | - readline=8.1=h27cfd23_0 148 | - requests=2.24.0=py_0 149 | - ruamel.yaml=0.16.12=py37h5e8e339_2 150 | - ruamel.yaml.clib=0.2.2=py37h5e8e339_2 151 | - scikit-image=0.17.2=py37hdf5156a_0 152 | - scipy=1.5.2=py37h0b6359f_0 153 | - setuptools=52.0.0=py37h06a4308_0 154 | - six=1.15.0=py_0 155 | - smmap=3.0.5=pyhd3eb1b0_0 156 | - snappy=1.1.8=he6710b0_0 157 | - sqlite=3.35.3=hdfb4753_0 158 | - tifffile=2020.10.1=py37hdd07704_2 159 | - tk=8.6.10=hbc83047_0 160 | - toml=0.10.2=pyhd3eb1b0_0 161 | - toolz=0.11.1=py_0 162 | - torchvision=0.8.1=py37_cu102 163 | - tornado=6.0.4=py37h7b6447c_1 164 | - tqdm=4.59.0=pyhd8ed1ab_0 165 | - typed-ast=1.4.2=py37h27cfd23_1 166 | - typing_extensions=3.7.4.3=pyha847dfd_0 167 | - urllib3=1.25.11=py_0 168 | - wheel=0.36.2=pyhd3eb1b0_0 169 | - wrapt=1.12.1=py37h7b6447c_1 170 | - x264=1!161.3030=h7f98852_0 171 | - xorg-kbproto=1.0.7=h7f98852_1002 172 | - xorg-libice=1.0.10=h7f98852_0 173 | - xorg-libsm=1.2.3=hd9c2040_1000 174 | - xorg-libx11=1.7.0=h7f98852_0 175 | - xorg-libxau=1.0.9=h7f98852_0 176 | - xorg-libxdmcp=1.1.3=h7f98852_0 177 | - xorg-libxext=1.3.4=h7f98852_1 178 | - xorg-libxrender=0.9.10=h7f98852_1003 179 | - xorg-renderproto=0.11.1=h7f98852_1002 180 | - xorg-xextproto=7.3.0=h7f98852_1002 181 | - xorg-xproto=7.0.31=h7f98852_1007 182 | - xz=5.2.5=h7b6447c_0 183 | - yaml=0.2.5=h7b6447c_0 184 | - zlib=1.2.11=h7b6447c_3 185 | - zstd=1.4.9=ha95c52a_0 186 | -------------------------------------------------------------------------------- /loader/loader_dsec.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pathlib import Path 3 | from typing import Dict, Tuple 4 | import weakref 5 | 6 | import cv2 7 | import h5py 8 | from numba import jit 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import Dataset, DataLoader 12 | from utils import visualization as visu 13 | from matplotlib import pyplot as plt 14 | from utils import transformers 15 | import os 16 | import imageio 17 | 18 | from utils.dsec_utils import RepresentationType, VoxelGrid, flow_16bit_to_float 19 | 20 | VISU_INDEX = 1 21 | 22 | class EventSlicer: 23 | def __init__(self, h5f: h5py.File): 24 | self.h5f = h5f 25 | 26 | self.events = dict() 27 | for dset_str in ['p', 'x', 'y', 't']: 28 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)] 29 | 30 | # This is the mapping from milliseconds to event index: 31 | # It is defined such that 32 | # (1) t[ms_to_idx[ms]] >= ms*1000 33 | # (2) t[ms_to_idx[ms] - 1] < ms*1000 34 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds. 35 | # 36 | # As an example, given 't' and 'ms': 37 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000 38 | # ms: 0 1 2 3 4 5 6 7 8 9 39 | # 40 | # we get 41 | # 42 | # ms_to_idx: 43 | # 0 2 2 3 3 3 5 5 8 9 44 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64') 45 | 46 | self.t_offset = int(h5f['t_offset'][()]) 47 | self.t_final = int(self.events['t'][-1]) + self.t_offset 48 | 49 | def get_final_time_us(self): 50 | return self.t_final 51 | 52 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]: 53 | """Get events (p, x, y, t) within the specified time window 54 | Parameters 55 | ---------- 56 | t_start_us: start time in microseconds 57 | t_end_us: end time in microseconds 58 | Returns 59 | ------- 60 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved 61 | """ 62 | assert t_start_us < t_end_us 63 | 64 | # We assume that the times are top-off-day, hence subtract offset: 65 | t_start_us -= self.t_offset 66 | t_end_us -= self.t_offset 67 | 68 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us) 69 | t_start_ms_idx = self.ms2idx(t_start_ms) 70 | t_end_ms_idx = self.ms2idx(t_end_ms) 71 | 72 | if t_start_ms_idx is None or t_end_ms_idx is None: 73 | # Cannot guarantee window size anymore 74 | return None 75 | 76 | events = dict() 77 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx]) 78 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us) 79 | t_start_us_idx = t_start_ms_idx + idx_start_offset 80 | t_end_us_idx = t_start_ms_idx + idx_end_offset 81 | # Again add t_offset to get gps time 82 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset 83 | for dset_str in ['p', 'x', 'y']: 84 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx]) 85 | assert events[dset_str].size == events['t'].size 86 | return events 87 | 88 | 89 | @staticmethod 90 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]: 91 | """Compute a conservative time window of time with millisecond resolution. 92 | We have a time to index mapping for each millisecond. Hence, we need 93 | to compute the lower and upper millisecond to retrieve events. 94 | Parameters 95 | ---------- 96 | ts_start_us: start time in microseconds 97 | ts_end_us: end time in microseconds 98 | Returns 99 | ------- 100 | window_start_ms: conservative start time in milliseconds 101 | window_end_ms: conservative end time in milliseconds 102 | """ 103 | assert ts_end_us > ts_start_us 104 | window_start_ms = math.floor(ts_start_us/1000) 105 | window_end_ms = math.ceil(ts_end_us/1000) 106 | return window_start_ms, window_end_ms 107 | 108 | @staticmethod 109 | @jit(nopython=True) 110 | def get_time_indices_offsets( 111 | time_array: np.ndarray, 112 | time_start_us: int, 113 | time_end_us: int) -> Tuple[int, int]: 114 | """Compute index offset of start and end timestamps in microseconds 115 | Parameters 116 | ---------- 117 | time_array: timestamps (in us) of the events 118 | time_start_us: start timestamp (in us) 119 | time_end_us: end timestamp (in us) 120 | Returns 121 | ------- 122 | idx_start: Index within this array corresponding to time_start_us 123 | idx_end: Index within this array corresponding to time_end_us 124 | such that (in non-edge cases) 125 | time_array[idx_start] >= time_start_us 126 | time_array[idx_end] >= time_end_us 127 | time_array[idx_start - 1] < time_start_us 128 | time_array[idx_end - 1] < time_end_us 129 | this means that 130 | time_start_us <= time_array[idx_start:idx_end] < time_end_us 131 | """ 132 | 133 | assert time_array.ndim == 1 134 | 135 | idx_start = -1 136 | if time_array[-1] < time_start_us: 137 | # This can happen in extreme corner cases. E.g. 138 | # time_array[0] = 1016 139 | # time_array[-1] = 1984 140 | # time_start_us = 1990 141 | # time_end_us = 2000 142 | 143 | # Return same index twice: array[x:x] is empty. 144 | return time_array.size, time_array.size 145 | else: 146 | for idx_from_start in range(0, time_array.size, 1): 147 | if time_array[idx_from_start] >= time_start_us: 148 | idx_start = idx_from_start 149 | break 150 | assert idx_start >= 0 151 | 152 | idx_end = time_array.size 153 | for idx_from_end in range(time_array.size - 1, -1, -1): 154 | if time_array[idx_from_end] >= time_end_us: 155 | idx_end = idx_from_end 156 | else: 157 | break 158 | 159 | assert time_array[idx_start] >= time_start_us 160 | if idx_end < time_array.size: 161 | assert time_array[idx_end] >= time_end_us 162 | if idx_start > 0: 163 | assert time_array[idx_start - 1] < time_start_us 164 | if idx_end > 0: 165 | assert time_array[idx_end - 1] < time_end_us 166 | return idx_start, idx_end 167 | 168 | def ms2idx(self, time_ms: int) -> int: 169 | assert time_ms >= 0 170 | if time_ms >= self.ms_to_idx.size: 171 | return None 172 | return self.ms_to_idx[time_ms] 173 | 174 | 175 | class Sequence(Dataset): 176 | def __init__(self, seq_path: Path, representation_type: RepresentationType, mode: str='test', delta_t_ms: int=100, 177 | num_bins: int=15, transforms=None, name_idx=0, visualize=False): 178 | assert num_bins >= 1 179 | assert delta_t_ms == 100 180 | assert seq_path.is_dir() 181 | assert mode in {'train', 'test'} 182 | ''' 183 | Directory Structure: 184 | 185 | Dataset 186 | └── test 187 | ├── interlaken_00_b 188 | │   ├── events_left 189 | │   │   ├── events.h5 190 | │   │   └── rectify_map.h5 191 | │   ├── image_timestamps.txt 192 | │   └── test_forward_flow_timestamps.csv 193 | 194 | ''' 195 | 196 | self.mode = mode 197 | self.name_idx = name_idx 198 | self.visualize_samples = visualize 199 | # Get Test Timestamp File 200 | test_timestamp_file = seq_path / 'test_forward_flow_timestamps.csv' 201 | assert test_timestamp_file.is_file() 202 | file = np.genfromtxt( 203 | test_timestamp_file, 204 | delimiter=',' 205 | ) 206 | self.idx_to_visualize = file[:,2] 207 | 208 | # Save output dimensions 209 | self.height = 480 210 | self.width = 640 211 | self.num_bins = num_bins 212 | 213 | # Just for now, we always train with num_bins=15 214 | assert self.num_bins==15 215 | 216 | # Set event representation 217 | self.voxel_grid = None 218 | if representation_type == RepresentationType.VOXEL: 219 | self.voxel_grid = VoxelGrid((self.num_bins, self.height, self.width), normalize=True) 220 | 221 | 222 | # Save delta timestamp in ms 223 | self.delta_t_us = delta_t_ms * 1000 224 | 225 | #Load and compute timestamps and indices 226 | timestamps_images = np.loadtxt(seq_path / 'image_timestamps.txt', dtype='int64') 227 | image_indices = np.arange(len(timestamps_images)) 228 | # But only use every second one because we train at 10 Hz, and we leave away the 1st & last one 229 | self.timestamps_flow = timestamps_images[::2][1:-1] 230 | self.indices = image_indices[::2][1:-1] 231 | 232 | # Left events only 233 | ev_dir_location = seq_path / 'events_left' 234 | ev_data_file = ev_dir_location / 'events.h5' 235 | ev_rect_file = ev_dir_location / 'rectify_map.h5' 236 | 237 | h5f_location = h5py.File(str(ev_data_file), 'r') 238 | self.h5f = h5f_location 239 | self.event_slicer = EventSlicer(h5f_location) 240 | with h5py.File(str(ev_rect_file), 'r') as h5_rect: 241 | self.rectify_ev_map = h5_rect['rectify_map'][()] 242 | 243 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f) 244 | 245 | def events_to_voxel_grid(self, p, t, x, y, device: str='cpu'): 246 | t = (t - t[0]).astype('float32') 247 | t = (t/t[-1]) 248 | x = x.astype('float32') 249 | y = y.astype('float32') 250 | pol = p.astype('float32') 251 | event_data_torch = { 252 | 'p': torch.from_numpy(pol), 253 | 't': torch.from_numpy(t), 254 | 'x': torch.from_numpy(x), 255 | 'y': torch.from_numpy(y), 256 | } 257 | return self.voxel_grid.convert(event_data_torch) 258 | 259 | def getHeightAndWidth(self): 260 | return self.height, self.width 261 | 262 | @staticmethod 263 | def get_disparity_map(filepath: Path): 264 | assert filepath.is_file() 265 | disp_16bit = cv2.imread(str(filepath), cv2.IMREAD_ANYDEPTH) 266 | return disp_16bit.astype('float32')/256 267 | 268 | @staticmethod 269 | def load_flow(flowfile: Path): 270 | assert flowfile.exists() 271 | assert flowfile.suffix == '.png' 272 | flow_16bit = imageio.imread(str(flowfile), format='PNG-FI') 273 | flow, valid2D = flow_16bit_to_float(flow_16bit) 274 | return flow, valid2D 275 | 276 | @staticmethod 277 | def close_callback(h5f): 278 | h5f.close() 279 | 280 | def get_image_width_height(self): 281 | return self.height, self.width 282 | 283 | def __len__(self): 284 | return len(self.timestamps_flow) 285 | 286 | def rectify_events(self, x: np.ndarray, y: np.ndarray): 287 | # assert location in self.locations 288 | # From distorted to undistorted 289 | rectify_map = self.rectify_ev_map 290 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape 291 | assert x.max() < self.width 292 | assert y.max() < self.height 293 | return rectify_map[y, x] 294 | 295 | def get_data_sample(self, index, crop_window=None, flip=None): 296 | # First entry corresponds to all events BEFORE the flow map 297 | # Second entry corresponds to all events AFTER the flow map (corresponding to the actual fwd flow) 298 | names = ['event_volume_old', 'event_volume_new'] 299 | ts_start = [self.timestamps_flow[index] - self.delta_t_us, self.timestamps_flow[index]] 300 | ts_end = [self.timestamps_flow[index], self.timestamps_flow[index] + self.delta_t_us] 301 | 302 | file_index = self.indices[index] 303 | 304 | output = { 305 | 'file_index': file_index, 306 | 'timestamp': self.timestamps_flow[index] 307 | } 308 | # Save sample for benchmark submission 309 | output['save_submission'] = file_index in self.idx_to_visualize 310 | output['visualize'] = self.visualize_samples 311 | 312 | for i in range(len(names)): 313 | event_data = self.event_slicer.get_events(ts_start[i], ts_end[i]) 314 | 315 | p = event_data['p'] 316 | t = event_data['t'] 317 | x = event_data['x'] 318 | y = event_data['y'] 319 | 320 | xy_rect = self.rectify_events(x, y) 321 | x_rect = xy_rect[:, 0] 322 | y_rect = xy_rect[:, 1] 323 | 324 | if crop_window is not None: 325 | # Cropping (+- 2 for safety reasons) 326 | x_mask = (x_rect >= crop_window['start_x']-2) & (x_rect < crop_window['start_x']+crop_window['crop_width']+2) 327 | y_mask = (y_rect >= crop_window['start_y']-2) & (y_rect < crop_window['start_y']+crop_window['crop_height']+2) 328 | mask_combined = x_mask & y_mask 329 | p = p[mask_combined] 330 | t = t[mask_combined] 331 | x_rect = x_rect[mask_combined] 332 | y_rect = y_rect[mask_combined] 333 | 334 | if self.voxel_grid is None: 335 | raise NotImplementedError 336 | else: 337 | event_representation = self.events_to_voxel_grid(p, t, x_rect, y_rect) 338 | output[names[i]] = event_representation 339 | output['name_map']=self.name_idx 340 | return output 341 | 342 | def __getitem__(self, idx): 343 | sample = self.get_data_sample(idx) 344 | return sample 345 | 346 | 347 | class SequenceRecurrent(Sequence): 348 | def __init__(self, seq_path: Path, representation_type: RepresentationType, mode: str='test', delta_t_ms: int=100, 349 | num_bins: int=15, transforms=None, sequence_length=1, name_idx=0, visualize=False): 350 | super(SequenceRecurrent, self).__init__(seq_path, representation_type, mode, delta_t_ms, transforms=transforms, 351 | name_idx=name_idx, visualize=visualize) 352 | self.sequence_length = sequence_length 353 | self.valid_indices = self.get_continuous_sequences() 354 | 355 | def get_continuous_sequences(self): 356 | continuous_seq_idcs = [] 357 | if self.sequence_length > 1: 358 | for i in range(len(self.timestamps_flow)-self.sequence_length+1): 359 | diff = self.timestamps_flow[i+self.sequence_length-1] - self.timestamps_flow[i] 360 | if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]): 361 | continuous_seq_idcs.append(i) 362 | else: 363 | for i in range(len(self.timestamps_flow)-1): 364 | diff = self.timestamps_flow[i+1] - self.timestamps_flow[i] 365 | if diff < np.max([100000 * (self.sequence_length-1) + 1000, 101000]): 366 | continuous_seq_idcs.append(i) 367 | return continuous_seq_idcs 368 | 369 | def __len__(self): 370 | return len(self.valid_indices) 371 | 372 | def __getitem__(self, idx): 373 | assert idx >= 0 374 | assert idx < len(self) 375 | 376 | # Valid index is the actual index we want to load, which guarantees a continuous sequence length 377 | valid_idx = self.valid_indices[idx] 378 | 379 | sequence = [] 380 | j = valid_idx 381 | 382 | ts_cur = self.timestamps_flow[j] 383 | # Add first sample 384 | sample = self.get_data_sample(j) 385 | sequence.append(sample) 386 | 387 | # Data augmentation according to first sample 388 | crop_window = None 389 | flip = None 390 | if 'crop_window' in sample.keys(): 391 | crop_window = sample['crop_window'] 392 | if 'flipped' in sample.keys(): 393 | flip = sample['flipped'] 394 | 395 | for i in range(self.sequence_length-1): 396 | j += 1 397 | ts_old = ts_cur 398 | ts_cur = self.timestamps_flow[j] 399 | assert(ts_cur-ts_old < 100000 + 1000) 400 | sample = self.get_data_sample(j, crop_window=crop_window, flip=flip) 401 | sequence.append(sample) 402 | 403 | # Check if the current sample is the first sample of a continuous sequence 404 | if idx==0 or self.valid_indices[idx]-self.valid_indices[idx-1] != 1: 405 | sequence[0]['new_sequence'] = 1 406 | print("Timestamp {} is the first one of the next seq!".format(self.timestamps_flow[self.valid_indices[idx]])) 407 | else: 408 | sequence[0]['new_sequence'] = 0 409 | return sequence 410 | 411 | class DatasetProvider: 412 | def __init__(self, dataset_path: Path, representation_type: RepresentationType, delta_t_ms: int=100, num_bins=15, 413 | type='standard', config=None, visualize=False): 414 | test_path = dataset_path / 'test' 415 | assert dataset_path.is_dir(), str(dataset_path) 416 | assert test_path.is_dir(), str(test_path) 417 | assert delta_t_ms == 100 418 | self.config=config 419 | self.name_mapper_test = [] 420 | 421 | test_sequences = list() 422 | for child in test_path.iterdir(): 423 | self.name_mapper_test.append(str(child).split("/")[-1]) 424 | if type == 'standard': 425 | test_sequences.append(Sequence(child, representation_type, 'test', delta_t_ms, num_bins, 426 | transforms=[], 427 | name_idx=len(self.name_mapper_test)-1, 428 | visualize=visualize)) 429 | elif type == 'warm_start': 430 | test_sequences.append(SequenceRecurrent(child, representation_type, 'test', delta_t_ms, num_bins, 431 | transforms=[], sequence_length=1, 432 | name_idx=len(self.name_mapper_test)-1, 433 | visualize=visualize)) 434 | else: 435 | raise Exception('Please provide a valid subtype [standard/warm_start] in config file!') 436 | 437 | self.test_dataset = torch.utils.data.ConcatDataset(test_sequences) 438 | 439 | def get_test_dataset(self): 440 | return self.test_dataset 441 | 442 | 443 | def get_name_mapping_test(self): 444 | return self.name_mapper_test 445 | 446 | def summary(self, logger): 447 | logger.write_line("================================== Dataloader Summary ====================================", True) 448 | logger.write_line("Loader Type:\t\t" + self.__class__.__name__, True) 449 | logger.write_line("Number of Voxel Bins: {}".format(self.test_dataset.datasets[0].num_bins), True) 450 | -------------------------------------------------------------------------------- /loader/loader_mvsec_flow.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | from matplotlib import pyplot as plt 4 | from torch.utils.data import Dataset 5 | import os 6 | import numpy 7 | from utils import filename_templates as TEMPLATES 8 | from loader.utils import * 9 | from utils.transformers import * 10 | from utils import mvsec_utils 11 | from torchvision import transforms 12 | 13 | class MvsecFlow(Dataset): 14 | def __init__(self, args, type, path): 15 | super(MvsecFlow, self).__init__() 16 | #self.data_files = self.get_files(args['path'], args['datasets']) 17 | self.path_dataset = path 18 | self.timestamp_files = {} 19 | self.timestamp_files_flow = {} 20 | # If we load the image timestamps, we consider the framerate to be 45Hz. 21 | # Else if we load the depth/flow timestamps, the framerate is 20Hz. 22 | # The update rate gets set to 20 or 40 in the "get indices" method 23 | self.update_rate = None 24 | self.dataset = self.get_indices(path, args['datasets'], args['filter'], args['align_to']) 25 | self.input_type = 'events' 26 | self.type = type # Train/Val/Test 27 | 28 | # Evaluation Type. Dense -> Valid where GT exists 29 | # Sparse -> Valid where GT & Events exist 30 | self.evaluation_type = 'dense' 31 | 32 | self.image_width = 346 33 | self.image_height = 260 34 | 35 | self.voxel = EventSequenceToVoxelGrid_Pytorch( 36 | num_bins=args['num_voxel_bins'], 37 | normalize=True, 38 | gpu=True 39 | ) 40 | self.cropper = transforms.CenterCrop((256,256)) 41 | 42 | def summary(self, logger): 43 | logger.write_line("================================== Dataloader Summary ====================================", True) 44 | logger.write_line("Loader Type:\t\t" + self.__class__.__name__ + " for {}".format(self.type), True) 45 | logger.write_line("Framerate:\t\t{}".format(self.update_rate), True) 46 | logger.write_line("Evaluation Type:\t{}".format(self.evaluation_type), True) 47 | 48 | def get_indices(self, path, dataset, filter, align_to): 49 | # Returns a list of dicts. Each dict contains the following items: 50 | # ['dataset_name'] (e.g. outdoor_day) 51 | # ['subset_number'] (e.g. 1) 52 | # ['index'] (e.g. 1), Frame Index in the dataset 53 | # ['timestamp'] Timestamp of the frame with index i 54 | samples = [] 55 | for dataset_name in dataset: 56 | self.timestamp_files[dataset_name] = {} 57 | self.timestamp_files_flow[dataset_name] = {} 58 | for subset in dataset[dataset_name]: 59 | dataset_path = TEMPLATES.MVSEC_DATASET_FOLDER.format(dataset_name, subset) 60 | 61 | # Timestamps of DEPTH image 62 | if align_to.lower() == 'images' or align_to.lower() == 'image': 63 | print("Aligning everything to the image timestamps!") 64 | ts_path = TEMPLATES.MVSEC_TIMESTAMPS_PATH_IMAGES 65 | if self.update_rate is not None and self.update_rate != 45: 66 | raise Exception('Something wrong with the update rate!') 67 | self.update_rate = 45 68 | self.timestamp_files_flow[dataset_name][subset] = numpy.loadtxt(os.path.join(path, 69 | dataset_path, 70 | TEMPLATES.MVSEC_TIMESTAMPS_PATH_FLOW)) 71 | elif align_to.lower() == 'depth': 72 | print("Aligning everything to the depth timestamps!") 73 | ts_path = TEMPLATES.MVSEC_TIMESTAMPS_PATH_DEPTH 74 | if self.update_rate is not None and self.update_rate != 20: 75 | raise Exception('Something wrong with the update rate!') 76 | self.update_rate = 20 77 | elif align_to.lower() == 'flow': 78 | print("Aligning everything to the flow timestamps!") 79 | ts_path = TEMPLATES.MVSEC_TIMESTAMPS_PATH_FLOW 80 | if self.update_rate is not None and self.update_rate != 20: 81 | raise Exception('Something wrong with the update rate!') 82 | self.update_rate = 20 83 | else: 84 | raise ValueError("Please define the variable 'align_to' in the dataset [image/depth/flow]") 85 | ts = numpy.loadtxt(os.path.join(path, dataset_path, ts_path)) 86 | self.timestamp_files[dataset_name][subset] = ts 87 | for idx in eval(filter[dataset_name][str(subset)]): 88 | sample = {} 89 | sample['dataset_name'] = dataset_name 90 | sample['subset_number'] = subset 91 | sample['index'] = idx 92 | sample['timestamp'] = ts[idx] 93 | samples.append(sample) 94 | 95 | return samples 96 | 97 | def get_data_sample(self, loader_idx): 98 | # ================================= Get Data Sample =============================== # 99 | # Returns dict with the following content: # 100 | # - Event Sequence New (if type == 'events') # 101 | # - Event Sequence Old (if type == 'events') # 102 | # - Optical Flow (forward) between two timesteps as defined below # 103 | # - Timestamp and some other params # 104 | # # 105 | # Nomenclature Definition # 106 | # # 107 | # NOTE THAT THIS IS A DIFFERENT NAMING SCHEME THAN IN THE OTHER DATASETS! # 108 | # # 109 | # Flow[i-1] Flow[i] Flow[i+1] Flow[i+2] # 110 | # Depth[i-1] Depth[i] Depth[i+1] Depth[i+2] 111 | # | . | . | . . | # 112 | # | . . | . | . . | # 113 | # | . | .. . | ... | # 114 | # | . | . | . . | # 115 | # Events[i] Events[i+1] Events[i+2] # 116 | # 117 | # Flow[i] tells us the flow between Depth[i] and Depth[i+1] # 118 | # This can be seen because the pixels of flow[i] are the same as depth[i] # 119 | # We are for now using the events aligned to the depth-timestamps. # 120 | # This means, to get the flow between Depth[i] and Depth[i+1], we need to load # 121 | # - Flow[i] 122 | # - Events[i+1] 123 | # - Events[i] (if using volumetric cost volumes) 124 | # - Timestamps (from depth) [i] 125 | # - Timestamps (from depth) [i+1] 126 | set = self.dataset[loader_idx]['dataset_name'] 127 | subset = self.dataset[loader_idx]['subset_number'] 128 | path_subset = TEMPLATES.MVSEC_DATASET_FOLDER.format(set, subset) 129 | path_dataset = os.path.join(self.path_dataset, path_subset) 130 | # params = self.config[self.dataset[loader_idx]['dataset_name']][self.dataset[loader_idx]['subset_number']] 131 | idx = self.dataset[loader_idx]['index'] 132 | type = self.input_type 133 | 134 | # If the update rate is 20 Hz (i.e. we're aligned to the depth/flow maps), we can directly take the flow gt 135 | ts_old = self.timestamp_files[set][subset][idx] 136 | ts_new = self.timestamp_files[set][subset][idx + 1] 137 | 138 | if self.update_rate == 20: 139 | flow = get_flow_npy(os.path.join(path_dataset,TEMPLATES.MVSEC_FLOW_GT_FILE.format(idx))) 140 | # Else, we need to interpolate the flow 141 | elif self.update_rate == 45: 142 | flow = self.estimate_gt_flow(loader_idx, ts_old, ts_new) 143 | else: 144 | raise NotImplementedError 145 | 146 | 147 | # Either flow_x or flow_y has to be != 0 s.t. the flow is valid 148 | flow_valid = (flow[0]!=0) | (flow[1] != 0) 149 | # Additionally, the car hood (that goes from row 193..260 is not included in the GT. so this is invalid too. 150 | flow_valid[193:,:]=False 151 | 152 | return_dict = {'idx': idx, 153 | 'loader_idx': loader_idx, 154 | 'flow': torch.from_numpy(flow), 155 | 'gt_valid_mask': torch.from_numpy(numpy.stack([flow_valid]*2, axis=0)), 156 | "param_evc": {'height': self.image_height, 157 | 'width': self.image_width} 158 | } 159 | 160 | 161 | # Load Events 162 | if type == 'events': 163 | event_path_old = os.path.join(path_dataset, TEMPLATES.MVSEC_EVENTS_FILE.format('left', idx)) 164 | event_path_new = os.path.join(path_dataset, TEMPLATES.MVSEC_EVENTS_FILE.format('left', idx+1)) 165 | params = {'height': self.image_height, 'width': self.image_width} 166 | 167 | events_old = get_events(event_path_old) 168 | events_new = get_events(event_path_new) 169 | 170 | # Timestamp multiplier of 1e6 because the timestamps are saved as seconds and we're used to microseconds 171 | # This can be relevant for the voxel grid! 172 | ev_seq_old = EventSequence(events_old, params, timestamp_multiplier=1e6, convert_to_relative=True) 173 | ev_seq_new = EventSequence(events_new, params, timestamp_multiplier=1e6, convert_to_relative=True) 174 | return_dict['event_volume_new'] = self.voxel(ev_seq_new) 175 | return_dict['event_volume_old'] = self.voxel(ev_seq_old) 176 | if self.evaluation_type == 'sparse': 177 | seq = ev_seq_new.get_sequence_only() 178 | h = self.image_height 179 | w = self.image_width 180 | hist, _, _ = numpy.histogram2d(x=seq[:,1], y=seq[:,2], 181 | bins=(w,h), 182 | range=[[0,w], [0,h]]) 183 | hist = hist.transpose() 184 | ev_mask = hist > 0 185 | return_dict['gt_valid_mask'] = torch.from_numpy(numpy.stack([flow_valid & ev_mask]*2, axis=0)) 186 | elif type == 'frames': 187 | raise NotImplementedError 188 | else: 189 | raise Exception("Input Type not defined properly! Check config file.") 190 | 191 | # Check Timestamps 192 | ev = get_events(event_path_new).to_numpy() 193 | ts_ev_min = numpy.min(ev[:,0]) 194 | ts_ev_max = numpy.max(ev[:,0]) 195 | assert(ts_ev_min > ts_old and ts_ev_max <= ts_new) 196 | 197 | # plot images 198 | ''' 199 | from utils import visualization as visu 200 | from matplotlib import pyplot as plt 201 | 202 | # Justifying my choice of alignment: 203 | # 1) Flow[i] corresponds to Depth[i] 204 | depth_i = torch.tensor(numpy.load(os.path.join(path_dataset, TEMPLATES.MVSEC_DEPTH_GT_FILE.format(idx)))) 205 | plt.figure("depth i") 206 | plt.imshow(depth_i.numpy()) 207 | 208 | flow_visu = visu.visualize_optical_flow(flow, return_image=True)[0] 209 | plt.figure('Flow i') 210 | plt.imshow(flow_visu) 211 | 212 | # 2) The events are aligned to the depth 213 | # -> events[i] correspond to all events BEFORE depth i 214 | # -> events[i+1] correspond to all events AFTER depth i 215 | # This can be proven by the timestamps timestamp[i] corresponding to depth[i] 216 | ts_old = self.timestamp_files[set][subset][idx] 217 | ts_new = self.timestamp_files[set][subset][idx + 1] 218 | 219 | 220 | ev = get_events(event_path_new).to_numpy() 221 | ts_ev_min = numpy.min(ev[:,0]) 222 | ts_ev_max = numpy.max(ev[:,0]) 223 | assert(ts_ev_min > ts_old and ts_ev_max <= ts_new) 224 | 225 | # -> Additionally, we can show this, if we plot the events of the first 5ms of events before the depth map 226 | # Remember: events[i] are all the events BEFORE the depth[i] 227 | 228 | event_path_i = os.path.join(path_dataset, TEMPLATES.MVSEC_EVENTS_FILE.format('left', idx)) 229 | ev_i = get_events(event_path_i).to_numpy() 230 | ts_i = self.timestamp_files[set][subset][idx] 231 | ev_inst_idx = ev_i[:,0] > ts_i - 0.005 232 | ev_inst = ev_i[ev_inst_idx] 233 | evv = visu.events_to_event_image(ev_inst, self.image_height, self.image_width) 234 | plt.figure("events_instantaneous") 235 | plt.imshow(evv.numpy().transpose(1,2,0)) 236 | # This should now match the depth_i 237 | 238 | # Hence, all misalignments are coming from the ground-truth itself. 239 | ''' 240 | 241 | return return_dict 242 | 243 | def estimate_gt_flow(self, loader_idx, ts_old, ts_new): 244 | # We need to estimate the flow between two timestamps. 245 | 246 | # First, get the dataset & subset 247 | set = self.dataset[loader_idx]['dataset_name'] 248 | subset = self.dataset[loader_idx]['subset_number'] 249 | path_flow = os.path.join(self.path_dataset, 250 | TEMPLATES.MVSEC_DATASET_FOLDER.format(set, subset)) 251 | 252 | assert ts_old >= self.timestamp_files_flow[set][subset].min(), \ 253 | 'Timestamp is smaller than the first flow timestamp' 254 | 255 | # Now, estimate the corresponding GT 256 | flow = mvsec_utils.estimate_corresponding_gt_flow(path_flow=path_flow, 257 | gt_timestamps=self.timestamp_files_flow[set][subset], 258 | start_time=ts_old, 259 | end_time=ts_new) 260 | # flow is a tuple of [H,W]. Stack it 261 | return numpy.stack(flow) 262 | 263 | @staticmethod 264 | def mvsec_time_conversion(timestamps): 265 | raise NotImplementedError 266 | 267 | def get_ts(self, path, i): 268 | try: 269 | f = open(path, "r") 270 | return float(f.readlines()[i]) 271 | except OSError: 272 | raise 273 | 274 | def get_image_width_height(self, type='event_camera'): 275 | if hasattr(self, 'cropper'): 276 | h = self.cropper.size[0] 277 | w = self.cropper.size[1] 278 | return h, w 279 | return self.image_height, self.image_width 280 | 281 | def get_events(self, loader_idx): 282 | # Get Events For Visualization Only!!! 283 | path_dataset = os.path.join(self.path_dataset,self.dataset[loader_idx]['dataset_name'] + "_" + str(self.dataset[loader_idx]['subset_number'])) 284 | params = {'height': self.image_height, 'width': self.image_width} 285 | i = self.dataset[loader_idx]['index'] 286 | path = os.path.join(path_dataset, TEMPLATES.MVSEC_EVENTS_FILE.format('left', i+1)) 287 | events = EventSequence(get_events(path), params).get_sequence_only() 288 | return events 289 | 290 | def __len__(self): 291 | return len(self.dataset) 292 | 293 | def __getitem__(self, idx, force_crop_window=None, force_flipping=None): 294 | if idx >= len(self): 295 | raise IndexError 296 | sample = self.get_data_sample(idx) 297 | # Center Crop Everything 298 | sample['flow'] = self.cropper(sample['flow']) 299 | sample['gt_valid_mask'] = self.cropper(sample['gt_valid_mask']) 300 | sample['event_volume_new'] = self.cropper(sample['event_volume_new']) 301 | sample['event_volume_old'] = self.cropper(sample['event_volume_old']) 302 | 303 | return sample 304 | 305 | class MvsecFlowRecurrent(Dataset): 306 | def __init__(self, args, type, path): 307 | super(MvsecFlowRecurrent, self).__init__() 308 | if type.lower() != 'test': 309 | self.sequence_length = args['sequence_length'] 310 | else: 311 | self.sequence_length = 1 312 | self.step_size = 1 313 | self.dataset = MvsecFlow(args, type, path=path) 314 | 315 | def __len__(self): 316 | return (len(self.dataset) - self.sequence_length) // self.step_size + 1 317 | 318 | def __getitem__(self, idx): 319 | # ----------------------------------------------------------------------------- # 320 | # Returns a list, containing of event/frame/flow # 321 | # [ e_(i-sequence_length), ..., e_(i) ] # 322 | # ----------------------------------------------------------------------------- # 323 | assert(idx >= 0) 324 | assert(idx < len(self)) 325 | sequence = [] 326 | j = idx * self.step_size 327 | 328 | flip = None 329 | crop_window = None 330 | 331 | for i in range(self.sequence_length): 332 | sequence.append(self.dataset.__getitem__(j + i, force_crop_window=crop_window, force_flipping=flip)) 333 | # Just Making Sure 334 | assert sequence[-1]['idx']-sequence[0]['idx'] == self.sequence_length-1 335 | return sequence 336 | 337 | def summary(self, logger): 338 | logger.write_line("================================== Dataloader Summary ====================================", True) 339 | logger.write_line("Loader Type:\t\t" + self.__class__.__name__ + " for {}".format(self.dataset.type), True) 340 | logger.write_line("Sequence Length:\t{}".format(self.sequence_length), True) 341 | logger.write_line("Step Size:\t\t{}".format(self.step_size), True) 342 | logger.write_line("Framerate:\t\t{}".format(self.dataset.update_rate), True) 343 | 344 | def get_image_width_height(self, type='event_camera'): 345 | return self.dataset.get_image_width_height(type) 346 | 347 | def get_events(self, loader_idx): 348 | return self.dataset.get_events(loader_idx) 349 | -------------------------------------------------------------------------------- /loader/utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import pandas 4 | from PIL import Image 5 | import random 6 | import torch 7 | from itertools import chain 8 | import h5py 9 | import json 10 | 11 | 12 | class EventSequence(object): 13 | def __init__(self, dataframe, params, features=None, timestamp_multiplier=None, convert_to_relative=False): 14 | if isinstance(dataframe, pandas.DataFrame): 15 | self.feature_names = dataframe.columns.values 16 | self.features = dataframe.to_numpy() 17 | else: 18 | self.feature_names = numpy.array(['ts', 'x', 'y', 'p'], dtype=object) 19 | if features is None: 20 | self.features = numpy.zeros([1, 4]) 21 | else: 22 | self.features = features 23 | self.image_height = params['height'] 24 | self.image_width = params['width'] 25 | if not self.is_sorted(): 26 | self.sort_by_timestamp() 27 | if timestamp_multiplier is not None: 28 | self.features[:,0] *= timestamp_multiplier 29 | if convert_to_relative: 30 | self.absolute_time_to_relative() 31 | 32 | def get_sequence_only(self): 33 | return self.features 34 | 35 | def __len__(self): 36 | return len(self.features) 37 | 38 | def __add__(self, sequence): 39 | event_sequence = EventSequence(dataframe=None, 40 | features=numpy.concatenate([self.features, sequence.features]), 41 | params={'height': self.image_height, 42 | 'width': self.image_width}) 43 | return event_sequence 44 | 45 | def is_sorted(self): 46 | return numpy.all(self.features[:-1, 0] <= self.features[1:, 0]) 47 | 48 | def sort_by_timestamp(self): 49 | if len(self.features[:, 0]) > 0: 50 | sort_indices = numpy.argsort(self.features[:, 0]) 51 | self.features = self.features[sort_indices] 52 | 53 | def absolute_time_to_relative(self): 54 | """Transforms absolute time to time relative to the first event.""" 55 | start_ts = self.features[:,0].min() 56 | assert(start_ts == self.features[0,0]) 57 | self.features[:,0] -= start_ts 58 | 59 | 60 | def get_image(image_path): 61 | try: 62 | im = Image.open(image_path) 63 | # print(image_path) 64 | return numpy.array(im) 65 | except OSError: 66 | raise 67 | 68 | 69 | def get_events(event_path): 70 | # It's possible that there is no event file! (camera standing still) 71 | try: 72 | f = pandas.read_hdf(event_path, "myDataset") 73 | return f[['ts', 'x', 'y', 'p']] 74 | except OSError: 75 | print("No file " + event_path) 76 | print("Creating an array of zeros!") 77 | return 0 78 | 79 | 80 | def get_ts(path, i, type='int'): 81 | try: 82 | f = open(path, "r") 83 | if type == 'int': 84 | return int(f.readlines()[i]) 85 | elif type == 'double' or type == 'float': 86 | return float(f.readlines()[i]) 87 | except OSError: 88 | raise 89 | 90 | 91 | def get_batchsize(path_dataset): 92 | filepath = os.path.join(path_dataset, "cam0", "timestamps.txt") 93 | try: 94 | f = open(filepath, "r") 95 | return len(f.readlines()) 96 | except OSError: 97 | raise 98 | 99 | 100 | def get_batch(path_dataset, i): 101 | return 0 102 | 103 | 104 | def dataset_paths(dataset_name, path_dataset, subset_number=None): 105 | cameras = {'cam0': {}, 'cam1': {}, 'cam2': {}, 'cam3': {}} 106 | if subset_number is not None: 107 | dataset_name = dataset_name + "_" + str(subset_number) 108 | paths = {'dataset_folder': os.path.join(path_dataset, dataset_name)} 109 | 110 | # For every camera, define its path 111 | for camera in cameras: 112 | cameras[camera]['image_folder'] = os.path.join(paths['dataset_folder'], camera, 'image_raw') 113 | cameras[camera]['event_folder'] = os.path.join(paths['dataset_folder'], camera, 'events') 114 | cameras[camera]['disparity_folder'] = os.path.join(paths['dataset_folder'], camera, 'disparity_image') 115 | cameras[camera]['depth_folder'] = os.path.join(paths['dataset_folder'], camera, 'depthmap') 116 | cameras["timestamp_file"] = os.path.join(paths['dataset_folder'], 'cam0', 'timestamps.txt') 117 | cameras["image_type"] = ".png" 118 | cameras["event_type"] = ".h5" 119 | cameras["disparity_type"] = ".png" 120 | cameras["depth_type"] = ".tiff" 121 | cameras["indexing_type"] = "%0.6i" 122 | paths.update(cameras) 123 | return paths 124 | 125 | 126 | def get_indices(path_dataset, dataset, filter, shuffle=False): 127 | samples = [] 128 | for dataset_name in dataset: 129 | for subset in dataset[dataset_name]: 130 | # Get all the dataframe paths 131 | paths = dataset_paths(dataset_name, path_dataset, subset) 132 | 133 | # import timestamps 134 | ts = numpy.loadtxt(paths["timestamp_file"]) 135 | 136 | # frames = [] 137 | # For every timestamp, import according data 138 | for idx in eval(filter[dataset_name][str(subset)]): 139 | frame = {} 140 | frame['dataset_name'] = dataset_name 141 | frame['subset_number'] = subset 142 | frame['index'] = idx 143 | frame['timestamp'] = ts[idx] 144 | samples.append(frame) 145 | # shuffle dataset 146 | if shuffle: 147 | random.shuffle(samples) 148 | return samples 149 | 150 | 151 | def get_flow_h5(flow_path): 152 | scaling_factor = 0.05 # seconds/frame 153 | f = h5py.File(flow_path, 'r') 154 | height, width = int(f['header']['height']), int(f['header']['width']) 155 | assert(len(f['x']) == height*width) 156 | assert(len(f['y']) == height*width) 157 | x = numpy.array(f['x']).reshape([height,width])*scaling_factor 158 | y = numpy.array(f['y']).reshape([height,width])*scaling_factor 159 | return numpy.stack([x,y]) 160 | 161 | 162 | def get_flow_npy(flow_path): 163 | # Array 2,height, width 164 | # No scaling needed. 165 | return numpy.load(flow_path, allow_pickle=True) 166 | 167 | 168 | def get_pose(pose_path, index): 169 | pose = pandas.read_csv(pose_path, delimiter=',').loc[index].to_numpy() 170 | # Convert Timestamp to int (as all the other timestamps) 171 | pose[0] = int(pose[0]) 172 | return pose 173 | 174 | 175 | def load_config(path, datasets): 176 | config = {} 177 | for dataset_name in datasets: 178 | config[dataset_name] = {} 179 | for subset in datasets[dataset_name]: 180 | name = "{}_{}".format(dataset_name, subset) 181 | try: 182 | config[dataset_name][subset] = json.load(open(os.path.join(path, name, "config.json"))) 183 | except: 184 | print("Could not find config file for dataset" + dataset_name + "_" + str(subset) + 185 | ". Please check if the file 'config.json' is existing in the dataset-scene directory") 186 | raise 187 | return config 188 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["MKL_NUM_THREADS"] = "1" 3 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 4 | os.environ["OMP_NUM_THREADS"] = "1" 5 | os.environ["NUMEXPR_MAX_THREADS"]="1" 6 | from loader.loader_mvsec_flow import * 7 | from loader.loader_dsec import * 8 | from utils.logger import * 9 | import utils.helper_functions as helper 10 | import json 11 | from torch.utils.data.dataloader import DataLoader 12 | from utils import visualization as visu 13 | import argparse 14 | from test import * 15 | import git 16 | import torch.nn 17 | from model import eraft 18 | 19 | def initialize_tester(config): 20 | # Warm Start 21 | if config['subtype'].lower() == 'warm_start': 22 | return TestRaftEventsWarm 23 | # Classic 24 | else: 25 | return TestRaftEvents 26 | 27 | def get_visualizer(args): 28 | # DSEC dataset 29 | if args.dataset.lower() == 'dsec': 30 | return visualization.DsecFlowVisualizer 31 | # MVSEC dataset 32 | else: 33 | return visualization.FlowVisualizerEvents 34 | 35 | def test(args): 36 | # Choose correct config file 37 | if args.dataset.lower()=='dsec': 38 | if args.type.lower()=='warm_start': 39 | config_path = 'config/dsec_warm_start.json' 40 | elif args.type.lower()=='standard': 41 | config_path = 'config/dsec_standard.json' 42 | else: 43 | raise Exception('Please provide a valid argument for --type. [warm_start/standard]') 44 | elif args.dataset.lower()=='mvsec': 45 | if args.frequency==20: 46 | config_path = 'config/mvsec_20.json' 47 | elif args.frequency==45: 48 | config_path = 'config/mvsec_45.json' 49 | else: 50 | raise Exception('Please provide a valid argument for --frequency. [20/45]') 51 | if args.type=='standard': 52 | raise NotImplementedError('Sorry, this is not implemented yet, please choose --type warm_start') 53 | else: 54 | raise Exception('Please provide a valid argument for --dataset. [dsec/mvsec]') 55 | 56 | 57 | # Load config file 58 | config = json.load(open(config_path)) 59 | # Create Save Folder 60 | save_path = helper.create_save_path(config['save_dir'].lower(), config['name'].lower()) 61 | print('Storing output in folder {}'.format(save_path)) 62 | # Copy config file to save dir 63 | json.dump(config, open(os.path.join(save_path, 'config.json'), 'w'), 64 | indent=4, sort_keys=False) 65 | # Logger 66 | logger = Logger(save_path) 67 | logger.initialize_file("test") 68 | 69 | # Instantiate Dataset 70 | # Case: DSEC Dataset 71 | additional_loader_returns = None 72 | if args.dataset.lower() == 'dsec': 73 | # Dsec Dataloading 74 | loader = DatasetProvider( 75 | dataset_path=Path(args.path), 76 | representation_type=RepresentationType.VOXEL, 77 | delta_t_ms=100, 78 | config=config, 79 | type=config['subtype'].lower(), 80 | visualize=args.visualize) 81 | loader.summary(logger) 82 | test_set = loader.get_test_dataset() 83 | additional_loader_returns = {'name_mapping_test': loader.get_name_mapping_test()} 84 | 85 | # Case: MVSEC Dataset 86 | else: 87 | if config['subtype'].lower() == 'standard': 88 | test_set = MvsecFlow( 89 | args = config["data_loader"]["test"]["args"], 90 | type='test', 91 | path=args.path 92 | ) 93 | elif config['subtype'].lower() == 'warm_start': 94 | test_set = MvsecFlowRecurrent( 95 | args = config["data_loader"]["test"]["args"], 96 | type='test', 97 | path=args.path 98 | ) 99 | else: 100 | raise NotImplementedError 101 | test_set.summary(logger) 102 | 103 | # Instantiate Dataloader 104 | test_set_loader = DataLoader(test_set, 105 | batch_size=config['data_loader']['test']['args']['batch_size'], 106 | shuffle=config['data_loader']['test']['args']['shuffle'], 107 | num_workers=args.num_workers, 108 | drop_last=True) 109 | 110 | # Load Model 111 | model = eraft.ERAFT( 112 | config=config, 113 | n_first_channels=config['data_loader']['test']['args']['num_voxel_bins'] 114 | ) 115 | # Load Checkpoint 116 | checkpoint = torch.load(config['test']['checkpoint']) 117 | model.load_state_dict(checkpoint['model']) 118 | 119 | # Get Visualizer 120 | visualizer = get_visualizer(args) 121 | 122 | # Initialize Tester 123 | test = initialize_tester(config) 124 | 125 | test = test( 126 | model=model, 127 | config=config, 128 | data_loader=test_set_loader, 129 | test_logger=logger, 130 | save_path=save_path, 131 | visualizer=visualizer, 132 | additional_args=additional_loader_returns 133 | ) 134 | 135 | test.summary() 136 | test._test() 137 | 138 | if __name__ == '__main__': 139 | config_path = "config/config_test.json" 140 | # Argument Parser 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument('-p', '--path', type=str, help="Dataset path", required=True) 143 | parser.add_argument('-d', '--dataset', default="dsec", type=str, help="Which dataset to use: ([dsec]/mvsec)") 144 | parser.add_argument('-f', '--frequency', default=20, type=int, help="Evaluation frequency of MVSEC dataset ([20]/45) Hz") 145 | parser.add_argument('-t', '--type', default='warm_start', type=str, help="Evaluation type ([warm_start]/standard)") 146 | parser.add_argument('-v', '--visualize', action='store_true', help='Provide this argument s.t. DSEC results are visualized. MVSEC experiments are always visualized.') 147 | parser.add_argument('-n', '--num_workers', default=0, type=int, help='How many sub-processes to use for data loading') 148 | args = parser.parse_args() 149 | 150 | # Run Test Script 151 | test(args) 152 | -------------------------------------------------------------------------------- /model/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from model.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | -------------------------------------------------------------------------------- /model/eraft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .update import BasicUpdateBlock 7 | from .extractor import BasicEncoder 8 | from .corr import CorrBlock 9 | from model.utils import coords_grid, upflow8 10 | from argparse import Namespace 11 | from utils.image_utils import ImagePadder 12 | 13 | try: 14 | autocast = torch.cuda.amp.autocast 15 | except: 16 | # dummy autocast for PyTorch < 1.6 17 | class autocast: 18 | def __init__(self, enabled): 19 | pass 20 | def __enter__(self): 21 | pass 22 | def __exit__(self, *args): 23 | pass 24 | 25 | 26 | def get_args(): 27 | # This is an adapter function that converts the arguments given in out config file to the format, which the ERAFT 28 | # expects. 29 | args = Namespace(small=False, 30 | dropout=False, 31 | mixed_precision=False, 32 | clip=1.0) 33 | return args 34 | 35 | 36 | 37 | class ERAFT(nn.Module): 38 | def __init__(self, config, n_first_channels): 39 | # args: 40 | super(ERAFT, self).__init__() 41 | args = get_args() 42 | self.args = args 43 | self.image_padder = ImagePadder(min_size=32) 44 | self.subtype = config['subtype'].lower() 45 | 46 | assert (self.subtype == 'standard' or self.subtype == 'warm_start') 47 | 48 | self.hidden_dim = hdim = 128 49 | self.context_dim = cdim = 128 50 | args.corr_levels = 4 51 | args.corr_radius = 4 52 | 53 | # feature network, context network, and update block 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=0, 55 | n_first_channels=n_first_channels) 56 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=0, 57 | n_first_channels=n_first_channels) 58 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 59 | 60 | def freeze_bn(self): 61 | for m in self.modules(): 62 | if isinstance(m, nn.BatchNorm2d): 63 | m.eval() 64 | 65 | def initialize_flow(self, img): 66 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 67 | N, C, H, W = img.shape 68 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 69 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 70 | 71 | # optical flow computed as difference: flow = coords1 - coords0 72 | return coords0, coords1 73 | 74 | def upsample_flow(self, flow, mask): 75 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 76 | N, _, H, W = flow.shape 77 | mask = mask.view(N, 1, 9, 8, 8, H, W) 78 | mask = torch.softmax(mask, dim=2) 79 | 80 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 81 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 82 | 83 | up_flow = torch.sum(mask * up_flow, dim=2) 84 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 85 | return up_flow.reshape(N, 2, 8*H, 8*W) 86 | 87 | 88 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True): 89 | """ Estimate optical flow between pair of frames """ 90 | # Pad Image (for flawless up&downsampling) 91 | image1 = self.image_padder.pad(image1) 92 | image2 = self.image_padder.pad(image2) 93 | 94 | image1 = image1.contiguous() 95 | image2 = image2.contiguous() 96 | 97 | hdim = self.hidden_dim 98 | cdim = self.context_dim 99 | 100 | # run the feature network 101 | with autocast(enabled=self.args.mixed_precision): 102 | fmap1, fmap2 = self.fnet([image1, image2]) 103 | 104 | fmap1 = fmap1.float() 105 | fmap2 = fmap2.float() 106 | 107 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 108 | 109 | # run the context network 110 | with autocast(enabled=self.args.mixed_precision): 111 | if (self.subtype == 'standard' or self.subtype == 'warm_start'): 112 | cnet = self.cnet(image2) 113 | else: 114 | raise Exception 115 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 116 | net = torch.tanh(net) 117 | inp = torch.relu(inp) 118 | 119 | # Initialize Grids. First channel: x, 2nd channel: y. Image is just used to get the shape 120 | coords0, coords1 = self.initialize_flow(image1) 121 | 122 | if flow_init is not None: 123 | coords1 = coords1 + flow_init 124 | 125 | flow_predictions = [] 126 | for itr in range(iters): 127 | coords1 = coords1.detach() 128 | corr = corr_fn(coords1) # index correlation volume 129 | 130 | flow = coords1 - coords0 131 | with autocast(enabled=self.args.mixed_precision): 132 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 133 | 134 | # F(t+1) = F(t) + \Delta(t) 135 | coords1 = coords1 + delta_flow 136 | 137 | # upsample predictions 138 | if up_mask is None: 139 | flow_up = upflow8(coords1 - coords0) 140 | else: 141 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 142 | 143 | flow_predictions.append(self.image_padder.unpad(flow_up)) 144 | 145 | return coords1 - coords0, flow_predictions 146 | -------------------------------------------------------------------------------- /model/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from model.network_blocks import RecurrentResidualBlock 5 | from torch.nn import init 6 | 7 | class ResidualBlock(nn.Module): 8 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 9 | super(ResidualBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | num_groups = planes // 8 16 | 17 | if norm_fn == 'group': 18 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 20 | if not stride == 1: 21 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 22 | 23 | elif norm_fn == 'batch': 24 | self.norm1 = nn.BatchNorm2d(planes) 25 | self.norm2 = nn.BatchNorm2d(planes) 26 | if not stride == 1: 27 | self.norm3 = nn.BatchNorm2d(planes) 28 | 29 | elif norm_fn == 'instance': 30 | self.norm1 = nn.InstanceNorm2d(planes) 31 | self.norm2 = nn.InstanceNorm2d(planes) 32 | if not stride == 1: 33 | self.norm3 = nn.InstanceNorm2d(planes) 34 | 35 | elif norm_fn == 'none': 36 | self.norm1 = nn.Sequential() 37 | self.norm2 = nn.Sequential() 38 | if not stride == 1: 39 | self.norm3 = nn.Sequential() 40 | 41 | if stride == 1: 42 | self.downsample = None 43 | 44 | else: 45 | self.downsample = nn.Sequential( 46 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 47 | 48 | 49 | def forward(self, x): 50 | y = x 51 | y = self.relu(self.norm1(self.conv1(y))) 52 | y = self.relu(self.norm2(self.conv2(y))) 53 | 54 | if self.downsample is not None: 55 | x = self.downsample(x) 56 | 57 | return self.relu(x+y) 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | 119 | class BasicEncoder(nn.Module): 120 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, n_first_channels=1): 121 | super(BasicEncoder, self).__init__() 122 | self.norm_fn = norm_fn 123 | 124 | if self.norm_fn == 'group': 125 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 126 | 127 | elif self.norm_fn == 'batch': 128 | self.norm1 = nn.BatchNorm2d(64) 129 | 130 | elif self.norm_fn == 'instance': 131 | self.norm1 = nn.InstanceNorm2d(64) 132 | 133 | elif self.norm_fn == 'none': 134 | self.norm1 = nn.Sequential() 135 | 136 | self.conv1 = nn.Conv2d(n_first_channels, 64, kernel_size=7, stride=2, padding=3) 137 | self.relu1 = nn.ReLU(inplace=True) 138 | 139 | self.in_planes = 64 140 | self.layer1 = self._make_layer(64, stride=1) 141 | self.layer2 = self._make_layer(96, stride=2) 142 | self.layer3 = self._make_layer(128, stride=2) 143 | 144 | # output convolution 145 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 146 | 147 | self.dropout = None 148 | if dropout > 0: 149 | self.dropout = nn.Dropout2d(p=dropout) 150 | 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d): 153 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 154 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 155 | if m.weight is not None: 156 | nn.init.constant_(m.weight, 1) 157 | if m.bias is not None: 158 | nn.init.constant_(m.bias, 0) 159 | 160 | def _make_layer(self, dim, stride=1): 161 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 162 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 163 | layers = (layer1, layer2) 164 | 165 | self.in_planes = dim 166 | return nn.Sequential(*layers) 167 | 168 | def forward(self, x): 169 | # if input is list, combine batch dimension 170 | is_list = isinstance(x, tuple) or isinstance(x, list) 171 | if is_list: 172 | batch_dim = x[0].shape[0] 173 | x = torch.cat(x, dim=0) 174 | 175 | x = self.conv1(x) 176 | x = self.norm1(x) 177 | x = self.relu1(x) 178 | x = self.layer1(x) 179 | x = self.layer2(x) 180 | x = self.layer3(x) 181 | 182 | x = self.conv2(x) 183 | 184 | if self.training and self.dropout is not None: 185 | x = self.dropout(x) 186 | 187 | if is_list: 188 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 189 | return x 190 | -------------------------------------------------------------------------------- /model/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | 63 | class BasicMotionEncoder(nn.Module): 64 | def __init__(self, args): 65 | super(BasicMotionEncoder, self).__init__() 66 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 67 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 68 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 69 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 70 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 71 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 72 | 73 | def forward(self, flow, corr): 74 | cor = F.relu(self.convc1(corr)) 75 | cor = F.relu(self.convc2(cor)) 76 | flo = F.relu(self.convf1(flow)) 77 | flo = F.relu(self.convf2(flo)) 78 | 79 | cor_flo = torch.cat([cor, flo], dim=1) 80 | out = F.relu(self.conv(cor_flo)) 81 | return torch.cat([out, flow], dim=1) 82 | 83 | 84 | class BasicUpdateBlock(nn.Module): 85 | def __init__(self, args, hidden_dim=128, input_dim=128): 86 | super(BasicUpdateBlock, self).__init__() 87 | self.args = args 88 | self.encoder = BasicMotionEncoder(args) 89 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 90 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 91 | 92 | self.mask = nn.Sequential( 93 | nn.Conv2d(128, 256, 3, padding=1), 94 | nn.ReLU(inplace=True), 95 | nn.Conv2d(256, 64*9, 1, padding=0)) 96 | 97 | def forward(self, net, inp, corr, flow, upsample=True): 98 | motion_features = self.encoder(flow, corr) 99 | inp = torch.cat([inp, motion_features], dim=1) 100 | 101 | net = self.gru(net, inp) 102 | delta_flow = self.flow_head(net) 103 | 104 | # scale mask to balance gradients 105 | mask = .25 * self.mask(net) 106 | return net, mask, delta_flow 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 8 | """ Wrapper for grid_sample, uses pixel coordinates """ 9 | H, W = img.shape[-2:] 10 | xgrid, ygrid = coords.split([1,1], dim=-1) 11 | xgrid = 2*xgrid/(W-1) - 1 12 | ygrid = 2*ygrid/(H-1) - 1 13 | 14 | grid = torch.cat([xgrid, ygrid], dim=-1) 15 | img = F.grid_sample(img, grid, align_corners=True) 16 | 17 | if mask: 18 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 19 | return img, mask.float() 20 | 21 | return img 22 | 23 | 24 | def coords_grid(batch, ht, wd): 25 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 26 | coords = torch.stack(coords[::-1], dim=0).float() 27 | return coords[None].repeat(batch, 1, 1, 1) 28 | 29 | 30 | def upflow8(flow, mode='bilinear'): 31 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 32 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 33 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | from torchvision import utils 4 | from utils.helper_functions import * 5 | import utils.visualization as visualization 6 | import utils.filename_templates as TEMPLATES 7 | import utils.helper_functions as helper 8 | import utils.logger as logger 9 | from utils import image_utils 10 | 11 | class Test(object): 12 | """ 13 | Test class 14 | 15 | """ 16 | 17 | def __init__(self, model, config, 18 | data_loader, visualizer, test_logger=None, save_path=None, additional_args=None): 19 | self.downsample = False # Downsampling for Rebuttal 20 | self.model = model 21 | self.config = config 22 | self.data_loader = data_loader 23 | self.additional_args = additional_args 24 | if config['cuda'] and not torch.cuda.is_available(): 25 | print('Warning: There\'s no CUDA support on this machine, ' 26 | 'training is performed on CPU.') 27 | else: 28 | self.gpu = torch.device('cuda:' + str(config['gpu'])) 29 | self.model = self.model.to(self.gpu) 30 | if save_path is None: 31 | self.save_path = helper.create_save_path(config['save_dir'].lower(), 32 | config['name'].lower()) 33 | else: 34 | self.save_path=save_path 35 | if logger is None: 36 | self.logger = logger.Logger(self.save_path) 37 | else: 38 | self.logger = test_logger 39 | if isinstance(self.additional_args, dict) and 'name_mapping_test' in self.additional_args.keys(): 40 | visu_add_args = {'name_mapping' : self.additional_args['name_mapping_test']} 41 | else: 42 | visu_add_args = None 43 | self.visualizer = visualizer(data_loader, self.save_path, additional_args=visu_add_args) 44 | 45 | def summary(self): 46 | self.logger.write_line("====================================== TEST SUMMARY ======================================", True) 47 | self.logger.write_line("Model:\t\t\t" + self.model.__class__.__name__, True) 48 | self.logger.write_line("Tester:\t\t" + self.__class__.__name__, True) 49 | self.logger.write_line("Test Set:\t" + self.data_loader.dataset.__class__.__name__, True) 50 | self.logger.write_line("\t-Dataset length:\t"+str(len(self.data_loader)), True) 51 | self.logger.write_line("\t-Batch size:\t\t" + str(self.data_loader.batch_size), True) 52 | self.logger.write_line("==========================================================================================", True) 53 | 54 | def run_network(self, epoch): 55 | raise NotImplementedError 56 | 57 | def move_batch_to_cuda(self, batch): 58 | raise NotImplementedError 59 | 60 | def visualize_sample(self, batch): 61 | self.visualizer(batch) 62 | 63 | def visualize_sample_dsec(self, batch, batch_idx): 64 | self.visualizer(batch, batch_idx, None) 65 | 66 | def get_estimation_and_target(self, batch): 67 | # Returns the estimation and target of the current batch 68 | raise NotImplementedError 69 | 70 | def _test(self): 71 | """ 72 | Validate after training an epoch 73 | 74 | :return: A log that contains information about validation 75 | 76 | Note: 77 | The validation metrics in log must have the key 'val_metrics'. 78 | """ 79 | self.model.eval() 80 | with torch.no_grad(): 81 | for batch_idx, batch in enumerate(self.data_loader): 82 | # Move Data to GPU 83 | if next(self.model.parameters()).is_cuda: 84 | batch = self.move_batch_to_cuda(batch) 85 | # Network Forward Pass 86 | self.run_network(batch) 87 | print("Sample {}/{}".format(batch_idx + 1, len(self.data_loader))) 88 | 89 | # Visualize 90 | if hasattr(batch, 'keys') and 'loader_idx' in batch.keys() \ 91 | or (isinstance(batch,list) and hasattr(batch[0], 'keys') and 'loader_idx' in batch[0].keys()): 92 | self.visualize_sample(batch) 93 | else: 94 | # DSEC Special Snowflake 95 | self.visualize_sample_dsec(batch, batch_idx) 96 | #print('Not Visualizing') 97 | 98 | # Log Generation 99 | log = {} 100 | 101 | return log 102 | 103 | class TestRaftEvents(Test): 104 | def move_batch_to_cuda(self, batch): 105 | return move_dict_to_cuda(batch, self.gpu) 106 | 107 | def get_estimation_and_target(self, batch): 108 | if not self.downsample: 109 | if 'gt_valid_mask' in batch.keys(): 110 | return batch['flow_est'].cpu().data, (batch['flow'].cpu().data, batch['gt_valid_mask'].cpu().data) 111 | return batch['flow_est'].cpu().data, batch['flow'].cpu().data 112 | else: 113 | f_est = batch['flow_est'].cpu().data 114 | f_gt = torch.nn.functional.interpolate(batch['flow'].cpu().data, scale_factor=0.5) 115 | if 'gt_valid_mask' in batch.keys(): 116 | f_mask = torch.nn.functional.interpolate(batch['gt_valid_mask'].cpu().data, scale_factor=0.5) 117 | return f_est, (f_gt, f_mask) 118 | return f_est, f_gt 119 | 120 | def run_network(self, batch): 121 | # RAFT just expects two images as input. cleanest. code. ever. 122 | if not self.downsample: 123 | im1 = batch['event_volume_old'] 124 | im2 = batch['event_volume_new'] 125 | else: 126 | im1 = torch.nn.functional.interpolate(batch['event_volume_old'], scale_factor=0.5) 127 | im2 = torch.nn.functional.interpolate(batch['event_volume_new'], scale_factor=0.5) 128 | _, batch['flow_list'] = self.model(image1=im1, 129 | image2=im2) 130 | batch['flow_est'] = batch['flow_list'][-1] 131 | 132 | class TestRaftEventsWarm(Test): 133 | def __init__(self, model, config, 134 | data_loader, visualizer, test_logger=None, save_path=None, additional_args=None): 135 | super(TestRaftEventsWarm, self).__init__(model, config, 136 | data_loader, visualizer, test_logger, save_path, 137 | additional_args=additional_args) 138 | self.subtype = config['subtype'].lower() 139 | print('Tester Subtype: {}'.format(self.subtype)) 140 | self.net_init = None # Hidden state of the refinement GRU 141 | self.flow_init = None 142 | self.idx_prev = None 143 | self.init_print=False 144 | assert self.data_loader.batch_size == 1, 'Batch size for recurrent testing must be 1' 145 | 146 | def move_batch_to_cuda(self, batch): 147 | return move_list_to_cuda(batch, self.gpu) 148 | 149 | def get_estimation_and_target(self, batch): 150 | if not self.downsample: 151 | if 'gt_valid_mask' in batch[-1].keys(): 152 | return batch[-1]['flow_est'].cpu().data, (batch[-1]['flow'].cpu().data, batch[-1]['gt_valid_mask'].cpu().data) 153 | return batch[-1]['flow_est'].cpu().data, batch[-1]['flow'].cpu().data 154 | else: 155 | f_est = batch[-1]['flow_est'].cpu().data 156 | f_gt = torch.nn.functional.interpolate(batch[-1]['flow'].cpu().data, scale_factor=0.5) 157 | if 'gt_valid_mask' in batch[-1].keys(): 158 | f_mask = torch.nn.functional.interpolate(batch[-1]['gt_valid_mask'].cpu().data, scale_factor=0.5) 159 | return f_est, (f_gt, f_mask) 160 | return f_est, f_gt 161 | 162 | def visualize_sample(self, batch): 163 | self.visualizer(batch[-1]) 164 | 165 | def visualize_sample_dsec(self, batch, batch_idx): 166 | self.visualizer(batch[-1], batch_idx, None) 167 | 168 | def check_states(self, batch): 169 | # 0th case: there is a flag in the batch that tells us to reset the state (DSEC) 170 | if 'new_sequence' in batch[0].keys(): 171 | if batch[0]['new_sequence'].item() == 1: 172 | self.flow_init = None 173 | self.net_init = None 174 | self.logger.write_line("Resetting States!", True) 175 | else: 176 | # During Validation, reset state if a new scene starts (index jump) 177 | if self.idx_prev is not None and batch[0]['idx'].item() - self.idx_prev != 1: 178 | self.flow_init = None 179 | self.net_init = None 180 | self.logger.write_line("Resetting States!", True) 181 | self.idx_prev = batch[0]['idx'].item() 182 | 183 | def run_network(self, batch): 184 | self.check_states(batch) 185 | for l in range(len(batch)): 186 | # Run Recurrent Network for this sample 187 | 188 | if not self.downsample: 189 | im1 = batch[l]['event_volume_old'] 190 | im2 = batch[l]['event_volume_new'] 191 | else: 192 | im1 = torch.nn.functional.interpolate(batch[l]['event_volume_old'], scale_factor=0.5) 193 | im2 = torch.nn.functional.interpolate(batch[l]['event_volume_new'], scale_factor=0.5) 194 | flow_low_res, batch[l]['flow_list'] = self.model(image1=im1, 195 | image2=im2, 196 | flow_init=self.flow_init) 197 | 198 | batch[l]['flow_est'] = batch[l]['flow_list'][-1] 199 | self.flow_init = image_utils.forward_interpolate_pytorch(flow_low_res) 200 | batch[l]['flow_init'] = self.flow_init 201 | -------------------------------------------------------------------------------- /utils/dsec_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from enum import Enum, auto 4 | 5 | 6 | class RepresentationType(Enum): 7 | VOXEL = auto() 8 | STEPAN = auto() 9 | 10 | 11 | class EventRepresentation: 12 | def __init__(self): 13 | pass 14 | 15 | def convert(self, events): 16 | raise NotImplementedError 17 | 18 | 19 | class VoxelGrid(EventRepresentation): 20 | def __init__(self, input_size: tuple, normalize: bool): 21 | assert len(input_size) == 3 22 | self.voxel_grid = torch.zeros((input_size), dtype=torch.float, requires_grad=False) 23 | self.nb_channels = input_size[0] 24 | self.normalize = normalize 25 | 26 | def convert(self, events): 27 | C, H, W = self.voxel_grid.shape 28 | with torch.no_grad(): 29 | self.voxel_grid = self.voxel_grid.to(events['p'].device) 30 | voxel_grid = self.voxel_grid.clone() 31 | 32 | t_norm = events['t'] 33 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0]) 34 | 35 | x0 = events['x'].int() 36 | y0 = events['y'].int() 37 | t0 = t_norm.int() 38 | 39 | value = 2*events['p']-1 40 | 41 | for xlim in [x0,x0+1]: 42 | for ylim in [y0,y0+1]: 43 | for tlim in [t0,t0+1]: 44 | 45 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels) 46 | interp_weights = value * (1 - (xlim-events['x']).abs()) * (1 - (ylim-events['y']).abs()) * (1 - (tlim - t_norm).abs()) 47 | 48 | index = H * W * tlim.long() + \ 49 | W * ylim.long() + \ 50 | xlim.long() 51 | 52 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 53 | 54 | if self.normalize: 55 | mask = torch.nonzero(voxel_grid, as_tuple=True) 56 | if mask[0].size()[0] > 0: 57 | mean = voxel_grid[mask].mean() 58 | std = voxel_grid[mask].std() 59 | if std > 0: 60 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 61 | else: 62 | voxel_grid[mask] = voxel_grid[mask] - mean 63 | 64 | return voxel_grid 65 | 66 | def flow_16bit_to_float(flow_16bit: np.ndarray): 67 | assert flow_16bit.dtype == np.uint16 68 | assert flow_16bit.ndim == 3 69 | h, w, c = flow_16bit.shape 70 | assert c == 3 71 | 72 | valid2D = flow_16bit[..., 2] == 1 73 | assert valid2D.shape == (h, w) 74 | assert np.all(flow_16bit[~valid2D, -1] == 0) 75 | valid_map = np.where(valid2D) 76 | 77 | # to actually compute something useful: 78 | flow_16bit = flow_16bit.astype('float') 79 | 80 | flow_map = np.zeros((h, w, 2)) 81 | flow_map[valid_map[0], valid_map[1], 0] = (flow_16bit[valid_map[0], valid_map[1], 0] - 2 ** 15) / 128 82 | flow_map[valid_map[0], valid_map[1], 1] = (flow_16bit[valid_map[0], valid_map[1], 1] - 2 ** 15) / 128 83 | return flow_map, valid2D 84 | -------------------------------------------------------------------------------- /utils/filename_templates.py: -------------------------------------------------------------------------------- 1 | # =========================== Templates for saving Images ========================= # 2 | GT_FLOW = '{}_{}_flow_gt.png' 3 | FLOW_TEST = '{}_{}_flow.png' 4 | IMG = '{}_{}_image.png' 5 | EVENTS = '{}_{}_events.png' 6 | 7 | # ========================= Templates for saving Checkpoints ====================== # 8 | CHECKPOINT = '{:03d}_checkpoint.tar' 9 | 10 | # ========================= MVSEC DATALOADING ====================== # 11 | MVSEC_DATASET_FOLDER = '{}_{}' 12 | MVSEC_TIMESTAMPS_PATH_DEPTH = 'timestamps_depth.txt' 13 | MVSEC_TIMESTAMPS_PATH_FLOW = 'timestamps_flow.txt' 14 | MVSEC_TIMESTAMPS_PATH_IMAGES = 'timestamps_images.txt' 15 | MVSEC_EVENTS_FILE = "davis/{}/events/{:06d}.h5" 16 | MVSEC_FLOW_GT_FILE = "optical_flow/{:06d}.npy" 17 | -------------------------------------------------------------------------------- /utils/helper_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import smtplib 4 | import json 5 | 6 | def move_dict_to_cuda(dictionary_of_tensors, gpu): 7 | if isinstance(dictionary_of_tensors, dict): 8 | return { 9 | key: move_dict_to_cuda(value, gpu) 10 | for key, value in dictionary_of_tensors.items() 11 | } 12 | return dictionary_of_tensors.to(gpu, dtype=torch.float) 13 | 14 | def move_list_to_cuda(list_of_dicts, gpu): 15 | for i in range(len(list_of_dicts)): 16 | list_of_dicts[i] = move_dict_to_cuda(list_of_dicts[i], gpu) 17 | return list_of_dicts 18 | 19 | def get_values_from_key(input_list, key): 20 | # Returns all the values with the same key from 21 | # a list filled with dicts of the same kind 22 | out = [] 23 | for i in input_list: 24 | out.append(i[key]) 25 | return out 26 | 27 | def create_save_path(subdir, name): 28 | # Check if sub-folder exists, and create if necessary 29 | if not os.path.exists(subdir): 30 | os.mkdir(subdir) 31 | # Create a new folder (named after the name defined in the config file) 32 | path = os.path.join(subdir, name) 33 | # Check if path already exists. if yes -> append a number 34 | if os.path.exists(path): 35 | i = 1 36 | while os.path.exists(path + "_" + str(i)): 37 | i += 1 38 | path = path + '_' + str(i) 39 | os.mkdir(path) 40 | return path 41 | 42 | def get_nth_element_of_all_dict_keys(dict, idx): 43 | out_dict = {} 44 | for k in dict.keys(): 45 | d = dict[k][idx] 46 | if isinstance(d,torch.Tensor): 47 | out_dict[k]=d.detach().cpu().item() 48 | else: 49 | out_dict[k]=d 50 | return out_dict 51 | 52 | def get_number_of_saved_elements(path, template, first=1): 53 | i = first 54 | while True: 55 | if os.path.exists(os.path.join(path,template.format(i))): 56 | i+=1 57 | else: 58 | break 59 | return range(first, i) 60 | 61 | def create_file_path(subdir, name): 62 | # Check if sub-folder exists, else raise exception 63 | if not os.path.exists(subdir): 64 | raise Exception("Path {} does not exist!".format(subdir)) 65 | # Check if file already exists, else create path 66 | if not os.path.exists(os.path.join(subdir,name)): 67 | return os.path.join(subdir,name) 68 | else: 69 | path = os.path.join(subdir,name) 70 | prefix,suffix = path.split('.') 71 | i = 1 72 | while os.path.exists("{}_{}.{}".format(prefix,i,suffix)): 73 | i += 1 74 | return "{}_{}.{}".format(prefix,i,suffix) 75 | 76 | def update_dict(dict_old, dict_new): 77 | # Update all the entries of dict_old with the new values(that have the identical keys) of dict_new 78 | for k in dict_new.keys(): 79 | if k in dict_old.keys(): 80 | # Replace the entry 81 | if isinstance(dict_new[k], dict): 82 | update_dict(dict_old[k], dict_new[k]) 83 | else: 84 | dict_old[k] = dict_new[k] 85 | return dict_old 86 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | from torch import nn 4 | from torch.nn.functional import grid_sample 5 | from scipy.spatial import transform 6 | from scipy import interpolate 7 | from matplotlib import pyplot as plt 8 | 9 | 10 | def grid_sample_values(input, height, width): 11 | # ================================ Grid Sample Values ============================= # 12 | # Input: Torch Tensor [3,H*W]m where the 3 Dimensions mean [x,y,z] # 13 | # Height: Image Height # 14 | # Width: Image Width # 15 | # --------------------------------------------------------------------------------- # 16 | # Output: tuple(value_ipl, valid_mask) # 17 | # value_ipl -> [H,W]: Interpolated values # 18 | # valid_mask -> [H,W]: 1: Point is valid, 0: Point is invalid # 19 | # ================================================================================= # 20 | device = input.device 21 | ceil = torch.stack([torch.ceil(input[0,:]), torch.ceil(input[1,:]), input[2,:]]) 22 | floor = torch.stack([torch.floor(input[0,:]), torch.floor(input[1,:]), input[2,:]]) 23 | z = input[2,:].clone() 24 | 25 | values_ipl = torch.zeros(height*width, device=device) 26 | weights_acc = torch.zeros(height*width, device=device) 27 | # Iterate over all ceil/floor points 28 | for x_vals in [floor[0], ceil[0]]: 29 | for y_vals in [floor[1], ceil[1]]: 30 | # Mask Points that are in the image 31 | in_bounds_mask = (x_vals < width) & (x_vals >=0) & (y_vals < height) & (y_vals >= 0) 32 | 33 | # Calculate weights, according to their real distance to the floored/ceiled value 34 | weights = (1 - (input[0]-x_vals).abs()) * (1 - (input[1]-y_vals).abs()) 35 | 36 | # Put them into the right grid 37 | indices = (x_vals + width * y_vals).long() 38 | values_ipl.put_(indices[in_bounds_mask], (z * weights)[in_bounds_mask], accumulate=True) 39 | weights_acc.put_(indices[in_bounds_mask], weights[in_bounds_mask], accumulate=True) 40 | 41 | # Mask of valid pixels -> Everywhere where we have an interpolated value 42 | valid_mask = weights_acc.clone() 43 | valid_mask[valid_mask > 0] = 1 44 | valid_mask= valid_mask.bool().reshape([height,width]) 45 | 46 | # Divide by weights to get interpolated values 47 | values_ipl = values_ipl / (weights_acc + 1e-15) 48 | values_rs = values_ipl.reshape([height,width]) 49 | 50 | return values_rs.unsqueeze(0).clone(), valid_mask.unsqueeze(0).clone() 51 | 52 | def forward_interpolate_pytorch(flow_in): 53 | # Same as the numpy implementation, but differentiable :) 54 | # Flow: [B,2,H,W] 55 | flow = flow_in.clone() 56 | if len(flow.shape) < 4: 57 | flow = flow.unsqueeze(0) 58 | 59 | b, _, h, w = flow.shape 60 | device = flow.device 61 | 62 | dx ,dy = flow[:,0], flow[:,1] 63 | y0, x0 = torch.meshgrid(torch.arange(0, h, 1), torch.arange(0, w, 1)) 64 | x0 = torch.stack([x0]*b).to(device) 65 | y0 = torch.stack([y0]*b).to(device) 66 | 67 | x1 = x0 + dx 68 | y1 = y0 + dy 69 | 70 | x1 = x1.flatten(start_dim=1) 71 | y1 = y1.flatten(start_dim=1) 72 | dx = dx.flatten(start_dim=1) 73 | dy = dy.flatten(start_dim=1) 74 | 75 | # Interpolate Griddata... 76 | # Note that a Nearest Neighbor Interpolation would be better. But there does not exist a pytorch fcn yet. 77 | # See issue: https://github.com/pytorch/pytorch/issues/50339 78 | flow_new = torch.zeros(flow.shape, device=device) 79 | for i in range(b): 80 | flow_new[i,0] = grid_sample_values(torch.stack([x1[i],y1[i],dx[i]]), h, w)[0] 81 | flow_new[i,1] = grid_sample_values(torch.stack([x1[i],y1[i],dy[i]]), h, w)[0] 82 | 83 | return flow_new 84 | 85 | class ImagePadder(object): 86 | # =================================================================== # 87 | # In some networks, the image gets downsized. This is a problem, if # 88 | # the to-be-downsized image has odd dimensions ([15x20]->[7.5x10]). # 89 | # To prevent this, the input image of the network needs to be a # 90 | # multiple of a minimum size (min_size) # 91 | # The ImagePadder makes sure, that the input image is of such a size, # 92 | # and if not, it pads the image accordingly. # 93 | # =================================================================== # 94 | 95 | def __init__(self, min_size=64): 96 | # --------------------------------------------------------------- # 97 | # The min_size additionally ensures, that the smallest image # 98 | # does not get too small # 99 | # --------------------------------------------------------------- # 100 | self.min_size = min_size 101 | self.pad_height = None 102 | self.pad_width = None 103 | 104 | def pad(self, image): 105 | # --------------------------------------------------------------- # 106 | # If necessary, this function pads the image on the left & top # 107 | # --------------------------------------------------------------- # 108 | height, width = image.shape[-2:] 109 | if self.pad_width is None: 110 | self.pad_height = (self.min_size - height % self.min_size)%self.min_size 111 | self.pad_width = (self.min_size - width % self.min_size)%self.min_size 112 | else: 113 | pad_height = (self.min_size - height % self.min_size)%self.min_size 114 | pad_width = (self.min_size - width % self.min_size)%self.min_size 115 | if pad_height != self.pad_height or pad_width != self.pad_width: 116 | raise 117 | return nn.ZeroPad2d((self.pad_width, 0, self.pad_height, 0))(image) 118 | 119 | def unpad(self, image): 120 | # --------------------------------------------------------------- # 121 | # Removes the padded rows & columns # 122 | # --------------------------------------------------------------- # 123 | return image[..., self.pad_height:, self.pad_width:] 124 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy 4 | import shutil 5 | 6 | class Logger: 7 | # Logger of the Training/Testing Process 8 | def __init__(self, save_path, custom_name='log.txt'): 9 | self.toWrite = {} 10 | self.signalization = "========================================" 11 | self.path = os.path.join(save_path,custom_name) 12 | 13 | def initialize_file(self, mode): 14 | # Mode : "Training" or "Testing" 15 | with open(self.path, 'a') as file: 16 | file.write(self.signalization + " " + mode + " " + self.signalization + "\n") 17 | 18 | def write_as_list(self, dict_to_write, overwrite=False): 19 | if overwrite: 20 | if os.path.exists(self.path): 21 | os.remove(self.path) 22 | with open(self.path, 'a') as file: 23 | for entry in dict_to_write.keys(): 24 | file.write(entry+"="+json.dumps(dict_to_write[entry])+"\n") 25 | 26 | def write_dict(self, dict_to_write, array_names=None, overwrite=False, as_list=False): 27 | if overwrite: 28 | open_type = 'w' 29 | else: 30 | open_type = 'a' 31 | dict_to_write = self.check_for_arrays(dict_to_write, array_names) 32 | if as_list: 33 | self.write_as_list(dict_to_write, overwrite) 34 | else: 35 | with open(self.path, open_type) as file: 36 | #if "epoch" in dict_to_write: 37 | # file.write("Epoch") 38 | file.write(json.dumps(dict_to_write) + "\n") 39 | 40 | def write_line(self,line, verbose=False): 41 | with open(self.path, 'a') as file: 42 | file.write(line + "\n") 43 | if verbose: 44 | print(line) 45 | 46 | def arrays_to_dicts(self, list_of_arrays, array_name, entry_name): 47 | list_of_arrays = numpy.array(list_of_arrays).T 48 | out = {} 49 | for i in range(list_of_arrays.shape[0]): 50 | out[array_name+'_'+entry_name[i]] = list(list_of_arrays[i]) 51 | return out 52 | 53 | 54 | def check_for_arrays(self, dict_to_write, array_names): 55 | if array_names is not None: 56 | names = [] 57 | for n in range(len(array_names)): 58 | if hasattr(array_names[n], 'name'): 59 | names.append(array_names[n].name) 60 | elif hasattr(array_names[n],'__name__'): 61 | names.append(array_names[n].__name__) 62 | elif hasattr(array_names[n],'__class__'): 63 | names.append(array_names[n].__class__.__name__) 64 | else: 65 | names.append(array_names[n]) 66 | 67 | keys = dict_to_write.keys() 68 | out = {} 69 | for entry in keys: 70 | if hasattr(dict_to_write[entry], '__len__') and len(dict_to_write[entry])>0: 71 | if isinstance(dict_to_write[entry][0], numpy.ndarray) or isinstance(dict_to_write[entry][0], list): 72 | out.update(self.arrays_to_dicts(dict_to_write[entry], entry, names)) 73 | else: 74 | out.update({entry:dict_to_write[entry]}) 75 | else: 76 | out.update({entry: dict_to_write[entry]}) 77 | return out 78 | -------------------------------------------------------------------------------- /utils/mvsec_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import cv2 3 | from matplotlib import pyplot as plt 4 | import os 5 | from utils import filename_templates as TEMPLATES 6 | 7 | def prop_flow(x_flow, y_flow, x_indices, y_indices, x_mask, y_mask, scale_factor=1.0): 8 | flow_x_interp = cv2.remap(x_flow, 9 | x_indices, 10 | y_indices, 11 | cv2.INTER_NEAREST) 12 | 13 | flow_y_interp = cv2.remap(y_flow, 14 | x_indices, 15 | y_indices, 16 | cv2.INTER_NEAREST) 17 | 18 | x_mask[flow_x_interp == 0] = False 19 | y_mask[flow_y_interp == 0] = False 20 | 21 | x_indices += flow_x_interp * scale_factor 22 | y_indices += flow_y_interp * scale_factor 23 | 24 | return 25 | 26 | def estimate_corresponding_gt_flow(path_flow, 27 | gt_timestamps, 28 | start_time, 29 | end_time): 30 | # Each gt flow at timestamp gt_timestamps[gt_iter] represents the displacement between 31 | # gt_iter and gt_iter+1. 32 | # gt_timestamps[gt_iter] -> Timestamp just before start_time 33 | 34 | gt_iter = np.searchsorted(gt_timestamps, start_time, side='right') - 1 35 | gt_dt = gt_timestamps[gt_iter + 1] - gt_timestamps[gt_iter] 36 | 37 | # Load Flow just before start_time 38 | flow_file = os.path.join(path_flow, TEMPLATES.MVSEC_FLOW_GT_FILE.format(gt_iter)) 39 | flow = np.load(flow_file) 40 | 41 | x_flow = flow[0] 42 | y_flow = flow[1] 43 | #x_flow = np.squeeze(x_flow_in[gt_iter, ...]) 44 | #y_flow = np.squeeze(y_flow_in[gt_iter, ...]) 45 | 46 | dt = end_time - start_time 47 | 48 | # No need to propagate if the desired dt is shorter than the time between gt timestamps. 49 | if gt_dt > dt: 50 | return x_flow * dt / gt_dt, y_flow * dt / gt_dt 51 | else: 52 | raise Exception 53 | -------------------------------------------------------------------------------- /utils/transformers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | def dictionary_of_numpy_arrays_to_tensors(sample): 5 | """Transforms dictionary of numpy arrays to dictionary of tensors.""" 6 | if isinstance(sample, dict): 7 | return { 8 | key: dictionary_of_numpy_arrays_to_tensors(value) 9 | for key, value in sample.items() 10 | } 11 | if isinstance(sample, np.ndarray): 12 | if len(sample.shape) == 2: 13 | return th.from_numpy(sample).float().unsqueeze(0) 14 | else: 15 | return th.from_numpy(sample).float() 16 | return sample 17 | 18 | class EventSequenceToVoxelGrid_Pytorch(object): 19 | # Source: https://github.com/uzh-rpg/rpg_e2vid/blob/master/utils/inference_utils.py#L480 20 | def __init__(self, num_bins, gpu=False, gpu_nr=0, normalize=True, forkserver=True): 21 | if forkserver: 22 | try: 23 | th.multiprocessing.set_start_method('forkserver') 24 | except RuntimeError: 25 | pass 26 | self.num_bins = num_bins 27 | self.normalize = normalize 28 | if gpu: 29 | if not th.cuda.is_available(): 30 | print('Warning: There\'s no CUDA support on this machine!') 31 | else: 32 | self.device = th.device('cuda:' + str(gpu_nr)) 33 | else: 34 | self.device = th.device('cpu') 35 | 36 | def __call__(self, event_sequence): 37 | """ 38 | Build a voxel grid with bilinear interpolation in the time domain from a set of events. 39 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity] 40 | :param num_bins: number of bins in the temporal axis of the voxel grid 41 | :param width, height: dimensions of the voxel grid 42 | :param device: device to use to perform computations 43 | :return voxel_grid: PyTorch event tensor (on the device specified) 44 | """ 45 | 46 | events = event_sequence.features.astype('float') 47 | 48 | width = event_sequence.image_width 49 | height = event_sequence.image_height 50 | 51 | assert (events.shape[1] == 4) 52 | assert (self.num_bins > 0) 53 | assert (width > 0) 54 | assert (height > 0) 55 | 56 | with th.no_grad(): 57 | 58 | events_torch = th.from_numpy(events) 59 | # with DeviceTimer('Events -> Device (voxel grid)'): 60 | events_torch = events_torch.to(self.device) 61 | 62 | # with DeviceTimer('Voxel grid voting'): 63 | voxel_grid = th.zeros(self.num_bins, height, width, dtype=th.float32, device=self.device).flatten() 64 | 65 | # normalize the event timestamps so that they lie between 0 and num_bins 66 | last_stamp = events_torch[-1, 0] 67 | first_stamp = events_torch[0, 0] 68 | 69 | assert last_stamp.dtype == th.float64, 'Timestamps must be float64!' 70 | # assert last_stamp.item()%1 == 0, 'Timestamps should not have decimals' 71 | 72 | deltaT = last_stamp - first_stamp 73 | 74 | if deltaT == 0: 75 | deltaT = 1.0 76 | 77 | events_torch[:, 0] = (self.num_bins - 1) * (events_torch[:, 0] - first_stamp) / deltaT 78 | ts = events_torch[:, 0] 79 | xs = events_torch[:, 1].long() 80 | ys = events_torch[:, 2].long() 81 | pols = events_torch[:, 3].float() 82 | pols[pols == 0] = -1 # polarity should be +1 / -1 83 | 84 | 85 | tis = th.floor(ts) 86 | tis_long = tis.long() 87 | dts = ts - tis 88 | vals_left = pols * (1.0 - dts.float()) 89 | vals_right = pols * dts.float() 90 | 91 | valid_indices = tis < self.num_bins 92 | valid_indices &= tis >= 0 93 | 94 | if events_torch.is_cuda: 95 | datatype = th.cuda.LongTensor 96 | else: 97 | datatype = th.LongTensor 98 | 99 | voxel_grid.index_add_(dim=0, 100 | index=(xs[valid_indices] + ys[valid_indices] 101 | * width + tis_long[valid_indices] * width * height).type( 102 | datatype), 103 | source=vals_left[valid_indices]) 104 | 105 | 106 | valid_indices = (tis + 1) < self.num_bins 107 | valid_indices &= tis >= 0 108 | 109 | voxel_grid.index_add_(dim=0, 110 | index=(xs[valid_indices] + ys[valid_indices] * width 111 | + (tis_long[valid_indices] + 1) * width * height).type(datatype), 112 | source=vals_right[valid_indices]) 113 | 114 | voxel_grid = voxel_grid.view(self.num_bins, height, width) 115 | 116 | if self.normalize: 117 | mask = th.nonzero(voxel_grid, as_tuple=True) 118 | if mask[0].size()[0] > 0: 119 | mean = voxel_grid[mask].mean() 120 | std = voxel_grid[mask].std() 121 | if std > 0: 122 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 123 | else: 124 | voxel_grid[mask] = voxel_grid[mask] - mean 125 | 126 | return voxel_grid 127 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from matplotlib import pyplot as plt 3 | from matplotlib import colors 4 | import numpy 5 | import os 6 | import loader.utils as loader 7 | import utils.transformers as transformers 8 | import utils.filename_templates as TEMPLATES 9 | import utils.helper_functions as helper 10 | from matplotlib.lines import Line2D 11 | from skimage.transform import rotate, warp 12 | from skimage import io 13 | import cv2 14 | import imageio 15 | from torchvision.transforms import CenterCrop 16 | 17 | class BaseVisualizer(object): 18 | def __init__(self, dataloader, save_path, additional_args=None): 19 | super(BaseVisualizer, self).__init__() 20 | self.dataloader = dataloader 21 | self.visu_path = helper.create_save_path(save_path, 'visualizations') 22 | self.submission_path = os.path.join(save_path, 'submission') 23 | os.mkdir(self.submission_path) 24 | self.mode = 'test' 25 | self.additional_args = additional_args 26 | 27 | def __call__(self, batch, epoch=None): 28 | for j in batch['loader_idx'].cpu().numpy().astype(int): 29 | # Get Batch Index 30 | batch_idx = torch.nonzero(batch['loader_idx'] == j).item() 31 | 32 | # Visualize Ground Truths, but only in first epoch or if we're cropping 33 | if epoch == 1 or epoch is None or 'crop_window' in batch.keys(): 34 | self.visualize_ground_truths(batch, batch_idx=batch_idx, epoch=epoch, 35 | data_aug='crop_window' in batch.keys()) 36 | 37 | # Visualize Estimations 38 | self.visualize_estimations(batch, batch_idx=batch_idx, epoch=epoch) 39 | 40 | def visualize_ground_truths(self, batch, batch_idx, epoch=None, data_aug=False): 41 | raise NotImplementedError 42 | 43 | def visualize_estimations(self, batch, batch_idx, epoch=None): 44 | raise NotImplementedError 45 | 46 | def visualize_image(self, image, true_idx, epoch=None, data_aug=False): 47 | true_idx = int(true_idx) 48 | name = TEMPLATES.IMG.format('inference', true_idx) 49 | save_image(os.path.join(self.visu_path, name), image.detach().cpu()) 50 | 51 | def visualize_events(self, image, batch, batch_idx, epoch=None, flip_before_crop=True, crop_window=None): 52 | raise NotImplementedError 53 | 54 | def visualize_flow_colours(self, flow, true_idx, epoch=None, data_aug=False, is_gt=False, fix_scaling=10, 55 | custom_name=None, prefix=None, suffix=None, sub_folder=None): 56 | true_idx = int(true_idx) 57 | 58 | if custom_name is None: 59 | name = TEMPLATES.FLOW_TEST.format('inference', true_idx) 60 | else: 61 | name = custom_name 62 | if prefix is not None: 63 | name = prefix + name 64 | if suffix is not None: 65 | split = name.split('.') 66 | name = split[0] + suffix + "." +split[1] 67 | if sub_folder is not None: 68 | name = os.path.join(sub_folder, name) 69 | # Visualize 70 | _, scaling = visualize_optical_flow(flow.detach().cpu().numpy(), 71 | os.path.join(self.visu_path, name), 72 | scaling=fix_scaling) 73 | return scaling 74 | 75 | def visualize_flow_submission(self, seq_name: str, flow: numpy.ndarray, file_index: int): 76 | # flow_u(u,v) = ((float)I(u,v,1)-2^15)/128.0; 77 | # flow_v(u,v) = ((float)I(u,v,2)-2^15)/128.0; 78 | # valid(u,v) = (bool)I(u,v,3); 79 | # [-2**15/128, 2**15/128] = [-256, 256] 80 | #flow_map_16bit = np.rint(flow_map*128 + 2**15).astype(np.uint16) 81 | _, h,w = flow.shape 82 | flow_map = numpy.rint(flow*128 + 2**15) 83 | flow_map = flow_map.astype(numpy.uint16).transpose(1,2,0) 84 | flow_map = numpy.concatenate((flow_map, numpy.zeros((h,w,1), dtype=numpy.uint16)), axis=-1) 85 | parent_path = os.path.join( 86 | self.submission_path, 87 | seq_name 88 | ) 89 | if not os.path.exists(parent_path): 90 | os.mkdir(parent_path) 91 | file_name = '{:06d}.png'.format(file_index) 92 | 93 | imageio.imwrite(os.path.join(parent_path, file_name), flow_map, format='PNG-FI') 94 | 95 | class FlowVisualizerEvents(BaseVisualizer): 96 | def __init__(self, dataloader, save_path, clamp_flow=True, additional_args=None): 97 | super(FlowVisualizerEvents, self).__init__(dataloader, save_path, additional_args=additional_args) 98 | self.flow_scaling = 0 99 | self.clamp_flow = clamp_flow 100 | 101 | def visualize_events(self, image, batch, batch_idx, epoch=None, flip_before_crop=True, crop_window=None): 102 | # Plots Events on top of an Image. 103 | if image is not None: 104 | im = image.detach().cpu() 105 | else: 106 | im = None 107 | 108 | # Load Raw events 109 | events = self.dataloader.dataset.get_events(loader_idx=int(batch['loader_idx'][batch_idx].item())) 110 | name_events = TEMPLATES.EVENTS.format('inference', int(batch['idx'][batch_idx].item())) 111 | 112 | # Event Sequence to Event Image 113 | events = events_to_event_image(events, 114 | int(batch['param_evc']['height'][batch_idx].item()), 115 | int(batch['param_evc']['width'][batch_idx].item()), 116 | im, 117 | crop_window=crop_window, 118 | rotation_angle=False, 119 | horizontal_flip=False, 120 | flip_before_crop=False) 121 | # center-crop 256x256 122 | crop = CenterCrop(256) 123 | events = crop(events) 124 | # Save 125 | save_image(os.path.join(self.visu_path, name_events), events) 126 | 127 | def visualize_ground_truths(self, batch, batch_idx, epoch=None, data_aug=False): 128 | # Visualize Events 129 | if 'image_old' in batch.keys(): 130 | image_old = batch['image_old'][batch_idx] 131 | else: 132 | image_old = None 133 | self.visualize_events(image_old, batch, batch_idx, epoch) 134 | 135 | # Visualize Image 136 | ''' 137 | if 'image_old' in batch.keys(): 138 | self.visualize_image(batch['image_old'][batch_idx], batch['idx'][batch_idx],epoch, data_aug) 139 | ''' 140 | # Visualize Flow GT 141 | flow_gt = batch['flow'][batch_idx].clone() 142 | flow_gt[~batch['gt_valid_mask'][batch_idx].bool()] = 0.0 143 | self.flow_scaling = self.visualize_flow_colours(flow_gt, batch['idx'][batch_idx], epoch=epoch, 144 | data_aug=data_aug, is_gt=True, fix_scaling=None, suffix='_gt') 145 | 146 | def visualize_estimations(self, batch, batch_idx, epoch=None): 147 | # Visualize Flow Estimation 148 | if self.clamp_flow: 149 | scaling = self.flow_scaling[1] 150 | else: 151 | scaling = None 152 | self.visualize_flow_colours(batch['flow_est'][batch_idx], batch['idx'][batch_idx], epoch=epoch, 153 | is_gt=False, fix_scaling=scaling) 154 | 155 | # Visualize Masked Flow 156 | flow_est = batch['flow_est'][batch_idx].clone() 157 | flow_est[~batch['gt_valid_mask'][batch_idx].bool()] = 0.0 158 | self.visualize_flow_colours(flow_est, batch['idx'][batch_idx], epoch=epoch, 159 | is_gt=False, fix_scaling=scaling, suffix='_masked') 160 | 161 | class DsecFlowVisualizer(BaseVisualizer): 162 | def __init__(self, dataloader, save_path, additional_args=None): 163 | super(DsecFlowVisualizer, self).__init__(dataloader, save_path, additional_args=additional_args) 164 | # Create Visu folders for every sequence 165 | for name in self.additional_args['name_mapping']: 166 | os.mkdir(os.path.join(self.visu_path, name)) 167 | os.mkdir(os.path.join(self.submission_path, name)) 168 | 169 | def visualize_events(self, image, batch, batch_idx, sequence_name): 170 | sequence_idx = [i for i, e in enumerate(self.additional_args['name_mapping']) if e == sequence_name][0] 171 | delta_t_us = self.dataloader.dataset.datasets[sequence_idx].delta_t_us 172 | loader_instance = self.dataloader.dataset.datasets[sequence_idx] 173 | h, w = loader_instance.get_image_width_height() 174 | events = loader_instance.event_slicer.get_events( 175 | t_start_us=batch['timestamp'][batch_idx].item(), 176 | t_end_us=batch['timestamp'][batch_idx].item()+delta_t_us 177 | ) 178 | p = events['p'].astype(numpy.int8) 179 | t = events['t'].astype(numpy.float64) 180 | x = events['x'] 181 | y = events['y'] 182 | p = 2*p - 1 183 | xy_rect = loader_instance.rectify_events(x, y) 184 | x_rect = numpy.rint(xy_rect[:, 0]) 185 | y_rect = numpy.rint(xy_rect[:, 1]) 186 | 187 | events_rectified = numpy.stack([t, x_rect, y_rect, p], axis=-1) 188 | event_image = events_to_event_image( 189 | event_sequence=events_rectified, 190 | height=h, 191 | width=w 192 | ).numpy() 193 | name_events = TEMPLATES.EVENTS.format('inference', int(batch['file_index'][batch_idx].item())) 194 | out_path = os.path.join(self.visu_path, sequence_name, name_events) 195 | imageio.imsave(out_path, event_image.transpose(1,2,0)) 196 | 197 | def __call__(self, batch, batch_idx, epoch=None): 198 | for batch_idx in range(len(batch['file_index'])): 199 | if batch['save_submission'][batch_idx]: 200 | sequence_name = self.additional_args['name_mapping'][int(batch['name_map'][batch_idx].item())] 201 | # Save for Benchmark Submission 202 | self.visualize_flow_submission( 203 | seq_name=sequence_name, 204 | flow=batch['flow_est'][batch_idx].clone().cpu().numpy(), 205 | file_index=int(batch['file_index'][batch_idx].item()), 206 | ) 207 | if batch['visualize'][batch_idx]: 208 | sequence_name = self.additional_args['name_mapping'][int(batch['name_map'][batch_idx].item())] 209 | # Visualize Flow 210 | self.visualize_flow_colours( 211 | batch['flow_est'][batch_idx], 212 | batch['file_index'][batch_idx], 213 | epoch=epoch, 214 | is_gt=False, 215 | fix_scaling=None, 216 | sub_folder=sequence_name 217 | ) 218 | # Visualize Events 219 | self.visualize_events( 220 | image=None, 221 | batch=batch, 222 | batch_idx=batch_idx, 223 | sequence_name=sequence_name 224 | ) 225 | 226 | def save_tensor(filepath, tensor): 227 | map = plt.get_cmap('plasma') 228 | t = tensor[0].numpy() / tensor[0].numpy().max() 229 | image = map(t) * 255 230 | io.imsave(filepath, image.astype(numpy.uint8)) 231 | 232 | 233 | def grayscale_to_rgb(tensor, permute=False): 234 | # Tensor [height, width, 3], or 235 | # Tensor [height, width, 1], or 236 | # Tensor [1, height, width], or 237 | # Tensor [3, height, width] 238 | 239 | # if permute -> Convert to [height, width, 3] 240 | if permute: 241 | if tensor.size()[0] < 4: 242 | tensor = tensor.permute(1, 2, 0) 243 | if tensor.size()[2] == 1: 244 | return torch.stack([tensor[:, :, 0]] * 3, dim=2) 245 | else: 246 | return tensor 247 | else: 248 | if tensor.size()[0] == 1: 249 | return torch.stack([tensor[0, :, :]] * 3, dim=0) 250 | else: 251 | return tensor 252 | 253 | 254 | def save_image(filepath, tensor): 255 | # Tensor [height, width, 3], or 256 | # Tensor [height, width, 1], or 257 | # Tensor [1, height, width], or 258 | # Tensor [3, height, width] 259 | 260 | # Convert to [height, width, 3] 261 | tensor = grayscale_to_rgb(tensor, True).numpy() 262 | use_pyplot=False 263 | if use_pyplot: 264 | fig = plt.figure() 265 | # Change Dimensions of Tensor 266 | plot = plt.imshow(tensor.astype(numpy.uint8)) 267 | plot.axes.get_xaxis().set_visible(False) 268 | plot.axes.get_yaxis().set_visible(False) 269 | fig.savefig(filepath, bbox_inches='tight', dpi=200) 270 | plt.close() 271 | else: 272 | io.imsave(filepath, tensor.astype(numpy.uint8)) 273 | 274 | 275 | def events_to_event_image(event_sequence, height, width, background=None, rotation_angle=None, crop_window=None, 276 | horizontal_flip=False, flip_before_crop=True): 277 | polarity = event_sequence[:, 3] == -1.0 278 | x_negative = event_sequence[~polarity, 1].astype(numpy.int) 279 | y_negative = event_sequence[~polarity, 2].astype(numpy.int) 280 | x_positive = event_sequence[polarity, 1].astype(numpy.int) 281 | y_positive = event_sequence[polarity, 2].astype(numpy.int) 282 | 283 | positive_histogram, _, _ = numpy.histogram2d( 284 | x_positive, 285 | y_positive, 286 | bins=(width, height), 287 | range=[[0, width], [0, height]]) 288 | negative_histogram, _, _ = numpy.histogram2d( 289 | x_negative, 290 | y_negative, 291 | bins=(width, height), 292 | range=[[0, width], [0, height]]) 293 | 294 | # Red -> Negative Events 295 | red = numpy.transpose((negative_histogram >= positive_histogram) & (negative_histogram != 0)) 296 | # Blue -> Positive Events 297 | blue = numpy.transpose(positive_histogram > negative_histogram) 298 | # Normally, we flip first, before we apply the other data augmentations 299 | if flip_before_crop: 300 | if horizontal_flip: 301 | red = numpy.flip(red, axis=1) 302 | blue = numpy.flip(blue, axis=1) 303 | # Rotate, if necessary 304 | if rotation_angle is not None: 305 | red = rotate(red, angle=rotation_angle, preserve_range=True).astype(bool) 306 | blue = rotate(blue, angle=rotation_angle, preserve_range=True).astype(bool) 307 | # Crop, if necessary 308 | if crop_window is not None: 309 | tf = transformers.RandomCropping(crop_height=crop_window['crop_height'], 310 | crop_width=crop_window['crop_width'], 311 | left_right=crop_window['left_right'], 312 | shift=crop_window['shift']) 313 | red = tf.crop_image(red, None, window=crop_window) 314 | blue = tf.crop_image(blue, None, window=crop_window) 315 | else: 316 | # Rotate, if necessary 317 | if rotation_angle is not None: 318 | red = rotate(red, angle=rotation_angle, preserve_range=True).astype(bool) 319 | blue = rotate(blue, angle=rotation_angle, preserve_range=True).astype(bool) 320 | # Crop, if necessary 321 | if crop_window is not None: 322 | tf = transformers.RandomCropping(crop_height=crop_window['crop_height'], 323 | crop_width=crop_window['crop_width'], 324 | left_right=crop_window['left_right'], 325 | shift=crop_window['shift']) 326 | red = tf.crop_image(red, None, window=crop_window) 327 | blue = tf.crop_image(blue, None, window=crop_window) 328 | if horizontal_flip: 329 | red = numpy.flip(red, axis=1) 330 | blue = numpy.flip(blue, axis=1) 331 | 332 | if background is None: 333 | height, width = red.shape 334 | background = torch.full((3, height, width), 255).byte() 335 | if len(background.shape) == 2: 336 | background = background.unsqueeze(0) 337 | else: 338 | if min(background.size()) == 1: 339 | background = grayscale_to_rgb(background) 340 | else: 341 | if not isinstance(background, torch.Tensor): 342 | background = torch.from_numpy(background) 343 | points_on_background = plot_points_on_background( 344 | torch.nonzero(torch.from_numpy(red.astype(numpy.uint8))), background, 345 | [255, 0, 0]) 346 | points_on_background = plot_points_on_background( 347 | torch.nonzero(torch.from_numpy(blue.astype(numpy.uint8))), 348 | points_on_background, [0, 0, 255]) 349 | return points_on_background 350 | 351 | 352 | def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100): 353 | new_cmap = colors.LinearSegmentedColormap.from_list( 354 | 'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval), 355 | cmap(numpy.linspace(minval, maxval, n))) 356 | return new_cmap 357 | 358 | 359 | def plot_points_on_background(points_coordinates, 360 | background, 361 | points_color=[0, 0, 255]): 362 | """ 363 | Args: 364 | points_coordinates: array of (y, x) points coordinates 365 | of size (number_of_points x 2). 366 | background: (3 x height x width) 367 | gray or color image uint8. 368 | color: color of points [red, green, blue] uint8. 369 | """ 370 | if not (len(background.size()) == 3 and background.size(0) == 3): 371 | raise ValueError('background should be (color x height x width).') 372 | _, height, width = background.size() 373 | background_with_points = background.clone() 374 | y, x = points_coordinates.transpose(0, 1) 375 | if len(x) > 0 and len(y) > 0: # There can be empty arrays! 376 | x_min, x_max = x.min(), x.max() 377 | y_min, y_max = y.min(), y.max() 378 | if not (x_min >= 0 and y_min >= 0 and x_max < width and y_max < height): 379 | raise ValueError('points coordinates are outsize of "background" ' 380 | 'boundaries.') 381 | background_with_points[:, y, x] = torch.Tensor(points_color).type_as( 382 | background).unsqueeze(-1) 383 | return background_with_points 384 | 385 | 386 | def visualize_optical_flow(flow, savepath=None, return_image=False, text=None, scaling=None): 387 | # flow -> numpy array 2 x height x width 388 | # 2,h,w -> h,w,2 389 | flow = flow.transpose(1,2,0) 390 | flow[numpy.isinf(flow)]=0 391 | # Use Hue, Saturation, Value colour model 392 | hsv = numpy.zeros((flow.shape[0], flow.shape[1], 3), dtype=float) 393 | 394 | # The additional **0.5 is a scaling factor 395 | mag = numpy.sqrt(flow[...,0]**2+flow[...,1]**2)**0.5 396 | 397 | ang = numpy.arctan2(flow[...,1], flow[...,0]) 398 | ang[ang<0]+=numpy.pi*2 399 | hsv[..., 0] = ang/numpy.pi/2.0 # Scale from 0..1 400 | hsv[..., 1] = 1 401 | if scaling is None: 402 | hsv[..., 2] = (mag-mag.min())/(mag-mag.min()).max() # Scale from 0..1 403 | else: 404 | mag[mag>scaling]=scaling 405 | hsv[...,2] = mag/scaling 406 | rgb = colors.hsv_to_rgb(hsv) 407 | # This all seems like an overkill, but it's just to exactly match the cv2 implementation 408 | bgr = numpy.stack([rgb[...,2],rgb[...,1],rgb[...,0]], axis=2) 409 | plot_with_pyplot = False 410 | if plot_with_pyplot: 411 | fig = plt.figure(frameon=False) 412 | plot = plt.imshow(bgr) 413 | plot.axes.get_xaxis().set_visible(False) 414 | plot.axes.get_yaxis().set_visible(False) 415 | if text is not None: 416 | plt.text(0, -5, text) 417 | 418 | if savepath is not None: 419 | if plot_with_pyplot: 420 | fig.savefig(savepath, bbox_inches='tight', dpi=200) 421 | plt.close() 422 | else: #Plot with skimage 423 | out = bgr*255 424 | io.imsave(savepath, out.astype('uint8')) 425 | return bgr, (mag.min(), mag.max()) 426 | --------------------------------------------------------------------------------