├── .gitignore ├── .tmux.conf ├── README.md ├── checkpoints └── dsec_best.tar ├── config ├── dsec_contrast.yaml ├── dsec_infer.yaml └── mvsec_contrast.yaml ├── dataloader ├── common.py ├── contrast │ ├── base.py │ ├── encodings.py │ ├── h5_mvsec.py │ ├── mvsecDataset_no_ann.py │ └── raw_event_utils.py ├── dsceDataset.py ├── dsecProvider_test.py ├── dsecSequence.py ├── eventslicer.py └── representations.py ├── docker ├── Dockerfile ├── README.md ├── docker_build.sh └── docker_run_multi.sh ├── infer_dsec.py ├── infer_dsec.sh ├── models ├── __init__.py ├── __init__.pyc ├── model_large.py ├── submodule.py └── submodule.pyc ├── pre-processing ├── README.md ├── eventslicer.py ├── raw_event_parsing.py ├── voxel_generate_all.py └── voxel_generate_test_all.py ├── requirements.txt ├── resource ├── teaser.png └── temporal_stereo_demo.gif ├── run_tmux.sh └── utils ├── __init__.py ├── evaluation ├── __init__.py ├── eval.py ├── flow_eval.py ├── flow_pixel_error.py ├── inverse_warp.py └── pixel_error.py ├── evaluation_old ├── __init__.py ├── eval.py ├── flow_eval.py ├── flow_pixel_error.py ├── inverse_warp.py └── pixel_error.py ├── flow.py ├── flow_vis.py ├── iwe.py ├── logger.py ├── preprocess.py ├── readpfm.py ├── softsplat.py ├── softsplat_prev.py ├── visualization.py ├── viz.py └── warp.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 111 | .pdm.toml 112 | .pdm-python 113 | .pdm-build/ 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /.tmux.conf: -------------------------------------------------------------------------------- 1 | set -g history-limit 999999999 2 | set -g mouse on 3 | 4 | # Shift arrow to swtich windows 5 | bind -n S-Left previous-window 6 | bind -n S-Right next-window 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Temporal Event Stereo via Joint Learning with Stereoscopic Flow (TESNet) (ECCV2024) 2 |

3 | 4 |

5 |

6 | 7 |

