├── .gitignore
├── .tmux.conf
├── README.md
├── checkpoints
└── dsec_best.tar
├── config
├── dsec_contrast.yaml
├── dsec_infer.yaml
└── mvsec_contrast.yaml
├── dataloader
├── common.py
├── contrast
│ ├── base.py
│ ├── encodings.py
│ ├── h5_mvsec.py
│ ├── mvsecDataset_no_ann.py
│ └── raw_event_utils.py
├── dsceDataset.py
├── dsecProvider_test.py
├── dsecSequence.py
├── eventslicer.py
└── representations.py
├── docker
├── Dockerfile
├── README.md
├── docker_build.sh
└── docker_run_multi.sh
├── infer_dsec.py
├── infer_dsec.sh
├── models
├── __init__.py
├── __init__.pyc
├── model_large.py
├── submodule.py
└── submodule.pyc
├── pre-processing
├── README.md
├── eventslicer.py
├── raw_event_parsing.py
├── voxel_generate_all.py
└── voxel_generate_test_all.py
├── requirements.txt
├── resource
├── teaser.png
└── temporal_stereo_demo.gif
├── run_tmux.sh
└── utils
├── __init__.py
├── evaluation
├── __init__.py
├── eval.py
├── flow_eval.py
├── flow_pixel_error.py
├── inverse_warp.py
└── pixel_error.py
├── evaluation_old
├── __init__.py
├── eval.py
├── flow_eval.py
├── flow_pixel_error.py
├── inverse_warp.py
└── pixel_error.py
├── flow.py
├── flow_vis.py
├── iwe.py
├── logger.py
├── preprocess.py
├── readpfm.py
├── softsplat.py
├── softsplat_prev.py
├── visualization.py
├── viz.py
└── warp.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
111 | .pdm.toml
112 | .pdm-python
113 | .pdm-build/
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
--------------------------------------------------------------------------------
/.tmux.conf:
--------------------------------------------------------------------------------
1 | set -g history-limit 999999999
2 | set -g mouse on
3 |
4 | # Shift arrow to swtich windows
5 | bind -n S-Left previous-window
6 | bind -n S-Right next-window
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Temporal Event Stereo via Joint Learning with Stereoscopic Flow (TESNet) (ECCV2024)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | Official code for "Temporal Event Stereo via Joint Learning with Stereoscopic Flow" (ECCV2024)
11 | ([Paper](https://arxiv.org/abs/2407.10831))
12 |
13 |
14 |
15 | ```bibtex
16 | @Article{tes24eccv,
17 | author = {Hoonhee Cho* and Jae-Young Kang* and Kuk-Jin Yoon},
18 | title = {Temporal Event Stereo via Joint Learning with Stereoscopic Flow},
19 | journal = {European Conference on Computer Vision. (ECCV)},
20 | year = {2024},
21 | }
22 | ```
23 |
24 |
25 |
26 | ## Abstract
27 | Event cameras are dynamic vision sensors inspired by the biological retina, characterized by their high dynamic range, high temporal resolution, and low power consumption. These features make them capable of perceiving 3D environments even in extreme conditions. Event data is continuous across the time dimension, which allows a detailed description of each pixel's movements. To fully utilize the temporally dense and continuous nature of event cameras, we propose a novel temporal event stereo, a framework that continuously uses information from previous time steps. This is accomplished through the simultaneous training of an event stereo matching network alongside stereoscopic flow, a new concept that captures all pixel movements from stereo cameras. Since obtaining ground truth for optical flow during training is challenging, we propose a method that uses only disparity maps to train the stereoscopic flow. The performance of event-based stereo matching is enhanced by temporally aggregating information using the flows. We have achieved state-of-the-art performance on the MVSEC and the DSEC datasets. The method is computationally efficient, as it stacks previous information in a cascading manner.
28 |
29 |
30 | ## Datasets
31 | Please refer to the pre-processing directory ([pre-process](https://github.com/mickeykang16/TemporalEventStereo/tree/master/pre-processing)) for the dataset's format and details.
32 |
33 | ## Installation
34 | ### Docker Environment
35 | This project is based on cuda 11.1, python 3.8 and torch 1.10.1. Please refer to [docker setup](https://github.com/mickeykang16/TemporalEventStereo/tree/master/docker) for more details.
36 | ### Conda Environment
37 | Comming Soon
38 |
39 |
40 | ## Training
41 | Comming Soon
42 |
43 | ## Testing
44 | We provide checkpoint and test code for DSEC dataset.
45 | ```bash
46 | source infer_dsec.sh
47 | ```
--------------------------------------------------------------------------------
/checkpoints/dsec_best.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mickeykang16/TemporalEventStereo/d9dc74677f568dbdf2d1d06da20480c50410e274/checkpoints/dsec_best.tar
--------------------------------------------------------------------------------
/config/dsec_contrast.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | epoch: 60 # to run test only, just set epoch to zero
3 | # type: minihourglass_small_v2_3d_of_backwarp_rawcost_small_of_channel
4 | # type: stackhourglass_small_v2_3d_of_backwarp_rawcost_small_of_channel
5 | # type: stackhourglass_3d_of_backwarp_rawcost_small_of_channel
6 | type: stackhourglass_v2
7 | # load_model: ''
8 | load_model: 'exps/new_flow_dsec/dsec_stackhourglass_v2_single/checkpoint_24.tar'
9 | # load_model: 'exps/stackhourglass_3d_of_backwarp_rawcost_small_of_channel_wo_cost_batch2/checkpoint_5.tar'
10 | # do_not_load_layer: ['dres4', 'classif3', 'of_block', 'fusion']
11 | # do_not_load_layer: ['of_block', 'fusion']
12 | load_strict: True
13 | remove_feat_first_weight : False
14 | load_optim: False
15 | pretrain_freeze: False
16 |
17 | maxdisp: 192 # depth for model costvolume
18 | eval_maxdisp: 192 # 255 / 7.0 # depth for evaluation mask
19 | height: 480
20 | # width: 352
21 | width: 640 # use for stackhourglass...
22 | # width: 640 # use for minihourglass_...
23 | # width: 384 # use for unet arch
24 | # frame_idxs: [0, ]
25 | frame_idxs: range(0, 1) # <- this type of expression is allowed
26 | use_prev_gradient: False
27 | dataset: DSEC
28 | split: 1
29 | # dataset_raw_root_path: /mnt2/DSEC_data
30 | dataset_raw_root_path: /home/jaeyoung/data3/dsec
31 | dataset_root_path: /home/jaeyoung/data3/dsec
32 | pseudo_root_path: /home/jaeyoung/data3/DSEC_pseudo_GT_all
33 | random_crop: False
34 | crop_size: [336, 480]
35 |
36 | use_pseudo_gt: True
37 | use_disp_gt_mask: False
38 | use_disp_flow_warp_mask: False
39 | use_featuremetric_loss: False
40 | use_disp_loss: True #default our loss
41 | use_contrast_loss: False
42 | use_stereo_loss: True #default for stereo True
43 |
44 | use_mini_data: False
45 | use_super_mini_data: False
46 |
47 | flow_smooth_weight: 0.1
48 | flow_scale: 1e-8
49 | val_of_viz: False
50 |
51 |
52 | orig_height: 480
53 | orig_width: 640
54 |
55 | use_raw_provider: True
56 | in_ch: 15
57 | delta_t_ms: 100
58 |
59 |
60 | log:
61 | log_train_every_n_batch: 40
62 | save_test_every_n_batch: 1
63 | lr: 0.001
64 | train:
65 | ann_path: view_4_train_v5_split1.json
66 | shuffle: True
67 | gradient_accumulation: 1
68 | batch_size: 4
69 | num_worker: 2
70 | validation:
71 | ann_path: view_4_test_v5_split1_all.json
72 | shuffle: False
73 | batch_size: 4
74 | num_worker: 2
75 | test:
76 | ann_path: view_4_test_v5_split1_all.json
77 | shuffle: False
78 | batch_size: 4
79 | num_worker: 4
--------------------------------------------------------------------------------
/config/dsec_infer.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | epoch: 0 # to run test only, just set epoch to zero
3 | type: ours_large
4 | load_model: 'checkpoints/dsec_best.tar'
5 | load_optim: False
6 | load_strict: True
7 | maxdisp: 192 # depth for model costvolume
8 | eval_maxdisp: 192 # evaluation
9 | height: 480
10 | width: 640
11 | frame_idxs: range(-3, 1) # use four frames
12 | use_prev_gradient: False
13 | dataset: DSEC
14 | dataset_raw_root_path: /home/data/dsec
15 | dataset_root_path: /home/data/dsec
16 | pseudo_root_path: /home/data/DSEC_pseudo_GT
17 |
18 | orig_height: 480
19 | orig_width: 640
20 |
21 | # Load both voxel and raw event for contrast maximization loss
22 | use_raw_provider: True
23 | in_ch: 15
24 | delta_t_ms: 50
25 |
26 | log:
27 | log_train_every_n_batch: 40
28 | save_test_every_n_batch: 100
29 | lr: 0.001
30 | train:
31 | shuffle: True
32 | gradient_accumulation: 0
33 | batch_size: 4
34 | num_worker: 16
35 | validation:
36 | shuffle: False
37 | batch_size: 4
38 | num_worker: 4
39 | test:
40 | shuffle: False
41 | batch_size: 1
42 | num_worker: 4
--------------------------------------------------------------------------------
/config/mvsec_contrast.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | epoch: 60 # to run test only, just set epoch to 0
3 | type: minihourglass_small_v6_3d_of_backwarp_rawcost_small_of_entropy_filter_new_flow
4 | load_model: 'exps/mvsec_multiframe_ablation2/split1_minihourglass_small_v6_3d_of_backwarp_rawcost_small_of_entropy_filter_new_flow_8frame3/continue/checkpoint_61.tar'
5 | # do_not_load_layer: ['dres4', 'classif3', 'of_block', 'fusion']
6 | do_not_load_layer: []
7 | load_strict: True
8 | load_optim: False # load optimizer state from checkpoint
9 |
10 | maxdisp: 48 # depth for model costvolume
11 | eval_maxdisp: 36 # depth for evaluation mask
12 | # Size after padding
13 | height: 288
14 | width: 348
15 |
16 | # frame_idxs: [0]
17 | frame_idxs: range(-7, 1) # <- this type of expression is allowed
18 | skip_num: 1
19 |
20 | use_prev_gradient: False
21 | dataset: MVSEC
22 | split: 1
23 | dataset_root_path: /home/jaeyoung/ws/event_stereo_ICCV2019/dataset
24 | use_pseudo_gt: True
25 |
26 | use_disp_loss: True #default our loss
27 |
28 | use_contrast_loss: True
29 | flow_regul_weight: 0.01
30 | contrast_flow_scale: 0.0
31 |
32 | use_stereo_loss: True #default for stereo True
33 |
34 | orig_height: 260
35 | orig_width: 346
36 | log:
37 | log_train_every_n_batch: 100
38 | save_test_every_n_batch: 1
39 | lr: 0.0008
40 | train:
41 | shuffle: True
42 | batch_size: 2
43 | num_worker: 4
44 | validation:
45 | shuffle: False
46 | batch_size: 4
47 | num_worker: 4
48 | test:
49 | shuffle: False
50 | batch_size: 4
51 | num_worker: 4
52 |
--------------------------------------------------------------------------------
/dataloader/common.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from skimage.color import rgb2hsv, hsv2rgb
4 | from skimage.transform import pyramid_gaussian
5 | from skimage.measure import block_reduce
6 |
7 | import torch
8 |
9 | def _apply(func, x):
10 |
11 | if isinstance(x, (list, tuple)):
12 | return [_apply(func, x_i) for x_i in x]
13 | elif isinstance(x, dict):
14 | y = {}
15 | for key, value in x.items():
16 | y[key] = _apply(func, value)
17 | return y
18 | else:
19 | return func(x)
20 |
21 | def crop(args, ps, py, px): # patch_size
22 | # args = [input, target]
23 | def _get_shape(*args):
24 | if isinstance(args[0], (list, tuple)):
25 | return _get_shape(args[0][0])
26 | elif isinstance(args[0], dict):
27 | return _get_shape(list(args[0].values())[0])
28 | else:
29 | return args[0].shape
30 |
31 | h, w, _ = _get_shape(args)
32 | # print(_get_shape(args))
33 | # print(ps[1])
34 | # import pdb; pdb.set_trace()
35 |
36 | # py = random.randrange(0, h-ps+1)
37 | # px = random.randrange(0, w-ps+1)
38 |
39 | def _crop(img):
40 | if img.ndim == 2:
41 | return img[py:py+ps[1], px:px+ps[0], np.newaxis]
42 | else:
43 | return img[py:py+ps[1], px:px+ps[0], :]
44 |
45 | return _apply(_crop, args)
46 |
47 | def crop_with_event(*args, left_event, right_event, ps=256): # patch_size
48 | # args = [input, target]
49 | def _get_shape(*args):
50 | if isinstance(args[0], (list, tuple)):
51 | return _get_shape(args[0][0])
52 | elif isinstance(args[0], dict):
53 | return _get_shape(list(args[0].values())[0])
54 | else:
55 | return args[0].shape
56 |
57 | h, w, _ = _get_shape(args)
58 |
59 | py = random.randrange(0, h-ps+1)
60 | px = random.randrange(0, w-ps+1)
61 |
62 | def _crop(img):
63 | if img.ndim == 2:
64 | return img[py:py+ps, px:px+ps, np.newaxis]
65 | else:
66 | return img[py:py+ps, px:px+ps, :]
67 |
68 | def _event_crop(event):
69 | return event[:, py:py+ps, px:px+ps]
70 |
71 | return _apply(_crop, args), _apply(_event_crop, left_event), _apply(_event_crop, right_event)
72 |
73 | def crop_event(args, ps, py, px): # patch_size
74 | # args = [input, target]
75 | def _get_shape(*args):
76 | if isinstance(args[0], (list, tuple)):
77 | return _get_shape(args[0][0])
78 | elif isinstance(args[0], dict):
79 | return _get_shape(list(args[0].values())[0])
80 | else:
81 | return args[0].shape
82 |
83 | _, h, w = _get_shape(args)
84 |
85 | # py = random.randrange(0, h-ps+1)
86 | # px = random.randrange(0, w-ps+1)
87 |
88 | def _crop(img):
89 | if img.ndim == 2:
90 | return img[py:py+ps, px:px+ps, np.newaxis]
91 | else:
92 | return img[py:py+ps, px:px+ps, :]
93 |
94 | def _event_crop(event):
95 | return event[:, py:py+ps[1], px:px+ps[0]]
96 |
97 | return _apply(_event_crop, args)
98 |
99 |
100 | def crop_disp(args, ps, py, px): # patch_size
101 | # args = [input, target]
102 | def _get_shape(*args):
103 | if isinstance(args[0], (list, tuple)):
104 | return _get_shape(args[0][0])
105 | elif isinstance(args[0], dict):
106 | return _get_shape(list(args[0].values())[0])
107 | else:
108 | return args[0].shape
109 |
110 | h, w = _get_shape(args)
111 |
112 | # py = random.randrange(0, h-ps+1)
113 | # px = random.randrange(0, w-ps+1)
114 |
115 |
116 | def _crop(img):
117 | if img.ndim == 2:
118 | return img[py:py+ps[1], px:px+ps[0], np.newaxis]
119 | else:
120 | return img[py:py+ps[1], px:px+ps[0], :]
121 |
122 | def _disp_crop(event):
123 | return event[py:py+ps[1], px:px+ps[0]]
124 |
125 | return _apply(_disp_crop, args)
126 |
127 | def add_noise(*args, sigma_sigma=2, rgb_range=255):
128 |
129 | if len(args) == 1: # usually there is only a single input
130 | args = args[0]
131 |
132 | sigma = np.random.normal() * sigma_sigma * rgb_range/255
133 |
134 | def _add_noise(img):
135 | noise = np.random.randn(*img.shape).astype(np.float32) * sigma
136 | return (img + noise).clip(0, rgb_range)
137 |
138 | return _apply(_add_noise, args)
139 |
140 | def augment(*args, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=255):
141 | """augmentation consistent to input and target"""
142 |
143 | choices = (False, True)
144 |
145 | hflip = hflip and random.choice(choices)
146 | vflip = rot and random.choice(choices)
147 | rot90 = rot and random.choice(choices)
148 | # shuffle = shuffle
149 |
150 | if shuffle:
151 | rgb_order = list(range(3))
152 | random.shuffle(rgb_order)
153 | if rgb_order == list(range(3)):
154 | shuffle = False
155 |
156 | if change_saturation:
157 | amp_factor = np.random.uniform(0.5, 1.5)
158 |
159 | def _augment(img):
160 | if hflip: img = img[:, ::-1, :]
161 | if vflip: img = img[::-1, :, :]
162 | if rot90: img = img.transpose(1, 0, 2)
163 | if shuffle and img.ndim > 2:
164 | if img.shape[-1] == 3: # RGB image only
165 | img = img[..., rgb_order]
166 |
167 | if change_saturation:
168 | hsv_img = rgb2hsv(img)
169 | hsv_img[..., 1] *= amp_factor
170 |
171 | img = hsv2rgb(hsv_img).clip(0, 1) * rgb_range
172 |
173 | return img.astype(np.float32)
174 |
175 | return _apply(_augment, args)
176 |
177 | def pad(img, divisor=4, pad_width=None, negative=False):
178 |
179 | def _pad_numpy(img, divisor=4, pad_width=None, negative=False):
180 | if pad_width is None:
181 | (h, w, _) = img.shape
182 | pad_h = -h % divisor
183 | pad_w = -w % divisor
184 | pad_width = ((0, pad_h), (0, pad_w), (0, 0))
185 |
186 | img = np.pad(img, pad_width, mode='edge')
187 |
188 | return img, pad_width
189 |
190 | def _pad_tensor(img, divisor=4, pad_width=None, negative=False):
191 |
192 | n, c, h, w = img.shape
193 | if pad_width is None:
194 | pad_h = -h % divisor
195 | pad_w = -w % divisor
196 | pad_width = (0, pad_w, 0, pad_h)
197 | else:
198 | try:
199 | pad_h = pad_width[0][1]
200 | pad_w = pad_width[1][1]
201 | if isinstance(pad_h, torch.Tensor):
202 | pad_h = pad_h.item()
203 | if isinstance(pad_w, torch.Tensor):
204 | pad_w = pad_w.item()
205 |
206 | pad_width = (0, pad_w, 0, pad_h)
207 | except:
208 | pass
209 |
210 | if negative:
211 | pad_width = [-val for val in pad_width]
212 |
213 | img = torch.nn.functional.pad(img, pad_width, 'reflect')
214 |
215 | return img, pad_width
216 |
217 | if isinstance(img, np.ndarray):
218 | return _pad_numpy(img, divisor, pad_width, negative)
219 | else: # torch.Tensor
220 | return _pad_tensor(img, divisor, pad_width, negative)
221 |
222 |
223 | def disp_pad(img, divisor=4, pad_width=None, negative=False):
224 |
225 | def _pad_numpy(img, divisor=4, pad_width=None, negative=False):
226 | if pad_width is None:
227 | (h, w, _) = img.shape
228 | pad_h = -h % divisor
229 | pad_w = -w % divisor
230 | pad_width = ((0, pad_h), (0, pad_w), (0, 0))
231 |
232 | img = np.pad(img, pad_width, mode='edge')
233 |
234 | return img, pad_width
235 |
236 | def _pad_tensor(img, divisor=4, pad_width=None, negative=False):
237 |
238 | if isinstance(img, list):
239 | img = img[0].unsqueeze(1)
240 | n, c, h, w = img.shape
241 | else:
242 | img = img.unsqueeze(1)
243 | n, c, h, w = img.shape
244 | if pad_width is None:
245 | pad_h = -h % divisor
246 | pad_w = -w % divisor
247 | pad_width = (0, pad_w, 0, pad_h)
248 |
249 |
250 | else:
251 | try:
252 | # import pdb; pdb.set_trace()
253 | pad_h = pad_width[0][1][0]
254 | pad_w = pad_width[1][1][0]
255 | if isinstance(pad_h, torch.Tensor):
256 | pad_h = pad_h.item()
257 | if isinstance(pad_w, torch.Tensor):
258 | pad_w = pad_w.item()
259 |
260 | pad_width = (0, pad_w, 0, pad_h)
261 | except:
262 | pass
263 |
264 | if negative:
265 | pad_width = [-val for val in pad_width]
266 |
267 |
268 | img = torch.nn.functional.pad(img, pad_width, 'reflect')
269 |
270 |
271 | return img
272 |
273 | if isinstance(img, np.ndarray):
274 | return _pad_numpy(img, divisor, pad_width, negative)
275 | else: # torch.Tensor
276 | return _pad_tensor(img, divisor, pad_width, negative)
277 |
278 | def event_pad(img, divisor=4, pad_width=None, negative=False):
279 |
280 | def _pad_numpy(img, divisor=4, pad_width=None, negative=False):
281 |
282 | if pad_width is None:
283 | (_, h, w) = img.shape
284 | pad_h = -h % divisor
285 | pad_w = -w % divisor
286 | pad_width = ((0, 0), (0, pad_h), (0, pad_w))
287 |
288 | img = np.pad(img, pad_width, mode='edge')
289 |
290 | return img, pad_width
291 |
292 | def _pad_tensor(img, divisor=4, pad_width=None, negative=False):
293 |
294 | n, c, h, w = img.shape
295 | if pad_width is None:
296 | pad_h = -h % divisor
297 | pad_w = -w % divisor
298 | pad_width = (0, pad_w, 0, pad_h)
299 | else:
300 | try:
301 | pad_h = pad_width[0][1]
302 | pad_w = pad_width[1][1]
303 | if isinstance(pad_h, torch.Tensor):
304 | pad_h = pad_h.item()
305 | if isinstance(pad_w, torch.Tensor):
306 | pad_w = pad_w.item()
307 |
308 | pad_width = (0, pad_w, 0, pad_h)
309 | except:
310 | pass
311 |
312 | if negative:
313 | pad_width = [-val for val in pad_width]
314 |
315 | img = torch.nn.functional.pad(img, pad_width, 'reflect')
316 |
317 | return img, pad_width
318 |
319 | if isinstance(img, np.ndarray):
320 | return _pad_numpy(img, divisor, pad_width, negative)
321 | else: # torch.Tensor
322 | return _pad_tensor(img, divisor, pad_width, negative)
323 |
324 | def generate_pyramid(*args, n_scales):
325 |
326 | def _generate_pyramid(img):
327 | if img.dtype != np.float32:
328 | img = img.astype(np.float32)
329 | pyramid = list(pyramid_gaussian(img, n_scales-1, multichannel=True))
330 |
331 | return pyramid
332 |
333 | return _apply(_generate_pyramid, args)
334 |
335 | def generate_event_pyramid(*args, n_scales):
336 | def _generate_event_pyramid(event):
337 | event = np.array(event)
338 |
339 | if event.dtype != np.float32:
340 | event = event.astype(np.float32)
341 |
342 | # import pdb
343 | # pdb.set_trace()
344 | # print("len event")
345 | # print(event)
346 | # print([len(a) for a in event])
347 |
348 | # print(event.shape)
349 | pyramid = []
350 | for i in range(n_scales):
351 | w, h, c = event.shape
352 | # scale_event = block_reduce(event, (w//(2**i) , h//(2**i), c), np.max)
353 | scale_event = block_reduce(event, (1, 2**i, 2**i), np.max)
354 | # print(scale_event)
355 | # print(event)
356 | pyramid.append(scale_event)
357 | # if i == 2:
358 | # print(scale_event.shape)
359 | # import pdb
360 | # pdb.set_trace()
361 |
362 |
363 | # pyramid = list(pyramid_gaussian(img, n_scales-1, multichannel=True))
364 | return pyramid
365 | return _generate_event_pyramid(args)
366 |
367 |
368 | def np2tensor(*args):
369 | def _np2tensor(x):
370 | np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1))
371 | tensor = torch.from_numpy(np_transpose)
372 |
373 | return tensor
374 |
375 | return _apply(_np2tensor, args)
376 |
377 | def image2tensor(x):
378 | np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1))
379 | tensor = torch.from_numpy(np_transpose)
380 |
381 | return tensor
382 |
383 | def event2tensor(*args):
384 | def _np2tensor(x):
385 | # np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1))
386 | tensor = torch.from_numpy(x)
387 | return tensor
388 |
389 |
390 | return _apply(_np2tensor, args)
391 |
392 | def to(*args, device=None, dtype=torch.float):
393 |
394 | def _to(x):
395 | return x.to(device=device, dtype=dtype, non_blocking=True, copy=False)
396 |
397 | return _apply(_to, args)
--------------------------------------------------------------------------------
/dataloader/contrast/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 | import numpy as np
4 | import random
5 | import torch
6 |
7 | from .encodings import events_to_voxel, events_to_channels, events_to_mask, get_hot_event_mask
8 |
9 |
10 | class BaseDataLoader(torch.utils.data.Dataset):
11 | """
12 | Base class for dataloader.
13 | """
14 |
15 | def __init__(self, config, num_bins):
16 | self.config = config
17 | self.epoch = 0
18 | self.seq_num = 0
19 | self.samples = 0
20 | self.new_seq = False
21 | self.tc_idx = 0
22 | self.num_bins = num_bins
23 |
24 | # batch-specific data augmentation mechanisms
25 | self.batch_augmentation = {}
26 | for mechanism in self.config["loader"]["augment"]:
27 | if mechanism != "Pause":
28 | self.batch_augmentation[mechanism] = [False for i in range(self.config["loader"]["batch_size"])]
29 | else:
30 | self.batch_augmentation[mechanism] = False # shared among batch elements
31 |
32 | for i, mechanism in enumerate(self.config["loader"]["augment"]):
33 | if mechanism != "Pause":
34 | for batch in range(self.config["loader"]["batch_size"]):
35 | if np.random.random() < self.config["loader"]["augment_prob"][i]:
36 | self.batch_augmentation[mechanism][batch] = True
37 |
38 | # hot pixels
39 | if self.config["hot_filter"]["enabled"]:
40 | self.hot_idx = [0 for i in range(self.config["loader"]["batch_size"])]
41 | self.hot_events = [
42 | torch.zeros(self.config["loader"]["resolution"]) for i in range(self.config["loader"]["batch_size"])
43 | ]
44 |
45 | @abstractmethod
46 | def __getitem__(self, index):
47 | raise NotImplementedError
48 |
49 | @abstractmethod
50 | def get_events(self, history):
51 | raise NotImplementedError
52 |
53 | def reset_sequence(self, batch):
54 | """
55 | Reset sequence-specific variables.
56 | :param batch: batch index
57 | """
58 |
59 | self.tc_idx = 0
60 | self.seq_num += 1
61 | if self.config["hot_filter"]["enabled"]:
62 | self.hot_idx[batch] = 0
63 | self.hot_events[batch] = torch.zeros(self.config["loader"]["resolution"])
64 |
65 | for i, mechanism in enumerate(self.config["loader"]["augment"]):
66 | if mechanism != "Pause":
67 | if np.random.random() < self.config["loader"]["augment_prob"][i]:
68 | self.batch_augmentation[mechanism][batch] = True
69 | else:
70 | self.batch_augmentation[mechanism][batch] = False
71 | else:
72 | self.batch_augmentation[mechanism] = False
73 |
74 | @staticmethod
75 | def event_formatting(xs, ys, ts, ps):
76 | """
77 | Reset sequence-specific variables.
78 | :param xs: [N] numpy array with event x location
79 | :param ys: [N] numpy array with event y location
80 | :param ts: [N] numpy array with event timestamp
81 | :param ps: [N] numpy array with event polarity ([-1, 1])
82 | :return xs: [N] tensor with event x location
83 | :return ys: [N] tensor with event y location
84 | :return ts: [N] tensor with normalized event timestamp
85 | :return ps: [N] tensor with event polarity ([-1, 1])
86 | """
87 |
88 | xs = torch.from_numpy(xs.astype(np.float32))
89 | ys = torch.from_numpy(ys.astype(np.float32))
90 | ts = torch.from_numpy(ts.astype(np.float32))
91 | ps = torch.from_numpy(ps.astype(np.float32))
92 | ts = (ts - ts[0]) / (ts[-1] - ts[0])
93 | return xs, ys, ts, ps
94 |
95 | def augment_events(self, xs, ys, ps, batch):
96 | """
97 | Augment event sequence with horizontal, vertical, and polarity flips, and
98 | artificial event pauses.
99 | :return xs: [N] tensor with event x location
100 | :return ys: [N] tensor with event y location
101 | :return ps: [N] tensor with event polarity ([-1, 1])
102 | :param batch: batch index
103 | :return xs: [N] tensor with augmented event x location
104 | :return ys: [N] tensor with augmented event y location
105 | :return ps: [N] tensor with augmented event polarity ([-1, 1])
106 | """
107 |
108 | for i, mechanism in enumerate(self.config["loader"]["augment"]):
109 |
110 | if mechanism == "Horizontal":
111 | if self.batch_augmentation["Horizontal"][batch]:
112 | xs = self.config["loader"]["resolution"][1] - 1 - xs
113 |
114 | elif mechanism == "Vertical":
115 | if self.batch_augmentation["Vertical"][batch]:
116 | ys = self.config["loader"]["resolution"][0] - 1 - ys
117 |
118 | elif mechanism == "Polarity":
119 | if self.batch_augmentation["Polarity"][batch]:
120 | ps *= -1
121 |
122 | # shared among batch elements
123 | elif (
124 | batch == 0
125 | and mechanism == "Pause"
126 | and self.tc_idx > self.config["loss"]["reconstruction_tc_idx_threshold"]
127 | ):
128 | if not self.batch_augmentation["Pause"]:
129 | if np.random.random() < self.config["loader"]["augment_prob"][i][0]:
130 | self.batch_augmentation["Pause"] = True
131 | else:
132 | if np.random.random() < self.config["loader"]["augment_prob"][i][1]:
133 | self.batch_augmentation["Pause"] = False
134 |
135 | return xs, ys, ps
136 |
137 | def augment_frames(self, img, batch):
138 | """
139 | Augment APS frame with horizontal and vertical flips.
140 | :param img: [H x W] numpy array with APS intensity
141 | :param batch: batch index
142 | :return img: [H x W] augmented numpy array with APS intensity
143 | """
144 | if "Horizontal" in self.batch_augmentation:
145 | if self.batch_augmentation["Horizontal"][batch]:
146 | img = np.flip(img, 1)
147 | if "Vertical" in self.batch_augmentation:
148 | if self.batch_augmentation["Vertical"][batch]:
149 | img = np.flip(img, 0)
150 | return img
151 |
152 | def create_cnt_encoding(self, xs, ys, ts, ps):
153 | """
154 | Creates a per-pixel and per-polarity event count representation.
155 | :param xs: [N] tensor with event x location
156 | :param ys: [N] tensor with event y location
157 | :param ts: [N] tensor with normalized event timestamp
158 | :param ps: [N] tensor with event polarity ([-1, 1])
159 | :return [2 x H x W] event representation
160 | """
161 |
162 | return events_to_channels(xs, ys, ps, sensor_size=self.config["loader"]["resolution"])
163 |
164 | def create_voxel_encoding(self, xs, ys, ts, ps):
165 | """
166 | Creates a spatiotemporal voxel grid tensor representation with a certain number of bins,
167 | as described in Section 3.1 of the paper 'Unsupervised Event-based Learning of Optical Flow,
168 | Depth, and Egomotion', Zhu et al., CVPR'19..
169 | Events are distributed to the spatiotemporal closest bins through bilinear interpolation.
170 | Positive events are added as +1, while negative as -1.
171 | :param xs: [N] tensor with event x location
172 | :param ys: [N] tensor with event y location
173 | :param ts: [N] tensor with normalized event timestamp
174 | :param ps: [N] tensor with event polarity ([-1, 1])
175 | :return [B x H x W] event representation
176 | """
177 |
178 | return events_to_voxel(
179 | xs,
180 | ys,
181 | ts,
182 | ps,
183 | self.num_bins,
184 | sensor_size=self.config["loader"]["resolution"],
185 | )
186 |
187 | @staticmethod
188 | def create_list_encoding(xs, ys, ts, ps):
189 | """
190 | Creates a four channel tensor with all the events in the input partition.
191 | :param xs: [N] tensor with event x location
192 | :param ys: [N] tensor with event y location
193 | :param ts: [N] tensor with normalized event timestamp
194 | :param ps: [N] tensor with event polarity ([-1, 1])
195 | :return [N x 4] event representation
196 | """
197 |
198 | return torch.stack([ts, ys, xs, ps])
199 |
200 | @staticmethod
201 | def create_polarity_mask(ps):
202 | """
203 | Creates a two channel tensor that acts as a mask for the input event list.
204 | :param ps: [N] tensor with event polarity ([-1, 1])
205 | :return [N x 2] event representation
206 | """
207 |
208 | inp_pol_mask = torch.stack([ps, ps])
209 | inp_pol_mask[0, :][inp_pol_mask[0, :] < 0] = 0
210 | inp_pol_mask[1, :][inp_pol_mask[1, :] > 0] = 0
211 | inp_pol_mask[1, :] *= -1
212 | return inp_pol_mask
213 |
214 | def create_hot_mask(self, xs, ys, ps, batch):
215 | """
216 | Creates a one channel tensor that can act as mask to remove pixel with high event rate.
217 | :param xs: [N] tensor with event x location
218 | :param ys: [N] tensor with event y location
219 | :param ps: [N] tensor with event polarity ([-1, 1])
220 | :return [H x W] binary mask
221 | """
222 |
223 | hot_update = events_to_mask(xs, ys, ps, sensor_size=self.hot_events[batch].shape)
224 | self.hot_events[batch] += hot_update
225 | self.hot_idx[batch] += 1
226 | event_rate = self.hot_events[batch] / self.hot_idx[batch]
227 | return get_hot_event_mask(
228 | event_rate,
229 | self.hot_idx[batch],
230 | max_px=self.config["hot_filter"]["max_px"],
231 | min_obvs=self.config["hot_filter"]["min_obvs"],
232 | max_rate=self.config["hot_filter"]["max_rate"],
233 | )
234 |
235 | # def __len__(self):
236 | # return 1000 # not used
237 |
238 | @staticmethod
239 | def custom_collate(batch):
240 | """
241 | Collects the different event representations and stores them together in a dictionary.
242 | """
243 | # breakpoint()
244 | batch_dict = {}
245 | for key in batch[0].keys():
246 | batch_dict[key] = []
247 | for entry in batch:
248 | for key in entry.keys():
249 | batch_dict[key].append(entry[key])
250 | for key in batch_dict.keys():
251 | if len(batch_dict[key][0].shape) > 2:
252 | item = torch.stack(batch_dict[key])
253 | # if len(item.shape) == 3:
254 | # item = item.transpose(2, 1)
255 | batch_dict[key] = item
256 | else:
257 | batch_dict[key] = [e_list.transpose(0, 1) for e_list in batch_dict[key]]
258 | return batch_dict
259 |
260 | # def shuffle(self, flag=True):
261 | # """
262 | # Shuffles the training data.
263 | # """
264 |
265 | # if flag:
266 | # random.shuffle(self.files)
267 |
--------------------------------------------------------------------------------
/dataloader/contrast/encodings.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from Monash University https://github.com/TimoStoff/events_contrast_maximization
3 | """
4 |
5 | import numpy as np
6 | import torch
7 |
8 |
9 | def binary_search_array(array, x, left=None, right=None, side="left"):
10 | """
11 | Binary search through a sorted array.
12 | """
13 |
14 | left = 0 if left is None else left
15 | right = len(array) - 1 if right is None else right
16 | mid = left + (right - left) // 2
17 |
18 | if left > right:
19 | return left if side == "left" else right
20 |
21 | if array[mid] == x:
22 | return mid
23 |
24 | if x < array[mid]:
25 | return binary_search_array(array, x, left=left, right=mid - 1)
26 |
27 | return binary_search_array(array, x, left=mid + 1, right=right)
28 |
29 |
30 | def events_to_mask(xs, ys, ps, sensor_size=(180, 240)):
31 | """
32 | Accumulate events into a binary mask.
33 | """
34 |
35 | device = xs.device
36 | img_size = list(sensor_size)
37 | mask = torch.zeros(img_size).to(device)
38 |
39 | if xs.dtype is not torch.long:
40 | xs = xs.long().to(device)
41 | if ys.dtype is not torch.long:
42 | ys = ys.long().to(device)
43 | mask.index_put_((ys, xs), ps.abs(), accumulate=False)
44 |
45 | return mask
46 |
47 |
48 | def events_to_image(xs, ys, ps, sensor_size=(180, 240)):
49 | """
50 | Accumulate events into an image.
51 | """
52 |
53 | device = xs.device
54 | img_size = list(sensor_size)
55 | img = torch.zeros(img_size).to(device)
56 |
57 | if xs.dtype is not torch.long:
58 | xs = xs.long().to(device)
59 | if ys.dtype is not torch.long:
60 | ys = ys.long().to(device)
61 | img.index_put_((ys, xs), ps, accumulate=True)
62 |
63 | return img
64 |
65 |
66 | def events_to_voxel(xs, ys, ts, ps, num_bins, sensor_size=(180, 240)):
67 | """
68 | Generate a voxel grid from input events using temporal bilinear interpolation.
69 | """
70 |
71 | assert len(xs) == len(ys) and len(ys) == len(ts) and len(ts) == len(ps)
72 |
73 | voxel = []
74 | ts = ts * (num_bins - 1)
75 | zeros = torch.zeros(ts.size())
76 | for b_idx in range(num_bins):
77 | weights = torch.max(zeros, 1.0 - torch.abs(ts - b_idx))
78 | voxel_bin = events_to_image(xs, ys, ps * weights, sensor_size=sensor_size)
79 | voxel.append(voxel_bin)
80 |
81 | return torch.stack(voxel)
82 |
83 |
84 | def events_to_channels(xs, ys, ps, sensor_size=(180, 240)):
85 | """
86 | Generate a two-channel event image containing event counters.
87 | """
88 |
89 | assert len(xs) == len(ys) and len(ys) == len(ps)
90 |
91 | mask_pos = ps.clone()
92 | mask_neg = ps.clone()
93 | mask_pos[ps < 0] = 0
94 | mask_neg[ps > 0] = 0
95 |
96 | pos_cnt = events_to_image(xs, ys, ps * mask_pos, sensor_size=sensor_size)
97 | neg_cnt = events_to_image(xs, ys, ps * mask_neg, sensor_size=sensor_size)
98 |
99 | return torch.stack([pos_cnt, neg_cnt])
100 |
101 |
102 | def get_hot_event_mask(event_rate, idx, max_px=100, min_obvs=5, max_rate=0.8):
103 | """
104 | Returns binary mask to remove events from hot pixels.
105 | """
106 |
107 | mask = torch.ones(event_rate.shape).to(event_rate.device)
108 | if idx > min_obvs:
109 | for i in range(max_px):
110 | argmax = torch.argmax(event_rate)
111 | index = (argmax // event_rate.shape[1], argmax % event_rate.shape[1])
112 | if event_rate[index] > max_rate:
113 | event_rate[index] = 0
114 | mask[index] = 0
115 | else:
116 | break
117 | return mask
118 |
--------------------------------------------------------------------------------
/dataloader/contrast/raw_event_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import h5py
3 | import numpy as np
4 |
5 | import torch
6 | import torch.utils.data as data
7 |
8 | # from .utils import ProgressBar
9 |
10 | # from .encodings import binary_search_array
11 | from .encodings import events_to_voxel, events_to_channels, events_to_mask, get_hot_event_mask
12 | from torch.utils.data.dataloader import default_collate
13 |
14 |
15 | class EventSequenceLoader():
16 | def __init__(self, config):
17 |
18 | self.root_path = config.get("root", '')
19 | self.resolution = config['resolution']
20 | self.num_bins = 5
21 |
22 | def get_events(self, path):
23 | """
24 | Get all the events in between two indices.
25 | :param file: file to read from
26 | :param idx0: start index
27 | :param idx1: end index
28 | :return xs: [N] numpy array with event x location
29 | :return ys: [N] numpy array with event y location
30 | :return ts: [N] numpy array with event timestamp
31 | :return ps: [N] numpy array with event polarity ([-1, 1])
32 | """
33 | file = np.load(path)
34 | xs = file[:, 1]
35 | ys = file[:, 2]
36 | ts = file[:, 0]
37 | ps = file[:, 3]
38 | ts -= file[0][0] # sequence starting at t0 = 0
39 | ts *= 1.0e6 # us
40 | return xs, ys, ts, ps
41 |
42 | @staticmethod
43 | def event_formatting(xs, ys, ts, ps):
44 | """
45 | Reset sequence-specific variables.
46 | :param xs: [N] numpy array with event x location
47 | :param ys: [N] numpy array with event y location
48 | :param ts: [N] numpy array with event timestamp
49 | :param ps: [N] numpy array with event polarity ([-1, 1])
50 | :return xs: [N] tensor with event x location
51 | :return ys: [N] tensor with event y location
52 | :return ts: [N] tensor with normalized event timestamp
53 | :return ps: [N] tensor with event polarity ([-1, 1])
54 | """
55 |
56 | xs = torch.from_numpy(xs.astype(np.float32))
57 | ys = torch.from_numpy(ys.astype(np.float32))
58 | ts = torch.from_numpy(ts.astype(np.float32))
59 | ps = torch.from_numpy(ps.astype(np.float32))
60 | if int(ps.min()) == 0:
61 | ps = 2*ps-1
62 | ts = (ts - ts[0]) / (ts[-1] - ts[0])
63 | return xs, ys, ts, ps
64 |
65 |
66 | def create_cnt_encoding(self, xs, ys, ts, ps):
67 | """
68 | Creates a per-pixel and per-polarity event count representation.
69 | :param xs: [N] tensor with event x location
70 | :param ys: [N] tensor with event y location
71 | :param ts: [N] tensor with normalized event timestamp
72 | :param ps: [N] tensor with event polarity ([-1, 1])
73 | :return [2 x H x W] event representation
74 | """
75 |
76 | return events_to_channels(xs, ys, ps, sensor_size=self.resolution)
77 |
78 | def create_voxel_encoding(self, xs, ys, ts, ps):
79 | """
80 | Creates a spatiotemporal voxel grid tensor representation with a certain number of bins,
81 | as described in Section 3.1 of the paper 'Unsupervised Event-based Learning of Optical Flow,
82 | Depth, and Egomotion', Zhu et al., CVPR'19..
83 | Events are distributed to the spatiotemporal closest bins through bilinear interpolation.
84 | Positive events are added as +1, while negative as -1.
85 | :param xs: [N] tensor with event x location
86 | :param ys: [N] tensor with event y location
87 | :param ts: [N] tensor with normalized event timestamp
88 | :param ps: [N] tensor with event polarity ([-1, 1])
89 | :return [B x H x W] event representation
90 | """
91 |
92 | return events_to_voxel(
93 | xs,
94 | ys,
95 | ts,
96 | ps,
97 | self.num_bins,
98 | sensor_size=self.resolution,
99 | )
100 |
101 | @staticmethod
102 | def create_list_encoding(xs, ys, ts, ps):
103 | """
104 | Creates a four channel tensor with all the events in the input partition.
105 | :param xs: [N] tensor with event x location
106 | :param ys: [N] tensor with event y location
107 | :param ts: [N] tensor with normalized event timestamp
108 | :param ps: [N] tensor with event polarity ([-1, 1])
109 | :return [N x 4] event representation
110 | """
111 |
112 | return torch.stack([ts, ys, xs, ps])
113 |
114 | @staticmethod
115 | def create_polarity_mask(ps):
116 | """
117 | Creates a two channel tensor that acts as a mask for the input event list.
118 | :param ps: [N] tensor with event polarity ([-1, 1])
119 | :return [N x 2] event representation
120 | """
121 |
122 | inp_pol_mask = torch.stack([ps, ps])
123 | inp_pol_mask[0, :][inp_pol_mask[0, :] < 0] = 0
124 | inp_pol_mask[1, :][inp_pol_mask[1, :] > 0] = 0
125 | inp_pol_mask[1, :] *= -1
126 | return inp_pol_mask
127 |
128 | @staticmethod
129 | def event_padding(xs, ys, ts, ps, pad):
130 | (pad_left, pad_right, pad_top, pad_bottom) = pad
131 | return xs+pad_left, ys+pad_top, ts, ps
132 |
133 | def get_from_data(self, xs, ys, ts, ps, pad):
134 |
135 | # xs, ys, ts, ps = self.get_events(path)
136 |
137 | # timestamp normalization
138 | xs, ys, ts, ps = self.event_formatting(xs, ys, ts, ps)
139 | # padding
140 | # xs, ys, ts, ps = self.event_padding(xs, ys, ts, ps, pad)
141 |
142 | # data augmentation
143 | # xs, ys, ps = self.augment_events(xs, ys, ps, batch)
144 |
145 | # artificial pauses to the event stream
146 | # if "Pause" in self.config["loader"]["augment"]:
147 | # if self.batch_augmentation["Pause"]:
148 | # xs = torch.from_numpy(np.empty([0]).astype(np.float32))
149 | # ys = torch.from_numpy(np.empty([0]).astype(np.float32))
150 | # ts = torch.from_numpy(np.empty([0]).astype(np.float32))
151 | # ps = torch.from_numpy(np.empty([0]).astype(np.float32))
152 |
153 | # events to tensors
154 | # inp_cnt = self.create_cnt_encoding(xs, ys, ts, ps)
155 | # inp_voxel = self.create_voxel_encoding(xs, ys, ts, ps)
156 | inp_list = self.create_list_encoding(xs, ys, ts, ps)
157 | inp_pol_mask = self.create_polarity_mask(ps)
158 |
159 | # hot pixel removal
160 | # if self.config["hot_filter"]["enabled"]:
161 | # hot_mask = self.create_hot_mask(xs, ys, ps, batch)
162 | # hot_mask_voxel = torch.stack([hot_mask] * self.num_bins, axis=2).permute(2, 0, 1)
163 | # hot_mask_cnt = torch.stack([hot_mask] * 2, axis=2).permute(2, 0, 1)
164 | # inp_voxel = inp_voxel * hot_mask_voxel
165 | # inp_cnt = inp_cnt * hot_mask_cnt
166 |
167 |
168 | # prepare output
169 | output = {}
170 | # output["inp_cnt"] = inp_cnt
171 | # output["inp_voxel"] = inp_voxel
172 |
173 | # 4 x N to N x 4
174 | output["inp_list"] = inp_list.permute(1, 0)
175 | output["inp_pol_mask"] = inp_pol_mask.permute(1, 0)
176 |
177 | return output
178 |
179 |
180 | def get_from_path(self, path, pad):
181 |
182 | # load events
183 | xs = np.zeros((0))
184 | ys = np.zeros((0))
185 | ts = np.zeros((0))
186 | ps = np.zeros((0))
187 |
188 | xs, ys, ts, ps = self.get_events(path)
189 |
190 | # timestamp normalization
191 | xs, ys, ts, ps = self.event_formatting(xs, ys, ts, ps)
192 | # padding
193 | # xs, ys, ts, ps = self.event_padding(xs, ys, ts, ps, pad)
194 |
195 | # data augmentation
196 | # xs, ys, ps = self.augment_events(xs, ys, ps, batch)
197 |
198 | # artificial pauses to the event stream
199 | # if "Pause" in self.config["loader"]["augment"]:
200 | # if self.batch_augmentation["Pause"]:
201 | # xs = torch.from_numpy(np.empty([0]).astype(np.float32))
202 | # ys = torch.from_numpy(np.empty([0]).astype(np.float32))
203 | # ts = torch.from_numpy(np.empty([0]).astype(np.float32))
204 | # ps = torch.from_numpy(np.empty([0]).astype(np.float32))
205 |
206 | # events to tensors
207 | # inp_cnt = self.create_cnt_encoding(xs, ys, ts, ps)
208 | # inp_voxel = self.create_voxel_encoding(xs, ys, ts, ps)
209 | inp_list = self.create_list_encoding(xs, ys, ts, ps)
210 | inp_pol_mask = self.create_polarity_mask(ps)
211 |
212 | # hot pixel removal
213 | # if self.config["hot_filter"]["enabled"]:
214 | # hot_mask = self.create_hot_mask(xs, ys, ps, batch)
215 | # hot_mask_voxel = torch.stack([hot_mask] * self.num_bins, axis=2).permute(2, 0, 1)
216 | # hot_mask_cnt = torch.stack([hot_mask] * 2, axis=2).permute(2, 0, 1)
217 | # inp_voxel = inp_voxel * hot_mask_voxel
218 | # inp_cnt = inp_cnt * hot_mask_cnt
219 |
220 |
221 | # prepare output
222 | output = {}
223 | # output["inp_cnt"] = inp_cnt
224 | # output["inp_voxel"] = inp_voxel
225 |
226 | # 4 x N to N x 4
227 | output["inp_list"] = inp_list.permute(1, 0)
228 | output["inp_pol_mask"] = inp_pol_mask.permute(1, 0)
229 |
230 | return output
231 |
232 | # B x frame x (datas) -> frame x B*(datas)
233 | def custom_collate(batch): # batch = ((left_padded, right_padded, disp_padded, pad, debug, "left_inp_list", "left_inp_pol_mask", "right_inp_list", "right_inp_pol_mask))
234 | """
235 | Collects the different event representations and stores them together in a dictionary.
236 | """
237 | B = len(batch)
238 | num_frame = len(batch[0])
239 |
240 | datasets = [([[] for i in range(4)] + [{}]) for i in range(num_frame)]
241 |
242 | for bn, data in enumerate(batch):
243 | for f, d in enumerate(data):
244 | datasets[f][0].append(d[0])
245 | datasets[f][1].append(d[1])
246 | datasets[f][2].append(d[2])
247 | datasets[f][3].append(d[3])
248 | if bn == 0:
249 | for key in d[4].keys():
250 | datasets[f][4][key] = []
251 | for key in d[4].keys():
252 | datasets[f][4][key].append(d[4][key])
253 |
254 | for fi, frame_data in enumerate(datasets):
255 | for i, data in enumerate(frame_data):
256 | if isinstance(data, dict):
257 | for key in data.keys():
258 | try:
259 | frame_data[i][key] = default_collate(data[key])
260 | except:
261 | continue
262 | else:
263 | frame_data[i] = default_collate(data)
264 |
265 | return datasets
266 |
--------------------------------------------------------------------------------
/dataloader/dsceDataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional
3 | from torch.nn.functional import pad
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import os
9 | import json
10 | import cv2
11 | import random
12 | import pdb
13 |
14 |
15 | class DSECdataset(Dataset):
16 | def __init__(self, root: str, ann: list, height: int, width: int, frame_idxs: list = [0, ], eval_maxdisp=255.0):
17 | self.root = root
18 | self.data_list = ann
19 | self.height = height
20 | self.width = width
21 | self.frame_idxs = frame_idxs
22 | self.eval_maxdisp = eval_maxdisp
23 |
24 |
25 | def add_padding(self, image, pad=0):
26 | H, W = image.shape[-2:]
27 | pad_left = 0
28 | pad_right = self.width - W
29 | pad_top = self.height - H
30 | pad_bottom = 0
31 | assert pad_left >= 0 and pad_right >= 0 and pad_top >= 0 and pad_bottom >= 0, "Require image crop, please check the image size"
32 | padded = torch.nn.functional.pad(image, [pad_left, pad_right, pad_top, pad_bottom], mode='constant', value=pad)
33 | return padded, (pad_left, pad_right, pad_top, pad_bottom)
34 |
35 | def __len__(self):
36 | return len(self.data_list)
37 |
38 | def __getitem__(self, i):
39 | data = []
40 | # get dictionary of frame idxs
41 | metadata = self.data_list[i]
42 | for idx in self.frame_idxs:
43 | frame_info = metadata.get(str(idx), None)
44 | # get relative file path respect to self.root
45 | assert frame_info is not None, "Cannot get frame info of frame {:d}:{:d}, check json file".format(i, idx)
46 | left_event_path = frame_info.get('left_image_path', None)
47 | assert left_event_path is not None
48 | right_event_path = frame_info.get('right_image_path', None)
49 | assert right_event_path is not None
50 | disp_path = frame_info.get('left_disp_path', None)
51 | assert disp_path is not None
52 | flow_path = frame_info.get('left_flow_path', None)
53 | assert flow_path is not None
54 |
55 | left_event_path = os.path.join(self.root, left_event_path)
56 | right_event_path = os.path.join(self.root, right_event_path)
57 | disp_path = os.path.join(self.root, disp_path)
58 | flow_path = os.path.join(self.root, flow_path)
59 | left_img_path = left_event_path.replace('num5voxel0', 'image0').replace('npy', 'png')
60 | right_img_path = right_event_path.replace('num5voxel1', 'image1').replace('npy', 'png')
61 |
62 | left_event = np.load(left_event_path)
63 | right_event = np.load(right_event_path)
64 | # need to be (C, H, W)
65 | assert left_event.ndim==3
66 | assert right_event.ndim==3
67 | flow_gt = np.load(flow_path)
68 | # pdb.set_trace()
69 | assert flow_gt.ndim==3
70 |
71 | # (H, W)
72 | disp = cv2.imread(disp_path, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_ANYCOLOR)
73 | # wierd behavior of mvsec, just do it
74 | disp = disp / 7.0
75 |
76 | left_event = torch.tensor(left_event)
77 | right_event = torch.tensor(right_event)
78 | disp = torch.tensor(disp)
79 | flow_gt = torch.tensor(flow_gt)
80 | left_img = cv2.imread(left_img_path, cv2.IMREAD_GRAYSCALE)
81 | right_img = cv2.imread(right_img_path, cv2.IMREAD_GRAYSCALE)
82 | left_img = torch.tensor(left_img).type(torch.float32)/255.0
83 | right_img = torch.tensor(right_img).type(torch.float32)/255.0
84 |
85 | left_padded, pad = self.add_padding(left_event)
86 | right_padded, _ = self.add_padding(right_event)
87 | disp_padded , _ = self.add_padding(disp, 255.0)
88 | flow_gt, _ = self.add_padding(flow_gt)
89 | left_img_padded, _ = self.add_padding(left_img)
90 | right_img_padded, _ = self.add_padding(right_img)
91 |
92 | disp_mask = disp_padded > self.eval_maxdisp
93 | disp_padded[disp_mask] = float('Nan')
94 | # disp_padded[disp_mask] = float(1e5)
95 |
96 | debug = {"left_event_path": left_event_path,
97 | "right_event_path": right_event_path,
98 | "disp_path": disp_path,
99 | "flow_gt": flow_gt,
100 | "left_img": left_img,
101 | "right_img": right_img}
102 | data.append((left_padded, right_padded, disp_padded, pad, debug))
103 | return data
104 |
105 | def get_datasets(root, split, height, width, frame_idxs, num_validation, eval_maxdisp):
106 | assert max(frame_idxs) == 0
107 | # frame_idxs.sort()
108 | assert os.path.isdir(root)
109 |
110 |
111 | train_test_file_dict = []
112 | for train_test in ['train', 'test']:
113 | if train_test == 'train':
114 | seqs = train_seqs
115 | frame_filter = FRAMES_FILTER_FOR_TRAINING['indoor_flying']
116 | elif train_test == 'test':
117 | seqs = test_seqs
118 | frame_filter = FRAMES_FILTER_FOR_TEST['indoor_flying']
119 | else:
120 | raise NotImplementedError
121 |
122 | file_dict = []
123 | # make file sub-path for train
124 | for seq in seqs:
125 | seq_dir = f'indoor_flying_{seq}'
126 | left_dir_path = os.path.join(seq_dir, 'num5voxel0')
127 | right_dir_path = os.path.join(seq_dir, 'num5voxel1')
128 | disp_dir_path = os.path.join(seq_dir, 'disparity_image')
129 | flow_dir_path = os.path.join(seq_dir, 'flow0')
130 | for frame in frame_filter[seq][-min(frame_idxs):]:
131 | left_right_disp = {}
132 | for fi in frame_idxs:
133 | num_str = "{:06d}".format(frame + fi)
134 | event_name = num_str + ".npy"
135 | disp_name = num_str + ".png"
136 | assert os.path.isfile(os.path.join(root, left_dir_path, event_name))
137 | assert os.path.isfile(os.path.join(root, right_dir_path, event_name))
138 | assert os.path.isfile(os.path.join(root, disp_dir_path, disp_name))
139 | left_right_disp[str(fi)] = {
140 | 'left_image_path': os.path.join(left_dir_path, event_name),
141 | 'right_image_path': os.path.join(right_dir_path, event_name),
142 | 'left_disp_path': os.path.join(disp_dir_path, disp_name),
143 | 'left_flow_path': os.path.join(flow_dir_path, event_name)
144 | }
145 | file_dict.append(left_right_disp)
146 | train_test_file_dict.append(file_dict)
147 |
148 | train_dataset = MVSECdataset(
149 | root,
150 | train_test_file_dict[0],
151 | height,
152 | width,
153 | frame_idxs,
154 | eval_maxdisp)
155 |
156 | if num_validation > 0:
157 | import random
158 | random.shuffle(train_test_file_dict[1])
159 | validation_dataset = MVSECdataset(
160 | root,
161 | train_test_file_dict[1][:200],
162 | height,
163 | width,
164 | frame_idxs,
165 | eval_maxdisp)
166 | test_dataset = MVSECdataset(
167 | root,
168 | train_test_file_dict[1][200:],
169 | height,
170 | width,
171 | frame_idxs,
172 | eval_maxdisp)
173 | else:
174 | validation_dataset = DSECdataset(
175 | root,
176 | train_test_file_dict[1],
177 | height,
178 | width,
179 | frame_idxs,
180 | eval_maxdisp)
181 | test_dataset = DSECdataset(
182 | root,
183 | train_test_file_dict[1],
184 | height,
185 | width,
186 | frame_idxs,
187 | eval_maxdisp)
188 |
189 | return train_dataset, validation_dataset, test_dataset
190 |
191 |
192 | # for debugging
193 | if __name__ == '__main__':
194 | # mvsecdataset = MVSECdataset(
195 | # '/home/jaeyoung/data/ws/event_stereo_ICCV2019/dataset',
196 | # '/home/jaeyoung/data/ws/event_stereo_ICCV2019/dataset/view_4_train_v5_split1.json',
197 | # 288,
198 | # 352
199 | # )
200 | # test = mvsecdataset[0]
201 | # breakpoint()
202 | new_train, new_val, new_test = get_datasets('/home/jaeyoung/data/ws/event_stereo_ICCV2019/dataset', 1, 288, 352, [-3, -2, -1, 0], 0)
203 | # min_disp = []
204 | # for i in new_train:
205 | # min_disp.append(torch.min(i[-1][2]).item())
206 | new_train[0]
207 | breakpoint()
208 | # getDataset(root, 'event_left', 'event_right', 'disparity_image', 80, 1260)
209 |
--------------------------------------------------------------------------------
/dataloader/dsecProvider_test.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | import pdb
5 | from .dsecSequence import Sequence
6 |
7 | DATA_SPLIT = {
8 | 'train': ['interlaken_00_c', 'interlaken_00_d', 'interlaken_00_e', 'zurich_city_00_a',
9 | 'zurich_city_00_b', 'zurich_city_01_a', 'zurich_city_01_b', 'zurich_city_01_c',
10 | 'zurich_city_01_d', 'zurich_city_01_e', 'zurich_city_02_a',
11 | 'zurich_city_02_b', 'zurich_city_02_c', 'zurich_city_02_d', 'zurich_city_02_e',
12 | 'zurich_city_03_a', 'zurich_city_04_a', 'zurich_city_04_b', 'zurich_city_04_c',
13 | 'zurich_city_04_d', 'zurich_city_04_e', 'zurich_city_04_f', 'zurich_city_09_a',
14 | 'zurich_city_09_b', 'zurich_city_09_e', 'zurich_city_10_a',
15 | 'zurich_city_11_a', 'zurich_city_11_c',
16 | 'interlaken_00_f', 'interlaken_00_g', 'thun_00_a', 'zurich_city_05_a',
17 | 'zurich_city_05_b', 'zurich_city_07_a',
18 | 'zurich_city_09_d', 'zurich_city_10_b'],
19 | 'validation': ['zurich_city_01_f', 'zurich_city_06_a', 'zurich_city_11_b','zurich_city_09_c',
20 | 'zurich_city_08_a']
21 | }
22 |
23 | DATA_SPLIT_MINI = {
24 | 'train': ['zurich_city_03_a', 'zurich_city_04_a', 'zurich_city_04_b', 'zurich_city_04_c',
25 | 'zurich_city_04_d', 'zurich_city_04_e', 'zurich_city_04_f', 'zurich_city_09_a',
26 | ],
27 | 'validation': ['zurich_city_01_f', 'zurich_city_06_a']
28 | }
29 |
30 | DATA_SPLIT_SUPER_MINI = {
31 | 'train': ['zurich_city_04_a' ],
32 | 'validation': ['zurich_city_04_d']
33 | }
34 |
35 | class DatasetProvider:
36 | def __init__(self, dataset_path: Path, raw_dataset_path: Path, delta_t_ms: int=100, num_bins=15, frame_idxs = range(-3, 1),
37 | eval_maxdisp=192, pseudo_path=None, pad_width=648, pad_height=480, use_mini = False, use_super_mini = False, img_load = False):
38 |
39 | train_path = dataset_path / 'train'
40 | train_raw_path = raw_dataset_path / 'train'
41 |
42 | assert dataset_path.is_dir(), str(dataset_path)
43 | assert train_raw_path.is_dir(), str(train_raw_path)
44 | assert train_path.is_dir(), str(train_path)
45 |
46 | if use_super_mini:
47 | data_split = DATA_SPLIT_SUPER_MINI
48 | elif use_mini:
49 | data_split = DATA_SPLIT_MINI
50 | else:
51 | data_split = DATA_SPLIT
52 |
53 | test_path = dataset_path / 'test'
54 | test_raw_path = raw_dataset_path / 'test'
55 |
56 | assert test_path.is_dir(), str(test_path)
57 |
58 | test_sequences = list()
59 | for child in test_path.iterdir():
60 | if str(child).split('/')[-1] not in 'thun_02_a':
61 | raw_child = test_raw_path / str(child).split('/')[-1]
62 | test_sequences.append(Sequence(child, 'test', delta_t_ms, num_bins, frame_idxs, eval_maxdisp, pseudo_path=pseudo_path, pad_width=pad_width, pad_height=pad_height,
63 | raw_seq_path = raw_child, img_load = img_load))
64 |
65 |
66 | self.test_dataset = torch.utils.data.ConcatDataset(test_sequences)
67 |
68 | def get_train_dataset(self):
69 | return self.train_dataset
70 |
71 | def get_val_dataset(self):
72 | return self.val_dataset
73 |
74 | def get_test_dataset(self):
75 | return self.test_dataset
--------------------------------------------------------------------------------
/dataloader/eventslicer.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, Tuple
3 |
4 | import h5py
5 | from numba import jit
6 | import numpy as np
7 |
8 |
9 | class EventSlicer:
10 | def __init__(self, h5f: h5py.File):
11 | self.h5f = h5f
12 |
13 | self.events = dict()
14 | for dset_str in ['p', 'x', 'y', 't']:
15 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)]
16 |
17 | # This is the mapping from milliseconds to event index:
18 | # It is defined such that
19 | # (1) t[ms_to_idx[ms]] >= ms*1000
20 | # (2) t[ms_to_idx[ms] - 1] < ms*1000
21 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds.
22 | #
23 | # As an example, given 't' and 'ms':
24 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000
25 | # ms: 0 1 2 3 4 5 6 7 8 9
26 | #
27 | # we get
28 | #
29 | # ms_to_idx:
30 | # 0 2 2 3 3 3 5 5 8 9
31 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64')
32 |
33 | if "t_offset" in list(h5f.keys()):
34 | self.t_offset = int(h5f['t_offset'][()])
35 | else:
36 | self.t_offset = 0
37 | self.t_final = int(self.events['t'][-1]) + self.t_offset
38 |
39 | def get_start_time_us(self):
40 | return self.t_offset
41 |
42 | def get_final_time_us(self):
43 | return self.t_final
44 |
45 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]:
46 | """Get events (p, x, y, t) within the specified time window
47 | Parameters
48 | ----------
49 | t_start_us: start time in microseconds
50 | t_end_us: end time in microseconds
51 | Returns
52 | -------
53 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved
54 | """
55 | assert t_start_us < t_end_us
56 |
57 | # We assume that the times are top-off-day, hence subtract offset:
58 | t_start_us -= self.t_offset
59 | t_end_us -= self.t_offset
60 |
61 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us)
62 | t_start_ms_idx = self.ms2idx(t_start_ms)
63 | t_end_ms_idx = self.ms2idx(t_end_ms)
64 |
65 | if t_start_ms_idx is None or t_end_ms_idx is None:
66 | # Cannot guarantee window size anymore
67 | return None
68 |
69 | events = dict()
70 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx])
71 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us)
72 | t_start_us_idx = t_start_ms_idx + idx_start_offset
73 | t_end_us_idx = t_start_ms_idx + idx_end_offset
74 | # Again add t_offset to get gps time
75 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset
76 | for dset_str in ['p', 'x', 'y']:
77 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx])
78 | assert events[dset_str].size == events['t'].size
79 | return events
80 |
81 |
82 | @staticmethod
83 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]:
84 | """Compute a conservative time window of time with millisecond resolution.
85 | We have a time to index mapping for each millisecond. Hence, we need
86 | to compute the lower and upper millisecond to retrieve events.
87 | Parameters
88 | ----------
89 | ts_start_us: start time in microseconds
90 | ts_end_us: end time in microseconds
91 | Returns
92 | -------
93 | window_start_ms: conservative start time in milliseconds
94 | window_end_ms: conservative end time in milliseconds
95 | """
96 | assert ts_end_us > ts_start_us
97 | window_start_ms = math.floor(ts_start_us/1000)
98 | window_end_ms = math.ceil(ts_end_us/1000)
99 | return window_start_ms, window_end_ms
100 |
101 | @staticmethod
102 | @jit(nopython=True)
103 | def get_time_indices_offsets(
104 | time_array: np.ndarray,
105 | time_start_us: int,
106 | time_end_us: int) -> Tuple[int, int]:
107 | """Compute index offset of start and end timestamps in microseconds
108 | Parameters
109 | ----------
110 | time_array: timestamps (in us) of the events
111 | time_start_us: start timestamp (in us)
112 | time_end_us: end timestamp (in us)
113 | Returns
114 | -------
115 | idx_start: Index within this array corresponding to time_start_us
116 | idx_end: Index within this array corresponding to time_end_us
117 | such that (in non-edge cases)
118 | time_array[idx_start] >= time_start_us
119 | time_array[idx_end] >= time_end_us
120 | time_array[idx_start - 1] < time_start_us
121 | time_array[idx_end - 1] < time_end_us
122 | this means that
123 | time_start_us <= time_array[idx_start:idx_end] < time_end_us
124 | """
125 |
126 | assert time_array.ndim == 1
127 |
128 | idx_start = -1
129 | if time_array[-1] < time_start_us:
130 | # This can happen in extreme corner cases. E.g.
131 | # time_array[0] = 1016
132 | # time_array[-1] = 1984
133 | # time_start_us = 1990
134 | # time_end_us = 2000
135 |
136 | # Return same index twice: array[x:x] is empty.
137 | return time_array.size, time_array.size
138 | else:
139 | for idx_from_start in range(0, time_array.size, 1):
140 | if time_array[idx_from_start] >= time_start_us:
141 | idx_start = idx_from_start
142 | break
143 | assert idx_start >= 0
144 |
145 | idx_end = time_array.size
146 | for idx_from_end in range(time_array.size - 1, -1, -1):
147 | if time_array[idx_from_end] >= time_end_us:
148 | idx_end = idx_from_end
149 | else:
150 | break
151 |
152 | assert time_array[idx_start] >= time_start_us
153 | if idx_end < time_array.size:
154 | assert time_array[idx_end] >= time_end_us
155 | if idx_start > 0:
156 | assert time_array[idx_start - 1] < time_start_us
157 | if idx_end > 0:
158 | assert time_array[idx_end - 1] < time_end_us
159 | return idx_start, idx_end
160 |
161 | def ms2idx(self, time_ms: int) -> int:
162 | assert time_ms >= 0
163 | if time_ms >= self.ms_to_idx.size:
164 | return None
165 | return self.ms_to_idx[time_ms]
--------------------------------------------------------------------------------
/dataloader/representations.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class EventRepresentation:
5 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor):
6 | raise NotImplementedError
7 |
8 |
9 | class VoxelGrid(EventRepresentation):
10 | def __init__(self, channels: int, height: int, width: int, normalize: bool):
11 | self.voxel_grid = torch.zeros((channels, height, width), dtype=torch.float, requires_grad=False)
12 | self.nb_channels = channels
13 | self.normalize = normalize
14 |
15 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor):
16 | assert x.shape == y.shape == pol.shape == time.shape
17 | assert x.ndim == 1
18 |
19 | C, H, W = self.voxel_grid.shape
20 | with torch.no_grad():
21 |
22 | self.voxel_grid = self.voxel_grid.to(pol.device)
23 | voxel_grid = self.voxel_grid.clone()
24 |
25 | t_norm = time
26 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0])
27 |
28 | x0 = x.int()
29 | y0 = y.int()
30 | t0 = t_norm.int()
31 |
32 | value = 2*pol-1
33 |
34 | for xlim in [x0,x0+1]:
35 | for ylim in [y0,y0+1]:
36 | for tlim in [t0,t0+1]:
37 |
38 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels)
39 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs())
40 |
41 | index = H * W * tlim.long() + \
42 | W * ylim.long() + \
43 | xlim.long()
44 |
45 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
46 |
47 | if self.normalize:
48 | mask = torch.nonzero(voxel_grid, as_tuple=True)
49 | if mask[0].size()[0] > 0:
50 | mean = voxel_grid[mask].mean()
51 | std = voxel_grid[mask].std()
52 | if std > 0:
53 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std
54 | else:
55 | voxel_grid[mask] = voxel_grid[mask] - mean
56 |
57 | return voxel_grid
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM zombbie/cuda11.1-cudnn8-ubuntu20.04:v1.0
2 | ARG DEBIAN_FRONTEND=noninteractive
3 |
4 | RUN apt-get update && \
5 | apt-get install -y software-properties-common && \
6 | add-apt-repository -y ppa:deadsnakes/ppa && \
7 | apt install -y python3.8 python3-pip
8 | RUN apt-get -y install git
9 |
10 | RUN pip --default-timeout=2000 install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html
11 |
12 | RUN pip install cupy-cuda111
13 |
14 | RUN apt-get update
15 | RUN apt-get -y install qt5-default python3-pyqt5
16 | RUN apt-get -y install net-tools
17 |
18 | RUN mkdir -p /home/data
19 | RUN mkdir -p /home/ws
20 | ADD ../requirements.txt /home/requirements.txt
21 | ADD ../.tmux.conf /home/.tmux.conf
22 | RUN apt-get -y install tmux
23 |
24 | WORKDIR /home/ws/TESNet
25 | RUN pip install -r ../../requirements.txt
26 |
27 | RUN apt-get install htop
28 | RUN apt-get install -y xclip
29 | RUN apt-get install -y zip
30 |
31 | ARG UID
32 | ARG GID
33 |
34 | # Update the package list, install sudo, create a non-root user, and grant password-less sudo permissions
35 | RUN apt update && \
36 | apt install -y sudo && \
37 | addgroup --gid $GID nonroot && \
38 | adduser --uid $UID --gid $GID --disabled-password --gecos "" nonroot && \
39 | echo 'nonroot ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers
40 |
41 | # Set the non-root user as the default user
42 | USER nonroot
43 | RUN echo "alias tn='. run_tmux.sh'" >> ~/.bashrc
44 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | ## 1. Install docker
2 | Follow the instructions here [Docker-install](https://docs.docker.com/desktop/install/linux/ubuntu/)
3 | ## 2. Install NVIDIA Container Toolkit
4 | Follow the instructions here [NVIDIA-install](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
5 |
6 | ## 3. Go to the project path
7 | ```bash
8 | cd {PATH_TO_THIS_PROJECT}/TemporalEventStereo
9 | ```
10 | ## 4. Build Docker Image
11 | Build docker image using the following script.
12 | This will build the docker image "tes:v0.1"
13 | ```bash
14 | source docker/docker_build.sh
15 | ```
16 | ## 5. Start Docker Container
17 |
18 | ```bash
19 | source docker/docker_run_multi.sh {ABSOLUTE_DATA_DIR_PATH}
20 | ```
--------------------------------------------------------------------------------
/docker/docker_build.sh:
--------------------------------------------------------------------------------
1 | export HOST_UID=$(id -u)
2 | export HOST_GID=$(id -g)
3 |
4 | docker build --build-arg UID=$HOST_UID --build-arg GID=$HOST_GID \
5 | -t tes:v0.1 \
6 | -f docker/Dockerfile .
7 |
--------------------------------------------------------------------------------
/docker/docker_run_multi.sh:
--------------------------------------------------------------------------------
1 | data_path=$1
2 |
3 | docker run -it --gpus all \
4 | --name TES \
5 | --shm-size=128G \
6 | -v $data_path:/home/data\
7 | -v $PWD:/home/ws/TESNet\
8 | tes:v0.1 bash
--------------------------------------------------------------------------------
/infer_dsec.sh:
--------------------------------------------------------------------------------
1 | logdir=exps/debug
2 |
3 | python3 infer_dsec.py --config config/dsec_infer.yaml\
4 | --savemodel ${logdir}
5 |
6 | curr_dir=$PWD
7 | cd $logdir
8 | zip -r test.zip test
9 | cd $curr_dir
10 | echo Gpu number is ${GPUNUM}!
11 | echo Done!
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model_large import TESNet as ours_large
2 |
--------------------------------------------------------------------------------
/models/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mickeykang16/TemporalEventStereo/d9dc74677f568dbdf2d1d06da20480c50410e274/models/__init__.pyc
--------------------------------------------------------------------------------
/models/submodule.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.data
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 | import math
8 | import numpy as np
9 |
10 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation):
11 |
12 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False),
13 | nn.BatchNorm2d(out_planes))
14 |
15 |
16 | def convbn_3d(in_planes, out_planes, kernel_size, stride, pad):
17 |
18 | return nn.Sequential(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=pad, stride=stride,bias=False),
19 | nn.BatchNorm3d(out_planes))
20 |
21 | class BasicBlock(nn.Module):
22 | expansion = 1
23 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation):
24 | super(BasicBlock, self).__init__()
25 |
26 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation),
27 | nn.ReLU(inplace=True))
28 |
29 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation)
30 |
31 | self.downsample = downsample
32 | self.stride = stride
33 |
34 | def forward(self, x):
35 | out = self.conv1(x)
36 | out = self.conv2(out)
37 |
38 | if self.downsample is not None:
39 | x = self.downsample(x)
40 |
41 | out += x
42 |
43 | return out
44 |
45 | class disparityregression(nn.Module):
46 | def __init__(self, maxdisp):
47 | super(disparityregression, self).__init__()
48 | self.disp = torch.Tensor(np.reshape(np.array(range(maxdisp)),[1, maxdisp,1,1])).cuda()
49 |
50 | def forward(self, x):
51 | out = torch.sum(x*self.disp.data,1, keepdim=True)
52 | return out
53 |
54 | class feature_extraction(nn.Module):
55 | def __init__(self, in_ch=5):
56 | super(feature_extraction, self).__init__()
57 | self.inplanes = 32
58 | self.firstconv = nn.Sequential(convbn(in_ch, 32, 3, 2, 1, 1),
59 | nn.ReLU(inplace=True),
60 | convbn(32, 32, 3, 1, 1, 1),
61 | nn.ReLU(inplace=True),
62 | convbn(32, 32, 3, 1, 1, 1),
63 | nn.ReLU(inplace=True))
64 |
65 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1,1,1)
66 | self.layer2 = self._make_layer(BasicBlock, 64, 16, 2,1,1)
67 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 1,1,1)
68 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1,1,2)
69 |
70 | self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64,64)),
71 | convbn(128, 32, 1, 1, 0, 1),
72 | nn.ReLU(inplace=True))
73 |
74 | self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32,32)),
75 | convbn(128, 32, 1, 1, 0, 1),
76 | nn.ReLU(inplace=True))
77 |
78 | self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16,16)),
79 | convbn(128, 32, 1, 1, 0, 1),
80 | nn.ReLU(inplace=True))
81 |
82 | self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8,8)),
83 | convbn(128, 32, 1, 1, 0, 1),
84 | nn.ReLU(inplace=True))
85 |
86 | self.lastconv = nn.Sequential(convbn(320, 128, 3, 1, 1, 1),
87 | nn.ReLU(inplace=True),
88 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride = 1, bias=False))
89 |
90 | def _make_layer(self, block, planes, blocks, stride, pad, dilation):
91 | downsample = None
92 | if stride != 1 or self.inplanes != planes * block.expansion:
93 | downsample = nn.Sequential(
94 | nn.Conv2d(self.inplanes, planes * block.expansion,
95 | kernel_size=1, stride=stride, bias=False),
96 | nn.BatchNorm2d(planes * block.expansion),)
97 |
98 | layers = []
99 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))
100 | self.inplanes = planes * block.expansion
101 | for i in range(1, blocks):
102 | layers.append(block(self.inplanes, planes,1,None,pad,dilation))
103 |
104 | return nn.Sequential(*layers)
105 |
106 | def forward(self, x):
107 | output = self.firstconv(x)
108 | output = self.layer1(output)
109 | output_raw = self.layer2(output)
110 | output = self.layer3(output_raw)
111 | output_skip = self.layer4(output)
112 |
113 |
114 | output_branch1 = self.branch1(output_skip)
115 | output_branch1 = F.upsample(output_branch1, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
116 |
117 | output_branch2 = self.branch2(output_skip)
118 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
119 |
120 | output_branch3 = self.branch3(output_skip)
121 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
122 |
123 | output_branch4 = self.branch4(output_skip)
124 | output_branch4 = F.upsample(output_branch4, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
125 |
126 | output_feature = torch.cat((output_raw, output_skip, output_branch4, output_branch3, output_branch2, output_branch1), 1)
127 | output_feature = self.lastconv(output_feature)
128 |
129 | return output_feature
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/models/submodule.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mickeykang16/TemporalEventStereo/d9dc74677f568dbdf2d1d06da20480c50410e274/models/submodule.pyc
--------------------------------------------------------------------------------
/pre-processing/README.md:
--------------------------------------------------------------------------------
1 |
2 | We utilize the pseudo dense label for stable training of temporal stereo. To make the pseudo label, we first train the single event stereo, which has the same backbone of temporal event stereo (TESNet), but does not use the temporal aggregation.
3 |
4 | ### MVSEC
5 | Please do not forget to cite [MVSEC](https://daniilidis-group.github.io/mvsec/) if you are using the MVSEC dataset.
6 | We provide the pre-processed MVSEC dataset following [Link](https://drive.google.com/drive/folders/1ANrz99Z3UwAcTMBlQyz_PYnoRblySh8H?usp=sharing).
7 |
8 | ```
9 | The MVSEC dataset should have the following format:
10 | ├── MVSEC_dataset
11 | │ ├── indoor_flying_1
12 | │ │ ├── event0
13 | │ │ ├── event1
14 | │ │ ├── pseudo_disp
15 | │ │ ├── disparity_image
16 | │ │ ├── voxel0skip1bin5
17 | │ │ ├── voxel1skip1bin5
18 | │ └── indoor_flying_2
19 | │ ├── event0
20 | │ └── ...
21 | ```
22 |
23 | ### DSEC
24 | The DSEC dataset can be downloaded [here](https://dsec.ifi.uzh.ch/dsec-datasets/download/). Please do not forget to cite [DSEC](https://github.com/uzh-rpg/DSEC) if you are using the DSEC dataset.
25 | The pseudo label generated by our works can be downloaded [Link](https://drive.google.com/drive/folders/1Hcoarpo40lVnlXQx4CERtW5loyfv6zNf?usp=sharing).
26 |
27 | ```
28 | The DSEC dataset should have the following format:
29 | ├── DSEC
30 | │ ├── train
31 | │ │ ├── zurich_city_00_a
32 | │ │ │ │ ├── events
33 | │ │ │ │ │ ├── left
34 | │ │ │ │ │ └── right
35 | │ │ │ │ ├── images
36 | │ │ │ │ │ ├── left
37 | │ │ │ │ │ ├── right
38 | │ │ │ │ │ └── timestamps.txt
39 | │ │ │ │ ├── raw_events
40 | │ │ │ │ │ ├── left
41 | │ │ │ │ │ └── right
42 | │ │ │ │ ├── disparity
43 | │ │ │ │ │ ├── event
44 | │ │ │ │ │ └── timestamps.txt
45 | │ │ │ │ └── voxel_50ms_15bin
46 | │ │ │ │ ├── left
47 | │ │ │ │ └── right
48 | │ │ └── ...
49 | │ └── test
50 | │ ├── zurich_city_12_a
51 | │ │ │ │ ├── events
52 | │ │ │ │ │ ├── left
53 | │ │ │ │ │ └── right
54 | │ │ │ │ ├── images
55 | │ │ │ │ │ ├── left
56 | │ │ │ │ │ ├── right
57 | │ │ │ │ │ └── timestamps.txt
58 | │ │ │ │ └── voxel_50ms_15bin
59 | │ │ │ │ ├── left
60 | │ │ │ │ └── right
61 | │ └── ...
62 | └── DSEC_pseudo_GT_all
63 | └── train
64 | ├── zurich_city_00_a
65 | │ ├── 000001.png
66 | │ └── ...
67 | ├── zurich_city_00_b
68 | └── ...
69 | ```
70 |
71 | We also pre-generate the voxel grid and raw events to prevent the repetitive calculation in dataloader.
72 | This can be done using the provided code in the processing directory with the following command:
73 | The raw events are only used in the training phase for contrast maximization loss.
74 | ```bash
75 | python raw_event_parsing.py --dataset_path $DSEC_TRAIN_DATASET_PATH$
76 | python voxel_generate_all.py --dataset_path $DSEC_TRAIN_DATASET_PATH$
77 | python voxel_generate_test_all.py --dataset_path $DSEC_TEST_DATASET_PATH$
78 | ```
79 |
--------------------------------------------------------------------------------
/pre-processing/eventslicer.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, Tuple
3 |
4 | import h5py
5 | from numba import jit
6 | import numpy as np
7 |
8 |
9 | class EventSlicer:
10 | def __init__(self, h5f: h5py.File):
11 | self.h5f = h5f
12 |
13 | self.events = dict()
14 | for dset_str in ['p', 'x', 'y', 't']:
15 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)]
16 |
17 | # This is the mapping from milliseconds to event index:
18 | # It is defined such that
19 | # (1) t[ms_to_idx[ms]] >= ms*1000
20 | # (2) t[ms_to_idx[ms] - 1] < ms*1000
21 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds.
22 | #
23 | # As an example, given 't' and 'ms':
24 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000
25 | # ms: 0 1 2 3 4 5 6 7 8 9
26 | #
27 | # we get
28 | #
29 | # ms_to_idx:
30 | # 0 2 2 3 3 3 5 5 8 9
31 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64')
32 |
33 | if "t_offset" in list(h5f.keys()):
34 | self.t_offset = int(h5f['t_offset'][()])
35 | else:
36 | self.t_offset = 0
37 | self.t_final = int(self.events['t'][-1]) + self.t_offset
38 |
39 | def get_start_time_us(self):
40 | return self.t_offset
41 |
42 | def get_final_time_us(self):
43 | return self.t_final
44 |
45 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]:
46 | """Get events (p, x, y, t) within the specified time window
47 | Parameters
48 | ----------
49 | t_start_us: start time in microseconds
50 | t_end_us: end time in microseconds
51 | Returns
52 | -------
53 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved
54 | """
55 | assert t_start_us < t_end_us
56 |
57 | # We assume that the times are top-off-day, hence subtract offset:
58 | t_start_us -= self.t_offset
59 | t_end_us -= self.t_offset
60 |
61 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us)
62 | t_start_ms_idx = self.ms2idx(t_start_ms)
63 | t_end_ms_idx = self.ms2idx(t_end_ms)
64 |
65 | if t_start_ms_idx is None or t_end_ms_idx is None:
66 | # Cannot guarantee window size anymore
67 | return None
68 |
69 | events = dict()
70 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx])
71 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us)
72 | t_start_us_idx = t_start_ms_idx + idx_start_offset
73 | t_end_us_idx = t_start_ms_idx + idx_end_offset
74 | # Again add t_offset to get gps time
75 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset
76 | for dset_str in ['p', 'x', 'y']:
77 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx])
78 | assert events[dset_str].size == events['t'].size
79 | return events
80 |
81 |
82 | @staticmethod
83 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]:
84 | """Compute a conservative time window of time with millisecond resolution.
85 | We have a time to index mapping for each millisecond. Hence, we need
86 | to compute the lower and upper millisecond to retrieve events.
87 | Parameters
88 | ----------
89 | ts_start_us: start time in microseconds
90 | ts_end_us: end time in microseconds
91 | Returns
92 | -------
93 | window_start_ms: conservative start time in milliseconds
94 | window_end_ms: conservative end time in milliseconds
95 | """
96 | assert ts_end_us > ts_start_us
97 | window_start_ms = math.floor(ts_start_us/1000)
98 | window_end_ms = math.ceil(ts_end_us/1000)
99 | return window_start_ms, window_end_ms
100 |
101 | @staticmethod
102 | @jit(nopython=True)
103 | def get_time_indices_offsets(
104 | time_array: np.ndarray,
105 | time_start_us: int,
106 | time_end_us: int) -> Tuple[int, int]:
107 | """Compute index offset of start and end timestamps in microseconds
108 | Parameters
109 | ----------
110 | time_array: timestamps (in us) of the events
111 | time_start_us: start timestamp (in us)
112 | time_end_us: end timestamp (in us)
113 | Returns
114 | -------
115 | idx_start: Index within this array corresponding to time_start_us
116 | idx_end: Index within this array corresponding to time_end_us
117 | such that (in non-edge cases)
118 | time_array[idx_start] >= time_start_us
119 | time_array[idx_end] >= time_end_us
120 | time_array[idx_start - 1] < time_start_us
121 | time_array[idx_end - 1] < time_end_us
122 | this means that
123 | time_start_us <= time_array[idx_start:idx_end] < time_end_us
124 | """
125 |
126 | assert time_array.ndim == 1
127 |
128 | idx_start = -1
129 | if time_array[-1] < time_start_us:
130 | # This can happen in extreme corner cases. E.g.
131 | # time_array[0] = 1016
132 | # time_array[-1] = 1984
133 | # time_start_us = 1990
134 | # time_end_us = 2000
135 |
136 | # Return same index twice: array[x:x] is empty.
137 | return time_array.size, time_array.size
138 | else:
139 | for idx_from_start in range(0, time_array.size, 1):
140 | if time_array[idx_from_start] >= time_start_us:
141 | idx_start = idx_from_start
142 | break
143 | assert idx_start >= 0
144 |
145 | idx_end = time_array.size
146 | for idx_from_end in range(time_array.size - 1, -1, -1):
147 | if time_array[idx_from_end] >= time_end_us:
148 | idx_end = idx_from_end
149 | else:
150 | break
151 |
152 | assert time_array[idx_start] >= time_start_us
153 | if idx_end < time_array.size:
154 | assert time_array[idx_end] >= time_end_us
155 | if idx_start > 0:
156 | assert time_array[idx_start - 1] < time_start_us
157 | if idx_end > 0:
158 | assert time_array[idx_end - 1] < time_end_us
159 | return idx_start, idx_end
160 |
161 | def ms2idx(self, time_ms: int) -> int:
162 | assert time_ms >= 0
163 | if time_ms >= self.ms_to_idx.size:
164 | return None
165 | return self.ms_to_idx[time_ms]
--------------------------------------------------------------------------------
/pre-processing/raw_event_parsing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pdb
4 | import torch
5 | from pathlib import Path
6 | import hdf5plugin
7 | import h5py
8 | from eventslicer import EventSlicer
9 | import torch
10 | import argparse
11 |
12 | height = 480
13 | width = 640
14 |
15 | def rectify_events(x: np.ndarray, y: np.ndarray, location: str, rectify_ev_maps):
16 | rectify_map = rectify_ev_maps[location]
17 | assert rectify_map.shape == (height, width, 2), rectify_map.shape
18 | assert x.max() < width
19 | assert y.max() < height
20 | return rectify_map[y, x]
21 |
22 |
23 | if __name__=='__main__':
24 | parser = argparse.ArgumentParser(description='Dataset Processing.')
25 | parser.add_argument('--dataset_path', help='Path to dataset', required=True)
26 | args = parser.parse_args()
27 |
28 | dataset_dir = args.dataset_path
29 |
30 | folder_list_all = os.listdir(dataset_dir)
31 | folder_list_all.sort()
32 |
33 | event_prefix = 'events'
34 |
35 | locations = ['left', 'right']
36 | delta_t_ms = 50
37 |
38 | delta_t_us = delta_t_ms * 1000
39 |
40 | save_dir = dataset_dir
41 |
42 | for folder_name in folder_list_all:
43 | print(folder_name)
44 | seq_path = Path(os.path.join(dataset_dir, folder_name))
45 |
46 | disp_dir = seq_path / 'disparity'
47 | assert disp_dir.is_dir()
48 |
49 | timestamps = np.loadtxt(disp_dir / 'timestamps.txt', dtype='int64')
50 |
51 | # load disparity paths
52 | ev_disp_dir = disp_dir / 'event'
53 | assert ev_disp_dir.is_dir()
54 | disp_gt_pathstrings = list()
55 | for entry in ev_disp_dir.iterdir():
56 | assert str(entry.name).endswith('.png')
57 | disp_gt_pathstrings.append(str(entry))
58 | disp_gt_pathstrings.sort()
59 | disp_gt_pathstrings = disp_gt_pathstrings
60 | assert len(disp_gt_pathstrings) == timestamps.size
61 |
62 | event_dir = seq_path / event_prefix
63 | assert event_dir.is_dir()
64 |
65 |
66 | assert int(Path(disp_gt_pathstrings[0]).stem) == 0
67 | disp_gt_pathstrings.pop(0)
68 | timestamps = timestamps[1:]
69 |
70 | h5f = dict()
71 | rectify_ev_maps = dict()
72 | event_slicers = dict()
73 |
74 | ev_dir = seq_path / 'events'
75 |
76 | event_vox_save_dir = os.path.join(save_dir, folder_name, 'raw_events')
77 | if not os.path.exists(event_vox_save_dir):
78 | os.makedirs(event_vox_save_dir)
79 | else:
80 | continue
81 | if not os.path.exists(os.path.join(event_vox_save_dir, 'left')):
82 | os.makedirs(os.path.join(event_vox_save_dir, 'left'))
83 | if not os.path.exists(os.path.join(event_vox_save_dir, 'right')):
84 | os.makedirs(os.path.join(event_vox_save_dir, 'right'))
85 |
86 |
87 | for location in locations:
88 | ev_dir_location = ev_dir / location
89 | ev_data_file = ev_dir_location / 'events.h5'
90 | ev_rect_file = ev_dir_location / 'rectify_map.h5'
91 |
92 | h5f_location = h5py.File(str(ev_data_file), 'r')
93 | h5f[location] = h5f_location
94 | event_slicers[location] = EventSlicer(h5f_location)
95 | with h5py.File(str(ev_rect_file), 'r') as h5_rect:
96 | rectify_ev_maps[location] = h5_rect['rectify_map'][()]
97 |
98 |
99 |
100 |
101 |
102 |
103 | for index in range(len(timestamps)):
104 | ts_end = timestamps[index]
105 | # ts_start should be fine (within the window as we removed the first disparity map)
106 | ts_start = ts_end - delta_t_us
107 |
108 | for location in locations:
109 | event_data = event_slicers[location].get_events(ts_start, ts_end)
110 |
111 | p = event_data['p']
112 | t = event_data['t']
113 |
114 | t = (t - t[0]).astype('uint32')
115 |
116 | x = event_data['x']
117 | y = event_data['y']
118 |
119 | xy_rect = rectify_events(x, y, location, rectify_ev_maps)
120 | x_rect = xy_rect[:, 0]
121 | y_rect = xy_rect[:, 1]
122 |
123 | event_index = (x_rect >= 0) & (x_rect < width) & (y_rect >= 0) & (y_rect < height)
124 |
125 | x_rect = x_rect[event_index]
126 | y_rect = y_rect[event_index]
127 | t = t[event_index]
128 | p = p[event_index]
129 |
130 |
131 | event_representation = np.stack([t, np.round(x_rect).astype('int'), np.round(y_rect).astype('int'), p], 1).astype('uint32')
132 |
133 |
134 | np.save(event_vox_save_dir + '/' + str(location) + '/' + disp_gt_pathstrings[index].split('/')[-1].replace('.png', '.npy'), event_representation)
135 |
136 |
137 |
--------------------------------------------------------------------------------
/pre-processing/voxel_generate_all.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pdb
4 | import torch
5 | from pathlib import Path
6 | import hdf5plugin
7 | import h5py
8 | from eventslicer import EventSlicer
9 | import argparse
10 |
11 | num_bins = 15
12 | height = 480
13 | width = 640
14 |
15 |
16 | def events_to_voxel_grid(x, y, p, t, num_bins=num_bins, width=width, height=height):
17 | t = (t - t[0]).astype('float32')
18 | t = (t/t[-1])
19 | x = x.astype('float32')
20 | y = y.astype('float32')
21 | pol = p.astype('float32')
22 |
23 | x = torch.from_numpy(x)
24 | y = torch.from_numpy(y)
25 | pol = torch.from_numpy(pol)
26 | time = torch.from_numpy(t)
27 |
28 | with torch.no_grad():
29 | voxel_grid = torch.zeros((num_bins, height, width), dtype=torch.float, requires_grad=False)
30 | C, H, W = voxel_grid.shape
31 | t_norm = time
32 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0])
33 |
34 | x0 = x.int()
35 | y0 = y.int()
36 | t0 = t_norm.int()
37 | if int(pol.min()) == -1:
38 | value = pol
39 | else:
40 | value = 2*pol-1
41 | # import pdb; pdb.set_trace()
42 | for xlim in [x0,x0+1]:
43 | for ylim in [y0,y0+1]:
44 | for tlim in [t0,t0+1]:
45 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < num_bins)
46 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs())
47 | index = H * W * tlim.long() + \
48 | W * ylim.long() + \
49 | xlim.long()
50 |
51 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
52 |
53 | mask = torch.nonzero(voxel_grid, as_tuple=True)
54 | if mask[0].size()[0] > 0:
55 | mean = voxel_grid[mask].mean()
56 | std = voxel_grid[mask].std()
57 | if std > 0:
58 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std
59 | else:
60 | voxel_grid[mask] = voxel_grid[mask] - mean
61 |
62 | return voxel_grid
63 |
64 | def rectify_events(x: np.ndarray, y: np.ndarray, location: str, rectify_ev_maps):
65 | rectify_map = rectify_ev_maps[location]
66 | assert rectify_map.shape == (height, width, 2), rectify_map.shape
67 | assert x.max() < width
68 | assert y.max() < height
69 | return rectify_map[y, x]
70 |
71 |
72 | if __name__=='__main__':
73 | parser = argparse.ArgumentParser(description='Dataset Processing.')
74 | parser.add_argument('--dataset_path', help='Path to dataset', required=True)
75 | args = parser.parse_args()
76 |
77 | dataset_dir = args.dataset_path
78 | save_dir = dataset_dir
79 |
80 | folder_list_all = os.listdir(dataset_dir)
81 | folder_list_all.sort()
82 |
83 | event_prefix = 'events'
84 |
85 | locations = ['left', 'right']
86 | delta_t_ms = 50
87 | delta_t_us = delta_t_ms * 1000
88 |
89 | for folder_name in folder_list_all:
90 |
91 | print(folder_name)
92 |
93 | seq_path = Path(os.path.join(dataset_dir, folder_name))
94 |
95 |
96 | disp_dir = seq_path / 'disparity'
97 | assert disp_dir.is_dir()
98 |
99 | # timestamps = np.loadtxt(disp_dir / 'timestamps.txt', dtype='int64')
100 |
101 | img_dir = seq_path / 'images'
102 | timestamps = np.loadtxt(img_dir / 'timestamps.txt', dtype='int64')
103 | assert img_dir.is_dir()
104 |
105 | img_left_dir = img_dir / 'left/rectified'
106 | img_pathstrings = list()
107 | for entry in img_left_dir.iterdir():
108 | assert str(entry.name).endswith('.png')
109 | img_pathstrings.append(str(entry))
110 | img_pathstrings.sort()
111 |
112 |
113 | # load disparity paths
114 | ev_disp_dir = disp_dir / 'event'
115 | assert ev_disp_dir.is_dir()
116 | disp_gt_pathstrings = list()
117 | for entry in ev_disp_dir.iterdir():
118 | assert str(entry.name).endswith('.png')
119 | disp_gt_pathstrings.append(str(entry))
120 | disp_gt_pathstrings.sort()
121 | disp_gt_pathstrings = disp_gt_pathstrings
122 |
123 | # assert len(disp_gt_pathstrings) == timestamps.size
124 |
125 |
126 | event_dir = seq_path / event_prefix
127 | assert event_dir.is_dir()
128 |
129 |
130 | assert int(Path(disp_gt_pathstrings[0]).stem) == 0
131 | disp_gt_pathstrings.pop(0)
132 | timestamps = timestamps[1:]
133 | img_pathstrings = img_pathstrings[1:]
134 |
135 | assert len(disp_gt_pathstrings) == (timestamps.size //2)
136 | assert len(img_pathstrings) == timestamps.size
137 |
138 | h5f = dict()
139 | rectify_ev_maps = dict()
140 | event_slicers = dict()
141 |
142 | ev_dir = seq_path / 'events'
143 | for location in locations:
144 | ev_dir_location = ev_dir / location
145 | ev_data_file = ev_dir_location / 'events.h5'
146 | ev_rect_file = ev_dir_location / 'rectify_map.h5'
147 |
148 | h5f_location = h5py.File(str(ev_data_file), 'r')
149 | h5f[location] = h5f_location
150 | event_slicers[location] = EventSlicer(h5f_location)
151 | with h5py.File(str(ev_rect_file), 'r') as h5_rect:
152 | rectify_ev_maps[location] = h5_rect['rectify_map'][()]
153 |
154 |
155 |
156 | event_vox_save_dir = os.path.join(save_dir, folder_name, 'voxel_50ms_15bin')
157 | if not os.path.exists(event_vox_save_dir):
158 | os.makedirs(event_vox_save_dir)
159 | if not os.path.exists(os.path.join(event_vox_save_dir, 'left')):
160 | os.makedirs(os.path.join(event_vox_save_dir, 'left'))
161 | if not os.path.exists(os.path.join(event_vox_save_dir, 'right')):
162 | os.makedirs(os.path.join(event_vox_save_dir, 'right'))
163 |
164 |
165 | for index in range(len(timestamps)):
166 | ts_end = timestamps[index]
167 | # ts_start should be fine (within the window as we removed the first disparity map)
168 | ts_start = ts_end - delta_t_us
169 |
170 | print(ts_end, event_vox_save_dir + '/' + str(location) + '/' + img_pathstrings[index].split('/')[-1].replace('.png', '.npy'))
171 |
172 | for location in locations:
173 | event_data = event_slicers[location].get_events(ts_start, ts_end)
174 |
175 | p = event_data['p']
176 | t = event_data['t']
177 | x = event_data['x']
178 | y = event_data['y']
179 |
180 | xy_rect = rectify_events(x, y, location, rectify_ev_maps)
181 | x_rect = xy_rect[:, 0]
182 | y_rect = xy_rect[:, 1]
183 |
184 | event_representation = events_to_voxel_grid(x_rect, y_rect, p, t)
185 | np.save(event_vox_save_dir + '/' + str(location) + '/' + img_pathstrings[index].split('/')[-1].replace('.png', '.npy')
186 | ,event_representation)
187 |
188 |
--------------------------------------------------------------------------------
/pre-processing/voxel_generate_test_all.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pdb
4 | import torch
5 | from pathlib import Path
6 | import hdf5plugin
7 | import h5py
8 | from eventslicer import EventSlicer
9 | import argparse
10 |
11 | num_bins = 15
12 | height = 480
13 | width = 640
14 |
15 |
16 | def events_to_voxel_grid(x, y, p, t, num_bins=num_bins, width=width, height=height):
17 | t = (t - t[0]).astype('float32')
18 | t = (t/t[-1])
19 | x = x.astype('float32')
20 | y = y.astype('float32')
21 | pol = p.astype('float32')
22 |
23 | x = torch.from_numpy(x)
24 | y = torch.from_numpy(y)
25 | pol = torch.from_numpy(pol)
26 | time = torch.from_numpy(t)
27 |
28 | with torch.no_grad():
29 | voxel_grid = torch.zeros((num_bins, height, width), dtype=torch.float, requires_grad=False)
30 | C, H, W = voxel_grid.shape
31 | t_norm = time
32 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0])
33 |
34 | x0 = x.int()
35 | y0 = y.int()
36 | t0 = t_norm.int()
37 | if int(pol.min()) == -1:
38 | value = pol
39 | else:
40 | value = 2*pol-1
41 | # import pdb; pdb.set_trace()
42 | for xlim in [x0,x0+1]:
43 | for ylim in [y0,y0+1]:
44 | for tlim in [t0,t0+1]:
45 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < num_bins)
46 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs())
47 | index = H * W * tlim.long() + \
48 | W * ylim.long() + \
49 | xlim.long()
50 |
51 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
52 |
53 | mask = torch.nonzero(voxel_grid, as_tuple=True)
54 | if mask[0].size()[0] > 0:
55 | mean = voxel_grid[mask].mean()
56 | std = voxel_grid[mask].std()
57 | if std > 0:
58 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std
59 | else:
60 | voxel_grid[mask] = voxel_grid[mask] - mean
61 |
62 | return voxel_grid
63 |
64 | def rectify_events(x: np.ndarray, y: np.ndarray, location: str, rectify_ev_maps):
65 | rectify_map = rectify_ev_maps[location]
66 | assert rectify_map.shape == (height, width, 2), rectify_map.shape
67 | assert x.max() < width
68 | assert y.max() < height
69 | return rectify_map[y, x]
70 |
71 |
72 | if __name__=='__main__':
73 | parser = argparse.ArgumentParser(description='Dataset Processing.')
74 | parser.add_argument('--dataset_path', help='Path to dataset', required=True)
75 | args = parser.parse_args()
76 |
77 | dataset_dir = args.dataset_path
78 | save_dir = dataset_dir
79 |
80 | folder_list_all = os.listdir(dataset_dir)
81 | folder_list_all.sort()
82 |
83 | event_prefix = 'events'
84 |
85 | locations = ['left', 'right']
86 | delta_t_ms = 50
87 | delta_t_us = delta_t_ms * 1000
88 |
89 | for folder_name in folder_list_all:
90 | print(folder_name)
91 |
92 | seq_path = Path(os.path.join(dataset_dir, folder_name))
93 |
94 |
95 | img_dir = seq_path / 'images'
96 | assert img_dir.is_dir()
97 |
98 | timestamps = np.loadtxt(img_dir / 'timestamps.txt', dtype='int64')
99 |
100 |
101 | img_left_dir = img_dir / 'left/rectified'
102 | img_pathstrings = list()
103 | for entry in img_left_dir.iterdir():
104 | assert str(entry.name).endswith('.png')
105 | img_pathstrings.append(str(entry))
106 | img_pathstrings.sort()
107 |
108 |
109 | # timestamps = timestamps[::2]
110 | timestamps = timestamps[1:]
111 | img_pathstrings = img_pathstrings[1:]
112 | assert len(img_pathstrings) == timestamps.size
113 |
114 | event_dir = seq_path / event_prefix
115 | assert event_dir.is_dir()
116 |
117 | h5f = dict()
118 | rectify_ev_maps = dict()
119 | event_slicers = dict()
120 |
121 | ev_dir = seq_path / 'events'
122 | for location in locations:
123 | ev_dir_location = ev_dir / location
124 | ev_data_file = ev_dir_location / 'events.h5'
125 | ev_rect_file = ev_dir_location / 'rectify_map.h5'
126 |
127 | h5f_location = h5py.File(str(ev_data_file), 'r')
128 | h5f[location] = h5f_location
129 | event_slicers[location] = EventSlicer(h5f_location)
130 | with h5py.File(str(ev_rect_file), 'r') as h5_rect:
131 | rectify_ev_maps[location] = h5_rect['rectify_map'][()]
132 |
133 |
134 | event_vox_save_dir = os.path.join(save_dir, folder_name, 'voxel_50ms_15bin')
135 | if not os.path.exists(event_vox_save_dir):
136 | os.makedirs(event_vox_save_dir)
137 | if not os.path.exists(os.path.join(event_vox_save_dir, 'left')):
138 | os.makedirs(os.path.join(event_vox_save_dir, 'left'))
139 | if not os.path.exists(os.path.join(event_vox_save_dir, 'right')):
140 | os.makedirs(os.path.join(event_vox_save_dir, 'right'))
141 |
142 |
143 | for index in range(len(timestamps)):
144 | # print(str(2 * index + 2).zfill(6))
145 | ts_end = timestamps[index]
146 | # ts_start should be fine (within the window as we removed the first disparity map)
147 | ts_start = ts_end - delta_t_us
148 |
149 | print(ts_end, event_vox_save_dir + '/' + str(location) + '/' +
150 | img_pathstrings[index].split('/')[-1].replace('.png', '.npy'))
151 |
152 |
153 | for location in locations:
154 | event_data = event_slicers[location].get_events(ts_start, ts_end)
155 |
156 | p = event_data['p']
157 | t = event_data['t']
158 | x = event_data['x']
159 | y = event_data['y']
160 |
161 | xy_rect = rectify_events(x, y, location, rectify_ev_maps)
162 | x_rect = xy_rect[:, 0]
163 | y_rect = xy_rect[:, 1]
164 |
165 | event_representation = events_to_voxel_grid(x_rect, y_rect, p, t)
166 | np.save(event_vox_save_dir + '/' + str(location) + '/' + img_pathstrings[index].split('/')[-1].replace('.png', '.npy')
167 | , event_representation)
168 |
169 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | timm==0.4.12
2 | omegaconf
3 | opencv-python
4 | opt-einsum
5 | pafy
6 | pypng
7 | PyQt5
8 | pytorch-lightning==1.5.2
9 | PyYAML
10 | seaborn
11 | scikit-image
12 | scikit-learn
13 | tensorboard
14 | thop
15 | tqdm
16 | typing-extensions
17 | yacs
18 | debugpy
19 | pillow==9.0.1
20 | einops
21 | h5py
22 | numba
23 | hdf5plugin
24 | setuptools==59.5.0
25 | flow_vis
--------------------------------------------------------------------------------
/resource/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mickeykang16/TemporalEventStereo/d9dc74677f568dbdf2d1d06da20480c50410e274/resource/teaser.png
--------------------------------------------------------------------------------
/resource/temporal_stereo_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mickeykang16/TemporalEventStereo/d9dc74677f568dbdf2d1d06da20480c50410e274/resource/temporal_stereo_demo.gif
--------------------------------------------------------------------------------
/run_tmux.sh:
--------------------------------------------------------------------------------
1 | tmux new-session -s gpu$GPUNUM 'tmux source-file ./.tmux.conf; $SHELL'
2 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import Logger
2 | from .visualization import disp_err_to_color
3 | from .viz import save_matrix
4 | from .softsplat import softsplat
5 | from .warp import flow_warp, disp_warp
6 | from .flow_vis import flow_to_color
--------------------------------------------------------------------------------
/utils/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .pixel_error import calc_error
2 | from .eval import do_evaluation, do_occlusion_evaluation, do_evaluation_test
3 | from .flow_pixel_error import flow_calc_error
4 | from .flow_eval import do_flow_evaluation
5 | from .inverse_warp import inverse_warp
--------------------------------------------------------------------------------
/utils/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 |
5 | from .inverse_warp import inverse_warp
6 | from .pixel_error import calc_error, calc_error_test
7 |
8 |
9 | def do_evaluation(est_disp, gt_disp, lb, ub):
10 | """
11 | Do pixel error evaluation. (See KITTI evaluation protocols for details.)
12 | Args:
13 | est_disp: (Tensor), estimated disparity map,
14 | [..., Height, Width] layout
15 | gt_disp: (Tensor), ground truth disparity map
16 | [..., Height, Width] layout
17 | lb: (scalar), the lower bound of disparity you want to mask out
18 | ub: (scalar), the upper bound of disparity you want to mask out
19 | Returns:
20 | error_dict: (dict), the error of 1px, 2px, 3px, 5px, in percent,
21 | range [0,100] and average error epe
22 | """
23 | error_dict = {}
24 | if est_disp is None:
25 | warnings.warn('Estimated disparity map is None')
26 | return error_dict
27 | if gt_disp is None:
28 | warnings.warn('Reference ground truth disparity map is None')
29 | return error_dict
30 |
31 | if torch.is_tensor(est_disp):
32 | est_disp = est_disp.clone().cpu()
33 |
34 | if torch.is_tensor(gt_disp):
35 | gt_disp = gt_disp.clone().cpu()
36 |
37 | assert est_disp.shape == gt_disp.shape, "Estimated Disparity map with shape: {}, but GroundTruth Disparity map" \
38 | " with shape: {}".format(est_disp.shape, gt_disp.shape)
39 |
40 | error_dict = calc_error(est_disp, gt_disp, lb=lb, ub=ub)
41 |
42 | return error_dict
43 |
44 |
45 |
46 | def do_evaluation_test(est_disp, gt_disp, lb, ub):
47 | """
48 | Do pixel error evaluation. (See KITTI evaluation protocols for details.)
49 | Args:
50 | est_disp: (Tensor), estimated disparity map,
51 | [..., Height, Width] layout
52 | gt_disp: (Tensor), ground truth disparity map
53 | [..., Height, Width] layout
54 | lb: (scalar), the lower bound of disparity you want to mask out
55 | ub: (scalar), the upper bound of disparity you want to mask out
56 | Returns:
57 | error_dict: (dict), the error of 1px, 2px, 3px, 5px, in percent,
58 | range [0,100] and average error epe
59 | """
60 | error_dict = {}
61 | if est_disp is None:
62 | warnings.warn('Estimated disparity map is None')
63 | return error_dict
64 | if gt_disp is None:
65 | warnings.warn('Reference ground truth disparity map is None')
66 | return error_dict
67 |
68 | if torch.is_tensor(est_disp):
69 | est_disp = est_disp.clone().cpu()
70 |
71 | if torch.is_tensor(gt_disp):
72 | gt_disp = gt_disp.clone().cpu()
73 |
74 | assert est_disp.shape == gt_disp.shape, "Estimated Disparity map with shape: {}, but GroundTruth Disparity map" \
75 | " with shape: {}".format(est_disp.shape, gt_disp.shape)
76 |
77 | error_dict = calc_error_test(est_disp, gt_disp, lb=lb, ub=ub)
78 |
79 | return error_dict
80 |
81 |
82 | def do_occlusion_evaluation(est_disp, ref_gt_disp, target_gt_disp, lb, ub):
83 | """
84 | Do occlusoin evaluation.
85 | Args:
86 | est_disp: (Tensor), estimated disparity map
87 | [BatchSize, 1, Height, Width] layout
88 | ref_gt_disp: (Tensor), reference(left) ground truth disparity map
89 | [BatchSize, 1, Height, Width] layout
90 | target_gt_disp: (Tensor), target(right) ground truth disparity map,
91 | [BatchSize, 1, Height, Width] layout
92 | lb: (scalar): the lower bound of disparity you want to mask out
93 | ub: (scalar): the upper bound of disparity you want to mask out
94 | Returns:
95 | """
96 | error_dict = {}
97 | if est_disp is None:
98 | warnings.warn('Estimated disparity map is None, expected given')
99 | return error_dict
100 | if ref_gt_disp is None:
101 | warnings.warn('Reference ground truth disparity map is None, expected given')
102 | return error_dict
103 | if target_gt_disp is None:
104 | warnings.warn('Target ground truth disparity map is None, expected given')
105 | return error_dict
106 |
107 | if torch.is_tensor(est_disp):
108 | est_disp = est_disp.clone().cpu()
109 | if torch.is_tensor(ref_gt_disp):
110 | ref_gt_disp = ref_gt_disp.clone().cpu()
111 | if torch.is_tensor(target_gt_disp):
112 | target_gt_disp = target_gt_disp.clone().cpu()
113 |
114 | assert est_disp.shape == ref_gt_disp.shape and target_gt_disp.shape == ref_gt_disp.shape, "{}, {}, {}".format(
115 | est_disp.shape, ref_gt_disp.shape, target_gt_disp.shape)
116 |
117 | warp_ref_gt_disp = inverse_warp(target_gt_disp.clone(), -ref_gt_disp.clone(), mode='disparity')
118 | theta = 1.0
119 | eps = 1e-6
120 | occlusion = (
121 | (torch.abs(warp_ref_gt_disp.clone() - ref_gt_disp.clone()) > theta) |
122 | (torch.abs(warp_ref_gt_disp.clone()) < eps)
123 | ).prod(dim=1, keepdim=True).type_as(ref_gt_disp)
124 | occlusion = occlusion.clamp(0, 1)
125 |
126 | occlusion_error_dict = calc_error(
127 | est_disp.clone() * occlusion.clone(),
128 | ref_gt_disp.clone() * occlusion.clone(),
129 | lb=lb, ub=ub
130 | )
131 | for key in occlusion_error_dict.keys():
132 | error_dict['occ_' + key] = occlusion_error_dict[key]
133 |
134 | not_occlusion = 1.0 - occlusion
135 | not_occlusion_error_dict = calc_error(
136 | est_disp.clone() * not_occlusion.clone(),
137 | ref_gt_disp.clone() * not_occlusion.clone(),
138 | lb=lb, ub=ub
139 | )
140 | for key in not_occlusion_error_dict.keys():
141 | error_dict['noc_' + key] = not_occlusion_error_dict[key]
142 |
143 | return error_dict
--------------------------------------------------------------------------------
/utils/evaluation/flow_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import warnings
3 |
4 | from .flow_pixel_error import flow_calc_error
5 |
6 | def do_flow_evaluation(est_flow, gt_flow, lb=0.0, ub=400, sparse=False):
7 | """
8 | Do pixel error evaluation. (See KITTI evaluation protocols for details.)
9 | Args:
10 | est_flow: (Tensor), estimated flow map
11 | [..., 2, Height, Width] layout
12 | gt_flow: (Tensor), ground truth flow map
13 | [..., 2, Height, Width] layout
14 | lb: (scalar), the lower bound of disparity you want to mask out
15 | ub: (scalar), the upper bound of disparity you want to mask out
16 | sparse: (bool), whether the given flow is sparse, default False
17 | Returns:
18 | error_dict (dict): the error of 1px, 2px, 3px, 5px, in percent,
19 | range [0,100] and average error epe
20 | """
21 | error_dict = {}
22 | if est_flow is None:
23 | warnings.warn('Estimated flow map is None')
24 | return error_dict
25 | if gt_flow is None:
26 | warnings.warn('Reference ground truth flow map is None')
27 | return error_dict
28 |
29 | if torch.is_tensor(est_flow):
30 | est_flow = est_flow.clone().cpu()
31 |
32 | if torch.is_tensor(gt_flow):
33 | gt_flow = gt_flow.clone().cpu()
34 |
35 | error_dict = flow_calc_error(est_flow, gt_flow, sparse=sparse)
36 |
37 | return error_dict
38 |
--------------------------------------------------------------------------------
/utils/evaluation/flow_pixel_error.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 | def zero_mask(input, eps=1e-12):
6 | mask = abs(input) < eps
7 | return mask
8 |
9 | def flow_calc_error(est_flow=None, gt_flow=None, lb=0.0, ub=400, sparse=False):
10 | """
11 | Args:
12 | est_flow: (Tensor), estimated flow map
13 | [..., 2, Height, Width] layout
14 | gt_flow: (Tensor), ground truth flow map
15 | [..., 2, Height, Width] layout
16 | lb: (scalar), the lower bound of disparity you want to mask out
17 | ub: (scalar), the upper bound of disparity you want to mask out
18 | sparse: (bool), whether the given flow is sparse, default False
19 | Output:
20 | dict: the error of 1px, 2px, 3px, 5px, in percent,
21 | range [0,100] and average error epe
22 | """
23 | error1 = torch.Tensor([0.])
24 | error2 = torch.Tensor([0.])
25 | error3 = torch.Tensor([0.])
26 | error5 = torch.Tensor([0.])
27 | epe = torch.Tensor([0.])
28 |
29 | if (not torch.is_tensor(est_flow)) or (not torch.is_tensor(gt_flow)):
30 | return {
31 | '1px': error1 * 100,
32 | '2px': error2 * 100,
33 | '3px': error3 * 100,
34 | '5px': error5 * 100,
35 | 'epe': epe
36 | }
37 |
38 | assert torch.is_tensor(est_flow) and torch.is_tensor(gt_flow)
39 | assert est_flow.shape == gt_flow.shape
40 |
41 | est_flow = est_flow.clone().cpu()
42 | gt_flow = gt_flow.clone().cpu()
43 | if len(gt_flow.shape) == 3:
44 | gt_flow = gt_flow.unsqueeze(0)
45 | est_flow = est_flow.unsqueeze(0)
46 |
47 | assert gt_flow.shape[1] == 2, "flow should have horizontal and vertical dimension, " \
48 | "but got {}".format(gt_flow.shape[1])
49 |
50 | # [B, 1, H, W]
51 | gt_u, gt_v = gt_flow[:, 0:1, :, :], gt_flow[:, 1:2, :, :]
52 | est_u, est_v = est_flow[:, 0:1, :, :], est_flow[:, 1:2, :, :]
53 |
54 | # get valid mask
55 | # [B, 1, H, W]
56 | mask = torch.ones(gt_u.shape, dtype=torch.bool, device=gt_u.device)
57 | if sparse:
58 | mask = mask & (~(zero_mask(gt_u) & zero_mask(gt_v)))
59 | mask = mask & (~(torch.isnan(gt_u) | torch.isnan(gt_v)))
60 |
61 | rad = torch.sqrt(gt_u**2 + gt_v**2)
62 | mask = mask & (rad > lb) & (rad < ub)
63 |
64 | mask.detach_()
65 | if abs(mask.float().sum()) < 1.0:
66 | return {
67 | '1px': error1 * 100,
68 | '2px': error2 * 100,
69 | '3px': error3 * 100,
70 | '5px': error5 * 100,
71 | 'epe': epe
72 | }
73 |
74 |
75 |
76 | gt_u = gt_u[mask]
77 | gt_v = gt_v[mask]
78 | est_u = est_u[mask]
79 | est_v = est_v[mask]
80 |
81 | abs_error = torch.sqrt((gt_u - est_u)**2 + (gt_v - est_v)**2)
82 | total_num = mask.float().sum()
83 |
84 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num
85 | error2 = torch.sum(torch.gt(abs_error, 2).float()) / total_num
86 | error3 = torch.sum(torch.gt(abs_error, 3).float()) / total_num
87 | error5 = torch.sum(torch.gt(abs_error, 5).float()) / total_num
88 | epe = abs_error.float().mean()
89 |
90 | return {
91 | '1px': error1 * 100,
92 | '2px': error2 * 100,
93 | '3px': error3 * 100,
94 | '5px': error5 * 100,
95 | 'epe': epe
96 | }
--------------------------------------------------------------------------------
/utils/evaluation/inverse_warp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def inverse_warp(img: torch.Tensor,
7 | motion: torch.Tensor,
8 | mode: str ='disparity',
9 | K: torch.Tensor = None,
10 | inv_K: torch.Tensor = None,
11 | T_target_to_source: torch.Tensor = None,
12 | interpolate_mode: str = 'bilinear',
13 | padding_mode: str = 'zeros',
14 | eps: float = 1e-7,
15 | output_all: bool = False):
16 | """
17 | sample the image pixel value from source image and project to target image space,
18 | Args:
19 | img: (Tensor): the source image (where to sample pixels)
20 | [BatchSize, C, Height, Width]
21 | motion: (Tensor): disparity/depth/flow map of the target image
22 | [BatchSize, 1, Height, Width]
23 | mode: (str): which kind of warp to perform, including: ['disparity', 'depth', 'flow']
24 | K: (Optional, Tensor): instrincs of camera
25 | [BatchSize, 3, 3] or [BatchSize, 4, 4]
26 | inv_K: (Optional, Tensor): invserse instrincs of camera
27 | [BatchSize, 3, 3] or [BatchSize, 4, 4]
28 | T_target_to_source: (Optional, Tensor): predicted transformation matrix from target image to source image frames
29 | [BatchSize, 4, 4]
30 | interpolate_mode: (str): interpolate mode when grid sample, default is bilinear
31 | padding_mode: (str): padding mode when grid sample, default is zero padding
32 | eps: (float): eplison value to avoid divide 0, default 1e-7
33 | output_all: (bool): if output all result during warp, default False
34 | Returns:
35 | projected_img: (Tensor): source image warped to the target image
36 | [BatchSize, C, Height, Width]
37 | output: (Optional, Dict): such as optical flow, flow mask, triangular depth, src_pixel_coord and so on
38 | """
39 | B, C, H, W = motion.shape
40 | device = motion.device
41 | dtype = motion.dtype
42 | output = {}
43 |
44 | if mode == 'disparity':
45 | assert C == 1, "Disparity map must be 1 channel, but {} got!".format(C)
46 | # [B, 2, H, W]
47 | pixel_coord = mesh_grid(B, H, W, device, dtype)
48 | X = pixel_coord[:, 0, :, :] + motion[:, 0]
49 | Y = pixel_coord[:, 1, :, :]
50 | elif mode == 'flow':
51 | assert C == 2, "Optical flow map must be 2 channel, but {} got!".format(C)
52 | # [B, 2, H, W]
53 | pixel_coord = mesh_grid(B, H, W, device, dtype)
54 | X = pixel_coord[:, 0, :, :] + motion[:, 0]
55 | Y = pixel_coord[:, 1, :, :] + motion[:, 1]
56 | elif mode == 'depth':
57 | assert C == 1, "Disparity map must be 1 channel, but {} got!".format(C)
58 | outs = project_to_3d(motion, K, inv_K, T_target_to_source, eps)
59 | output.update(outs)
60 | src_pixel_coord = outs['src_pixel_coord']
61 | X = src_pixel_coord[:, 0, :, :]
62 | Y = src_pixel_coord[:, 1, :, :]
63 |
64 | else:
65 | raise TypeError("Inverse warp only support [disparity, flow, depth] mode, but {} got".format(mode))
66 |
67 | X_norm = 2 * X / (W-1) - 1
68 | Y_norm = 2 * Y / (H-1) - 1
69 | # [B, H, W, 2]
70 | pixel_coord_norm = torch.stack((X_norm, Y_norm), dim=3)
71 |
72 | projected_img = F.grid_sample(img, pixel_coord_norm, mode=interpolate_mode, padding_mode=padding_mode, align_corners=True)
73 |
74 | if output_all:
75 | return projected_img, output
76 | else:
77 | return projected_img
78 |
79 |
80 | def mesh_grid(b, h, w, device, dtype=torch.float):
81 | """ construct pixel coordination in an image"""
82 | # [1, H, W] copy 0-width for h times : x coord
83 | x_range = torch.arange(0, w, device=device, dtype=dtype).view(1, 1, 1, w).expand(b, 1, h, w)
84 | # [1, H, W] copy 0-height for w times : y coord
85 | y_range = torch.arange(0, h, device=device, dtype=dtype).view(1, 1, h, 1).expand(b, 1, h, w)
86 |
87 | # [b, 2, h, w]
88 | pixel_coord = torch.cat((x_range, y_range), dim=1)
89 |
90 | return pixel_coord
91 |
92 | def project_to_3d(depth, K, inv_K=None, T_target_to_source:torch.Tensor=None, eps=1e-7):
93 | """
94 | project depth map to 3D space
95 | Args:
96 | depth: (Tensor): depth map(s), can be several depth maps concatenated at channel dimension
97 | [BatchSize, Channel, Height, Width]
98 | K: (Tensor): instrincs of camera
99 | [BatchSize, 3, 3] or [BatchSize, 4, 4]
100 | inv_K: (Optional, Tensor): invserse instrincs of camera
101 | [BatchSize, 3, 3] or [BatchSize, 4, 4]
102 | T_target_to_source: (Optional, Tensor): predicted transformation matrix from target image to source image frames
103 | [BatchSize, 4, 4]
104 | eps: (float): eplison value to avoid divide 0, default 1e-7
105 |
106 | Returns: Dict including
107 | homo_points_3d: (Tensor): the homogeneous points after depth project to 3D space, [x, y, z, 1]
108 | [BatchSize, 4, Channel*Height*Width]
109 | if T_target_to_source provided:
110 |
111 | triangular_depth: (Tensor): the depth map after the 3D points project to source camera
112 | [BatchSize, Channel, Height, Width]
113 | optical_flow: (Tensor): by 3D projection, the rigid flow can be got
114 | [BatchSize, Channel*2, Height, Width], to get the 2nd flow, index like [:, 2:4, :, :]
115 | flow_mask: (Tensor): the mask indicates which pixel's optical flow is valid
116 | [BatchSize, Channel, Height, Width]
117 | """
118 |
119 | # support C >=1, for C > 1, it means several depth maps are concatenated at channel dimension
120 | B, C, H, W = depth.size()
121 | device = depth.device
122 | dtype = depth.dtype
123 | output = {}
124 |
125 | # [B, 2, H, W]
126 | pixel_coord = mesh_grid(B, H, W, device, dtype)
127 | ones = torch.ones(B, 1, H, W, device=device, dtype=dtype)
128 | # [B, 3, H, W], homogeneous coordination of image pixel, [x, y, 1]
129 | homo_pixel_coord = torch.cat((pixel_coord, ones), dim=1).contiguous()
130 |
131 | # [B, 3, H*W] -> [B, 3, C*H*W]
132 | homo_pixel_coord = homo_pixel_coord.view(B, 3, -1).repeat(1, 1, C).contiguous()
133 | # [B, C*H*W] -> [B, 1, C*H*W]
134 | depth = depth.view(B, -1).unsqueeze(dim=1).contiguous()
135 | if inv_K is None:
136 | inv_K = torch.inverse(K[:, :3, :3])
137 | # [B, 3, C*H*W]
138 | points_3d = torch.matmul(inv_K[:, :3, :3], homo_pixel_coord) * depth
139 | ones = torch.ones(B, 1, C*H*W, device=device, dtype=dtype)
140 | # [B, 4, C*H*W], homogeneous coordiate, [x, y, z, 1]
141 | homo_points_3d = torch.cat((points_3d, ones), dim=1)
142 | output['homo_points_3d'] = homo_points_3d
143 |
144 | if T_target_to_source is not None:
145 | if K.shape[-1] == 3:
146 | new_K = torch.eye(4, device=device, dtype=dtype).unsqueeze(dim=0).repeat(B, 1, 1)
147 | new_K[:, :3, :3] = K[:, :3, :3]
148 | # [B, 3, 4]
149 | P = torch.matmul(new_K, T_target_to_source)[:, :3, :]
150 | else:
151 | # [B, 3, 4]
152 | P = torch.matmul(K, T_target_to_source)[:, :3, :]
153 | # [B, 3, C*H*W]
154 | src_points_3d = torch.matmul(P, homo_points_3d)
155 |
156 | # [B, C*H*W] -> [B, C, H, W], the depth map after 3D points projected to source camera
157 | triangular_depth = src_points_3d[:, -1, :].reshape(B, C, H, W).contiguous()
158 | output['triangular_depth'] = triangular_depth
159 | # [B, 2, C*H*W]
160 | src_pixel_coord = src_points_3d[:, :2, :] / (src_points_3d[:, 2:3, :] + eps)
161 | # [B, 2, C, H, W] -> [B, C, 2, H, W]
162 | src_pixel_coord = src_pixel_coord.reshape(B, 2, C, H, W).permute(0, 2, 1, 3, 4).contiguous()
163 |
164 | # [B, C, 1, H, W]
165 | mask = (src_pixel_coord[:, :, 0:1] >=0) & (src_pixel_coord[:, :, 0:1] <= W-1) \
166 | & (src_pixel_coord[:, :, 1:2] >=0) & (src_pixel_coord[:, :, 1:2] <= H-1)
167 |
168 | # valid flow mask
169 | mask = mask.reshape(B, C, H, W).contiguous()
170 | output['flow_mask'] = mask
171 | # [B, C*2, H, W]
172 | src_pixel_coord = src_pixel_coord.reshape(B, C*2, H, W).contiguous()
173 | output['src_pixel_coord'] = src_pixel_coord
174 | # [B, C*2, H, W]
175 | optical_flow = src_pixel_coord - pixel_coord.repeat(1, C, 1, 1)
176 | output['optical_flow'] = optical_flow
177 |
178 | return output
--------------------------------------------------------------------------------
/utils/evaluation/pixel_error.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import pdb
5 | FOCAL_LENGTH_X_BASELINE = {
6 | 'indoor_flying': 19.941772,
7 | 'outdoor_night': 19.651191,
8 | 'outdoor_day': 19.635287
9 | }
10 |
11 |
12 | def disparity_to_depth(disparity_image):
13 |
14 |
15 | unknown_disparity = disparity_image == float('inf')
16 | depth_image = \
17 | FOCAL_LENGTH_X_BASELINE['indoor_flying'] / (
18 | disparity_image + 1e-7)
19 | depth_image[unknown_disparity] = float('inf')
20 | return depth_image
21 |
22 |
23 | def calc_error(est_disp=None, gt_disp=None, lb=None, ub=None):
24 | """
25 | Args:
26 | est_disp (Tensor): in [..., Height, Width] layout
27 | gt_disp (Tensor): in [..., Height, Width] layout
28 | lb (scalar): the lower bound of disparity you want to mask out
29 | ub (scalar): the upper bound of disparity you want to mask out
30 | Output:
31 | dict: the error of 1px, 2px, 3px, 5px, in percent,
32 | range [0,100] and average error epe
33 | """
34 |
35 |
36 | error1 = torch.Tensor([0.])
37 | error2 = torch.Tensor([0.])
38 | error3 = torch.Tensor([0.])
39 | error5 = torch.Tensor([0.])
40 | epe = torch.Tensor([0.])
41 |
42 | if (not torch.is_tensor(est_disp)) or (not torch.is_tensor(gt_disp)):
43 | return {
44 | '1px': error1 * 100,
45 | '2px': error2 * 100,
46 | '3px': error3 * 100,
47 | '5px': error5 * 100,
48 | 'epe': epe
49 | }
50 |
51 | assert torch.is_tensor(est_disp) and torch.is_tensor(gt_disp)
52 | assert est_disp.shape == gt_disp.shape
53 |
54 | est_disp = est_disp.clone().cpu()
55 | gt_disp = gt_disp.clone().cpu()
56 |
57 | mask = torch.ones(gt_disp.shape, dtype=torch.bool, device=gt_disp.device)
58 | if lb is not None:
59 | mask = mask & (gt_disp >= lb)
60 | if ub is not None:
61 | mask = mask & (gt_disp <= ub)
62 | # mask = mask & (est_disp <= ub)
63 |
64 | mask.detach_()
65 | if abs(mask.float().sum()) < 1.0:
66 | return {
67 | '1px': error1 * 100,
68 | '2px': error2 * 100,
69 | '3px': error3 * 100,
70 | '5px': error5 * 100,
71 | 'epe': epe
72 | }
73 | ## original error compute
74 | gt_disp = gt_disp[mask]
75 | est_disp = est_disp[mask]
76 | abs_error = torch.abs(gt_disp - est_disp)
77 |
78 | ## hh error compute for top k ###################
79 | # abs_error = torch.abs(gt_disp - est_disp) * mask
80 | # pix_sum = torch.gt(abs_error, 1)
81 | # # error_sum = abs_error.sum(1).sum(1)
82 | # error_sum = pix_sum.sum(1).sum(1)
83 | # mask_sum = mask.sum(1).sum(1)
84 |
85 | # per_img_error = error_sum / mask_sum
86 |
87 | # values, indexs = torch.topk(-per_img_error, 1343)
88 | # gt_disp = gt_disp[indexs, :, :]
89 | # est_disp = est_disp[indexs, :, :]
90 | # mask = mask[indexs, :, :]
91 | # abs_error = torch.abs(gt_disp - est_disp) * mask
92 | ###############################
93 |
94 |
95 | total_num = mask.float().sum()
96 |
97 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num
98 | error2 = torch.sum(torch.gt(abs_error, 2).float()) / total_num
99 | error3 = torch.sum(torch.gt(abs_error, 3).float()) / total_num
100 | error5 = torch.sum(torch.gt(abs_error, 5).float()) / total_num
101 | epe = abs_error.float().mean()
102 |
103 | # .mean() will get a tensor with size: torch.Size([]), after decorate with torch.Tensor, the size will be: torch.Size([1])
104 | return {
105 | '1px': torch.Tensor([error1 * 100]),
106 | '2px': torch.Tensor([error2 * 100]),
107 | '3px': torch.Tensor([error3 * 100]),
108 | '5px': torch.Tensor([error5 * 100]),
109 | 'epe': torch.Tensor([epe]),
110 | }
111 |
112 |
113 | def calc_error_test(est_disp=None, gt_disp=None, lb=None, ub=None):
114 | """
115 | Args:
116 | est_disp (Tensor): in [..., Height, Width] layout
117 | gt_disp (Tensor): in [..., Height, Width] layout
118 | lb (scalar): the lower bound of disparity you want to mask out
119 | ub (scalar): the upper bound of disparity you want to mask out
120 | Output:
121 | dict: the error of 1px, 2px, 3px, 5px, in percent,
122 | range [0,100] and average error epe
123 | """
124 |
125 | error1 = torch.Tensor([0.])
126 | # mean_depth = torch.Tensor([0.])
127 | # median_depth = torch.Tensor([0.])
128 | epe = torch.Tensor([0.])
129 |
130 | if (not torch.is_tensor(est_disp)) or (not torch.is_tensor(gt_disp)):
131 | return {
132 | '1px': error1 * 100,
133 | 'epe': epe
134 | }
135 |
136 | assert torch.is_tensor(est_disp) and torch.is_tensor(gt_disp)
137 | assert est_disp.shape == gt_disp.shape
138 |
139 | est_disp = est_disp.clone().cpu()
140 | gt_disp = gt_disp.clone().cpu()
141 |
142 | gt_mask = gt_disp > ub
143 | nan_gt_mask = torch.isnan(gt_disp)
144 | gt_disp[gt_mask] = float('inf')
145 | gt_disp[nan_gt_mask] = float('inf')
146 |
147 | estimated_depth = disparity_to_depth(est_disp)
148 | ground_truth_depth = disparity_to_depth(gt_disp)
149 |
150 |
151 | mean_depth = compute_absolute_error(estimated_depth, ground_truth_depth)[1]
152 | median_depth = compute_absolute_error(estimated_depth, ground_truth_depth, use_mean=False)[1]
153 |
154 | # est_mask = est_disp > ub
155 | # est_disp[est_mask] = float('inf')
156 |
157 |
158 |
159 | binary_error_map, one_pixel_error = compute_n_pixels_error(est_disp, gt_disp, n=1.0)
160 |
161 |
162 | mean_disparity_error = compute_absolute_error(est_disp, gt_disp)[1]
163 |
164 |
165 |
166 | mask = torch.ones(gt_disp.shape, dtype=torch.bool, device=gt_disp.device)
167 | if lb is not None:
168 | mask = mask & (gt_disp >= lb)
169 | if ub is not None:
170 | mask = mask & (gt_disp <= ub)
171 | # mask = mask & (est_disp <= ub)
172 |
173 | mask.detach_()
174 | if abs(mask.float().sum()) < 1.0:
175 | return {
176 | '1px': one_pixel_error,
177 | 'mean_depth': mean_depth * 100,
178 | 'median_depth': median_depth * 100,
179 | 'epe': mean_disparity_error
180 | }
181 | ## original error compute
182 | gt_disp = gt_disp[mask]
183 | est_disp = est_disp[mask]
184 | abs_error = torch.abs(gt_disp - est_disp)
185 |
186 | ## hh error compute for top k ###################
187 | # abs_error = torch.abs(gt_disp - est_disp) * mask
188 | # pix_sum = torch.gt(abs_error, 1)
189 | # # error_sum = abs_error.sum(1).sum(1)
190 | # error_sum = pix_sum.sum(1).sum(1)
191 | # mask_sum = mask.sum(1).sum(1)
192 |
193 | # per_img_error = error_sum / mask_sum
194 |
195 | # values, indexs = torch.topk(-per_img_error, 1343)
196 | # gt_disp = gt_disp[indexs, :, :]
197 | # est_disp = est_disp[indexs, :, :]
198 | # mask = mask[indexs, :, :]
199 | # abs_error = torch.abs(gt_disp - est_disp) * mask
200 | ###############################
201 |
202 |
203 | total_num = mask.float().sum()
204 |
205 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num
206 |
207 | epe = abs_error.float().mean()
208 |
209 | # .mean() will get a tensor with size: torch.Size([]), after decorate with torch.Tensor, the size will be: torch.Size([1])
210 | return {
211 | '1px': one_pixel_error,
212 | 'mean_depth': mean_depth * 100,
213 | 'median_depth': median_depth * 100,
214 | 'epe': mean_disparity_error,
215 | }
216 |
217 |
218 | def compute_absolute_error(estimated_disparity,
219 | ground_truth_disparity,
220 | use_mean=True):
221 | """Returns pixel-wise and mean absolute error.
222 |
223 | Locations where ground truth is not avaliable do not contribute to mean
224 | absolute error. In such locations pixel-wise error is shown as zero.
225 | If ground truth is not avaliable in all locations, function returns 0.
226 |
227 | Args:
228 | ground_truth_disparity: ground truth disparity where locations with
229 | unknow disparity are set to inf's.
230 | estimated_disparity: estimated disparity.
231 | use_mean: if True than use mean to average pixelwise errors,
232 | otherwise use median.
233 | """
234 | absolute_difference = (estimated_disparity - ground_truth_disparity).abs()
235 | locations_without_ground_truth = torch.isinf(ground_truth_disparity)
236 | pixelwise_absolute_error = absolute_difference.clone()
237 | pixelwise_absolute_error[locations_without_ground_truth] = 0
238 | absolute_differece_with_ground_truth = absolute_difference[
239 | ~locations_without_ground_truth]
240 | if absolute_differece_with_ground_truth.numel() == 0:
241 | average_absolute_error = 0.0
242 | else:
243 | if use_mean:
244 | average_absolute_error = absolute_differece_with_ground_truth.mean(
245 | ).item()
246 | else:
247 | average_absolute_error = absolute_differece_with_ground_truth.median(
248 | ).item()
249 | return pixelwise_absolute_error, average_absolute_error
250 |
251 |
252 |
253 | def compute_n_pixels_error(estimated_disparity, ground_truth_disparity, n=3.0):
254 | """Return pixel-wise n-pixels error and % of pixels with n-pixels error.
255 |
256 | Locations where ground truth is not avaliable do not contribute to mean
257 | n-pixel error. In such locations pixel-wise error is shown as zero.
258 |
259 | Note that n-pixel error is equal to one if
260 | |estimated_disparity-ground_truth_disparity| > n and zero otherwise.
261 |
262 | If ground truth is not avaliable in all locations, function returns 0.
263 |
264 | Args:
265 | ground_truth_disparity: ground truth disparity where locations with
266 | unknow disparity are set to inf's.
267 | estimated_disparity: estimated disparity.
268 | n: maximum absolute disparity difference, that does not trigger
269 | n-pixel error.
270 | """
271 | locations_without_ground_truth = torch.isinf(ground_truth_disparity)
272 | more_than_n_pixels_absolute_difference = (
273 | estimated_disparity - ground_truth_disparity).abs().gt(n).float()
274 | pixelwise_n_pixels_error = more_than_n_pixels_absolute_difference.clone()
275 | pixelwise_n_pixels_error[locations_without_ground_truth] = 0.0
276 | more_than_n_pixels_absolute_difference_with_ground_truth = \
277 | more_than_n_pixels_absolute_difference[~locations_without_ground_truth]
278 | if more_than_n_pixels_absolute_difference_with_ground_truth.numel() == 0:
279 | percentage_of_pixels_with_error = 0.0
280 | else:
281 | percentage_of_pixels_with_error = \
282 | more_than_n_pixels_absolute_difference_with_ground_truth.mean(
283 | ).item() * 100
284 | return pixelwise_n_pixels_error, percentage_of_pixels_with_error
--------------------------------------------------------------------------------
/utils/evaluation_old/__init__.py:
--------------------------------------------------------------------------------
1 | from .pixel_error import calc_error
2 | from .eval import do_evaluation, do_occlusion_evaluation
3 | from .flow_pixel_error import flow_calc_error
4 | from .flow_eval import do_flow_evaluation
5 | from .inverse_warp import inverse_warp
--------------------------------------------------------------------------------
/utils/evaluation_old/eval.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 |
5 | from .inverse_warp import inverse_warp
6 | from .pixel_error import calc_error
7 |
8 |
9 | def do_evaluation(est_disp, gt_disp, lb, ub):
10 | """
11 | Do pixel error evaluation. (See KITTI evaluation protocols for details.)
12 | Args:
13 | est_disp: (Tensor), estimated disparity map,
14 | [..., Height, Width] layout
15 | gt_disp: (Tensor), ground truth disparity map
16 | [..., Height, Width] layout
17 | lb: (scalar), the lower bound of disparity you want to mask out
18 | ub: (scalar), the upper bound of disparity you want to mask out
19 | Returns:
20 | error_dict: (dict), the error of 1px, 2px, 3px, 5px, in percent,
21 | range [0,100] and average error epe
22 | """
23 | error_dict = {}
24 | if est_disp is None:
25 | warnings.warn('Estimated disparity map is None')
26 | return error_dict
27 | if gt_disp is None:
28 | warnings.warn('Reference ground truth disparity map is None')
29 | return error_dict
30 |
31 | if torch.is_tensor(est_disp):
32 | est_disp = est_disp.clone().cpu()
33 |
34 | if torch.is_tensor(gt_disp):
35 | gt_disp = gt_disp.clone().cpu()
36 |
37 | assert est_disp.shape == gt_disp.shape, "Estimated Disparity map with shape: {}, but GroundTruth Disparity map" \
38 | " with shape: {}".format(est_disp.shape, gt_disp.shape)
39 |
40 | error_dict = calc_error(est_disp, gt_disp, lb=lb, ub=ub)
41 |
42 | return error_dict
43 |
44 |
45 | def do_occlusion_evaluation(est_disp, ref_gt_disp, target_gt_disp, lb, ub):
46 | """
47 | Do occlusoin evaluation.
48 | Args:
49 | est_disp: (Tensor), estimated disparity map
50 | [BatchSize, 1, Height, Width] layout
51 | ref_gt_disp: (Tensor), reference(left) ground truth disparity map
52 | [BatchSize, 1, Height, Width] layout
53 | target_gt_disp: (Tensor), target(right) ground truth disparity map,
54 | [BatchSize, 1, Height, Width] layout
55 | lb: (scalar): the lower bound of disparity you want to mask out
56 | ub: (scalar): the upper bound of disparity you want to mask out
57 | Returns:
58 | """
59 | error_dict = {}
60 | if est_disp is None:
61 | warnings.warn('Estimated disparity map is None, expected given')
62 | return error_dict
63 | if ref_gt_disp is None:
64 | warnings.warn('Reference ground truth disparity map is None, expected given')
65 | return error_dict
66 | if target_gt_disp is None:
67 | warnings.warn('Target ground truth disparity map is None, expected given')
68 | return error_dict
69 |
70 | if torch.is_tensor(est_disp):
71 | est_disp = est_disp.clone().cpu()
72 | if torch.is_tensor(ref_gt_disp):
73 | ref_gt_disp = ref_gt_disp.clone().cpu()
74 | if torch.is_tensor(target_gt_disp):
75 | target_gt_disp = target_gt_disp.clone().cpu()
76 |
77 | assert est_disp.shape == ref_gt_disp.shape and target_gt_disp.shape == ref_gt_disp.shape, "{}, {}, {}".format(
78 | est_disp.shape, ref_gt_disp.shape, target_gt_disp.shape)
79 |
80 | warp_ref_gt_disp = inverse_warp(target_gt_disp.clone(), -ref_gt_disp.clone(), mode='disparity')
81 | theta = 1.0
82 | eps = 1e-6
83 | occlusion = (
84 | (torch.abs(warp_ref_gt_disp.clone() - ref_gt_disp.clone()) > theta) |
85 | (torch.abs(warp_ref_gt_disp.clone()) < eps)
86 | ).prod(dim=1, keepdim=True).type_as(ref_gt_disp)
87 | occlusion = occlusion.clamp(0, 1)
88 |
89 | occlusion_error_dict = calc_error(
90 | est_disp.clone() * occlusion.clone(),
91 | ref_gt_disp.clone() * occlusion.clone(),
92 | lb=lb, ub=ub
93 | )
94 | for key in occlusion_error_dict.keys():
95 | error_dict['occ_' + key] = occlusion_error_dict[key]
96 |
97 | not_occlusion = 1.0 - occlusion
98 | not_occlusion_error_dict = calc_error(
99 | est_disp.clone() * not_occlusion.clone(),
100 | ref_gt_disp.clone() * not_occlusion.clone(),
101 | lb=lb, ub=ub
102 | )
103 | for key in not_occlusion_error_dict.keys():
104 | error_dict['noc_' + key] = not_occlusion_error_dict[key]
105 |
106 | return error_dict
--------------------------------------------------------------------------------
/utils/evaluation_old/flow_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import warnings
3 |
4 | from .flow_pixel_error import flow_calc_error
5 |
6 | def do_flow_evaluation(est_flow, gt_flow, lb=0.0, ub=400, sparse=False):
7 | """
8 | Do pixel error evaluation. (See KITTI evaluation protocols for details.)
9 | Args:
10 | est_flow: (Tensor), estimated flow map
11 | [..., 2, Height, Width] layout
12 | gt_flow: (Tensor), ground truth flow map
13 | [..., 2, Height, Width] layout
14 | lb: (scalar), the lower bound of disparity you want to mask out
15 | ub: (scalar), the upper bound of disparity you want to mask out
16 | sparse: (bool), whether the given flow is sparse, default False
17 | Returns:
18 | error_dict (dict): the error of 1px, 2px, 3px, 5px, in percent,
19 | range [0,100] and average error epe
20 | """
21 | error_dict = {}
22 | if est_flow is None:
23 | warnings.warn('Estimated flow map is None')
24 | return error_dict
25 | if gt_flow is None:
26 | warnings.warn('Reference ground truth flow map is None')
27 | return error_dict
28 |
29 | if torch.is_tensor(est_flow):
30 | est_flow = est_flow.clone().cpu()
31 |
32 | if torch.is_tensor(gt_flow):
33 | gt_flow = gt_flow.clone().cpu()
34 |
35 | error_dict = flow_calc_error(est_flow, gt_flow, sparse=sparse)
36 |
37 | return error_dict
38 |
--------------------------------------------------------------------------------
/utils/evaluation_old/flow_pixel_error.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 | def zero_mask(input, eps=1e-12):
6 | mask = abs(input) < eps
7 | return mask
8 |
9 | def flow_calc_error(est_flow=None, gt_flow=None, lb=0.0, ub=400, sparse=False):
10 | """
11 | Args:
12 | est_flow: (Tensor), estimated flow map
13 | [..., 2, Height, Width] layout
14 | gt_flow: (Tensor), ground truth flow map
15 | [..., 2, Height, Width] layout
16 | lb: (scalar), the lower bound of disparity you want to mask out
17 | ub: (scalar), the upper bound of disparity you want to mask out
18 | sparse: (bool), whether the given flow is sparse, default False
19 | Output:
20 | dict: the error of 1px, 2px, 3px, 5px, in percent,
21 | range [0,100] and average error epe
22 | """
23 | error1 = torch.Tensor([0.])
24 | error2 = torch.Tensor([0.])
25 | error3 = torch.Tensor([0.])
26 | error5 = torch.Tensor([0.])
27 | epe = torch.Tensor([0.])
28 |
29 | if (not torch.is_tensor(est_flow)) or (not torch.is_tensor(gt_flow)):
30 | return {
31 | '1px': error1 * 100,
32 | '2px': error2 * 100,
33 | '3px': error3 * 100,
34 | '5px': error5 * 100,
35 | 'epe': epe
36 | }
37 |
38 | assert torch.is_tensor(est_flow) and torch.is_tensor(gt_flow)
39 | assert est_flow.shape == gt_flow.shape
40 |
41 | est_flow = est_flow.clone().cpu()
42 | gt_flow = gt_flow.clone().cpu()
43 | if len(gt_flow.shape) == 3:
44 | gt_flow = gt_flow.unsqueeze(0)
45 | est_flow = est_flow.unsqueeze(0)
46 |
47 | assert gt_flow.shape[1] == 2, "flow should have horizontal and vertical dimension, " \
48 | "but got {}".format(gt_flow.shape[1])
49 |
50 | # [B, 1, H, W]
51 | gt_u, gt_v = gt_flow[:, 0:1, :, :], gt_flow[:, 1:2, :, :]
52 | est_u, est_v = est_flow[:, 0:1, :, :], est_flow[:, 1:2, :, :]
53 |
54 | # get valid mask
55 | # [B, 1, H, W]
56 | mask = torch.ones(gt_u.shape, dtype=torch.bool, device=gt_u.device)
57 | if sparse:
58 | mask = mask & (~(zero_mask(gt_u) & zero_mask(gt_v)))
59 | mask = mask & (~(torch.isnan(gt_u) | torch.isnan(gt_v)))
60 |
61 | rad = torch.sqrt(gt_u**2 + gt_v**2)
62 | mask = mask & (rad > lb) & (rad < ub)
63 |
64 | mask.detach_()
65 | if abs(mask.float().sum()) < 1.0:
66 | return {
67 | '1px': error1 * 100,
68 | '2px': error2 * 100,
69 | '3px': error3 * 100,
70 | '5px': error5 * 100,
71 | 'epe': epe
72 | }
73 |
74 |
75 |
76 | gt_u = gt_u[mask]
77 | gt_v = gt_v[mask]
78 | est_u = est_u[mask]
79 | est_v = est_v[mask]
80 |
81 | abs_error = torch.sqrt((gt_u - est_u)**2 + (gt_v - est_v)**2)
82 | total_num = mask.float().sum()
83 |
84 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num
85 | error2 = torch.sum(torch.gt(abs_error, 2).float()) / total_num
86 | error3 = torch.sum(torch.gt(abs_error, 3).float()) / total_num
87 | error5 = torch.sum(torch.gt(abs_error, 5).float()) / total_num
88 | epe = abs_error.float().mean()
89 |
90 | return {
91 | '1px': error1 * 100,
92 | '2px': error2 * 100,
93 | '3px': error3 * 100,
94 | '5px': error5 * 100,
95 | 'epe': epe
96 | }
--------------------------------------------------------------------------------
/utils/evaluation_old/inverse_warp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def inverse_warp(img: torch.Tensor,
7 | motion: torch.Tensor,
8 | mode: str ='disparity',
9 | K: torch.Tensor = None,
10 | inv_K: torch.Tensor = None,
11 | T_target_to_source: torch.Tensor = None,
12 | interpolate_mode: str = 'bilinear',
13 | padding_mode: str = 'zeros',
14 | eps: float = 1e-7,
15 | output_all: bool = False):
16 | """
17 | sample the image pixel value from source image and project to target image space,
18 | Args:
19 | img: (Tensor): the source image (where to sample pixels)
20 | [BatchSize, C, Height, Width]
21 | motion: (Tensor): disparity/depth/flow map of the target image
22 | [BatchSize, 1, Height, Width]
23 | mode: (str): which kind of warp to perform, including: ['disparity', 'depth', 'flow']
24 | K: (Optional, Tensor): instrincs of camera
25 | [BatchSize, 3, 3] or [BatchSize, 4, 4]
26 | inv_K: (Optional, Tensor): invserse instrincs of camera
27 | [BatchSize, 3, 3] or [BatchSize, 4, 4]
28 | T_target_to_source: (Optional, Tensor): predicted transformation matrix from target image to source image frames
29 | [BatchSize, 4, 4]
30 | interpolate_mode: (str): interpolate mode when grid sample, default is bilinear
31 | padding_mode: (str): padding mode when grid sample, default is zero padding
32 | eps: (float): eplison value to avoid divide 0, default 1e-7
33 | output_all: (bool): if output all result during warp, default False
34 | Returns:
35 | projected_img: (Tensor): source image warped to the target image
36 | [BatchSize, C, Height, Width]
37 | output: (Optional, Dict): such as optical flow, flow mask, triangular depth, src_pixel_coord and so on
38 | """
39 | B, C, H, W = motion.shape
40 | device = motion.device
41 | dtype = motion.dtype
42 | output = {}
43 |
44 | if mode == 'disparity':
45 | assert C == 1, "Disparity map must be 1 channel, but {} got!".format(C)
46 | # [B, 2, H, W]
47 | pixel_coord = mesh_grid(B, H, W, device, dtype)
48 | X = pixel_coord[:, 0, :, :] + motion[:, 0]
49 | Y = pixel_coord[:, 1, :, :]
50 | elif mode == 'flow':
51 | assert C == 2, "Optical flow map must be 2 channel, but {} got!".format(C)
52 | # [B, 2, H, W]
53 | pixel_coord = mesh_grid(B, H, W, device, dtype)
54 | X = pixel_coord[:, 0, :, :] + motion[:, 0]
55 | Y = pixel_coord[:, 1, :, :] + motion[:, 1]
56 | elif mode == 'depth':
57 | assert C == 1, "Disparity map must be 1 channel, but {} got!".format(C)
58 | outs = project_to_3d(motion, K, inv_K, T_target_to_source, eps)
59 | output.update(outs)
60 | src_pixel_coord = outs['src_pixel_coord']
61 | X = src_pixel_coord[:, 0, :, :]
62 | Y = src_pixel_coord[:, 1, :, :]
63 |
64 | else:
65 | raise TypeError("Inverse warp only support [disparity, flow, depth] mode, but {} got".format(mode))
66 |
67 | X_norm = 2 * X / (W-1) - 1
68 | Y_norm = 2 * Y / (H-1) - 1
69 | # [B, H, W, 2]
70 | pixel_coord_norm = torch.stack((X_norm, Y_norm), dim=3)
71 |
72 | projected_img = F.grid_sample(img, pixel_coord_norm, mode=interpolate_mode, padding_mode=padding_mode, align_corners=True)
73 |
74 | if output_all:
75 | return projected_img, output
76 | else:
77 | return projected_img
78 |
79 |
80 | def mesh_grid(b, h, w, device, dtype=torch.float):
81 | """ construct pixel coordination in an image"""
82 | # [1, H, W] copy 0-width for h times : x coord
83 | x_range = torch.arange(0, w, device=device, dtype=dtype).view(1, 1, 1, w).expand(b, 1, h, w)
84 | # [1, H, W] copy 0-height for w times : y coord
85 | y_range = torch.arange(0, h, device=device, dtype=dtype).view(1, 1, h, 1).expand(b, 1, h, w)
86 |
87 | # [b, 2, h, w]
88 | pixel_coord = torch.cat((x_range, y_range), dim=1)
89 |
90 | return pixel_coord
91 |
92 | def project_to_3d(depth, K, inv_K=None, T_target_to_source:torch.Tensor=None, eps=1e-7):
93 | """
94 | project depth map to 3D space
95 | Args:
96 | depth: (Tensor): depth map(s), can be several depth maps concatenated at channel dimension
97 | [BatchSize, Channel, Height, Width]
98 | K: (Tensor): instrincs of camera
99 | [BatchSize, 3, 3] or [BatchSize, 4, 4]
100 | inv_K: (Optional, Tensor): invserse instrincs of camera
101 | [BatchSize, 3, 3] or [BatchSize, 4, 4]
102 | T_target_to_source: (Optional, Tensor): predicted transformation matrix from target image to source image frames
103 | [BatchSize, 4, 4]
104 | eps: (float): eplison value to avoid divide 0, default 1e-7
105 |
106 | Returns: Dict including
107 | homo_points_3d: (Tensor): the homogeneous points after depth project to 3D space, [x, y, z, 1]
108 | [BatchSize, 4, Channel*Height*Width]
109 | if T_target_to_source provided:
110 |
111 | triangular_depth: (Tensor): the depth map after the 3D points project to source camera
112 | [BatchSize, Channel, Height, Width]
113 | optical_flow: (Tensor): by 3D projection, the rigid flow can be got
114 | [BatchSize, Channel*2, Height, Width], to get the 2nd flow, index like [:, 2:4, :, :]
115 | flow_mask: (Tensor): the mask indicates which pixel's optical flow is valid
116 | [BatchSize, Channel, Height, Width]
117 | """
118 |
119 | # support C >=1, for C > 1, it means several depth maps are concatenated at channel dimension
120 | B, C, H, W = depth.size()
121 | device = depth.device
122 | dtype = depth.dtype
123 | output = {}
124 |
125 | # [B, 2, H, W]
126 | pixel_coord = mesh_grid(B, H, W, device, dtype)
127 | ones = torch.ones(B, 1, H, W, device=device, dtype=dtype)
128 | # [B, 3, H, W], homogeneous coordination of image pixel, [x, y, 1]
129 | homo_pixel_coord = torch.cat((pixel_coord, ones), dim=1).contiguous()
130 |
131 | # [B, 3, H*W] -> [B, 3, C*H*W]
132 | homo_pixel_coord = homo_pixel_coord.view(B, 3, -1).repeat(1, 1, C).contiguous()
133 | # [B, C*H*W] -> [B, 1, C*H*W]
134 | depth = depth.view(B, -1).unsqueeze(dim=1).contiguous()
135 | if inv_K is None:
136 | inv_K = torch.inverse(K[:, :3, :3])
137 | # [B, 3, C*H*W]
138 | points_3d = torch.matmul(inv_K[:, :3, :3], homo_pixel_coord) * depth
139 | ones = torch.ones(B, 1, C*H*W, device=device, dtype=dtype)
140 | # [B, 4, C*H*W], homogeneous coordiate, [x, y, z, 1]
141 | homo_points_3d = torch.cat((points_3d, ones), dim=1)
142 | output['homo_points_3d'] = homo_points_3d
143 |
144 | if T_target_to_source is not None:
145 | if K.shape[-1] == 3:
146 | new_K = torch.eye(4, device=device, dtype=dtype).unsqueeze(dim=0).repeat(B, 1, 1)
147 | new_K[:, :3, :3] = K[:, :3, :3]
148 | # [B, 3, 4]
149 | P = torch.matmul(new_K, T_target_to_source)[:, :3, :]
150 | else:
151 | # [B, 3, 4]
152 | P = torch.matmul(K, T_target_to_source)[:, :3, :]
153 | # [B, 3, C*H*W]
154 | src_points_3d = torch.matmul(P, homo_points_3d)
155 |
156 | # [B, C*H*W] -> [B, C, H, W], the depth map after 3D points projected to source camera
157 | triangular_depth = src_points_3d[:, -1, :].reshape(B, C, H, W).contiguous()
158 | output['triangular_depth'] = triangular_depth
159 | # [B, 2, C*H*W]
160 | src_pixel_coord = src_points_3d[:, :2, :] / (src_points_3d[:, 2:3, :] + eps)
161 | # [B, 2, C, H, W] -> [B, C, 2, H, W]
162 | src_pixel_coord = src_pixel_coord.reshape(B, 2, C, H, W).permute(0, 2, 1, 3, 4).contiguous()
163 |
164 | # [B, C, 1, H, W]
165 | mask = (src_pixel_coord[:, :, 0:1] >=0) & (src_pixel_coord[:, :, 0:1] <= W-1) \
166 | & (src_pixel_coord[:, :, 1:2] >=0) & (src_pixel_coord[:, :, 1:2] <= H-1)
167 |
168 | # valid flow mask
169 | mask = mask.reshape(B, C, H, W).contiguous()
170 | output['flow_mask'] = mask
171 | # [B, C*2, H, W]
172 | src_pixel_coord = src_pixel_coord.reshape(B, C*2, H, W).contiguous()
173 | output['src_pixel_coord'] = src_pixel_coord
174 | # [B, C*2, H, W]
175 | optical_flow = src_pixel_coord - pixel_coord.repeat(1, C, 1, 1)
176 | output['optical_flow'] = optical_flow
177 |
178 | return output
--------------------------------------------------------------------------------
/utils/evaluation_old/pixel_error.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 |
6 | def calc_error(est_disp=None, gt_disp=None, lb=None, ub=None):
7 | """
8 | Args:
9 | est_disp (Tensor): in [..., Height, Width] layout
10 | gt_disp (Tensor): in [..., Height, Width] layout
11 | lb (scalar): the lower bound of disparity you want to mask out
12 | ub (scalar): the upper bound of disparity you want to mask out
13 | Output:
14 | dict: the error of 1px, 2px, 3px, 5px, in percent,
15 | range [0,100] and average error epe
16 | """
17 | error1 = torch.Tensor([0.])
18 | error2 = torch.Tensor([0.])
19 | error3 = torch.Tensor([0.])
20 | error5 = torch.Tensor([0.])
21 | epe = torch.Tensor([0.])
22 |
23 | if (not torch.is_tensor(est_disp)) or (not torch.is_tensor(gt_disp)):
24 | return {
25 | '1px': error1 * 100,
26 | '2px': error2 * 100,
27 | '3px': error3 * 100,
28 | '5px': error5 * 100,
29 | 'epe': epe
30 | }
31 |
32 | assert torch.is_tensor(est_disp) and torch.is_tensor(gt_disp)
33 | assert est_disp.shape == gt_disp.shape
34 |
35 | est_disp = est_disp.clone().cpu()
36 | gt_disp = gt_disp.clone().cpu()
37 |
38 | mask = torch.ones(gt_disp.shape, dtype=torch.bool, device=gt_disp.device)
39 | if lb is not None:
40 | mask = mask & (gt_disp >= lb)
41 | if ub is not None:
42 | mask = mask & (gt_disp <= ub)
43 | # mask = mask & (est_disp <= ub)
44 |
45 | mask.detach_()
46 | if abs(mask.float().sum()) < 1.0:
47 | return {
48 | '1px': error1 * 100,
49 | '2px': error2 * 100,
50 | '3px': error3 * 100,
51 | '5px': error5 * 100,
52 | 'epe': epe
53 | }
54 | ## original error compute
55 | gt_disp = gt_disp[mask]
56 | est_disp = est_disp[mask]
57 | abs_error = torch.abs(gt_disp - est_disp)
58 |
59 | ## hh error compute for top k ###################
60 | # abs_error = torch.abs(gt_disp - est_disp) * mask
61 | # pix_sum = torch.gt(abs_error, 1)
62 | # # error_sum = abs_error.sum(1).sum(1)
63 | # error_sum = pix_sum.sum(1).sum(1)
64 | # mask_sum = mask.sum(1).sum(1)
65 |
66 | # per_img_error = error_sum / mask_sum
67 |
68 | # values, indexs = torch.topk(-per_img_error, 1343)
69 | # gt_disp = gt_disp[indexs, :, :]
70 | # est_disp = est_disp[indexs, :, :]
71 | # mask = mask[indexs, :, :]
72 | # abs_error = torch.abs(gt_disp - est_disp) * mask
73 | ###############################
74 |
75 |
76 | total_num = mask.float().sum()
77 |
78 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num
79 | error2 = torch.sum(torch.gt(abs_error, 2).float()) / total_num
80 | error3 = torch.sum(torch.gt(abs_error, 3).float()) / total_num
81 | error5 = torch.sum(torch.gt(abs_error, 5).float()) / total_num
82 | epe = abs_error.float().mean()
83 |
84 | # .mean() will get a tensor with size: torch.Size([]), after decorate with torch.Tensor, the size will be: torch.Size([1])
85 | return {
86 | '1px': torch.Tensor([error1 * 100]),
87 | '2px': torch.Tensor([error2 * 100]),
88 | '3px': torch.Tensor([error3 * 100]),
89 | '5px': torch.Tensor([error5 * 100]),
90 | 'epe': torch.Tensor([epe]),
91 | }
--------------------------------------------------------------------------------
/utils/flow.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import numpy as np
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | parent_dir_name = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
10 | sys.path.append(parent_dir_name)
11 |
12 | from utils.iwe import purge_unfeasible, get_interpolation, interpolate
13 |
14 |
15 | class EventWarping(nn.Module):
16 | """
17 | Contrast maximization loss, as described in Section 3.2 of the paper 'Unsupervised Event-based Learning
18 | of Optical Flow, Depth, and Egomotion', Zhu et al., CVPR'19.
19 | The contrast maximization loss is the minimization of the per-pixel and per-polarity image of averaged
20 | timestamps of the input events after they have been compensated for their motion using the estimated
21 | optical flow. This minimization is performed in a forward and in a backward fashion to prevent scaling
22 | issues during backpropagation.
23 | """
24 |
25 | def __init__(self, res, config):
26 | super(EventWarping, self).__init__()
27 | self.res = res
28 | self.flow_scaling = 1.0
29 | # self.flow_scaling = max(res)
30 | self.weight = config.get("flow_regul_weight", 1.0)
31 | # self.weight = 0.0
32 |
33 |
34 | def forward(self, flow_list, event_list, pol_mask):
35 | """
36 | :param flow_list: [[batch_size x 2 x H x W]] list of optical flow maps
37 | :param event_list: [batch_size x N x 4] input events (y, x, ts, p)
38 | :param pol_mask: [batch_size x N x 2] per-polarity binary mask of the input events
39 | """
40 | # |-------------->x
41 | # |
42 | # |
43 | # |
44 | # V
45 | # y
46 | event_list[: ,:, 0:1] = (1 - event_list[: ,:, 0:1])
47 |
48 |
49 |
50 |
51 | # import pdb; pdb.set_trace()
52 |
53 | # split input
54 | pol_mask = torch.cat([pol_mask for i in range(4)], dim=1)
55 | ts_list = torch.cat([event_list[:, :, 0:1] for i in range(4)], dim=1)
56 |
57 | # flow vector per input event
58 | flow_idx = event_list[:, :, 1:3].clone()
59 | flow_idx[:, :, 0] *= self.res[1] # torch.view is row-major
60 | flow_idx = torch.sum(flow_idx, dim=2)
61 |
62 |
63 | loss = 0
64 | for flow in flow_list:
65 | flow = flow.contiguous()
66 | # get flow for every event in the list
67 | flow = flow.view(flow.shape[0], 2, -1)
68 | event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component
69 | event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component
70 | # event_flowy = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # vertical component
71 | # event_flowx = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # horizontal component
72 |
73 | event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1)
74 | event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1)
75 | event_flow = torch.cat([event_flowy, event_flowx], dim=2)
76 | # B X N x 2 (x, )
77 | # interpolate forward
78 | tref = 1
79 | fw_idx, fw_weights = get_interpolation(event_list, event_flow, tref, self.res, self.flow_scaling)
80 |
81 | # per-polarity image of (forward) warped events
82 | fw_iwe_pos = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 0:1])
83 | fw_iwe_neg = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 1:2])
84 |
85 | # image of (forward) warped averaged timestamps
86 | fw_iwe_pos_ts = interpolate(
87 | fw_idx.long(), fw_weights * ts_list, self.res, polarity_mask=pol_mask[:, :, 0:1]
88 | )
89 | fw_iwe_neg_ts = interpolate(
90 | fw_idx.long(), fw_weights * ts_list, self.res, polarity_mask=pol_mask[:, :, 1:2]
91 | )
92 | fw_iwe_pos_ts /= fw_iwe_pos + 1e-9
93 | fw_iwe_neg_ts /= fw_iwe_neg + 1e-9
94 |
95 | # interpolate backward
96 | tref = 0
97 | bw_idx, bw_weights = get_interpolation(event_list, event_flow, tref, self.res, self.flow_scaling)
98 |
99 | # per-polarity image of (backward) warped events
100 | bw_iwe_pos = interpolate(bw_idx.long(), bw_weights, self.res, polarity_mask=pol_mask[:, :, 0:1])
101 | bw_iwe_neg = interpolate(bw_idx.long(), bw_weights, self.res, polarity_mask=pol_mask[:, :, 1:2])
102 |
103 | # image of (backward) warped averaged timestamps
104 | bw_iwe_pos_ts = interpolate(
105 | bw_idx.long(), bw_weights * (1 - ts_list), self.res, polarity_mask=pol_mask[:, :, 0:1]
106 | )
107 | bw_iwe_neg_ts = interpolate(
108 | bw_idx.long(), bw_weights * (1 - ts_list), self.res, polarity_mask=pol_mask[:, :, 1:2]
109 | )
110 | bw_iwe_pos_ts /= bw_iwe_pos + 1e-9
111 | bw_iwe_neg_ts /= bw_iwe_neg + 1e-9
112 |
113 | # flow smoothing
114 | flow = flow.view(flow.shape[0], 2, self.res[0], self.res[1])
115 |
116 | flow = flow / max(self.res)
117 |
118 | flow_dx = flow[:, :, :-1, :] - flow[:, :, 1:, :]
119 | flow_dy = flow[:, :, :, :-1] - flow[:, :, :, 1:]
120 | flow_dx = torch.sqrt(flow_dx ** 2 + 1e-6) # charbonnier
121 | flow_dy = torch.sqrt(flow_dy ** 2 + 1e-6) # charbonnier
122 | loss += (
123 | torch.sum(fw_iwe_pos_ts ** 2)
124 | + torch.sum(fw_iwe_neg_ts ** 2)
125 | + torch.sum(bw_iwe_pos_ts ** 2)
126 | + torch.sum(bw_iwe_neg_ts ** 2)
127 | + self.weight * (flow_dx.sum() + flow_dy.sum())
128 | )
129 |
130 | return loss
131 |
132 |
133 | # class AveragedIWE(nn.Module):
134 | # """
135 | # Returns an image of the per-pixel and per-polarity average number of warped events given
136 | # an optical flow map.
137 | # """
138 |
139 | # def __init__(self, config, device):
140 | # super(AveragedIWE, self).__init__()
141 | # self.res = config["loader"]["resolution"]
142 | # self.flow_scaling = max(config["loader"]["resolution"])
143 | # self.batch_size = config["loader"]["batch_size"]
144 | # self.device = device
145 |
146 | # def forward(self, flow, event_list, pol_mask):
147 | # """
148 | # :param flow: [batch_size x 2 x H x W] optical flow maps
149 | # :param event_list: [batch_size x N x 4] input events (y, x, ts, p)
150 | # :param pol_mask: [batch_size x N x 2] per-polarity binary mask of the input events
151 | # """
152 |
153 | # # original location of events
154 | # idx = event_list[:, :, 1:3].clone()
155 | # idx[:, :, 0] *= self.res[1] # torch.view is row-major
156 | # idx = torch.sum(idx, dim=2, keepdim=True)
157 |
158 | # # flow vector per input event
159 | # flow_idx = event_list[:, :, 1:3].clone()
160 | # flow_idx[:, :, 0] *= self.res[1] # torch.view is row-major
161 | # flow_idx = torch.sum(flow_idx, dim=2)
162 |
163 | # # get flow for every event in the list
164 | # flow = flow.view(flow.shape[0], 2, -1)
165 | # event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component
166 | # event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component
167 | # event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1)
168 | # event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1)
169 | # event_flow = torch.cat([event_flowy, event_flowx], dim=2)
170 |
171 | # # interpolate forward
172 | # fw_idx, fw_weights = get_interpolation(event_list, event_flow, 1, self.res, self.flow_scaling, round_idx=True)
173 |
174 | # # per-polarity image of (forward) warped events
175 | # fw_iwe_pos = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 0:1])
176 | # fw_iwe_neg = interpolate(fw_idx.long(), fw_weights, self.res, polarity_mask=pol_mask[:, :, 1:2])
177 | # if fw_idx.shape[1] == 0:
178 | # return torch.cat([fw_iwe_pos, fw_iwe_neg], dim=1)
179 |
180 | # # make sure unfeasible mappings are not considered
181 | # pol_list = event_list[:, :, 3:4].clone()
182 | # pol_list[pol_list < 1] = 0 # negative polarity set to 0
183 | # pol_list[fw_weights == 0] = 2 # fake polarity to detect unfeasible mappings
184 |
185 | # # encode unique ID for pixel location mapping (idx <-> fw_idx = m_idx)
186 | # m_idx = torch.cat([idx.long(), fw_idx.long()], dim=2)
187 | # m_idx[:, :, 0] *= self.res[0] * self.res[1]
188 | # m_idx = torch.sum(m_idx, dim=2, keepdim=True)
189 |
190 | # # encode unique ID for per-polarity pixel location mapping (pol_list <-> m_idx = pm_idx)
191 | # pm_idx = torch.cat([pol_list.long(), m_idx.long()], dim=2)
192 | # pm_idx[:, :, 0] *= (self.res[0] * self.res[1]) ** 2
193 | # pm_idx = torch.sum(pm_idx, dim=2, keepdim=True)
194 |
195 | # # number of different pixels locations from where pixels originate during warping
196 | # # this needs to be done per batch as the number of unique indices differs
197 | # fw_iwe_pos_contrib = torch.zeros((flow.shape[0], self.res[0] * self.res[1], 1)).to(self.device)
198 | # fw_iwe_neg_contrib = torch.zeros((flow.shape[0], self.res[0] * self.res[1], 1)).to(self.device)
199 | # for b in range(0, self.batch_size):
200 |
201 | # # per-polarity unique mapping combinations
202 | # unique_pm_idx = torch.unique(pm_idx[b, :, :], dim=0)
203 | # unique_pm_idx = torch.cat(
204 | # [
205 | # unique_pm_idx // ((self.res[0] * self.res[1]) ** 2),
206 | # unique_pm_idx % ((self.res[0] * self.res[1]) ** 2),
207 | # ],
208 | # dim=1,
209 | # ) # (pol_idx, mapping_idx)
210 | # unique_pm_idx = torch.cat(
211 | # [unique_pm_idx[:, 0:1], unique_pm_idx[:, 1:2] % (self.res[0] * self.res[1])], dim=1
212 | # ) # (pol_idx, fw_idx)
213 | # unique_pm_idx[:, 0] *= self.res[0] * self.res[1]
214 | # unique_pm_idx = torch.sum(unique_pm_idx, dim=1, keepdim=True)
215 |
216 | # # per-polarity unique receiving pixels
217 | # unique_pfw_idx, contrib_pfw = torch.unique(unique_pm_idx[:, 0], dim=0, return_counts=True)
218 | # unique_pfw_idx = unique_pfw_idx.view((unique_pfw_idx.shape[0], 1))
219 | # contrib_pfw = contrib_pfw.view((contrib_pfw.shape[0], 1))
220 | # unique_pfw_idx = torch.cat(
221 | # [unique_pfw_idx // (self.res[0] * self.res[1]), unique_pfw_idx % (self.res[0] * self.res[1])],
222 | # dim=1,
223 | # ) # (polarity mask, fw_idx)
224 |
225 | # # positive scatter pixel contribution
226 | # mask_pos = unique_pfw_idx[:, 0:1].clone()
227 | # mask_pos[mask_pos == 2] = 0 # remove unfeasible mappings
228 | # b_fw_iwe_pos_contrib = torch.zeros((self.res[0] * self.res[1], 1)).to(self.device)
229 | # b_fw_iwe_pos_contrib = b_fw_iwe_pos_contrib.scatter_add_(
230 | # 0, unique_pfw_idx[:, 1:2], mask_pos.float() * contrib_pfw.float()
231 | # )
232 |
233 | # # negative scatter pixel contribution
234 | # mask_neg = unique_pfw_idx[:, 0:1].clone()
235 | # mask_neg[mask_neg == 2] = 1 # remove unfeasible mappings
236 | # mask_neg = 1 - mask_neg # invert polarities
237 | # b_fw_iwe_neg_contrib = torch.zeros((self.res[0] * self.res[1], 1)).to(self.device)
238 | # b_fw_iwe_neg_contrib = b_fw_iwe_neg_contrib.scatter_add_(
239 | # 0, unique_pfw_idx[:, 1:2], mask_neg.float() * contrib_pfw.float()
240 | # )
241 |
242 | # # store info
243 | # fw_iwe_pos_contrib[b, :, :] = b_fw_iwe_pos_contrib
244 | # fw_iwe_neg_contrib[b, :, :] = b_fw_iwe_neg_contrib
245 |
246 | # # average number of warped events per pixel
247 | # fw_iwe_pos_contrib = fw_iwe_pos_contrib.view((flow.shape[0], 1, self.res[0], self.res[1]))
248 | # fw_iwe_neg_contrib = fw_iwe_neg_contrib.view((flow.shape[0], 1, self.res[0], self.res[1]))
249 | # fw_iwe_pos[fw_iwe_pos_contrib > 0] /= fw_iwe_pos_contrib[fw_iwe_pos_contrib > 0]
250 | # fw_iwe_neg[fw_iwe_neg_contrib > 0] /= fw_iwe_neg_contrib[fw_iwe_neg_contrib > 0]
251 |
252 | # return torch.cat([fw_iwe_pos, fw_iwe_neg], dim=1)
253 |
--------------------------------------------------------------------------------
/utils/flow_vis.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-08-03
14 |
15 | import numpy as np
16 |
17 | def make_colorwheel():
18 | """
19 | Generates a color wheel for optical flow visualization as presented in:
20 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
21 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
22 |
23 | Code follows the original C++ source code of Daniel Scharstein.
24 | Code follows the the Matlab source code of Deqing Sun.
25 |
26 | Returns:
27 | np.ndarray: Color wheel
28 | """
29 |
30 | RY = 15
31 | YG = 6
32 | GC = 4
33 | CB = 11
34 | BM = 13
35 | MR = 6
36 |
37 | ncols = RY + YG + GC + CB + BM + MR
38 | colorwheel = np.zeros((ncols, 3))
39 | col = 0
40 |
41 | # RY
42 | colorwheel[0:RY, 0] = 255
43 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
44 | col = col+RY
45 | # YG
46 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
47 | colorwheel[col:col+YG, 1] = 255
48 | col = col+YG
49 | # GC
50 | colorwheel[col:col+GC, 1] = 255
51 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
52 | col = col+GC
53 | # CB
54 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
55 | colorwheel[col:col+CB, 2] = 255
56 | col = col+CB
57 | # BM
58 | colorwheel[col:col+BM, 2] = 255
59 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
60 | col = col+BM
61 | # MR
62 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
63 | colorwheel[col:col+MR, 0] = 255
64 | return colorwheel
65 |
66 |
67 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
68 | """
69 | Applies the flow color wheel to (possibly clipped) flow components u and v.
70 |
71 | According to the C++ source code of Daniel Scharstein
72 | According to the Matlab source code of Deqing Sun
73 |
74 | Args:
75 | u (np.ndarray): Input horizontal flow of shape [H,W]
76 | v (np.ndarray): Input vertical flow of shape [H,W]
77 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
78 |
79 | Returns:
80 | np.ndarray: Flow visualization image of shape [H,W,3]
81 | """
82 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
83 | colorwheel = make_colorwheel() # shape [55x3]
84 | ncols = colorwheel.shape[0]
85 | rad = np.sqrt(np.square(u) + np.square(v))
86 | a = np.arctan2(-v, -u)/np.pi
87 | fk = (a+1) / 2*(ncols-1)
88 | k0 = np.floor(fk).astype(np.int32)
89 | k1 = k0 + 1
90 | k1[k1 == ncols] = 0
91 | f = fk - k0
92 | for i in range(colorwheel.shape[1]):
93 | tmp = colorwheel[:,i]
94 | col0 = tmp[k0] / 255.0
95 | col1 = tmp[k1] / 255.0
96 | col = (1-f)*col0 + f*col1
97 | idx = (rad <= 1)
98 | col[idx] = 1 - rad[idx] * (1-col[idx])
99 | col[~idx] = col[~idx] * 0.75 # out of range
100 | # Note the 2-i => BGR instead of RGB
101 | ch_idx = 2-i if convert_to_bgr else i
102 | flow_image[:,:,ch_idx] = np.floor(255 * col)
103 | return flow_image
104 |
105 |
106 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
107 | """
108 | Expects a two dimensional flow image of shape.
109 |
110 | Args:
111 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
112 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
113 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
114 |
115 | Returns:
116 | np.ndarray: Flow visualization image of shape [H,W,3]
117 | """
118 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
119 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
120 | if clip_flow is not None:
121 | flow_uv = np.clip(flow_uv, -clip_flow, clip_flow)
122 | u = flow_uv[:,:,0]
123 | v = flow_uv[:,:,1]
124 | rad = np.sqrt(np.square(u) + np.square(v))
125 | rad_max = np.max(rad)
126 | if clip_flow is not None:
127 | rad_max = clip_flow
128 | epsilon = 1e-5
129 | u = u / (rad_max + epsilon)
130 | v = v / (rad_max + epsilon)
131 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/utils/iwe.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def purge_unfeasible(x, res):
5 | """
6 | Purge unfeasible event locations by setting their interpolation weights to zero.
7 | :param x: location of motion compensated events
8 | :param res: resolution of the image space
9 | :return masked indices
10 | :return mask for interpolation weights
11 | """
12 |
13 | mask = torch.ones((x.shape[0], x.shape[1], 1)).to(x.device)
14 | mask_y = (x[:, :, 0:1] < 0) + (x[:, :, 0:1] >= res[0])
15 | mask_x = (x[:, :, 1:2] < 0) + (x[:, :, 1:2] >= res[1])
16 | mask[mask_y + mask_x] = 0
17 | return x * mask, mask
18 |
19 |
20 | def get_interpolation(events, flow, tref, res, flow_scaling, round_idx=False):
21 | """
22 | Warp the input events according to the provided optical flow map and compute the bilinar interpolation
23 | (or rounding) weights to distribute the events to the closes (integer) locations in the image space.
24 | :param events: [batch_size x N x 4] input events (y, x, ts, p)
25 | :param flow: [batch_size x 2 x H x W] optical flow map
26 | :param tref: reference time toward which events are warped
27 | :param res: resolution of the image space
28 | :param flow_scaling: scalar that multiplies the optical flow map
29 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = False)
30 | :return interpolated event indices
31 | :return interpolation weights
32 | """
33 |
34 | # event propagation
35 | warped_events = events[:, :, 1:3] + (tref - events[:, :, 0:1]) * flow * flow_scaling
36 | # warped_events = events[:, :, 1:3] - (tref - events[:, :, 0:1]) * flow * flow_scaling
37 |
38 | if round_idx:
39 |
40 | # no bilinear interpolation
41 | idx = torch.round(warped_events)
42 | weights = torch.ones(idx.shape).to(events.device)
43 |
44 | else:
45 |
46 | # get scattering indices
47 | top_y = torch.floor(warped_events[:, :, 0:1])
48 | bot_y = torch.floor(warped_events[:, :, 0:1] + 1)
49 | left_x = torch.floor(warped_events[:, :, 1:2])
50 | right_x = torch.floor(warped_events[:, :, 1:2] + 1)
51 |
52 | top_left = torch.cat([top_y, left_x], dim=2)
53 | top_right = torch.cat([top_y, right_x], dim=2)
54 | bottom_left = torch.cat([bot_y, left_x], dim=2)
55 | bottom_right = torch.cat([bot_y, right_x], dim=2)
56 | idx = torch.cat([top_left, top_right, bottom_left, bottom_right], dim=1)
57 |
58 | # get scattering interpolation weights
59 | warped_events = torch.cat([warped_events for i in range(4)], dim=1)
60 | zeros = torch.zeros(warped_events.shape).to(events.device)
61 | weights = torch.max(zeros, 1 - torch.abs(warped_events - idx))
62 |
63 | # purge unfeasible indices
64 | idx, mask = purge_unfeasible(idx, res)
65 |
66 | # make unfeasible weights zero
67 | weights = torch.prod(weights, dim=-1, keepdim=True) * mask # bilinear interpolation
68 |
69 | # prepare indices
70 | idx[:, :, 0] *= res[1] # torch.view is row-major
71 | idx = torch.sum(idx, dim=2, keepdim=True)
72 |
73 | # import pdb; pdb.set_trace()
74 |
75 | return idx, weights
76 |
77 |
78 | def interpolate(idx, weights, res, polarity_mask=None):
79 | """
80 | Create an image-like representation of the warped events.
81 | :param idx: [batch_size x N x 1] warped event locations
82 | :param weights: [batch_size x N x 1] interpolation weights for the warped events
83 | :param res: resolution of the image space
84 | :param polarity_mask: [batch_size x N x 2] polarity mask for the warped events (default = None)
85 | :return image of warped events
86 | """
87 |
88 | if polarity_mask is not None:
89 | weights = weights * polarity_mask
90 | iwe = torch.zeros((idx.shape[0], res[0] * res[1], 1)).to(idx.device)
91 | iwe = iwe.scatter_add_(1, idx.long(), weights)
92 | iwe = iwe.view((idx.shape[0], 1, res[0], res[1]))
93 | return iwe
94 |
95 |
96 | def deblur_events(flow, event_list, res, flow_scaling=128, round_idx=True, polarity_mask=None):
97 | """
98 | Deblur the input events given an optical flow map.
99 | Event timestamp needs to be normalized between 0 and 1.
100 | :param flow: [batch_size x 2 x H x W] optical flow map
101 | :param events: [batch_size x N x 4] input events (y, x, ts, p)
102 | :param res: resolution of the image space
103 | :param flow_scaling: scalar that multiplies the optical flow map
104 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = False)
105 | :param polarity_mask: [batch_size x N x 2] polarity mask for the warped events (default = None)
106 | :return iwe: [batch_size x 1 x H x W] image of warped events
107 | """
108 |
109 | # flow vector per input event
110 | flow_idx = event_list[:, :, 1:3].clone()
111 | flow_idx[:, :, 0] *= res[1] # torch.view is row-major
112 | flow_idx = torch.sum(flow_idx, dim=2)
113 |
114 | # get flow for every event in the list
115 | flow = flow.view(flow.shape[0], 2, -1)
116 | event_flowy = torch.gather(flow[:, 1, :], 1, flow_idx.long()) # vertical component
117 | event_flowx = torch.gather(flow[:, 0, :], 1, flow_idx.long()) # horizontal component
118 | event_flowy = event_flowy.view(event_flowy.shape[0], event_flowy.shape[1], 1)
119 | event_flowx = event_flowx.view(event_flowx.shape[0], event_flowx.shape[1], 1)
120 | event_flow = torch.cat([event_flowy, event_flowx], dim=2)
121 |
122 | # interpolate forward
123 | fw_idx, fw_weights = get_interpolation(event_list, event_flow, 1, res, flow_scaling, round_idx=round_idx)
124 | if not round_idx:
125 | polarity_mask = torch.cat([polarity_mask for i in range(4)], dim=1)
126 |
127 | # image of (forward) warped events
128 | iwe = interpolate(fw_idx.long(), fw_weights, res, polarity_mask=polarity_mask)
129 |
130 | return iwe
131 |
132 |
133 | def compute_pol_iwe(flow, event_list, res, pos_mask, neg_mask, flow_scaling=128, round_idx=True):
134 | """
135 | Create a per-polarity image of warped events given an optical flow map.
136 | :param flow: [batch_size x 2 x H x W] optical flow map
137 | :param event_list: [batch_size x N x 4] input events (y, x, ts, p)
138 | :param res: resolution of the image space
139 | :param pos_mask: [batch_size x N x 1] polarity mask for positive events
140 | :param neg_mask: [batch_size x N x 1] polarity mask for negative events
141 | :param flow_scaling: scalar that multiplies the optical flow map
142 | :param round_idx: whether or not to round the event locations instead of doing bilinear interp. (default = True)
143 | :return iwe: [batch_size x 2 x H x W] image of warped events
144 | """
145 |
146 | iwe_pos = deblur_events(
147 | flow, event_list, res, flow_scaling=flow_scaling, round_idx=round_idx, polarity_mask=pos_mask
148 | )
149 | iwe_neg = deblur_events(
150 | flow, event_list, res, flow_scaling=flow_scaling, round_idx=round_idx, polarity_mask=neg_mask
151 | )
152 | iwe = torch.cat([iwe_pos, iwe_neg], dim=1)
153 |
154 | return iwe
155 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | class Logger():
4 | def __init__(self, savemodel_path, accelerator=None):
5 | self.accelerator = accelerator
6 | if self.accelerator is None or self.accelerator.is_local_main_process:
7 | file_path = os.path.join(savemodel_path, "log.txt")
8 | self.file_= open(file_path, "a+")
9 | self.counter = 0
10 | def __del__(self):
11 | if self.accelerator is None or self.accelerator.is_local_main_process:
12 | self.file_.close()
13 | def log_and_print(self, contents):
14 | if self.accelerator is None or self.accelerator.is_local_main_process:
15 | self.counter+=1
16 | self.file_.write(str(contents) + "\n")
17 | print(contents)
18 | if self.counter == 5:
19 | self.file_.flush()
20 | self.counter = 0
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 | import random
4 |
5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
6 | 'std': [0.229, 0.224, 0.225]}
7 |
8 | #__imagenet_stats = {'mean': [0.5, 0.5, 0.5],
9 | # 'std': [0.5, 0.5, 0.5]}
10 |
11 | __imagenet_pca = {
12 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
13 | 'eigvec': torch.Tensor([
14 | [-0.5675, 0.7192, 0.4009],
15 | [-0.5808, -0.0045, -0.8140],
16 | [-0.5836, -0.6948, 0.4203],
17 | ])
18 | }
19 |
20 |
21 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats):
22 | t_list = [
23 | transforms.ToTensor(),
24 | transforms.Normalize(**normalize),
25 | ]
26 | #if scale_size != input_size:
27 | #t_list = [transforms.Scale((960,540))] + t_list
28 |
29 | return transforms.Compose(t_list)
30 |
31 |
32 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
33 | t_list = [
34 | transforms.RandomCrop(input_size),
35 | transforms.ToTensor(),
36 | transforms.Normalize(**normalize),
37 | ]
38 | if scale_size != input_size:
39 | t_list = [transforms.Scale(scale_size)] + t_list
40 |
41 | transforms.Compose(t_list)
42 |
43 |
44 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
45 | padding = int((scale_size - input_size) / 2)
46 | return transforms.Compose([
47 | transforms.RandomCrop(input_size, padding=padding),
48 | transforms.RandomHorizontalFlip(),
49 | transforms.ToTensor(),
50 | transforms.Normalize(**normalize),
51 | ])
52 |
53 |
54 | def inception_preproccess(input_size, normalize=__imagenet_stats):
55 | return transforms.Compose([
56 | transforms.RandomSizedCrop(input_size),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.ToTensor(),
59 | transforms.Normalize(**normalize)
60 | ])
61 | def inception_color_preproccess(input_size, normalize=__imagenet_stats):
62 | return transforms.Compose([
63 | #transforms.RandomSizedCrop(input_size),
64 | #transforms.RandomHorizontalFlip(),
65 | transforms.ToTensor(),
66 | ColorJitter(
67 | brightness=0.4,
68 | contrast=0.4,
69 | saturation=0.4,
70 | ),
71 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']),
72 | transforms.Normalize(**normalize)
73 | ])
74 |
75 |
76 | def get_transform(name='imagenet', input_size=None,
77 | scale_size=None, normalize=None, augment=True):
78 | normalize = __imagenet_stats
79 | input_size = 256
80 | if augment:
81 | return inception_color_preproccess(input_size, normalize=normalize)
82 | else:
83 | return scale_crop(input_size=input_size,
84 | scale_size=scale_size, normalize=normalize)
85 |
86 |
87 |
88 |
89 | class Lighting(object):
90 | """Lighting noise(AlexNet - style PCA - based noise)"""
91 |
92 | def __init__(self, alphastd, eigval, eigvec):
93 | self.alphastd = alphastd
94 | self.eigval = eigval
95 | self.eigvec = eigvec
96 |
97 | def __call__(self, img):
98 | if self.alphastd == 0:
99 | return img
100 |
101 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
102 | rgb = self.eigvec.type_as(img).clone()\
103 | .mul(alpha.view(1, 3).expand(3, 3))\
104 | .mul(self.eigval.view(1, 3).expand(3, 3))\
105 | .sum(1).squeeze()
106 |
107 | return img.add(rgb.view(3, 1, 1).expand_as(img))
108 |
109 |
110 | class Grayscale(object):
111 |
112 | def __call__(self, img):
113 | gs = img.clone()
114 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
115 | gs[1].copy_(gs[0])
116 | gs[2].copy_(gs[0])
117 | return gs
118 |
119 |
120 | class Saturation(object):
121 |
122 | def __init__(self, var):
123 | self.var = var
124 |
125 | def __call__(self, img):
126 | gs = Grayscale()(img)
127 | alpha = random.uniform(0, self.var)
128 | return img.lerp(gs, alpha)
129 |
130 |
131 | class Brightness(object):
132 |
133 | def __init__(self, var):
134 | self.var = var
135 |
136 | def __call__(self, img):
137 | gs = img.new().resize_as_(img).zero_()
138 | alpha = random.uniform(0, self.var)
139 | return img.lerp(gs, alpha)
140 |
141 |
142 | class Contrast(object):
143 |
144 | def __init__(self, var):
145 | self.var = var
146 |
147 | def __call__(self, img):
148 | gs = Grayscale()(img)
149 | gs.fill_(gs.mean())
150 | alpha = random.uniform(0, self.var)
151 | return img.lerp(gs, alpha)
152 |
153 |
154 | class RandomOrder(object):
155 | """ Composes several transforms together in random order.
156 | """
157 |
158 | def __init__(self, transforms):
159 | self.transforms = transforms
160 |
161 | def __call__(self, img):
162 | if self.transforms is None:
163 | return img
164 | order = torch.randperm(len(self.transforms))
165 | for i in order:
166 | img = self.transforms[i](img)
167 | return img
168 |
169 |
170 | class ColorJitter(RandomOrder):
171 |
172 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
173 | self.transforms = []
174 | if brightness != 0:
175 | self.transforms.append(Brightness(brightness))
176 | if contrast != 0:
177 | self.transforms.append(Contrast(contrast))
178 | if saturation != 0:
179 | self.transforms.append(Saturation(saturation))
180 |
--------------------------------------------------------------------------------
/utils/readpfm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sys
4 |
5 |
6 | def readPFM(file):
7 | file = open(file, 'rb')
8 |
9 | color = None
10 | width = None
11 | height = None
12 | scale = None
13 | endian = None
14 |
15 | header = file.readline().rstrip()
16 | if header == 'PF':
17 | color = True
18 | elif header == 'Pf':
19 | color = False
20 | else:
21 | raise Exception('Not a PFM file.')
22 |
23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
24 | if dim_match:
25 | width, height = map(int, dim_match.groups())
26 | else:
27 | raise Exception('Malformed PFM header.')
28 |
29 | scale = float(file.readline().rstrip())
30 | if scale < 0: # little-endian
31 | endian = '<'
32 | scale = -scale
33 | else:
34 | endian = '>' # big-endian
35 |
36 | data = np.fromfile(file, endian + 'f')
37 | shape = (height, width, 3) if color else (height, width)
38 |
39 | data = np.reshape(data, shape)
40 | data = np.flipud(data)
41 | return data, scale
42 |
43 |
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def disp_err_to_color(disp_est, disp_gt):
4 | """
5 | Calculate the error map between disparity estimation and disparity ground-truth
6 | hot color -> big error, cold color -> small error
7 | Args:
8 | disp_est (numpy.array): estimated disparity map
9 | in (Height, Width) layout, range [0,inf]
10 | disp_gt (numpy.array): ground truth disparity map
11 | in (Height, Width) layout, range [0,inf]
12 | Returns:
13 | disp_err (numpy.array): disparity error map
14 | in (Height, Width, 3) layout, range [0,1]
15 | """
16 | """ matlab
17 | function D_err = disp_error_image (D_gt,D_est,tau,dilate_radius)
18 | if nargin==3
19 | dilate_radius = 1;
20 | end
21 | [E,D_val] = disp_error_map (D_gt,D_est);
22 | E = min(E/tau(1),(E./abs(D_gt))/tau(2));
23 | cols = error_colormap();
24 | D_err = zeros([size(D_gt) 3]);
25 | for i=1:size(cols,1)
26 | [v,u] = find(D_val > 0 & E >= cols(i,1) & E <= cols(i,2));
27 | D_err(sub2ind(size(D_err),v,u,1*ones(length(v),1))) = cols(i,3);
28 | D_err(sub2ind(size(D_err),v,u,2*ones(length(v),1))) = cols(i,4);
29 | D_err(sub2ind(size(D_err),v,u,3*ones(length(v),1))) = cols(i,5);
30 | end
31 | D_err = imdilate(D_err,strel('disk',dilate_radius));
32 | """
33 | # error color map with interval (0, 0.1875, 0.375, 0.75, 1.5, 3, 6, 12, 24, 48, inf)/3.0
34 | # different interval corresponds to different 3-channel projection
35 | cols = np.array(
36 | [
37 | [0 / 3.0, 0.1875 / 3.0, 49, 54, 149],
38 | [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180],
39 | [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209],
40 | [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233],
41 | [1.5 / 3.0, 3 / 3.0, 224, 243, 248],
42 | [3 / 3.0, 6 / 3.0, 254, 224, 144],
43 | [6 / 3.0, 12 / 3.0, 253, 174, 97],
44 | [12 / 3.0, 24 / 3.0, 244, 109, 67],
45 | [24 / 3.0, 48 / 3.0, 215, 48, 39],
46 | [48 / 3.0, float("inf"), 165, 0, 38]
47 | ]
48 | )
49 |
50 | # [0, 1] -> [0, 255.0]
51 | disp_est = disp_est.copy() * 255.0
52 | disp_gt = disp_gt.copy() * 255.0
53 | # get the error (<3px or <5%) map
54 | tau = [3.0, 0.05]
55 | E = np.abs(disp_est - disp_gt)
56 |
57 | not_empty = disp_gt > 0.0
58 | tmp = np.zeros_like(disp_gt)
59 | tmp[not_empty] = E[not_empty] / disp_gt[not_empty] / tau[1]
60 | E = np.minimum(E / tau[0], tmp)
61 |
62 | h, w = disp_gt.shape
63 | err_im = np.zeros(shape=(h, w, 3)).astype(np.uint8)
64 | for col in cols:
65 | y_x = not_empty & (E >= col[0]) & (E <= col[1])
66 | err_im[y_x] = col[2:]
67 |
68 | # value range [0, 1], shape in [H, W 3]
69 | err_im = err_im.astype(np.float64) / 255.0
70 |
71 | return err_im
--------------------------------------------------------------------------------
/utils/viz.py:
--------------------------------------------------------------------------------
1 | # Copyrights. All rights reserved.
2 | # ECOLE POLYTECHNIQUE FEDERALE DE LAUSANNE, Switzerland,
3 | # Space Center (eSpace), 2018
4 | # See the LICENSE.TXT file for more details.
5 |
6 | import os
7 |
8 | import torch as th
9 | import numpy as np
10 |
11 | # Matplotlib backend should be choosen before pyplot is imported.
12 | import matplotlib
13 | matplotlib.use('Agg')
14 | from matplotlib import pyplot as plt
15 | from mpl_toolkits import axes_grid1
16 | import asyncio
17 |
18 | def gray_to_color(array, colormap_name='jet', vmin=None, vmax=None):
19 | cmap = plt.get_cmap(colormap_name)
20 | norm = plt.Normalize(vmin, vmax)
21 | return cmap(norm(array))
22 |
23 |
24 | def _add_scaled_colorbar(plot, aspect=20, pad_fraction=0.5, **kwargs):
25 | """Adds scaled colorbar to existing plot."""
26 | divider = axes_grid1.make_axes_locatable(plot.axes)
27 | width = axes_grid1.axes_size.AxesY(plot.axes, aspect=1. / aspect)
28 | pad = axes_grid1.axes_size.Fraction(pad_fraction, width)
29 | current_axis = plt.gca()
30 | cax = divider.append_axes("right", size=width, pad=pad)
31 | plt.sca(current_axis)
32 | return plot.axes.figure.colorbar(plot, cax=cax, **kwargs)
33 |
34 |
35 | def save_image(filename, image, color_first=True):
36 | """Save color image to file.
37 |
38 | Args:
39 | filename: image file where the image will be saved..
40 | image: 3d image tensor.
41 | color_first: if True, the color dimesion is the first
42 | dimension of the "image", otherwise the
43 | color dimesion is the last dimesion.
44 | """
45 | figure = plt.figure()
46 | if color_first:
47 | numpy_image = image.permute(1, 2, 0).numpy()
48 | else:
49 | numpy_image = image.numpy()
50 | plot = plt.imshow(numpy_image.astype(np.uint8))
51 | plot.axes.get_xaxis().set_visible(False)
52 | plot.axes.get_yaxis().set_visible(False)
53 | figure.savefig(filename, bbox_inches='tight', dpi=200)
54 | plt.close()
55 |
56 |
57 | def save_matrix(filename,
58 | matrix,
59 | minimum_value=None,
60 | maximum_value=None,
61 | colormap='magma',
62 | is_colorbar=True):
63 | """Saves the matrix to the image file.
64 |
65 | Args:
66 | filename: image file where the matrix will be saved.
67 | matrix: tensor of size (height x width). Some values might be
68 | equal to inf.
69 | minimum_value, maximum value: boundaries of the range.
70 | Values outside ot the range are
71 | shown in white. The colors of other
72 | values are determined by the colormap.
73 | If maximum and minimum values are not
74 | given they are calculated as 0.001 and
75 | 0.999 quantile.
76 | colormap: map that determines color coding of matrix values.
77 | """
78 | figure = plt.figure()
79 | noninf_mask = matrix != float('inf')
80 | if minimum_value is None:
81 | minimum_value = np.quantile(matrix[noninf_mask], 0.001)
82 | if maximum_value is None:
83 | maximum_value = np.quantile(matrix[noninf_mask], 0.999)
84 | plot = plt.imshow(
85 | matrix.numpy(), colormap, vmin=minimum_value, vmax=maximum_value)
86 | if is_colorbar:
87 | _add_scaled_colorbar(plot)
88 | plot.axes.get_xaxis().set_visible(False)
89 | plot.axes.get_yaxis().set_visible(False)
90 | figure.savefig(filename, bbox_inches='tight', dpi=200)
91 | plt.close()
92 |
93 | def plot_points_on_background(points_coordinates,
94 | background,
95 | points_color=[0, 0, 255]):
96 | """
97 | Args:
98 | points_coordinates: array of (y, x) points coordinates
99 | of size (number_of_points x 2).
100 | background: (3 x height x width)
101 | gray or color image uint8.
102 | color: color of points [red, green, blue] uint8.
103 | """
104 | if not (len(background.size()) == 3 and background.size(0) == 3):
105 | raise ValueError('background should be (color x height x width).')
106 | _, height, width = background.size()
107 | background_with_points = background.clone()
108 | y, x = points_coordinates.transpose(0, 1)
109 | x_min, x_max = x.min(), x.max()
110 | y_min, y_max = y.min(), y.max()
111 | if not (x_min >= 0 and y_min >= 0 and x_max < width and y_max < height):
112 | raise ValueError('points coordinates are outsize of "background" '
113 | 'boundries.')
114 | background_with_points[:, y, x] = th.Tensor(points_color).type_as(
115 | background).unsqueeze(-1)
116 | return background_with_points
117 |
118 |
119 | def overlay_image_with_binary_error(color_image, binary_error):
120 | """Returns byte image overlayed with the binary error.
121 |
122 | Contrast of the image is reduced, brightness is incrased,
123 | and locations with the errors are shown in blue.
124 |
125 | Args:
126 | color_image: byte image tensor of size
127 | (color_index, height, width);
128 | binary_error: byte tensor of size (height x width),
129 | where "True"s correspond to error,
130 | and "False"s correspond to absence of error.
131 | """
132 | points_coordinates = th.nonzero(binary_error)
133 | washed_out_image = color_image // 2 + 128
134 | return plot_points_on_background(points_coordinates, washed_out_image)
135 |
136 |
137 | class Logger(object):
138 | def __init__(self, filename):
139 | self._filename = filename
140 |
141 | def log(self, text):
142 | """Appends text line to the file."""
143 | if os.path.isfile(self._filename):
144 | handler = open(self._filename, 'r')
145 | lines = handler.readlines()
146 | handler.close()
147 | else:
148 | lines = []
149 | lines.append(text + '\n')
150 | handler = open(self._filename, 'w')
151 | handler.writelines(lines)
152 | handler.close()
153 |
154 |
155 | def plot_losses_and_errors(filename,
156 | losses,
157 | errors,
158 | righ_y_axis_label='Validation error, [%]'):
159 | """Plots the loss and the error.
160 |
161 | The plot has two y-axis: the left is reserved for the loss
162 | and the right is reserved for the error. The axis have
163 | different scale. The axis and the curve of the loss are shown
164 | in blue and the axis and the curve for the error are shown
165 | in red.
166 |
167 | Args:
168 | filename: image file where plot is saved;
169 | training_loss, validation_error: lists with loss and error values
170 | respectively. Every element of the
171 | list corresponds to an epoch.
172 | """
173 | epochs = range(1, len(losses) + 1)
174 | figure, loss_axis = plt.subplots()
175 | smallest_loss = min(losses)
176 | loss_label = 'Training loss (smallest {0:.3f})'.format(smallest_loss)
177 | loss_plot = loss_axis.plot(epochs, losses, 'bs-', label=loss_label)[0]
178 | loss_axis.set_ylabel('Training loss', color='blue')
179 | loss_axis.set_xlabel('Epoch')
180 | error_axis = loss_axis.twinx()
181 | smallest_error = min(errors)
182 | error_label = 'Validation error (smallest {0:.3f})'.format(smallest_error)
183 | error_plot = error_axis.plot(epochs, errors, 'ro--', label=error_label)[0]
184 | error_axis.set_ylabel(righ_y_axis_label, color='red')
185 | error_axis.legend(handles=[loss_plot, error_plot])
186 | figure.savefig(filename, bbox_inches='tight')
187 | plt.close()
188 |
--------------------------------------------------------------------------------
/utils/warp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.data
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 | import math
7 |
8 | def normalize_coords(grid):
9 | """Normalize coordinates of image scale to [-1, 1]
10 | Args:
11 | grid: [B, 2, H, W]
12 | """
13 | assert grid.size(1) == 2
14 | h, w = grid.size()[2:]
15 | grid[:, 0, :, :] = 2 * (grid[:, 0, :, :].clone() / (w - 1)) - 1 # x: [-1, 1]
16 | grid[:, 1, :, :] = 2 * (grid[:, 1, :, :].clone() / (h - 1)) - 1 # y: [-1, 1]
17 | grid = grid.permute((0, 2, 3, 1)) # [B, H, W, 2]
18 | return grid
19 |
20 | def meshgrid(img, homogeneous=False):
21 | """Generate meshgrid in image scale
22 | Args:
23 | img: [B, _, H, W]
24 | homogeneous: whether to return homogeneous coordinates
25 | Return:
26 | grid: [B, 2, H, W]
27 | """
28 | b, _, h, w = img.size()
29 |
30 | x_range = torch.arange(0, w).view(1, 1, w).expand(1, h, w).type_as(img) # [1, H, W]
31 | y_range = torch.arange(0, h).view(1, h, 1).expand(1, h, w).type_as(img)
32 |
33 | grid = torch.cat((x_range, y_range), dim=0) # [2, H, W], grid[:, i, j] = [j, i]
34 | grid = grid.unsqueeze(0).expand(b, 2, h, w) # [B, 2, H, W]
35 |
36 | if homogeneous:
37 | ones = torch.ones_like(x_range).unsqueeze(0).expand(b, 1, h, w) # [B, 1, H, W]
38 | grid = torch.cat((grid, ones), dim=1) # [B, 3, H, W]
39 | assert grid.size(1) == 3
40 | return grid
41 |
42 | def disp_warp(img, disp, padding_mode='border', interpolate_mode = 'bilinear'):
43 | """Warping by disparity
44 | Args:
45 | img: [B, 3, H, W]
46 | disp: [B, 1, H, W], positive
47 | padding_mode: 'zeros' or 'border'
48 | Returns:
49 | warped_img: [B, 3, H, W]
50 | valid_mask: [B, 3, H, W]
51 | """
52 | # assert disp.min() >= 0
53 |
54 | grid = meshgrid(img) # [B, 2, H, W] in image scale
55 | # Note that -disp here
56 | offset = torch.cat((-disp, torch.zeros_like(disp)), dim=1) # [B, 2, H, W]
57 | sample_grid = grid + offset
58 | sample_grid = normalize_coords(sample_grid) # [B, H, W, 2] in [-1, 1]
59 | warped_img = F.grid_sample(img, sample_grid, mode=interpolate_mode, padding_mode=padding_mode, align_corners=True)
60 |
61 | mask = torch.ones_like(img)
62 | valid_mask = F.grid_sample(mask, sample_grid, mode=interpolate_mode, padding_mode='zeros')
63 | valid_mask[valid_mask < 0.9999] = 0
64 | valid_mask[valid_mask > 0] = 1
65 | return warped_img, valid_mask
66 |
67 |
68 | def flow_warp(feature, flow, mask=False, padding_mode='zeros', interpolate_mode = 'bilinear'):
69 | b, c, h, w = feature.size()
70 | assert flow.size(1) == 2
71 |
72 | grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
73 |
74 | return bilinear_sampler(feature, grid, mode = interpolate_mode, mask=mask, padding_mode=padding_mode)
75 |
76 |
77 | def coords_grid(batch, ht, wd, normalize=False):
78 | if normalize: # [-1, 1]
79 | coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1,
80 | 2 * torch.arange(wd) / (wd - 1) - 1)
81 | else:
82 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
83 | coords = torch.stack(coords[::-1], dim=0).float()
84 |
85 | return coords[None].repeat(batch, 1, 1, 1) # [B, 2, H, W]
86 |
87 | def bilinear_sampler(img, coords, mode='bilinear', mask=False, padding_mode='zeros'):
88 | """ Wrapper for grid_sample, uses pixel coordinates """
89 | if coords.size(-1) != 2: # [B, 2, H, W] -> [B, H, W, 2]
90 | coords = coords.permute(0, 2, 3, 1)
91 |
92 | H, W = img.shape[-2:]
93 | # H = height if height is not None else img.shape[-2]
94 | # W = width if width is not None else img.shape[-1]
95 |
96 | xgrid, ygrid = coords.split([1, 1], dim=-1)
97 |
98 | # To handle H or W equals to 1 by explicitly defining height and width
99 | if H == 1:
100 | assert ygrid.abs().max() < 1e-8
101 | H = 10
102 | if W == 1:
103 | assert xgrid.abs().max() < 1e-8
104 | W = 10
105 |
106 | xgrid = 2 * xgrid / (W - 1) - 1
107 | ygrid = 2 * ygrid / (H - 1) - 1
108 |
109 | grid = torch.cat([xgrid, ygrid], dim=-1)
110 | img = F.grid_sample(img, grid, mode=mode,
111 | padding_mode=padding_mode,
112 | align_corners=True)
113 | # breakpoint()
114 | if mask:
115 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
116 | return img, mask.squeeze(-1).float()
117 |
118 | return img
--------------------------------------------------------------------------------