├── README.md ├── config └── eval.yml ├── data ├── cityscapes_loader.py ├── cityscapes_video │ ├── gtFine │ │ └── val │ │ │ └── stuttgart_00 │ │ │ ├── stuttgart_00_000000_000001_gtFine_color.png │ │ │ ├── stuttgart_00_000000_000001_gtFine_labelIds.png │ │ │ ├── stuttgart_00_000000_000002_gtFine_color.png │ │ │ ├── stuttgart_00_000000_000002_gtFine_labelIds.png │ │ │ ├── stuttgart_00_000000_000003_gtFine_color.png │ │ │ ├── stuttgart_00_000000_000003_gtFine_labelIds.png │ │ │ ├── stuttgart_00_000000_000004_gtFine_color.png │ │ │ ├── stuttgart_00_000000_000004_gtFine_labelIds.png │ │ │ ├── stuttgart_00_000000_000005_gtFine_color.png │ │ │ └── stuttgart_00_000000_000005_gtFine_labelIds.png │ └── leftImg8bit │ │ ├── test │ │ └── munich │ │ │ ├── munich_000000_000000_leftImg8bit.png │ │ │ ├── munich_000000_000001_leftImg8bit.png │ │ │ └── munich_000000_000002_leftImg8bit.png │ │ └── val │ │ └── stuttgart_00 │ │ ├── stuttgart_00_000000_000001_leftImg8bit.png │ │ ├── stuttgart_00_000000_000002_leftImg8bit.png │ │ ├── stuttgart_00_000000_000003_leftImg8bit.png │ │ ├── stuttgart_00_000000_000004_leftImg8bit.png │ │ └── stuttgart_00_000000_000005_leftImg8bit.png ├── labels.png └── metrics.py ├── eval.py ├── model ├── ESPNetL1b-Cityscapes-pretrained.pkl ├── convolutional_lstm.py ├── esp_net.py └── loss.py ├── requirements.txt ├── resources └── esp_vs_our_model.gif └── utils └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Frame-To-Frame Consistent Semantic Segmentation 2 | 3 | This code implements the method introduced in the publication [Frame-To-Frame Consistent Semantic Segmentation](https://arxiv.org/abs/2008.00948): 4 | 5 | @InProceedings{Rebol_2020_ACVRW, 6 | author = {Rebol, Manuel and Knöbelreiter, Patrick}, 7 | title = {Frame-To-Frame Consistent Semantic Segmentation}, 8 | booktitle = {Joint Austrian Computer Vision And Robotics Workshop (ACVRW)}, 9 | month = {April}, 10 | year = {2020}, 11 | pages = {79-86}, 12 | doi = {10.3217/978-3-85125-752-6-18} 13 | } 14 | 15 | ![ESPNet vs our model](https://github.com/mrebol/f2fcss/blob/master/resources/esp_vs_our_model.gif) 16 | *ESPNet vs Our Model ESPNet_L1b* 17 | 18 | ## Dependencies 19 | + CUDA 11.0 for execution on GPU (optional) 20 | + Python 3.7 21 | + PyTorch 1.3.1 22 | + pip packages in `requirements.txt` 23 | 24 | 25 | ## Dataset 26 | We provide a data loader for the Cityscapes dataset in the `data` directory. 27 | Additionally, we included frames of the Cityscapes Demo Video in the repository to support quick first experiments. 28 | Images placed inside the `data/cityscapes_video/leftImg8bit/val` directory do require corresponding ground truth files, whereas images in the `data/cityscapes_video/leftImg8bit/test` directory don't. 29 | 30 | ## Configuration 31 | The default configuration file is stored in `config/eval.yml`. 32 | It loads the pretrained ESPNet_L1b model and inputs the dataset provided. 33 | If the GPU parameter in the config is enabled, CUDA 11.0 needs to be installed additionally. 34 | 35 | 36 | ## Evaluation 37 | 38 | The required python packages need to be installed using the provided `requirements.txt`: 39 | 40 | pip install -r requirements.txt 41 | 42 | To run the evaluation with the default config file located at `config/eval.yml` enter: 43 | 44 | python eval.py --config config/eval.yml 45 | 46 | ## Results 47 | The predicted semantic segmentation images are saved in the `output//images/` folder. 48 | Depending on the input config, the folder contains the semantic color maps and/or the semantic label ids. 49 | 50 | 51 | Additionally, we generate Tensorboard logs at `output//tensorboard/`. These logs can be 52 | examined after installing Tensorboard 53 | 54 | pip install tensorboard 55 | with the command: 56 | 57 | tensorboard --logdir output 58 | 59 | Tensorboard visualizes the statistics at [http://localhost:6006/](http://localhost:6006/) by default. 60 | 61 | -------------------------------------------------------------------------------- /config/eval.yml: -------------------------------------------------------------------------------- 1 | # This file is part of f2fcss. 2 | # 3 | # Copyright (C) 2020 Manuel Rebol 4 | # Patrick Knoebelreiter 5 | # Institute for Computer Graphics and Vision, Graz University of Technology 6 | # https://www.tugraz.at/institute/icg/teams/team-pock/ 7 | # 8 | # f2fcss is free software: you can redistribute it and/or modify it under the 9 | # terms of the GNU Affero General Public License as published by the Free Software 10 | # Foundation, either version 3 of the License, or any later version. 11 | # 12 | # f2fcss is distributed in the hope that it will be useful, but WITHOUT ANY 13 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 14 | # FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Affero General Public License 17 | # along with this program. If not, see . 18 | # 19 | model: 20 | load: model/ESPNetL1b-Cityscapes-pretrained.pkl 21 | input_size: [512, 1024] # height, width of model input 22 | gpu: 0 # 0..CPU, 1..GPU1, 2..GPU2,... 23 | 24 | data: 25 | dataset: data/cityscapes_video 26 | splits: val # test, train, val 27 | nr_of_scenes: all # 0,1,2,..,all 28 | nr_of_sequences: all # 0,1,2,..,all 29 | sequence_length: 1 # nr of sequential images passed to the network (~batch size) 30 | 31 | output: 32 | save_pred_color: True # Save segmentation color images 33 | save_label_ids: False # For Cityscapes test submission 34 | 35 | -------------------------------------------------------------------------------- /data/cityscapes_loader.py: -------------------------------------------------------------------------------- 1 | # This file is part of f2fcss. 2 | # 3 | # Copyright (C) 2020 Manuel Rebol 4 | # Patrick Knoebelreiter 5 | # Institute for Computer Graphics and Vision, Graz University of Technology 6 | # https://www.tugraz.at/institute/icg/teams/team-pock/ 7 | # 8 | # f2fcss is free software: you can redistribute it and/or modify it under the 9 | # terms of the GNU Affero General Public License as published by the Free Software 10 | # Foundation, either version 3 of the License, or any later version. 11 | # 12 | # f2fcss is distributed in the hope that it will be useful, but WITHOUT ANY 13 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 14 | # FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Affero General Public License 17 | # along with this program. If not, see . 18 | # 19 | import random 20 | import re 21 | import numpy as np 22 | import torch 23 | from PIL import Image 24 | from torch.utils import data 25 | from os import path, sep 26 | from utils.utils import recursive_glob 27 | 28 | 29 | class CityscapesLoader(data.Dataset): 30 | ego_vehicle_class_id = 1 31 | rectification_border_class_id = 2 32 | out_of_roi_class_id = 3 33 | 34 | colors = [ 35 | [128, 64, 128], 36 | [244, 35, 232], 37 | [70, 70, 70], 38 | [102, 102, 156], 39 | [190, 153, 153], 40 | [153, 153, 153], 41 | [250, 170, 30], 42 | [220, 220, 0], 43 | [107, 142, 35], 44 | [152, 251, 152], 45 | [0, 130, 180], 46 | [220, 20, 60], 47 | [255, 0, 0], 48 | [0, 0, 142], 49 | [0, 0, 70], 50 | [0, 60, 100], 51 | [0, 80, 100], 52 | [0, 0, 230], 53 | [119, 11, 32], 54 | ] # len=19 55 | class_names = [ 56 | "unlabelled", 57 | "road", # 0 58 | "sidewalk", # 1 59 | "building", # 2 60 | "wall", # 3 61 | "fence", # 4 62 | "pole", # 5 63 | "traffic_light", # 6 64 | "traffic_sign", # 7 65 | "vegetation", # 8 66 | "terrain", # 9 67 | "sky", # 10 68 | "person", # 11 69 | "rider", # 12 70 | "car", # 13 71 | "truck", # 14 72 | "bus", # 15 73 | "train", # 16 74 | "motorcycle", # 17 75 | "bicycle", # 18 76 | ] # len=20 77 | label_colours = dict(zip(range(19), colors)) 78 | 79 | statistics = {'cityscapes': # Cityscapes, SingleFrame, 2048x1024, train 80 | {'class_weights': [2.601622, 6.704553, 3.522924, 9.876478, 9.684879, 9.397963, 81 | 10.288337, 9.969174, 4.3375425, 9.453512, 7.622256, 9.404625, 82 | 10.358636, 6.3711667, 10.231368, 10.262094, 10.264279, 10.39429, 10.09429], 83 | 'mean': [73.15842, 82.90896, 72.39239], 84 | 'std': [44.91484, 46.152893, 45.319214]}, 85 | 'cityscapes_seq': # Cityscapes_seq_gtDeepLab, SequenceData, 2048x1024, train 86 | {'class_weights': [2.630181, 6.467089, 3.500679, 9.828134, 9.666817, 9.3363495, 87 | 10.296027, 9.910831, 4.3855543, 9.411499, 7.660441, 9.438194, 88 | 10.371957, 6.399998, 10.219244, 10.274652, 10.271446, 10.397583, 10.083669], 89 | 'mean': [73.20033613, 82.95346218, 72.43207843], 90 | 'std': [44.91366387, 46.15787395, 45.32914566] 91 | } 92 | } 93 | 94 | valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33, ] # len=19 95 | n_classes = 19 96 | void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 97 | ignore_index = 250 98 | class_map = dict(zip(valid_classes, range(19))) 99 | full_resolution = [1024, 2048] 100 | 101 | def __init__( 102 | self, 103 | root, 104 | scene_group_size, 105 | sequence_overlap, 106 | nr_of_scenes, 107 | nr_of_sequences, 108 | split, 109 | augmentations, 110 | shuffle_before_cut=False, 111 | is_shuffle_scenes=False, 112 | scale_size=None, 113 | seq_len=1, 114 | img_dtype=torch.float, 115 | lbl_dtype=torch.uint8, 116 | path_full_res=None, 117 | ): 118 | assert sequence_overlap < seq_len 119 | if isinstance(root, list): 120 | assert len(root) == len(nr_of_scenes) == len(nr_of_sequences) == len(split) == len(shuffle_before_cut) 121 | self.root = root 122 | self.split = split 123 | else: 124 | self.root = [root] 125 | self.split = [split] 126 | nr_of_scenes = [nr_of_scenes] 127 | nr_of_sequences = [nr_of_sequences] 128 | shuffle_before_cut = [shuffle_before_cut] 129 | self.nr_of_datasets = len(self.root) 130 | self.augmentations = augmentations 131 | self.scale_size = scale_size 132 | 133 | folder_name = self.root[0].rstrip('/') 134 | if scale_size: 135 | self.final_size = scale_size 136 | else: 137 | self.final_size = self.full_resolution 138 | 139 | self.statistics = self.statistics['cityscapes'] 140 | self.seq_len = seq_len 141 | self.img_dtype = img_dtype 142 | self.lbl_dtype = lbl_dtype 143 | self.path_full_res = path_full_res 144 | self.sequence_overlap = sequence_overlap 145 | self.scene_group_size = scene_group_size 146 | self.is_shuffle_scenes = is_shuffle_scenes 147 | self.train_scales = [0] * len(self.root) 148 | 149 | self.sequences = self._generate_sequence_list(self.root, self.split, nr_of_scenes, nr_of_sequences, 150 | self.seq_len, self.sequence_overlap, shuffle_before_cut, 151 | self.train_scales) 152 | 153 | if self.scene_group_size > 0: # Assign Scene Group number 154 | for sequence in self.sequences: 155 | sequence[3] = sequence[4] // self.scene_group_size 156 | # Sort first by train_scale, then by scene_group_nr, then by seq_in_scene, then by scene_nr 157 | self.sequences.sort(key=lambda x: (x[7][1], x[3], x[2], x[4])) 158 | 159 | self.nr_of_scenes = 0 160 | if self.sequences: 161 | self.nr_of_scenes = len(set(np.array(self.sequences)[:, 4])) # 4... scene_nr 162 | 163 | print( 164 | "Cityscapes Loader: Found %d scenes and %d %s intervals." % (self.nr_of_scenes, len(self.sequences), split)) 165 | 166 | def _generate_sequence_list(self, dirs, splits, nr_scenes, nr_sequences, seq_len, seq_overlap, shuffle_before_cut, 167 | train_scales): 168 | sequences = [] # [(img_path, lbl_path),...], 'scene_name', seq_in_scene, scene_group_nr, scene_nr, 'file_type', files_seq_nr[], augmentation[hflip, scaleSize] 169 | scene_nr = 0 170 | scene_group_nr = 0 171 | augmentation = [0] 172 | for dir, split, nr_of_scenes, nr_of_sequences, shuffle_bef_cut, train_scale in zip(dirs, splits, nr_scenes, 173 | nr_sequences, 174 | shuffle_before_cut, 175 | train_scales): 176 | images_base = path.join(dir, "leftImg8bit", split) 177 | lbls_base = path.join(dir, "gtFine", split) 178 | files = recursive_glob(rootdir=images_base, suffix=".png") 179 | files = sorted(files) 180 | 181 | if split != 'test': # remove files without labels 182 | files_temp = [] 183 | for i, file in enumerate(files): 184 | lbl_path = path.join(lbls_base, file.split(sep)[-2], 185 | path.basename(file)[:-15] + "gtFine_labelIds.png") 186 | if path.isfile(lbl_path): 187 | files_temp.append(file) 188 | files = files_temp 189 | 190 | # create sequence list 191 | sequences_dataset = [] 192 | if sequences: 193 | scene_nr = sequences[-1][4] + 1 194 | seq_in_scene = 0 195 | for i, file in enumerate(files): 196 | lbl_path = path.join(lbls_base, file.split(sep)[-2], 197 | path.basename(file)[:-15] + "gtFine_labelIds.png") 198 | current_file_type = '' 199 | if len(re.findall(r'_\d+_\d+_\d+_', path.basename(file))) == 1: # video-file 200 | current_file_type = 'video-file' 201 | seq_nr_str = re.findall(r'\d+_\d+_\d+', path.basename(file))[0] 202 | elif len(re.findall(r'_\d+_\d+_', path.basename(file))) == 1: # single-frame file or sequence file 203 | current_file_type = 'seq-file' 204 | seq_nr_str = re.findall(r'_\d+_\d+_', path.basename(file))[0] 205 | seq_nr_str = seq_nr_str[1:-1] 206 | 207 | seq_nr = int(seq_nr_str.replace('_', '')) 208 | 209 | if len(sequences_dataset) == 0: # very first interval 210 | seq_in_scene = 0 211 | sequences_dataset.append( 212 | [[(file, lbl_path)], seq_nr_str, seq_in_scene, scene_group_nr, scene_nr, current_file_type, 213 | [seq_nr], augmentation.copy() + [train_scale]]) 214 | continue 215 | 216 | prev_interval = sequences_dataset[-1] 217 | if current_file_type != prev_interval[5] or prev_interval[6][-1] + 1 != seq_nr: # new scene: 218 | scene_nr += 1 219 | seq_in_scene = 0 220 | sequences_dataset.append( 221 | [[(file, lbl_path)], seq_nr_str, seq_in_scene, scene_group_nr, scene_nr, current_file_type, 222 | [seq_nr], augmentation.copy() + [train_scale]]) 223 | elif len(prev_interval[0]) == seq_len: # check if last interval full --> new interval, same_scene 224 | seq_in_scene += 1 225 | if seq_overlap > 0: 226 | sequences_dataset.append( 227 | [prev_interval[0][-seq_overlap:] + [(file, lbl_path)], prev_interval[1], 228 | seq_in_scene, scene_group_nr, scene_nr, current_file_type, 229 | prev_interval[6][-seq_overlap:] + [seq_nr], augmentation.copy() + [train_scale]]) 230 | else: 231 | sequences_dataset.append( 232 | [[(file, lbl_path)], prev_interval[1], seq_in_scene, scene_group_nr, scene_nr, 233 | current_file_type, [seq_nr], augmentation.copy() + [train_scale]]) 234 | else: # same interval, same scene 235 | prev_interval[0].append((file, lbl_path)) 236 | prev_interval[6].append(seq_nr) 237 | 238 | # Cut sequence list 239 | assert not (nr_of_scenes != 'all' and nr_of_sequences != 'all') 240 | # Requires file_intervals list to be sorted by scenes. 241 | if nr_of_scenes != 'all': 242 | if shuffle_bef_cut: 243 | self.shuffle_scenes_of_sequences(sequences_dataset) # shuffle before cut scenes 244 | nr_of_scenes_curr = 0 245 | prev_scene_name = '' 246 | for index, file_interval in enumerate(sequences_dataset): 247 | if prev_scene_name != file_interval[1]: 248 | nr_of_scenes_curr += 1 249 | if nr_of_scenes_curr > nr_of_scenes: 250 | sequences_dataset = sequences_dataset[:index] 251 | break 252 | prev_scene_name = file_interval[1] 253 | elif nr_of_sequences != "all": 254 | sequences_dataset = sequences_dataset[:nr_of_sequences] 255 | 256 | sequences += sequences_dataset 257 | return sequences 258 | 259 | def shuffle_scenes(self): 260 | assert self.scene_group_size > 0 261 | scene_nr_at_scale = [[] for x in self.train_scales] 262 | for seq in self.sequences: 263 | scene_nr_at_scale[seq[7][1]].append(seq[4]) # get scene_nrs for each scale 264 | for i in range(len(self.train_scales)): # shuffle each scale 265 | scene_nr_at_scale[i] = np.repeat(np.unique(np.array(scene_nr_at_scale[i]))[None, :], 2, axis=0) 266 | random.shuffle(scene_nr_at_scale[i][1]) # inplace 267 | 268 | # assign new scene_nr 269 | for sequence in self.sequences: 270 | dictionary = scene_nr_at_scale[sequence[7][1]] 271 | sequence[4] = np.asscalar(dictionary[1, dictionary[0] == sequence[4]]) 272 | 273 | for sequence in self.sequences: # assign scene group nr 274 | sequence[3] = sequence[4] // self.scene_group_size # works, because sequence is reference # 4... scene_nr 275 | self.sequences.sort(key=lambda x: (x[7][1], x[3], x[2], x[ 276 | 4])) # sort first by train_scale, then by scene_group_nr, then by seq_in_scene, then by scene_nr 277 | 278 | def shuffle_scenes_of_sequences(self, sequences): 279 | assert self.scene_group_size > 0 280 | # shuffle scene_numbers 281 | scene_numbers = list(set(np.array(sequences)[:, 4])) # 4... scene_nr 282 | start_scene_number = scene_numbers[0] 283 | assert start_scene_number + len(scene_numbers) - 1 == scene_numbers[-1] 284 | random.shuffle(scene_numbers) 285 | 286 | # assign new scene_nr and scene group nr 287 | for sequence in sequences: 288 | sequence[4] = scene_numbers[sequence[4] - start_scene_number] 289 | sequence[3] = sequence[4] // self.scene_group_size # works, because sequence is reference # 4... scene_nr 290 | 291 | sequences.sort( 292 | key=lambda x: (x[4], x[2])) # sort first by scene_group_nr, then by seq_in_scene, then by scene_nr 293 | 294 | def hflip_scenes(self): 295 | self.sequences.sort(key=lambda x: (x[4])) # sort by scene_nr 296 | old_scene_nr = -1 297 | hflip = 0 298 | for seq in self.sequences: 299 | if old_scene_nr != seq[4]: 300 | hflip = random.random() < 0.5 301 | old_scene_nr = seq[4] 302 | seq[7][0] = hflip 303 | 304 | def prepare_iteration(self): 305 | if self.is_shuffle_scenes: 306 | self.shuffle_scenes() 307 | if 'hflip' in self.augmentations: 308 | self.hflip_scenes() 309 | # sort first by train_scale, then by scene_group_nr, then by seq_in_scene, then by scene_nr 310 | self.sequences.sort(key=lambda x: (x[7][1], x[3], x[2], x[4])) 311 | 312 | def __len__(self): 313 | return len(self.sequences) 314 | 315 | def __getitem__(self, index): 316 | sequence = self.sequences[ 317 | index] # [(img_path, lbl_path),...], 'scene_name', seq_in_scene, scene_group_nr, scene_nr, 'file_type', files_seq_nr[] 318 | imgs = [] 319 | lbls = [] 320 | for time_step, (img_path, lbl_path) in enumerate(sequence[0]): 321 | img = Image.open(img_path) 322 | img = np.array(img, dtype=np.uint8) 323 | if path.exists(lbl_path): 324 | lbl = Image.open(lbl_path) 325 | lbl = self.labelId_to_segmap(np.array(lbl, dtype=np.uint8)) 326 | else: 327 | lbl = np.zeros((1, 1)) 328 | img, lbl = self.transform(img, lbl) 329 | imgs.append(img) 330 | lbls.append(lbl) 331 | 332 | imgs = torch.stack(imgs) 333 | lbls = torch.stack(lbls) 334 | 335 | return [imgs, lbls, sequence[0], sequence[1], index, sequence[2], sequence[4], sequence[3]] 336 | 337 | def transform(self, img, lbl): 338 | if self.scale_size: 339 | img = np.array(Image.fromarray(img).resize((self.scale_size[1], self.scale_size[0]), Image.BILINEAR)) 340 | img = img.astype(np.float32) 341 | 342 | if 'uniform' in self.augmentations: 343 | img -= self.statistics['mean'] 344 | img /= self.statistics['std'] 345 | elif 'normalize' in self.augmentations: 346 | img /= 255.0 347 | 348 | img = img.transpose(2, 0, 1) # HWC -> CHW 349 | img = torch.from_numpy(img).type(self.img_dtype) 350 | lbl = torch.from_numpy(lbl).type(self.lbl_dtype) 351 | 352 | return img, lbl 353 | 354 | def segmap_to_color(self, segmap, gt=None, insert_mask=False, float_output=True): # color an input segmap [0..18] 355 | rgb = np.zeros((segmap.shape[0], segmap.shape[1], 3), dtype=np.int8) 356 | for cl in range(0, self.n_classes): 357 | rgb[segmap == cl] = self.label_colours[cl] 358 | if float_output: 359 | rgb = rgb.astype(np.float) 360 | rgb /= 255.0 361 | if insert_mask: 362 | rgb[gt == 250] = [0, 0, 0] 363 | return rgb 364 | 365 | def labelId_to_segmap(self, mask): # input:7,8,11,... output: 0,1,2,... 366 | for void_class in self.void_classes: 367 | mask[mask == void_class] = self.ignore_index 368 | for valid_class in self.valid_classes: 369 | mask[mask == valid_class] = self.class_map[valid_class] # Put all valid classes in range 0..18 370 | return mask 371 | 372 | @staticmethod 373 | def segmap_to_labelId_static(segmap): # input: 0,1,2,... output: 7,8,11,... 374 | return np.array(CityscapesLoader.valid_classes, dtype=np.uint8)[segmap] 375 | 376 | def _get_index_of_filepath(self, files, filepath): 377 | return files.index(filepath) 378 | 379 | 380 | def get_ego_vehicle_mask(label_id_img): 381 | img = Image.open(label_id_img) 382 | img = np.array(img, dtype=np.uint8) 383 | return img == CityscapesLoader.ego_vehicle_class_id 384 | 385 | 386 | def compute_class_weights(classWeights, histogram, classes): 387 | normHist = histogram / np.sum(histogram) 388 | for i in range(classes): 389 | classWeights[i] = 1 / ( 390 | np.log(1.10 + normHist[i])) # 1.10 is the normalization value, as defined in ERFNet paper 391 | return classWeights 392 | -------------------------------------------------------------------------------- /data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000001_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000001_gtFine_color.png -------------------------------------------------------------------------------- /data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000001_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000001_gtFine_labelIds.png -------------------------------------------------------------------------------- /data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000002_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000002_gtFine_color.png -------------------------------------------------------------------------------- /data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000002_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000002_gtFine_labelIds.png -------------------------------------------------------------------------------- /data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000003_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000003_gtFine_color.png -------------------------------------------------------------------------------- /data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000003_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000003_gtFine_labelIds.png -------------------------------------------------------------------------------- /data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000004_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000004_gtFine_color.png -------------------------------------------------------------------------------- /data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000004_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000004_gtFine_labelIds.png -------------------------------------------------------------------------------- /data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000005_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000005_gtFine_color.png -------------------------------------------------------------------------------- /data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000005_gtFine_labelIds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/gtFine/val/stuttgart_00/stuttgart_00_000000_000005_gtFine_labelIds.png -------------------------------------------------------------------------------- /data/cityscapes_video/leftImg8bit/test/munich/munich_000000_000000_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/leftImg8bit/test/munich/munich_000000_000000_leftImg8bit.png -------------------------------------------------------------------------------- /data/cityscapes_video/leftImg8bit/test/munich/munich_000000_000001_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/leftImg8bit/test/munich/munich_000000_000001_leftImg8bit.png -------------------------------------------------------------------------------- /data/cityscapes_video/leftImg8bit/test/munich/munich_000000_000002_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/leftImg8bit/test/munich/munich_000000_000002_leftImg8bit.png -------------------------------------------------------------------------------- /data/cityscapes_video/leftImg8bit/val/stuttgart_00/stuttgart_00_000000_000001_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/leftImg8bit/val/stuttgart_00/stuttgart_00_000000_000001_leftImg8bit.png -------------------------------------------------------------------------------- /data/cityscapes_video/leftImg8bit/val/stuttgart_00/stuttgart_00_000000_000002_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/leftImg8bit/val/stuttgart_00/stuttgart_00_000000_000002_leftImg8bit.png -------------------------------------------------------------------------------- /data/cityscapes_video/leftImg8bit/val/stuttgart_00/stuttgart_00_000000_000003_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/leftImg8bit/val/stuttgart_00/stuttgart_00_000000_000003_leftImg8bit.png -------------------------------------------------------------------------------- /data/cityscapes_video/leftImg8bit/val/stuttgart_00/stuttgart_00_000000_000004_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/leftImg8bit/val/stuttgart_00/stuttgart_00_000000_000004_leftImg8bit.png -------------------------------------------------------------------------------- /data/cityscapes_video/leftImg8bit/val/stuttgart_00/stuttgart_00_000000_000005_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/cityscapes_video/leftImg8bit/val/stuttgart_00/stuttgart_00_000000_000005_leftImg8bit.png -------------------------------------------------------------------------------- /data/labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/data/labels.png -------------------------------------------------------------------------------- /data/metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | # 4 | import numpy as np 5 | 6 | 7 | class RunningScore(object): 8 | def __init__(self, n_classes): 9 | self.n_classes = n_classes 10 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 11 | 12 | def _fast_hist(self, label_true, label_pred, n_class): # input labels as vectors 13 | valid_mask = (label_true >= 0) & (label_true < n_class) 14 | hist = np.bincount( 15 | n_class * label_true[valid_mask].astype(int) + label_pred[valid_mask], 16 | # dim0 (y-axis): nr_classes * true_label + dim1 (x-axis): pred_label 17 | minlength=n_class ** 2, # to ensure quadratic shape 18 | ).reshape(n_class, n_class) 19 | return hist 20 | 21 | def update(self, label_trues, label_preds): 22 | for lt, lp in zip(label_trues, label_preds): # zip through batch_size 23 | self.confusion_matrix += self._fast_hist( 24 | lt.flatten(), lp.flatten(), self.n_classes 25 | ) 26 | 27 | def get_scores(self): 28 | hist = self.confusion_matrix 29 | 30 | # Accuracy for all pixels 31 | acc = np.diag(hist).sum() / hist.sum() # sum of confusion matrix is equal to all pixels 32 | 33 | # For each class Intersection over Union (IU) score is: 34 | # true positive / (true positive + false positive + false negative) 35 | np.seterr(invalid='ignore') # , divide='ignore') 36 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) # div by 0 if class not in val set 37 | mean_iu = np.nanmean(iu) # nanmean ignores NaNs (happens if class not in val set), RuntimeWarning is raised 38 | return acc, mean_iu 39 | 40 | def reset(self): 41 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 42 | 43 | def get_consistency(self, gt, pred): # time, bs, h, w 44 | consistency = np.empty((gt.shape[0] - 1)) 45 | for t in range(1, gt.shape[0]): 46 | valid_mask = (gt[t - 1] >= 0) & (gt[t - 1] < self.n_classes) & (gt[t] >= 0) & (gt[t] < self.n_classes) 47 | diff_pred_valid = ((pred[t - 1] != pred[t]) & valid_mask) 48 | diff_gt_valid = ((gt[t - 1] != gt[t]) & valid_mask) 49 | inconsistencies_pred = diff_pred_valid & np.logical_not(diff_gt_valid) 50 | consistency[t - 1] = 1 - (inconsistencies_pred.sum() / (valid_mask & np.logical_not(diff_gt_valid)).sum()) 51 | return consistency 52 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # This file is part of f2fcss. 2 | # 3 | # Copyright (C) 2020 Manuel Rebol 4 | # Patrick Knoebelreiter 5 | # Institute for Computer Graphics and Vision, Graz University of Technology 6 | # https://www.tugraz.at/institute/icg/teams/team-pock/ 7 | # 8 | # f2fcss is free software: you can redistribute it and/or modify it under the 9 | # terms of the GNU Affero General Public License as published by the Free Software 10 | # Foundation, either version 3 of the License, or any later version. 11 | # 12 | # f2fcss is distributed in the hope that it will be useful, but WITHOUT ANY 13 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 14 | # FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Affero General Public License 17 | # along with this program. If not, see . 18 | # 19 | import numpy as np 20 | import yaml 21 | import argparse 22 | import torch 23 | from PIL import Image 24 | from torch.utils import data 25 | from datetime import datetime 26 | from data.cityscapes_loader import CityscapesLoader 27 | from data.metrics import RunningScore 28 | from utils.utils import * 29 | from os import path, makedirs 30 | from model.esp_net import ESPNet_L1b 31 | 32 | 33 | def eval(val_loader, net, tensorboard, device, batch_nr_training, image_save_path, save_pred_color=False, 34 | save_label_ids=False): 35 | batch_size = val_loader.batch_size 36 | assert batch_size == 1 # == val_loader.dataset.time_interval 37 | 38 | dataset_name = path.basename(val_loader.dataset.root[0]) + '_' + val_loader.dataset.split[0] 39 | if save_pred_color: 40 | image_save_color_path = path.join(image_save_path, "{}-dataset_color".format(dataset_name)) 41 | makedirs(image_save_color_path, exist_ok=True) 42 | if save_label_ids: 43 | image_save_label_path = path.join(image_save_path, "{}-dataset_labelIds".format(dataset_name)) 44 | makedirs(image_save_label_path, exist_ok=True) 45 | 46 | single_frame_net = False 47 | lstm_net = False 48 | if net.__class__.__name__ in ['ESPNet', 'ESPNet_C']: 49 | single_frame_net = True 50 | if net.__class__.__name__ in ['ESPNet_L1b', 'ESPNet_L1c', 'ESPNet_L1d', 'ESPNet_C_L1b']: 51 | lstm_net = True 52 | 53 | consistency_metric = [] 54 | running_metrics = RunningScore(19) 55 | 56 | for batch_nr, loader_data in enumerate(val_loader): 57 | images = loader_data[0] 58 | gt = loader_data[1] 59 | file_names = np.array(loader_data[2])[:, :, 0] # time, img/lbl, bs 60 | seq_nr_in_scene = loader_data[5][0].item() 61 | scene_nr = loader_data[6][0].item() 62 | 63 | images = images.transpose(0, 1).to(device) # => TimeStep, BatchSize, ... 64 | gt = gt[0].numpy() 65 | 66 | # Model prediction 67 | lstm_states = None 68 | if lstm_net and (seq_nr_in_scene != 0): 69 | lstm_states = lstm_states_prev 70 | if lstm_net: 71 | outputs, lstm_states_prev = net.forward(images, lstm_states) 72 | elif single_frame_net and images.shape[0] > 1: 73 | outputs = net.forward(images.transpose(0, 1)) 74 | outputs = outputs.transpose(0, 1) 75 | else: 76 | outputs = net.forward(images) 77 | prediction = outputs.data.argmax(2).cpu() 78 | prediction = prediction.reshape(prediction.shape[0] * prediction.shape[1], 1, prediction.shape[2], 79 | prediction.shape[3]) 80 | prediction = torch.nn.functional.interpolate(prediction.type(torch.float32), 81 | size=val_loader.dataset.full_resolution, mode='nearest').type( 82 | torch.uint8)[:, 0, :, :].numpy() 83 | if seq_nr_in_scene != 0: # if not scene start 84 | prediction_scene = np.concatenate( 85 | (prediction_prev[None, :, :], prediction)) # stack prev pred here in time dimension 86 | gt_scene = np.concatenate((gt_prev[None, :, :], gt)) 87 | else: 88 | prediction_scene = prediction 89 | gt_scene = gt 90 | 91 | # Compute metrics 92 | if gt.shape[2] != 1: 93 | running_metrics.update(gt, prediction) 94 | if prediction_scene.shape[0] > 1: # check if not time_steps = 1 and first sequence 95 | consistency_metric.append(running_metrics.get_consistency(gt_scene, prediction_scene)) 96 | mean_consistency = np.mean(consistency_metric[-1], axis=0) 97 | tensorboard.add_scalar('Validation_{}-dataset/Interval-Consistency'.format(dataset_name), 98 | mean_consistency, batch_nr) 99 | 100 | # Save images 101 | for t in range(prediction.shape[0]): 102 | if save_pred_color: 103 | Image.fromarray(val_loader.dataset.segmap_to_color(prediction[t], gt[t], gt.shape[2] != 1, False), 104 | mode='RGB').save(path.join(image_save_color_path, 105 | "{}.png".format( 106 | path.splitext(path.basename(file_names[t, 0]))[ 107 | 0].replace("leftImg8bit", "pred_color")))) 108 | if save_label_ids: 109 | Image.fromarray( 110 | CityscapesLoader.segmap_to_labelId_static(prediction[t]), 111 | mode='L').save(path.join(image_save_label_path, "{}.png".format( 112 | path.splitext(path.basename(file_names[t, 1]))[0]))) 113 | 114 | file_names_print = np.stack((file_names[0, 0], file_names[-1, 0])) # file_in_seq, img/lbl 115 | print_func = np.vectorize(lambda x: path.basename(x).replace('_leftImg8bit.png', '')) 116 | file_names_print = print_func(file_names_print).transpose() 117 | output_string = "<> Batch_nr [%d/%d] Batch_total[%d] " \ 118 | "%s Scene_nr-Seq_nr_in_scene [%s-%s]" % \ 119 | (batch_nr + 1, len(val_loader), batch_nr_training, 120 | file_names_print, scene_nr, seq_nr_in_scene) 121 | print(BColors.GREEN + output_string + BColors.RESET) 122 | 123 | prediction_prev = prediction[-1] 124 | gt_prev = gt[-1] 125 | 126 | # Final reports 127 | if consistency_metric: 128 | mean_consistency = np.concatenate(consistency_metric, axis=0).mean(axis=0) 129 | print("Mean Consistency: {:5.2f}%".format(mean_consistency * 100)) 130 | tensorboard.add_scalar("Validation_{}-dataset/Consistency".format(dataset_name, ), mean_consistency, 131 | batch_nr_training) 132 | if val_loader.dataset.split[0] != 'test': 133 | acc, mean_iou = running_metrics.get_scores() 134 | print('Accuracy: {:5.2f}% mIoU: {:5.2f}%'.format(acc * 100, mean_iou * 100)) 135 | tensorboard.add_scalar("Validation_{}-dataset/Accuracy".format(dataset_name), acc, batch_nr_training) 136 | tensorboard.add_scalar("Validation_{}-dataset/Mean_IoU".format(dataset_name), mean_iou, batch_nr_training) 137 | running_metrics.reset() 138 | print("Validation on", dataset_name, " dataset finished!") 139 | 140 | 141 | if __name__ == "__main__": 142 | parser = argparse.ArgumentParser(description="config") 143 | parser.add_argument( 144 | "--config", 145 | nargs="?", 146 | type=str, 147 | default="config/eval.yml", 148 | help="Configuration file to use" 149 | ) 150 | args = parser.parse_args() 151 | with open(args.config) as fp: 152 | cfg = yaml.safe_load(fp) 153 | 154 | # CUDA 155 | torch.cuda.init() 156 | cuda_available = torch.cuda.is_available() 157 | print("CUDA available:", cuda_available, ". Number of CUDA devices = ", torch.cuda.device_count()) 158 | device = torch.device( 159 | "cuda:{}".format(cfg['model']['gpu'] - 1) if torch.cuda.is_available() and cfg['model']['gpu'] else "cpu") 160 | 161 | batch_nr_training = 0 162 | if cfg['model']['gpu']: 163 | checkpoint = torch.load(cfg['model']['load']) 164 | else: 165 | checkpoint = torch.load(cfg['model']['load'], map_location='cpu') 166 | model_name = checkpoint["model_name"] 167 | 168 | # Setup dataloader 169 | v_loader = CityscapesLoader(root=cfg['data']['dataset'], 170 | split=cfg['data']['splits'], 171 | scene_group_size=0, 172 | nr_of_scenes=cfg['data']['nr_of_scenes'], 173 | nr_of_sequences=cfg['data']['nr_of_sequences'], 174 | seq_len=cfg['data']['sequence_length'], 175 | sequence_overlap=0, 176 | augmentations=['uniform'], 177 | scale_size=cfg['model']['input_size'], 178 | ) 179 | val_loader = data.DataLoader(v_loader, batch_size=1, num_workers=4, shuffle=False, ) 180 | 181 | # Load model 182 | if path.isfile(cfg['model']['load']): 183 | n_classes = CityscapesLoader.n_classes 184 | model_class = globals()[model_name] 185 | if model_class == ESPNet_L1b: 186 | if checkpoint['model_state']['encoder.clstm.cell.peephole_weights'] is not None: 187 | cell_type = 5 188 | if 'encoder.batch_norm.weight' in checkpoint['model_state']: 189 | activation_function = 'tanh' 190 | elif 'encoder.clstm.cell.activation_function.weight' in checkpoint['model_state']: 191 | activation_function = 'prelu' 192 | else: 193 | activation_function = 'lrelu' 194 | if checkpoint['model_state']['encoder.clstm.c0'].requires_grad: 195 | state_init = 'learn' 196 | net = ESPNet_L1b('default', activation_function, cfg['model']['input_size'], 197 | device, torch.float32, 1, state_init, cell_type, 1, 1, 198 | checkpoint['model_state']['encoder.clstm.cell.convolution.weight'].shape[-1], 0) 199 | if (cfg['model']['input_size'][0]//8 != checkpoint['model_state']['encoder.clstm.h0'].data.shape[1] or 200 | cfg['model']['input_size'][1]//8 != checkpoint['model_state']['encoder.clstm.h0'].data.shape[2]): 201 | print('WARNING: The model was trained with a different input size. Evaluation at different scale' 202 | ' leads to worse results.') 203 | checkpoint['model_state']['encoder.clstm.h0'].data = torch.nn.functional.interpolate( 204 | checkpoint['model_state']['encoder.clstm.h0'].data.unsqueeze(0), 205 | size=[cfg['model']['input_size'][0]//8, cfg['model']['input_size'][1]//8], 206 | mode='bilinear', 207 | align_corners=False).squeeze(0) 208 | checkpoint['model_state']['encoder.clstm.c0'].data = torch.nn.functional.interpolate( 209 | checkpoint['model_state']['encoder.clstm.c0'].data.unsqueeze(0), 210 | size=[cfg['model']['input_size'][0] // 8, cfg['model']['input_size'][1] // 8], 211 | mode='bilinear', 212 | align_corners=False).squeeze(0) 213 | net.load_state_dict(checkpoint["model_state"]) 214 | batch_nr_training = checkpoint["batch_total"] 215 | print("Loaded checkpoint '{}' (iteration {})".format(cfg['model']['load'], checkpoint["iter"])) 216 | else: 217 | raise Exception("No checkpoint found at '{}'".format(cfg['model']['load'])) 218 | 219 | # Set paths 220 | output_path = path.join('output', datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) 221 | image_path = path.join(output_path, 'images') 222 | makedirs(image_path) 223 | 224 | # Load tensorboard 225 | tensorboard = load_tensorboard(output_path, batch_nr_training) 226 | 227 | net.eval() 228 | net = net.to(device) 229 | with torch.no_grad(): 230 | eval(val_loader, net, tensorboard, device, batch_nr_training, image_path, 231 | cfg['output']['save_pred_color'], cfg['output']['save_label_ids']) 232 | 233 | tensorboard.close() 234 | -------------------------------------------------------------------------------- /model/ESPNetL1b-Cityscapes-pretrained.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/model/ESPNetL1b-Cityscapes-pretrained.pkl -------------------------------------------------------------------------------- /model/convolutional_lstm.py: -------------------------------------------------------------------------------- 1 | # This file is part of f2fcss. 2 | # 3 | # Copyright (C) 2020 Manuel Rebol 4 | # Patrick Knoebelreiter 5 | # Institute for Computer Graphics and Vision, Graz University of Technology 6 | # https://www.tugraz.at/institute/icg/teams/team-pock/ 7 | # 8 | # f2fcss is free software: you can redistribute it and/or modify it under the 9 | # terms of the GNU Affero General Public License as published by the Free Software 10 | # Foundation, either version 3 of the License, or any later version. 11 | # 12 | # f2fcss is distributed in the hope that it will be useful, but WITHOUT ANY 13 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 14 | # FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Affero General Public License 17 | # along with this program. If not, see . 18 | # 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | 24 | class ConvLSTMCell5(nn.Module): # normal conv with peephole connections 25 | def __init__(self, input_channels, hidden_channels, kernel_size, dilation, activation_function): 26 | super(ConvLSTMCell5, self).__init__() 27 | self.input_channels = input_channels 28 | self.hidden_channels = hidden_channels 29 | self.kernel_size = kernel_size 30 | self.dilation = dilation 31 | self.num_gates = 4 # f i g o 32 | self.padding = int((kernel_size - 1) / 2) 33 | self.convolution = nn.Conv2d(self.input_channels + self.hidden_channels, 4 * self.hidden_channels, 34 | self.kernel_size, stride=1, padding=self.padding, dilation=dilation) 35 | self.activation_function = activation_function 36 | self.peephole_weights = nn.Parameter(torch.zeros(3, self.hidden_channels), requires_grad=True) 37 | 38 | def forward(self, x, h, c): # batch, channel, height, width 39 | x_stack_h = torch.cat((x, h), dim=1) 40 | A = self.convolution((x_stack_h)) 41 | split_size = int(A.shape[1] / self.num_gates) 42 | (ai, af, ao, ag) = torch.split(A, split_size, dim=1) 43 | f = torch.sigmoid(af + c * self.peephole_weights[1, :, None, None]) 44 | i = torch.sigmoid(ai + c * self.peephole_weights[0, :, None, None]) 45 | g = self.activation_function(ag) 46 | o = torch.sigmoid(ao + c * self.peephole_weights[2, :, None, None]) 47 | new_c = f * c + i * g 48 | new_h = o * self.activation_function(new_c) 49 | return new_h, new_c 50 | 51 | 52 | class ConvLSTM(nn.Module): 53 | # input_channels corresponds to the first input feature map 54 | # hidden state is a list of succeeding lstm layers. 55 | def __init__(self, input_channels, hidden_channels, kernel_size, activation_function, device, 56 | dtype, state_init, cell_type, batch_size, time_steps, overlap, dilation=1, init='default', 57 | is_stateful=True, state_img_size=None): 58 | super(ConvLSTM, self).__init__() 59 | 60 | self.input_channels = input_channels 61 | self.hidden_channels = hidden_channels 62 | self.kernel_size = kernel_size 63 | self.dilation = dilation 64 | if activation_function == 'tanh': 65 | activation_function = torch.tanh 66 | elif activation_function == 'lrelu': 67 | activation_function = F.leaky_relu 68 | elif activation_function == 'prelu': 69 | activation_function = nn.PReLU() 70 | self.cell_type = cell_type 71 | if cell_type == 5: 72 | self.cell = ConvLSTMCell5(self.input_channels, self.hidden_channels, self.kernel_size, self.dilation, 73 | activation_function) 74 | self.is_stateful = is_stateful 75 | self.dtype = dtype 76 | self.device = device 77 | self.state_init = state_init 78 | self.state_img_size = state_img_size 79 | 80 | self.update_parameters(batch_size, time_steps, overlap) 81 | self.init_states(state_img_size, state_init) 82 | 83 | # initialization 84 | if init == 'default': 85 | self.cell.convolution.bias.data.fill_(0) # init all biases with 0 86 | nn.init.xavier_normal_(self.cell.convolution.weight.data[ 87 | 0 * self.cell.hidden_channels: 1 * self.cell.hidden_channels]) # sigmoid, i 88 | nn.init.xavier_normal_(self.cell.convolution.weight.data[ 89 | 1 * self.cell.hidden_channels: 2 * self.cell.hidden_channels]) # sigmoid, f 90 | self.cell.convolution.bias.data[1 * self.cell.hidden_channels: 2 * self.cell.hidden_channels].fill_( 91 | 0.1) # f bias 92 | nn.init.xavier_normal_(self.cell.convolution.weight.data[ 93 | 2 * self.cell.hidden_channels: 3 * self.cell.hidden_channels]) # sigmoid, o 94 | if cell_type == 5: 95 | nn.init.constant_(self.cell.peephole_weights, 0.1) 96 | if activation_function == 'tanh': 97 | nn.init.xavier_normal_(self.cell.convolution.weight.data[ 98 | 3 * self.cell.hidden_channels: 4 * self.cell.hidden_channels]) # tanh, g 99 | elif activation_function in ['lrelu', 'prelu']: 100 | nn.init.kaiming_normal_( 101 | self.cell.convolution.weight.data[3 * self.cell.hidden_channels: 4 * self.cell.hidden_channels], 102 | nonlinearity='leaky_relu') # lrelu, g 103 | 104 | def forward(self, inputs, states): # inputs shape: time_step, batch_size, channels, height, width 105 | new_states = None 106 | time_steps = inputs.shape[0] 107 | outputs = torch.empty(time_steps, self.batch_size, self.hidden_channels, inputs.shape[3], 108 | inputs.shape[4], dtype=self.dtype, device=self.device) 109 | if self.is_stateful == 0 or states is None: 110 | h = nn.functional.interpolate(self.h0.expand(self.batch_size, -1, -1, -1), 111 | size=(inputs.shape[3], inputs.shape[4]), mode='bilinear', align_corners=True) 112 | c = nn.functional.interpolate(self.c0.expand(self.batch_size, -1, -1, -1), 113 | size=(inputs.shape[3], inputs.shape[4]), mode='bilinear', align_corners=True) 114 | print("Init LSTM") 115 | else: 116 | c = states[0] 117 | h = states[1] 118 | 119 | for time_step in range(time_steps): 120 | x = inputs[time_step] 121 | h, c = self.cell(x, h, c) # to run hooks (pre, post) and .forward() 122 | 123 | if self.cell_type == 4: 124 | outputs[time_step] = h[:, :, 0] 125 | else: 126 | outputs[time_step] = h 127 | if self.is_stateful and time_step == time_steps - (self.overlap + 1): 128 | new_states = torch.stack((c.data, h.data)) 129 | 130 | return outputs, new_states 131 | 132 | def init_states(self, state_size, state_init): 133 | if state_init == 'zero': 134 | self.h0 = nn.Parameter(torch.zeros(self.hidden_channels, state_size[0], state_size[1], dtype=self.dtype), 135 | requires_grad=False) 136 | self.c0 = nn.Parameter(torch.zeros(self.hidden_channels, state_size[0], state_size[1], dtype=self.dtype), 137 | requires_grad=False) 138 | elif state_init == 'rand': # cell_state rand [0,1) init 139 | self.h0 = nn.Parameter(torch.rand(self.hidden_channels, state_size[0], state_size[1], dtype=self.dtype), 140 | requires_grad=False) 141 | self.c0 = nn.Parameter(torch.rand(self.hidden_channels, state_size[0], state_size[1], dtype=self.dtype), 142 | requires_grad=False) 143 | elif state_init == 'learn': 144 | if self.cell_type == 4: 145 | self.h0 = nn.Parameter(torch.zeros(19, 4, state_size[0], state_size[1], dtype=self.dtype), 146 | requires_grad=True) 147 | self.c0 = nn.Parameter(torch.zeros(19, 4, state_size[0], state_size[1], dtype=self.dtype), 148 | requires_grad=True) 149 | else: 150 | self.h0 = nn.Parameter( 151 | torch.zeros(self.hidden_channels, state_size[0], state_size[1], dtype=self.dtype), 152 | requires_grad=True) 153 | self.c0 = nn.Parameter( 154 | torch.zeros(self.hidden_channels, state_size[0], state_size[1], dtype=self.dtype), 155 | requires_grad=True) 156 | 157 | def update_parameters(self, batch_size, time_steps, overlap): 158 | self.time_steps = time_steps 159 | self.batch_size = batch_size 160 | self.overlap = overlap 161 | -------------------------------------------------------------------------------- /model/esp_net.py: -------------------------------------------------------------------------------- 1 | # Adapted from code written by Sachin Mehta 2 | # https://github.com/sacmehta/ESPNet/tree/master/test 3 | # 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import model.convolutional_lstm as clstm 8 | 9 | 10 | # Convolution with succeeding batch normalization and PReLU activation 11 | class CBR(nn.Module): 12 | ''' 13 | This class defines the convolution layer with batch normalization and PReLU activation 14 | ''' 15 | 16 | def __init__(self, nIn, nOut, kSize, stride=1): 17 | ''' 18 | :param nIn: number of input channels 19 | :param nOut: number of output channels 20 | :param kSize: kernel size 21 | :param stride: stride rate for down-sampling. Default is 1 22 | ''' 23 | super().__init__() 24 | padding = int((kSize - 1) / 2) 25 | # self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False) 26 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False) 27 | # self.conv1 = nn.Conv2d(nOut, nOut, (1, kSize), stride=1, padding=(0, padding), bias=False) 28 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03) 29 | self.act = nn.PReLU(nOut) 30 | 31 | def forward(self, input): 32 | ''' 33 | :param input: input feature map 34 | :return: transformed feature map 35 | ''' 36 | output = self.conv(input) 37 | # output = self.conv1(output) 38 | output = self.bn(output) 39 | output = self.act(output) 40 | return output 41 | 42 | 43 | # Batch normalization with succeeding PReLU activation 44 | class BR(nn.Module): 45 | ''' 46 | This class groups the batch normalization and PReLU activation 47 | ''' 48 | 49 | def __init__(self, nOut): 50 | ''' 51 | :param nOut: output feature maps 52 | ''' 53 | super().__init__() 54 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03) 55 | self.act = nn.PReLU(nOut) 56 | 57 | def forward(self, input): 58 | ''' 59 | :param input: input feature map 60 | :return: normalized and thresholded feature map 61 | ''' 62 | output = self.bn(input) 63 | output = self.act(output) 64 | return output 65 | 66 | 67 | # Convolution with succeeding batch normalization 68 | class CB(nn.Module): 69 | ''' 70 | This class groups the convolution and batch normalization 71 | ''' 72 | 73 | def __init__(self, nIn, nOut, kSize, stride=1): 74 | ''' 75 | :param nIn: number of input channels 76 | :param nOut: number of output channels 77 | :param kSize: kernel size 78 | :param stride: optinal stide for down-sampling 79 | ''' 80 | super().__init__() 81 | padding = int((kSize - 1) / 2) 82 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False) 83 | self.bn = nn.BatchNorm2d(nOut, eps=1e-03) 84 | 85 | def forward(self, input): 86 | ''' 87 | 88 | :param input: input feature map 89 | :return: transformed feature map 90 | ''' 91 | output = self.conv(input) 92 | output = self.bn(output) 93 | return output 94 | 95 | 96 | # Convolution with zero-padding 97 | class C(nn.Module): 98 | ''' 99 | This class is for a convolutional layer. 100 | ''' 101 | 102 | def __init__(self, nIn, nOut, kSize, stride=1): 103 | ''' 104 | 105 | :param nIn: number of input channels 106 | :param nOut: number of output channels 107 | :param kSize: kernel size 108 | :param stride: optional stride rate for down-sampling 109 | ''' 110 | super().__init__() 111 | padding = int((kSize - 1) / 2) 112 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False) 113 | 114 | def forward(self, input): 115 | ''' 116 | :param input: input feature map 117 | :return: transformed feature map 118 | ''' 119 | output = self.conv(input) 120 | return output 121 | 122 | 123 | # Dilated convolution with zero-padding 124 | class CDilated(nn.Module): 125 | ''' 126 | This class defines the dilated convolution. 127 | ''' 128 | 129 | def __init__(self, nIn, nOut, kSize, stride=1, d=1): 130 | ''' 131 | :param nIn: number of input channels 132 | :param nOut: number of output channels 133 | :param kSize: kernel size 134 | :param stride: optional stride rate for down-sampling 135 | :param d: optional dilation rate 136 | ''' 137 | super().__init__() 138 | padding = int((kSize - 1) / 2) * d 139 | self.conv = nn.Conv2d(nIn, nOut, (kSize, kSize), stride=stride, padding=(padding, padding), bias=False, 140 | dilation=d) 141 | 142 | def forward(self, input): 143 | ''' 144 | :param input: input feature map 145 | :return: transformed feature map 146 | ''' 147 | output = self.conv(input) 148 | return output 149 | 150 | 151 | # ESP block with downsampling (red): Spatial dimensions /2 e.g. 256x512 -> 128x256 152 | class DownSamplerB(nn.Module): 153 | def __init__(self, nIn, nOut): 154 | super().__init__() 155 | n = int(nOut / 5) 156 | n1 = nOut - 4 * n 157 | self.c1 = C(nIn, n, 3, 2) # os=2: difference to ESP block 158 | self.d1 = CDilated(n, n1, 3, 1, 1) 159 | self.d2 = CDilated(n, n, 3, 1, 2) 160 | self.d4 = CDilated(n, n, 3, 1, 4) 161 | self.d8 = CDilated(n, n, 3, 1, 8) 162 | self.d16 = CDilated(n, n, 3, 1, 16) 163 | self.bn = nn.BatchNorm2d(nOut, eps=1e-3) 164 | self.act = nn.PReLU(nOut) 165 | 166 | def forward(self, input): 167 | output1 = self.c1(input) 168 | d1 = self.d1(output1) # convolution with different dil on 169 | d2 = self.d2(output1) 170 | d4 = self.d4(output1) 171 | d8 = self.d8(output1) 172 | d16 = self.d16(output1) 173 | 174 | add1 = d2 # add this different dilations 175 | add2 = add1 + d4 176 | add3 = add2 + d8 177 | add4 = add3 + d16 178 | 179 | combine = torch.cat([d1, add1, add2, add3, add4], 1) 180 | # combine_in_out = input + combine 181 | output = self.bn(combine) 182 | output = self.act(output) 183 | return output 184 | 185 | 186 | # ESP block: spatial dim stay the same 187 | class DilatedParllelResidualBlockB(nn.Module): 188 | ''' 189 | This class defines the ESP block, which is based on the following principle 190 | Reduce ---> Split ---> Transform --> Merge 191 | ''' 192 | 193 | def __init__(self, nIn, nOut, add=True): 194 | ''' 195 | :param nIn: number of input channels 196 | :param nOut: number of output channels 197 | :param add: if true, add a residual connection through identity operation. You can use projection too as 198 | in ResNet paper, but we avoid to use it if the dimensions are not the same because we do not want to 199 | increase the module complexity 200 | ''' 201 | super().__init__() 202 | n = int(nOut / 5) 203 | n1 = nOut - 4 * n 204 | self.c1 = C(nIn, n, 1, 1) 205 | self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 206 | self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 207 | self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 208 | self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 209 | self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4 210 | self.bn = BR(nOut) 211 | self.add = add 212 | 213 | def forward(self, input): 214 | ''' 215 | :param input: input feature map 216 | :return: transformed feature map 217 | ''' 218 | # reduce 219 | output1 = self.c1(input) 220 | # split and transform 221 | d1 = self.d1(output1) 222 | d2 = self.d2(output1) 223 | d4 = self.d4(output1) 224 | d8 = self.d8(output1) 225 | d16 = self.d16(output1) 226 | 227 | # heirarchical fusion for de-gridding 228 | add1 = d2 229 | add2 = add1 + d4 230 | add3 = add2 + d8 231 | add4 = add3 + d16 232 | 233 | # merge 234 | combine = torch.cat([d1, add1, add2, add3, add4], 1) 235 | 236 | # if residual version 237 | if self.add: 238 | combine = input + combine 239 | output = self.bn(combine) 240 | return output 241 | 242 | 243 | # Apply avg-pooling n-times, RGB images with red arrow 244 | class InputProjectionA(nn.Module): 245 | ''' 246 | This class projects the input image to the same spatial dimensions as the feature map. 247 | For example, if the input image is 512 x512 x3 and spatial dimensions of feature map size are 56x56xF, then 248 | this class will generate an output of 56x56x3 249 | ''' 250 | 251 | def __init__(self, samplingTimes): 252 | ''' 253 | :param samplingTimes: The rate at which you want to down-sample the image 254 | ''' 255 | super().__init__() 256 | self.pool = nn.ModuleList() 257 | for i in range(0, samplingTimes): 258 | # pyramid-based approach for down-sampling 259 | self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) 260 | 261 | def forward(self, input): 262 | ''' 263 | :param input: Input RGB Image 264 | :return: down-sampled image (pyramid-based approach) 265 | ''' 266 | for pool in self.pool: 267 | input = pool(input) 268 | return input 269 | 270 | 271 | # ESPNet-C, Encoder part 272 | class ESPNet_Encoder(nn.Module): 273 | ''' 274 | This class defines the ESPNet-C network in the paper 275 | ''' 276 | 277 | def __init__(self, classes=19, p=2, q=3): 278 | ''' 279 | :param classes: number of classes in the dataset. Default is 19 for the cityscapes_video 280 | :param p: depth multiplier 281 | :param q: depth multiplier 282 | ''' 283 | super().__init__() 284 | self.level1 = CBR(3, 16, 3, 2) 285 | self.sample1 = InputProjectionA(1) 286 | self.sample2 = InputProjectionA(2) 287 | 288 | self.b1 = BR(16 + 3) 289 | self.level2_0 = DownSamplerB(16 + 3, 64) 290 | 291 | self.level2 = nn.ModuleList() 292 | for i in range(0, p): 293 | self.level2.append(DilatedParllelResidualBlockB(64, 64)) 294 | self.b2 = BR(128 + 3) 295 | 296 | self.level3_0 = DownSamplerB(128 + 3, 128) 297 | self.level3 = nn.ModuleList() 298 | for i in range(0, q): 299 | self.level3.append(DilatedParllelResidualBlockB(128, 128)) 300 | self.b3 = BR(256) 301 | 302 | self.classifier = C(256, classes, 1, 1) 303 | 304 | def forward(self, input): 305 | ''' 306 | :param input: Receives the input RGB image 307 | :return: the transformed feature map with spatial dimensions 1/8th of the input image 308 | ''' 309 | input = input.squeeze(0) # 5Ax -> 4Ax 310 | 311 | output0 = self.level1(input) # 512x1024 --> 256x512 312 | inp1 = self.sample1(input) # scale down RGB 313 | inp2 = self.sample2(input) # scale down RGB 314 | 315 | output0_cat = self.b1(torch.cat([output0, inp1], 1)) # Concat_0 316 | output1_0 = self.level2_0(output0_cat) # down-sampled, ESP_red_0 317 | 318 | for i, layer in enumerate(self.level2): # p ESP-blocks, p..alpha_2 319 | if i == 0: 320 | output1 = layer(output1_0) 321 | else: 322 | output1 = layer(output1) 323 | 324 | output1_cat = self.b2(torch.cat([output1, output1_0, inp2], 1)) # Concat_1 325 | 326 | output2_0 = self.level3_0(output1_cat) # down-sampled, ESP_red_1 327 | for i, layer in enumerate(self.level3): # q ESP-blocks, q..alpha_3 328 | if i == 0: 329 | output2 = layer(output2_0) 330 | else: 331 | output2 = layer(output2) 332 | 333 | output2_cat = self.b3(torch.cat([output2_0, output2], 1)) # Concat_2 334 | 335 | classifier = self.classifier(output2_cat) 336 | 337 | classifier = F.softmax(classifier, dim=1) 338 | classifier = classifier.unsqueeze(0) # 4Ax -> 5Ax 339 | return classifier 340 | 341 | 342 | class ESPNet(nn.Module): 343 | ''' 344 | This class defines the ESPNet network 345 | ''' 346 | 347 | def __init__(self, lstm_filter_size, device, dtype, state_init, cell_type, batch_size, time_steps, overlap, 348 | val_img_size, 349 | lstm_activation_function, classes=19, p=2, q=3, encoder_type=None, encoderFile=None): 350 | ''' 351 | :param classes: number of classes in the dataset. Default is 19 for the cityscapes_video 352 | :param p: depth multiplier 353 | :param q: depth multiplier 354 | :param encoderFile: pretrained encoder weights. Recall that we first trained the ESPNet-C and then attached the 355 | RUM-based light weight decoder. See paper for more details. 356 | ''' 357 | super().__init__() 358 | if encoder_type == 'ESPNet_C_L1b': 359 | self.encoder = ESPNet_C_L1b(lstm_filter_size, device, dtype, state_init, cell_type, batch_size, 360 | time_steps, overlap, val_img_size, lstm_activation_function, classes, p, q) 361 | elif encoder_type == 'ESPNet_C': 362 | self.encoder = ESPNet_Encoder(classes, p, q) 363 | else: 364 | assert False 365 | if encoderFile != None: 366 | self.encoder.load_state_dict(torch.load(encoderFile)) 367 | print('Encoder loaded!') 368 | # load the encoder modules 369 | self.modules = [] 370 | for i, m in enumerate(self.encoder.children()): 371 | self.modules.append(m) 372 | 373 | # light-weight decoder 374 | self.level3_C = C(128 + 3, classes, 1, 1) 375 | self.br = nn.BatchNorm2d(classes, eps=1e-03) 376 | self.conv = CBR(16 + classes, classes, 3, 1) 377 | 378 | self.up_l3 = nn.Sequential( 379 | nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False)) 380 | self.combine_l2_l3 = nn.Sequential(BR(2 * classes), 381 | DilatedParllelResidualBlockB(2 * classes, classes, add=False)) 382 | 383 | self.up_l2 = nn.Sequential( 384 | nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False), BR(classes)) 385 | 386 | self.classifier = nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False) 387 | 388 | def forward(self, input): 389 | ''' 390 | :param input: RGB image 391 | :return: transformed feature map 392 | ''' 393 | input = input.squeeze(0) # 5Ax -> 4Ax 394 | 395 | output0 = self.modules[0](input) # Conv-3_red 396 | inp1 = self.modules[1](input) # RGB_1, down-scaled by recursive avg-pooling 397 | inp2 = self.modules[2](input) # RGB_2, down-scaled by recursive 0avg-pooling 398 | 399 | output0_cat = self.modules[3](torch.cat([output0, inp1], 1)) # Concat_0 400 | output1_0 = self.modules[4](output0_cat) # down-sampled, ESP_red_0 401 | 402 | for i, layer in enumerate(self.modules[5]): # p times ESP_0 403 | if i == 0: 404 | output1 = layer(output1_0) 405 | else: 406 | output1 = layer(output1) 407 | 408 | output1_cat = self.modules[6](torch.cat([output1, output1_0, inp2], 1)) # Concat_1 409 | 410 | output2_0 = self.modules[7](output1_cat) # down-sampled, ESP_red_1 411 | for i, layer in enumerate(self.modules[8]): # q times ESP_1 412 | if i == 0: 413 | output2 = layer(output2_0) 414 | else: 415 | output2 = layer(output2) 416 | 417 | output2_cat = self.modules[9]( 418 | torch.cat([output2_0, output2], 1)) # concatenate for feature map width expansion, Concat_2 419 | 420 | output2_c = self.up_l3(self.br(self.modules[10](output2_cat))) # RUM, Conv-1_2 + DeConv_green_0 421 | 422 | output1_C = self.level3_C(output1_cat) # project to C-dimensional space, Conv-1_1 423 | comb_l2_l3 = self.up_l2( 424 | self.combine_l2_l3(torch.cat([output1_C, output2_c], 1))) # RUM, Concat_3 + ESP_2 + DeConv_green_1 425 | 426 | concat_features = self.conv(torch.cat([comb_l2_l3, output0], 1)) # Concat_4 + Conv-1 427 | 428 | classifier = self.classifier(concat_features) # DeConv_green_2 429 | 430 | classifier = F.softmax(classifier, dim=1) 431 | classifier = classifier.unsqueeze(0) # 4Ax -> 5Ax 432 | return classifier 433 | 434 | def logging(self, batch_total, tensorboard): 435 | # no biases 436 | tensorboard.add_histogram("Conv-3/weights", self.modules[0].conv.weight.data, batch_total) 437 | tensorboard.add_histogram("ESP_red_1/c1-weights", self.modules[7].c1.conv.weight.data, batch_total) 438 | tensorboard.add_histogram("ESP_red_1/d16-weights", self.modules[7].d16.conv.weight.data, batch_total) 439 | tensorboard.add_histogram("DeConv_green_1/weights", self.up_l2[0].weight.data, batch_total) 440 | 441 | tensorboard.add_scalar('Weights/Conv-3_abs_mean', self.modules[0].conv.weight.data.abs().mean(), batch_total) 442 | tensorboard.add_scalar('Weights/ESP_red_1-c1_abs_mean', self.modules[7].c1.conv.weight.data.abs().mean(), 443 | batch_total) 444 | tensorboard.add_scalar('Weights/ESP_red_1-d16_abs_mean', self.modules[7].d16.conv.weight.data.abs().mean(), 445 | batch_total) 446 | tensorboard.add_scalar('Weights/DeConv_green_1_abs_mean', self.up_l2[0].weight.data.abs().mean(), batch_total) 447 | 448 | if self.modules[0].conv.weight.requires_grad and self.modules[0].conv.weight.grad is not None: 449 | tensorboard.add_histogram('Conv-3/grad_hist', self.modules[0].conv.weight.grad.data, batch_total) 450 | tensorboard.add_scalar('Gradient/Conv-3_abs_mean', self.modules[0].conv.weight.grad.data.abs().mean(), 451 | batch_total) 452 | if self.modules[7].c1.conv.weight.requires_grad and self.modules[7].c1.conv.weight.grad is not None: 453 | tensorboard.add_histogram('ESP_red_1/c1-grad_hist', self.modules[7].c1.conv.weight.grad.data, batch_total) 454 | tensorboard.add_scalar('Gradient/CESP_red_1-c1_abs_mean', 455 | self.modules[7].c1.conv.weight.grad.data.abs().mean(), batch_total) 456 | if self.modules[7].d16.conv.weight.requires_grad and self.modules[7].d16.conv.weight.grad is not None: 457 | tensorboard.add_histogram('ESP_red_1/d16-grad_hist', self.modules[7].d16.conv.weight.grad.data, batch_total) 458 | tensorboard.add_scalar('Gradient/CESP_red_1-d16_abs_mean', 459 | self.modules[7].d16.conv.weight.grad.data.abs().mean(), batch_total) 460 | if self.up_l2[0].weight.requires_grad and self.up_l2[0].weight.grad is not None: 461 | tensorboard.add_histogram('DeConv_green_1/grad_hist', self.up_l2[0].weight.grad.data, batch_total) 462 | tensorboard.add_scalar('Gradient/DeConv_green_1_abs_mean', self.up_l2[0].weight.grad.data.abs().mean(), 463 | batch_total) 464 | 465 | 466 | class ESPNet_C_L1b(nn.Module): 467 | def __init__(self, lstm_filter_size, device, dtype, state_init, cell_type, batch_size, time_steps, overlap, 468 | val_img_size, lstm_activation_function, classes=19, p=2, q=3, init='default'): 469 | super().__init__() 470 | self.val_img_size = val_img_size 471 | self.state_scale_factor = 8 472 | self.batch_size = batch_size 473 | self.state_channels = 19 474 | 475 | self.level1 = CBR(3, 16, 3, 2) 476 | self.sample1 = InputProjectionA(1) 477 | self.sample2 = InputProjectionA(2) 478 | 479 | self.b1 = BR(16 + 3) 480 | self.level2_0 = DownSamplerB(16 + 3, 64) 481 | 482 | self.level2 = nn.ModuleList() 483 | for i in range(0, p): 484 | self.level2.append(DilatedParllelResidualBlockB(64, 64)) 485 | self.b2 = BR(128 + 3) 486 | 487 | self.level3_0 = DownSamplerB(128 + 3, 128) 488 | self.level3 = nn.ModuleList() 489 | for i in range(0, q): 490 | self.level3.append(DilatedParllelResidualBlockB(128, 128)) 491 | self.b3 = BR(256) 492 | 493 | # LSTM stuff 494 | self.clstm = clstm.ConvLSTM(256, 19, lstm_filter_size, lstm_activation_function, device, dtype, 495 | state_init, cell_type, batch_size, time_steps, overlap, 496 | state_img_size=[val_img_size[0] // 8, val_img_size[1] // 8]) 497 | self.is_batch_norm = False 498 | if lstm_activation_function == 'tanh': 499 | self.is_batch_norm = True 500 | if self.is_batch_norm: 501 | self.batch_norm = nn.BatchNorm2d(256) 502 | 503 | def forward(self, input, states): 504 | ''' 505 | :param input: Receives the input RGB image 506 | :return: the transformed feature map with spatial dimensions 1/8th of the input image 507 | ''' 508 | input = input.contiguous().view(-1, input.shape[2], input.shape[3], input.shape[4]) # merge time and bs dim 509 | 510 | output0 = self.level1(input) # 512x1024 --> 256x512 511 | inp1 = self.sample1(input) # scale down RGB 512 | inp2 = self.sample2(input) # scale down RGB 513 | 514 | output0_cat = self.b1(torch.cat([output0, inp1], 1)) # Concat_0 515 | output1_0 = self.level2_0(output0_cat) # down-sampled, ESP_red_0 516 | 517 | for i, layer in enumerate(self.level2): # p ESP-blocks, p..alpha_2 518 | if i == 0: 519 | output1 = layer(output1_0) 520 | else: 521 | output1 = layer(output1) 522 | 523 | output1_cat = self.b2(torch.cat([output1, output1_0, inp2], 1)) # Concat_1 524 | 525 | output2_0 = self.level3_0(output1_cat) # down-sampled, ESP_red_1 526 | for i, layer in enumerate(self.level3): # q ESP-blocks, q..alpha_3 527 | if i == 0: 528 | output2 = layer(output2_0) 529 | else: 530 | output2 = layer(output2) 531 | 532 | output2_cat = self.b3(torch.cat([output2_0, output2], 1)) # Concat_2 533 | 534 | # LSTM here 535 | if self.is_batch_norm: 536 | batch_norm_features = self.batch_norm(output2_cat) 537 | else: 538 | batch_norm_features = output2_cat 539 | lstm_in = batch_norm_features.view(-1, self.batch_size, batch_norm_features.shape[1], 540 | batch_norm_features.shape[2], 541 | batch_norm_features.shape[3]) # -1 ... time_steps, 1..bs 542 | lstm_out, new_states = self.clstm(lstm_in, states) 543 | 544 | classifier = F.softmax(lstm_out, dim=2) 545 | return classifier, new_states 546 | 547 | def logging(self, batch_total, tensorboard): 548 | # SingleFrame 549 | tensorboard.add_histogram("Conv-3/weights", self.level1.conv.weight.data, batch_total) 550 | tensorboard.add_histogram("ESP_red_1/c1-weights", self.level2_0.c1.conv.weight.data, batch_total) 551 | tensorboard.add_histogram("ESP_red_1/d16-weights", self.level2_0.d16.conv.weight.data, batch_total) 552 | 553 | tensorboard.add_scalar('Weights/Conv-3_abs_mean', self.level1.conv.weight.data.abs().mean(), batch_total) 554 | tensorboard.add_scalar('Weights/ESP_red_1-c1_abs_mean', self.level2_0.c1.conv.weight.data.abs().mean(), 555 | batch_total) 556 | tensorboard.add_scalar('Weights/ESP_red_1-d16_abs_mean', self.level2_0.d16.conv.weight.data.abs().mean(), 557 | batch_total) 558 | 559 | # ConvLSTM 560 | tensorboard.add_histogram("clstm/input/weights", self.clstm.cell.convolution.weight.data[ 561 | 0 * self.clstm.cell.hidden_channels: 1 * self.clstm.cell.hidden_channels], 562 | batch_total) 563 | tensorboard.add_histogram("clstm/input/bias", self.clstm.cell.convolution.bias.data[ 564 | 0 * self.clstm.cell.hidden_channels: 1 * self.clstm.cell.hidden_channels], 565 | batch_total) 566 | tensorboard.add_histogram("clstm/forget/weights", self.clstm.cell.convolution.weight.data[ 567 | 1 * self.clstm.cell.hidden_channels: 2 * self.clstm.cell.hidden_channels], 568 | batch_total) 569 | tensorboard.add_histogram("clstm/forget/bias", self.clstm.cell.convolution.bias.data[ 570 | 1 * self.clstm.cell.hidden_channels: 2 * self.clstm.cell.hidden_channels], 571 | batch_total) 572 | tensorboard.add_histogram("clstm/output/weights", self.clstm.cell.convolution.weight.data[ 573 | 2 * self.clstm.cell.hidden_channels: 3 * self.clstm.cell.hidden_channels], 574 | batch_total) 575 | tensorboard.add_histogram("clstm/output/bias", self.clstm.cell.convolution.bias.data[ 576 | 2 * self.clstm.cell.hidden_channels: 3 * self.clstm.cell.hidden_channels], 577 | batch_total) 578 | tensorboard.add_histogram("clstm/gate/weights", self.clstm.cell.convolution.weight.data[ 579 | 3 * self.clstm.cell.hidden_channels: 4 * self.clstm.cell.hidden_channels], 580 | batch_total) 581 | tensorboard.add_histogram("clstm/gate/bias", self.clstm.cell.convolution.bias.data[ 582 | 3 * self.clstm.cell.hidden_channels: 4 * self.clstm.cell.hidden_channels], 583 | batch_total) 584 | 585 | tensorboard.add_scalar('Weights/clstm_abs_mean', self.clstm.cell.convolution.weight.data.abs().mean(), 586 | batch_total) 587 | tensorboard.add_scalar('Weights/clstm/forget_abs_mean', self.clstm.cell.convolution.weight.data[ 588 | 1 * self.clstm.cell.hidden_channels: 2 * self.clstm.cell.hidden_channels].abs().mean(), 589 | batch_total) 590 | tensorboard.add_scalar('Bias/clstm/forget_abs_mean', self.clstm.cell.convolution.weight.data[ 591 | 1 * self.clstm.cell.hidden_channels: 2 * self.clstm.cell.hidden_channels].abs().mean(), 592 | batch_total) 593 | 594 | if self.clstm.cell.convolution.weight.requires_grad and self.clstm.cell.convolution.weight.grad is not None and False: 595 | tensorboard.add_histogram('clstm/grad_hist', self.clstm.cell.convolution.weight.grad.data, batch_total) 596 | tensorboard.add_scalar('Gradient/clstm/input_weight_abs_mean', self.clstm.cell.convolution.weight.grad.data[ 597 | 0 * self.clstm.cell.hidden_channels: 1 * self.clstm.cell.hidden_channels].abs().mean(), 598 | batch_total) 599 | tensorboard.add_scalar('Gradient/clstm/input_bias_abs_mean', self.clstm.cell.convolution.bias.grad.data[ 600 | 0 * self.clstm.cell.hidden_channels: 1 * self.clstm.cell.hidden_channels].abs().mean(), 601 | batch_total) 602 | tensorboard.add_scalar('Gradient/clstm/forget_weight_abs_mean', 603 | self.clstm.cell.convolution.weight.grad.data[ 604 | 1 * self.clstm.cell.hidden_channels: 2 * self.clstm.cell.hidden_channels].abs().mean(), 605 | batch_total) 606 | tensorboard.add_scalar('Gradient/clstm/forget_bias_abs_mean', self.clstm.cell.convolution.bias.grad.data[ 607 | 1 * self.clstm.cell.hidden_channels: 2 * self.clstm.cell.hidden_channels].abs().mean(), 608 | batch_total) 609 | tensorboard.add_scalar('Gradient/clstm/output_weight_abs_mean', 610 | self.clstm.cell.convolution.weight.grad.data[ 611 | 2 * self.clstm.cell.hidden_channels: 3 * self.clstm.cell.hidden_channels].abs().mean(), 612 | batch_total) 613 | tensorboard.add_scalar('Gradient/clstm/output_bias_abs_mean', self.clstm.cell.convolution.bias.grad.data[ 614 | 2 * self.clstm.cell.hidden_channels: 3 * self.clstm.cell.hidden_channels].abs().mean(), 615 | batch_total) 616 | tensorboard.add_scalar('Gradient/clstm/gate_weight_abs_mean', self.clstm.cell.convolution.weight.grad.data[ 617 | 3 * self.clstm.cell.hidden_channels: 4 * self.clstm.cell.hidden_channels].abs().mean(), 618 | batch_total) 619 | tensorboard.add_scalar('Gradient/clstm/gate_bias_abs_mean', self.clstm.cell.convolution.bias.grad.data[ 620 | 3 * self.clstm.cell.hidden_channels: 4 * self.clstm.cell.hidden_channels].abs().mean(), 621 | batch_total) 622 | tensorboard.add_scalar('Gradient/clstm_abs_mean', self.clstm.cell.convolution.weight.grad.data.abs().mean(), 623 | batch_total) 624 | 625 | if self.clstm.state_init == 'learn': 626 | tensorboard.add_histogram("clstm/c0", self.clstm.c0.data, batch_total) 627 | tensorboard.add_histogram("clstm/h0", self.clstm.h0.data, batch_total) 628 | 629 | if self.is_batch_norm: 630 | tensorboard.add_histogram('clstm/batch_norm_before', self.batch_norm.weight.data, batch_total) 631 | 632 | def update_parameters(self, batch_size, time_steps, overlap): 633 | self.time_steps = time_steps 634 | self.batch_size = batch_size 635 | self.clstm.update_parameters(batch_size, time_steps, overlap) 636 | 637 | 638 | class ESPNet_L1b(ESPNet): 639 | 640 | def __init__(self, init, lstm_activation_function, img_size, device, dtype, 641 | is_stateful, state_init, cell_type, batch_size, time_steps, 642 | lstm_filter_size, overlap, classes=19, p=2, q=3, encoderFile=None): 643 | super().__init__(lstm_filter_size, device, dtype, state_init, cell_type, batch_size, time_steps, overlap, 644 | img_size, lstm_activation_function, classes, p, q, 'ESPNet_C_L1b', encoderFile) 645 | 646 | def forward(self, input, states): 647 | # Dimensions: Time, BatchSize, Channels, Height, Width 648 | input = input.view(-1, input.shape[2], input.shape[3], input.shape[4]) # merge time and bs dim 649 | 650 | output0 = self.modules[0](input) # Conv-3_red 651 | inp1 = self.modules[1](input) # RGB_1, down-scaled by recursive avg-pooling 652 | inp2 = self.modules[2](input) # RGB_2, down-scaled by recursive 0avg-pooling 653 | output0_cat = self.modules[3](torch.cat([output0, inp1], 1)) # Concat_0 654 | output1_0 = self.modules[4](output0_cat) # down-sampled, ESP_red_0 655 | 656 | for i, layer in enumerate(self.modules[5]): # p times ESP_0 657 | if i == 0: 658 | output1 = layer(output1_0) 659 | else: 660 | output1 = layer(output1) 661 | 662 | output1_cat = self.modules[6](torch.cat([output1, output1_0, inp2], 1)) # Concat_1 663 | 664 | output2_0 = self.modules[7](output1_cat) # down-sampled, ESP_red_1 665 | for i, layer in enumerate(self.modules[8]): # q times ESP_1 666 | if i == 0: 667 | output2 = layer(output2_0) 668 | else: 669 | output2 = layer(output2) 670 | 671 | output2_cat = self.modules[9]( 672 | torch.cat([output2_0, output2], 1)) # concatenate for feature map width expansion, Concat_2 673 | 674 | if self.encoder.is_batch_norm: 675 | batch_norm_features = self.modules[10](output2_cat) 676 | lstm_in = batch_norm_features.view(-1, self.encoder.batch_size, batch_norm_features.shape[1], 677 | batch_norm_features.shape[2], 678 | batch_norm_features.shape[3]) # -1 ... time_steps, 1..bs 679 | lstm_out, new_states = self.modules[11](lstm_in, states) 680 | lstm_out = lstm_out.view(-1, lstm_out.shape[2], lstm_out.shape[3], lstm_out.shape[4]) 681 | else: 682 | batch_norm_features = output2_cat 683 | lstm_in = batch_norm_features.view(-1, self.encoder.batch_size, batch_norm_features.shape[1], 684 | batch_norm_features.shape[2], 685 | batch_norm_features.shape[3]) # -1 ... time_steps, 1..bs 686 | lstm_out, new_states = self.modules[10](lstm_in, states) 687 | lstm_out = lstm_out.view(-1, lstm_out.shape[2], lstm_out.shape[3], lstm_out.shape[4]) 688 | 689 | output2_c = self.up_l3(self.br(lstm_out)) # RUM, Conv-1_2 + DeConv_green_0 690 | 691 | output1_C = self.level3_C(output1_cat) # project to C-dimensional space, Conv-1_1 692 | comb_l2_l3 = self.up_l2( 693 | self.combine_l2_l3(torch.cat([output1_C, output2_c], 1))) # RUM, Concat_3 + ESP_2 + DeConv_green_1 694 | 695 | concat_features = self.conv(torch.cat([comb_l2_l3, output0], 1)) # Concat_4 + Conv-1 696 | 697 | classifier = self.classifier(concat_features) # DeConv_green_2 698 | classifier = classifier.view(-1, self.encoder.batch_size, classifier.shape[1], classifier.shape[2], 699 | classifier.shape[3]) # -1 ... time_steps, 1..bs 700 | return F.softmax(classifier, dim=2), new_states 701 | 702 | def logging(self, batch_total, tensorboard): 703 | super().logging(batch_total, tensorboard) 704 | # ConvLSTM 705 | tensorboard.add_histogram("encoder.clstm/input/weights", self.encoder.clstm.cell.convolution.weight.data[ 706 | 0 * self.encoder.clstm.cell.hidden_channels: 1 * self.encoder.clstm.cell.hidden_channels], 707 | batch_total) 708 | tensorboard.add_histogram("encoder.clstm/input/bias", self.encoder.clstm.cell.convolution.bias.data[ 709 | 0 * self.encoder.clstm.cell.hidden_channels: 1 * self.encoder.clstm.cell.hidden_channels], 710 | batch_total) 711 | tensorboard.add_histogram("encoder.clstm/forget/weights", self.encoder.clstm.cell.convolution.weight.data[ 712 | 1 * self.encoder.clstm.cell.hidden_channels: 2 * self.encoder.clstm.cell.hidden_channels], 713 | batch_total) 714 | tensorboard.add_histogram("encoder.clstm/forget/bias", self.encoder.clstm.cell.convolution.bias.data[ 715 | 1 * self.encoder.clstm.cell.hidden_channels: 2 * self.encoder.clstm.cell.hidden_channels], 716 | batch_total) 717 | tensorboard.add_histogram("encoder.clstm/output/weights", self.encoder.clstm.cell.convolution.weight.data[ 718 | 2 * self.encoder.clstm.cell.hidden_channels: 3 * self.encoder.clstm.cell.hidden_channels], 719 | batch_total) 720 | tensorboard.add_histogram("encoder.clstm/output/bias", self.encoder.clstm.cell.convolution.bias.data[ 721 | 2 * self.encoder.clstm.cell.hidden_channels: 3 * self.encoder.clstm.cell.hidden_channels], 722 | batch_total) 723 | tensorboard.add_histogram("encoder.clstm/gate/weights", self.encoder.clstm.cell.convolution.weight.data[ 724 | 3 * self.encoder.clstm.cell.hidden_channels: 4 * self.encoder.clstm.cell.hidden_channels], 725 | batch_total) 726 | tensorboard.add_histogram("encoder.clstm/gate/bias", self.encoder.clstm.cell.convolution.bias.data[ 727 | 3 * self.encoder.clstm.cell.hidden_channels: 4 * self.encoder.clstm.cell.hidden_channels], 728 | batch_total) 729 | 730 | tensorboard.add_scalar('Weights/encoder.clstm_abs_mean', 731 | self.encoder.clstm.cell.convolution.weight.data.abs().mean(), batch_total) 732 | tensorboard.add_scalar('Weights/encoder.clstm/forget_abs_mean', self.encoder.clstm.cell.convolution.weight.data[ 733 | 1 * self.encoder.clstm.cell.hidden_channels: 2 * self.encoder.clstm.cell.hidden_channels].abs().mean(), 734 | batch_total) 735 | tensorboard.add_scalar('Bias/encoder.clstm/forget_abs_mean', self.encoder.clstm.cell.convolution.weight.data[ 736 | 1 * self.encoder.clstm.cell.hidden_channels: 2 * self.encoder.clstm.cell.hidden_channels].abs().mean(), 737 | batch_total) 738 | 739 | if self.encoder.clstm.cell.convolution.weight.requires_grad and self.encoder.clstm.cell.convolution.weight.grad is not None: 740 | tensorboard.add_histogram('encoder.clstm/grad_hist', self.encoder.clstm.cell.convolution.weight.grad.data, 741 | batch_total) 742 | tensorboard.add_scalar('Gradient/encoder.clstm/input_weight_abs_mean', 743 | self.encoder.clstm.cell.convolution.weight.grad.data[ 744 | 0 * self.encoder.clstm.cell.hidden_channels: 1 * self.encoder.clstm.cell.hidden_channels].abs().mean(), 745 | batch_total) 746 | tensorboard.add_scalar('Gradient/encoder.clstm/input_bias_abs_mean', 747 | self.encoder.clstm.cell.convolution.bias.grad.data[ 748 | 0 * self.encoder.clstm.cell.hidden_channels: 1 * self.encoder.clstm.cell.hidden_channels].abs().mean(), 749 | batch_total) 750 | tensorboard.add_scalar('Gradient/encoder.clstm/forget_weight_abs_mean', 751 | self.encoder.clstm.cell.convolution.weight.grad.data[ 752 | 1 * self.encoder.clstm.cell.hidden_channels: 2 * self.encoder.clstm.cell.hidden_channels].abs().mean(), 753 | batch_total) 754 | tensorboard.add_scalar('Gradient/encoder.clstm/forget_bias_abs_mean', 755 | self.encoder.clstm.cell.convolution.bias.grad.data[ 756 | 1 * self.encoder.clstm.cell.hidden_channels: 2 * self.encoder.clstm.cell.hidden_channels].abs().mean(), 757 | batch_total) 758 | tensorboard.add_scalar('Gradient/encoder.clstm/output_weight_abs_mean', 759 | self.encoder.clstm.cell.convolution.weight.grad.data[ 760 | 2 * self.encoder.clstm.cell.hidden_channels: 3 * self.encoder.clstm.cell.hidden_channels].abs().mean(), 761 | batch_total) 762 | tensorboard.add_scalar('Gradient/encoder.clstm/output_bias_abs_mean', 763 | self.encoder.clstm.cell.convolution.bias.grad.data[ 764 | 2 * self.encoder.clstm.cell.hidden_channels: 3 * self.encoder.clstm.cell.hidden_channels].abs().mean(), 765 | batch_total) 766 | tensorboard.add_scalar('Gradient/encoder.clstm/gate_weight_abs_mean', 767 | self.encoder.clstm.cell.convolution.weight.grad.data[ 768 | 3 * self.encoder.clstm.cell.hidden_channels: 4 * self.encoder.clstm.cell.hidden_channels].abs().mean(), 769 | batch_total) 770 | tensorboard.add_scalar('Gradient/encoder.clstm/gate_bias_abs_mean', 771 | self.encoder.clstm.cell.convolution.bias.grad.data[ 772 | 3 * self.encoder.clstm.cell.hidden_channels: 4 * self.encoder.clstm.cell.hidden_channels].abs().mean(), 773 | batch_total) 774 | tensorboard.add_scalar('Gradient/encoder.clstm_abs_mean', 775 | self.encoder.clstm.cell.convolution.weight.grad.data.abs().mean(), batch_total) 776 | 777 | if self.encoder.clstm.state_init == 'learn': 778 | tensorboard.add_histogram("encoder.clstm/c0", self.encoder.clstm.c0.data, batch_total) 779 | tensorboard.add_histogram("encoder.clstm/h0", self.encoder.clstm.h0.data, batch_total) 780 | 781 | def update_parameters(self, batch_size, time_steps, overlap): 782 | self.encoder.update_parameters(batch_size, time_steps, overlap) 783 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | # This file is part of f2fcss. 2 | # 3 | # Copyright (C) 2020 Manuel Rebol 4 | # Patrick Knoebelreiter 5 | # Institute for Computer Graphics and Vision, Graz University of Technology 6 | # https://www.tugraz.at/institute/icg/teams/team-pock/ 7 | # 8 | # f2fcss is free software: you can redistribute it and/or modify it under the 9 | # terms of the GNU Affero General Public License as published by the Free Software 10 | # Foundation, either version 3 of the License, or any later version. 11 | # 12 | # f2fcss is distributed in the hope that it will be useful, but WITHOUT ANY 13 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 14 | # FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Affero General Public License 17 | # along with this program. If not, see . 18 | # 19 | import torch 20 | import numpy as np 21 | import torch.nn as nn 22 | import scipy.ndimage as sp_img 23 | 24 | 25 | def video_loss(output, target, cross_entropy_lambda, consistency_lambda, consistency_function, ignore_class): 26 | # output: Time, BatchSize, Channels, Height, Width 27 | # labels: Time, BatchSize, Height, Width 28 | valid_mask = (target != ignore_class) 29 | target_select = target.clone() 30 | target_select[target_select == ignore_class] = 0 31 | target_select = target_select[:, :, None, :, :].long() 32 | 33 | loss_cross_entropy = torch.tensor([0.0], dtype=torch.float32, device=output.device) 34 | if cross_entropy_lambda > 0: 35 | loss_cross_entropy = cross_entropy_lambda * cross_entropy_loss(output, target_select, valid_mask) 36 | 37 | loss_inconsistency = torch.tensor([0.0], dtype=torch.float32, device=output.device) 38 | if consistency_lambda > 0 and output.shape[0] > 1: 39 | loss_inconsistency = consistency_lambda * inconsistency_loss(output, target, consistency_function, valid_mask, 40 | target_select) 41 | 42 | return loss_cross_entropy, loss_inconsistency 43 | 44 | 45 | def cross_entropy_loss(output, target_select, valid_mask): 46 | pixel_loss = torch.gather(output, dim=2, index=target_select).squeeze(dim=2) 47 | pixel_loss = - torch.log(pixel_loss.clamp(min=1e-10)) # clamp: values smaller than 1e-10 become 1e-10 48 | pixel_loss = pixel_loss * valid_mask.to(dtype=torch.float32) # without ignore pixels 49 | total_loss = pixel_loss.sum() 50 | return total_loss / valid_mask.sum().to(dtype=torch.float32) # normalize 51 | 52 | 53 | def inconsistency_loss(output, target, consistency_function, valid_mask, target_select): 54 | pred = torch.argmax(output, dim=2).to(dtype=target.dtype) 55 | valid_mask_sum = torch.tensor([0.0], dtype=torch.float32, device=output.device) 56 | inconsistencies_sum = torch.tensor([0.0], dtype=torch.float32, device=output.device) 57 | 58 | for t in range(output.shape[0] - 1): 59 | gt1 = target[t] 60 | gt2 = target[t + 1] 61 | valid_mask2 = valid_mask[t] & valid_mask[t + 1] # valid mask always has to be calculated over 2 imgs 62 | 63 | if consistency_function == 'argmax_pred': 64 | pred1 = pred[t] 65 | pred2 = pred[t + 1] 66 | diff_pred_valid = ((pred1 != pred2) & valid_mask2).to(output.dtype) 67 | elif consistency_function == 'abs_diff': 68 | diff_pred_valid = (torch.abs(output[t] - output[t + 1])).sum(dim=1) * valid_mask2.to(output.dtype) 69 | elif consistency_function == 'sq_diff': 70 | diff_pred_valid = (torch.pow(output[t] - output[t + 1], 2)).sum(dim=1) * valid_mask2.to(output.dtype) 71 | elif consistency_function == 'abs_diff_true': 72 | pred1 = pred[t] 73 | pred2 = pred[t + 1] 74 | right_pred_mask = (pred1 == gt1) | (pred2 == gt2) 75 | diff_pred = torch.abs(output[t] - output[t + 1]) 76 | diff_pred_true = torch.gather(diff_pred, dim=1, index=target_select[t]).squeeze(dim=1) 77 | diff_pred_valid = diff_pred_true * (valid_mask2 & right_pred_mask).to(dtype=output.dtype) 78 | elif consistency_function == 'sq_diff_true': 79 | pred1 = pred[t] 80 | pred2 = pred[t + 1] 81 | right_pred_mask = (pred1 == gt1) | (pred2 == gt2) 82 | diff_pred = torch.pow(output[t] - output[t + 1], 2) 83 | diff_pred_true = torch.gather(diff_pred, dim=1, index=target_select[t]).squeeze(dim=1) 84 | diff_pred_valid = diff_pred_true * (valid_mask2 & right_pred_mask).to(dtype=output.dtype) 85 | elif consistency_function == 'sq_diff_true_XOR': 86 | pred1 = pred[t] 87 | pred2 = pred[t + 1] 88 | right_pred_mask = (pred1 == gt1) ^ (pred2 == gt2) 89 | diff_pred = torch.pow(output[t] - output[t + 1], 2) 90 | diff_pred_true = torch.gather(diff_pred, dim=1, index=target_select[t]).squeeze(dim=1) 91 | diff_pred_valid = diff_pred_true * (valid_mask2 & right_pred_mask).to(dtype=output.dtype) 92 | elif consistency_function == 'abs_diff_th20': 93 | th_mask = (output[t] > 0.2) & (output[t + 1] > 0.2) 94 | diff_pred_valid = (torch.abs((output[t] - output[t + 1]) * th_mask.to(dtype=output.dtype))).sum( 95 | dim=1) * valid_mask2.to(output.dtype) 96 | 97 | diff_gt_valid = ((gt1 != gt2) & valid_mask2) # torch.uint8 98 | diff_gt_valid_dil = sp_img.binary_dilation(diff_gt_valid.cpu().numpy(), 99 | iterations=2) # default: 4-neighbourhood 100 | inconsistencies = diff_pred_valid * torch.from_numpy(np.logical_not(diff_gt_valid_dil).astype(np.uint8)).to( 101 | output.device, dtype=output.dtype) 102 | valid_mask_sum += valid_mask2.sum() 103 | inconsistencies_sum += inconsistencies.sum() 104 | 105 | return inconsistencies_sum / valid_mask_sum 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.6.20 2 | cycler==0.10.0 3 | kiwisolver==1.2.0 4 | matplotlib==3.3.1 5 | numpy==1.17.4 6 | Pillow==7.2.0 7 | protobuf==3.13.0 8 | pyparsing==2.4.7 9 | python-dateutil==2.8.1 10 | PyYAML==5.2 11 | six==1.15.0 12 | tensorboardX==1.9 13 | torch==1.3.1 14 | -------------------------------------------------------------------------------- /resources/esp_vs_our_model.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrebol/f2f-consistent-semantic-segmentation/f7e362c5cf0d3461435e8af4f334cca974e521dc/resources/esp_vs_our_model.gif -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # This file is part of f2fcss. 2 | # 3 | # Copyright (C) 2020 Manuel Rebol 4 | # Patrick Knoebelreiter 5 | # Institute for Computer Graphics and Vision, Graz University of Technology 6 | # https://www.tugraz.at/institute/icg/teams/team-pock/ 7 | # 8 | # f2fcss is free software: you can redistribute it and/or modify it under the 9 | # terms of the GNU Affero General Public License as published by the Free Software 10 | # Foundation, either version 3 of the License, or any later version. 11 | # 12 | # f2fcss is distributed in the hope that it will be useful, but WITHOUT ANY 13 | # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 14 | # FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Affero General Public License 17 | # along with this program. If not, see . 18 | # 19 | from os import path, makedirs, walk 20 | import matplotlib.pyplot as plt 21 | from tensorboardX import SummaryWriter 22 | import shutil 23 | 24 | 25 | class BColors: 26 | HEADER = '\033[95m' 27 | OKBLUE = '\033[94m' 28 | OKGREEN = '\033[92m' 29 | WARNING = '\033[93m' 30 | FAIL = '\033[91m' 31 | ENDC = '\033[0m' 32 | BOLD = '\033[1m' 33 | UNDERLINE = '\033[4m' 34 | 35 | BLACK = '\u001b[30m' 36 | RED = '\u001b[31m' 37 | GREEN = '\u001b[32m' 38 | YELLOW = '\u001b[33m' 39 | BLUE = '\u001b[34m' 40 | MAGENTA = '\u001b[35m' 41 | CYAN = '\u001b[36m' 42 | WHITE = '\u001b[37m' 43 | RESET = '\u001b[0m' 44 | 45 | Bright_Black = '\u001b[30;1m' 46 | Bright_Red = '\u001b[31;1m' 47 | Bright_Green = '\u001b[32;1m' 48 | Bright_Yellow = '\u001b[33;1m' 49 | Bright_Blue = '\u001b[34;1m' 50 | Bright_Magenta = '\u001b[35;1m' 51 | Bright_Cyan = '\u001b[36;1m' 52 | Bright_White = '\u001b[37;1m' 53 | 54 | 55 | def recursive_glob(rootdir=".", suffix=""): 56 | return [ 57 | path.join(looproot, filename) 58 | for looproot, _, filenames in walk(rootdir) 59 | for filename in filenames 60 | if filename.endswith(suffix) 61 | ] 62 | 63 | 64 | def filenames_in_dir(dir=".", suffix=""): 65 | return [ 66 | filename 67 | for _, _, filenames in walk(dir) 68 | for filename in filenames 69 | if filename.endswith(suffix) 70 | ] 71 | 72 | 73 | def show_val(pred, gt, bs, loader): 74 | f, axarr = plt.subplots(bs, 2) 75 | if bs > 1: 76 | for j in range(bs): 77 | axarr[j][0].imshow(loader.segmap_to_color(pred[j])) 78 | axarr[j][1].imshow(loader.segmap_to_color(gt[j])) 79 | else: 80 | axarr[0].imshow(loader.segmap_to_color(pred[0])) 81 | axarr[1].imshow(loader.segmap_to_color(gt[0])) 82 | plt.show() 83 | 84 | 85 | def save_val(imgs, pred, gt, bs, loader, batch_total_nr, valid_img_nr, experiment_path, tensorboard): 86 | f, axarr = plt.subplots(bs, 3) 87 | if bs > 1: 88 | for j in range(bs): 89 | axarr[j][0].imshow(loader.segmap_to_color(pred[j])) 90 | axarr[j][1].imshow(loader.segmap_to_color(gt[j])) 91 | axarr[j][2].imshow(imgs[0].permute(1, 2, 0)) 92 | else: 93 | axarr[0].imshow(loader.segmap_to_color(pred[0])) 94 | axarr[1].imshow(loader.segmap_to_color(gt[0])) 95 | axarr[2].imshow(imgs[0].permute(1, 2, 0)) 96 | plt.savefig(path.join(experiment_path, "val", 97 | "batch_nr_{:06}_val_img_{:02}.png".format(batch_total_nr, valid_img_nr))) 98 | plt.close() 99 | 100 | 101 | def mkdir_p(mypath): 102 | makedirs(mypath, exist_ok=True) 103 | 104 | 105 | 106 | def load_tensorboard(dir, batch_total): 107 | if path.isdir(path.join(dir, "tensorboard", "bn-{:06}".format(batch_total))): 108 | ex = "Tensorboard folder {} already exists!".format( 109 | path.join(dir, "tensorboard", "bn-{:06}".format(batch_total))) 110 | print(ex) 111 | response = input("Remove existing? (y/n): ") 112 | if response == 'y': 113 | shutil.rmtree(path.join(dir, "tensorboard", "bn-{:06}".format(batch_total))) 114 | mkdir_p(path.join(dir, "tensorboard", "bn-{:06}".format(batch_total))) 115 | else: 116 | raise Exception(ex) # remove manually 117 | return SummaryWriter(path.join(dir, "tensorboard", "bn-{:06}".format(batch_total))) 118 | --------------------------------------------------------------------------------