├── .gitignore ├── .images ├── augmentation.png ├── cmax.png ├── representations.png └── visualisations.png ├── LICENSE ├── README.md ├── __init__.py ├── lib ├── augmentation │ ├── __init__.py │ └── event_augmentation.py ├── contrast_max │ ├── Doxyfile │ ├── __init__.py │ ├── events_cmax.py │ ├── objectives.py │ └── warps.py ├── data_formats │ ├── __init__.py │ ├── add_hdf5_attribute.py │ ├── data_providers.py │ ├── data_utils.py │ ├── event_packagers.py │ ├── h5_to_memmap.py │ ├── read_events.py │ └── rosbag_to_h5.py ├── data_loaders │ ├── __init__.py │ ├── base_dataset.py │ ├── data_augmentation.py │ ├── data_util.py │ ├── dataloader_util.py │ ├── hdf5_dataset.py │ ├── memmap_dataset.py │ └── npy_dataset.py ├── representations │ ├── image.py │ └── voxel_grid.py ├── transforms │ └── optic_flow.py ├── util │ ├── __init__.py │ ├── event_util.py │ └── util.py └── visualization │ ├── __init__.py │ ├── draw_event_stream.py │ ├── draw_event_stream_mayavi.py │ ├── draw_flow.py │ ├── utils │ ├── draw_plane.py │ └── draw_plane_simple.py │ ├── visualization_utils.py │ ├── visualizers.py │ └── visualizers_mayavi.py ├── visualize.py ├── visualize_events.py ├── visualize_flow.py └── visualize_voxel.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Local config folders 2 | config/tt 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | /tmp 8 | */latex 9 | */html 10 | 11 | # C extensions 12 | *.so 13 | data_generator/voxel_generation/build 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # dotenv 88 | .env 89 | 90 | # virtualenv 91 | .venv 92 | venv/ 93 | ENV/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | # input data, saved log, checkpoints 109 | data/ 110 | input/ 111 | saved/ 112 | datasets/ 113 | 114 | # editor, os cache directory 115 | .vscode/ 116 | .idea/ 117 | __MACOSX/ 118 | 119 | # outputs 120 | *.jpg 121 | *.jpeg 122 | *.h5 123 | *.swp 124 | 125 | # dirs 126 | /configs/r2 127 | -------------------------------------------------------------------------------- /.images/augmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimoStoff/event_utils/dc0a0712156bb0c3659d90b33e211fa58a83a75f/.images/augmentation.png -------------------------------------------------------------------------------- /.images/cmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimoStoff/event_utils/dc0a0712156bb0c3659d90b33e211fa58a83a75f/.images/cmax.png -------------------------------------------------------------------------------- /.images/representations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimoStoff/event_utils/dc0a0712156bb0c3659d90b33e211fa58a83a75f/.images/representations.png -------------------------------------------------------------------------------- /.images/visualisations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimoStoff/event_utils/dc0a0712156bb0c3659d90b33e211fa58a83a75f/.images/visualisations.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Timo Stoffregen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | -------------------------------------------------------------------------------- /lib/augmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimoStoff/event_utils/dc0a0712156bb0c3659d90b33e211fa58a83a75f/lib/augmentation/__init__.py -------------------------------------------------------------------------------- /lib/augmentation/event_augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.representations.voxel_grid import events_to_neg_pos_voxel 3 | from lib.data_formats.read_events import read_h5_event_components 4 | from lib.visualization.draw_event_stream import plot_events 5 | from lib.util.event_util import clip_events_to_bounds 6 | import matplotlib.pyplot as plt 7 | 8 | def sample(cdf, ts): 9 | """ 10 | Given a cumulative density function (CDF) and timestamps, draw 11 | a random sample from the CDF then find the index of the corresponding 12 | event. The idea is to allow fair sampling of an event streams timestamps 13 | @param cdf The CDF as np array 14 | @param ts The timestamps to sample from 15 | @returns The index of the sampled event 16 | """ 17 | minval = cdf[0] 18 | maxval = cdf[-1] 19 | rnd = np.random.uniform(minval, maxval) 20 | idx = np.searchsorted(ts, rnd) 21 | return idx 22 | 23 | def events_to_block(xs, ys, ts, ps): 24 | """ 25 | Given events as lists of components, return a 4xN numpy array of the events 26 | where N is the number of events 27 | @param xs x component of events 28 | @param ys y component of events 29 | @param ts t component of events 30 | @param ps p component of events 31 | @returns The block of events 32 | """ 33 | block_events = np.concatenate(( 34 | xs[:,np.newaxis], 35 | ys[:,np.newaxis], 36 | ts[:,np.newaxis], 37 | ps[:,np.newaxis]), axis=1) 38 | return block_events 39 | 40 | def merge_events(event_sets): 41 | """ 42 | Merge multiple sets of events 43 | @param event_sets A list of event streams, where each event strea consists 44 | of four numpy arrays of xs, ys, ts and ps 45 | @returns One merged set of events as tuple: xs, ys, ts, ps 46 | """ 47 | xs,ys,ts,ps = [],[],[],[] 48 | for events in event_sets: 49 | xs.append(events[0]) 50 | ys.append(events[1]) 51 | ts.append(events[2]) 52 | ps.append(events[3]) 53 | merged = events_to_block( 54 | np.concatenate(xs), 55 | np.concatenate(ys), 56 | np.concatenate(ts), 57 | np.concatenate(ps)) 58 | return merged 59 | 60 | def add_random_events(xs, ys, ts, ps, to_add, sensor_resolution=None, 61 | sort=True, return_merged=True): 62 | """ 63 | Add new, random events drawn from a uniform distribution. 64 | Event coordinates are drawn from uniform dist over the sensor resolution and 65 | duration of the events. 66 | @param xs x component of events 67 | @param ys y component of events 68 | @param ts t component of events 69 | @param ps p component of events 70 | @param to_add How many events to add 71 | @param sensor_resolution The resolution of the events. If left None, takes the range 72 | of the spatial coordinates of the imput events 73 | @param sort Sort the output events? 74 | @param return_merged Whether to return the random events separately or merged into 75 | the orginal input events 76 | @returns The random events as tuple: xs, ys, ts, ps 77 | """ 78 | xs_new = np.random.randint(np.max(xs)+1, size=to_add) 79 | ys_new = np.random.randint(np.max(ys)+1, size=to_add) 80 | ts_new = np.random.uniform(np.min(ts), np.max(ts), size=to_add) 81 | ps_new = (np.random.randint(2, size=to_add))*2-1 82 | if return_merged: 83 | new_events = merge_events([[xs_new, ys_new, ts_new, ps_new], [xs, ys, ts, ps]]) 84 | if sort: 85 | new_events.view('i8,i8,i8,i8').sort(order=['f2'], axis=0) 86 | return new_events[:,0], new_events[:,1], new_events[:,2], new_events[:,3], 87 | elif sort: 88 | new_events = events_to_block(xs_new, ys_new, ts_new, ps_new) 89 | new_events.view('i8,i8,i8,i8').sort(order=['f2'], axis=0) 90 | return new_events[:,0], new_events[:,1], new_events[:,2], new_events[:,3], 91 | else: 92 | return xs_new, ys_new, ts_new, ps_new 93 | 94 | def remove_events(xs, ys, ts, ps, to_remove, add_noise=0): 95 | """ 96 | Remove events by random selection 97 | @param xs x component of events 98 | @param ys y component of events 99 | @param ts t component of events 100 | @param ps p component of events 101 | @param to_remove How many events to remove 102 | @param add_noise How many noise events to add (0 by default) 103 | @returns Event stream with events removed as tuple: xs, ys, ts, ps 104 | """ 105 | if to_remove > len(xs): 106 | return np.array([]), np.array([]), np.array([]), np.array([]) 107 | to_select = len(xs)-to_remove 108 | idx = np.random.choice(np.arange(len(xs)), size=to_select, replace=False) 109 | if add_noise <= 0: 110 | idx.sort() 111 | return xs[idx], ys[idx], ts[idx], ps[idx] 112 | else: 113 | nsx, nsy, nst, nsp = add_random_events(xs, ys, ts, ps, add_noise, sort=False, return_merged=False) 114 | new_events = merge_events([[xs[idx], ys[idx], ts[idx], ps[idx]], [nsx, nsy, nst, nsp]]) 115 | new_events.view('i8,i8,i8,i8').sort(order=['f2'], axis=0) 116 | return new_events[:,0], new_events[:,1], new_events[:,2], new_events[:,3], 117 | 118 | def add_correlated_events(xs, ys, ts, ps, to_add, sort=True, return_merged=True, xy_std = 1.5, ts_std = 0.001, add_noise=0): 119 | """ 120 | Add events in the vicinity of existing events. Each original event has a Gaussian bubble 121 | placed around it from which the new events are sampled. 122 | @param ys y component of events 123 | @param ts t component of events 124 | @param ps p component of events 125 | @param to_add How many events to add 126 | @param sort Whether to sort the output events 127 | @param return_merged Whether to return the random events separately or merged into 128 | the orginal input events 129 | @param xy_std Standard deviation of new xy coords 130 | @param ts_std standard deviation of new timestamp 131 | @param add_noise How many random noise events to add (default 0) 132 | @returns Events augemented with correlated events in tuple: xs, ys, ts, ps 133 | """ 134 | iters = int(to_add/len(xs))+1 135 | xs_new, ys_new, ts_new, ps_new = [], [], [], [] 136 | for i in range(iters): 137 | xs_new.append(xs+np.random.normal(scale=xy_std, size=xs.shape).astype(int)) 138 | ys_new.append(ys+np.random.normal(scale=xy_std, size=ys.shape).astype(int)) 139 | ts_new.append(ts+np.random.normal(scale=ts_std, size=ts.shape)) 140 | ps_new.append(ps) 141 | xs_new = np.concatenate(xs_new, axis=0) 142 | ys_new = np.concatenate(ys_new, axis=0) 143 | ts_new = np.concatenate(ts_new, axis=0) 144 | ps_new = np.concatenate(ps_new, axis=0) 145 | idx = np.random.choice(np.arange(len(xs_new)), size=to_add, replace=False) 146 | xs_new = np.clip(xs_new[idx], 0, np.max(xs)) 147 | ys_new = np.clip(ys_new[idx], 0, np.max(ys)) 148 | ts_new = ts_new[idx] 149 | ps_new = ps_new[idx] 150 | nsx, nsy, nst, nsp = add_random_events(xs, ys, ts, ps, add_noise, sort=False, return_merged=False) 151 | if return_merged: 152 | new_events = merge_events([[xs_new, ys_new, ts_new, ps_new], [nsx, nsy, nst, nsp]]) 153 | else: 154 | new_events = events_to_block(xs_new, ys_new, ts_new, ps_new) 155 | if sort: 156 | new_events.view('i8,i8,i8,i8').sort(order=['f2'], axis=0) 157 | return new_events[:,0], new_events[:,1], new_events[:,2], new_events[:,3], 158 | 159 | def flip_events_x(xs, ys, ts, ps, sensor_resolution=(180,240)): 160 | """ 161 | Flip events along x axis 162 | @param xs x component of events 163 | @param ys y component of events 164 | @param ts t component of events 165 | @param ps p component of events 166 | @returns Flipped events 167 | """ 168 | xs = sensor_resolution[1]-xs 169 | return xs, ys, ts, ps 170 | 171 | def flip_events_y(xs, ys, ts, ps, sensor_resolution=(180,240)): 172 | """ 173 | Flip events along y axis 174 | @param xs x component of events 175 | @param ys y component of events 176 | @param ts t component of events 177 | @param ps p component of events 178 | @returns Flipped events 179 | """ 180 | ys = sensor_resolution[0]-ys 181 | return xs, ys, ts, ps 182 | 183 | def crop_events(xs, ys, sensor_resolution, new_resolution): 184 | """ 185 | Crop events to new resolution 186 | @param xs x component of events 187 | @param ys y component of events 188 | @param sensor_resolution Original resolution 189 | @param new_resolution New desired resolution 190 | @returns Events cropped to new resolution as tuple: xs, ys 191 | """ 192 | clip = clip_events_to_bounds(xs, ys, None, None, new_resolution) 193 | return clip[0], clip[1] 194 | 195 | def rotate_events(xs, ys, sensor_resolution=(180,240), 196 | theta_radians=None, center_of_rotation=None, clip_to_range=False): 197 | """ 198 | Rotate events by a given angle around a given center of rotation. 199 | Note that the output events are floating point and may no longer 200 | be in the range of the image sensor. Thus, if 'standard' events are 201 | required, conversion to int and clipping to range may be necessary. 202 | @param xs x component of events 203 | @param ys y component of events 204 | @param sensor_resolution Size of event camera sensor 205 | @param theta_radians Angle of rotation in radians. If left empty, choose random 206 | @param center_of_rotation Center of the rotation. If left empty, choose random 207 | @param clip_to_range If True, remove events that lie outside of image plane after rotation 208 | @returns Rotated event coords and rotation parameters: xs, ys, 209 | theta_radians, center_of_rotation 210 | """ 211 | theta_radians = np.random.uniform(0, 2*3.14159265359) if theta_radians is None else theta_radians 212 | corx = int(np.random.uniform(0, sensor_resolution[1])+1) 213 | cory = int(np.random.uniform(0, sensor_resolution[1])+1) 214 | center_of_rotation = (corx, cory) if center_of_rotation is None else center_of_rotation 215 | 216 | cxs = xs-center_of_rotation[0] 217 | cys = ys-center_of_rotation[1] 218 | new_xs = (cxs*np.cos(theta_radians)-cys*np.sin(theta_radians))+cxs 219 | new_ys = (cxs*np.sin(theta_radians)+cys*np.cos(theta_radians))+cys 220 | if clip_to_range: 221 | clip = clip_events_to_bounds(new_xs, new_ys, None, None, sensor_resolution) 222 | new_xs, new_ys = clip[0], clip[1] 223 | return new_xs, new_ys, theta_radians, center_of_rotation 224 | 225 | if __name__ == "__main__": 226 | """ 227 | Tool to add events to a set of events. 228 | """ 229 | import argparse 230 | import os 231 | parser = argparse.ArgumentParser() 232 | parser.add_argument("path", help="Path to event file") 233 | parser.add_argument("--output_path", default="/tmp/extracted_data", help="Folder where to put augmented events") 234 | parser.add_argument("--to_add", type=float, default=1.0, help="How many more events, as a proportion \ 235 | (eg, 1.5 will result in 150% more events, 0.2 will result in 20% of the events).") 236 | args = parser.parse_args() 237 | out_dir = args.output_path 238 | 239 | xs, ys, ts, ps = read_h5_event_components(args.path) 240 | ys = 180-ys 241 | num = 50000 242 | s = 0#10000 243 | num_to_add = num*2 244 | num_comp=5000 245 | 246 | pth = os.path.join(out_dir, "img0") 247 | plot_events(xs[s:s+num], ys[s:s+num], ts[s:s+num], ps[s:s+num], elev=30, num_compress=num_comp, num_show=-1, save_path=pth, show_axes=True, compress_front=True) 248 | 249 | pth = os.path.join(out_dir, "img1") 250 | nx, ny, nt, npo = add_correlated_events(xs[s:s+num], ys[s:s+num], ts[s:s+num], ps[s:s+num], num_to_add) 251 | plot_events(nx, ny, nt, npo, elev=30, num_compress=num_comp, num_show=-1, save_path=pth, show_axes=True, compress_front=True) 252 | 253 | pth = os.path.join(out_dir, "img3") 254 | nx, ny, nt, npo = add_random_events(xs[s:s+num], ys[s:s+num], ts[s:s+num], ps[s:s+num], num_to_add, sensor_resolution=(180,240)) 255 | plot_events(nx, ny, nt, npo, elev=30, num_compress=num_comp, num_show=-1, save_path=pth, show_axes=True, compress_front=True) 256 | 257 | pth = os.path.join(out_dir, "img4") 258 | nx, ny, nt, npo = remove_events(xs[s:s+num], ys[s:s+num], ts[s:s+num], ps[s:s+num], num//2) 259 | plot_events(nx, ny, nt, npo, elev=30, num_compress=num_comp, num_show=-1, save_path=pth, show_axes=True, compress_front=True) 260 | 261 | pth = os.path.join(out_dir, "img5") 262 | nx, ny, rot, cor = rotate_events(xs[s:s+num], ys[s:s+num], theta_radians=1.4, center_of_rotation=(90, 120), clip_to_range=True) 263 | plot_events(nx, ny, ts, ps, elev=30, num_compress=num_comp, num_show=-1, save_path=pth, show_axes=True, compress_front=True) 264 | 265 | pth = os.path.join(out_dir, "img6") 266 | nx, ny, rot, cor = flip_events_x(xs[s:s+num], ys[s:s+num], ts[s:s+num], ps[s:s+num]) 267 | plot_events(nx, ny, ts, ps, elev=30, num_compress=num_comp, num_show=-1, save_path=pth, show_axes=True, compress_front=True) 268 | -------------------------------------------------------------------------------- /lib/contrast_max/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | from .events_cmax import * 3 | from .warps import * 4 | from .objectives import * 5 | -------------------------------------------------------------------------------- /lib/contrast_max/events_cmax.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import scipy 4 | import scipy.optimize as opt 5 | from scipy.ndimage.filters import gaussian_filter 6 | import torch 7 | import copy 8 | from ..util.event_util import infer_resolution, get_events_from_mask 9 | from ..util.util import plot_image, save_image, plot_image_grid 10 | from ..visualization.draw_event_stream import plot_events 11 | from .objectives import * 12 | from .warps import * 13 | 14 | def get_hsv_shifted(): 15 | """ 16 | Get the colormap used in Mitrokhin etal, Event-based Moving Object Detection and Tracking 17 | """ 18 | from matplotlib import cm 19 | from matplotlib.colors import LinearSegmentedColormap 20 | 21 | hsv = cm.get_cmap('hsv') 22 | hsv_shifted = [] 23 | for i in np.arange(0, 0.6666, 0.01): 24 | hsv_shifted.append(hsv(np.fmod(i+0.6666, 1.0))) 25 | hsv_shifted = LinearSegmentedColormap.from_list('hsv_shifted', hsv_shifted, N=100) 26 | return hsv_shifted 27 | 28 | def grid_cmax(xs, ys, ts, ps, roi_size=(20,20), step=None, warp=linvel_warp(), 29 | obj=variance_objective(adaptive_lifespan=True, minimum_events=105), 30 | min_events=10): 31 | """ 32 | Break sensor into a grid and perform contrast maximisation on each sector of grid 33 | separately. Main input parameters are the events and the size of each window of the 34 | grid (roi_size) 35 | @param xs x components of events as list 36 | @param ys y components of events as list 37 | @param ts t components of events as list 38 | @param ps p components of events as list 39 | @param roi_size The size of the grid regions of interest (rois) 40 | @param step The sliding window step size (same as roi_size if left empty) 41 | @param warp The warp function to be used 42 | @param The objective fuction to be used 43 | @param The min number of events in a ROI to be considered valid 44 | @returns List of optimal parameters, optimal function evaluations and rois 45 | """ 46 | step = roi_size if step is None else step 47 | resolution = infer_resolution(xs, ys) 48 | warpfunc = linvel_warp() 49 | 50 | results_params = [] 51 | results_rois = [] 52 | results_f_evals = [] 53 | for xc in range(0, resolution[1], step[1]): 54 | x_roi_idc = np.argwhere((xs>=xc) & (xs=yc) & (y_subset min_events: 65 | obj = variance_objective(adaptive_lifespan=True, minimum_events=105) 66 | params = optimize_contrast(roi_xs, roi_ys, roi_ts, roi_ps, warp, obj, numeric_grads=False, blur_sigma=2.0, img_size=resolution, grid_search_init=True) 67 | params = optimize_contrast(roi_xs, roi_ys, roi_ts, roi_ps, warp, obj, numeric_grads=False, blur_sigma=1.0, img_size=resolution, x0=params) 68 | iwe, d_iwe = get_iwe(params, xs, ys, ts, ps, warp, resolution, 69 | use_polarity=True, compute_gradient=False, return_events=False) 70 | f_eval = obj.evaluate_function(iwe=iwe) 71 | 72 | results_params.append(params) 73 | results_rois.append([yc, xc, step[0], step[1]]) 74 | results_f_evals.append(obj.evaluate_function(iwe=iwe)) 75 | 76 | return results_params, results_rois, results_f_evals 77 | 78 | def segmentation_mask_from_d_iwe(d_iwe, th=None): 79 | """ 80 | Generate a segmentation mask from the derivative of the IWE wrt motion params 81 | @param d_iwe First derivative of IWE wrt motion parameters 82 | @param th Value threshold for segmentation mask, auto generated if left blank 83 | @returns Segmentation mask 84 | """ 85 | th1 = np.percentile(np.abs(d_iwe), 90) 86 | validx = d_iwe[0].flatten()[np.argwhere(np.abs(d_iwe[0].flatten()) > th1).squeeze()] 87 | validy = d_iwe[1].flatten()[np.argwhere(np.abs(d_iwe[1].flatten()) > th1).squeeze()] 88 | x_c = np.percentile(validx, 95) 89 | y_c = np.percentile(validy, 95) 90 | 91 | thx = x_c if th is None else th 92 | thy = y_c if th is None else th 93 | 94 | imgxp = np.where(d_iwe[0] > thx, 1, 0) 95 | imgyp = np.where(d_iwe[1] > thy, 1, 0) 96 | imgxn = np.where(d_iwe[0] < -thx, 1, 0) 97 | imgyn = np.where(d_iwe[1] < -thy, 1, 0) 98 | imgx = imgxp + imgxn 99 | imgy = imgyp + imgyn 100 | img = np.clip(np.add(imgx, imgy), 0, 1) 101 | return img 102 | 103 | def draw_objective_function(xs, ys, ts, ps, objective=variance_objective(minimum_events=1), 104 | warpfunc=linvel_warp(), x_range=(-200, 200), y_range=(-200, 200), 105 | gt=(0,0), show_gt=True, resolution=20, img_size=(180, 240), show_axes=True, norm_min=None, norm_max=None, 106 | show=True): 107 | """ 108 | Draw the objective function given by sampling over a range. Depending on the value of resolution, this 109 | can involve many samples and take some time. 110 | @param xs x components of events as np array 111 | @param ys y components of events as np array 112 | @param ts t components of events as np array 113 | @param ps p components of events as np array 114 | @param objective (object) The objective function 115 | @param warpfunc (object) The warp function 116 | @param x_range, y_range (tuple) the range over which to plot the parameters 117 | @param gt (tuple) The ground truth 118 | @param show_gt (bool) Whether to draw the ground truth in 119 | @param resolution (float) The resolution of the sampling 120 | @param img_size (tuple) The image sensor size 121 | """ 122 | width = x_range[1]-x_range[0] 123 | height = y_range[1]-y_range[0] 124 | print("Drawing objective function. Taking {} samples".format((width*height)/resolution)) 125 | imshape = (int(height/resolution+0.5), int(width/resolution+0.5)) 126 | img = np.zeros(imshape) 127 | for x in range(img.shape[1]): 128 | for y in range(img.shape[0]): 129 | params = np.array([x*resolution+x_range[0], y*resolution+y_range[0]]) 130 | img[y,x] = -objective.evaluate_function(params, xs, ys, ts, ps, warpfunc, img_size, blur_sigma=0) 131 | norm_min = np.min(img) if norm_min is None else norm_min 132 | norm_max = np.max(img) if norm_max is None else norm_max 133 | img = (img-norm_min)/((norm_max-norm_min)+1e-6) 134 | #img = cv.normalize(img, None, 0, 1.0, cv.NORM_MINMAX) 135 | plt.imshow(img, interpolation='bilinear', cmap='viridis') 136 | if not show_axes: 137 | plt.xticks([]) 138 | plt.yticks([]) 139 | else: 140 | xt = plt.xticks()[0][1:-1] 141 | xticklabs = np.linspace(x_range[0], x_range[1], len(xt)) 142 | xticklabs = ["{}".format(int(x)) for x in xticklabs] 143 | 144 | yt = plt.yticks()[0][1:-1] 145 | yticklabs = np.linspace(y_range[0], y_range[1], len(yt)) 146 | yticklabs = ["{}".format(int(y)) for y in yticklabs] 147 | 148 | plt.xticks(ticks=xt, labels=xticklabs) 149 | plt.yticks(ticks=yt, labels=yticklabs) 150 | 151 | plt.xlabel("$v_x$") 152 | plt.ylabel("$v_y$") 153 | 154 | if show_gt: 155 | xloc = ((gt[0]-x_range[0])/(width))*imshape[1] 156 | yloc = ((gt[1]-y_range[0])/(height))*imshape[0] 157 | plt.axhline(y=yloc, color='r', linestyle='--') 158 | plt.axvline(x=xloc, color='r', linestyle='--') 159 | if show: 160 | plt.show() 161 | 162 | def find_new_range(search_axes, param): 163 | """ 164 | During grid search, we need to find a new search range once we have located 165 | an optimal parameter. This function gives us a new search range for a given axis 166 | of the search space, given a parameter value, such that all the unsearched domain around 167 | that parameter is encompassed. 168 | @param search_axes The previous set of samples along one axis of the search space 169 | @param The current motion parameter 170 | @returns The new parameter search range 171 | """ 172 | magnitude = np.abs(param) 173 | nearest_idx = np.searchsorted(search_axes, param) 174 | if nearest_idx >= len(search_axes)-1: 175 | d1 = np.abs(search_axes[-1]-search_axes[-2]) 176 | d2 = d1 177 | elif nearest_idx == 0: 178 | d1 = np.abs(search_axes[0]-search_axes[-1]) 179 | d2 = np.abs(search_axes[0]-search_axes[1]) 180 | else: 181 | d1 = np.abs(search_axes[nearest_idx]-search_axes[nearest_idx-1]) 182 | d2 = np.abs(search_axes[nearest_idx]-search_axes[nearest_idx+1]) 183 | param_range = [param-d1, param+d2] 184 | return param_range 185 | 186 | def grid_search_optimisation(xs, ys, ts, ps, warp_function, objective_function, img_size, param_ranges=None, 187 | log_scale=True, num_samples_per_param=5, depth=0, th0=1, max_iters=20): 188 | """ 189 | Recursive grid-search optimization as per SOFAS. For each axis of the parameter space, samples that 190 | space evenly. Having found the best point in the space, resamples the region surrounding that point, 191 | expanding the range if necessary. Continues to do this until convergence (search space is smaller than 192 | th0) or until iterations exceed max_iters. Can select to logarithmically sample the search space (ie 193 | samples are taken more densely near the origin). 194 | 195 | @param xs x components of events as np array 196 | @param ys y components of events as np array 197 | @param ts t components of events as np array 198 | @param ps p components of events as np array 199 | @param warp_function The warp function to use 200 | @param objective_function The objective function to use 201 | @param img_size The size of the event camera sensor 202 | @param param_ranges A list of lists, where each list contains the search range for 203 | the given warp function parameter. If None, the default is to search from -100 to 100 for 204 | each parameter. 205 | @param log_scale If true, the sample points are drawn from a log scale. This means that 206 | the parameter space is searched more frequently near the origin and less frequently at 207 | the fringes. 208 | @param num_samples_per_param How many samples to take per parameter. The number of evaluations 209 | this method needs to perform is equal to num_samples_per_param^warp_function.dims. Thus, 210 | for high dimensional warp functions, it is advised to keep this value low. Must be greater 211 | than 5 and odd. 212 | @param depth Keeps track of the recursion depth 213 | @param th0 When the subgrid search radius is smaller than th0, convergence is reached. 214 | @param max_iters Maximum number of iterations 215 | @returns The optimal parameter 216 | """ 217 | assert num_samples_per_param%2==1 and num_samples_per_param>=5 218 | 219 | optimal = grid_search_initial(xs, ys, ts, ps, warp_function, copy.deepcopy(objective_function), 220 | img_size, param_ranges=param_ranges, log_scale=log_scale, 221 | num_samples_per_param=num_samples_per_param) 222 | 223 | params = optimal["min_params"] 224 | new_param_ranges = [] 225 | max_range = 0 226 | # Iterate over each search axis and each element of the 227 | # optimal parameter to find new search range 228 | for sa, param in zip(optimal["search_axes"], params): 229 | new_range = find_new_range(sa, param) 230 | new_param_ranges.append(new_range) 231 | max_range = np.abs(new_range[1]-new_range[0]) if np.abs(new_range[1]-new_range[0]) > max_range else max_range 232 | if max_range >= th0 and depth < max_iters: 233 | return recursive_search(xs,ys,ts,ps,warp_function,objective_function,img_size, 234 | param_ranges=new_param_ranges, log_scale=log_scale, 235 | num_samples_per_param=num_samples_per_param, depth=depth+1) 236 | else: 237 | return optimal 238 | 239 | 240 | 241 | def grid_search_initial(xs, ys, ts, ps, warp_function, objective_function, img_size, param_ranges=None, 242 | log_scale=True, num_samples_per_param=5): 243 | """ 244 | Recursive grid-search optimization as per SOFAS. Given a set of ranges for each parametrisation axis, 245 | searches that range at evenly sampled points. Can also use a logarithmically sampled space (samples are 246 | denser near the origin) if desired. 247 | 248 | @param xs x components of events as np array 249 | @param ys y components of events as np array 250 | @param ts t components of events as np array 251 | @param ps p components of events as np array 252 | @param warp_function The warp function to use 253 | @param objective_function The objective function to use 254 | @param img_size The size of the event camera sensor 255 | @param param_ranges A list of lists, where each list contains the search range for 256 | the given warp function parameter. If None, the default is to search from -100 to 100 for 257 | each parameter. 258 | @param log_scale If true, the sample points are drawn from a log scale. This means that 259 | the parameter space is searched more frequently near the origin and less frequently at 260 | the fringes. 261 | @param num_samples_per_param How many samples to take per parameter. The number of evaluations 262 | this method needs to perform is equal to num_samples_per_param^warp_function.dims. Thus, 263 | for high dimensional warp functions, it is advised to keep this value low. Must be greater 264 | than 5 and odd. 265 | @returns optimal is a dict with keys 'params' (the list of sampling coordinates used), 266 | 'eval' (the evaluation at each sample coordinate), 'search_axes' (the sample coordinates on each parameter axis), 267 | 'min_params' (the best parameter, minimsing the optimisation problem) and 'min_func_eval' (the function value at 268 | the best parameter). 269 | """ 270 | assert num_samples_per_param%2 == 1 271 | 272 | if log_scale: 273 | #Function is sampled from 10^x from 0 to 2 274 | scale = np.logspace(0, 2.0, int(num_samples_per_param/2.0)+1)[1:] 275 | scale /= scale[-1] 276 | else: 277 | scale = np.linspace(0, 1.0, int(num_samples_per_param/2.0)+1)[1:] 278 | 279 | # If the parameter ranges are empty, intialise them 280 | if param_ranges is None: 281 | param_ranges = [] 282 | for i in range(warp_function.dims): 283 | param_ranges.append([-150, 150]) 284 | 285 | axes = [] 286 | for param_range in param_ranges: 287 | rng = param_range[1]-param_range[0] 288 | mid = param_range[0] + rng/2.0 289 | rescale_pos = np.array(mid+scale*(rng/2.0)) 290 | rescale_neg = np.array(mid-scale*(rng/2.0))[::-1] 291 | rescale = np.concatenate((rescale_neg, np.array([mid]), rescale_pos)) 292 | axes.append(rescale) 293 | grids = np.meshgrid(*axes) 294 | coords = np.vstack(map(np.ravel, grids)) 295 | 296 | output = {"params":[], "eval": [], "search_axes": axes} 297 | best_eval = 0 298 | best_params = None 299 | 300 | for params in zip(*coords): 301 | f_eval = objective_function.evaluate_function(params=params, xs=xs, ys=ys, ts=ts, ps=ps, 302 | warpfunc=warp_function, img_size=img_size, blur_sigma=1.0) 303 | output["params"].append(params) 304 | output["eval"].append(f_eval) 305 | if f_eval < best_eval: 306 | best_eval = f_eval 307 | best_params = params 308 | 309 | output["min_params"] = best_params 310 | output["min_func_eval"] = best_eval 311 | return output 312 | 313 | def optimize_contrast(xs, ys, ts, ps, warp_function, objective, optimizer=opt.fmin_bfgs, x0=None, 314 | numeric_grads=False, blur_sigma=None, img_size=(180, 240), grid_search_init=False, minimum_events=200): 315 | """ 316 | Optimize contrast for a set of events using gradient based optimiser 317 | @param xs x components of events as np array 318 | @param ys y components of events as np array 319 | @param ts t components of events as np array 320 | @param ps p components of events as np array 321 | @param warp_function (function) The function with which to warp the events 322 | @param objective (objective class object) The objective to optimize 323 | @param optimizer (function) The optimizer to use 324 | @param x0 (np array) The initial guess for optimization 325 | @param numeric_grads (bool) If true, use numeric derivatives, otherwise use analytic drivatives if available. 326 | Numeric grads tend to be more stable as they are a little less prone to noise and don't require as much 327 | tuning on the blurring parameter. However, they do make optimization slower. 328 | @param img_size (tuple) The size of the event camera sensor 329 | @param blur_sigma (float) Size of the blurring kernel. Blurring the images of warped events can 330 | have a large impact on the convergence of the optimization. 331 | @returns The max arguments for the warp parameters wrt the objective 332 | """ 333 | if grid_search_init and x0 is None: 334 | init_obj = copy.deepcopy(objective) 335 | init_obj.adaptive_lifespan = False 336 | minv = recursive_search(xs, ys, ts, ps, warp_function, init_obj, img_size, log_scale=False) 337 | x0 = minv["min_params"] 338 | elif x0 is None: 339 | x0 = np.array([0,0]) 340 | objective.iter_update(x0) 341 | args = (xs, ys, ts, ps, warp_function, img_size, blur_sigma) 342 | if numeric_grads: 343 | argmax = optimizer(objective.evaluate_function, x0, args=args, epsilon=1, disp=False, callback=objective.iter_update) 344 | else: 345 | argmax = optimizer(objective.evaluate_function, x0, fprime=objective.evaluate_gradient, args=args, disp=True, callback=objective.iter_update) 346 | return argmax 347 | 348 | def optimize(xs, ys, ts, ps, warp, obj, numeric_grads=True, img_size=(180, 240)): 349 | """ 350 | Optimize contrast for a set of events using gradient based optimiser. 351 | Uses optimize_contrast() for the optimiziation, but allows 352 | blurring schedules for successive optimization iterations. 353 | Parameters: 354 | @param xs x components of events as np array 355 | @param ys y components of events as np array 356 | @param ts t components of events as np array 357 | @param ps p components of events as np array 358 | @params warp (function) The function with which to warp the events 359 | @params obj (objective class object) The objective to optimize 360 | @params numeric_grads (bool) If true, use numeric derivatives, otherwise use analytic drivatives if available. 361 | Numeric grads tend to be more stable as they are a little less prone to noise and don't require as much 362 | tuning on the blurring parameter. However, they do make optimization slower. 363 | @params img_size (tuple) The size of the event camera sensor 364 | @returns The max arguments for the warp parameters wrt the objective 365 | """ 366 | numeric_grads = numeric_grads if obj.has_derivative else True 367 | argmax_an = optimize_contrast(xs, ys, ts, ps, warp, obj, numeric_grads=numeric_grads, blur_sigma=1.0, img_size=img_size) 368 | return argmax_an 369 | 370 | def optimize_r2(xs, ys, ts, ps, warp, obj, numeric_grads=True, img_size=(180, 240)): 371 | """ 372 | Optimize contrast for a set of events, finishing with SoE loss. 373 | @param xs x components of events as np array 374 | @param ys y components of events as np array 375 | @param ts t components of events as np array 376 | @param ps p components of events as np array 377 | @param warp (function) The function with which to warp the events 378 | @param obj (objective class object) The objective to optimize 379 | @param numeric_grads (bool) If true, use numeric derivatives, otherwise use analytic drivatives if available. 380 | Numeric grads tend to be more stable as they are a little less prone to noise and don't require as much 381 | tuning on the blurring parameter. However, they do make optimization slower. 382 | @param img_size (tuple) The size of the event camera sensor 383 | @returns The max arguments for the warp parameters wrt the objective 384 | """ 385 | soe_obj = soe_objective() 386 | numeric_grads = numeric_grads if obj.has_derivative else True 387 | argmax_an = optimize_contrast(xs, ys, ts, ps, warp, obj, numeric_grads=numeric_grads, blur_sigma=None) 388 | argmax_an = optimize_contrast(xs, ys, ts, ps, warp, soe_obj, x0=argmax_an, numeric_grads=numeric_grads, blur_sigma=1.0) 389 | return argmax_an 390 | 391 | if __name__ == "__main__": 392 | """ 393 | Quick demo of various objectives. 394 | Args: 395 | path Path to h5 file with event data 396 | gt Ground truth optic flow for event slice 397 | img_size The size of the event camera sensor 398 | """ 399 | import argparse 400 | parser = argparse.ArgumentParser() 401 | parser.add_argument("path", help="h5 events path") 402 | parser.add_argument("--gt", nargs='+', type=float, default=(0,0)) 403 | parser.add_argument("--img_size", nargs='+', type=float, default=(180,240)) 404 | args = parser.parse_args() 405 | 406 | xs, ys, ts, ps = read_h5_event_components(args.path) 407 | ts = ts-ts[0] 408 | gt_params = tuple(args.gt) 409 | img_size=tuple(args.img_size) 410 | 411 | start_idx = 20000 412 | end_idx=start_idx+15000 413 | blur = None 414 | 415 | draw_objective_function(xs[start_idx:end_idx], ys[start_idx:end_idx], ts[start_idx:end_idx], ps[start_idx:end_idx], variance_objective(), linvel_warp()) 416 | 417 | objectives = [r1_objective(), zhu_timestamp_objective(), variance_objective(), sos_objective(), soe_objective(), moa_objective(), 418 | isoa_objective(), sosa_objective(), rms_objective()] 419 | warp = linvel_warp() 420 | for obj in objectives: 421 | argmax = optimize(xs[start_idx:end_idx], ys[start_idx:end_idx], ts[start_idx:end_idx], ps[start_idx:end_idx], warp, obj, numeric_grads=True) 422 | loss = obj.evaluate_function(argmax, xs[start_idx:end_idx], ys[start_idx:end_idx], ts[start_idx:end_idx], 423 | ps[start_idx:end_idx], warp, img_size=img_size) 424 | gtloss = obj.evaluate_function(gt_params, xs[start_idx:end_idx], ys[start_idx:end_idx], 425 | ts[start_idx:end_idx], ps[start_idx:end_idx], warp, img_size=img_size) 426 | print("{}:({})={}, gt={}".format(obj.name, argmax, loss, gtloss)) 427 | if obj.has_derivative: 428 | argmax = optimize(xs[start_idx:end_idx], ys[start_idx:end_idx], ts[start_idx:end_idx], 429 | ps[start_idx:end_idx], warp, obj, numeric_grads=False) 430 | loss_an = obj.evaluate_function(argmax, xs[start_idx:end_idx], ys[start_idx:end_idx], 431 | ts[start_idx:end_idx], ps[start_idx:end_idx], warp, img_size=img_size) 432 | print(" analytical:{}={}".format(argmax, loss_an)) 433 | -------------------------------------------------------------------------------- /lib/contrast_max/warps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from event_utils import * 4 | from abc import ABC, abstractmethod 5 | 6 | class warp_function(ABC): 7 | """ 8 | Base class for objects that can warp events to a reference time 9 | via a parametrizeable, differentiable motion model 10 | """ 11 | def __init__(self, name, dims): 12 | """ 13 | Constructor. 14 | @param name The name of the warp function (eg 'optic flow') 15 | @param dims The number of degrees of freedom of the motion model 16 | """ 17 | self.name = name 18 | self.dims = dims 19 | super().__init__() 20 | 21 | @abstractmethod 22 | def warp(self, xs, ys, ts, ps, t0, params, compute_grad=False): 23 | """ 24 | Warp function which given a set of events and a reference time, 25 | moves the events to that reference time via a motion model 26 | @param xs x components of events as list 27 | @param ys y components of events as list 28 | @param ts t components of events as list 29 | @param ps p components of events as list 30 | @param t0 The reference time to which to warp the events to 31 | @param params The parameters of the motion model for 32 | which to warp the events 33 | @param compute_grad If True, compute the gradient of the warp with 34 | respect to the motion parameters for each event (the Jacobian) 35 | @returns xs_warped, ys_warped, xs_jacobian, ys_jacobian: The warped 36 | event locations and the gradients for each event as a tuple of four 37 | numpy arrays 38 | """ 39 | #Warp the events... 40 | #if compute_grad: 41 | # compute the jacobian of the warp function 42 | pass 43 | 44 | class linvel_warp(warp_function): 45 | """ 46 | This class implements linear velocity warping (global optic flow) 47 | """ 48 | def __init__(self): 49 | warp_function.__init__(self, 'linvel_warp', 2) 50 | 51 | def warp(self, xs, ys, ts, ps, t0, params, compute_grad=False): 52 | dt = ts-t0 53 | x_prime = xs-dt*params[0] 54 | y_prime = ys-dt*params[1] 55 | jacobian_x, jacobian_y = None, None 56 | if compute_grad: 57 | jacobian_x = np.zeros((2, len(x_prime))) 58 | jacobian_y = np.zeros((2, len(y_prime))) 59 | jacobian_x[0, :] = -dt 60 | jacobian_y[1, :] = -dt 61 | return x_prime, y_prime, jacobian_x, jacobian_y 62 | 63 | class xyztheta_warp(warp_function): 64 | """ 65 | This class implements 4-DoF x,y,z,rotation warps from Mitrokhin etal, 66 | "Event-based moving object detection and tracking" 67 | """ 68 | def __init__(self): 69 | warp_function.__init__(self, 'xyztheta_warp', 4) 70 | 71 | def warp(self, xs, ys, ts, ps, t0, params, compute_grad=False): 72 | pass 73 | 74 | class pure_rotation_warp(warp_function): 75 | """ 76 | This class implements pure rotation warps, with params 77 | x,y,theta (x,y is center of rotation, theta is angular velocity 78 | """ 79 | def __init__(self): 80 | warp_function.__init__(self, 'pure_rotation_warp', 4) 81 | {not:timeslice} 82 | def warp(self, xs, ys, ts, ps, t0, params, compute_grad=False): 83 | pass 84 | -------------------------------------------------------------------------------- /lib/data_formats/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | from .data_utils import * 3 | from .read_events import * 4 | -------------------------------------------------------------------------------- /lib/data_formats/add_hdf5_attribute.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import h5py 4 | import os 5 | import glob 6 | 7 | def endswith(path, extensions): 8 | for ext in extensions: 9 | if path.endswith(ext): 10 | return True 11 | return False 12 | 13 | def get_filepaths_from_path_or_file(path, extensions=[], datafile_extensions=[".txt", ".csv"]): 14 | files = [] 15 | path = path.rstrip("/") 16 | if os.path.isdir(path): 17 | for ext in extensions: 18 | files += sorted(glob.glob("{}/*{}".format(path, ext))) 19 | else: 20 | if endswith(path, extensions): 21 | files.append(path) 22 | elif endswith(path, datafile_extensions): 23 | with open(path, 'r') as f: 24 | #files.append(line) for line in f.readlines 25 | files = [line.strip() for line in f.readlines()] 26 | return files 27 | 28 | def add_attribute(h5_filepaths, group, attribute_name, attribute_value, dry_run=False): 29 | for h5_filepath in h5_filepaths: 30 | print("adding {}/{}[{}]={}".format(h5_filepath, group, attribute_name, attribute_value)) 31 | if dry_run: 32 | continue 33 | h5_file = h5py.File(h5_filepath, 'a') 34 | dset = h5_file["{}/".format(group)] 35 | dset.attrs[attribute_name] = attribute_value 36 | h5_file.close() 37 | 38 | if __name__ == "__main__": 39 | # arguments 40 | parser = argparse.ArgumentParser() 41 | parser._action_groups.pop() 42 | required = parser.add_argument_group('required arguments') 43 | optional = parser.add_argument_group('optional arguments') 44 | 45 | required.add_argument("--path", help="Can be either 1: path to individual hdf file, " + 46 | "2: txt file with list of hdf files, or " + 47 | "3: directory (all hdf files in directory will be processed).", required=True) 48 | required.add_argument("--attr_name", help="Name of new attribute", required=True) 49 | required.add_argument("--attr_val", help="Value of new attribute", required=True) 50 | optional.add_argument("--group", help="Group to add attribute to. Subgroups " + 51 | "are represented like paths, eg: /group1/subgroup2...", default="") 52 | optional.add_argument("--dry_run", default=0, type=int, 53 | help="If set to 1, will print changes without performing them") 54 | 55 | args = parser.parse_args() 56 | path = args.path 57 | extensions = [".hdf", ".h5"] 58 | files = get_filepaths_from_path_or_file(path, extensions=extensions) 59 | print(files) 60 | dry_run = False if args.dry_run <= 0 else True 61 | add_attribute(files, args.group, args.attr_name, args.attr_val, dry_run=dry_run) 62 | -------------------------------------------------------------------------------- /lib/data_formats/data_providers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | 4 | 5 | class BaseDataLoader(): 6 | def __init__(self, data_root, iter_method='between_frames'): 7 | pass 8 | 9 | def __getitem__(): 10 | pass 11 | 12 | def __len__(self): 13 | return self.length 14 | -------------------------------------------------------------------------------- /lib/data_formats/data_utils.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | 4 | def binary_search_h5_dset(dset, x, l=None, r=None, side='left'): 5 | l = 0 if l is None else l 6 | r = len(dset)-1 if r is None else r 7 | while l <= r: 8 | mid = l + (r - l)//2; 9 | midval = dset[mid] 10 | if midval == x: 11 | return mid 12 | elif midval < x: 13 | l = mid + 1 14 | else: 15 | r = mid - 1 16 | if side == 'left': 17 | return l 18 | return r 19 | 20 | def binary_search_h5_timestamp(hdf_path, l, r, x, side='left'): 21 | f = h5py.File(hdf_path, 'r') 22 | return binary_search_h5_dset(f['events/ts'], x, l=l, r=r, side=side) 23 | -------------------------------------------------------------------------------- /lib/data_formats/event_packagers.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import h5py 3 | import cv2 as cv 4 | import numpy as np 5 | 6 | class packager(): 7 | """ 8 | Abstract base class for classes that package event-based data to 9 | some storage format 10 | """ 11 | __metaclass__ = ABCMeta 12 | 13 | def __init__(self, name, output_path, max_buffer_size=1000000): 14 | """ 15 | Set class attributes 16 | @param name The name of the packager (eg: txt_packager) 17 | @param output_path Where to dump event data 18 | @param max_buffer_size For packagers that buffer data prior to 19 | writing, how large this buffer may maximally be 20 | """ 21 | self.name = name 22 | self.output_path = output_path 23 | self.max_buffer_size = max_buffer_size 24 | 25 | @abstractmethod 26 | def package_events(self, xs, ys, ts, ps): 27 | """ 28 | Given events, write them to the file/store them into the buffer 29 | @param xs x component of events 30 | @param ys y component of events 31 | @param ts t component of events 32 | @param ps p component of events 33 | @returns None 34 | """ 35 | pass 36 | 37 | @abstractmethod 38 | def package_image(self, frame, timestamp): 39 | """ 40 | Given an image, write it to the file/buffer 41 | @param frame The image frame to write to the file/buffer 42 | @param timestamp The timestamp of the frame 43 | @returns None 44 | """ 45 | pass 46 | 47 | @abstractmethod 48 | def package_flow(self, flow, timestamp): 49 | """ 50 | Given an optic flow image, write it to the file/buffer 51 | @param frame The optic flow image frame to write to the file/buffer 52 | @param timestamp The timestamp of the optic flow frame 53 | @returns None 54 | """ 55 | pass 56 | 57 | @abstractmethod 58 | def add_metadata(self, num_events, num_pos, num_neg, 59 | duration, t0, tk, num_imgs, num_flow): 60 | """ 61 | Add metadata to the file 62 | @param num_events The number of events in the sequence 63 | @param num_pos The numer of positive events in the sequence 64 | @param num_neg The numer of negative events in the sequence 65 | @param duration The length of the sequence in seconds 66 | @param t0 The start time of the sequence 67 | @param tk The end time of the sequence 68 | @param num_imgs The number of images in the sequence 69 | @param num_flow The number of optic flow frames in the sequence 70 | """ 71 | pass 72 | 73 | @abstractmethod 74 | def set_data_available(self, num_images, num_flow): 75 | """ 76 | Configure the file/buffers depending on which data needs to be written 77 | @param num_images How many images in the dataset 78 | @param num_flow How many optic flow frames in the dataset 79 | """ 80 | pass 81 | 82 | class hdf5_packager(packager): 83 | """ 84 | This class packages data to hdf5 files 85 | """ 86 | def __init__(self, output_path, max_buffer_size=1000000): 87 | packager.__init__(self, 'hdf5', output_path, max_buffer_size) 88 | print("CREATING FILE IN {}".format(output_path)) 89 | self.events_file = h5py.File(output_path, 'w') 90 | self.event_xs = self.events_file.create_dataset("events/xs", (0, ), dtype=np.dtype(np.int16), maxshape=(None, ), chunks=True) 91 | self.event_ys = self.events_file.create_dataset("events/ys", (0, ), dtype=np.dtype(np.int16), maxshape=(None, ), chunks=True) 92 | self.event_ts = self.events_file.create_dataset("events/ts", (0, ), dtype=np.dtype(np.float64), maxshape=(None, ), chunks=True) 93 | self.event_ps = self.events_file.create_dataset("events/ps", (0, ), dtype=np.dtype(np.bool_), maxshape=(None, ), chunks=True) 94 | 95 | def append_to_dataset(self, dataset, data): 96 | dataset.resize(dataset.shape[0] + len(data), axis=0) 97 | if len(data) == 0: 98 | return 99 | dataset[-len(data):] = data[:] 100 | 101 | def package_events(self, xs, ys, ts, ps): 102 | self.append_to_dataset(self.event_xs, xs) 103 | self.append_to_dataset(self.event_ys, ys) 104 | self.append_to_dataset(self.event_ts, ts) 105 | self.append_to_dataset(self.event_ps, ps) 106 | 107 | def package_image(self, image, timestamp, img_idx): 108 | image_dset = self.events_file.create_dataset("images/image{:09d}".format(img_idx), 109 | data=image, dtype=np.dtype(np.uint8)) 110 | image_dset.attrs['size'] = image.shape 111 | image_dset.attrs['timestamp'] = timestamp 112 | image_dset.attrs['type'] = "greyscale" if image.shape[-1] == 1 or len(image.shape) == 2 else "color_bgr" 113 | 114 | def package_flow(self, flow_image, timestamp, flow_idx): 115 | flow_dset = self.events_file.create_dataset("flow/flow{:09d}".format(flow_idx), 116 | data=flow_image, dtype=np.dtype(np.float32)) 117 | flow_dset.attrs['size'] = flow_image.shape 118 | flow_dset.attrs['timestamp'] = timestamp 119 | 120 | def add_event_indices(self): 121 | datatypes = ['images', 'flow'] 122 | for datatype in datatypes: 123 | if datatype in self.events_file.keys(): 124 | s = 0 125 | added = 0 126 | ts = self.events_file["events/ts"][s:s+self.max_buffer_size] 127 | for image in self.events_file[datatype]: 128 | img_ts = self.events_file[datatype][image].attrs['timestamp'] 129 | event_idx = np.searchsorted(ts, img_ts) 130 | if event_idx == len(ts): 131 | added += len(ts) 132 | s += self.max_buffer_size 133 | ts = self.events_file["events/ts"][s:s+self.max_buffer_size] 134 | event_idx = np.searchsorted(ts, img_ts) 135 | event_idx = max(0, event_idx-1) 136 | self.events_file[datatype][image].attrs['event_idx'] = event_idx + added 137 | 138 | def add_metadata(self, num_pos, num_neg, 139 | duration, t0, tk, num_imgs, num_flow, sensor_size): 140 | self.events_file.attrs['num_events'] = num_pos+num_neg 141 | self.events_file.attrs['num_pos'] = num_pos 142 | self.events_file.attrs['num_neg'] = num_neg 143 | self.events_file.attrs['duration'] = tk-t0 144 | self.events_file.attrs['t0'] = t0 145 | self.events_file.attrs['tk'] = tk 146 | self.events_file.attrs['num_imgs'] = num_imgs 147 | self.events_file.attrs['num_flow'] = num_flow 148 | self.events_file.attrs['sensor_resolution'] = sensor_size 149 | self.add_event_indices() 150 | 151 | def set_data_available(self, num_images, num_flow): 152 | if num_images > 0: 153 | self.image_dset = self.events_file.create_group("images") 154 | self.image_dset.attrs['num_images'] = num_images 155 | if num_flow > 0: 156 | self.flow_dset = self.events_file.create_group("flow") 157 | self.flow_dset.attrs['num_images'] = num_flow 158 | 159 | -------------------------------------------------------------------------------- /lib/data_formats/h5_to_memmap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import h5py 3 | import numpy as np 4 | import os, shutil 5 | import json 6 | 7 | class NpEncoder(json.JSONEncoder): 8 | def default(self, obj): 9 | if isinstance(obj, np.integer): 10 | return int(obj) 11 | elif isinstance(obj, np.floating): 12 | return float(obj) 13 | elif isinstance(obj, np.ndarray): 14 | return obj.tolist() 15 | else: 16 | return super(NpEncoder, self).default(obj) 17 | 18 | def find_safe_alternative(output_base_path): 19 | i = 0 20 | alternative_path = "{}_{:09d}".format(output_base_path, i) 21 | while(os.path.exists(alternative_path)): 22 | i += 1 23 | alternative_path = "{}_{:09d}".format(output_base_path, i) 24 | assert(i < 999999999) 25 | return alternative_path 26 | 27 | def save_additional_data_as_mmap(f, mmap_pth, data): 28 | data_path = os.path.join(mmap_pth, data['mmap_filename']) 29 | data_ts_path = os.path.join(mmap_pth, data['mmap_ts_filename']) 30 | data_event_idx_path = os.path.join(mmap_pth, data['mmap_event_idx_filename']) 31 | data_key = data['h5_key'] 32 | print('Writing {} to mmap {}, timestamps to {}'.format(data_key, data_path, data_ts_path)) 33 | h, w, c = 1, 1, 1 34 | if data_key in f.keys(): 35 | num_data = len(f[data_key].keys()) 36 | if num_data > 0: 37 | data_keys = list(f[data_key].keys()) 38 | data_size = f[data_key][data_keys[0]].attrs['size'] 39 | h, w = data_size[0], data_size[1] 40 | c = 1 if len(data_size) <= 2 else data_size[2] 41 | else: 42 | num_data = 1 43 | mmp_imgs = np.memmap(data_path, dtype='uint8', mode='w+', shape=(num_data, h, w, c)) 44 | mmp_img_ts = np.memmap(data_ts_path, dtype='float64', mode='w+', shape=(num_data, 1)) 45 | mmp_event_indices = np.memmap(data_event_idx_path, dtype='uint16', mode='w+', shape=(num_data, 1)) 46 | 47 | if data_key in f.keys(): 48 | data = [] 49 | data_timestamps = [] 50 | data_event_index = [] 51 | for img_key in f[data_key].keys(): 52 | data.append(f[data_key][img_key][:]) 53 | data_timestamps.append(f[data_key][img_key].attrs['timestamp']) 54 | data_event_index.append(f[data_key][img_key].attrs['event_idx']) 55 | 56 | data_stack = np.expand_dims(np.stack(data), axis=3) 57 | data_ts_stack = np.expand_dims(np.stack(data_timestamps), axis=1) 58 | data_event_indices_stack = np.expand_dims(np.stack(data_event_index), axis=1) 59 | mmp_imgs[...] = data_stack 60 | mmp_img_ts[...] = data_ts_stack 61 | mmp_event_indices[...] = data_event_indices_stack 62 | 63 | def write_metadata(f, metadata_path): 64 | metadata = {} 65 | for attr in f.attrs: 66 | val = f.attrs[attr] 67 | if isinstance(val, np.ndarray): 68 | val = val.tolist() 69 | metadata[attr] = val 70 | with open(metadata_path, 'w') as js: 71 | json.dump(metadata, js, cls=NpEncoder) 72 | 73 | def h5_to_memmap(h5_file_path, output_base_path, overwrite=True): 74 | output_pth = output_base_path 75 | if os.path.exists(output_pth): 76 | if overwrite: 77 | print("Overwriting {}".format(output_pth)) 78 | shutil.rmtree(output_pth) 79 | else: 80 | output_pth = find_safe_alternative(output_base_path) 81 | print('Data will be extracted to: {}'.format(output_pth)) 82 | os.makedirs(output_pth) 83 | mmap_pth = os.path.join(output_pth, "memmap") 84 | os.makedirs(mmap_pth) 85 | 86 | ts_path = os.path.join(mmap_pth, 't.npy') 87 | xy_path = os.path.join(mmap_pth, 'xy.npy') 88 | ps_path = os.path.join(mmap_pth, 'p.npy') 89 | metadata_path = os.path.join(mmap_pth, 'metadata.json') 90 | 91 | additional_data = { 92 | "images": 93 | { 94 | 'h5_key' : 'images', 95 | 'mmap_filename' : 'images.npy', 96 | 'mmap_ts_filename' : 'timestamps.npy', 97 | 'mmap_event_idx_filename' : 'image_event_indices.npy', 98 | 'dims' : 3 99 | }, 100 | "flow": 101 | { 102 | 'h5_key' : 'flow', 103 | 'mmap_filename' : 'flow.npy', 104 | 'mmap_ts_filename' : 'flow_timestamps.npy', 105 | 'mmap_event_idx_filename' : 'flow_event_indices.npy', 106 | 'dims' : 3 107 | } 108 | } 109 | 110 | with h5py.File(h5_file_path, 'r') as f: 111 | num_events = f.attrs['num_events'] 112 | num_images = f.attrs['num_imgs'] 113 | num_flow = f.attrs['num_flow'] 114 | 115 | mmp_ts = np.memmap(ts_path, dtype='float64', mode='w+', shape=(num_events, 1)) 116 | mmp_xy = np.memmap(xy_path, dtype='int16', mode='w+', shape=(num_events, 2)) 117 | mmp_ps = np.memmap(ps_path, dtype='uint8', mode='w+', shape=(num_events, 1)) 118 | 119 | mmp_ts[:, 0] = f['events/ts'][:] 120 | mmp_xy[:, :] = np.stack((f['events/xs'][:], f['events/ys'][:])).transpose() 121 | mmp_ps[:, 0] = f['events/ps'][:] 122 | 123 | for data in additional_data: 124 | save_additional_data_as_mmap(f, mmap_pth, additional_data[data]) 125 | write_metadata(f, metadata_path) 126 | 127 | 128 | if __name__ == "__main__": 129 | """ 130 | Tool to convert this projects style hdf5 files to the memmap format used in some RPG projects 131 | """ 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument("path", help="HDF5 file to convert") 134 | parser.add_argument("--output_dir", default=None, help="Path to extract (same as bag if left empty)") 135 | parser.add_argument('--not_overwrite', action='store_false', help='If set, will not overwrite\ 136 | existing memmap, but will place safe alternative') 137 | 138 | args = parser.parse_args() 139 | 140 | bagname = os.path.splitext(os.path.basename(args.path))[0] 141 | if args.output_dir is None: 142 | output_path = os.path.join(os.path.dirname(os.path.abspath(args.path)), bagname) 143 | else: 144 | output_path = os.path.join(args.output_dir, bagname) 145 | h5_to_memmap(args.path, output_path, overwrite=args.not_overwrite) 146 | -------------------------------------------------------------------------------- /lib/data_formats/read_events.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import os 4 | 5 | def compute_indices(event_stamps, frame_stamps): 6 | """ 7 | Given event timestamps and frame timestamps as arrays, 8 | find the event indices that correspond to the beginning and 9 | end period of each frames 10 | @param event_stamps The event timestamps 11 | @param frame_stamps The frame timestamps 12 | @returns The indices as a 2xN numpy array (N=number of frames) 13 | """ 14 | indices_first = np.searchsorted(event_stamps[:,0], frame_stamps[1:]) 15 | indices_last = np.searchsorted(event_stamps[:,0], frame_stamps[:-1]) 16 | index = np.stack([indices_first, indices_last], -1) 17 | return index 18 | 19 | def read_memmap_events(memmap_path, skip_frames=1, return_events=False, images_file = 'images.npy', 20 | images_ts_file = 'timestamps.npy', optic_flow_file = 'optic_flow.npy', 21 | optic_flow_ts_file = 'optic_flow_timestamps.npy', events_xy_file = 'xy.npy', 22 | events_p_file = 'p.npy', events_t_file = 't.npy'): 23 | """ 24 | Given a path to an RPG-style memmap, read the events it contains. 25 | These memmaps break images, timestamps, optic flow, xy, p and t 26 | components of events into separate files. 27 | @param memmap_path Path to the root directory of the memmap 28 | @param skip_frames Skip reading every 'skip_frames'th frame, default=1 29 | @param return_events If True, return the events as numpy arrays, else return 30 | a handle to the event data files (which can be indexed, but does not load 31 | events into RAM) 32 | @param images_file The file containing images 33 | @param images_ts_file The file containing image timestamps 34 | @param optic_flow_file The file containing optic flow frames 35 | @param optic_flow_ts_file The file containing optic flow frame timestamps 36 | @param events_xy_file The file containing event coordinate data 37 | @param events_p_file The file containing the event polarities 38 | @param events_ts_file The file containing the event timestamps 39 | @return dict with event data: 40 | data = { 41 | "index": index mapping image index to event idx 42 | "frame_stamps": frame timestamps 43 | "images": images 44 | "optic_flow": optic flow 45 | "optic_flow_stamps": of timestamps 46 | "t": event timestamps 47 | "xy": event coords 48 | "p": event polarities 49 | "t0": t0 50 | """ 51 | assert os.path.isdir(memmap_path), '%s is not a valid memmap_pathectory' % memmap_path 52 | 53 | data = {} 54 | has_flow = False 55 | for subroot, _, fnames in sorted(os.walk(memmap_path)): 56 | for fname in sorted(fnames): 57 | path = os.path.join(subroot, fname) 58 | if fname.endswith(".npy"): 59 | if fname=="index.npy": # index mapping image index to event idx 60 | indices = np.load(path) # N x 2 61 | assert len(indices.shape) == 2 and indices.shape[1] == 2 62 | indices = indices.astype("int64") # ignore event indices which are 0 (before first image) 63 | data["index"] = indices.T 64 | elif fname==images_ts_file: 65 | data["frame_stamps"] = np.load(path)[::skip_frames,...] 66 | elif fname==images_file: 67 | data["images"] = np.load(path, mmap_mode="r")[::skip_frames,...] 68 | elif fname==optic_flow_file: 69 | data["optic_flow"] = np.load(path, mmap_mode="r")[::skip_frames,...] 70 | has_flow = True 71 | elif fname==optic_flow_ts_file: 72 | data["optic_flow_stamps"] = np.load(path)[::skip_frames,...] 73 | 74 | handle = np.load(path, mmap_mode="r") 75 | if fname==events_t_file: # timestamps 76 | data["t"] = handle[:].squeeze() if return_events else handle 77 | data["t0"] = handle[0] 78 | elif fname==events_xy_file: # coordinates 79 | data["xy"] = handle[:].squeeze() if return_events else handle 80 | elif fname==events_p_file: # polarity 81 | data["p"] = handle[:].squeeze() if return_events else handle 82 | 83 | if len(data) > 0: 84 | data['path'] = subroot 85 | if "t" not in data: 86 | raise Exception(f"Ignoring memmap_pathectory {subroot} since no events") 87 | if not (len(data['p']) == len(data['xy']) and len(data['p']) == len(data['t'])): 88 | raise Exception(f"Events from {subroot} invalid") 89 | data["num_events"] = len(data['p']) 90 | 91 | if "index" not in data and "frame_stamps" in data: 92 | data["index"] = compute_indices(data["t"], data['frame_stamps']) 93 | return data 94 | 95 | def read_memmap_events_dict(memmap_path, skip_frames=1, return_events=False, images_file = 'images.npy', 96 | images_ts_file = 'timestamps.npy', optic_flow_file = 'optic_flow.npy', 97 | optic_flow_ts_file = 'optic_flow_timestamps.npy', events_xy_file = 'xy.npy', 98 | events_p_file = 'p.npy', events_t_file = 't.npy'): 99 | """ 100 | Read memmap file events and return them in a dict 101 | """ 102 | data = read_memmap_events(memmap_path, skip_frames, return_events, images_file, images_ts_file, 103 | optic_flow_file, optic_flow_ts_file, events_xy_file, events_p_file, events_t_file) 104 | events = { 105 | 'xs':data['xy'][:,0].squeeze(), 106 | 'ys':data['xy'][:,1].squeeze(), 107 | 'ts':events['t'][:].squeeze(), 108 | 'ps':events['p'][:].squeeze()} 109 | return events 110 | 111 | def read_h5_events(hdf_path): 112 | """ 113 | Read events from HDF5 file (Monash style). 114 | @param hdf_path Path to HDF5 file 115 | @returns Events as 4xN numpy array (N=num events) 116 | """ 117 | f = h5py.File(hdf_path, 'r') 118 | if 'events/x' in f: 119 | #legacy 120 | events = np.stack((f['events/x'][:], f['events/y'][:], f['events/ts'][:], np.where(f['events/p'][:], 1, -1)), axis=1) 121 | else: 122 | events = np.stack((f['events/xs'][:], f['events/ys'][:], f['events/ts'][:], np.where(f['events/ps'][:], 1, -1)), axis=1) 123 | return events 124 | 125 | def read_h5_event_components(hdf_path): 126 | """ 127 | Read events from HDF5 file (Monash style). 128 | @param hdf_path Path to HDF5 file 129 | @returns Events as four np arrays with the event components 130 | """ 131 | f = h5py.File(hdf_path, 'r') 132 | if 'events/x' in f: 133 | #legacy 134 | return (f['events/x'][:], f['events/y'][:], f['events/ts'][:], np.where(f['events/p'][:], 1, -1)) 135 | else: 136 | return (f['events/xs'][:], f['events/ys'][:], f['events/ts'][:], np.where(f['events/ps'][:], 1, -1)) 137 | 138 | def read_h5_events_dict(hdf_path, read_frames=True): 139 | """ 140 | Read events from HDF5 file (Monash style). 141 | @param hdf_path Path to HDF5 file 142 | @returns Events as a dict with entries 'xs', 'ys', 'ts', 'ps' containing the event components, 143 | 'frames' containing the frames, 'frame_timestamps' containing frame timestamps and 144 | 'frame_event_indices' containing the indices of the corresponding event for each frame 145 | """ 146 | f = h5py.File(hdf_path, 'r') 147 | if 'events/x' in f: 148 | #legacy 149 | events = { 150 | 'xs':f['events/x'][:], 151 | 'ys':f['events/y'][:], 152 | 'ts':f['events/ts'][:], 153 | 'ps':np.where(f['events/p'][:], 1, -1) 154 | } 155 | return events 156 | else: 157 | events = { 158 | 'xs':f['events/xs'][:], 159 | 'ys':f['events/ys'][:], 160 | 'ts':f['events/ts'][:], 161 | 'ps':np.where(f['events/ps'][:], 1, -1) 162 | } 163 | if read_frames: 164 | images = [] 165 | image_stamps = [] 166 | image_event_indices = [] 167 | for key in f['images']: 168 | frame = f['images/{}'.format(key)][:] 169 | images.append(frame) 170 | image_stamps.append(f['images/{}'.format(key)].attrs['timestamp']) 171 | image_event_indices.append(f['images/{}'.format(key)].attrs['event_idx']) 172 | events['frames'] = images 173 | #np.concatenate(images, axis=2).swapaxes(0,2) if len(frame.shape)==3 else np.stack(images, axis=0) 174 | events['frame_timestamps'] = np.array(image_stamps) 175 | events['frame_event_indices'] = np.array(image_event_indices) 176 | return events 177 | -------------------------------------------------------------------------------- /lib/data_formats/rosbag_to_h5.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import argparse 3 | import rosbag 4 | import rospy 5 | from cv_bridge import CvBridge, CvBridgeError 6 | import os 7 | import h5py 8 | import numpy as np 9 | from event_packagers import * 10 | from tqdm import tqdm 11 | 12 | 13 | def append_to_dataset(dataset, data): 14 | dataset.resize(dataset.shape[0] + len(data), axis=0) 15 | if len(data) == 0: 16 | return 17 | dataset[-len(data):] = data[:] 18 | 19 | 20 | def timestamp_float(ts): 21 | return ts.secs + ts.nsecs / float(1e9) 22 | 23 | 24 | def get_rosbag_stats(bag, event_topic, image_topic=None, flow_topic=None): 25 | num_event_msgs = 0 26 | num_img_msgs = 0 27 | num_flow_msgs = 0 28 | topics = bag.get_type_and_topic_info().topics 29 | for topic_name, topic_info in topics.iteritems(): 30 | if topic_name == event_topic: 31 | num_event_msgs = topic_info.message_count 32 | print('Found events topic: {} with {} messages'.format(topic_name, topic_info.message_count)) 33 | if topic_name == image_topic: 34 | num_img_msgs = topic_info.message_count 35 | print('Found image topic: {} with {} messages'.format(topic_name, num_img_msgs)) 36 | if topic_name == flow_topic: 37 | num_flow_msgs = topic_info.message_count 38 | print('Found flow topic: {} with {} messages'.format(topic_name, num_flow_msgs)) 39 | return num_event_msgs, num_img_msgs, num_flow_msgs 40 | 41 | 42 | # Inspired by https://github.com/uzh-rpg/rpg_e2vid 43 | def extract_rosbag(rosbag_path, output_path, event_topic, image_topic=None, 44 | flow_topic=None, start_time=None, end_time=None, zero_timestamps=False, 45 | packager=hdf5_packager, is_color=False): 46 | ep = packager(output_path) 47 | topics = (event_topic, image_topic, flow_topic) 48 | event_msg_sum = 0 49 | num_msgs_between_logs = 25 50 | first_ts = -1 51 | t0 = -1 52 | sensor_size = None 53 | if not os.path.exists(rosbag_path): 54 | print("{} does not exist!".format(rosbag_path)) 55 | return 56 | with rosbag.Bag(rosbag_path, 'r') as bag: 57 | # Look for the topics that are available and save the total number of messages for each topic (useful for the progress bar) 58 | num_event_msgs, num_img_msgs, num_flow_msgs = get_rosbag_stats(bag, event_topic, image_topic, flow_topic) 59 | # Extract events to h5 60 | xs, ys, ts, ps = [], [], [], [] 61 | max_buffer_size = 1e20 62 | ep.set_data_available(num_img_msgs, num_flow_msgs) 63 | num_pos, num_neg, last_ts, img_cnt, flow_cnt = 0, 0, 0, 0, 0 64 | 65 | for topic, msg, t in tqdm(bag.read_messages()): 66 | if first_ts == -1 and topic in topics: 67 | timestamp = timestamp_float(msg.header.stamp) 68 | first_ts = timestamp 69 | if zero_timestamps: 70 | timestamp = timestamp-first_ts 71 | if start_time is None: 72 | start_time = first_ts 73 | start_time = start_time + first_ts 74 | if end_time is not None: 75 | end_time = end_time+start_time 76 | t0 = timestamp 77 | 78 | if topic == image_topic: 79 | timestamp = timestamp_float(msg.header.stamp)-(first_ts if zero_timestamps else 0) 80 | if is_color: 81 | image = CvBridge().imgmsg_to_cv2(msg, "bgr8") 82 | else: 83 | image = CvBridge().imgmsg_to_cv2(msg, "mono8") 84 | 85 | ep.package_image(image, timestamp, img_cnt) 86 | sensor_size = image.shape 87 | img_cnt += 1 88 | 89 | elif topic == flow_topic: 90 | timestamp = timestamp_float(msg.header.stamp)-(first_ts if zero_timestamps else 0) 91 | 92 | flow_x = np.array(msg.flow_x) 93 | flow_y = np.array(msg.flow_y) 94 | flow_x.shape = (msg.height, msg.width) 95 | flow_y.shape = (msg.height, msg.width) 96 | flow_image = np.stack((flow_x, flow_y), axis=0) 97 | 98 | ep.package_flow(flow_image, timestamp, flow_cnt) 99 | flow_cnt += 1 100 | 101 | elif topic == event_topic: 102 | event_msg_sum += 1 103 | #if event_msg_sum % num_msgs_between_logs == 0 or event_msg_sum >= num_event_msgs - 1: 104 | # print('Event messages: {} / {}'.format(event_msg_sum + 1, num_event_msgs)) 105 | for e in msg.events: 106 | timestamp = timestamp_float(e.ts)-(first_ts if zero_timestamps else 0) 107 | xs.append(e.x) 108 | ys.append(e.y) 109 | ts.append(timestamp) 110 | ps.append(1 if e.polarity else 0) 111 | if e.polarity: 112 | num_pos += 1 113 | else: 114 | num_neg += 1 115 | last_ts = timestamp 116 | if (len(xs) > max_buffer_size and timestamp >= start_time) or (end_time is not None and timestamp >= start_time): 117 | print("Writing events") 118 | if sensor_size is None or sensor_size[0] < max(ys) or sensor_size[1] < max(xs): 119 | sensor_size = [max(ys), max(xs)] 120 | print("Sensor size inferred from events as {}".format(sensor_size)) 121 | ep.package_events(xs, ys, ts, ps) 122 | del xs[:] 123 | del ys[:] 124 | del ts[:] 125 | del ps[:] 126 | if end_time is not None and timestamp >= start_time: 127 | return 128 | if sensor_size is None or sensor_size[0] < max(ys) or sensor_size[1] < max(xs): 129 | sensor_size = [max(ys), max(xs)] 130 | print("Sensor size inferred from events as {}".format(sensor_size)) 131 | ep.package_events(xs, ys, ts, ps) 132 | del xs[:] 133 | del ys[:] 134 | del ts[:] 135 | del ps[:] 136 | if sensor_size is None: 137 | raise Exception("ERROR: No sensor size detected, implies no events/images in bag topics?") 138 | print("Detected sensor size {}".format(sensor_size)) 139 | ep.add_metadata(num_pos, num_neg, last_ts-t0, t0, last_ts, img_cnt, flow_cnt, sensor_size) 140 | 141 | 142 | def extract_rosbags(rosbag_paths, output_dir, event_topic, image_topic, flow_topic, 143 | zero_timestamps=False, is_color=False): 144 | for path in rosbag_paths: 145 | bagname = os.path.splitext(os.path.basename(path))[0] 146 | out_path = os.path.join(output_dir, "{}.h5".format(bagname)) 147 | print("Extracting {} to {}".format(path, out_path)) 148 | extract_rosbag(path, out_path, event_topic, image_topic=image_topic, 149 | flow_topic=flow_topic, zero_timestamps=zero_timestamps, is_color=is_color) 150 | 151 | 152 | if __name__ == "__main__": 153 | """ 154 | Tool for converting rosbag events to an efficient HDF5 format that can be speedily 155 | accessed by python code. 156 | """ 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("path", help="ROS bag file to extract or directory containing bags") 159 | parser.add_argument("--output_dir", default="/tmp/extracted_data", help="Folder where to extract the data") 160 | parser.add_argument("--event_topic", default="/dvs/events", help="Event topic") 161 | parser.add_argument("--image_topic", default=None, help="Image topic (if left empty, no images will be collected)") 162 | parser.add_argument("--flow_topic", default=None, help="Flow topic (if left empty, no flow will be collected)") 163 | parser.add_argument('--zero_timestamps', action='store_true', help='If true, timestamps will be offset to start at 0') 164 | parser.add_argument('--is_color', action='store_true', help='Set flag to save frames from image_topic as 3-channel, bgr color images') 165 | args = parser.parse_args() 166 | 167 | print('Data will be extracted in folder: {}'.format(args.output_dir)) 168 | if not os.path.exists(args.output_dir): 169 | os.makedirs(args.output_dir) 170 | if os.path.isdir(args.path): 171 | rosbag_paths = sorted(glob.glob(os.path.join(args.path, "*.bag"))) 172 | else: 173 | rosbag_paths = [args.path] 174 | extract_rosbags(rosbag_paths, args.output_dir, args.event_topic, args.image_topic, 175 | args.flow_topic, zero_timestamps=args.zero_timestamps, is_color=args.is_color) 176 | -------------------------------------------------------------------------------- /lib/data_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | from .base_dataset import * 3 | from .memmap_dataset import * 4 | from .hdf5_dataset import * 5 | from .npy_dataset import * 6 | -------------------------------------------------------------------------------- /lib/data_loaders/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numbers 3 | import torchvision.transforms 4 | 5 | 6 | class Compose(object): 7 | """ 8 | Composes several transforms together. 9 | Example: 10 | >>> torchvision.transforms.Compose([ 11 | >>> torchvision.transforms.CenterCrop(10), 12 | >>> torchvision.transforms.ToTensor(), 13 | >>> ]) 14 | """ 15 | 16 | def __init__(self, transforms): 17 | """ 18 | @param transforms (list of ``Transform`` objects): list of transforms to compose. 19 | """ 20 | self.transforms = transforms 21 | 22 | def __call__(self, x, is_flow=False): 23 | """ 24 | Call the transform. 25 | @param x The tensor to transform 26 | @param is_flow Set true if tensor represents optic flow 27 | @returns Transformed tensor 28 | """ 29 | for t in self.transforms: 30 | x = t(x, is_flow) 31 | return x 32 | 33 | def __repr__(self): 34 | format_string = self.__class__.__name__ + '(' 35 | for t in self.transforms: 36 | format_string += '\n' 37 | format_string += ' {0}'.format(t) 38 | format_string += '\n)' 39 | return format_string 40 | 41 | 42 | class CenterCrop(object): 43 | """ 44 | Center crop the tensor to a certain size. 45 | """ 46 | 47 | def __init__(self, size, preserve_mosaicing_pattern=False): 48 | if isinstance(size, numbers.Number): 49 | self.size = (int(size), int(size)) 50 | else: 51 | self.size = size 52 | 53 | self.preserve_mosaicing_pattern = preserve_mosaicing_pattern 54 | 55 | def __call__(self, x, is_flow=False): 56 | """ 57 | @param x [C x H x W] Tensor to be rotated. 58 | @param is_flow this parameter does not have any effect 59 | @returns Cropped tensor. 60 | """ 61 | w, h = x.shape[2], x.shape[1] 62 | th, tw = self.size 63 | assert(th <= h) 64 | assert(tw <= w) 65 | i = int(round((h - th) / 2.)) 66 | j = int(round((w - tw) / 2.)) 67 | 68 | if self.preserve_mosaicing_pattern: 69 | # make sure that i and j are even, to preserve 70 | # the mosaicing pattern 71 | if i % 2 == 1: 72 | i = i + 1 73 | if j % 2 == 1: 74 | j = j + 1 75 | 76 | return x[:, i:i + th, j:j + tw] 77 | 78 | def __repr__(self): 79 | return self.__class__.__name__ + '(size={0})'.format(self.size) 80 | 81 | 82 | class RobustNorm(object): 83 | 84 | """ 85 | Robustly normalize tensor (ie normalise it between top and 86 | bottom centiles of tensor value range) 87 | """ 88 | 89 | def __init__(self, low_perc=0, top_perc=95): 90 | self.top_perc = top_perc 91 | self.low_perc = low_perc 92 | 93 | @staticmethod 94 | def percentile(t, q): 95 | """ 96 | Return the ``q``-th percentile of the flattened input tensor's data. 97 | CAUTION: 98 | * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used. 99 | * Values are not interpolated, which corresponds to 100 | ``numpy.percentile(..., interpolation="nearest")``. 101 | @param t Input tensor. 102 | @param q Percentile to compute, which must be between 0 and 100 inclusive. 103 | @returns Resulting value (scalar). 104 | """ 105 | # Note that ``kthvalue()`` works one-based, i.e. the first sorted value 106 | # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly, 107 | # so that ``round()`` returns an integer, even if q is a np.float32. 108 | k = 1 + round(.01 * float(q) * (t.numel() - 1)) 109 | try: 110 | result = t.view(-1).kthvalue(k).values.item() 111 | except RuntimeError: 112 | result = t.reshape(-1).kthvalue(k).values.item() 113 | return result 114 | 115 | def __call__(self, x, is_flow=False): 116 | """ 117 | Call the transform. 118 | @param x The tensor to normalise 119 | @param is_flow Set true if the tensor represents optic flow 120 | @returns Normalised tensor 121 | """ 122 | t_max = self.percentile(x, self.top_perc) 123 | t_min = self.percentile(x, self.low_perc) 124 | # print("t_max={}, t_min={}".format(t_max, t_min)) 125 | if t_max == 0 and t_min == 0: 126 | return x 127 | eps = 1e-6 128 | normed = torch.clamp(x, min=t_min, max=t_max) 129 | normed = (normed-torch.min(normed))/(torch.max(normed)+eps) 130 | return normed 131 | 132 | def __repr__(self): 133 | format_string = self.__class__.__name__ 134 | format_string += '(top_perc={:.2f}'.format(self.top_perc) 135 | format_string += ', low_perc={:.2f})'.format(self.low_perc) 136 | return format_string 137 | -------------------------------------------------------------------------------- /lib/data_loaders/data_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from torch.utils.data import ConcatDataset 5 | 6 | 7 | data_sources = ('esim', 'ijrr', 'mvsec', 'eccd', 'hqfd', 'unknown') 8 | # Usage: name = data_sources[1], idx = data_sources.index('ijrr') 9 | 10 | 11 | def concatenate_subfolders(data_file, dataset, dataset_kwargs): 12 | """ 13 | Create an instance of ConcatDataset by aggregating all the datasets in a given folder 14 | """ 15 | if os.path.isdir(data_file): 16 | subfolders = [os.path.join(data_file, s) for s in os.listdir(data_file)] 17 | elif os.path.isfile(data_file): 18 | subfolders = pd.read_csv(data_file, header=None).values.flatten().tolist() 19 | else: 20 | raise Exception('{} must be data_file.txt or base/folder'.format(data_file)) 21 | print('Found {} samples in {}'.format(len(subfolders), data_file)) 22 | datasets = [] 23 | for subfolder in subfolders: 24 | dataset_kwargs['item_kwargs'].update({'base_folder': subfolder}) 25 | datasets.append(dataset(**dataset_kwargs)) 26 | return ConcatDataset(datasets) 27 | 28 | 29 | def concatenate_datasets(data_file, dataset_type, dataset_kwargs=None): 30 | """ 31 | Generates a dataset for each cti_path specified in data_file and concatenates the datasets. 32 | :param data_file: A file containing a list of paths to CTI h5 files. 33 | Each file is expected to have a sequence of frame_{:09d} 34 | :param dataset_type: Pointer to dataset class 35 | :param dataset_kwargs: Dataset keyword arguments 36 | :return ConcatDataset: concatenated dataset of all cti_paths in data_file 37 | """ 38 | if dataset_kwargs is None: 39 | dataset_kwargs = {} 40 | 41 | cti_paths = pd.read_csv(data_file, header=None).values.flatten().tolist() 42 | dataset_list = [] 43 | print('Concatenating {} datasets'.format(dataset_type)) 44 | for cti_path in tqdm(cti_paths): 45 | dataset_kwargs['dataset_kwargs'].update({'h5_path': cti_path}) 46 | dataset_list.append(dataset_type(**dataset_kwargs)) 47 | return ConcatDataset(dataset_list) 48 | 49 | 50 | def concatenate_memmap_datasets(data_file, dataset_type, dataset_kwargs): 51 | """ 52 | Generates a dataset for each memmap_path specified in data_file and concatenates the datasets. 53 | :param data_file: A file containing a list of paths to memmap root dirs. 54 | :param dataset_type: Pointer to dataset class 55 | :param dataset_kwargs: Dataset keyword arguments 56 | :return ConcatDataset: concatenated dataset of all memmap_paths in data_file 57 | """ 58 | if dataset_kwargs is None: 59 | dataset_kwargs = {} 60 | 61 | memmap_paths = pd.read_csv(data_file, header=None).values.flatten().tolist() 62 | dataset_list = [] 63 | print('Concatenating {} datasets'.format(dataset_type)) 64 | for memmap_path in tqdm(memmap_paths): 65 | dataset_kwargs['dataset_kwargs'].update({'root': memmap_path}) 66 | dataset_list.append(dataset_type(**dataset_kwargs)) 67 | return ConcatDataset(dataset_list) 68 | -------------------------------------------------------------------------------- /lib/data_loaders/dataloader_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def unpack_batched_events(events, batch_indices): 4 | """ 5 | When returning events from a pytorch dataloader, it is often convenient when 6 | batching, to place them into a contiguous 1x1xNx4 array, where N=length of all 7 | B event arrays in the batch. This function unpacks the events into a Bx1xMx4 array, 8 | where B is the batch size, M is the length of the *longest* event array in the 9 | batch. The shorter event arrays are then padded with zeros. 10 | Parameters 11 | ---------- 12 | events : 1x1xNx4 array of the events 13 | batch_indices : A list of the end indices of events, where one event array ends and 14 | the next begins. For example, if you batched two event arrays A and B of length 15 | 200 and 700 respectively, batch_indices=[200, 900] 16 | Returns 17 | ------- 18 | unpacked_events: Bx1xMx4 array of unpacked events 19 | """ 20 | maxlen = 0 21 | start_idx = 0 22 | for b_idx in range(len(batch_indices)): 23 | end_idx = event_batch_indices[b_idx] 24 | maxlen = end_idx-start_idx if end_idx-start_dx > maxlen else maxlen 25 | 26 | unpacked_events = torch.zeros((len(batch_indices), 1, maxlen, 4)) 27 | start_idx = 0 28 | for b_idx in range(len(batch_indices)): 29 | num_events = end_idx-start_idx 30 | unpacked_events[b_idx, 0, 0:num_events, :] = events[start_idx:end_idx, :] 31 | start_idx = end_idx 32 | return unpacked_events 33 | -------------------------------------------------------------------------------- /lib/data_loaders/hdf5_dataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | from ..util.event_util import binary_search_h5_dset 3 | from .base_dataset import BaseVoxelDataset 4 | import matplotlib.pyplot as plt 5 | 6 | class DynamicH5Dataset(BaseVoxelDataset): 7 | """ 8 | Dataloader for events saved in the Monash University HDF5 events format 9 | (see https://github.com/TimoStoff/event_utils for code to convert datasets) 10 | """ 11 | 12 | def get_frame(self, index): 13 | return self.h5_file['images']['image{:09d}'.format(index)][:] 14 | 15 | def get_flow(self, index): 16 | return self.h5_file['flow']['flow{:09d}'.format(index)][:] 17 | 18 | def get_events(self, idx0, idx1): 19 | xs = self.h5_file['events/xs'][idx0:idx1] 20 | ys = self.h5_file['events/ys'][idx0:idx1] 21 | ts = self.h5_file['events/ts'][idx0:idx1] 22 | ps = self.h5_file['events/ps'][idx0:idx1] * 2.0 - 1.0 23 | return xs, ys, ts, ps 24 | 25 | def load_data(self, data_path): 26 | self.data_sources = ('esim', 'ijrr', 'mvsec', 'eccd', 'hqfd', 'unknown') 27 | try: 28 | self.h5_file = h5py.File(data_path, 'r') 29 | except OSError as err: 30 | print("Couldn't open {}: {}".format(data_path, err)) 31 | 32 | if self.sensor_resolution is None: 33 | self.sensor_resolution = self.h5_file.attrs['sensor_resolution'][0:2] 34 | else: 35 | self.sensor_resolution = self.sensor_resolution[0:2] 36 | print("sensor resolution = {}".format(self.sensor_resolution)) 37 | self.has_flow = 'flow' in self.h5_file.keys() and len(self.h5_file['flow']) > 0 38 | self.t0 = self.h5_file['events/ts'][0] 39 | self.tk = self.h5_file['events/ts'][-1] 40 | self.num_events = self.h5_file.attrs["num_events"] 41 | self.num_frames = self.h5_file.attrs["num_imgs"] 42 | 43 | self.frame_ts = [] 44 | for img_name in self.h5_file['images']: 45 | self.frame_ts.append(self.h5_file['images/{}'.format(img_name)].attrs['timestamp']) 46 | 47 | data_source = self.h5_file.attrs.get('source', 'unknown') 48 | try: 49 | self.data_source_idx = self.data_sources.index(data_source) 50 | except ValueError: 51 | self.data_source_idx = -1 52 | 53 | def find_ts_index(self, timestamp): 54 | idx = binary_search_h5_dset(self.h5_file['events/ts'], timestamp) 55 | return idx 56 | 57 | def ts(self, index): 58 | return self.h5_file['events/ts'][index] 59 | 60 | def compute_frame_indices(self): 61 | frame_indices = [] 62 | start_idx = 0 63 | for img_name in self.h5_file['images']: 64 | end_idx = self.h5_file['images/{}'.format(img_name)].attrs['event_idx'] 65 | frame_indices.append([start_idx, end_idx]) 66 | start_idx = end_idx 67 | return frame_indices 68 | 69 | if __name__ == "__main__": 70 | """ 71 | Tool to add events to a set of events. 72 | """ 73 | import argparse 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument("path", help="Path to event file") 76 | args = parser.parse_args() 77 | 78 | dloader = DynamicH5Dataset(args.path) 79 | for item in dloader: 80 | print(item['events'].shape) 81 | -------------------------------------------------------------------------------- /lib/data_loaders/memmap_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from .base_dataset import BaseVoxelDataset 4 | 5 | class MemMapDataset(BaseVoxelDataset): 6 | """ 7 | Dataloader for events saved in the MemMap events format used at RPG. 8 | (see https://github.com/TimoStoff/event_utils for code to convert datasets) 9 | """ 10 | 11 | def get_frame(self, index): 12 | frame = self.filehandle['images'][index][:, :, 0] 13 | return frame 14 | 15 | def get_flow(self, index): 16 | flow = self.filehandle['optic_flow'][index] 17 | return flow 18 | 19 | def get_events(self, idx0, idx1): 20 | xy = self.filehandle["xy"][idx0:idx1] 21 | xs = xy[:, 0].astype(np.float32) 22 | ys = xy[:, 1].astype(np.float32) 23 | ts = self.filehandle["t"][idx0:idx1] 24 | ps = self.filehandle["p"][idx0:idx1] * 2.0 - 1.0 25 | return xs, ys, ts, ps 26 | 27 | def load_data(self, data_path, timestamp_fname="timestamps.npy", image_fname="images.npy", 28 | optic_flow_fname="optic_flow.npy", optic_flow_stamps_fname="optic_flow_stamps.npy", 29 | t_fname="t.npy", xy_fname="xy.npy", p_fname="p.npy"): 30 | 31 | assert os.path.isdir(data_path), '%s is not a valid data_path' % data_path 32 | 33 | data = {} 34 | self.has_flow = False 35 | for subroot, _, fnames in sorted(os.walk(data_path)): 36 | for fname in sorted(fnames): 37 | path = os.path.join(subroot, fname) 38 | if fname.endswith(".npy"): 39 | if fname.endswith(timestamp_fname): 40 | frame_stamps = np.load(path) 41 | data["frame_stamps"] = frame_stamps 42 | elif fname.endswith(image_fname): 43 | data["images"] = np.load(path, mmap_mode="r") 44 | elif fname.endswith(optic_flow_fname): 45 | data["optic_flow"] = np.load(path, mmap_mode="r") 46 | self.has_flow = True 47 | elif fname.endswith(optic_flow_stamps_fname): 48 | optic_flow_stamps = np.load(path) 49 | data["optic_flow_stamps"] = optic_flow_stamps 50 | 51 | try: 52 | handle = np.load(path, mmap_mode="r") 53 | except Exception as err: 54 | print("Couldn't load {}:".format(path)) 55 | raise err 56 | if fname.endswith(t_fname): # timestamps 57 | data["t"] = handle.squeeze() 58 | elif fname.endswith(xy_fname): # coordinates 59 | data["xy"] = handle.squeeze() 60 | elif fname.endswith(p_fname): # polarity 61 | data["p"] = handle.squeeze() 62 | if len(data) > 0: 63 | data['path'] = subroot 64 | if "t" not in data: 65 | print("Ignoring root {} since no events".format(subroot)) 66 | continue 67 | assert (len(data['p']) == len(data['xy']) and len(data['p']) == len(data['t'])) 68 | 69 | self.t0, self.tk = data['t'][0], data['t'][-1] 70 | self.num_events = len(data['p']) 71 | self.num_frames = len(data['images']) 72 | 73 | self.frame_ts = [] 74 | for ts in data["frame_stamps"]: 75 | self.frame_ts.append(ts) 76 | data["index"] = self.frame_ts 77 | 78 | self.filehandle = data 79 | self.find_config(data_path) 80 | 81 | def find_ts_index(self, timestamp): 82 | index = np.searchsorted(self.filehandle["t"], timestamp) 83 | return index 84 | 85 | def ts(self, index): 86 | return self.filehandle["t"][index] 87 | 88 | def infer_resolution(self): 89 | if len(self.filehandle["images"]) > 0: 90 | sr = self.filehandle["images"][0].shape[0:2] 91 | else: 92 | sr = [np.max(self.filehandle["xy"][:, 1]) + 1, np.max(self.filehandle["xy"][:, 0]) + 1] 93 | print("Inferred sensor resolution: {}".format(self.sensor_resolution)) 94 | return sr 95 | 96 | def find_config(self, data_path): 97 | if self.sensor_resolution is None: 98 | config = os.path.join(data_path, "dataset_config.json") 99 | if os.path.exists(config): 100 | self.config = read_json(config) 101 | self.data_source = self.config['data_source'] 102 | self.sensor_resolution = self.config["sensor_resolution"] 103 | else: 104 | data_source = 'unknown' 105 | self.sensor_resolution = self.infer_resolution() 106 | -------------------------------------------------------------------------------- /lib/data_loaders/npy_dataset.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseVoxelDataset 2 | import numpy as np 3 | 4 | class NpyDataset(BaseVoxelDataset): 5 | """ 6 | Dataloader for events saved in the Monash University HDF5 events format 7 | (see https://github.com/TimoStoff/event_utils for code to convert datasets) 8 | """ 9 | 10 | def get_frame(self, index): 11 | return None 12 | 13 | def get_flow(self, index): 14 | return None 15 | 16 | def get_events(self, idx0, idx1): 17 | xs = self.xs[idx0:idx1] 18 | ys = self.ys[idx0:idx1] 19 | ts = self.ts[idx0:idx1] 20 | ps = self.ps[idx0:idx1] 21 | return xs, ys, ts, ps 22 | 23 | def load_data(self, data_path): 24 | try: 25 | self.data = np.load(data_path) 26 | self.xs, self.ys, self.ps, self.ts = self.data[:, 0], self.data[:, 1], self.data[:, 2]*2-1, self.data[:, 3]*1e-6 27 | except OSError as err: 28 | print("Couldn't open {}: {}".format(data_path, err)) 29 | print(self.ps) 30 | 31 | if self.sensor_resolution is None: 32 | self.sensor_resolution = [np.max(self.xs), np.max(self.ys)] 33 | print("Inferred resolution as {}".format(self.sensor_resolution)) 34 | else: 35 | self.sensor_resolution = self.sensor_resolution[0:2] 36 | print("sensor resolution = {}".format(self.sensor_resolution)) 37 | self.has_flow = False 38 | self.has_frames = False 39 | self.t0 = self.ts[0] 40 | self.tk = self.ts[-1] 41 | self.num_events = len(self.xs) 42 | self.num_frames = 0 43 | self.frame_ts = [] 44 | 45 | def find_ts_index(self, timestamp): 46 | idx = np.searchsorted(self.ts, timestamp) 47 | return idx 48 | 49 | def ts(self, index): 50 | return ts[index] 51 | 52 | def compute_frame_indices(self): 53 | return None 54 | 55 | if __name__ == "__main__": 56 | """ 57 | Tool to add events to a set of events. 58 | """ 59 | import argparse 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("path", help="Path to event file") 62 | args = parser.parse_args() 63 | 64 | dloader = NpyDataset(args.path) 65 | for item in dloader: 66 | print(item['events'].shape) 67 | -------------------------------------------------------------------------------- /lib/representations/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import rankdata 3 | import torch 4 | 5 | def events_to_image(xs, ys, ps, sensor_size=(180, 240), interpolation=None, padding=False, meanval=False, default=0): 6 | """ 7 | Place events into an image using numpy 8 | @param xs x coords of events 9 | @param ys y coords of events 10 | @param ps Event polarities/weights 11 | @param sensor_size The size of the event camera sensor 12 | @param interpolation Whether to add the events to the pixels by interpolation (values: None, 'bilinear') 13 | @param padding If true, pad the output image to include events otherwise warped off sensor 14 | @param meanval If true, divide the sum of the values by the number of events at that location 15 | @returns Event image from the input events 16 | """ 17 | img_size = (sensor_size[0]+1, sensor_size[1]+1) 18 | if interpolation == 'bilinear' and xs.dtype is not torch.long and xs.dtype is not torch.long: 19 | xt, yt, pt = torch.from_numpy(xs), torch.from_numpy(ys), torch.from_numpy(ps) 20 | xt, yt, pt = xt.float(), yt.float(), pt.float() 21 | img = events_to_image_torch(xt, yt, pt, clip_out_of_range=True, interpolation='bilinear', padding=padding) 22 | img[img==0] = default 23 | img = img.numpy() 24 | if meanval: 25 | event_count_image = events_to_image_torch(xt, yt, torch.ones_like(xt), 26 | clip_out_of_range=True, padding=padding) 27 | event_count_image = event_count_image.numpy() 28 | else: 29 | coords = np.stack((ys, xs)) 30 | try: 31 | abs_coords = np.ravel_multi_index(coords, img_size) 32 | except ValueError: 33 | print("Issue with input arrays! minx={}, maxx={}, miny={}, maxy={}, coords.shape={}, \ 34 | sum(coords)={}, sensor_size={}".format(np.min(xs), np.max(xs), np.min(ys), np.max(ys), 35 | coords.shape, np.sum(coords), img_size)) 36 | raise ValueError 37 | img = np.bincount(abs_coords, weights=ps, minlength=img_size[0]*img_size[1]) 38 | img = img.reshape(img_size) 39 | if meanval: 40 | event_count_image = np.bincount(abs_coords, weights=np.ones_like(xs), minlength=img_size[0]*img_size[1]) 41 | event_count_image = event_count_image.reshape(img_size) 42 | if meanval: 43 | img = np.divide(img, event_count_image, out=np.ones_like(img)*default, where=event_count_image!=0) 44 | return img[0:sensor_size[0], 0:sensor_size[1]] 45 | 46 | def events_to_image_torch(xs, ys, ps, 47 | device=None, sensor_size=(180, 240), clip_out_of_range=True, 48 | interpolation=None, padding=True, default=0): 49 | """ 50 | Method to turn event tensor to image. Allows for bilinear interpolation. 51 | @param xs Tensor of x coords of events 52 | @param ys Tensor of y coords of events 53 | @param ps Tensor of event polarities/weights 54 | @param device The device on which the image is. If none, set to events device 55 | @param sensor_size The size of the image sensor/output image 56 | @param clip_out_of_range If the events go beyond the desired image size, 57 | clip the events to fit into the image 58 | @param interpolation Which interpolation to use. Options=None,'bilinear' 59 | @param padding If bilinear interpolation, allow padding the image by 1 to allow events to fit: 60 | @returns Event image from the events 61 | """ 62 | if device is None: 63 | device = xs.device 64 | if interpolation == 'bilinear' and padding: 65 | img_size = (sensor_size[0]+1, sensor_size[1]+1) 66 | else: 67 | img_size = list(sensor_size) 68 | 69 | mask = torch.ones(xs.size(), device=device) 70 | if clip_out_of_range: 71 | zero_v = torch.tensor([0.], device=device) 72 | ones_v = torch.tensor([1.], device=device) 73 | clipx = img_size[1] if interpolation is None and padding==False else img_size[1]-1 74 | clipy = img_size[0] if interpolation is None and padding==False else img_size[0]-1 75 | mask = torch.where(xs>=clipx, zero_v, ones_v)*torch.where(ys>=clipy, zero_v, ones_v) 76 | 77 | img = (torch.ones(img_size)*default).to(device) 78 | if interpolation == 'bilinear' and xs.dtype is not torch.long and xs.dtype is not torch.long: 79 | pxs = (xs.floor()).float() 80 | pys = (ys.floor()).float() 81 | dxs = (xs-pxs).float() 82 | dys = (ys-pys).float() 83 | pxs = (pxs*mask).long() 84 | pys = (pys*mask).long() 85 | masked_ps = ps.squeeze()*mask 86 | interpolate_to_image(pxs, pys, dxs, dys, masked_ps, img) 87 | else: 88 | if xs.dtype is not torch.long: 89 | xs = xs.long().to(device) 90 | if ys.dtype is not torch.long: 91 | ys = ys.long().to(device) 92 | try: 93 | mask = mask.long().to(device) 94 | xs, ys = xs*mask, ys*mask 95 | img.index_put_((ys, xs), ps, accumulate=True) 96 | except Exception as e: 97 | print("Unable to put tensor {} positions ({}, {}) into {}. Range = {},{}".format( 98 | ps.shape, ys.shape, xs.shape, img.shape, torch.max(ys), torch.max(xs))) 99 | raise e 100 | return img 101 | 102 | def interpolate_to_image(pxs, pys, dxs, dys, weights, img): 103 | """ 104 | Accumulate x and y coords to an image using bilinear interpolation 105 | @param pxs Numpy array of integer typecast x coords of events 106 | @param pys Numpy array of integer typecast y coords of events 107 | @param dxs Numpy array of residual difference between x coord and int(x coord) 108 | @param dys Numpy array of residual difference between y coord and int(y coord) 109 | @returns Image 110 | """ 111 | img.index_put_((pys, pxs ), weights*(1.0-dxs)*(1.0-dys), accumulate=True) 112 | img.index_put_((pys, pxs+1), weights*dxs*(1.0-dys), accumulate=True) 113 | img.index_put_((pys+1, pxs ), weights*(1.0-dxs)*dys, accumulate=True) 114 | img.index_put_((pys+1, pxs+1), weights*dxs*dys, accumulate=True) 115 | return img 116 | 117 | def interpolate_to_derivative_img(pxs, pys, dxs, dys, d_img, w1, w2): 118 | """ 119 | Accumulate x and y coords to an image using double weighted bilinear interpolation. 120 | This allows for computing gradient images, since in the gradient image the interpolation 121 | is weighted by the values of the Jacobian. 122 | @param pxs Numpy array of integer typecast x coords of events 123 | @param pys Numpy array of integer typecast y coords of events 124 | @param dxs Numpy array of residual difference between x coord and int(x coord) 125 | @param dys Numpy array of residual difference between y coord and int(y coord) 126 | @param dimg Derivative image (needs to be of appropriate dimensions) 127 | @param w1 Weight for x component of bilinear interpolation 128 | @param w2 Weight for y component of bilinear interpolation 129 | @returns Image 130 | """ 131 | for i in range(d_img.shape[0]): 132 | d_img[i].index_put_((pys, pxs ), w1[i] * (-(1.0-dys)) + w2[i] * (-(1.0-dxs)), accumulate=True) 133 | d_img[i].index_put_((pys, pxs+1), w1[i] * (1.0-dys) + w2[i] * (-dxs), accumulate=True) 134 | d_img[i].index_put_((pys+1, pxs ), w1[i] * (-dys) + w2[i] * (1.0-dxs), accumulate=True) 135 | d_img[i].index_put_((pys+1, pxs+1), w1[i] * dys + w2[i] * dxs, accumulate=True) 136 | return d_img 137 | 138 | def image_to_event_weights(xs, ys, img): 139 | """ 140 | Given an image and a set of event coordinates, get the pixel value 141 | of the image for each event using reverse bilinear interpolation 142 | @param xs x coords of events 143 | @param ys y coords of events 144 | @param img The image from which to draw the weights 145 | @return List containing the value in the image for each event 146 | """ 147 | clipx, clipy = img.shape[1]-1, img.shape[0]-1 148 | mask = np.where(xs>=clipx, 0, 1)*np.where(ys>=clipy, 0, 1) 149 | 150 | pxs = np.floor(xs*mask).astype(int) 151 | pys = np.floor(ys*mask).astype(int) 152 | dxs = xs-pxs 153 | dys = ys-pys 154 | wxs, wys = 1.0-dxs, 1.0-dys 155 | 156 | weights = img[pys, pxs] *wxs*wys 157 | weights += img[pys, pxs+1] *dxs*wys 158 | weights += img[pys+1, pxs] *wxs*dys 159 | weights += img[pys+1, pxs+1] *dxs*dys 160 | return weights*mask 161 | 162 | def events_to_image_drv(xn, yn, pn, jacobian_xn, jacobian_yn, 163 | device=None, sensor_size=(180, 240), clip_out_of_range=True, 164 | interpolation='bilinear', padding=True, compute_gradient=False): 165 | """ 166 | Method to turn event tensor to image and derivative image (given event Jacobians). 167 | Allows for bilinear interpolation. 168 | @param xs Tensor of x coords of events 169 | @param ys Tensor of y coords of events 170 | @param ps Tensor of event polarities/weights 171 | @param device The device on which the image is. If none, set to events device 172 | @param sensor_size The size of the image sensor/output image 173 | @param clip_out_of_range If the events go beyond the desired image size, 174 | clip the events to fit into the image 175 | @param interpolation Which interpolation to use. Options=None,'bilinear' 176 | @param padding If bilinear interpolation, allow padding the image by 1 to allow events to fit: 177 | @param compute_gradient If True, compute the image gradient 178 | """ 179 | xt, yt, pt = torch.from_numpy(xn), torch.from_numpy(yn), torch.from_numpy(pn) 180 | xs, ys, ps, = xt.float(), yt.float(), pt.float() 181 | if compute_gradient: 182 | jacobian_x, jacobian_y = torch.from_numpy(jacobian_xn), torch.from_numpy(jacobian_yn) 183 | jacobian_x, jacobian_y = jacobian_x.float(), jacobian_y.float() 184 | if device is None: 185 | device = xs.device 186 | if padding: 187 | img_size = (sensor_size[0]+1, sensor_size[1]+1) 188 | else: 189 | img_size = sensor_size 190 | 191 | mask = torch.ones(xs.size()) 192 | if clip_out_of_range: 193 | zero_v = torch.tensor([0.]) 194 | ones_v = torch.tensor([1.]) 195 | clipx = img_size[1] if interpolation is None and padding==False else img_size[1]-1 196 | clipy = img_size[0] if interpolation is None and padding==False else img_size[0]-1 197 | mask = torch.where(xs>=clipx, zero_v, ones_v)*torch.where(ys>=clipy, zero_v, ones_v) 198 | 199 | pxs = xs.floor() 200 | pys = ys.floor() 201 | dxs = xs-pxs 202 | dys = ys-pys 203 | pxs = (pxs*mask).long() 204 | pys = (pys*mask).long() 205 | masked_ps = ps*mask 206 | img = torch.zeros(img_size).to(device) 207 | interpolate_to_image(pxs, pys, dxs, dys, masked_ps, img) 208 | 209 | if compute_gradient: 210 | d_img = torch.zeros((2, *img_size)).to(device) 211 | w1 = jacobian_x*masked_ps 212 | w2 = jacobian_y*masked_ps 213 | interpolate_to_derivative_img(pxs, pys, dxs, dys, d_img, w1, w2) 214 | d_img = d_img.numpy() 215 | else: 216 | d_img = None 217 | return img.numpy(), d_img 218 | 219 | def events_to_timestamp_image(xn, yn, ts, pn, 220 | device=None, sensor_size=(180, 240), clip_out_of_range=True, 221 | interpolation='bilinear', padding=True, normalize_timestamps=True): 222 | """ 223 | Method to generate the average timestamp images from 'Zhu19, Unsupervised Event-based Learning 224 | of Optical Flow, Depth, and Egomotion'. This method does not have known derivative. 225 | @param xs List of event x coordinates 226 | @param ys List of event y coordinates 227 | @param ts List of event timestamps 228 | @param ps List of event polarities 229 | @param device The device that the events are on 230 | @param sensor_size The size of the event sensor/output voxels 231 | @param clip_out_of_range If the events go beyond the desired image size, 232 | clip the events to fit into the image 233 | @param interpolation Which interpolation to use. Options=None,'bilinear' 234 | @param padding If bilinear interpolation, allow padding the image by 1 to allow events to fit 235 | @returns Timestamp images of the positive and negative events: ti_pos, ti_neg 236 | """ 237 | 238 | t0 = ts[0] 239 | xt, yt, ts, pt = torch.from_numpy(xn), torch.from_numpy(yn), torch.from_numpy(ts-t0), torch.from_numpy(pn) 240 | xs, ys, ts, ps = xt.float(), yt.float(), ts.float(), pt.float() 241 | zero_v = torch.tensor([0.]) 242 | ones_v = torch.tensor([1.]) 243 | if device is None: 244 | device = xs.device 245 | if padding: 246 | img_size = (sensor_size[0]+1, sensor_size[1]+1) 247 | else: 248 | img_size = sensor_size 249 | 250 | mask = torch.ones(xs.size()) 251 | if clip_out_of_range: 252 | clipx = img_size[1] if interpolation is None and padding==False else img_size[1]-1 253 | clipy = img_size[0] if interpolation is None and padding==False else img_size[0]-1 254 | mask = torch.where(xs>=clipx, zero_v, ones_v)*torch.where(ys>=clipy, zero_v, ones_v) 255 | 256 | pos_events_mask = torch.where(ps>0, ones_v, zero_v) 257 | neg_events_mask = torch.where(ps<=0, ones_v, zero_v) 258 | normalized_ts = (ts-ts[0])/(ts[-1]+1e-6) if normalize_timestamps else ts 259 | pxs = xs.floor() 260 | pys = ys.floor() 261 | dxs = xs-pxs 262 | dys = ys-pys 263 | pxs = (pxs*mask).long() 264 | pys = (pys*mask).long() 265 | masked_ps = ps*mask 266 | 267 | pos_weights = normalized_ts*pos_events_mask 268 | neg_weights = normalized_ts*neg_events_mask 269 | img_pos = torch.zeros(img_size).to(device) 270 | img_pos_cnt = torch.ones(img_size).to(device) 271 | img_neg = torch.zeros(img_size).to(device) 272 | img_neg_cnt = torch.ones(img_size).to(device) 273 | 274 | interpolate_to_image(pxs, pys, dxs, dys, pos_weights, img_pos) 275 | interpolate_to_image(pxs, pys, dxs, dys, pos_events_mask, img_pos_cnt) 276 | interpolate_to_image(pxs, pys, dxs, dys, neg_weights, img_neg) 277 | interpolate_to_image(pxs, pys, dxs, dys, neg_events_mask, img_neg_cnt) 278 | 279 | img_pos, img_pos_cnt = img_pos.numpy(), img_pos_cnt.numpy() 280 | img_pos_cnt[img_pos_cnt==0] = 1 281 | img_neg, img_neg_cnt = img_neg.numpy(), img_neg_cnt.numpy() 282 | img_neg_cnt[img_neg_cnt==0] = 1 283 | img_pos, img_neg = img_pos/img_pos_cnt, img_neg/img_neg_cnt 284 | return img_pos, img_neg 285 | 286 | def events_to_timestamp_image_torch(xs, ys, ts, ps, 287 | device=None, sensor_size=(180, 240), clip_out_of_range=True, 288 | interpolation='bilinear', padding=True, timestamp_reverse=False): 289 | """ 290 | Method to generate the average timestamp images from 'Zhu19, Unsupervised Event-based Learning 291 | of Optical Flow, Depth, and Egomotion'. This method does not have known derivative. 292 | @param xs List of event x coordinates 293 | @param ys List of event y coordinates 294 | @param ts List of event timestamps 295 | @param ps List of event polarities 296 | @param device The device that the events are on 297 | @param sensor_size The size of the event sensor/output voxels 298 | @param clip_out_of_range If the events go beyond the desired image size, 299 | clip the events to fit into the image 300 | @param interpolation Which interpolation to use. Options=None,'bilinear' 301 | @param padding If bilinear interpolation, allow padding the image by 1 to allow events to fit 302 | @param timestamp_reverse Reverse the timestamps of the events, for backward warping 303 | @returns Timestamp images of the positive and negative events: ti_pos, ti_neg 304 | """ 305 | if device is None: 306 | device = xs.device 307 | xs, ys, ps, ts = xs.squeeze(), ys.squeeze(), ps.squeeze(), ts.squeeze() 308 | if padding: 309 | img_size = (sensor_size[0]+1, sensor_size[1]+1) 310 | else: 311 | img_size = sensor_size 312 | zero_v = torch.tensor([0.], device=device) 313 | ones_v = torch.tensor([1.], device=device) 314 | 315 | mask = torch.ones(xs.size(), device=device) 316 | if clip_out_of_range: 317 | clipx = img_size[1] if interpolation is None and padding==False else img_size[1]-1 318 | clipy = img_size[0] if interpolation is None and padding==False else img_size[0]-1 319 | mask = torch.where(xs>=clipx, zero_v, ones_v)*torch.where(ys>=clipy, zero_v, ones_v) 320 | 321 | pos_events_mask = torch.where(ps>0, ones_v, zero_v) 322 | neg_events_mask = torch.where(ps<=0, ones_v, zero_v) 323 | epsilon = 1e-6 324 | if timestamp_reverse: 325 | normalized_ts = ((-ts+ts[-1])/(ts[-1]-ts[0]+epsilon)).squeeze() 326 | else: 327 | normalized_ts = ((ts-ts[0])/(ts[-1]-ts[0]+epsilon)).squeeze() 328 | pxs = xs.floor().float() 329 | pys = ys.floor().float() 330 | dxs = (xs-pxs).float() 331 | dys = (ys-pys).float() 332 | pxs = (pxs*mask).long() 333 | pys = (pys*mask).long() 334 | masked_ps = ps*mask 335 | 336 | pos_weights = (normalized_ts*pos_events_mask).float() 337 | neg_weights = (normalized_ts*neg_events_mask).float() 338 | img_pos = torch.zeros(img_size).to(device) 339 | img_pos_cnt = torch.ones(img_size).to(device) 340 | img_neg = torch.zeros(img_size).to(device) 341 | img_neg_cnt = torch.ones(img_size).to(device) 342 | 343 | interpolate_to_image(pxs, pys, dxs, dys, pos_weights, img_pos) 344 | interpolate_to_image(pxs, pys, dxs, dys, pos_events_mask, img_pos_cnt) 345 | interpolate_to_image(pxs, pys, dxs, dys, neg_weights, img_neg) 346 | interpolate_to_image(pxs, pys, dxs, dys, neg_events_mask, img_neg_cnt) 347 | 348 | # Avoid division by 0 349 | img_pos_cnt[img_pos_cnt==0] = 1 350 | img_neg_cnt[img_neg_cnt==0] = 1 351 | img_pos = img_pos.div(img_pos_cnt) 352 | img_neg = img_neg.div(img_neg_cnt) 353 | return img_pos, img_neg #/img_pos_cnt, img_neg/img_neg_cnt 354 | 355 | class TimestampImage: 356 | 357 | def __init__(self, sensor_size): 358 | self.sensor_size = sensor_size 359 | self.num_pixels = sensor_size[0]*sensor_size[1] 360 | self.image = np.ones(sensor_size) 361 | 362 | def set_init(self, value): 363 | self.image = np.ones_like(self.image)*value 364 | 365 | def add_event(self, x, y, t, p): 366 | self.image[int(y), int(x)] = t 367 | 368 | def add_events(self, xs, ys, ts, ps): 369 | for x, y, t in zip(xs, ys, ts): 370 | self.add_event(x, y, t, 0) 371 | 372 | def get_image(self): 373 | sort_args = rankdata(self.image, method='dense') 374 | sort_args = sort_args-1 375 | sort_args = sort_args.reshape(self.sensor_size) 376 | sort_args = sort_args/np.max(sort_args) 377 | return sort_args 378 | 379 | class EventImage: 380 | 381 | def __init__(self, sensor_size): 382 | self.sensor_size = sensor_size 383 | self.num_pixels = sensor_size[0]*sensor_size[1] 384 | self.image = np.ones(sensor_size) 385 | 386 | def add_event(self, x, y, t, p): 387 | self.image[int(y), int(x)] += p 388 | 389 | def add_events(self, xs, ys, ts, ps): 390 | for x, y, t in zip(xs, ys, ts): 391 | self.add_event(x, y, t, 0) 392 | 393 | def get_image(self): 394 | mn, mx = np.min(self.image), np.max(self.image) 395 | norm_img = (self.image-mn)/(mx-mn) 396 | return norm_img 397 | -------------------------------------------------------------------------------- /lib/representations/voxel_grid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import cv2 as cv 5 | import torch 6 | from ..util.event_util import events_bounds_mask 7 | from .image import events_to_image, events_to_image_torch 8 | 9 | def get_voxel_grid_as_image(voxelgrid): 10 | """ 11 | Debug function. Returns a voxelgrid as a series of images, 12 | one for each bin for display. 13 | @param voxelgrid Input voxel grid 14 | @returns Image of N bins placed side by side 15 | """ 16 | images = [] 17 | splitter = np.ones((voxelgrid.shape[1], 2))*np.max(voxelgrid) 18 | for image in voxelgrid: 19 | images.append(image) 20 | images.append(splitter) 21 | images.pop() 22 | sidebyside = np.hstack(images) 23 | sidebyside = cv.normalize(sidebyside, None, 0, 255, cv.NORM_MINMAX) 24 | return sidebyside 25 | 26 | def plot_voxel_grid(voxelgrid, cmap='gray'): 27 | """ 28 | Debug function. Given a voxel grid, display it as an image. 29 | @param voxelgrid The input voxel grid 30 | @param cmap The color map to use 31 | @returns None 32 | """ 33 | sidebyside = get_voxel_grid_as_image(voxelgrid) 34 | plt.imshow(sidebyside, cmap=cmap) 35 | plt.show() 36 | 37 | def voxel_grids_fixed_n_torch(xs, ys, ts, ps, B, n, sensor_size=(180, 240), temporal_bilinear=True): 38 | """ 39 | Given a set of events, return the voxel grid formed with a fixed number of events. 40 | @param xs List of event x coordinates (torch tensor) 41 | @param ys List of event y coordinates (torch tensor) 42 | @param ts List of event timestamps (torch tensor) 43 | @param ps List of event polarities (torch tensor) 44 | @param B Number of bins in output voxel grids (int) 45 | @param n The number of events per voxel 46 | @param sensor_size The size of the event sensor/output voxels 47 | @param temporal_bilinear Whether the events should be naively 48 | accumulated to the voxels (faster), or properly 49 | temporally distributed 50 | @returns List of output voxel grids 51 | """ 52 | voxels = [] 53 | for idx in range(0, len(xs)-n, n): 54 | voxels.append(events_to_voxel_torch(xs[idx:idx+n], ys[idx:idx+n], 55 | ts[idx:idx+n], ps[idx:idx+n], B, sensor_size=sensor_size, 56 | temporal_bilinear=temporal_bilinear)) 57 | return voxels 58 | 59 | def voxel_grids_fixed_t_torch(xs, ys, ts, ps, B, t, sensor_size=(180, 240), temporal_bilinear=True): 60 | """ 61 | Given a set of events, return a voxel grid with a fixed temporal width. 62 | @param xs List of event x coordinates (torch tensor) 63 | @param ys List of event y coordinates (torch tensor) 64 | @param ts List of event timestamps (torch tensor) 65 | @param ps List of event polarities (torch tensor) 66 | @param B Number of bins in output voxel grids (int) 67 | @param t The time width of the voxel grids 68 | @param sensor_size The size of the event sensor/output voxels 69 | @param temporal_bilinear Whether the events should be naively 70 | accumulated to the voxels (faster), or properly 71 | temporally distributed 72 | @returns List of output voxel grids 73 | """ 74 | device = xs.device 75 | voxels = [] 76 | np_ts = ts.cpu().numpy() 77 | for t_start in np.arange(ts[0].item(), ts[-1].item()-t, t): 78 | voxels.append(events_to_voxel_timesync_torch(xs, ys, ts, ps, B, t_start, t_start+t, np_ts=np_ts, 79 | sensor_size=sensor_size, temporal_bilinear=temporal_bilinear)) 80 | return voxels 81 | 82 | def events_to_voxel_timesync_torch(xs, ys, ts, ps, B, t0, t1, device=None, np_ts=None, 83 | sensor_size=(180, 240), temporal_bilinear=True): 84 | """ 85 | Given a set of events, return a voxel grid of the events between t0 and t1 86 | @param xs List of event x coordinates (torch tensor) 87 | @param ys List of event y coordinates (torch tensor) 88 | @param ts List of event timestamps (torch tensor) 89 | @param ps List of event polarities (torch tensor) 90 | @param B Number of bins in output voxel grids (int) 91 | @param t0 The start time of the voxel grid 92 | @param t1 The end time of the voxel grid 93 | @param device Device to put voxel grid. If left empty, same device as events 94 | @param np_ts A numpy copy of ts (optional). If not given, will be created in situ 95 | @param sensor_size The size of the event sensor/output voxels 96 | @param temporal_bilinear Whether the events should be naively 97 | accumulated to the voxels (faster), or properly 98 | temporally distributed 99 | @returns Voxel of the events between t0 and t1 100 | """ 101 | assert(t1>t0) 102 | if np_ts is None: 103 | np_ts = ts.cpu().numpy() 104 | if device is None: 105 | device = xs.device 106 | start_idx = np.searchsorted(np_ts, t0) 107 | end_idx = np.searchsorted(np_ts, t1) 108 | assert(start_idx < end_idx) 109 | voxel = events_to_voxel_torch(xs[start_idx:end_idx], ys[start_idx:end_idx], 110 | ts[start_idx:end_idx], ps[start_idx:end_idx], B, device, sensor_size=sensor_size, 111 | temporal_bilinear=temporal_bilinear) 112 | return voxel 113 | 114 | def events_to_voxel_torch(xs, ys, ts, ps, B, device=None, sensor_size=(180, 240), temporal_bilinear=True): 115 | """ 116 | Turn set of events to a voxel grid tensor, using temporal bilinear interpolation 117 | @param xs List of event x coordinates (torch tensor) 118 | @param ys List of event y coordinates (torch tensor) 119 | @param ts List of event timestamps (torch tensor) 120 | @param ps List of event polarities (torch tensor) 121 | @param B Number of bins in output voxel grids (int) 122 | @param device Device to put voxel grid. If left empty, same device as events 123 | @param sensor_size The size of the event sensor/output voxels 124 | @param temporal_bilinear Whether the events should be naively 125 | accumulated to the voxels (faster), or properly 126 | temporally distributed 127 | @returns Voxel of the events between t0 and t1 128 | """ 129 | if device is None: 130 | device = xs.device 131 | assert(len(xs)==len(ys) and len(ys)==len(ts) and len(ts)==len(ps)) 132 | bins = [] 133 | dt = ts[-1]-ts[0] 134 | t_norm = (ts-ts[0])/dt*(B-1) 135 | zeros = torch.zeros(t_norm.size()) 136 | for bi in range(B): 137 | if temporal_bilinear: 138 | bilinear_weights = torch.max(zeros, 1.0-torch.abs(t_norm-bi)) 139 | weights = ps*bilinear_weights 140 | vb = events_to_image_torch(xs, ys, 141 | weights, device, sensor_size=sensor_size, 142 | clip_out_of_range=False) 143 | else: 144 | tstart = t[0] + dt*bi 145 | tend = tstart + dt 146 | beg = binary_search_torch_tensor(t, 0, len(ts)-1, tstart) 147 | end = binary_search_torch_tensor(t, 0, len(ts)-1, tend) 148 | vb = events_to_image_torch(xs[beg:end], ys[beg:end], 149 | ps[beg:end], device, sensor_size=sensor_size, 150 | clip_out_of_range=False) 151 | bins.append(vb) 152 | bins = torch.stack(bins) 153 | return bins 154 | 155 | def events_to_neg_pos_voxel_torch(xs, ys, ts, ps, B, device=None, 156 | sensor_size=(180, 240), temporal_bilinear=True): 157 | """ 158 | Turn set of events to a voxel grid tensor, using temporal bilinear interpolation. 159 | Positive and negative events are put into separate voxel grids 160 | @param xs List of event x coordinates (torch tensor) 161 | @param ys List of event y coordinates (torch tensor) 162 | @param ts List of event timestamps (torch tensor) 163 | @param ps List of event polarities (torch tensor) 164 | @param B Number of bins in output voxel grids (int) 165 | @param device Device to put voxel grid. If left empty, same device as events 166 | @param sensor_size The size of the event sensor/output voxels 167 | @param temporal_bilinear Whether the events should be naively 168 | accumulated to the voxels (faster), or properly 169 | temporally distributed 170 | @returns Two voxel grids, one for positive one for negative events 171 | """ 172 | zero_v = torch.tensor([0.]) 173 | ones_v = torch.tensor([1.]) 174 | pos_weights = torch.where(ps>0, ones_v, zero_v) 175 | neg_weights = torch.where(ps<=0, ones_v, zero_v) 176 | 177 | voxel_pos = events_to_voxel_torch(xs, ys, ts, pos_weights, B, device=device, 178 | sensor_size=sensor_size, temporal_bilinear=temporal_bilinear) 179 | voxel_neg = events_to_voxel_torch(xs, ys, ts, neg_weights, B, device=device, 180 | sensor_size=sensor_size, temporal_bilinear=temporal_bilinear) 181 | 182 | return voxel_pos, voxel_neg 183 | 184 | def events_to_voxel(xs, ys, ts, ps, B, sensor_size=(180, 240), temporal_bilinear=True): 185 | """ 186 | Turn set of events to a voxel grid tensor, using temporal bilinear interpolation 187 | @param xs List of event x coordinates (torch tensor) 188 | @param ys List of event y coordinates (torch tensor) 189 | @param ts List of event timestamps (torch tensor) 190 | @param ps List of event polarities (torch tensor) 191 | @param B Number of bins in output voxel grids (int) 192 | @param sensor_size The size of the event sensor/output voxels 193 | @param temporal_bilinear Whether the events should be naively 194 | accumulated to the voxels (faster), or properly 195 | temporally distributed 196 | @returns Voxel of the events between t0 and t1 197 | """ 198 | assert(len(xs)==len(ys) and len(ys)==len(ts) and len(ts)==len(ps)) 199 | num_events_per_bin = len(xs)//B 200 | bins = [] 201 | dt = ts[-1]-ts[0] 202 | t_norm = (ts-ts[0])/dt*(B-1) 203 | zeros = (np.expand_dims(np.zeros(t_norm.shape[0]), axis=0).transpose()).squeeze() 204 | for bi in range(B): 205 | if temporal_bilinear: 206 | bilinear_weights = np.maximum(zeros, 1.0-np.abs(t_norm-bi)) 207 | weights = ps*bilinear_weights 208 | vb = events_to_image(xs.squeeze(), ys.squeeze(), weights.squeeze(), 209 | sensor_size=sensor_size, interpolation=None) 210 | else: 211 | beg = bi*num_events_per_bin 212 | end = beg + num_events_per_bin 213 | vb = events_to_image(xs[beg:end], ys[beg:end], 214 | weights[beg:end], sensor_size=sensor_size) 215 | bins.append(vb) 216 | bins = np.stack(bins) 217 | return bins 218 | 219 | def events_to_neg_pos_voxel(xs, ys, ts, ps, B, 220 | sensor_size=(180, 240), temporal_bilinear=True): 221 | """ 222 | Turn set of events to a voxel grid tensor, using temporal bilinear interpolation. 223 | Positive and negative events are put into separate voxel grids 224 | @param xs List of event x coordinates (torch tensor) 225 | @param ys List of event y coordinates (torch tensor) 226 | @param ts List of event timestamps (torch tensor) 227 | @param ps List of event polarities (torch tensor) 228 | @param B Number of bins in output voxel grids (int) 229 | @param sensor_size The size of the event sensor/output voxels 230 | @param temporal_bilinear Whether the events should be naively 231 | accumulated to the voxels (faster), or properly 232 | temporally distributed 233 | @returns Two voxel grids, one for positive one for negative events 234 | """ 235 | pos_weights = np.where(ps, 1, 0) 236 | neg_weights = np.where(ps, 0, 1) 237 | 238 | voxel_pos = events_to_voxel(xs, ys, ts, pos_weights, B, 239 | sensor_size=sensor_size, temporal_bilinear=temporal_bilinear) 240 | voxel_neg = events_to_voxel(xs, ys, ts, neg_weights, B, 241 | sensor_size=sensor_size, temporal_bilinear=temporal_bilinear) 242 | 243 | return voxel_pos, voxel_neg 244 | -------------------------------------------------------------------------------- /lib/transforms/optic_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | def warp_events_flow_torch(xt, yt, tt, pt, flow_field, t0=None, 6 | batched=False, batch_indices=None): 7 | """ 8 | Given events and a flow field, warp the events by the flow 9 | Parameters 10 | ---------- 11 | xs : list of event x coordinates 12 | ys : list of event y coordinates 13 | ts : list of event timestamps 14 | ps : list of event polarities 15 | flow_field : 2D tensor containing the flow at each x,y position 16 | t0 : the reference time to warp events to. If empty, will use the 17 | timestamp of the last event 18 | Returns 19 | ------- 20 | warped_xt: x coords of warped events 21 | warped_yt: y coords of warped events 22 | """ 23 | if len(xt.shape) > 1: 24 | xt, yt, tt, pt = xt.squeeze(), yt.squeeze(), tt.squeeze(), pt.squeeze() 25 | if t0 is None: 26 | t0 = tt[-1] 27 | while len(flow_field.size()) < 4: 28 | flow_field = flow_field.unsqueeze(0) 29 | if len(xt.size()) == 1: 30 | event_indices = torch.transpose(torch.stack((xt, yt), dim=0), 0, 1) 31 | else: 32 | event_indices = torch.transpose(torch.cat((xt, yt), dim=1), 0, 1) 33 | #event_indices.requires_grad_ = False 34 | event_indices = torch.reshape(event_indices, [1, 1, len(xt), 2]) 35 | 36 | # Event indices need to be between -1 and 1 for F.gridsample 37 | event_indices[:,:,:,0] = event_indices[:,:,:,0]/(flow_field.shape[-1]-1)*2.0-1.0 38 | event_indices[:,:,:,1] = event_indices[:,:,:,1]/(flow_field.shape[-2]-1)*2.0-1.0 39 | 40 | flow_at_event = F.grid_sample(flow_field, event_indices, align_corners=True) 41 | dt = (tt-t0).squeeze() 42 | 43 | warped_xt = xt+flow_at_event[:,0,:,:].squeeze()*dt 44 | warped_yt = yt+flow_at_event[:,1,:,:].squeeze()*dt 45 | 46 | return warped_xt, warped_yt 47 | 48 | -------------------------------------------------------------------------------- /lib/util/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | from .event_util import * 3 | from .util import * 4 | -------------------------------------------------------------------------------- /lib/util/event_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | from ..representations.image import events_to_image 4 | 5 | def infer_resolution(xs, ys): 6 | """ 7 | Given events, guess the resolution by looking at the max and min values 8 | @param xs Event x coords 9 | @param ys Event y coords 10 | @returns Inferred resolution 11 | """ 12 | sr = [np.max(ys) + 1, np.max(xs) + 1] 13 | return sr 14 | 15 | def events_bounds_mask(xs, ys, x_min, x_max, y_min, y_max): 16 | """ 17 | Get a mask of the events that are within the given bounds 18 | @param xs Event x coords 19 | @param ys Event y coords 20 | @param x_min Lower bound of x axis 21 | @param x_max Upper bound of x axis 22 | @param y_min Lower bound of y axis 23 | @param y_max Upper bound of y axis 24 | @returns mask 25 | """ 26 | mask = np.where(np.logical_or(xs<=x_min, xs>x_max), 0.0, 1.0) 27 | mask *= np.where(np.logical_or(ys<=y_min, ys>y_max), 0.0, 1.0) 28 | return mask 29 | 30 | def cut_events_to_lifespan(xs, ys, ts, ps, params, 31 | pixel_crossings, minimum_events=100, side='back'): 32 | """ 33 | Given motion model parameters, compute the speed and thus 34 | the lifespan, given a desired number of pixel crossings 35 | @param xs Event x coords 36 | @param ys Event y coords 37 | @param ts Event timestamps 38 | @param ps Event polarities 39 | @param params Motion model parameters 40 | @param pixel_crossings Number of pixel crossings 41 | @param minimum_events The minimum number of events to cut down to 42 | @param side Cut events from 'back' or 'front' 43 | @returns Cut events 44 | """ 45 | magnitude = np.linalg.norm(params) 46 | dt = pixel_crossings/magnitude 47 | if side == 'back': 48 | s_idx = np.searchsorted(ts, ts[-1]-dt) 49 | num_events = len(xs)-s_idx 50 | s_idx = len(xs)-minimum_events if num_events < minimum_events else s_idx 51 | return xs[s_idx:-1], ys[s_idx:-1], ts[s_idx:-1], ps[s_idx:-1] 52 | elif side == 'front': 53 | s_idx = np.searchsorted(ts, dt+ts[0]) 54 | num_events = s_idx 55 | s_idx = minimum_events if num_events < minimum_events else s_idx 56 | return xs[0:s_idx], ys[0:s_idx], ts[0:s_idx], ps[0:s_idx] 57 | else: 58 | raise Exception("Invalid side given: {}. To cut events, must provide an \ 59 | appropriate side to cut from, either 'front' or 'back'".format(side)) 60 | 61 | def clip_events_to_bounds(xs, ys, ts, ps, bounds, set_zero=False): 62 | """ 63 | Clip events to the given bounds. 64 | @param xs x coords of events 65 | @param ys y coords of events 66 | @param ts Timestamps of events (may be None) 67 | @param ps Polarities of events (may be None) 68 | @param bounds the bounds of the events. Must be list of 69 | length 2 (in which case the lower bound is assumed to be 0,0) 70 | or length 4, in format [min_y, max_y, min_x, max_x] 71 | @param: set_zero if True, simply multiplies the out of bounds events with 0 mask. 72 | Otherwise, removes the events. 73 | @returns Clipped events 74 | """ 75 | if len(bounds) == 2: 76 | bounds = [0, bounds[0], 0, bounds[1]] 77 | elif len(bounds) != 4: 78 | raise Exception("Bounds must be of length 2 or 4 (not {})".format(len(bounds))) 79 | miny, maxy, minx, maxx = bounds 80 | if set_zero: 81 | mask = events_bounds_mask(xs, ys, minx, maxx, miny, maxy) 82 | ts_mask = None if ts is None else ts*mask 83 | ps_mask = None if ps is None else ps*mask 84 | return xs*mask, ys*mask, ts_mask, ps_mask 85 | else: 86 | x_clip_idc = np.argwhere((xs >= minx) & (xs < maxx))[:, 0] 87 | y_subset = ys[x_clip_idc] 88 | y_clip_idc = np.argwhere((y_subset >= miny) & (y_subset < maxy))[:, 0] 89 | 90 | xs_clip = xs[x_clip_idc][y_clip_idc] 91 | ys_clip = ys[x_clip_idc][y_clip_idc] 92 | ts_clip = None if ts is None else ts[x_clip_idc][y_clip_idc] 93 | ps_clip = None if ps is None else ps[x_clip_idc][y_clip_idc] 94 | return xs_clip, ys_clip, ts_clip, ps_clip 95 | 96 | def get_events_from_mask(mask, xs, ys): 97 | """ 98 | Given an image mask, return the indices of all events at each location in the mask 99 | @params mask The image mask 100 | @param xs x components of events as list 101 | @param ys y components of events as list 102 | @returns Indices of events that lie on the mask 103 | """ 104 | xs = xs.astype(int) 105 | ys = ys.astype(int) 106 | idx = np.stack((ys, xs)) 107 | event_vals = mask[tuple(idx)] 108 | event_indices = np.argwhere(event_vals >= 0.01).squeeze() 109 | return event_indices 110 | 111 | def binary_search_h5_dset(dset, x, l=None, r=None, side='left'): 112 | """ 113 | Binary search for a timestamp in an HDF5 event file, without 114 | loading the entire file into RAM 115 | @param dset The HDF5 dataset 116 | @param x The timestamp being searched for 117 | @param l Starting guess for the left side (0 if None is chosen) 118 | @param r Starting guess for the right side (-1 if None is chosen) 119 | @param side Which side to take final result for if exact match is not found 120 | @returns Index of nearest event to 'x' 121 | """ 122 | l = 0 if l is None else l 123 | r = len(dset)-1 if r is None else r 124 | while l <= r: 125 | mid = l + (r - l)//2; 126 | midval = dset[mid] 127 | if midval == x: 128 | return mid 129 | elif midval < x: 130 | l = mid + 1 131 | else: 132 | r = mid - 1 133 | if side == 'left': 134 | return l 135 | return r 136 | 137 | def binary_search_h5_timestamp(hdf_path, l, r, x, side='left'): 138 | f = h5py.File(hdf_path, 'r') 139 | return binary_search_h5_dset(f['events/ts'], x, l=l, r=r, side=side) 140 | 141 | def binary_search_torch_tensor(t, l, r, x, side='left'): 142 | """ 143 | Binary search implemented for pytorch tensors (no native implementation exists) 144 | @param t The tensor 145 | @param x The value being searched for 146 | @param l Starting lower bound (0 if None is chosen) 147 | @param r Starting upper bound (-1 if None is chosen) 148 | @param side Which side to take final result for if exact match is not found 149 | @returns Index of nearest event to 'x' 150 | """ 151 | if r is None: 152 | r = len(t)-1 153 | while l <= r: 154 | mid = l + (r - l)//2; 155 | midval = t[mid] 156 | if midval == x: 157 | return mid 158 | elif midval < x: 159 | l = mid + 1 160 | else: 161 | r = mid - 1 162 | if side == 'left': 163 | return l 164 | return r 165 | 166 | def remove_hot_pixels(xs, ys, ts, ps, sensor_size=(180, 240), num_hot=50): 167 | """ 168 | Given a set of events, removes the 'hot' pixel events. 169 | Accumulates all of the events into an event image and removes 170 | the 'num_hot' highest value pixels. 171 | @param xs Event x coords 172 | @param ys Event y coords 173 | @param ts Event timestamps 174 | @param ps Event polarities 175 | @param sensor_size The size of the event camera sensor 176 | @param num_hot The number of hot pixels to remove 177 | """ 178 | img = events_to_image(xs, ys, ps, sensor_size=sensor_size) 179 | hot = np.array([]) 180 | for i in range(num_hot): 181 | maxc = np.unravel_index(np.argmax(img), sensor_size) 182 | #print("{} = {}".format(maxc, img[maxc])) 183 | img[maxc] = 0 184 | h = np.where((xs == maxc[1]) & (ys == maxc[0])) 185 | hot = np.concatenate((hot, h[0])) 186 | xs, ys, ts, ps = np.delete(xs, hot), np.delete(ys, hot), np.delete(ts, hot), np.delete(ps, hot) 187 | return xs, ys, ts, ps 188 | -------------------------------------------------------------------------------- /lib/util/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import cv2 as cv 4 | import pandas as pd 5 | from pathlib import Path 6 | from itertools import repeat 7 | from collections import OrderedDict 8 | from math import fabs, ceil, floor 9 | from torch.nn import ZeroPad2d 10 | import matplotlib.pyplot as plt 11 | import matplotlib.patches as patches 12 | import cv2 as cv 13 | 14 | 15 | def ensure_dir(dirname): 16 | """ 17 | Ensure a directory exists, if not create it 18 | @param dirname Directory name 19 | @returns None 20 | """ 21 | dirname = Path(dirname) 22 | if not dirname.is_dir(): 23 | dirname.mkdir(parents=True, exist_ok=False) 24 | 25 | 26 | def read_json(fname): 27 | fname = Path(fname) 28 | with fname.open('rt') as handle: 29 | return json.load(handle, object_hook=OrderedDict) 30 | 31 | 32 | def write_json(content, fname): 33 | fname = Path(fname) 34 | with fname.open('wt') as handle: 35 | json.dump(content, handle, indent=4, sort_keys=False) 36 | 37 | 38 | def inf_loop(data_loader): 39 | ''' wrapper function for endless data loader. ''' 40 | for loader in repeat(data_loader): 41 | yield from loader 42 | 43 | 44 | def optimal_crop_size(max_size, max_subsample_factor, safety_margin=0): 45 | """ Find the optimal crop size for a given max_size and subsample_factor. 46 | The optimal crop size is the smallest integer which is greater or equal than max_size, 47 | while being divisible by 2^max_subsample_factor. 48 | """ 49 | crop_size = int(pow(2, max_subsample_factor) * ceil(max_size / pow(2, max_subsample_factor))) 50 | crop_size += safety_margin * pow(2, max_subsample_factor) 51 | return crop_size 52 | 53 | 54 | class CropParameters: 55 | """ Helper class to compute and store useful parameters for pre-processing and post-processing 56 | of images in and out of E2VID. 57 | Pre-processing: finding the best image size for the network, and padding the input image with zeros 58 | Post-processing: Crop the output image back to the original image size 59 | """ 60 | 61 | def __init__(self, width, height, num_encoders, safety_margin=0): 62 | 63 | self.height = height 64 | self.width = width 65 | self.num_encoders = num_encoders 66 | self.width_crop_size = optimal_crop_size(self.width, num_encoders, safety_margin) 67 | self.height_crop_size = optimal_crop_size(self.height, num_encoders, safety_margin) 68 | 69 | self.padding_top = ceil(0.5 * (self.height_crop_size - self.height)) 70 | self.padding_bottom = floor(0.5 * (self.height_crop_size - self.height)) 71 | self.padding_left = ceil(0.5 * (self.width_crop_size - self.width)) 72 | self.padding_right = floor(0.5 * (self.width_crop_size - self.width)) 73 | self.pad = ZeroPad2d((self.padding_left, self.padding_right, self.padding_top, self.padding_bottom)) 74 | 75 | self.cx = floor(self.width_crop_size / 2) 76 | self.cy = floor(self.height_crop_size / 2) 77 | 78 | self.ix0 = self.cx - floor(self.width / 2) 79 | self.ix1 = self.cx + ceil(self.width / 2) 80 | self.iy0 = self.cy - floor(self.height / 2) 81 | self.iy1 = self.cy + ceil(self.height / 2) 82 | 83 | def crop(self, img): 84 | return img[..., self.iy0:self.iy1, self.ix0:self.ix1] 85 | 86 | 87 | def format_power(size): 88 | power = 1e3 89 | n = 0 90 | power_labels = {0: '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'} 91 | while size > power: 92 | size /= power 93 | n += 1 94 | return size, power_labels[n] 95 | 96 | def plot_image(image, lognorm=False, cmap='gray', bbox=None, ticks=False, norm=True, savename=None, colorbar=False): 97 | """ 98 | Plot an image 99 | :param image: The image to plot, as np array 100 | :param lognorm: If true, apply log transform the normalize image 101 | :param cmap: Colormap (defaul gray) 102 | :param bbox: Optional bounding box to draw on image, as array with [[top corner x,y,w,h]] 103 | :param ticks: Whether or not to draw axis ticks 104 | :param norm: Normalize image? 105 | :param savename: Optional save path 106 | :param colorbar: Display color bar if true 107 | """ 108 | fig, ax = plt.subplots(1) 109 | if lognorm: 110 | image = np.log10(image) 111 | cmap='viridis' 112 | if norm: 113 | image = cv.normalize(image, None, 0, 1.0, cv.NORM_MINMAX) 114 | ims = ax.imshow(image, cmap=cmap) 115 | if bbox is not None: 116 | w,h = bbox[2], bbox[3] 117 | rect = patches.Rectangle((bbox[0:2]), w, h, linewidth=1, edgecolor='r', facecolor='none') 118 | ax.add_patch(rect) 119 | if colorbar: 120 | fig.colorbar(ims) 121 | if not ticks: 122 | plt.axis('off') 123 | if savename is not None: 124 | plt.savefig(savename) 125 | plt.show() 126 | 127 | def plot_image_grid(images, grid_shape=None, lognorm=False, 128 | cmap='gray', bbox=None, norm=True, savename=None, 129 | colorbar=False): 130 | """ 131 | Given a list of images, stitches them into a grid and displays/saves the grid 132 | @param images List of images 133 | @param grid_shape Shape of the grid 134 | @param lognorm Logarithmic normalise the image 135 | @param cmap Color map to use 136 | @param bbox Draw a bounding box on the image 137 | @param norm If True, normalise the image 138 | @param savename If set, save the image to that path 139 | @param colorbar If true, plot the colorbar 140 | """ 141 | if grid_shape is None: 142 | grid_shape = [1, len(images)] 143 | 144 | col = [] 145 | img_idx = 0 146 | for xc in range(grid_shape[0]): 147 | row = [] 148 | for yc in range(grid_shape[1]): 149 | image = images[img_idx] 150 | if lognorm: 151 | image = np.log10(image) 152 | cmap='viridis' 153 | if norm: 154 | image = cv.normalize(image, None, 0, 1.0, cv.NORM_MINMAX) 155 | row.append(image) 156 | img_idx += 1 157 | col.append(np.concatenate(row, axis=1)) 158 | comp_img = np.concatenate(col, axis=0) 159 | if savename is None: 160 | plot_image(comp_img, norm=False, colorbar=colorbar, cmap=cmap) 161 | else: 162 | save_image(comp_img, fname=savename, colorbar=colorbar, cmap=cmap) 163 | 164 | def save_image(image, fname=None, lognorm=False, cmap='gray', bbox=None, colorbar=False): 165 | fname = "/tmp/img.png" if fname is None else fname 166 | fig, ax = plt.subplots(1) 167 | if lognorm: 168 | image = np.log10(image) 169 | cmap='viridis' 170 | image = cv.normalize(image, None, 0, 1.0, cv.NORM_MINMAX) 171 | ims = ax.imshow(image, cmap=cmap) 172 | if bbox is not None: 173 | w = bbox[1][0]-bbox[0][0] 174 | h = bbox[1][1]-bbox[0][1] 175 | rect = patches.Rectangle((bbox[0]), w, h, linewidth=1, edgecolor='r', facecolor='none') 176 | ax.add_patch(rect) 177 | if colorbar: 178 | fig.colorbar(ims) 179 | plt.savefig(fname, dpi=150) 180 | plt.close() 181 | 182 | def flow2bgr_np(disp_x, disp_y, max_magnitude=None): 183 | """ 184 | Convert an optic flow tensor to an RGB color map for visualization 185 | Code adapted from: https://github.com/ClementPinard/FlowNetPytorch/blob/master/main.py#L339 186 | @param disp_x A [H x W] NumPy array containing the X displacement 187 | @param disp_y A [H x W] NumPy array containing the Y displacement 188 | @returns A [H x W x 3] NumPy array containing a color-coded representation of the flow [0, 255] 189 | """ 190 | assert(disp_x.shape == disp_y.shape) 191 | H, W = disp_x.shape 192 | 193 | # X, Y = np.meshgrid(np.linspace(-1, 1, H), np.linspace(-1, 1, W)) 194 | 195 | # flow_x = (X - disp_x) * float(W) / 2 196 | # flow_y = (Y - disp_y) * float(H) / 2 197 | # magnitude, angle = cv.cartToPolar(flow_x, flow_y) 198 | # magnitude, angle = cv.cartToPolar(disp_x, disp_y) 199 | 200 | # follow alex zhu color convention https://github.com/daniilidis-group/EV-FlowNet 201 | 202 | flows = np.stack((disp_x, disp_y), axis=2) 203 | magnitude = np.linalg.norm(flows, axis=2) 204 | 205 | angle = np.arctan2(disp_y, disp_x) 206 | angle += np.pi 207 | angle *= 180. / np.pi / 2. 208 | angle = angle.astype(np.uint8) 209 | 210 | if max_magnitude is None: 211 | v = np.zeros(magnitude.shape, dtype=np.uint8) 212 | cv.normalize(src=magnitude, dst=v, alpha=0, beta=255, norm_type=cv.NORM_MINMAX, dtype=cv.CV_8U) 213 | else: 214 | v = np.clip(255.0 * magnitude / max_magnitude, 0, 255) 215 | v = v.astype(np.uint8) 216 | 217 | hsv = np.zeros((H, W, 3), dtype=np.uint8) 218 | hsv[..., 1] = 255 219 | hsv[..., 0] = angle 220 | hsv[..., 2] = v 221 | bgr = cv.cvtColor(hsv, cv.COLOR_HSV2BGR) 222 | 223 | return bgr 224 | -------------------------------------------------------------------------------- /lib/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | from . import draw_event_stream 3 | -------------------------------------------------------------------------------- /lib/visualization/draw_event_stream.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.lib.recfunctions as nlr 3 | import cv2 as cv 4 | from skimage.measure import block_reduce 5 | import os 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.mplot3d import Axes3D 8 | 9 | from ..representations.image import events_to_image 10 | from ..representations.voxel_grid import events_to_voxel 11 | from ..util.event_util import clip_events_to_bounds 12 | from .visualization_utils import * 13 | from tqdm import tqdm 14 | 15 | def plot_events_sliding(xs, ys, ts, ps, args, frames=[], frame_ts=[]): 16 | """ 17 | Plot the given events in a sliding window fashion to generate a video 18 | @param xs x component of events 19 | @param ys y component of events 20 | @param ts t component of events 21 | @param ps p component of events 22 | @param args Arguments for the rendering (see args list 23 | for 'plot_events' function) 24 | @param frames List of image frames 25 | @param frame_ts List of the image timestamps 26 | @returns None 27 | """ 28 | dt, sdt = args.w_width, args.sw_width 29 | if dt is None: 30 | dt = (ts[-1]-ts[0])/10 31 | sdt = dt/10 32 | print("Using dt={}, sdt={}".format(dt, sdt)) 33 | 34 | if len(frames) > 0: 35 | has_frames = True 36 | sensor_size = frames[0].shape 37 | frame_ts = frame_ts[:,1] if len(frame_ts.shape) == 2 else frame_ts 38 | else: 39 | has_frames = False 40 | sensor_size = [max(ys), max(xs)] 41 | 42 | n_frames = len(np.arange(ts[0], ts[-1]-dt, sdt)) 43 | for i, t0 in enumerate(tqdm(np.arange(ts[0], ts[-1]-dt, sdt))): 44 | te = t0+dt 45 | eidx0 = np.searchsorted(ts, t0) 46 | eidx1 = np.searchsorted(ts, te) 47 | wxs, wys, wts, wps = xs[eidx0:eidx1], ys[eidx0:eidx1], ts[eidx0:eidx1], ps[eidx0:eidx1], 48 | 49 | wframes, wframe_ts = [], [] 50 | if has_frames: 51 | fidx0 = np.searchsorted(frame_ts, t0) 52 | fidx1 = np.searchsorted(frame_ts, te) 53 | wframes = [frames[fidx0]] 54 | wframe_ts = [wts[0]] 55 | 56 | save_path = os.path.join(args.output_path, "frame_{:010d}.jpg".format(i)) 57 | 58 | perc = i/n_frames 59 | min_p, max_p = 0.2, 0.7 60 | elev, azim = args.elev, args.azim 61 | max_elev, max_azim = 10, 45 62 | if perc > min_p and perc < max_p: 63 | p_way = (perc-min_p)/(max_p-min_p) 64 | elev = elev + (max_elev*p_way) 65 | azim = azim - (max_azim*p_way) 66 | elif perc >= max_p: 67 | elev, azim = max_elev, max_azim 68 | 69 | plot_events(wxs, wys, wts, wps, save_path=save_path, num_show=args.num_show, event_size=args.event_size, 70 | imgs=wframes, img_ts=wframe_ts, show_events=not args.hide_events, azim=azim, 71 | elev=elev, show_frames=not args.hide_frames, crop=args.crop, compress_front=args.compress_front, 72 | invert=args.invert, num_compress=args.num_compress, show_plot=args.show_plot, img_size=sensor_size, 73 | show_axes=args.show_axes, stride=args.stride) 74 | 75 | def plot_voxel_grid(xs, ys, ts, ps, bins=5, frames=[], frame_ts=[], 76 | sensor_size=None, crop=None, elev=0, azim=45, show_axes=False): 77 | """ 78 | @param xs x component of events 79 | @param ys y component of events 80 | @param ts t component of events 81 | @param ps p component of events 82 | @param bins The number of bins to have in the voxel grid 83 | @param frames The list of image frames 84 | @param frame_ts The list of image timestamps 85 | @param sensor_size The size of the event sensor resolution 86 | @param crop Cropping parameters for the voxel grid (no crop if None) 87 | @param elev The elevation of the plot 88 | @param azim The azimuth of the plot 89 | @param show_axes Show the axes of the plot 90 | @returns None 91 | """ 92 | if sensor_size is None: 93 | sensor_size = [np.max(ys)+1, np.max(xs)+1] if len(frames)==0 else frames[0].shape 94 | if crop is not None: 95 | xs, ys, ts, ps = clip_events_to_bounds(xs, ys, ts, ps, crop) 96 | sensor_size = crop_to_size(crop) 97 | xs, ys = xs-crop[2], ys-crop[0] 98 | num = 10000 99 | xs, ys, ts, ps = xs[0:num], ys[0:num], ts[0:num], ps[0:num] 100 | if len(xs) == 0: 101 | return 102 | voxels = events_to_voxel(xs, ys, ts, ps, bins, sensor_size=sensor_size) 103 | voxels = block_reduce(voxels, block_size=(1,10,10), func=np.mean, cval=0) 104 | dimdiff = voxels.shape[1]-voxels.shape[0] 105 | filler = np.zeros((dimdiff, *voxels.shape[1:])) 106 | voxels = np.concatenate((filler, voxels), axis=0) 107 | voxels = voxels.transpose(0,2,1) 108 | 109 | pltvoxels = voxels != 0 110 | pvp, nvp = voxels > 0, voxels < 0 111 | pvox, nvox = voxels*np.where(voxels > 0, 1, 0), voxels*np.where(voxels < 0, 1, 0) 112 | pvox, nvox = (pvox/np.max(pvox))*0.5+0.5, (np.abs(nvox)/np.max(np.abs(nvox)))*0.5+0.5 113 | zeros = np.zeros_like(voxels) 114 | 115 | colors = np.empty(voxels.shape, dtype=object) 116 | 117 | redvals = np.stack((pvox, zeros, pvox-0.5), axis=3) 118 | redvals = nlr.unstructured_to_structured(redvals).astype('O') 119 | 120 | bluvals = np.stack((nvox-0.5, zeros, nvox), axis=3) 121 | bluvals = nlr.unstructured_to_structured(bluvals).astype('O') 122 | 123 | colors[pvp] = redvals[pvp] 124 | colors[nvp] = bluvals[nvp] 125 | 126 | fig = plt.figure() 127 | ax = fig.gca(projection='3d') 128 | ax.voxels(pltvoxels, facecolors=colors, edgecolor='k') 129 | ax.view_init(elev=elev, azim=azim) 130 | 131 | ax.grid(False) 132 | # Hide panes 133 | ax.xaxis.pane.fill = False 134 | ax.yaxis.pane.fill = False 135 | ax.zaxis.pane.fill = False 136 | if not show_axes: 137 | # Hide spines 138 | ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 139 | ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 140 | ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 141 | ax.set_frame_on(False) 142 | # Hide xy axes 143 | ax.set_xticks([]) 144 | ax.set_yticks([]) 145 | ax.set_zticks([]) 146 | 147 | ax.xaxis.set_visible(False) 148 | ax.axes.get_yaxis().set_visible(False) 149 | 150 | plt.show() 151 | 152 | def plot_events(xs, ys, ts, ps, save_path=None, num_compress='auto', num_show=1000, 153 | event_size=2, elev=0, azim=45, imgs=[], img_ts=[], show_events=True, 154 | show_frames=True, show_plot=False, crop=None, compress_front=False, 155 | marker='.', stride = 1, invert=False, img_size=None, show_axes=False): 156 | """ 157 | Given events, plot these in a spatiotemporal volume. 158 | @param xs x coords of events 159 | @param ys y coords of events 160 | @param ts t coords of events 161 | @param ps p coords of events 162 | @param save_path If set, will save plot to here 163 | @param num_compress Takes num_compress events from the beginning of the 164 | sequence and draws them in the plot at time $t=0$ in black 165 | @param compress_front If True, display the compressed events in black at the 166 | front of the spatiotemporal volume rather than the back 167 | @param num_show Sets the number of events to plot. If set to -1 168 | will plot all of the events (can be potentially expensive) 169 | @param event_size Sets the size of the plotted events 170 | @param elev Sets the elevation of the plot 171 | @param azim Sets the azimuth of the plot 172 | @param imgs A list of images to draw into the spatiotemporal volume 173 | @param img_ts A list of the position on the temporal axis where each 174 | image from 'imgs' is to be placed (the timestamp of the images, usually) 175 | @param show_events If False, will not plot the events (only images) 176 | @param show_plot If True, display the plot in a matplotlib window as 177 | well as saving to disk 178 | @param crop A list of length 4 that sets the crop of the plot (must 179 | be in the format [top_left_y, top_left_x, height, width] 180 | @param marker Which marker should be used to display the events (default 181 | is '.', which results in points, but circles 'o' or crosses 'x' are 182 | among many other possible options) 183 | @param stride Determines the pixel stride of the image rendering 184 | (1=full resolution, but can be quite resource intensive) 185 | @param invert Inverts the color scheme for black backgrounds 186 | @param img_size The size of the sensor resolution. Inferred if empty. 187 | @param show_axes If True, draw axes onto the plot. 188 | @returns None 189 | """ 190 | #Crop events 191 | if img_size is None: 192 | img_size = [max(ys), max(xs)] if len(imgs)==0 else imgs[0].shape[0:2] 193 | print("Inferred image size = {}".format(img_size)) 194 | crop = [0, img_size[0], 0, img_size[1]] if crop is None else crop 195 | xs, ys, ts, ps = clip_events_to_bounds(xs, ys, ts, ps, crop, set_zero=False) 196 | xs, ys = xs-crop[2], ys-crop[0] 197 | 198 | #Defaults and range checks 199 | num_show = len(xs) if num_show == -1 else num_show 200 | skip = max(len(xs)//num_show, 1) 201 | num_compress = len(xs) if num_compress == -1 else num_compress 202 | num_compress = min(img_size[0]*img_size[1]*0.5, len(xs)) if num_compress=='auto' else num_compress 203 | xs, ys, ts, ps = xs[::skip], ys[::skip], ts[::skip], ps[::skip] 204 | 205 | #Prepare the plot, set colors 206 | fig = plt.figure() 207 | ax = fig.add_subplot(111, projection='3d', proj_type = 'ortho') 208 | colors = ['r' if p>0 else ('#00DAFF' if invert else 'b') for p in ps] 209 | 210 | #Plot images 211 | if len(imgs)>0 and show_frames: 212 | for imgidx, (img, img_ts) in enumerate(zip(imgs, img_ts)): 213 | img = img[crop[0]:crop[1], crop[2]:crop[3]] 214 | if len(img.shape)==2: 215 | img = np.stack((img, img, img), axis=2) 216 | if num_compress > 0: 217 | events_img = events_to_image(xs[0:num_compress], ys[0:num_compress], 218 | np.ones(num_compress), sensor_size=img.shape[0:2]) 219 | events_img[events_img>0] = 1 220 | img[:,:,1]+=events_img[:,:] 221 | img = np.clip(img, 0, 1) 222 | x, y = np.ogrid[0:img.shape[0], 0:img.shape[1]] 223 | event_idx = np.searchsorted(ts, img_ts) 224 | 225 | ax.scatter(xs[0:event_idx], ts[0:event_idx], ys[0:event_idx], zdir='z', 226 | c=colors[0:event_idx], facecolors=colors[0:event_idx], 227 | s=np.ones(xs.shape)*event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0) 228 | 229 | ax.plot_surface(y, img_ts, x, rstride=stride, cstride=stride, facecolors=img, alpha=1) 230 | 231 | ax.scatter(xs[event_idx:-1], ts[event_idx:-1], ys[event_idx:-1], zdir='z', 232 | c=colors[event_idx:-1], facecolors=colors[event_idx:-1], 233 | s=np.ones(xs.shape)*event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0) 234 | 235 | elif num_compress > 0: 236 | # Plot events 237 | ax.scatter(xs[::skip], ts[::skip], ys[::skip], zdir='z', c=colors[::skip], facecolors=colors[::skip], 238 | s=np.ones(xs.shape)*event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0) 239 | num_compress = min(num_compress, len(xs)) 240 | if not compress_front: 241 | ax.scatter(xs[0:num_compress], np.ones(num_compress)*ts[0], ys[0:num_compress], 242 | marker=marker, zdir='z', c='w' if invert else 'k', s=np.ones(num_compress)*event_size) 243 | else: 244 | ax.scatter(xs[-num_compress-1:-1], np.ones(num_compress)*ts[-1], ys[-num_compress-1:-1], 245 | marker=marker, zdir='z', c='w' if invert else 'k', s=np.ones(num_compress)*event_size) 246 | else: 247 | # Plot events 248 | ax.scatter(xs, ts, ys,zdir='z', c=colors, facecolors=colors, s=np.ones(xs.shape)*event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0) 249 | 250 | ax.view_init(elev=elev, azim=azim) 251 | ax.grid(False) 252 | # Hide panes 253 | ax.xaxis.pane.fill = False 254 | ax.yaxis.pane.fill = False 255 | ax.zaxis.pane.fill = False 256 | if not show_axes: 257 | # Hide spines 258 | ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 259 | ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 260 | ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 261 | ax.set_frame_on(False) 262 | # Hide xy axes 263 | ax.set_xticks([]) 264 | ax.set_yticks([]) 265 | ax.set_zticks([]) 266 | # Flush axes 267 | ax.set_xlim3d(0, img_size[1]) 268 | ax.set_ylim3d(ts[0], ts[-1]) 269 | ax.set_zlim3d(0,img_size[0]) 270 | 271 | if show_plot: 272 | plt.show() 273 | if save_path is not None: 274 | ensure_dir(save_path) 275 | plt.savefig(save_path, transparent=True, dpi=600, bbox_inches = 'tight') 276 | plt.close() 277 | 278 | def plot_between_frames(xs, ys, ts, ps, frames, frame_event_idx, args, plttype='voxel'): 279 | """ 280 | Plot events between frames for an entire sequence to form a video 281 | @param xs x component of events 282 | @param ys y component of events 283 | @param ts t component of events 284 | @param ps p component of events 285 | @param frames List of the frames 286 | @param frame_event_idx The event index for each frame 287 | @param args Arguments for the rendering function 'plot_events' 288 | @param plttype Whether to plot 'voxel' or 'events' 289 | @return None 290 | """ 291 | args.crop = None if args.crop is None else parse_crop(args.crop) 292 | prev_idx = 0 293 | for i in range(0, len(frames), args.skip_frames): 294 | if args.hide_skipped: 295 | frame = [frames[i]] 296 | frame_indices = frame_event_idx[i][np.newaxis, ...] 297 | else: 298 | frame = frames[i:i+args.skip_frames] 299 | frame_indices = frame_event_idx[i:i+args.skip_frames] 300 | print("Processing frame {}".format(i)) 301 | s, e = frame_indices[0,1], frame_indices[-1,0] 302 | img_ts = [] 303 | for f_idx in frame_indices: 304 | img_ts.append(ts[f_idx[1]]) 305 | fname = os.path.join(args.output_path, "events_{:09d}.png".format(i)) 306 | if plttype == 'voxel': 307 | plot_voxel_grid(xs[s:e], ys[s:e], ts[s:e], ps[s:e], bins=args.num_bins, crop=args.crop, 308 | frames=frame, frame_ts=img_ts, elev=args.elev, azim=args.azim) 309 | elif plttype == 'events': 310 | plot_events(xs[s:e], ys[s:e], ts[s:e], ps[s:e], save_path=fname, 311 | num_show=args.num_show, event_size=args.event_size, imgs=frame, 312 | img_ts=img_ts, show_events=not args.hide_events, azim=args.azim, 313 | elev=args.elev, show_frames=not args.hide_frames, crop=args.crop, 314 | compress_front=args.compress_front, invert=args.invert, 315 | num_compress=args.num_compress, show_plot=args.show_plot, stride=args.stride) 316 | 317 | -------------------------------------------------------------------------------- /lib/visualization/draw_event_stream_mayavi.py: -------------------------------------------------------------------------------- 1 | from mayavi import mlab 2 | from mayavi.api import Engine 3 | import numpy as np 4 | import numpy.lib.recfunctions as nlr 5 | import cv2 as cv 6 | from skimage.measure import block_reduce 7 | import os 8 | #import matplotlib.pyplot as plt 9 | #from mpl_toolkits.mplot3d import Axes3D 10 | 11 | from ..representations.image import events_to_image 12 | from ..representations.voxel_grid import events_to_voxel 13 | from ..util.event_util import clip_events_to_bounds 14 | from ..visualization.visualization_utils import * 15 | from tqdm import tqdm 16 | 17 | def plot_events_sliding(xs, ys, ts, ps, args, dt=None, sdt=None, frames=None, frame_ts=None, padding=True): 18 | 19 | skip = max(len(xs)//args.num_show, 1) 20 | xs, ys, ts, ps = xs[::skip], ys[::skip], ts[::skip], ps[::skip] 21 | t0 = ts[0] 22 | sx,sy, st, sp = [], [], [], [] 23 | if padding: 24 | for i in np.arange(ts[0]-dt, ts[0], sdt): 25 | sx.append(0) 26 | sy.append(0) 27 | st.append(i) 28 | sp.append(0) 29 | print(len(sx)) 30 | print(st) 31 | print(ts) 32 | xs = np.concatenate((np.array(sx), xs)) 33 | ys = np.concatenate((np.array(sy), ys)) 34 | ts = np.concatenate((np.array(st), ts)) 35 | ps = np.concatenate((np.array(sp), ps)) 36 | print(ts) 37 | 38 | ts += -st[0] 39 | frame_ts += -st[0] 40 | t0 += -st[0] 41 | print(ts) 42 | 43 | f = mlab.figure(bgcolor=(1,1,1), size=(1080, 720)) 44 | engine = mlab.get_engine() 45 | scene = engine.scenes[0] 46 | scene.scene.camera.position = [373.1207907160101, 5353.96218497846, 7350.065665045519] 47 | scene.scene.camera.focal_point = [228.0033999234376, 37.75424682790012, 3421.439332472788] 48 | scene.scene.camera.view_angle = 30.0 49 | scene.scene.camera.view_up = [0.9997493712140433, -0.02027499237784438, -0.009493125997461629] 50 | scene.scene.camera.clipping_range = [2400.251302762254, 11907.415293888362] 51 | scene.scene.camera.compute_view_plane_normal() 52 | 53 | print("ts from {} to {}, imgs from {} to {}".format(ts[0], ts[-1], frame_ts[0], frame_ts[-1])) 54 | frame_ts = np.array([t0]+list(frame_ts[0:-1])) 55 | if dt is None: 56 | dt = (ts[-1]-ts[0])/10 57 | sdt = dt/10 58 | print("Using dt={}, sdt={}".format(dt, sdt)) 59 | if frames is not None: 60 | sensor_size = frames[0].shape 61 | else: 62 | sensor_size = [max(ys), max(xs)] 63 | 64 | if len(frame_ts.shape) == 2: 65 | frame_ts = frame_ts[:,1] 66 | for i, t0 in enumerate(tqdm(np.arange(ts[0], ts[-1]-dt, sdt))): 67 | te = t0+dt 68 | eidx0 = np.searchsorted(ts, t0) 69 | eidx1 = np.searchsorted(ts, te) 70 | fidx0 = np.searchsorted(frame_ts, t0) 71 | fidx1 = np.searchsorted(frame_ts, te) 72 | #print("{}:{} = {}".format(frame_ts[fidx0], ts[eidx0], fidx0)) 73 | 74 | wxs, wys, wts, wps = xs[eidx0:eidx1], ys[eidx0:eidx1], ts[eidx0:eidx1], ps[eidx0:eidx1], 75 | if fidx0 == fidx1: 76 | wframes=[] 77 | wframe_ts=[] 78 | else: 79 | wframes = frames[fidx0:fidx1] 80 | wframe_ts = frame_ts[fidx0:fidx1] 81 | 82 | save_path = os.path.join(args.output_path, "frame_{:010d}.jpg".format(i)) 83 | plot_events(wxs, wys, wts, wps, save_path=save_path, num_show=-1, event_size=args.event_size, 84 | imgs=wframes, img_ts=wframe_ts, show_events=not args.hide_events, azim=args.azim, 85 | elev=args.elev, show_frames=not args.hide_frames, crop=args.crop, compress_front=args.compress_front, 86 | invert=args.invert, num_compress=args.num_compress, show_plot=args.show_plot, img_size=sensor_size, 87 | show_axes=args.show_axes, ts_scale=args.ts_scale) 88 | 89 | if save_path is not None: 90 | ensure_dir(save_path) 91 | #mlab.savefig(save_path, figure=f, magnification=10) 92 | #GUI().process_events() 93 | #img = mlab.screenshot(figure=f, mode='rgba', antialiased=True) 94 | #print(img.shape) 95 | mlab.savefig(save_path, figure=f, magnification=8) 96 | 97 | mlab.clf() 98 | 99 | def plot_voxel_grid(xs, ys, ts, ps, bins=5, frames=[], frame_ts=[], 100 | sensor_size=None, crop=None, elev=0, azim=45, show_axes=False): 101 | if sensor_size is None: 102 | sensor_size = [np.max(ys)+1, np.max(xs)+1] if len(frames)==0 else frames[0].shape 103 | if crop is not None: 104 | xs, ys, ts, ps = clip_events_to_bounds(xs, ys, ts, ps, crop) 105 | sensor_size = crop_to_size(crop) 106 | xs, ys = xs-crop[2], ys-crop[0] 107 | num = 10000 108 | xs, ys, ts, ps = xs[0:num], ys[0:num], ts[0:num], ps[0:num] 109 | if len(xs) == 0: 110 | return 111 | voxels = events_to_voxel(xs, ys, ts, ps, bins, sensor_size=sensor_size) 112 | voxels = block_reduce(voxels, block_size=(1,10,10), func=np.mean, cval=0) 113 | dimdiff = voxels.shape[1]-voxels.shape[0] 114 | filler = np.zeros((dimdiff, *voxels.shape[1:])) 115 | voxels = np.concatenate((filler, voxels), axis=0) 116 | voxels = voxels.transpose(0,2,1) 117 | 118 | pltvoxels = voxels != 0 119 | pvp, nvp = voxels > 0, voxels < 0 120 | pvox, nvox = voxels*np.where(voxels > 0, 1, 0), voxels*np.where(voxels < 0, 1, 0) 121 | pvox, nvox = (pvox/np.max(pvox))*0.5+0.5, (np.abs(nvox)/np.max(np.abs(nvox)))*0.5+0.5 122 | zeros = np.zeros_like(voxels) 123 | 124 | colors = np.empty(voxels.shape, dtype=object) 125 | 126 | redvals = np.stack((pvox, zeros, pvox-0.5), axis=3) 127 | redvals = nlr.unstructured_to_structured(redvals).astype('O') 128 | 129 | bluvals = np.stack((nvox-0.5, zeros, nvox), axis=3) 130 | bluvals = nlr.unstructured_to_structured(bluvals).astype('O') 131 | 132 | colors[pvp] = redvals[pvp] 133 | colors[nvp] = bluvals[nvp] 134 | 135 | fig = plt.figure() 136 | ax = fig.gca(projection='3d') 137 | ax.voxels(pltvoxels, facecolors=colors, edgecolor='k') 138 | ax.view_init(elev=elev, azim=azim) 139 | 140 | ax.grid(False) 141 | # Hide panes 142 | ax.xaxis.pane.fill = False 143 | ax.yaxis.pane.fill = False 144 | ax.zaxis.pane.fill = False 145 | if not show_axes: 146 | # Hide spines 147 | ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 148 | ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 149 | ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 150 | ax.set_frame_on(False) 151 | # Hide xy axes 152 | ax.set_xticks([]) 153 | ax.set_yticks([]) 154 | ax.set_zticks([]) 155 | 156 | ax.xaxis.set_visible(False) 157 | ax.axes.get_yaxis().set_visible(False) 158 | 159 | plt.show() 160 | 161 | def plot_events(xs, ys, ts, ps, save_path=None, num_compress='auto', num_show=1000, 162 | event_size=2, elev=0, azim=45, imgs=[], img_ts=[], show_events=True, 163 | show_frames=True, show_plot=False, crop=None, compress_front=False, 164 | marker='.', stride = 1, invert=False, img_size=None, show_axes=False, 165 | ts_scale = 100000): 166 | """ 167 | Given events, plot these in a spatiotemporal volume. 168 | :param: xs x coords of events 169 | :param: ys y coords of events 170 | :param: ts t coords of events 171 | :param: ps p coords of events 172 | :param: save_path if set, will save plot to here 173 | :param: num_compress will take this number of events from the end 174 | and create an event image from these. This event image will 175 | be displayed at the end of the spatiotemporal volume 176 | :param: num_show sets the number of events to plot. If set to -1 177 | will plot all of the events (can be potentially expensive) 178 | :param: event_size sets the size of the plotted events 179 | :param: elev sets the elevation of the plot 180 | :param: azim sets the azimuth of the plot 181 | :param: imgs a list of images to draw into the spatiotemporal volume 182 | :param: img_ts a list of the position on the temporal axis where each 183 | image from 'imgs' is to be placed (the timestamp of the images, usually) 184 | :param: show_events if False, will not plot the events (only images) 185 | :param: crop a list of length 4 that sets the crop of the plot (must 186 | be in the format [top_left_y, top_left_x, height, width] 187 | """ 188 | print("plot all") 189 | #Crop events 190 | if img_size is None: 191 | img_size = [max(ys), max(ps)] if len(imgs)==0 else imgs[0].shape[0:2] 192 | crop = [0, img_size[0], 0, img_size[1]] if crop is None else crop 193 | xs, ys, ts, ps = clip_events_to_bounds(xs, ys, ts, ps, crop, set_zero=False) 194 | xs, ys = xs-crop[2], ys-crop[0] 195 | 196 | #Defaults and range checks 197 | num_show = len(xs) if num_show == -1 else num_show 198 | skip = max(len(xs)//num_show, 1) 199 | print("Has {} events, show only {}, skip = {}".format(len(xs), num_show, skip)) 200 | num_compress = len(xs) if num_compress == -1 else num_compress 201 | num_compress = min(img_size[0]*img_size[1]*0.5, len(xs)) if num_compress=='auto' else num_compress 202 | xs, ys, ts, ps = xs[::skip], ys[::skip], ts[::skip], ps[::skip] 203 | 204 | t0 = ts[0] 205 | ts = ts-t0 206 | 207 | #mlab.options.offscreen = True 208 | 209 | #Plot images 210 | if len(imgs)>0 and show_frames: 211 | for imgidx, (img, img_t) in enumerate(zip(imgs, img_ts)): 212 | img = img[crop[0]:crop[1], crop[2]:crop[3]] 213 | 214 | mlab.imshow(img, colormap='gray', extent=[0, img.shape[0], 0, img.shape[1], (img_t-t0)*ts_scale, (img_t-t0)*ts_scale+0.01], opacity=1.0, transparent=False) 215 | 216 | colors = [0 if p>0 else 240 for p in ps] 217 | ones = np.array([0 if p==0 else 1 for p in ps]) 218 | p3d = mlab.quiver3d(ys, xs, ts*ts_scale, ones, ones, ones, scalars=colors, mode='sphere', scale_factor=event_size) 219 | p3d.glyph.color_mode = 'color_by_scalar' 220 | p3d.module_manager.scalar_lut_manager.lut.table = colors 221 | #mlab.draw() 222 | 223 | #mlab.view(84.5, 54, 5400, np.array([ 187, 175, 2276]), roll=95) 224 | 225 | if show_plot: 226 | mlab.show() 227 | #if save_path is not None: 228 | # ensure_dir(save_path) 229 | # print("Saving to {}".format(save_path)) 230 | # imgmap = mlab.screenshot(mode='rgba', antialiased=True) 231 | # print(imgmap.shape) 232 | # cv.imwrite(save_path, imgmap) 233 | 234 | def plot_between_frames(xs, ys, ts, ps, frames, frame_event_idx, args, plttype='voxel'): 235 | args.crop = None if args.crop is None else parse_crop(args.crop) 236 | prev_idx = 0 237 | for i in range(0, len(frames), args.skip_frames): 238 | if i != 3: 239 | continue 240 | if args.hide_skipped: 241 | frame = [frames[i]] 242 | frame_indices = frame_event_idx[i][np.newaxis, ...] 243 | else: 244 | frame = frames[i:i+args.skip_frames] 245 | frame_indices = frame_event_idx[i:i+args.skip_frames] 246 | print("Processing frame {}".format(i)) 247 | s, e = frame_indices[0,1], frame_indices[-1,0] 248 | img_ts = [] 249 | for f_idx in frame_indices: 250 | img_ts.append(ts[f_idx[1]]) 251 | fname = os.path.join(args.output_path, "events_{:09d}.png".format(i)) 252 | if plttype == 'voxel': 253 | plot_voxel_grid(xs[s:e], ys[s:e], ts[s:e], ps[s:e], bins=args.num_bins, crop=args.crop, 254 | frames=frame, frame_ts=img_ts, elev=args.elev, azim=args.azim) 255 | elif plttype == 'events': 256 | print("plot events") 257 | plot_events(xs[s:e], ys[s:e], ts[s:e], ps[s:e], save_path=fname, 258 | num_show=args.num_show, event_size=args.event_size, imgs=frame, 259 | img_ts=img_ts, show_events=not args.hide_events, azim=args.azim, 260 | elev=args.elev, show_frames=not args.hide_frames, crop=args.crop, 261 | compress_front=args.compress_front, invert=args.invert, 262 | num_compress=args.num_compress, show_plot=args.show_plot, stride=args.stride) 263 | -------------------------------------------------------------------------------- /lib/visualization/draw_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 as cv 4 | import os 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.mplot3d import Axes3D 8 | 9 | from ..util.event_util import clip_events_to_bounds 10 | from ..util.util import flow2bgr_np 11 | from ..transforms.optic_flow import warp_events_flow_torch 12 | from ..representations.image import events_to_image_torch 13 | from .visualization_utils import * 14 | 15 | def motion_compensate(xs, ys, ts, ps, flow, fname="/tmp/img.png", crop=None): 16 | xs, ys, ts, ps, flow = torch.from_numpy(xs).type(torch.float32), torch.from_numpy(ys).type(torch.float32),\ 17 | torch.from_numpy(ts).type(torch.float32), torch.from_numpy(ps).type(torch.float32), torch.from_numpy(flow).type(torch.float32) 18 | xw, yw = warp_events_flow_torch(xs, ys, ts, ps, flow) 19 | img_size = list(flow.shape) 20 | img_size.remove(2) 21 | img = events_to_image_torch(xw, yw, ps, sensor_size=img_size, interpolation='bilinear') 22 | img = np.flip(np.flip(img.numpy(), axis=0), axis=1) 23 | img = cv.normalize(img, None, 0, 255, cv.NORM_MINMAX) 24 | if crop is not None: 25 | img = img[crop[0]:crop[1], crop[2]: crop[3]] 26 | cv.imwrite(fname, img) 27 | 28 | def plot_flow_and_events(xs, ys, ts, ps, flow, save_path=None, 29 | num_show=1000, event_size=2, elev=0, azim=45, show_events=True, 30 | show_frames=True, show_plot=False, crop=None, 31 | marker='.', stride = 20, img_size=None, show_axes=False, 32 | invert=False): 33 | 34 | print(event_size) 35 | #Crop events 36 | if img_size is None: 37 | img_size = [max(ys), max(xs)] if len(flow)==0 else flow[0].shape[1:3] 38 | crop = [0, img_size[0], 0, img_size[1]] if crop is None else crop 39 | xs, ys = img_size[1]-xs, img_size[0]-ys 40 | xs, ys, ts, ps = clip_events_to_bounds(xs, ys, ts, ps, crop, set_zero=False) 41 | xs -= crop[2] 42 | ys -= crop[0] 43 | img_size = [crop[1]-crop[0], crop[3]-crop[2]] 44 | xs, ys = img_size[1]-xs, img_size[0]-ys 45 | #flow[0] = flow[0][:, crop[0]:crop[1], crop[2]:crop[3]] 46 | flow = flow[0][:, crop[0]:crop[1], crop[2]:crop[3]] 47 | flow = np.flip(np.flip(flow, axis=1), axis=2) 48 | 49 | #Defaults and range checks 50 | num_show = len(xs) if num_show == -1 else num_show 51 | skip = max(len(xs)//num_show, 1) 52 | xs, ys, ts, ps = xs[::skip], ys[::skip], ts[::skip], ps[::skip] 53 | 54 | #Prepare the plot, set colors 55 | fig = plt.figure() 56 | ax = fig.add_subplot(111, projection='3d', proj_type = 'ortho') 57 | colors = ['r' if p>0 else ('#00DAFF' if invert else 'b') for p in ps] 58 | 59 | # Plot quivers 60 | f_reshape = flow.transpose(1,2,0) 61 | print(f_reshape.shape) 62 | t_w = ts[-1]-ts[0] 63 | coords, flow_vals, magnitudes = [], [], [] 64 | s = 20 65 | offset = 0 66 | thresh = 0 67 | print(img_size) 68 | for x in np.linspace(offset, img_size[1]-1-offset, s): 69 | for y in np.linspace(offset, img_size[0]-1-offset, s): 70 | ix, iy = int(x), int(y) 71 | flow_v = np.array([f_reshape[iy,ix,0]*t_w, f_reshape[iy,ix,1]*t_w, t_w]) 72 | mag = np.linalg.norm(flow_v) 73 | if mag >= thresh: 74 | flow_vals.append(flow_v) 75 | magnitudes.append(mag) 76 | coords.append([x,y]) 77 | magnitudes = np.array(magnitudes) 78 | max_flow = np.percentile(magnitudes, 99) 79 | 80 | x,y,z,u,v,w = [],[],[],[],[],[] 81 | idx = 0 82 | for coord, flow_vec, mag in zip(coords, flow_vals, magnitudes): 83 | #q_start = [coord[0], ts[0], coord[1]] 84 | rel_len = mag/max_flow 85 | flow_vec = flow_vec*rel_len 86 | x.append(coord[0]) 87 | y.append(0.065) 88 | z.append(coord[1]) 89 | u.append(max(1, flow_vec[0])) 90 | v.append(flow_vec[2]) 91 | w.append(max(1, flow_vec[1])) 92 | ax.quiver(x,y,z,u,v,w,color='c', arrow_length_ratio=0, alpha=0.8) 93 | 94 | img = flow2bgr_np(flow[0, :], flow[1, :]) 95 | img = img/255 96 | 97 | x, y = np.ogrid[0:img.shape[0], 0:img.shape[1]] 98 | ax.plot_surface(y, ts[0], x, rstride=stride, cstride=stride, facecolors=img, alpha=1) 99 | 100 | ax.scatter(xs, ts, ys, zdir='z', c=colors, facecolors=colors, 101 | s=np.ones(xs.shape)*event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0) 102 | 103 | ax.view_init(elev=elev, azim=azim) 104 | 105 | ax.grid(False) 106 | # Hide panes 107 | ax.xaxis.pane.fill = False 108 | ax.yaxis.pane.fill = False 109 | ax.zaxis.pane.fill = False 110 | if not show_axes: 111 | # Hide spines 112 | ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 113 | ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 114 | ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 115 | ax.set_frame_on(False) 116 | # Hide xy axes 117 | ax.set_xticks([]) 118 | ax.set_yticks([]) 119 | ax.set_zticks([]) 120 | 121 | ax.xaxis.set_visible(False) 122 | ax.axes.get_yaxis().set_visible(False) 123 | 124 | plt.show() 125 | 126 | 127 | 128 | def plot_between_frames(xs, ys, ts, ps, flows, flow_imgs, flow_ts, args, plttype='voxel'): 129 | args.crop = None if args.crop is None else parse_crop(args.crop) 130 | 131 | flow_event_idx = get_frame_indices(ts, flow_ts) 132 | if len(flow_ts.shape) == 1: 133 | flow_ts = frame_stamps_to_start_end(flow_ts) 134 | flow_event_idx = frame_stamps_to_start_end(flow_event_idx) 135 | prev_idx = 0 136 | for i in range(0, len(flows), args.skip_frames): 137 | if i != 12: 138 | continue 139 | flow = flows[i:i+args.skip_frames] 140 | flow_indices = flow_event_idx[i:i+args.skip_frames] 141 | s, e = flow_indices[-1,0], flow_indices[0,1] 142 | 143 | motion_compensate(xs[s:e], ys[s:e], ts[s:e], ps[s:e], -np.flip(np.flip(flow[0], axis=1), axis=2).copy(), fname="/tmp/comp.png", crop=args.crop) 144 | motion_compensate(xs[s:e], ys[s:e], ts[s:e], ps[s:e], np.zeros_like(flow[0]), fname="/tmp/zero.png", crop=args.crop) 145 | e = np.searchsorted(ts, ts[s]+0.02) 146 | flow_ts = [] 147 | for f_idx in flow_indices: 148 | flow_ts.append(ts[f_idx[1]]) 149 | fname = os.path.join(args.output_path, "events_{:09d}.png".format(i)) 150 | 151 | print("se: {}, {}".format(s, e)) 152 | plot_flow_and_events(xs[s:e], ys[s:e], ts[s:e], ps[s:e], flow, 153 | num_show=args.num_show, event_size=args.event_size, elev=args.elev, 154 | azim=args.azim, show_events=not args.hide_events, 155 | show_frames=not args.hide_frames, show_plot=args.show_plot, crop=args.crop, 156 | stride=args.stride, show_axes=args.show_axes, invert=args.invert) 157 | -------------------------------------------------------------------------------- /lib/visualization/utils/draw_plane.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | from mpl_toolkits.mplot3d import axes3d, Axes3D #<-- Note the capitalization! 5 | 6 | # z = ax + by + d 7 | 8 | x_min = 0 9 | x_max = 100 10 | y_min = 0 11 | y_max = 100 12 | 13 | a = 0 14 | b = 10 15 | d = 10 16 | 17 | num_points = 5000 18 | point_size = 10 19 | 20 | points = np.random.rand(num_points, 3) 21 | points[:, 0] = points[:, 0]*(x_max-x_min) + x_min 22 | points[:, 1] = points[:, 1]*(y_max-y_min) + y_min 23 | points[:, 2] = points[:, 0]*a + points[:, 1]*b + d 24 | 25 | mean = 0 26 | stdev = 10 27 | noise = np.random.normal(mean, stdev, num_points) 28 | points[:, 2] = points[:, 2] + noise 29 | 30 | print(points) 31 | new_points = points[np.where(points[:, 1] < 50)] 32 | print(new_points) 33 | 34 | for x in range(y_min, y_max, 1): 35 | 36 | fig = plt.figure() 37 | ax = Axes3D(fig) 38 | ax.set_xlabel('x') 39 | ax.set_ylabel('y') 40 | ax.set_zlabel('time') 41 | ax.set_ylim([0, 100]) 42 | 43 | new_points = points[np.where(points[:, 1] < x)] 44 | ax.scatter(new_points[:, 0], new_points[:, 1], new_points[:, 2], s=point_size, c=(new_points[:, 2]), 45 | edgecolors='none', cmap='plasma') 46 | ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=0, c=(points[:, 2]), 47 | edgecolors='none', cmap='plasma') 48 | 49 | point = np.array([0, 1, 0]) 50 | normal = np.array([0, 0, 1]) 51 | 52 | # a plane is a*x+b*y+c*z+d=0 53 | # [a,b,c] is the normal. Thus, we have to calculate 54 | # d and we're set 55 | d = -point.dot(normal) 56 | 57 | # create x,y 58 | xx, yy = np.meshgrid(range(100), range(10)) 59 | yy = yy + x - 10 60 | 61 | # calculate corresponding z 62 | z = (-normal[0] * xx - normal[1] * yy - d) * 1. / normal[2] 63 | # plot the surface 64 | # plt3d = plt.figure().gca(projection='3d') 65 | ax.plot_surface(xx, yy, z, alpha=1) 66 | 67 | save_name = ("frame_" + str(x) + ".png") 68 | fig.tight_layout() 69 | fig.savefig(save_name, dpi=300, transparent=True) 70 | 71 | # plt.show() 72 | plt.close() -------------------------------------------------------------------------------- /lib/visualization/utils/draw_plane_simple.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | from mpl_toolkits.mplot3d import axes3d, Axes3D #<-- Note the capitalization! 5 | 6 | fig = plt.figure() 7 | ax = Axes3D(fig) 8 | ax.set_xlabel('x') 9 | ax.set_ylabel('y') 10 | ax.set_zlabel('time') 11 | 12 | # z = ax + by + d 13 | x_min = 0 14 | x_max = 10 15 | y_min = 0 16 | y_max = 10 17 | 18 | a = 0 19 | b = 10 20 | d = 10 21 | 22 | num_points = 50 23 | point_size = 20 24 | 25 | points = np.random.rand(num_points, 3) 26 | points[:, 0] = points[:, 0]*(x_max-x_min) + x_min 27 | points[:, 1] = points[:, 1]*(y_max-y_min) + y_min 28 | points[:, 2] = points[:, 0]*a + points[:, 1]*b + d 29 | 30 | mean = 0 31 | stdev = 10 32 | noise = np.random.normal(mean, stdev, num_points) 33 | points[:, 2] = points[:, 2] + noise 34 | 35 | ax.scatter(points[:, 0], points[:, 1], points[:, 2], s=point_size, c=(points[:, 2]), 36 | edgecolors='none', cmap='plasma') 37 | 38 | 39 | # create x,y 40 | xx, yy = np.meshgrid(range(10), range(10)) 41 | yy = yy 42 | 43 | # calculate corresponding z 44 | # z = (-normal[0] * xx - normal[1] * yy - d) * 1. / normal[2] 45 | z = xx*a + yy*b + d 46 | # plot the surface 47 | # plt3d = plt.figure().gca(projection='3d') 48 | ax.plot_surface(xx, yy, z, alpha=0.2) 49 | 50 | save_name = ("plane.png") 51 | fig.tight_layout() 52 | fig.savefig(save_name, dpi=600, transparent=True) 53 | plt.close() -------------------------------------------------------------------------------- /lib/visualization/visualization_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | def frame_stamps_to_start_end(stamps): 5 | ends = list(stamps[1:]) 6 | ends.append(ends[-1]) 7 | se_stamps = np.stack((stamps, np.array(ends)), axis=1) 8 | return se_stamps 9 | 10 | def get_frame_indices(ts, frame_ts): 11 | indices = [np.searchsorted(ts, fts) for fts in frame_ts] 12 | return np.array(indices) 13 | 14 | def crop_to_size(crop): 15 | return [crop[0]-crop[1], crop[2]-crop[3]] 16 | 17 | def parse_crop(cropstr): 18 | """ 19 | Crop is provided as string, same as imagemagick: 20 | size_x, size_y, offset_x, offset_y, eg 10x10+30+30 would cut a 10x10 square at 30,30 21 | Output is the indices as would be used in a numpy array. In the example, 22 | [30,40,30,40] (ie [miny, maxy, minx, maxx]) 23 | 24 | """ 25 | split = cropstr.split("x") 26 | xsize = int(split[0]) 27 | split = split[1].split("+") 28 | ysize = int(split[0]) 29 | xoff = int(split[1]) 30 | yoff = int(split[2]) 31 | crop = [yoff, yoff+ysize, xoff, xoff+xsize] 32 | return crop 33 | 34 | def ensure_dir(file_path): 35 | directory = os.path.dirname(file_path) 36 | if not os.path.exists(directory): 37 | print(f"Creating {directory}") 38 | os.makedirs(directory) 39 | 40 | -------------------------------------------------------------------------------- /lib/visualization/visualizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.lib.recfunctions as nlr 3 | import cv2 as cv 4 | import colorsys 5 | from skimage.measure import block_reduce 6 | import os 7 | import matplotlib.pyplot as plt 8 | from mpl_toolkits.mplot3d import Axes3D 9 | 10 | from ..representations.image import events_to_image, TimestampImage 11 | from ..representations.voxel_grid import events_to_voxel 12 | from ..util.event_util import clip_events_to_bounds 13 | from .visualization_utils import * 14 | from tqdm import tqdm 15 | 16 | class Visualizer(): 17 | 18 | def __init__(self): 19 | raise NotImplementedError 20 | 21 | def plot_events(self, data, save_path, **kwargs): 22 | raise NotImplementedError 23 | 24 | @staticmethod 25 | def unpackage_events(events): 26 | return events[:,0].astype(int), events[:,1].astype(int), events[:,2], events[:,3] 27 | 28 | class TimeStampImageVisualizer(Visualizer): 29 | 30 | def __init__(self, sensor_size): 31 | self.ts_img = TimestampImage(sensor_size) 32 | self.sensor_size = sensor_size 33 | 34 | def plot_events(self, data, save_path, **kwargs): 35 | xs, ys, ts, ps = self.unpackage_events(data['events']) 36 | self.ts_img.set_init(ts[0]) 37 | self.ts_img.add_events(xs, ys, ts, ps) 38 | timestamp_image = self.ts_img.get_image() 39 | fig = plt.figure() 40 | plt.imshow(timestamp_image, cmap='viridis') 41 | ensure_dir(save_path) 42 | plt.savefig(save_path, transparent=True, dpi=600, bbox_inches = 'tight') 43 | #plt.show() 44 | 45 | class EventImageVisualizer(Visualizer): 46 | 47 | def __init__(self, sensor_size): 48 | self.sensor_size = sensor_size 49 | 50 | def plot_events(self, data, save_path, **kwargs): 51 | xs, ys, ts, ps = self.unpackage_events(data['events']) 52 | img = events_to_image(xs.astype(int), ys.astype(int), ps, self.sensor_size, interpolation=None, padding=False) 53 | mn, mx = np.min(img), np.max(img) 54 | img = (img-mn)/(mx-mn) 55 | 56 | fig = plt.figure() 57 | plt.imshow(img, cmap='gray') 58 | ensure_dir(save_path) 59 | plt.savefig(save_path, transparent=True, dpi=600, bbox_inches = 'tight') 60 | #plt.show() 61 | 62 | 63 | class EventsVisualizer(Visualizer): 64 | 65 | def __init__(self, sensor_size): 66 | self.sensor_size = sensor_size 67 | 68 | def plot_events(self, data, save_path, 69 | num_compress='auto', num_show=1000, 70 | event_size=2, elev=0, azim=45, show_events=True, 71 | show_frames=True, show_plot=False, crop=None, compress_front=False, 72 | marker='.', stride = 1, invert=False, show_axes=False, flip_x=False): 73 | """ 74 | Given events, plot these in a spatiotemporal volume. 75 | :param: xs x coords of events 76 | :param: ys y coords of events 77 | :param: ts t coords of events 78 | :param: ps p coords of events 79 | :param: save_path if set, will save plot to here 80 | :param: num_compress will take this number of events from the end 81 | and create an event image from these. This event image will 82 | be displayed at the end of the spatiotemporal volume 83 | :param: num_show sets the number of events to plot. If set to -1 84 | will plot all of the events (can be potentially expensive) 85 | :param: event_size sets the size of the plotted events 86 | :param: elev sets the elevation of the plot 87 | :param: azim sets the azimuth of the plot 88 | :param: imgs a list of images to draw into the spatiotemporal volume 89 | :param: img_ts a list of the position on the temporal axis where each 90 | image from 'imgs' is to be placed (the timestamp of the images, usually) 91 | :param: show_events if False, will not plot the events (only images) 92 | :param: crop a list of length 4 that sets the crop of the plot (must 93 | be in the format [top_left_y, top_left_x, height, width] 94 | """ 95 | xs, ys, ts, ps = self.unpackage_events(data['events']) 96 | imgs, img_ts = data['frame'], data['frame_ts'] 97 | if not (isinstance(imgs, list) or isinstance(imgs, tuple)): 98 | imgs, img_ts = [imgs], [img_ts] 99 | 100 | ys = self.sensor_size[0]-ys 101 | xs = self.sensor_size[1]-xs if flip_x else xs 102 | #Crop events 103 | img_size = self.sensor_size 104 | if img_size is None: 105 | img_size = [max(ys), max(ps)] if len(imgs)==0 else imgs[0].shape[0:2] 106 | crop = [0, img_size[0], 0, img_size[1]] if crop is None else crop 107 | xs, ys, ts, ps = clip_events_to_bounds(xs, ys, ts, ps, crop, set_zero=False) 108 | xs, ys = xs-crop[2], ys-crop[0] 109 | 110 | if len(xs) < 2: 111 | xs = np.array([0,0]) 112 | ys = np.array([0,0]) 113 | if img_ts is None: 114 | ts = np.array([0,0]) 115 | else: 116 | ts = np.array([img_ts[0], img_ts[0]+0.000001]) 117 | ps = np.array([0.,0.]) 118 | 119 | #Defaults and range checks 120 | num_show = len(xs) if num_show == -1 else num_show 121 | skip = max(len(xs)//num_show, 1) 122 | num_compress = len(xs) if num_compress == 'all' else num_compress 123 | num_compress = min(int(img_size[0]*img_size[1]*0.5), len(xs)) if num_compress=='auto' else 0 124 | xs, ys, ts, ps = xs[::skip], ys[::skip], ts[::skip], ps[::skip] 125 | 126 | #Prepare the plot, set colors 127 | fig = plt.figure() 128 | ax = fig.add_subplot(111, projection='3d', proj_type = 'ortho') 129 | colors = ['r' if p>0 else ('#00DAFF' if invert else 'b') for p in ps] 130 | 131 | #Plot images 132 | if len(imgs)>0 and show_frames: 133 | for imgidx, (img, img_ts) in enumerate(zip(imgs, img_ts)): 134 | img = img[crop[0]:crop[1], crop[2]:crop[3]].astype(float) 135 | img = np.flip(img, axis=0) 136 | img = np.flip(img, axis=1) if flip_x else img 137 | if len(img.shape)==2: 138 | img = np.stack((img, img, img), axis=2) 139 | if num_compress > 0: 140 | events_img = events_to_image(xs[0:num_compress], ys[0:num_compress], 141 | np.ones(min(num_compress, len(xs))), sensor_size=img.shape[0:2]) 142 | events_img[events_img>0] = 1 143 | img[:,:,1] += events_img[:,:] 144 | img = np.clip(img, 0, 1) 145 | x, y = np.ogrid[0:img.shape[0], 0:img.shape[1]] 146 | event_idx = np.searchsorted(ts, img_ts) 147 | 148 | ax.scatter(xs[0:event_idx], ts[0:event_idx], ys[0:event_idx], zdir='z', 149 | c=colors[0:event_idx], facecolors=colors[0:event_idx], 150 | s=event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0) 151 | 152 | img /= 255.0 153 | #img = cv.normalize(img, None, 0, 1, cv.NORM_MINMAX) 154 | ax.plot_surface(y, img_ts, x, rstride=stride, cstride=stride, facecolors=img, alpha=1) 155 | 156 | ax.scatter(xs[event_idx:-1], ts[event_idx:-1], ys[event_idx:-1], zdir='z', 157 | c=colors[event_idx:-1], facecolors=colors[event_idx:-1], 158 | s=event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0) 159 | 160 | elif num_compress > 0: 161 | # Plot events 162 | ax.scatter(xs[::skip], ts[::skip], ys[::skip], zdir='z', c=colors[::skip], facecolors=colors[::skip], 163 | s=np.ones(xs.shape)*event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0) 164 | num_compress = min(num_compress, len(xs)) 165 | if not compress_front: 166 | ax.scatter(xs[0:num_compress], np.ones(num_compress)*ts[0], ys[0:num_compress], 167 | marker=marker, zdir='z', c='w' if invert else 'k', s=np.ones(num_compress)*event_size) 168 | else: 169 | ax.scatter(xs[-num_compress-1:-1], np.ones(num_compress)*ts[-1], ys[-num_compress-1:-1], 170 | marker=marker, zdir='z', c='w' if invert else 'k', s=np.ones(num_compress)*event_size) 171 | else: 172 | # Plot events 173 | ax.scatter(xs, ts, ys,zdir='z', c=colors, facecolors=colors, s=np.ones(xs.shape)*event_size, marker=marker, linewidths=0, alpha=1.0 if show_events else 0) 174 | 175 | ax.view_init(elev=elev, azim=azim) 176 | ax.grid(False) 177 | # Hide panes 178 | ax.xaxis.pane.fill = False 179 | ax.yaxis.pane.fill = False 180 | ax.zaxis.pane.fill = False 181 | if not show_axes: 182 | # Hide spines 183 | ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 184 | ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 185 | ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 186 | ax.set_frame_on(False) 187 | # Hide xy axes 188 | ax.set_xticks([]) 189 | ax.set_yticks([]) 190 | ax.set_zticks([]) 191 | # Flush axes 192 | ax.set_xlim3d(0, img_size[1]) 193 | ax.set_ylim3d(ts[0], ts[-1]) 194 | ax.set_zlim3d(0,img_size[0]) 195 | #ax.xaxis.set_visible(False) 196 | #ax.axes.get_yaxis().set_visible(False) 197 | 198 | if show_plot: 199 | plt.show() 200 | if save_path is not None: 201 | ensure_dir(save_path) 202 | print("Saving to {}".format(save_path)) 203 | plt.savefig(save_path, transparent=True, dpi=600, bbox_inches = 'tight') 204 | plt.close() 205 | 206 | class VoxelVisualizer(Visualizer): 207 | 208 | def __init__(self, sensor_size): 209 | self.sensor_size = sensor_size 210 | 211 | @staticmethod 212 | def increase_brightness(rgb, increase=0.5): 213 | rgb = (rgb*255).astype('uint8') 214 | channels = rgb.shape[1] 215 | hsv = (np.stack([cv.cvtColor(rgb[:,x,:,:], cv.COLOR_RGB2HSV) for x in range(channels)])).astype(float) 216 | hsv[:,:,:,2] = np.clip(hsv[:,:,:,2] + increase*255, 0, 255) 217 | hsv = hsv.astype('uint8') 218 | rgb_new = np.stack([cv.cvtColor(hsv[x,:,:,:], cv.COLOR_HSV2RGB) for x in range(channels)]) 219 | rgb_new = (rgb_new.transpose(1,0,2,3)).astype(float) 220 | return rgb_new/255.0 221 | 222 | def plot_events(self, data, save_path, bins=5, crop=None, elev=0, azim=45, show_axes=False, 223 | show_plot=False, flip_x=False, size_reduction=10): 224 | 225 | xs, ys, ts, ps = self.unpackage_events(data['events']) 226 | if len(xs) < 2: 227 | return 228 | ys = self.sensor_size[0]-ys 229 | xs = self.sensor_size[1]-xs if flip_x else xs 230 | 231 | frames, frame_ts = data['frame'], data['frame_ts'] 232 | if not isinstance(frames, list): 233 | frames, frame_ts = [frames], [frame_ts] 234 | 235 | if self.sensor_size is None: 236 | self.sensor_size = [np.max(ys)+1, np.max(xs)+1] if len(frames)==0 else frames[0].shape 237 | if crop is not None: 238 | xs, ys, ts, ps = clip_events_to_bounds(xs, ys, ts, ps, crop) 239 | self.sensor_size = crop_to_size(crop) 240 | xs, ys = xs-crop[2], ys-crop[0] 241 | num = 10000 242 | xs, ys, ts, ps = xs[0:num], ys[0:num], ts[0:num], ps[0:num] 243 | if len(xs) == 0: 244 | return 245 | voxels = events_to_voxel(xs, ys, ts, ps, bins, sensor_size=self.sensor_size) 246 | voxels = block_reduce(voxels, block_size=(1,size_reduction,size_reduction), func=np.mean, cval=0) 247 | dimdiff = voxels.shape[1]-voxels.shape[0] 248 | filler = np.zeros((dimdiff, *voxels.shape[1:])) 249 | voxels = np.concatenate((filler, voxels), axis=0) 250 | voxels = voxels.transpose(0,2,1) 251 | 252 | pltvoxels = voxels != 0 253 | pvp, nvp = voxels > 0, voxels < 0 254 | rng = 0.2 255 | min_r, min_b, max_g = 80/255.0, 80/255.0, 0/255.0 256 | 257 | vox_cols = voxels/(max(np.abs(np.max(voxels)), np.abs(np.min(voxels)))) 258 | pvox, nvox = vox_cols*np.where(vox_cols > 0, 1, 0), np.abs(vox_cols)*np.where(vox_cols < 0, 1, 0) 259 | pvox, nvox = pvox*(1-min_r)+min_r, nvox*(1-min_b)+min_b 260 | zeros = np.zeros_like(voxels) 261 | 262 | colors = np.empty(voxels.shape, dtype=object) 263 | 264 | increase = 0.5 265 | redvals = np.stack((pvox, (1.0-pvox)*max_g, pvox-min_r), axis=3) 266 | redvals = self.increase_brightness(redvals, increase=increase) 267 | redvals = nlr.unstructured_to_structured(redvals).astype('O') 268 | 269 | bluvals = np.stack((nvox-min_b, (1.0-nvox)*max_g, nvox), axis=3) 270 | bluvals = self.increase_brightness(bluvals, increase=increase) 271 | bluvals = nlr.unstructured_to_structured(bluvals).astype('O') 272 | 273 | colors[pvp] = redvals[pvp] 274 | colors[nvp] = bluvals[nvp] 275 | 276 | fig = plt.figure() 277 | ax = fig.gca(projection='3d') 278 | ax.voxels(pltvoxels, facecolors=colors) 279 | ax.view_init(elev=elev, azim=azim) 280 | 281 | ax.grid(False) 282 | # Hide panes 283 | ax.xaxis.pane.fill = False 284 | ax.yaxis.pane.fill = False 285 | ax.zaxis.pane.fill = False 286 | if not show_axes: 287 | # Hide spines 288 | ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 289 | ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 290 | ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0)) 291 | ax.set_frame_on(False) 292 | # Hide xy axes 293 | ax.set_xticks([]) 294 | ax.set_yticks([]) 295 | ax.set_zticks([]) 296 | 297 | ax.xaxis.set_visible(False) 298 | ax.axes.get_yaxis().set_visible(False) 299 | 300 | if show_plot: 301 | plt.show() 302 | if save_path is not None: 303 | ensure_dir(save_path) 304 | print("Saving to {}".format(save_path)) 305 | plt.savefig(save_path, transparent=True, dpi=600, bbox_inches = 'tight') 306 | plt.close() 307 | -------------------------------------------------------------------------------- /lib/visualization/visualizers_mayavi.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimoStoff/event_utils/dc0a0712156bb0c3659d90b33e211fa58a83a75f/lib/visualization/visualizers_mayavi.py -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from tqdm import tqdm 4 | import numpy as np 5 | from lib.data_formats.read_events import read_memmap_events, read_h5_events_dict 6 | from lib.data_loaders import MemMapDataset, DynamicH5Dataset, NpyDataset 7 | from lib.visualization.visualizers import TimeStampImageVisualizer, EventImageVisualizer, \ 8 | EventsVisualizer, VoxelVisualizer 9 | 10 | if __name__ == "__main__": 11 | """ 12 | Quick demo 13 | """ 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("path", help="memmap events path") 16 | parser.add_argument("--output_path", type=str, default="/tmp/visualization", help="Where to save image outputs") 17 | parser.add_argument("--filetype", type=str, default="png", help="Which filetype to save as", choices=["png", "jpg", "pdf"]) 18 | 19 | parser.add_argument('--plot_method', default='between_frames', type=str, 20 | help='which method should be used to visualize', 21 | choices=['between_frames', 'k_events', 't_seconds', 'fixed_frames']) 22 | parser.add_argument('--w_width', type=float, default=0.01, 23 | help='new plot is formed every t seconds/k events (required if voxel_method is t_seconds)') 24 | parser.add_argument('--sw_width', type=float, 25 | help='sliding_window size in seconds/events (required if voxel_method is t_seconds)') 26 | parser.add_argument('--num_frames', type=int, default=100, help='if fixed_frames chosen as voxel method, sets the number of frames') 27 | 28 | parser.add_argument('--visualization', type=str, default='events', choices=['events', 'voxels', 'event_image', 'ts_image']) 29 | 30 | parser.add_argument("--num_bins", type=int, default=6, help="How many bins voxels should have.") 31 | 32 | parser.add_argument('--show_plot', action='store_true', help='If true, will also display the plot in an interactive window.\ 33 | Useful for selecting the desired orientation.') 34 | 35 | parser.add_argument("--num_show", type=int, default=-1, help="How many events to show per plot. If -1, show all events.") 36 | parser.add_argument("--event_size", type=float, default=2, help="Marker size of the plotted events") 37 | parser.add_argument("--ts_scale", type=int, default=10000, help="Scales the time axis. Only applicable for mayavi rendering.") 38 | parser.add_argument("--elev", type=float, default=0, help="Elevation of plot") 39 | parser.add_argument("--azim", type=float, default=45, help="Azimuth of plot") 40 | parser.add_argument("--stride", type=int, default=1, help="Downsample stride for plotted images.") 41 | parser.add_argument("--skip_frames", type=int, default=1, help="Amount of frames to place per plot.") 42 | parser.add_argument("--start_frame", type=int, default=0, help="On which frame to start.") 43 | parser.add_argument('--hide_skipped', action='store_true', help='Do not draw skipped frames into plot.') 44 | parser.add_argument('--hide_events', action='store_true', help='Do not draw events') 45 | parser.add_argument('--hide_frames', action='store_true', help='Do not draw frames') 46 | parser.add_argument('--show_axes', action='store_true', help='Draw axes') 47 | parser.add_argument('--flip_x', action='store_true', help='Flip in the x axis') 48 | parser.add_argument("--num_compress", type=str, default='auto', help="How many events to draw compressed. If 'auto'\ 49 | will automatically determine.", choices=['auto', 'none', 'all']) 50 | parser.add_argument('--compress_front', action='store_true', help='If set, will put the compressed events at the _start_\ 51 | of the event volume, rather than the back.') 52 | parser.add_argument('--invert', action='store_true', help='If the figure is for a black background, you can invert the \ 53 | colors for better visibility.') 54 | parser.add_argument("--crop", type=str, default=None, help="Set a crop of both images and events. Uses 'imagemagick' \ 55 | syntax, eg for a crop of 10x20 starting from point 30,40 use: 10x20+30+40.") 56 | parser.add_argument("--renderer", type=str, default="matplotlib", help="Which renderer to use (mayavi is faster)", choices=["matplotlib", "mayavi"]) 57 | args = parser.parse_args() 58 | if not os.path.exists(args.output_path): 59 | os.makedirs(args.output_path) 60 | 61 | if os.path.isdir(args.path): 62 | loader_type = MemMapDataset 63 | elif os.path.splitext(args.path)[1] == ".npy": 64 | loader_type = NpyDataset 65 | else: 66 | loader_type = DynamicH5Dataset 67 | dataloader = loader_type(args.path, voxel_method={'method':args.plot_method, 't':args.w_width, 68 | 'k':args.w_width, 'sliding_window_t':args.sw_width, 'sliding_window_w':args.sw_width, 'num_frames':args.num_frames}, 69 | return_events=True, return_voxelgrid=False, return_frame=True, return_flow=True, return_format='numpy') 70 | sensor_size = dataloader.size() 71 | 72 | if args.visualization == 'events': 73 | kwargs = {'num_compress':args.num_compress, 'num_show':args.num_show, 'event_size':args.event_size, 74 | 'elev':args.elev, 'azim':args.azim, 'show_events':not args.hide_events, 75 | 'show_frames':not args.hide_frames, 'show_plot':args.show_plot, 'crop':args.crop, 76 | 'compress_front':args.compress_front, 'marker':'.', 'stride':args.stride, 77 | 'invert':args.invert, 'show_axes':args.show_axes, 'flip_x':args.flip_x} 78 | visualizer = EventsVisualizer(sensor_size) 79 | elif args.visualization == 'voxels': 80 | kwargs = {'bins':args.num_bins, 'crop':args.crop, 'elev':args.elev, 'azim':args.azim, 81 | 'show_axes':args.show_axes, 'show_plot':args.show_plot, 'flip_x':args.flip_x} 82 | visualizer = VoxelVisualizer(sensor_size) 83 | elif args.visualization == 'event_image': 84 | kwargs = {} 85 | visualizer = EventImageVisualizer(sensor_size) 86 | elif args.visualization == 'ts_image': 87 | kwargs = {} 88 | visualizer = TimeStampImageVisualizer(sensor_size) 89 | else: 90 | raise Exception("Unknown visualization chosen: {}".format(args.visualization)) 91 | 92 | plot_data = {'events':np.ones((0, 4)), 'frame':[], 'frame_ts':[]} 93 | print("{} frames in sequence".format(len(dataloader))) 94 | for i, data in enumerate(tqdm(dataloader)): 95 | plot_data['events'] = np.concatenate((plot_data['events'], data['events'])) 96 | if args.plot_method == 'between_frames': 97 | plot_data['frame'].append(data['frame']) 98 | plot_data['frame_ts'].append(data['frame_ts']) 99 | else: 100 | plot_data['frame'] = data['frame'] 101 | plot_data['frame_ts'] = data['frame_ts'] 102 | 103 | output_path = os.path.join(args.output_path, "frame_{:010d}.{}".format(i, args.filetype)) 104 | if i%args.skip_frames == 0: 105 | visualizer.plot_events(plot_data, output_path, **kwargs) 106 | plot_data = {'events':np.ones((0, 4)), 'frame':[], 'frame_ts':[]} 107 | 108 | #if args.plot_method == 'between_frames': 109 | # if args.renderer == "mayavi": 110 | # from lib.visualization.draw_event_stream_mayavi import plot_between_frames 111 | # plot_between_frames(xs, ys, ts, ps, frames, frame_idx, args, plttype='events') 112 | # elif args.renderer == "matplotlib": 113 | # from lib.visualization.draw_event_stream import plot_between_frames 114 | # plot_between_frames(xs, ys, ts, ps, frames, frame_idx, args, plttype='events') 115 | #elif args.plot_method == 'k_events': 116 | # print(args.renderer) 117 | # pass 118 | #elif args.plot_method == 't_seconds': 119 | # if args.renderer == "mayavi": 120 | # from lib.visualization.draw_event_stream_mayavi import plot_events_sliding 121 | # plot_events_sliding(xs, ys, ts, ps, args, dt=args.w_width, sdt=args.sw_width, frames=frames, frame_ts=frame_ts) 122 | # elif args.renderer == "matplotlib": 123 | # from lib.visualization.draw_event_stream import plot_events_sliding 124 | # plot_events_sliding(xs, ys, ts, ps, args, frames=frames, frame_ts=frame_ts) 125 | -------------------------------------------------------------------------------- /visualize_events.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from lib.data_formats.read_events import read_memmap_events, read_h5_events_dict 5 | 6 | if __name__ == "__main__": 7 | """ 8 | Quick demo 9 | """ 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("path", help="memmap events path") 12 | parser.add_argument("--output_path", type=str, default="/tmp/visualization", help="Where to save image outputs") 13 | 14 | parser.add_argument('--plot_method', default='between_frames', type=str, 15 | help='which method should be used to visualize', 16 | choices=['between_frames', 'k_events', 't_seconds']) 17 | parser.add_argument('--w_width', type=float, default=0.01, 18 | help='new plot is formed every t seconds (required if voxel_method is t_seconds)') 19 | parser.add_argument('--sw_width', type=float, 20 | help='sliding_window size in seconds (required if voxel_method is t_seconds)') 21 | 22 | parser.add_argument("--num_bins", type=int, default=6, help="How many bins voxels should have.") 23 | 24 | parser.add_argument('--show_plot', action='store_true', help='If true, will also display the plot in an interactive window.\ 25 | Useful for selecting the desired orientation.') 26 | 27 | parser.add_argument("--num_show", type=int, default=-1, help="How many events to show per plot. If -1, show all events.") 28 | parser.add_argument("--event_size", type=float, default=2, help="Marker size of the plotted events") 29 | parser.add_argument("--ts_scale", type=int, default=10000, help="Scales the time axis. Only applicable for mayavi rendering.") 30 | parser.add_argument("--elev", type=float, default=20, help="Elevation of plot") 31 | parser.add_argument("--azim", type=float, default=45, help="Azimuth of plot") 32 | parser.add_argument("--stride", type=int, default=1, help="Downsample stride for plotted images.") 33 | parser.add_argument("--skip_frames", type=int, default=1, help="Amount of frames to place per plot.") 34 | parser.add_argument("--start_frame", type=int, default=0, help="On which frame to start.") 35 | parser.add_argument('--hide_skipped', action='store_true', help='Do not draw skipped frames into plot.') 36 | parser.add_argument('--hide_events', action='store_true', help='Do not draw events') 37 | parser.add_argument('--hide_frames', action='store_true', help='Do not draw frames') 38 | parser.add_argument('--show_axes', action='store_true', help='Draw axes') 39 | parser.add_argument("--num_compress", type=int, default=0, help="How many events to draw compressed. If 'auto'\ 40 | will automatically determine.", choices=['value', 'auto']) 41 | parser.add_argument('--compress_front', action='store_true', help='If set, will put the compressed events at the _start_\ 42 | of the event volume, rather than the back.') 43 | parser.add_argument('--invert', action='store_true', help='If the figure is for a black background, you can invert the \ 44 | colors for better visibility.') 45 | parser.add_argument("--crop", type=str, default=None, help="Set a crop of both images and events. Uses 'imagemagick' \ 46 | syntax, eg for a crop of 10x20 starting from point 30,40 use: 10x20+30+40.") 47 | parser.add_argument("--renderer", type=str, default="matplotlib", help="Which renderer to use (mayavi is faster)", choices=["matplotlib", "mayavi"]) 48 | args = parser.parse_args() 49 | 50 | if os.path.isdir(args.path): 51 | events = read_memmap_events(args.path) 52 | 53 | ts = events['t'][:].squeeze() 54 | t0 = ts[0] 55 | ts = ts-t0 56 | frames = (events['images'][args.start_frame+1::])/255 57 | frame_idx = events['index'][args.start_frame::] 58 | frame_ts = events['frame_stamps'][args.start_frame+1::]-t0 59 | 60 | start_idx = np.searchsorted(ts, frame_ts[0]) 61 | print("Starting from frame {}, event {}".format(args.start_frame, start_idx)) 62 | 63 | xs = events['xy'][:,0] 64 | ys = events['xy'][:,1] 65 | ts = ts[:] 66 | ps = events['p'][:] 67 | 68 | print("Have {} frames".format(frames.shape)) 69 | else: 70 | events = read_h5_events_dict(args.path) 71 | xs = events['xs'] 72 | ys = events['ys'] 73 | ts = events['ts'] 74 | ps = events['ps'] 75 | t0 = ts[0] 76 | ts = ts-t0 77 | frames = [np.flip(np.flip(x/255., axis=0), axis=1) for x in events['frames']] 78 | frame_ts = events['frame_timestamps'][1:]-t0 79 | frame_end = events['frame_event_indices'][1:] 80 | frame_start = np.concatenate((np.array([0]), frame_end)) 81 | frame_idx = np.stack((frame_end, frame_start[0:-1]), axis=1) 82 | ys = frames[0].shape[0]-ys 83 | xs = frames[0].shape[1]-xs 84 | 85 | if args.plot_method == 'between_frames': 86 | if args.renderer == "mayavi": 87 | from lib.visualization.draw_event_stream_mayavi import plot_between_frames 88 | plot_between_frames(xs, ys, ts, ps, frames, frame_idx, args, plttype='events') 89 | elif args.renderer == "matplotlib": 90 | from lib.visualization.draw_event_stream import plot_between_frames 91 | plot_between_frames(xs, ys, ts, ps, frames, frame_idx, args, plttype='events') 92 | elif args.plot_method == 'k_events': 93 | print(args.renderer) 94 | pass 95 | elif args.plot_method == 't_seconds': 96 | if args.renderer == "mayavi": 97 | from lib.visualization.draw_event_stream_mayavi import plot_events_sliding 98 | plot_events_sliding(xs, ys, ts, ps, args, dt=args.w_width, sdt=args.sw_width, frames=frames, frame_ts=frame_ts) 99 | elif args.renderer == "matplotlib": 100 | from lib.visualization.draw_event_stream import plot_events_sliding 101 | plot_events_sliding(xs, ys, ts, ps, args, frames=frames, frame_ts=frame_ts) 102 | -------------------------------------------------------------------------------- /visualize_flow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pandas as pd 4 | import glob 5 | import numpy as np 6 | import cv2 as cv 7 | from lib.data_formats.read_events import read_memmap_events, read_h5_events_dict 8 | 9 | if __name__ == "__main__": 10 | """ 11 | Quick demo 12 | """ 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("path", help="events path") 15 | parser.add_argument("flow_path", help="flow path") 16 | parser.add_argument("--output_path", type=str, default="/tmp/visualization", help="Where to save image outputs") 17 | 18 | parser.add_argument('--plot_method', default='between_frames', type=str, 19 | help='which method should be used to visualize', 20 | choices=['between_frames', 'k_events', 't_seconds']) 21 | parser.add_argument('--w_width', type=float, default=0.01, 22 | help='new plot is formed every t seconds (required if voxel_method is t_seconds)') 23 | parser.add_argument('--sw_width', type=float, 24 | help='sliding_window size in seconds (required if voxel_method is t_seconds)') 25 | 26 | parser.add_argument("--num_bins", type=int, default=6, help="How many bins voxels should have.") 27 | 28 | parser.add_argument('--show_plot', action='store_true', help='If true, will also display the plot in an interactive window.\ 29 | Useful for selecting the desired orientation.') 30 | 31 | parser.add_argument("--num_show", type=int, default=-1, help="How many events to show per plot. If -1, show all events.") 32 | parser.add_argument("--event_size", type=float, default=2, help="Marker size of the plotted events") 33 | parser.add_argument("--ts_scale", type=int, default=10000, help="Scales the time axis. Only applicable for mayavi rendering.") 34 | parser.add_argument("--elev", type=float, default=0, help="Elevation of plot") 35 | parser.add_argument("--azim", type=float, default=45, help="Azimuth of plot") 36 | parser.add_argument("--stride", type=int, default=1, help="Downsample stride for plotted images.") 37 | parser.add_argument("--skip_frames", type=int, default=1, help="Amount of frames to place per plot.") 38 | parser.add_argument("--start_frame", type=int, default=0, help="On which frame to start.") 39 | parser.add_argument('--hide_skipped', action='store_true', help='Do not draw skipped frames into plot.') 40 | parser.add_argument('--hide_events', action='store_true', help='Do not draw events') 41 | parser.add_argument('--hide_frames', action='store_true', help='Do not draw frames') 42 | parser.add_argument('--show_axes', action='store_true', help='Draw axes') 43 | parser.add_argument("--num_compress", type=int, default=0, help="How many events to draw compressed. If 'auto'\ 44 | will automatically determine.", choices=['value', 'auto']) 45 | parser.add_argument('--compress_front', action='store_true', help='If set, will put the compressed events at the _start_\ 46 | of the event volume, rather than the back.') 47 | parser.add_argument('--invert', action='store_true', help='If the figure is for a black background, you can invert the \ 48 | colors for better visibility.') 49 | parser.add_argument("--crop", type=str, default=None, help="Set a crop of both images and events. Uses 'imagemagick' \ 50 | syntax, eg for a crop of 10x20 starting from point 30,40 use: 10x20+30+40.") 51 | parser.add_argument("--renderer", type=str, default="matplotlib", help="Which renderer to use (mayavi is faster)", choices=["matplotlib", "mayavi"]) 52 | args = parser.parse_args() 53 | 54 | events = read_h5_events_dict(args.path) 55 | xs = events['xs'] 56 | ys = events['ys'] 57 | ts = events['ts'] 58 | ps = events['ps'] 59 | t0 = ts[0] 60 | ts = ts-t0 61 | frames = [np.flip(np.flip(x/255., axis=0), axis=1) for x in events['frames']] 62 | frame_ts = events['frame_timestamps'][1:]-t0 63 | frame_end = events['frame_event_indices'][1:] 64 | frame_start = np.concatenate((np.array([0]), frame_end)) 65 | frame_idx = np.stack((frame_end, frame_start[0:-1]), axis=1) 66 | ys = frames[0].shape[0]-ys 67 | xs = frames[0].shape[1]-xs 68 | 69 | flow_paths = sorted(glob.glob(os.path.join(args.flow_path, "*.npy"))) 70 | flow_img_paths = sorted(glob.glob(os.path.join(args.flow_path, "*.png"))) 71 | flow_ts = pd.read_csv(os.path.join(args.flow_path, "timestamps.txt"), delimiter=" ", names=["fname", "timestamp"]) 72 | flow_ts = np.array(flow_ts["timestamp"]) 73 | 74 | #flows = [-np.flip(np.flip(np.load(fp), axis=1), axis=2) for fp in flow_paths] 75 | flows = [-np.load(fp) for fp in flow_paths] 76 | flow_imgs = [cv.imread(fi) for fi in flow_img_paths] 77 | print("Loaded {} flow, {} img, {} ts".format(len(flows), len(flow_imgs), len(flow_ts))) 78 | 79 | if args.plot_method == 'between_frames': 80 | if args.renderer == "mayavi": 81 | print(args.renderer) 82 | pass 83 | elif args.renderer == "matplotlib": 84 | from lib.visualization.draw_flow import plot_between_frames 85 | plot_between_frames(xs, ys, ts, ps, flows, flow_imgs, flow_ts, args) 86 | print(args.renderer) 87 | pass 88 | elif args.plot_method == 'k_events': 89 | print(args.renderer) 90 | pass 91 | elif args.plot_method == 't_seconds': 92 | if args.renderer == "mayavi": 93 | print(args.renderer) 94 | pass 95 | elif args.renderer == "matplotlib": 96 | print(args.renderer) 97 | pass 98 | -------------------------------------------------------------------------------- /visualize_voxel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from lib.visualization.draw_event_stream import plot_between_frames 5 | from lib.data_formats.read_events import read_memmap_events, read_h5_events_dict 6 | 7 | def plot_events_sliding(xs, ys, ts, ps, args, frames=None, frame_ts=None): 8 | if dt is None: 9 | dt = (ts[-1]-ts[0])/10 10 | sdt = dt/10 11 | print("Using dt={}, sdt={}".format(dt, sdt)) 12 | if frames is not None: 13 | sensor_size = frames[0].shape 14 | else: 15 | sensor_size = [max(ys), max(xs)] 16 | 17 | if len(frame_ts.shape) == 2: 18 | frame_ts = frame_ts[:,1] 19 | for i, t0 in enumerate(tqdm(np.arange(ts[0], ts[-1]-dt, sdt))): 20 | te = t0+dt 21 | eidx0 = np.searchsorted(ts, t0) 22 | eidx1 = np.searchsorted(ts, te) 23 | fidx0 = np.searchsorted(frame_ts, t0) 24 | fidx1 = np.searchsorted(frame_ts, te) 25 | #print("{}:{} = {}".format(frame_ts[fidx0], ts[eidx0], fidx0)) 26 | 27 | wxs, wys, wts, wps = xs[eidx0:eidx1], ys[eidx0:eidx1], ts[eidx0:eidx1], ps[eidx0:eidx1], 28 | if fidx0 == fidx1: 29 | wframes=[] 30 | wframe_ts=[] 31 | else: 32 | wframes = frames[fidx0:fidx1] 33 | wframe_ts = frame_ts[fidx0:fidx1] 34 | 35 | save_path = os.path.join(args.output_path, "frame_{:010d}.png".format(i)) 36 | plot_events(wxs, wys, wts, wps, save_path=save_path, num_show=args.num_show, event_size=args.event_size, 37 | imgs=args.wframes, img_ts=args.wframe_ts, show_events=args.show_events, azim=args.azim, 38 | elev=args.elev, show_frames=args.show_frames, crop=args.crop, compress_front=args.compress_front, 39 | invert=args.invert, num_compress=args.num_compress, show_plot=args.show_plot, img_size=args.sensor_size, 40 | show_axes=args.show_axes) 41 | 42 | if __name__ == "__main__": 43 | """ 44 | Quick demo 45 | """ 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("path", help="memmap events path") 48 | parser.add_argument("--output_path", type=str, default="/tmp/visualization", help="Where to save image outputs") 49 | 50 | parser.add_argument('--plot_method', default='between_frames', type=str, 51 | help='which method should be used to visualize', 52 | choices=['between_frames', 'k_events', 't_seconds']) 53 | parser.add_argument('--k', type=int, 54 | help='new plot is formed every k events (required if voxel_method is k_events)') 55 | parser.add_argument('--sliding_window_w', type=int, 56 | help='sliding_window size (required if voxel_method is k_events)') 57 | parser.add_argument('--t', type=float, 58 | help='new plot is formed every t seconds (required if voxel_method is t_seconds)') 59 | parser.add_argument('--sliding_window_t', type=float, 60 | help='sliding_window size in seconds (required if voxel_method is t_seconds)') 61 | parser.add_argument("--num_bins", type=int, default=6, help="How many bins voxels should have.") 62 | 63 | parser.add_argument('--show_plot', action='store_true', help='If true, will also display the plot in an interactive window.\ 64 | Useful for selecting the desired orientation.') 65 | 66 | parser.add_argument("--num_show", type=int, default=-1, help="How many events to show per plot. If -1, show all events.") 67 | parser.add_argument("--event_size", type=float, default=2, help="Marker size of the plotted events") 68 | parser.add_argument("--elev", type=float, default=20, help="Elevation of plot") 69 | parser.add_argument("--azim", type=float, default=-25, help="Azimuth of plot") 70 | parser.add_argument("--stride", type=int, default=1, help="Downsample stride for plotted images.") 71 | parser.add_argument("--skip_frames", type=int, default=1, help="Amount of frames to place per plot.") 72 | parser.add_argument("--start_frame", type=int, default=0, help="On which frame to start.") 73 | parser.add_argument('--hide_skipped', action='store_true', help='Do not draw skipped frames into plot.') 74 | parser.add_argument('--hide_events', action='store_true', help='Do not draw events') 75 | parser.add_argument('--hide_frames', action='store_true', help='Do not draw frames') 76 | parser.add_argument('--show_axes', action='store_true', help='Draw axes') 77 | parser.add_argument("--num_compress", type=int, default=0, help="How many events to draw compressed. If 'auto'\ 78 | will automatically determine.", choices=['value', 'auto']) 79 | parser.add_argument('--compress_front', action='store_true', help='If set, will put the compressed events at the _start_\ 80 | of the event volume, rather than the back.') 81 | parser.add_argument('--invert', action='store_true', help='If the figure is for a black background, you can invert the \ 82 | colors for better visibility.') 83 | parser.add_argument("--crop", type=str, default=None, help="Set a crop of both images and events. Uses 'imagemagick' \ 84 | syntax, eg for a crop of 10x20 starting from point 30,40 use: 10x20+30+40.") 85 | args = parser.parse_args() 86 | 87 | if os.path.isdir(args.path): 88 | events = read_memmap_events(args.path) 89 | 90 | ts = events['t'][:].squeeze() 91 | t0 = ts[0] 92 | ts = ts-t0 93 | frames = (events['images'][args.start_frame+1::])/255 94 | frame_idx = events['index'][args.start_frame::] 95 | frame_ts = events['frame_stamps'][args.start_frame+1::]-t0 96 | 97 | start_idx = np.searchsorted(ts, frame_ts[0]) 98 | print("Starting from frame {}, event {}".format(args.start_frame, start_idx)) 99 | 100 | xs = events['xy'][:,0] 101 | ys = events['xy'][:,1] 102 | ts = ts[:] 103 | ps = events['p'][:] 104 | 105 | print("Have {} frames".format(frames.shape)) 106 | else: 107 | events = read_h5_events_dict(args.path) 108 | xs = events['xs'] 109 | ys = events['ys'] 110 | ts = events['ts'] 111 | ps = events['ps'] 112 | t0 = ts[0] 113 | ts = ts-t0 114 | frames = [np.flip(x/255., axis=0) for x in events['frames']] 115 | frame_ts = events['frame_timestamps'][1:]-t0 116 | frame_end = events['frame_event_indices'][1:] 117 | frame_start = np.concatenate((np.array([0]), frame_end)) 118 | frame_idx = np.stack((frame_end, frame_start[0:-1]), axis=1) 119 | ys = frames[0].shape[0]-ys 120 | 121 | plot_between_frames(xs, ys, ts, ps, frames, frame_idx, args, plttype='voxel') 122 | --------------------------------------------------------------------------------