├── .gitignore ├── DSEC ├── dataset │ ├── provider.py │ ├── representations.py │ ├── sequence.py │ ├── sequence_recurrent.py │ └── visualization.py ├── utils │ ├── __init__.py │ ├── eventslicer.py │ └── viz_utils.py └── visualization │ ├── __init__.py │ └── eventreader.py ├── LICENSE ├── README.md ├── __init__.py ├── config ├── __init__.py ├── settings.py ├── settings_DDD17.yaml └── settings_DSEC.yaml ├── datasets ├── DSEC_events_loader.py ├── cityscapes_loader.py ├── data_util.py ├── ddd17_events_loader.py ├── extract_data_tools │ └── example_loader_ddd17.py └── wrapper_dataloader.py ├── e2vid ├── LICENSE ├── README.md ├── base │ ├── __init__.py │ └── base_model.py ├── image_reconstructor.py ├── model │ ├── __init__.py │ ├── model.py │ ├── submodules.py │ └── unet.py ├── options │ ├── __init__.py │ └── inference_options.py ├── pretrained │ └── .gitignore ├── run_reconstruction.py └── utils │ ├── __init__.py │ ├── event_readers.py │ ├── inference_utils.py │ ├── loading_utils.py │ ├── path_utils.py │ ├── timers.py │ └── util.py ├── evaluation ├── __init__.py └── metrics.py ├── models ├── __init__.py ├── style_networks.py └── submodules.py ├── requirements.txt ├── resources └── ESS.png ├── train.py ├── training ├── __init__.py ├── base_trainer.py ├── ess_supervised_trainer.py └── ess_trainer.py └── utils ├── __init__.py ├── labels.py ├── loss_functions.py ├── radam.py ├── saver.py └── viz_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | log/ 3 | 4 | #data loader test 5 | cityscapes_test.py 6 | ddd17_test.py 7 | DSEC_align.py 8 | DSEC_final.py 9 | DSEC_final_UDA.py 10 | DSEC_hdr.py 11 | DSEC_loader_test.py 12 | DSEC_test.py 13 | DSEC_test_2.py 14 | e2vid_test.py 15 | EventScape_final.py 16 | Eventscape_final_UDA.py 17 | eventscapeloadertest.py 18 | flowtest.py 19 | 20 | config/settings_DDD17.yaml 21 | config/eventscape_settings_snaga.yaml 22 | 23 | training/img_to_events_pre_trainer.py 24 | training/img_to_events_trainer.py 25 | training/semantic_segmentation_pre_trainer.py 26 | training/semantic_segmentation_recurrent_trainer.py 27 | training/supervised_baseline.py 28 | training/ess_supervised_trainer.py 29 | training/semantic_segmentation_baseline_supervised.py 30 | training/style_latent_trainer.py 31 | training/style_object_det_trainer.py -------------------------------------------------------------------------------- /DSEC/dataset/provider.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | 5 | from DSEC.dataset.sequence import Sequence 6 | 7 | 8 | class DatasetProvider: 9 | def __init__(self, dataset_path: Path, mode: str = 'train', event_representation: str = 'voxel_grid', 10 | nr_events_data: int = 5, delta_t_per_data: int = 20, 11 | nr_events_window=-1, nr_bins_per_data=5, require_paired_data=False, normalize_event=False, 12 | separate_pol=False, semseg_num_classes=11, augmentation=False, 13 | fixed_duration=False, resize=False): 14 | train_path = dataset_path / 'train' 15 | val_path = dataset_path / 'test' 16 | assert dataset_path.is_dir(), str(dataset_path) 17 | assert train_path.is_dir(), str(train_path) 18 | assert val_path.is_dir(), str(val_path) 19 | 20 | if mode == 'train': 21 | train_sequences = list() 22 | train_sequences_namelist = ['zurich_city_00_a', 'zurich_city_01_a', 'zurich_city_02_a', 23 | 'zurich_city_04_a', 'zurich_city_05_a', 'zurich_city_06_a', 24 | 'zurich_city_07_a', 'zurich_city_08_a'] 25 | for child in train_path.iterdir(): 26 | if any(k in str(child) for k in train_sequences_namelist): 27 | train_sequences.append(Sequence(child, 'train', event_representation, nr_events_data, delta_t_per_data, 28 | nr_events_window, nr_bins_per_data, require_paired_data, normalize_event 29 | , separate_pol, semseg_num_classes, augmentation, fixed_duration 30 | , resize=resize)) 31 | else: 32 | continue 33 | 34 | self.train_dataset = torch.utils.data.ConcatDataset(train_sequences) 35 | self.train_dataset.require_paired_data = require_paired_data 36 | 37 | elif mode == 'val': 38 | val_sequences = list() 39 | val_sequences_namelist = ['zurich_city_13_a', 'zurich_city_14_c', 'zurich_city_15_a'] 40 | for child in val_path.iterdir(): 41 | if any(k in str(child) for k in val_sequences_namelist): 42 | val_sequences.append(Sequence(child, 'val', event_representation, nr_events_data, delta_t_per_data, 43 | nr_events_window, nr_bins_per_data, require_paired_data, normalize_event 44 | , separate_pol, semseg_num_classes, augmentation, fixed_duration 45 | , resize=resize)) 46 | else: 47 | continue 48 | 49 | self.val_dataset = torch.utils.data.ConcatDataset(val_sequences) 50 | self.val_dataset.require_paired_data = require_paired_data 51 | 52 | 53 | def get_train_dataset(self): 54 | return self.train_dataset 55 | 56 | def get_val_dataset(self): 57 | # Implement this according to your needs. 58 | return self.val_dataset 59 | 60 | def get_test_dataset(self): 61 | # Implement this according to your needs. 62 | raise NotImplementedError 63 | -------------------------------------------------------------------------------- /DSEC/dataset/representations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EventRepresentation: 5 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor): 6 | raise NotImplementedError 7 | 8 | 9 | class VoxelGrid(EventRepresentation): 10 | def __init__(self, channels: int, height: int, width: int, normalize: bool): 11 | self.voxel_grid = torch.zeros((channels, height, width), dtype=torch.float, requires_grad=False) 12 | self.nb_channels = channels 13 | self.normalize = normalize 14 | 15 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor): 16 | assert x.shape == y.shape == pol.shape == time.shape 17 | assert x.ndim == 1 18 | 19 | C, H, W = self.voxel_grid.shape 20 | with torch.no_grad(): 21 | self.voxel_grid = self.voxel_grid.to(pol.device) 22 | voxel_grid = self.voxel_grid.clone() 23 | 24 | t_norm = time 25 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0]) 26 | 27 | x0 = x.int() 28 | y0 = y.int() 29 | t0 = t_norm.int() 30 | 31 | value = 2*pol-1 32 | 33 | for xlim in [x0,x0+1]: 34 | for ylim in [y0,y0+1]: 35 | for tlim in [t0,t0+1]: 36 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels) 37 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs()) 38 | 39 | index = H * W * tlim.long() + \ 40 | W * ylim.long() + \ 41 | xlim.long() 42 | 43 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 44 | 45 | if self.normalize: 46 | mask = torch.nonzero(voxel_grid, as_tuple=True) 47 | if mask[0].size()[0] > 0: 48 | mean = voxel_grid[mask].mean() 49 | std = voxel_grid[mask].std() 50 | if std > 0: 51 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 52 | else: 53 | voxel_grid[mask] = voxel_grid[mask] - mean 54 | 55 | return voxel_grid 56 | -------------------------------------------------------------------------------- /DSEC/dataset/sequence.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/uzh-rpg/DSEC/blob/main/scripts/dataset/sequence.py 3 | """ 4 | from pathlib import Path 5 | import weakref 6 | 7 | import cv2 8 | import h5py 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as f 12 | from torch.utils.data import Dataset 13 | import torchvision.transforms as transforms 14 | from PIL import Image 15 | from joblib import Parallel, delayed 16 | 17 | from DSEC.dataset.representations import VoxelGrid 18 | from DSEC.utils.eventslicer import EventSlicer 19 | import albumentations as A 20 | import datasets.data_util as data_util 21 | import random 22 | 23 | class Sequence(Dataset): 24 | # This class assumes the following structure in a sequence directory: 25 | # 26 | # seq_name (e.g. zurich_city_00_a) 27 | # ├── semantic 28 | # │ ├── left 29 | # │ │ ├── 11classes 30 | # │ │ │ └──data 31 | # │ │ │ ├── 000000.png 32 | # │ │ │ └── ... 33 | # │ │ └── 19classes 34 | # │ │ └──data 35 | # │ │ ├── 000000.png 36 | # │ │ └── ... 37 | # │ └── timestamps.txt 38 | # └── events 39 | # └── left 40 | # ├── events.h5 41 | # └── rectify_map.h5 42 | 43 | def __init__(self, seq_path: Path, mode: str='train', event_representation: str = 'voxel_grid', 44 | nr_events_data: int = 5, delta_t_per_data: int = 20, nr_events_per_data: int = 100000, 45 | nr_bins_per_data: int = 5, require_paired_data=False, normalize_event=False, separate_pol=False, 46 | semseg_num_classes: int = 11, augmentation=False, fixed_duration=False, remove_time_window: int = 250, 47 | resize=False): 48 | assert nr_bins_per_data >= 1 49 | assert seq_path.is_dir() 50 | self.sequence_name = seq_path.name 51 | self.mode = mode 52 | 53 | # Save output dimensions 54 | self.height = 480 55 | self.width = 640 56 | self.resize = resize 57 | self.shape_resize = None 58 | if self.resize: 59 | self.shape_resize = [448, 640] 60 | 61 | # Set event representation 62 | self.nr_events_data = nr_events_data 63 | self.num_bins = nr_bins_per_data 64 | assert nr_events_per_data > 0 65 | self.nr_events_per_data = nr_events_per_data 66 | self.event_representation = event_representation 67 | self.separate_pol = separate_pol 68 | self.normalize_event = normalize_event 69 | self.voxel_grid = VoxelGrid(self.num_bins, self.height, self.width, normalize=self.normalize_event) 70 | 71 | self.locations = ['left'] 72 | self.semseg_num_classes = semseg_num_classes 73 | self.augmentation = augmentation 74 | 75 | # Save delta timestamp 76 | self.fixed_duration = fixed_duration 77 | if self.fixed_duration: 78 | delta_t_ms = nr_events_data * delta_t_per_data 79 | self.delta_t_us = delta_t_ms * 1000 80 | self.remove_time_window = remove_time_window 81 | 82 | self.require_paired_data = require_paired_data 83 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 84 | 85 | # load timestamps 86 | self.timestamps = np.loadtxt(str(seq_path / 'semantic' / 'timestamps.txt'), dtype='int64') 87 | 88 | # load label paths 89 | if self.semseg_num_classes == 11: 90 | label_dir = seq_path / 'semantic' / '11classes' / 'data' 91 | elif self.semseg_num_classes == 19: 92 | label_dir = seq_path / 'semantic' / '19classes' / 'data' 93 | else: 94 | raise ValueError 95 | assert label_dir.is_dir() 96 | label_pathstrings = list() 97 | for entry in label_dir.iterdir(): 98 | assert str(entry.name).endswith('.png') 99 | label_pathstrings.append(str(entry)) 100 | label_pathstrings.sort() 101 | self.label_pathstrings = label_pathstrings 102 | 103 | assert len(self.label_pathstrings) == self.timestamps.size 104 | 105 | # load images paths 106 | if self.require_paired_data: 107 | img_dir = seq_path / 'images' 108 | img_left_dir = img_dir / 'left' / 'ev_inf' 109 | assert img_left_dir.is_dir() 110 | img_left_pathstrings = list() 111 | for entry in img_left_dir.iterdir(): 112 | assert str(entry.name).endswith('.png') 113 | img_left_pathstrings.append(str(entry)) 114 | img_left_pathstrings.sort() 115 | self.img_left_pathstrings = img_left_pathstrings 116 | 117 | assert len(self.img_left_pathstrings) == self.timestamps.size 118 | 119 | # Remove several label paths and corresponding timestamps in the remove_time_window. 120 | # This is necessary because we do not have enough events before the first label. 121 | self.timestamps = self.timestamps[(self.remove_time_window // 100 + 1) * 2:] 122 | del self.label_pathstrings[:(self.remove_time_window // 100 + 1) * 2] 123 | assert len(self.label_pathstrings) == self.timestamps.size 124 | if self.require_paired_data: 125 | del self.img_left_pathstrings[:(self.remove_time_window // 100 + 1) * 2] 126 | assert len(self.img_left_pathstrings) == self.timestamps.size 127 | 128 | self.h5f = dict() 129 | self.rectify_ev_maps = dict() 130 | self.event_slicers = dict() 131 | 132 | ev_dir = seq_path / 'events' 133 | for location in self.locations: 134 | ev_dir_location = ev_dir / location 135 | ev_data_file = ev_dir_location / 'events.h5' 136 | ev_rect_file = ev_dir_location / 'rectify_map.h5' 137 | 138 | h5f_location = h5py.File(str(ev_data_file), 'r') 139 | self.h5f[location] = h5f_location 140 | self.event_slicers[location] = EventSlicer(h5f_location) 141 | with h5py.File(str(ev_rect_file), 'r') as h5_rect: 142 | self.rectify_ev_maps[location] = h5_rect['rectify_map'][()] 143 | 144 | def events_to_voxel_grid(self, x, y, p, t): 145 | t = (t - t[0]).astype('float32') 146 | t = (t/t[-1]) 147 | x = x.astype('float32') 148 | y = y.astype('float32') 149 | pol = p.astype('float32') 150 | return self.voxel_grid.convert( 151 | torch.from_numpy(x), 152 | torch.from_numpy(y), 153 | torch.from_numpy(pol), 154 | torch.from_numpy(t)) 155 | 156 | def getHeightAndWidth(self): 157 | return self.height, self.width 158 | 159 | @staticmethod 160 | def get_disparity_map(filepath: Path): 161 | assert filepath.is_file() 162 | disp_16bit = cv2.imread(str(filepath), cv2.IMREAD_ANYDEPTH) 163 | return disp_16bit.astype('float32')/256 164 | 165 | @staticmethod 166 | def get_img(filepath: Path, shape_resize=None): 167 | assert filepath.is_file() 168 | img = Image.open(str(filepath)) 169 | if shape_resize is not None: 170 | img = img.resize((shape_resize[1], shape_resize[0])) 171 | img_transform = transforms.Compose([ 172 | transforms.Grayscale(), 173 | transforms.ToTensor() 174 | ]) 175 | img_tensor = img_transform(img) 176 | return img_tensor 177 | 178 | @staticmethod 179 | def get_label(filepath: Path): 180 | assert filepath.is_file() 181 | label = Image.open(str(filepath)) 182 | label = np.array(label) 183 | return label 184 | 185 | @staticmethod 186 | def close_callback(h5f_dict): 187 | for k, h5f in h5f_dict.items(): 188 | h5f.close() 189 | 190 | def __len__(self): 191 | return (self.timestamps.size + 1) // 2 192 | 193 | def rectify_events(self, x: np.ndarray, y: np.ndarray, location: str): 194 | assert location in self.locations 195 | # From distorted to undistorted 196 | rectify_map = self.rectify_ev_maps[location] 197 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape 198 | assert x.max() < self.width 199 | assert y.max() < self.height 200 | return rectify_map[y, x] 201 | 202 | def generate_event_tensor(self, job_id, events, event_tensor, nr_events_per_data): 203 | id_start = job_id * nr_events_per_data 204 | id_end = (job_id + 1) * nr_events_per_data 205 | events_temp = events[id_start:id_end] 206 | event_representation = self.events_to_voxel_grid(events_temp[:, 0], events_temp[:, 1], events_temp[:, 3], 207 | events_temp[:, 2]) 208 | event_tensor[(job_id * self.num_bins):((job_id+1) * self.num_bins), :, :] = event_representation 209 | 210 | def __getitem__(self, index): 211 | label_path = Path(self.label_pathstrings[index * 2]) 212 | if self.resize: 213 | segmentation_mask = cv2.imread(str(label_path), 0) 214 | segmentation_mask = cv2.resize(segmentation_mask, (self.shape_resize[1], self.shape_resize[0]), 215 | interpolation=cv2.INTER_NEAREST) 216 | label = np.array(segmentation_mask) 217 | else: 218 | label = self.get_label(label_path) 219 | 220 | ts_end = self.timestamps[index * 2] 221 | 222 | output = {} 223 | for location in self.locations: 224 | if self.fixed_duration: 225 | ts_start = ts_end - self.delta_t_us 226 | event_tensor = None 227 | self.delta_t_per_data_us = self.delta_t_us / self.nr_events_data 228 | for i in range(self.nr_events_data): 229 | t_s = ts_start + i * self.delta_t_per_data_us 230 | t_end = ts_start + (i+1) * self.delta_t_per_data_us 231 | event_data = self.event_slicers[location].get_events(t_s, t_end) 232 | 233 | p = event_data['p'] 234 | t = event_data['t'] 235 | x = event_data['x'] 236 | y = event_data['y'] 237 | 238 | xy_rect = self.rectify_events(x, y, location) 239 | x_rect = xy_rect[:, 0] 240 | y_rect = xy_rect[:, 1] 241 | 242 | if self.event_representation == 'voxel_grid': 243 | event_representation = self.events_to_voxel_grid(x_rect, y_rect, p, t) 244 | else: 245 | events = np.stack([x_rect, y_rect, t, p], axis=1) 246 | event_representation = data_util.generate_input_representation(events, self.event_representation, 247 | (self.height, self.width)) 248 | event_representation = torch.from_numpy(event_representation).type(torch.FloatTensor) 249 | 250 | if event_tensor is None: 251 | event_tensor = event_representation 252 | else: 253 | event_tensor = torch.cat([event_tensor, event_representation], dim=0) 254 | 255 | else: 256 | num_bins_total = self.nr_events_data * self.num_bins 257 | event_tensor = torch.zeros((num_bins_total, self.height, self.width)) 258 | self.nr_events = self.nr_events_data * self.nr_events_per_data 259 | event_data = self.event_slicers[location].get_events_fixed_num(ts_end, self.nr_events) 260 | 261 | if self.nr_events >= event_data['t'].size: 262 | start_index = 0 263 | else: 264 | start_index = -self.nr_events 265 | 266 | p = event_data['p'][start_index:] 267 | t = event_data['t'][start_index:] 268 | x = event_data['x'][start_index:] 269 | y = event_data['y'][start_index:] 270 | nr_events_loaded = t.size 271 | 272 | xy_rect = self.rectify_events(x, y, location) 273 | x_rect = xy_rect[:, 0] 274 | y_rect = xy_rect[:, 1] 275 | 276 | nr_events_temp = nr_events_loaded // self.nr_events_data 277 | events = np.stack([x_rect, y_rect, t, p], axis=-1) 278 | Parallel(n_jobs=8, backend="threading")( 279 | delayed(self.generate_event_tensor)(i, events, event_tensor, nr_events_temp) for i in range(self.nr_events_data)) 280 | 281 | # remove 40 bottom rows 282 | event_tensor = event_tensor[:, :-40, :] 283 | 284 | if self.resize: 285 | event_tensor = f.interpolate(event_tensor.unsqueeze(0), 286 | size=(self.shape_resize[0], self.shape_resize[1]), 287 | mode='bilinear', align_corners=True).squeeze(0) 288 | 289 | label_tensor = torch.from_numpy(label).long() 290 | 291 | if self.augmentation: 292 | value_flip = round(random.random()) 293 | if value_flip > 0.5: 294 | event_tensor = torch.flip(event_tensor, [2]) 295 | label_tensor = torch.flip(label_tensor, [1]) 296 | 297 | if 'representation' not in output: 298 | output['representation'] = dict() 299 | output['representation'][location] = event_tensor 300 | 301 | if self.require_paired_data: 302 | img_left_path = Path(self.img_left_pathstrings[index * 2]) 303 | output['img_left'] = self.get_img(img_left_path, self.shape_resize) 304 | return output['representation']['left'], output['img_left'], label_tensor 305 | return output['representation']['left'], label_tensor 306 | 307 | 308 | -------------------------------------------------------------------------------- /DSEC/dataset/sequence_recurrent.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import weakref 3 | 4 | import cv2 5 | import h5py 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | 12 | from DSEC.dataset.representations import VoxelGrid 13 | from DSEC.utils.eventslicer import EventSlicer 14 | import albumentations as A 15 | 16 | 17 | class SequenceRecurrent(Dataset): 18 | # NOTE: This is just an EXAMPLE class for convenience. Adapt it to your case. 19 | # In this example, we use the voxel grid representation. 20 | # 21 | # This class assumes the following structure in a sequence directory: 22 | # 23 | # seq_name (e.g. zurich_city_11_a) 24 | # ├── disparity 25 | # │   ├── event 26 | # │   │   ├── 000000.png 27 | # │   │   └── ... 28 | # │   └── timestamps.txt 29 | # └── events 30 | #    ├── left 31 | #    │   ├── events.h5 32 | #    │   └── rectify_map.h5 33 | #    └── right 34 | #    ├── events.h5 35 | #    └── rectify_map.h5 36 | 37 | def __init__(self, seq_path: Path, mode: str='train', event_representation: str = 'voxel_grid', 38 | nr_events_data: int = 5, delta_t_per_data: int = 20, nr_events_per_data: int = 100000, 39 | nr_bins_per_data: int = 5, require_paired_data=False, normalize_event=False, separate_pol=False, 40 | semseg_num_classes: int = 11, augmentation=True, fixed_duration=False, loading_time_window: int = 250): 41 | assert nr_bins_per_data >= 1 42 | # assert delta_t_ms <= 200, 'adapt this code, if duration is higher than 100 ms' 43 | assert seq_path.is_dir() 44 | 45 | # NOTE: Adapt this code according to the present mode (e.g. train, val or test). 46 | self.mode = mode 47 | self.augmentation = augmentation 48 | 49 | # Save output dimensions 50 | self.height = 480 51 | self.width = 640 52 | 53 | # Set event representation 54 | self.nr_events_data = nr_events_data 55 | self.num_bins = nr_bins_per_data 56 | assert nr_events_per_data > 0 57 | self.nr_events_per_data = nr_events_per_data 58 | self.event_representation = event_representation 59 | self.separate_pol = separate_pol 60 | self.normalize_event = normalize_event 61 | self.voxel_grid = VoxelGrid(self.num_bins, self.height, self.width, normalize=self.normalize_event) 62 | 63 | self.locations = ['left'] # 'right' 64 | self.semseg_num_classes = semseg_num_classes 65 | 66 | # Save delta timestamp in ms 67 | self.fixed_duration = fixed_duration 68 | if self.fixed_duration: 69 | delta_t_ms = nr_events_data * delta_t_per_data 70 | else: 71 | delta_t_ms = loading_time_window 72 | self.delta_t_us = delta_t_ms * 1000 73 | 74 | self.require_paired_data = require_paired_data 75 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 76 | 77 | # load disparity timestamps 78 | # disp_dir = seq_path / 'disparity' 79 | img_dir = seq_path / 'images' 80 | assert img_dir.is_dir() 81 | 82 | self.timestamps = np.loadtxt(img_dir / 'left' / 'exposure_timestamps.txt', comments='#', delimiter=',', dtype='int64')[:, 1] 83 | 84 | # load images paths 85 | if self.require_paired_data: 86 | img_left_dir = img_dir / 'left' / 'ev_inf' 87 | assert img_left_dir.is_dir() 88 | img_left_pathstrings = list() 89 | for entry in img_left_dir.iterdir(): 90 | assert str(entry.name).endswith('.png') 91 | img_left_pathstrings.append(str(entry)) 92 | img_left_pathstrings.sort() 93 | self.img_left_pathstrings = img_left_pathstrings 94 | 95 | assert len(self.img_left_pathstrings) == self.timestamps.size 96 | 97 | if self.mode == 'val': 98 | if self.semseg_num_classes == 11: 99 | label_dir = seq_path / 'semantic' / '11classes' / 'data' 100 | elif self.semseg_num_classes == 19: 101 | label_dir = seq_path / 'semantic' / '19classes' / 'data' 102 | elif self.semseg_num_classes == 6: 103 | label_dir = seq_path / 'semantic' / '6classes' / 'data' 104 | else: 105 | raise ValueError 106 | assert label_dir.is_dir() 107 | label_pathstrings = list() 108 | for entry in label_dir.iterdir(): 109 | assert str(entry.name).endswith('.png') 110 | label_pathstrings.append(str(entry)) 111 | label_pathstrings.sort() 112 | self.label_pathstrings = label_pathstrings 113 | 114 | assert len(self.label_pathstrings) == self.timestamps.size 115 | 116 | # Remove first disparity path and corresponding timestamp. 117 | # This is necessary because we do not have events before the first disparity map. 118 | # assert int(Path(self.disp_gt_pathstrings[0]).stem) == 0 119 | # del self.disp_gt_pathstrings[:(delta_t_ms // 100 + 1)] 120 | self.timestamps = self.timestamps[(delta_t_ms // 50 + 1):] 121 | if self.require_paired_data: 122 | del self.img_left_pathstrings[:(delta_t_ms // 50 + 1)] 123 | assert len(self.img_left_pathstrings) == self.timestamps.size 124 | if self.mode == 'val': 125 | del self.label_pathstrings[:(delta_t_ms // 50 + 1)] 126 | assert len(self.img_left_pathstrings) == len(self.label_pathstrings) 127 | 128 | self.h5f = dict() 129 | self.rectify_ev_maps = dict() 130 | self.event_slicers = dict() 131 | 132 | ev_dir = seq_path / 'events' 133 | for location in self.locations: 134 | ev_dir_location = ev_dir / location 135 | ev_data_file = ev_dir_location / 'events.h5' 136 | ev_rect_file = ev_dir_location / 'rectify_map.h5' 137 | 138 | h5f_location = h5py.File(str(ev_data_file), 'r') 139 | self.h5f[location] = h5f_location 140 | self.event_slicers[location] = EventSlicer(h5f_location) 141 | with h5py.File(str(ev_rect_file), 'r') as h5_rect: 142 | self.rectify_ev_maps[location] = h5_rect['rectify_map'][()] 143 | 144 | def events_to_voxel_grid(self, x, y, p, t, device: str='cpu'): 145 | t = (t - t[0]).astype('float32') 146 | t = (t/t[-1]) 147 | x = x.astype('float32') 148 | y = y.astype('float32') 149 | pol = p.astype('float32') 150 | return self.voxel_grid.convert( 151 | torch.from_numpy(x), 152 | torch.from_numpy(y), 153 | torch.from_numpy(pol), 154 | torch.from_numpy(t)) 155 | 156 | def getHeightAndWidth(self): 157 | return self.height, self.width 158 | 159 | @staticmethod 160 | def get_disparity_map(filepath: Path): 161 | assert filepath.is_file() 162 | disp_16bit = cv2.imread(str(filepath), cv2.IMREAD_ANYDEPTH) 163 | return disp_16bit.astype('float32')/256 164 | 165 | @staticmethod 166 | def get_img(filepath: Path): 167 | assert filepath.is_file() 168 | img = Image.open(str(filepath)) 169 | img_transform = transforms.Compose([ 170 | transforms.Grayscale(), 171 | transforms.ToTensor() 172 | ]) 173 | img_tensor = img_transform(img) 174 | return img_tensor 175 | 176 | @staticmethod 177 | def get_label(filepath: Path): 178 | assert filepath.is_file() 179 | label = Image.open(str(filepath)) 180 | label = np.array(label) 181 | # label_tensor = torch.from_numpy(label).long() 182 | return label 183 | 184 | @staticmethod 185 | def close_callback(h5f_dict): 186 | for k, h5f in h5f_dict.items(): 187 | h5f.close() 188 | 189 | def __len__(self): 190 | if self.fixed_duration: 191 | return self.timestamps.size 192 | else: 193 | return self.event_slicers['left'].events['t'].size // self.nr_events_per_data 194 | 195 | def rectify_events(self, x: np.ndarray, y: np.ndarray, location: str): 196 | assert location in self.locations 197 | # From distorted to undistorted 198 | rectify_map = self.rectify_ev_maps[location] 199 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape 200 | assert x.max() < self.width 201 | assert y.max() < self.height 202 | return rectify_map[y, x] 203 | 204 | def __getitem__(self, index): 205 | if self.augmentation: 206 | transform_a = A.ReplayCompose([ 207 | A.HorizontalFlip(p=0.5), 208 | # A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=0.5, border_mode=0), 209 | # A.PadIfNeeded(min_height=self.height-40, min_width=self.width, always_apply=True, border_mode=0), 210 | # A.RandomCrop(height=self.height-40, width=self.width, p=1) 211 | ]) 212 | A_data = None 213 | 214 | label = np.zeros((self.height-40, self.width)) 215 | 216 | output = {} 217 | for location in self.locations: 218 | if self.fixed_duration: 219 | if self.mode == 'val': 220 | label_path = Path(self.label_pathstrings[index]) 221 | label = self.get_label(label_path) 222 | 223 | ts_end = self.timestamps[index] 224 | # ts_start should be fine (within the window as we removed the first disparity map) 225 | ts_start = ts_end - self.delta_t_us 226 | 227 | event_tensor = None 228 | self.delta_t_per_data_us = self.delta_t_us / self.nr_events_data 229 | for i in range(self.nr_events_data): 230 | t_s = ts_start + i * self.delta_t_per_data_us 231 | t_end = ts_start + (i+1) * self.delta_t_per_data_us 232 | event_data = self.event_slicers[location].get_events(t_s, t_end) 233 | 234 | p = event_data['p'] 235 | t = event_data['t'] 236 | x = event_data['x'] 237 | y = event_data['y'] 238 | 239 | xy_rect = self.rectify_events(x, y, location) 240 | x_rect = xy_rect[:, 0] 241 | y_rect = xy_rect[:, 1] 242 | 243 | event_representation = self.events_to_voxel_grid(x_rect, y_rect, p, t) 244 | 245 | if event_tensor is None: 246 | event_tensor = event_representation 247 | else: 248 | event_tensor = torch.cat([event_tensor, event_representation], dim=0) 249 | else: 250 | num_bins_total = self.nr_events_data * self.num_bins 251 | self.nr_events = self.nr_events_data * self.nr_events_per_data 252 | t_start_us_idx = index * self.nr_events 253 | t_end_us_idx = t_start_us_idx + self.nr_events 254 | event_data = self.event_slicers[location].get_events_fixed_num_recurrent(t_start_us_idx, t_end_us_idx) 255 | 256 | p = event_data['p'] 257 | t = event_data['t'] 258 | x = event_data['x'] 259 | y = event_data['y'] 260 | 261 | xy_rect = self.rectify_events(x, y, location) 262 | x_rect = xy_rect[:, 0] 263 | y_rect = xy_rect[:, 1] 264 | 265 | event_representation = self.events_to_voxel_grid(x_rect, y_rect, p, t) 266 | 267 | event_tensor = event_representation 268 | 269 | 270 | event_tensor = event_tensor[:, :-40, :] # remove 40 bottom rows 271 | 272 | if self.augmentation: 273 | if A_data is None: 274 | A_data = transform_a(image=event_tensor[0, :, :].numpy(), mask=label) 275 | label = A_data['mask'] 276 | for k in range(event_tensor.shape[0]): 277 | event_tensor[k, :, :] = torch.from_numpy( 278 | A.ReplayCompose.replay(A_data['replay'], image=event_tensor[k, :, :].numpy())['image']) 279 | 280 | if 'representation' not in output: 281 | output['representation'] = dict() 282 | output['representation'][location] = event_tensor 283 | 284 | label_tensor = torch.from_numpy(label).long() 285 | 286 | if self.require_paired_data: 287 | img_left_path = Path(self.img_left_pathstrings[index]) 288 | # output['img_left'] = self.get_img(img_left_path)[:, :-40, :] # remove 40 bottom rows 289 | output['img_left'] = self.get_img(img_left_path) 290 | return output['representation']['left'], output['img_left'], label_tensor 291 | return output['representation']['left'], label_tensor 292 | -------------------------------------------------------------------------------- /DSEC/dataset/visualization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib as mpl 3 | import matplotlib.cm as cm 4 | import numpy as np 5 | 6 | 7 | def disp_img_to_rgb_img(disp_array: np.ndarray): 8 | disp_pixels = np.argwhere(disp_array > 0) 9 | u_indices = disp_pixels[:, 1] 10 | v_indices = disp_pixels[:, 0] 11 | disp = disp_array[v_indices, u_indices] 12 | max_disp = 80 13 | 14 | norm = mpl.colors.Normalize(vmin=0, vmax=max_disp, clip=True) 15 | mapper = cm.ScalarMappable(norm=norm, cmap='inferno') 16 | 17 | disp_color = mapper.to_rgba(disp)[..., :3] 18 | output_image = np.zeros((disp_array.shape[0], disp_array.shape[1], 3)) 19 | output_image[v_indices, u_indices, :] = disp_color 20 | output_image = (255 * output_image).astype("uint8") 21 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR) 22 | return output_image 23 | 24 | def show_image(image): 25 | cv2.namedWindow('viz', cv2.WND_PROP_FULLSCREEN) 26 | cv2.imshow('viz', image) 27 | cv2.waitKey(0) 28 | 29 | def get_disp_overlay(image_1c, disp_rgb_image, height, width): 30 | image = np.repeat(image_1c[..., np.newaxis], 3, axis=2) 31 | overlay = cv2.addWeighted(image, 0.1, disp_rgb_image, 0.9, 0) 32 | return overlay 33 | 34 | def show_disp_overlay(image_1c, disp_rgb_image, height, width): 35 | overlay = get_disp_overlay(image_1c, disp_rgb_image, height, width) 36 | show_image(overlay) 37 | -------------------------------------------------------------------------------- /DSEC/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/DSEC/utils/__init__.py -------------------------------------------------------------------------------- /DSEC/utils/eventslicer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Tuple 3 | 4 | import h5py 5 | import hdf5plugin 6 | from numba import jit 7 | import numpy as np 8 | 9 | 10 | class EventSlicer: 11 | def __init__(self, h5f: h5py.File): 12 | self.h5f = h5f 13 | 14 | self.events = dict() 15 | for dset_str in ['p', 'x', 'y', 't']: 16 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)] 17 | 18 | # This is the mapping from milliseconds to event index: 19 | # It is defined such that 20 | # (1) t[ms_to_idx[ms]] >= ms*1000 21 | # (2) t[ms_to_idx[ms] - 1] < ms*1000 22 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds. 23 | # 24 | # As an example, given 't' and 'ms': 25 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000 26 | # ms: 0 1 2 3 4 5 6 7 8 9 27 | # 28 | # we get 29 | # 30 | # ms_to_idx: 31 | # 0 2 2 3 3 3 5 5 8 9 32 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64') 33 | 34 | if "t_offset" in list(h5f.keys()): 35 | self.t_offset = int(h5f['t_offset'][()]) 36 | else: 37 | self.t_offset = 0 38 | self.t_final = int(self.events['t'][-1]) + self.t_offset 39 | 40 | def get_start_time_us(self): 41 | return self.t_offset 42 | 43 | def get_final_time_us(self): 44 | return self.t_final 45 | 46 | def get_events(self, t_start_us: int, t_end_us: int, max_events_per_data: int = -1) -> Dict[str, np.ndarray]: 47 | """Get events (p, x, y, t) within the specified time window 48 | Parameters 49 | ---------- 50 | t_start_us: start time in microseconds 51 | t_end_us: end time in microseconds 52 | Returns 53 | ------- 54 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved 55 | """ 56 | assert t_start_us < t_end_us 57 | 58 | # We assume that the times are top-off-day, hence subtract offset: 59 | t_start_us -= self.t_offset 60 | t_end_us -= self.t_offset 61 | 62 | # print(t_start_us, self.t_offset) 63 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us) 64 | t_start_ms_idx = self.ms2idx(t_start_ms) 65 | t_end_ms_idx = self.ms2idx(t_end_ms) 66 | 67 | if t_start_ms_idx is None or t_end_ms_idx is None: 68 | print('Error', 'start', t_start_us, 'end', t_end_us) 69 | # Cannot guarantee window size anymore 70 | return None 71 | 72 | events = dict() 73 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx]) 74 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us) 75 | t_start_us_idx = t_start_ms_idx + idx_start_offset 76 | t_end_us_idx = t_start_ms_idx + idx_end_offset 77 | # Again add t_offset to get gps time 78 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset 79 | for dset_str in ['p', 'x', 'y']: 80 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx]) 81 | assert events[dset_str].size == events['t'].size 82 | # if max_events_per_data != -1 and events['t'].size > max_events_per_data: 83 | # idx = np.round(np.linspace(0, events['t'].size - 1, max_events_per_data)).astype(int) 84 | # for key in events.keys(): 85 | # events[key] = events[key][idx] 86 | return events 87 | 88 | def get_events_fixed_num(self, t_end_us: int, nr_events: int = 100000) -> Dict[str, np.ndarray]: 89 | """Get events (p, x, y, t) with fixed number of events 90 | Parameters 91 | ---------- 92 | t_end_us: end time in microseconds 93 | nr_events: number of events to load 94 | Returns 95 | ------- 96 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved 97 | """ 98 | # t_start_us = t_end_us - 1000 99 | # assert t_start_us < t_end_us 100 | 101 | # We assume that the times are top-off-day, hence subtract offset: 102 | # t_start_us -= self.t_offset 103 | t_end_us -= self.t_offset 104 | 105 | # print(t_start_us, self.t_offset) 106 | t_end_lower_ms, t_end_upper_ms = self.get_conservative_ms(t_end_us) 107 | t_end_lower_ms_idx = self.ms2idx(t_end_lower_ms) 108 | t_end_upper_ms_idx = self.ms2idx(t_end_upper_ms) 109 | 110 | if t_end_lower_ms_idx is None or t_end_upper_ms_idx is None: 111 | # Cannot guarantee window size anymore 112 | return None 113 | 114 | events = dict() 115 | time_array_conservative = np.asarray(self.events['t'][t_end_lower_ms_idx:t_end_upper_ms_idx]) 116 | _, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_end_us, t_end_us) 117 | t_end_us_idx = t_end_lower_ms_idx + idx_end_offset 118 | t_start_us_idx = t_end_us_idx - nr_events 119 | if t_start_us_idx < 0: 120 | t_start_us_idx = 0 121 | 122 | for dset_str in self.events.keys(): 123 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx]) 124 | 125 | return events 126 | 127 | def get_events_fixed_num_recurrent(self, t_start_us_idx: int, t_end_us_idx: int) -> Dict[str, np.ndarray]: 128 | """Get events (p, x, y, t) with fixed number of events 129 | Parameters 130 | ---------- 131 | t_start_us_idx: start id 132 | t_end_us_idx: end id 133 | Returns 134 | ------- 135 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved 136 | """ 137 | assert t_start_us_idx < t_end_us_idx 138 | 139 | events = dict() 140 | for dset_str in self.events.keys(): 141 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx]) 142 | 143 | return events 144 | 145 | 146 | @staticmethod 147 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]: 148 | """Compute a conservative time window of time with millisecond resolution. 149 | We have a time to index mapping for each millisecond. Hence, we need 150 | to compute the lower and upper millisecond to retrieve events. 151 | Parameters 152 | ---------- 153 | ts_start_us: start time in microseconds 154 | ts_end_us: end time in microseconds 155 | Returns 156 | ------- 157 | window_start_ms: conservative start time in milliseconds 158 | window_end_ms: conservative end time in milliseconds 159 | """ 160 | assert ts_end_us > ts_start_us 161 | window_start_ms = math.floor(ts_start_us/1000) 162 | window_end_ms = math.ceil(ts_end_us/1000) 163 | return window_start_ms, window_end_ms 164 | 165 | @staticmethod 166 | def get_conservative_ms(ts_us: int) -> Tuple[int, int]: 167 | """Convert time in microseconds into milliseconds 168 | ---------- 169 | ts_us: time in microseconds 170 | Returns 171 | ------- 172 | ts_lower_ms: lower millisecond 173 | ts_upper_ms: upper millisecond 174 | """ 175 | ts_lower_ms = math.floor(ts_us / 1000) 176 | ts_upper_ms = math.ceil(ts_us / 1000) 177 | return ts_lower_ms, ts_upper_ms 178 | 179 | @staticmethod 180 | @jit(nopython=True) 181 | def get_time_indices_offsets( 182 | time_array: np.ndarray, 183 | time_start_us: int, 184 | time_end_us: int) -> Tuple[int, int]: 185 | """Compute index offset of start and end timestamps in microseconds 186 | Parameters 187 | ---------- 188 | time_array: timestamps (in us) of the events 189 | time_start_us: start timestamp (in us) 190 | time_end_us: end timestamp (in us) 191 | Returns 192 | ------- 193 | idx_start: Index within this array corresponding to time_start_us 194 | idx_end: Index within this array corresponding to time_end_us 195 | such that (in non-edge cases) 196 | time_array[idx_start] >= time_start_us 197 | time_array[idx_end] >= time_end_us 198 | time_array[idx_start - 1] < time_start_us 199 | time_array[idx_end - 1] < time_end_us 200 | this means that 201 | time_start_us <= time_array[idx_start:idx_end] < time_end_us 202 | """ 203 | 204 | assert time_array.ndim == 1 205 | 206 | idx_start = -1 207 | if time_array[-1] < time_start_us: 208 | # This can happen in extreme corner cases. E.g. 209 | # time_array[0] = 1016 210 | # time_array[-1] = 1984 211 | # time_start_us = 1990 212 | # time_end_us = 2000 213 | 214 | # Return same index twice: array[x:x] is empty. 215 | return time_array.size, time_array.size 216 | else: 217 | for idx_from_start in range(0, time_array.size, 1): 218 | if time_array[idx_from_start] >= time_start_us: 219 | idx_start = idx_from_start 220 | break 221 | assert idx_start >= 0 222 | 223 | idx_end = time_array.size 224 | for idx_from_end in range(time_array.size - 1, -1, -1): 225 | if time_array[idx_from_end] >= time_end_us: 226 | idx_end = idx_from_end 227 | else: 228 | break 229 | 230 | assert time_array[idx_start] >= time_start_us 231 | if idx_end < time_array.size: 232 | assert time_array[idx_end] >= time_end_us 233 | if idx_start > 0: 234 | assert time_array[idx_start - 1] < time_start_us 235 | if idx_end > 0: 236 | assert time_array[idx_end - 1] < time_end_us 237 | return idx_start, idx_end 238 | 239 | def ms2idx(self, time_ms: int) -> int: 240 | assert time_ms >= 0 241 | if time_ms >= self.ms_to_idx.size: 242 | return None 243 | return self.ms_to_idx[time_ms] 244 | -------------------------------------------------------------------------------- /DSEC/utils/viz_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import torchvision.utils 5 | import torch.nn.functional as F 6 | import matplotlib.pyplot as plt 7 | import itertools 8 | 9 | 10 | def createRGBGrid(tensor_list, nrow): 11 | """Creates a grid of rgb values based on the tensor stored in tensor_list""" 12 | vis_tensor_list = [] 13 | for tensor in tensor_list: 14 | vis_tensor_list.append(visualizeTensors(tensor)) 15 | 16 | return torchvision.utils.make_grid(torch.cat(vis_tensor_list, dim=0), nrow=nrow) 17 | 18 | 19 | def createRGBImage(tensor, separate_pol=True): 20 | """Creates a grid of rgb values based on the tensor stored in tensor_list""" 21 | if tensor.shape[1] == 3: 22 | return tensor 23 | elif tensor.shape[1] == 1: 24 | return tensor.expand(-1, 3, -1, -1) 25 | elif tensor.shape[1] == 2: 26 | return visualizeHistogram(tensor) 27 | elif tensor.shape[1] > 3: 28 | return visualizeVoxelGrid(tensor, separate_pol) 29 | 30 | 31 | def visualizeTensors(tensor): 32 | """Creates a rgb image of the given tensor. Can be event histogram, event voxel grid, grayscale and rgb.""" 33 | if tensor.shape[1] == 3: 34 | return tensor 35 | elif tensor.shape[1] == 1: 36 | return tensor.expand(-1, 3, -1, -1) 37 | elif tensor.shape[1] == 2: 38 | return visualizeHistogram(tensor) 39 | elif tensor.shape[1] > 3: 40 | return visualizeVoxelGrid(tensor) 41 | 42 | 43 | def visualizeHistogram(histogram): 44 | """Visualizes the input histogram""" 45 | batch, _, height, width = histogram.shape 46 | torch_image = torch.zeros([batch, 1, height, width], device=histogram.device) 47 | 48 | return torch.cat([histogram.clamp(0, 1), torch_image], dim=1) 49 | 50 | 51 | def visualizeVoxelGrid(voxel_grid, separate_pol=True): 52 | """Visualizes the input histogram""" 53 | batch, nr_channels, height, width = voxel_grid.shape 54 | if separate_pol: 55 | pos_events_idx = nr_channels // 2 56 | temporal_scaling = torch.arange(start=1, end=pos_events_idx+1, dtype=voxel_grid.dtype, 57 | device=voxel_grid.device)[None, :, None, None] / pos_events_idx 58 | pos_voxel_grid = voxel_grid[:, :pos_events_idx] * temporal_scaling 59 | neg_voxel_grid = voxel_grid[:, pos_events_idx:] * temporal_scaling 60 | 61 | torch_image = torch.zeros([batch, 1, height, width], device=voxel_grid.device) 62 | pos_image = torch.sum(pos_voxel_grid, dim=1, keepdim=True) 63 | neg_image = torch.sum(neg_voxel_grid, dim=1, keepdim=True) 64 | 65 | return torch.cat([neg_image.clamp(0, 1), pos_image.clamp(0, 1), torch_image], dim=1) 66 | 67 | sum_events = torch.sum(voxel_grid, dim=1).detach() 68 | event_preview = torch.zeros((batch, 3, height, width)) 69 | b = event_preview[:, 0, :, :] 70 | r = event_preview[:, 2, :, :] 71 | b[sum_events > 0] = 255 72 | r[sum_events < 0] = 255 73 | return event_preview 74 | 75 | 76 | def visualizeConfusionMatrix(confusion_matrix, path_name=None): 77 | """ 78 | Visualizes the confustion matrix using matplotlib. 79 | 80 | :param confusion_matrix: NxN numpy array 81 | :param path_name: if no path name is given, just an image is returned 82 | """ 83 | import matplotlib.pyplot as plt 84 | nr_classes = confusion_matrix.shape[0] 85 | fig, ax = plt.subplots(1, 1, figsize=(16, 16)) 86 | ax.matshow(confusion_matrix) 87 | ax.plot([-0.5, nr_classes - 0.5], [-0.5, nr_classes - 0.5], '-', color='grey') 88 | ax.set_xlabel('Labels') 89 | ax.set_ylabel('Predicted') 90 | 91 | if path_name is None: 92 | fig.tight_layout(pad=0) 93 | fig.canvas.draw() 94 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 95 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 96 | plt.close() 97 | 98 | return data 99 | 100 | else: 101 | fig.savefig(path_name) 102 | plt.close() 103 | 104 | 105 | def drawBoundingBoxes(np_image, bounding_boxes, class_name=None, ground_truth=True, rescale_image=False): 106 | """ 107 | Draws the bounding boxes in the image 108 | 109 | :param np_image: [H, W, C] 110 | :param bounding_boxes: list of bounding boxes with shape [x, y, width, height]. 111 | :param class_name: string 112 | """ 113 | np_image = np_image.astype(np.float) 114 | resize_scale = 1.5 115 | if rescale_image: 116 | bounding_boxes[:, :4] = (bounding_boxes.astype(np.float)[:, :4] * resize_scale) 117 | new_dim = np.array(np_image.shape[:2], dtype=np.float) * resize_scale 118 | np_image = cv2.resize(np_image, tuple(new_dim.astype(int)[::-1]), interpolation=cv2.INTER_NEAREST) 119 | 120 | for i, bounding_box in enumerate(bounding_boxes): 121 | if bounding_box.sum() == 0: 122 | break 123 | if class_name is None: 124 | np_image = drawBoundingBox(np_image, bounding_box, ground_truth=ground_truth) 125 | else: 126 | np_image = drawBoundingBox(np_image, bounding_box, class_name[i], ground_truth) 127 | 128 | return np_image 129 | 130 | 131 | def drawBoundingBox(np_image, bounding_box, class_name=None, ground_truth=False): 132 | """ 133 | Draws a bounding box in the image. 134 | 135 | :param np_image: [H, W, C] 136 | :param bounding_box: [x, y, width, height]. 137 | :param class_name: string 138 | """ 139 | if ground_truth: 140 | bbox_color = np.array([0, 1, 1]) 141 | else: 142 | bbox_color = np.array([1, 0, 1]) 143 | height, width = bounding_box[2:4] 144 | 145 | np_image[bounding_box[0], bounding_box[1]:(bounding_box[1] + width)] = bbox_color 146 | np_image[bounding_box[0]:(bounding_box[0] + height), (bounding_box[1] + width)] = bbox_color 147 | np_image[(bounding_box[0] + height), bounding_box[1]:(bounding_box[1] + width)] = bbox_color 148 | np_image[bounding_box[0]:(bounding_box[0] + height), bounding_box[1]] = bbox_color 149 | 150 | if class_name is not None: 151 | font = cv2.FONT_HERSHEY_SIMPLEX 152 | font_color = (0, 0, 0) 153 | font_scale = 0.5 154 | thickness = 1 155 | bottom_left = tuple(((bounding_box[[1, 0]] + np.array([+1, height - 2]))).astype(int)) 156 | 157 | # Draw Box 158 | (text_width, text_height) = cv2.getTextSize(class_name, font, fontScale=font_scale, thickness=thickness)[0] 159 | box_coords = ((bottom_left[0], bottom_left[1] + 2), 160 | (bottom_left[0] + text_width + 2, bottom_left[1] - text_height - 2 + 2)) 161 | color_format = (int(bbox_color[0]), int(bbox_color[1]), int(bbox_color[2])) 162 | # np_image = cv2.UMat(np_image) 163 | np_image = cv2.UMat(np_image).get() 164 | cv2.rectangle(np_image, box_coords[0], box_coords[1], color_format, cv2.FILLED) 165 | 166 | cv2.putText(np_image, class_name, bottom_left, font, font_scale, font_color, thickness, cv2.LINE_AA) 167 | 168 | return np_image 169 | 170 | 171 | def visualizeFlow(tensor_flow_map): 172 | """ 173 | Visualizes the direction flow based on the HSV model 174 | """ 175 | np_flow_map = tensor_flow_map.cpu().detach().numpy() 176 | batch_s, channel, height, width = np_flow_map.shape 177 | viz_array = np.zeros([batch_s, height, width, 3], dtype=np.uint8) 178 | hsv = np.zeros([height, width, 3], dtype=np.uint8) 179 | 180 | for i, sample_flow_map in enumerate(np_flow_map): 181 | hsv[..., 1] = 255 182 | mag, ang = cv2.cartToPolar(sample_flow_map[0, :, :], sample_flow_map[1, :, :]) 183 | hsv[..., 0] = ang * 180 / np.pi / 2 184 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 185 | 186 | bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 187 | viz_array[i] = bgr 188 | 189 | return torch.from_numpy(viz_array.transpose([0, 3, 1, 2]) / 255.).to(tensor_flow_map.device) 190 | 191 | def create_checkerboard(N, C, H, W): 192 | cell_sz = max(min(H, W) // 32, 1) 193 | mH = (H + cell_sz - 1) // cell_sz 194 | mW = (W + cell_sz - 1) // cell_sz 195 | checkerboard = torch.full((mH, mW), 0.25, dtype=torch.float32) 196 | checkerboard[0::2, 0::2] = 0.75 197 | checkerboard[1::2, 1::2] = 0.75 198 | checkerboard = checkerboard.float().view(1, 1, mH, mW) 199 | checkerboard = F.interpolate(checkerboard, scale_factor=cell_sz, mode='nearest') 200 | checkerboard = checkerboard[:, :, :H, :W].repeat(N, C, 1, 1) 201 | return checkerboard 202 | 203 | 204 | def prepare_semseg(img, semseg_color_map, semseg_ignore_label): 205 | assert (img.dim() == 3 or img.dim() == 4 and img.shape[1] == 1) and img.dtype in (torch.int, torch.long), \ 206 | f'Expecting 4D tensor with semseg classes, got {img.shape}' 207 | if img.dim() == 4: 208 | img = img.squeeze(1) 209 | colors = torch.tensor(semseg_color_map, dtype=torch.float32) 210 | assert colors.dim() == 2 and colors.shape[1] == 3 211 | if torch.max(colors) > 128: 212 | colors /= 255 213 | img = img.cpu().clone() # N x H x W 214 | N, H, W = img.shape 215 | img_color_ids = torch.unique(img) 216 | assert all(c_id == semseg_ignore_label or 0 <= c_id < len(semseg_color_map) for c_id in img_color_ids) 217 | checkerboard, mask_ignore = None, None 218 | if semseg_ignore_label in img_color_ids: 219 | checkerboard = create_checkerboard(N, 3, H, W) 220 | mask_ignore = img == semseg_ignore_label 221 | img[mask_ignore] = 0 222 | img = colors[img] # N x H x W x 3 223 | img = img.permute(0, 3, 1, 2) 224 | 225 | # checkerboard 226 | # if semseg_ignore_label in img_color_ids: 227 | # mask_ignore = mask_ignore.unsqueeze(1).repeat(1, 3, 1, 1) 228 | # img[mask_ignore] = checkerboard[mask_ignore] 229 | return img 230 | 231 | 232 | def plot_confusion_matrix(cm, classes, 233 | normalize=False, 234 | title='Confusion matrix', 235 | cmap=plt.cm.Blues): 236 | """ 237 | This function prints and plots the confusion matrix. 238 | Normalization can be applied by setting `normalize=True`. 239 | """ 240 | cm = cm.numpy() 241 | if normalize: 242 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 243 | print("Normalized confusion matrix") 244 | else: 245 | print('Confusion matrix, without normalization') 246 | 247 | # print(cm) 248 | 249 | fig = plt.figure() 250 | 251 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 252 | plt.title(title) 253 | plt.colorbar() 254 | tick_marks = np.arange(len(classes)) 255 | plt.xticks(tick_marks, classes, rotation=45) 256 | plt.yticks(tick_marks, classes) 257 | 258 | fmt = '.2f' if normalize else 'd' 259 | thresh = cm.max() / 2. 260 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 261 | plt.text(j, i, format(cm[i, j], fmt), 262 | horizontalalignment="center", 263 | color="white" if cm[i, j] > thresh else "black") 264 | 265 | plt.tight_layout() 266 | plt.ylabel('True label') 267 | plt.xlabel('Predicted label') 268 | return fig 269 | -------------------------------------------------------------------------------- /DSEC/visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/DSEC/visualization/__init__.py -------------------------------------------------------------------------------- /DSEC/visualization/eventreader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import weakref 3 | 4 | import h5py 5 | 6 | from utils.eventslicer import EventSlicer 7 | 8 | 9 | class EventReaderAbstract: 10 | def __init__(self, filepath: Path): 11 | assert filepath.is_file() 12 | assert filepath.name.endswith('.h5') 13 | self.h5f = h5py.File(str(filepath), 'r') 14 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f) 15 | 16 | @staticmethod 17 | def close_callback(h5f: h5py.File): 18 | h5f.close() 19 | 20 | def __enter__(self): 21 | return self 22 | 23 | def __exit__(self, exc_type, exc_value, traceback): 24 | self._finalizer() 25 | 26 | def __iter__(self): 27 | return self 28 | 29 | def __next__(self): 30 | raise NotImplementedError 31 | 32 | 33 | class EventReader(EventReaderAbstract): 34 | def __init__(self, filepath: Path, dt_milliseconds: int): 35 | super().__init__(filepath) 36 | self.event_slicer = EventSlicer(self.h5f) 37 | 38 | self.dt_us = int(dt_milliseconds * 1000) 39 | self.t_start_us = self.event_slicer.get_start_time_us() 40 | self.t_end_us = self.event_slicer.get_final_time_us() 41 | 42 | self._length = (self.t_end_us - self.t_start_us)//self.dt_us 43 | 44 | def __len__(self): 45 | return self._length 46 | 47 | def __next__(self): 48 | t_end_us = self.t_start_us + self.dt_us 49 | if t_end_us > self.t_end_us: 50 | raise StopIteration 51 | events = self.event_slicer.get_events(self.t_start_us, t_end_us) 52 | if events is None: 53 | raise StopIteration 54 | 55 | self.t_start_us = t_end_us 56 | return events 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESS: Learning Event-based Semantic Segmentation from Still Images 2 | 3 |

