├── README.md ├── args.py ├── data ├── README.md ├── __init__.py ├── camvid.py └── utils.py ├── functional.py ├── main.py ├── models ├── SparseConvNet.py └── enet.py ├── test.py ├── train.py ├── transforms.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Sparsity-Invariant-CNNs-pytorch 2 | Reproduced codes for paper “Sparsity Invariant CNNs” 3 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def get_arguments(): 5 | """Defines command-line arguments, and parses them. 6 | 7 | """ 8 | parser = ArgumentParser() 9 | 10 | # Execution mode 11 | parser.add_argument( 12 | "--mode", 13 | "-m", 14 | choices=['train', 'test', 'full'], 15 | default='train', 16 | help=("train: performs training and validation; test: tests the model " 17 | "found in \"--save_dir\" with name \"--name\" on \"--dataset\"; " 18 | "full: combines train and test modes. Default: train")) 19 | parser.add_argument( 20 | "--resume", 21 | action='store_true', 22 | help=("The model found in \"--checkpoint_dir/--name/\" and filename " 23 | "\"--name.h5\" is loaded.")) 24 | 25 | # Hyperparameters 26 | parser.add_argument( 27 | "--batch-size", 28 | "-b", 29 | type=int, 30 | default=1, 31 | help="The batch size. Default: 10") 32 | parser.add_argument( 33 | "--epochs", 34 | type=int, 35 | default=300, 36 | help="Number of training epochs. Default: 300") 37 | parser.add_argument( 38 | "--learning-rate", 39 | "-lr", 40 | type=float, 41 | default=1e-4, 42 | help="The learning rate. Default: 5e-4") 43 | parser.add_argument( 44 | "--lr-decay", 45 | type=float, 46 | default=0.1, 47 | help="The learning rate decay factor. Default: 0.5") 48 | parser.add_argument( 49 | "--lr-decay-epochs", 50 | type=int, 51 | default=100, 52 | help="The number of epochs before adjusting the learning rate. " 53 | "Default: 100") 54 | parser.add_argument( 55 | "--weight-decay", 56 | "-wd", 57 | type=float, 58 | default=2e-4, 59 | help="L2 regularization factor. Default: 2e-4") 60 | 61 | # Dataset 62 | parser.add_argument( 63 | "--dataset", 64 | choices=['camvid', 'cityscapes'], 65 | default='camvid', 66 | help="Dataset to use. Default: camvid") 67 | parser.add_argument( 68 | "--dataset-dir", 69 | type=str, 70 | default="/media/usr515/26C0245EC0243709/cxy/1909/data/srdata", 71 | help="Path to the root directory of the selected dataset. " 72 | "Default: data/CamVid") 73 | 74 | 75 | # Settings 76 | parser.add_argument( 77 | "--workers", 78 | type=int, 79 | default=0, 80 | help="Number of subprocesses to use for data loading. Default: 4") 81 | parser.add_argument( 82 | "--print-step", 83 | action='store_true', 84 | help="Print loss every step") 85 | parser.add_argument( 86 | "--imshow-batch", 87 | action='store_true', 88 | help=("Displays batch images when loading the dataset and making " 89 | "predictions.")) 90 | parser.add_argument( 91 | "--device", 92 | default='cuda', 93 | help="Device on which the network will be trained. Default: cuda") 94 | 95 | # Storage settings 96 | parser.add_argument( 97 | "--name", 98 | type=str, 99 | default='ENet', 100 | help="Name given to the model when saving. Default: ENet") 101 | parser.add_argument( 102 | "--save-dir", 103 | type=str, 104 | default='save', 105 | help="The directory where models are saved. Default: save") 106 | 107 | return parser.parse_args() 108 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Supported datasets 2 | 3 | - CamVid 4 | - CityScapes 5 | 6 | Note: When referring to the number of classes, the void/unlabeled class is excluded. 7 | 8 | ## CamVid Dataset 9 | 10 | The Cambridge-driving Labeled Video Database (CamVid) is a collection of over ten minutes of high-quality 30Hz footage with object class semantic labels at 1Hz and in part, 15Hz. Each pixel is associated with one of 32 classes. 11 | 12 | The CamVid dataset supported here is a 12 class version developed by the authors of SegNet. [Download link here](https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid). For actual training, an 11 class version is used - the "road marking" class is combined with the "road" class. 13 | 14 | More detailed information about the CamVid dataset can be found [here](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) and on the [SegNet GitHub repository](https://github.com/alexgkendall/SegNet-Tutorial). 15 | 16 | ## Cityscapes 17 | 18 | Cityscapes is a set of stereo video sequences recorded in streets from 50 different cities with 34 different classes. There are 5000 images with fine annotations and 20000 images coarsely annotated. 19 | 20 | The version supported here is the finely annotated one with 19 classes. 21 | 22 | For more detailed information see the official [website](https://www.cityscapes-dataset.com/) and [repository](https://github.com/mcordts/cityscapesScripts). 23 | 24 | The dataset can be downloaded from https://www.cityscapes-dataset.com/downloads/. At this time, a registration is required to download the data. 25 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .camvid import CamVid 2 | 3 | __all__ = ['CamVid'] 4 | -------------------------------------------------------------------------------- /data/camvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | from . import utils 4 | 5 | class CamVid(data.Dataset): 6 | 7 | def __init__(self, 8 | root_dir, 9 | mode='train', 10 | transform=None, 11 | loader=utils.h5_loader): 12 | self.root_dir = root_dir 13 | self.mode = mode 14 | self.transform = transform 15 | self.loader = loader 16 | self._file_names = self._get_file_names(self.mode.lower()) 17 | 18 | 19 | def __getitem__(self, index): 20 | 21 | names = self._file_names[index] 22 | 23 | data_path = os.path.join(self.root_dir, names[0]) 24 | label_path = os.path.join(self.root_dir, names[1]) 25 | 26 | img, label = self.loader(data_path, label_path) 27 | 28 | if self.mode.lower() == 'train': 29 | img, label = self.transform(img, label) 30 | else: 31 | img = self.transform(img) 32 | label = self.transform(label) 33 | 34 | 35 | return img, label 36 | 37 | def _get_file_names(self, split_name): 38 | assert split_name in ['train', 'val', 'test'] 39 | split = split_name + '.txt' 40 | source = os.path.join(self.root_dir, split) 41 | 42 | file_names = [] 43 | with open(source) as f: 44 | files = f.readlines() 45 | 46 | for item in files: 47 | img_name, gt_name = self._process_item_names(item) 48 | file_names.append([img_name, gt_name]) 49 | 50 | return file_names 51 | 52 | def _process_item_names(self, item): 53 | item = item.strip() 54 | item = item.split('\t') 55 | img_name = item[0] 56 | gt_name = item[1] 57 | 58 | return img_name, gt_name 59 | 60 | def __len__(self): 61 | """Returns the length of the dataset.""" 62 | return len(self._file_names) 63 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import h5py 5 | 6 | def get_files(folder, name_filter=None, extension_filter=None): 7 | """Helper function that returns the list of files in a specified folder 8 | with a specified extension. 9 | 10 | Keyword arguments: 11 | - folder (``string``): The path to a folder. 12 | - name_filter (```string``, optional): The returned files must contain 13 | this substring in their filename. Default: None; files are not filtered. 14 | - extension_filter (``string``, optional): The desired file extension. 15 | Default: None; files are not filtered 16 | 17 | """ 18 | if not os.path.isdir(folder): 19 | raise RuntimeError("\"{0}\" is not a folder.".format(folder)) 20 | 21 | # Filename filter: if not specified don't filter (condition always true); 22 | # otherwise, use a lambda expression to filter out files that do not 23 | # contain "name_filter" 24 | if name_filter is None: 25 | # This looks hackish...there is probably a better way 26 | name_cond = lambda filename: True 27 | else: 28 | name_cond = lambda filename: name_filter in filename 29 | 30 | # Extension filter: if not specified don't filter (condition always true); 31 | # otherwise, use a lambda expression to filter out files whose extension 32 | # is not "extension_filter" 33 | if extension_filter is None: 34 | # This looks hackish...there is probably a better way 35 | ext_cond = lambda filename: True 36 | else: 37 | ext_cond = lambda filename: filename.endswith(extension_filter) 38 | 39 | filtered_files = [] 40 | 41 | # Explore the directory tree to get files that contain "name_filter" and 42 | # with extension "extension_filter" 43 | for path, _, files in os.walk(folder): 44 | files.sort() 45 | for file in files: 46 | if name_cond(file) and ext_cond(file): 47 | full_path = os.path.join(path, file) 48 | filtered_files.append(full_path) 49 | 50 | return filtered_files 51 | 52 | 53 | def h5_loader(data_path, label_path): 54 | 55 | h = h5py.File(data_path,'r') 56 | data = np.array(h['data']).transpose(1,0) 57 | 58 | h = h5py.File(label_path,'r') 59 | label = np.array(h['data']).transpose(1,0) 60 | 61 | data=Image.fromarray(data) 62 | label=Image.fromarray(label) 63 | 64 | return data, label 65 | 66 | -------------------------------------------------------------------------------- /functional.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION 6 | try: 7 | import accimage 8 | except ImportError: 9 | accimage = None 10 | import numpy as np 11 | import numbers 12 | import types 13 | import collections 14 | import warnings 15 | 16 | 17 | def _is_pil_image(img): 18 | if accimage is not None: 19 | return isinstance(img, (Image.Image, accimage.Image)) 20 | else: 21 | return isinstance(img, Image.Image) 22 | 23 | 24 | def _is_tensor_image(img): 25 | return torch.is_tensor(img) and img.ndimension() == 3 26 | 27 | 28 | def _is_numpy_image(img): 29 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 30 | 31 | 32 | def to_tensor(pic): 33 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 34 | 35 | See ``ToTensor`` for more details. 36 | 37 | Args: 38 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 39 | 40 | Returns: 41 | Tensor: Converted image. 42 | """ 43 | if not(_is_pil_image(pic) or _is_numpy_image(pic)): 44 | raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) 45 | 46 | if isinstance(pic, np.ndarray): 47 | # handle numpy array 48 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 49 | # backward compatibility 50 | if isinstance(img, torch.ByteTensor): 51 | return img.float().div(255) 52 | else: 53 | return img 54 | 55 | if accimage is not None and isinstance(pic, accimage.Image): 56 | nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) 57 | pic.copyto(nppic) 58 | return torch.from_numpy(nppic) 59 | 60 | # handle PIL Image 61 | if pic.mode == 'I': 62 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 63 | elif pic.mode == 'I;16': 64 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 65 | elif pic.mode == 'F': 66 | img = torch.from_numpy(np.array(pic, np.float32, copy=False)) 67 | elif pic.mode == '1': 68 | img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) 69 | else: 70 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 71 | # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK 72 | if pic.mode == 'YCbCr': 73 | nchannel = 3 74 | elif pic.mode == 'I;16': 75 | nchannel = 1 76 | else: 77 | nchannel = len(pic.mode) 78 | img = img.view(pic.size[1], pic.size[0], nchannel) 79 | # put it from HWC to CHW format 80 | # yikes, this transpose takes 80% of the loading time/CPU 81 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 82 | if isinstance(img, torch.ByteTensor): 83 | return img.float().div(255) 84 | else: 85 | return img 86 | 87 | 88 | def to_pil_image(pic, mode=None): 89 | """Convert a tensor or an ndarray to PIL Image. 90 | 91 | See :class:`~torchvision.transforms.ToPIlImage` for more details. 92 | 93 | Args: 94 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 95 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 96 | 97 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 98 | 99 | Returns: 100 | PIL Image: Image converted to PIL Image. 101 | """ 102 | if not(_is_numpy_image(pic) or _is_tensor_image(pic)): 103 | raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) 104 | 105 | npimg = pic 106 | if isinstance(pic, torch.FloatTensor): 107 | pic = pic.mul(255).byte() 108 | if torch.is_tensor(pic): 109 | npimg = np.transpose(pic.numpy(), (1, 2, 0)) 110 | 111 | if not isinstance(npimg, np.ndarray): 112 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 113 | 'not {}'.format(type(npimg))) 114 | 115 | if npimg.shape[2] == 1: 116 | expected_mode = None 117 | npimg = npimg[:, :, 0] 118 | if npimg.dtype == np.uint8: 119 | expected_mode = 'L' 120 | elif npimg.dtype == np.int16: 121 | expected_mode = 'I;16' 122 | elif npimg.dtype == np.int32: 123 | expected_mode = 'I' 124 | elif npimg.dtype == np.float32: 125 | expected_mode = 'F' 126 | if mode is not None and mode != expected_mode: 127 | raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" 128 | .format(mode, np.dtype, expected_mode)) 129 | mode = expected_mode 130 | 131 | elif npimg.shape[2] == 4: 132 | permitted_4_channel_modes = ['RGBA', 'CMYK'] 133 | if mode is not None and mode not in permitted_4_channel_modes: 134 | raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) 135 | 136 | if mode is None and npimg.dtype == np.uint8: 137 | mode = 'RGBA' 138 | else: 139 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] 140 | if mode is not None and mode not in permitted_3_channel_modes: 141 | raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) 142 | if mode is None and npimg.dtype == np.uint8: 143 | mode = 'RGB' 144 | 145 | if mode is None: 146 | raise TypeError('Input type {} is not supported'.format(npimg.dtype)) 147 | 148 | return Image.fromarray(npimg, mode=mode) 149 | 150 | 151 | def normalize(tensor, mean, std): 152 | """Normalize a tensor image with mean and standard deviation. 153 | 154 | See ``Normalize`` for more details. 155 | 156 | Args: 157 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 158 | mean (sequence): Sequence of means for each channel. 159 | std (sequence): Sequence of standard deviations for each channely. 160 | 161 | Returns: 162 | Tensor: Normalized Tensor image. 163 | """ 164 | if not _is_tensor_image(tensor): 165 | raise TypeError('tensor is not a torch image.') 166 | # TODO: make efficient 167 | for t, m, s in zip(tensor, mean, std): 168 | t.sub_(m).div_(s) 169 | return tensor 170 | 171 | 172 | def resize(img, size, interpolation=Image.BILINEAR): 173 | """Resize the input PIL Image to the given size. 174 | 175 | Args: 176 | img (PIL Image): Image to be resized. 177 | size (sequence or int): Desired output size. If size is a sequence like 178 | (h, w), the output size will be matched to this. If size is an int, 179 | the smaller edge of the image will be matched to this number maintaing 180 | the aspect ratio. i.e, if height > width, then image will be rescaled to 181 | (size * height / width, size) 182 | interpolation (int, optional): Desired interpolation. Default is 183 | ``PIL.Image.BILINEAR`` 184 | 185 | Returns: 186 | PIL Image: Resized image. 187 | """ 188 | if not _is_pil_image(img): 189 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 190 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 191 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 192 | 193 | if isinstance(size, int): 194 | w, h = img.size 195 | if (w <= h and w == size) or (h <= w and h == size): 196 | return img 197 | if w < h: 198 | ow = size 199 | oh = int(size * h / w) 200 | return img.resize((ow, oh), interpolation) 201 | else: 202 | oh = size 203 | ow = int(size * w / h) 204 | return img.resize((ow, oh), interpolation) 205 | else: 206 | return img.resize(size[::-1], interpolation) 207 | 208 | 209 | def scale(*args, **kwargs): 210 | warnings.warn("The use of the transforms.Scale transform is deprecated, " + 211 | "please use transforms.Resize instead.") 212 | return resize(*args, **kwargs) 213 | 214 | 215 | def pad(img, padding, fill=0, padding_mode='constant'): 216 | """Pad the given PIL Image on all sides with speficified padding mode and fill value. 217 | 218 | Args: 219 | img (PIL Image): Image to be padded. 220 | padding (int or tuple): Padding on each border. If a single int is provided this 221 | is used to pad all borders. If tuple of length 2 is provided this is the padding 222 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 223 | this is the padding for the left, top, right and bottom borders 224 | respectively. 225 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 226 | length 3, it is used to fill R, G, B channels respectively. 227 | This value is only used when the padding_mode is constant 228 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 229 | constant: pads with a constant value, this value is specified with fill 230 | edge: pads with the last value on the edge of the image 231 | reflect: pads with reflection of image (without repeating the last value on the edge) 232 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 233 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 234 | symmetric: pads with reflection of image (repeating the last value on the edge) 235 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 236 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 237 | 238 | Returns: 239 | PIL Image: Padded image. 240 | """ 241 | if not _is_pil_image(img): 242 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 243 | 244 | if not isinstance(padding, (numbers.Number, tuple)): 245 | raise TypeError('Got inappropriate padding arg') 246 | if not isinstance(fill, (numbers.Number, str, tuple)): 247 | raise TypeError('Got inappropriate fill arg') 248 | if not isinstance(padding_mode, str): 249 | raise TypeError('Got inappropriate padding_mode arg') 250 | 251 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 252 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 253 | "{} element tuple".format(len(padding))) 254 | 255 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ 256 | 'Padding mode should be either constant, edge, reflect or symmetric' 257 | 258 | if padding_mode == 'constant': 259 | return ImageOps.expand(img, border=padding, fill=fill) 260 | else: 261 | if isinstance(padding, int): 262 | pad_left = pad_right = pad_top = pad_bottom = padding 263 | if isinstance(padding, collections.Sequence) and len(padding) == 2: 264 | pad_left = pad_right = padding[0] 265 | pad_top = pad_bottom = padding[1] 266 | if isinstance(padding, collections.Sequence) and len(padding) == 4: 267 | pad_left = padding[0] 268 | pad_top = padding[1] 269 | pad_right = padding[2] 270 | pad_bottom = padding[3] 271 | 272 | img = np.asarray(img) 273 | # RGB image 274 | if len(img.shape) == 3: 275 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) 276 | # Grayscale image 277 | if len(img.shape) == 2: 278 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) 279 | 280 | return Image.fromarray(img) 281 | 282 | 283 | def crop(img, i, j, h, w): 284 | """Crop the given PIL Image. 285 | 286 | Args: 287 | img (PIL Image): Image to be cropped. 288 | i: Upper pixel coordinate. 289 | j: Left pixel coordinate. 290 | h: Height of the cropped image. 291 | w: Width of the cropped image. 292 | 293 | Returns: 294 | PIL Image: Cropped image. 295 | """ 296 | if not _is_pil_image(img): 297 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 298 | 299 | return img.crop((j, i, j + w, i + h)) 300 | 301 | 302 | def center_crop(img, output_size): 303 | if isinstance(output_size, numbers.Number): 304 | output_size = (int(output_size), int(output_size)) 305 | w, h = img.size 306 | th, tw = output_size 307 | i = int(round((h - th) / 2.)) 308 | j = int(round((w - tw) / 2.)) 309 | return crop(img, i, j, th, tw) 310 | 311 | 312 | def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): 313 | """Crop the given PIL Image and resize it to desired size. 314 | 315 | Notably used in RandomResizedCrop. 316 | 317 | Args: 318 | img (PIL Image): Image to be cropped. 319 | i: Upper pixel coordinate. 320 | j: Left pixel coordinate. 321 | h: Height of the cropped image. 322 | w: Width of the cropped image. 323 | size (sequence or int): Desired output size. Same semantics as ``scale``. 324 | interpolation (int, optional): Desired interpolation. Default is 325 | ``PIL.Image.BILINEAR``. 326 | Returns: 327 | PIL Image: Cropped image. 328 | """ 329 | assert _is_pil_image(img), 'img should be PIL Image' 330 | img = crop(img, i, j, h, w) 331 | img = resize(img, size, interpolation) 332 | return img 333 | 334 | 335 | def hflip(img): 336 | """Horizontally flip the given PIL Image. 337 | 338 | Args: 339 | img (PIL Image): Image to be flipped. 340 | 341 | Returns: 342 | PIL Image: Horizontall flipped image. 343 | """ 344 | if not _is_pil_image(img): 345 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 346 | 347 | return img.transpose(Image.FLIP_LEFT_RIGHT) 348 | 349 | 350 | def vflip(img): 351 | """Vertically flip the given PIL Image. 352 | 353 | Args: 354 | img (PIL Image): Image to be flipped. 355 | 356 | Returns: 357 | PIL Image: Vertically flipped image. 358 | """ 359 | if not _is_pil_image(img): 360 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 361 | 362 | return img.transpose(Image.FLIP_TOP_BOTTOM) 363 | 364 | 365 | def five_crop(img, size): 366 | """Crop the given PIL Image into four corners and the central crop. 367 | 368 | .. Note:: 369 | This transform returns a tuple of images and there may be a 370 | mismatch in the number of inputs and targets your ``Dataset`` returns. 371 | 372 | Args: 373 | size (sequence or int): Desired output size of the crop. If size is an 374 | int instead of sequence like (h, w), a square crop (size, size) is 375 | made. 376 | Returns: 377 | tuple: tuple (tl, tr, bl, br, center) corresponding top left, 378 | top right, bottom left, bottom right and center crop. 379 | """ 380 | if isinstance(size, numbers.Number): 381 | size = (int(size), int(size)) 382 | else: 383 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 384 | 385 | w, h = img.size 386 | crop_h, crop_w = size 387 | if crop_w > w or crop_h > h: 388 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size, 389 | (h, w))) 390 | tl = img.crop((0, 0, crop_w, crop_h)) 391 | tr = img.crop((w - crop_w, 0, w, crop_h)) 392 | bl = img.crop((0, h - crop_h, crop_w, h)) 393 | br = img.crop((w - crop_w, h - crop_h, w, h)) 394 | center = center_crop(img, (crop_h, crop_w)) 395 | return (tl, tr, bl, br, center) 396 | 397 | 398 | def ten_crop(img, size, vertical_flip=False): 399 | """Crop the given PIL Image into four corners and the central crop plus the 400 | flipped version of these (horizontal flipping is used by default). 401 | 402 | .. Note:: 403 | This transform returns a tuple of images and there may be a 404 | mismatch in the number of inputs and targets your ``Dataset`` returns. 405 | 406 | Args: 407 | size (sequence or int): Desired output size of the crop. If size is an 408 | int instead of sequence like (h, w), a square crop (size, size) is 409 | made. 410 | vertical_flip (bool): Use vertical flipping instead of horizontal 411 | 412 | Returns: 413 | tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, 414 | br_flip, center_flip) corresponding top left, top right, 415 | bottom left, bottom right and center crop and same for the 416 | flipped image. 417 | """ 418 | if isinstance(size, numbers.Number): 419 | size = (int(size), int(size)) 420 | else: 421 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 422 | 423 | first_five = five_crop(img, size) 424 | 425 | if vertical_flip: 426 | img = vflip(img) 427 | else: 428 | img = hflip(img) 429 | 430 | second_five = five_crop(img, size) 431 | return first_five + second_five 432 | 433 | 434 | def adjust_brightness(img, brightness_factor): 435 | """Adjust brightness of an Image. 436 | 437 | Args: 438 | img (PIL Image): PIL Image to be adjusted. 439 | brightness_factor (float): How much to adjust the brightness. Can be 440 | any non negative number. 0 gives a black image, 1 gives the 441 | original image while 2 increases the brightness by a factor of 2. 442 | 443 | Returns: 444 | PIL Image: Brightness adjusted image. 445 | """ 446 | if not _is_pil_image(img): 447 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 448 | 449 | enhancer = ImageEnhance.Brightness(img) 450 | img = enhancer.enhance(brightness_factor) 451 | return img 452 | 453 | 454 | def adjust_contrast(img, contrast_factor): 455 | """Adjust contrast of an Image. 456 | 457 | Args: 458 | img (PIL Image): PIL Image to be adjusted. 459 | contrast_factor (float): How much to adjust the contrast. Can be any 460 | non negative number. 0 gives a solid gray image, 1 gives the 461 | original image while 2 increases the contrast by a factor of 2. 462 | 463 | Returns: 464 | PIL Image: Contrast adjusted image. 465 | """ 466 | if not _is_pil_image(img): 467 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 468 | 469 | enhancer = ImageEnhance.Contrast(img) 470 | img = enhancer.enhance(contrast_factor) 471 | return img 472 | 473 | 474 | def adjust_saturation(img, saturation_factor): 475 | """Adjust color saturation of an image. 476 | 477 | Args: 478 | img (PIL Image): PIL Image to be adjusted. 479 | saturation_factor (float): How much to adjust the saturation. 0 will 480 | give a black and white image, 1 will give the original image while 481 | 2 will enhance the saturation by a factor of 2. 482 | 483 | Returns: 484 | PIL Image: Saturation adjusted image. 485 | """ 486 | if not _is_pil_image(img): 487 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 488 | 489 | enhancer = ImageEnhance.Color(img) 490 | img = enhancer.enhance(saturation_factor) 491 | return img 492 | 493 | 494 | def adjust_hue(img, hue_factor): 495 | """Adjust hue of an image. 496 | 497 | The image hue is adjusted by converting the image to HSV and 498 | cyclically shifting the intensities in the hue channel (H). 499 | The image is then converted back to original image mode. 500 | 501 | `hue_factor` is the amount of shift in H channel and must be in the 502 | interval `[-0.5, 0.5]`. 503 | 504 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 505 | 506 | Args: 507 | img (PIL Image): PIL Image to be adjusted. 508 | hue_factor (float): How much to shift the hue channel. Should be in 509 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 510 | HSV space in positive and negative direction respectively. 511 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 512 | with complementary colors while 0 gives the original image. 513 | 514 | Returns: 515 | PIL Image: Hue adjusted image. 516 | """ 517 | if not(-0.5 <= hue_factor <= 0.5): 518 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 519 | 520 | if not _is_pil_image(img): 521 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 522 | 523 | input_mode = img.mode 524 | if input_mode in {'L', '1', 'I', 'F'}: 525 | return img 526 | 527 | h, s, v = img.convert('HSV').split() 528 | 529 | np_h = np.array(h, dtype=np.uint8) 530 | # uint8 addition take cares of rotation across boundaries 531 | with np.errstate(over='ignore'): 532 | np_h += np.uint8(hue_factor * 255) 533 | h = Image.fromarray(np_h, 'L') 534 | 535 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 536 | return img 537 | 538 | 539 | def adjust_gamma(img, gamma, gain=1): 540 | """Perform gamma correction on an image. 541 | 542 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 543 | based on the following equation: 544 | 545 | I_out = 255 * gain * ((I_in / 255) ** gamma) 546 | 547 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 548 | 549 | Args: 550 | img (PIL Image): PIL Image to be adjusted. 551 | gamma (float): Non negative real number. gamma larger than 1 make the 552 | shadows darker, while gamma smaller than 1 make dark regions 553 | lighter. 554 | gain (float): The constant multiplier. 555 | """ 556 | if not _is_pil_image(img): 557 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 558 | 559 | if gamma < 0: 560 | raise ValueError('Gamma should be a non-negative real number') 561 | 562 | input_mode = img.mode 563 | img = img.convert('RGB') 564 | 565 | gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 566 | img = img.point(gamma_map) # use PIL's point-function to accelerate this part 567 | 568 | img = img.convert(input_mode) 569 | return img 570 | 571 | 572 | def rotate(img, angle, resample=False, expand=False, center=None): 573 | """Rotate the image by angle. 574 | 575 | 576 | Args: 577 | img (PIL Image): PIL Image to be rotated. 578 | angle ({float, int}): In degrees degrees counter clockwise order. 579 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 580 | An optional resampling filter. 581 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 582 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 583 | expand (bool, optional): Optional expansion flag. 584 | If true, expands the output image to make it large enough to hold the entire rotated image. 585 | If false or omitted, make the output image the same size as the input image. 586 | Note that the expand flag assumes rotation around the center and no translation. 587 | center (2-tuple, optional): Optional center of rotation. 588 | Origin is the upper left corner. 589 | Default is the center of the image. 590 | """ 591 | 592 | if not _is_pil_image(img): 593 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 594 | 595 | return img.rotate(angle, resample, expand, center) 596 | 597 | 598 | def _get_inverse_affine_matrix(center, angle, translate, scale, shear): 599 | # Helper method to compute inverse matrix for affine transformation 600 | 601 | # As it is explained in PIL.Image.rotate 602 | # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 603 | # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] 604 | # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] 605 | # RSS is rotation with scale and shear matrix 606 | # RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0] 607 | # [ sin(a)*scale cos(a + shear)*scale 0] 608 | # [ 0 0 1] 609 | # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 610 | 611 | angle = math.radians(angle) 612 | shear = math.radians(shear) 613 | scale = 1.0 / scale 614 | 615 | # Inverted rotation matrix with scale and shear 616 | d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle) 617 | matrix = [ 618 | math.cos(angle + shear), math.sin(angle + shear), 0, 619 | -math.sin(angle), math.cos(angle), 0 620 | ] 621 | matrix = [scale / d * m for m in matrix] 622 | 623 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 624 | matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1]) 625 | matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1]) 626 | 627 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1 628 | matrix[2] += center[0] 629 | matrix[5] += center[1] 630 | return matrix 631 | 632 | 633 | def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): 634 | """Apply affine transformation on the image keeping image center invariant 635 | 636 | Args: 637 | img (PIL Image): PIL Image to be rotated. 638 | angle ({float, int}): rotation angle in degrees between -180 and 180, clockwise direction. 639 | translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) 640 | scale (float): overall scale 641 | shear (float): shear angle value in degrees between -180 to 180, clockwise direction. 642 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 643 | An optional resampling filter. 644 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 645 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 646 | fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) 647 | """ 648 | if not _is_pil_image(img): 649 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 650 | 651 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 652 | "Argument translate should be a list or tuple of length 2" 653 | 654 | assert scale > 0.0, "Argument scale should be positive" 655 | 656 | output_size = img.size 657 | center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) 658 | matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) 659 | kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] == '5' else {} 660 | return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs) 661 | 662 | 663 | def to_grayscale(img, num_output_channels=1): 664 | """Convert image to grayscale version of image. 665 | 666 | Args: 667 | img (PIL Image): Image to be converted to grayscale. 668 | 669 | Returns: 670 | PIL Image: Grayscale version of the image. 671 | if num_output_channels == 1 : returned image is single channel 672 | if num_output_channels == 3 : returned image is 3 channel with r == g == b 673 | """ 674 | if not _is_pil_image(img): 675 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 676 | 677 | if num_output_channels == 1: 678 | img = img.convert('L') 679 | elif num_output_channels == 3: 680 | img = img.convert('L') 681 | np_img = np.array(img, dtype=np.uint8) 682 | np_img = np.dstack([np_img, np_img, np_img]) 683 | img = Image.fromarray(np_img, 'RGB') 684 | else: 685 | raise ValueError('num_output_channels should be either 1 or 3') 686 | 687 | return img 688 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.optim.lr_scheduler as lr_scheduler 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | 10 | from PIL import Image 11 | 12 | import transforms as ext_transforms 13 | from models.SparseConvNet import SparseConvNet 14 | from train import Train 15 | from test import Test 16 | from args import get_arguments 17 | import utils 18 | from data import CamVid as dataset 19 | 20 | # Get the arguments 21 | args = get_arguments() 22 | 23 | device = torch.device(args.device) 24 | 25 | 26 | def load_dataset(dataset): 27 | print("\nLoading dataset...\n") 28 | 29 | print("Selected dataset:", args.dataset) 30 | print("Dataset directory:", args.dataset_dir) 31 | print("Save directory:", args.save_dir) 32 | 33 | image_transform = ext_transforms.RandomCrop(336) 34 | val_transform = transforms.ToTensor() 35 | 36 | train_set = dataset( 37 | args.dataset_dir, 38 | transform=image_transform) 39 | train_loader = data.DataLoader( 40 | train_set, 41 | batch_size=args.batch_size, 42 | shuffle=True, 43 | num_workers=args.workers) 44 | 45 | # Load the validation set as tensors 46 | val_set = dataset( 47 | args.dataset_dir, 48 | transform=val_transform, 49 | mode='val') 50 | val_loader = data.DataLoader( 51 | val_set, 52 | batch_size=args.batch_size, 53 | shuffle=False, 54 | num_workers=args.workers) 55 | 56 | # Load the test set as tensors 57 | test_set = dataset( 58 | args.dataset_dir, 59 | transform=val_transform, 60 | mode='test') 61 | test_loader = data.DataLoader( 62 | test_set, 63 | batch_size=args.batch_size, 64 | shuffle=False, 65 | num_workers=args.workers) 66 | 67 | return train_loader, val_loader, test_loader 68 | 69 | 70 | def train(train_loader, val_loader): 71 | print("\nTraining...\n") 72 | 73 | model = SparseConvNet().to(device) 74 | criterion = nn.MSELoss(reduction='none') 75 | 76 | optimizer = optim.Adam( 77 | model.parameters(), 78 | lr=args.learning_rate, 79 | weight_decay=args.weight_decay) 80 | 81 | # Learning rate decay scheduler 82 | lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs, 83 | args.lr_decay) 84 | 85 | # Optionally resume from a checkpoint 86 | if args.resume: 87 | model, optimizer, start_epoch, best_loss = utils.load_checkpoint( 88 | model, optimizer, args.save_dir, args.name) 89 | print("Resuming from model: Start epoch = {0} " 90 | "| Best mean loss = {1:.4f}".format(start_epoch, best_loss)) 91 | else: 92 | start_epoch = 0 93 | best_loss = 1000 94 | 95 | # Start Training 96 | print() 97 | train = Train(model, train_loader, optimizer, criterion, device) 98 | val = Test(model, val_loader, criterion, device) 99 | for epoch in range(start_epoch, args.epochs): 100 | print(">>>> [Epoch: {0:d}] Training".format(epoch)) 101 | 102 | epoch_loss = train.run_epoch(lr_updater, args.print_step) 103 | 104 | print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f}". 105 | format(epoch, epoch_loss)) 106 | 107 | if (epoch + 1) % 1 == 0 or epoch + 1 == args.epochs: 108 | print(">>>> [Epoch: {0:d}] Validation".format(epoch)) 109 | 110 | loss = val.run_epoch(args.print_step) 111 | 112 | print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f}". 113 | format(epoch, loss)) 114 | 115 | # Save the model if it's the best thus far 116 | if loss < best_loss: 117 | print("\nBest model thus far. Saving...\n") 118 | best_loss = loss 119 | utils.save_checkpoint(model, optimizer, epoch + 1, best_loss, 120 | args) 121 | 122 | return model 123 | 124 | 125 | def test(model, test_loader): 126 | print("\nTesting...\n") 127 | 128 | criterion = nn.MSELoss() 129 | 130 | # Test the trained model on the test set 131 | test = Test(model, test_loader, criterion, device) 132 | 133 | print(">>>> Running test dataset") 134 | loss, (iou, miou) = test.run_epoch(args.print_step) 135 | print(">>>> Avg. loss: {0:.4f} | Mean IoU: {1:.4f}".format(loss, miou)) 136 | 137 | 138 | # Run only if this module is being run directly 139 | if __name__ == '__main__': 140 | 141 | # Fail fast if the dataset directory doesn't exist 142 | assert os.path.isdir( 143 | args.dataset_dir), "The directory \"{0}\" doesn't exist.".format( 144 | args.dataset_dir) 145 | 146 | # Fail fast if the saving directory doesn't exist 147 | assert os.path.isdir( 148 | args.save_dir), "The directory \"{0}\" doesn't exist.".format( 149 | args.save_dir) 150 | 151 | train_loader, val_loader, test_loader = load_dataset(dataset) 152 | 153 | if args.mode.lower() in {'train', 'full'}: 154 | model = train(train_loader, val_loader) 155 | 156 | if args.mode.lower() in {'test', 'full'}: 157 | if args.mode.lower() == 'test': 158 | # Intialize a new SparseConvNet model 159 | model = SparseConvNet().to(device) 160 | 161 | # Initialize a optimizer just so we can retrieve the model from the 162 | # checkpoint 163 | optimizer = optim.Adam(model.parameters()) 164 | 165 | # Load the previoulsy saved model state to the SparseConvNet model 166 | model = utils.load_checkpoint(model, optimizer, args.save_dir, 167 | args.name)[0] 168 | 169 | test(model, test_loader) 170 | -------------------------------------------------------------------------------- /models/SparseConvNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class SparseConv(nn.Module): 6 | 7 | def __init__(self, 8 | in_channels, 9 | out_channels, 10 | kernel_size): 11 | super().__init__() 12 | 13 | padding = kernel_size//2 14 | 15 | self.conv = nn.Conv2d( 16 | in_channels, 17 | out_channels, 18 | kernel_size=kernel_size, 19 | padding=padding, 20 | bias=False) 21 | 22 | self.bias = nn.Parameter( 23 | torch.zeros(out_channels), 24 | requires_grad=True) 25 | 26 | self.sparsity = nn.Conv2d( 27 | in_channels, 28 | out_channels, 29 | kernel_size=kernel_size, 30 | padding=padding, 31 | bias=False) 32 | 33 | kernel = torch.FloatTensor(torch.ones([kernel_size, kernel_size])).unsqueeze(0).unsqueeze(0) 34 | 35 | self.sparsity.weight = nn.Parameter( 36 | data=kernel, 37 | requires_grad=False) 38 | 39 | self.relu = nn.ReLU(inplace=True) 40 | 41 | 42 | self.max_pool = nn.MaxPool2d( 43 | kernel_size, 44 | stride=1, 45 | padding=padding) 46 | 47 | 48 | 49 | def forward(self, x, mask): 50 | x = x*mask 51 | x = self.conv(x) 52 | normalizer = 1/(self.sparsity(mask)+1e-8) 53 | x = x * normalizer + self.bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 54 | x = self.relu(x) 55 | 56 | mask = self.max_pool(mask) 57 | 58 | return x, mask 59 | 60 | 61 | 62 | class SparseConvNet(nn.Module): 63 | 64 | def __init__(self): 65 | super().__init__() 66 | 67 | self.SparseLayer1 = SparseConv(1, 16, 11) 68 | self.SparseLayer2 = SparseConv(16, 16, 7) 69 | self.SparseLayer3 = SparseConv(16, 16, 5) 70 | self.SparseLayer4 = SparseConv(16, 16, 3) 71 | self.SparseLayer5 = SparseConv(16, 16, 3) 72 | self.SparseLayer6 = SparseConv(16, 1, 1) 73 | 74 | def forward(self, x, mask): 75 | 76 | x, mask = self.SparseLayer1(x, mask) 77 | x, mask = self.SparseLayer2(x, mask) 78 | x, mask = self.SparseLayer3(x, mask) 79 | x, mask = self.SparseLayer4(x, mask) 80 | x, mask = self.SparseLayer5(x, mask) 81 | x, mask = self.SparseLayer6(x, mask) 82 | 83 | return x 84 | -------------------------------------------------------------------------------- /models/enet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class InitialBlock(nn.Module): 6 | """The initial block is composed of two branches: 7 | 1. a main branch which performs a regular convolution with stride 2; 8 | 2. an extension branch which performs max-pooling. 9 | 10 | Doing both operations in parallel and concatenating their results 11 | allows for efficient downsampling and expansion. The main branch 12 | outputs 13 feature maps while the extension branch outputs 3, for a 13 | total of 16 feature maps after concatenation. 14 | 15 | Keyword arguments: 16 | - in_channels (int): the number of input channels. 17 | - out_channels (int): the number output channels. 18 | - kernel_size (int, optional): the kernel size of the filters used in 19 | the convolution layer. Default: 3. 20 | - padding (int, optional): zero-padding added to both sides of the 21 | input. Default: 0. 22 | - bias (bool, optional): Adds a learnable bias to the output if 23 | ``True``. Default: False. 24 | - relu (bool, optional): When ``True`` ReLU is used as the activation 25 | function; otherwise, PReLU is used. Default: True. 26 | 27 | """ 28 | 29 | def __init__(self, 30 | in_channels, 31 | out_channels, 32 | bias=False, 33 | relu=True): 34 | super().__init__() 35 | 36 | if relu: 37 | activation = nn.ReLU 38 | else: 39 | activation = nn.PReLU 40 | 41 | # Main branch - As stated above the number of output channels for this 42 | # branch is the total minus 3, since the remaining channels come from 43 | # the extension branch 44 | self.main_branch = nn.Conv2d( 45 | in_channels, 46 | out_channels - 3, 47 | kernel_size=3, 48 | stride=2, 49 | padding=1, 50 | bias=bias) 51 | 52 | # Extension branch 53 | self.ext_branch = nn.MaxPool2d(3, stride=2, padding=1) 54 | 55 | # Initialize batch normalization to be used after concatenation 56 | self.batch_norm = nn.BatchNorm2d(out_channels) 57 | 58 | # PReLU layer to apply after concatenating the branches 59 | self.out_activation = activation() 60 | 61 | def forward(self, x): 62 | main = self.main_branch(x) 63 | ext = self.ext_branch(x) 64 | 65 | # Concatenate branches 66 | out = torch.cat((main, ext), 1) 67 | 68 | # Apply batch normalization 69 | out = self.batch_norm(out) 70 | 71 | return self.out_activation(out) 72 | 73 | 74 | class RegularBottleneck(nn.Module): 75 | """Regular bottlenecks are the main building block of ENet. 76 | Main branch: 77 | 1. Shortcut connection. 78 | 79 | Extension branch: 80 | 1. 1x1 convolution which decreases the number of channels by 81 | ``internal_ratio``, also called a projection; 82 | 2. regular, dilated or asymmetric convolution; 83 | 3. 1x1 convolution which increases the number of channels back to 84 | ``channels``, also called an expansion; 85 | 4. dropout as a regularizer. 86 | 87 | Keyword arguments: 88 | - channels (int): the number of input and output channels. 89 | - internal_ratio (int, optional): a scale factor applied to 90 | ``channels`` used to compute the number of 91 | channels after the projection. eg. given ``channels`` equal to 128 and 92 | internal_ratio equal to 2 the number of channels after the projection 93 | is 64. Default: 4. 94 | - kernel_size (int, optional): the kernel size of the filters used in 95 | the convolution layer described above in item 2 of the extension 96 | branch. Default: 3. 97 | - padding (int, optional): zero-padding added to both sides of the 98 | input. Default: 0. 99 | - dilation (int, optional): spacing between kernel elements for the 100 | convolution described in item 2 of the extension branch. Default: 1. 101 | asymmetric (bool, optional): flags if the convolution described in 102 | item 2 of the extension branch is asymmetric or not. Default: False. 103 | - dropout_prob (float, optional): probability of an element to be 104 | zeroed. Default: 0 (no dropout). 105 | - bias (bool, optional): Adds a learnable bias to the output if 106 | ``True``. Default: False. 107 | - relu (bool, optional): When ``True`` ReLU is used as the activation 108 | function; otherwise, PReLU is used. Default: True. 109 | 110 | """ 111 | 112 | def __init__(self, 113 | channels, 114 | internal_ratio=4, 115 | kernel_size=3, 116 | padding=0, 117 | dilation=1, 118 | asymmetric=False, 119 | dropout_prob=0, 120 | bias=False, 121 | relu=True): 122 | super().__init__() 123 | 124 | # Check in the internal_scale parameter is within the expected range 125 | # [1, channels] 126 | if internal_ratio <= 1 or internal_ratio > channels: 127 | raise RuntimeError("Value out of range. Expected value in the " 128 | "interval [1, {0}], got internal_scale={1}." 129 | .format(channels, internal_ratio)) 130 | 131 | internal_channels = channels // internal_ratio 132 | 133 | if relu: 134 | activation = nn.ReLU 135 | else: 136 | activation = nn.PReLU 137 | 138 | # Main branch - shortcut connection 139 | 140 | # Extension branch - 1x1 convolution, followed by a regular, dilated or 141 | # asymmetric convolution, followed by another 1x1 convolution, and, 142 | # finally, a regularizer (spatial dropout). Number of channels is constant. 143 | 144 | # 1x1 projection convolution 145 | self.ext_conv1 = nn.Sequential( 146 | nn.Conv2d( 147 | channels, 148 | internal_channels, 149 | kernel_size=1, 150 | stride=1, 151 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 152 | 153 | # If the convolution is asymmetric we split the main convolution in 154 | # two. Eg. for a 5x5 asymmetric convolution we have two convolution: 155 | # the first is 5x1 and the second is 1x5. 156 | if asymmetric: 157 | self.ext_conv2 = nn.Sequential( 158 | nn.Conv2d( 159 | internal_channels, 160 | internal_channels, 161 | kernel_size=(kernel_size, 1), 162 | stride=1, 163 | padding=(padding, 0), 164 | dilation=dilation, 165 | bias=bias), nn.BatchNorm2d(internal_channels), activation(), 166 | nn.Conv2d( 167 | internal_channels, 168 | internal_channels, 169 | kernel_size=(1, kernel_size), 170 | stride=1, 171 | padding=(0, padding), 172 | dilation=dilation, 173 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 174 | else: 175 | self.ext_conv2 = nn.Sequential( 176 | nn.Conv2d( 177 | internal_channels, 178 | internal_channels, 179 | kernel_size=kernel_size, 180 | stride=1, 181 | padding=padding, 182 | dilation=dilation, 183 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 184 | 185 | # 1x1 expansion convolution 186 | self.ext_conv3 = nn.Sequential( 187 | nn.Conv2d( 188 | internal_channels, 189 | channels, 190 | kernel_size=1, 191 | stride=1, 192 | bias=bias), nn.BatchNorm2d(channels), activation()) 193 | 194 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 195 | 196 | # PReLU layer to apply after adding the branches 197 | self.out_activation = activation() 198 | 199 | def forward(self, x): 200 | # Main branch shortcut 201 | main = x 202 | 203 | # Extension branch 204 | ext = self.ext_conv1(x) 205 | ext = self.ext_conv2(ext) 206 | ext = self.ext_conv3(ext) 207 | ext = self.ext_regul(ext) 208 | 209 | # Add main and extension branches 210 | out = main + ext 211 | 212 | return self.out_activation(out) 213 | 214 | 215 | class DownsamplingBottleneck(nn.Module): 216 | """Downsampling bottlenecks further downsample the feature map size. 217 | 218 | Main branch: 219 | 1. max pooling with stride 2; indices are saved to be used for 220 | unpooling later. 221 | 222 | Extension branch: 223 | 1. 2x2 convolution with stride 2 that decreases the number of channels 224 | by ``internal_ratio``, also called a projection; 225 | 2. regular convolution (by default, 3x3); 226 | 3. 1x1 convolution which increases the number of channels to 227 | ``out_channels``, also called an expansion; 228 | 4. dropout as a regularizer. 229 | 230 | Keyword arguments: 231 | - in_channels (int): the number of input channels. 232 | - out_channels (int): the number of output channels. 233 | - internal_ratio (int, optional): a scale factor applied to ``channels`` 234 | used to compute the number of channels after the projection. eg. given 235 | ``channels`` equal to 128 and internal_ratio equal to 2 the number of 236 | channels after the projection is 64. Default: 4. 237 | - return_indices (bool, optional): if ``True``, will return the max 238 | indices along with the outputs. Useful when unpooling later. 239 | - dropout_prob (float, optional): probability of an element to be 240 | zeroed. Default: 0 (no dropout). 241 | - bias (bool, optional): Adds a learnable bias to the output if 242 | ``True``. Default: False. 243 | - relu (bool, optional): When ``True`` ReLU is used as the activation 244 | function; otherwise, PReLU is used. Default: True. 245 | 246 | """ 247 | 248 | def __init__(self, 249 | in_channels, 250 | out_channels, 251 | internal_ratio=4, 252 | return_indices=False, 253 | dropout_prob=0, 254 | bias=False, 255 | relu=True): 256 | super().__init__() 257 | 258 | # Store parameters that are needed later 259 | self.return_indices = return_indices 260 | 261 | # Check in the internal_scale parameter is within the expected range 262 | # [1, channels] 263 | if internal_ratio <= 1 or internal_ratio > in_channels: 264 | raise RuntimeError("Value out of range. Expected value in the " 265 | "interval [1, {0}], got internal_scale={1}. " 266 | .format(in_channels, internal_ratio)) 267 | 268 | internal_channels = in_channels // internal_ratio 269 | 270 | if relu: 271 | activation = nn.ReLU 272 | else: 273 | activation = nn.PReLU 274 | 275 | # Main branch - max pooling followed by feature map (channels) padding 276 | self.main_max1 = nn.MaxPool2d( 277 | 2, 278 | stride=2, 279 | return_indices=return_indices) 280 | 281 | # Extension branch - 2x2 convolution, followed by a regular, dilated or 282 | # asymmetric convolution, followed by another 1x1 convolution. Number 283 | # of channels is doubled. 284 | 285 | # 2x2 projection convolution with stride 2 286 | self.ext_conv1 = nn.Sequential( 287 | nn.Conv2d( 288 | in_channels, 289 | internal_channels, 290 | kernel_size=2, 291 | stride=2, 292 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 293 | 294 | # Convolution 295 | self.ext_conv2 = nn.Sequential( 296 | nn.Conv2d( 297 | internal_channels, 298 | internal_channels, 299 | kernel_size=3, 300 | stride=1, 301 | padding=1, 302 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 303 | 304 | # 1x1 expansion convolution 305 | self.ext_conv3 = nn.Sequential( 306 | nn.Conv2d( 307 | internal_channels, 308 | out_channels, 309 | kernel_size=1, 310 | stride=1, 311 | bias=bias), nn.BatchNorm2d(out_channels), activation()) 312 | 313 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 314 | 315 | # PReLU layer to apply after concatenating the branches 316 | self.out_activation = activation() 317 | 318 | def forward(self, x): 319 | # Main branch shortcut 320 | if self.return_indices: 321 | main, max_indices = self.main_max1(x) 322 | else: 323 | main = self.main_max1(x) 324 | 325 | # Extension branch 326 | ext = self.ext_conv1(x) 327 | ext = self.ext_conv2(ext) 328 | ext = self.ext_conv3(ext) 329 | ext = self.ext_regul(ext) 330 | 331 | # Main branch channel padding 332 | n, ch_ext, h, w = ext.size() 333 | ch_main = main.size()[1] 334 | padding = torch.zeros(n, ch_ext - ch_main, h, w) 335 | 336 | # Before concatenating, check if main is on the CPU or GPU and 337 | # convert padding accordingly 338 | if main.is_cuda: 339 | padding = padding.cuda() 340 | 341 | # Concatenate 342 | main = torch.cat((main, padding), 1) 343 | 344 | # Add main and extension branches 345 | out = main + ext 346 | 347 | return self.out_activation(out), max_indices 348 | 349 | 350 | class UpsamplingBottleneck(nn.Module): 351 | """The upsampling bottlenecks upsample the feature map resolution using max 352 | pooling indices stored from the corresponding downsampling bottleneck. 353 | 354 | Main branch: 355 | 1. 1x1 convolution with stride 1 that decreases the number of channels by 356 | ``internal_ratio``, also called a projection; 357 | 2. max unpool layer using the max pool indices from the corresponding 358 | downsampling max pool layer. 359 | 360 | Extension branch: 361 | 1. 1x1 convolution with stride 1 that decreases the number of channels by 362 | ``internal_ratio``, also called a projection; 363 | 2. transposed convolution (by default, 3x3); 364 | 3. 1x1 convolution which increases the number of channels to 365 | ``out_channels``, also called an expansion; 366 | 4. dropout as a regularizer. 367 | 368 | Keyword arguments: 369 | - in_channels (int): the number of input channels. 370 | - out_channels (int): the number of output channels. 371 | - internal_ratio (int, optional): a scale factor applied to ``in_channels`` 372 | used to compute the number of channels after the projection. eg. given 373 | ``in_channels`` equal to 128 and ``internal_ratio`` equal to 2 the number 374 | of channels after the projection is 64. Default: 4. 375 | - dropout_prob (float, optional): probability of an element to be zeroed. 376 | Default: 0 (no dropout). 377 | - bias (bool, optional): Adds a learnable bias to the output if ``True``. 378 | Default: False. 379 | - relu (bool, optional): When ``True`` ReLU is used as the activation 380 | function; otherwise, PReLU is used. Default: True. 381 | 382 | """ 383 | 384 | def __init__(self, 385 | in_channels, 386 | out_channels, 387 | internal_ratio=4, 388 | dropout_prob=0, 389 | bias=False, 390 | relu=True): 391 | super().__init__() 392 | 393 | # Check in the internal_scale parameter is within the expected range 394 | # [1, channels] 395 | if internal_ratio <= 1 or internal_ratio > in_channels: 396 | raise RuntimeError("Value out of range. Expected value in the " 397 | "interval [1, {0}], got internal_scale={1}. " 398 | .format(in_channels, internal_ratio)) 399 | 400 | internal_channels = in_channels // internal_ratio 401 | 402 | if relu: 403 | activation = nn.ReLU 404 | else: 405 | activation = nn.PReLU 406 | 407 | # Main branch - max pooling followed by feature map (channels) padding 408 | self.main_conv1 = nn.Sequential( 409 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias), 410 | nn.BatchNorm2d(out_channels)) 411 | 412 | # Remember that the stride is the same as the kernel_size, just like 413 | # the max pooling layers 414 | # self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2) 415 | self.main_unpool1 = torch.nn.Upsample(scale_factor=2) 416 | 417 | # Extension branch - 1x1 convolution, followed by a regular, dilated or 418 | # asymmetric convolution, followed by another 1x1 convolution. Number 419 | # of channels is doubled. 420 | 421 | # 1x1 projection convolution with stride 1 422 | self.ext_conv1 = nn.Sequential( 423 | nn.Conv2d( 424 | in_channels, internal_channels, kernel_size=1, bias=bias), 425 | nn.BatchNorm2d(internal_channels), activation()) 426 | 427 | # Transposed convolution 428 | self.ext_tconv1 = nn.ConvTranspose2d( 429 | internal_channels, 430 | internal_channels, 431 | kernel_size=2, 432 | stride=2, 433 | bias=bias) 434 | self.ext_tconv1_bnorm = nn.BatchNorm2d(internal_channels) 435 | self.ext_tconv1_activation = activation() 436 | 437 | # 1x1 expansion convolution 438 | self.ext_conv2 = nn.Sequential( 439 | nn.Conv2d( 440 | internal_channels, out_channels, kernel_size=1, bias=bias), 441 | nn.BatchNorm2d(out_channels), activation()) 442 | 443 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 444 | 445 | # PReLU layer to apply after concatenating the branches 446 | self.out_activation = activation() 447 | 448 | def forward(self, x, max_indices, output_size): 449 | # Main branch shortcut 450 | main = self.main_conv1(x) 451 | main = self.main_unpool1( 452 | main) 453 | 454 | # Extension branch 455 | ext = self.ext_conv1(x) 456 | ext = self.ext_tconv1(ext, output_size=output_size) 457 | ext = self.ext_tconv1_bnorm(ext) 458 | ext = self.ext_tconv1_activation(ext) 459 | ext = self.ext_conv2(ext) 460 | ext = self.ext_regul(ext) 461 | 462 | # Add main and extension branches 463 | out = main + ext 464 | 465 | return self.out_activation(out) 466 | 467 | 468 | class ENet(nn.Module): 469 | """Generate the ENet model. 470 | 471 | Keyword arguments: 472 | - num_classes (int): the number of classes to segment. 473 | - encoder_relu (bool, optional): When ``True`` ReLU is used as the 474 | activation function in the encoder blocks/layers; otherwise, PReLU 475 | is used. Default: False. 476 | - decoder_relu (bool, optional): When ``True`` ReLU is used as the 477 | activation function in the decoder blocks/layers; otherwise, PReLU 478 | is used. Default: True. 479 | 480 | """ 481 | 482 | def __init__(self, num_classes, encoder_relu=False, decoder_relu=True): 483 | super().__init__() 484 | 485 | self.initial_block = InitialBlock(3, 16, relu=encoder_relu) 486 | 487 | # Stage 1 - Encoder 488 | self.downsample1_0 = DownsamplingBottleneck( 489 | 16, 490 | 64, 491 | return_indices=True, 492 | dropout_prob=0.01, 493 | relu=encoder_relu) 494 | self.regular1_1 = RegularBottleneck( 495 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 496 | self.regular1_2 = RegularBottleneck( 497 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 498 | self.regular1_3 = RegularBottleneck( 499 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 500 | self.regular1_4 = RegularBottleneck( 501 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 502 | 503 | # Stage 2 - Encoder 504 | self.downsample2_0 = DownsamplingBottleneck( 505 | 64, 506 | 128, 507 | return_indices=True, 508 | dropout_prob=0.1, 509 | relu=encoder_relu) 510 | self.regular2_1 = RegularBottleneck( 511 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 512 | self.dilated2_2 = RegularBottleneck( 513 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 514 | self.asymmetric2_3 = RegularBottleneck( 515 | 128, 516 | kernel_size=5, 517 | padding=2, 518 | asymmetric=True, 519 | dropout_prob=0.1, 520 | relu=encoder_relu) 521 | self.dilated2_4 = RegularBottleneck( 522 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 523 | self.regular2_5 = RegularBottleneck( 524 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 525 | self.dilated2_6 = RegularBottleneck( 526 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 527 | self.asymmetric2_7 = RegularBottleneck( 528 | 128, 529 | kernel_size=5, 530 | asymmetric=True, 531 | padding=2, 532 | dropout_prob=0.1, 533 | relu=encoder_relu) 534 | self.dilated2_8 = RegularBottleneck( 535 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 536 | 537 | # Stage 3 - Encoder 538 | self.regular3_0 = RegularBottleneck( 539 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 540 | self.dilated3_1 = RegularBottleneck( 541 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 542 | self.asymmetric3_2 = RegularBottleneck( 543 | 128, 544 | kernel_size=5, 545 | padding=2, 546 | asymmetric=True, 547 | dropout_prob=0.1, 548 | relu=encoder_relu) 549 | self.dilated3_3 = RegularBottleneck( 550 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 551 | self.regular3_4 = RegularBottleneck( 552 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 553 | self.dilated3_5 = RegularBottleneck( 554 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 555 | self.asymmetric3_6 = RegularBottleneck( 556 | 128, 557 | kernel_size=5, 558 | asymmetric=True, 559 | padding=2, 560 | dropout_prob=0.1, 561 | relu=encoder_relu) 562 | self.dilated3_7 = RegularBottleneck( 563 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 564 | 565 | # Stage 4 - Decoder 566 | self.upsample4_0 = UpsamplingBottleneck( 567 | 128, 64, dropout_prob=0.1, relu=decoder_relu) 568 | self.regular4_1 = RegularBottleneck( 569 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 570 | self.regular4_2 = RegularBottleneck( 571 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 572 | 573 | # Stage 5 - Decoder 574 | self.upsample5_0 = UpsamplingBottleneck( 575 | 64, 16, dropout_prob=0.1, relu=decoder_relu) 576 | self.regular5_1 = RegularBottleneck( 577 | 16, padding=1, dropout_prob=0.1, relu=decoder_relu) 578 | self.transposed_conv = nn.ConvTranspose2d( 579 | 16, 580 | num_classes, 581 | kernel_size=3, 582 | stride=2, 583 | padding=1, 584 | bias=False) 585 | 586 | def forward(self, x): 587 | # Initial block 588 | input_size = x.size() 589 | x = self.initial_block(x) 590 | 591 | # Stage 1 - Encoder 592 | stage1_input_size = x.size() 593 | x, max_indices1_0 = self.downsample1_0(x) 594 | x = self.regular1_1(x) 595 | x = self.regular1_2(x) 596 | x = self.regular1_3(x) 597 | x = self.regular1_4(x) 598 | 599 | # Stage 2 - Encoder 600 | stage2_input_size = x.size() 601 | x, max_indices2_0 = self.downsample2_0(x) 602 | x = self.regular2_1(x) 603 | x = self.dilated2_2(x) 604 | x = self.asymmetric2_3(x) 605 | x = self.dilated2_4(x) 606 | x = self.regular2_5(x) 607 | x = self.dilated2_6(x) 608 | x = self.asymmetric2_7(x) 609 | x = self.dilated2_8(x) 610 | 611 | # Stage 3 - Encoder 612 | x = self.regular3_0(x) 613 | x = self.dilated3_1(x) 614 | x = self.asymmetric3_2(x) 615 | x = self.dilated3_3(x) 616 | x = self.regular3_4(x) 617 | x = self.dilated3_5(x) 618 | x = self.asymmetric3_6(x) 619 | x = self.dilated3_7(x) 620 | 621 | # Stage 4 - Decoder 622 | x = self.upsample4_0(x, max_indices2_0, output_size=stage2_input_size) 623 | x = self.regular4_1(x) 624 | x = self.regular4_2(x) 625 | 626 | # Stage 5 - Decoder 627 | x = self.upsample5_0(x, max_indices1_0, output_size=stage1_input_size) 628 | x = self.regular5_1(x) 629 | x = self.transposed_conv(x, output_size=input_size) 630 | 631 | return x 632 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | 4 | class Test: 5 | 6 | def __init__(self, model, data_loader, criterion, device): 7 | self.model = model 8 | self.data_loader = data_loader 9 | self.criterion = criterion 10 | self.device = device 11 | 12 | def run_epoch(self, iteration_loss=False): 13 | 14 | self.model.eval() 15 | epoch_loss = 0.0 16 | for step, batch_data in enumerate(self.data_loader): 17 | # Get the inputs and labels 18 | inputs = batch_data[0].to(self.device) 19 | labels = batch_data[1].to(self.device) 20 | 21 | with torch.no_grad(): 22 | # Forward propagation 23 | mask = (inputs>0).float() 24 | outputs = self.model(inputs, mask) 25 | 26 | 27 | plt.figure() 28 | plt.imshow(inputs[0,0].cpu().detach().numpy()) 29 | plt.figure() 30 | plt.imshow(outputs[0,0].cpu().detach().numpy()) 31 | plt.figure() 32 | plt.imshow((outputs*mask)[0,0].cpu().detach().numpy()) 33 | plt.figure() 34 | plt.imshow(labels[0,0].cpu().detach().numpy()) 35 | plt.show() 36 | 37 | 38 | # Loss computation 39 | loss = (self.criterion(outputs, labels)*mask.detach()).sum()/mask.sum() 40 | 41 | # Keep track of loss for current epoch 42 | epoch_loss += loss.item() 43 | 44 | if iteration_loss: 45 | print("[Step: %d] Iteration loss: %.4f" % (step, loss.item())) 46 | 47 | return epoch_loss / len(self.data_loader) 48 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | class Train: 2 | 3 | def __init__(self, model, data_loader, optim, criterion, device): 4 | self.model = model 5 | self.data_loader = data_loader 6 | self.optim = optim 7 | self.criterion = criterion 8 | self.device = device 9 | 10 | def run_epoch(self, lr_updater, iteration_loss=False): 11 | """Runs an epoch of training. 12 | 13 | Keyword arguments: 14 | - iteration_loss (``bool``, optional): Prints loss at every step. 15 | 16 | Returns: 17 | - The epoch loss (float). 18 | 19 | """ 20 | self.model.train() 21 | epoch_loss = 0.0 22 | for step, batch_data in enumerate(self.data_loader): 23 | 24 | # Get the inputs and labels 25 | inputs = batch_data[0].to(self.device) 26 | labels = batch_data[1].to(self.device) 27 | 28 | # Forward propagation 29 | mask = (inputs>0).float() 30 | outputs = self.model(inputs, mask) 31 | 32 | # Loss computation 33 | loss = (self.criterion(outputs, labels)*mask.detach()).sum()/mask.sum() 34 | 35 | # Backpropagation 36 | self.optim.zero_grad() 37 | loss.backward() 38 | self.optim.step() 39 | lr_updater.step() 40 | 41 | # Keep track of loss for current epoch 42 | epoch_loss += loss.item() 43 | 44 | if iteration_loss: 45 | print("[Step: %d] Iteration loss: %.4f" % (step, loss.item())) 46 | 47 | return epoch_loss / len(self.data_loader) 48 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | import random 5 | from collections import OrderedDict 6 | from torchvision.transforms import ToPILImage 7 | import functional as F 8 | import numbers 9 | 10 | 11 | class RandomCrop(object): 12 | """Crop the given PIL Image at a random location. 13 | 14 | Args: 15 | size (sequence or int): Desired output size of the crop. If size is an 16 | int instead of sequence like (h, w), a square crop (size, size) is 17 | made. 18 | padding (int or sequence, optional): Optional padding on each border 19 | of the image. Default is 0, i.e no padding. If a sequence of length 20 | 4 is provided, it is used to pad left, top, right, bottom borders 21 | respectively. 22 | pad_if_needed (boolean): It will pad the image if smaller than the 23 | desired size to avoid raising an exception. 24 | """ 25 | 26 | def __init__(self, size, padding=0, pad_if_needed=False): 27 | if isinstance(size, numbers.Number): 28 | self.size = (int(size), int(size)) 29 | else: 30 | self.size = size 31 | self.padding = padding 32 | self.pad_if_needed = pad_if_needed 33 | 34 | @staticmethod 35 | def get_params(img, output_size): 36 | """Get parameters for ``crop`` for a random crop. 37 | 38 | Args: 39 | img (PIL Image): Image to be cropped. 40 | output_size (tuple): Expected output size of the crop. 41 | 42 | Returns: 43 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 44 | """ 45 | w, h = img.size 46 | th, tw = output_size 47 | if w == tw and h == th: 48 | return 0, 0, h, w 49 | 50 | i = random.randint(0, h - th) 51 | j = random.randint(0, w - tw) 52 | return i, j, th, tw 53 | 54 | def __call__(self, img, label): 55 | """ 56 | Args: 57 | img (PIL Image): Image to be cropped. 58 | 59 | Returns: 60 | PIL Image: Cropped image. 61 | """ 62 | if self.padding > 0: 63 | img = F.pad(img, self.padding) 64 | 65 | # pad the width if needed 66 | if self.pad_if_needed and img.size[0] < self.size[1]: 67 | img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0)) 68 | # pad the height if needed 69 | if self.pad_if_needed and img.size[1] < self.size[0]: 70 | img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2))) 71 | 72 | i, j, h, w = self.get_params(img, self.size) 73 | 74 | return F.to_tensor(F.crop(img, i, j, h, w)), F.to_tensor(F.crop(label, i, j, h, w)) 75 | 76 | class PILToLongTensor(object): 77 | """Converts a ``PIL Image`` to a ``torch.LongTensor``. 78 | 79 | Code adapted from: http://pytorch.org/docs/master/torchvision/transforms.html?highlight=totensor 80 | 81 | """ 82 | 83 | def __call__(self, pic): 84 | """Performs the conversion from a ``PIL Image`` to a ``torch.LongTensor``. 85 | 86 | Keyword arguments: 87 | - pic (``PIL.Image``): the image to convert to ``torch.LongTensor`` 88 | 89 | Returns: 90 | A ``torch.LongTensor``. 91 | 92 | """ 93 | if not isinstance(pic, Image.Image): 94 | raise TypeError("pic should be PIL Image. Got {}".format( 95 | type(pic))) 96 | 97 | # handle numpy array 98 | if isinstance(pic, np.ndarray): 99 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 100 | # backward compatibility 101 | return img.long() 102 | 103 | # Convert PIL image to ByteTensor 104 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 105 | 106 | # Reshape tensor 107 | nchannel = len(pic.mode) 108 | img = img.view(pic.size[1], pic.size[0], nchannel) 109 | 110 | # Convert to long and squeeze the channels 111 | return img.transpose(0, 1).transpose(0, 112 | 2).contiguous().long().squeeze_() 113 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import os 6 | 7 | 8 | def batch_transform(batch, transform): 9 | transf_slices = [transform(tensor) for tensor in torch.unbind(batch)] 10 | return torch.stack(transf_slices) 11 | 12 | 13 | def save_checkpoint(model, optimizer, epoch, loss, args): 14 | 15 | name = args.name 16 | save_dir = args.save_dir 17 | 18 | assert os.path.isdir( 19 | save_dir), "The directory \"{0}\" doesn't exist.".format(save_dir) 20 | 21 | # Save model 22 | model_path = os.path.join(save_dir, name) 23 | checkpoint = { 24 | 'epoch': epoch, 25 | 'loss': loss, 26 | 'state_dict': model.state_dict(), 27 | 'optimizer': optimizer.state_dict() 28 | } 29 | torch.save(checkpoint, model_path) 30 | 31 | # Save arguments 32 | summary_filename = os.path.join(save_dir, name + '_summary.txt') 33 | with open(summary_filename, 'w') as summary_file: 34 | sorted_args = sorted(vars(args)) 35 | summary_file.write("ARGUMENTS\n") 36 | for arg in sorted_args: 37 | arg_str = "{0}: {1}\n".format(arg, getattr(args, arg)) 38 | summary_file.write(arg_str) 39 | 40 | summary_file.write("\nBEST VALIDATION\n") 41 | summary_file.write("Epoch: {0}\n". format(epoch)) 42 | summary_file.write("Mean IoU: {0}\n". format(loss)) 43 | 44 | 45 | def load_checkpoint(model, optimizer, folder_dir, filename): 46 | 47 | assert os.path.isdir( 48 | folder_dir), "The directory \"{0}\" doesn't exist.".format(folder_dir) 49 | 50 | # Create folder to save model and information 51 | model_path = os.path.join(folder_dir, filename) 52 | assert os.path.isfile( 53 | model_path), "The model file \"{0}\" doesn't exist.".format(filename) 54 | 55 | # Load the stored model parameters to the model instance 56 | checkpoint = torch.load(model_path) 57 | model.load_state_dict(checkpoint['state_dict']) 58 | optimizer.load_state_dict(checkpoint['optimizer']) 59 | epoch = checkpoint['epoch'] 60 | loss = checkpoint['loss'] 61 | 62 | return model, optimizer, epoch, loss 63 | --------------------------------------------------------------------------------