├── .gitignore
├── DSEC
├── dataset
│ ├── provider.py
│ ├── representations.py
│ ├── sequence.py
│ ├── sequence_recurrent.py
│ └── visualization.py
├── utils
│ ├── __init__.py
│ ├── eventslicer.py
│ └── viz_utils.py
└── visualization
│ ├── __init__.py
│ └── eventreader.py
├── LICENSE
├── README.md
├── __init__.py
├── config
├── __init__.py
├── settings.py
├── settings_DDD17.yaml
└── settings_DSEC.yaml
├── datasets
├── DSEC_events_loader.py
├── cityscapes_loader.py
├── data_util.py
├── ddd17_events_loader.py
├── extract_data_tools
│ └── example_loader_ddd17.py
└── wrapper_dataloader.py
├── e2vid
├── LICENSE
├── README.md
├── base
│ ├── __init__.py
│ └── base_model.py
├── image_reconstructor.py
├── model
│ ├── __init__.py
│ ├── model.py
│ ├── submodules.py
│ └── unet.py
├── options
│ ├── __init__.py
│ └── inference_options.py
├── pretrained
│ └── .gitignore
├── run_reconstruction.py
└── utils
│ ├── __init__.py
│ ├── event_readers.py
│ ├── inference_utils.py
│ ├── loading_utils.py
│ ├── path_utils.py
│ ├── timers.py
│ └── util.py
├── evaluation
├── __init__.py
└── metrics.py
├── models
├── __init__.py
├── style_networks.py
└── submodules.py
├── requirements.txt
├── resources
└── ESS.png
├── train.py
├── training
├── __init__.py
├── base_trainer.py
├── ess_supervised_trainer.py
└── ess_trainer.py
└── utils
├── __init__.py
├── labels.py
├── loss_functions.py
├── radam.py
├── saver.py
└── viz_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | log/
3 |
4 | #data loader test
5 | cityscapes_test.py
6 | ddd17_test.py
7 | DSEC_align.py
8 | DSEC_final.py
9 | DSEC_final_UDA.py
10 | DSEC_hdr.py
11 | DSEC_loader_test.py
12 | DSEC_test.py
13 | DSEC_test_2.py
14 | e2vid_test.py
15 | EventScape_final.py
16 | Eventscape_final_UDA.py
17 | eventscapeloadertest.py
18 | flowtest.py
19 |
20 | config/settings_DDD17.yaml
21 | config/eventscape_settings_snaga.yaml
22 |
23 | training/img_to_events_pre_trainer.py
24 | training/img_to_events_trainer.py
25 | training/semantic_segmentation_pre_trainer.py
26 | training/semantic_segmentation_recurrent_trainer.py
27 | training/supervised_baseline.py
28 | training/ess_supervised_trainer.py
29 | training/semantic_segmentation_baseline_supervised.py
30 | training/style_latent_trainer.py
31 | training/style_object_det_trainer.py
--------------------------------------------------------------------------------
/DSEC/dataset/provider.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 |
5 | from DSEC.dataset.sequence import Sequence
6 |
7 |
8 | class DatasetProvider:
9 | def __init__(self, dataset_path: Path, mode: str = 'train', event_representation: str = 'voxel_grid',
10 | nr_events_data: int = 5, delta_t_per_data: int = 20,
11 | nr_events_window=-1, nr_bins_per_data=5, require_paired_data=False, normalize_event=False,
12 | separate_pol=False, semseg_num_classes=11, augmentation=False,
13 | fixed_duration=False, resize=False):
14 | train_path = dataset_path / 'train'
15 | val_path = dataset_path / 'test'
16 | assert dataset_path.is_dir(), str(dataset_path)
17 | assert train_path.is_dir(), str(train_path)
18 | assert val_path.is_dir(), str(val_path)
19 |
20 | if mode == 'train':
21 | train_sequences = list()
22 | train_sequences_namelist = ['zurich_city_00_a', 'zurich_city_01_a', 'zurich_city_02_a',
23 | 'zurich_city_04_a', 'zurich_city_05_a', 'zurich_city_06_a',
24 | 'zurich_city_07_a', 'zurich_city_08_a']
25 | for child in train_path.iterdir():
26 | if any(k in str(child) for k in train_sequences_namelist):
27 | train_sequences.append(Sequence(child, 'train', event_representation, nr_events_data, delta_t_per_data,
28 | nr_events_window, nr_bins_per_data, require_paired_data, normalize_event
29 | , separate_pol, semseg_num_classes, augmentation, fixed_duration
30 | , resize=resize))
31 | else:
32 | continue
33 |
34 | self.train_dataset = torch.utils.data.ConcatDataset(train_sequences)
35 | self.train_dataset.require_paired_data = require_paired_data
36 |
37 | elif mode == 'val':
38 | val_sequences = list()
39 | val_sequences_namelist = ['zurich_city_13_a', 'zurich_city_14_c', 'zurich_city_15_a']
40 | for child in val_path.iterdir():
41 | if any(k in str(child) for k in val_sequences_namelist):
42 | val_sequences.append(Sequence(child, 'val', event_representation, nr_events_data, delta_t_per_data,
43 | nr_events_window, nr_bins_per_data, require_paired_data, normalize_event
44 | , separate_pol, semseg_num_classes, augmentation, fixed_duration
45 | , resize=resize))
46 | else:
47 | continue
48 |
49 | self.val_dataset = torch.utils.data.ConcatDataset(val_sequences)
50 | self.val_dataset.require_paired_data = require_paired_data
51 |
52 |
53 | def get_train_dataset(self):
54 | return self.train_dataset
55 |
56 | def get_val_dataset(self):
57 | # Implement this according to your needs.
58 | return self.val_dataset
59 |
60 | def get_test_dataset(self):
61 | # Implement this according to your needs.
62 | raise NotImplementedError
63 |
--------------------------------------------------------------------------------
/DSEC/dataset/representations.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class EventRepresentation:
5 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor):
6 | raise NotImplementedError
7 |
8 |
9 | class VoxelGrid(EventRepresentation):
10 | def __init__(self, channels: int, height: int, width: int, normalize: bool):
11 | self.voxel_grid = torch.zeros((channels, height, width), dtype=torch.float, requires_grad=False)
12 | self.nb_channels = channels
13 | self.normalize = normalize
14 |
15 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor):
16 | assert x.shape == y.shape == pol.shape == time.shape
17 | assert x.ndim == 1
18 |
19 | C, H, W = self.voxel_grid.shape
20 | with torch.no_grad():
21 | self.voxel_grid = self.voxel_grid.to(pol.device)
22 | voxel_grid = self.voxel_grid.clone()
23 |
24 | t_norm = time
25 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0])
26 |
27 | x0 = x.int()
28 | y0 = y.int()
29 | t0 = t_norm.int()
30 |
31 | value = 2*pol-1
32 |
33 | for xlim in [x0,x0+1]:
34 | for ylim in [y0,y0+1]:
35 | for tlim in [t0,t0+1]:
36 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels)
37 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs())
38 |
39 | index = H * W * tlim.long() + \
40 | W * ylim.long() + \
41 | xlim.long()
42 |
43 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True)
44 |
45 | if self.normalize:
46 | mask = torch.nonzero(voxel_grid, as_tuple=True)
47 | if mask[0].size()[0] > 0:
48 | mean = voxel_grid[mask].mean()
49 | std = voxel_grid[mask].std()
50 | if std > 0:
51 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std
52 | else:
53 | voxel_grid[mask] = voxel_grid[mask] - mean
54 |
55 | return voxel_grid
56 |
--------------------------------------------------------------------------------
/DSEC/dataset/sequence.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from https://github.com/uzh-rpg/DSEC/blob/main/scripts/dataset/sequence.py
3 | """
4 | from pathlib import Path
5 | import weakref
6 |
7 | import cv2
8 | import h5py
9 | import numpy as np
10 | import torch
11 | import torch.nn.functional as f
12 | from torch.utils.data import Dataset
13 | import torchvision.transforms as transforms
14 | from PIL import Image
15 | from joblib import Parallel, delayed
16 |
17 | from DSEC.dataset.representations import VoxelGrid
18 | from DSEC.utils.eventslicer import EventSlicer
19 | import albumentations as A
20 | import datasets.data_util as data_util
21 | import random
22 |
23 | class Sequence(Dataset):
24 | # This class assumes the following structure in a sequence directory:
25 | #
26 | # seq_name (e.g. zurich_city_00_a)
27 | # ├── semantic
28 | # │ ├── left
29 | # │ │ ├── 11classes
30 | # │ │ │ └──data
31 | # │ │ │ ├── 000000.png
32 | # │ │ │ └── ...
33 | # │ │ └── 19classes
34 | # │ │ └──data
35 | # │ │ ├── 000000.png
36 | # │ │ └── ...
37 | # │ └── timestamps.txt
38 | # └── events
39 | # └── left
40 | # ├── events.h5
41 | # └── rectify_map.h5
42 |
43 | def __init__(self, seq_path: Path, mode: str='train', event_representation: str = 'voxel_grid',
44 | nr_events_data: int = 5, delta_t_per_data: int = 20, nr_events_per_data: int = 100000,
45 | nr_bins_per_data: int = 5, require_paired_data=False, normalize_event=False, separate_pol=False,
46 | semseg_num_classes: int = 11, augmentation=False, fixed_duration=False, remove_time_window: int = 250,
47 | resize=False):
48 | assert nr_bins_per_data >= 1
49 | assert seq_path.is_dir()
50 | self.sequence_name = seq_path.name
51 | self.mode = mode
52 |
53 | # Save output dimensions
54 | self.height = 480
55 | self.width = 640
56 | self.resize = resize
57 | self.shape_resize = None
58 | if self.resize:
59 | self.shape_resize = [448, 640]
60 |
61 | # Set event representation
62 | self.nr_events_data = nr_events_data
63 | self.num_bins = nr_bins_per_data
64 | assert nr_events_per_data > 0
65 | self.nr_events_per_data = nr_events_per_data
66 | self.event_representation = event_representation
67 | self.separate_pol = separate_pol
68 | self.normalize_event = normalize_event
69 | self.voxel_grid = VoxelGrid(self.num_bins, self.height, self.width, normalize=self.normalize_event)
70 |
71 | self.locations = ['left']
72 | self.semseg_num_classes = semseg_num_classes
73 | self.augmentation = augmentation
74 |
75 | # Save delta timestamp
76 | self.fixed_duration = fixed_duration
77 | if self.fixed_duration:
78 | delta_t_ms = nr_events_data * delta_t_per_data
79 | self.delta_t_us = delta_t_ms * 1000
80 | self.remove_time_window = remove_time_window
81 |
82 | self.require_paired_data = require_paired_data
83 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84 |
85 | # load timestamps
86 | self.timestamps = np.loadtxt(str(seq_path / 'semantic' / 'timestamps.txt'), dtype='int64')
87 |
88 | # load label paths
89 | if self.semseg_num_classes == 11:
90 | label_dir = seq_path / 'semantic' / '11classes' / 'data'
91 | elif self.semseg_num_classes == 19:
92 | label_dir = seq_path / 'semantic' / '19classes' / 'data'
93 | else:
94 | raise ValueError
95 | assert label_dir.is_dir()
96 | label_pathstrings = list()
97 | for entry in label_dir.iterdir():
98 | assert str(entry.name).endswith('.png')
99 | label_pathstrings.append(str(entry))
100 | label_pathstrings.sort()
101 | self.label_pathstrings = label_pathstrings
102 |
103 | assert len(self.label_pathstrings) == self.timestamps.size
104 |
105 | # load images paths
106 | if self.require_paired_data:
107 | img_dir = seq_path / 'images'
108 | img_left_dir = img_dir / 'left' / 'ev_inf'
109 | assert img_left_dir.is_dir()
110 | img_left_pathstrings = list()
111 | for entry in img_left_dir.iterdir():
112 | assert str(entry.name).endswith('.png')
113 | img_left_pathstrings.append(str(entry))
114 | img_left_pathstrings.sort()
115 | self.img_left_pathstrings = img_left_pathstrings
116 |
117 | assert len(self.img_left_pathstrings) == self.timestamps.size
118 |
119 | # Remove several label paths and corresponding timestamps in the remove_time_window.
120 | # This is necessary because we do not have enough events before the first label.
121 | self.timestamps = self.timestamps[(self.remove_time_window // 100 + 1) * 2:]
122 | del self.label_pathstrings[:(self.remove_time_window // 100 + 1) * 2]
123 | assert len(self.label_pathstrings) == self.timestamps.size
124 | if self.require_paired_data:
125 | del self.img_left_pathstrings[:(self.remove_time_window // 100 + 1) * 2]
126 | assert len(self.img_left_pathstrings) == self.timestamps.size
127 |
128 | self.h5f = dict()
129 | self.rectify_ev_maps = dict()
130 | self.event_slicers = dict()
131 |
132 | ev_dir = seq_path / 'events'
133 | for location in self.locations:
134 | ev_dir_location = ev_dir / location
135 | ev_data_file = ev_dir_location / 'events.h5'
136 | ev_rect_file = ev_dir_location / 'rectify_map.h5'
137 |
138 | h5f_location = h5py.File(str(ev_data_file), 'r')
139 | self.h5f[location] = h5f_location
140 | self.event_slicers[location] = EventSlicer(h5f_location)
141 | with h5py.File(str(ev_rect_file), 'r') as h5_rect:
142 | self.rectify_ev_maps[location] = h5_rect['rectify_map'][()]
143 |
144 | def events_to_voxel_grid(self, x, y, p, t):
145 | t = (t - t[0]).astype('float32')
146 | t = (t/t[-1])
147 | x = x.astype('float32')
148 | y = y.astype('float32')
149 | pol = p.astype('float32')
150 | return self.voxel_grid.convert(
151 | torch.from_numpy(x),
152 | torch.from_numpy(y),
153 | torch.from_numpy(pol),
154 | torch.from_numpy(t))
155 |
156 | def getHeightAndWidth(self):
157 | return self.height, self.width
158 |
159 | @staticmethod
160 | def get_disparity_map(filepath: Path):
161 | assert filepath.is_file()
162 | disp_16bit = cv2.imread(str(filepath), cv2.IMREAD_ANYDEPTH)
163 | return disp_16bit.astype('float32')/256
164 |
165 | @staticmethod
166 | def get_img(filepath: Path, shape_resize=None):
167 | assert filepath.is_file()
168 | img = Image.open(str(filepath))
169 | if shape_resize is not None:
170 | img = img.resize((shape_resize[1], shape_resize[0]))
171 | img_transform = transforms.Compose([
172 | transforms.Grayscale(),
173 | transforms.ToTensor()
174 | ])
175 | img_tensor = img_transform(img)
176 | return img_tensor
177 |
178 | @staticmethod
179 | def get_label(filepath: Path):
180 | assert filepath.is_file()
181 | label = Image.open(str(filepath))
182 | label = np.array(label)
183 | return label
184 |
185 | @staticmethod
186 | def close_callback(h5f_dict):
187 | for k, h5f in h5f_dict.items():
188 | h5f.close()
189 |
190 | def __len__(self):
191 | return (self.timestamps.size + 1) // 2
192 |
193 | def rectify_events(self, x: np.ndarray, y: np.ndarray, location: str):
194 | assert location in self.locations
195 | # From distorted to undistorted
196 | rectify_map = self.rectify_ev_maps[location]
197 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape
198 | assert x.max() < self.width
199 | assert y.max() < self.height
200 | return rectify_map[y, x]
201 |
202 | def generate_event_tensor(self, job_id, events, event_tensor, nr_events_per_data):
203 | id_start = job_id * nr_events_per_data
204 | id_end = (job_id + 1) * nr_events_per_data
205 | events_temp = events[id_start:id_end]
206 | event_representation = self.events_to_voxel_grid(events_temp[:, 0], events_temp[:, 1], events_temp[:, 3],
207 | events_temp[:, 2])
208 | event_tensor[(job_id * self.num_bins):((job_id+1) * self.num_bins), :, :] = event_representation
209 |
210 | def __getitem__(self, index):
211 | label_path = Path(self.label_pathstrings[index * 2])
212 | if self.resize:
213 | segmentation_mask = cv2.imread(str(label_path), 0)
214 | segmentation_mask = cv2.resize(segmentation_mask, (self.shape_resize[1], self.shape_resize[0]),
215 | interpolation=cv2.INTER_NEAREST)
216 | label = np.array(segmentation_mask)
217 | else:
218 | label = self.get_label(label_path)
219 |
220 | ts_end = self.timestamps[index * 2]
221 |
222 | output = {}
223 | for location in self.locations:
224 | if self.fixed_duration:
225 | ts_start = ts_end - self.delta_t_us
226 | event_tensor = None
227 | self.delta_t_per_data_us = self.delta_t_us / self.nr_events_data
228 | for i in range(self.nr_events_data):
229 | t_s = ts_start + i * self.delta_t_per_data_us
230 | t_end = ts_start + (i+1) * self.delta_t_per_data_us
231 | event_data = self.event_slicers[location].get_events(t_s, t_end)
232 |
233 | p = event_data['p']
234 | t = event_data['t']
235 | x = event_data['x']
236 | y = event_data['y']
237 |
238 | xy_rect = self.rectify_events(x, y, location)
239 | x_rect = xy_rect[:, 0]
240 | y_rect = xy_rect[:, 1]
241 |
242 | if self.event_representation == 'voxel_grid':
243 | event_representation = self.events_to_voxel_grid(x_rect, y_rect, p, t)
244 | else:
245 | events = np.stack([x_rect, y_rect, t, p], axis=1)
246 | event_representation = data_util.generate_input_representation(events, self.event_representation,
247 | (self.height, self.width))
248 | event_representation = torch.from_numpy(event_representation).type(torch.FloatTensor)
249 |
250 | if event_tensor is None:
251 | event_tensor = event_representation
252 | else:
253 | event_tensor = torch.cat([event_tensor, event_representation], dim=0)
254 |
255 | else:
256 | num_bins_total = self.nr_events_data * self.num_bins
257 | event_tensor = torch.zeros((num_bins_total, self.height, self.width))
258 | self.nr_events = self.nr_events_data * self.nr_events_per_data
259 | event_data = self.event_slicers[location].get_events_fixed_num(ts_end, self.nr_events)
260 |
261 | if self.nr_events >= event_data['t'].size:
262 | start_index = 0
263 | else:
264 | start_index = -self.nr_events
265 |
266 | p = event_data['p'][start_index:]
267 | t = event_data['t'][start_index:]
268 | x = event_data['x'][start_index:]
269 | y = event_data['y'][start_index:]
270 | nr_events_loaded = t.size
271 |
272 | xy_rect = self.rectify_events(x, y, location)
273 | x_rect = xy_rect[:, 0]
274 | y_rect = xy_rect[:, 1]
275 |
276 | nr_events_temp = nr_events_loaded // self.nr_events_data
277 | events = np.stack([x_rect, y_rect, t, p], axis=-1)
278 | Parallel(n_jobs=8, backend="threading")(
279 | delayed(self.generate_event_tensor)(i, events, event_tensor, nr_events_temp) for i in range(self.nr_events_data))
280 |
281 | # remove 40 bottom rows
282 | event_tensor = event_tensor[:, :-40, :]
283 |
284 | if self.resize:
285 | event_tensor = f.interpolate(event_tensor.unsqueeze(0),
286 | size=(self.shape_resize[0], self.shape_resize[1]),
287 | mode='bilinear', align_corners=True).squeeze(0)
288 |
289 | label_tensor = torch.from_numpy(label).long()
290 |
291 | if self.augmentation:
292 | value_flip = round(random.random())
293 | if value_flip > 0.5:
294 | event_tensor = torch.flip(event_tensor, [2])
295 | label_tensor = torch.flip(label_tensor, [1])
296 |
297 | if 'representation' not in output:
298 | output['representation'] = dict()
299 | output['representation'][location] = event_tensor
300 |
301 | if self.require_paired_data:
302 | img_left_path = Path(self.img_left_pathstrings[index * 2])
303 | output['img_left'] = self.get_img(img_left_path, self.shape_resize)
304 | return output['representation']['left'], output['img_left'], label_tensor
305 | return output['representation']['left'], label_tensor
306 |
307 |
308 |
--------------------------------------------------------------------------------
/DSEC/dataset/sequence_recurrent.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import weakref
3 |
4 | import cv2
5 | import h5py
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import Dataset
9 | import torchvision.transforms as transforms
10 | from PIL import Image
11 |
12 | from DSEC.dataset.representations import VoxelGrid
13 | from DSEC.utils.eventslicer import EventSlicer
14 | import albumentations as A
15 |
16 |
17 | class SequenceRecurrent(Dataset):
18 | # NOTE: This is just an EXAMPLE class for convenience. Adapt it to your case.
19 | # In this example, we use the voxel grid representation.
20 | #
21 | # This class assumes the following structure in a sequence directory:
22 | #
23 | # seq_name (e.g. zurich_city_11_a)
24 | # ├── disparity
25 | # │ ├── event
26 | # │ │ ├── 000000.png
27 | # │ │ └── ...
28 | # │ └── timestamps.txt
29 | # └── events
30 | # ├── left
31 | # │ ├── events.h5
32 | # │ └── rectify_map.h5
33 | # └── right
34 | # ├── events.h5
35 | # └── rectify_map.h5
36 |
37 | def __init__(self, seq_path: Path, mode: str='train', event_representation: str = 'voxel_grid',
38 | nr_events_data: int = 5, delta_t_per_data: int = 20, nr_events_per_data: int = 100000,
39 | nr_bins_per_data: int = 5, require_paired_data=False, normalize_event=False, separate_pol=False,
40 | semseg_num_classes: int = 11, augmentation=True, fixed_duration=False, loading_time_window: int = 250):
41 | assert nr_bins_per_data >= 1
42 | # assert delta_t_ms <= 200, 'adapt this code, if duration is higher than 100 ms'
43 | assert seq_path.is_dir()
44 |
45 | # NOTE: Adapt this code according to the present mode (e.g. train, val or test).
46 | self.mode = mode
47 | self.augmentation = augmentation
48 |
49 | # Save output dimensions
50 | self.height = 480
51 | self.width = 640
52 |
53 | # Set event representation
54 | self.nr_events_data = nr_events_data
55 | self.num_bins = nr_bins_per_data
56 | assert nr_events_per_data > 0
57 | self.nr_events_per_data = nr_events_per_data
58 | self.event_representation = event_representation
59 | self.separate_pol = separate_pol
60 | self.normalize_event = normalize_event
61 | self.voxel_grid = VoxelGrid(self.num_bins, self.height, self.width, normalize=self.normalize_event)
62 |
63 | self.locations = ['left'] # 'right'
64 | self.semseg_num_classes = semseg_num_classes
65 |
66 | # Save delta timestamp in ms
67 | self.fixed_duration = fixed_duration
68 | if self.fixed_duration:
69 | delta_t_ms = nr_events_data * delta_t_per_data
70 | else:
71 | delta_t_ms = loading_time_window
72 | self.delta_t_us = delta_t_ms * 1000
73 |
74 | self.require_paired_data = require_paired_data
75 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76 |
77 | # load disparity timestamps
78 | # disp_dir = seq_path / 'disparity'
79 | img_dir = seq_path / 'images'
80 | assert img_dir.is_dir()
81 |
82 | self.timestamps = np.loadtxt(img_dir / 'left' / 'exposure_timestamps.txt', comments='#', delimiter=',', dtype='int64')[:, 1]
83 |
84 | # load images paths
85 | if self.require_paired_data:
86 | img_left_dir = img_dir / 'left' / 'ev_inf'
87 | assert img_left_dir.is_dir()
88 | img_left_pathstrings = list()
89 | for entry in img_left_dir.iterdir():
90 | assert str(entry.name).endswith('.png')
91 | img_left_pathstrings.append(str(entry))
92 | img_left_pathstrings.sort()
93 | self.img_left_pathstrings = img_left_pathstrings
94 |
95 | assert len(self.img_left_pathstrings) == self.timestamps.size
96 |
97 | if self.mode == 'val':
98 | if self.semseg_num_classes == 11:
99 | label_dir = seq_path / 'semantic' / '11classes' / 'data'
100 | elif self.semseg_num_classes == 19:
101 | label_dir = seq_path / 'semantic' / '19classes' / 'data'
102 | elif self.semseg_num_classes == 6:
103 | label_dir = seq_path / 'semantic' / '6classes' / 'data'
104 | else:
105 | raise ValueError
106 | assert label_dir.is_dir()
107 | label_pathstrings = list()
108 | for entry in label_dir.iterdir():
109 | assert str(entry.name).endswith('.png')
110 | label_pathstrings.append(str(entry))
111 | label_pathstrings.sort()
112 | self.label_pathstrings = label_pathstrings
113 |
114 | assert len(self.label_pathstrings) == self.timestamps.size
115 |
116 | # Remove first disparity path and corresponding timestamp.
117 | # This is necessary because we do not have events before the first disparity map.
118 | # assert int(Path(self.disp_gt_pathstrings[0]).stem) == 0
119 | # del self.disp_gt_pathstrings[:(delta_t_ms // 100 + 1)]
120 | self.timestamps = self.timestamps[(delta_t_ms // 50 + 1):]
121 | if self.require_paired_data:
122 | del self.img_left_pathstrings[:(delta_t_ms // 50 + 1)]
123 | assert len(self.img_left_pathstrings) == self.timestamps.size
124 | if self.mode == 'val':
125 | del self.label_pathstrings[:(delta_t_ms // 50 + 1)]
126 | assert len(self.img_left_pathstrings) == len(self.label_pathstrings)
127 |
128 | self.h5f = dict()
129 | self.rectify_ev_maps = dict()
130 | self.event_slicers = dict()
131 |
132 | ev_dir = seq_path / 'events'
133 | for location in self.locations:
134 | ev_dir_location = ev_dir / location
135 | ev_data_file = ev_dir_location / 'events.h5'
136 | ev_rect_file = ev_dir_location / 'rectify_map.h5'
137 |
138 | h5f_location = h5py.File(str(ev_data_file), 'r')
139 | self.h5f[location] = h5f_location
140 | self.event_slicers[location] = EventSlicer(h5f_location)
141 | with h5py.File(str(ev_rect_file), 'r') as h5_rect:
142 | self.rectify_ev_maps[location] = h5_rect['rectify_map'][()]
143 |
144 | def events_to_voxel_grid(self, x, y, p, t, device: str='cpu'):
145 | t = (t - t[0]).astype('float32')
146 | t = (t/t[-1])
147 | x = x.astype('float32')
148 | y = y.astype('float32')
149 | pol = p.astype('float32')
150 | return self.voxel_grid.convert(
151 | torch.from_numpy(x),
152 | torch.from_numpy(y),
153 | torch.from_numpy(pol),
154 | torch.from_numpy(t))
155 |
156 | def getHeightAndWidth(self):
157 | return self.height, self.width
158 |
159 | @staticmethod
160 | def get_disparity_map(filepath: Path):
161 | assert filepath.is_file()
162 | disp_16bit = cv2.imread(str(filepath), cv2.IMREAD_ANYDEPTH)
163 | return disp_16bit.astype('float32')/256
164 |
165 | @staticmethod
166 | def get_img(filepath: Path):
167 | assert filepath.is_file()
168 | img = Image.open(str(filepath))
169 | img_transform = transforms.Compose([
170 | transforms.Grayscale(),
171 | transforms.ToTensor()
172 | ])
173 | img_tensor = img_transform(img)
174 | return img_tensor
175 |
176 | @staticmethod
177 | def get_label(filepath: Path):
178 | assert filepath.is_file()
179 | label = Image.open(str(filepath))
180 | label = np.array(label)
181 | # label_tensor = torch.from_numpy(label).long()
182 | return label
183 |
184 | @staticmethod
185 | def close_callback(h5f_dict):
186 | for k, h5f in h5f_dict.items():
187 | h5f.close()
188 |
189 | def __len__(self):
190 | if self.fixed_duration:
191 | return self.timestamps.size
192 | else:
193 | return self.event_slicers['left'].events['t'].size // self.nr_events_per_data
194 |
195 | def rectify_events(self, x: np.ndarray, y: np.ndarray, location: str):
196 | assert location in self.locations
197 | # From distorted to undistorted
198 | rectify_map = self.rectify_ev_maps[location]
199 | assert rectify_map.shape == (self.height, self.width, 2), rectify_map.shape
200 | assert x.max() < self.width
201 | assert y.max() < self.height
202 | return rectify_map[y, x]
203 |
204 | def __getitem__(self, index):
205 | if self.augmentation:
206 | transform_a = A.ReplayCompose([
207 | A.HorizontalFlip(p=0.5),
208 | # A.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=0.5, border_mode=0),
209 | # A.PadIfNeeded(min_height=self.height-40, min_width=self.width, always_apply=True, border_mode=0),
210 | # A.RandomCrop(height=self.height-40, width=self.width, p=1)
211 | ])
212 | A_data = None
213 |
214 | label = np.zeros((self.height-40, self.width))
215 |
216 | output = {}
217 | for location in self.locations:
218 | if self.fixed_duration:
219 | if self.mode == 'val':
220 | label_path = Path(self.label_pathstrings[index])
221 | label = self.get_label(label_path)
222 |
223 | ts_end = self.timestamps[index]
224 | # ts_start should be fine (within the window as we removed the first disparity map)
225 | ts_start = ts_end - self.delta_t_us
226 |
227 | event_tensor = None
228 | self.delta_t_per_data_us = self.delta_t_us / self.nr_events_data
229 | for i in range(self.nr_events_data):
230 | t_s = ts_start + i * self.delta_t_per_data_us
231 | t_end = ts_start + (i+1) * self.delta_t_per_data_us
232 | event_data = self.event_slicers[location].get_events(t_s, t_end)
233 |
234 | p = event_data['p']
235 | t = event_data['t']
236 | x = event_data['x']
237 | y = event_data['y']
238 |
239 | xy_rect = self.rectify_events(x, y, location)
240 | x_rect = xy_rect[:, 0]
241 | y_rect = xy_rect[:, 1]
242 |
243 | event_representation = self.events_to_voxel_grid(x_rect, y_rect, p, t)
244 |
245 | if event_tensor is None:
246 | event_tensor = event_representation
247 | else:
248 | event_tensor = torch.cat([event_tensor, event_representation], dim=0)
249 | else:
250 | num_bins_total = self.nr_events_data * self.num_bins
251 | self.nr_events = self.nr_events_data * self.nr_events_per_data
252 | t_start_us_idx = index * self.nr_events
253 | t_end_us_idx = t_start_us_idx + self.nr_events
254 | event_data = self.event_slicers[location].get_events_fixed_num_recurrent(t_start_us_idx, t_end_us_idx)
255 |
256 | p = event_data['p']
257 | t = event_data['t']
258 | x = event_data['x']
259 | y = event_data['y']
260 |
261 | xy_rect = self.rectify_events(x, y, location)
262 | x_rect = xy_rect[:, 0]
263 | y_rect = xy_rect[:, 1]
264 |
265 | event_representation = self.events_to_voxel_grid(x_rect, y_rect, p, t)
266 |
267 | event_tensor = event_representation
268 |
269 |
270 | event_tensor = event_tensor[:, :-40, :] # remove 40 bottom rows
271 |
272 | if self.augmentation:
273 | if A_data is None:
274 | A_data = transform_a(image=event_tensor[0, :, :].numpy(), mask=label)
275 | label = A_data['mask']
276 | for k in range(event_tensor.shape[0]):
277 | event_tensor[k, :, :] = torch.from_numpy(
278 | A.ReplayCompose.replay(A_data['replay'], image=event_tensor[k, :, :].numpy())['image'])
279 |
280 | if 'representation' not in output:
281 | output['representation'] = dict()
282 | output['representation'][location] = event_tensor
283 |
284 | label_tensor = torch.from_numpy(label).long()
285 |
286 | if self.require_paired_data:
287 | img_left_path = Path(self.img_left_pathstrings[index])
288 | # output['img_left'] = self.get_img(img_left_path)[:, :-40, :] # remove 40 bottom rows
289 | output['img_left'] = self.get_img(img_left_path)
290 | return output['representation']['left'], output['img_left'], label_tensor
291 | return output['representation']['left'], label_tensor
292 |
--------------------------------------------------------------------------------
/DSEC/dataset/visualization.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import matplotlib as mpl
3 | import matplotlib.cm as cm
4 | import numpy as np
5 |
6 |
7 | def disp_img_to_rgb_img(disp_array: np.ndarray):
8 | disp_pixels = np.argwhere(disp_array > 0)
9 | u_indices = disp_pixels[:, 1]
10 | v_indices = disp_pixels[:, 0]
11 | disp = disp_array[v_indices, u_indices]
12 | max_disp = 80
13 |
14 | norm = mpl.colors.Normalize(vmin=0, vmax=max_disp, clip=True)
15 | mapper = cm.ScalarMappable(norm=norm, cmap='inferno')
16 |
17 | disp_color = mapper.to_rgba(disp)[..., :3]
18 | output_image = np.zeros((disp_array.shape[0], disp_array.shape[1], 3))
19 | output_image[v_indices, u_indices, :] = disp_color
20 | output_image = (255 * output_image).astype("uint8")
21 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
22 | return output_image
23 |
24 | def show_image(image):
25 | cv2.namedWindow('viz', cv2.WND_PROP_FULLSCREEN)
26 | cv2.imshow('viz', image)
27 | cv2.waitKey(0)
28 |
29 | def get_disp_overlay(image_1c, disp_rgb_image, height, width):
30 | image = np.repeat(image_1c[..., np.newaxis], 3, axis=2)
31 | overlay = cv2.addWeighted(image, 0.1, disp_rgb_image, 0.9, 0)
32 | return overlay
33 |
34 | def show_disp_overlay(image_1c, disp_rgb_image, height, width):
35 | overlay = get_disp_overlay(image_1c, disp_rgb_image, height, width)
36 | show_image(overlay)
37 |
--------------------------------------------------------------------------------
/DSEC/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/DSEC/utils/__init__.py
--------------------------------------------------------------------------------
/DSEC/utils/eventslicer.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, Tuple
3 |
4 | import h5py
5 | import hdf5plugin
6 | from numba import jit
7 | import numpy as np
8 |
9 |
10 | class EventSlicer:
11 | def __init__(self, h5f: h5py.File):
12 | self.h5f = h5f
13 |
14 | self.events = dict()
15 | for dset_str in ['p', 'x', 'y', 't']:
16 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)]
17 |
18 | # This is the mapping from milliseconds to event index:
19 | # It is defined such that
20 | # (1) t[ms_to_idx[ms]] >= ms*1000
21 | # (2) t[ms_to_idx[ms] - 1] < ms*1000
22 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds.
23 | #
24 | # As an example, given 't' and 'ms':
25 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000
26 | # ms: 0 1 2 3 4 5 6 7 8 9
27 | #
28 | # we get
29 | #
30 | # ms_to_idx:
31 | # 0 2 2 3 3 3 5 5 8 9
32 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64')
33 |
34 | if "t_offset" in list(h5f.keys()):
35 | self.t_offset = int(h5f['t_offset'][()])
36 | else:
37 | self.t_offset = 0
38 | self.t_final = int(self.events['t'][-1]) + self.t_offset
39 |
40 | def get_start_time_us(self):
41 | return self.t_offset
42 |
43 | def get_final_time_us(self):
44 | return self.t_final
45 |
46 | def get_events(self, t_start_us: int, t_end_us: int, max_events_per_data: int = -1) -> Dict[str, np.ndarray]:
47 | """Get events (p, x, y, t) within the specified time window
48 | Parameters
49 | ----------
50 | t_start_us: start time in microseconds
51 | t_end_us: end time in microseconds
52 | Returns
53 | -------
54 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved
55 | """
56 | assert t_start_us < t_end_us
57 |
58 | # We assume that the times are top-off-day, hence subtract offset:
59 | t_start_us -= self.t_offset
60 | t_end_us -= self.t_offset
61 |
62 | # print(t_start_us, self.t_offset)
63 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us)
64 | t_start_ms_idx = self.ms2idx(t_start_ms)
65 | t_end_ms_idx = self.ms2idx(t_end_ms)
66 |
67 | if t_start_ms_idx is None or t_end_ms_idx is None:
68 | print('Error', 'start', t_start_us, 'end', t_end_us)
69 | # Cannot guarantee window size anymore
70 | return None
71 |
72 | events = dict()
73 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx])
74 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us)
75 | t_start_us_idx = t_start_ms_idx + idx_start_offset
76 | t_end_us_idx = t_start_ms_idx + idx_end_offset
77 | # Again add t_offset to get gps time
78 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset
79 | for dset_str in ['p', 'x', 'y']:
80 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx])
81 | assert events[dset_str].size == events['t'].size
82 | # if max_events_per_data != -1 and events['t'].size > max_events_per_data:
83 | # idx = np.round(np.linspace(0, events['t'].size - 1, max_events_per_data)).astype(int)
84 | # for key in events.keys():
85 | # events[key] = events[key][idx]
86 | return events
87 |
88 | def get_events_fixed_num(self, t_end_us: int, nr_events: int = 100000) -> Dict[str, np.ndarray]:
89 | """Get events (p, x, y, t) with fixed number of events
90 | Parameters
91 | ----------
92 | t_end_us: end time in microseconds
93 | nr_events: number of events to load
94 | Returns
95 | -------
96 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved
97 | """
98 | # t_start_us = t_end_us - 1000
99 | # assert t_start_us < t_end_us
100 |
101 | # We assume that the times are top-off-day, hence subtract offset:
102 | # t_start_us -= self.t_offset
103 | t_end_us -= self.t_offset
104 |
105 | # print(t_start_us, self.t_offset)
106 | t_end_lower_ms, t_end_upper_ms = self.get_conservative_ms(t_end_us)
107 | t_end_lower_ms_idx = self.ms2idx(t_end_lower_ms)
108 | t_end_upper_ms_idx = self.ms2idx(t_end_upper_ms)
109 |
110 | if t_end_lower_ms_idx is None or t_end_upper_ms_idx is None:
111 | # Cannot guarantee window size anymore
112 | return None
113 |
114 | events = dict()
115 | time_array_conservative = np.asarray(self.events['t'][t_end_lower_ms_idx:t_end_upper_ms_idx])
116 | _, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_end_us, t_end_us)
117 | t_end_us_idx = t_end_lower_ms_idx + idx_end_offset
118 | t_start_us_idx = t_end_us_idx - nr_events
119 | if t_start_us_idx < 0:
120 | t_start_us_idx = 0
121 |
122 | for dset_str in self.events.keys():
123 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx])
124 |
125 | return events
126 |
127 | def get_events_fixed_num_recurrent(self, t_start_us_idx: int, t_end_us_idx: int) -> Dict[str, np.ndarray]:
128 | """Get events (p, x, y, t) with fixed number of events
129 | Parameters
130 | ----------
131 | t_start_us_idx: start id
132 | t_end_us_idx: end id
133 | Returns
134 | -------
135 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved
136 | """
137 | assert t_start_us_idx < t_end_us_idx
138 |
139 | events = dict()
140 | for dset_str in self.events.keys():
141 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx])
142 |
143 | return events
144 |
145 |
146 | @staticmethod
147 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]:
148 | """Compute a conservative time window of time with millisecond resolution.
149 | We have a time to index mapping for each millisecond. Hence, we need
150 | to compute the lower and upper millisecond to retrieve events.
151 | Parameters
152 | ----------
153 | ts_start_us: start time in microseconds
154 | ts_end_us: end time in microseconds
155 | Returns
156 | -------
157 | window_start_ms: conservative start time in milliseconds
158 | window_end_ms: conservative end time in milliseconds
159 | """
160 | assert ts_end_us > ts_start_us
161 | window_start_ms = math.floor(ts_start_us/1000)
162 | window_end_ms = math.ceil(ts_end_us/1000)
163 | return window_start_ms, window_end_ms
164 |
165 | @staticmethod
166 | def get_conservative_ms(ts_us: int) -> Tuple[int, int]:
167 | """Convert time in microseconds into milliseconds
168 | ----------
169 | ts_us: time in microseconds
170 | Returns
171 | -------
172 | ts_lower_ms: lower millisecond
173 | ts_upper_ms: upper millisecond
174 | """
175 | ts_lower_ms = math.floor(ts_us / 1000)
176 | ts_upper_ms = math.ceil(ts_us / 1000)
177 | return ts_lower_ms, ts_upper_ms
178 |
179 | @staticmethod
180 | @jit(nopython=True)
181 | def get_time_indices_offsets(
182 | time_array: np.ndarray,
183 | time_start_us: int,
184 | time_end_us: int) -> Tuple[int, int]:
185 | """Compute index offset of start and end timestamps in microseconds
186 | Parameters
187 | ----------
188 | time_array: timestamps (in us) of the events
189 | time_start_us: start timestamp (in us)
190 | time_end_us: end timestamp (in us)
191 | Returns
192 | -------
193 | idx_start: Index within this array corresponding to time_start_us
194 | idx_end: Index within this array corresponding to time_end_us
195 | such that (in non-edge cases)
196 | time_array[idx_start] >= time_start_us
197 | time_array[idx_end] >= time_end_us
198 | time_array[idx_start - 1] < time_start_us
199 | time_array[idx_end - 1] < time_end_us
200 | this means that
201 | time_start_us <= time_array[idx_start:idx_end] < time_end_us
202 | """
203 |
204 | assert time_array.ndim == 1
205 |
206 | idx_start = -1
207 | if time_array[-1] < time_start_us:
208 | # This can happen in extreme corner cases. E.g.
209 | # time_array[0] = 1016
210 | # time_array[-1] = 1984
211 | # time_start_us = 1990
212 | # time_end_us = 2000
213 |
214 | # Return same index twice: array[x:x] is empty.
215 | return time_array.size, time_array.size
216 | else:
217 | for idx_from_start in range(0, time_array.size, 1):
218 | if time_array[idx_from_start] >= time_start_us:
219 | idx_start = idx_from_start
220 | break
221 | assert idx_start >= 0
222 |
223 | idx_end = time_array.size
224 | for idx_from_end in range(time_array.size - 1, -1, -1):
225 | if time_array[idx_from_end] >= time_end_us:
226 | idx_end = idx_from_end
227 | else:
228 | break
229 |
230 | assert time_array[idx_start] >= time_start_us
231 | if idx_end < time_array.size:
232 | assert time_array[idx_end] >= time_end_us
233 | if idx_start > 0:
234 | assert time_array[idx_start - 1] < time_start_us
235 | if idx_end > 0:
236 | assert time_array[idx_end - 1] < time_end_us
237 | return idx_start, idx_end
238 |
239 | def ms2idx(self, time_ms: int) -> int:
240 | assert time_ms >= 0
241 | if time_ms >= self.ms_to_idx.size:
242 | return None
243 | return self.ms_to_idx[time_ms]
244 |
--------------------------------------------------------------------------------
/DSEC/utils/viz_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | import torchvision.utils
5 | import torch.nn.functional as F
6 | import matplotlib.pyplot as plt
7 | import itertools
8 |
9 |
10 | def createRGBGrid(tensor_list, nrow):
11 | """Creates a grid of rgb values based on the tensor stored in tensor_list"""
12 | vis_tensor_list = []
13 | for tensor in tensor_list:
14 | vis_tensor_list.append(visualizeTensors(tensor))
15 |
16 | return torchvision.utils.make_grid(torch.cat(vis_tensor_list, dim=0), nrow=nrow)
17 |
18 |
19 | def createRGBImage(tensor, separate_pol=True):
20 | """Creates a grid of rgb values based on the tensor stored in tensor_list"""
21 | if tensor.shape[1] == 3:
22 | return tensor
23 | elif tensor.shape[1] == 1:
24 | return tensor.expand(-1, 3, -1, -1)
25 | elif tensor.shape[1] == 2:
26 | return visualizeHistogram(tensor)
27 | elif tensor.shape[1] > 3:
28 | return visualizeVoxelGrid(tensor, separate_pol)
29 |
30 |
31 | def visualizeTensors(tensor):
32 | """Creates a rgb image of the given tensor. Can be event histogram, event voxel grid, grayscale and rgb."""
33 | if tensor.shape[1] == 3:
34 | return tensor
35 | elif tensor.shape[1] == 1:
36 | return tensor.expand(-1, 3, -1, -1)
37 | elif tensor.shape[1] == 2:
38 | return visualizeHistogram(tensor)
39 | elif tensor.shape[1] > 3:
40 | return visualizeVoxelGrid(tensor)
41 |
42 |
43 | def visualizeHistogram(histogram):
44 | """Visualizes the input histogram"""
45 | batch, _, height, width = histogram.shape
46 | torch_image = torch.zeros([batch, 1, height, width], device=histogram.device)
47 |
48 | return torch.cat([histogram.clamp(0, 1), torch_image], dim=1)
49 |
50 |
51 | def visualizeVoxelGrid(voxel_grid, separate_pol=True):
52 | """Visualizes the input histogram"""
53 | batch, nr_channels, height, width = voxel_grid.shape
54 | if separate_pol:
55 | pos_events_idx = nr_channels // 2
56 | temporal_scaling = torch.arange(start=1, end=pos_events_idx+1, dtype=voxel_grid.dtype,
57 | device=voxel_grid.device)[None, :, None, None] / pos_events_idx
58 | pos_voxel_grid = voxel_grid[:, :pos_events_idx] * temporal_scaling
59 | neg_voxel_grid = voxel_grid[:, pos_events_idx:] * temporal_scaling
60 |
61 | torch_image = torch.zeros([batch, 1, height, width], device=voxel_grid.device)
62 | pos_image = torch.sum(pos_voxel_grid, dim=1, keepdim=True)
63 | neg_image = torch.sum(neg_voxel_grid, dim=1, keepdim=True)
64 |
65 | return torch.cat([neg_image.clamp(0, 1), pos_image.clamp(0, 1), torch_image], dim=1)
66 |
67 | sum_events = torch.sum(voxel_grid, dim=1).detach()
68 | event_preview = torch.zeros((batch, 3, height, width))
69 | b = event_preview[:, 0, :, :]
70 | r = event_preview[:, 2, :, :]
71 | b[sum_events > 0] = 255
72 | r[sum_events < 0] = 255
73 | return event_preview
74 |
75 |
76 | def visualizeConfusionMatrix(confusion_matrix, path_name=None):
77 | """
78 | Visualizes the confustion matrix using matplotlib.
79 |
80 | :param confusion_matrix: NxN numpy array
81 | :param path_name: if no path name is given, just an image is returned
82 | """
83 | import matplotlib.pyplot as plt
84 | nr_classes = confusion_matrix.shape[0]
85 | fig, ax = plt.subplots(1, 1, figsize=(16, 16))
86 | ax.matshow(confusion_matrix)
87 | ax.plot([-0.5, nr_classes - 0.5], [-0.5, nr_classes - 0.5], '-', color='grey')
88 | ax.set_xlabel('Labels')
89 | ax.set_ylabel('Predicted')
90 |
91 | if path_name is None:
92 | fig.tight_layout(pad=0)
93 | fig.canvas.draw()
94 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
95 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
96 | plt.close()
97 |
98 | return data
99 |
100 | else:
101 | fig.savefig(path_name)
102 | plt.close()
103 |
104 |
105 | def drawBoundingBoxes(np_image, bounding_boxes, class_name=None, ground_truth=True, rescale_image=False):
106 | """
107 | Draws the bounding boxes in the image
108 |
109 | :param np_image: [H, W, C]
110 | :param bounding_boxes: list of bounding boxes with shape [x, y, width, height].
111 | :param class_name: string
112 | """
113 | np_image = np_image.astype(np.float)
114 | resize_scale = 1.5
115 | if rescale_image:
116 | bounding_boxes[:, :4] = (bounding_boxes.astype(np.float)[:, :4] * resize_scale)
117 | new_dim = np.array(np_image.shape[:2], dtype=np.float) * resize_scale
118 | np_image = cv2.resize(np_image, tuple(new_dim.astype(int)[::-1]), interpolation=cv2.INTER_NEAREST)
119 |
120 | for i, bounding_box in enumerate(bounding_boxes):
121 | if bounding_box.sum() == 0:
122 | break
123 | if class_name is None:
124 | np_image = drawBoundingBox(np_image, bounding_box, ground_truth=ground_truth)
125 | else:
126 | np_image = drawBoundingBox(np_image, bounding_box, class_name[i], ground_truth)
127 |
128 | return np_image
129 |
130 |
131 | def drawBoundingBox(np_image, bounding_box, class_name=None, ground_truth=False):
132 | """
133 | Draws a bounding box in the image.
134 |
135 | :param np_image: [H, W, C]
136 | :param bounding_box: [x, y, width, height].
137 | :param class_name: string
138 | """
139 | if ground_truth:
140 | bbox_color = np.array([0, 1, 1])
141 | else:
142 | bbox_color = np.array([1, 0, 1])
143 | height, width = bounding_box[2:4]
144 |
145 | np_image[bounding_box[0], bounding_box[1]:(bounding_box[1] + width)] = bbox_color
146 | np_image[bounding_box[0]:(bounding_box[0] + height), (bounding_box[1] + width)] = bbox_color
147 | np_image[(bounding_box[0] + height), bounding_box[1]:(bounding_box[1] + width)] = bbox_color
148 | np_image[bounding_box[0]:(bounding_box[0] + height), bounding_box[1]] = bbox_color
149 |
150 | if class_name is not None:
151 | font = cv2.FONT_HERSHEY_SIMPLEX
152 | font_color = (0, 0, 0)
153 | font_scale = 0.5
154 | thickness = 1
155 | bottom_left = tuple(((bounding_box[[1, 0]] + np.array([+1, height - 2]))).astype(int))
156 |
157 | # Draw Box
158 | (text_width, text_height) = cv2.getTextSize(class_name, font, fontScale=font_scale, thickness=thickness)[0]
159 | box_coords = ((bottom_left[0], bottom_left[1] + 2),
160 | (bottom_left[0] + text_width + 2, bottom_left[1] - text_height - 2 + 2))
161 | color_format = (int(bbox_color[0]), int(bbox_color[1]), int(bbox_color[2]))
162 | # np_image = cv2.UMat(np_image)
163 | np_image = cv2.UMat(np_image).get()
164 | cv2.rectangle(np_image, box_coords[0], box_coords[1], color_format, cv2.FILLED)
165 |
166 | cv2.putText(np_image, class_name, bottom_left, font, font_scale, font_color, thickness, cv2.LINE_AA)
167 |
168 | return np_image
169 |
170 |
171 | def visualizeFlow(tensor_flow_map):
172 | """
173 | Visualizes the direction flow based on the HSV model
174 | """
175 | np_flow_map = tensor_flow_map.cpu().detach().numpy()
176 | batch_s, channel, height, width = np_flow_map.shape
177 | viz_array = np.zeros([batch_s, height, width, 3], dtype=np.uint8)
178 | hsv = np.zeros([height, width, 3], dtype=np.uint8)
179 |
180 | for i, sample_flow_map in enumerate(np_flow_map):
181 | hsv[..., 1] = 255
182 | mag, ang = cv2.cartToPolar(sample_flow_map[0, :, :], sample_flow_map[1, :, :])
183 | hsv[..., 0] = ang * 180 / np.pi / 2
184 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
185 |
186 | bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
187 | viz_array[i] = bgr
188 |
189 | return torch.from_numpy(viz_array.transpose([0, 3, 1, 2]) / 255.).to(tensor_flow_map.device)
190 |
191 | def create_checkerboard(N, C, H, W):
192 | cell_sz = max(min(H, W) // 32, 1)
193 | mH = (H + cell_sz - 1) // cell_sz
194 | mW = (W + cell_sz - 1) // cell_sz
195 | checkerboard = torch.full((mH, mW), 0.25, dtype=torch.float32)
196 | checkerboard[0::2, 0::2] = 0.75
197 | checkerboard[1::2, 1::2] = 0.75
198 | checkerboard = checkerboard.float().view(1, 1, mH, mW)
199 | checkerboard = F.interpolate(checkerboard, scale_factor=cell_sz, mode='nearest')
200 | checkerboard = checkerboard[:, :, :H, :W].repeat(N, C, 1, 1)
201 | return checkerboard
202 |
203 |
204 | def prepare_semseg(img, semseg_color_map, semseg_ignore_label):
205 | assert (img.dim() == 3 or img.dim() == 4 and img.shape[1] == 1) and img.dtype in (torch.int, torch.long), \
206 | f'Expecting 4D tensor with semseg classes, got {img.shape}'
207 | if img.dim() == 4:
208 | img = img.squeeze(1)
209 | colors = torch.tensor(semseg_color_map, dtype=torch.float32)
210 | assert colors.dim() == 2 and colors.shape[1] == 3
211 | if torch.max(colors) > 128:
212 | colors /= 255
213 | img = img.cpu().clone() # N x H x W
214 | N, H, W = img.shape
215 | img_color_ids = torch.unique(img)
216 | assert all(c_id == semseg_ignore_label or 0 <= c_id < len(semseg_color_map) for c_id in img_color_ids)
217 | checkerboard, mask_ignore = None, None
218 | if semseg_ignore_label in img_color_ids:
219 | checkerboard = create_checkerboard(N, 3, H, W)
220 | mask_ignore = img == semseg_ignore_label
221 | img[mask_ignore] = 0
222 | img = colors[img] # N x H x W x 3
223 | img = img.permute(0, 3, 1, 2)
224 |
225 | # checkerboard
226 | # if semseg_ignore_label in img_color_ids:
227 | # mask_ignore = mask_ignore.unsqueeze(1).repeat(1, 3, 1, 1)
228 | # img[mask_ignore] = checkerboard[mask_ignore]
229 | return img
230 |
231 |
232 | def plot_confusion_matrix(cm, classes,
233 | normalize=False,
234 | title='Confusion matrix',
235 | cmap=plt.cm.Blues):
236 | """
237 | This function prints and plots the confusion matrix.
238 | Normalization can be applied by setting `normalize=True`.
239 | """
240 | cm = cm.numpy()
241 | if normalize:
242 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
243 | print("Normalized confusion matrix")
244 | else:
245 | print('Confusion matrix, without normalization')
246 |
247 | # print(cm)
248 |
249 | fig = plt.figure()
250 |
251 | plt.imshow(cm, interpolation='nearest', cmap=cmap)
252 | plt.title(title)
253 | plt.colorbar()
254 | tick_marks = np.arange(len(classes))
255 | plt.xticks(tick_marks, classes, rotation=45)
256 | plt.yticks(tick_marks, classes)
257 |
258 | fmt = '.2f' if normalize else 'd'
259 | thresh = cm.max() / 2.
260 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
261 | plt.text(j, i, format(cm[i, j], fmt),
262 | horizontalalignment="center",
263 | color="white" if cm[i, j] > thresh else "black")
264 |
265 | plt.tight_layout()
266 | plt.ylabel('True label')
267 | plt.xlabel('Predicted label')
268 | return fig
269 |
--------------------------------------------------------------------------------
/DSEC/visualization/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/DSEC/visualization/__init__.py
--------------------------------------------------------------------------------
/DSEC/visualization/eventreader.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import weakref
3 |
4 | import h5py
5 |
6 | from utils.eventslicer import EventSlicer
7 |
8 |
9 | class EventReaderAbstract:
10 | def __init__(self, filepath: Path):
11 | assert filepath.is_file()
12 | assert filepath.name.endswith('.h5')
13 | self.h5f = h5py.File(str(filepath), 'r')
14 | self._finalizer = weakref.finalize(self, self.close_callback, self.h5f)
15 |
16 | @staticmethod
17 | def close_callback(h5f: h5py.File):
18 | h5f.close()
19 |
20 | def __enter__(self):
21 | return self
22 |
23 | def __exit__(self, exc_type, exc_value, traceback):
24 | self._finalizer()
25 |
26 | def __iter__(self):
27 | return self
28 |
29 | def __next__(self):
30 | raise NotImplementedError
31 |
32 |
33 | class EventReader(EventReaderAbstract):
34 | def __init__(self, filepath: Path, dt_milliseconds: int):
35 | super().__init__(filepath)
36 | self.event_slicer = EventSlicer(self.h5f)
37 |
38 | self.dt_us = int(dt_milliseconds * 1000)
39 | self.t_start_us = self.event_slicer.get_start_time_us()
40 | self.t_end_us = self.event_slicer.get_final_time_us()
41 |
42 | self._length = (self.t_end_us - self.t_start_us)//self.dt_us
43 |
44 | def __len__(self):
45 | return self._length
46 |
47 | def __next__(self):
48 | t_end_us = self.t_start_us + self.dt_us
49 | if t_end_us > self.t_end_us:
50 | raise StopIteration
51 | events = self.event_slicer.get_events(self.t_start_us, t_end_us)
52 | if events is None:
53 | raise StopIteration
54 |
55 | self.t_start_us = t_end_us
56 | return events
57 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ESS: Learning Event-based Semantic Segmentation from Still Images
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | This is the code for the paper **ESS: Learning Event-based Semantic Segmentation from Still Images**
10 | ([PDF](https://rpg.ifi.uzh.ch/docs/ECCV22_Sun.pdf)) by Zhaoning Sun*, [Nico Messikommer](https://messikommernico.github.io/)*, [Daniel Gehrig](https://danielgehrig18.github.io), and [Davide Scaramuzza](http://rpg.ifi.uzh.ch/people_scaramuzza.html). For an overview of our method, check out our [video](https://youtu.be/Tby5c9IDsDc).
11 |
12 | If you use any of this code, please cite the following publication:
13 |
14 | ```bibtex
15 | @Article{Sun22eccv,
16 | author = {Zhaoning Sun* and Nico Messikommer* and Daniel Gehrig and Davide Scaramuzza},
17 | title = {ESS: Learning Event-based Semantic Segmentation from Still Images},
18 | journal = {European Conference on Computer Vision. (ECCV)},
19 | year = {2022},
20 | }
21 | ```
22 |
23 | ## Abstract
24 | Retrieving accurate semantic information in challenging high dynamic range (HDR) and high-speed conditions
25 | remains an open challenge for image-based algorithms due to severe image degradations. Event cameras promise
26 | to address these challenges since they feature a much higher dynamic range and are resilient to motion blur.
27 | Nonetheless, semantic segmentation with event cameras is still in its infancy which is chiefly due to the
28 | novelty of the sensor, and the lack of high-quality, labeled datasets. In this work, we introduce ESS, which
29 | tackles this problem by directly transferring the semantic segmentation task from existing labeled image
30 | datasets to unlabeled events via unsupervised domain adaptation (UDA). Compared to existing UDA methods,
31 | our approach aligns recurrent, motion-invariant event embeddings with image embeddings. For this reason, our
32 | method neither requires video data nor per-pixel alignment between images and events and, crucially, does not
33 | need to hallucinate motion from still images. Additionally, to spur further research in event-based semantic
34 | segmentation, we introduce DSEC-Semantic, the first large-scale event-based dataset with fine-grained labels.
35 | We show that using image labels alone, ESS outperforms existing UDA approaches, and when combined with event
36 | labels, it even outperforms state-of-the-art supervised approaches on both DDD17 and DSEC-Semantic. Finally,
37 | ESS is general-purpose, which unlocks the vast amount of existing labeled image datasets and paves the way
38 | for new and exciting research directions in new fields previously inaccessible for event cameras.
39 |
40 | ## Installation
41 |
42 | ### Dependencies
43 | If desired, a conda environment can be created using the followig command:
44 | ```bash
45 | conda create -n
46 | ```
47 | As an initial step, the wheel package needs to be installed with the following command:
48 | ```bash
49 | pip install wheel
50 | ```
51 | The required python packages are listed in the `requirements.txt` file.
52 | ```bash
53 | pip install -r requirements.txt
54 | ```
55 |
56 | ### Pre-trained E2VID Model
57 | Pre-trained E2VID model needs to be downloaded [here](https://github.com/uzh-rpg/rpg_e2vid) and placed in `/e2vid/pretrained/`.
58 |
59 | ## Datasets
60 |
61 | ### DSEC-Semantic
62 | The DSEC-Semantic dataset can be downloaded [here](https://dsec.ifi.uzh.ch/dsec-semantic/). The dataset should have the following format:
63 |
64 | ├── DSEC_Semantic
65 | │ ├── train
66 | │ │ ├── zurich_city_00_a
67 | │ │ │ ├── semantic
68 | │ │ │ │ ├── left
69 | │ │ │ │ │ ├── 11classes
70 | │ │ │ │ │ │ └──data
71 | │ │ │ │ │ │ ├── 000000.png
72 | │ │ │ │ │ │ └── ...
73 | │ │ │ │ │ └── 19classes
74 | │ │ │ │ │ └──data
75 | │ │ │ │ │ ├── 000000.png
76 | │ │ │ │ │ └── ...
77 | │ │ │ │ └── timestamps.txt
78 | │ │ │ └── events
79 | │ │ │ └── left
80 | │ │ │ ├── events.h5
81 | │ │ │ └── rectify_map.h5
82 | │ │ └── ...
83 | │ └── test
84 | │ ├── zurich_city_13_a
85 | │ │ └── ...
86 | │ └── ...
87 |
88 | ### DDD17
89 | The original DDD17 dataset with semantic segmentation labels can be downloaded [here](https://github.com/Shathe/Ev-SegNet).
90 | Additionally, we provide a pre-processed DDD17 dataset with semantic labels [here](https://download.ifi.uzh.ch/rpg/ESS/ddd17_seg.tar.gz). Please do not forget to cite [DDD17](https://sensors.ini.uzh.ch/news_page/DDD17.html) and [Ev-SegNet](https://github.com/Shathe/Ev-SegNet) if you are using the DDD17 with semantic labels.
91 |
92 | ### Cityscapes
93 | The Cityscapes dataset can be downloaded [here](https://www.cityscapes-dataset.com/).
94 |
95 | ## Training
96 | The settings for the training can be specified in `config/settings_XXXX.yaml`.
97 | Two different models can be trained:
98 | - ess: ESS UDA / ESS supervised (events labels + frames labels)
99 | - ess_supervised: ESS supervised (only events labels)
100 |
101 | The following command starts the training:
102 |
103 | ```bash
104 | CUDA_VISIBLE_DEVICES=, python train.py --settings_file config/settings_XXXX.yaml
105 | ```
106 |
107 | For testing the pre-trained models, please set `load_pretrained_weights=True` and specify the path of pre-trained weights in `pretrained_file`.
108 |
109 | ## Pre-trained Weights
110 | To download the pre-trained weights for the models on DDD17 and DSEC in the UDA setting, please fill in your details in [this](https://docs.google.com/forms/d/e/1FAIpQLScn5XWvBcmjPoSbaIqEUoEpWeheLGpQTUeK6Pp19wx2jNCPpA/viewform?usp=sf_link) Google Form.
111 |
112 | # Acknowledgement
113 | Several network architectures were adapted from:
114 | https://github.com/uzh-rpg/rpg_e2vid
115 |
116 | The general training framework was inspired by:
117 | https://github.com/uzh-rpg/rpg_ev-transfer
118 |
119 | The DSEC data loader was adapted from:
120 | https://github.com/uzh-rpg/DSEC
121 |
122 | The optimizer was adapted from:
123 | https://github.com/LiyuanLucasLiu/RAdam
124 |
125 | The DICE loss was adapted from:
126 | https://github.com/Guocode/DiceLoss.Pytorch
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/config/__init__.py
--------------------------------------------------------------------------------
/config/settings_DDD17.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name_a: 'Cityscapes_gray'
3 | name_b: 'DDD17_events'
4 | DDD17_events:
5 | dataset_path:
6 | split_train: 'train'
7 | shape: [200, 346]
8 | nr_events_data: 20
9 | nr_events_files_per_data: None
10 | fixed_duration: False
11 | delta_t_per_data: 50
12 | require_paired_data_train: False
13 | require_paired_data_val: True
14 | nr_events_window: 32000
15 | event_representation: 'voxel_grid'
16 | nr_temporal_bins: 5
17 | separate_pol: False
18 | normalize_event: False
19 | cityscapes_img:
20 | dataset_path:
21 | shape: [200, 352] # [200, 352] for DDD17, [440, 640] for DSEC
22 | random_crop: True # True for DDD17, False for DSEC
23 | read_two_imgs: False
24 | require_paired_data_train: False
25 | require_paired_data_val: False
26 | task:
27 | semseg_num_classes: 6 # 6 for DDD17, 11 for DSEC
28 | dir:
29 | log:
30 | model:
31 | model_name: 'ess' # ['ess', 'ess_supervised']
32 | skip_connect_encoder: True
33 | skip_connect_task: True
34 | skip_connect_task_type: 'concat'
35 | data_augmentation_train: True
36 | train_on_event_labels: False # True for ESS supervised (events labels + frames labels), False for ESS UDA
37 | optim:
38 | batch_size_a: 16
39 | batch_size_b: 16
40 | lr_front: 1e-5
41 | lr_back: 1e-4
42 | lr_decay: 1
43 | num_epochs: 20
44 | val_epoch_step: 1
45 | weight_task_loss: 1
46 | weight_cycle_pred_loss: 1
47 | weight_cycle_emb_loss: 0.01
48 | weight_cycle_task_loss: 0.01
49 | task_loss: ['dice', 'cross_entropy']
50 | checkpoint:
51 | save_checkpoint: True
52 | resume_training: False
53 | load_pretrained_weights: False # True for loading pre-trained weights
54 | resume_file:
55 | pretrained_file:
56 | hardware:
57 | # num_cpu_workers: {-1: auto, 0: main thread, >0: ...}
58 | num_cpu_workers: 8
59 | gpu_device: 0 # [0 or 'cpu']
60 |
61 |
--------------------------------------------------------------------------------
/config/settings_DSEC.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name_a: 'Cityscapes_gray'
3 | name_b: 'DSEC_events'
4 | DSEC_events:
5 | dataset_path:
6 | shape: [440, 640]
7 | nr_events_data: 20
8 | nr_events_files_per_data: None
9 | fixed_duration: False
10 | delta_t_per_data: 50
11 | require_paired_data_train: False
12 | require_paired_data_val: True
13 | nr_events_window: 100000
14 | event_representation: 'voxel_grid'
15 | nr_temporal_bins: 5
16 | separate_pol: False
17 | normalize_event: False
18 | cityscapes_img:
19 | dataset_path:
20 | shape: [440, 640] # [200, 352] for DDD17, [440, 640] for DSEC
21 | random_crop: False # True for DDD17, False for DSEC
22 | read_two_imgs: False
23 | require_paired_data_train: False
24 | require_paired_data_val: False
25 | task:
26 | semseg_num_classes: 11 # 6 for DDD17, 11 for DSEC
27 | dir:
28 | log:
29 | model:
30 | model_name: 'ess' # ['ess', 'ess_supervised']
31 | skip_connect_encoder: True
32 | skip_connect_task: True
33 | skip_connect_task_type: 'concat'
34 | data_augmentation_train: True
35 | train_on_event_labels: False # True for ESS supervised (events labels + frames labels), False for ESS UDA
36 | optim:
37 | batch_size_a: 8
38 | batch_size_b: 8
39 | lr_front: 5e-4
40 | lr_back: 5e-4
41 | lr_decay: 1
42 | num_epochs: 50
43 | val_epoch_step: 5
44 | weight_task_loss: 1
45 | weight_cycle_pred_loss: 1
46 | weight_cycle_emb_loss: 1
47 | weight_cycle_task_loss: 1
48 | task_loss: ['dice', 'cross_entropy']
49 | checkpoint:
50 | save_checkpoint: True
51 | resume_training: False
52 | load_pretrained_weights: False # True for loading pre-trained weights
53 | resume_file:
54 | pretrained_file:
55 | hardware:
56 | # num_cpu_workers: {-1: auto, 0: main thread, >0: ...}
57 | num_cpu_workers: 8
58 | gpu_device: 0 # [0 or 'cpu']
59 |
60 |
--------------------------------------------------------------------------------
/datasets/DSEC_events_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import numpy as np
5 | from PIL import Image
6 |
7 | import torchvision.transforms as transforms
8 | import datasets.data_util as data_util
9 | from pathlib import Path
10 |
11 | from DSEC.dataset.provider import DatasetProvider
12 |
13 |
14 | def DSECEvents(dsec_dir, nr_events_data=1, delta_t_per_data=50, nr_events_window=-1,
15 | augmentation=False, mode='train', task='segmentation', event_representation='voxel_grid',
16 | nr_bins_per_data=5, require_paired_data=False, separate_pol=False, normalize_event=False,
17 | semseg_num_classes=11, fixed_duration=False, resize=False):
18 | """
19 | Creates an iterator over the EventScape dataset.
20 |
21 | :param root: path to dataset root
22 | :param height: height of dataset image
23 | :param width: width of dataset image
24 | :param nr_events_window: number of events summed in the sliding histogram
25 | :param augmentation: flip, shift and random window start for training
26 | :param mode: 'train', 'test' or 'val'
27 | """
28 | dsec_dir = Path(dsec_dir)
29 | assert dsec_dir.is_dir()
30 |
31 | dataset_provider = DatasetProvider(dsec_dir, mode, event_representation=event_representation,
32 | nr_events_data=nr_events_data, delta_t_per_data=delta_t_per_data,
33 | nr_events_window=nr_events_window, nr_bins_per_data=nr_bins_per_data,
34 | require_paired_data=require_paired_data, normalize_event=normalize_event,
35 | separate_pol=separate_pol, semseg_num_classes=semseg_num_classes,
36 | augmentation=augmentation, fixed_duration=fixed_duration, resize=resize)
37 | if mode == 'train':
38 | train_dataset = dataset_provider.get_train_dataset()
39 | return train_dataset
40 | else:
41 | val_dataset = dataset_provider.get_val_dataset()
42 | return val_dataset
43 |
--------------------------------------------------------------------------------
/datasets/cityscapes_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image
4 | from torchvision import datasets
5 | import torchvision.transforms as transforms
6 | from torch.utils.data import Dataset
7 | import albumentations as A
8 | from utils.labels import Id2label_6_Cityscapes, Id2label_11_Cityscapes, fromIdToTrainId
9 |
10 |
11 | class CityscapesGray(Dataset):
12 | def __init__(self, root, height=None, width=None, augmentation=False, split='train', target_type='semantic',
13 | semseg_num_classes=6, standardization=False, random_crop=True):
14 |
15 | self.root = root
16 | self.split = split
17 | self.height = height
18 | self.width = width
19 | self.random_crop = random_crop
20 | if self.random_crop:
21 | self.height_resize = 256 # 154
22 | self.width_resize = 512 # 308
23 | else:
24 | self.height_resize = height
25 | self.width_resize = width
26 | self.transform = transforms.Compose([
27 | transforms.Grayscale(),
28 | transforms.Resize([self.height_resize, self.width_resize])
29 | ])
30 | self.cityscapes_dataset = datasets.Cityscapes(self.root, split=self.split, mode='fine', target_type=target_type,
31 | transform=self.transform, target_transform=None)
32 | self.augmentation = augmentation
33 | self.standardization = standardization
34 | if self.standardization:
35 | mean = 0.3091
36 | std = 0.1852
37 | self.standardization_a = A.Normalize(mean=mean, std=std)
38 |
39 | if self.augmentation:
40 | self.transform_a = A.Compose([
41 | A.HorizontalFlip(p=0.5),
42 | A.ShiftScaleRotate(scale_limit=(0, 0.5), rotate_limit=0, shift_limit=0.1, p=0.5, border_mode=0),
43 | A.PadIfNeeded(min_height=self.height, min_width=self.width, always_apply=True, border_mode=0),
44 | A.RandomCrop(height=self.height, width=self.width, always_apply=True),
45 | A.GaussNoise(p=0.2),
46 | A.Perspective(p=0.2),
47 | A.RandomBrightnessContrast(p=0.5),
48 | A.OneOf(
49 | [
50 | A.Sharpen(p=1),
51 | A.Blur(blur_limit=3, p=1),
52 | A.MotionBlur(blur_limit=3, p=1),
53 | ],
54 | p=0.5,
55 | )
56 | ])
57 |
58 | self.transform_a_random_crop = A.Compose([
59 | A.HorizontalFlip(p=0.5),
60 | A.ShiftScaleRotate(scale_limit=(0, 0.5), rotate_limit=0, shift_limit=0, p=0.5, border_mode=0),
61 | A.PadIfNeeded(min_height=self.height, min_width=self.width, always_apply=True, border_mode=0),
62 | A.RandomCrop(height=self.height, width=self.width, always_apply=True),
63 | A.GaussNoise(p=0.2),
64 | A.Perspective(p=0.2),
65 | A.RandomBrightnessContrast(p=0.5),
66 | A.OneOf(
67 | [
68 | A.Sharpen(p=1),
69 | A.Blur(blur_limit=3, p=1),
70 | A.MotionBlur(blur_limit=3, p=1),
71 | ],
72 | p=0.5,
73 | )
74 | ])
75 |
76 | self.transform_a_center_crop = A.Compose([
77 | A.CenterCrop(height=self.height, width=self.width, always_apply=True),
78 | ])
79 |
80 | self.semseg_num_classes = semseg_num_classes
81 | self.require_paired_data = False
82 |
83 | def __len__(self):
84 | return len(self.cityscapes_dataset)
85 |
86 | def __getitem__(self, idx):
87 | img, label = self.cityscapes_dataset[idx]
88 | img = np.array(img)
89 | label = label.resize((self.width_resize, self.height_resize), Image.NEAREST)
90 | label = np.array(label)
91 |
92 | if self.standardization:
93 | Imin = np.min(img)
94 | Imax = np.max(img)
95 | img = 255.0 * (img - Imin) / (Imax - Imin)
96 | img = img.astype('uint8')
97 |
98 | if self.random_crop:
99 | img = img[:self.height, :]
100 | label = label[:self.height, :]
101 |
102 | if self.augmentation:
103 | sample = self.transform_a_random_crop(image=img, mask=label)
104 | else:
105 | sample = self.transform_a_center_crop(image=img, mask=label)
106 | img, label = sample["image"], sample['mask']
107 |
108 | else:
109 | if self.augmentation:
110 | sample = self.transform_a(image=img, mask=label)
111 | img, label = sample["image"], sample['mask']
112 |
113 | img = Image.fromarray(img.astype('uint8'))
114 |
115 | if self.semseg_num_classes == 6:
116 | label = fromIdToTrainId(label, Id2label_6_Cityscapes)
117 | elif self.semseg_num_classes == 11:
118 | label = fromIdToTrainId(label, Id2label_11_Cityscapes)
119 |
120 | label_tensor = torch.from_numpy(label).long()
121 | img_transform = transforms.Compose([
122 | transforms.ToTensor()
123 | ])
124 | img_tensor = img_transform(img)
125 |
126 | return img_tensor, label_tensor
127 |
128 |
--------------------------------------------------------------------------------
/datasets/data_util.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def generate_input_representation(events, event_representation, shape, nr_temporal_bins=5, separate_pol=True):
7 | """
8 | Events: N x 4, where cols are x, y, t, polarity, and polarity is in {-1, 1}. x and y correspond to image
9 | coordinates u and v.
10 | """
11 | if event_representation == 'histogram':
12 | return generate_event_histogram(events, shape)
13 | elif event_representation == 'voxel_grid':
14 | return generate_voxel_grid(events, shape, nr_temporal_bins, separate_pol)
15 |
16 |
17 | def generate_event_histogram(events, shape):
18 | """
19 | Events: N x 4, where cols are x, y, t, polarity, and polarity is in {-1, 1}. x and y correspond to image
20 | coordinates u and v.
21 | """
22 | height, width = shape
23 | x, y, t, p = events.T
24 | x = x.astype(np.int)
25 | y = y.astype(np.int)
26 | p[p == 0] = -1 # polarity should be +1 / -1
27 | img_pos = np.zeros((height * width,), dtype="float32")
28 | img_neg = np.zeros((height * width,), dtype="float32")
29 |
30 | np.add.at(img_pos, x[p == 1] + width * y[p == 1], 1)
31 | np.add.at(img_neg, x[p == -1] + width * y[p == -1], 1)
32 |
33 | histogram = np.stack([img_neg, img_pos], 0).reshape((2, height, width))
34 |
35 | return histogram
36 |
37 |
38 | def normalize_voxel_grid(events):
39 | """Normalize event voxel grids"""
40 | nonzero_ev = (events != 0)
41 | num_nonzeros = nonzero_ev.sum()
42 | if num_nonzeros > 0:
43 | # compute mean and stddev of the **nonzero** elements of the event tensor
44 | # we do not use PyTorch's default mean() and std() functions since it's faster
45 | # to compute it by hand than applying those funcs to a masked array
46 | mean = events.sum() / num_nonzeros
47 | stddev = torch.sqrt((events ** 2).sum() / num_nonzeros - mean ** 2)
48 | mask = nonzero_ev.float()
49 | events = mask * (events - mean) / stddev
50 |
51 | return events
52 |
53 |
54 | def generate_voxel_grid(events, shape, nr_temporal_bins, separate_pol=True):
55 | """
56 | Build a voxel grid with bilinear interpolation in the time domain from a set of events.
57 | :param events: a [N x 4] NumPy array containing one event per row in the form: [timestamp, x, y, polarity]
58 | :param nr_temporal_bins: number of bins in the temporal axis of the voxel grid
59 | :param shape: dimensions of the voxel grid
60 | """
61 | height, width = shape
62 | assert(events.shape[1] == 4)
63 | assert(nr_temporal_bins > 0)
64 | assert(width > 0)
65 | assert(height > 0)
66 |
67 | voxel_grid_positive = np.zeros((nr_temporal_bins, height, width), np.float32).ravel()
68 | voxel_grid_negative = np.zeros((nr_temporal_bins, height, width), np.float32).ravel()
69 |
70 | # normalize the event timestamps so that they lie between 0 and num_bins
71 | last_stamp = events[-1, 2]
72 | first_stamp = events[0, 2]
73 | deltaT = last_stamp - first_stamp
74 |
75 | if deltaT == 0:
76 | deltaT = 1.0
77 |
78 | # events[:, 2] = (nr_temporal_bins - 1) * (events[:, 2] - first_stamp) / deltaT
79 | xs = events[:, 0].astype(np.int)
80 | ys = events[:, 1].astype(np.int)
81 | # ts = events[:, 2]
82 | # print(ts[:10])
83 | ts = (nr_temporal_bins - 1) * (events[:, 2] - first_stamp) / deltaT
84 |
85 | pols = events[:, 3]
86 | pols[pols == 0] = -1 # polarity should be +1 / -1
87 |
88 | tis = ts.astype(np.int)
89 | dts = ts - tis
90 | vals_left = np.abs(pols) * (1.0 - dts)
91 | vals_right = np.abs(pols) * dts
92 | pos_events_indices = pols == 1
93 |
94 | # Positive Voxels Grid
95 | valid_indices_pos = np.logical_and(tis < nr_temporal_bins, pos_events_indices)
96 | valid_pos = (xs < width) & (xs >= 0) & (ys < height) & (ys >= 0) & (ts >= 0) & (ts < nr_temporal_bins)
97 | valid_indices_pos = np.logical_and(valid_indices_pos, valid_pos)
98 |
99 | np.add.at(voxel_grid_positive, xs[valid_indices_pos] + ys[valid_indices_pos] * width +
100 | tis[valid_indices_pos] * width * height, vals_left[valid_indices_pos])
101 |
102 | valid_indices_pos = np.logical_and((tis + 1) < nr_temporal_bins, pos_events_indices)
103 | valid_indices_pos = np.logical_and(valid_indices_pos, valid_pos)
104 | np.add.at(voxel_grid_positive, xs[valid_indices_pos] + ys[valid_indices_pos] * width +
105 | (tis[valid_indices_pos] + 1) * width * height, vals_right[valid_indices_pos])
106 |
107 | # Negative Voxels Grid
108 | valid_indices_neg = np.logical_and(tis < nr_temporal_bins, ~pos_events_indices)
109 | valid_indices_neg = np.logical_and(valid_indices_neg, valid_pos)
110 |
111 | np.add.at(voxel_grid_negative, xs[valid_indices_neg] + ys[valid_indices_neg] * width +
112 | tis[valid_indices_neg] * width * height, vals_left[valid_indices_neg])
113 |
114 | valid_indices_neg = np.logical_and((tis + 1) < nr_temporal_bins, ~pos_events_indices)
115 | valid_indices_neg = np.logical_and(valid_indices_neg, valid_pos)
116 | np.add.at(voxel_grid_negative, xs[valid_indices_neg] + ys[valid_indices_neg] * width +
117 | (tis[valid_indices_neg] + 1) * width * height, vals_right[valid_indices_neg])
118 |
119 | voxel_grid_positive = np.reshape(voxel_grid_positive, (nr_temporal_bins, height, width))
120 | voxel_grid_negative = np.reshape(voxel_grid_negative, (nr_temporal_bins, height, width))
121 |
122 | if separate_pol:
123 | return np.concatenate([voxel_grid_positive, voxel_grid_negative], axis=0)
124 |
125 | voxel_grid = voxel_grid_positive - voxel_grid_negative
126 | return voxel_grid
127 |
128 |
--------------------------------------------------------------------------------
/datasets/ddd17_events_loader.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from os.path import join, exists, dirname, basename
3 | import os
4 | import cv2
5 | import torch
6 | import numpy as np
7 | from torch.utils.data import Dataset
8 | import torch.nn.functional as f
9 | import torchvision.transforms as transforms
10 |
11 | from datasets.extract_data_tools.example_loader_ddd17 import load_files_in_directory, extract_events_from_memmap
12 | import datasets.data_util as data_util
13 | import albumentations as A
14 | from PIL import Image
15 | from utils.labels import shiftUpId, shiftDownId
16 |
17 |
18 | def get_split(dirs, split):
19 | return {
20 | "train": [dirs[0], dirs[2], dirs[3], dirs[5], dirs[6]],
21 | "test": [dirs[4]],
22 | "valid": [dirs[1]]
23 | }[split]
24 |
25 |
26 | def unzip_segmentation_masks(dirs):
27 | for d in dirs:
28 | assert exists(join(d, "segmentation_masks.zip"))
29 | if not exists(join(d, "segmentation_masks")):
30 | print("Unzipping segmentation mask in %s" % d)
31 | os.system("unzip %s -d %s" % (join(d, "segmentation_masks"), d))
32 |
33 |
34 | class DDD17Events(Dataset):
35 | def __init__(self, root, split="train", event_representation='voxel_grid',
36 | nr_events_data=5, delta_t_per_data=50, nr_bins_per_data=5, require_paired_data=False,
37 | separate_pol=False, normalize_event=False, augmentation=False, fixed_duration=False,
38 | nr_events_per_data=32000, resize=True, random_crop=False):
39 | data_dirs = sorted(glob.glob(join(root, "dir*")))
40 | assert len(data_dirs) > 0
41 | assert split in ["train", "valid", "test"]
42 |
43 | self.split = split
44 | self.augmentation = augmentation
45 | self.fixed_duration = fixed_duration
46 | self.nr_events_per_data = nr_events_per_data
47 |
48 | self.nr_events_data = nr_events_data
49 | self.delta_t_per_data = delta_t_per_data
50 | if self.fixed_duration:
51 | self.t_interval = nr_events_data * delta_t_per_data
52 | else:
53 | self.t_interval = -1
54 | self.nr_events = self.nr_events_data * self.nr_events_per_data
55 | assert self.t_interval in [10, 50, 250, -1]
56 | self.nr_temporal_bins = nr_bins_per_data
57 | self.require_paired_data = require_paired_data
58 | self.event_representation = event_representation
59 | self.shape = [260, 346]
60 | self.resize = resize
61 | self.shape_resize = [260, 352]
62 | self.random_crop = random_crop
63 | self.shape_crop = [120, 216]
64 | self.separate_pol = separate_pol
65 | self.normalize_event = normalize_event
66 | self.dirs = get_split(data_dirs, split)
67 | # unzip_segmentation_masks(self.dirs)
68 |
69 | self.files = []
70 | for d in self.dirs:
71 | self.files += glob.glob(join(d, "segmentation_masks", "*.png"))
72 |
73 | print("[DDD17Segmentation]: Found %s segmentation masks for split %s" % (len(self.files), split))
74 |
75 | # load events and image_idx -> event index mapping
76 | self.img_timestamp_event_idx = {}
77 | self.event_data = {}
78 |
79 | print("[DDD17Segmentation]: Loading real events.")
80 | self.event_dirs = self.dirs
81 |
82 | for d in self.event_dirs:
83 | img_timestamp_event_idx, t_events, xyp_events, _ = load_files_in_directory(d, self.t_interval)
84 | self.img_timestamp_event_idx[d] = img_timestamp_event_idx
85 | self.event_data[d] = [t_events, xyp_events]
86 |
87 | if self.augmentation:
88 | self.transform_a = A.ReplayCompose([
89 | A.HorizontalFlip(p=0.5)
90 | ])
91 | self.transform_a_random_crop = A.ReplayCompose([
92 | A.HorizontalFlip(p=0.5),
93 | A.RandomCrop(height=self.shape_crop[0], width=self.shape_crop[1], always_apply=True)])
94 | self.transform_a_center_crop = A.ReplayCompose([
95 | A.CenterCrop(height=self.shape_crop[0], width=self.shape_crop[1], always_apply=True),
96 | ])
97 |
98 | def __len__(self):
99 | return len(self.files)
100 |
101 | def apply_augmentation(self, transform_a, events, label):
102 | label = shiftUpId(label)
103 | A_data = transform_a(image=events[0, :, :].numpy(), mask=label)
104 | label = A_data['mask']
105 | label = shiftDownId(label)
106 | if self.random_crop and self.split == 'train':
107 | events_tensor = torch.zeros((events.shape[0], self.shape_crop[0], self.shape_crop[1]))
108 | else:
109 | events_tensor = events
110 | for k in range(events.shape[0]):
111 | events_tensor[k, :, :] = torch.from_numpy(
112 | A.ReplayCompose.replay(A_data['replay'], image=events[k, :, :].numpy())['image'])
113 | return events_tensor, label
114 |
115 | def __getitem__(self, idx):
116 | segmentation_mask_file = self.files[idx]
117 | segmentation_mask = cv2.imread(segmentation_mask_file, 0)
118 | label_original = np.array(segmentation_mask)
119 | if self.resize:
120 | segmentation_mask = cv2.resize(segmentation_mask, (self.shape_resize[1], self.shape_resize[0] - 60),
121 | interpolation=cv2.INTER_NEAREST)
122 | label = np.array(segmentation_mask)
123 |
124 | directory = dirname(dirname(segmentation_mask_file))
125 |
126 | img_idx = int(basename(segmentation_mask_file).split("_")[-1].split(".")[0]) - 1
127 |
128 | img_timestamp_event_idx = self.img_timestamp_event_idx[directory]
129 | t_events, xyp_events = self.event_data[directory]
130 |
131 | # events has form x, y, t_ns, p (in [0,1])
132 | events = extract_events_from_memmap(t_events, xyp_events, img_idx, img_timestamp_event_idx, self.fixed_duration,
133 | self.nr_events)
134 | t_ns = events[:, 2]
135 | delta_t_ns = int((t_ns[-1] - t_ns[0]) / self.nr_events_data)
136 | nr_events_loaded = events.shape[0]
137 | nr_events_temp = nr_events_loaded // self.nr_events_data
138 |
139 | id_end = 0
140 | event_tensor = None
141 | for i in range(self.nr_events_data):
142 | id_start = id_end
143 | if self.fixed_duration:
144 | id_end = np.searchsorted(t_ns, t_ns[0] + (i + 1) * delta_t_ns)
145 | else:
146 | id_end += nr_events_temp
147 |
148 | if id_end > nr_events_loaded:
149 | id_end = nr_events_loaded
150 |
151 | event_representation = data_util.generate_input_representation(events[id_start:id_end],
152 | self.event_representation,
153 | self.shape,
154 | nr_temporal_bins=self.nr_temporal_bins,
155 | separate_pol=self.separate_pol)
156 |
157 | event_representation = torch.from_numpy(event_representation)
158 |
159 | if self.normalize_event:
160 | event_representation = data_util.normalize_voxel_grid(event_representation)
161 |
162 | if self.resize:
163 | event_representation_resize = f.interpolate(event_representation.unsqueeze(0),
164 | size=(self.shape_resize[0], self.shape_resize[1]),
165 | mode='bilinear', align_corners=True)
166 | event_representation = event_representation_resize.squeeze(0)
167 |
168 | if event_tensor is None:
169 | event_tensor = event_representation
170 | else:
171 | event_tensor = torch.cat([event_tensor, event_representation], dim=0)
172 |
173 | event_tensor = event_tensor[:, :-60, :] # remove 60 bottom rows
174 |
175 | if self.random_crop and self.split == 'train':
176 | event_tensor = event_tensor[:, -self.shape_crop[0]:, :]
177 | label = label[-self.shape_crop[0]:, :]
178 | if self.augmentation:
179 | event_tensor, label = self.apply_augmentation(self.transform_a_random_crop, event_tensor, label)
180 |
181 | else:
182 | if self.augmentation:
183 | event_tensor, label = self.apply_augmentation(self.transform_a, event_tensor, label)
184 |
185 | label_tensor = torch.from_numpy(label).long()
186 |
187 | if self.split == 'valid' and self.require_paired_data:
188 | segmentation_mask_filepath_list = str(segmentation_mask_file).split('/')
189 | segmentation_mask_filename = segmentation_mask_filepath_list[-1]
190 | dir_name = segmentation_mask_filepath_list[-3]
191 | filename_id = segmentation_mask_filename.split('_')[-1]
192 | img_filename = '_'.join(['img', filename_id])
193 | img_filepath_list = segmentation_mask_filepath_list
194 | img_filepath_list[-2] = 'imgs'
195 | img_filepath_list[-1] = img_filename
196 | img_file = '/'.join(img_filepath_list)
197 | if not os.path.exists(img_file):
198 | img_filename = filename_id.zfill(14)
199 | img_filepath_list[-1] = img_filename
200 | img_file = '/'.join(img_filepath_list)
201 | img = Image.open(img_file)
202 |
203 | if self.resize:
204 | img = img.resize((self.shape_resize[1], self.shape_resize[0]))
205 | img_transform = transforms.Compose([
206 | transforms.Grayscale(),
207 | transforms.ToTensor()
208 | ])
209 | img_tensor = img_transform(img)
210 | img_tensor = img_tensor[:, :-60, :]
211 |
212 | label_original_tensor = torch.from_numpy(label_original).long()
213 | return event_tensor, img_tensor, label_tensor, label_original_tensor
214 | return event_tensor, label_tensor
215 |
216 |
--------------------------------------------------------------------------------
/datasets/extract_data_tools/example_loader_ddd17.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import numpy as np
4 | import cv2
5 | import matplotlib.pyplot as plt
6 | import copy
7 |
8 |
9 | def load_files_in_directory(directory, t_interval=50):
10 | # Load files: these include these have the following form
11 | #
12 | # idx.npy :
13 | # t0_ns idx0
14 | # t1_ns idx1
15 | # ...
16 | # tj_ns idxj
17 | # ...
18 | # tN_ns idxN
19 | #
20 | # This file contains a mapping from j -> tj_ns idxj,
21 | # where j+1 is the idx of the img with timestamp tj_ns (in nanoseconds)
22 | # and idxj is the idx of the last event before the img (in events.dat.t and events.dat.xyp)
23 | if t_interval == 10:
24 | img_timestamp_event_idx = np.load(os.path.join(directory, "index/index_10ms.npy"))
25 | elif t_interval == 50:
26 | img_timestamp_event_idx = np.load(os.path.join(directory, "index/index_50ms.npy"))
27 | elif t_interval == 250:
28 | img_timestamp_event_idx = np.load(os.path.join(directory, "index/index_250ms.npy"))
29 | else:
30 | img_timestamp_event_idx = np.load(os.path.join(directory, "index/index_50ms.npy"))
31 |
32 | # events.dat.t :
33 | # t0_ns
34 | # t1_ns
35 | # ...
36 | # tM_ns
37 | #
38 | # events.dat.xyp :
39 | # x0 y0 p0
40 | # ...
41 | # xM yM pM
42 | events_t_file = os.path.join(directory, "events.dat.t")
43 | events_xyp_file = os.path.join(directory, "events.dat.xyp")
44 |
45 | t_events, xyp_events = load_events(events_t_file, events_xyp_file)
46 |
47 | # Since the imgs are in a video format, they cannot be loaded directly, however, the segmentation masks from the
48 | # original dataset (EV-SegNet) have been copied into this folder. First unzip the segmentation masks with
49 | #
50 | # unzip segmentation_masks.zip
51 | #
52 | segmentation_mask_files = sorted(glob.glob(os.path.join(directory, "segmentation_masks", "*.png")))
53 |
54 | return img_timestamp_event_idx, t_events, xyp_events, segmentation_mask_files
55 |
56 |
57 | def load_events(t_file, xyp_file):
58 | # events.dat.t saves the timestamps of the indiviual events (in nanoseconds -> int64)
59 | # events.dat.xyp saves the x, y and polarity of events in uint8 to save storage. The polarity is 0 or 1.
60 | # We first need to compute the number of events in the memmap since it does not do it for us. We can do
61 | # this by computing the file size of the timestamps and dividing by 8 (since timestamps take 8 bytes)
62 |
63 | num_events = int(os.path.getsize(t_file) / 8)
64 | t_events = np.memmap(t_file, dtype="int64", mode="r", shape=(num_events, 1))
65 | xyp_events = np.memmap(xyp_file, dtype="int16", mode="r", shape=(num_events, 3))
66 |
67 | return t_events, xyp_events
68 |
69 |
70 | def extract_events_from_memmap(t_events, xyp_events, img_idx, img_timestamp_event_idx, fixed_duration=False,
71 | nr_events=32000):
72 | # timestep, event_idx = img_timestamp_event_idx[img_idx]
73 | # _, event_idx_before = img_timestamp_event_idx[img_idx - 1]
74 | if fixed_duration:
75 | timestep, event_idx, event_idx_before = img_timestamp_event_idx[img_idx]
76 | event_idx_before = max([event_idx_before, 0])
77 | else:
78 | timestep, event_idx, _ = img_timestamp_event_idx[img_idx]
79 | event_idx_before = max([event_idx - nr_events, 0])
80 | events_between_imgs = np.concatenate([
81 | np.array(t_events[event_idx_before:event_idx], dtype="int64"),
82 | np.array(xyp_events[event_idx_before:event_idx], dtype="int64")
83 | ], -1)
84 | events_between_imgs = events_between_imgs[:, [1, 2, 0, 3]] # events have format xytp, and p is in [0,1]
85 |
86 | return events_between_imgs
87 |
88 |
89 | def generate_event_img(shape, events):
90 | H, W = shape
91 | # generate event img
92 | event_img_pos = np.zeros((H * W,), dtype="float32")
93 | event_img_neg = np.zeros((H * W,), dtype="float32")
94 |
95 | x, y, t, p = events.T
96 |
97 | np.add.at(event_img_pos, x[p == 1] + W * y[p == 1], p[p == 1])
98 | np.add.at(event_img_neg, x[p == 0] + W * y[p == 0], p[p == 0] + 1)
99 |
100 | event_img_pos = event_img_pos.reshape((H, W))
101 | event_img_neg = event_img_neg.reshape((H, W))
102 |
103 | return event_img_neg, event_img_pos
104 |
105 |
106 | def generate_colored_label_img(shape, label_mask):
107 | H, W = shape
108 |
109 | colors = [[0, 0, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255]]
110 | mask = segmentation_mask.reshape((-1, 3))[:, 0]
111 |
112 | img = np.zeros((H * W, 3), dtype="uint8")
113 |
114 | for i in np.unique(segmentation_mask):
115 | c = colors[int(i)]
116 | img[mask == i, 0] = c[0]
117 | img[mask == i, 1] = c[1]
118 | img[mask == i, 2] = c[2]
119 |
120 | img = img.reshape((H, W, 3))
121 |
122 | return img
123 |
124 |
125 | def generate_rendered_events_on_img(img, event_map_neg, event_map_pos):
126 | orig_shape = img.shape
127 |
128 | img = img.copy()
129 |
130 | img = img.reshape((-1, 3))
131 | pos_mask = event_map_pos.reshape((-1,)) > 0
132 | neg_mask = event_map_neg.reshape((-1,)) > 0
133 |
134 | img[neg_mask, 0] = 255
135 | img[pos_mask, 2] = 255
136 | img[neg_mask | pos_mask, 1] = 0
137 |
138 | img = img.reshape(orig_shape)
139 |
140 | return img
141 |
142 |
143 | if __name__ == "__main__":
144 |
145 | directories = sorted(glob.glob(os.path.join(os.path.dirname(__file__), "dir*")))
146 | assert len(directories) > 0
147 | # example with one directory
148 | directory = directories[1]
149 | print(directories)
150 | print("Using directory: %s" % directory)
151 |
152 | # load all files that are in the directory (these are real events)
153 | img_timestamp_event_idx, t_events, xyp_events, segmentation_mask_files = \
154 | load_files_in_directory(directory)
155 |
156 | # load all files that are in video_upsampled_events. These are simulated data.
157 | sim_directory = os.path.join(directory, "video_upsampled_events")
158 | load_sim = os.path.exists(sim_directory)
159 | img_timestamp_event_idx_sim, t_events_sim, xyp_events_sim = None, None, None
160 | if load_sim:
161 | print("Loading sim data")
162 | img_timestamp_event_idx_sim, t_events_sim, xyp_events_sim, _ = \
163 | load_files_in_directory(sim_directory)
164 |
165 | num_plots = 3 if load_sim else 2
166 | fig, ax = plt.subplots(ncols=num_plots)
167 | img_handles = []
168 | assert len(segmentation_mask_files) > 0
169 | for segmentation_mask_file in segmentation_mask_files[-100:]:
170 | # take an example mask and extract the corresponding idx
171 | print("Using segmentation mask: %s" % segmentation_mask_file)
172 | segmentation_mask = cv2.imread(segmentation_mask_file)
173 |
174 | img_idx = int(os.path.basename(segmentation_mask_file).split("_")[-1].split(".")[0]) - 1
175 | print("Loading img with idx %s" % img_idx)
176 |
177 | # load corresponding img
178 | # first decompress video by running
179 | #
180 | # mkdir imgs
181 | # ffmpeg -i video.mp4 imgs/img_%08d.png
182 | #
183 | img_file = segmentation_mask_file.replace("segmentation_masks", "imgs").replace("/segmentation_", "/img_")
184 | # img_file = '/'.join(img_file.split('/')[:-1]) + '/' + '{:0>10}'.format(img_file.split('_')[-1][:-4]) + '.png'
185 | img = cv2.imread(img_file)
186 |
187 | # crop img since this was done in EV-SegNet
188 | img = img[:200]
189 |
190 | # find events between this idx and the last
191 | events_between_imgs = \
192 | extract_events_from_memmap(t_events, xyp_events, img_idx, img_timestamp_event_idx)
193 | print("Found %s events" % (len(events_between_imgs)))
194 |
195 | # remove all events with y > 200 since these were cropped from the dataset
196 | events_between_imgs = events_between_imgs[events_between_imgs[:, 1] < 200]
197 |
198 | if load_sim:
199 | # find sim events between this idx and the last
200 | events_between_imgs_sim = \
201 | extract_events_from_memmap(t_events_sim, xyp_events_sim, img_idx, img_timestamp_event_idx_sim)
202 | print("Found %s simulated events" % (len(events_between_imgs_sim)))
203 |
204 | # remove all events with y > 200 since these were cropped from the dataset
205 | events_between_imgs_sim = events_between_imgs_sim[events_between_imgs_sim[:, 1] < 200]
206 |
207 | event_img_neg, event_img_pos = generate_event_img((200, 346), events_between_imgs)
208 | event_img_neg_sim, event_img_pos_sim = generate_event_img((200, 346), events_between_imgs_sim)
209 |
210 | # generate view of labels
211 | colored_label_img = generate_colored_label_img((200, 346), segmentation_mask)
212 |
213 | # draw events on img
214 | rendered_events_on_img = generate_rendered_events_on_img(copy.deepcopy(img), event_img_neg, event_img_pos)
215 |
216 | if load_sim:
217 | # draw events on img
218 | rendered_events_on_img_sim = generate_rendered_events_on_img(copy.deepcopy(img), event_img_neg_sim,
219 | event_img_pos_sim)
220 |
221 | print("Error: ",
222 | np.abs((rendered_events_on_img_sim).astype("float32") - (rendered_events_on_img).astype("float32")).sum())
223 |
224 | if len(img_handles) == 0:
225 | img_handles += [ax[0].imshow(colored_label_img)]
226 | img_handles += [ax[1].imshow(rendered_events_on_img)]
227 | if load_sim:
228 | img_handles += [ax[2].imshow(rendered_events_on_img_sim)]
229 | plt.show(block=False)
230 | else:
231 | img_handles[0].set_data(colored_label_img)
232 | img_handles[1].set_data(rendered_events_on_img)
233 | if load_sim:
234 | img_handles[2].set_data(rendered_events_on_img_sim)
235 | fig.canvas.draw()
236 | plt.pause(0.002)
237 |
238 |
239 |
--------------------------------------------------------------------------------
/datasets/wrapper_dataloader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 |
3 |
4 | class WrapperDataset(Dataset):
5 | def __init__(self, dataloader_a, dataloader_b, device, dataset_len_to_use=None):
6 | self.dataloader_a = dataloader_a
7 | self.dataloader_b = dataloader_b
8 | self.dataloader_a_iter = iter(self.dataloader_a)
9 | self.dataloader_b_iter = iter(self.dataloader_b)
10 | self.require_paired_data_a = dataloader_a.dataset.require_paired_data
11 | self.require_paired_data_b = dataloader_b.dataset.require_paired_data
12 | self.device = device
13 |
14 | self.dataset_a_larger = False
15 | if self.dataloader_a.__len__() > self.dataloader_b.__len__():
16 | self.dataset_a_larger = True
17 |
18 | if dataset_len_to_use == 'first':
19 | self.dataset_a_larger = True
20 | elif dataset_len_to_use == 'second':
21 | self.dataset_a_larger = False
22 |
23 | def __len__(self):
24 | if self.dataset_a_larger:
25 | return self.dataloader_a.__len__()
26 | else:
27 | return self.dataloader_b.__len__()
28 |
29 | def createIterators(self):
30 | self.dataloader_a_iter = iter(self.dataloader_a)
31 | self.dataloader_b_iter = iter(self.dataloader_b)
32 |
33 | def __getitem__(self, idx):
34 | """
35 | Returns two samples
36 | """
37 | if not self.require_paired_data_a and not self.require_paired_data_b:
38 | if self.dataset_a_larger:
39 | try:
40 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter)
41 | except StopIteration:
42 | self.dataloader_b_iter = iter(self.dataloader_b)
43 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter)
44 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter)
45 | else:
46 | try:
47 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter)
48 | except StopIteration:
49 | self.dataloader_a_iter = iter(self.dataloader_a)
50 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter)
51 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter)
52 |
53 | return [dataset_a_data.to(self.device), dataset_a_label.to(self.device)], \
54 | [dataset_b_data.to(self.device), dataset_b_label.to(self.device)]
55 |
56 | if self.require_paired_data_a and not self.require_paired_data_b:
57 | if self.dataset_a_larger:
58 | try:
59 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter)
60 | except StopIteration:
61 | self.dataloader_b_iter = iter(self.dataloader_b)
62 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter)
63 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter)
64 | else:
65 | try:
66 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter)
67 | except StopIteration:
68 | self.dataloader_a_iter = iter(self.dataloader_a)
69 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter)
70 | dataset_b_data, dataset_b_label = next(self.dataloader_b_iter)
71 |
72 | return [dataset_a_data.to(self.device), dataset_a_paired_data.to(self.device), dataset_a_label.to(self.device)], \
73 | [dataset_b_data.to(self.device), dataset_b_label.to(self.device)]
74 |
75 | if not self.require_paired_data_a and self.require_paired_data_b:
76 | if self.dataset_a_larger:
77 | try:
78 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter)
79 | except StopIteration:
80 | self.dataloader_b_iter = iter(self.dataloader_b)
81 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter)
82 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter)
83 | else:
84 | try:
85 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter)
86 | except StopIteration:
87 | self.dataloader_a_iter = iter(self.dataloader_a)
88 | dataset_a_data, dataset_a_label = next(self.dataloader_a_iter)
89 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter)
90 |
91 | return [dataset_a_data.to(self.device), dataset_a_label.to(self.device)], \
92 | [dataset_b_data.to(self.device), dataset_b_paired_data.to(self.device), dataset_b_label.to(self.device)]
93 |
94 | if self.require_paired_data_a and self.require_paired_data_b:
95 | if self.dataset_a_larger:
96 | try:
97 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter)
98 | except StopIteration:
99 | self.dataloader_b_iter = iter(self.dataloader_b)
100 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter)
101 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter)
102 | else:
103 | try:
104 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter)
105 | except StopIteration:
106 | self.dataloader_a_iter = iter(self.dataloader_a)
107 | dataset_a_data, dataset_a_paired_data, dataset_a_label = next(self.dataloader_a_iter)
108 | dataset_b_data, dataset_b_paired_data, dataset_b_label = next(self.dataloader_b_iter)
109 |
110 | return [dataset_a_data.to(self.device), dataset_a_paired_data.to(self.device), dataset_a_label.to(self.device)], \
111 | [dataset_b_data.to(self.device), dataset_b_paired_data.to(self.device), dataset_b_label.to(self.device)]
--------------------------------------------------------------------------------
/e2vid/README.md:
--------------------------------------------------------------------------------
1 | # High Speed and High Dynamic Range Video with an Event Camera
2 |
3 | [](https://youtu.be/eomALySSGVU)
4 |
5 | This is the code for the paper **High Speed and High Dynamic Range Video with an Event Camera** by [Henri Rebecq](http://henri.rebecq.fr), Rene Ranftl, [Vladlen Koltun](http://vladlen.info/) and [Davide Scaramuzza](http://rpg.ifi.uzh.ch/people_scaramuzza.html):
6 |
7 | You can find a pdf of the paper [here](http://rpg.ifi.uzh.ch/docs/TPAMI19_Rebecq.pdf).
8 | If you use any of this code, please cite the following publications:
9 |
10 | ```bibtex
11 | @Article{Rebecq19pami,
12 | author = {Henri Rebecq and Ren{\'{e}} Ranftl and Vladlen Koltun and Davide Scaramuzza},
13 | title = {High Speed and High Dynamic Range Video with an Event Camera},
14 | journal = {{IEEE} Trans. Pattern Anal. Mach. Intell. (T-PAMI)},
15 | url = {http://rpg.ifi.uzh.ch/docs/TPAMI19_Rebecq.pdf},
16 | year = 2019
17 | }
18 | ```
19 |
20 |
21 | ```bibtex
22 | @Article{Rebecq19cvpr,
23 | author = {Henri Rebecq and Ren{\'{e}} Ranftl and Vladlen Koltun and Davide Scaramuzza},
24 | title = {Events-to-Video: Bringing Modern Computer Vision to Event Cameras},
25 | journal = {{IEEE} Conf. Comput. Vis. Pattern Recog. (CVPR)},
26 | year = 2019
27 | }
28 | ```
29 |
30 | ## Install
31 |
32 | Dependencies:
33 |
34 | - [PyTorch](https://pytorch.org/get-started/locally/) >= 1.0
35 | - [NumPy](https://www.numpy.org/)
36 | - [Pandas](https://pandas.pydata.org/)
37 | - [OpenCV](https://opencv.org/)
38 |
39 | ### Install with Anaconda
40 |
41 | The installation requires [Anaconda3](https://www.anaconda.com/distribution/).
42 | You can create a new Anaconda environment with the required dependencies as follows (make sure to adapt the CUDA toolkit version according to your setup):
43 |
44 | ```bash
45 | conda create -n E2VID
46 | conda activate E2VID
47 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
48 | conda install pandas
49 | conda install -c conda-forge opencv
50 | ```
51 |
52 | ## Run
53 |
54 | - Download the pretrained model:
55 |
56 | ```bash
57 | wget "http://rpg.ifi.uzh.ch/data/E2VID/models/E2VID_lightweight.pth.tar" -O pretrained/E2VID_lightweight.pth.tar
58 | ```
59 |
60 | - Download an example file with event data:
61 |
62 | ```bash
63 | wget "http://rpg.ifi.uzh.ch/data/E2VID/datasets/ECD_IJRR17/dynamic_6dof.zip" -O data/dynamic_6dof.zip
64 | ```
65 |
66 | Before running the reconstruction, make sure the conda environment is sourced:
67 |
68 | ```bash
69 | conda activate E2VID
70 | ```
71 |
72 | - Run reconstruction:
73 |
74 | ```bash
75 | python run_reconstruction.py \
76 | -c pretrained/E2VID_lightweight.pth.tar \
77 | -i data/dynamic_6dof.zip \
78 | --auto_hdr \
79 | --display \
80 | --show_events
81 | ```
82 |
83 | ## Parameters
84 |
85 | Below is a description of the most important parameters:
86 |
87 | #### Main parameters
88 |
89 | - ``--window_size`` / ``-N`` (default: None) Number of events per window. This is the parameter that has the most influence of the image reconstruction quality. If set to None, this number will be automatically computed based on the sensor size, as N = width * height * num_events_per_pixel (see description of that parameter below). Ignored if `--fixed_duration` is set.
90 | - ``--fixed_duration`` (default: False) If True, will use windows of events with a fixed duration (i.e. a fixed output frame rate).
91 | - ``--window_duration`` / ``-T`` (default: 33 ms) Duration of each event window, in milliseconds. The value of this parameter has strong influence on the image reconstruction quality. Its value may need to be adapted to the dynamics of the scene. Ignored if `--fixed_duration` is not set.
92 | - ``--Imin`` (default: 0.0), `--Imax` (default: 1.0): linear tone mapping is performed by normalizing the output image as follows: `I = (I - Imin) / (Imax - Imin)`. If `--auto_hdr` is set to True, `--Imin` and `--Imax` will be automatically computed as the min (resp. max) intensity values.
93 | - ``--auto_hdr`` (default: False) Automatically compute `--Imin` and `--Imax`. Disabled when `--color` is set.
94 | - ``--color`` (default: False): if True, will perform color reconstruction as described in the paper. Only use this with a [color event camera](http://rpg.ifi.uzh.ch/CED.html) such as the Color DAVIS346.
95 |
96 | #### Output parameters
97 |
98 | - ``--output_folder``: path of the output folder. If not set, the image reconstructions will not be saved to disk.
99 | - ``--dataset_name``: name of the output folder directory (default: 'reconstruction').
100 |
101 | #### Display parameters
102 |
103 | - ``--display`` (default: False): display the video reconstruction in real-time in an OpenCV window.
104 | - ``--show_events`` (default: False): show the input events side-by-side with the reconstruction. If ``--output_folder`` is set, the previews will also be saved to disk in ``/path/to/output/folder/events``.
105 |
106 | #### Additional parameters
107 |
108 | - ``--num_events_per_pixel`` (default: 0.35): Parameter used to automatically estimate the window size based on the sensor size. The value of 0.35 was chosen to correspond to ~ 15,000 events on a 240x180 sensor such as the DAVIS240C.
109 | - ``--no-normalize`` (default: False): Disable event tensor normalization: this will improve speed a bit, but might degrade the image quality a bit.
110 | - ``--no-recurrent`` (default: False): Disable the recurrent connection (i.e. do not maintain a state). For experimenting only, the results will be flickering a lot.
111 | - ``--hot_pixels_file`` (default: None): Path to a file specifying the locations of hot pixels (such a file can be obtained with [this tool](https://github.com/cedric-scheerlinck/dvs_tools/tree/master/dvs_hot_pixel_filter) for example). These pixels will be ignored (i.e. zeroed out in the event tensors).
112 |
113 | ## Example datasets
114 |
115 | We provide a list of example (publicly available) event datasets to get started with E2VID.
116 |
117 | - [High Speed (gun shooting!) and HDR Dataset](http://rpg.ifi.uzh.ch/E2VID.html)
118 | - [Event Camera Dataset](http://rpg.ifi.uzh.ch/data/E2VID/datasets/ECD_IJRR17/)
119 | - [Bardow et al., CVPR'16](http://rpg.ifi.uzh.ch/data/E2VID/datasets/SOFIE_CVPR16/)
120 | - [Scherlinck et al., ACCV'18](http://rpg.ifi.uzh.ch/data/E2VID/datasets/HF_ACCV18/)
121 | - [Color event sequences from the CED dataset Scheerlinck et al., CVPR'18](http://rpg.ifi.uzh.ch/data/E2VID/datasets/CED_CVPRW19/)
122 |
123 | ## Working with ROS
124 |
125 | Because PyTorch recommends Python 3 and ROS is only compatible with Python2, it is not straightforward to have the PyTorch reconstruction code and ROS code running in the same environment.
126 | To make things easy, the reconstruction code we provide has no dependency on ROS, and simply read events from a text file or ZIP file.
127 | We provide convenience functions to convert ROS bags (a popular format for event datasets) into event text files.
128 | In addition, we also provide scripts to convert a folder containing image reconstructions back to a rosbag (or to append image reconstructions to an existing rosbag).
129 |
130 | **Note**: it is **not** necessary to have a sourced conda environment to run the following scripts. However, [ROS](https://www.ros.org/) needs to be installed and sourced.
131 |
132 | ### rosbag -> events.txt
133 |
134 | To extract the events from a rosbag to a zip file containing the event data:
135 |
136 | ```bash
137 | python scripts/extract_events_from_rosbag.py /path/to/rosbag.bag \
138 | --output_folder=/path/to/output/folder \
139 | --event_topic=/dvs/events
140 | ```
141 |
142 | ### image reconstruction folder -> rosbag
143 |
144 | ```bash
145 | python scripts/image_folder_to_rosbag.py \
146 | --datasets dynamic_6dof \
147 | --image_folder /path/to/image/folder \
148 | --output_folder /path/to/output_folder \
149 | --image_topic /dvs/image_reconstructed
150 | ```
151 |
152 | ### Append image_reconstruction_folder to an existing rosbag
153 |
154 | ```bash
155 | cd scripts
156 | python embed_reconstructed_images_in_rosbag.py \
157 | --rosbag_folder /path/to/rosbag/folder \
158 | --datasets dynamic_6dof \
159 | --image_folder /path/to/image/folder \
160 | --output_folder /path/to/output_folder \
161 | --image_topic /dvs/image_reconstructed
162 | ```
163 |
164 | ### Generating a video reconstruction (with a fixed framerate)
165 |
166 | It can be convenient to convert an image folder to a video with a fixed framerate (for example for use in a video editing tool).
167 | You can proceed as follows:
168 |
169 | ```bash
170 | export FRAMERATE=30
171 | python resample_reconstructions.py -i /path/to/input_folder -o /tmp/resampled -r $FRAMERATE
172 | ffmpeg -framerate $FRAMERATE -i /tmp/resampled/frame_%010d.png video_"$FRAMERATE"Hz.mp4
173 | ```
174 |
175 | ## Acknowledgements
176 |
177 | This code borrows from the following open source projects, whom we would like to thank:
178 |
179 | - [pytorch-template](https://github.com/victoresque/pytorch-template)
180 |
--------------------------------------------------------------------------------
/e2vid/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_model import *
2 |
--------------------------------------------------------------------------------
/e2vid/base/base_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class BaseModel(nn.Module):
7 | """
8 | Base class for all models
9 | """
10 | def __init__(self, config):
11 | super(BaseModel, self).__init__()
12 | self.config = config
13 | self.logger = logging.getLogger(self.__class__.__name__)
14 |
15 | def forward(self, *input):
16 | """
17 | Forward pass logic
18 |
19 | :return: Model output
20 | """
21 | raise NotImplementedError
22 |
23 | def summary(self):
24 | """
25 | Model summary
26 | """
27 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
28 | params = sum([np.prod(p.size()) for p in model_parameters])
29 | self.logger.info('Trainable parameters: {}'.format(params))
30 | self.logger.info(self)
31 |
--------------------------------------------------------------------------------
/e2vid/image_reconstructor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import numpy as np
4 | from e2vid.model.model import *
5 | from e2vid.utils.inference_utils import CropParameters, EventPreprocessor, IntensityRescaler, ImageFilter, ImageDisplay, \
6 | ImageWriter, UnsharpMaskFilter
7 | from e2vid.utils.inference_utils import upsample_color_image, \
8 | merge_channels_into_color_image # for color reconstruction
9 | from e2vid.utils.util import robust_min, robust_max
10 | from e2vid.utils.timers import CudaTimer, cuda_timers
11 | from os.path import join
12 | from collections import deque
13 | import torchvision.transforms as transforms
14 | import albumentations as A
15 | from PIL import Image
16 |
17 |
18 | class ImageReconstructor:
19 | def __init__(self, model, height, width, num_bins, device, options, augmentation=False, standardization=False):
20 |
21 | self.model = model
22 | self.use_gpu = options.use_gpu
23 | self.device = device
24 | self.height = height
25 | self.width = width
26 | self.num_bins = num_bins
27 |
28 | self.standardization = standardization
29 |
30 | self.augmentation = augmentation
31 | if self.augmentation:
32 | self.transform_a = A.Compose([
33 | A.GaussNoise(p=0.2),
34 | A.RandomBrightnessContrast(p=0.5),
35 | A.OneOf(
36 | [
37 | A.Sharpen(p=1),
38 | A.Blur(blur_limit=3, p=1),
39 | A.MotionBlur(blur_limit=3, p=1),
40 | ],
41 | p=0.5,
42 | )
43 | ])
44 | self.img_transform = transforms.Compose([
45 | transforms.Grayscale(),
46 | transforms.ToTensor()
47 | ])
48 |
49 | self.initialize(self.height, self.width, options)
50 |
51 | def initialize(self, height, width, options):
52 | # print('== Image reconstruction == ')
53 | # print('Image size: {}x{}'.format(self.height, self.width))
54 |
55 | self.no_recurrent = options.no_recurrent
56 | if self.no_recurrent:
57 | print('!!Recurrent connection disabled!!')
58 |
59 | self.perform_color_reconstruction = options.color # whether to perform color reconstruction (only use this with the DAVIS346color)
60 | if self.perform_color_reconstruction:
61 | if options.auto_hdr:
62 | print('!!Warning: disabling auto HDR for color reconstruction!!')
63 | options.auto_hdr = False # disable auto_hdr for color reconstruction (otherwise, each channel will be normalized independently)
64 |
65 | self.crop = CropParameters(self.width, self.height, self.model.num_encoders)
66 |
67 | self.last_states_for_each_channel = {'grayscale': None}
68 |
69 | if self.perform_color_reconstruction:
70 | self.crop_halfres = CropParameters(int(width / 2), int(height / 2),
71 | self.model.num_encoders)
72 | for channel in ['R', 'G', 'B', 'W']:
73 | self.last_states_for_each_channel[channel] = None
74 |
75 | self.event_preprocessor = EventPreprocessor(options)
76 | self.intensity_rescaler = IntensityRescaler(options)
77 | self.image_filter = ImageFilter(options)
78 | self.unsharp_mask_filter = UnsharpMaskFilter(options, device=self.device)
79 | # self.image_writer = ImageWriter(options)
80 | # self.image_display = ImageDisplay(options)
81 |
82 | def update_reconstruction(self, event_tensor, event_tensor_id=None, stamp=None):
83 | with torch.no_grad():
84 |
85 | with CudaTimer('Reconstruction'):
86 |
87 | with CudaTimer('NumPy (CPU) -> Tensor (GPU)'):
88 | events = event_tensor
89 | events = events.to(self.device)
90 |
91 | events = self.event_preprocessor(events)
92 |
93 | # Resize tensor to [1 x C x crop_size x crop_size] by applying zero padding
94 | events_for_each_channel = {'grayscale': self.crop.pad(events)}
95 | reconstructions_for_each_channel = {}
96 | # if self.perform_color_reconstruction:
97 | # events_for_each_channel['R'] = self.crop_halfres.pad(events[:, :, 0::2, 0::2])
98 | # events_for_each_channel['G'] = self.crop_halfres.pad(events[:, :, 0::2, 1::2])
99 | # events_for_each_channel['W'] = self.crop_halfres.pad(events[:, :, 1::2, 0::2])
100 | # events_for_each_channel['B'] = self.crop_halfres.pad(events[:, :, 1::2, 1::2])
101 |
102 | # Reconstruct new intensity image for each channel (grayscale + RGBW if color reconstruction is enabled)
103 | for channel in events_for_each_channel.keys():
104 | with CudaTimer('Inference'):
105 | new_predicted_frame, states, latent = self.model(events_for_each_channel[channel],
106 | self.last_states_for_each_channel[channel])
107 |
108 | if self.no_recurrent:
109 | self.last_states_for_each_channel[channel] = None
110 | else:
111 | self.last_states_for_each_channel[channel] = states
112 |
113 | # Output reconstructed image
114 | # crop = self.crop if channel == 'grayscale' else self.crop_halfres
115 |
116 | # Unsharp mask (on GPU)
117 | # new_predicted_frame = self.unsharp_mask_filter(new_predicted_frame)
118 |
119 | # Intensity rescaler (on GPU)
120 | # new_predicted_frame = self.intensity_rescaler(new_predicted_frame)
121 |
122 | with CudaTimer('Tensor (GPU) -> NumPy (CPU)'):
123 | reconstructions_for_each_channel[channel] = new_predicted_frame
124 | # reconstructions_for_each_channel[channel] = new_predicted_frame.cpu().numpy()
125 |
126 | # if self.perform_color_reconstruction:
127 | # out = merge_channels_into_color_image(reconstructions_for_each_channel)
128 | # else:
129 | out = reconstructions_for_each_channel['grayscale']
130 |
131 | if self.standardization:
132 | batch_size, height, width = out.size(0), out.size(2), out.size(3)
133 | out = out.view(out.size(0), -1)
134 | out -= out.min(1, keepdim=True)[0]
135 | out /= out.max(1, keepdim=True)[0]
136 | out = out.view(batch_size, 1, height, width)
137 |
138 | # Imin = torch.min(out).item()
139 | # Imax = torch.max(out).item()
140 | # out = 255.0 * (out - Imin) / (Imax - Imin)
141 | # out.clamp_(0.0, 255.0)
142 | # out = out.byte() # convert to 8-bit tensor
143 | # out = out.float().div(255)
144 |
145 | # mean = [0.5371 for i in range(out.shape[0])]
146 | # std = [0.1540 for i in range(out.shape[0])]
147 | # stand_transform = transforms.Normalize(mean=mean, std=std)
148 | # out = stand_transform(out.squeeze(1)).unsqueeze(1)
149 | # out = torch.clamp(out, min=-1.0, max=1.0)
150 | # out = (out + 1.0) / 2.0
151 |
152 | if self.augmentation:
153 | for i in range(out.shape[0]):
154 | img_aug = out[i].cpu()
155 | img_aug = transforms.ToPILImage()(img_aug)
156 | img_aug = np.array(img_aug)
157 | img_aug = self.transform_a(image=img_aug)["image"]
158 | img_aug = Image.fromarray(img_aug.astype('uint8')).convert('RGB')
159 | out[i] = self.img_transform(img_aug).to(self.device)
160 | # Post-processing, e.g bilateral filter (on CPU)
161 | # out = torch.from_numpy(self.image_filter(out)).to(self.device)
162 |
163 | return out, states, latent
164 |
165 |
166 | class PostProcessor:
167 | def __init__(self, device, options):
168 | self.device = device
169 | self.unsharp_mask_filter = UnsharpMaskFilter(options, device=self.device)
170 | self.intensity_rescaler = IntensityRescaler(options)
171 | self.image_filter = ImageFilter(options)
172 |
173 | def process(self, new_predicted_frame):
174 | with torch.no_grad():
175 | # Unsharp mask (on GPU)
176 | new_predicted_frame = self.unsharp_mask_filter(new_predicted_frame)
177 |
178 | # Intensity rescaler (on GPU)
179 | new_predicted_frame = self.intensity_rescaler(new_predicted_frame)
180 |
181 | out = new_predicted_frame.cpu().numpy()
182 |
183 | # Post-processing, e.g bilateral filter (on CPU)
184 | out = torch.from_numpy(self.image_filter(out)).to(self.device)
185 | return out
186 |
--------------------------------------------------------------------------------
/e2vid/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/e2vid/model/__init__.py
--------------------------------------------------------------------------------
/e2vid/model/model.py:
--------------------------------------------------------------------------------
1 | from e2vid.base import BaseModel
2 | import torch.nn as nn
3 | import torch
4 | from e2vid.model.unet import UNet, UNetRecurrent, UNetDecoder, UNetTask
5 | from os.path import join
6 | from e2vid.model.submodules import ConvLSTM, ResidualBlock, ConvLayer, UpsampleConvLayer, TransposedConvLayer
7 |
8 |
9 | class BaseE2VID(BaseModel):
10 | def __init__(self, config):
11 | super().__init__(config)
12 |
13 | assert('num_bins' in config)
14 | self.num_bins = int(config['num_bins']) # number of bins in the voxel grid event tensor
15 |
16 | try:
17 | self.skip_type = str(config['skip_type'])
18 | except KeyError:
19 | self.skip_type = 'sum'
20 |
21 | try:
22 | self.num_encoders = int(config['num_encoders'])
23 | except KeyError:
24 | self.num_encoders = 4
25 |
26 | try:
27 | self.base_num_channels = int(config['base_num_channels'])
28 | except KeyError:
29 | self.base_num_channels = 32
30 |
31 | try:
32 | self.num_residual_blocks = int(config['num_residual_blocks'])
33 | except KeyError:
34 | self.num_residual_blocks = 2
35 |
36 | try:
37 | self.norm = str(config['norm'])
38 | except KeyError:
39 | self.norm = None
40 |
41 | try:
42 | self.use_upsample_conv = bool(config['use_upsample_conv'])
43 | except KeyError:
44 | self.use_upsample_conv = True
45 |
46 |
47 | class E2VID(BaseE2VID):
48 | def __init__(self, config):
49 | super(E2VID, self).__init__(config)
50 |
51 | self.unet = UNet(num_input_channels=self.num_bins,
52 | num_output_channels=1,
53 | skip_type=self.skip_type,
54 | activation='sigmoid',
55 | num_encoders=self.num_encoders,
56 | base_num_channels=self.base_num_channels,
57 | num_residual_blocks=self.num_residual_blocks,
58 | norm=self.norm,
59 | use_upsample_conv=self.use_upsample_conv)
60 |
61 | def forward(self, event_tensor, prev_states=None):
62 | """
63 | :param event_tensor: N x num_bins x H x W
64 | :return: a predicted image of size N x 1 x H x W, taking values in [0,1].
65 | """
66 | return self.unet.forward(event_tensor), None
67 |
68 |
69 | class E2VIDRecurrent(BaseE2VID):
70 | """
71 | Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU.
72 | """
73 |
74 | def __init__(self, config):
75 | super(E2VIDRecurrent, self).__init__(config)
76 |
77 | try:
78 | self.recurrent_block_type = str(config['recurrent_block_type'])
79 | except KeyError:
80 | self.recurrent_block_type = 'convlstm' # or 'convgru'
81 |
82 | self.unetrecurrent = UNetRecurrent(num_input_channels=self.num_bins,
83 | num_output_channels=1,
84 | skip_type=self.skip_type,
85 | recurrent_block_type=self.recurrent_block_type,
86 | activation='sigmoid',
87 | num_encoders=self.num_encoders,
88 | base_num_channels=self.base_num_channels,
89 | num_residual_blocks=self.num_residual_blocks,
90 | norm=self.norm,
91 | use_upsample_conv=self.use_upsample_conv)
92 |
93 | def forward(self, event_tensor, prev_states):
94 | """
95 | :param event_tensor: N x num_bins x H x W
96 | :param prev_states: previous ConvLSTM state for each encoder module
97 | :return: reconstructed image, taking values in [0,1].
98 | """
99 | img_pred, states, latent = self.unetrecurrent.forward(event_tensor, prev_states)
100 | return img_pred, states, latent
101 |
102 | class E2VIDDecoder(BaseE2VID):
103 | """
104 | Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU.
105 | """
106 |
107 | def __init__(self, config):
108 | super(E2VIDDecoder, self).__init__(config)
109 |
110 | try:
111 | self.recurrent_block_type = str(config['recurrent_block_type'])
112 | except KeyError:
113 | self.recurrent_block_type = 'convlstm' # or 'convgru'
114 |
115 | self.unetrecurrent = UNetDecoder(num_input_channels=self.num_bins,
116 | num_output_channels=1,
117 | skip_type=self.skip_type,
118 | recurrent_block_type=self.recurrent_block_type,
119 | activation='sigmoid',
120 | num_encoders=self.num_encoders,
121 | base_num_channels=self.base_num_channels,
122 | num_residual_blocks=self.num_residual_blocks,
123 | norm=self.norm,
124 | use_upsample_conv=self.use_upsample_conv)
125 |
126 | def forward(self, x, blocks, head):
127 | """
128 | :param event_tensor: N x num_bins x H x W
129 | :param prev_states: previous ConvLSTM state for each encoder module
130 | :return: reconstructed image, taking values in [0,1].
131 | """
132 | img_pred = self.unetrecurrent.forward(x, blocks, head)
133 | return img_pred
134 |
135 | class E2VIDTask(BaseE2VID):
136 | """
137 | Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU.
138 | """
139 |
140 | def __init__(self, config):
141 | super(E2VIDTask, self).__init__(config)
142 |
143 | try:
144 | self.recurrent_block_type = str(config['recurrent_block_type'])
145 | except KeyError:
146 | self.recurrent_block_type = 'convlstm' # or 'convgru'
147 |
148 | self.unetrecurrent = UNetTask(num_input_channels=self.num_bins,
149 | num_output_channels=13,
150 | skip_type=self.skip_type,
151 | recurrent_block_type=self.recurrent_block_type,
152 | activation='sigmoid',
153 | num_encoders=self.num_encoders,
154 | base_num_channels=self.base_num_channels,
155 | num_residual_blocks=self.num_residual_blocks,
156 | norm=self.norm,
157 | use_upsample_conv=self.use_upsample_conv)
158 |
159 | def forward(self, input_dict):
160 | """
161 | :param event_tensor: N x num_bins x H x W
162 | :param prev_states: previous ConvLSTM state for each encoder module
163 | :return: reconstructed image, taking values in [0,1].
164 | """
165 | semseg_pred = self.unetrecurrent.forward(input_dict)
166 | return semseg_pred
167 |
--------------------------------------------------------------------------------
/e2vid/model/submodules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as f
4 | from torch.nn import init
5 |
6 |
7 | class ConvLayer(nn.Module):
8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None):
9 | super(ConvLayer, self).__init__()
10 |
11 | bias = False if norm == 'BN' else True
12 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
13 | if activation is not None:
14 | self.activation = getattr(torch, activation, 'relu')
15 | else:
16 | self.activation = None
17 |
18 | self.norm = norm
19 | if norm == 'BN':
20 | self.norm_layer = nn.BatchNorm2d(out_channels)
21 | elif norm == 'IN':
22 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
23 |
24 | def forward(self, x):
25 | out = self.conv2d(x)
26 | if self.norm in ['BN', 'IN']:
27 | out = self.norm_layer(out)
28 |
29 | if self.activation is not None:
30 | out = self.activation(out)
31 | return out
32 |
33 |
34 | class TransposedConvLayer(nn.Module):
35 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None):
36 | super(TransposedConvLayer, self).__init__()
37 |
38 | bias = False if norm == 'BN' else True
39 | self.transposed_conv2d = nn.ConvTranspose2d(
40 | in_channels, out_channels, kernel_size, stride=2, padding=padding, output_padding=1, bias=bias)
41 |
42 | if activation is not None:
43 | self.activation = getattr(torch, activation, 'relu')
44 | else:
45 | self.activation = None
46 |
47 | self.norm = norm
48 | if norm == 'BN':
49 | self.norm_layer = nn.BatchNorm2d(out_channels)
50 | elif norm == 'IN':
51 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
52 |
53 | def forward(self, x):
54 | out = self.transposed_conv2d(x)
55 |
56 | if self.norm in ['BN', 'IN']:
57 | out = self.norm_layer(out)
58 |
59 | if self.activation is not None:
60 | out = self.activation(out)
61 |
62 | return out
63 |
64 |
65 | class UpsampleConvLayer(nn.Module):
66 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None):
67 | super(UpsampleConvLayer, self).__init__()
68 |
69 | bias = False if norm == 'BN' else True
70 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
71 |
72 | if activation is not None:
73 | self.activation = getattr(torch, activation, 'relu')
74 | else:
75 | self.activation = None
76 |
77 | self.norm = norm
78 | if norm == 'BN':
79 | self.norm_layer = nn.BatchNorm2d(out_channels)
80 | elif norm == 'IN':
81 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True)
82 |
83 | def forward(self, x):
84 | x_upsampled = f.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
85 | out = self.conv2d(x_upsampled)
86 |
87 | if self.norm in ['BN', 'IN']:
88 | out = self.norm_layer(out)
89 |
90 | if self.activation is not None:
91 | out = self.activation(out)
92 |
93 | return out
94 |
95 |
96 | class RecurrentConvLayer(nn.Module):
97 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
98 | recurrent_block_type='convlstm', activation='relu', norm=None):
99 | super(RecurrentConvLayer, self).__init__()
100 |
101 | assert(recurrent_block_type in ['convlstm', 'convgru'])
102 | self.recurrent_block_type = recurrent_block_type
103 | if self.recurrent_block_type == 'convlstm':
104 | RecurrentBlock = ConvLSTM
105 | else:
106 | RecurrentBlock = ConvGRU
107 | self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm)
108 | self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3)
109 |
110 | def forward(self, x, prev_state):
111 | x = self.conv(x)
112 | # print('x', x)
113 | state = self.recurrent_block(x, prev_state)
114 | x = state[0] if self.recurrent_block_type == 'convlstm' else state
115 | return x, state
116 |
117 |
118 | class DownsampleRecurrentConvLayer(nn.Module):
119 | def __init__(self, in_channels, out_channels, kernel_size=3, recurrent_block_type='convlstm', padding=0, activation='relu'):
120 | super(DownsampleRecurrentConvLayer, self).__init__()
121 |
122 | self.activation = getattr(torch, activation, 'relu')
123 |
124 | assert(recurrent_block_type in ['convlstm', 'convgru'])
125 | self.recurrent_block_type = recurrent_block_type
126 | if self.recurrent_block_type == 'convlstm':
127 | RecurrentBlock = ConvLSTM
128 | else:
129 | RecurrentBlock = ConvGRU
130 | self.recurrent_block = RecurrentBlock(input_size=in_channels, hidden_size=out_channels, kernel_size=kernel_size)
131 |
132 | def forward(self, x, prev_state):
133 | state = self.recurrent_block(x, prev_state)
134 | x = state[0] if self.recurrent_block_type == 'convlstm' else state
135 | x = f.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
136 | return self.activation(x), state
137 |
138 |
139 | # Residual block
140 | class ResidualBlock(nn.Module):
141 | def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm=None):
142 | super(ResidualBlock, self).__init__()
143 | bias = False if norm == 'BN' else True
144 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=bias)
145 | self.norm = norm
146 | if norm == 'BN':
147 | self.bn1 = nn.BatchNorm2d(out_channels)
148 | self.bn2 = nn.BatchNorm2d(out_channels)
149 | elif norm == 'IN':
150 | self.bn1 = nn.InstanceNorm2d(out_channels)
151 | self.bn2 = nn.InstanceNorm2d(out_channels)
152 |
153 | self.relu = nn.ReLU(inplace=True)
154 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
155 | self.downsample = downsample
156 |
157 | def forward(self, x):
158 | residual = x
159 | out = self.conv1(x)
160 | if self.norm in ['BN', 'IN']:
161 | out = self.bn1(out)
162 | out = self.relu(out)
163 | out = self.conv2(out)
164 | if self.norm in ['BN', 'IN']:
165 | out = self.bn2(out)
166 |
167 | if self.downsample:
168 | residual = self.downsample(x)
169 |
170 | out += residual
171 | out = self.relu(out)
172 | return out
173 |
174 |
175 | class ConvLSTM(nn.Module):
176 | """Adapted from: https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py """
177 |
178 | def __init__(self, input_size, hidden_size, kernel_size):
179 | super(ConvLSTM, self).__init__()
180 |
181 | self.input_size = input_size
182 | self.hidden_size = hidden_size
183 | pad = kernel_size // 2
184 |
185 | # cache a tensor filled with zeros to avoid reallocating memory at each inference step if --no-recurrent is enabled
186 | self.zero_tensors = {}
187 |
188 | self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad)
189 |
190 | def forward(self, input_, prev_state=None):
191 | # get batch and spatial sizes
192 | batch_size = input_.data.size()[0]
193 | spatial_size = input_.data.size()[2:]
194 |
195 | # generate empty prev_state, if None is provided
196 | if prev_state is None:
197 |
198 | # create the zero tensor if it has not been created already
199 | state_size = tuple([batch_size, self.hidden_size] + list(spatial_size))
200 | if state_size not in self.zero_tensors:
201 | # allocate a tensor with size `spatial_size`, filled with zero (if it has not been allocated already)
202 | self.zero_tensors[state_size] = (
203 | torch.zeros(state_size).to(input_.device),
204 | torch.zeros(state_size).to(input_.device)
205 | )
206 |
207 | prev_state = self.zero_tensors[tuple(state_size)]
208 |
209 | prev_hidden, prev_cell = prev_state
210 |
211 | # data size is [batch, channel, height, width]
212 | stacked_inputs = torch.cat((input_, prev_hidden), 1)
213 | gates = self.Gates(stacked_inputs)
214 |
215 | # chunk across channel dimension
216 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
217 |
218 | # apply sigmoid non linearity
219 | in_gate = torch.sigmoid(in_gate)
220 | remember_gate = torch.sigmoid(remember_gate)
221 | out_gate = torch.sigmoid(out_gate)
222 |
223 | # apply tanh non linearity
224 | cell_gate = torch.tanh(cell_gate)
225 |
226 | # compute current cell and hidden state
227 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
228 | hidden = out_gate * torch.tanh(cell)
229 |
230 | return hidden, cell
231 |
232 |
233 | class ConvGRU(nn.Module):
234 | """
235 | Generate a convolutional GRU cell
236 | Adapted from: https://github.com/jacobkimmel/pytorch_convgru/blob/master/convgru.py
237 | """
238 |
239 | def __init__(self, input_size, hidden_size, kernel_size):
240 | super().__init__()
241 | padding = kernel_size // 2
242 | self.input_size = input_size
243 | self.hidden_size = hidden_size
244 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
245 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
246 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
247 |
248 | init.orthogonal_(self.reset_gate.weight)
249 | init.orthogonal_(self.update_gate.weight)
250 | init.orthogonal_(self.out_gate.weight)
251 | init.constant_(self.reset_gate.bias, 0.)
252 | init.constant_(self.update_gate.bias, 0.)
253 | init.constant_(self.out_gate.bias, 0.)
254 |
255 | def forward(self, input_, prev_state):
256 |
257 | # get batch and spatial sizes
258 | batch_size = input_.data.size()[0]
259 | spatial_size = input_.data.size()[2:]
260 |
261 | # generate empty prev_state, if None is provided
262 | if prev_state is None:
263 | state_size = [batch_size, self.hidden_size] + list(spatial_size)
264 | prev_state = torch.zeros(state_size).to(input_.device)
265 |
266 | # data size is [batch, channel, height, width]
267 | stacked_inputs = torch.cat([input_, prev_state], dim=1)
268 | update = torch.sigmoid(self.update_gate(stacked_inputs))
269 | reset = torch.sigmoid(self.reset_gate(stacked_inputs))
270 | out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1)))
271 | new_state = prev_state * (1 - update) + out_inputs * update
272 |
273 | return new_state
274 |
--------------------------------------------------------------------------------
/e2vid/model/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as f
4 | from torch.nn import init
5 | from .submodules import ConvLayer, UpsampleConvLayer, TransposedConvLayer, RecurrentConvLayer, ResidualBlock, ConvLSTM, ConvGRU
6 |
7 |
8 | def skip_concat(x1, x2):
9 | return torch.cat([x1, x2], dim=1)
10 |
11 |
12 | def skip_sum(x1, x2):
13 | return x1 + x2
14 |
15 |
16 | class BaseUNet(nn.Module):
17 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', activation='sigmoid',
18 | num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm=None, use_upsample_conv=True):
19 | super(BaseUNet, self).__init__()
20 |
21 | self.num_input_channels = num_input_channels
22 | self.num_output_channels = num_output_channels
23 | self.skip_type = skip_type
24 | self.apply_skip_connection = skip_sum if self.skip_type == 'sum' else skip_concat
25 | self.activation = activation
26 | self.norm = norm
27 |
28 | if use_upsample_conv:
29 | print('Using UpsampleConvLayer (slow, but no checkerboard artefacts)')
30 | self.UpsampleLayer = UpsampleConvLayer
31 | else:
32 | print('Using TransposedConvLayer (fast, with checkerboard artefacts)')
33 | self.UpsampleLayer = TransposedConvLayer
34 |
35 | self.num_encoders = num_encoders
36 | self.base_num_channels = base_num_channels
37 | self.num_residual_blocks = num_residual_blocks
38 | self.max_num_channels = self.base_num_channels * pow(2, self.num_encoders)
39 |
40 | assert(self.num_input_channels > 0)
41 | assert(self.num_output_channels > 0)
42 |
43 | self.encoder_input_sizes = []
44 | for i in range(self.num_encoders):
45 | self.encoder_input_sizes.append(self.base_num_channels * pow(2, i))
46 |
47 | self.encoder_output_sizes = [self.base_num_channels * pow(2, i + 1) for i in range(self.num_encoders)]
48 |
49 | self.activation = getattr(torch, self.activation, 'sigmoid')
50 |
51 | def build_resblocks(self):
52 | self.resblocks = nn.ModuleList()
53 | for i in range(self.num_residual_blocks):
54 | self.resblocks.append(ResidualBlock(self.max_num_channels, self.max_num_channels, norm=self.norm))
55 |
56 | def build_decoders(self):
57 | decoder_input_sizes = list(reversed([self.base_num_channels * pow(2, i + 1) for i in range(self.num_encoders)]))
58 |
59 | self.decoders = nn.ModuleList()
60 | for input_size in decoder_input_sizes:
61 | self.decoders.append(self.UpsampleLayer(input_size if self.skip_type == 'sum' else 2 * input_size,
62 | input_size // 2,
63 | kernel_size=5, padding=2, norm=self.norm))
64 |
65 | def build_prediction_layer(self):
66 | self.pred = ConvLayer(self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels,
67 | self.num_output_channels, 1, activation=None, norm=self.norm)
68 |
69 |
70 | class UNet(BaseUNet):
71 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', activation='sigmoid',
72 | num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm=None, use_upsample_conv=True):
73 | super(UNet, self).__init__(num_input_channels, num_output_channels, skip_type, activation,
74 | num_encoders, base_num_channels, num_residual_blocks, norm, use_upsample_conv)
75 |
76 | self.head = ConvLayer(self.num_input_channels, self.base_num_channels,
77 | kernel_size=5, stride=1, padding=2) # N x C x H x W -> N x 32 x H x W
78 |
79 | self.encoders = nn.ModuleList()
80 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes):
81 | self.encoders.append(ConvLayer(input_size, output_size, kernel_size=5,
82 | stride=2, padding=2, norm=self.norm))
83 |
84 | self.build_resblocks()
85 | self.build_decoders()
86 | self.build_prediction_layer()
87 |
88 | def forward(self, x):
89 | """
90 | :param x: N x num_input_channels x H x W
91 | :return: N x num_output_channels x H x W
92 | """
93 |
94 | # head
95 | x = self.head(x)
96 | head = x
97 |
98 | # encoder
99 | blocks = []
100 | for i, encoder in enumerate(self.encoders):
101 | x = encoder(x)
102 | blocks.append(x)
103 |
104 | # residual blocks
105 | for resblock in self.resblocks:
106 | x = resblock(x)
107 |
108 | # decoder
109 | for i, decoder in enumerate(self.decoders):
110 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1]))
111 |
112 | img = self.activation(self.pred(self.apply_skip_connection(x, head)))
113 |
114 | return img
115 |
116 |
117 | class UNetRecurrent(BaseUNet):
118 | """
119 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block,
120 | such as a ConvLSTM or a ConvGRU.
121 | Symmetric, skip connections on every encoding layer.
122 | """
123 |
124 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum',
125 | recurrent_block_type='convlstm', activation='sigmoid', num_encoders=4, base_num_channels=32,
126 | num_residual_blocks=2, norm=None, use_upsample_conv=True):
127 | super(UNetRecurrent, self).__init__(num_input_channels, num_output_channels, skip_type, activation,
128 | num_encoders, base_num_channels, num_residual_blocks, norm,
129 | use_upsample_conv)
130 |
131 | self.head = ConvLayer(self.num_input_channels, self.base_num_channels,
132 | kernel_size=5, stride=1, padding=2) # N x C x H x W -> N x 32 x H x W
133 |
134 | self.encoders = nn.ModuleList()
135 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes):
136 | self.encoders.append(RecurrentConvLayer(input_size, output_size,
137 | kernel_size=5, stride=2, padding=2,
138 | recurrent_block_type=recurrent_block_type,
139 | norm=self.norm))
140 |
141 | self.build_resblocks()
142 | self.build_decoders()
143 | self.build_prediction_layer()
144 |
145 | def forward(self, x, prev_states):
146 | """
147 | :param x: N x num_input_channels x H x W
148 | :param prev_states: previous LSTM states for every encoder layer
149 | :return: N x num_output_channels x H x W
150 | """
151 |
152 | # head
153 | x = self.head(x)
154 | head = x
155 |
156 | if prev_states is None:
157 | prev_states = [None] * self.num_encoders
158 |
159 | # encoder
160 | blocks = []
161 | states = []
162 | for i, encoder in enumerate(self.encoders):
163 | x, state = encoder(x, prev_states[i])
164 |
165 | blocks.append(x)
166 | states.append(state)
167 |
168 | # residual blocks
169 | for resblock in self.resblocks:
170 | x = resblock(x)
171 |
172 | latent = {1: head, 2: blocks[0], 4: blocks[1], 8: blocks[2]}
173 |
174 | # decoder
175 | for i, decoder in enumerate(self.decoders):
176 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1]))
177 |
178 | # tail
179 | img = self.activation(self.pred(self.apply_skip_connection(x, head)))
180 |
181 | return img, states, latent
182 |
183 | class UNetDecoder(BaseUNet):
184 | """
185 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block,
186 | such as a ConvLSTM or a ConvGRU.
187 | Symmetric, skip connections on every encoding layer.
188 | """
189 |
190 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum',
191 | recurrent_block_type='convlstm', activation='sigmoid', num_encoders=4, base_num_channels=32,
192 | num_residual_blocks=2, norm=None, use_upsample_conv=True):
193 | super(UNetDecoder, self).__init__(num_input_channels, num_output_channels, skip_type, activation,
194 | num_encoders, base_num_channels, num_residual_blocks, norm,
195 | use_upsample_conv)
196 |
197 | self.build_resblocks()
198 | self.build_decoders()
199 | self.build_prediction_layer()
200 |
201 | def forward(self, x, blocks, head):
202 | """
203 | :param x: N x num_input_channels x H x W
204 | :param prev_states: previous LSTM states for every encoder layer
205 | :return: N x num_output_channels x H x W
206 | """
207 |
208 | # residual blocks
209 | for resblock in self.resblocks:
210 | x = resblock(x)
211 |
212 | # decoder
213 | for i, decoder in enumerate(self.decoders):
214 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1]))
215 |
216 | # tail
217 | img = self.activation(self.pred(self.apply_skip_connection(x, head)))
218 |
219 | return img
220 |
221 |
222 | class UNetTask(BaseUNet):
223 | """
224 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block,
225 | such as a ConvLSTM or a ConvGRU.
226 | Symmetric, skip connections on every encoding layer.
227 | """
228 |
229 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum',
230 | recurrent_block_type='convlstm', activation='sigmoid', num_encoders=4, base_num_channels=32,
231 | num_residual_blocks=2, norm=None, use_upsample_conv=True):
232 | super(UNetTask, self).__init__(num_input_channels, num_output_channels, skip_type, activation,
233 | num_encoders, base_num_channels, num_residual_blocks, norm,
234 | use_upsample_conv)
235 | self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
236 | self.build_resblocks()
237 | self.build_decoders()
238 | self.build_prediction_layer_semseg()
239 |
240 | def build_prediction_layer_semseg(self):
241 | self.pred_semseg = torch.nn.Sequential(ConvLayer(self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels,
242 | self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels, 1, activation='relu', norm=self.norm),
243 | ConvLayer(
244 | self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels,
245 | self.num_output_channels, 1, activation=None, norm=None)
246 | )
247 |
248 | def update_skip_dict(self, skips, x, sz_in):
249 | rem, scale = sz_in % x.shape[3], sz_in // x.shape[3]
250 | assert rem == 0
251 | skips[scale] = x
252 |
253 | def forward(self, input_dict):
254 | """
255 | :param x: N x num_input_channels x H x W
256 | :param prev_states: previous LSTM states for every encoder layer
257 | :return: N x num_output_channels x H x W
258 | """
259 | sz_in = input_dict[1].shape[3]
260 |
261 | x = input_dict[8]
262 | out = {8: x}
263 | blocks = [input_dict[2], input_dict[4], input_dict[8]]
264 | head = torch.zeros((input_dict[2].shape[0], 32, 256, 512)).to(self.device)
265 |
266 | # residual blocks
267 | for resblock in self.resblocks:
268 | x = resblock(x)
269 |
270 | # decoder
271 | for i, decoder in enumerate(self.decoders):
272 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1]))
273 | self.update_skip_dict(out, x, sz_in)
274 |
275 | # tail
276 | pred = self.pred_semseg(self.apply_skip_connection(x, head))
277 | self.update_skip_dict(out, pred, sz_in)
278 |
279 | return out
280 |
--------------------------------------------------------------------------------
/e2vid/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/e2vid/options/__init__.py
--------------------------------------------------------------------------------
/e2vid/options/inference_options.py:
--------------------------------------------------------------------------------
1 | def set_inference_options(parser):
2 |
3 | parser.add_argument('-o', '--output_folder', default=None, type=str) # if None, will not write the images to disk
4 | parser.add_argument('--dataset_name', default='reconstruction', type=str)
5 |
6 | parser.add_argument('--use_gpu', dest='use_gpu', action='store_true')
7 | parser.set_defaults(use_gpu=True)
8 |
9 | """ Display """
10 | parser.add_argument('--display', dest='display', action='store_true')
11 | parser.set_defaults(display=False)
12 |
13 | parser.add_argument('--show_events', dest='show_events', action='store_true')
14 | parser.set_defaults(show_events=False)
15 |
16 | parser.add_argument('--event_display_mode', default='red-blue', type=str,
17 | help="Event display mode ('red-blue' or 'grayscale')")
18 |
19 | parser.add_argument('--num_bins_to_show', default=-1, type=int,
20 | help="Number of bins of the voxel grid to show when displaying events (-1 means show all the bins).")
21 |
22 | parser.add_argument('--display_border_crop', default=0, type=int,
23 | help="Remove the outer border of size display_border_crop before displaying image.")
24 |
25 | parser.add_argument('--display_wait_time', default=1, type=int,
26 | help="Time to wait after each call to cv2.imshow, in milliseconds (default: 1)")
27 |
28 | """ Post-processing / filtering """
29 |
30 | # (optional) path to a text file containing the locations of hot pixels to ignore
31 | parser.add_argument('--hot_pixels_file', default=None, type=str)
32 |
33 | # (optional) unsharp mask
34 | parser.add_argument('--unsharp_mask_amount', default=0.3, type=float)
35 | parser.add_argument('--unsharp_mask_sigma', default=1.0, type=float)
36 |
37 | # (optional) bilateral filter
38 | parser.add_argument('--bilateral_filter_sigma', default=0.0, type=float)
39 |
40 | # (optional) flip the event tensors vertically
41 | parser.add_argument('--flip', dest='flip', action='store_true')
42 | parser.set_defaults(flip=False)
43 |
44 | """ Tone mapping (i.e. rescaling of the image intensities)"""
45 | parser.add_argument('--Imin', default=0.0, type=float,
46 | help="Min intensity for intensity rescaling (linear tone mapping).")
47 | parser.add_argument('--Imax', default=1.0, type=float,
48 | help="Max intensity value for intensity rescaling (linear tone mapping).")
49 | parser.add_argument('--auto_hdr', dest='auto_hdr', action='store_true',
50 | help="If True, will compute Imin and Imax automatically.")
51 | parser.set_defaults(auto_hdr=False)
52 | parser.add_argument('--auto_hdr_median_filter_size', default=10, type=int,
53 | help="Size of the median filter window used to smooth temporally Imin and Imax")
54 |
55 | """ Perform color reconstruction? (only use this flag with the DAVIS346color) """
56 | parser.add_argument('--color', dest='color', action='store_true')
57 | parser.set_defaults(color=False)
58 |
59 | """ Advanced parameters """
60 | # disable normalization of input event tensors (saves a bit of time, but may produce slightly worse results)
61 | parser.add_argument('--no-normalize', dest='no_normalize', action='store_true')
62 | parser.set_defaults(no_normalize=False)
63 |
64 | # disable recurrent connection (will severely degrade the results; for testing purposes only)
65 | parser.add_argument('--no-recurrent', dest='no_recurrent', action='store_true')
66 | parser.set_defaults(no_recurrent=False)
67 |
--------------------------------------------------------------------------------
/e2vid/pretrained/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/e2vid/pretrained/.gitignore
--------------------------------------------------------------------------------
/e2vid/run_reconstruction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from e2vid.utils.loading_utils import load_model, get_device
3 | import numpy as np
4 | import argparse
5 | import pandas as pd
6 | from e2vid.utils.event_readers import FixedSizeEventReader, FixedDurationEventReader
7 | from e2vid.utils.inference_utils import events_to_voxel_grid, events_to_voxel_grid_pytorch
8 | from e2vid.utils.timers import Timer
9 | import time
10 | from e2vid.image_reconstructor import ImageReconstructor
11 | from e2vid.options.inference_options import set_inference_options
12 |
13 |
14 | if __name__ == "__main__":
15 |
16 | parser = argparse.ArgumentParser(
17 | description='Evaluating a trained network')
18 | parser.add_argument('-c', '--path_to_model', required=True, type=str,
19 | help='path to model weights')
20 | parser.add_argument('-i', '--input_file', required=True, type=str)
21 | parser.add_argument('--fixed_duration', dest='fixed_duration', action='store_true')
22 | parser.set_defaults(fixed_duration=False)
23 | parser.add_argument('-N', '--window_size', default=None, type=int,
24 | help="Size of each event window, in number of events. Ignored if --fixed_duration=True")
25 | parser.add_argument('-T', '--window_duration', default=33.33, type=float,
26 | help="Duration of each event window, in milliseconds. Ignored if --fixed_duration=False")
27 | parser.add_argument('--num_events_per_pixel', default=0.35, type=float,
28 | help='in case N (window size) is not specified, it will be \
29 | automatically computed as N = width * height * num_events_per_pixel')
30 | parser.add_argument('--skipevents', default=0, type=int)
31 | parser.add_argument('--suboffset', default=0, type=int)
32 | parser.add_argument('--compute_voxel_grid_on_cpu', dest='compute_voxel_grid_on_cpu', action='store_true')
33 | parser.set_defaults(compute_voxel_grid_on_cpu=False)
34 |
35 | set_inference_options(parser)
36 |
37 | args = parser.parse_args()
38 |
39 | # Read sensor size from the first first line of the event file
40 | path_to_events = args.input_file
41 |
42 | header = pd.read_csv(path_to_events, delim_whitespace=True, header=None, names=['width', 'height'],
43 | dtype={'width': np.int, 'height': np.int},
44 | nrows=1)
45 | width, height = header.values[0]
46 | print('Sensor size: {} x {}'.format(width, height))
47 |
48 | # Load model
49 | model = load_model(args.path_to_model)
50 | device = get_device(args.use_gpu)
51 |
52 | model = model.to(device)
53 | model.eval()
54 |
55 | reconstructor = ImageReconstructor(model, height, width, model.num_bins, args)
56 |
57 | """ Read chunks of events using Pandas """
58 |
59 | # Loop through the events and reconstruct images
60 | N = args.window_size
61 | if not args.fixed_duration:
62 | if N is None:
63 | N = int(width * height * args.num_events_per_pixel)
64 | print('Will use {} events per tensor (automatically estimated with num_events_per_pixel={:0.2f}).'.format(
65 | N, args.num_events_per_pixel))
66 | else:
67 | print('Will use {} events per tensor (user-specified)'.format(N))
68 | mean_num_events_per_pixel = float(N) / float(width * height)
69 | if mean_num_events_per_pixel < 0.1:
70 | print('!!Warning!! the number of events used ({}) seems to be low compared to the sensor size. \
71 | The reconstruction results might be suboptimal.'.format(N))
72 | elif mean_num_events_per_pixel > 1.5:
73 | print('!!Warning!! the number of events used ({}) seems to be high compared to the sensor size. \
74 | The reconstruction results might be suboptimal.'.format(N))
75 |
76 | initial_offset = args.skipevents
77 | sub_offset = args.suboffset
78 | start_index = initial_offset + sub_offset
79 |
80 | if args.compute_voxel_grid_on_cpu:
81 | print('Will compute voxel grid on CPU.')
82 |
83 | if args.fixed_duration:
84 | event_window_iterator = FixedDurationEventReader(path_to_events,
85 | duration_ms=args.window_duration,
86 | start_index=start_index)
87 | else:
88 | event_window_iterator = FixedSizeEventReader(path_to_events, num_events=N, start_index=start_index)
89 |
90 | with Timer('Processing entire dataset'):
91 | for event_window in event_window_iterator:
92 | print(event_window.shape)
93 | last_timestamp = event_window[-1, 0]
94 |
95 | with Timer('Building event tensor'):
96 | if args.compute_voxel_grid_on_cpu:
97 | event_tensor = events_to_voxel_grid(event_window,
98 | num_bins=model.num_bins,
99 | width=width,
100 | height=height)
101 | event_tensor = torch.from_numpy(event_tensor)
102 | else:
103 | event_tensor = events_to_voxel_grid_pytorch(event_window,
104 | num_bins=model.num_bins,
105 | width=width,
106 | height=height,
107 | device=device)
108 |
109 | num_events_in_window = event_window.shape[0]
110 | reconstructor.update_reconstruction(event_tensor, start_index + num_events_in_window, last_timestamp)
111 |
112 | start_index += num_events_in_window
113 |
--------------------------------------------------------------------------------
/e2vid/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/e2vid/utils/__init__.py
--------------------------------------------------------------------------------
/e2vid/utils/event_readers.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import zipfile
3 | from os.path import splitext
4 | import numpy as np
5 | from .timers import Timer
6 |
7 |
8 | class FixedSizeEventReader:
9 | """
10 | Reads events from a '.txt' or '.zip' file, and packages the events into
11 | non-overlapping event windows, each containing a fixed number of events.
12 | """
13 |
14 | def __init__(self, path_to_event_file, num_events=10000, start_index=0):
15 | print('Will use fixed size event windows with {} events'.format(num_events))
16 | print('Output frame rate: variable')
17 | self.iterator = pd.read_csv(path_to_event_file, delim_whitespace=True, header=None,
18 | names=['t', 'x', 'y', 'pol'],
19 | dtype={'t': np.float64, 'x': np.int16, 'y': np.int16, 'pol': np.int16},
20 | engine='c',
21 | skiprows=start_index + 1, chunksize=num_events, nrows=None, memory_map=True)
22 |
23 | def __iter__(self):
24 | return self
25 |
26 | def __next__(self):
27 | with Timer('Reading event window from file'):
28 | event_window = self.iterator.__next__().values
29 | return event_window
30 |
31 |
32 | class FixedDurationEventReader:
33 | """
34 | Reads events from a '.txt' or '.zip' file, and packages the events into
35 | non-overlapping event windows, each of a fixed duration.
36 |
37 | **Note**: This reader is much slower than the FixedSizeEventReader.
38 | The reason is that the latter can use Pandas' very efficient cunk-based reading scheme implemented in C.
39 | """
40 |
41 | def __init__(self, path_to_event_file, duration_ms=50.0, start_index=0):
42 | print('Will use fixed duration event windows of size {:.2f} ms'.format(duration_ms))
43 | print('Output frame rate: {:.1f} Hz'.format(1000.0 / duration_ms))
44 | file_extension = splitext(path_to_event_file)[1]
45 | assert(file_extension in ['.txt', '.zip'])
46 | self.is_zip_file = (file_extension == '.zip')
47 |
48 | if self.is_zip_file: # '.zip'
49 | self.zip_file = zipfile.ZipFile(path_to_event_file)
50 | files_in_archive = self.zip_file.namelist()
51 | assert(len(files_in_archive) == 1) # make sure there is only one text file in the archive
52 | self.event_file = self.zip_file.open(files_in_archive[0], 'r')
53 | else:
54 | self.event_file = open(path_to_event_file, 'r')
55 |
56 | # ignore header + the first start_index lines
57 | for i in range(1 + start_index):
58 | self.event_file.readline()
59 |
60 | self.last_stamp = None
61 | self.duration_s = duration_ms / 1000.0
62 |
63 | def __iter__(self):
64 | return self
65 |
66 | def __del__(self):
67 | if self.is_zip_file:
68 | self.zip_file.close()
69 |
70 | self.event_file.close()
71 |
72 | def __next__(self):
73 | with Timer('Reading event window from file'):
74 | event_list = []
75 | for line in self.event_file:
76 | if self.is_zip_file:
77 | line = line.decode("utf-8")
78 | t, x, y, pol = line.split(' ')
79 | t, x, y, pol = float(t), int(x), int(y), int(pol)
80 | event_list.append([t, x, y, pol])
81 | if self.last_stamp is None:
82 | self.last_stamp = t
83 | if t > self.last_stamp + self.duration_s:
84 | self.last_stamp = t
85 | event_window = np.array(event_list)
86 | return event_window
87 |
88 | raise StopIteration
89 |
--------------------------------------------------------------------------------
/e2vid/utils/loading_utils.py:
--------------------------------------------------------------------------------
1 | from e2vid.model.model import *
2 | from collections import OrderedDict
3 |
4 |
5 | def load_model(path_to_model, return_task=False):
6 | print('Loading model {}...'.format(path_to_model))
7 | raw_model = torch.load(path_to_model)
8 | arch = raw_model['arch']
9 |
10 | try:
11 | model_type = raw_model['model']
12 | except KeyError:
13 | model_type = raw_model['config']['model']
14 |
15 | # instantiate model
16 | model = eval(arch)(model_type)
17 |
18 | # load model weights
19 | model.load_state_dict(raw_model['state_dict'])
20 |
21 | # E2VID Decoder
22 | decoder = E2VIDDecoder(model_type)
23 | decoder.load_state_dict(raw_model['state_dict'], strict=False)
24 |
25 | if return_task:
26 | # E2VID Task
27 | task = E2VIDTask(model_type)
28 | new_dict = copyStateDict(raw_model['state_dict'])
29 | keys = []
30 | for k, v in new_dict.items():
31 | if k.startswith('unetrecurrent.pred'):
32 | continue
33 | keys.append(k)
34 |
35 | new_dict = {k: new_dict[k] for k in keys}
36 | task.load_state_dict(new_dict, strict=False)
37 | return model, decoder, task
38 | return model, decoder
39 |
40 |
41 | def get_device(use_gpu):
42 | if use_gpu and torch.cuda.is_available():
43 | device = torch.device('cuda:0')
44 | else:
45 | device = torch.device('cpu')
46 | print('Device:', device)
47 |
48 | return device
49 |
50 | def copyStateDict(state_dict):
51 | if list(state_dict.keys())[0].startswith('module'):
52 | start_idx = 1
53 | else:
54 | start_idx = 0
55 | new_state_dict = OrderedDict()
56 | for k,v in state_dict.items():
57 | name = '.'.join(k.split('.')[start_idx:])
58 |
59 | new_state_dict[name] = v
60 | return new_state_dict
61 |
62 |
--------------------------------------------------------------------------------
/e2vid/utils/path_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def ensure_dir(path):
5 | if not os.path.exists(path):
6 | os.makedirs(path)
7 |
--------------------------------------------------------------------------------
/e2vid/utils/timers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import numpy as np
4 | import atexit
5 |
6 | cuda_timers = {}
7 | timers = {}
8 |
9 |
10 | class CudaTimer:
11 | def __init__(self, timer_name=''):
12 | self.timer_name = timer_name
13 | if self.timer_name not in cuda_timers:
14 | cuda_timers[self.timer_name] = []
15 |
16 | self.start = torch.cuda.Event(enable_timing=True)
17 | self.end = torch.cuda.Event(enable_timing=True)
18 |
19 | def __enter__(self):
20 | self.start.record()
21 | return self
22 |
23 | def __exit__(self, *args):
24 | self.end.record()
25 | torch.cuda.synchronize()
26 | cuda_timers[self.timer_name].append(self.start.elapsed_time(self.end))
27 |
28 |
29 | class Timer:
30 | def __init__(self, timer_name=''):
31 | self.timer_name = timer_name
32 | if self.timer_name not in timers:
33 | timers[self.timer_name] = []
34 |
35 | def __enter__(self):
36 | self.start = time.time()
37 | return self
38 |
39 | def __exit__(self, *args):
40 | self.end = time.time()
41 | self.interval = self.end - self.start # measured in seconds
42 | self.interval *= 1000.0 # convert to milliseconds
43 | timers[self.timer_name].append(self.interval)
44 |
45 |
46 | def print_timing_info():
47 | print('== Timing statistics ==')
48 | for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]:
49 | timing_value = np.mean(np.array(timing_values))
50 | if timing_value < 1000.0:
51 | print('{}: {:.2f} ms'.format(timer_name, timing_value))
52 | else:
53 | print('{}: {:.2f} s'.format(timer_name, timing_value / 1000.0))
54 |
55 |
56 | # this will print all the timer values upon termination of any program that imported this file
57 | atexit.register(print_timing_info)
58 |
--------------------------------------------------------------------------------
/e2vid/utils/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from math import fabs
3 |
4 |
5 | def robust_min(img, p=5):
6 | return np.percentile(img.ravel(), p)
7 |
8 |
9 | def robust_max(img, p=95):
10 | return np.percentile(img.ravel(), p)
11 |
12 |
13 | def normalize(img, m=10, M=90):
14 | return np.clip((img - robust_min(img, m)) / (robust_max(img, M) - robust_min(img, m)), 0.0, 1.0)
15 |
16 |
17 | def first_element_greater_than(values, req_value):
18 | """Returns the pair (i, values[i]) such that i is the minimum value that satisfies values[i] >= req_value.
19 | Returns (-1, None) if there is no such i.
20 | Note: this function assumes that values is a sorted array!"""
21 | i = np.searchsorted(values, req_value)
22 | val = values[i] if i < len(values) else None
23 | return (i, val)
24 |
25 |
26 | def last_element_less_than(values, req_value):
27 | """Returns the pair (i, values[i]) such that i is the maximum value that satisfies values[i] <= req_value.
28 | Returns (-1, None) if there is no such i.
29 | Note: this function assumes that values is a sorted array!"""
30 | i = np.searchsorted(values, req_value, side='right') - 1
31 | val = values[i] if i >= 0 else None
32 | return (i, val)
33 |
34 |
35 | def closest_element_to(values, req_value):
36 | """Returns the tuple (i, values[i], diff) such that i is the closest value to req_value,
37 | and diff = |values(i) - req_value|
38 | Note: this function assumes that values is a sorted array!"""
39 | assert(len(values) > 0)
40 |
41 | i = np.searchsorted(values, req_value, side='left')
42 | if i > 0 and (i == len(values) or fabs(req_value - values[i - 1]) < fabs(req_value - values[i])):
43 | idx = i - 1
44 | val = values[i - 1]
45 | else:
46 | idx = i
47 | val = values[i]
48 |
49 | diff = fabs(val - req_value)
50 | return (idx, val, diff)
51 |
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/evaluation/__init__.py
--------------------------------------------------------------------------------
/evaluation/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def semseg_compute_confusion(y_hat_lbl, y_lbl, num_classes, ignore_label):
5 | assert torch.is_tensor(y_hat_lbl) and torch.is_tensor(y_lbl), 'Inputs must be torch tensors'
6 | assert y_lbl.device == y_hat_lbl.device, 'Input tensors have different device placement'
7 |
8 | assert y_hat_lbl.dim() == 3 or y_hat_lbl.dim() == 4 and y_hat_lbl.shape[1] == 1
9 | assert y_lbl.dim() == 3 or y_lbl.dim() == 4 and y_lbl.shape[1] == 1
10 | if y_hat_lbl.dim() == 4:
11 | y_hat_lbl = y_hat_lbl.squeeze(1)
12 | if y_lbl.dim() == 4:
13 | y_lbl = y_lbl.squeeze(1)
14 |
15 | mask = y_lbl != ignore_label
16 | y_hat_lbl = y_hat_lbl[mask]
17 | y_lbl = y_lbl[mask]
18 |
19 | # hack for bincounting 2 arrays together
20 | x = y_hat_lbl + num_classes * y_lbl
21 | bincount_2d = torch.bincount(x.long(), minlength=num_classes ** 2)
22 | assert bincount_2d.numel() == num_classes ** 2, 'Internal error'
23 | conf = bincount_2d.view((num_classes, num_classes)).long()
24 | return conf
25 |
26 |
27 | def semseg_accum_confusion_to_iou(confusion_accum):
28 | conf = confusion_accum.double()
29 | diag = conf.diag()
30 | iou_per_class = 100 * diag / (conf.sum(dim=1) + conf.sum(dim=0) - diag).clamp(min=1e-12)
31 | iou_mean = iou_per_class.mean()
32 | return iou_mean, iou_per_class
33 |
34 | def semseg_accum_confusion_to_acc(confusion_accum):
35 | conf = confusion_accum.double()
36 | diag = conf.diag()
37 | acc = 100 * diag.sum() / (conf.sum(dim=1).sum()).clamp(min=1e-12)
38 | return acc
39 |
40 | class MetricsSemseg:
41 | def __init__(self, num_classes, ignore_label, class_names):
42 | self.num_classes = num_classes
43 | self.ignore_label = ignore_label
44 | self.class_names = class_names
45 | self.metrics_acc = None
46 |
47 | def reset(self):
48 | self.metrics_acc = None
49 |
50 | def update_batch(self, y_hat_lbl, y_lbl):
51 | with torch.no_grad():
52 | metrics_batch = semseg_compute_confusion(y_hat_lbl, y_lbl, self.num_classes, self.ignore_label).cpu()
53 | if self.metrics_acc is None:
54 | self.metrics_acc = metrics_batch
55 | else:
56 | self.metrics_acc += metrics_batch
57 |
58 | def get_metrics_summary(self):
59 | iou_mean, iou_per_class = semseg_accum_confusion_to_iou(self.metrics_acc)
60 | out = {self.class_names[i]: iou for i, iou in enumerate(iou_per_class)}
61 | out['mean_iou'] = iou_mean
62 | acc = semseg_accum_confusion_to_acc((self.metrics_acc))
63 | out['acc'] = acc
64 | out['cm'] = self.metrics_acc
65 | return out
66 |
67 |
68 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/models/__init__.py
--------------------------------------------------------------------------------
/models/style_networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | import torch.nn.functional as f
5 |
6 | from models.submodules import InterpolationLayer
7 |
8 |
9 | class SemSegE2VID(nn.Module):
10 | def __init__(self, input_c, output_c, skip_connect=False, skip_type='sum', input_index_map=False):
11 | super(SemSegE2VID, self).__init__()
12 | self.skip_connect = skip_connect
13 | self.skip_type = skip_type
14 | self.apply_skip_connection = skip_sum if self.skip_type == 'sum' else skip_concat
15 | tch = input_c
16 | self.index_coords = None
17 | self.input_index_map = input_index_map
18 |
19 | if self.skip_connect:
20 | decoder_list_1 = []
21 | for i in range(0, 5): # 3, 5
22 | decoder_list_1 += [INSResBlock(tch, tch)]
23 | decoder_list_1 += [ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1)]
24 | self.decoder_scale_1 = torch.nn.Sequential(*decoder_list_1)
25 | self.decoder_scale_2 = nn.Sequential(ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1),
26 | ReLUINSConv2d(tch // 2, tch // 4, kernel_size=3, stride=1, padding=1))
27 | tch = tch // 2
28 | self.decoder_scale_3 = nn.Sequential(ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1),
29 | ReLUINSConv2d(tch // 2, tch // 2, kernel_size=3, stride=1, padding=1))
30 | tch = tch // 2
31 | self.decoder_scale_4 = nn.Sequential(ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1))
32 | tch = tch // 2
33 | self.decoder_scale_5 = nn.Sequential(
34 | torch.nn.Conv2d(tch, output_c, kernel_size=1, stride=1, padding=0))
35 | else:
36 | if self.input_index_map:
37 | tch += 2
38 | # Instance Norm
39 | decoder_list_1 = []
40 | for i in range(0, 3):
41 | decoder_list_1 += [INSResBlock(tch, tch)]
42 |
43 | self.decoder_scale_1 = torch.nn.Sequential(*decoder_list_1)
44 | tch = tch
45 | if self.input_index_map:
46 | self.decoder_scale_2 = nn.Sequential(InterpolationLayer(scale_factor=2, mode='nearest'),
47 | ReLUINSConv2d(tch, (tch - 2) // 2, kernel_size=3, stride=1,
48 | padding=1))
49 | tch = (tch - 2) // 2
50 | else:
51 | self.decoder_scale_2 = nn.Sequential(InterpolationLayer(scale_factor=2, mode='nearest'),
52 | ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1))
53 | tch = tch // 2
54 | self.decoder_scale_3 = nn.Sequential(InterpolationLayer(scale_factor=2, mode='nearest'),
55 | ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1))
56 | tch = tch // 2
57 | tch = tch
58 | self.decoder_scale_4 = nn.Sequential(InterpolationLayer(scale_factor=2, mode='nearest'),
59 | ReLUINSConv2d(tch, tch // 2, kernel_size=3, stride=1, padding=1))
60 | tch = tch // 2
61 | self.decoder_scale_5 = nn.Sequential(
62 | torch.nn.Conv2d(tch, output_c, kernel_size=1, stride=1, padding=0))
63 |
64 | def update_skip_dict(self, skips, x, sz_in):
65 | rem, scale = sz_in % x.shape[3], sz_in // x.shape[3]
66 | assert rem == 0
67 | skips[scale] = x
68 |
69 | def forward(self, input_dict):
70 | sz_in = input_dict[1].shape[3]
71 |
72 | x = input_dict[8]
73 | out = {8: x}
74 |
75 | if self.skip_connect:
76 | x = self.decoder_scale_1(x)
77 | x = f.interpolate(x, scale_factor=2, mode='nearest')
78 | x = self.apply_skip_connection(x, input_dict[4])
79 | x = self.decoder_scale_2(x)
80 | self.update_skip_dict(out, x, sz_in)
81 | x = f.interpolate(x, scale_factor=2, mode='nearest')
82 | x = self.apply_skip_connection(x, input_dict[2])
83 | x = self.decoder_scale_3(x)
84 | self.update_skip_dict(out, x, sz_in)
85 | x = f.interpolate(x, scale_factor=2, mode='nearest')
86 | x = self.decoder_scale_4(x)
87 | x = self.decoder_scale_5(x)
88 | self.update_skip_dict(out, x, sz_in)
89 | else:
90 | if self.input_index_map:
91 | if self.index_coords is None or self.index_coords.size(2) != x.size(2):
92 | x_coords = torch.arange(x.size(2), device=x.device, dtype=torch.float)
93 | y_coords = torch.arange(x.size(3), device=x.device, dtype=torch.float)
94 | self.index_coords = torch.stack(torch.meshgrid([x_coords,
95 | y_coords]), dim=0)
96 | self.index_coords = self.index_coords[None, :, :, :].repeat([x.size(0), 1, 1, 1])
97 | x = torch.cat([x, self.index_coords], dim=1)
98 |
99 | x = self.decoder_scale_1(x)
100 | x = self.decoder_scale_2(x)
101 | self.update_skip_dict(out, x, sz_in)
102 | x = self.decoder_scale_3(x)
103 | self.update_skip_dict(out, x, sz_in)
104 | x = self.decoder_scale_4(x)
105 | x = self.decoder_scale_5(x)
106 | self.update_skip_dict(out, x, sz_in)
107 | return out
108 |
109 |
110 | class StyleEncoderE2VID(nn.Module):
111 | def __init__(self, input_dim, skip_connect=False):
112 | super(StyleEncoderE2VID, self).__init__()
113 | conv_list = []
114 | self.skip_connect = skip_connect
115 |
116 | conv_list += [nn.Conv2d(input_dim, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)]
117 | conv_list += list(models.resnet18(pretrained=True).children())[1:3]
118 | conv_list += list(models.resnet18(pretrained=True).children())[4:5]
119 | self.encoder_scale_1 = nn.Sequential(*conv_list)
120 | self.encoder_scale_2 = list(models.resnet18(pretrained=True).children())[5]
121 | self.encoder_scale_3 = list(models.resnet18(pretrained=True).children())[6]
122 |
123 | def update_skip_dict(self, skips, x, sz_in):
124 | rem, scale = sz_in % x.shape[3], sz_in // x.shape[3]
125 | assert rem == 0
126 | skips[scale] = x
127 |
128 | def forward(self, x):
129 | out = {1: x}
130 | sz_in = x.shape[3]
131 |
132 | if self.skip_connect:
133 | x = self.encoder_scale_1(x)
134 | self.update_skip_dict(out, x, sz_in)
135 | x = self.encoder_scale_2(x)
136 | self.update_skip_dict(out, x, sz_in)
137 | x = self.encoder_scale_3(x)
138 | self.update_skip_dict(out, x, sz_in)
139 | else:
140 | x = self.encoder_scale_1(x)
141 | x = self.encoder_scale_2(x)
142 | x = self.encoder_scale_3(x)
143 | self.update_skip_dict(out, x, sz_in)
144 |
145 | return out
146 |
147 |
148 | ####################################################################
149 | # -------------------------- Basic Blocks --------------------------
150 | ####################################################################
151 |
152 | def gaussian_weights_init(m):
153 | classname = m.__class__.__name__
154 | if classname.find('Conv') != -1 and classname.find('Conv') == 0:
155 | m.weight.data.normal_(0.0, 0.02)
156 |
157 |
158 | class ReLUINSConv2d(nn.Module):
159 | def __init__(self, n_in, n_out, kernel_size, stride, padding=0):
160 | super(ReLUINSConv2d, self).__init__()
161 | model = []
162 | model += [nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)]
163 | model += [nn.InstanceNorm2d(n_out, affine=False)]
164 | model += [nn.ReLU(inplace=True)]
165 | self.model = nn.Sequential(*model)
166 | self.model.apply(gaussian_weights_init)
167 |
168 | def forward(self, x):
169 | return self.model(x)
170 |
171 |
172 | class INSResBlock(nn.Module):
173 | def conv3x3(self, inplanes, out_planes, stride=1):
174 | return [nn.Conv2d(inplanes, out_planes, kernel_size=3, stride=stride, padding=1)]
175 |
176 | def __init__(self, inplanes, planes, stride=1, dropout=0.0):
177 | super(INSResBlock, self).__init__()
178 | model = []
179 | model += self.conv3x3(inplanes, planes, stride)
180 | model += [nn.InstanceNorm2d(planes)]
181 | model += [nn.ReLU(inplace=True)]
182 | model += self.conv3x3(planes, planes)
183 | model += [nn.InstanceNorm2d(planes)]
184 | if dropout > 0:
185 | model += [nn.Dropout(p=dropout)]
186 | self.model = nn.Sequential(*model)
187 | self.model.apply(gaussian_weights_init)
188 |
189 | def forward(self, x):
190 | residual = x
191 | out = self.model(x)
192 | out += residual
193 | return out
194 |
195 |
196 | def skip_concat(x1, x2):
197 | return torch.cat([x1, x2], dim=1)
198 |
199 |
200 | def skip_sum(x1, x2):
201 | return x1 + x2
202 |
203 |
204 |
205 |
--------------------------------------------------------------------------------
/models/submodules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as f
4 | from torch.nn import init
5 |
6 |
7 | class InterpolationLayer(nn.Module):
8 | def __init__(self, size=None, scale_factor=None, mode='nearest'):
9 | super(InterpolationLayer, self).__init__()
10 | self.interp = nn.functional.interpolate
11 | self.scale_factor = scale_factor
12 | self.size = size
13 | self.mode = mode
14 |
15 | def forward(self, x):
16 | if self.scale_factor is not None:
17 | if self.mode == 'nearest' and self.scale_factor == 2:
18 | return x[:, :, :, None, :, None].expand(-1, -1, -1, 2, -1, 2).reshape(x.size(0), x.size(1),
19 | 2 * x.size(2), 2 * x.size(3))
20 | else:
21 | return self.interp(x, scale_factor=self.scale_factor, mode=self.mode)
22 |
23 | else:
24 | return self.interp(x, size=self.size, mode=self.mode)
25 |
26 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.11.0
2 | aiohttp==3.7.4.post0
3 | alabaster==0.7.12
4 | albumentations==1.1.0
5 | appdirs==1.4.4
6 | astor==0.8.1
7 | async-timeout==3.0.1
8 | attrs==21.2.0
9 | Babel==2.8.0
10 | backcall==0.2.0
11 | beautifulsoup4==4.9.1
12 | blinker==1.4
13 | brotlipy==0.7.0
14 | cachetools==4.2.1
15 | certifi==2021.10.8
16 | cffi==1.14.0
17 | chardet==4.0.0
18 | charset-normalizer==2.0.4
19 | cityscapesScripts==2.2.0
20 | click==8.0.1
21 | coloredlogs==15.0.1
22 | configparser==5.2.0
23 | coverage==5.5
24 | cryptography==3.4.8
25 | cycler==0.10.0
26 | Cython==0.29.24
27 | decorator==4.4.2
28 | docker-pycreds==0.4.0
29 | docutils==0.16
30 | future==0.18.2
31 | gitdb==4.0.9
32 | GitPython==3.1.27
33 | google-auth==1.27.0
34 | google-auth-oauthlib==0.4.2
35 | google-pasta==0.2.0
36 | gql==0.2.0
37 | graphql-core==1.1
38 | grpcio==1.35.0
39 | h5py==2.10.0
40 | hdf5plugin==3.2.0
41 | humanfriendly==10.0
42 | idna==2.10
43 | imageio==2.13.3
44 | imagesize==1.2.0
45 | importlib-metadata==4.8.1
46 | ipython==7.16.1
47 | ipython-genutils==0.2.0
48 | javalang==0.13.0
49 | javasphinx==0.9.15
50 | jedi==0.17.0
51 | Jinja2==2.11.2
52 | joblib==0.16.0
53 | kiwisolver==1.2.0
54 | llvmlite==0.38.0rc1
55 | lxml==4.5.1
56 | Markdown==3.3.4
57 | MarkupSafe==1.1.1
58 | matplotlib==3.3.1
59 | mkl-fft
60 | mkl-random
61 | mkl-service
62 | multidict==5.1.0
63 | networkx==2.6.3
64 | numba==0.55.0rc1
65 | numpy==1.18.5
66 | nvidia-ml-py3==7.352.0
67 | oauthlib==3.1.0
68 | olefile==0.46
69 | opencv-python==4.3.0.38
70 | opencv-python-headless==4.5.4.60
71 | packaging==20.4
72 | parso==0.8.1
73 | pathtools==0.1.2
74 | pexpect==4.8.0
75 | pickleshare==0.7.5
76 | Pillow==7.2.0
77 | pip==20.1.1
78 | promise==2.3
79 | prompt-toolkit==3.0.8
80 | protobuf==3.18.1
81 | psutil==5.9.0
82 | ptyprocess==0.7.0
83 | pyasn1==0.4.8
84 | pyasn1-modules==0.2.8
85 | pycparser==2.20
86 | Pygments==2.7.3
87 | PyJWT==2.1.0
88 | pyOpenSSL==20.0.1
89 | pyparsing==2.4.7
90 | pyquaternion==0.9.9
91 | PySocks==1.7.1
92 | python-dateutil==2.8.1
93 | pytz==2020.1
94 | PyWavelets==1.2.0
95 | PyYAML==5.3.1
96 | qudida==0.0.4
97 | requests==2.25.1
98 | requests-oauthlib==1.3.0
99 | rsa==4.7.2
100 | scikit-image==0.19.0
101 | scikit-learn==0.23.2
102 | scipy==1.4.1
103 | sentry-sdk==1.5.7
104 | setproctitle==1.2.2
105 | setuptools==47.3.1.post20200622
106 | shortuuid==1.0.8
107 | sip==4.19.13
108 | six==1.15.0
109 | smmap==5.0.0
110 | snowballstemmer==2.0.0
111 | soupsieve==2.0.1
112 | Sphinx==2.4.4
113 | sphinxcontrib-applehelp==1.0.2
114 | sphinxcontrib-devhelp==1.0.2
115 | sphinxcontrib-htmlhelp==1.0.3
116 | sphinxcontrib-jsmath==1.0.1
117 | sphinxcontrib-katex==0.6.1
118 | sphinxcontrib-qthelp==1.0.3
119 | sphinxcontrib-serializinghtml==1.1.4
120 | subprocess32==3.5.4
121 | tensorboard==2.4.1
122 | tensorboard-plugin-wit==1.8.0
123 | tensorboardX==2.1
124 | termcolor==1.1.0
125 | threadpoolctl==2.1.0
126 | tifffile==2021.11.2
127 | timm==0.5.4
128 | torch==1.6.0
129 | torchvision==0.7.0
130 | tornado==6.1
131 | tqdm==4.48.2
132 | traitlets==4.3.3
133 | typing==3.7.4.3
134 | typing-extensions==3.7.4.3
135 | urllib3==1.26.3
136 | wandb==0.9.7
137 | watchdog==2.1.6
138 | wcwidth==0.2.5
139 | Werkzeug==1.0.1
140 | wheel==0.34.2
141 | wrapt==1.12.1
142 | yarl==1.6.3
143 | yaspin==2.1.0
144 | zipp==3.4.0
145 |
--------------------------------------------------------------------------------
/resources/ESS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/resources/ESS.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Example usage: CUDA_VISIBLE_DEVICES=1, python train.py --settings_file "config/settings_DDD17.yaml"
3 | """
4 | import argparse
5 | import wandb
6 |
7 | from config.settings import Settings
8 | from training.ess_trainer import ESSModel
9 | from training.ess_supervised_trainer import ESSSupervisedModel
10 |
11 | import numpy as np
12 | import torch
13 | import random
14 | import os
15 |
16 | # random seed
17 | seed_value = 6
18 | np.random.seed(seed_value)
19 | random.seed(seed_value)
20 | os.environ['PYTHONHASHSEED'] = str(seed_value)
21 |
22 | torch.manual_seed(seed_value)
23 | torch.cuda.manual_seed(seed_value)
24 | torch.cuda.manual_seed_all(seed_value)
25 | torch.backends.cudnn.deterministic = True
26 |
27 | def main():
28 | parser = argparse.ArgumentParser(description='Train network.')
29 | parser.add_argument('--settings_file', help='Path to settings yaml', required=True)
30 |
31 | args = parser.parse_args()
32 | settings_filepath = args.settings_file
33 | settings = Settings(settings_filepath, generate_log=True)
34 |
35 | wandb.init(name=(settings.dataset_name_b.split("_")[0] + '_' + settings.timestr), project="zhaoning_sun_semester_thesis", entity="zhasun", sync_tensorboard=True)
36 |
37 | if settings.model_name == 'ess':
38 | trainer = ESSModel(settings)
39 | elif settings.model_name == 'ess_supervised':
40 | trainer = ESSSupervisedModel(settings)
41 |
42 | else:
43 | raise ValueError('Model name %s specified in the settings file is not implemented' % settings.model_name)
44 |
45 | wandb.config = {
46 | "random_seed": seed_value,
47 | "lr_front": settings.lr_front,
48 | "lr_back": settings.lr_back,
49 | "batch_size_a": settings.batch_size_a,
50 | "batch_size_b": settings.batch_size_b
51 | }
52 |
53 | trainer.train()
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/training/__init__.py
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uzh-rpg/ess/46bf1eed677d869c4733c89d2284e43ea27f97bd/utils/__init__.py
--------------------------------------------------------------------------------
/utils/labels.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | # a label and all meta information
4 | Label = namedtuple('Label', [
5 |
6 | 'name', # The identifier of this label, e.g. 'car', 'person', ... .
7 | # We use them to uniquely name a class
8 |
9 | 'id', # An integer ID that is associated with this label.
10 | # The IDs are used to represent the label in ground truth images
11 | # An ID of -1 means that this label does not have an ID and thus
12 | # is ignored when creating ground truth images (e.g. license plate).
13 | # Do not modify these IDs, since exactly these IDs are expected by the
14 | # evaluation server.
15 |
16 | 'trainId', # Feel free to modify these IDs as suitable for your method. Then create
17 | # ground truth images with train IDs, using the tools provided in the
18 | # 'preparation' folder. However, make sure to validate or submit results
19 | # to our evaluation server using the regular IDs above!
20 | # For trainIds, multiple labels might have the same ID. Then, these labels
21 | # are mapped to the same class in the ground truth images. For the inverse
22 | # mapping, we use the label that is defined first in the list below.
23 | # For example, mapping all void-type classes to the same ID in training,
24 | # might make sense for some approaches.
25 | # Max value is 255!
26 |
27 | 'category', # The name of the category that this label belongs to
28 |
29 | 'categoryId', # The ID of this category. Used to create ground truth images
30 | # on category level.
31 |
32 | 'hasInstances', # Whether this label distinguishes between single instances or not
33 |
34 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
35 | # during evaluations or not
36 |
37 | 'color', # The color of this label
38 | ])
39 |
40 | labels_6_Cityscapes = [
41 | # name id trainId category catId hasInstances ignoreInEval color
42 | Label('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
43 | Label('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
44 | Label('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
45 | Label('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
46 | Label('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
47 | Label('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
48 | Label('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
49 | Label('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
50 | Label('sidewalk', 8, 0, 'flat', 1, False, False, (244, 35, 232)),
51 | Label('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
52 | Label('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
53 | Label('building', 11, 1, 'construction', 2, False, False, (70, 70, 70)),
54 | Label('wall', 12, 1, 'construction', 2, False, False, (102, 102, 156)),
55 | Label('fence', 13, 1, 'construction', 2, False, False, (190, 153, 153)),
56 | Label('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
57 | Label('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
58 | Label('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
59 | Label('pole', 17, 2, 'object', 3, False, False, (153, 153, 153)),
60 | Label('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
61 | Label('traffic light', 19, 2, 'object', 3, False, False, (250, 170, 30)),
62 | Label('traffic sign', 20, 2, 'object', 3, False, False, (220, 220, 0)),
63 | Label('vegetation', 21, 3, 'nature', 4, False, False, (107, 142, 35)),
64 | Label('terrain', 22, 3, 'nature', 4, False, False, (152, 251, 152)),
65 | Label('sky', 23, 1, 'sky', 5, False, False, (70, 130, 180)),
66 | Label('person', 24, 4, 'human', 6, True, False, (220, 20, 60)),
67 | Label('rider', 25, 4, 'human', 6, True, False, (255, 0, 0)),
68 | Label('car', 26, 5, 'vehicle', 7, True, False, (0, 0, 142)),
69 | Label('truck', 27, 5, 'vehicle', 7, True, False, (0, 0, 70)),
70 | Label('bus', 28, 5, 'vehicle', 7, True, False, (0, 60, 100)),
71 | Label('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
72 | Label('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
73 | Label('train', 31, 5, 'vehicle', 7, True, False, (0, 80, 100)),
74 | Label('motorcycle', 32, 5, 'vehicle', 7, True, False, (0, 0, 230)),
75 | Label('bicycle', 33, 5, 'vehicle', 7, True, False, (119, 11, 32)),
76 | Label('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
77 | ]
78 |
79 | Id2label_6_Cityscapes = {label.id: label for label in reversed(labels_6_Cityscapes)}
80 |
81 | labels_11_Cityscapes = [
82 | # name id trainId category catId hasInstances ignoreInEval color
83 | Label('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
84 | Label('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
85 | Label('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
86 | Label('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
87 | Label('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
88 | Label('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
89 | Label('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
90 | Label('road', 7, 5, 'flat', 1, False, False, (128, 64, 128)),
91 | Label('sidewalk', 8, 6, 'flat', 1, False, False, (244, 35, 232)),
92 | Label('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
93 | Label('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
94 | Label('building', 11, 1, 'construction', 2, False, False, (70, 70, 70)),
95 | Label('wall', 12, 9, 'construction', 2, False, False, (102, 102, 156)),
96 | Label('fence', 13, 2, 'construction', 2, False, False, (190, 153, 153)),
97 | Label('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
98 | Label('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
99 | Label('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
100 | Label('pole', 17, 4, 'object', 3, False, False, (153, 153, 153)),
101 | Label('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
102 | Label('traffic light', 19, 10, 'object', 3, False, False, (250, 170, 30)),
103 | Label('traffic sign', 20, 10, 'object', 3, False, False, (220, 220, 0)),
104 | Label('vegetation', 21, 7, 'nature', 4, False, False, (107, 142, 35)),
105 | Label('terrain', 22, 7, 'nature', 4, False, False, (152, 251, 152)),
106 | Label('sky', 23, 0, 'sky', 5, False, False, (70, 130, 180)),
107 | Label('person', 24, 3, 'human', 6, True, False, (220, 20, 60)),
108 | Label('rider', 25, 3, 'human', 6, True, False, (255, 0, 0)),
109 | Label('car', 26, 8, 'vehicle', 7, True, False, (0, 0, 142)),
110 | Label('truck', 27, 8, 'vehicle', 7, True, False, (0, 0, 70)),
111 | Label('bus', 28, 8, 'vehicle', 7, True, False, (0, 60, 100)),
112 | Label('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
113 | Label('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
114 | Label('train', 31, 8, 'vehicle', 7, True, False, (0, 80, 100)),
115 | Label('motorcycle', 32, 8, 'vehicle', 7, True, False, (0, 0, 230)),
116 | Label('bicycle', 33, 8, 'vehicle', 7, True, False, (119, 11, 32)),
117 | Label('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
118 | ]
119 |
120 | Id2label_11_Cityscapes = {label.id: label for label in reversed(labels_11_Cityscapes)}
121 |
122 |
123 | def fromIdToTrainId(imgin, Id2label):
124 | imgout = imgin.copy()
125 | for id in Id2label:
126 | imgout[imgin == id] = Id2label[id].trainId
127 | return imgout
128 |
129 |
130 | def shiftUpId(imgin):
131 | imgout = imgin.copy() + 1
132 | return imgout
133 |
134 |
135 | def shiftDownId(imgin):
136 | imgout = imgin.copy()
137 | imgout[imgin == 0] = 256 # ignore label + 1
138 | imgout -= 1
139 | return imgout
140 |
--------------------------------------------------------------------------------
/utils/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 |
6 | class TaskLoss(torch.nn.Module):
7 | def __init__(self, losses=['cross_entropy'], gamma=2.0, num_classes=13, alpha=None, weight=None, ignore_index=None, reduction='mean'):
8 | super(TaskLoss, self).__init__()
9 | self.losses = losses
10 | self.weight = weight
11 | self.gamma = gamma
12 | self.alpha = alpha
13 | self.ignore_index = ignore_index
14 | self.dice_loss = DiceLoss(num_classes=num_classes, ignore_index=self.ignore_index)
15 | self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index)
16 |
17 | def forward(self, predict, target):
18 | total_loss = 0
19 | if 'dice' in self.losses:
20 | total_loss += self.dice_loss(predict, target)
21 | if 'cross_entropy' in self.losses:
22 | total_loss += self.ce_loss(predict, target)
23 |
24 | return total_loss
25 |
26 |
27 | class symJSDivLoss(torch.nn.Module):
28 | def __init__(self, ):
29 | super(symJSDivLoss, self).__init__()
30 | self.KLDivLoss = torch.nn.KLDivLoss()
31 |
32 | def forward(self, predict, target):
33 | total_loss = 0
34 | total_loss += 0.5 * self.KLDivLoss(predict.softmax(dim=1).clamp(min=1e-10).log(), target.softmax(dim=1).clamp(min=1e-10))
35 | total_loss += 0.5 * self.KLDivLoss(target.softmax(dim=1).clamp(min=1e-10).log(), predict.softmax(dim=1).clamp(min=1e-10))
36 |
37 | return total_loss
38 |
39 |
40 | """
41 | Adapted from https://github.com/Guocode/DiceLoss.Pytorch.git
42 | """
43 | def make_one_hot(input, num_classes):
44 | """Convert class index tensor to one hot encoding tensor.
45 | Args:
46 | input: A tensor of shape [N, 1, *]
47 | num_classes: An int of number of class
48 | Returns:
49 | A tensor of shape [N, num_classes, *]
50 | """
51 | shape = np.array(input.shape)
52 | shape[1] = num_classes
53 | shape = tuple(shape)
54 | result = torch.zeros(shape, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
55 | result = result.scatter_(1, input, 1)
56 |
57 | return result
58 |
59 |
60 | """
61 | Adapted from https://github.com/Guocode/DiceLoss.Pytorch.git
62 | """
63 | class BinaryDiceLoss(torch.nn.Module):
64 | """Dice loss of binary class
65 | Args:
66 | smooth: A float number to smooth loss, and avoid NaN error, default: 1
67 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
68 | predict: A tensor of shape [N, *]
69 | target: A tensor of shape same with predict
70 | Returns:
71 | Loss tensor according to arg reduction
72 | Raise:
73 | Exception if unexpected reduction
74 | """
75 | def __init__(self, smooth=1, p=2):
76 | super(BinaryDiceLoss, self).__init__()
77 | self.smooth = smooth
78 | self.p = p
79 |
80 | def forward(self, predict, target):
81 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
82 | predict = predict.contiguous().view(predict.shape[0], -1)
83 | target = target.contiguous().view(target.shape[0], -1)
84 |
85 | num = torch.sum(torch.mul(predict, target))*2 + self.smooth
86 | den = torch.sum(predict.pow(self.p) + target.pow(self.p)) + self.smooth
87 |
88 | dice = num / den
89 | loss = 1 - dice
90 | return loss
91 |
92 |
93 | """
94 | Adapted from https://github.com/Guocode/DiceLoss.Pytorch.git
95 | """
96 | class DiceLoss(torch.nn.Module):
97 | """Dice loss, need one hot encode input
98 | Args:
99 | weight: An array of shape [num_classes,]
100 | ignore_index: class index to ignore
101 | predict: A tensor of shape [N, C, *]
102 | target: A tensor of same shape with predict
103 | other args pass to BinaryDiceLoss
104 | Return:
105 | same as BinaryDiceLoss
106 | """
107 | def __init__(self, weight=None, num_classes=13, ignore_index=None, **kwargs):
108 | super(DiceLoss, self).__init__()
109 | self.kwargs = kwargs
110 | self.weight = weight
111 | self.num_classes = num_classes
112 | self.ignore_index = ignore_index
113 |
114 | def forward(self, predict, target):
115 | mask = target != self.ignore_index
116 | target = target * mask
117 | target = make_one_hot(torch.unsqueeze(target, 1), self.num_classes)
118 | target = target * mask.unsqueeze(1)
119 |
120 | assert predict.shape == target.shape, 'predict & target shape do not match'
121 | dice = BinaryDiceLoss(**self.kwargs)
122 | total_loss = 0
123 | predict = F.softmax(predict, dim=1)
124 | predict = predict * mask.unsqueeze(1)
125 |
126 | for i in range(target.shape[1]):
127 | if i != self.ignore_index:
128 | dice_loss = dice(predict[:, i], target[:, i])
129 | if self.weight is not None:
130 | assert self.weight.shape[0] == target.shape[1], \
131 | 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
132 | dice_loss *= self.weights[i]
133 | total_loss += dice_loss
134 |
135 | return total_loss/target.shape[1]
136 |
137 |
138 |
139 |
--------------------------------------------------------------------------------
/utils/radam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer, required
4 |
5 |
6 | class RAdam(Optimizer):
7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
8 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
9 | self.buffer = [[None, None, None] for ind in range(10)]
10 | super(RAdam, self).__init__(params, defaults)
11 |
12 | def __setstate__(self, state):
13 | super(RAdam, self).__setstate__(state)
14 |
15 | def step(self, closure=None):
16 |
17 | loss = None
18 | if closure is not None:
19 | loss = closure()
20 |
21 | for group in self.param_groups:
22 |
23 | for p in group['params']:
24 | if p.grad is None:
25 | continue
26 | grad = p.grad.data.float()
27 | if grad.is_sparse:
28 | raise RuntimeError('RAdam does not support sparse gradients')
29 |
30 | p_data_fp32 = p.data.float()
31 |
32 | state = self.state[p]
33 |
34 | if len(state) == 0:
35 | state['step'] = 0
36 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
37 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
38 | else:
39 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
40 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
41 |
42 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
43 | beta1, beta2 = group['betas']
44 |
45 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
46 | # exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
47 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
48 |
49 | state['step'] += 1
50 | buffered = self.buffer[int(state['step'] % 10)]
51 | if state['step'] == buffered[0]:
52 | N_sma, step_size = buffered[1], buffered[2]
53 | else:
54 | buffered[0] = state['step']
55 | beta2_t = beta2 ** state['step']
56 | N_sma_max = 2 / (1 - beta2) - 1
57 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
58 | buffered[1] = N_sma
59 |
60 | # more conservative since it's an approximated value
61 | if N_sma >= 5:
62 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) /
63 | N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
64 | else:
65 | step_size = 1.0 / (1 - beta1 ** state['step'])
66 | buffered[2] = step_size
67 |
68 | if group['weight_decay'] != 0:
69 | p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
70 |
71 | # more conservative since it's an approximated value
72 | if N_sma >= 5:
73 | denom = exp_avg_sq.sqrt().add_(group['eps'])
74 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
75 | else:
76 | p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
77 |
78 | p.data.copy_(p_data_fp32)
79 |
80 | return loss
81 |
--------------------------------------------------------------------------------
/utils/saver.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import datetime
4 |
5 | import torch
6 |
7 |
8 | class CheckpointSaver(object):
9 | def __init__(self, save_dir):
10 | if save_dir is not None:
11 | self.save_dir = os.path.abspath(save_dir)
12 | return
13 |
14 | # save checkpoint
15 | def save_checkpoint(self, models, optimizers, epoch, step_count, batch_size_a, batch_size_b):
16 | timestamp = datetime.datetime.now()
17 | checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'Epoch_' + str(epoch) + '.pt'))
18 | checkpoint = {}
19 | for model in models:
20 | checkpoint[model] = models[model].state_dict()
21 | for optimizer in optimizers:
22 | checkpoint[optimizer] = optimizers[optimizer].state_dict()
23 | checkpoint['epoch'] = epoch
24 | checkpoint['step_count'] = step_count
25 | checkpoint['batch_size_a'] = batch_size_a
26 | checkpoint['batch_size_b'] = batch_size_b
27 | print()
28 | print(timestamp, 'Epoch:', epoch, 'Iteration:', step_count)
29 | print('Saving checkpoint file [' + checkpoint_filename + ']')
30 | torch.save(checkpoint, checkpoint_filename)
31 | return
32 |
33 | # load a checkpoint
34 | def load_checkpoint(self, models, optimizers, checkpoint_file=None, load_optimizer=True):
35 | checkpoint = torch.load(checkpoint_file)
36 | for model in models:
37 | if model in checkpoint:
38 | models[model].load_state_dict(checkpoint[model])
39 | if load_optimizer:
40 | for optimizer in optimizers:
41 | if optimizer in checkpoint:
42 | optimizers[optimizer].load_state_dict(checkpoint[optimizer])
43 | print("Loading checkpoint with epoch {}, step {}"
44 | .format(checkpoint['epoch'], checkpoint['step_count']))
45 | return {'epoch': checkpoint['epoch'],
46 | 'step_count': checkpoint['step_count'],
47 | 'batch_size_a': checkpoint['batch_size_a'],
48 | 'batch_size_b': checkpoint['batch_size_b']}
49 |
50 | def load_pretrained_weights(self, models, model_list, checkpoint_file=None):
51 | checkpoint = torch.load(checkpoint_file)
52 | load_model_list = []
53 | for model_name in model_list:
54 | if model_name in ['front_sensor_b', 'e2vid_decoder']:
55 | continue
56 | if model_name in checkpoint:
57 | load_model_list.append(model_name)
58 | models[model_name].load_state_dict(checkpoint[model_name])
59 |
60 | print("Loading pretrained checkpoints for {}".format(load_model_list))
61 |
--------------------------------------------------------------------------------
/utils/viz_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | import torchvision.utils
5 | import torch.nn.functional as F
6 | import matplotlib.pyplot as plt
7 | import itertools
8 |
9 |
10 | def createRGBGrid(tensor_list, nrow):
11 | """Creates a grid of rgb values based on the tensor stored in tensor_list"""
12 | vis_tensor_list = []
13 | for tensor in tensor_list:
14 | vis_tensor_list.append(visualizeTensors(tensor))
15 |
16 | return torchvision.utils.make_grid(torch.cat(vis_tensor_list, dim=0), nrow=nrow)
17 |
18 |
19 | def createRGBImage(tensor, separate_pol=True):
20 | """Creates a grid of rgb values based on the tensor stored in tensor_list"""
21 | if tensor.shape[1] == 3:
22 | return tensor
23 | elif tensor.shape[1] == 1:
24 | return tensor.expand(-1, 3, -1, -1)
25 | elif tensor.shape[1] == 2:
26 | return visualizeHistogram(tensor)
27 | elif tensor.shape[1] > 3:
28 | return visualizeVoxelGrid(tensor, separate_pol)
29 |
30 |
31 | def visualizeTensors(tensor):
32 | """Creates a rgb image of the given tensor. Can be event histogram, event voxel grid, grayscale and rgb."""
33 | if tensor.shape[1] == 3:
34 | return tensor
35 | elif tensor.shape[1] == 1:
36 | return tensor.expand(-1, 3, -1, -1)
37 | elif tensor.shape[1] == 2:
38 | return visualizeHistogram(tensor)
39 | elif tensor.shape[1] > 3:
40 | return visualizeVoxelGrid(tensor)
41 |
42 |
43 | def visualizeHistogram(histogram):
44 | """Visualizes the input histogram"""
45 | batch, _, height, width = histogram.shape
46 | torch_image = torch.zeros([batch, 1, height, width], device=histogram.device)
47 |
48 | return torch.cat([histogram.clamp(0, 1), torch_image], dim=1)
49 |
50 |
51 | def visualizeVoxelGrid(voxel_grid, separate_pol=True):
52 | """Visualizes the input histogram"""
53 | batch, nr_channels, height, width = voxel_grid.shape
54 | if separate_pol:
55 | pos_events_idx = nr_channels // 2
56 | temporal_scaling = torch.arange(start=1, end=pos_events_idx+1, dtype=voxel_grid.dtype,
57 | device=voxel_grid.device)[None, :, None, None] / pos_events_idx
58 | pos_voxel_grid = voxel_grid[:, :pos_events_idx] * temporal_scaling
59 | neg_voxel_grid = voxel_grid[:, pos_events_idx:] * temporal_scaling
60 |
61 | torch_image = torch.zeros([batch, 1, height, width], device=voxel_grid.device)
62 | pos_image = torch.sum(pos_voxel_grid, dim=1, keepdim=True)
63 | neg_image = torch.sum(neg_voxel_grid, dim=1, keepdim=True)
64 |
65 | return torch.cat([neg_image.clamp(0, 1), pos_image.clamp(0, 1), torch_image], dim=1)
66 |
67 | sum_events = torch.sum(voxel_grid, dim=1).detach()
68 | event_preview = torch.zeros((batch, 3, height, width))
69 | b = event_preview[:, 0, :, :]
70 | r = event_preview[:, 2, :, :]
71 | b[sum_events > 0] = 255
72 | r[sum_events < 0] = 255
73 | return event_preview
74 |
75 |
76 | def visualizeConfusionMatrix(confusion_matrix, path_name=None):
77 | """
78 | Visualizes the confustion matrix using matplotlib.
79 |
80 | :param confusion_matrix: NxN numpy array
81 | :param path_name: if no path name is given, just an image is returned
82 | """
83 | import matplotlib.pyplot as plt
84 | nr_classes = confusion_matrix.shape[0]
85 | fig, ax = plt.subplots(1, 1, figsize=(16, 16))
86 | ax.matshow(confusion_matrix)
87 | ax.plot([-0.5, nr_classes - 0.5], [-0.5, nr_classes - 0.5], '-', color='grey')
88 | ax.set_xlabel('Labels')
89 | ax.set_ylabel('Predicted')
90 |
91 | if path_name is None:
92 | fig.tight_layout(pad=0)
93 | fig.canvas.draw()
94 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
95 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
96 | plt.close()
97 |
98 | return data
99 |
100 | else:
101 | fig.savefig(path_name)
102 | plt.close()
103 |
104 |
105 | def create_checkerboard(N, C, H, W):
106 | cell_sz = max(min(H, W) // 32, 1)
107 | mH = (H + cell_sz - 1) // cell_sz
108 | mW = (W + cell_sz - 1) // cell_sz
109 | checkerboard = torch.full((mH, mW), 0.25, dtype=torch.float32)
110 | checkerboard[0::2, 0::2] = 0.75
111 | checkerboard[1::2, 1::2] = 0.75
112 | checkerboard = checkerboard.float().view(1, 1, mH, mW)
113 | checkerboard = F.interpolate(checkerboard, scale_factor=cell_sz, mode='nearest')
114 | checkerboard = checkerboard[:, :, :H, :W].repeat(N, C, 1, 1)
115 | return checkerboard
116 |
117 |
118 | def prepare_semseg(img, semseg_color_map, semseg_ignore_label):
119 | assert (img.dim() == 3 or img.dim() == 4 and img.shape[1] == 1) and img.dtype in (torch.int, torch.long), \
120 | f'Expecting 4D tensor with semseg classes, got {img.shape}'
121 | if img.dim() == 4:
122 | img = img.squeeze(1)
123 | colors = torch.tensor(semseg_color_map, dtype=torch.float32)
124 | assert colors.dim() == 2 and colors.shape[1] == 3
125 | if torch.max(colors) > 128:
126 | colors /= 255
127 | img = img.cpu().clone() # N x H x W
128 | N, H, W = img.shape
129 | img_color_ids = torch.unique(img)
130 | assert all(c_id == semseg_ignore_label or 0 <= c_id < len(semseg_color_map) for c_id in img_color_ids)
131 | checkerboard, mask_ignore = None, None
132 | if semseg_ignore_label in img_color_ids:
133 | checkerboard = create_checkerboard(N, 3, H, W)
134 | # blackboard = create_blackboard(N, 3, H, W)
135 | mask_ignore = img == semseg_ignore_label
136 | img[mask_ignore] = 0
137 | img = colors[img] # N x H x W x 3
138 | img = img.permute(0, 3, 1, 2)
139 |
140 | # checkerboard
141 | if semseg_ignore_label in img_color_ids:
142 | mask_ignore = mask_ignore.unsqueeze(1).repeat(1, 3, 1, 1)
143 | img[mask_ignore] = checkerboard[mask_ignore]
144 | # img[mask_ignore] = blackboard[mask_ignore]
145 | return img
146 |
147 |
148 | def plot_confusion_matrix(cm, classes,
149 | normalize=False,
150 | title='Confusion matrix',
151 | cmap=plt.cm.Blues):
152 | """
153 | This function prints and plots the confusion matrix.
154 | Normalization can be applied by setting `normalize=True`.
155 | """
156 | cm = cm.numpy()
157 | if normalize:
158 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
159 | print("Normalized confusion matrix")
160 | else:
161 | print('Confusion matrix, without normalization')
162 |
163 | fig = plt.figure()
164 |
165 | plt.imshow(cm, interpolation='nearest', cmap=cmap)
166 | plt.title(title)
167 | plt.colorbar()
168 | tick_marks = np.arange(len(classes))
169 | plt.xticks(tick_marks, classes, rotation=45)
170 | plt.yticks(tick_marks, classes)
171 |
172 | fmt = '.2f' if normalize else 'd'
173 | thresh = cm.max() / 2.
174 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
175 | plt.text(j, i, format(cm[i, j], fmt),
176 | horizontalalignment="center",
177 | color="white" if cm[i, j] > thresh else "black")
178 |
179 | plt.tight_layout()
180 | plt.ylabel('True label')
181 | plt.xlabel('Predicted label')
182 | return fig
183 |
--------------------------------------------------------------------------------