4 | 5 | Transfer 6 | 7 |

8 | 9 | This is the code for the paper **ESS: Learning Event-based Semantic Segmentation from Still Images** 10 | ([PDF](https://rpg.ifi.uzh.ch/docs/ECCV22_Sun.pdf)) by Zhaoning Sun*, [Nico Messikommer](https://messikommernico.github.io/)*, [Daniel Gehrig](https://danielgehrig18.github.io), and [Davide Scaramuzza](http://rpg.ifi.uzh.ch/people_scaramuzza.html). For an overview of our method, check out our [video](https://youtu.be/Tby5c9IDsDc). 11 | 12 | If you use any of this code, please cite the following publication: 13 | 14 | ```bibtex 15 | @Article{Sun22eccv, 16 | author = {Zhaoning Sun* and Nico Messikommer* and Daniel Gehrig and Davide Scaramuzza}, 17 | title = {ESS: Learning Event-based Semantic Segmentation from Still Images}, 18 | journal = {European Conference on Computer Vision. (ECCV)}, 19 | year = {2022}, 20 | } 21 | ``` 22 | 23 | ## Abstract 24 | Retrieving accurate semantic information in challenging high dynamic range (HDR) and high-speed conditions 25 | remains an open challenge for image-based algorithms due to severe image degradations. Event cameras promise 26 | to address these challenges since they feature a much higher dynamic range and are resilient to motion blur. 27 | Nonetheless, semantic segmentation with event cameras is still in its infancy which is chiefly due to the 28 | novelty of the sensor, and the lack of high-quality, labeled datasets. In this work, we introduce ESS, which 29 | tackles this problem by directly transferring the semantic segmentation task from existing labeled image 30 | datasets to unlabeled events via unsupervised domain adaptation (UDA). Compared to existing UDA methods, 31 | our approach aligns recurrent, motion-invariant event embeddings with image embeddings. For this reason, our 32 | method neither requires video data nor per-pixel alignment between images and events and, crucially, does not 33 | need to hallucinate motion from still images. Additionally, to spur further research in event-based semantic 34 | segmentation, we introduce DSEC-Semantic, the first large-scale event-based dataset with fine-grained labels. 35 | We show that using image labels alone, ESS outperforms existing UDA approaches, and when combined with event 36 | labels, it even outperforms state-of-the-art supervised approaches on both DDD17 and DSEC-Semantic. Finally, 37 | ESS is general-purpose, which unlocks the vast amount of existing labeled image datasets and paves the way 38 | for new and exciting research directions in new fields previously inaccessible for event cameras. 39 | 40 | ## Installation 41 | 42 | ### Dependencies 43 | If desired, a conda environment can be created using the followig command: 44 | ```bash 45 | conda create -n 46 | ``` 47 | As an initial step, the wheel package needs to be installed with the following command: 48 | ```bash 49 | pip install wheel 50 | ``` 51 | The required python packages are listed in the `requirements.txt` file. 52 | ```bash 53 | pip install -r requirements.txt 54 | ``` 55 | 56 | ### Pre-trained E2VID Model 57 | Pre-trained E2VID model needs to be downloaded [here](https://github.com/uzh-rpg/rpg_e2vid) and placed in `/e2vid/pretrained/`. 58 | 59 | ## Datasets 60 | 61 | ### DSEC-Semantic 62 | The DSEC-Semantic dataset can be downloaded [here](https://dsec.ifi.uzh.ch/dsec-semantic/). The dataset should have the following format: 63 | 64 | ├── DSEC_Semantic 65 | │ ├── train 66 | │ │ ├── zurich_city_00_a 67 | │ │ │ ├── semantic 68 | │ │ │ │ ├── left 69 | │ │ │ │ │ ├── 11classes 70 | │ │ │ │ │ │ └──data 71 | │ │ │ │ │ │ ├── 000000.png 72 | │ │ │ │ │ │ └── ... 73 | │ │ │ │ │ └── 19classes 74 | │ │ │ │ │ └──data 75 | │ │ │ │ │ ├── 000000.png 76 | │ │ │ │ │ └── ... 77 | │ │ │ │ └── timestamps.txt 78 | │ │ │ └── events 79 | │ │ │ └── left 80 | │ │ │ ├── events.h5 81 | │ │ │ └── rectify_map.h5 82 | │ │ └── ... 83 | │ └── test 84 | │ ├── zurich_city_13_a 85 | │ │ └── ... 86 | │ └── ... 87 | 88 | ### DDD17 89 | The original DDD17 dataset with semantic segmentation labels can be downloaded [here](https://github.com/Shathe/Ev-SegNet). 90 | Additionally, we provide a pre-processed DDD17 dataset with semantic labels [here](https://download.ifi.uzh.ch/rpg/ESS/ddd17_seg.tar.gz). Please do not forget to cite [DDD17](https://sensors.ini.uzh.ch/news_page/DDD17.html) and [Ev-SegNet](https://github.com/Shathe/Ev-SegNet) if you are using the DDD17 with semantic labels. 91 | 92 | ### Cityscapes 93 | The Cityscapes dataset can be downloaded [here](https://www.cityscapes-dataset.com/). 94 | 95 | ## Training 96 | The settings for the training can be specified in `config/settings_XXXX.yaml`. 97 | Two different models can be trained: 98 | - ess: ESS UDA / ESS supervised (events labels + frames labels) 99 | - ess_supervised: ESS supervised (only events labels) 100 | 101 | The following command starts the training: 102 | 103 | ```bash 104 | CUDA_VISIBLE_DEVICES=, python train.py --settings_file config/settings_XXXX.yaml 105 | ``` 106 | 107 | For testing the pre-trained models, please set `load_pretrained_weights=True` and specify the path of pre-trained weights in `pretrained_file`. 108 | 109 | ## Pre-trained Weights 110 | To download the pre-trained weights for the models on DDD17 and DSEC in the UDA setting, please fill in your details in [this](https://docs.google.com/forms/d/e/1FAIpQLScn5XWvBcmjPoSbaIqEUoEpWeheLGpQTUeK6Pp19wx2jNCPpA/viewform?usp=sf_link) Google Form. 111 | 112 | # Acknowledgement 113 | Several network architectures were adapted from: 114 | https://github.com/uzh-rpg/rpg_e2vid 115 | 116 | The general training framework was inspired by: 117 | https://github.com/uzh-rpg/rpg_ev-transfer 118 | 119 | The DSEC data loader was adapted from: 120 | https://github.com/uzh-rpg/DSEC 121 | 122 | The optimizer was adapted from: 123 | https://github.com/LiyuanLucasLiu/RAdam 124 | 125 | The DICE loss was adapted from: 126 | https://github.com/Guocode/DiceLoss.Pytorch -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/config/__init__.py -------------------------------------------------------------------------------- /config/settings_DDD17.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name_a: 'Cityscapes_gray' 3 | name_b: 'DDD17_events' 4 | DDD17_events: 5 | dataset_path: 6 | split_train: 'train' 7 | shape: [200, 346] 8 | nr_events_data: 20 9 | nr_events_files_per_data: None 10 | fixed_duration: False 11 | delta_t_per_data: 50 12 | require_paired_data_train: False 13 | require_paired_data_val: True 14 | nr_events_window: 32000 15 | event_representation: 'voxel_grid' 16 | nr_temporal_bins: 5 17 | separate_pol: False 18 | normalize_event: False 19 | cityscapes_img: 20 | dataset_path: 21 | shape: [200, 352] # [200, 352] for DDD17, [440, 640] for DSEC 22 | random_crop: True # True for DDD17, False for DSEC 23 | read_two_imgs: False 24 | require_paired_data_train: False 25 | require_paired_data_val: False 26 | task: 27 | semseg_num_classes: 6 # 6 for DDD17, 11 for DSEC 28 | dir: 29 | log: 30 | model: 31 | model_name: 'ess' # ['ess', 'ess_supervised'] 32 | skip_connect_encoder: True 33 | skip_connect_task: True 34 | skip_connect_task_type: 'concat' 35 | data_augmentation_train: True 36 | train_on_event_labels: False # True for ESS supervised (events labels + frames labels), False for ESS UDA 37 | optim: 38 | batch_size_a: 16 39 | batch_size_b: 16 40 | lr_front: 1e-5 41 | lr_back: 1e-4 42 | lr_decay: 1 43 | num_epochs: 20 44 | val_epoch_step: 1 45 | weight_task_loss: 1 46 | weight_cycle_pred_loss: 1 47 | weight_cycle_emb_loss: 0.01 48 | weight_cycle_task_loss: 0.01 49 | task_loss: ['dice', 'cross_entropy'] 50 | checkpoint: 51 | save_checkpoint: True 52 | resume_training: False 53 | load_pretrained_weights: False # True for loading pre-trained weights 54 | resume_file: 55 | pretrained_file: 56 | hardware: 57 | # num_cpu_workers: {-1: auto, 0: main thread, >0: ...} 58 | num_cpu_workers: 8 59 | gpu_device: 0 # [0 or 'cpu'] 60 | 61 | -------------------------------------------------------------------------------- /config/settings_DSEC.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name_a: 'Cityscapes_gray' 3 | name_b: 'DSEC_events' 4 | DSEC_events: 5 | dataset_path: 6 | shape: [440, 640] 7 | nr_events_data: 20 8 | nr_events_files_per_data: None 9 | fixed_duration: False 10 | delta_t_per_data: 50 11 | require_paired_data_train: False 12 | require_paired_data_val: True 13 | nr_events_window: 100000 14 | event_representation: 'voxel_grid' 15 | nr_temporal_bins: 5 16 | separate_pol: False 17 | normalize_event: False 18 | cityscapes_img: 19 | dataset_path: 20 | shape: [440, 640] # [200, 352] for DDD17, [440, 640] for DSEC 21 | random_crop: False # True for DDD17, False for DSEC 22 | read_two_imgs: False 23 | require_paired_data_train: False 24 | require_paired_data_val: False 25 | task: 26 | semseg_num_classes: 11 # 6 for DDD17, 11 for DSEC 27 | dir: 28 | log: 29 | model: 30 | model_name: 'ess' # ['ess', 'ess_supervised'] 31 | skip_connect_encoder: True 32 | skip_connect_task: True 33 | skip_connect_task_type: 'concat' 34 | data_augmentation_train: True 35 | train_on_event_labels: False # True for ESS supervised (events labels + frames labels), False for ESS UDA 36 | optim: 37 | batch_size_a: 8 38 | batch_size_b: 8 39 | lr_front: 5e-4 40 | lr_back: 5e-4 41 | lr_decay: 1 42 | num_epochs: 50 43 | val_epoch_step: 5 44 | weight_task_loss: 1 45 | weight_cycle_pred_loss: 1 46 | weight_cycle_emb_loss: 1 47 | weight_cycle_task_loss: 1 48 | task_loss: ['dice', 'cross_entropy'] 49 | checkpoint: 50 | save_checkpoint: True 51 | resume_training: False 52 | load_pretrained_weights: False # True for loading pre-trained weights 53 | resume_file: 54 | pretrained_file: 55 | hardware: 56 | # num_cpu_workers: {-1: auto, 0: main thread, >0: ...} 57 | num_cpu_workers: 8 58 | gpu_device: 0 # [0 or 'cpu'] 59 | 60 | -------------------------------------------------------------------------------- /datasets/DSEC_events_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torchvision.transforms as transforms 8 | import datasets.data_util as data_util 9 | from pathlib import Path 10 | 11 | from DSEC.dataset.provider import DatasetProvider 12 | 13 | 14 | def DSECEvents(dsec_dir, nr_events_data=1, delta_t_per_data=50, nr_events_window=-1, 15 | augmentation=False, mode='train', task='segmentation', event_representation='voxel_grid', 16 | nr_bins_per_data=5, require_paired_data=False, separate_pol=False, normalize_event=False, 17 | semseg_num_classes=11, fixed_duration=False, resize=False): 18 | """ 19 | Creates an iterator over the EventScape dataset. 20 | 21 | :param root: path to dataset root 22 | :param height: height of dataset image 23 | :param width: width of dataset image 24 | :param nr_events_window: number of events summed in the sliding histogram 25 | :param augmentation: flip, shift and random window start for training 26 | :param mode: 'train', 'test' or 'val' 27 | """ 28 | dsec_dir = Path(dsec_dir) 29 | assert dsec_dir.is_dir() 30 | 31 | dataset_provider = DatasetProvider(dsec_dir, mode, event_representation=event_representation, 32 | nr_events_data=nr_events_data, delta_t_per_data=delta_t_per_data, 33 | nr_events_window=nr_events_window, nr_bins_per_data=nr_bins_per_data, 34 | require_paired_data=require_paired_data, normalize_event=normalize_event, 35 | separate_pol=separate_pol, semseg_num_classes=semseg_num_classes, 36 | augmentation=augmentation, fixed_duration=fixed_duration, resize=resize) 37 | if mode == 'train': 38 | train_dataset = dataset_provider.get_train_dataset() 39 | return train_dataset 40 | else: 41 | val_dataset = dataset_provider.get_val_dataset() 42 | return val_dataset 43 | -------------------------------------------------------------------------------- /datasets/cityscapes_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | from torchvision import datasets 5 | import torchvision.transforms as transforms 6 | from torch.utils.data import Dataset 7 | import albumentations as A 8 | from utils.labels import Id2label_6_Cityscapes, Id2label_11_Cityscapes, fromIdToTrainId 9 | 10 | 11 | class CityscapesGray(Dataset): 12 | def __init__(self, root, height=None, width=None, augmentation=False, split='train', target_type='semantic', 13 | semseg_num_classes=6, standardization=False, random_crop=True): 14 | 15 | self.root = root 16 | self.split = split 17 | self.height = height 18 | self.width = width 19 | self.random_crop = random_crop 20 | if self.random_crop: 21 | self.height_resize = 256 # 154 22 | self.width_resize = 512 # 308 23 | else: 24 | self.height_resize = height 25 | self.width_resize = width 26 | self.transform = transforms.Compose([ 27 | transforms.Grayscale(), 28 | transforms.Resize([self.height_resize, self.width_resize]) 29 | ]) 30 | self.cityscapes_dataset = datasets.Cityscapes(self.root, split=self.split, mode='fine', target_type=target_type, 31 | transform=self.transform, target_transform=None) 32 | self.augmentation = augmentation 33 | self.standardization = standardization 34 | if self.standardization: 35 | mean = 0.3091 36 | std = 0.1852 37 | self.standardization_a = A.Normalize(mean=mean, std=std) 38 | 39 | if self.augmentation: 40 | self.transform_a = A.Compose([ 41 | A.HorizontalFlip(p=0.5), 42 | A.ShiftScaleRotate(scale_limit=(0, 0.5), rotate_limit=0, shift_limit=0.1, p=0.5, border_mode=0), 43 | A.PadIfNeeded(min_height=self.height, min_width=self.width, always_apply=True, border_mode=0), 44 | A.RandomCrop(height=self.height, width=self.width, always_apply=True), 45 | A.GaussNoise(p=0.2), 46 | A.Perspective(p=0.2), 47 | A.RandomBrightnessContrast(p=0.5), 48 | A.OneOf( 49 | [ 50 | A.Sharpen(p=1), 51 | A.Blur(blur_limit=3, p=1), 52 | A.MotionBlur(blur_limit=3, p=1), 53 | ], 54 | p=0.5, 55 | ) 56 | ]) 57 | 58 | self.transform_a_random_crop = A.Compose([ 59 | A.HorizontalFlip(p=0.5), 60 | A.ShiftScaleRotate(scale_limit=(0, 0.5), rotate_limit=0, shift_limit=0, p=0.5, border_mode=0), 61 | A.PadIfNeeded(min_height=self.height, min_width=self.width, always_apply=True, border_mode=0), 62 | A.RandomCrop(height=self.height, width=self.width, always_apply=True), 63 | A.GaussNoise(p=0.2), 64 | A.Perspective(p=0.2), 65 | A.RandomBrightnessContrast(p=0.5), 66 | A.OneOf( 67 | [ 68 | A.Sharpen(p=1), 69 | A.Blur(blur_limit=3, p=1), 70 | A.MotionBlur(blur_limit=3, p=1), 71 | ], 72 | p=0.5, 73 | ) 74 | ]) 75 | 76 | self.transform_a_center_crop = A.Compose([ 77 | A.CenterCrop(height=self.height, width=self.width, always_apply=True), 78 | ]) 79 | 80 | self.semseg_num_classes = semseg_num_classes 81 | self.require_paired_data = False 82 | 83 | def __len__(self): 84 | return len(self.cityscapes_dataset) 85 | 86 | def __getitem__(self, idx): 87 | img, label = self.cityscapes_dataset[idx] 88 | img = np.array(img) 89 | label = label.resize((self.width_resize, self.height_resize), Image.NEAREST) 90 | label = np.array(label) 91 | 92 | if self.standardization: 93 | Imin = np.min(img) 94 | Imax = np.max(img) 95 | img = 255.0 * (img - Imin) / (Imax - Imin) 96 | img = img.astype('uint8') 97 | 98 | if self.random_crop: 99 | img = img[:self.height, :] 100 | label = label[:self.height, :] 101 | 102 | if self.augmentation: 103 | sample = self.transform_a_random_crop(image=img, mask=label) 104 | else: 105 | sample = self.transform_a_center_crop(image=img, mask=label) 106 | img, label = sample["image"], sample['mask'] 107 | 108 | else: 109 | if self.augmentation: 110 | sample = self.transform_a(image=img, mask=label) 111 | img, label = sample["image"], sample['mask'] 112 | 113 | img = Image.fromarray(img.astype('uint8')) 114 | 115 | if self.semseg_num_classes == 6: 116 | label = fromIdToTrainId(label, Id2label_6_Cityscapes) 117 | elif self.semseg_num_classes == 11: 118 | label = fromIdToTrainId(label, Id2label_11_Cityscapes) 119 | 120 | label_tensor = torch.from_numpy(label).long() 121 | img_transform = transforms.Compose([ 122 | transforms.ToTensor() 123 | ]) 124 | img_tensor = img_transform(img) 125 | 126 | return img_tensor, label_tensor 127 | 128 | -------------------------------------------------------------------------------- /datasets/data_util.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def generate_input_representation(events, event_representation, shape, nr_temporal_bins=5, separate_pol=True): 7 | """ 8 | Events: N x 4, where cols are x, y, t, polarity, and polarity is in {-1, 1}. x and y correspond to image 9 | coordinates u and v. 10 | """ 11 | if event_representation == 'histogram': 12 | return generate_event_histogram(events, shape) 13 | elif event_representation == 'voxel_grid': 14 | return generate_voxel_grid(events, shape, nr_temporal_bins, separate_pol) 15 | 16 | 17 | def generate_event_histogram(events, shape): 18 | """ 19 | Events: N x 4, where cols are x, y, t, polarity, and polarity is in {-1, 1}. x and y correspond to image 20 | coordinates u and v. 21 | """ 22 | height, width = shape 23 | x, y, t, p = events.T 24 | x = x.astype(np.int) 25 | y = y.astype(np.int) 26 | p[p == 0] = -1 # polarity should be +1 / -1 27 | img_pos = np.zeros((height * width,), dtype="float32") 28 | img_neg = np.zeros((height * width,), dtype="float32") 29 | 30 | np.add.at(img_pos, x[p == 1] + width * y[p == 1], 1) 31 | np.add.at(img_neg, x[p == -1] + width * y[p == -1], 1) 32 | 33 | histogram = np.stack([img_neg, img_pos], 0).reshape((2, height, width)) 34 | 35 | return histogram 36 | 37 | 38 | def normalize_voxel_grid(events): 39 | """Normalize event voxel grids""" 40 | nonzero_ev = (events != 0) 41 | num_nonzeros = nonzero_ev.sum() 42 | if num_nonzeros > 0: 43 | # compute mean and stddev of the **nonzero** elements of the event tensor 44 | # we do not use PyTorch's default mean() and std() functions since it's faster 45 | # to compute it by hand than applying those funcs to a masked array 46 | mean = events.sum() / num_nonzeros 47 | stddev = torch.sqrt((events ** 2).sum() / num_nonzeros - mean ** 2) 48 | mask = nonzero_ev.float() 49 | events = mask * (events - mean) / stddev 50 | 51 | return events 52 | 53 | 54 | def generate_voxel_grid(events, shape, nr_temporal_bins, separate_pol=True): 55 | """ 56 | Build a voxel grid with bilinear interpolation in the time domain from a set of events. 57 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity] 58 | :param nr_temporal_bins: number of bins in the temporal axis of the voxel grid 59 | :param shape: dimensions of the voxel grid 60 | """ 61 | height, width = shape 62 | assert(events.shape[1] == 4) 63 | assert(nr_temporal_bins > 0) 64 | assert(width > 0) 65 | assert(height > 0) 66 | 67 | voxel_grid_positive = np.zeros((nr_temporal_bins, height, width), np.float32).ravel() 68 | voxel_grid_negative = np.zeros((nr_temporal_bins, height, width), np.float32).ravel() 69 | 70 | # normalize the event timestamps so that they lie between 0 and num_bins 71 | last_stamp = events[-1, 2] 72 | first_stamp = events[0, 2] 73 | deltaT = last_stamp - first_stamp 74 | 75 | if deltaT == 0: 76 | deltaT = 1.0 77 | 78 | # events[:, 2] = (nr_temporal_bins - 1) * (events[:, 2] - first_stamp) / deltaT 79 | xs = events[:, 0].astype(np.int) 80 | ys = events[:, 1].astype(np.int) 81 | # ts = events[:, 2] 82 | # print(ts[:10]) 83 | ts = (nr_temporal_bins - 1) * (events[:, 2] - first_stamp) / deltaT 84 | 85 | pols = events[:, 3] 86 | pols[pols == 0] = -1 # polarity should be +1 / -1 87 | 88 | tis = ts.astype(np.int) 89 | dts = ts - tis 90 | vals_left = np.abs(pols) * (1.0 - dts) 91 | vals_right = np.abs(pols) * dts 92 | pos_events_indices = pols == 1 93 | 94 | # Positive Voxels Grid 95 | valid_indices_pos = np.logical_and(tis < nr_temporal_bins, pos_events_indices) 96 | valid_pos = (xs < width) & (xs >= 0) & (ys < height) & (ys >= 0) & (ts >= 0) & (ts < nr_temporal_bins) 97 | valid_indices_pos = np.logical_and(valid_indices_pos, valid_pos) 98 | 99 | np.add.at(voxel_grid_positive, xs[valid_indices_pos] + ys[valid_indices_pos] * width + 100 | tis[valid_indices_pos] * width * height, vals_left[valid_indices_pos]) 101 | 102 | valid_indices_pos = np.logical_and((tis + 1) < nr_temporal_bins, pos_events_indices) 103 | valid_indices_pos = np.logical_and(valid_indices_pos, valid_pos) 104 | np.add.at(voxel_grid_positive, xs[valid_indices_pos] + ys[valid_indices_pos] * width + 105 | (tis[valid_indices_pos] + 1) * width * height, vals_right[valid_indices_pos]) 106 | 107 | # Negative Voxels Grid 108 | valid_indices_neg = np.logical_and(tis < nr_temporal_bins, ~pos_events_indices) 109 | valid_indices_neg = np.logical_and(valid_indices_neg, valid_pos) 110 | 111 | np.add.at(voxel_grid_negative, xs[valid_indices_neg] + ys[valid_indices_neg] * width + 112 | tis[valid_indices_neg] * width * height, vals_left[valid_indices_neg]) 113 | 114 | valid_indices_neg = np.logical_and((tis + 1) < nr_temporal_bins, ~pos_events_indices) 115 | valid_indices_neg = np.logical_and(valid_indices_neg, valid_pos) 116 | np.add.at(voxel_grid_negative, xs[valid_indices_neg] + ys[valid_indices_neg] * width + 117 | (tis[valid_indices_neg] + 1) * width * height, vals_right[valid_indices_neg]) 118 | 119 | voxel_grid_positive = np.reshape(voxel_grid_positive, (nr_temporal_bins, height, width)) 120 | voxel_grid_negative = np.reshape(voxel_grid_negative, (nr_temporal_bins, height, width)) 121 | 122 | if separate_pol: 123 | return np.concatenate([voxel_grid_positive, voxel_grid_negative], axis=0) 124 | 125 | voxel_grid = voxel_grid_positive - voxel_grid_negative 126 | return voxel_grid 127 | 128 | -------------------------------------------------------------------------------- /datasets/ddd17_events_loader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from os.path import join, exists, dirname, basename 3 | import os 4 | import cv2 5 | import torch 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | import torch.nn.functional as f 9 | import torchvision.transforms as transforms 10 | 11 | from datasets.extract_data_tools.example_loader_ddd17 import load_files_in_directory, extract_events_from_memmap 12 | import datasets.data_util as data_util 13 | import albumentations as A 14 | from PIL import Image 15 | from utils.labels import shiftUpId, shiftDownId 16 | 17 | 18 | def get_split(dirs, split): 19 | return { 20 | "train": [dirs[0], dirs[2], dirs[3], dirs[5], dirs[6]], 21 | "test": [dirs[4]], 22 | "valid": [dirs[1]] 23 | }[split] 24 | 25 | 26 | def unzip_segmentation_masks(dirs): 27 | for d in dirs: 28 | assert exists(join(d, "segmentation_masks.zip")) 29 | if not exists(join(d, "segmentation_masks")): 30 | print("Unzipping segmentation mask in %s" % d) 31 | os.system("unzip %s -d %s" % (join(d, "segmentation_masks"), d)) 32 | 33 | 34 | class DDD17Events(Dataset): 35 | def __init__(self, root, split="train", event_representation='voxel_grid', 36 | nr_events_data=5, delta_t_per_data=50, nr_bins_per_data=5, require_paired_data=False, 37 | separate_pol=False, normalize_event=False, augmentation=False, fixed_duration=False, 38 | nr_events_per_data=32000, resize=True, random_crop=False): 39 | data_dirs = sorted(glob.glob(join(root, "dir*"))) 40 | assert len(data_dirs) > 0 41 | assert split in ["train", "valid", "test"] 42 | 43 | self.split = split 44 | self.augmentation = augmentation 45 | self.fixed_duration = fixed_duration 46 | self.nr_events_per_data = nr_events_per_data 47 | 48 | self.nr_events_data = nr_events_data 49 | self.delta_t_per_data = delta_t_per_data 50 | if self.fixed_duration: 51 | self.t_interval = nr_events_data * delta_t_per_data 52 | else: 53 | self.t_interval = -1 54 | self.nr_events = self.nr_events_data * self.nr_events_per_data 55 | assert self.t_interval in [10, 50, 250, -1] 56 | self.nr_temporal_bins = nr_bins_per_data 57 | self.require_paired_data = require_paired_data 58 | self.event_representation = event_representation 59 | self.shape = [260, 346] 60 | self.resize = resize 61 | self.shape_resize = [260, 352] 62 | self.random_crop = random_crop 63 | self.shape_crop = [120, 216] 64 | self.separate_pol = separate_pol 65 | self.normalize_event = normalize_event 66 | self.dirs = get_split(data_dirs, split) 67 | # unzip_segmentation_masks(self.dirs) 68 | 69 | self.files = [] 70 | for d in self.dirs: 71 | self.files += glob.glob(join(d, "segmentation_masks", "*.png")) 72 | 73 | print("[DDD17Segmentation]: Found %s segmentation masks for split %s" % (len(self.files), split)) 74 | 75 | # load events and image_idx -> event index mapping 76 | self.img_timestamp_event_idx = {} 77 | self.event_data = {} 78 | 79 | print("[DDD17Segmentation]: Loading real events.") 80 | self.event_dirs = self.dirs 81 | 82 | for d in self.event_dirs: 83 | img_timestamp_event_idx, t_events, xyp_events, _ = load_files_in_directory(d, self.t_interval) 84 | self.img_timestamp_event_idx[d] = img_timestamp_event_idx 85 | self.event_data[d] = [t_events, xyp_events] 86 | 87 | if self.augmentation: 88 | self.transform_a = A.ReplayCompose([ 89 | A.HorizontalFlip(p=0.5) 90 | ]) 91 | self.transform_a_random_crop = A.ReplayCompose([ 92 | A.HorizontalFlip(p=0.5), 93 | A.RandomCrop(height=self.shape_crop[0], width=self.shape_crop[1], always_apply=True)]) 94 | self.transform_a_center_crop = A.ReplayCompose([ 95 | A.CenterCrop(height=self.shape_crop[0], width=self.shape_crop[1], always_apply=True), 96 | ]) 97 | 98 | def __len__(self): 99 | return len(self.files) 100 | 101 | def apply_augmentation(self, transform_a, events, label): 102 | label = shiftUpId(label) 103 | A_data = transform_a(image=events[0, :, :].numpy(), mask=label) 104 | label = A_data['mask'] 105 | label = shiftDownId(label) 106 | if self.random_crop and self.split == 'train': 107 | events_tensor = torch.zeros((events.shape[0], self.shape_crop[0], self.shape_crop[1])) 108 | else: 109 | events_tensor = events 110 | for k in range(events.shape[0]): 111 | events_tensor[k, :, :] = torch.from_numpy( 112 | A.ReplayCompose.replay(A_data['replay'], image=events[k, :, :].numpy())['image']) 113 | return events_tensor, label 114 | 115 | def __getitem__(self, idx): 116 | segmentation_mask_file = self.files[idx] 117 | segmentation_mask = cv2.imread(segmentation_mask_file, 0) 118 | label_original = np.array(segmentation_mask) 119 | if self.resize: 120 | segmentation_mask = cv2.resize(segmentation_mask, (self.shape_resize[1], self.shape_resize[0] - 60), 121 | interpolation=cv2.INTER_NEAREST) 122 | label = np.array(segmentation_mask) 123 | 124 | directory = dirname(dirname(segmentation_mask_file)) 125 | 126 | img_idx = int(basename(segmentation_mask_file).split("_")[-1].split(".")[0]) - 1 127 | 128 | img_timestamp_event_idx = self.img_timestamp_event_idx[directory] 129 | t_events, xyp_events = self.event_data[directory] 130 | 131 | # events has form x, y, t_ns, p (in [0,1]) 132 | events = extract_events_from_memmap(t_events, xyp_events, img_idx, img_timestamp_event_idx, self.fixed_duration, 133 | self.nr_events) 134 | t_ns = events[:, 2] 135 | delta_t_ns = int((t_ns[-1] - t_ns[0]) / self.nr_events_data) 136 | nr_events_loaded = events.shape[0] 137 | nr_events_temp = nr_events_loaded // self.nr_events_data 138 | 139 | id_end = 0 140 | event_tensor = None 141 | for i in range(self.nr_events_data): 142 | id_start = id_end 143 | if self.fixed_duration: 144 | id_end = np.searchsorted(t_ns, t_ns[0] + (i + 1) * delta_t_ns) 145 | else: 146 | id_end += nr_events_temp 147 | 148 | if id_end > nr_events_loaded: 149 | id_end = nr_events_loaded 150 | 151 | event_representation = data_util.generate_input_representation(events[id_start:id_end], 152 | self.event_representation, 153 | self.shape, 154 | nr_temporal_bins=self.nr_temporal_bins, 155 | separate_pol=self.separate_pol) 156 | 157 | event_representation = torch.from_numpy(event_representation) 158 | 159 | if self.normalize_event: 160 | event_representation = data_util.normalize_voxel_grid(event_representation) 161 | 162 | if self.resize: 163 | event_representation_resize = f.interpolate(event_representation.unsqueeze(0), 164 | size=(self.shape_resize[0], self.shape_resize[1]), 165 | mode='bilinear', align_corners=True) 166 | event_representation = event_representation_resize.squeeze(0) 167 | 168 | if event_tensor is None: 169 | event_tensor = event_representation 170 | else: 171 | event_tensor = torch.cat([event_tensor, event_representation], dim=0) 172 | 173 | event_tensor = event_tensor[:, :-60, :] # remove 60 bottom rows 174 | 175 | if self.random_crop and self.split == 'train': 176 | event_tensor = event_tensor[:, -self.shape_crop[0]:, :] 177 | label = label[-self.shape_crop[0]:, :] 178 | if self.augmentation: 179 | event_tensor, label = self.apply_augmentation(self.transform_a_random_crop, event_tensor, label) 180 | 181 | else: 182 | if self.augmentation: 183 | event_tensor, label = self.apply_augmentation(self.transform_a, event_tensor, label) 184 | 185 | label_tensor = torch.from_numpy(label).long() 186 | 187 | if self.split == 'valid' and self.require_paired_data: 188 | segmentation_mask_filepath_list = str(segmentation_mask_file).split('/') 189 | segmentation_mask_filename = segmentation_mask_filepath_list[-1] 190 | dir_name = segmentation_mask_filepath_list[-3] 191 | filename_id = segmentation_mask_filename.split('_')[-1] 192 | img_filename = '_'.join(['img', filename_id]) 193 | img_filepath_list = segmentation_mask_filepath_list 194 | img_filepath_list[-2] = 'imgs' 195 | img_filepath_list[-1] = img_filename 196 | img_file = '/'.join(img_filepath_list) 197 | if not os.path.exists(img_file): 198 | img_filename = filename_id.zfill(14) 199 | img_filepath_list[-1] = img_filename 200 | img_file = '/'.join(img_filepath_list) 201 | img = Image.open(img_file) 202 | 203 | if self.resize: 204 | img = img.resize((self.shape_resize[1], self.shape_resize[0])) 205 | img_transform = transforms.Compose([ 206 | transforms.Grayscale(), 207 | transforms.ToTensor() 208 | ]) 209 | img_tensor = img_transform(img) 210 | img_tensor = img_tensor[:, :-60, :] 211 | 212 | label_original_tensor = torch.from_numpy(label_original).long() 213 | return event_tensor, img_tensor, label_tensor, label_original_tensor 214 | return event_tensor, label_tensor 215 | 216 | -------------------------------------------------------------------------------- /datasets/extract_data_tools/example_loader_ddd17.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import copy 7 | 8 | 9 | def load_files_in_directory(directory, t_interval=50): 10 | # Load files: these include these have the following form 11 | # 12 | # idx.npy : 13 | # t0_ns idx0 14 | # t1_ns idx1 15 | # ... 16 | # tj_ns idxj 17 | # ... 18 | # tN_ns idxN 19 | # 20 | # This file contains a mapping from j -> tj_ns idxj, 21 | # where j+1 is the idx of the img with timestamp tj_ns (in nanoseconds) 22 | # and idxj is the idx of the last event before the img (in events.dat.t and events.dat.xyp) 23 | if t_interval == 10: 24 | img_timestamp_event_idx = np.load(os.path.join(directory, "index/index_10ms.npy")) 25 | elif t_interval == 50: 26 | img_timestamp_event_idx = np.load(os.path.join(directory, "index/index_50ms.npy")) 27 | elif t_interval == 250: 28 | img_timestamp_event_idx = np.load(os.path.join(directory, "index/index_250ms.npy")) 29 | else: 30 | img_timestamp_event_idx = np.load(os.path.join(directory, "index/index_50ms.npy")) 31 | 32 | # events.dat.t : 33 | # t0_ns 34 | # t1_ns 35 | # ... 36 | # tM_ns 37 | # 38 | # events.dat.xyp : 39 | # x0 y0 p0 40 | # ... 41 | # xM yM pM 42 | events_t_file = os.path.join(directory, "events.dat.t") 43 | events_xyp_file = os.path.join(directory, "events.dat.xyp") 44 | 45 | t_events, xyp_events = load_events(events_t_file, events_xyp_file) 46 | 47 | # Since the imgs are in a video format, they cannot be loaded directly, however, the segmentation masks from the 48 | # original dataset (EV-SegNet) have been copied into this folder. First unzip the segmentation masks with 49 | # 50 | # unzip segmentation_masks.zip 51 | # 52 | segmentation_mask_files = sorted(glob.glob(os.path.join(directory, "segmentation_masks", "*.png"))) 53 | 54 | return img_timestamp_event_idx, t_events, xyp_events, segmentation_mask_files 55 | 56 | 57 | def load_events(t_file, xyp_file): 58 | # events.dat.t saves the timestamps of the indiviual events (in nanoseconds -> int64) 59 | # events.dat.xyp saves the x, y and polarity of events in uint8 to save storage. The polarity is 0 or 1. 60 | # We first need to compute the number of events in the memmap since it does not do it for us. We can do 61 | # this by computing the file size of the timestamps and dividing by 8 (since timestamps take 8 bytes) 62 | 63 | num_events = int(os.path.getsize(t_file) / 8) 64 | t_events = np.memmap(t_file, dtype="int64", mode="r", shape=(num_events, 1)) 65 | xyp_events = np.memmap(xyp_file, dtype="int16", mode="r", shape=(num_events, 3)) 66 | 67 | return t_events, xyp_events 68 | 69 | 70 | def extract_events_from_memmap(t_events, xyp_events, img_idx, img_timestamp_event_idx, fixed_duration=False, 71 | nr_events=32000): 72 | # timestep, event_idx = img_timestamp_event_idx[img_idx] 73 | # _, event_idx_before = img_timestamp_event_idx[img_idx - 1] 74 | if fixed_duration: 75 | timestep, event_idx, event_idx_before = img_timestamp_event_idx[img_idx] 76 | event_idx_before = max([event_idx_before, 0]) 77 | else: 78 | timestep, event_idx, _ = img_timestamp_event_idx[img_idx] 79 | event_idx_before = max([event_idx - nr_events, 0]) 80 | events_between_imgs = np.concatenate([ 81 | np.array(t_events[event_idx_before:event_idx], dtype="int64"), 82 | np.array(xyp_events[event_idx_before:event_idx], dtype="int64") 83 | ], -1) 84 | events_between_imgs = events_between_imgs[:, [1, 2, 0, 3]] # events have format xytp, and p is in [0,1] 85 | 86 | return events_between_imgs 87 | 88 | 89 | def generate_event_img(shape, events): 90 | H, W = shape 91 | # generate event img 92 | event_img_pos = np.zeros((H * W,), dtype="float32") 93 | event_img_neg = np.zeros((H * W,), dtype="float32") 94 | 95 | x, y, t, p = events.T 96 | 97 | np.add.at(event_img_pos, x[p == 1] + W * y[p == 1], p[p == 1]) 98 | np.add.at(event_img_neg, x[p == 0] + W * y[p == 0], p[p == 0] + 1) 99 | 100 | event_img_pos = event_img_pos.reshape((H, W)) 101 | event_img_neg = event_img_neg.reshape((H, W)) 102 | 103 | return event_img_neg, event_img_pos 104 | 105 | 106 | def generate_colored_label_img(shape, label_mask): 107 | H, W = shape 108 | 109 | colors = [[0, 0, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255]] 110 | mask = segmentation_mask.reshape((-1, 3))[:, 0] 111 | 112 | img = np.zeros((H * W, 3), dtype="uint8") 113 | 114 | for i in np.unique(segmentation_mask): 115 | c = colors[int(i)] 116 | img[mask == i, 0] = c[0] 117 | img[mask == i, 1] = c[1] 118 | img[mask == i, 2] = c[2] 119 | 120 | img = img.reshape((H, W, 3)) 121 | 122 | return img 123 | 124 | 125 | def generate_rendered_events_on_img(img, event_map_neg, event_map_pos): 126 | orig_shape = img.shape 127 | 128 | img = img.copy() 129 | 130 | img = img.reshape((-1, 3)) 131 | pos_mask = event_map_pos.reshape((-1,)) > 0 132 | neg_mask = event_map_neg.reshape((-1,)) > 0 133 | 134 | img[neg_mask, 0] = 255 135 | img[pos_mask, 2] = 255 136 | img[neg_mask | pos_mask, 1] = 0 137 | 138 | img = img.reshape(orig_shape) 139 | 140 | return img 141 | 142 | 143 | if __name__ == "__main__": 144 | 145 | directories = sorted(glob.glob(os.path.join(os.path.dirname(__file__), "dir*"))) 146 | assert len(directories) > 0 147 | # example with one directory 148 | directory = directories[1] 149 | print(directories) 150 | print("Using directory: %s" % directory) 151 | 152 | # load all files that are in the directory (these are real events) 153 | img_timestamp_event_idx, t_events, xyp_events, segmentation_mask_files = \ 154 | load_files_in_directory(directory) 155 | 156 | # load all files that are in video_upsampled_events. These are simulated data. 157 | sim_directory = os.path.join(directory, "video_upsampled_events") 158 | load_sim = os.path.exists(sim_directory) 159 | img_timestamp_event_idx_sim, t_events_sim, xyp_events_sim = None, None, None 160 | if load_sim: 161 | print("Loading sim data") 162 | img_timestamp_event_idx_sim, t_events_sim, xyp_events_sim, _ = \ 163 | load_files_in_directory(sim_directory) 164 | 165 | num_plots = 3 if load_sim else 2 166 | fig, ax = plt.subplots(ncols=num_plots) 167 | img_handles = [] 168 | assert len(segmentation_mask_files) > 0 169 | for segmentation_mask_file in segmentation_mask_files[-100:]: 170 | # take an example mask and extract the corresponding idx 171 | print("Using segmentation mask: %s" % segmentation_mask_file) 172 | segmentation_mask = cv2.imread(segmentation_mask_file) 173 | 174 | img_idx = int(os.path.basename(segmentation_mask_file).split("_")[-1].split(".")[0]) - 1 175 | print("Loading img with idx %s" % img_idx) 176 | 177 | # load corresponding img 178 | # first decompress video by running 179 | # 180 | # mkdir imgs 181 | # ffmpeg -i video.mp4 imgs/img_%08d.png 182 | # 183 | img_file = segmentation_mask_file.replace("segmentation_masks", "imgs").replace("/segmentation_", "/img_") 184 | # img_file = '/'.join(img_file.split('/')[:-1]) + '/' + '{:0>10}'.format(img_file.split('_')[-1][:-4]) + '.png' 185 | img = cv2.imread(img_file) 186 | 187 | # crop img since this was done in EV-SegNet 188 | img = img[:200] 189 | 190 | # find events between this idx and the last 191 | events_between_imgs = \ 192 | extract_events_from_memmap(t_events, xyp_events, img_idx, img_timestamp_event_idx) 193 | print("Found %s events" % (len(events_between_imgs))) 194 | 195 | # remove all events with y > 200 since these were cropped from the dataset 196 | events_between_imgs = events_between_imgs[events_between_imgs[:, 1] < 200] 197 | 198 | if load_sim: 199 | # find sim events between this idx and the last 200 | events_between_imgs_sim = \ 201 | extract_events_from_memmap(t_events_sim, xyp_events_sim, img_idx, img_timestamp_event_idx_sim) 202 | print("Found %s simulated events" % (len(events_between_imgs_sim))) 203 | 204 | # remove all events with y > 200 since these were cropped from the dataset 205 | events_between_imgs_sim = events_between_imgs_sim[events_between_imgs_sim[:, 1] < 200] 206 | 207 | event_img_neg, event_img_pos = generate_event_img((200, 346), events_between_imgs) 208 | event_img_neg_sim, event_img_pos_sim = generate_event_img((200, 346), events_between_imgs_sim) 209 | 210 | # generate view of labels 211 | colored_label_img = generate_colored_label_img((200, 346), segmentation_mask) 212 | 213 | # draw events on img 214 | rendered_events_on_img = generate_rendered_events_on_img(copy.deepcopy(img), event_img_neg, event_img_pos) 215 | 216 | if load_sim: 217 | # draw events on img 218 | rendered_events_on_img_sim = generate_rendered_events_on_img(copy.deepcopy(img), event_img_neg_sim, 219 | event_img_pos_sim) 220 | 221 | print("Error: ", 222 | np.abs((rendered_events_on_img_sim).astype("float32") - (rendered_events_on_img).astype("float32")).sum()) 223 | 224 | if len(img_handles) == 0: 225 | img_handles += [ax[0].imshow(colored_label_img)] 226 | img_handles += [ax[1].imshow(rendered_events_on_img)] 227 | if load_sim: 228 | img_handles += [ax[2].imshow(rendered_events_on_img_sim)] 229 | plt.show(block=False) 230 | else: 231 | img_handles[0].set_data(colored_label_img) 232 | img_handles[1].set_data(rendered_events_on_img) 233 | if load_sim: 234 | img_handles[2].set_data(rendered_events_on_img_sim) 235 | fig.canvas.draw() 236 | plt.pause(0.002) 237 | 238 | 239 | -------------------------------------------------------------------------------- /datasets/wrapper_dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class WrapperDataset(Dataset): 5 | def __init__(self, dataloader_a, dataloader_b, device, dataset_len_to_use=None): 6 | self.dataloader_a = dataloader_a 7 | self.dataloader_b = dataloader_b 8 | self.dataloader_a_iter = iter(self.dataloader_a) 9 | self.dataloader_b_iter = iter(self.dataloader_b) 10 | self.require_paired_data_a = dataloader_a.dataset.require_paired_data 11 | self.require_paired_data_b = dataloader_b.dataset.require_paired_data 12 | self.device = device 13 | 14 | self.dataset_a_larger = False 15 | if self.dataloader_a.__len__() > self.dataloader_b.__len__(): 16 | self.dataset_a_larger = True 17 | 18 | if dataset_len_to_use == 'first': 19 | self.dataset_a_larger = True 20 | elif dataset_len_to_use == 'second': 21 | self.dataset_a_larger = False 22 | 23 | def __len__(self): 24 | if self.dataset_a_larger: 25 | return self.dataloader_a.__len__() 26 | else: 27 | return self.dataloader_b.__len__() 28 | 29 | def createIterators(self): 30 | self.dataloader_a_iter = iter(self.dataloader_a) 31 | self.dataloader_b_iter = iter(self.dataloader_b) 32 | 33 | def __getitem__(self, idx): 34 | """ 35 | Returns two samples 36 | """ 37 | if not self.require_paired_data_a and not self.require_paired_data_b: 38 | if self.dataset_a_larger: 39 | try: 40 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter) 41 | except StopIteration: 42 | self.dataloader_b_iter = iter(self.dataloader_b) 43 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter) 44 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter) 45 | else: 46 | try: 47 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter) 48 | except StopIteration: 49 | self.dataloader_a_iter = iter(self.dataloader_a) 50 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter) 51 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter) 52 | 53 | return [dataset_a_data.to(self.device), dataset_a_label.to(self.device)], \ 54 | [dataset_b_data.to(self.device), dataset_b_label.to(self.device)] 55 | 56 | if self.require_paired_data_a and not self.require_paired_data_b: 57 | if self.dataset_a_larger: 58 | try: 59 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter) 60 | except StopIteration: 61 | self.dataloader_b_iter = iter(self.dataloader_b) 62 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter) 63 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter) 64 | else: 65 | try: 66 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter) 67 | except StopIteration: 68 | self.dataloader_a_iter = iter(self.dataloader_a) 69 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter) 70 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter) 71 | 72 | return [dataset_a_data.to(self.device), dataset_a_paired_data.to(self.device), dataset_a_label.to(self.device)], \ 73 | [dataset_b_data.to(self.device), dataset_b_label.to(self.device)] 74 | 75 | if not self.require_paired_data_a and self.require_paired_data_b: 76 | if self.dataset_a_larger: 77 | try: 78 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter) 79 | except StopIteration: 80 | self.dataloader_b_iter = iter(self.dataloader_b) 81 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter) 82 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter) 83 | else: 84 | try: 85 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter) 86 | except StopIteration: 87 | self.dataloader_a_iter = iter(self.dataloader_a) 88 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter) 89 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter) 90 | 91 | return [dataset_a_data.to(self.device), dataset_a_label.to(self.device)], \ 92 | [dataset_b_data.to(self.device), dataset_b_paired_data.to(self.device), dataset_b_label.to(self.device)] 93 | 94 | if self.require_paired_data_a and self.require_paired_data_b: 95 | if self.dataset_a_larger: 96 | try: 97 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter) 98 | except StopIteration: 99 | self.dataloader_b_iter = iter(self.dataloader_b) 100 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter) 101 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter) 102 | else: 103 | try: 104 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter) 105 | except StopIteration: 106 | self.dataloader_a_iter = iter(self.dataloader_a) 107 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter) 108 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter) 109 | 110 | return [dataset_a_data.to(self.device), dataset_a_paired_data.to(self.device), dataset_a_label.to(self.device)], \ 111 | [dataset_b_data.to(self.device), dataset_b_paired_data.to(self.device), dataset_b_label.to(self.device)] -------------------------------------------------------------------------------- /e2vid/README.md: -------------------------------------------------------------------------------- 1 | # High Speed and High Dynamic Range Video with an Event Camera 2 | 3 | [![High Speed and High Dynamic Range Video with an Event Camera](http://rpg.ifi.uzh.ch/E2VID/video_thumbnail.png)](https://youtu.be/eomALySSGVU) 4 | 5 | This is the code for the paper **High Speed and High Dynamic Range Video with an Event Camera** by [Henri Rebecq](http://henri.rebecq.fr), Rene Ranftl, [Vladlen Koltun](http://vladlen.info/) and [Davide Scaramuzza](http://rpg.ifi.uzh.ch/people_scaramuzza.html): 6 | 7 | You can find a pdf of the paper [here](http://rpg.ifi.uzh.ch/docs/TPAMI19_Rebecq.pdf). 8 | If you use any of this code, please cite the following publications: 9 | 10 | ```bibtex 11 | @Article{Rebecq19pami, 12 | author = {Henri Rebecq and Ren{\'{e}} Ranftl and Vladlen Koltun and Davide Scaramuzza}, 13 | title = {High Speed and High Dynamic Range Video with an Event Camera}, 14 | journal = {{IEEE} Trans. Pattern Anal. Mach. Intell. (T-PAMI)}, 15 | url = {http://rpg.ifi.uzh.ch/docs/TPAMI19_Rebecq.pdf}, 16 | year = 2019 17 | } 18 | ``` 19 | 20 | 21 | ```bibtex 22 | @Article{Rebecq19cvpr, 23 | author = {Henri Rebecq and Ren{\'{e}} Ranftl and Vladlen Koltun and Davide Scaramuzza}, 24 | title = {Events-to-Video: Bringing Modern Computer Vision to Event Cameras}, 25 | journal = {{IEEE} Conf. Comput. Vis. Pattern Recog. (CVPR)}, 26 | year = 2019 27 | } 28 | ``` 29 | 30 | ## Install 31 | 32 | Dependencies: 33 | 34 | - [PyTorch](https://pytorch.org/get-started/locally/) >= 1.0 35 | - [NumPy](https://www.numpy.org/) 36 | - [Pandas](https://pandas.pydata.org/) 37 | - [OpenCV](https://opencv.org/) 38 | 39 | ### Install with Anaconda 40 | 41 | The installation requires [Anaconda3](https://www.anaconda.com/distribution/). 42 | You can create a new Anaconda environment with the required dependencies as follows (make sure to adapt the CUDA toolkit version according to your setup): 43 | 44 | ```bash 45 | conda create -n E2VID 46 | conda activate E2VID 47 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 48 | conda install pandas 49 | conda install -c conda-forge opencv 50 | ``` 51 | 52 | ## Run 53 | 54 | - Download the pretrained model: 55 | 56 | ```bash 57 | wget "http://rpg.ifi.uzh.ch/data/E2VID/models/E2VID_lightweight.pth.tar" -O pretrained/E2VID_lightweight.pth.tar 58 | ``` 59 | 60 | - Download an example file with event data: 61 | 62 | ```bash 63 | wget "http://rpg.ifi.uzh.ch/data/E2VID/datasets/ECD_IJRR17/dynamic_6dof.zip" -O data/dynamic_6dof.zip 64 | ``` 65 | 66 | Before running the reconstruction, make sure the conda environment is sourced: 67 | 68 | ```bash 69 | conda activate E2VID 70 | ``` 71 | 72 | - Run reconstruction: 73 | 74 | ```bash 75 | python run_reconstruction.py \ 76 | -c pretrained/E2VID_lightweight.pth.tar \ 77 | -i data/dynamic_6dof.zip \ 78 | --auto_hdr \ 79 | --display \ 80 | --show_events 81 | ``` 82 | 83 | ## Parameters 84 | 85 | Below is a description of the most important parameters: 86 | 87 | #### Main parameters 88 | 89 | - ``--window_size`` / ``-N`` (default: None) Number of events per window. This is the parameter that has the most influence of the image reconstruction quality. If set to None, this number will be automatically computed based on the sensor size, as N = width * height * num_events_per_pixel (see description of that parameter below). Ignored if `--fixed_duration` is set. 90 | - ``--fixed_duration`` (default: False) If True, will use windows of events with a fixed duration (i.e. a fixed output frame rate). 91 | - ``--window_duration`` / ``-T`` (default: 33 ms) Duration of each event window, in milliseconds. The value of this parameter has strong influence on the image reconstruction quality. Its value may need to be adapted to the dynamics of the scene. Ignored if `--fixed_duration` is not set. 92 | - ``--Imin`` (default: 0.0), `--Imax` (default: 1.0): linear tone mapping is performed by normalizing the output image as follows: `I = (I - Imin) / (Imax - Imin)`. If `--auto_hdr` is set to True, `--Imin` and `--Imax` will be automatically computed as the min (resp. max) intensity values. 93 | - ``--auto_hdr`` (default: False) Automatically compute `--Imin` and `--Imax`. Disabled when `--color` is set. 94 | - ``--color`` (default: False): if True, will perform color reconstruction as described in the paper. Only use this with a [color event camera](http://rpg.ifi.uzh.ch/CED.html) such as the Color DAVIS346. 95 | 96 | #### Output parameters 97 | 98 | - ``--output_folder``: path of the output folder. If not set, the image reconstructions will not be saved to disk. 99 | - ``--dataset_name``: name of the output folder directory (default: 'reconstruction'). 100 | 101 | #### Display parameters 102 | 103 | - ``--display`` (default: False): display the video reconstruction in real-time in an OpenCV window. 104 | - ``--show_events`` (default: False): show the input events side-by-side with the reconstruction. If ``--output_folder`` is set, the previews will also be saved to disk in ``/path/to/output/folder/events``. 105 | 106 | #### Additional parameters 107 | 108 | - ``--num_events_per_pixel`` (default: 0.35): Parameter used to automatically estimate the window size based on the sensor size. The value of 0.35 was chosen to correspond to ~ 15,000 events on a 240x180 sensor such as the DAVIS240C. 109 | - ``--no-normalize`` (default: False): Disable event tensor normalization: this will improve speed a bit, but might degrade the image quality a bit. 110 | - ``--no-recurrent`` (default: False): Disable the recurrent connection (i.e. do not maintain a state). For experimenting only, the results will be flickering a lot. 111 | - ``--hot_pixels_file`` (default: None): Path to a file specifying the locations of hot pixels (such a file can be obtained with [this tool](https://github.com/cedric-scheerlinck/dvs_tools/tree/master/dvs_hot_pixel_filter) for example). These pixels will be ignored (i.e. zeroed out in the event tensors). 112 | 113 | ## Example datasets 114 | 115 | We provide a list of example (publicly available) event datasets to get started with E2VID. 116 | 117 | - [High Speed (gun shooting!) and HDR Dataset](http://rpg.ifi.uzh.ch/E2VID.html) 118 | - [Event Camera Dataset](http://rpg.ifi.uzh.ch/data/E2VID/datasets/ECD_IJRR17/) 119 | - [Bardow et al., CVPR'16](http://rpg.ifi.uzh.ch/data/E2VID/datasets/SOFIE_CVPR16/) 120 | - [Scherlinck et al., ACCV'18](http://rpg.ifi.uzh.ch/data/E2VID/datasets/HF_ACCV18/) 121 | - [Color event sequences from the CED dataset Scheerlinck et al., CVPR'18](http://rpg.ifi.uzh.ch/data/E2VID/datasets/CED_CVPRW19/) 122 | 123 | ## Working with ROS 124 | 125 | Because PyTorch recommends Python 3 and ROS is only compatible with Python2, it is not straightforward to have the PyTorch reconstruction code and ROS code running in the same environment. 126 | To make things easy, the reconstruction code we provide has no dependency on ROS, and simply read events from a text file or ZIP file. 127 | We provide convenience functions to convert ROS bags (a popular format for event datasets) into event text files. 128 | In addition, we also provide scripts to convert a folder containing image reconstructions back to a rosbag (or to append image reconstructions to an existing rosbag). 129 | 130 | **Note**: it is **not** necessary to have a sourced conda environment to run the following scripts. However, [ROS](https://www.ros.org/) needs to be installed and sourced. 131 | 132 | ### rosbag -> events.txt 133 | 134 | To extract the events from a rosbag to a zip file containing the event data: 135 | 136 | ```bash 137 | python scripts/extract_events_from_rosbag.py /path/to/rosbag.bag \ 138 | --output_folder=/path/to/output/folder \ 139 | --event_topic=/dvs/events 140 | ``` 141 | 142 | ### image reconstruction folder -> rosbag 143 | 144 | ```bash 145 | python scripts/image_folder_to_rosbag.py \ 146 | --datasets dynamic_6dof \ 147 | --image_folder /path/to/image/folder \ 148 | --output_folder /path/to/output_folder \ 149 | --image_topic /dvs/image_reconstructed 150 | ``` 151 | 152 | ### Append image_reconstruction_folder to an existing rosbag 153 | 154 | ```bash 155 | cd scripts 156 | python embed_reconstructed_images_in_rosbag.py \ 157 | --rosbag_folder /path/to/rosbag/folder \ 158 | --datasets dynamic_6dof \ 159 | --image_folder /path/to/image/folder \ 160 | --output_folder /path/to/output_folder \ 161 | --image_topic /dvs/image_reconstructed 162 | ``` 163 | 164 | ### Generating a video reconstruction (with a fixed framerate) 165 | 166 | It can be convenient to convert an image folder to a video with a fixed framerate (for example for use in a video editing tool). 167 | You can proceed as follows: 168 | 169 | ```bash 170 | export FRAMERATE=30 171 | python resample_reconstructions.py -i /path/to/input_folder -o /tmp/resampled -r $FRAMERATE 172 | ffmpeg -framerate $FRAMERATE -i /tmp/resampled/frame_%010d.png video_"$FRAMERATE"Hz.mp4 173 | ``` 174 | 175 | ## Acknowledgements 176 | 177 | This code borrows from the following open source projects, whom we would like to thank: 178 | 179 | - [pytorch-template](https://github.com/victoresque/pytorch-template) 180 | -------------------------------------------------------------------------------- /e2vid/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import * 2 | -------------------------------------------------------------------------------- /e2vid/base/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | def __init__(self, config): 11 | super(BaseModel, self).__init__() 12 | self.config = config 13 | self.logger = logging.getLogger(self.__class__.__name__) 14 | 15 | def forward(self, *input): 16 | """ 17 | Forward pass logic 18 | 19 | :return: Model output 20 | """ 21 | raise NotImplementedError 22 | 23 | def summary(self): 24 | """ 25 | Model summary 26 | """ 27 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 28 | params = sum([np.prod(p.size()) for p in model_parameters]) 29 | self.logger.info('Trainable parameters: {}'.format(params)) 30 | self.logger.info(self) 31 | -------------------------------------------------------------------------------- /e2vid/image_reconstructor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | from e2vid.model.model import * 5 | from e2vid.utils.inference_utils import CropParameters, EventPreprocessor, IntensityRescaler, ImageFilter, ImageDisplay, \ 6 | ImageWriter, UnsharpMaskFilter 7 | from e2vid.utils.inference_utils import upsample_color_image, \ 8 | merge_channels_into_color_image # for color reconstruction 9 | from e2vid.utils.util import robust_min, robust_max 10 | from e2vid.utils.timers import CudaTimer, cuda_timers 11 | from os.path import join 12 | from collections import deque 13 | import torchvision.transforms as transforms 14 | import albumentations as A 15 | from PIL import Image 16 | 17 | 18 | class ImageReconstructor: 19 | def __init__(self, model, height, width, num_bins, device, options, augmentation=False, standardization=False): 20 | 21 | self.model = model 22 | self.use_gpu = options.use_gpu 23 | self.device = device 24 | self.height = height 25 | self.width = width 26 | self.num_bins = num_bins 27 | 28 | self.standardization = standardization 29 | 30 | self.augmentation = augmentation 31 | if self.augmentation: 32 | self.transform_a = A.Compose([ 33 | A.GaussNoise(p=0.2), 34 | A.RandomBrightnessContrast(p=0.5), 35 | A.OneOf( 36 | [ 37 | A.Sharpen(p=1), 38 | A.Blur(blur_limit=3, p=1), 39 | A.MotionBlur(blur_limit=3, p=1), 40 | ], 41 | p=0.5, 42 | ) 43 | ]) 44 | self.img_transform = transforms.Compose([ 45 | transforms.Grayscale(), 46 | transforms.ToTensor() 47 | ]) 48 | 49 | self.initialize(self.height, self.width, options) 50 | 51 | def initialize(self, height, width, options): 52 | # print('== Image reconstruction == ') 53 | # print('Image size: {}x{}'.format(self.height, self.width)) 54 | 55 | self.no_recurrent = options.no_recurrent 56 | if self.no_recurrent: 57 | print('!!Recurrent connection disabled!!') 58 | 59 | self.perform_color_reconstruction = options.color # whether to perform color reconstruction (only use this with the DAVIS346color) 60 | if self.perform_color_reconstruction: 61 | if options.auto_hdr: 62 | print('!!Warning: disabling auto HDR for color reconstruction!!') 63 | options.auto_hdr = False # disable auto_hdr for color reconstruction (otherwise, each channel will be normalized independently) 64 | 65 | self.crop = CropParameters(self.width, self.height, self.model.num_encoders) 66 | 67 | self.last_states_for_each_channel = {'grayscale': None} 68 | 69 | if self.perform_color_reconstruction: 70 | self.crop_halfres = CropParameters(int(width / 2), int(height / 2), 71 | self.model.num_encoders) 72 | for channel in ['R', 'G', 'B', 'W']: 73 | self.last_states_for_each_channel[channel] = None 74 | 75 | self.event_preprocessor = EventPreprocessor(options) 76 | self.intensity_rescaler = IntensityRescaler(options) 77 | self.image_filter = ImageFilter(options) 78 | self.unsharp_mask_filter = UnsharpMaskFilter(options, device=self.device) 79 | # self.image_writer = ImageWriter(options) 80 | # self.image_display = ImageDisplay(options) 81 | 82 | def update_reconstruction(self, event_tensor, event_tensor_id=None, stamp=None): 83 | with torch.no_grad(): 84 | 85 | with CudaTimer('Reconstruction'): 86 | 87 | with CudaTimer('NumPy (CPU) -> Tensor (GPU)'): 88 | events = event_tensor 89 | events = events.to(self.device) 90 | 91 | events = self.event_preprocessor(events) 92 | 93 | # Resize tensor to [1 x C x crop_size x crop_size] by applying zero padding 94 | events_for_each_channel = {'grayscale': self.crop.pad(events)} 95 | reconstructions_for_each_channel = {} 96 | # if self.perform_color_reconstruction: 97 | # events_for_each_channel['R'] = self.crop_halfres.pad(events[:, :, 0::2, 0::2]) 98 | # events_for_each_channel['G'] = self.crop_halfres.pad(events[:, :, 0::2, 1::2]) 99 | # events_for_each_channel['W'] = self.crop_halfres.pad(events[:, :, 1::2, 0::2]) 100 | # events_for_each_channel['B'] = self.crop_halfres.pad(events[:, :, 1::2, 1::2]) 101 | 102 | # Reconstruct new intensity image for each channel (grayscale + RGBW if color reconstruction is enabled) 103 | for channel in events_for_each_channel.keys(): 104 | with CudaTimer('Inference'): 105 | new_predicted_frame, states, latent = self.model(events_for_each_channel[channel], 106 | self.last_states_for_each_channel[channel]) 107 | 108 | if self.no_recurrent: 109 | self.last_states_for_each_channel[channel] = None 110 | else: 111 | self.last_states_for_each_channel[channel] = states 112 | 113 | # Output reconstructed image 114 | # crop = self.crop if channel == 'grayscale' else self.crop_halfres 115 | 116 | # Unsharp mask (on GPU) 117 | # new_predicted_frame = self.unsharp_mask_filter(new_predicted_frame) 118 | 119 | # Intensity rescaler (on GPU) 120 | # new_predicted_frame = self.intensity_rescaler(new_predicted_frame) 121 | 122 | with CudaTimer('Tensor (GPU) -> NumPy (CPU)'): 123 | reconstructions_for_each_channel[channel] = new_predicted_frame 124 | # reconstructions_for_each_channel[channel] = new_predicted_frame.cpu().numpy() 125 | 126 | # if self.perform_color_reconstruction: 127 | # out = merge_channels_into_color_image(reconstructions_for_each_channel) 128 | # else: 129 | out = reconstructions_for_each_channel['grayscale'] 130 | 131 | if self.standardization: 132 | batch_size, height, width = out.size(0), out.size(2), out.size(3) 133 | out = out.view(out.size(0), -1) 134 | out -= out.min(1, keepdim=True)[0] 135 | out /= out.max(1, keepdim=True)[0] 136 | out = out.view(batch_size, 1, height, width) 137 | 138 | # Imin = torch.min(out).item() 139 | # Imax = torch.max(out).item() 140 | # out = 255.0 * (out - Imin) / (Imax - Imin) 141 | # out.clamp_(0.0, 255.0) 142 | # out = out.byte() # convert to 8-bit tensor 143 | # out = out.float().div(255) 144 | 145 | # mean = [0.5371 for i in range(out.shape[0])] 146 | # std = [0.1540 for i in range(out.shape[0])] 147 | # stand_transform = transforms.Normalize(mean=mean, std=std) 148 | # out = stand_transform(out.squeeze(1)).unsqueeze(1) 149 | # out = torch.clamp(out, min=-1.0, max=1.0) 150 | # out = (out + 1.0) / 2.0 151 | 152 | if self.augmentation: 153 | for i in range(out.shape[0]): 154 | img_aug = out[i].cpu() 155 | img_aug = transforms.ToPILImage()(img_aug) 156 | img_aug = np.array(img_aug) 157 | img_aug = self.transform_a(image=img_aug)["image"] 158 | img_aug = Image.fromarray(img_aug.astype('uint8')).convert('RGB') 159 | out[i] = self.img_transform(img_aug).to(self.device) 160 | # Post-processing, e.g bilateral filter (on CPU) 161 | # out = torch.from_numpy(self.image_filter(out)).to(self.device) 162 | 163 | return out, states, latent 164 | 165 | 166 | class PostProcessor: 167 | def __init__(self, device, options): 168 | self.device = device 169 | self.unsharp_mask_filter = UnsharpMaskFilter(options, device=self.device) 170 | self.intensity_rescaler = IntensityRescaler(options) 171 | self.image_filter = ImageFilter(options) 172 | 173 | def process(self, new_predicted_frame): 174 | with torch.no_grad(): 175 | # Unsharp mask (on GPU) 176 | new_predicted_frame = self.unsharp_mask_filter(new_predicted_frame) 177 | 178 | # Intensity rescaler (on GPU) 179 | new_predicted_frame = self.intensity_rescaler(new_predicted_frame) 180 | 181 | out = new_predicted_frame.cpu().numpy() 182 | 183 | # Post-processing, e.g bilateral filter (on CPU) 184 | out = torch.from_numpy(self.image_filter(out)).to(self.device) 185 | return out 186 | -------------------------------------------------------------------------------- /e2vid/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/e2vid/model/__init__.py -------------------------------------------------------------------------------- /e2vid/model/model.py: -------------------------------------------------------------------------------- 1 | from e2vid.base import BaseModel 2 | import torch.nn as nn 3 | import torch 4 | from e2vid.model.unet import UNet, UNetRecurrent, UNetDecoder, UNetTask 5 | from os.path import join 6 | from e2vid.model.submodules import ConvLSTM, ResidualBlock, ConvLayer, UpsampleConvLayer, TransposedConvLayer 7 | 8 | 9 | class BaseE2VID(BaseModel): 10 | def __init__(self, config): 11 | super().__init__(config) 12 | 13 | assert('num_bins' in config) 14 | self.num_bins = int(config['num_bins']) # number of bins in the voxel grid event tensor 15 | 16 | try: 17 | self.skip_type = str(config['skip_type']) 18 | except KeyError: 19 | self.skip_type = 'sum' 20 | 21 | try: 22 | self.num_encoders = int(config['num_encoders']) 23 | except KeyError: 24 | self.num_encoders = 4 25 | 26 | try: 27 | self.base_num_channels = int(config['base_num_channels']) 28 | except KeyError: 29 | self.base_num_channels = 32 30 | 31 | try: 32 | self.num_residual_blocks = int(config['num_residual_blocks']) 33 | except KeyError: 34 | self.num_residual_blocks = 2 35 | 36 | try: 37 | self.norm = str(config['norm']) 38 | except KeyError: 39 | self.norm = None 40 | 41 | try: 42 | self.use_upsample_conv = bool(config['use_upsample_conv']) 43 | except KeyError: 44 | self.use_upsample_conv = True 45 | 46 | 47 | class E2VID(BaseE2VID): 48 | def __init__(self, config): 49 | super(E2VID, self).__init__(config) 50 | 51 | self.unet = UNet(num_input_channels=self.num_bins, 52 | num_output_channels=1, 53 | skip_type=self.skip_type, 54 | activation='sigmoid', 55 | num_encoders=self.num_encoders, 56 | base_num_channels=self.base_num_channels, 57 | num_residual_blocks=self.num_residual_blocks, 58 | norm=self.norm, 59 | use_upsample_conv=self.use_upsample_conv) 60 | 61 | def forward(self, event_tensor, prev_states=None): 62 | """ 63 | :param event_tensor: N x num_bins x H x W 64 | :return: a predicted image of size N x 1 x H x W, taking values in [0,1]. 65 | """ 66 | return self.unet.forward(event_tensor), None 67 | 68 | 69 | class E2VIDRecurrent(BaseE2VID): 70 | """ 71 | Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU. 72 | """ 73 | 74 | def __init__(self, config): 75 | super(E2VIDRecurrent, self).__init__(config) 76 | 77 | try: 78 | self.recurrent_block_type = str(config['recurrent_block_type']) 79 | except KeyError: 80 | self.recurrent_block_type = 'convlstm' # or 'convgru' 81 | 82 | self.unetrecurrent = UNetRecurrent(num_input_channels=self.num_bins, 83 | num_output_channels=1, 84 | skip_type=self.skip_type, 85 | recurrent_block_type=self.recurrent_block_type, 86 | activation='sigmoid', 87 | num_encoders=self.num_encoders, 88 | base_num_channels=self.base_num_channels, 89 | num_residual_blocks=self.num_residual_blocks, 90 | norm=self.norm, 91 | use_upsample_conv=self.use_upsample_conv) 92 | 93 | def forward(self, event_tensor, prev_states): 94 | """ 95 | :param event_tensor: N x num_bins x H x W 96 | :param prev_states: previous ConvLSTM state for each encoder module 97 | :return: reconstructed image, taking values in [0,1]. 98 | """ 99 | img_pred, states, latent = self.unetrecurrent.forward(event_tensor, prev_states) 100 | return img_pred, states, latent 101 | 102 | class E2VIDDecoder(BaseE2VID): 103 | """ 104 | Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU. 105 | """ 106 | 107 | def __init__(self, config): 108 | super(E2VIDDecoder, self).__init__(config) 109 | 110 | try: 111 | self.recurrent_block_type = str(config['recurrent_block_type']) 112 | except KeyError: 113 | self.recurrent_block_type = 'convlstm' # or 'convgru' 114 | 115 | self.unetrecurrent = UNetDecoder(num_input_channels=self.num_bins, 116 | num_output_channels=1, 117 | skip_type=self.skip_type, 118 | recurrent_block_type=self.recurrent_block_type, 119 | activation='sigmoid', 120 | num_encoders=self.num_encoders, 121 | base_num_channels=self.base_num_channels, 122 | num_residual_blocks=self.num_residual_blocks, 123 | norm=self.norm, 124 | use_upsample_conv=self.use_upsample_conv) 125 | 126 | def forward(self, x, blocks, head): 127 | """ 128 | :param event_tensor: N x num_bins x H x W 129 | :param prev_states: previous ConvLSTM state for each encoder module 130 | :return: reconstructed image, taking values in [0,1]. 131 | """ 132 | img_pred = self.unetrecurrent.forward(x, blocks, head) 133 | return img_pred 134 | 135 | class E2VIDTask(BaseE2VID): 136 | """ 137 | Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU. 138 | """ 139 | 140 | def __init__(self, config): 141 | super(E2VIDTask, self).__init__(config) 142 | 143 | try: 144 | self.recurrent_block_type = str(config['recurrent_block_type']) 145 | except KeyError: 146 | self.recurrent_block_type = 'convlstm' # or 'convgru' 147 | 148 | self.unetrecurrent = UNetTask(num_input_channels=self.num_bins, 149 | num_output_channels=13, 150 | skip_type=self.skip_type, 151 | recurrent_block_type=self.recurrent_block_type, 152 | activation='sigmoid', 153 | num_encoders=self.num_encoders, 154 | base_num_channels=self.base_num_channels, 155 | num_residual_blocks=self.num_residual_blocks, 156 | norm=self.norm, 157 | use_upsample_conv=self.use_upsample_conv) 158 | 159 | def forward(self, input_dict): 160 | """ 161 | :param event_tensor: N x num_bins x H x W 162 | :param prev_states: previous ConvLSTM state for each encoder module 163 | :return: reconstructed image, taking values in [0,1]. 164 | """ 165 | semseg_pred = self.unetrecurrent.forward(input_dict) 166 | return semseg_pred 167 | -------------------------------------------------------------------------------- /e2vid/model/submodules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | from torch.nn import init 5 | 6 | 7 | class ConvLayer(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): 9 | super(ConvLayer, self).__init__() 10 | 11 | bias = False if norm == 'BN' else True 12 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 13 | if activation is not None: 14 | self.activation = getattr(torch, activation, 'relu') 15 | else: 16 | self.activation = None 17 | 18 | self.norm = norm 19 | if norm == 'BN': 20 | self.norm_layer = nn.BatchNorm2d(out_channels) 21 | elif norm == 'IN': 22 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 23 | 24 | def forward(self, x): 25 | out = self.conv2d(x) 26 | if self.norm in ['BN', 'IN']: 27 | out = self.norm_layer(out) 28 | 29 | if self.activation is not None: 30 | out = self.activation(out) 31 | return out 32 | 33 | 34 | class TransposedConvLayer(nn.Module): 35 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): 36 | super(TransposedConvLayer, self).__init__() 37 | 38 | bias = False if norm == 'BN' else True 39 | self.transposed_conv2d = nn.ConvTranspose2d( 40 | in_channels, out_channels, kernel_size, stride=2, padding=padding, output_padding=1, bias=bias) 41 | 42 | if activation is not None: 43 | self.activation = getattr(torch, activation, 'relu') 44 | else: 45 | self.activation = None 46 | 47 | self.norm = norm 48 | if norm == 'BN': 49 | self.norm_layer = nn.BatchNorm2d(out_channels) 50 | elif norm == 'IN': 51 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 52 | 53 | def forward(self, x): 54 | out = self.transposed_conv2d(x) 55 | 56 | if self.norm in ['BN', 'IN']: 57 | out = self.norm_layer(out) 58 | 59 | if self.activation is not None: 60 | out = self.activation(out) 61 | 62 | return out 63 | 64 | 65 | class UpsampleConvLayer(nn.Module): 66 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): 67 | super(UpsampleConvLayer, self).__init__() 68 | 69 | bias = False if norm == 'BN' else True 70 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 71 | 72 | if activation is not None: 73 | self.activation = getattr(torch, activation, 'relu') 74 | else: 75 | self.activation = None 76 | 77 | self.norm = norm 78 | if norm == 'BN': 79 | self.norm_layer = nn.BatchNorm2d(out_channels) 80 | elif norm == 'IN': 81 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 82 | 83 | def forward(self, x): 84 | x_upsampled = f.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 85 | out = self.conv2d(x_upsampled) 86 | 87 | if self.norm in ['BN', 'IN']: 88 | out = self.norm_layer(out) 89 | 90 | if self.activation is not None: 91 | out = self.activation(out) 92 | 93 | return out 94 | 95 | 96 | class RecurrentConvLayer(nn.Module): 97 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 98 | recurrent_block_type='convlstm', activation='relu', norm=None): 99 | super(RecurrentConvLayer, self).__init__() 100 | 101 | assert(recurrent_block_type in ['convlstm', 'convgru']) 102 | self.recurrent_block_type = recurrent_block_type 103 | if self.recurrent_block_type == 'convlstm': 104 | RecurrentBlock = ConvLSTM 105 | else: 106 | RecurrentBlock = ConvGRU 107 | self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm) 108 | self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3) 109 | 110 | def forward(self, x, prev_state): 111 | x = self.conv(x) 112 | # print('x', x) 113 | state = self.recurrent_block(x, prev_state) 114 | x = state[0] if self.recurrent_block_type == 'convlstm' else state 115 | return x, state 116 | 117 | 118 | class DownsampleRecurrentConvLayer(nn.Module): 119 | def __init__(self, in_channels, out_channels, kernel_size=3, recurrent_block_type='convlstm', padding=0, activation='relu'): 120 | super(DownsampleRecurrentConvLayer, self).__init__() 121 | 122 | self.activation = getattr(torch, activation, 'relu') 123 | 124 | assert(recurrent_block_type in ['convlstm', 'convgru']) 125 | self.recurrent_block_type = recurrent_block_type 126 | if self.recurrent_block_type == 'convlstm': 127 | RecurrentBlock = ConvLSTM 128 | else: 129 | RecurrentBlock = ConvGRU 130 | self.recurrent_block = RecurrentBlock(input_size=in_channels, hidden_size=out_channels, kernel_size=kernel_size) 131 | 132 | def forward(self, x, prev_state): 133 | state = self.recurrent_block(x, prev_state) 134 | x = state[0] if self.recurrent_block_type == 'convlstm' else state 135 | x = f.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) 136 | return self.activation(x), state 137 | 138 | 139 | # Residual block 140 | class ResidualBlock(nn.Module): 141 | def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm=None): 142 | super(ResidualBlock, self).__init__() 143 | bias = False if norm == 'BN' else True 144 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=bias) 145 | self.norm = norm 146 | if norm == 'BN': 147 | self.bn1 = nn.BatchNorm2d(out_channels) 148 | self.bn2 = nn.BatchNorm2d(out_channels) 149 | elif norm == 'IN': 150 | self.bn1 = nn.InstanceNorm2d(out_channels) 151 | self.bn2 = nn.InstanceNorm2d(out_channels) 152 | 153 | self.relu = nn.ReLU(inplace=True) 154 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 155 | self.downsample = downsample 156 | 157 | def forward(self, x): 158 | residual = x 159 | out = self.conv1(x) 160 | if self.norm in ['BN', 'IN']: 161 | out = self.bn1(out) 162 | out = self.relu(out) 163 | out = self.conv2(out) 164 | if self.norm in ['BN', 'IN']: 165 | out = self.bn2(out) 166 | 167 | if self.downsample: 168 | residual = self.downsample(x) 169 | 170 | out += residual 171 | out = self.relu(out) 172 | return out 173 | 174 | 175 | class ConvLSTM(nn.Module): 176 | """Adapted from: https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py """ 177 | 178 | def __init__(self, input_size, hidden_size, kernel_size): 179 | super(ConvLSTM, self).__init__() 180 | 181 | self.input_size = input_size 182 | self.hidden_size = hidden_size 183 | pad = kernel_size // 2 184 | 185 | # cache a tensor filled with zeros to avoid reallocating memory at each inference step if --no-recurrent is enabled 186 | self.zero_tensors = {} 187 | 188 | self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad) 189 | 190 | def forward(self, input_, prev_state=None): 191 | # get batch and spatial sizes 192 | batch_size = input_.data.size()[0] 193 | spatial_size = input_.data.size()[2:] 194 | 195 | # generate empty prev_state, if None is provided 196 | if prev_state is None: 197 | 198 | # create the zero tensor if it has not been created already 199 | state_size = tuple([batch_size, self.hidden_size] + list(spatial_size)) 200 | if state_size not in self.zero_tensors: 201 | # allocate a tensor with size `spatial_size`, filled with zero (if it has not been allocated already) 202 | self.zero_tensors[state_size] = ( 203 | torch.zeros(state_size).to(input_.device), 204 | torch.zeros(state_size).to(input_.device) 205 | ) 206 | 207 | prev_state = self.zero_tensors[tuple(state_size)] 208 | 209 | prev_hidden, prev_cell = prev_state 210 | 211 | # data size is [batch, channel, height, width] 212 | stacked_inputs = torch.cat((input_, prev_hidden), 1) 213 | gates = self.Gates(stacked_inputs) 214 | 215 | # chunk across channel dimension 216 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1) 217 | 218 | # apply sigmoid non linearity 219 | in_gate = torch.sigmoid(in_gate) 220 | remember_gate = torch.sigmoid(remember_gate) 221 | out_gate = torch.sigmoid(out_gate) 222 | 223 | # apply tanh non linearity 224 | cell_gate = torch.tanh(cell_gate) 225 | 226 | # compute current cell and hidden state 227 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate) 228 | hidden = out_gate * torch.tanh(cell) 229 | 230 | return hidden, cell 231 | 232 | 233 | class ConvGRU(nn.Module): 234 | """ 235 | Generate a convolutional GRU cell 236 | Adapted from: https://github.com/jacobkimmel/pytorch_convgru/blob/master/convgru.py 237 | """ 238 | 239 | def __init__(self, input_size, hidden_size, kernel_size): 240 | super().__init__() 241 | padding = kernel_size // 2 242 | self.input_size = input_size 243 | self.hidden_size = hidden_size 244 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 245 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 246 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 247 | 248 | init.orthogonal_(self.reset_gate.weight) 249 | init.orthogonal_(self.update_gate.weight) 250 | init.orthogonal_(self.out_gate.weight) 251 | init.constant_(self.reset_gate.bias, 0.) 252 | init.constant_(self.update_gate.bias, 0.) 253 | init.constant_(self.out_gate.bias, 0.) 254 | 255 | def forward(self, input_, prev_state): 256 | 257 | # get batch and spatial sizes 258 | batch_size = input_.data.size()[0] 259 | spatial_size = input_.data.size()[2:] 260 | 261 | # generate empty prev_state, if None is provided 262 | if prev_state is None: 263 | state_size = [batch_size, self.hidden_size] + list(spatial_size) 264 | prev_state = torch.zeros(state_size).to(input_.device) 265 | 266 | # data size is [batch, channel, height, width] 267 | stacked_inputs = torch.cat([input_, prev_state], dim=1) 268 | update = torch.sigmoid(self.update_gate(stacked_inputs)) 269 | reset = torch.sigmoid(self.reset_gate(stacked_inputs)) 270 | out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) 271 | new_state = prev_state * (1 - update) + out_inputs * update 272 | 273 | return new_state 274 | -------------------------------------------------------------------------------- /e2vid/model/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | from torch.nn import init 5 | from .submodules import ConvLayer, UpsampleConvLayer, TransposedConvLayer, RecurrentConvLayer, ResidualBlock, ConvLSTM, ConvGRU 6 | 7 | 8 | def skip_concat(x1, x2): 9 | return torch.cat([x1, x2], dim=1) 10 | 11 | 12 | def skip_sum(x1, x2): 13 | return x1 + x2 14 | 15 | 16 | class BaseUNet(nn.Module): 17 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', activation='sigmoid', 18 | num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm=None, use_upsample_conv=True): 19 | super(BaseUNet, self).__init__() 20 | 21 | self.num_input_channels = num_input_channels 22 | self.num_output_channels = num_output_channels 23 | self.skip_type = skip_type 24 | self.apply_skip_connection = skip_sum if self.skip_type == 'sum' else skip_concat 25 | self.activation = activation 26 | self.norm = norm 27 | 28 | if use_upsample_conv: 29 | print('Using UpsampleConvLayer (slow, but no checkerboard artefacts)') 30 | self.UpsampleLayer = UpsampleConvLayer 31 | else: 32 | print('Using TransposedConvLayer (fast, with checkerboard artefacts)') 33 | self.UpsampleLayer = TransposedConvLayer 34 | 35 | self.num_encoders = num_encoders 36 | self.base_num_channels = base_num_channels 37 | self.num_residual_blocks = num_residual_blocks 38 | self.max_num_channels = self.base_num_channels * pow(2, self.num_encoders) 39 | 40 | assert(self.num_input_channels > 0) 41 | assert(self.num_output_channels > 0) 42 | 43 | self.encoder_input_sizes = [] 44 | for i in range(self.num_encoders): 45 | self.encoder_input_sizes.append(self.base_num_channels * pow(2, i)) 46 | 47 | self.encoder_output_sizes = [self.base_num_channels * pow(2, i + 1) for i in range(self.num_encoders)] 48 | 49 | self.activation = getattr(torch, self.activation, 'sigmoid') 50 | 51 | def build_resblocks(self): 52 | self.resblocks = nn.ModuleList() 53 | for i in range(self.num_residual_blocks): 54 | self.resblocks.append(ResidualBlock(self.max_num_channels, self.max_num_channels, norm=self.norm)) 55 | 56 | def build_decoders(self): 57 | decoder_input_sizes = list(reversed([self.base_num_channels * pow(2, i + 1) for i in range(self.num_encoders)])) 58 | 59 | self.decoders = nn.ModuleList() 60 | for input_size in decoder_input_sizes: 61 | self.decoders.append(self.UpsampleLayer(input_size if self.skip_type == 'sum' else 2 * input_size, 62 | input_size // 2, 63 | kernel_size=5, padding=2, norm=self.norm)) 64 | 65 | def build_prediction_layer(self): 66 | self.pred = ConvLayer(self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels, 67 | self.num_output_channels, 1, activation=None, norm=self.norm) 68 | 69 | 70 | class UNet(BaseUNet): 71 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', activation='sigmoid', 72 | num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm=None, use_upsample_conv=True): 73 | super(UNet, self).__init__(num_input_channels, num_output_channels, skip_type, activation, 74 | num_encoders, base_num_channels, num_residual_blocks, norm, use_upsample_conv) 75 | 76 | self.head = ConvLayer(self.num_input_channels, self.base_num_channels, 77 | kernel_size=5, stride=1, padding=2) # N x C x H x W -> N x 32 x H x W 78 | 79 | self.encoders = nn.ModuleList() 80 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes): 81 | self.encoders.append(ConvLayer(input_size, output_size, kernel_size=5, 82 | stride=2, padding=2, norm=self.norm)) 83 | 84 | self.build_resblocks() 85 | self.build_decoders() 86 | self.build_prediction_layer() 87 | 88 | def forward(self, x): 89 | """ 90 | :param x: N x num_input_channels x H x W 91 | :return: N x num_output_channels x H x W 92 | """ 93 | 94 | # head 95 | x = self.head(x) 96 | head = x 97 | 98 | # encoder 99 | blocks = [] 100 | for i, encoder in enumerate(self.encoders): 101 | x = encoder(x) 102 | blocks.append(x) 103 | 104 | # residual blocks 105 | for resblock in self.resblocks: 106 | x = resblock(x) 107 | 108 | # decoder 109 | for i, decoder in enumerate(self.decoders): 110 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1])) 111 | 112 | img = self.activation(self.pred(self.apply_skip_connection(x, head))) 113 | 114 | return img 115 | 116 | 117 | class UNetRecurrent(BaseUNet): 118 | """ 119 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block, 120 | such as a ConvLSTM or a ConvGRU. 121 | Symmetric, skip connections on every encoding layer. 122 | """ 123 | 124 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', 125 | recurrent_block_type='convlstm', activation='sigmoid', num_encoders=4, base_num_channels=32, 126 | num_residual_blocks=2, norm=None, use_upsample_conv=True): 127 | super(UNetRecurrent, self).__init__(num_input_channels, num_output_channels, skip_type, activation, 128 | num_encoders, base_num_channels, num_residual_blocks, norm, 129 | use_upsample_conv) 130 | 131 | self.head = ConvLayer(self.num_input_channels, self.base_num_channels, 132 | kernel_size=5, stride=1, padding=2) # N x C x H x W -> N x 32 x H x W 133 | 134 | self.encoders = nn.ModuleList() 135 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes): 136 | self.encoders.append(RecurrentConvLayer(input_size, output_size, 137 | kernel_size=5, stride=2, padding=2, 138 | recurrent_block_type=recurrent_block_type, 139 | norm=self.norm)) 140 | 141 | self.build_resblocks() 142 | self.build_decoders() 143 | self.build_prediction_layer() 144 | 145 | def forward(self, x, prev_states): 146 | """ 147 | :param x: N x num_input_channels x H x W 148 | :param prev_states: previous LSTM states for every encoder layer 149 | :return: N x num_output_channels x H x W 150 | """ 151 | 152 | # head 153 | x = self.head(x) 154 | head = x 155 | 156 | if prev_states is None: 157 | prev_states = [None] * self.num_encoders 158 | 159 | # encoder 160 | blocks = [] 161 | states = [] 162 | for i, encoder in enumerate(self.encoders): 163 | x, state = encoder(x, prev_states[i]) 164 | 165 | blocks.append(x) 166 | states.append(state) 167 | 168 | # residual blocks 169 | for resblock in self.resblocks: 170 | x = resblock(x) 171 | 172 | latent = {1: head, 2: blocks[0], 4: blocks[1], 8: blocks[2]} 173 | 174 | # decoder 175 | for i, decoder in enumerate(self.decoders): 176 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1])) 177 | 178 | # tail 179 | img = self.activation(self.pred(self.apply_skip_connection(x, head))) 180 | 181 | return img, states, latent 182 | 183 | class UNetDecoder(BaseUNet): 184 | """ 185 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block, 186 | such as a ConvLSTM or a ConvGRU. 187 | Symmetric, skip connections on every encoding layer. 188 | """ 189 | 190 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', 191 | recurrent_block_type='convlstm', activation='sigmoid', num_encoders=4, base_num_channels=32, 192 | num_residual_blocks=2, norm=None, use_upsample_conv=True): 193 | super(UNetDecoder, self).__init__(num_input_channels, num_output_channels, skip_type, activation, 194 | num_encoders, base_num_channels, num_residual_blocks, norm, 195 | use_upsample_conv) 196 | 197 | self.build_resblocks() 198 | self.build_decoders() 199 | self.build_prediction_layer() 200 | 201 | def forward(self, x, blocks, head): 202 | """ 203 | :param x: N x num_input_channels x H x W 204 | :param prev_states: previous LSTM states for every encoder layer 205 | :return: N x num_output_channels x H x W 206 | """ 207 | 208 | # residual blocks 209 | for resblock in self.resblocks: 210 | x = resblock(x) 211 | 212 | # decoder 213 | for i, decoder in enumerate(self.decoders): 214 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1])) 215 | 216 | # tail 217 | img = self.activation(self.pred(self.apply_skip_connection(x, head))) 218 | 219 | return img 220 | 221 | 222 | class UNetTask(BaseUNet): 223 | """ 224 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block, 225 | such as a ConvLSTM or a ConvGRU. 226 | Symmetric, skip connections on every encoding layer. 227 | """ 228 | 229 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', 230 | recurrent_block_type='convlstm', activation='sigmoid', num_encoders=4, base_num_channels=32, 231 | num_residual_blocks=2, norm=None, use_upsample_conv=True): 232 | super(UNetTask, self).__init__(num_input_channels, num_output_channels, skip_type, activation, 233 | num_encoders, base_num_channels, num_residual_blocks, norm, 234 | use_upsample_conv) 235 | self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 236 | self.build_resblocks() 237 | self.build_decoders() 238 | self.build_prediction_layer_semseg() 239 | 240 | def build_prediction_layer_semseg(self): 241 | self.pred_semseg = torch.nn.Sequential(ConvLayer(self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels, 242 | self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels, 1, activation='relu', norm=self.norm), 243 | ConvLayer( 244 | self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels, 245 | self.num_output_channels, 1, activation=None, norm=None) 246 | ) 247 | 248 | def update_skip_dict(self, skips, x, sz_in): 249 | rem, scale = sz_in % x.shape[3], sz_in // x.shape[3] 250 | assert rem == 0 251 | skips[scale] = x 252 | 253 | def forward(self, input_dict): 254 | """ 255 | :param x: N x num_input_channels x H x W 256 | :param prev_states: previous LSTM states for every encoder layer 257 | :return: N x num_output_channels x H x W 258 | """ 259 | sz_in = input_dict[1].shape[3] 260 | 261 | x = input_dict[8] 262 | out = {8: x} 263 | blocks = [input_dict[2], input_dict[4], input_dict[8]] 264 | head = torch.zeros((input_dict[2].shape[0], 32, 256, 512)).to(self.device) 265 | 266 | # residual blocks 267 | for resblock in self.resblocks: 268 | x = resblock(x) 269 | 270 | # decoder 271 | for i, decoder in enumerate(self.decoders): 272 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1])) 273 | self.update_skip_dict(out, x, sz_in) 274 | 275 | # tail 276 | pred = self.pred_semseg(self.apply_skip_connection(x, head)) 277 | self.update_skip_dict(out, pred, sz_in) 278 | 279 | return out 280 | -------------------------------------------------------------------------------- /e2vid/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/e2vid/options/__init__.py -------------------------------------------------------------------------------- /e2vid/options/inference_options.py: -------------------------------------------------------------------------------- 1 | def set_inference_options(parser): 2 | 3 | parser.add_argument('-o', '--output_folder', default=None, type=str) # if None, will not write the images to disk 4 | parser.add_argument('--dataset_name', default='reconstruction', type=str) 5 | 6 | parser.add_argument('--use_gpu', dest='use_gpu', action='store_true') 7 | parser.set_defaults(use_gpu=True) 8 | 9 | """ Display """ 10 | parser.add_argument('--display', dest='display', action='store_true') 11 | parser.set_defaults(display=False) 12 | 13 | parser.add_argument('--show_events', dest='show_events', action='store_true') 14 | parser.set_defaults(show_events=False) 15 | 16 | parser.add_argument('--event_display_mode', default='red-blue', type=str, 17 | help="Event display mode ('red-blue' or 'grayscale')") 18 | 19 | parser.add_argument('--num_bins_to_show', default=-1, type=int, 20 | help="Number of bins of the voxel grid to show when displaying events (-1 means show all the bins).") 21 | 22 | parser.add_argument('--display_border_crop', default=0, type=int, 23 | help="Remove the outer border of size display_border_crop before displaying image.") 24 | 25 | parser.add_argument('--display_wait_time', default=1, type=int, 26 | help="Time to wait after each call to cv2.imshow, in milliseconds (default: 1)") 27 | 28 | """ Post-processing / filtering """ 29 | 30 | # (optional) path to a text file containing the locations of hot pixels to ignore 31 | parser.add_argument('--hot_pixels_file', default=None, type=str) 32 | 33 | # (optional) unsharp mask 34 | parser.add_argument('--unsharp_mask_amount', default=0.3, type=float) 35 | parser.add_argument('--unsharp_mask_sigma', default=1.0, type=float) 36 | 37 | # (optional) bilateral filter 38 | parser.add_argument('--bilateral_filter_sigma', default=0.0, type=float) 39 | 40 | # (optional) flip the event tensors vertically 41 | parser.add_argument('--flip', dest='flip', action='store_true') 42 | parser.set_defaults(flip=False) 43 | 44 | """ Tone mapping (i.e. rescaling of the image intensities)""" 45 | parser.add_argument('--Imin', default=0.0, type=float, 46 | help="Min intensity for intensity rescaling (linear tone mapping).") 47 | parser.add_argument('--Imax', default=1.0, type=float, 48 | help="Max intensity value for intensity rescaling (linear tone mapping).") 49 | parser.add_argument('--auto_hdr', dest='auto_hdr', action='store_true', 50 | help="If True, will compute Imin and Imax automatically.") 51 | parser.set_defaults(auto_hdr=False) 52 | parser.add_argument('--auto_hdr_median_filter_size', default=10, type=int, 53 | help="Size of the median filter window used to smooth temporally Imin and Imax") 54 | 55 | """ Perform color reconstruction? (only use this flag with the DAVIS346color) """ 56 | parser.add_argument('--color', dest='color', action='store_true') 57 | parser.set_defaults(color=False) 58 | 59 | """ Advanced parameters """ 60 | # disable normalization of input event tensors (saves a bit of time, but may produce slightly worse results) 61 | parser.add_argument('--no-normalize', dest='no_normalize', action='store_true') 62 | parser.set_defaults(no_normalize=False) 63 | 64 | # disable recurrent connection (will severely degrade the results; for testing purposes only) 65 | parser.add_argument('--no-recurrent', dest='no_recurrent', action='store_true') 66 | parser.set_defaults(no_recurrent=False) 67 | -------------------------------------------------------------------------------- /e2vid/pretrained/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/e2vid/pretrained/.gitignore -------------------------------------------------------------------------------- /e2vid/run_reconstruction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from e2vid.utils.loading_utils import load_model, get_device 3 | import numpy as np 4 | import argparse 5 | import pandas as pd 6 | from e2vid.utils.event_readers import FixedSizeEventReader, FixedDurationEventReader 7 | from e2vid.utils.inference_utils import events_to_voxel_grid, events_to_voxel_grid_pytorch 8 | from e2vid.utils.timers import Timer 9 | import time 10 | from e2vid.image_reconstructor import ImageReconstructor 11 | from e2vid.options.inference_options import set_inference_options 12 | 13 | 14 | if __name__ == "__main__": 15 | 16 | parser = argparse.ArgumentParser( 17 | description='Evaluating a trained network') 18 | parser.add_argument('-c', '--path_to_model', required=True, type=str, 19 | help='path to model weights') 20 | parser.add_argument('-i', '--input_file', required=True, type=str) 21 | parser.add_argument('--fixed_duration', dest='fixed_duration', action='store_true') 22 | parser.set_defaults(fixed_duration=False) 23 | parser.add_argument('-N', '--window_size', default=None, type=int, 24 | help="Size of each event window, in number of events. Ignored if --fixed_duration=True") 25 | parser.add_argument('-T', '--window_duration', default=33.33, type=float, 26 | help="Duration of each event window, in milliseconds. Ignored if --fixed_duration=False") 27 | parser.add_argument('--num_events_per_pixel', default=0.35, type=float, 28 | help='in case N (window size) is not specified, it will be \ 29 | automatically computed as N = width * height * num_events_per_pixel') 30 | parser.add_argument('--skipevents', default=0, type=int) 31 | parser.add_argument('--suboffset', default=0, type=int) 32 | parser.add_argument('--compute_voxel_grid_on_cpu', dest='compute_voxel_grid_on_cpu', action='store_true') 33 | parser.set_defaults(compute_voxel_grid_on_cpu=False) 34 | 35 | set_inference_options(parser) 36 | 37 | args = parser.parse_args() 38 | 39 | # Read sensor size from the first first line of the event file 40 | path_to_events = args.input_file 41 | 42 | header = pd.read_csv(path_to_events, delim_whitespace=True, header=None, names=['width', 'height'], 43 | dtype={'width': np.int, 'height': np.int}, 44 | nrows=1) 45 | width, height = header.values[0] 46 | print('Sensor size: {} x {}'.format(width, height)) 47 | 48 | # Load model 49 | model = load_model(args.path_to_model) 50 | device = get_device(args.use_gpu) 51 | 52 | model = model.to(device) 53 | model.eval() 54 | 55 | reconstructor = ImageReconstructor(model, height, width, model.num_bins, args) 56 | 57 | """ Read chunks of events using Pandas """ 58 | 59 | # Loop through the events and reconstruct images 60 | N = args.window_size 61 | if not args.fixed_duration: 62 | if N is None: 63 | N = int(width * height * args.num_events_per_pixel) 64 | print('Will use {} events per tensor (automatically estimated with num_events_per_pixel={:0.2f}).'.format( 65 | N, args.num_events_per_pixel)) 66 | else: 67 | print('Will use {} events per tensor (user-specified)'.format(N)) 68 | mean_num_events_per_pixel = float(N) / float(width * height) 69 | if mean_num_events_per_pixel < 0.1: 70 | print('!!Warning!! the number of events used ({}) seems to be low compared to the sensor size. \ 71 | The reconstruction results might be suboptimal.'.format(N)) 72 | elif mean_num_events_per_pixel > 1.5: 73 | print('!!Warning!! the number of events used ({}) seems to be high compared to the sensor size. \ 74 | The reconstruction results might be suboptimal.'.format(N)) 75 | 76 | initial_offset = args.skipevents 77 | sub_offset = args.suboffset 78 | start_index = initial_offset + sub_offset 79 | 80 | if args.compute_voxel_grid_on_cpu: 81 | print('Will compute voxel grid on CPU.') 82 | 83 | if args.fixed_duration: 84 | event_window_iterator = FixedDurationEventReader(path_to_events, 85 | duration_ms=args.window_duration, 86 | start_index=start_index) 87 | else: 88 | event_window_iterator = FixedSizeEventReader(path_to_events, num_events=N, start_index=start_index) 89 | 90 | with Timer('Processing entire dataset'): 91 | for event_window in event_window_iterator: 92 | print(event_window.shape) 93 | last_timestamp = event_window[-1, 0] 94 | 95 | with Timer('Building event tensor'): 96 | if args.compute_voxel_grid_on_cpu: 97 | event_tensor = events_to_voxel_grid(event_window, 98 | num_bins=model.num_bins, 99 | width=width, 100 | height=height) 101 | event_tensor = torch.from_numpy(event_tensor) 102 | else: 103 | event_tensor = events_to_voxel_grid_pytorch(event_window, 104 | num_bins=model.num_bins, 105 | width=width, 106 | height=height, 107 | device=device) 108 | 109 | num_events_in_window = event_window.shape[0] 110 | reconstructor.update_reconstruction(event_tensor, start_index + num_events_in_window, last_timestamp) 111 | 112 | start_index += num_events_in_window 113 | -------------------------------------------------------------------------------- /e2vid/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/e2vid/utils/__init__.py -------------------------------------------------------------------------------- /e2vid/utils/event_readers.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import zipfile 3 | from os.path import splitext 4 | import numpy as np 5 | from .timers import Timer 6 | 7 | 8 | class FixedSizeEventReader: 9 | """ 10 | Reads events from a '.txt' or '.zip' file, and packages the events into 11 | non-overlapping event windows, each containing a fixed number of events. 12 | """ 13 | 14 | def __init__(self, path_to_event_file, num_events=10000, start_index=0): 15 | print('Will use fixed size event windows with {} events'.format(num_events)) 16 | print('Output frame rate: variable') 17 | self.iterator = pd.read_csv(path_to_event_file, delim_whitespace=True, header=None, 18 | names=['t', 'x', 'y', 'pol'], 19 | dtype={'t': np.float64, 'x': np.int16, 'y': np.int16, 'pol': np.int16}, 20 | engine='c', 21 | skiprows=start_index + 1, chunksize=num_events, nrows=None, memory_map=True) 22 | 23 | def __iter__(self): 24 | return self 25 | 26 | def __next__(self): 27 | with Timer('Reading event window from file'): 28 | event_window = self.iterator.__next__().values 29 | return event_window 30 | 31 | 32 | class FixedDurationEventReader: 33 | """ 34 | Reads events from a '.txt' or '.zip' file, and packages the events into 35 | non-overlapping event windows, each of a fixed duration. 36 | 37 | **Note**: This reader is much slower than the FixedSizeEventReader. 38 | The reason is that the latter can use Pandas' very efficient cunk-based reading scheme implemented in C. 39 | """ 40 | 41 | def __init__(self, path_to_event_file, duration_ms=50.0, start_index=0): 42 | print('Will use fixed duration event windows of size {:.2f} ms'.format(duration_ms)) 43 | print('Output frame rate: {:.1f} Hz'.format(1000.0 / duration_ms)) 44 | file_extension = splitext(path_to_event_file)[1] 45 | assert(file_extension in ['.txt', '.zip']) 46 | self.is_zip_file = (file_extension == '.zip') 47 | 48 | if self.is_zip_file: # '.zip' 49 | self.zip_file = zipfile.ZipFile(path_to_event_file) 50 | files_in_archive = self.zip_file.namelist() 51 | assert(len(files_in_archive) == 1) # make sure there is only one text file in the archive 52 | self.event_file = self.zip_file.open(files_in_archive[0], 'r') 53 | else: 54 | self.event_file = open(path_to_event_file, 'r') 55 | 56 | # ignore header + the first start_index lines 57 | for i in range(1 + start_index): 58 | self.event_file.readline() 59 | 60 | self.last_stamp = None 61 | self.duration_s = duration_ms / 1000.0 62 | 63 | def __iter__(self): 64 | return self 65 | 66 | def __del__(self): 67 | if self.is_zip_file: 68 | self.zip_file.close() 69 | 70 | self.event_file.close() 71 | 72 | def __next__(self): 73 | with Timer('Reading event window from file'): 74 | event_list = [] 75 | for line in self.event_file: 76 | if self.is_zip_file: 77 | line = line.decode("utf-8") 78 | t, x, y, pol = line.split(' ') 79 | t, x, y, pol = float(t), int(x), int(y), int(pol) 80 | event_list.append([t, x, y, pol]) 81 | if self.last_stamp is None: 82 | self.last_stamp = t 83 | if t > self.last_stamp + self.duration_s: 84 | self.last_stamp = t 85 | event_window = np.array(event_list) 86 | return event_window 87 | 88 | raise StopIteration 89 | -------------------------------------------------------------------------------- /e2vid/utils/loading_utils.py: -------------------------------------------------------------------------------- 1 | from e2vid.model.model import * 2 | from collections import OrderedDict 3 | 4 | 5 | def load_model(path_to_model, return_task=False): 6 | print('Loading model {}...'.format(path_to_model)) 7 | raw_model = torch.load(path_to_model) 8 | arch = raw_model['arch'] 9 | 10 | try: 11 | model_type = raw_model['model'] 12 | except KeyError: 13 | model_type = raw_model['config']['model'] 14 | 15 | # instantiate model 16 | model = eval(arch)(model_type) 17 | 18 | # load model weights 19 | model.load_state_dict(raw_model['state_dict']) 20 | 21 | # E2VID Decoder 22 | decoder = E2VIDDecoder(model_type) 23 | decoder.load_state_dict(raw_model['state_dict'], strict=False) 24 | 25 | if return_task: 26 | # E2VID Task 27 | task = E2VIDTask(model_type) 28 | new_dict = copyStateDict(raw_model['state_dict']) 29 | keys = [] 30 | for k, v in new_dict.items(): 31 | if k.startswith('unetrecurrent.pred'): 32 | continue 33 | keys.append(k) 34 | 35 | new_dict = {k: new_dict[k] for k in keys} 36 | task.load_state_dict(new_dict, strict=False) 37 | return model, decoder, task 38 | return model, decoder 39 | 40 | 41 | def get_device(use_gpu): 42 | if use_gpu and torch.cuda.is_available(): 43 | device = torch.device('cuda:0') 44 | else: 45 | device = torch.device('cpu') 46 | print('Device:', device) 47 | 48 | return device 49 | 50 | def copyStateDict(state_dict): 51 | if list(state_dict.keys())[0].startswith('module'): 52 | start_idx = 1 53 | else: 54 | start_idx = 0 55 | new_state_dict = OrderedDict() 56 | for k,v in state_dict.items(): 57 | name = '.'.join(k.split('.')[start_idx:]) 58 | 59 | new_state_dict[name] = v 60 | return new_state_dict 61 | 62 | -------------------------------------------------------------------------------- /e2vid/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def ensure_dir(path): 5 | if not os.path.exists(path): 6 | os.makedirs(path) 7 | -------------------------------------------------------------------------------- /e2vid/utils/timers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import numpy as np 4 | import atexit 5 | 6 | cuda_timers = {} 7 | timers = {} 8 | 9 | 10 | class CudaTimer: 11 | def __init__(self, timer_name=''): 12 | self.timer_name = timer_name 13 | if self.timer_name not in cuda_timers: 14 | cuda_timers[self.timer_name] = [] 15 | 16 | self.start = torch.cuda.Event(enable_timing=True) 17 | self.end = torch.cuda.Event(enable_timing=True) 18 | 19 | def __enter__(self): 20 | self.start.record() 21 | return self 22 | 23 | def __exit__(self, *args): 24 | self.end.record() 25 | torch.cuda.synchronize() 26 | cuda_timers[self.timer_name].append(self.start.elapsed_time(self.end)) 27 | 28 | 29 | class Timer: 30 | def __init__(self, timer_name=''): 31 | self.timer_name = timer_name 32 | if self.timer_name not in timers: 33 | timers[self.timer_name] = [] 34 | 35 | def __enter__(self): 36 | self.start = time.time() 37 | return self 38 | 39 | def __exit__(self, *args): 40 | self.end = time.time() 41 | self.interval = self.end - self.start # measured in seconds 42 | self.interval *= 1000.0 # convert to milliseconds 43 | timers[self.timer_name].append(self.interval) 44 | 45 | 46 | def print_timing_info(): 47 | print('== Timing statistics ==') 48 | for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]: 49 | timing_value = np.mean(np.array(timing_values)) 50 | if timing_value < 1000.0: 51 | print('{}: {:.2f} ms'.format(timer_name, timing_value)) 52 | else: 53 | print('{}: {:.2f} s'.format(timer_name, timing_value / 1000.0)) 54 | 55 | 56 | # this will print all the timer values upon termination of any program that imported this file 57 | atexit.register(print_timing_info) 58 | -------------------------------------------------------------------------------- /e2vid/utils/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import fabs 3 | 4 | 5 | def robust_min(img, p=5): 6 | return np.percentile(img.ravel(), p) 7 | 8 | 9 | def robust_max(img, p=95): 10 | return np.percentile(img.ravel(), p) 11 | 12 | 13 | def normalize(img, m=10, M=90): 14 | return np.clip((img - robust_min(img, m)) / (robust_max(img, M) - robust_min(img, m)), 0.0, 1.0) 15 | 16 | 17 | def first_element_greater_than(values, req_value): 18 | """Returns the pair (i, values[i]) such that i is the minimum value that satisfies values[i] >= req_value. 19 | Returns (-1, None) if there is no such i. 20 | Note: this function assumes that values is a sorted array!""" 21 | i = np.searchsorted(values, req_value) 22 | val = values[i] if i < len(values) else None 23 | return (i, val) 24 | 25 | 26 | def last_element_less_than(values, req_value): 27 | """Returns the pair (i, values[i]) such that i is the maximum value that satisfies values[i] <= req_value. 28 | Returns (-1, None) if there is no such i. 29 | Note: this function assumes that values is a sorted array!""" 30 | i = np.searchsorted(values, req_value, side='right') - 1 31 | val = values[i] if i >= 0 else None 32 | return (i, val) 33 | 34 | 35 | def closest_element_to(values, req_value): 36 | """Returns the tuple (i, values[i], diff) such that i is the closest value to req_value, 37 | and diff = |values(i) - req_value| 38 | Note: this function assumes that values is a sorted array!""" 39 | assert(len(values) > 0) 40 | 41 | i = np.searchsorted(values, req_value, side='left') 42 | if i > 0 and (i == len(values) or fabs(req_value - values[i - 1]) < fabs(req_value - values[i])): 43 | idx = i - 1 44 | val = values[i - 1] 45 | else: 46 | idx = i 47 | val = values[i] 48 | 49 | diff = fabs(val - req_value) 50 | return (idx, val, diff) 51 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def semseg_compute_confusion(y_hat_lbl, y_lbl, num_classes, ignore_label): 5 | assert torch.is_tensor(y_hat_lbl) and torch.is_tensor(y_lbl), 'Inputs must be torch tensors' 6 | assert y_lbl.device == y_hat_lbl.device, 'Input tensors have different device placement' 7 | 8 | assert y_hat_lbl.dim() == 3 or y_hat_lbl.dim() == 4 and y_hat_lbl.shape[1] == 1 9 | assert y_lbl.dim() == 3 or y_lbl.dim() == 4 and y_lbl.shape[1] == 1 10 | if y_hat_lbl.dim() == 4: 11 | y_hat_lbl = y_hat_lbl.squeeze(1) 12 | if y_lbl.dim() == 4: 13 | y_lbl = y_lbl.squeeze(1) 14 | 15 | mask = y_lbl != ignore_label 16 | y_hat_lbl = y_hat_lbl[mask] 17 | y_lbl = y_lbl[mask] 18 | 19 | # hack for bincounting 2 arrays together 20 | x = y_hat_lbl + num_classes * y_lbl 21 | bincount_2d = torch.bincount(x.long(), minlength=num_classes ** 2) 22 | assert bincount_2d.numel() == num_classes ** 2, 'Internal error' 23 | conf = bincount_2d.view((num_classes, num_classes)).long() 24 | return conf 25 | 26 | 27 | def semseg_accum_confusion_to_iou(confusion_accum): 28 | conf = confusion_accum.double() 29 | diag = conf.diag() 30 | iou_per_class = 100 * diag / (conf.sum(dim=1) + conf.sum(dim=0) - diag).clamp(min=1e-12) 31 | iou_mean = iou_per_class.mean() 32 | return iou_mean, iou_per_class 33 | 34 | def semseg_accum_confusion_to_acc(confusion_accum): 35 | conf = confusion_accum.double() 36 | diag = conf.diag() 37 | acc = 100 * diag.sum() / (conf.sum(dim=1).sum()).clamp(min=1e-12) 38 | return acc 39 | 40 | class MetricsSemseg: 41 | def __init__(self, num_classes, ignore_label, class_names): 42 | self.num_classes = num_classes 43 | self.ignore_label = ignore_label 44 | self.class_names = class_names 45 | self.metrics_acc = None 46 | 47 | def reset(self): 48 | self.metrics_acc = None 49 | 50 | def update_batch(self, y_hat_lbl, y_lbl): 51 | with torch.no_grad(): 52 | metrics_batch = semseg_compute_confusion(y_hat_lbl, y_lbl, self.num_classes, self.ignore_label).cpu() 53 | if self.metrics_acc is None: 54 | self.metrics_acc = metrics_batch 55 | else: 56 | self.metrics_acc += metrics_batch 57 | 58 | def get_metrics_summary(self): 59 | iou_mean, iou_per_class = semseg_accum_confusion_to_iou(self.metrics_acc) 60 | out = {self.class_names[i]: iou for i, iou in enumerate(iou_per_class)} 61 | out['mean_iou'] = iou_mean 62 | acc = semseg_accum_confusion_to_acc((self.metrics_acc)) 63 | out['acc'] = acc 64 | out['cm'] = self.metrics_acc 65 | return out 66 | 67 | 68 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/models/__init__.py -------------------------------------------------------------------------------- /models/style_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torch.nn.functional as f 5 | 6 | from models.submodules import InterpolationLayer 7 | 8 | 9 | class SemSegE2VID(nn.Module): 10 | def __init__(self, input_c, output_c, skip_connect=False, skip_type='sum', input_index_map=False): 11 | super(SemSegE2VID, self).__init__() 12 | self.skip_connect = skip_connect 13 | self.skip_type = skip_type 14 | self.apply_skip_connection = skip_sum if self.skip_type == 'sum' else skip_concat 15 | tch = input_c 16 | self.index_coords = None 17 | self.input_index_map = input_index_map 18 | 19 | if self.skip_connect: 20 | decoder_list_1 = [] 21 | for i in range(0, 5): # 3, 5 22 | decoder_list_1 += [INSResBlock(tch, tch)] 23 | decoder_list_1 += [ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1)] 24 | self.decoder_scale_1 = torch.nn.Sequential(*decoder_list_1) 25 | self.decoder_scale_2 = nn.Sequential(ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1), 26 | ReLUINSConv2d(tch // 2, tch // 4, kernel_size=3, stride=1, padding=1)) 27 | tch = tch // 2 28 | self.decoder_scale_3 = nn.Sequential(ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1), 29 | ReLUINSConv2d(tch // 2, tch // 2, kernel_size=3, stride=1, padding=1)) 30 | tch = tch // 2 31 | self.decoder_scale_4 = nn.Sequential(ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1)) 32 | tch = tch // 2 33 | self.decoder_scale_5 = nn.Sequential( 34 | torch.nn.Conv2d(tch, output_c, kernel_size=1, stride=1, padding=0)) 35 | else: 36 | if self.input_index_map: 37 | tch += 2 38 | # Instance Norm 39 | decoder_list_1 = [] 40 | for i in range(0, 3): 41 | decoder_list_1 += [INSResBlock(tch, tch)] 42 | 43 | self.decoder_scale_1 = torch.nn.Sequential(*decoder_list_1) 44 | tch = tch 45 | if self.input_index_map: 46 | self.decoder_scale_2 = nn.Sequential(InterpolationLayer(scale_factor=2, mode='nearest'), 47 | ReLUINSConv2d(tch, (tch - 2) // 2, kernel_size=3, stride=1, 48 | padding=1)) 49 | tch = (tch - 2) // 2 50 | else: 51 | self.decoder_scale_2 = nn.Sequential(InterpolationLayer(scale_factor=2, mode='nearest'), 52 | ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1)) 53 | tch = tch // 2 54 | self.decoder_scale_3 = nn.Sequential(InterpolationLayer(scale_factor=2, mode='nearest'), 55 | ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1)) 56 | tch = tch // 2 57 | tch = tch 58 | self.decoder_scale_4 = nn.Sequential(InterpolationLayer(scale_factor=2, mode='nearest'), 59 | ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1)) 60 | tch = tch // 2 61 | self.decoder_scale_5 = nn.Sequential( 62 | torch.nn.Conv2d(tch, output_c, kernel_size=1, stride=1, padding=0)) 63 | 64 | def update_skip_dict(self, skips, x, sz_in): 65 | rem, scale = sz_in % x.shape[3], sz_in // x.shape[3] 66 | assert rem == 0 67 | skips[scale] = x 68 | 69 | def forward(self, input_dict): 70 | sz_in = input_dict[1].shape[3] 71 | 72 | x = input_dict[8] 73 | out = {8: x} 74 | 75 | if self.skip_connect: 76 | x = self.decoder_scale_1(x) 77 | x = f.interpolate(x, scale_factor=2, mode='nearest') 78 | x = self.apply_skip_connection(x, input_dict[4]) 79 | x = self.decoder_scale_2(x) 80 | self.update_skip_dict(out, x, sz_in) 81 | x = f.interpolate(x, scale_factor=2, mode='nearest') 82 | x = self.apply_skip_connection(x, input_dict[2]) 83 | x = self.decoder_scale_3(x) 84 | self.update_skip_dict(out, x, sz_in) 85 | x = f.interpolate(x, scale_factor=2, mode='nearest') 86 | x = self.decoder_scale_4(x) 87 | x = self.decoder_scale_5(x) 88 | self.update_skip_dict(out, x, sz_in) 89 | else: 90 | if self.input_index_map: 91 | if self.index_coords is None or self.index_coords.size(2) != x.size(2): 92 | x_coords = torch.arange(x.size(2), device=x.device, dtype=torch.float) 93 | y_coords = torch.arange(x.size(3), device=x.device, dtype=torch.float) 94 | self.index_coords = torch.stack(torch.meshgrid([x_coords, 95 | y_coords]), dim=0) 96 | self.index_coords = self.index_coords[None, :, :, :].repeat([x.size(0), 1, 1, 1]) 97 | x = torch.cat([x, self.index_coords], dim=1) 98 | 99 | x = self.decoder_scale_1(x) 100 | x = self.decoder_scale_2(x) 101 | self.update_skip_dict(out, x, sz_in) 102 | x = self.decoder_scale_3(x) 103 | self.update_skip_dict(out, x, sz_in) 104 | x = self.decoder_scale_4(x) 105 | x = self.decoder_scale_5(x) 106 | self.update_skip_dict(out, x, sz_in) 107 | return out 108 | 109 | 110 | class StyleEncoderE2VID(nn.Module): 111 | def __init__(self, input_dim, skip_connect=False): 112 | super(StyleEncoderE2VID, self).__init__() 113 | conv_list = [] 114 | self.skip_connect = skip_connect 115 | 116 | conv_list += [nn.Conv2d(input_dim, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)] 117 | conv_list += list(models.resnet18(pretrained=True).children())[1:3] 118 | conv_list += list(models.resnet18(pretrained=True).children())[4:5] 119 | self.encoder_scale_1 = nn.Sequential(*conv_list) 120 | self.encoder_scale_2 = list(models.resnet18(pretrained=True).children())[5] 121 | self.encoder_scale_3 = list(models.resnet18(pretrained=True).children())[6] 122 | 123 | def update_skip_dict(self, skips, x, sz_in): 124 | rem, scale = sz_in % x.shape[3], sz_in // x.shape[3] 125 | assert rem == 0 126 | skips[scale] = x 127 | 128 | def forward(self, x): 129 | out = {1: x} 130 | sz_in = x.shape[3] 131 | 132 | if self.skip_connect: 133 | x = self.encoder_scale_1(x) 134 | self.update_skip_dict(out, x, sz_in) 135 | x = self.encoder_scale_2(x) 136 | self.update_skip_dict(out, x, sz_in) 137 | x = self.encoder_scale_3(x) 138 | self.update_skip_dict(out, x, sz_in) 139 | else: 140 | x = self.encoder_scale_1(x) 141 | x = self.encoder_scale_2(x) 142 | x = self.encoder_scale_3(x) 143 | self.update_skip_dict(out, x, sz_in) 144 | 145 | return out 146 | 147 | 148 | #################################################################### 149 | # -------------------------- Basic Blocks -------------------------- 150 | #################################################################### 151 | 152 | def gaussian_weights_init(m): 153 | classname = m.__class__.__name__ 154 | if classname.find('Conv') != -1 and classname.find('Conv') == 0: 155 | m.weight.data.normal_(0.0, 0.02) 156 | 157 | 158 | class ReLUINSConv2d(nn.Module): 159 | def __init__(self, n_in, n_out, kernel_size, stride, padding=0): 160 | super(ReLUINSConv2d, self).__init__() 161 | model = [] 162 | model += [nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)] 163 | model += [nn.InstanceNorm2d(n_out, affine=False)] 164 | model += [nn.ReLU(inplace=True)] 165 | self.model = nn.Sequential(*model) 166 | self.model.apply(gaussian_weights_init) 167 | 168 | def forward(self, x): 169 | return self.model(x) 170 | 171 | 172 | class INSResBlock(nn.Module): 173 | def conv3x3(self, inplanes, out_planes, stride=1): 174 | return [nn.Conv2d(inplanes, out_planes, kernel_size=3, stride=stride, padding=1)] 175 | 176 | def __init__(self, inplanes, planes, stride=1, dropout=0.0): 177 | super(INSResBlock, self).__init__() 178 | model = [] 179 | model += self.conv3x3(inplanes, planes, stride) 180 | model += [nn.InstanceNorm2d(planes)] 181 | model += [nn.ReLU(inplace=True)] 182 | model += self.conv3x3(planes, planes) 183 | model += [nn.InstanceNorm2d(planes)] 184 | if dropout > 0: 185 | model += [nn.Dropout(p=dropout)] 186 | self.model = nn.Sequential(*model) 187 | self.model.apply(gaussian_weights_init) 188 | 189 | def forward(self, x): 190 | residual = x 191 | out = self.model(x) 192 | out += residual 193 | return out 194 | 195 | 196 | def skip_concat(x1, x2): 197 | return torch.cat([x1, x2], dim=1) 198 | 199 | 200 | def skip_sum(x1, x2): 201 | return x1 + x2 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /models/submodules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | from torch.nn import init 5 | 6 | 7 | class InterpolationLayer(nn.Module): 8 | def __init__(self, size=None, scale_factor=None, mode='nearest'): 9 | super(InterpolationLayer, self).__init__() 10 | self.interp = nn.functional.interpolate 11 | self.scale_factor = scale_factor 12 | self.size = size 13 | self.mode = mode 14 | 15 | def forward(self, x): 16 | if self.scale_factor is not None: 17 | if self.mode == 'nearest' and self.scale_factor == 2: 18 | return x[:, :, :, None, :, None].expand(-1, -1, -1, 2, -1, 2).reshape(x.size(0), x.size(1), 19 | 2 * x.size(2), 2 * x.size(3)) 20 | else: 21 | return self.interp(x, scale_factor=self.scale_factor, mode=self.mode) 22 | 23 | else: 24 | return self.interp(x, size=self.size, mode=self.mode) 25 | 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | aiohttp==3.7.4.post0 3 | alabaster==0.7.12 4 | albumentations==1.1.0 5 | appdirs==1.4.4 6 | astor==0.8.1 7 | async-timeout==3.0.1 8 | attrs==21.2.0 9 | Babel==2.8.0 10 | backcall==0.2.0 11 | beautifulsoup4==4.9.1 12 | blinker==1.4 13 | brotlipy==0.7.0 14 | cachetools==4.2.1 15 | certifi==2021.10.8 16 | cffi==1.14.0 17 | chardet==4.0.0 18 | charset-normalizer==2.0.4 19 | cityscapesScripts==2.2.0 20 | click==8.0.1 21 | coloredlogs==15.0.1 22 | configparser==5.2.0 23 | coverage==5.5 24 | cryptography==3.4.8 25 | cycler==0.10.0 26 | Cython==0.29.24 27 | decorator==4.4.2 28 | docker-pycreds==0.4.0 29 | docutils==0.16 30 | future==0.18.2 31 | gitdb==4.0.9 32 | GitPython==3.1.27 33 | google-auth==1.27.0 34 | google-auth-oauthlib==0.4.2 35 | google-pasta==0.2.0 36 | gql==0.2.0 37 | graphql-core==1.1 38 | grpcio==1.35.0 39 | h5py==2.10.0 40 | hdf5plugin==3.2.0 41 | humanfriendly==10.0 42 | idna==2.10 43 | imageio==2.13.3 44 | imagesize==1.2.0 45 | importlib-metadata==4.8.1 46 | ipython==7.16.1 47 | ipython-genutils==0.2.0 48 | javalang==0.13.0 49 | javasphinx==0.9.15 50 | jedi==0.17.0 51 | Jinja2==2.11.2 52 | joblib==0.16.0 53 | kiwisolver==1.2.0 54 | llvmlite==0.38.0rc1 55 | lxml==4.5.1 56 | Markdown==3.3.4 57 | MarkupSafe==1.1.1 58 | matplotlib==3.3.1 59 | mkl-fft 60 | mkl-random 61 | mkl-service 62 | multidict==5.1.0 63 | networkx==2.6.3 64 | numba==0.55.0rc1 65 | numpy==1.18.5 66 | nvidia-ml-py3==7.352.0 67 | oauthlib==3.1.0 68 | olefile==0.46 69 | opencv-python==4.3.0.38 70 | opencv-python-headless==4.5.4.60 71 | packaging==20.4 72 | parso==0.8.1 73 | pathtools==0.1.2 74 | pexpect==4.8.0 75 | pickleshare==0.7.5 76 | Pillow==7.2.0 77 | pip==20.1.1 78 | promise==2.3 79 | prompt-toolkit==3.0.8 80 | protobuf==3.18.1 81 | psutil==5.9.0 82 | ptyprocess==0.7.0 83 | pyasn1==0.4.8 84 | pyasn1-modules==0.2.8 85 | pycparser==2.20 86 | Pygments==2.7.3 87 | PyJWT==2.1.0 88 | pyOpenSSL==20.0.1 89 | pyparsing==2.4.7 90 | pyquaternion==0.9.9 91 | PySocks==1.7.1 92 | python-dateutil==2.8.1 93 | pytz==2020.1 94 | PyWavelets==1.2.0 95 | PyYAML==5.3.1 96 | qudida==0.0.4 97 | requests==2.25.1 98 | requests-oauthlib==1.3.0 99 | rsa==4.7.2 100 | scikit-image==0.19.0 101 | scikit-learn==0.23.2 102 | scipy==1.4.1 103 | sentry-sdk==1.5.7 104 | setproctitle==1.2.2 105 | setuptools==47.3.1.post20200622 106 | shortuuid==1.0.8 107 | sip==4.19.13 108 | six==1.15.0 109 | smmap==5.0.0 110 | snowballstemmer==2.0.0 111 | soupsieve==2.0.1 112 | Sphinx==2.4.4 113 | sphinxcontrib-applehelp==1.0.2 114 | sphinxcontrib-devhelp==1.0.2 115 | sphinxcontrib-htmlhelp==1.0.3 116 | sphinxcontrib-jsmath==1.0.1 117 | sphinxcontrib-katex==0.6.1 118 | sphinxcontrib-qthelp==1.0.3 119 | sphinxcontrib-serializinghtml==1.1.4 120 | subprocess32==3.5.4 121 | tensorboard==2.4.1 122 | tensorboard-plugin-wit==1.8.0 123 | tensorboardX==2.1 124 | termcolor==1.1.0 125 | threadpoolctl==2.1.0 126 | tifffile==2021.11.2 127 | timm==0.5.4 128 | torch==1.6.0 129 | torchvision==0.7.0 130 | tornado==6.1 131 | tqdm==4.48.2 132 | traitlets==4.3.3 133 | typing==3.7.4.3 134 | typing-extensions==3.7.4.3 135 | urllib3==1.26.3 136 | wandb==0.9.7 137 | watchdog==2.1.6 138 | wcwidth==0.2.5 139 | Werkzeug==1.0.1 140 | wheel==0.34.2 141 | wrapt==1.12.1 142 | yarl==1.6.3 143 | yaspin==2.1.0 144 | zipp==3.4.0 145 | -------------------------------------------------------------------------------- /resources/ESS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/resources/ESS.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example usage: CUDA_VISIBLE_DEVICES=1, python train.py --settings_file "config/settings_DDD17.yaml" 3 | """ 4 | import argparse 5 | import wandb 6 | 7 | from config.settings import Settings 8 | from training.ess_trainer import ESSModel 9 | from training.ess_supervised_trainer import ESSSupervisedModel 10 | 11 | import numpy as np 12 | import torch 13 | import random 14 | import os 15 | 16 | # random seed 17 | seed_value = 6 18 | np.random.seed(seed_value) 19 | random.seed(seed_value) 20 | os.environ['PYTHONHASHSEED'] = str(seed_value) 21 | 22 | torch.manual_seed(seed_value) 23 | torch.cuda.manual_seed(seed_value) 24 | torch.cuda.manual_seed_all(seed_value) 25 | torch.backends.cudnn.deterministic = True 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser(description='Train network.') 29 | parser.add_argument('--settings_file', help='Path to settings yaml', required=True) 30 | 31 | args = parser.parse_args() 32 | settings_filepath = args.settings_file 33 | settings = Settings(settings_filepath, generate_log=True) 34 | 35 | wandb.init(name=(settings.dataset_name_b.split("_")[0] + '_' + settings.timestr), project="zhaoning_sun_semester_thesis", entity="zhasun", sync_tensorboard=True) 36 | 37 | if settings.model_name == 'ess': 38 | trainer = ESSModel(settings) 39 | elif settings.model_name == 'ess_supervised': 40 | trainer = ESSSupervisedModel(settings) 41 | 42 | else: 43 | raise ValueError('Model name %s specified in the settings file is not implemented' % settings.model_name) 44 | 45 | wandb.config = { 46 | "random_seed": seed_value, 47 | "lr_front": settings.lr_front, 48 | "lr_back": settings.lr_back, 49 | "batch_size_a": settings.batch_size_a, 50 | "batch_size_b": settings.batch_size_b 51 | } 52 | 53 | trainer.train() 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/training/__init__.py -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/utils/__init__.py -------------------------------------------------------------------------------- /utils/labels.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | # a label and all meta information 4 | Label = namedtuple('Label', [ 5 | 6 | 'name', # The identifier of this label, e.g. 'car', 'person', ... . 7 | # We use them to uniquely name a class 8 | 9 | 'id', # An integer ID that is associated with this label. 10 | # The IDs are used to represent the label in ground truth images 11 | # An ID of -1 means that this label does not have an ID and thus 12 | # is ignored when creating ground truth images (e.g. license plate). 13 | # Do not modify these IDs, since exactly these IDs are expected by the 14 | # evaluation server. 15 | 16 | 'trainId', # Feel free to modify these IDs as suitable for your method. Then create 17 | # ground truth images with train IDs, using the tools provided in the 18 | # 'preparation' folder. However, make sure to validate or submit results 19 | # to our evaluation server using the regular IDs above! 20 | # For trainIds, multiple labels might have the same ID. Then, these labels 21 | # are mapped to the same class in the ground truth images. For the inverse 22 | # mapping, we use the label that is defined first in the list below. 23 | # For example, mapping all void-type classes to the same ID in training, 24 | # might make sense for some approaches. 25 | # Max value is 255! 26 | 27 | 'category', # The name of the category that this label belongs to 28 | 29 | 'categoryId', # The ID of this category. Used to create ground truth images 30 | # on category level. 31 | 32 | 'hasInstances', # Whether this label distinguishes between single instances or not 33 | 34 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 35 | # during evaluations or not 36 | 37 | 'color', # The color of this label 38 | ]) 39 | 40 | labels_6_Cityscapes = [ 41 | # name id trainId category catId hasInstances ignoreInEval color 42 | Label('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), 43 | Label('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), 44 | Label('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), 45 | Label('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), 46 | Label('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), 47 | Label('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), 48 | Label('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), 49 | Label('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), 50 | Label('sidewalk', 8, 0, 'flat', 1, False, False, (244, 35, 232)), 51 | Label('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), 52 | Label('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), 53 | Label('building', 11, 1, 'construction', 2, False, False, (70, 70, 70)), 54 | Label('wall', 12, 1, 'construction', 2, False, False, (102, 102, 156)), 55 | Label('fence', 13, 1, 'construction', 2, False, False, (190, 153, 153)), 56 | Label('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), 57 | Label('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), 58 | Label('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), 59 | Label('pole', 17, 2, 'object', 3, False, False, (153, 153, 153)), 60 | Label('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), 61 | Label('traffic light', 19, 2, 'object', 3, False, False, (250, 170, 30)), 62 | Label('traffic sign', 20, 2, 'object', 3, False, False, (220, 220, 0)), 63 | Label('vegetation', 21, 3, 'nature', 4, False, False, (107, 142, 35)), 64 | Label('terrain', 22, 3, 'nature', 4, False, False, (152, 251, 152)), 65 | Label('sky', 23, 1, 'sky', 5, False, False, (70, 130, 180)), 66 | Label('person', 24, 4, 'human', 6, True, False, (220, 20, 60)), 67 | Label('rider', 25, 4, 'human', 6, True, False, (255, 0, 0)), 68 | Label('car', 26, 5, 'vehicle', 7, True, False, (0, 0, 142)), 69 | Label('truck', 27, 5, 'vehicle', 7, True, False, (0, 0, 70)), 70 | Label('bus', 28, 5, 'vehicle', 7, True, False, (0, 60, 100)), 71 | Label('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), 72 | Label('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), 73 | Label('train', 31, 5, 'vehicle', 7, True, False, (0, 80, 100)), 74 | Label('motorcycle', 32, 5, 'vehicle', 7, True, False, (0, 0, 230)), 75 | Label('bicycle', 33, 5, 'vehicle', 7, True, False, (119, 11, 32)), 76 | Label('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)), 77 | ] 78 | 79 | Id2label_6_Cityscapes = {label.id: label for label in reversed(labels_6_Cityscapes)} 80 | 81 | labels_11_Cityscapes = [ 82 | # name id trainId category catId hasInstances ignoreInEval color 83 | Label('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), 84 | Label('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), 85 | Label('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), 86 | Label('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), 87 | Label('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), 88 | Label('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), 89 | Label('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), 90 | Label('road', 7, 5, 'flat', 1, False, False, (128, 64, 128)), 91 | Label('sidewalk', 8, 6, 'flat', 1, False, False, (244, 35, 232)), 92 | Label('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), 93 | Label('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), 94 | Label('building', 11, 1, 'construction', 2, False, False, (70, 70, 70)), 95 | Label('wall', 12, 9, 'construction', 2, False, False, (102, 102, 156)), 96 | Label('fence', 13, 2, 'construction', 2, False, False, (190, 153, 153)), 97 | Label('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), 98 | Label('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), 99 | Label('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), 100 | Label('pole', 17, 4, 'object', 3, False, False, (153, 153, 153)), 101 | Label('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), 102 | Label('traffic light', 19, 10, 'object', 3, False, False, (250, 170, 30)), 103 | Label('traffic sign', 20, 10, 'object', 3, False, False, (220, 220, 0)), 104 | Label('vegetation', 21, 7, 'nature', 4, False, False, (107, 142, 35)), 105 | Label('terrain', 22, 7, 'nature', 4, False, False, (152, 251, 152)), 106 | Label('sky', 23, 0, 'sky', 5, False, False, (70, 130, 180)), 107 | Label('person', 24, 3, 'human', 6, True, False, (220, 20, 60)), 108 | Label('rider', 25, 3, 'human', 6, True, False, (255, 0, 0)), 109 | Label('car', 26, 8, 'vehicle', 7, True, False, (0, 0, 142)), 110 | Label('truck', 27, 8, 'vehicle', 7, True, False, (0, 0, 70)), 111 | Label('bus', 28, 8, 'vehicle', 7, True, False, (0, 60, 100)), 112 | Label('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), 113 | Label('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), 114 | Label('train', 31, 8, 'vehicle', 7, True, False, (0, 80, 100)), 115 | Label('motorcycle', 32, 8, 'vehicle', 7, True, False, (0, 0, 230)), 116 | Label('bicycle', 33, 8, 'vehicle', 7, True, False, (119, 11, 32)), 117 | Label('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)), 118 | ] 119 | 120 | Id2label_11_Cityscapes = {label.id: label for label in reversed(labels_11_Cityscapes)} 121 | 122 | 123 | def fromIdToTrainId(imgin, Id2label): 124 | imgout = imgin.copy() 125 | for id in Id2label: 126 | imgout[imgin == id] = Id2label[id].trainId 127 | return imgout 128 | 129 | 130 | def shiftUpId(imgin): 131 | imgout = imgin.copy() + 1 132 | return imgout 133 | 134 | 135 | def shiftDownId(imgin): 136 | imgout = imgin.copy() 137 | imgout[imgin == 0] = 256 # ignore label + 1 138 | imgout -= 1 139 | return imgout 140 | -------------------------------------------------------------------------------- /utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | class TaskLoss(torch.nn.Module): 7 | def __init__(self, losses=['cross_entropy'], gamma=2.0, num_classes=13, alpha=None, weight=None, ignore_index=None, reduction='mean'): 8 | super(TaskLoss, self).__init__() 9 | self.losses = losses 10 | self.weight = weight 11 | self.gamma = gamma 12 | self.alpha = alpha 13 | self.ignore_index = ignore_index 14 | self.dice_loss = DiceLoss(num_classes=num_classes, ignore_index=self.ignore_index) 15 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index) 16 | 17 | def forward(self, predict, target): 18 | total_loss = 0 19 | if 'dice' in self.losses: 20 | total_loss += self.dice_loss(predict, target) 21 | if 'cross_entropy' in self.losses: 22 | total_loss += self.ce_loss(predict, target) 23 | 24 | return total_loss 25 | 26 | 27 | class symJSDivLoss(torch.nn.Module): 28 | def __init__(self, ): 29 | super(symJSDivLoss, self).__init__() 30 | self.KLDivLoss = torch.nn.KLDivLoss() 31 | 32 | def forward(self, predict, target): 33 | total_loss = 0 34 | total_loss += 0.5 * self.KLDivLoss(predict.softmax(dim=1).clamp(min=1e-10).log(), target.softmax(dim=1).clamp(min=1e-10)) 35 | total_loss += 0.5 * self.KLDivLoss(target.softmax(dim=1).clamp(min=1e-10).log(), predict.softmax(dim=1).clamp(min=1e-10)) 36 | 37 | return total_loss 38 | 39 | 40 | """ 41 | Adapted from https://github.com/Guocode/DiceLoss.Pytorch.git 42 | """ 43 | def make_one_hot(input, num_classes): 44 | """Convert class index tensor to one hot encoding tensor. 45 | Args: 46 | input: A tensor of shape [N, 1, *] 47 | num_classes: An int of number of class 48 | Returns: 49 | A tensor of shape [N, num_classes, *] 50 | """ 51 | shape = np.array(input.shape) 52 | shape[1] = num_classes 53 | shape = tuple(shape) 54 | result = torch.zeros(shape, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 55 | result = result.scatter_(1, input, 1) 56 | 57 | return result 58 | 59 | 60 | """ 61 | Adapted from https://github.com/Guocode/DiceLoss.Pytorch.git 62 | """ 63 | class BinaryDiceLoss(torch.nn.Module): 64 | """Dice loss of binary class 65 | Args: 66 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 67 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 68 | predict: A tensor of shape [N, *] 69 | target: A tensor of shape same with predict 70 | Returns: 71 | Loss tensor according to arg reduction 72 | Raise: 73 | Exception if unexpected reduction 74 | """ 75 | def __init__(self, smooth=1, p=2): 76 | super(BinaryDiceLoss, self).__init__() 77 | self.smooth = smooth 78 | self.p = p 79 | 80 | def forward(self, predict, target): 81 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 82 | predict = predict.contiguous().view(predict.shape[0], -1) 83 | target = target.contiguous().view(target.shape[0], -1) 84 | 85 | num = torch.sum(torch.mul(predict, target))*2 + self.smooth 86 | den = torch.sum(predict.pow(self.p) + target.pow(self.p)) + self.smooth 87 | 88 | dice = num / den 89 | loss = 1 - dice 90 | return loss 91 | 92 | 93 | """ 94 | Adapted from https://github.com/Guocode/DiceLoss.Pytorch.git 95 | """ 96 | class DiceLoss(torch.nn.Module): 97 | """Dice loss, need one hot encode input 98 | Args: 99 | weight: An array of shape [num_classes,] 100 | ignore_index: class index to ignore 101 | predict: A tensor of shape [N, C, *] 102 | target: A tensor of same shape with predict 103 | other args pass to BinaryDiceLoss 104 | Return: 105 | same as BinaryDiceLoss 106 | """ 107 | def __init__(self, weight=None, num_classes=13, ignore_index=None, **kwargs): 108 | super(DiceLoss, self).__init__() 109 | self.kwargs = kwargs 110 | self.weight = weight 111 | self.num_classes = num_classes 112 | self.ignore_index = ignore_index 113 | 114 | def forward(self, predict, target): 115 | mask = target != self.ignore_index 116 | target = target * mask 117 | target = make_one_hot(torch.unsqueeze(target, 1), self.num_classes) 118 | target = target * mask.unsqueeze(1) 119 | 120 | assert predict.shape == target.shape, 'predict & target shape do not match' 121 | dice = BinaryDiceLoss(**self.kwargs) 122 | total_loss = 0 123 | predict = F.softmax(predict, dim=1) 124 | predict = predict * mask.unsqueeze(1) 125 | 126 | for i in range(target.shape[1]): 127 | if i != self.ignore_index: 128 | dice_loss = dice(predict[:, i], target[:, i]) 129 | if self.weight is not None: 130 | assert self.weight.shape[0] == target.shape[1], \ 131 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) 132 | dice_loss *= self.weights[i] 133 | total_loss += dice_loss 134 | 135 | return total_loss/target.shape[1] 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /utils/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | 6 | class RAdam(Optimizer): 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 8 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 9 | self.buffer = [[None, None, None] for ind in range(10)] 10 | super(RAdam, self).__init__(params, defaults) 11 | 12 | def __setstate__(self, state): 13 | super(RAdam, self).__setstate__(state) 14 | 15 | def step(self, closure=None): 16 | 17 | loss = None 18 | if closure is not None: 19 | loss = closure() 20 | 21 | for group in self.param_groups: 22 | 23 | for p in group['params']: 24 | if p.grad is None: 25 | continue 26 | grad = p.grad.data.float() 27 | if grad.is_sparse: 28 | raise RuntimeError('RAdam does not support sparse gradients') 29 | 30 | p_data_fp32 = p.data.float() 31 | 32 | state = self.state[p] 33 | 34 | if len(state) == 0: 35 | state['step'] = 0 36 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 37 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 38 | else: 39 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 40 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 41 | 42 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 43 | beta1, beta2 = group['betas'] 44 | 45 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 46 | # exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 47 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 48 | 49 | state['step'] += 1 50 | buffered = self.buffer[int(state['step'] % 10)] 51 | if state['step'] == buffered[0]: 52 | N_sma, step_size = buffered[1], buffered[2] 53 | else: 54 | buffered[0] = state['step'] 55 | beta2_t = beta2 ** state['step'] 56 | N_sma_max = 2 / (1 - beta2) - 1 57 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 58 | buffered[1] = N_sma 59 | 60 | # more conservative since it's an approximated value 61 | if N_sma >= 5: 62 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / 63 | N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 64 | else: 65 | step_size = 1.0 / (1 - beta1 ** state['step']) 66 | buffered[2] = step_size 67 | 68 | if group['weight_decay'] != 0: 69 | p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) 70 | 71 | # more conservative since it's an approximated value 72 | if N_sma >= 5: 73 | denom = exp_avg_sq.sqrt().add_(group['eps']) 74 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) 75 | else: 76 | p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr']) 77 | 78 | p.data.copy_(p_data_fp32) 79 | 80 | return loss 81 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import datetime 4 | 5 | import torch 6 | 7 | 8 | class CheckpointSaver(object): 9 | def __init__(self, save_dir): 10 | if save_dir is not None: 11 | self.save_dir = os.path.abspath(save_dir) 12 | return 13 | 14 | # save checkpoint 15 | def save_checkpoint(self, models, optimizers, epoch, step_count, batch_size_a, batch_size_b): 16 | timestamp = datetime.datetime.now() 17 | checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'Epoch_' + str(epoch) + '.pt')) 18 | checkpoint = {} 19 | for model in models: 20 | checkpoint[model] = models[model].state_dict() 21 | for optimizer in optimizers: 22 | checkpoint[optimizer] = optimizers[optimizer].state_dict() 23 | checkpoint['epoch'] = epoch 24 | checkpoint['step_count'] = step_count 25 | checkpoint['batch_size_a'] = batch_size_a 26 | checkpoint['batch_size_b'] = batch_size_b 27 | print() 28 | print(timestamp, 'Epoch:', epoch, 'Iteration:', step_count) 29 | print('Saving checkpoint file [' + checkpoint_filename + ']') 30 | torch.save(checkpoint, checkpoint_filename) 31 | return 32 | 33 | # load a checkpoint 34 | def load_checkpoint(self, models, optimizers, checkpoint_file=None, load_optimizer=True): 35 | checkpoint = torch.load(checkpoint_file) 36 | for model in models: 37 | if model in checkpoint: 38 | models[model].load_state_dict(checkpoint[model]) 39 | if load_optimizer: 40 | for optimizer in optimizers: 41 | if optimizer in checkpoint: 42 | optimizers[optimizer].load_state_dict(checkpoint[optimizer]) 43 | print("Loading checkpoint with epoch {}, step {}" 44 | .format(checkpoint['epoch'], checkpoint['step_count'])) 45 | return {'epoch': checkpoint['epoch'], 46 | 'step_count': checkpoint['step_count'], 47 | 'batch_size_a': checkpoint['batch_size_a'], 48 | 'batch_size_b': checkpoint['batch_size_b']} 49 | 50 | def load_pretrained_weights(self, models, model_list, checkpoint_file=None): 51 | checkpoint = torch.load(checkpoint_file) 52 | load_model_list = [] 53 | for model_name in model_list: 54 | if model_name in ['front_sensor_b', 'e2vid_decoder']: 55 | continue 56 | if model_name in checkpoint: 57 | load_model_list.append(model_name) 58 | models[model_name].load_state_dict(checkpoint[model_name]) 59 | 60 | print("Loading pretrained checkpoints for {}".format(load_model_list)) 61 | -------------------------------------------------------------------------------- /utils/viz_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import torchvision.utils 5 | import torch.nn.functional as F 6 | import matplotlib.pyplot as plt 7 | import itertools 8 | 9 | 10 | def createRGBGrid(tensor_list, nrow): 11 | """Creates a grid of rgb values based on the tensor stored in tensor_list""" 12 | vis_tensor_list = [] 13 | for tensor in tensor_list: 14 | vis_tensor_list.append(visualizeTensors(tensor)) 15 | 16 | return torchvision.utils.make_grid(torch.cat(vis_tensor_list, dim=0), nrow=nrow) 17 | 18 | 19 | def createRGBImage(tensor, separate_pol=True): 20 | """Creates a grid of rgb values based on the tensor stored in tensor_list""" 21 | if tensor.shape[1] == 3: 22 | return tensor 23 | elif tensor.shape[1] == 1: 24 | return tensor.expand(-1, 3, -1, -1) 25 | elif tensor.shape[1] == 2: 26 | return visualizeHistogram(tensor) 27 | elif tensor.shape[1] > 3: 28 | return visualizeVoxelGrid(tensor, separate_pol) 29 | 30 | 31 | def visualizeTensors(tensor): 32 | """Creates a rgb image of the given tensor. Can be event histogram, event voxel grid, grayscale and rgb.""" 33 | if tensor.shape[1] == 3: 34 | return tensor 35 | elif tensor.shape[1] == 1: 36 | return tensor.expand(-1, 3, -1, -1) 37 | elif tensor.shape[1] == 2: 38 | return visualizeHistogram(tensor) 39 | elif tensor.shape[1] > 3: 40 | return visualizeVoxelGrid(tensor) 41 | 42 | 43 | def visualizeHistogram(histogram): 44 | """Visualizes the input histogram""" 45 | batch, _, height, width = histogram.shape 46 | torch_image = torch.zeros([batch, 1, height, width], device=histogram.device) 47 | 48 | return torch.cat([histogram.clamp(0, 1), torch_image], dim=1) 49 | 50 | 51 | def visualizeVoxelGrid(voxel_grid, separate_pol=True): 52 | """Visualizes the input histogram""" 53 | batch, nr_channels, height, width = voxel_grid.shape 54 | if separate_pol: 55 | pos_events_idx = nr_channels // 2 56 | temporal_scaling = torch.arange(start=1, end=pos_events_idx+1, dtype=voxel_grid.dtype, 57 | device=voxel_grid.device)[None, :, None, None] / pos_events_idx 58 | pos_voxel_grid = voxel_grid[:, :pos_events_idx] * temporal_scaling 59 | neg_voxel_grid = voxel_grid[:, pos_events_idx:] * temporal_scaling 60 | 61 | torch_image = torch.zeros([batch, 1, height, width], device=voxel_grid.device) 62 | pos_image = torch.sum(pos_voxel_grid, dim=1, keepdim=True) 63 | neg_image = torch.sum(neg_voxel_grid, dim=1, keepdim=True) 64 | 65 | return torch.cat([neg_image.clamp(0, 1), pos_image.clamp(0, 1), torch_image], dim=1) 66 | 67 | sum_events = torch.sum(voxel_grid, dim=1).detach() 68 | event_preview = torch.zeros((batch, 3, height, width)) 69 | b = event_preview[:, 0, :, :] 70 | r = event_preview[:, 2, :, :] 71 | b[sum_events > 0] = 255 72 | r[sum_events < 0] = 255 73 | return event_preview 74 | 75 | 76 | def visualizeConfusionMatrix(confusion_matrix, path_name=None): 77 | """ 78 | Visualizes the confustion matrix using matplotlib. 79 | 80 | :param confusion_matrix: NxN numpy array 81 | :param path_name: if no path name is given, just an image is returned 82 | """ 83 | import matplotlib.pyplot as plt 84 | nr_classes = confusion_matrix.shape[0] 85 | fig, ax = plt.subplots(1, 1, figsize=(16, 16)) 86 | ax.matshow(confusion_matrix) 87 | ax.plot([-0.5, nr_classes - 0.5], [-0.5, nr_classes - 0.5], '-', color='grey') 88 | ax.set_xlabel('Labels') 89 | ax.set_ylabel('Predicted') 90 | 91 | if path_name is None: 92 | fig.tight_layout(pad=0) 93 | fig.canvas.draw() 94 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 95 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 96 | plt.close() 97 | 98 | return data 99 | 100 | else: 101 | fig.savefig(path_name) 102 | plt.close() 103 | 104 | 105 | def create_checkerboard(N, C, H, W): 106 | cell_sz = max(min(H, W) // 32, 1) 107 | mH = (H + cell_sz - 1) // cell_sz 108 | mW = (W + cell_sz - 1) // cell_sz 109 | checkerboard = torch.full((mH, mW), 0.25, dtype=torch.float32) 110 | checkerboard[0::2, 0::2] = 0.75 111 | checkerboard[1::2, 1::2] = 0.75 112 | checkerboard = checkerboard.float().view(1, 1, mH, mW) 113 | checkerboard = F.interpolate(checkerboard, scale_factor=cell_sz, mode='nearest') 114 | checkerboard = checkerboard[:, :, :H, :W].repeat(N, C, 1, 1) 115 | return checkerboard 116 | 117 | 118 | def prepare_semseg(img, semseg_color_map, semseg_ignore_label): 119 | assert (img.dim() == 3 or img.dim() == 4 and img.shape[1] == 1) and img.dtype in (torch.int, torch.long), \ 120 | f'Expecting 4D tensor with semseg classes, got {img.shape}' 121 | if img.dim() == 4: 122 | img = img.squeeze(1) 123 | colors = torch.tensor(semseg_color_map, dtype=torch.float32) 124 | assert colors.dim() == 2 and colors.shape[1] == 3 125 | if torch.max(colors) > 128: 126 | colors /= 255 127 | img = img.cpu().clone() # N x H x W 128 | N, H, W = img.shape 129 | img_color_ids = torch.unique(img) 130 | assert all(c_id == semseg_ignore_label or 0 <= c_id < len(semseg_color_map) for c_id in img_color_ids) 131 | checkerboard, mask_ignore = None, None 132 | if semseg_ignore_label in img_color_ids: 133 | checkerboard = create_checkerboard(N, 3, H, W) 134 | # blackboard = create_blackboard(N, 3, H, W) 135 | mask_ignore = img == semseg_ignore_label 136 | img[mask_ignore] = 0 137 | img = colors[img] # N x H x W x 3 138 | img = img.permute(0, 3, 1, 2) 139 | 140 | # checkerboard 141 | if semseg_ignore_label in img_color_ids: 142 | mask_ignore = mask_ignore.unsqueeze(1).repeat(1, 3, 1, 1) 143 | img[mask_ignore] = checkerboard[mask_ignore] 144 | # img[mask_ignore] = blackboard[mask_ignore] 145 | return img 146 | 147 | 148 | def plot_confusion_matrix(cm, classes, 149 | normalize=False, 150 | title='Confusion matrix', 151 | cmap=plt.cm.Blues): 152 | """ 153 | This function prints and plots the confusion matrix. 154 | Normalization can be applied by setting `normalize=True`. 155 | """ 156 | cm = cm.numpy() 157 | if normalize: 158 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 159 | print("Normalized confusion matrix") 160 | else: 161 | print('Confusion matrix, without normalization') 162 | 163 | fig = plt.figure() 164 | 165 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 166 | plt.title(title) 167 | plt.colorbar() 168 | tick_marks = np.arange(len(classes)) 169 | plt.xticks(tick_marks, classes, rotation=45) 170 | plt.yticks(tick_marks, classes) 171 | 172 | fmt = '.2f' if normalize else 'd' 173 | thresh = cm.max() / 2. 174 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 175 | plt.text(j, i, format(cm[i, j], fmt), 176 | horizontalalignment="center", 177 | color="white" if cm[i, j] > thresh else "black") 178 | 179 | plt.tight_layout() 180 | plt.ylabel('True label') 181 | plt.xlabel('Predicted label') 182 | return fig 183 | --------------------------------------------------------------------------------