8 | 9 | 10 | Official code for "Temporal Event Stereo via Joint Learning with Stereoscopic Flow" (ECCV2024) 11 | ([Paper](https://arxiv.org/abs/2407.10831)) 12 | 13 | 14 | 15 | ```bibtex 16 | @Article{tes24eccv, 17 | author = {Hoonhee Cho* and Jae-Young Kang* and Kuk-Jin Yoon}, 18 | title = {Temporal Event Stereo via Joint Learning with Stereoscopic Flow}, 19 | journal = {European Conference on Computer Vision. (ECCV)}, 20 | year = {2024}, 21 | } 22 | ``` 23 | 24 | 25 | 26 | ## Abstract 27 | Event cameras are dynamic vision sensors inspired by the biological retina, characterized by their high dynamic range, high temporal resolution, and low power consumption. These features make them capable of perceiving 3D environments even in extreme conditions. Event data is continuous across the time dimension, which allows a detailed description of each pixel's movements. To fully utilize the temporally dense and continuous nature of event cameras, we propose a novel temporal event stereo, a framework that continuously uses information from previous time steps. This is accomplished through the simultaneous training of an event stereo matching network alongside stereoscopic flow, a new concept that captures all pixel movements from stereo cameras. Since obtaining ground truth for optical flow during training is challenging, we propose a method that uses only disparity maps to train the stereoscopic flow. The performance of event-based stereo matching is enhanced by temporally aggregating information using the flows. We have achieved state-of-the-art performance on the MVSEC and the DSEC datasets. The method is computationally efficient, as it stacks previous information in a cascading manner. 28 | 29 | 30 | ## Datasets 31 | Please refer to the pre-processing directory ([pre-process](https://github.com/mickeykang16/TemporalEventStereo/tree/master/pre-processing)) for the dataset's format and details. 32 | 33 | ## Installation 34 | ### Docker Environment 35 | This project is based on cuda 11.1, python 3.8 and torch 1.10.1. Please refer to [docker setup](https://github.com/mickeykang16/TemporalEventStereo/tree/master/docker) for more details. 36 | ### Conda Environment 37 | Comming Soon 38 | 39 | 40 | ## Training 41 | Comming Soon 42 | 43 | ## Testing 44 | We provide checkpoint and test code for DSEC dataset. 45 | ```bash 46 | source infer_dsec.sh 47 | ``` -------------------------------------------------------------------------------- /checkpoints/dsec_best.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mickeykang16/TemporalEventStereo/d9dc74677f568dbdf2d1d06da20480c50410e274/checkpoints/dsec_best.tar -------------------------------------------------------------------------------- /config/dsec_contrast.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | epoch: 60 # to run test only, just set epoch to zero 3 | # type: minihourglass_small_v2_3d_of_backwarp_rawcost_small_of_channel 4 | # type: stackhourglass_small_v2_3d_of_backwarp_rawcost_small_of_channel 5 | # type: stackhourglass_3d_of_backwarp_rawcost_small_of_channel 6 | type: stackhourglass_v2 7 | # load_model: '' 8 | load_model: 'exps/new_flow_dsec/dsec_stackhourglass_v2_single/checkpoint_24.tar' 9 | # load_model: 'exps/stackhourglass_3d_of_backwarp_rawcost_small_of_channel_wo_cost_batch2/checkpoint_5.tar' 10 | # do_not_load_layer: ['dres4', 'classif3', 'of_block', 'fusion'] 11 | # do_not_load_layer: ['of_block', 'fusion'] 12 | load_strict: True 13 | remove_feat_first_weight : False 14 | load_optim: False 15 | pretrain_freeze: False 16 | 17 | maxdisp: 192 # depth for model costvolume 18 | eval_maxdisp: 192 # 255 / 7.0 # depth for evaluation mask 19 | height: 480 20 | # width: 352 21 | width: 640 # use for stackhourglass... 22 | # width: 640 # use for minihourglass_... 23 | # width: 384 # use for unet arch 24 | # frame_idxs: [0, ] 25 | frame_idxs: range(0, 1) # <- this type of expression is allowed 26 | use_prev_gradient: False 27 | dataset: DSEC 28 | split: 1 29 | # dataset_raw_root_path: /mnt2/DSEC_data 30 | dataset_raw_root_path: /home/jaeyoung/data3/dsec 31 | dataset_root_path: /home/jaeyoung/data3/dsec 32 | pseudo_root_path: /home/jaeyoung/data3/DSEC_pseudo_GT_all 33 | random_crop: False 34 | crop_size: [336, 480] 35 | 36 | use_pseudo_gt: True 37 | use_disp_gt_mask: False 38 | use_disp_flow_warp_mask: False 39 | use_featuremetric_loss: False 40 | use_disp_loss: True #default our loss 41 | use_contrast_loss: False 42 | use_stereo_loss: True #default for stereo True 43 | 44 | use_mini_data: False 45 | use_super_mini_data: False 46 | 47 | flow_smooth_weight: 0.1 48 | flow_scale: 1e-8 49 | val_of_viz: False 50 | 51 | 52 | orig_height: 480 53 | orig_width: 640 54 | 55 | use_raw_provider: True 56 | in_ch: 15 57 | delta_t_ms: 100 58 | 59 | 60 | log: 61 | log_train_every_n_batch: 40 62 | save_test_every_n_batch: 1 63 | lr: 0.001 64 | train: 65 | ann_path: view_4_train_v5_split1.json 66 | shuffle: True 67 | gradient_accumulation: 1 68 | batch_size: 4 69 | num_worker: 2 70 | validation: 71 | ann_path: view_4_test_v5_split1_all.json 72 | shuffle: False 73 | batch_size: 4 74 | num_worker: 2 75 | test: 76 | ann_path: view_4_test_v5_split1_all.json 77 | shuffle: False 78 | batch_size: 4 79 | num_worker: 4 -------------------------------------------------------------------------------- /config/dsec_infer.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | epoch: 0 # to run test only, just set epoch to zero 3 | type: ours_large 4 | load_model: 'checkpoints/dsec_best.tar' 5 | load_optim: False 6 | load_strict: True 7 | maxdisp: 192 # depth for model costvolume 8 | eval_maxdisp: 192 # evaluation 9 | height: 480 10 | width: 640 11 | frame_idxs: range(-3, 1) # use four frames 12 | use_prev_gradient: False 13 | dataset: DSEC 14 | dataset_raw_root_path: /home/data/dsec 15 | dataset_root_path: /home/data/dsec 16 | pseudo_root_path: /home/data/DSEC_pseudo_GT 17 | 18 | orig_height: 480 19 | orig_width: 640 20 | 21 | # Load both voxel and raw event for contrast maximization loss 22 | use_raw_provider: True 23 | in_ch: 15 24 | delta_t_ms: 50 25 | 26 | log: 27 | log_train_every_n_batch: 40 28 | save_test_every_n_batch: 100 29 | lr: 0.001 30 | train: 31 | shuffle: True 32 | gradient_accumulation: 0 33 | batch_size: 4 34 | num_worker: 16 35 | validation: 36 | shuffle: False 37 | batch_size: 4 38 | num_worker: 4 39 | test: 40 | shuffle: False 41 | batch_size: 1 42 | num_worker: 4 -------------------------------------------------------------------------------- /config/mvsec_contrast.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | epoch: 60 # to run test only, just set epoch to 0 3 | type: minihourglass_small_v6_3d_of_backwarp_rawcost_small_of_entropy_filter_new_flow 4 | load_model: 'exps/mvsec_multiframe_ablation2/split1_minihourglass_small_v6_3d_of_backwarp_rawcost_small_of_entropy_filter_new_flow_8frame3/continue/checkpoint_61.tar' 5 | # do_not_load_layer: ['dres4', 'classif3', 'of_block', 'fusion'] 6 | do_not_load_layer: [] 7 | load_strict: True 8 | load_optim: False # load optimizer state from checkpoint 9 | 10 | maxdisp: 48 # depth for model costvolume 11 | eval_maxdisp: 36 # depth for evaluation mask 12 | # Size after padding 13 | height: 288 14 | width: 348 15 | 16 | # frame_idxs: [0] 17 | frame_idxs: range(-7, 1) # <- this type of expression is allowed 18 | skip_num: 1 19 | 20 | use_prev_gradient: False 21 | dataset: MVSEC 22 | split: 1 23 | dataset_root_path: /home/jaeyoung/ws/event_stereo_ICCV2019/dataset 24 | use_pseudo_gt: True 25 | 26 | use_disp_loss: True #default our loss 27 | 28 | use_contrast_loss: True 29 | flow_regul_weight: 0.01 30 | contrast_flow_scale: 0.0 31 | 32 | use_stereo_loss: True #default for stereo True 33 | 34 | orig_height: 260 35 | orig_width: 346 36 | log: 37 | log_train_every_n_batch: 100 38 | save_test_every_n_batch: 1 39 | lr: 0.0008 40 | train: 41 | shuffle: True 42 | batch_size: 2 43 | num_worker: 4 44 | validation: 45 | shuffle: False 46 | batch_size: 4 47 | num_worker: 4 48 | test: 49 | shuffle: False 50 | batch_size: 4 51 | num_worker: 4 52 | -------------------------------------------------------------------------------- /dataloader/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from skimage.color import rgb2hsv, hsv2rgb 4 | from skimage.transform import pyramid_gaussian 5 | from skimage.measure import block_reduce 6 | 7 | import torch 8 | 9 | def _apply(func, x): 10 | 11 | if isinstance(x, (list, tuple)): 12 | return [_apply(func, x_i) for x_i in x] 13 | elif isinstance(x, dict): 14 | y = {} 15 | for key, value in x.items(): 16 | y[key] = _apply(func, value) 17 | return y 18 | else: 19 | return func(x) 20 | 21 | def crop(args, ps, py, px): # patch_size 22 | # args = [input, target] 23 | def _get_shape(*args): 24 | if isinstance(args[0], (list, tuple)): 25 | return _get_shape(args[0][0]) 26 | elif isinstance(args[0], dict): 27 | return _get_shape(list(args[0].values())[0]) 28 | else: 29 | return args[0].shape 30 | 31 | h, w, _ = _get_shape(args) 32 | # print(_get_shape(args)) 33 | # print(ps[1]) 34 | # import pdb; pdb.set_trace() 35 | 36 | # py = random.randrange(0, h-ps+1) 37 | # px = random.randrange(0, w-ps+1) 38 | 39 | def _crop(img): 40 | if img.ndim == 2: 41 | return img[py:py+ps[1], px:px+ps[0], np.newaxis] 42 | else: 43 | return img[py:py+ps[1], px:px+ps[0], :] 44 | 45 | return _apply(_crop, args) 46 | 47 | def crop_with_event(*args, left_event, right_event, ps=256): # patch_size 48 | # args = [input, target] 49 | def _get_shape(*args): 50 | if isinstance(args[0], (list, tuple)): 51 | return _get_shape(args[0][0]) 52 | elif isinstance(args[0], dict): 53 | return _get_shape(list(args[0].values())[0]) 54 | else: 55 | return args[0].shape 56 | 57 | h, w, _ = _get_shape(args) 58 | 59 | py = random.randrange(0, h-ps+1) 60 | px = random.randrange(0, w-ps+1) 61 | 62 | def _crop(img): 63 | if img.ndim == 2: 64 | return img[py:py+ps, px:px+ps, np.newaxis] 65 | else: 66 | return img[py:py+ps, px:px+ps, :] 67 | 68 | def _event_crop(event): 69 | return event[:, py:py+ps, px:px+ps] 70 | 71 | return _apply(_crop, args), _apply(_event_crop, left_event), _apply(_event_crop, right_event) 72 | 73 | def crop_event(args, ps, py, px): # patch_size 74 | # args = [input, target] 75 | def _get_shape(*args): 76 | if isinstance(args[0], (list, tuple)): 77 | return _get_shape(args[0][0]) 78 | elif isinstance(args[0], dict): 79 | return _get_shape(list(args[0].values())[0]) 80 | else: 81 | return args[0].shape 82 | 83 | _, h, w = _get_shape(args) 84 | 85 | # py = random.randrange(0, h-ps+1) 86 | # px = random.randrange(0, w-ps+1) 87 | 88 | def _crop(img): 89 | if img.ndim == 2: 90 | return img[py:py+ps, px:px+ps, np.newaxis] 91 | else: 92 | return img[py:py+ps, px:px+ps, :] 93 | 94 | def _event_crop(event): 95 | return event[:, py:py+ps[1], px:px+ps[0]] 96 | 97 | return _apply(_event_crop, args) 98 | 99 | 100 | def crop_disp(args, ps, py, px): # patch_size 101 | # args = [input, target] 102 | def _get_shape(*args): 103 | if isinstance(args[0], (list, tuple)): 104 | return _get_shape(args[0][0]) 105 | elif isinstance(args[0], dict): 106 | return _get_shape(list(args[0].values())[0]) 107 | else: 108 | return args[0].shape 109 | 110 | h, w = _get_shape(args) 111 | 112 | # py = random.randrange(0, h-ps+1) 113 | # px = random.randrange(0, w-ps+1) 114 | 115 | 116 | def _crop(img): 117 | if img.ndim == 2: 118 | return img[py:py+ps[1], px:px+ps[0], np.newaxis] 119 | else: 120 | return img[py:py+ps[1], px:px+ps[0], :] 121 | 122 | def _disp_crop(event): 123 | return event[py:py+ps[1], px:px+ps[0]] 124 | 125 | return _apply(_disp_crop, args) 126 | 127 | def add_noise(*args, sigma_sigma=2, rgb_range=255): 128 | 129 | if len(args) == 1: # usually there is only a single input 130 | args = args[0] 131 | 132 | sigma = np.random.normal() * sigma_sigma * rgb_range/255 133 | 134 | def _add_noise(img): 135 | noise = np.random.randn(*img.shape).astype(np.float32) * sigma 136 | return (img + noise).clip(0, rgb_range) 137 | 138 | return _apply(_add_noise, args) 139 | 140 | def augment(*args, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=255): 141 | """augmentation consistent to input and target""" 142 | 143 | choices = (False, True) 144 | 145 | hflip = hflip and random.choice(choices) 146 | vflip = rot and random.choice(choices) 147 | rot90 = rot and random.choice(choices) 148 | # shuffle = shuffle 149 | 150 | if shuffle: 151 | rgb_order = list(range(3)) 152 | random.shuffle(rgb_order) 153 | if rgb_order == list(range(3)): 154 | shuffle = False 155 | 156 | if change_saturation: 157 | amp_factor = np.random.uniform(0.5, 1.5) 158 | 159 | def _augment(img): 160 | if hflip: img = img[:, ::-1, :] 161 | if vflip: img = img[::-1, :, :] 162 | if rot90: img = img.transpose(1, 0, 2) 163 | if shuffle and img.ndim > 2: 164 | if img.shape[-1] == 3: # RGB image only 165 | img = img[..., rgb_order] 166 | 167 | if change_saturation: 168 | hsv_img = rgb2hsv(img) 169 | hsv_img[..., 1] *= amp_factor 170 | 171 | img = hsv2rgb(hsv_img).clip(0, 1) * rgb_range 172 | 173 | return img.astype(np.float32) 174 | 175 | return _apply(_augment, args) 176 | 177 | def pad(img, divisor=4, pad_width=None, negative=False): 178 | 179 | def _pad_numpy(img, divisor=4, pad_width=None, negative=False): 180 | if pad_width is None: 181 | (h, w, _) = img.shape 182 | pad_h = -h % divisor 183 | pad_w = -w % divisor 184 | pad_width = ((0, pad_h), (0, pad_w), (0, 0)) 185 | 186 | img = np.pad(img, pad_width, mode='edge') 187 | 188 | return img, pad_width 189 | 190 | def _pad_tensor(img, divisor=4, pad_width=None, negative=False): 191 | 192 | n, c, h, w = img.shape 193 | if pad_width is None: 194 | pad_h = -h % divisor 195 | pad_w = -w % divisor 196 | pad_width = (0, pad_w, 0, pad_h) 197 | else: 198 | try: 199 | pad_h = pad_width[0][1] 200 | pad_w = pad_width[1][1] 201 | if isinstance(pad_h, torch.Tensor): 202 | pad_h = pad_h.item() 203 | if isinstance(pad_w, torch.Tensor): 204 | pad_w = pad_w.item() 205 | 206 | pad_width = (0, pad_w, 0, pad_h) 207 | except: 208 | pass 209 | 210 | if negative: 211 | pad_width = [-val for val in pad_width] 212 | 213 | img = torch.nn.functional.pad(img, pad_width, 'reflect') 214 | 215 | return img, pad_width 216 | 217 | if isinstance(img, np.ndarray): 218 | return _pad_numpy(img, divisor, pad_width, negative) 219 | else: # torch.Tensor 220 | return _pad_tensor(img, divisor, pad_width, negative) 221 | 222 | 223 | def disp_pad(img, divisor=4, pad_width=None, negative=False): 224 | 225 | def _pad_numpy(img, divisor=4, pad_width=None, negative=False): 226 | if pad_width is None: 227 | (h, w, _) = img.shape 228 | pad_h = -h % divisor 229 | pad_w = -w % divisor 230 | pad_width = ((0, pad_h), (0, pad_w), (0, 0)) 231 | 232 | img = np.pad(img, pad_width, mode='edge') 233 | 234 | return img, pad_width 235 | 236 | def _pad_tensor(img, divisor=4, pad_width=None, negative=False): 237 | 238 | if isinstance(img, list): 239 | img = img[0].unsqueeze(1) 240 | n, c, h, w = img.shape 241 | else: 242 | img = img.unsqueeze(1) 243 | n, c, h, w = img.shape 244 | if pad_width is None: 245 | pad_h = -h % divisor 246 | pad_w = -w % divisor 247 | pad_width = (0, pad_w, 0, pad_h) 248 | 249 | 250 | else: 251 | try: 252 | # import pdb; pdb.set_trace() 253 | pad_h = pad_width[0][1][0] 254 | pad_w = pad_width[1][1][0] 255 | if isinstance(pad_h, torch.Tensor): 256 | pad_h = pad_h.item() 257 | if isinstance(pad_w, torch.Tensor): 258 | pad_w = pad_w.item() 259 | 260 | pad_width = (0, pad_w, 0, pad_h) 261 | except: 262 | pass 263 | 264 | if negative: 265 | pad_width = [-val for val in pad_width] 266 | 267 | 268 | img = torch.nn.functional.pad(img, pad_width, 'reflect') 269 | 270 | 271 | return img 272 | 273 | if isinstance(img, np.ndarray): 274 | return _pad_numpy(img, divisor, pad_width, negative) 275 | else: # torch.Tensor 276 | return _pad_tensor(img, divisor, pad_width, negative) 277 | 278 | def event_pad(img, divisor=4, pad_width=None, negative=False): 279 | 280 | def _pad_numpy(img, divisor=4, pad_width=None, negative=False): 281 | 282 | if pad_width is None: 283 | (_, h, w) = img.shape 284 | pad_h = -h % divisor 285 | pad_w = -w % divisor 286 | pad_width = ((0, 0), (0, pad_h), (0, pad_w)) 287 | 288 | img = np.pad(img, pad_width, mode='edge') 289 | 290 | return img, pad_width 291 | 292 | def _pad_tensor(img, divisor=4, pad_width=None, negative=False): 293 | 294 | n, c, h, w = img.shape 295 | if pad_width is None: 296 | pad_h = -h % divisor 297 | pad_w = -w % divisor 298 | pad_width = (0, pad_w, 0, pad_h) 299 | else: 300 | try: 301 | pad_h = pad_width[0][1] 302 | pad_w = pad_width[1][1] 303 | if isinstance(pad_h, torch.Tensor): 304 | pad_h = pad_h.item() 305 | if isinstance(pad_w, torch.Tensor): 306 | pad_w = pad_w.item() 307 | 308 | pad_width = (0, pad_w, 0, pad_h) 309 | except: 310 | pass 311 | 312 | if negative: 313 | pad_width = [-val for val in pad_width] 314 | 315 | img = torch.nn.functional.pad(img, pad_width, 'reflect') 316 | 317 | return img, pad_width 318 | 319 | if isinstance(img, np.ndarray): 320 | return _pad_numpy(img, divisor, pad_width, negative) 321 | else: # torch.Tensor 322 | return _pad_tensor(img, divisor, pad_width, negative) 323 | 324 | def generate_pyramid(*args, n_scales): 325 | 326 | def _generate_pyramid(img): 327 | if img.dtype != np.float32: 328 | img = img.astype(np.float32) 329 | pyramid = list(pyramid_gaussian(img, n_scales-1, multichannel=True)) 330 | 331 | return pyramid 332 | 333 | return _apply(_generate_pyramid, args) 334 | 335 | def generate_event_pyramid(*args, n_scales): 336 | def _generate_event_pyramid(event): 337 | event = np.array(event) 338 | 339 | if event.dtype != np.float32: 340 | event = event.astype(np.float32) 341 | 342 | # import pdb 343 | # pdb.set_trace() 344 | # print("len event") 345 | # print(event) 346 | # print([len(a) for a in event]) 347 | 348 | # print(event.shape) 349 | pyramid = [] 350 | for i in range(n_scales): 351 | w, h, c = event.shape 352 | # scale_event = block_reduce(event, (w//(2**i) , h//(2**i), c), np.max) 353 | scale_event = block_reduce(event, (1, 2**i, 2**i), np.max) 354 | # print(scale_event) 355 | # print(event) 356 | pyramid.append(scale_event) 357 | # if i == 2: 358 | # print(scale_event.shape) 359 | # import pdb 360 | # pdb.set_trace() 361 | 362 | 363 | # pyramid = list(pyramid_gaussian(img, n_scales-1, multichannel=True)) 364 | return pyramid 365 | return _generate_event_pyramid(args) 366 | 367 | 368 | def np2tensor(*args): 369 | def _np2tensor(x): 370 | np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1)) 371 | tensor = torch.from_numpy(np_transpose) 372 | 373 | return tensor 374 | 375 | return _apply(_np2tensor, args) 376 | 377 | def image2tensor(x): 378 | np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1)) 379 | tensor = torch.from_numpy(np_transpose) 380 | 381 | return tensor 382 | 383 | def event2tensor(*args): 384 | def _np2tensor(x): 385 | # np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1)) 386 | tensor = torch.from_numpy(x) 387 | return tensor 388 | 389 | 390 | return _apply(_np2tensor, args) 391 | 392 | def to(*args, device=None, dtype=torch.float): 393 | 394 | def _to(x): 395 | return x.to(device=device, dtype=dtype, non_blocking=True, copy=False) 396 | 397 | return _apply(_to, args) -------------------------------------------------------------------------------- /dataloader/contrast/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import numpy as np 4 | import random 5 | import torch 6 | 7 | from .encodings import events_to_voxel, events_to_channels, events_to_mask, get_hot_event_mask 8 | 9 | 10 | class BaseDataLoader(torch.utils.data.Dataset): 11 | """ 12 | Base class for dataloader. 13 | """ 14 | 15 | def __init__(self, config, num_bins): 16 | self.config = config 17 | self.epoch = 0 18 | self.seq_num = 0 19 | self.samples = 0 20 | self.new_seq = False 21 | self.tc_idx = 0 22 | self.num_bins = num_bins 23 | 24 | # batch-specific data augmentation mechanisms 25 | self.batch_augmentation = {} 26 | for mechanism in self.config["loader"]["augment"]: 27 | if mechanism != "Pause": 28 | self.batch_augmentation[mechanism] = [False for i in range(self.config["loader"]["batch_size"])] 29 | else: 30 | self.batch_augmentation[mechanism] = False # shared among batch elements 31 | 32 | for i, mechanism in enumerate(self.config["loader"]["augment"]): 33 | if mechanism != "Pause": 34 | for batch in range(self.config["loader"]["batch_size"]): 35 | if np.random.random() < self.config["loader"]["augment_prob"][i]: 36 | self.batch_augmentation[mechanism][batch] = True 37 | 38 | # hot pixels 39 | if self.config["hot_filter"]["enabled"]: 40 | self.hot_idx = [0 for i in range(self.config["loader"]["batch_size"])] 41 | self.hot_events = [ 42 | torch.zeros(self.config["loader"]["resolution"]) for i in range(self.config["loader"]["batch_size"]) 43 | ] 44 | 45 | @abstractmethod 46 | def __getitem__(self, index): 47 | raise NotImplementedError 48 | 49 | @abstractmethod 50 | def get_events(self, history): 51 | raise NotImplementedError 52 | 53 | def reset_sequence(self, batch): 54 | """ 55 | Reset sequence-specific variables. 56 | :param batch: batch index 57 | """ 58 | 59 | self.tc_idx = 0 60 | self.seq_num += 1 61 | if self.config["hot_filter"]["enabled"]: 62 | self.hot_idx[batch] = 0 63 | self.hot_events[batch] = torch.zeros(self.config["loader"]["resolution"]) 64 | 65 | for i, mechanism in enumerate(self.config["loader"]["augment"]): 66 | if mechanism != "Pause": 67 | if np.random.random() < self.config["loader"]["augment_prob"][i]: 68 | self.batch_augmentation[mechanism][batch] = True 69 | else: 70 | self.batch_augmentation[mechanism][batch] = False 71 | else: 72 | self.batch_augmentation[mechanism] = False 73 | 74 | @staticmethod 75 | def event_formatting(xs, ys, ts, ps): 76 | """ 77 | Reset sequence-specific variables. 78 | :param xs: [N] numpy array with event x location 79 | :param ys: [N] numpy array with event y location 80 | :param ts: [N] numpy array with event timestamp 81 | :param ps: [N] numpy array with event polarity ([-1, 1]) 82 | :return xs: [N] tensor with event x location 83 | :return ys: [N] tensor with event y location 84 | :return ts: [N] tensor with normalized event timestamp 85 | :return ps: [N] tensor with event polarity ([-1, 1]) 86 | """ 87 | 88 | xs = torch.from_numpy(xs.astype(np.float32)) 89 | ys = torch.from_numpy(ys.astype(np.float32)) 90 | ts = torch.from_numpy(ts.astype(np.float32)) 91 | ps = torch.from_numpy(ps.astype(np.float32)) 92 | ts = (ts - ts[0]) / (ts[-1] - ts[0]) 93 | return xs, ys, ts, ps 94 | 95 | def augment_events(self, xs, ys, ps, batch): 96 | """ 97 | Augment event sequence with horizontal, vertical, and polarity flips, and 98 | artificial event pauses. 99 | :return xs: [N] tensor with event x location 100 | :return ys: [N] tensor with event y location 101 | :return ps: [N] tensor with event polarity ([-1, 1]) 102 | :param batch: batch index 103 | :return xs: [N] tensor with augmented event x location 104 | :return ys: [N] tensor with augmented event y location 105 | :return ps: [N] tensor with augmented event polarity ([-1, 1]) 106 | """ 107 | 108 | for i, mechanism in enumerate(self.config["loader"]["augment"]): 109 | 110 | if mechanism == "Horizontal": 111 | if self.batch_augmentation["Horizontal"][batch]: 112 | xs = self.config["loader"]["resolution"][1] - 1 - xs 113 | 114 | elif mechanism == "Vertical": 115 | if self.batch_augmentation["Vertical"][batch]: 116 | ys = self.config["loader"]["resolution"][0] - 1 - ys 117 | 118 | elif mechanism == "Polarity": 119 | if self.batch_augmentation["Polarity"][batch]: 120 | ps *= -1 121 | 122 | # shared among batch elements 123 | elif ( 124 | batch == 0 125 | and mechanism == "Pause" 126 | and self.tc_idx > self.config["loss"]["reconstruction_tc_idx_threshold"] 127 | ): 128 | if not self.batch_augmentation["Pause"]: 129 | if np.random.random() < self.config["loader"]["augment_prob"][i][0]: 130 | self.batch_augmentation["Pause"] = True 131 | else: 132 | if np.random.random() < self.config["loader"]["augment_prob"][i][1]: 133 | self.batch_augmentation["Pause"] = False 134 | 135 | return xs, ys, ps 136 | 137 | def augment_frames(self, img, batch): 138 | """ 139 | Augment APS frame with horizontal and vertical flips. 140 | :param img: [H x W] numpy array with APS intensity 141 | :param batch: batch index 142 | :return img: [H x W] augmented numpy array with APS intensity 143 | """ 144 | if "Horizontal" in self.batch_augmentation: 145 | if self.batch_augmentation["Horizontal"][batch]: 146 | img = np.flip(img, 1) 147 | if "Vertical" in self.batch_augmentation: 148 | if self.batch_augmentation["Vertical"][batch]: 149 | img = np.flip(img, 0) 150 | return img 151 | 152 | def create_cnt_encoding(self, xs, ys, ts, ps): 153 | """ 154 | Creates a per-pixel and per-polarity event count representation. 155 | :param xs: [N] tensor with event x location 156 | :param ys: [N] tensor with event y location 157 | :param ts: [N] tensor with normalized event timestamp 158 | :param ps: [N] tensor with event polarity ([-1, 1]) 159 | :return [2 x H x W] event representation 160 | """ 161 | 162 | return events_to_channels(xs, ys, ps, sensor_size=self.config["loader"]["resolution"]) 163 | 164 | def create_voxel_encoding(self, xs, ys, ts, ps): 165 | """ 166 | Creates a spatiotemporal voxel grid tensor representation with a certain number of bins, 167 | as described in Section 3.1 of the paper 'Unsupervised Event-based Learning of Optical Flow, 168 | Depth, and Egomotion', Zhu et al., CVPR'19.. 169 | Events are distributed to the spatiotemporal closest bins through bilinear interpolation. 170 | Positive events are added as +1, while negative as -1. 171 | :param xs: [N] tensor with event x location 172 | :param ys: [N] tensor with event y location 173 | :param ts: [N] tensor with normalized event timestamp 174 | :param ps: [N] tensor with event polarity ([-1, 1]) 175 | :return [B x H x W] event representation 176 | """ 177 | 178 | return events_to_voxel( 179 | xs, 180 | ys, 181 | ts, 182 | ps, 183 | self.num_bins, 184 | sensor_size=self.config["loader"]["resolution"], 185 | ) 186 | 187 | @staticmethod 188 | def create_list_encoding(xs, ys, ts, ps): 189 | """ 190 | Creates a four channel tensor with all the events in the input partition. 191 | :param xs: [N] tensor with event x location 192 | :param ys: [N] tensor with event y location 193 | :param ts: [N] tensor with normalized event timestamp 194 | :param ps: [N] tensor with event polarity ([-1, 1]) 195 | :return [N x 4] event representation 196 | """ 197 | 198 | return torch.stack([ts, ys, xs, ps]) 199 | 200 | @staticmethod 201 | def create_polarity_mask(ps): 202 | """ 203 | Creates a two channel tensor that acts as a mask for the input event list. 204 | :param ps: [N] tensor with event polarity ([-1, 1]) 205 | :return [N x 2] event representation 206 | """ 207 | 208 | inp_pol_mask = torch.stack([ps, ps]) 209 | inp_pol_mask[0, :][inp_pol_mask[0, :] < 0] = 0 210 | inp_pol_mask[1, :][inp_pol_mask[1, :] > 0] = 0 211 | inp_pol_mask[1, :] *= -1 212 | return inp_pol_mask 213 | 214 | def create_hot_mask(self, xs, ys, ps, batch): 215 | """ 216 | Creates a one channel tensor that can act as mask to remove pixel with high event rate. 217 | :param xs: [N] tensor with event x location 218 | :param ys: [N] tensor with event y location 219 | :param ps: [N] tensor with event polarity ([-1, 1]) 220 | :return [H x W] binary mask 221 | """ 222 | 223 | hot_update = events_to_mask(xs, ys, ps, sensor_size=self.hot_events[batch].shape) 224 | self.hot_events[batch] += hot_update 225 | self.hot_idx[batch] += 1 226 | event_rate = self.hot_events[batch] / self.hot_idx[batch] 227 | return get_hot_event_mask( 228 | event_rate, 229 | self.hot_idx[batch], 230 | max_px=self.config["hot_filter"]["max_px"], 231 | min_obvs=self.config["hot_filter"]["min_obvs"], 232 | max_rate=self.config["hot_filter"]["max_rate"], 233 | ) 234 | 235 | # def __len__(self): 236 | # return 1000 # not used 237 | 238 | @staticmethod 239 | def custom_collate(batch): 240 | """ 241 | Collects the different event representations and stores them together in a dictionary. 242 | """ 243 | # breakpoint() 244 | batch_dict = {} 245 | for key in batch[0].keys(): 246 | batch_dict[key] = [] 247 | for entry in batch: 248 | for key in entry.keys(): 249 | batch_dict[key].append(entry[key]) 250 | for key in batch_dict.keys(): 251 | if len(batch_dict[key][0].shape) > 2: 252 | item = torch.stack(batch_dict[key]) 253 | # if len(item.shape) == 3: 254 | # item = item.transpose(2, 1) 255 | batch_dict[key] = item 256 | else: 257 | batch_dict[key] = [e_list.transpose(0, 1) for e_list in batch_dict[key]] 258 | return batch_dict 259 | 260 | # def shuffle(self, flag=True): 261 | # """ 262 | # Shuffles the training data. 263 | # """ 264 | 265 | # if flag: 266 | # random.shuffle(self.files) 267 | -------------------------------------------------------------------------------- /dataloader/contrast/encodings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Monash University https://github.com/TimoStoff/events_contrast_maximization 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def binary_search_array(array, x, left=None, right=None, side="left"): 10 | """ 11 | Binary search through a sorted array. 12 | """ 13 | 14 | left = 0 if left is None else left 15 | right = len(array) - 1 if right is None else right 16 | mid = left + (right - left) // 2 17 | 18 | if left > right: 19 | return left if side == "left" else right 20 | 21 | if array[mid] == x: 22 | return mid 23 | 24 | if x < array[mid]: 25 | return binary_search_array(array, x, left=left, right=mid - 1) 26 | 27 | return binary_search_array(array, x, left=mid + 1, right=right) 28 | 29 | 30 | def events_to_mask(xs, ys, ps, sensor_size=(180, 240)): 31 | """ 32 | Accumulate events into a binary mask. 33 | """ 34 | 35 | device = xs.device 36 | img_size = list(sensor_size) 37 | mask = torch.zeros(img_size).to(device) 38 | 39 | if xs.dtype is not torch.long: 40 | xs = xs.long().to(device) 41 | if ys.dtype is not torch.long: 42 | ys = ys.long().to(device) 43 | mask.index_put_((ys, xs), ps.abs(), accumulate=False) 44 | 45 | return mask 46 | 47 | 48 | def events_to_image(xs, ys, ps, sensor_size=(180, 240)): 49 | """ 50 | Accumulate events into an image. 51 | """ 52 | 53 | device = xs.device 54 | img_size = list(sensor_size) 55 | img = torch.zeros(img_size).to(device) 56 | 57 | if xs.dtype is not torch.long: 58 | xs = xs.long().to(device) 59 | if ys.dtype is not torch.long: 60 | ys = ys.long().to(device) 61 | img.index_put_((ys, xs), ps, accumulate=True) 62 | 63 | return img 64 | 65 | 66 | def events_to_voxel(xs, ys, ts, ps, num_bins, sensor_size=(180, 240)): 67 | """ 68 | Generate a voxel grid from input events using temporal bilinear interpolation. 69 | """ 70 | 71 | assert len(xs) == len(ys) and len(ys) == len(ts) and len(ts) == len(ps) 72 | 73 | voxel = [] 74 | ts = ts * (num_bins - 1) 75 | zeros = torch.zeros(ts.size()) 76 | for b_idx in range(num_bins): 77 | weights = torch.max(zeros, 1.0 - torch.abs(ts - b_idx)) 78 | voxel_bin = events_to_image(xs, ys, ps * weights, sensor_size=sensor_size) 79 | voxel.append(voxel_bin) 80 | 81 | return torch.stack(voxel) 82 | 83 | 84 | def events_to_channels(xs, ys, ps, sensor_size=(180, 240)): 85 | """ 86 | Generate a two-channel event image containing event counters. 87 | """ 88 | 89 | assert len(xs) == len(ys) and len(ys) == len(ps) 90 | 91 | mask_pos = ps.clone() 92 | mask_neg = ps.clone() 93 | mask_pos[ps < 0] = 0 94 | mask_neg[ps > 0] = 0 95 | 96 | pos_cnt = events_to_image(xs, ys, ps * mask_pos, sensor_size=sensor_size) 97 | neg_cnt = events_to_image(xs, ys, ps * mask_neg, sensor_size=sensor_size) 98 | 99 | return torch.stack([pos_cnt, neg_cnt]) 100 | 101 | 102 | def get_hot_event_mask(event_rate, idx, max_px=100, min_obvs=5, max_rate=0.8): 103 | """ 104 | Returns binary mask to remove events from hot pixels. 105 | """ 106 | 107 | mask = torch.ones(event_rate.shape).to(event_rate.device) 108 | if idx > min_obvs: 109 | for i in range(max_px): 110 | argmax = torch.argmax(event_rate) 111 | index = (argmax // event_rate.shape[1], argmax % event_rate.shape[1]) 112 | if event_rate[index] > max_rate: 113 | event_rate[index] = 0 114 | mask[index] = 0 115 | else: 116 | break 117 | return mask 118 | -------------------------------------------------------------------------------- /dataloader/contrast/raw_event_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import numpy as np 4 | 5 | import torch 6 | import torch.utils.data as data 7 | 8 | # from .utils import ProgressBar 9 | 10 | # from .encodings import binary_search_array 11 | from .encodings import events_to_voxel, events_to_channels, events_to_mask, get_hot_event_mask 12 | from torch.utils.data.dataloader import default_collate 13 | 14 | 15 | class EventSequenceLoader(): 16 | def __init__(self, config): 17 | 18 | self.root_path = config.get("root", '') 19 | self.resolution = config['resolution'] 20 | self.num_bins = 5 21 | 22 | def get_events(self, path): 23 | """ 24 | Get all the events in between two indices. 25 | :param file: file to read from 26 | :param idx0: start index 27 | :param idx1: end index 28 | :return xs: [N] numpy array with event x location 29 | :return ys: [N] numpy array with event y location 30 | :return ts: [N] numpy array with event timestamp 31 | :return ps: [N] numpy array with event polarity ([-1, 1]) 32 | """ 33 | file = np.load(path) 34 | xs = file[:, 1] 35 | ys = file[:, 2] 36 | ts = file[:, 0] 37 | ps = file[:, 3] 38 | ts -= file[0][0] # sequence starting at t0 = 0 39 | ts *= 1.0e6 # us 40 | return xs, ys, ts, ps 41 | 42 | @staticmethod 43 | def event_formatting(xs, ys, ts, ps): 44 | """ 45 | Reset sequence-specific variables. 46 | :param xs: [N] numpy array with event x location 47 | :param ys: [N] numpy array with event y location 48 | :param ts: [N] numpy array with event timestamp 49 | :param ps: [N] numpy array with event polarity ([-1, 1]) 50 | :return xs: [N] tensor with event x location 51 | :return ys: [N] tensor with event y location 52 | :return ts: [N] tensor with normalized event timestamp 53 | :return ps: [N] tensor with event polarity ([-1, 1]) 54 | """ 55 | 56 | xs = torch.from_numpy(xs.astype(np.float32)) 57 | ys = torch.from_numpy(ys.astype(np.float32)) 58 | ts = torch.from_numpy(ts.astype(np.float32)) 59 | ps = torch.from_numpy(ps.astype(np.float32)) 60 | if int(ps.min()) == 0: 61 | ps = 2*ps-1 62 | ts = (ts - ts[0]) / (ts[-1] - ts[0]) 63 | return xs, ys, ts, ps 64 | 65 | 66 | def create_cnt_encoding(self, xs, ys, ts, ps): 67 | """ 68 | Creates a per-pixel and per-polarity event count representation. 69 | :param xs: [N] tensor with event x location 70 | :param ys: [N] tensor with event y location 71 | :param ts: [N] tensor with normalized event timestamp 72 | :param ps: [N] tensor with event polarity ([-1, 1]) 73 | :return [2 x H x W] event representation 74 | """ 75 | 76 | return events_to_channels(xs, ys, ps, sensor_size=self.resolution) 77 | 78 | def create_voxel_encoding(self, xs, ys, ts, ps): 79 | """ 80 | Creates a spatiotemporal voxel grid tensor representation with a certain number of bins, 81 | as described in Section 3.1 of the paper 'Unsupervised Event-based Learning of Optical Flow, 82 | Depth, and Egomotion', Zhu et al., CVPR'19.. 83 | Events are distributed to the spatiotemporal closest bins through bilinear interpolation. 84 | Positive events are added as +1, while negative as -1. 85 | :param xs: [N] tensor with event x location 86 | :param ys: [N] tensor with event y location 87 | :param ts: [N] tensor with normalized event timestamp 88 | :param ps: [N] tensor with event polarity ([-1, 1]) 89 | :return [B x H x W] event representation 90 | """ 91 | 92 | return events_to_voxel( 93 | xs, 94 | ys, 95 | ts, 96 | ps, 97 | self.num_bins, 98 | sensor_size=self.resolution, 99 | ) 100 | 101 | @staticmethod 102 | def create_list_encoding(xs, ys, ts, ps): 103 | """ 104 | Creates a four channel tensor with all the events in the input partition. 105 | :param xs: [N] tensor with event x location 106 | :param ys: [N] tensor with event y location 107 | :param ts: [N] tensor with normalized event timestamp 108 | :param ps: [N] tensor with event polarity ([-1, 1]) 109 | :return [N x 4] event representation 110 | """ 111 | 112 | return torch.stack([ts, ys, xs, ps]) 113 | 114 | @staticmethod 115 | def create_polarity_mask(ps): 116 | """ 117 | Creates a two channel tensor that acts as a mask for the input event list. 118 | :param ps: [N] tensor with event polarity ([-1, 1]) 119 | :return [N x 2] event representation 120 | """ 121 | 122 | inp_pol_mask = torch.stack([ps, ps]) 123 | inp_pol_mask[0, :][inp_pol_mask[0, :] < 0] = 0 124 | inp_pol_mask[1, :][inp_pol_mask[1, :] > 0] = 0 125 | inp_pol_mask[1, :] *= -1 126 | return inp_pol_mask 127 | 128 | @staticmethod 129 | def event_padding(xs, ys, ts, ps, pad): 130 | (pad_left, pad_right, pad_top, pad_bottom) = pad 131 | return xs+pad_left, ys+pad_top, ts, ps 132 | 133 | def get_from_data(self, xs, ys, ts, ps, pad): 134 | 135 | # xs, ys, ts, ps = self.get_events(path) 136 | 137 | # timestamp normalization 138 | xs, ys, ts, ps = self.event_formatting(xs, ys, ts, ps) 139 | # padding 140 | # xs, ys, ts, ps = self.event_padding(xs, ys, ts, ps, pad) 141 | 142 | # data augmentation 143 | # xs, ys, ps = self.augment_events(xs, ys, ps, batch) 144 | 145 | # artificial pauses to the event stream 146 | # if "Pause" in self.config["loader"]["augment"]: 147 | # if self.batch_augmentation["Pause"]: 148 | # xs = torch.from_numpy(np.empty([0]).astype(np.float32)) 149 | # ys = torch.from_numpy(np.empty([0]).astype(np.float32)) 150 | # ts = torch.from_numpy(np.empty([0]).astype(np.float32)) 151 | # ps = torch.from_numpy(np.empty([0]).astype(np.float32)) 152 | 153 | # events to tensors 154 | # inp_cnt = self.create_cnt_encoding(xs, ys, ts, ps) 155 | # inp_voxel = self.create_voxel_encoding(xs, ys, ts, ps) 156 | inp_list = self.create_list_encoding(xs, ys, ts, ps) 157 | inp_pol_mask = self.create_polarity_mask(ps) 158 | 159 | # hot pixel removal 160 | # if self.config["hot_filter"]["enabled"]: 161 | # hot_mask = self.create_hot_mask(xs, ys, ps, batch) 162 | # hot_mask_voxel = torch.stack([hot_mask] * self.num_bins, axis=2).permute(2, 0, 1) 163 | # hot_mask_cnt = torch.stack([hot_mask] * 2, axis=2).permute(2, 0, 1) 164 | # inp_voxel = inp_voxel * hot_mask_voxel 165 | # inp_cnt = inp_cnt * hot_mask_cnt 166 | 167 | 168 | # prepare output 169 | output = {} 170 | # output["inp_cnt"] = inp_cnt 171 | # output["inp_voxel"] = inp_voxel 172 | 173 | # 4 x N to N x 4 174 | output["inp_list"] = inp_list.permute(1, 0) 175 | output["inp_pol_mask"] = inp_pol_mask.permute(1, 0) 176 | 177 | return output 178 | 179 | 180 | def get_from_path(self, path, pad): 181 | 182 | # load events 183 | xs = np.zeros((0)) 184 | ys = np.zeros((0)) 185 | ts = np.zeros((0)) 186 | ps = np.zeros((0)) 187 | 188 | xs, ys, ts, ps = self.get_events(path) 189 | 190 | # timestamp normalization 191 | xs, ys, ts, ps = self.event_formatting(xs, ys, ts, ps) 192 | # padding 193 | # xs, ys, ts, ps = self.event_padding(xs, ys, ts, ps, pad) 194 | 195 | # data augmentation 196 | # xs, ys, ps = self.augment_events(xs, ys, ps, batch) 197 | 198 | # artificial pauses to the event stream 199 | # if "Pause" in self.config["loader"]["augment"]: 200 | # if self.batch_augmentation["Pause"]: 201 | # xs = torch.from_numpy(np.empty([0]).astype(np.float32)) 202 | # ys = torch.from_numpy(np.empty([0]).astype(np.float32)) 203 | # ts = torch.from_numpy(np.empty([0]).astype(np.float32)) 204 | # ps = torch.from_numpy(np.empty([0]).astype(np.float32)) 205 | 206 | # events to tensors 207 | # inp_cnt = self.create_cnt_encoding(xs, ys, ts, ps) 208 | # inp_voxel = self.create_voxel_encoding(xs, ys, ts, ps) 209 | inp_list = self.create_list_encoding(xs, ys, ts, ps) 210 | inp_pol_mask = self.create_polarity_mask(ps) 211 | 212 | # hot pixel removal 213 | # if self.config["hot_filter"]["enabled"]: 214 | # hot_mask = self.create_hot_mask(xs, ys, ps, batch) 215 | # hot_mask_voxel = torch.stack([hot_mask] * self.num_bins, axis=2).permute(2, 0, 1) 216 | # hot_mask_cnt = torch.stack([hot_mask] * 2, axis=2).permute(2, 0, 1) 217 | # inp_voxel = inp_voxel * hot_mask_voxel 218 | # inp_cnt = inp_cnt * hot_mask_cnt 219 | 220 | 221 | # prepare output 222 | output = {} 223 | # output["inp_cnt"] = inp_cnt 224 | # output["inp_voxel"] = inp_voxel 225 | 226 | # 4 x N to N x 4 227 | output["inp_list"] = inp_list.permute(1, 0) 228 | output["inp_pol_mask"] = inp_pol_mask.permute(1, 0) 229 | 230 | return output 231 | 232 | # B x frame x (datas) -> frame x B*(datas) 233 | def custom_collate(batch): # batch = ((left_padded, right_padded, disp_padded, pad, debug, "left_inp_list", "left_inp_pol_mask", "right_inp_list", "right_inp_pol_mask)) 234 | """ 235 | Collects the different event representations and stores them together in a dictionary. 236 | """ 237 | B = len(batch) 238 | num_frame = len(batch[0]) 239 | 240 | datasets = [([[] for i in range(4)] + [{}]) for i in range(num_frame)] 241 | 242 | for bn, data in enumerate(batch): 243 | for f, d in enumerate(data): 244 | datasets[f][0].append(d[0]) 245 | datasets[f][1].append(d[1]) 246 | datasets[f][2].append(d[2]) 247 | datasets[f][3].append(d[3]) 248 | if bn == 0: 249 | for key in d[4].keys(): 250 | datasets[f][4][key] = [] 251 | for key in d[4].keys(): 252 | datasets[f][4][key].append(d[4][key]) 253 | 254 | for fi, frame_data in enumerate(datasets): 255 | for i, data in enumerate(frame_data): 256 | if isinstance(data, dict): 257 | for key in data.keys(): 258 | try: 259 | frame_data[i][key] = default_collate(data[key]) 260 | except: 261 | continue 262 | else: 263 | frame_data[i] = default_collate(data) 264 | 265 | return datasets 266 | -------------------------------------------------------------------------------- /dataloader/dsceDataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional 3 | from torch.nn.functional import pad 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import os 9 | import json 10 | import cv2 11 | import random 12 | import pdb 13 | 14 | 15 | class DSECdataset(Dataset): 16 | def __init__(self, root: str, ann: list, height: int, width: int, frame_idxs: list = [0, ], eval_maxdisp=255.0): 17 | self.root = root 18 | self.data_list = ann 19 | self.height = height 20 | self.width = width 21 | self.frame_idxs = frame_idxs 22 | self.eval_maxdisp = eval_maxdisp 23 | 24 | 25 | def add_padding(self, image, pad=0): 26 | H, W = image.shape[-2:] 27 | pad_left = 0 28 | pad_right = self.width - W 29 | pad_top = self.height - H 30 | pad_bottom = 0 31 | assert pad_left >= 0 and pad_right >= 0 and pad_top >= 0 and pad_bottom >= 0, "Require image crop, please check the image size" 32 | padded = torch.nn.functional.pad(image, [pad_left, pad_right, pad_top, pad_bottom], mode='constant', value=pad) 33 | return padded, (pad_left, pad_right, pad_top, pad_bottom) 34 | 35 | def __len__(self): 36 | return len(self.data_list) 37 | 38 | def __getitem__(self, i): 39 | data = [] 40 | # get dictionary of frame idxs 41 | metadata = self.data_list[i] 42 | for idx in self.frame_idxs: 43 | frame_info = metadata.get(str(idx), None) 44 | # get relative file path respect to self.root 45 | assert frame_info is not None, "Cannot get frame info of frame {:d}:{:d}, check json file".format(i, idx) 46 | left_event_path = frame_info.get('left_image_path', None) 47 | assert left_event_path is not None 48 | right_event_path = frame_info.get('right_image_path', None) 49 | assert right_event_path is not None 50 | disp_path = frame_info.get('left_disp_path', None) 51 | assert disp_path is not None 52 | flow_path = frame_info.get('left_flow_path', None) 53 | assert flow_path is not None 54 | 55 | left_event_path = os.path.join(self.root, left_event_path) 56 | right_event_path = os.path.join(self.root, right_event_path) 57 | disp_path = os.path.join(self.root, disp_path) 58 | flow_path = os.path.join(self.root, flow_path) 59 | left_img_path = left_event_path.replace('num5voxel0', 'image0').replace('npy', 'png') 60 | right_img_path = right_event_path.replace('num5voxel1', 'image1').replace('npy', 'png') 61 | 62 | left_event = np.load(left_event_path) 63 | right_event = np.load(right_event_path) 64 | # need to be (C, H, W) 65 | assert left_event.ndim==3 66 | assert right_event.ndim==3 67 | flow_gt = np.load(flow_path) 68 | # pdb.set_trace() 69 | assert flow_gt.ndim==3 70 | 71 | # (H, W) 72 | disp = cv2.imread(disp_path, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_ANYCOLOR) 73 | # wierd behavior of mvsec, just do it 74 | disp = disp / 7.0 75 | 76 | left_event = torch.tensor(left_event) 77 | right_event = torch.tensor(right_event) 78 | disp = torch.tensor(disp) 79 | flow_gt = torch.tensor(flow_gt) 80 | left_img = cv2.imread(left_img_path, cv2.IMREAD_GRAYSCALE) 81 | right_img = cv2.imread(right_img_path, cv2.IMREAD_GRAYSCALE) 82 | left_img = torch.tensor(left_img).type(torch.float32)/255.0 83 | right_img = torch.tensor(right_img).type(torch.float32)/255.0 84 | 85 | left_padded, pad = self.add_padding(left_event) 86 | right_padded, _ = self.add_padding(right_event) 87 | disp_padded , _ = self.add_padding(disp, 255.0) 88 | flow_gt, _ = self.add_padding(flow_gt) 89 | left_img_padded, _ = self.add_padding(left_img) 90 | right_img_padded, _ = self.add_padding(right_img) 91 | 92 | disp_mask = disp_padded > self.eval_maxdisp 93 | disp_padded[disp_mask] = float('Nan') 94 | # disp_padded[disp_mask] = float(1e5) 95 | 96 | debug = {"left_event_path": left_event_path, 97 | "right_event_path": right_event_path, 98 | "disp_path": disp_path, 99 | "flow_gt": flow_gt, 100 | "left_img": left_img, 101 | "right_img": right_img} 102 | data.append((left_padded, right_padded, disp_padded, pad, debug)) 103 | return data 104 | 105 | def get_datasets(root, split, height, width, frame_idxs, num_validation, eval_maxdisp): 106 | assert max(frame_idxs) == 0 107 | # frame_idxs.sort() 108 | assert os.path.isdir(root) 109 | 110 | 111 | train_test_file_dict = [] 112 | for train_test in ['train', 'test']: 113 | if train_test == 'train': 114 | seqs = train_seqs 115 | frame_filter = FRAMES_FILTER_FOR_TRAINING['indoor_flying'] 116 | elif train_test == 'test': 117 | seqs = test_seqs 118 | frame_filter = FRAMES_FILTER_FOR_TEST['indoor_flying'] 119 | else: 120 | raise NotImplementedError 121 | 122 | file_dict = [] 123 | # make file sub-path for train 124 | for seq in seqs: 125 | seq_dir = f'indoor_flying_{seq}' 126 | left_dir_path = os.path.join(seq_dir, 'num5voxel0') 127 | right_dir_path = os.path.join(seq_dir, 'num5voxel1') 128 | disp_dir_path = os.path.join(seq_dir, 'disparity_image') 129 | flow_dir_path = os.path.join(seq_dir, 'flow0') 130 | for frame in frame_filter[seq][-min(frame_idxs):]: 131 | left_right_disp = {} 132 | for fi in frame_idxs: 133 | num_str = "{:06d}".format(frame + fi) 134 | event_name = num_str + ".npy" 135 | disp_name = num_str + ".png" 136 | assert os.path.isfile(os.path.join(root, left_dir_path, event_name)) 137 | assert os.path.isfile(os.path.join(root, right_dir_path, event_name)) 138 | assert os.path.isfile(os.path.join(root, disp_dir_path, disp_name)) 139 | left_right_disp[str(fi)] = { 140 | 'left_image_path': os.path.join(left_dir_path, event_name), 141 | 'right_image_path': os.path.join(right_dir_path, event_name), 142 | 'left_disp_path': os.path.join(disp_dir_path, disp_name), 143 | 'left_flow_path': os.path.join(flow_dir_path, event_name) 144 | } 145 | file_dict.append(left_right_disp) 146 | train_test_file_dict.append(file_dict) 147 | 148 | train_dataset = MVSECdataset( 149 | root, 150 | train_test_file_dict[0], 151 | height, 152 | width, 153 | frame_idxs, 154 | eval_maxdisp) 155 | 156 | if num_validation > 0: 157 | import random 158 | random.shuffle(train_test_file_dict[1]) 159 | validation_dataset = MVSECdataset( 160 | root, 161 | train_test_file_dict[1][:200], 162 | height, 163 | width, 164 | frame_idxs, 165 | eval_maxdisp) 166 | test_dataset = MVSECdataset( 167 | root, 168 | train_test_file_dict[1][200:], 169 | height, 170 | width, 171 | frame_idxs, 172 | eval_maxdisp) 173 | else: 174 | validation_dataset = DSECdataset( 175 | root, 176 | train_test_file_dict[1], 177 | height, 178 | width, 179 | frame_idxs, 180 | eval_maxdisp) 181 | test_dataset = DSECdataset( 182 | root, 183 | train_test_file_dict[1], 184 | height, 185 | width, 186 | frame_idxs, 187 | eval_maxdisp) 188 | 189 | return train_dataset, validation_dataset, test_dataset 190 | 191 | 192 | # for debugging 193 | if __name__ == '__main__': 194 | # mvsecdataset = MVSECdataset( 195 | # '/home/jaeyoung/data/ws/event_stereo_ICCV2019/dataset', 196 | # '/home/jaeyoung/data/ws/event_stereo_ICCV2019/dataset/view_4_train_v5_split1.json', 197 | # 288, 198 | # 352 199 | # ) 200 | # test = mvsecdataset[0] 201 | # breakpoint() 202 | new_train, new_val, new_test = get_datasets('/home/jaeyoung/data/ws/event_stereo_ICCV2019/dataset', 1, 288, 352, [-3, -2, -1, 0], 0) 203 | # min_disp = [] 204 | # for i in new_train: 205 | # min_disp.append(torch.min(i[-1][2]).item()) 206 | new_train[0] 207 | breakpoint() 208 | # getDataset(root, 'event_left', 'event_right', 'disparity_image', 80, 1260) 209 | -------------------------------------------------------------------------------- /dataloader/dsecProvider_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import pdb 5 | from .dsecSequence import Sequence 6 | 7 | DATA_SPLIT = { 8 | 'train': ['interlaken_00_c', 'interlaken_00_d', 'interlaken_00_e', 'zurich_city_00_a', 9 | 'zurich_city_00_b', 'zurich_city_01_a', 'zurich_city_01_b', 'zurich_city_01_c', 10 | 'zurich_city_01_d', 'zurich_city_01_e', 'zurich_city_02_a', 11 | 'zurich_city_02_b', 'zurich_city_02_c', 'zurich_city_02_d', 'zurich_city_02_e', 12 | 'zurich_city_03_a', 'zurich_city_04_a', 'zurich_city_04_b', 'zurich_city_04_c', 13 | 'zurich_city_04_d', 'zurich_city_04_e', 'zurich_city_04_f', 'zurich_city_09_a', 14 | 'zurich_city_09_b', 'zurich_city_09_e', 'zurich_city_10_a', 15 | 'zurich_city_11_a', 'zurich_city_11_c', 16 | 'interlaken_00_f', 'interlaken_00_g', 'thun_00_a', 'zurich_city_05_a', 17 | 'zurich_city_05_b', 'zurich_city_07_a', 18 | 'zurich_city_09_d', 'zurich_city_10_b'], 19 | 'validation': ['zurich_city_01_f', 'zurich_city_06_a', 'zurich_city_11_b','zurich_city_09_c', 20 | 'zurich_city_08_a'] 21 | } 22 | 23 | DATA_SPLIT_MINI = { 24 | 'train': ['zurich_city_03_a', 'zurich_city_04_a', 'zurich_city_04_b', 'zurich_city_04_c', 25 | 'zurich_city_04_d', 'zurich_city_04_e', 'zurich_city_04_f', 'zurich_city_09_a', 26 | ], 27 | 'validation': ['zurich_city_01_f', 'zurich_city_06_a'] 28 | } 29 | 30 | DATA_SPLIT_SUPER_MINI = { 31 | 'train': ['zurich_city_04_a' ], 32 | 'validation': ['zurich_city_04_d'] 33 | } 34 | 35 | class DatasetProvider: 36 | def __init__(self, dataset_path: Path, raw_dataset_path: Path, delta_t_ms: int=100, num_bins=15, frame_idxs = range(-3, 1), 37 | eval_maxdisp=192, pseudo_path=None, pad_width=648, pad_height=480, use_mini = False, use_super_mini = False, img_load = False): 38 | 39 | train_path = dataset_path / 'train' 40 | train_raw_path = raw_dataset_path / 'train' 41 | 42 | assert dataset_path.is_dir(), str(dataset_path) 43 | assert train_raw_path.is_dir(), str(train_raw_path) 44 | assert train_path.is_dir(), str(train_path) 45 | 46 | if use_super_mini: 47 | data_split = DATA_SPLIT_SUPER_MINI 48 | elif use_mini: 49 | data_split = DATA_SPLIT_MINI 50 | else: 51 | data_split = DATA_SPLIT 52 | 53 | test_path = dataset_path / 'test' 54 | test_raw_path = raw_dataset_path / 'test' 55 | 56 | assert test_path.is_dir(), str(test_path) 57 | 58 | test_sequences = list() 59 | for child in test_path.iterdir(): 60 | if str(child).split('/')[-1] not in 'thun_02_a': 61 | raw_child = test_raw_path / str(child).split('/')[-1] 62 | test_sequences.append(Sequence(child, 'test', delta_t_ms, num_bins, frame_idxs, eval_maxdisp, pseudo_path=pseudo_path, pad_width=pad_width, pad_height=pad_height, 63 | raw_seq_path = raw_child, img_load = img_load)) 64 | 65 | 66 | self.test_dataset = torch.utils.data.ConcatDataset(test_sequences) 67 | 68 | def get_train_dataset(self): 69 | return self.train_dataset 70 | 71 | def get_val_dataset(self): 72 | return self.val_dataset 73 | 74 | def get_test_dataset(self): 75 | return self.test_dataset -------------------------------------------------------------------------------- /dataloader/eventslicer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Tuple 3 | 4 | import h5py 5 | from numba import jit 6 | import numpy as np 7 | 8 | 9 | class EventSlicer: 10 | def __init__(self, h5f: h5py.File): 11 | self.h5f = h5f 12 | 13 | self.events = dict() 14 | for dset_str in ['p', 'x', 'y', 't']: 15 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)] 16 | 17 | # This is the mapping from milliseconds to event index: 18 | # It is defined such that 19 | # (1) t[ms_to_idx[ms]] >= ms*1000 20 | # (2) t[ms_to_idx[ms] - 1] < ms*1000 21 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds. 22 | # 23 | # As an example, given 't' and 'ms': 24 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000 25 | # ms: 0 1 2 3 4 5 6 7 8 9 26 | # 27 | # we get 28 | # 29 | # ms_to_idx: 30 | # 0 2 2 3 3 3 5 5 8 9 31 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64') 32 | 33 | if "t_offset" in list(h5f.keys()): 34 | self.t_offset = int(h5f['t_offset'][()]) 35 | else: 36 | self.t_offset = 0 37 | self.t_final = int(self.events['t'][-1]) + self.t_offset 38 | 39 | def get_start_time_us(self): 40 | return self.t_offset 41 | 42 | def get_final_time_us(self): 43 | return self.t_final 44 | 45 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]: 46 | """Get events (p, x, y, t) within the specified time window 47 | Parameters 48 | ---------- 49 | t_start_us: start time in microseconds 50 | t_end_us: end time in microseconds 51 | Returns 52 | ------- 53 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved 54 | """ 55 | assert t_start_us < t_end_us 56 | 57 | # We assume that the times are top-off-day, hence subtract offset: 58 | t_start_us -= self.t_offset 59 | t_end_us -= self.t_offset 60 | 61 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us) 62 | t_start_ms_idx = self.ms2idx(t_start_ms) 63 | t_end_ms_idx = self.ms2idx(t_end_ms) 64 | 65 | if t_start_ms_idx is None or t_end_ms_idx is None: 66 | # Cannot guarantee window size anymore 67 | return None 68 | 69 | events = dict() 70 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx]) 71 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us) 72 | t_start_us_idx = t_start_ms_idx + idx_start_offset 73 | t_end_us_idx = t_start_ms_idx + idx_end_offset 74 | # Again add t_offset to get gps time 75 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset 76 | for dset_str in ['p', 'x', 'y']: 77 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx]) 78 | assert events[dset_str].size == events['t'].size 79 | return events 80 | 81 | 82 | @staticmethod 83 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]: 84 | """Compute a conservative time window of time with millisecond resolution. 85 | We have a time to index mapping for each millisecond. Hence, we need 86 | to compute the lower and upper millisecond to retrieve events. 87 | Parameters 88 | ---------- 89 | ts_start_us: start time in microseconds 90 | ts_end_us: end time in microseconds 91 | Returns 92 | ------- 93 | window_start_ms: conservative start time in milliseconds 94 | window_end_ms: conservative end time in milliseconds 95 | """ 96 | assert ts_end_us > ts_start_us 97 | window_start_ms = math.floor(ts_start_us/1000) 98 | window_end_ms = math.ceil(ts_end_us/1000) 99 | return window_start_ms, window_end_ms 100 | 101 | @staticmethod 102 | @jit(nopython=True) 103 | def get_time_indices_offsets( 104 | time_array: np.ndarray, 105 | time_start_us: int, 106 | time_end_us: int) -> Tuple[int, int]: 107 | """Compute index offset of start and end timestamps in microseconds 108 | Parameters 109 | ---------- 110 | time_array: timestamps (in us) of the events 111 | time_start_us: start timestamp (in us) 112 | time_end_us: end timestamp (in us) 113 | Returns 114 | ------- 115 | idx_start: Index within this array corresponding to time_start_us 116 | idx_end: Index within this array corresponding to time_end_us 117 | such that (in non-edge cases) 118 | time_array[idx_start] >= time_start_us 119 | time_array[idx_end] >= time_end_us 120 | time_array[idx_start - 1] < time_start_us 121 | time_array[idx_end - 1] < time_end_us 122 | this means that 123 | time_start_us <= time_array[idx_start:idx_end] < time_end_us 124 | """ 125 | 126 | assert time_array.ndim == 1 127 | 128 | idx_start = -1 129 | if time_array[-1] < time_start_us: 130 | # This can happen in extreme corner cases. E.g. 131 | # time_array[0] = 1016 132 | # time_array[-1] = 1984 133 | # time_start_us = 1990 134 | # time_end_us = 2000 135 | 136 | # Return same index twice: array[x:x] is empty. 137 | return time_array.size, time_array.size 138 | else: 139 | for idx_from_start in range(0, time_array.size, 1): 140 | if time_array[idx_from_start] >= time_start_us: 141 | idx_start = idx_from_start 142 | break 143 | assert idx_start >= 0 144 | 145 | idx_end = time_array.size 146 | for idx_from_end in range(time_array.size - 1, -1, -1): 147 | if time_array[idx_from_end] >= time_end_us: 148 | idx_end = idx_from_end 149 | else: 150 | break 151 | 152 | assert time_array[idx_start] >= time_start_us 153 | if idx_end < time_array.size: 154 | assert time_array[idx_end] >= time_end_us 155 | if idx_start > 0: 156 | assert time_array[idx_start - 1] < time_start_us 157 | if idx_end > 0: 158 | assert time_array[idx_end - 1] < time_end_us 159 | return idx_start, idx_end 160 | 161 | def ms2idx(self, time_ms: int) -> int: 162 | assert time_ms >= 0 163 | if time_ms >= self.ms_to_idx.size: 164 | return None 165 | return self.ms_to_idx[time_ms] -------------------------------------------------------------------------------- /dataloader/representations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EventRepresentation: 5 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor): 6 | raise NotImplementedError 7 | 8 | 9 | class VoxelGrid(EventRepresentation): 10 | def __init__(self, channels: int, height: int, width: int, normalize: bool): 11 | self.voxel_grid = torch.zeros((channels, height, width), dtype=torch.float, requires_grad=False) 12 | self.nb_channels = channels 13 | self.normalize = normalize 14 | 15 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor): 16 | assert x.shape == y.shape == pol.shape == time.shape 17 | assert x.ndim == 1 18 | 19 | C, H, W = self.voxel_grid.shape 20 | with torch.no_grad(): 21 | 22 | self.voxel_grid = self.voxel_grid.to(pol.device) 23 | voxel_grid = self.voxel_grid.clone() 24 | 25 | t_norm = time 26 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0]) 27 | 28 | x0 = x.int() 29 | y0 = y.int() 30 | t0 = t_norm.int() 31 | 32 | value = 2*pol-1 33 | 34 | for xlim in [x0,x0+1]: 35 | for ylim in [y0,y0+1]: 36 | for tlim in [t0,t0+1]: 37 | 38 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels) 39 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs()) 40 | 41 | index = H * W * tlim.long() + \ 42 | W * ylim.long() + \ 43 | xlim.long() 44 | 45 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 46 | 47 | if self.normalize: 48 | mask = torch.nonzero(voxel_grid, as_tuple=True) 49 | if mask[0].size()[0] > 0: 50 | mean = voxel_grid[mask].mean() 51 | std = voxel_grid[mask].std() 52 | if std > 0: 53 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 54 | else: 55 | voxel_grid[mask] = voxel_grid[mask] - mean 56 | 57 | return voxel_grid -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM zombbie/cuda11.1-cudnn8-ubuntu20.04:v1.0 2 | ARG DEBIAN_FRONTEND=noninteractive 3 | 4 | RUN apt-get update && \ 5 | apt-get install -y software-properties-common && \ 6 | add-apt-repository -y ppa:deadsnakes/ppa && \ 7 | apt install -y python3.8 python3-pip 8 | RUN apt-get -y install git 9 | 10 | RUN pip --default-timeout=2000 install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html 11 | 12 | RUN pip install cupy-cuda111 13 | 14 | RUN apt-get update 15 | RUN apt-get -y install qt5-default python3-pyqt5 16 | RUN apt-get -y install net-tools 17 | 18 | RUN mkdir -p /home/data 19 | RUN mkdir -p /home/ws 20 | ADD ../requirements.txt /home/requirements.txt 21 | ADD ../.tmux.conf /home/.tmux.conf 22 | RUN apt-get -y install tmux 23 | 24 | WORKDIR /home/ws/TESNet 25 | RUN pip install -r ../../requirements.txt 26 | 27 | RUN apt-get install htop 28 | RUN apt-get install -y xclip 29 | RUN apt-get install -y zip 30 | 31 | ARG UID 32 | ARG GID 33 | 34 | # Update the package list, install sudo, create a non-root user, and grant password-less sudo permissions 35 | RUN apt update && \ 36 | apt install -y sudo && \ 37 | addgroup --gid $GID nonroot && \ 38 | adduser --uid $UID --gid $GID --disabled-password --gecos "" nonroot && \ 39 | echo 'nonroot ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers 40 | 41 | # Set the non-root user as the default user 42 | USER nonroot 43 | RUN echo "alias tn='. run_tmux.sh'" >> ~/.bashrc 44 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | ## 1. Install docker 2 | Follow the instructions here [Docker-install](https://docs.docker.com/desktop/install/linux/ubuntu/) 3 | ## 2. Install NVIDIA Container Toolkit 4 | Follow the instructions here [NVIDIA-install](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) 5 | 6 | ## 3. Go to the project path 7 | ```bash 8 | cd {PATH_TO_THIS_PROJECT}/TemporalEventStereo 9 | ``` 10 | ## 4. Build Docker Image 11 | Build docker image using the following script. 12 | This will build the docker image "tes:v0.1" 13 | ```bash 14 | source docker/docker_build.sh 15 | ``` 16 | ## 5. Start Docker Container 17 | 18 | ```bash 19 | source docker/docker_run_multi.sh {ABSOLUTE_DATA_DIR_PATH} 20 | ``` -------------------------------------------------------------------------------- /docker/docker_build.sh: -------------------------------------------------------------------------------- 1 | export HOST_UID=$(id -u) 2 | export HOST_GID=$(id -g) 3 | 4 | docker build --build-arg UID=$HOST_UID --build-arg GID=$HOST_GID \ 5 | -t tes:v0.1 \ 6 | -f docker/Dockerfile . 7 | -------------------------------------------------------------------------------- /docker/docker_run_multi.sh: -------------------------------------------------------------------------------- 1 | data_path=$1 2 | 3 | docker run -it --gpus all \ 4 | --name TES \ 5 | --shm-size=128G \ 6 | -v $data_path:/home/data\ 7 | -v $PWD:/home/ws/TESNet\ 8 | tes:v0.1 bash -------------------------------------------------------------------------------- /infer_dsec.sh: -------------------------------------------------------------------------------- 1 | logdir=exps/debug 2 | 3 | python3 infer_dsec.py --config config/dsec_infer.yaml\ 4 | --savemodel ${logdir} 5 | 6 | curr_dir=$PWD 7 | cd $logdir 8 | zip -r test.zip test 9 | cd $curr_dir 10 | echo Gpu number is ${GPUNUM}! 11 | echo Done! -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_large import TESNet as ours_large 2 | -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mickeykang16/TemporalEventStereo/d9dc74677f568dbdf2d1d06da20480c50410e274/models/__init__.pyc -------------------------------------------------------------------------------- /models/submodule.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import math 8 | import numpy as np 9 | 10 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation): 11 | 12 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False), 13 | nn.BatchNorm2d(out_planes)) 14 | 15 | 16 | def convbn_3d(in_planes, out_planes, kernel_size, stride, pad): 17 | 18 | return nn.Sequential(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride,bias=False), 19 | nn.BatchNorm3d(out_planes)) 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 24 | super(BasicBlock, self).__init__() 25 | 26 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation), 27 | nn.ReLU(inplace=True)) 28 | 29 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation) 30 | 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | out = self.conv1(x) 36 | out = self.conv2(out) 37 | 38 | if self.downsample is not None: 39 | x = self.downsample(x) 40 | 41 | out += x 42 | 43 | return out 44 | 45 | class disparityregression(nn.Module): 46 | def __init__(self, maxdisp): 47 | super(disparityregression, self).__init__() 48 | self.disp = torch.Tensor(np.reshape(np.array(range(maxdisp)),[1, maxdisp,1,1])).cuda() 49 | 50 | def forward(self, x): 51 | out = torch.sum(x*self.disp.data,1, keepdim=True) 52 | return out 53 | 54 | class feature_extraction(nn.Module): 55 | def __init__(self, in_ch=5): 56 | super(feature_extraction, self).__init__() 57 | self.inplanes = 32 58 | self.firstconv = nn.Sequential(convbn(in_ch, 32, 3, 2, 1, 1), 59 | nn.ReLU(inplace=True), 60 | convbn(32, 32, 3, 1, 1, 1), 61 | nn.ReLU(inplace=True), 62 | convbn(32, 32, 3, 1, 1, 1), 63 | nn.ReLU(inplace=True)) 64 | 65 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1,1,1) 66 | self.layer2 = self._make_layer(BasicBlock, 64, 16, 2,1,1) 67 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 1,1,1) 68 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1,1,2) 69 | 70 | self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64,64)), 71 | convbn(128, 32, 1, 1, 0, 1), 72 | nn.ReLU(inplace=True)) 73 | 74 | self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32,32)), 75 | convbn(128, 32, 1, 1, 0, 1), 76 | nn.ReLU(inplace=True)) 77 | 78 | self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16,16)), 79 | convbn(128, 32, 1, 1, 0, 1), 80 | nn.ReLU(inplace=True)) 81 | 82 | self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8,8)), 83 | convbn(128, 32, 1, 1, 0, 1), 84 | nn.ReLU(inplace=True)) 85 | 86 | self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1), 87 | nn.ReLU(inplace=True), 88 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride = 1, bias=False)) 89 | 90 | def _make_layer(self, block, planes, blocks, stride, pad, dilation): 91 | downsample = None 92 | if stride != 1 or self.inplanes != planes * block.expansion: 93 | downsample = nn.Sequential( 94 | nn.Conv2d(self.inplanes, planes * block.expansion, 95 | kernel_size=1, stride=stride, bias=False), 96 | nn.BatchNorm2d(planes * block.expansion),) 97 | 98 | layers = [] 99 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation)) 100 | self.inplanes = planes * block.expansion 101 | for i in range(1, blocks): 102 | layers.append(block(self.inplanes, planes,1,None,pad,dilation)) 103 | 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | output = self.firstconv(x) 108 | output = self.layer1(output) 109 | output_raw = self.layer2(output) 110 | output = self.layer3(output_raw) 111 | output_skip = self.layer4(output) 112 | 113 | 114 | output_branch1 = self.branch1(output_skip) 115 | output_branch1 = F.upsample(output_branch1, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear') 116 | 117 | output_branch2 = self.branch2(output_skip) 118 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear') 119 | 120 | output_branch3 = self.branch3(output_skip) 121 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear') 122 | 123 | output_branch4 = self.branch4(output_skip) 124 | output_branch4 = F.upsample(output_branch4, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear') 125 | 126 | output_feature = torch.cat((output_raw, output_skip, output_branch4, output_branch3, output_branch2, output_branch1), 1) 127 | output_feature = self.lastconv(output_feature) 128 | 129 | return output_feature 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /models/submodule.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mickeykang16/TemporalEventStereo/d9dc74677f568dbdf2d1d06da20480c50410e274/models/submodule.pyc -------------------------------------------------------------------------------- /pre-processing/README.md: -------------------------------------------------------------------------------- 1 | 2 | We utilize the pseudo dense label for stable training of temporal stereo. To make the pseudo label, we first train the single event stereo, which has the same backbone of temporal event stereo (TESNet), but does not use the temporal aggregation. 3 | 4 | ### MVSEC 5 | Please do not forget to cite [MVSEC](https://daniilidis-group.github.io/mvsec/) if you are using the MVSEC dataset. 6 | We provide the pre-processed MVSEC dataset following [Link](https://drive.google.com/drive/folders/1ANrz99Z3UwAcTMBlQyz_PYnoRblySh8H?usp=sharing). 7 | 8 | ``` 9 | The MVSEC dataset should have the following format: 10 | ├── MVSEC_dataset 11 | │ ├── indoor_flying_1 12 | │ │ ├── event0 13 | │ │ ├── event1 14 | │ │ ├── pseudo_disp 15 | │ │ ├── disparity_image 16 | │ │ ├── voxel0skip1bin5 17 | │ │ ├── voxel1skip1bin5 18 | │ └── indoor_flying_2 19 | │ ├── event0 20 | │ └── ... 21 | ``` 22 | 23 | ### DSEC 24 | The DSEC dataset can be downloaded [here](https://dsec.ifi.uzh.ch/dsec-datasets/download/). Please do not forget to cite [DSEC](https://github.com/uzh-rpg/DSEC) if you are using the DSEC dataset. 25 | The pseudo label generated by our works can be downloaded [Link](https://drive.google.com/drive/folders/1Hcoarpo40lVnlXQx4CERtW5loyfv6zNf?usp=sharing). 26 | 27 | ``` 28 | The DSEC dataset should have the following format: 29 | ├── DSEC 30 | │ ├── train 31 | │ │ ├── zurich_city_00_a 32 | │ │ │ │ ├── events 33 | │ │ │ │ │ ├── left 34 | │ │ │ │ │ └── right 35 | │ │ │ │ ├── images 36 | │ │ │ │ │ ├── left 37 | │ │ │ │ │ ├── right 38 | │ │ │ │ │ └── timestamps.txt 39 | │ │ │ │ ├── raw_events 40 | │ │ │ │ │ ├── left 41 | │ │ │ │ │ └── right 42 | │ │ │ │ ├── disparity 43 | │ │ │ │ │ ├── event 44 | │ │ │ │ │ └── timestamps.txt 45 | │ │ │ │ └── voxel_50ms_15bin 46 | │ │ │ │ ├── left 47 | │ │ │ │ └── right 48 | │ │ └── ... 49 | │ └── test 50 | │ ├── zurich_city_12_a 51 | │ │ │ │ ├── events 52 | │ │ │ │ │ ├── left 53 | │ │ │ │ │ └── right 54 | │ │ │ │ ├── images 55 | │ │ │ │ │ ├── left 56 | │ │ │ │ │ ├── right 57 | │ │ │ │ │ └── timestamps.txt 58 | │ │ │ │ └── voxel_50ms_15bin 59 | │ │ │ │ ├── left 60 | │ │ │ │ └── right 61 | │ └── ... 62 | └── DSEC_pseudo_GT_all 63 | └── train 64 | ├── zurich_city_00_a 65 | │ ├── 000001.png 66 | │ └── ... 67 | ├── zurich_city_00_b 68 | └── ... 69 | ``` 70 | 71 | We also pre-generate the voxel grid and raw events to prevent the repetitive calculation in dataloader. 72 | This can be done using the provided code in the processing directory with the following command: 73 | The raw events are only used in the training phase for contrast maximization loss. 74 | ```bash 75 | python raw_event_parsing.py --dataset_path $DSEC_TRAIN_DATASET_PATH$ 76 | python voxel_generate_all.py --dataset_path $DSEC_TRAIN_DATASET_PATH$ 77 | python voxel_generate_test_all.py --dataset_path $DSEC_TEST_DATASET_PATH$ 78 | ``` 79 | -------------------------------------------------------------------------------- /pre-processing/eventslicer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Tuple 3 | 4 | import h5py 5 | from numba import jit 6 | import numpy as np 7 | 8 | 9 | class EventSlicer: 10 | def __init__(self, h5f: h5py.File): 11 | self.h5f = h5f 12 | 13 | self.events = dict() 14 | for dset_str in ['p', 'x', 'y', 't']: 15 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)] 16 | 17 | # This is the mapping from milliseconds to event index: 18 | # It is defined such that 19 | # (1) t[ms_to_idx[ms]] >= ms*1000 20 | # (2) t[ms_to_idx[ms] - 1] < ms*1000 21 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds. 22 | # 23 | # As an example, given 't' and 'ms': 24 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000 25 | # ms: 0 1 2 3 4 5 6 7 8 9 26 | # 27 | # we get 28 | # 29 | # ms_to_idx: 30 | # 0 2 2 3 3 3 5 5 8 9 31 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64') 32 | 33 | if "t_offset" in list(h5f.keys()): 34 | self.t_offset = int(h5f['t_offset'][()]) 35 | else: 36 | self.t_offset = 0 37 | self.t_final = int(self.events['t'][-1]) + self.t_offset 38 | 39 | def get_start_time_us(self): 40 | return self.t_offset 41 | 42 | def get_final_time_us(self): 43 | return self.t_final 44 | 45 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]: 46 | """Get events (p, x, y, t) within the specified time window 47 | Parameters 48 | ---------- 49 | t_start_us: start time in microseconds 50 | t_end_us: end time in microseconds 51 | Returns 52 | ------- 53 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved 54 | """ 55 | assert t_start_us < t_end_us 56 | 57 | # We assume that the times are top-off-day, hence subtract offset: 58 | t_start_us -= self.t_offset 59 | t_end_us -= self.t_offset 60 | 61 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us) 62 | t_start_ms_idx = self.ms2idx(t_start_ms) 63 | t_end_ms_idx = self.ms2idx(t_end_ms) 64 | 65 | if t_start_ms_idx is None or t_end_ms_idx is None: 66 | # Cannot guarantee window size anymore 67 | return None 68 | 69 | events = dict() 70 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx]) 71 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us) 72 | t_start_us_idx = t_start_ms_idx + idx_start_offset 73 | t_end_us_idx = t_start_ms_idx + idx_end_offset 74 | # Again add t_offset to get gps time 75 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset 76 | for dset_str in ['p', 'x', 'y']: 77 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx]) 78 | assert events[dset_str].size == events['t'].size 79 | return events 80 | 81 | 82 | @staticmethod 83 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]: 84 | """Compute a conservative time window of time with millisecond resolution. 85 | We have a time to index mapping for each millisecond. Hence, we need 86 | to compute the lower and upper millisecond to retrieve events. 87 | Parameters 88 | ---------- 89 | ts_start_us: start time in microseconds 90 | ts_end_us: end time in microseconds 91 | Returns 92 | ------- 93 | window_start_ms: conservative start time in milliseconds 94 | window_end_ms: conservative end time in milliseconds 95 | """ 96 | assert ts_end_us > ts_start_us 97 | window_start_ms = math.floor(ts_start_us/1000) 98 | window_end_ms = math.ceil(ts_end_us/1000) 99 | return window_start_ms, window_end_ms 100 | 101 | @staticmethod 102 | @jit(nopython=True) 103 | def get_time_indices_offsets( 104 | time_array: np.ndarray, 105 | time_start_us: int, 106 | time_end_us: int) -> Tuple[int, int]: 107 | """Compute index offset of start and end timestamps in microseconds 108 | Parameters 109 | ---------- 110 | time_array: timestamps (in us) of the events 111 | time_start_us: start timestamp (in us) 112 | time_end_us: end timestamp (in us) 113 | Returns 114 | ------- 115 | idx_start: Index within this array corresponding to time_start_us 116 | idx_end: Index within this array corresponding to time_end_us 117 | such that (in non-edge cases) 118 | time_array[idx_start] >= time_start_us 119 | time_array[idx_end] >= time_end_us 120 | time_array[idx_start - 1] < time_start_us 121 | time_array[idx_end - 1] < time_end_us 122 | this means that 123 | time_start_us <= time_array[idx_start:idx_end] < time_end_us 124 | """ 125 | 126 | assert time_array.ndim == 1 127 | 128 | idx_start = -1 129 | if time_array[-1] < time_start_us: 130 | # This can happen in extreme corner cases. E.g. 131 | # time_array[0] = 1016 132 | # time_array[-1] = 1984 133 | # time_start_us = 1990 134 | # time_end_us = 2000 135 | 136 | # Return same index twice: array[x:x] is empty. 137 | return time_array.size, time_array.size 138 | else: 139 | for idx_from_start in range(0, time_array.size, 1): 140 | if time_array[idx_from_start] >= time_start_us: 141 | idx_start = idx_from_start 142 | break 143 | assert idx_start >= 0 144 | 145 | idx_end = time_array.size 146 | for idx_from_end in range(time_array.size - 1, -1, -1): 147 | if time_array[idx_from_end] >= time_end_us: 148 | idx_end = idx_from_end 149 | else: 150 | break 151 | 152 | assert time_array[idx_start] >= time_start_us 153 | if idx_end < time_array.size: 154 | assert time_array[idx_end] >= time_end_us 155 | if idx_start > 0: 156 | assert time_array[idx_start - 1] < time_start_us 157 | if idx_end > 0: 158 | assert time_array[idx_end - 1] < time_end_us 159 | return idx_start, idx_end 160 | 161 | def ms2idx(self, time_ms: int) -> int: 162 | assert time_ms >= 0 163 | if time_ms >= self.ms_to_idx.size: 164 | return None 165 | return self.ms_to_idx[time_ms] -------------------------------------------------------------------------------- /pre-processing/raw_event_parsing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pdb 4 | import torch 5 | from pathlib import Path 6 | import hdf5plugin 7 | import h5py 8 | from eventslicer import EventSlicer 9 | import torch 10 | import argparse 11 | 12 | height = 480 13 | width = 640 14 | 15 | def rectify_events(x: np.ndarray, y: np.ndarray, location: str, rectify_ev_maps): 16 | rectify_map = rectify_ev_maps[location] 17 | assert rectify_map.shape == (height, width, 2), rectify_map.shape 18 | assert x.max() < width 19 | assert y.max() < height 20 | return rectify_map[y, x] 21 | 22 | 23 | if __name__=='__main__': 24 | parser = argparse.ArgumentParser(description='Dataset Processing.') 25 | parser.add_argument('--dataset_path', help='Path to dataset', required=True) 26 | args = parser.parse_args() 27 | 28 | dataset_dir = args.dataset_path 29 | 30 | folder_list_all = os.listdir(dataset_dir) 31 | folder_list_all.sort() 32 | 33 | event_prefix = 'events' 34 | 35 | locations = ['left', 'right'] 36 | delta_t_ms = 50 37 | 38 | delta_t_us = delta_t_ms * 1000 39 | 40 | save_dir = dataset_dir 41 | 42 | for folder_name in folder_list_all: 43 | print(folder_name) 44 | seq_path = Path(os.path.join(dataset_dir, folder_name)) 45 | 46 | disp_dir = seq_path / 'disparity' 47 | assert disp_dir.is_dir() 48 | 49 | timestamps = np.loadtxt(disp_dir / 'timestamps.txt', dtype='int64') 50 | 51 | # load disparity paths 52 | ev_disp_dir = disp_dir / 'event' 53 | assert ev_disp_dir.is_dir() 54 | disp_gt_pathstrings = list() 55 | for entry in ev_disp_dir.iterdir(): 56 | assert str(entry.name).endswith('.png') 57 | disp_gt_pathstrings.append(str(entry)) 58 | disp_gt_pathstrings.sort() 59 | disp_gt_pathstrings = disp_gt_pathstrings 60 | assert len(disp_gt_pathstrings) == timestamps.size 61 | 62 | event_dir = seq_path / event_prefix 63 | assert event_dir.is_dir() 64 | 65 | 66 | assert int(Path(disp_gt_pathstrings[0]).stem) == 0 67 | disp_gt_pathstrings.pop(0) 68 | timestamps = timestamps[1:] 69 | 70 | h5f = dict() 71 | rectify_ev_maps = dict() 72 | event_slicers = dict() 73 | 74 | ev_dir = seq_path / 'events' 75 | 76 | event_vox_save_dir = os.path.join(save_dir, folder_name, 'raw_events') 77 | if not os.path.exists(event_vox_save_dir): 78 | os.makedirs(event_vox_save_dir) 79 | else: 80 | continue 81 | if not os.path.exists(os.path.join(event_vox_save_dir, 'left')): 82 | os.makedirs(os.path.join(event_vox_save_dir, 'left')) 83 | if not os.path.exists(os.path.join(event_vox_save_dir, 'right')): 84 | os.makedirs(os.path.join(event_vox_save_dir, 'right')) 85 | 86 | 87 | for location in locations: 88 | ev_dir_location = ev_dir / location 89 | ev_data_file = ev_dir_location / 'events.h5' 90 | ev_rect_file = ev_dir_location / 'rectify_map.h5' 91 | 92 | h5f_location = h5py.File(str(ev_data_file), 'r') 93 | h5f[location] = h5f_location 94 | event_slicers[location] = EventSlicer(h5f_location) 95 | with h5py.File(str(ev_rect_file), 'r') as h5_rect: 96 | rectify_ev_maps[location] = h5_rect['rectify_map'][()] 97 | 98 | 99 | 100 | 101 | 102 | 103 | for index in range(len(timestamps)): 104 | ts_end = timestamps[index] 105 | # ts_start should be fine (within the window as we removed the first disparity map) 106 | ts_start = ts_end - delta_t_us 107 | 108 | for location in locations: 109 | event_data = event_slicers[location].get_events(ts_start, ts_end) 110 | 111 | p = event_data['p'] 112 | t = event_data['t'] 113 | 114 | t = (t - t[0]).astype('uint32') 115 | 116 | x = event_data['x'] 117 | y = event_data['y'] 118 | 119 | xy_rect = rectify_events(x, y, location, rectify_ev_maps) 120 | x_rect = xy_rect[:, 0] 121 | y_rect = xy_rect[:, 1] 122 | 123 | event_index = (x_rect >= 0) & (x_rect < width) & (y_rect >= 0) & (y_rect < height) 124 | 125 | x_rect = x_rect[event_index] 126 | y_rect = y_rect[event_index] 127 | t = t[event_index] 128 | p = p[event_index] 129 | 130 | 131 | event_representation = np.stack([t, np.round(x_rect).astype('int'), np.round(y_rect).astype('int'), p], 1).astype('uint32') 132 | 133 | 134 | np.save(event_vox_save_dir + '/' + str(location) + '/' + disp_gt_pathstrings[index].split('/')[-1].replace('.png', '.npy'), event_representation) 135 | 136 | 137 | -------------------------------------------------------------------------------- /pre-processing/voxel_generate_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pdb 4 | import torch 5 | from pathlib import Path 6 | import hdf5plugin 7 | import h5py 8 | from eventslicer import EventSlicer 9 | import argparse 10 | 11 | num_bins = 15 12 | height = 480 13 | width = 640 14 | 15 | 16 | def events_to_voxel_grid(x, y, p, t, num_bins=num_bins, width=width, height=height): 17 | t = (t - t[0]).astype('float32') 18 | t = (t/t[-1]) 19 | x = x.astype('float32') 20 | y = y.astype('float32') 21 | pol = p.astype('float32') 22 | 23 | x = torch.from_numpy(x) 24 | y = torch.from_numpy(y) 25 | pol = torch.from_numpy(pol) 26 | time = torch.from_numpy(t) 27 | 28 | with torch.no_grad(): 29 | voxel_grid = torch.zeros((num_bins, height, width), dtype=torch.float, requires_grad=False) 30 | C, H, W = voxel_grid.shape 31 | t_norm = time 32 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0]) 33 | 34 | x0 = x.int() 35 | y0 = y.int() 36 | t0 = t_norm.int() 37 | if int(pol.min()) == -1: 38 | value = pol 39 | else: 40 | value = 2*pol-1 41 | # import pdb; pdb.set_trace() 42 | for xlim in [x0,x0+1]: 43 | for ylim in [y0,y0+1]: 44 | for tlim in [t0,t0+1]: 45 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < num_bins) 46 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs()) 47 | index = H * W * tlim.long() + \ 48 | W * ylim.long() + \ 49 | xlim.long() 50 | 51 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 52 | 53 | mask = torch.nonzero(voxel_grid, as_tuple=True) 54 | if mask[0].size()[0] > 0: 55 | mean = voxel_grid[mask].mean() 56 | std = voxel_grid[mask].std() 57 | if std > 0: 58 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 59 | else: 60 | voxel_grid[mask] = voxel_grid[mask] - mean 61 | 62 | return voxel_grid 63 | 64 | def rectify_events(x: np.ndarray, y: np.ndarray, location: str, rectify_ev_maps): 65 | rectify_map = rectify_ev_maps[location] 66 | assert rectify_map.shape == (height, width, 2), rectify_map.shape 67 | assert x.max() < width 68 | assert y.max() < height 69 | return rectify_map[y, x] 70 | 71 | 72 | if __name__=='__main__': 73 | parser = argparse.ArgumentParser(description='Dataset Processing.') 74 | parser.add_argument('--dataset_path', help='Path to dataset', required=True) 75 | args = parser.parse_args() 76 | 77 | dataset_dir = args.dataset_path 78 | save_dir = dataset_dir 79 | 80 | folder_list_all = os.listdir(dataset_dir) 81 | folder_list_all.sort() 82 | 83 | event_prefix = 'events' 84 | 85 | locations = ['left', 'right'] 86 | delta_t_ms = 50 87 | delta_t_us = delta_t_ms * 1000 88 | 89 | for folder_name in folder_list_all: 90 | 91 | print(folder_name) 92 | 93 | seq_path = Path(os.path.join(dataset_dir, folder_name)) 94 | 95 | 96 | disp_dir = seq_path / 'disparity' 97 | assert disp_dir.is_dir() 98 | 99 | # timestamps = np.loadtxt(disp_dir / 'timestamps.txt', dtype='int64') 100 | 101 | img_dir = seq_path / 'images' 102 | timestamps = np.loadtxt(img_dir / 'timestamps.txt', dtype='int64') 103 | assert img_dir.is_dir() 104 | 105 | img_left_dir = img_dir / 'left/rectified' 106 | img_pathstrings = list() 107 | for entry in img_left_dir.iterdir(): 108 | assert str(entry.name).endswith('.png') 109 | img_pathstrings.append(str(entry)) 110 | img_pathstrings.sort() 111 | 112 | 113 | # load disparity paths 114 | ev_disp_dir = disp_dir / 'event' 115 | assert ev_disp_dir.is_dir() 116 | disp_gt_pathstrings = list() 117 | for entry in ev_disp_dir.iterdir(): 118 | assert str(entry.name).endswith('.png') 119 | disp_gt_pathstrings.append(str(entry)) 120 | disp_gt_pathstrings.sort() 121 | disp_gt_pathstrings = disp_gt_pathstrings 122 | 123 | # assert len(disp_gt_pathstrings) == timestamps.size 124 | 125 | 126 | event_dir = seq_path / event_prefix 127 | assert event_dir.is_dir() 128 | 129 | 130 | assert int(Path(disp_gt_pathstrings[0]).stem) == 0 131 | disp_gt_pathstrings.pop(0) 132 | timestamps = timestamps[1:] 133 | img_pathstrings = img_pathstrings[1:] 134 | 135 | assert len(disp_gt_pathstrings) == (timestamps.size //2) 136 | assert len(img_pathstrings) == timestamps.size 137 | 138 | h5f = dict() 139 | rectify_ev_maps = dict() 140 | event_slicers = dict() 141 | 142 | ev_dir = seq_path / 'events' 143 | for location in locations: 144 | ev_dir_location = ev_dir / location 145 | ev_data_file = ev_dir_location / 'events.h5' 146 | ev_rect_file = ev_dir_location / 'rectify_map.h5' 147 | 148 | h5f_location = h5py.File(str(ev_data_file), 'r') 149 | h5f[location] = h5f_location 150 | event_slicers[location] = EventSlicer(h5f_location) 151 | with h5py.File(str(ev_rect_file), 'r') as h5_rect: 152 | rectify_ev_maps[location] = h5_rect['rectify_map'][()] 153 | 154 | 155 | 156 | event_vox_save_dir = os.path.join(save_dir, folder_name, 'voxel_50ms_15bin') 157 | if not os.path.exists(event_vox_save_dir): 158 | os.makedirs(event_vox_save_dir) 159 | if not os.path.exists(os.path.join(event_vox_save_dir, 'left')): 160 | os.makedirs(os.path.join(event_vox_save_dir, 'left')) 161 | if not os.path.exists(os.path.join(event_vox_save_dir, 'right')): 162 | os.makedirs(os.path.join(event_vox_save_dir, 'right')) 163 | 164 | 165 | for index in range(len(timestamps)): 166 | ts_end = timestamps[index] 167 | # ts_start should be fine (within the window as we removed the first disparity map) 168 | ts_start = ts_end - delta_t_us 169 | 170 | print(ts_end, event_vox_save_dir + '/' + str(location) + '/' + img_pathstrings[index].split('/')[-1].replace('.png', '.npy')) 171 | 172 | for location in locations: 173 | event_data = event_slicers[location].get_events(ts_start, ts_end) 174 | 175 | p = event_data['p'] 176 | t = event_data['t'] 177 | x = event_data['x'] 178 | y = event_data['y'] 179 | 180 | xy_rect = rectify_events(x, y, location, rectify_ev_maps) 181 | x_rect = xy_rect[:, 0] 182 | y_rect = xy_rect[:, 1] 183 | 184 | event_representation = events_to_voxel_grid(x_rect, y_rect, p, t) 185 | np.save(event_vox_save_dir + '/' + str(location) + '/' + img_pathstrings[index].split('/')[-1].replace('.png', '.npy') 186 | ,event_representation) 187 | 188 | -------------------------------------------------------------------------------- /pre-processing/voxel_generate_test_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pdb 4 | import torch 5 | from pathlib import Path 6 | import hdf5plugin 7 | import h5py 8 | from eventslicer import EventSlicer 9 | import argparse 10 | 11 | num_bins = 15 12 | height = 480 13 | width = 640 14 | 15 | 16 | def events_to_voxel_grid(x, y, p, t, num_bins=num_bins, width=width, height=height): 17 | t = (t - t[0]).astype('float32') 18 | t = (t/t[-1]) 19 | x = x.astype('float32') 20 | y = y.astype('float32') 21 | pol = p.astype('float32') 22 | 23 | x = torch.from_numpy(x) 24 | y = torch.from_numpy(y) 25 | pol = torch.from_numpy(pol) 26 | time = torch.from_numpy(t) 27 | 28 | with torch.no_grad(): 29 | voxel_grid = torch.zeros((num_bins, height, width), dtype=torch.float, requires_grad=False) 30 | C, H, W = voxel_grid.shape 31 | t_norm = time 32 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0]) 33 | 34 | x0 = x.int() 35 | y0 = y.int() 36 | t0 = t_norm.int() 37 | if int(pol.min()) == -1: 38 | value = pol 39 | else: 40 | value = 2*pol-1 41 | # import pdb; pdb.set_trace() 42 | for xlim in [x0,x0+1]: 43 | for ylim in [y0,y0+1]: 44 | for tlim in [t0,t0+1]: 45 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < num_bins) 46 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs()) 47 | index = H * W * tlim.long() + \ 48 | W * ylim.long() + \ 49 | xlim.long() 50 | 51 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 52 | 53 | mask = torch.nonzero(voxel_grid, as_tuple=True) 54 | if mask[0].size()[0] > 0: 55 | mean = voxel_grid[mask].mean() 56 | std = voxel_grid[mask].std() 57 | if std > 0: 58 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 59 | else: 60 | voxel_grid[mask] = voxel_grid[mask] - mean 61 | 62 | return voxel_grid 63 | 64 | def rectify_events(x: np.ndarray, y: np.ndarray, location: str, rectify_ev_maps): 65 | rectify_map = rectify_ev_maps[location] 66 | assert rectify_map.shape == (height, width, 2), rectify_map.shape 67 | assert x.max() < width 68 | assert y.max() < height 69 | return rectify_map[y, x] 70 | 71 | 72 | if __name__=='__main__': 73 | parser = argparse.ArgumentParser(description='Dataset Processing.') 74 | parser.add_argument('--dataset_path', help='Path to dataset', required=True) 75 | args = parser.parse_args() 76 | 77 | dataset_dir = args.dataset_path 78 | save_dir = dataset_dir 79 | 80 | folder_list_all = os.listdir(dataset_dir) 81 | folder_list_all.sort() 82 | 83 | event_prefix = 'events' 84 | 85 | locations = ['left', 'right'] 86 | delta_t_ms = 50 87 | delta_t_us = delta_t_ms * 1000 88 | 89 | for folder_name in folder_list_all: 90 | print(folder_name) 91 | 92 | seq_path = Path(os.path.join(dataset_dir, folder_name)) 93 | 94 | 95 | img_dir = seq_path / 'images' 96 | assert img_dir.is_dir() 97 | 98 | timestamps = np.loadtxt(img_dir / 'timestamps.txt', dtype='int64') 99 | 100 | 101 | img_left_dir = img_dir / 'left/rectified' 102 | img_pathstrings = list() 103 | for entry in img_left_dir.iterdir(): 104 | assert str(entry.name).endswith('.png') 105 | img_pathstrings.append(str(entry)) 106 | img_pathstrings.sort() 107 | 108 | 109 | # timestamps = timestamps[::2] 110 | timestamps = timestamps[1:] 111 | img_pathstrings = img_pathstrings[1:] 112 | assert len(img_pathstrings) == timestamps.size 113 | 114 | event_dir = seq_path / event_prefix 115 | assert event_dir.is_dir() 116 | 117 | h5f = dict() 118 | rectify_ev_maps = dict() 119 | event_slicers = dict() 120 | 121 | ev_dir = seq_path / 'events' 122 | for location in locations: 123 | ev_dir_location = ev_dir / location 124 | ev_data_file = ev_dir_location / 'events.h5' 125 | ev_rect_file = ev_dir_location / 'rectify_map.h5' 126 | 127 | h5f_location = h5py.File(str(ev_data_file), 'r') 128 | h5f[location] = h5f_location 129 | event_slicers[location] = EventSlicer(h5f_location) 130 | with h5py.File(str(ev_rect_file), 'r') as h5_rect: 131 | rectify_ev_maps[location] = h5_rect['rectify_map'][()] 132 | 133 | 134 | event_vox_save_dir = os.path.join(save_dir, folder_name, 'voxel_50ms_15bin') 135 | if not os.path.exists(event_vox_save_dir): 136 | os.makedirs(event_vox_save_dir) 137 | if not os.path.exists(os.path.join(event_vox_save_dir, 'left')): 138 | os.makedirs(os.path.join(event_vox_save_dir, 'left')) 139 | if not os.path.exists(os.path.join(event_vox_save_dir, 'right')): 140 | os.makedirs(os.path.join(event_vox_save_dir, 'right')) 141 | 142 | 143 | for index in range(len(timestamps)): 144 | # print(str(2 * index + 2).zfill(6)) 145 | ts_end = timestamps[index] 146 | # ts_start should be fine (within the window as we removed the first disparity map) 147 | ts_start = ts_end - delta_t_us 148 | 149 | print(ts_end, event_vox_save_dir + '/' + str(location) + '/' + 150 | img_pathstrings[index].split('/')[-1].replace('.png', '.npy')) 151 | 152 | 153 | for location in locations: 154 | event_data = event_slicers[location].get_events(ts_start, ts_end) 155 | 156 | p = event_data['p'] 157 | t = event_data['t'] 158 | x = event_data['x'] 159 | y = event_data['y'] 160 | 161 | xy_rect = rectify_events(x, y, location, rectify_ev_maps) 162 | x_rect = xy_rect[:, 0] 163 | y_rect = xy_rect[:, 1] 164 | 165 | event_representation = events_to_voxel_grid(x_rect, y_rect, p, t) 166 | np.save(event_vox_save_dir + '/' + str(location) + '/' + img_pathstrings[index].split('/')[-1].replace('.png', '.npy') 167 | , event_representation) 168 | 169 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | omegaconf 3 | opencv-python 4 | opt-einsum 5 | pafy 6 | pypng 7 | PyQt5 8 | pytorch-lightning==1.5.2 9 | PyYAML 10 | seaborn 11 | scikit-image 12 | scikit-learn 13 | tensorboard 14 | thop 15 | tqdm 16 | typing-extensions 17 | yacs 18 | debugpy 19 | pillow==9.0.1 20 | einops 21 | h5py 22 | numba 23 | hdf5plugin 24 | setuptools==59.5.0 25 | flow_vis -------------------------------------------------------------------------------- /resource/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mickeykang16/TemporalEventStereo/d9dc74677f568dbdf2d1d06da20480c50410e274/resource/teaser.png -------------------------------------------------------------------------------- /resource/temporal_stereo_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mickeykang16/TemporalEventStereo/d9dc74677f568dbdf2d1d06da20480c50410e274/resource/temporal_stereo_demo.gif -------------------------------------------------------------------------------- /run_tmux.sh: -------------------------------------------------------------------------------- 1 | tmux new-session -s gpu$GPUNUM 'tmux source-file ./.tmux.conf; $SHELL' 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | from .visualization import disp_err_to_color 3 | from .viz import save_matrix 4 | from .softsplat import softsplat 5 | from .warp import flow_warp, disp_warp 6 | from .flow_vis import flow_to_color -------------------------------------------------------------------------------- /utils/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .pixel_error import calc_error 2 | from .eval import do_evaluation, do_occlusion_evaluation, do_evaluation_test 3 | from .flow_pixel_error import flow_calc_error 4 | from .flow_eval import do_flow_evaluation 5 | from .inverse_warp import inverse_warp -------------------------------------------------------------------------------- /utils/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | from .inverse_warp import inverse_warp 6 | from .pixel_error import calc_error, calc_error_test 7 | 8 | 9 | def do_evaluation(est_disp, gt_disp, lb, ub): 10 | """ 11 | Do pixel error evaluation. (See KITTI evaluation protocols for details.) 12 | Args: 13 | est_disp: (Tensor), estimated disparity map, 14 | [..., Height, Width] layout 15 | gt_disp: (Tensor), ground truth disparity map 16 | [..., Height, Width] layout 17 | lb: (scalar), the lower bound of disparity you want to mask out 18 | ub: (scalar), the upper bound of disparity you want to mask out 19 | Returns: 20 | error_dict: (dict), the error of 1px, 2px, 3px, 5px, in percent, 21 | range [0,100] and average error epe 22 | """ 23 | error_dict = {} 24 | if est_disp is None: 25 | warnings.warn('Estimated disparity map is None') 26 | return error_dict 27 | if gt_disp is None: 28 | warnings.warn('Reference ground truth disparity map is None') 29 | return error_dict 30 | 31 | if torch.is_tensor(est_disp): 32 | est_disp = est_disp.clone().cpu() 33 | 34 | if torch.is_tensor(gt_disp): 35 | gt_disp = gt_disp.clone().cpu() 36 | 37 | assert est_disp.shape == gt_disp.shape, "Estimated Disparity map with shape: {}, but GroundTruth Disparity map" \ 38 | " with shape: {}".format(est_disp.shape, gt_disp.shape) 39 | 40 | error_dict = calc_error(est_disp, gt_disp, lb=lb, ub=ub) 41 | 42 | return error_dict 43 | 44 | 45 | 46 | def do_evaluation_test(est_disp, gt_disp, lb, ub): 47 | """ 48 | Do pixel error evaluation. (See KITTI evaluation protocols for details.) 49 | Args: 50 | est_disp: (Tensor), estimated disparity map, 51 | [..., Height, Width] layout 52 | gt_disp: (Tensor), ground truth disparity map 53 | [..., Height, Width] layout 54 | lb: (scalar), the lower bound of disparity you want to mask out 55 | ub: (scalar), the upper bound of disparity you want to mask out 56 | Returns: 57 | error_dict: (dict), the error of 1px, 2px, 3px, 5px, in percent, 58 | range [0,100] and average error epe 59 | """ 60 | error_dict = {} 61 | if est_disp is None: 62 | warnings.warn('Estimated disparity map is None') 63 | return error_dict 64 | if gt_disp is None: 65 | warnings.warn('Reference ground truth disparity map is None') 66 | return error_dict 67 | 68 | if torch.is_tensor(est_disp): 69 | est_disp = est_disp.clone().cpu() 70 | 71 | if torch.is_tensor(gt_disp): 72 | gt_disp = gt_disp.clone().cpu() 73 | 74 | assert est_disp.shape == gt_disp.shape, "Estimated Disparity map with shape: {}, but GroundTruth Disparity map" \ 75 | " with shape: {}".format(est_disp.shape, gt_disp.shape) 76 | 77 | error_dict = calc_error_test(est_disp, gt_disp, lb=lb, ub=ub) 78 | 79 | return error_dict 80 | 81 | 82 | def do_occlusion_evaluation(est_disp, ref_gt_disp, target_gt_disp, lb, ub): 83 | """ 84 | Do occlusoin evaluation. 85 | Args: 86 | est_disp: (Tensor), estimated disparity map 87 | [BatchSize, 1, Height, Width] layout 88 | ref_gt_disp: (Tensor), reference(left) ground truth disparity map 89 | [BatchSize, 1, Height, Width] layout 90 | target_gt_disp: (Tensor), target(right) ground truth disparity map, 91 | [BatchSize, 1, Height, Width] layout 92 | lb: (scalar): the lower bound of disparity you want to mask out 93 | ub: (scalar): the upper bound of disparity you want to mask out 94 | Returns: 95 | """ 96 | error_dict = {} 97 | if est_disp is None: 98 | warnings.warn('Estimated disparity map is None, expected given') 99 | return error_dict 100 | if ref_gt_disp is None: 101 | warnings.warn('Reference ground truth disparity map is None, expected given') 102 | return error_dict 103 | if target_gt_disp is None: 104 | warnings.warn('Target ground truth disparity map is None, expected given') 105 | return error_dict 106 | 107 | if torch.is_tensor(est_disp): 108 | est_disp = est_disp.clone().cpu() 109 | if torch.is_tensor(ref_gt_disp): 110 | ref_gt_disp = ref_gt_disp.clone().cpu() 111 | if torch.is_tensor(target_gt_disp): 112 | target_gt_disp = target_gt_disp.clone().cpu() 113 | 114 | assert est_disp.shape == ref_gt_disp.shape and target_gt_disp.shape == ref_gt_disp.shape, "{}, {}, {}".format( 115 | est_disp.shape, ref_gt_disp.shape, target_gt_disp.shape) 116 | 117 | warp_ref_gt_disp = inverse_warp(target_gt_disp.clone(), -ref_gt_disp.clone(), mode='disparity') 118 | theta = 1.0 119 | eps = 1e-6 120 | occlusion = ( 121 | (torch.abs(warp_ref_gt_disp.clone() - ref_gt_disp.clone()) > theta) | 122 | (torch.abs(warp_ref_gt_disp.clone()) < eps) 123 | ).prod(dim=1, keepdim=True).type_as(ref_gt_disp) 124 | occlusion = occlusion.clamp(0, 1) 125 | 126 | occlusion_error_dict = calc_error( 127 | est_disp.clone() * occlusion.clone(), 128 | ref_gt_disp.clone() * occlusion.clone(), 129 | lb=lb, ub=ub 130 | ) 131 | for key in occlusion_error_dict.keys(): 132 | error_dict['occ_' + key] = occlusion_error_dict[key] 133 | 134 | not_occlusion = 1.0 - occlusion 135 | not_occlusion_error_dict = calc_error( 136 | est_disp.clone() * not_occlusion.clone(), 137 | ref_gt_disp.clone() * not_occlusion.clone(), 138 | lb=lb, ub=ub 139 | ) 140 | for key in not_occlusion_error_dict.keys(): 141 | error_dict['noc_' + key] = not_occlusion_error_dict[key] 142 | 143 | return error_dict -------------------------------------------------------------------------------- /utils/evaluation/flow_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | 4 | from .flow_pixel_error import flow_calc_error 5 | 6 | def do_flow_evaluation(est_flow, gt_flow, lb=0.0, ub=400, sparse=False): 7 | """ 8 | Do pixel error evaluation. (See KITTI evaluation protocols for details.) 9 | Args: 10 | est_flow: (Tensor), estimated flow map 11 | [..., 2, Height, Width] layout 12 | gt_flow: (Tensor), ground truth flow map 13 | [..., 2, Height, Width] layout 14 | lb: (scalar), the lower bound of disparity you want to mask out 15 | ub: (scalar), the upper bound of disparity you want to mask out 16 | sparse: (bool), whether the given flow is sparse, default False 17 | Returns: 18 | error_dict (dict): the error of 1px, 2px, 3px, 5px, in percent, 19 | range [0,100] and average error epe 20 | """ 21 | error_dict = {} 22 | if est_flow is None: 23 | warnings.warn('Estimated flow map is None') 24 | return error_dict 25 | if gt_flow is None: 26 | warnings.warn('Reference ground truth flow map is None') 27 | return error_dict 28 | 29 | if torch.is_tensor(est_flow): 30 | est_flow = est_flow.clone().cpu() 31 | 32 | if torch.is_tensor(gt_flow): 33 | gt_flow = gt_flow.clone().cpu() 34 | 35 | error_dict = flow_calc_error(est_flow, gt_flow, sparse=sparse) 36 | 37 | return error_dict 38 | -------------------------------------------------------------------------------- /utils/evaluation/flow_pixel_error.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | def zero_mask(input, eps=1e-12): 6 | mask = abs(input) < eps 7 | return mask 8 | 9 | def flow_calc_error(est_flow=None, gt_flow=None, lb=0.0, ub=400, sparse=False): 10 | """ 11 | Args: 12 | est_flow: (Tensor), estimated flow map 13 | [..., 2, Height, Width] layout 14 | gt_flow: (Tensor), ground truth flow map 15 | [..., 2, Height, Width] layout 16 | lb: (scalar), the lower bound of disparity you want to mask out 17 | ub: (scalar), the upper bound of disparity you want to mask out 18 | sparse: (bool), whether the given flow is sparse, default False 19 | Output: 20 | dict: the error of 1px, 2px, 3px, 5px, in percent, 21 | range [0,100] and average error epe 22 | """ 23 | error1 = torch.Tensor([0.]) 24 | error2 = torch.Tensor([0.]) 25 | error3 = torch.Tensor([0.]) 26 | error5 = torch.Tensor([0.]) 27 | epe = torch.Tensor([0.]) 28 | 29 | if (not torch.is_tensor(est_flow)) or (not torch.is_tensor(gt_flow)): 30 | return { 31 | '1px': error1 * 100, 32 | '2px': error2 * 100, 33 | '3px': error3 * 100, 34 | '5px': error5 * 100, 35 | 'epe': epe 36 | } 37 | 38 | assert torch.is_tensor(est_flow) and torch.is_tensor(gt_flow) 39 | assert est_flow.shape == gt_flow.shape 40 | 41 | est_flow = est_flow.clone().cpu() 42 | gt_flow = gt_flow.clone().cpu() 43 | if len(gt_flow.shape) == 3: 44 | gt_flow = gt_flow.unsqueeze(0) 45 | est_flow = est_flow.unsqueeze(0) 46 | 47 | assert gt_flow.shape[1] == 2, "flow should have horizontal and vertical dimension, " \ 48 | "but got {}".format(gt_flow.shape[1]) 49 | 50 | # [B, 1, H, W] 51 | gt_u, gt_v = gt_flow[:, 0:1, :, :], gt_flow[:, 1:2, :, :] 52 | est_u, est_v = est_flow[:, 0:1, :, :], est_flow[:, 1:2, :, :] 53 | 54 | # get valid mask 55 | # [B, 1, H, W] 56 | mask = torch.ones(gt_u.shape, dtype=torch.bool, device=gt_u.device) 57 | if sparse: 58 | mask = mask & (~(zero_mask(gt_u) & zero_mask(gt_v))) 59 | mask = mask & (~(torch.isnan(gt_u) | torch.isnan(gt_v))) 60 | 61 | rad = torch.sqrt(gt_u**2 + gt_v**2) 62 | mask = mask & (rad > lb) & (rad < ub) 63 | 64 | mask.detach_() 65 | if abs(mask.float().sum()) < 1.0: 66 | return { 67 | '1px': error1 * 100, 68 | '2px': error2 * 100, 69 | '3px': error3 * 100, 70 | '5px': error5 * 100, 71 | 'epe': epe 72 | } 73 | 74 | 75 | 76 | gt_u = gt_u[mask] 77 | gt_v = gt_v[mask] 78 | est_u = est_u[mask] 79 | est_v = est_v[mask] 80 | 81 | abs_error = torch.sqrt((gt_u - est_u)**2 + (gt_v - est_v)**2) 82 | total_num = mask.float().sum() 83 | 84 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num 85 | error2 = torch.sum(torch.gt(abs_error, 2).float()) / total_num 86 | error3 = torch.sum(torch.gt(abs_error, 3).float()) / total_num 87 | error5 = torch.sum(torch.gt(abs_error, 5).float()) / total_num 88 | epe = abs_error.float().mean() 89 | 90 | return { 91 | '1px': error1 * 100, 92 | '2px': error2 * 100, 93 | '3px': error3 * 100, 94 | '5px': error5 * 100, 95 | 'epe': epe 96 | } -------------------------------------------------------------------------------- /utils/evaluation/inverse_warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def inverse_warp(img: torch.Tensor, 7 | motion: torch.Tensor, 8 | mode: str ='disparity', 9 | K: torch.Tensor = None, 10 | inv_K: torch.Tensor = None, 11 | T_target_to_source: torch.Tensor = None, 12 | interpolate_mode: str = 'bilinear', 13 | padding_mode: str = 'zeros', 14 | eps: float = 1e-7, 15 | output_all: bool = False): 16 | """ 17 | sample the image pixel value from source image and project to target image space, 18 | Args: 19 | img: (Tensor): the source image (where to sample pixels) 20 | [BatchSize, C, Height, Width] 21 | motion: (Tensor): disparity/depth/flow map of the target image 22 | [BatchSize, 1, Height, Width] 23 | mode: (str): which kind of warp to perform, including: ['disparity', 'depth', 'flow'] 24 | K: (Optional, Tensor): instrincs of camera 25 | [BatchSize, 3, 3] or [BatchSize, 4, 4] 26 | inv_K: (Optional, Tensor): invserse instrincs of camera 27 | [BatchSize, 3, 3] or [BatchSize, 4, 4] 28 | T_target_to_source: (Optional, Tensor): predicted transformation matrix from target image to source image frames 29 | [BatchSize, 4, 4] 30 | interpolate_mode: (str): interpolate mode when grid sample, default is bilinear 31 | padding_mode: (str): padding mode when grid sample, default is zero padding 32 | eps: (float): eplison value to avoid divide 0, default 1e-7 33 | output_all: (bool): if output all result during warp, default False 34 | Returns: 35 | projected_img: (Tensor): source image warped to the target image 36 | [BatchSize, C, Height, Width] 37 | output: (Optional, Dict): such as optical flow, flow mask, triangular depth, src_pixel_coord and so on 38 | """ 39 | B, C, H, W = motion.shape 40 | device = motion.device 41 | dtype = motion.dtype 42 | output = {} 43 | 44 | if mode == 'disparity': 45 | assert C == 1, "Disparity map must be 1 channel, but {} got!".format(C) 46 | # [B, 2, H, W] 47 | pixel_coord = mesh_grid(B, H, W, device, dtype) 48 | X = pixel_coord[:, 0, :, :] + motion[:, 0] 49 | Y = pixel_coord[:, 1, :, :] 50 | elif mode == 'flow': 51 | assert C == 2, "Optical flow map must be 2 channel, but {} got!".format(C) 52 | # [B, 2, H, W] 53 | pixel_coord = mesh_grid(B, H, W, device, dtype) 54 | X = pixel_coord[:, 0, :, :] + motion[:, 0] 55 | Y = pixel_coord[:, 1, :, :] + motion[:, 1] 56 | elif mode == 'depth': 57 | assert C == 1, "Disparity map must be 1 channel, but {} got!".format(C) 58 | outs = project_to_3d(motion, K, inv_K, T_target_to_source, eps) 59 | output.update(outs) 60 | src_pixel_coord = outs['src_pixel_coord'] 61 | X = src_pixel_coord[:, 0, :, :] 62 | Y = src_pixel_coord[:, 1, :, :] 63 | 64 | else: 65 | raise TypeError("Inverse warp only support [disparity, flow, depth] mode, but {} got".format(mode)) 66 | 67 | X_norm = 2 * X / (W-1) - 1 68 | Y_norm = 2 * Y / (H-1) - 1 69 | # [B, H, W, 2] 70 | pixel_coord_norm = torch.stack((X_norm, Y_norm), dim=3) 71 | 72 | projected_img = F.grid_sample(img, pixel_coord_norm, mode=interpolate_mode, padding_mode=padding_mode, align_corners=True) 73 | 74 | if output_all: 75 | return projected_img, output 76 | else: 77 | return projected_img 78 | 79 | 80 | def mesh_grid(b, h, w, device, dtype=torch.float): 81 | """ construct pixel coordination in an image""" 82 | # [1, H, W] copy 0-width for h times : x coord 83 | x_range = torch.arange(0, w, device=device, dtype=dtype).view(1, 1, 1, w).expand(b, 1, h, w) 84 | # [1, H, W] copy 0-height for w times : y coord 85 | y_range = torch.arange(0, h, device=device, dtype=dtype).view(1, 1, h, 1).expand(b, 1, h, w) 86 | 87 | # [b, 2, h, w] 88 | pixel_coord = torch.cat((x_range, y_range), dim=1) 89 | 90 | return pixel_coord 91 | 92 | def project_to_3d(depth, K, inv_K=None, T_target_to_source:torch.Tensor=None, eps=1e-7): 93 | """ 94 | project depth map to 3D space 95 | Args: 96 | depth: (Tensor): depth map(s), can be several depth maps concatenated at channel dimension 97 | [BatchSize, Channel, Height, Width] 98 | K: (Tensor): instrincs of camera 99 | [BatchSize, 3, 3] or [BatchSize, 4, 4] 100 | inv_K: (Optional, Tensor): invserse instrincs of camera 101 | [BatchSize, 3, 3] or [BatchSize, 4, 4] 102 | T_target_to_source: (Optional, Tensor): predicted transformation matrix from target image to source image frames 103 | [BatchSize, 4, 4] 104 | eps: (float): eplison value to avoid divide 0, default 1e-7 105 | 106 | Returns: Dict including 107 | homo_points_3d: (Tensor): the homogeneous points after depth project to 3D space, [x, y, z, 1] 108 | [BatchSize, 4, Channel*Height*Width] 109 | if T_target_to_source provided: 110 | 111 | triangular_depth: (Tensor): the depth map after the 3D points project to source camera 112 | [BatchSize, Channel, Height, Width] 113 | optical_flow: (Tensor): by 3D projection, the rigid flow can be got 114 | [BatchSize, Channel*2, Height, Width], to get the 2nd flow, index like [:, 2:4, :, :] 115 | flow_mask: (Tensor): the mask indicates which pixel's optical flow is valid 116 | [BatchSize, Channel, Height, Width] 117 | """ 118 | 119 | # support C >=1, for C > 1, it means several depth maps are concatenated at channel dimension 120 | B, C, H, W = depth.size() 121 | device = depth.device 122 | dtype = depth.dtype 123 | output = {} 124 | 125 | # [B, 2, H, W] 126 | pixel_coord = mesh_grid(B, H, W, device, dtype) 127 | ones = torch.ones(B, 1, H, W, device=device, dtype=dtype) 128 | # [B, 3, H, W], homogeneous coordination of image pixel, [x, y, 1] 129 | homo_pixel_coord = torch.cat((pixel_coord, ones), dim=1).contiguous() 130 | 131 | # [B, 3, H*W] -> [B, 3, C*H*W] 132 | homo_pixel_coord = homo_pixel_coord.view(B, 3, -1).repeat(1, 1, C).contiguous() 133 | # [B, C*H*W] -> [B, 1, C*H*W] 134 | depth = depth.view(B, -1).unsqueeze(dim=1).contiguous() 135 | if inv_K is None: 136 | inv_K = torch.inverse(K[:, :3, :3]) 137 | # [B, 3, C*H*W] 138 | points_3d = torch.matmul(inv_K[:, :3, :3], homo_pixel_coord) * depth 139 | ones = torch.ones(B, 1, C*H*W, device=device, dtype=dtype) 140 | # [B, 4, C*H*W], homogeneous coordiate, [x, y, z, 1] 141 | homo_points_3d = torch.cat((points_3d, ones), dim=1) 142 | output['homo_points_3d'] = homo_points_3d 143 | 144 | if T_target_to_source is not None: 145 | if K.shape[-1] == 3: 146 | new_K = torch.eye(4, device=device, dtype=dtype).unsqueeze(dim=0).repeat(B, 1, 1) 147 | new_K[:, :3, :3] = K[:, :3, :3] 148 | # [B, 3, 4] 149 | P = torch.matmul(new_K, T_target_to_source)[:, :3, :] 150 | else: 151 | # [B, 3, 4] 152 | P = torch.matmul(K, T_target_to_source)[:, :3, :] 153 | # [B, 3, C*H*W] 154 | src_points_3d = torch.matmul(P, homo_points_3d) 155 | 156 | # [B, C*H*W] -> [B, C, H, W], the depth map after 3D points projected to source camera 157 | triangular_depth = src_points_3d[:, -1, :].reshape(B, C, H, W).contiguous() 158 | output['triangular_depth'] = triangular_depth 159 | # [B, 2, C*H*W] 160 | src_pixel_coord = src_points_3d[:, :2, :] / (src_points_3d[:, 2:3, :] + eps) 161 | # [B, 2, C, H, W] -> [B, C, 2, H, W] 162 | src_pixel_coord = src_pixel_coord.reshape(B, 2, C, H, W).permute(0, 2, 1, 3, 4).contiguous() 163 | 164 | # [B, C, 1, H, W] 165 | mask = (src_pixel_coord[:, :, 0:1] >=0) & (src_pixel_coord[:, :, 0:1] <= W-1) \ 166 | & (src_pixel_coord[:, :, 1:2] >=0) & (src_pixel_coord[:, :, 1:2] <= H-1) 167 | 168 | # valid flow mask 169 | mask = mask.reshape(B, C, H, W).contiguous() 170 | output['flow_mask'] = mask 171 | # [B, C*2, H, W] 172 | src_pixel_coord = src_pixel_coord.reshape(B, C*2, H, W).contiguous() 173 | output['src_pixel_coord'] = src_pixel_coord 174 | # [B, C*2, H, W] 175 | optical_flow = src_pixel_coord - pixel_coord.repeat(1, C, 1, 1) 176 | output['optical_flow'] = optical_flow 177 | 178 | return output -------------------------------------------------------------------------------- /utils/evaluation/pixel_error.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import pdb 5 | FOCAL_LENGTH_X_BASELINE = { 6 | 'indoor_flying': 19.941772, 7 | 'outdoor_night': 19.651191, 8 | 'outdoor_day': 19.635287 9 | } 10 | 11 | 12 | def disparity_to_depth(disparity_image): 13 | 14 | 15 | unknown_disparity = disparity_image == float('inf') 16 | depth_image = \ 17 | FOCAL_LENGTH_X_BASELINE['indoor_flying'] / ( 18 | disparity_image + 1e-7) 19 | depth_image[unknown_disparity] = float('inf') 20 | return depth_image 21 | 22 | 23 | def calc_error(est_disp=None, gt_disp=None, lb=None, ub=None): 24 | """ 25 | Args: 26 | est_disp (Tensor): in [..., Height, Width] layout 27 | gt_disp (Tensor): in [..., Height, Width] layout 28 | lb (scalar): the lower bound of disparity you want to mask out 29 | ub (scalar): the upper bound of disparity you want to mask out 30 | Output: 31 | dict: the error of 1px, 2px, 3px, 5px, in percent, 32 | range [0,100] and average error epe 33 | """ 34 | 35 | 36 | error1 = torch.Tensor([0.]) 37 | error2 = torch.Tensor([0.]) 38 | error3 = torch.Tensor([0.]) 39 | error5 = torch.Tensor([0.]) 40 | epe = torch.Tensor([0.]) 41 | 42 | if (not torch.is_tensor(est_disp)) or (not torch.is_tensor(gt_disp)): 43 | return { 44 | '1px': error1 * 100, 45 | '2px': error2 * 100, 46 | '3px': error3 * 100, 47 | '5px': error5 * 100, 48 | 'epe': epe 49 | } 50 | 51 | assert torch.is_tensor(est_disp) and torch.is_tensor(gt_disp) 52 | assert est_disp.shape == gt_disp.shape 53 | 54 | est_disp = est_disp.clone().cpu() 55 | gt_disp = gt_disp.clone().cpu() 56 | 57 | mask = torch.ones(gt_disp.shape, dtype=torch.bool, device=gt_disp.device) 58 | if lb is not None: 59 | mask = mask & (gt_disp >= lb) 60 | if ub is not None: 61 | mask = mask & (gt_disp <= ub) 62 | # mask = mask & (est_disp <= ub) 63 | 64 | mask.detach_() 65 | if abs(mask.float().sum()) < 1.0: 66 | return { 67 | '1px': error1 * 100, 68 | '2px': error2 * 100, 69 | '3px': error3 * 100, 70 | '5px': error5 * 100, 71 | 'epe': epe 72 | } 73 | ## original error compute 74 | gt_disp = gt_disp[mask] 75 | est_disp = est_disp[mask] 76 | abs_error = torch.abs(gt_disp - est_disp) 77 | 78 | ## hh error compute for top k ################### 79 | # abs_error = torch.abs(gt_disp - est_disp) * mask 80 | # pix_sum = torch.gt(abs_error, 1) 81 | # # error_sum = abs_error.sum(1).sum(1) 82 | # error_sum = pix_sum.sum(1).sum(1) 83 | # mask_sum = mask.sum(1).sum(1) 84 | 85 | # per_img_error = error_sum / mask_sum 86 | 87 | # values, indexs = torch.topk(-per_img_error, 1343) 88 | # gt_disp = gt_disp[indexs, :, :] 89 | # est_disp = est_disp[indexs, :, :] 90 | # mask = mask[indexs, :, :] 91 | # abs_error = torch.abs(gt_disp - est_disp) * mask 92 | ############################### 93 | 94 | 95 | total_num = mask.float().sum() 96 | 97 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num 98 | error2 = torch.sum(torch.gt(abs_error, 2).float()) / total_num 99 | error3 = torch.sum(torch.gt(abs_error, 3).float()) / total_num 100 | error5 = torch.sum(torch.gt(abs_error, 5).float()) / total_num 101 | epe = abs_error.float().mean() 102 | 103 | # .mean() will get a tensor with size: torch.Size([]), after decorate with torch.Tensor, the size will be: torch.Size([1]) 104 | return { 105 | '1px': torch.Tensor([error1 * 100]), 106 | '2px': torch.Tensor([error2 * 100]), 107 | '3px': torch.Tensor([error3 * 100]), 108 | '5px': torch.Tensor([error5 * 100]), 109 | 'epe': torch.Tensor([epe]), 110 | } 111 | 112 | 113 | def calc_error_test(est_disp=None, gt_disp=None, lb=None, ub=None): 114 | """ 115 | Args: 116 | est_disp (Tensor): in [..., Height, Width] layout 117 | gt_disp (Tensor): in [..., Height, Width] layout 118 | lb (scalar): the lower bound of disparity you want to mask out 119 | ub (scalar): the upper bound of disparity you want to mask out 120 | Output: 121 | dict: the error of 1px, 2px, 3px, 5px, in percent, 122 | range [0,100] and average error epe 123 | """ 124 | 125 | error1 = torch.Tensor([0.]) 126 | # mean_depth = torch.Tensor([0.]) 127 | # median_depth = torch.Tensor([0.]) 128 | epe = torch.Tensor([0.]) 129 | 130 | if (not torch.is_tensor(est_disp)) or (not torch.is_tensor(gt_disp)): 131 | return { 132 | '1px': error1 * 100, 133 | 'epe': epe 134 | } 135 | 136 | assert torch.is_tensor(est_disp) and torch.is_tensor(gt_disp) 137 | assert est_disp.shape == gt_disp.shape 138 | 139 | est_disp = est_disp.clone().cpu() 140 | gt_disp = gt_disp.clone().cpu() 141 | 142 | gt_mask = gt_disp > ub 143 | nan_gt_mask = torch.isnan(gt_disp) 144 | gt_disp[gt_mask] = float('inf') 145 | gt_disp[nan_gt_mask] = float('inf') 146 | 147 | estimated_depth = disparity_to_depth(est_disp) 148 | ground_truth_depth = disparity_to_depth(gt_disp) 149 | 150 | 151 | mean_depth = compute_absolute_error(estimated_depth, ground_truth_depth)[1] 152 | median_depth = compute_absolute_error(estimated_depth, ground_truth_depth, use_mean=False)[1] 153 | 154 | # est_mask = est_disp > ub 155 | # est_disp[est_mask] = float('inf') 156 | 157 | 158 | 159 | binary_error_map, one_pixel_error = compute_n_pixels_error(est_disp, gt_disp, n=1.0) 160 | 161 | 162 | mean_disparity_error = compute_absolute_error(est_disp, gt_disp)[1] 163 | 164 | 165 | 166 | mask = torch.ones(gt_disp.shape, dtype=torch.bool, device=gt_disp.device) 167 | if lb is not None: 168 | mask = mask & (gt_disp >= lb) 169 | if ub is not None: 170 | mask = mask & (gt_disp <= ub) 171 | # mask = mask & (est_disp <= ub) 172 | 173 | mask.detach_() 174 | if abs(mask.float().sum()) < 1.0: 175 | return { 176 | '1px': one_pixel_error, 177 | 'mean_depth': mean_depth * 100, 178 | 'median_depth': median_depth * 100, 179 | 'epe': mean_disparity_error 180 | } 181 | ## original error compute 182 | gt_disp = gt_disp[mask] 183 | est_disp = est_disp[mask] 184 | abs_error = torch.abs(gt_disp - est_disp) 185 | 186 | ## hh error compute for top k ################### 187 | # abs_error = torch.abs(gt_disp - est_disp) * mask 188 | # pix_sum = torch.gt(abs_error, 1) 189 | # # error_sum = abs_error.sum(1).sum(1) 190 | # error_sum = pix_sum.sum(1).sum(1) 191 | # mask_sum = mask.sum(1).sum(1) 192 | 193 | # per_img_error = error_sum / mask_sum 194 | 195 | # values, indexs = torch.topk(-per_img_error, 1343) 196 | # gt_disp = gt_disp[indexs, :, :] 197 | # est_disp = est_disp[indexs, :, :] 198 | # mask = mask[indexs, :, :] 199 | # abs_error = torch.abs(gt_disp - est_disp) * mask 200 | ############################### 201 | 202 | 203 | total_num = mask.float().sum() 204 | 205 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num 206 | 207 | epe = abs_error.float().mean() 208 | 209 | # .mean() will get a tensor with size: torch.Size([]), after decorate with torch.Tensor, the size will be: torch.Size([1]) 210 | return { 211 | '1px': one_pixel_error, 212 | 'mean_depth': mean_depth * 100, 213 | 'median_depth': median_depth * 100, 214 | 'epe': mean_disparity_error, 215 | } 216 | 217 | 218 | def compute_absolute_error(estimated_disparity, 219 | ground_truth_disparity, 220 | use_mean=True): 221 | """Returns pixel-wise and mean absolute error. 222 | 223 | Locations where ground truth is not avaliable do not contribute to mean 224 | absolute error. In such locations pixel-wise error is shown as zero. 225 | If ground truth is not avaliable in all locations, function returns 0. 226 | 227 | Args: 228 | ground_truth_disparity: ground truth disparity where locations with 229 | unknow disparity are set to inf's. 230 | estimated_disparity: estimated disparity. 231 | use_mean: if True than use mean to average pixelwise errors, 232 | otherwise use median. 233 | """ 234 | absolute_difference = (estimated_disparity - ground_truth_disparity).abs() 235 | locations_without_ground_truth = torch.isinf(ground_truth_disparity) 236 | pixelwise_absolute_error = absolute_difference.clone() 237 | pixelwise_absolute_error[locations_without_ground_truth] = 0 238 | absolute_differece_with_ground_truth = absolute_difference[ 239 | ~locations_without_ground_truth] 240 | if absolute_differece_with_ground_truth.numel() == 0: 241 | average_absolute_error = 0.0 242 | else: 243 | if use_mean: 244 | average_absolute_error = absolute_differece_with_ground_truth.mean( 245 | ).item() 246 | else: 247 | average_absolute_error = absolute_differece_with_ground_truth.median( 248 | ).item() 249 | return pixelwise_absolute_error, average_absolute_error 250 | 251 | 252 | 253 | def compute_n_pixels_error(estimated_disparity, ground_truth_disparity, n=3.0): 254 | """Return pixel-wise n-pixels error and % of pixels with n-pixels error. 255 | 256 | Locations where ground truth is not avaliable do not contribute to mean 257 | n-pixel error. In such locations pixel-wise error is shown as zero. 258 | 259 | Note that n-pixel error is equal to one if 260 | |estimated_disparity-ground_truth_disparity| > n and zero otherwise. 261 | 262 | If ground truth is not avaliable in all locations, function returns 0. 263 | 264 | Args: 265 | ground_truth_disparity: ground truth disparity where locations with 266 | unknow disparity are set to inf's. 267 | estimated_disparity: estimated disparity. 268 | n: maximum absolute disparity difference, that does not trigger 269 | n-pixel error. 270 | """ 271 | locations_without_ground_truth = torch.isinf(ground_truth_disparity) 272 | more_than_n_pixels_absolute_difference = ( 273 | estimated_disparity - ground_truth_disparity).abs().gt(n).float() 274 | pixelwise_n_pixels_error = more_than_n_pixels_absolute_difference.clone() 275 | pixelwise_n_pixels_error[locations_without_ground_truth] = 0.0 276 | more_than_n_pixels_absolute_difference_with_ground_truth = \ 277 | more_than_n_pixels_absolute_difference[~locations_without_ground_truth] 278 | if more_than_n_pixels_absolute_difference_with_ground_truth.numel() == 0: 279 | percentage_of_pixels_with_error = 0.0 280 | else: 281 | percentage_of_pixels_with_error = \ 282 | more_than_n_pixels_absolute_difference_with_ground_truth.mean( 283 | ).item() * 100 284 | return pixelwise_n_pixels_error, percentage_of_pixels_with_error -------------------------------------------------------------------------------- /utils/evaluation_old/__init__.py: -------------------------------------------------------------------------------- 1 | from .pixel_error import calc_error 2 | from .eval import do_evaluation, do_occlusion_evaluation 3 | from .flow_pixel_error import flow_calc_error 4 | from .flow_eval import do_flow_evaluation 5 | from .inverse_warp import inverse_warp -------------------------------------------------------------------------------- /utils/evaluation_old/eval.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | from .inverse_warp import inverse_warp 6 | from .pixel_error import calc_error 7 | 8 | 9 | def do_evaluation(est_disp, gt_disp, lb, ub): 10 | """ 11 | Do pixel error evaluation. (See KITTI evaluation protocols for details.) 12 | Args: 13 | est_disp: (Tensor), estimated disparity map, 14 | [..., Height, Width] layout 15 | gt_disp: (Tensor), ground truth disparity map 16 | [..., Height, Width] layout 17 | lb: (scalar), the lower bound of disparity you want to mask out 18 | ub: (scalar), the upper bound of disparity you want to mask out 19 | Returns: 20 | error_dict: (dict), the error of 1px, 2px, 3px, 5px, in percent, 21 | range [0,100] and average error epe 22 | """ 23 | error_dict = {} 24 | if est_disp is None: 25 | warnings.warn('Estimated disparity map is None') 26 | return error_dict 27 | if gt_disp is None: 28 | warnings.warn('Reference ground truth disparity map is None') 29 | return error_dict 30 | 31 | if torch.is_tensor(est_disp): 32 | est_disp = est_disp.clone().cpu() 33 | 34 | if torch.is_tensor(gt_disp): 35 | gt_disp = gt_disp.clone().cpu() 36 | 37 | assert est_disp.shape == gt_disp.shape, "Estimated Disparity map with shape: {}, but GroundTruth Disparity map" \ 38 | " with shape: {}".format(est_disp.shape, gt_disp.shape) 39 | 40 | error_dict = calc_error(est_disp, gt_disp, lb=lb, ub=ub) 41 | 42 | return error_dict 43 | 44 | 45 | def do_occlusion_evaluation(est_disp, ref_gt_disp, target_gt_disp, lb, ub): 46 | """ 47 | Do occlusoin evaluation. 48 | Args: 49 | est_disp: (Tensor), estimated disparity map 50 | [BatchSize, 1, Height, Width] layout 51 | ref_gt_disp: (Tensor), reference(left) ground truth disparity map 52 | [BatchSize, 1, Height, Width] layout 53 | target_gt_disp: (Tensor), target(right) ground truth disparity map, 54 | [BatchSize, 1, Height, Width] layout 55 | lb: (scalar): the lower bound of disparity you want to mask out 56 | ub: (scalar): the upper bound of disparity you want to mask out 57 | Returns: 58 | """ 59 | error_dict = {} 60 | if est_disp is None: 61 | warnings.warn('Estimated disparity map is None, expected given') 62 | return error_dict 63 | if ref_gt_disp is None: 64 | warnings.warn('Reference ground truth disparity map is None, expected given') 65 | return error_dict 66 | if target_gt_disp is None: 67 | warnings.warn('Target ground truth disparity map is None, expected given') 68 | return error_dict 69 | 70 | if torch.is_tensor(est_disp): 71 | est_disp = est_disp.clone().cpu() 72 | if torch.is_tensor(ref_gt_disp): 73 | ref_gt_disp = ref_gt_disp.clone().cpu() 74 | if torch.is_tensor(target_gt_disp): 75 | target_gt_disp = target_gt_disp.clone().cpu() 76 | 77 | assert est_disp.shape == ref_gt_disp.shape and target_gt_disp.shape == ref_gt_disp.shape, "{}, {}, {}".format( 78 | est_disp.shape, ref_gt_disp.shape, target_gt_disp.shape) 79 | 80 | warp_ref_gt_disp = inverse_warp(target_gt_disp.clone(), -ref_gt_disp.clone(), mode='disparity') 81 | theta = 1.0 82 | eps = 1e-6 83 | occlusion = ( 84 | (torch.abs(warp_ref_gt_disp.clone() - ref_gt_disp.clone()) > theta) | 85 | (torch.abs(warp_ref_gt_disp.clone()) < eps) 86 | ).prod(dim=1, keepdim=True).type_as(ref_gt_disp) 87 | occlusion = occlusion.clamp(0, 1) 88 | 89 | occlusion_error_dict = calc_error( 90 | est_disp.clone() * occlusion.clone(), 91 | ref_gt_disp.clone() * occlusion.clone(), 92 | lb=lb, ub=ub 93 | ) 94 | for key in occlusion_error_dict.keys(): 95 | error_dict['occ_' + key] = occlusion_error_dict[key] 96 | 97 | not_occlusion = 1.0 - occlusion 98 | not_occlusion_error_dict = calc_error( 99 | est_disp.clone() * not_occlusion.clone(), 100 | ref_gt_disp.clone() * not_occlusion.clone(), 101 | lb=lb, ub=ub 102 | ) 103 | for key in not_occlusion_error_dict.keys(): 104 | error_dict['noc_' + key] = not_occlusion_error_dict[key] 105 | 106 | return error_dict -------------------------------------------------------------------------------- /utils/evaluation_old/flow_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | 4 | from .flow_pixel_error import flow_calc_error 5 | 6 | def do_flow_evaluation(est_flow, gt_flow, lb=0.0, ub=400, sparse=False): 7 | """ 8 | Do pixel error evaluation. (See KITTI evaluation protocols for details.) 9 | Args: 10 | est_flow: (Tensor), estimated flow map 11 | [..., 2, Height, Width] layout 12 | gt_flow: (Tensor), ground truth flow map 13 | [..., 2, Height, Width] layout 14 | lb: (scalar), the lower bound of disparity you want to mask out 15 | ub: (scalar), the upper bound of disparity you want to mask out 16 | sparse: (bool), whether the given flow is sparse, default False 17 | Returns: 18 | error_dict (dict): the error of 1px, 2px, 3px, 5px, in percent, 19 | range [0,100] and average error epe 20 | """ 21 | error_dict = {} 22 | if est_flow is None: 23 | warnings.warn('Estimated flow map is None') 24 | return error_dict 25 | if gt_flow is None: 26 | warnings.warn('Reference ground truth flow map is None') 27 | return error_dict 28 | 29 | if torch.is_tensor(est_flow): 30 | est_flow = est_flow.clone().cpu() 31 | 32 | if torch.is_tensor(gt_flow): 33 | gt_flow = gt_flow.clone().cpu() 34 | 35 | error_dict = flow_calc_error(est_flow, gt_flow, sparse=sparse) 36 | 37 | return error_dict 38 | -------------------------------------------------------------------------------- /utils/evaluation_old/flow_pixel_error.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | def zero_mask(input, eps=1e-12): 6 | mask = abs(input) < eps 7 | return mask 8 | 9 | def flow_calc_error(est_flow=None, gt_flow=None, lb=0.0, ub=400, sparse=False): 10 | """ 11 | Args: 12 | est_flow: (Tensor), estimated flow map 13 | [..., 2, Height, Width] layout 14 | gt_flow: (Tensor), ground truth flow map 15 | [..., 2, Height, Width] layout 16 | lb: (scalar), the lower bound of disparity you want to mask out 17 | ub: (scalar), the upper bound of disparity you want to mask out 18 | sparse: (bool), whether the given flow is sparse, default False 19 | Output: 20 | dict: the error of 1px, 2px, 3px, 5px, in percent, 21 | range [0,100] and average error epe 22 | """ 23 | error1 = torch.Tensor([0.]) 24 | error2 = torch.Tensor([0.]) 25 | error3 = torch.Tensor([0.]) 26 | error5 = torch.Tensor([0.]) 27 | epe = torch.Tensor([0.]) 28 | 29 | if (not torch.is_tensor(est_flow)) or (not torch.is_tensor(gt_flow)): 30 | return { 31 | '1px': error1 * 100, 32 | '2px': error2 * 100, 33 | '3px': error3 * 100, 34 | '5px': error5 * 100, 35 | 'epe': epe 36 | } 37 | 38 | assert torch.is_tensor(est_flow) and torch.is_tensor(gt_flow) 39 | assert est_flow.shape == gt_flow.shape 40 | 41 | est_flow = est_flow.clone().cpu() 42 | gt_flow = gt_flow.clone().cpu() 43 | if len(gt_flow.shape) == 3: 44 | gt_flow = gt_flow.unsqueeze(0) 45 | est_flow = est_flow.unsqueeze(0) 46 | 47 | assert gt_flow.shape[1] == 2, "flow should have horizontal and vertical dimension, " \ 48 | "but got {}".format(gt_flow.shape[1]) 49 | 50 | # [B, 1, H, W] 51 | gt_u, gt_v = gt_flow[:, 0:1, :, :], gt_flow[:, 1:2, :, :] 52 | est_u, est_v = est_flow[:, 0:1, :, :], est_flow[:, 1:2, :, :] 53 | 54 | # get valid mask 55 | # [B, 1, H, W] 56 | mask = torch.ones(gt_u.shape, dtype=torch.bool, device=gt_u.device) 57 | if sparse: 58 | mask = mask & (~(zero_mask(gt_u) & zero_mask(gt_v))) 59 | mask = mask & (~(torch.isnan(gt_u) | torch.isnan(gt_v))) 60 | 61 | rad = torch.sqrt(gt_u**2 + gt_v**2) 62 | mask = mask & (rad > lb) & (rad < ub) 63 | 64 | mask.detach_() 65 | if abs(mask.float().sum()) < 1.0: 66 | return { 67 | '1px': error1 * 100, 68 | '2px': error2 * 100, 69 | '3px': error3 * 100, 70 | '5px': error5 * 100, 71 | 'epe': epe 72 | } 73 | 74 | 75 | 76 | gt_u = gt_u[mask] 77 | gt_v = gt_v[mask] 78 | est_u = est_u[mask] 79 | est_v = est_v[mask] 80 | 81 | abs_error = torch.sqrt((gt_u - est_u)**2 + (gt_v - est_v)**2) 82 | total_num = mask.float().sum() 83 | 84 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num 85 | error2 = torch.sum(torch.gt(abs_error, 2).float()) / total_num 86 | error3 = torch.sum(torch.gt(abs_error, 3).float()) / total_num 87 | error5 = torch.sum(torch.gt(abs_error, 5).float()) / total_num 88 | epe = abs_error.float().mean() 89 | 90 | return { 91 | '1px': error1 * 100, 92 | '2px': error2 * 100, 93 | '3px': error3 * 100, 94 | '5px': error5 * 100, 95 | 'epe': epe 96 | } -------------------------------------------------------------------------------- /utils/evaluation_old/inverse_warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def inverse_warp(img: torch.Tensor, 7 | motion: torch.Tensor, 8 | mode: str ='disparity', 9 | K: torch.Tensor = None, 10 | inv_K: torch.Tensor = None, 11 | T_target_to_source: torch.Tensor = None, 12 | interpolate_mode: str = 'bilinear', 13 | padding_mode: str = 'zeros', 14 | eps: float = 1e-7, 15 | output_all: bool = False): 16 | """ 17 | sample the image pixel value from source image and project to target image space, 18 | Args: 19 | img: (Tensor): the source image (where to sample pixels) 20 | [BatchSize, C, Height, Width] 21 | motion: (Tensor): disparity/depth/flow map of the target image 22 | [BatchSize, 1, Height, Width] 23 | mode: (str): which kind of warp to perform, including: ['disparity', 'depth', 'flow'] 24 | K: (Optional, Tensor): instrincs of camera 25 | [BatchSize, 3, 3] or [BatchSize, 4, 4] 26 | inv_K: (Optional, Tensor): invserse instrincs of camera 27 | [BatchSize, 3, 3] or [BatchSize, 4, 4] 28 | T_target_to_source: (Optional, Tensor): predicted transformation matrix from target image to source image frames 29 | [BatchSize, 4, 4] 30 | interpolate_mode: (str): interpolate mode when grid sample, default is bilinear 31 | padding_mode: (str): padding mode when grid sample, default is zero padding 32 | eps: (float): eplison value to avoid divide 0, default 1e-7 33 | output_all: (bool): if output all result during warp, default False 34 | Returns: 35 | projected_img: (Tensor): source image warped to the target image 36 | [BatchSize, C, Height, Width] 37 | output: (Optional, Dict): such as optical flow, flow mask, triangular depth, src_pixel_coord and so on 38 | """ 39 | B, C, H, W = motion.shape 40 | device = motion.device 41 | dtype = motion.dtype 42 | output = {} 43 | 44 | if mode == 'disparity': 45 | assert C == 1, "Disparity map must be 1 channel, but {} got!".format(C) 46 | # [B, 2, H, W] 47 | pixel_coord = mesh_grid(B, H, W, device, dtype) 48 | X = pixel_coord[:, 0, :, :] + motion[:, 0] 49 | Y = pixel_coord[:, 1, :, :] 50 | elif mode == 'flow': 51 | assert C == 2, "Optical flow map must be 2 channel, but {} got!".format(C) 52 | # [B, 2, H, W] 53 | pixel_coord = mesh_grid(B, H, W, device, dtype) 54 | X = pixel_coord[:, 0, :, :] + motion[:, 0] 55 | Y = pixel_coord[:, 1, :, :] + motion[:, 1] 56 | elif mode == 'depth': 57 | assert C == 1, "Disparity map must be 1 channel, but {} got!".format(C) 58 | outs = project_to_3d(motion, K, inv_K, T_target_to_source, eps) 59 | output.update(outs) 60 | src_pixel_coord = outs['src_pixel_coord'] 61 | X = src_pixel_coord[:, 0, :, :] 62 | Y = src_pixel_coord[:, 1, :, :] 63 | 64 | else: 65 | raise TypeError("Inverse warp only support [disparity, flow, depth] mode, but {} got".format(mode)) 66 | 67 | X_norm = 2 * X / (W-1) - 1 68 | Y_norm = 2 * Y / (H-1) - 1 69 | # [B, H, W, 2] 70 | pixel_coord_norm = torch.stack((X_norm, Y_norm), dim=3) 71 | 72 | projected_img = F.grid_sample(img, pixel_coord_norm, mode=interpolate_mode, padding_mode=padding_mode, align_corners=True) 73 | 74 | if output_all: 75 | return projected_img, output 76 | else: 77 | return projected_img 78 | 79 | 80 | def mesh_grid(b, h, w, device, dtype=torch.float): 81 | """ construct pixel coordination in an image""" 82 | # [1, H, W] copy 0-width for h times : x coord 83 | x_range = torch.arange(0, w, device=device, dtype=dtype).view(1, 1, 1, w).expand(b, 1, h, w) 84 | # [1, H, W] copy 0-height for w times : y coord 85 | y_range = torch.arange(0, h, device=device, dtype=dtype).view(1, 1, h, 1).expand(b, 1, h, w) 86 | 87 | # [b, 2, h, w] 88 | pixel_coord = torch.cat((x_range, y_range), dim=1) 89 | 90 | return pixel_coord 91 | 92 | def project_to_3d(depth, K, inv_K=None, T_target_to_source:torch.Tensor=None, eps=1e-7): 93 | """ 94 | project depth map to 3D space 95 | Args: 96 | depth: (Tensor): depth map(s), can be several depth maps concatenated at channel dimension 97 | [BatchSize, Channel, Height, Width] 98 | K: (Tensor): instrincs of camera 99 | [BatchSize, 3, 3] or [BatchSize, 4, 4] 100 | inv_K: (Optional, Tensor): invserse instrincs of camera 101 | [BatchSize, 3, 3] or [BatchSize, 4, 4] 102 | T_target_to_source: (Optional, Tensor): predicted transformation matrix from target image to source image frames 103 | [BatchSize, 4, 4] 104 | eps: (float): eplison value to avoid divide 0, default 1e-7 105 | 106 | Returns: Dict including 107 | homo_points_3d: (Tensor): the homogeneous points after depth project to 3D space, [x, y, z, 1] 108 | [BatchSize, 4, Channel*Height*Width] 109 | if T_target_to_source provided: 110 | 111 | triangular_depth: (Tensor): the depth map after the 3D points project to source camera 112 | [BatchSize, Channel, Height, Width] 113 | optical_flow: (Tensor): by 3D projection, the rigid flow can be got 114 | [BatchSize, Channel*2, Height, Width], to get the 2nd flow, index like [:, 2:4, :, :] 115 | flow_mask: (Tensor): the mask indicates which pixel's optical flow is valid 116 | [BatchSize, Channel, Height, Width] 117 | """ 118 | 119 | # support C >=1, for C > 1, it means several depth maps are concatenated at channel dimension 120 | B, C, H, W = depth.size() 121 | device = depth.device 122 | dtype = depth.dtype 123 | output = {} 124 | 125 | # [B, 2, H, W] 126 | pixel_coord = mesh_grid(B, H, W, device, dtype) 127 | ones = torch.ones(B, 1, H, W, device=device, dtype=dtype) 128 | # [B, 3, H, W], homogeneous coordination of image pixel, [x, y, 1] 129 | homo_pixel_coord = torch.cat((pixel_coord, ones), dim=1).contiguous() 130 | 131 | # [B, 3, H*W] -> [B, 3, C*H*W] 132 | homo_pixel_coord = homo_pixel_coord.view(B, 3, -1).repeat(1, 1, C).contiguous() 133 | # [B, C*H*W] -> [B, 1, C*H*W] 134 | depth = depth.view(B, -1).unsqueeze(dim=1).contiguous() 135 | if inv_K is None: 136 | inv_K = torch.inverse(K[:, :3, :3]) 137 | # [B, 3, C*H*W] 138 | points_3d = torch.matmul(inv_K[:, :3, :3], homo_pixel_coord) * depth 139 | ones = torch.ones(B, 1, C*H*W, device=device, dtype=dtype) 140 | # [B, 4, C*H*W], homogeneous coordiate, [x, y, z, 1] 141 | homo_points_3d = torch.cat((points_3d, ones), dim=1) 142 | output['homo_points_3d'] = homo_points_3d 143 | 144 | if T_target_to_source is not None: 145 | if K.shape[-1] == 3: 146 | new_K = torch.eye(4, device=device, dtype=dtype).unsqueeze(dim=0).repeat(B, 1, 1) 147 | new_K[:, :3, :3] = K[:, :3, :3] 148 | # [B, 3, 4] 149 | P = torch.matmul(new_K, T_target_to_source)[:, :3, :] 150 | else: 151 | # [B, 3, 4] 152 | P = torch.matmul(K, T_target_to_source)[:, :3, :] 153 | # [B, 3, C*H*W] 154 | src_points_3d = torch.matmul(P, homo_points_3d) 155 | 156 | # [B, C*H*W] -> [B, C, H, W], the depth map after 3D points projected to source camera 157 | triangular_depth = src_points_3d[:, -1, :].reshape(B, C, H, W).contiguous() 158 | output['triangular_depth'] = triangular_depth 159 | # [B, 2, C*H*W] 160 | src_pixel_coord = src_points_3d[:, :2, :] / (src_points_3d[:, 2:3, :] + eps) 161 | # [B, 2, C, H, W] -> [B, C, 2, H, W] 162 | src_pixel_coord = src_pixel_coord.reshape(B, 2, C, H, W).permute(0, 2, 1, 3, 4).contiguous() 163 | 164 | # [B, C, 1, H, W] 165 | mask = (src_pixel_coord[:, :, 0:1] >=0) & (src_pixel_coord[:, :, 0:1] <= W-1) \ 166 | & (src_pixel_coord[:, :, 1:2] >=0) & (src_pixel_coord[:, :, 1:2] <= H-1) 167 | 168 | # valid flow mask 169 | mask = mask.reshape(B, C, H, W).contiguous() 170 | output['flow_mask'] = mask 171 | # [B, C*2, H, W] 172 | src_pixel_coord = src_pixel_coord.reshape(B, C*2, H, W).contiguous() 173 | output['src_pixel_coord'] = src_pixel_coord 174 | # [B, C*2, H, W] 175 | optical_flow = src_pixel_coord - pixel_coord.repeat(1, C, 1, 1) 176 | output['optical_flow'] = optical_flow 177 | 178 | return output -------------------------------------------------------------------------------- /utils/evaluation_old/pixel_error.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | 6 | def calc_error(est_disp=None, gt_disp=None, lb=None, ub=None): 7 | """ 8 | Args: 9 | est_disp (Tensor): in [..., Height, Width] layout 10 | gt_disp (Tensor): in [..., Height, Width] layout 11 | lb (scalar): the lower bound of disparity you want to mask out 12 | ub (scalar): the upper bound of disparity you want to mask out 13 | Output: 14 | dict: the error of 1px, 2px, 3px, 5px, in percent, 15 | range [0,100] and average error epe 16 | """ 17 | error1 = torch.Tensor([0.]) 18 | error2 = torch.Tensor([0.]) 19 | error3 = torch.Tensor([0.]) 20 | error5 = torch.Tensor([0.]) 21 | epe = torch.Tensor([0.]) 22 | 23 | if (not torch.is_tensor(est_disp)) or (not torch.is_tensor(gt_disp)): 24 | return { 25 | '1px': error1 * 100, 26 | '2px': error2 * 100, 27 | '3px': error3 * 100, 28 | '5px': error5 * 100, 29 | 'epe': epe 30 | } 31 | 32 | assert torch.is_tensor(est_disp) and torch.is_tensor(gt_disp) 33 | assert est_disp.shape == gt_disp.shape 34 | 35 | est_disp = est_disp.clone().cpu() 36 | gt_disp = gt_disp.clone().cpu() 37 | 38 | mask = torch.ones(gt_disp.shape, dtype=torch.bool, device=gt_disp.device) 39 | if lb is not None: 40 | mask = mask & (gt_disp >= lb) 41 | if ub is not None: 42 | mask = mask & (gt_disp <= ub) 43 | # mask = mask & (est_disp <= ub) 44 | 45 | mask.detach_() 46 | if abs(mask.float().sum()) < 1.0: 47 | return { 48 | '1px': error1 * 100, 49 | '2px': error2 * 100, 50 | '3px': error3 * 100, 51 | '5px': error5 * 100, 52 | 'epe': epe 53 | } 54 | ## original error compute 55 | gt_disp = gt_disp[mask] 56 | est_disp = est_disp[mask] 57 | abs_error = torch.abs(gt_disp - est_disp) 58 | 59 | ## hh error compute for top k ################### 60 | # abs_error = torch.abs(gt_disp - est_disp) * mask 61 | # pix_sum = torch.gt(abs_error, 1) 62 | # # error_sum = abs_error.sum(1).sum(1) 63 | # error_sum = pix_sum.sum(1).sum(1) 64 | # mask_sum = mask.sum(1).sum(1) 65 | 66 | # per_img_error = error_sum / mask_sum 67 | 68 | # values, indexs = torch.topk(-per_img_error, 1343) 69 | # gt_disp = gt_disp[indexs, :, :] 70 | # est_disp = est_disp[indexs, :, :] 71 | # mask = mask[indexs, :, :] 72 | # abs_error = torch.abs(gt_disp - est_disp) * mask 73 | ############################### 74 | 75 | 76 | total_num = mask.float().sum() 77 | 78 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num 79 | error2 = torch.sum(torch.gt(abs_error, 2).float()) / total_num 80 | error3 = torch.sum(torch.gt(abs_error, 3).float()) / total_num 81 | error5 = torch.sum(torch.gt(abs_error, 5).float()) / total_num 82 | epe = abs_error.float().mean() 83 | 84 | # .mean() will get a tensor with size: torch.Size([]), after decorate with torch.Tensor, the size will be: torch.Size([1]) 85 | return { 86 | '1px': torch.Tensor([error1 * 100]), 87 | '2px': torch.Tensor([error2 * 100]), 88 | '3px': torch.Tensor([error3 * 100]), 89 | '5px': torch.Tensor([error5 * 100]), 90 | 'epe': torch.Tensor([epe]), 91 | } -------------------------------------------------------------------------------- /utils/flow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 10 | sys.path.append(parent_dir_name) 11 | 12 | from utils.iwe import purge_unfeasible, get_interpolation, interpolate 13 | 14 | 15 | class EventWarping(nn.Module): 16 | """ 17 | Contrast maximization loss, as described in Section 3.2 of the paper 'Unsupervised Event-based Learning 18 | of Optical Flow, Depth, and Egomotion', Zhu et al., CVPR'19. 19 | The contrast maximization loss is the minimization of the per-pixel and per-polarity image of averaged 20 | timestamps of the input events after they have been compensated for their motion using the estimated 21 | optical flow. This minimization is performed in a forward and in a backward fashion to prevent scaling 22 | issues during backpropagation. 23 | """ 24 | 25 | def __init__(self, res, config): 26 | super(EventWarping, self).__init__() 27 | self.res = res 28 | self.flow_scaling = 1.0 29 | # self.flow_scaling = max(res) 30 | self.weight = config.get("flow_regul_weight", 1.0) 31 | # self.weight = 0.0 32 | 33 | 34 | def forward(self, flow_list, event_list, pol_mask): 35 | """ 36 | :param flow_list: [[batch_size x 2 x H x W]] list of optical flow maps 37 | :param event_list: [batch_size x N x 4] input events (y, x, ts, p) 38 | :param pol_mask: [batch_size x N x 2] per-polarity binary mask of the input events 39 | """ 40 | # |-------------->x 41 | # | 42 | # | 43 | # | 44 | # V 45 | # y 46 | event_list[: ,:, 0:1] = (1 - event_list[: ,:, 0:1]) 47 | 48 | 49 | 50 | 51 | # import pdb; pdb.set_trace() 52 | 53 | # split input 54 | pol_mask = torch.cat([pol_mask for i in range(4)], dim=1) 55 | ts_list = torch.cat([event_list[:, :, 0:1] for i in range(4)], dim=1) 56 | 57 | # flow vector per input event 58 | flow_idx = event_list[:, :, 1:3].clone() 59 | flow_idx[:, :, 0] *= self.res[1] # torch.view is row-major 60 | flow_idx = torch.sum(flow_idx, dim=2) 61 | 62 | 63 | loss = 0 64 | for flow in flow_list: 65 | flow = flow.contiguous() 66 | # get flow for every event in the list 67 | flow = flow.view(flow.shape[0], 2, -1) 68 | event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component 69 | event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component 70 | # event_flowy = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # vertical component 71 | # event_flowx = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # horizontal component 72 | 73 | event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1) 74 | event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1) 75 | event_flow = torch.cat([event_flowy, event_flowx], dim=2) 76 | # B X N x 2 (x, ) 77 | # interpolate forward 78 | tref = 1 79 | fw_idx, fw_weights = get_interpolation(event_list, event_flow, tref, self.res, self.flow_scaling) 80 | 81 | # per-polarity image of (forward) warped events 82 | fw_iwe_pos = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 0:1]) 83 | fw_iwe_neg = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 1:2]) 84 | 85 | # image of (forward) warped averaged timestamps 86 | fw_iwe_pos_ts = interpolate( 87 | fw_idx.long(), fw_weights * ts_list, self.res, polarity_mask=pol_mask[:, :, 0:1] 88 | ) 89 | fw_iwe_neg_ts = interpolate( 90 | fw_idx.long(), fw_weights * ts_list, self.res, polarity_mask=pol_mask[:, :, 1:2] 91 | ) 92 | fw_iwe_pos_ts /= fw_iwe_pos + 1e-9 93 | fw_iwe_neg_ts /= fw_iwe_neg + 1e-9 94 | 95 | # interpolate backward 96 | tref = 0 97 | bw_idx, bw_weights = get_interpolation(event_list, event_flow, tref, self.res, self.flow_scaling) 98 | 99 | # per-polarity image of (backward) warped events 100 | bw_iwe_pos = interpolate(bw_idx.long(), bw_weights, self.res, polarity_mask=pol_mask[:, :, 0:1]) 101 | bw_iwe_neg = interpolate(bw_idx.long(), bw_weights, self.res, polarity_mask=pol_mask[:, :, 1:2]) 102 | 103 | # image of (backward) warped averaged timestamps 104 | bw_iwe_pos_ts = interpolate( 105 | bw_idx.long(), bw_weights * (1 - ts_list), self.res, polarity_mask=pol_mask[:, :, 0:1] 106 | ) 107 | bw_iwe_neg_ts = interpolate( 108 | bw_idx.long(), bw_weights * (1 - ts_list), self.res, polarity_mask=pol_mask[:, :, 1:2] 109 | ) 110 | bw_iwe_pos_ts /= bw_iwe_pos + 1e-9 111 | bw_iwe_neg_ts /= bw_iwe_neg + 1e-9 112 | 113 | # flow smoothing 114 | flow = flow.view(flow.shape[0], 2, self.res[0], self.res[1]) 115 | 116 | flow = flow / max(self.res) 117 | 118 | flow_dx = flow[:, :, :-1, :] - flow[:, :, 1:, :] 119 | flow_dy = flow[:, :, :, :-1] - flow[:, :, :, 1:] 120 | flow_dx = torch.sqrt(flow_dx ** 2 + 1e-6) # charbonnier 121 | flow_dy = torch.sqrt(flow_dy ** 2 + 1e-6) # charbonnier 122 | loss += ( 123 | torch.sum(fw_iwe_pos_ts ** 2) 124 | + torch.sum(fw_iwe_neg_ts ** 2) 125 | + torch.sum(bw_iwe_pos_ts ** 2) 126 | + torch.sum(bw_iwe_neg_ts ** 2) 127 | + self.weight * (flow_dx.sum() + flow_dy.sum()) 128 | ) 129 | 130 | return loss 131 | 132 | 133 | # class AveragedIWE(nn.Module): 134 | # """ 135 | # Returns an image of the per-pixel and per-polarity average number of warped events given 136 | # an optical flow map. 137 | # """ 138 | 139 | # def __init__(self, config, device): 140 | # super(AveragedIWE, self).__init__() 141 | # self.res = config["loader"]["resolution"] 142 | # self.flow_scaling = max(config["loader"]["resolution"]) 143 | # self.batch_size = config["loader"]["batch_size"] 144 | # self.device = device 145 | 146 | # def forward(self, flow, event_list, pol_mask): 147 | # """ 148 | # :param flow: [batch_size x 2 x H x W] optical flow maps 149 | # :param event_list: [batch_size x N x 4] input events (y, x, ts, p) 150 | # :param pol_mask: [batch_size x N x 2] per-polarity binary mask of the input events 151 | # """ 152 | 153 | # # original location of events 154 | # idx = event_list[:, :, 1:3].clone() 155 | # idx[:, :, 0] *= self.res[1] # torch.view is row-major 156 | # idx = torch.sum(idx, dim=2, keepdim=True) 157 | 158 | # # flow vector per input event 159 | # flow_idx = event_list[:, :, 1:3].clone() 160 | # flow_idx[:, :, 0] *= self.res[1] # torch.view is row-major 161 | # flow_idx = torch.sum(flow_idx, dim=2) 162 | 163 | # # get flow for every event in the list 164 | # flow = flow.view(flow.shape[0], 2, -1) 165 | # event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component 166 | # event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component 167 | # event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1) 168 | # event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1) 169 | # event_flow = torch.cat([event_flowy, event_flowx], dim=2) 170 | 171 | # # interpolate forward 172 | # fw_idx, fw_weights = get_interpolation(event_list, event_flow, 1, self.res, self.flow_scaling, round_idx=True) 173 | 174 | # # per-polarity image of (forward) warped events 175 | # fw_iwe_pos = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 0:1]) 176 | # fw_iwe_neg = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 1:2]) 177 | # if fw_idx.shape[1] == 0: 178 | # return torch.cat([fw_iwe_pos, fw_iwe_neg], dim=1) 179 | 180 | # # make sure unfeasible mappings are not considered 181 | # pol_list = event_list[:, :, 3:4].clone() 182 | # pol_list[pol_list < 1] = 0 # negative polarity set to 0 183 | # pol_list[fw_weights == 0] = 2 # fake polarity to detect unfeasible mappings 184 | 185 | # # encode unique ID for pixel location mapping (idx <-> fw_idx = m_idx) 186 | # m_idx = torch.cat([idx.long(), fw_idx.long()], dim=2) 187 | # m_idx[:, :, 0] *= self.res[0] * self.res[1] 188 | # m_idx = torch.sum(m_idx, dim=2, keepdim=True) 189 | 190 | # # encode unique ID for per-polarity pixel location mapping (pol_list <-> m_idx = pm_idx) 191 | # pm_idx = torch.cat([pol_list.long(), m_idx.long()], dim=2) 192 | # pm_idx[:, :, 0] *= (self.res[0] * self.res[1]) ** 2 193 | # pm_idx = torch.sum(pm_idx, dim=2, keepdim=True) 194 | 195 | # # number of different pixels locations from where pixels originate during warping 196 | # # this needs to be done per batch as the number of unique indices differs 197 | # fw_iwe_pos_contrib = torch.zeros((flow.shape[0], self.res[0] * self.res[1], 1)).to(self.device) 198 | # fw_iwe_neg_contrib = torch.zeros((flow.shape[0], self.res[0] * self.res[1], 1)).to(self.device) 199 | # for b in range(0, self.batch_size): 200 | 201 | # # per-polarity unique mapping combinations 202 | # unique_pm_idx = torch.unique(pm_idx[b, :, :], dim=0) 203 | # unique_pm_idx = torch.cat( 204 | # [ 205 | # unique_pm_idx // ((self.res[0] * self.res[1]) ** 2), 206 | # unique_pm_idx % ((self.res[0] * self.res[1]) ** 2), 207 | # ], 208 | # dim=1, 209 | # ) # (pol_idx, mapping_idx) 210 | # unique_pm_idx = torch.cat( 211 | # [unique_pm_idx[:, 0:1], unique_pm_idx[:, 1:2] % (self.res[0] * self.res[1])], dim=1 212 | # ) # (pol_idx, fw_idx) 213 | # unique_pm_idx[:, 0] *= self.res[0] * self.res[1] 214 | # unique_pm_idx = torch.sum(unique_pm_idx, dim=1, keepdim=True) 215 | 216 | # # per-polarity unique receiving pixels 217 | # unique_pfw_idx, contrib_pfw = torch.unique(unique_pm_idx[:, 0], dim=0, return_counts=True) 218 | # unique_pfw_idx = unique_pfw_idx.view((unique_pfw_idx.shape[0], 1)) 219 | # contrib_pfw = contrib_pfw.view((contrib_pfw.shape[0], 1)) 220 | # unique_pfw_idx = torch.cat( 221 | # [unique_pfw_idx // (self.res[0] * self.res[1]), unique_pfw_idx % (self.res[0] * self.res[1])], 222 | # dim=1, 223 | # ) # (polarity mask, fw_idx) 224 | 225 | # # positive scatter pixel contribution 226 | # mask_pos = unique_pfw_idx[:, 0:1].clone() 227 | # mask_pos[mask_pos == 2] = 0 # remove unfeasible mappings 228 | # b_fw_iwe_pos_contrib = torch.zeros((self.res[0] * self.res[1], 1)).to(self.device) 229 | # b_fw_iwe_pos_contrib = b_fw_iwe_pos_contrib.scatter_add_( 230 | # 0, unique_pfw_idx[:, 1:2], mask_pos.float() * contrib_pfw.float() 231 | # ) 232 | 233 | # # negative scatter pixel contribution 234 | # mask_neg = unique_pfw_idx[:, 0:1].clone() 235 | # mask_neg[mask_neg == 2] = 1 # remove unfeasible mappings 236 | # mask_neg = 1 - mask_neg # invert polarities 237 | # b_fw_iwe_neg_contrib = torch.zeros((self.res[0] * self.res[1], 1)).to(self.device) 238 | # b_fw_iwe_neg_contrib = b_fw_iwe_neg_contrib.scatter_add_( 239 | # 0, unique_pfw_idx[:, 1:2], mask_neg.float() * contrib_pfw.float() 240 | # ) 241 | 242 | # # store info 243 | # fw_iwe_pos_contrib[b, :, :] = b_fw_iwe_pos_contrib 244 | # fw_iwe_neg_contrib[b, :, :] = b_fw_iwe_neg_contrib 245 | 246 | # # average number of warped events per pixel 247 | # fw_iwe_pos_contrib = fw_iwe_pos_contrib.view((flow.shape[0], 1, self.res[0], self.res[1])) 248 | # fw_iwe_neg_contrib = fw_iwe_neg_contrib.view((flow.shape[0], 1, self.res[0], self.res[1])) 249 | # fw_iwe_pos[fw_iwe_pos_contrib > 0] /= fw_iwe_pos_contrib[fw_iwe_pos_contrib > 0] 250 | # fw_iwe_neg[fw_iwe_neg_contrib > 0] /= fw_iwe_neg_contrib[fw_iwe_neg_contrib > 0] 251 | 252 | # return torch.cat([fw_iwe_pos, fw_iwe_neg], dim=1) 253 | -------------------------------------------------------------------------------- /utils/flow_vis.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-08-03 14 | 15 | import numpy as np 16 | 17 | def make_colorwheel(): 18 | """ 19 | Generates a color wheel for optical flow visualization as presented in: 20 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 21 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 22 | 23 | Code follows the original C++ source code of Daniel Scharstein. 24 | Code follows the the Matlab source code of Deqing Sun. 25 | 26 | Returns: 27 | np.ndarray: Color wheel 28 | """ 29 | 30 | RY = 15 31 | YG = 6 32 | GC = 4 33 | CB = 11 34 | BM = 13 35 | MR = 6 36 | 37 | ncols = RY + YG + GC + CB + BM + MR 38 | colorwheel = np.zeros((ncols, 3)) 39 | col = 0 40 | 41 | # RY 42 | colorwheel[0:RY, 0] = 255 43 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 44 | col = col+RY 45 | # YG 46 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 47 | colorwheel[col:col+YG, 1] = 255 48 | col = col+YG 49 | # GC 50 | colorwheel[col:col+GC, 1] = 255 51 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 52 | col = col+GC 53 | # CB 54 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 55 | colorwheel[col:col+CB, 2] = 255 56 | col = col+CB 57 | # BM 58 | colorwheel[col:col+BM, 2] = 255 59 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 60 | col = col+BM 61 | # MR 62 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 63 | colorwheel[col:col+MR, 0] = 255 64 | return colorwheel 65 | 66 | 67 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 68 | """ 69 | Applies the flow color wheel to (possibly clipped) flow components u and v. 70 | 71 | According to the C++ source code of Daniel Scharstein 72 | According to the Matlab source code of Deqing Sun 73 | 74 | Args: 75 | u (np.ndarray): Input horizontal flow of shape [H,W] 76 | v (np.ndarray): Input vertical flow of shape [H,W] 77 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 78 | 79 | Returns: 80 | np.ndarray: Flow visualization image of shape [H,W,3] 81 | """ 82 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 83 | colorwheel = make_colorwheel() # shape [55x3] 84 | ncols = colorwheel.shape[0] 85 | rad = np.sqrt(np.square(u) + np.square(v)) 86 | a = np.arctan2(-v, -u)/np.pi 87 | fk = (a+1) / 2*(ncols-1) 88 | k0 = np.floor(fk).astype(np.int32) 89 | k1 = k0 + 1 90 | k1[k1 == ncols] = 0 91 | f = fk - k0 92 | for i in range(colorwheel.shape[1]): 93 | tmp = colorwheel[:,i] 94 | col0 = tmp[k0] / 255.0 95 | col1 = tmp[k1] / 255.0 96 | col = (1-f)*col0 + f*col1 97 | idx = (rad <= 1) 98 | col[idx] = 1 - rad[idx] * (1-col[idx]) 99 | col[~idx] = col[~idx] * 0.75 # out of range 100 | # Note the 2-i => BGR instead of RGB 101 | ch_idx = 2-i if convert_to_bgr else i 102 | flow_image[:,:,ch_idx] = np.floor(255 * col) 103 | return flow_image 104 | 105 | 106 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): 107 | """ 108 | Expects a two dimensional flow image of shape. 109 | 110 | Args: 111 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 112 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 113 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 114 | 115 | Returns: 116 | np.ndarray: Flow visualization image of shape [H,W,3] 117 | """ 118 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 119 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 120 | if clip_flow is not None: 121 | flow_uv = np.clip(flow_uv, -clip_flow, clip_flow) 122 | u = flow_uv[:,:,0] 123 | v = flow_uv[:,:,1] 124 | rad = np.sqrt(np.square(u) + np.square(v)) 125 | rad_max = np.max(rad) 126 | if clip_flow is not None: 127 | rad_max = clip_flow 128 | epsilon = 1e-5 129 | u = u / (rad_max + epsilon) 130 | v = v / (rad_max + epsilon) 131 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /utils/iwe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def purge_unfeasible(x, res): 5 | """ 6 | Purge unfeasible event locations by setting their interpolation weights to zero. 7 | :param x: location of motion compensated events 8 | :param res: resolution of the image space 9 | :return masked indices 10 | :return mask for interpolation weights 11 | """ 12 | 13 | mask = torch.ones((x.shape[0], x.shape[1], 1)).to(x.device) 14 | mask_y = (x[:, :, 0:1] < 0) + (x[:, :, 0:1] >= res[0]) 15 | mask_x = (x[:, :, 1:2] < 0) + (x[:, :, 1:2] >= res[1]) 16 | mask[mask_y + mask_x] = 0 17 | return x * mask, mask 18 | 19 | 20 | def get_interpolation(events, flow, tref, res, flow_scaling, round_idx=False): 21 | """ 22 | Warp the input events according to the provided optical flow map and compute the bilinar interpolation 23 | (or rounding) weights to distribute the events to the closes (integer) locations in the image space. 24 | :param events: [batch_size x N x 4] input events (y, x, ts, p) 25 | :param flow: [batch_size x 2 x H x W] optical flow map 26 | :param tref: reference time toward which events are warped 27 | :param res: resolution of the image space 28 | :param flow_scaling: scalar that multiplies the optical flow map 29 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = False) 30 | :return interpolated event indices 31 | :return interpolation weights 32 | """ 33 | 34 | # event propagation 35 | warped_events = events[:, :, 1:3] + (tref - events[:, :, 0:1]) * flow * flow_scaling 36 | # warped_events = events[:, :, 1:3] - (tref - events[:, :, 0:1]) * flow * flow_scaling 37 | 38 | if round_idx: 39 | 40 | # no bilinear interpolation 41 | idx = torch.round(warped_events) 42 | weights = torch.ones(idx.shape).to(events.device) 43 | 44 | else: 45 | 46 | # get scattering indices 47 | top_y = torch.floor(warped_events[:, :, 0:1]) 48 | bot_y = torch.floor(warped_events[:, :, 0:1] + 1) 49 | left_x = torch.floor(warped_events[:, :, 1:2]) 50 | right_x = torch.floor(warped_events[:, :, 1:2] + 1) 51 | 52 | top_left = torch.cat([top_y, left_x], dim=2) 53 | top_right = torch.cat([top_y, right_x], dim=2) 54 | bottom_left = torch.cat([bot_y, left_x], dim=2) 55 | bottom_right = torch.cat([bot_y, right_x], dim=2) 56 | idx = torch.cat([top_left, top_right, bottom_left, bottom_right], dim=1) 57 | 58 | # get scattering interpolation weights 59 | warped_events = torch.cat([warped_events for i in range(4)], dim=1) 60 | zeros = torch.zeros(warped_events.shape).to(events.device) 61 | weights = torch.max(zeros, 1 - torch.abs(warped_events - idx)) 62 | 63 | # purge unfeasible indices 64 | idx, mask = purge_unfeasible(idx, res) 65 | 66 | # make unfeasible weights zero 67 | weights = torch.prod(weights, dim=-1, keepdim=True) * mask # bilinear interpolation 68 | 69 | # prepare indices 70 | idx[:, :, 0] *= res[1] # torch.view is row-major 71 | idx = torch.sum(idx, dim=2, keepdim=True) 72 | 73 | # import pdb; pdb.set_trace() 74 | 75 | return idx, weights 76 | 77 | 78 | def interpolate(idx, weights, res, polarity_mask=None): 79 | """ 80 | Create an image-like representation of the warped events. 81 | :param idx: [batch_size x N x 1] warped event locations 82 | :param weights: [batch_size x N x 1] interpolation weights for the warped events 83 | :param res: resolution of the image space 84 | :param polarity_mask: [batch_size x N x 2] polarity mask for the warped events (default = None) 85 | :return image of warped events 86 | """ 87 | 88 | if polarity_mask is not None: 89 | weights = weights * polarity_mask 90 | iwe = torch.zeros((idx.shape[0], res[0] * res[1], 1)).to(idx.device) 91 | iwe = iwe.scatter_add_(1, idx.long(), weights) 92 | iwe = iwe.view((idx.shape[0], 1, res[0], res[1])) 93 | return iwe 94 | 95 | 96 | def deblur_events(flow, event_list, res, flow_scaling=128, round_idx=True, polarity_mask=None): 97 | """ 98 | Deblur the input events given an optical flow map. 99 | Event timestamp needs to be normalized between 0 and 1. 100 | :param flow: [batch_size x 2 x H x W] optical flow map 101 | :param events: [batch_size x N x 4] input events (y, x, ts, p) 102 | :param res: resolution of the image space 103 | :param flow_scaling: scalar that multiplies the optical flow map 104 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = False) 105 | :param polarity_mask: [batch_size x N x 2] polarity mask for the warped events (default = None) 106 | :return iwe: [batch_size x 1 x H x W] image of warped events 107 | """ 108 | 109 | # flow vector per input event 110 | flow_idx = event_list[:, :, 1:3].clone() 111 | flow_idx[:, :, 0] *= res[1] # torch.view is row-major 112 | flow_idx = torch.sum(flow_idx, dim=2) 113 | 114 | # get flow for every event in the list 115 | flow = flow.view(flow.shape[0], 2, -1) 116 | event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component 117 | event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component 118 | event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1) 119 | event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1) 120 | event_flow = torch.cat([event_flowy, event_flowx], dim=2) 121 | 122 | # interpolate forward 123 | fw_idx, fw_weights = get_interpolation(event_list, event_flow, 1, res, flow_scaling, round_idx=round_idx) 124 | if not round_idx: 125 | polarity_mask = torch.cat([polarity_mask for i in range(4)], dim=1) 126 | 127 | # image of (forward) warped events 128 | iwe = interpolate(fw_idx.long(), fw_weights, res, polarity_mask=polarity_mask) 129 | 130 | return iwe 131 | 132 | 133 | def compute_pol_iwe(flow, event_list, res, pos_mask, neg_mask, flow_scaling=128, round_idx=True): 134 | """ 135 | Create a per-polarity image of warped events given an optical flow map. 136 | :param flow: [batch_size x 2 x H x W] optical flow map 137 | :param event_list: [batch_size x N x 4] input events (y, x, ts, p) 138 | :param res: resolution of the image space 139 | :param pos_mask: [batch_size x N x 1] polarity mask for positive events 140 | :param neg_mask: [batch_size x N x 1] polarity mask for negative events 141 | :param flow_scaling: scalar that multiplies the optical flow map 142 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = True) 143 | :return iwe: [batch_size x 2 x H x W] image of warped events 144 | """ 145 | 146 | iwe_pos = deblur_events( 147 | flow, event_list, res, flow_scaling=flow_scaling, round_idx=round_idx, polarity_mask=pos_mask 148 | ) 149 | iwe_neg = deblur_events( 150 | flow, event_list, res, flow_scaling=flow_scaling, round_idx=round_idx, polarity_mask=neg_mask 151 | ) 152 | iwe = torch.cat([iwe_pos, iwe_neg], dim=1) 153 | 154 | return iwe 155 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Logger(): 4 | def __init__(self, savemodel_path, accelerator=None): 5 | self.accelerator = accelerator 6 | if self.accelerator is None or self.accelerator.is_local_main_process: 7 | file_path = os.path.join(savemodel_path, "log.txt") 8 | self.file_= open(file_path, "a+") 9 | self.counter = 0 10 | def __del__(self): 11 | if self.accelerator is None or self.accelerator.is_local_main_process: 12 | self.file_.close() 13 | def log_and_print(self, contents): 14 | if self.accelerator is None or self.accelerator.is_local_main_process: 15 | self.counter+=1 16 | self.file_.write(str(contents) + "\n") 17 | print(contents) 18 | if self.counter == 5: 19 | self.file_.flush() 20 | self.counter = 0 -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import random 4 | 5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 6 | 'std': [0.229, 0.224, 0.225]} 7 | 8 | #__imagenet_stats = {'mean': [0.5, 0.5, 0.5], 9 | # 'std': [0.5, 0.5, 0.5]} 10 | 11 | __imagenet_pca = { 12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 13 | 'eigvec': torch.Tensor([ 14 | [-0.5675, 0.7192, 0.4009], 15 | [-0.5808, -0.0045, -0.8140], 16 | [-0.5836, -0.6948, 0.4203], 17 | ]) 18 | } 19 | 20 | 21 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 22 | t_list = [ 23 | transforms.ToTensor(), 24 | transforms.Normalize(**normalize), 25 | ] 26 | #if scale_size != input_size: 27 | #t_list = [transforms.Scale((960,540))] + t_list 28 | 29 | return transforms.Compose(t_list) 30 | 31 | 32 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 33 | t_list = [ 34 | transforms.RandomCrop(input_size), 35 | transforms.ToTensor(), 36 | transforms.Normalize(**normalize), 37 | ] 38 | if scale_size != input_size: 39 | t_list = [transforms.Scale(scale_size)] + t_list 40 | 41 | transforms.Compose(t_list) 42 | 43 | 44 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 45 | padding = int((scale_size - input_size) / 2) 46 | return transforms.Compose([ 47 | transforms.RandomCrop(input_size, padding=padding), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | transforms.Normalize(**normalize), 51 | ]) 52 | 53 | 54 | def inception_preproccess(input_size, normalize=__imagenet_stats): 55 | return transforms.Compose([ 56 | transforms.RandomSizedCrop(input_size), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | transforms.Normalize(**normalize) 60 | ]) 61 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 62 | return transforms.Compose([ 63 | #transforms.RandomSizedCrop(input_size), 64 | #transforms.RandomHorizontalFlip(), 65 | transforms.ToTensor(), 66 | ColorJitter( 67 | brightness=0.4, 68 | contrast=0.4, 69 | saturation=0.4, 70 | ), 71 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 72 | transforms.Normalize(**normalize) 73 | ]) 74 | 75 | 76 | def get_transform(name='imagenet', input_size=None, 77 | scale_size=None, normalize=None, augment=True): 78 | normalize = __imagenet_stats 79 | input_size = 256 80 | if augment: 81 | return inception_color_preproccess(input_size, normalize=normalize) 82 | else: 83 | return scale_crop(input_size=input_size, 84 | scale_size=scale_size, normalize=normalize) 85 | 86 | 87 | 88 | 89 | class Lighting(object): 90 | """Lighting noise(AlexNet - style PCA - based noise)""" 91 | 92 | def __init__(self, alphastd, eigval, eigvec): 93 | self.alphastd = alphastd 94 | self.eigval = eigval 95 | self.eigvec = eigvec 96 | 97 | def __call__(self, img): 98 | if self.alphastd == 0: 99 | return img 100 | 101 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 102 | rgb = self.eigvec.type_as(img).clone()\ 103 | .mul(alpha.view(1, 3).expand(3, 3))\ 104 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 105 | .sum(1).squeeze() 106 | 107 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 108 | 109 | 110 | class Grayscale(object): 111 | 112 | def __call__(self, img): 113 | gs = img.clone() 114 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 115 | gs[1].copy_(gs[0]) 116 | gs[2].copy_(gs[0]) 117 | return gs 118 | 119 | 120 | class Saturation(object): 121 | 122 | def __init__(self, var): 123 | self.var = var 124 | 125 | def __call__(self, img): 126 | gs = Grayscale()(img) 127 | alpha = random.uniform(0, self.var) 128 | return img.lerp(gs, alpha) 129 | 130 | 131 | class Brightness(object): 132 | 133 | def __init__(self, var): 134 | self.var = var 135 | 136 | def __call__(self, img): 137 | gs = img.new().resize_as_(img).zero_() 138 | alpha = random.uniform(0, self.var) 139 | return img.lerp(gs, alpha) 140 | 141 | 142 | class Contrast(object): 143 | 144 | def __init__(self, var): 145 | self.var = var 146 | 147 | def __call__(self, img): 148 | gs = Grayscale()(img) 149 | gs.fill_(gs.mean()) 150 | alpha = random.uniform(0, self.var) 151 | return img.lerp(gs, alpha) 152 | 153 | 154 | class RandomOrder(object): 155 | """ Composes several transforms together in random order. 156 | """ 157 | 158 | def __init__(self, transforms): 159 | self.transforms = transforms 160 | 161 | def __call__(self, img): 162 | if self.transforms is None: 163 | return img 164 | order = torch.randperm(len(self.transforms)) 165 | for i in order: 166 | img = self.transforms[i](img) 167 | return img 168 | 169 | 170 | class ColorJitter(RandomOrder): 171 | 172 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 173 | self.transforms = [] 174 | if brightness != 0: 175 | self.transforms.append(Brightness(brightness)) 176 | if contrast != 0: 177 | self.transforms.append(Contrast(contrast)) 178 | if saturation != 0: 179 | self.transforms.append(Saturation(saturation)) 180 | -------------------------------------------------------------------------------- /utils/readpfm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sys 4 | 5 | 6 | def readPFM(file): 7 | file = open(file, 'rb') 8 | 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().rstrip() 16 | if header == 'PF': 17 | color = True 18 | elif header == 'Pf': 19 | color = False 20 | else: 21 | raise Exception('Not a PFM file.') 22 | 23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 24 | if dim_match: 25 | width, height = map(int, dim_match.groups()) 26 | else: 27 | raise Exception('Malformed PFM header.') 28 | 29 | scale = float(file.readline().rstrip()) 30 | if scale < 0: # little-endian 31 | endian = '<' 32 | scale = -scale 33 | else: 34 | endian = '>' # big-endian 35 | 36 | data = np.fromfile(file, endian + 'f') 37 | shape = (height, width, 3) if color else (height, width) 38 | 39 | data = np.reshape(data, shape) 40 | data = np.flipud(data) 41 | return data, scale 42 | 43 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def disp_err_to_color(disp_est, disp_gt): 4 | """ 5 | Calculate the error map between disparity estimation and disparity ground-truth 6 | hot color -> big error, cold color -> small error 7 | Args: 8 | disp_est (numpy.array): estimated disparity map 9 | in (Height, Width) layout, range [0,inf] 10 | disp_gt (numpy.array): ground truth disparity map 11 | in (Height, Width) layout, range [0,inf] 12 | Returns: 13 | disp_err (numpy.array): disparity error map 14 | in (Height, Width, 3) layout, range [0,1] 15 | """ 16 | """ matlab 17 | function D_err = disp_error_image (D_gt,D_est,tau,dilate_radius) 18 | if nargin==3 19 | dilate_radius = 1; 20 | end 21 | [E,D_val] = disp_error_map (D_gt,D_est); 22 | E = min(E/tau(1),(E./abs(D_gt))/tau(2)); 23 | cols = error_colormap(); 24 | D_err = zeros([size(D_gt) 3]); 25 | for i=1:size(cols,1) 26 | [v,u] = find(D_val > 0 & E >= cols(i,1) & E <= cols(i,2)); 27 | D_err(sub2ind(size(D_err),v,u,1*ones(length(v),1))) = cols(i,3); 28 | D_err(sub2ind(size(D_err),v,u,2*ones(length(v),1))) = cols(i,4); 29 | D_err(sub2ind(size(D_err),v,u,3*ones(length(v),1))) = cols(i,5); 30 | end 31 | D_err = imdilate(D_err,strel('disk',dilate_radius)); 32 | """ 33 | # error color map with interval (0, 0.1875, 0.375, 0.75, 1.5, 3, 6, 12, 24, 48, inf)/3.0 34 | # different interval corresponds to different 3-channel projection 35 | cols = np.array( 36 | [ 37 | [0 / 3.0, 0.1875 / 3.0, 49, 54, 149], 38 | [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180], 39 | [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209], 40 | [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233], 41 | [1.5 / 3.0, 3 / 3.0, 224, 243, 248], 42 | [3 / 3.0, 6 / 3.0, 254, 224, 144], 43 | [6 / 3.0, 12 / 3.0, 253, 174, 97], 44 | [12 / 3.0, 24 / 3.0, 244, 109, 67], 45 | [24 / 3.0, 48 / 3.0, 215, 48, 39], 46 | [48 / 3.0, float("inf"), 165, 0, 38] 47 | ] 48 | ) 49 | 50 | # [0, 1] -> [0, 255.0] 51 | disp_est = disp_est.copy() * 255.0 52 | disp_gt = disp_gt.copy() * 255.0 53 | # get the error (<3px or <5%) map 54 | tau = [3.0, 0.05] 55 | E = np.abs(disp_est - disp_gt) 56 | 57 | not_empty = disp_gt > 0.0 58 | tmp = np.zeros_like(disp_gt) 59 | tmp[not_empty] = E[not_empty] / disp_gt[not_empty] / tau[1] 60 | E = np.minimum(E / tau[0], tmp) 61 | 62 | h, w = disp_gt.shape 63 | err_im = np.zeros(shape=(h, w, 3)).astype(np.uint8) 64 | for col in cols: 65 | y_x = not_empty & (E >= col[0]) & (E <= col[1]) 66 | err_im[y_x] = col[2:] 67 | 68 | # value range [0, 1], shape in [H, W 3] 69 | err_im = err_im.astype(np.float64) / 255.0 70 | 71 | return err_im -------------------------------------------------------------------------------- /utils/viz.py: -------------------------------------------------------------------------------- 1 | # Copyrights. All rights reserved. 2 | # ECOLE POLYTECHNIQUE FEDERALE DE LAUSANNE, Switzerland, 3 | # Space Center (eSpace), 2018 4 | # See the LICENSE.TXT file for more details. 5 | 6 | import os 7 | 8 | import torch as th 9 | import numpy as np 10 | 11 | # Matplotlib backend should be choosen before pyplot is imported. 12 | import matplotlib 13 | matplotlib.use('Agg') 14 | from matplotlib import pyplot as plt 15 | from mpl_toolkits import axes_grid1 16 | import asyncio 17 | 18 | def gray_to_color(array, colormap_name='jet', vmin=None, vmax=None): 19 | cmap = plt.get_cmap(colormap_name) 20 | norm = plt.Normalize(vmin, vmax) 21 | return cmap(norm(array)) 22 | 23 | 24 | def _add_scaled_colorbar(plot, aspect=20, pad_fraction=0.5, **kwargs): 25 | """Adds scaled colorbar to existing plot.""" 26 | divider = axes_grid1.make_axes_locatable(plot.axes) 27 | width = axes_grid1.axes_size.AxesY(plot.axes, aspect=1. / aspect) 28 | pad = axes_grid1.axes_size.Fraction(pad_fraction, width) 29 | current_axis = plt.gca() 30 | cax = divider.append_axes("right", size=width, pad=pad) 31 | plt.sca(current_axis) 32 | return plot.axes.figure.colorbar(plot, cax=cax, **kwargs) 33 | 34 | 35 | def save_image(filename, image, color_first=True): 36 | """Save color image to file. 37 | 38 | Args: 39 | filename: image file where the image will be saved.. 40 | image: 3d image tensor. 41 | color_first: if True, the color dimesion is the first 42 | dimension of the "image", otherwise the 43 | color dimesion is the last dimesion. 44 | """ 45 | figure = plt.figure() 46 | if color_first: 47 | numpy_image = image.permute(1, 2, 0).numpy() 48 | else: 49 | numpy_image = image.numpy() 50 | plot = plt.imshow(numpy_image.astype(np.uint8)) 51 | plot.axes.get_xaxis().set_visible(False) 52 | plot.axes.get_yaxis().set_visible(False) 53 | figure.savefig(filename, bbox_inches='tight', dpi=200) 54 | plt.close() 55 | 56 | 57 | def save_matrix(filename, 58 | matrix, 59 | minimum_value=None, 60 | maximum_value=None, 61 | colormap='magma', 62 | is_colorbar=True): 63 | """Saves the matrix to the image file. 64 | 65 | Args: 66 | filename: image file where the matrix will be saved. 67 | matrix: tensor of size (height x width). Some values might be 68 | equal to inf. 69 | minimum_value, maximum value: boundaries of the range. 70 | Values outside ot the range are 71 | shown in white. The colors of other 72 | values are determined by the colormap. 73 | If maximum and minimum values are not 74 | given they are calculated as 0.001 and 75 | 0.999 quantile. 76 | colormap: map that determines color coding of matrix values. 77 | """ 78 | figure = plt.figure() 79 | noninf_mask = matrix != float('inf') 80 | if minimum_value is None: 81 | minimum_value = np.quantile(matrix[noninf_mask], 0.001) 82 | if maximum_value is None: 83 | maximum_value = np.quantile(matrix[noninf_mask], 0.999) 84 | plot = plt.imshow( 85 | matrix.numpy(), colormap, vmin=minimum_value, vmax=maximum_value) 86 | if is_colorbar: 87 | _add_scaled_colorbar(plot) 88 | plot.axes.get_xaxis().set_visible(False) 89 | plot.axes.get_yaxis().set_visible(False) 90 | figure.savefig(filename, bbox_inches='tight', dpi=200) 91 | plt.close() 92 | 93 | def plot_points_on_background(points_coordinates, 94 | background, 95 | points_color=[0, 0, 255]): 96 | """ 97 | Args: 98 | points_coordinates: array of (y, x) points coordinates 99 | of size (number_of_points x 2). 100 | background: (3 x height x width) 101 | gray or color image uint8. 102 | color: color of points [red, green, blue] uint8. 103 | """ 104 | if not (len(background.size()) == 3 and background.size(0) == 3): 105 | raise ValueError('background should be (color x height x width).') 106 | _, height, width = background.size() 107 | background_with_points = background.clone() 108 | y, x = points_coordinates.transpose(0, 1) 109 | x_min, x_max = x.min(), x.max() 110 | y_min, y_max = y.min(), y.max() 111 | if not (x_min >= 0 and y_min >= 0 and x_max < width and y_max < height): 112 | raise ValueError('points coordinates are outsize of "background" ' 113 | 'boundries.') 114 | background_with_points[:, y, x] = th.Tensor(points_color).type_as( 115 | background).unsqueeze(-1) 116 | return background_with_points 117 | 118 | 119 | def overlay_image_with_binary_error(color_image, binary_error): 120 | """Returns byte image overlayed with the binary error. 121 | 122 | Contrast of the image is reduced, brightness is incrased, 123 | and locations with the errors are shown in blue. 124 | 125 | Args: 126 | color_image: byte image tensor of size 127 | (color_index, height, width); 128 | binary_error: byte tensor of size (height x width), 129 | where "True"s correspond to error, 130 | and "False"s correspond to absence of error. 131 | """ 132 | points_coordinates = th.nonzero(binary_error) 133 | washed_out_image = color_image // 2 + 128 134 | return plot_points_on_background(points_coordinates, washed_out_image) 135 | 136 | 137 | class Logger(object): 138 | def __init__(self, filename): 139 | self._filename = filename 140 | 141 | def log(self, text): 142 | """Appends text line to the file.""" 143 | if os.path.isfile(self._filename): 144 | handler = open(self._filename, 'r') 145 | lines = handler.readlines() 146 | handler.close() 147 | else: 148 | lines = [] 149 | lines.append(text + '\n') 150 | handler = open(self._filename, 'w') 151 | handler.writelines(lines) 152 | handler.close() 153 | 154 | 155 | def plot_losses_and_errors(filename, 156 | losses, 157 | errors, 158 | righ_y_axis_label='Validation error, [%]'): 159 | """Plots the loss and the error. 160 | 161 | The plot has two y-axis: the left is reserved for the loss 162 | and the right is reserved for the error. The axis have 163 | different scale. The axis and the curve of the loss are shown 164 | in blue and the axis and the curve for the error are shown 165 | in red. 166 | 167 | Args: 168 | filename: image file where plot is saved; 169 | training_loss, validation_error: lists with loss and error values 170 | respectively. Every element of the 171 | list corresponds to an epoch. 172 | """ 173 | epochs = range(1, len(losses) + 1) 174 | figure, loss_axis = plt.subplots() 175 | smallest_loss = min(losses) 176 | loss_label = 'Training loss (smallest {0:.3f})'.format(smallest_loss) 177 | loss_plot = loss_axis.plot(epochs, losses, 'bs-', label=loss_label)[0] 178 | loss_axis.set_ylabel('Training loss', color='blue') 179 | loss_axis.set_xlabel('Epoch') 180 | error_axis = loss_axis.twinx() 181 | smallest_error = min(errors) 182 | error_label = 'Validation error (smallest {0:.3f})'.format(smallest_error) 183 | error_plot = error_axis.plot(epochs, errors, 'ro--', label=error_label)[0] 184 | error_axis.set_ylabel(righ_y_axis_label, color='red') 185 | error_axis.legend(handles=[loss_plot, error_plot]) 186 | figure.savefig(filename, bbox_inches='tight') 187 | plt.close() 188 | -------------------------------------------------------------------------------- /utils/warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | import math 7 | 8 | def normalize_coords(grid): 9 | """Normalize coordinates of image scale to [-1, 1] 10 | Args: 11 | grid: [B, 2, H, W] 12 | """ 13 | assert grid.size(1) == 2 14 | h, w = grid.size()[2:] 15 | grid[:, 0, :, :] = 2 * (grid[:, 0, :, :].clone() / (w - 1)) - 1 # x: [-1, 1] 16 | grid[:, 1, :, :] = 2 * (grid[:, 1, :, :].clone() / (h - 1)) - 1 # y: [-1, 1] 17 | grid = grid.permute((0, 2, 3, 1)) # [B, H, W, 2] 18 | return grid 19 | 20 | def meshgrid(img, homogeneous=False): 21 | """Generate meshgrid in image scale 22 | Args: 23 | img: [B, _, H, W] 24 | homogeneous: whether to return homogeneous coordinates 25 | Return: 26 | grid: [B, 2, H, W] 27 | """ 28 | b, _, h, w = img.size() 29 | 30 | x_range = torch.arange(0, w).view(1, 1, w).expand(1, h, w).type_as(img) # [1, H, W] 31 | y_range = torch.arange(0, h).view(1, h, 1).expand(1, h, w).type_as(img) 32 | 33 | grid = torch.cat((x_range, y_range), dim=0) # [2, H, W], grid[:, i, j] = [j, i] 34 | grid = grid.unsqueeze(0).expand(b, 2, h, w) # [B, 2, H, W] 35 | 36 | if homogeneous: 37 | ones = torch.ones_like(x_range).unsqueeze(0).expand(b, 1, h, w) # [B, 1, H, W] 38 | grid = torch.cat((grid, ones), dim=1) # [B, 3, H, W] 39 | assert grid.size(1) == 3 40 | return grid 41 | 42 | def disp_warp(img, disp, padding_mode='border', interpolate_mode = 'bilinear'): 43 | """Warping by disparity 44 | Args: 45 | img: [B, 3, H, W] 46 | disp: [B, 1, H, W], positive 47 | padding_mode: 'zeros' or 'border' 48 | Returns: 49 | warped_img: [B, 3, H, W] 50 | valid_mask: [B, 3, H, W] 51 | """ 52 | # assert disp.min() >= 0 53 | 54 | grid = meshgrid(img) # [B, 2, H, W] in image scale 55 | # Note that -disp here 56 | offset = torch.cat((-disp, torch.zeros_like(disp)), dim=1) # [B, 2, H, W] 57 | sample_grid = grid + offset 58 | sample_grid = normalize_coords(sample_grid) # [B, H, W, 2] in [-1, 1] 59 | warped_img = F.grid_sample(img, sample_grid, mode=interpolate_mode, padding_mode=padding_mode, align_corners=True) 60 | 61 | mask = torch.ones_like(img) 62 | valid_mask = F.grid_sample(mask, sample_grid, mode=interpolate_mode, padding_mode='zeros') 63 | valid_mask[valid_mask < 0.9999] = 0 64 | valid_mask[valid_mask > 0] = 1 65 | return warped_img, valid_mask 66 | 67 | 68 | def flow_warp(feature, flow, mask=False, padding_mode='zeros', interpolate_mode = 'bilinear'): 69 | b, c, h, w = feature.size() 70 | assert flow.size(1) == 2 71 | 72 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] 73 | 74 | return bilinear_sampler(feature, grid, mode = interpolate_mode, mask=mask, padding_mode=padding_mode) 75 | 76 | 77 | def coords_grid(batch, ht, wd, normalize=False): 78 | if normalize: # [-1, 1] 79 | coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1, 80 | 2 * torch.arange(wd) / (wd - 1) - 1) 81 | else: 82 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 83 | coords = torch.stack(coords[::-1], dim=0).float() 84 | 85 | return coords[None].repeat(batch, 1, 1, 1) # [B, 2, H, W] 86 | 87 | def bilinear_sampler(img, coords, mode='bilinear', mask=False, padding_mode='zeros'): 88 | """ Wrapper for grid_sample, uses pixel coordinates """ 89 | if coords.size(-1) != 2: # [B, 2, H, W] -> [B, H, W, 2] 90 | coords = coords.permute(0, 2, 3, 1) 91 | 92 | H, W = img.shape[-2:] 93 | # H = height if height is not None else img.shape[-2] 94 | # W = width if width is not None else img.shape[-1] 95 | 96 | xgrid, ygrid = coords.split([1, 1], dim=-1) 97 | 98 | # To handle H or W equals to 1 by explicitly defining height and width 99 | if H == 1: 100 | assert ygrid.abs().max() < 1e-8 101 | H = 10 102 | if W == 1: 103 | assert xgrid.abs().max() < 1e-8 104 | W = 10 105 | 106 | xgrid = 2 * xgrid / (W - 1) - 1 107 | ygrid = 2 * ygrid / (H - 1) - 1 108 | 109 | grid = torch.cat([xgrid, ygrid], dim=-1) 110 | img = F.grid_sample(img, grid, mode=mode, 111 | padding_mode=padding_mode, 112 | align_corners=True) 113 | # breakpoint() 114 | if mask: 115 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 116 | return img, mask.squeeze(-1).float() 117 | 118 | return img --------------------------------------------------------------------------------