├── requirements_dev.txt ├── save ├── ENet_CamVid │ ├── ENet │ └── ENet_summary.txt ├── ENet_Cityscapes │ ├── ENet │ └── ENet_summary.txt └── README.md ├── data ├── __init__.py ├── README.md ├── camvid.py ├── utils.py └── cityscapes.py ├── metric ├── __init__.py ├── metric.py ├── confusionmatrix.py └── iou.py ├── requirements.txt ├── Dockerfile ├── LICENSE ├── .gitignore ├── test.py ├── train.py ├── transforms.py ├── args.py ├── utils.py ├── README.md ├── main.py └── models └── enet.py /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | flake8 3 | black 4 | -------------------------------------------------------------------------------- /save/ENet_CamVid/ENet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtvs/PyTorch-ENet/HEAD/save/ENet_CamVid/ENet -------------------------------------------------------------------------------- /save/ENet_Cityscapes/ENet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidtvs/PyTorch-ENet/HEAD/save/ENet_Cityscapes/ENet -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .camvid import CamVid 2 | from .cityscapes import Cityscapes 3 | 4 | __all__ = ['CamVid', 'Cityscapes'] 5 | -------------------------------------------------------------------------------- /metric/__init__.py: -------------------------------------------------------------------------------- 1 | from .confusionmatrix import ConfusionMatrix 2 | from .iou import IoU 3 | from .metric import Metric 4 | 5 | __all__ = ['ConfusionMatrix', 'IoU', 'Metric'] 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler>=0.10.0 2 | kiwisolver>=1.0.1 3 | matplotlib>=3.0.2 4 | numpy>=1.16.0 5 | Pillow>=6.2.0 6 | pyparsing>=2.3.1 7 | python-dateutil>=2.7.5 8 | pytz>=2018.9 9 | six>=1.12.0 10 | torch>=1.1.0 11 | torchvision>=0.2.2 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Based on a PyTorch docker image that matches the minimum requirements: PyTorch 1.1.0 2 | FROM pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-runtime 3 | 4 | RUN python -m pip install --upgrade pip 5 | 6 | COPY . /enet 7 | WORKDIR /enet 8 | RUN pip install -r requirements.txt 9 | -------------------------------------------------------------------------------- /metric/metric.py: -------------------------------------------------------------------------------- 1 | class Metric(object): 2 | """Base class for all metrics. 3 | 4 | From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py 5 | """ 6 | def reset(self): 7 | pass 8 | 9 | def add(self): 10 | pass 11 | 12 | def value(self): 13 | pass 14 | -------------------------------------------------------------------------------- /save/ENet_CamVid/ENet_summary.txt: -------------------------------------------------------------------------------- 1 | ARGUMENTS 2 | batch_size: 10 3 | dataset: camvid 4 | dataset_dir: ../CamVid/ 5 | device: cuda 6 | epochs: 300 7 | height: 360 8 | ignore_unlabeled: True 9 | imshow_batch: False 10 | learning_rate: 0.0005 11 | lr_decay: 0.1 12 | lr_decay_epochs: 100 13 | mode: full 14 | name: ENet 15 | print_step: False 16 | resume: False 17 | save_dir: save/new_camvid/ 18 | weighing: ENet 19 | weight_decay: 0.0002 20 | width: 480 21 | workers: 4 22 | 23 | BEST VALIDATION 24 | Epoch: 280 25 | Mean IoU: 0.6518655444842216 26 | -------------------------------------------------------------------------------- /save/ENet_Cityscapes/ENet_summary.txt: -------------------------------------------------------------------------------- 1 | ARGUMENTS 2 | batch_size: 4 3 | dataset: cityscapes 4 | dataset_dir: ../Cityscapes/1024x512/ 5 | device: cuda 6 | epochs: 300 7 | height: 512 8 | ignore_unlabeled: True 9 | imshow_batch: False 10 | learning_rate: 0.0005 11 | lr_decay: 0.1 12 | lr_decay_epochs: 100 13 | mode: train 14 | name: ENet 15 | print_step: False 16 | resume: False 17 | save_dir: save/cityscapes_2/ 18 | weighing: ENet 19 | weight_decay: 0.0002 20 | width: 1024 21 | workers: 4 22 | 23 | BEST VALIDATION 24 | Epoch: 250 25 | Mean IoU: 0.5949690267526815 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 davidtvs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Supported datasets 2 | 3 | - CamVid 4 | - CityScapes 5 | 6 | Note: When referring to the number of classes, the void/unlabeled class is excluded. 7 | 8 | ## CamVid Dataset 9 | 10 | The Cambridge-driving Labeled Video Database (CamVid) is a collection of over ten minutes of high-quality 30Hz footage with object class semantic labels at 1Hz and in part, 15Hz. Each pixel is associated with one of 32 classes. 11 | 12 | The CamVid dataset supported here is a 12 class version developed by the authors of SegNet. [Download link here](https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid). For actual training, an 11 class version is used - the "road marking" class is combined with the "road" class. 13 | 14 | More detailed information about the CamVid dataset can be found [here](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) and on the [SegNet GitHub repository](https://github.com/alexgkendall/SegNet-Tutorial). 15 | 16 | ## Cityscapes 17 | 18 | Cityscapes is a set of stereo video sequences recorded in streets from 50 different cities with 34 different classes. There are 5000 images with fine annotations and 20000 images coarsely annotated. 19 | 20 | The version supported here is the finely annotated one with 19 classes. 21 | 22 | For more detailed information see the official [website](https://www.cityscapes-dataset.com/) and [repository](https://github.com/mcordts/cityscapesScripts). 23 | 24 | The dataset can be downloaded from https://www.cityscapes-dataset.com/downloads/. At this time, a registration is required to download the data. 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | .static_storage/ 57 | .media/ 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # Sublime text project 108 | *.sublime-workspace 109 | *.sublime-project 110 | 111 | # VSCode 112 | .vscode/ 113 | *.code-workspace 114 | .devcontainer 115 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Test: 5 | """Tests the ``model`` on the specified test dataset using the 6 | data loader, and loss criterion. 7 | 8 | Keyword arguments: 9 | - model (``nn.Module``): the model instance to test. 10 | - data_loader (``Dataloader``): Provides single or multi-process 11 | iterators over the dataset. 12 | - criterion (``Optimizer``): The loss criterion. 13 | - metric (```Metric``): An instance specifying the metric to return. 14 | - device (``torch.device``): An object representing the device on which 15 | tensors are allocated. 16 | 17 | """ 18 | 19 | def __init__(self, model, data_loader, criterion, metric, device): 20 | self.model = model 21 | self.data_loader = data_loader 22 | self.criterion = criterion 23 | self.metric = metric 24 | self.device = device 25 | 26 | def run_epoch(self, iteration_loss=False): 27 | """Runs an epoch of validation. 28 | 29 | Keyword arguments: 30 | - iteration_loss (``bool``, optional): Prints loss at every step. 31 | 32 | Returns: 33 | - The epoch loss (float), and the values of the specified metrics 34 | 35 | """ 36 | self.model.eval() 37 | epoch_loss = 0.0 38 | self.metric.reset() 39 | for step, batch_data in enumerate(self.data_loader): 40 | # Get the inputs and labels 41 | inputs = batch_data[0].to(self.device) 42 | labels = batch_data[1].to(self.device) 43 | 44 | with torch.no_grad(): 45 | # Forward propagation 46 | outputs = self.model(inputs) 47 | 48 | # Loss computation 49 | loss = self.criterion(outputs, labels) 50 | 51 | # Keep track of loss for current epoch 52 | epoch_loss += loss.item() 53 | 54 | # Keep track of evaluation the metric 55 | self.metric.add(outputs.detach(), labels.detach()) 56 | 57 | if iteration_loss: 58 | print("[Step: %d] Iteration loss: %.4f" % (step, loss.item())) 59 | 60 | return epoch_loss / len(self.data_loader), self.metric.value() 61 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | class Train: 2 | """Performs the training of ``model`` given a training dataset data 3 | loader, the optimizer, and the loss criterion. 4 | 5 | Keyword arguments: 6 | - model (``nn.Module``): the model instance to train. 7 | - data_loader (``Dataloader``): Provides single or multi-process 8 | iterators over the dataset. 9 | - optim (``Optimizer``): The optimization algorithm. 10 | - criterion (``Optimizer``): The loss criterion. 11 | - metric (```Metric``): An instance specifying the metric to return. 12 | - device (``torch.device``): An object representing the device on which 13 | tensors are allocated. 14 | 15 | """ 16 | 17 | def __init__(self, model, data_loader, optim, criterion, metric, device): 18 | self.model = model 19 | self.data_loader = data_loader 20 | self.optim = optim 21 | self.criterion = criterion 22 | self.metric = metric 23 | self.device = device 24 | 25 | def run_epoch(self, iteration_loss=False): 26 | """Runs an epoch of training. 27 | 28 | Keyword arguments: 29 | - iteration_loss (``bool``, optional): Prints loss at every step. 30 | 31 | Returns: 32 | - The epoch loss (float). 33 | 34 | """ 35 | self.model.train() 36 | epoch_loss = 0.0 37 | self.metric.reset() 38 | for step, batch_data in enumerate(self.data_loader): 39 | # Get the inputs and labels 40 | inputs = batch_data[0].to(self.device) 41 | labels = batch_data[1].to(self.device) 42 | 43 | # Forward propagation 44 | outputs = self.model(inputs) 45 | 46 | # Loss computation 47 | loss = self.criterion(outputs, labels) 48 | 49 | # Backpropagation 50 | self.optim.zero_grad() 51 | loss.backward() 52 | self.optim.step() 53 | 54 | # Keep track of loss for current epoch 55 | epoch_loss += loss.item() 56 | 57 | # Keep track of the evaluation metric 58 | self.metric.add(outputs.detach(), labels.detach()) 59 | 60 | if iteration_loss: 61 | print("[Step: %d] Iteration loss: %.4f" % (step, loss.item())) 62 | 63 | return epoch_loss / len(self.data_loader), self.metric.value() 64 | -------------------------------------------------------------------------------- /save/README.md: -------------------------------------------------------------------------------- 1 | # Pre-trained models 2 | 3 | | Dataset | Classes 1 | Input resolution | Batch size | Epochs | Mean IoU (%) | GPU memory (GiB) | Training time (hours)2 | 4 | | :------------------------------------------------------------------: | :------------------: | :--------------: | :--------: | :----: | :---------------: | :--------------: | :-------------------------------: | 5 | | [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) | 11 | 480x360 | 10 | 300 | 52.13 | 4.2 | 1 | 6 | | [Cityscapes](https://www.cityscapes-dataset.com/) | 19 | 1024x512 | 4 | 300 | 59.54 | 5.4 | 20 | 7 | 8 | ## Per-class IoU: CamVid3 9 | 10 | | | Sky | Building | Pole | Road | Pavement | Tree | Sign Symbol | Fence | Car | Pedestrian | Bicyclist | 11 | | :-----: | :---: | :------: | :---: | :---: | :------: | :---: | :---------: | :---: | :---: | :--------: | :-------: | 12 | | IoU (%) | 90.2 | 68.6 | 22.6 | 91.5 | 73.2 | 63.6 | 19.3 | 16.7 | 65.1 | 27.2 | 35.0 | 13 | 14 | ## Per-class IoU: Cityscapes4 15 | 16 | | | Road | Sidewalk | Building | Wall | Fence | Pole | Traffic light | Traffic Sign | Vegetation | Terrain | Sky | Person | Rider | Car | Truck | Bus | Train | Motorcycle | Bicycle | 17 | | :-----: | :---: | :------: | :------: | :---: | :---: | :---: | :-----------: | :----------: | :--------: | :-----: | :---: | :----: | :---: | :---: | :---: | :---: | :---: | :--------: | :-----: | 18 | | IoU (%) | 96.1 | 73.3 | 85.8 | 44.1 | 40.5 | 45.3 | 42.5 | 53.9 | 87.9 | 53.5 | 90.1 | 62.3 | 44.3 | 87.6 | 46.6 | 58.2 | 34.8 | 25.8 | 57.9 | 19 | 20 | 1 When referring to the number of classes, the void/unlabeled class is always excluded.
21 | 2 These are just for reference. Implementation, datasets, and hardware changes can lead to very different results. Reference hardware: Nvidia GTX 1070 and an AMD Ryzen 5 3600 3.6GHz. You can also train for 100 epochs or so and get similar mean IoU (± 2%).
22 | 3 Test set.
23 | 4 Validation set. 24 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from collections import OrderedDict 5 | from torchvision.transforms import ToPILImage 6 | 7 | 8 | class PILToLongTensor(object): 9 | """Converts a ``PIL Image`` to a ``torch.LongTensor``. 10 | 11 | Code adapted from: http://pytorch.org/docs/master/torchvision/transforms.html?highlight=totensor 12 | 13 | """ 14 | 15 | def __call__(self, pic): 16 | """Performs the conversion from a ``PIL Image`` to a ``torch.LongTensor``. 17 | 18 | Keyword arguments: 19 | - pic (``PIL.Image``): the image to convert to ``torch.LongTensor`` 20 | 21 | Returns: 22 | A ``torch.LongTensor``. 23 | 24 | """ 25 | if not isinstance(pic, Image.Image): 26 | raise TypeError("pic should be PIL Image. Got {}".format( 27 | type(pic))) 28 | 29 | # handle numpy array 30 | if isinstance(pic, np.ndarray): 31 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 32 | # backward compatibility 33 | return img.long() 34 | 35 | # Convert PIL image to ByteTensor 36 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 37 | 38 | # Reshape tensor 39 | nchannel = len(pic.mode) 40 | img = img.view(pic.size[1], pic.size[0], nchannel) 41 | 42 | # Convert to long and squeeze the channels 43 | return img.transpose(0, 1).transpose(0, 44 | 2).contiguous().long().squeeze_() 45 | 46 | 47 | class LongTensorToRGBPIL(object): 48 | """Converts a ``torch.LongTensor`` to a ``PIL image``. 49 | 50 | The input is a ``torch.LongTensor`` where each pixel's value identifies the 51 | class. 52 | 53 | Keyword arguments: 54 | - rgb_encoding (``OrderedDict``): An ``OrderedDict`` that relates pixel 55 | values, class names, and class colors. 56 | 57 | """ 58 | def __init__(self, rgb_encoding): 59 | self.rgb_encoding = rgb_encoding 60 | 61 | def __call__(self, tensor): 62 | """Performs the conversion from ``torch.LongTensor`` to a ``PIL image`` 63 | 64 | Keyword arguments: 65 | - tensor (``torch.LongTensor``): the tensor to convert 66 | 67 | Returns: 68 | A ``PIL.Image``. 69 | 70 | """ 71 | # Check if label_tensor is a LongTensor 72 | if not isinstance(tensor, torch.LongTensor): 73 | raise TypeError("label_tensor should be torch.LongTensor. Got {}" 74 | .format(type(tensor))) 75 | # Check if encoding is a ordered dictionary 76 | if not isinstance(self.rgb_encoding, OrderedDict): 77 | raise TypeError("encoding should be an OrderedDict. Got {}".format( 78 | type(self.rgb_encoding))) 79 | 80 | # label_tensor might be an image without a channel dimension, in this 81 | # case unsqueeze it 82 | if len(tensor.size()) == 2: 83 | tensor.unsqueeze_(0) 84 | 85 | color_tensor = torch.ByteTensor(3, tensor.size(1), tensor.size(2)) 86 | 87 | for index, (class_name, color) in enumerate(self.rgb_encoding.items()): 88 | # Get a mask of elements equal to index 89 | mask = torch.eq(tensor, index).squeeze_() 90 | # Fill color_tensor with corresponding colors 91 | for channel, color_value in enumerate(color): 92 | color_tensor[channel].masked_fill_(mask, color_value) 93 | 94 | return ToPILImage()(color_tensor) 95 | -------------------------------------------------------------------------------- /metric/confusionmatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from metric import metric 4 | 5 | 6 | class ConfusionMatrix(metric.Metric): 7 | """Constructs a confusion matrix for a multi-class classification problems. 8 | 9 | Does not support multi-label, multi-class problems. 10 | 11 | Keyword arguments: 12 | - num_classes (int): number of classes in the classification problem. 13 | - normalized (boolean, optional): Determines whether or not the confusion 14 | matrix is normalized or not. Default: False. 15 | 16 | Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py 17 | """ 18 | 19 | def __init__(self, num_classes, normalized=False): 20 | super().__init__() 21 | 22 | self.conf = np.ndarray((num_classes, num_classes), dtype=np.int64) 23 | self.normalized = normalized 24 | self.num_classes = num_classes 25 | self.reset() 26 | 27 | def reset(self): 28 | self.conf.fill(0) 29 | 30 | def add(self, predicted, target): 31 | """Computes the confusion matrix 32 | 33 | The shape of the confusion matrix is K x K, where K is the number 34 | of classes. 35 | 36 | Keyword arguments: 37 | - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of 38 | predicted scores obtained from the model for N examples and K classes, 39 | or an N-tensor/array of integer values between 0 and K-1. 40 | - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of 41 | ground-truth classes for N examples and K classes, or an N-tensor/array 42 | of integer values between 0 and K-1. 43 | 44 | """ 45 | # If target and/or predicted are tensors, convert them to numpy arrays 46 | if torch.is_tensor(predicted): 47 | predicted = predicted.cpu().numpy() 48 | if torch.is_tensor(target): 49 | target = target.cpu().numpy() 50 | 51 | assert predicted.shape[0] == target.shape[0], \ 52 | 'number of targets and predicted outputs do not match' 53 | 54 | if np.ndim(predicted) != 1: 55 | assert predicted.shape[1] == self.num_classes, \ 56 | 'number of predictions does not match size of confusion matrix' 57 | predicted = np.argmax(predicted, 1) 58 | else: 59 | assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \ 60 | 'predicted values are not between 0 and k-1' 61 | 62 | if np.ndim(target) != 1: 63 | assert target.shape[1] == self.num_classes, \ 64 | 'Onehot target does not match size of confusion matrix' 65 | assert (target >= 0).all() and (target <= 1).all(), \ 66 | 'in one-hot encoding, target values should be 0 or 1' 67 | assert (target.sum(1) == 1).all(), \ 68 | 'multi-label setting is not supported' 69 | target = np.argmax(target, 1) 70 | else: 71 | assert (target.max() < self.num_classes) and (target.min() >= 0), \ 72 | 'target values are not between 0 and k-1' 73 | 74 | # hack for bincounting 2 arrays together 75 | x = predicted + self.num_classes * target 76 | bincount_2d = np.bincount( 77 | x.astype(np.int64), minlength=self.num_classes**2) 78 | assert bincount_2d.size == self.num_classes**2 79 | conf = bincount_2d.reshape((self.num_classes, self.num_classes)) 80 | 81 | self.conf += conf 82 | 83 | def value(self): 84 | """ 85 | Returns: 86 | Confustion matrix of K rows and K columns, where rows corresponds 87 | to ground-truth targets and columns corresponds to predicted 88 | targets. 89 | """ 90 | if self.normalized: 91 | conf = self.conf.astype(np.float32) 92 | return conf / conf.sum(1).clip(min=1e-12)[:, None] 93 | else: 94 | return self.conf 95 | -------------------------------------------------------------------------------- /metric/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from metric import metric 4 | from metric.confusionmatrix import ConfusionMatrix 5 | 6 | 7 | class IoU(metric.Metric): 8 | """Computes the intersection over union (IoU) per class and corresponding 9 | mean (mIoU). 10 | 11 | Intersection over union (IoU) is a common evaluation metric for semantic 12 | segmentation. The predictions are first accumulated in a confusion matrix 13 | and the IoU is computed from it as follows: 14 | 15 | IoU = true_positive / (true_positive + false_positive + false_negative). 16 | 17 | Keyword arguments: 18 | - num_classes (int): number of classes in the classification problem 19 | - normalized (boolean, optional): Determines whether or not the confusion 20 | matrix is normalized or not. Default: False. 21 | - ignore_index (int or iterable, optional): Index of the classes to ignore 22 | when computing the IoU. Can be an int, or any iterable of ints. 23 | """ 24 | 25 | def __init__(self, num_classes, normalized=False, ignore_index=None): 26 | super().__init__() 27 | self.conf_metric = ConfusionMatrix(num_classes, normalized) 28 | 29 | if ignore_index is None: 30 | self.ignore_index = None 31 | elif isinstance(ignore_index, int): 32 | self.ignore_index = (ignore_index,) 33 | else: 34 | try: 35 | self.ignore_index = tuple(ignore_index) 36 | except TypeError: 37 | raise ValueError("'ignore_index' must be an int or iterable") 38 | 39 | def reset(self): 40 | self.conf_metric.reset() 41 | 42 | def add(self, predicted, target): 43 | """Adds the predicted and target pair to the IoU metric. 44 | 45 | Keyword arguments: 46 | - predicted (Tensor): Can be a (N, K, H, W) tensor of 47 | predicted scores obtained from the model for N examples and K classes, 48 | or (N, H, W) tensor of integer values between 0 and K-1. 49 | - target (Tensor): Can be a (N, K, H, W) tensor of 50 | target scores for N examples and K classes, or (N, H, W) tensor of 51 | integer values between 0 and K-1. 52 | 53 | """ 54 | # Dimensions check 55 | assert predicted.size(0) == target.size(0), \ 56 | 'number of targets and predicted outputs do not match' 57 | assert predicted.dim() == 3 or predicted.dim() == 4, \ 58 | "predictions must be of dimension (N, H, W) or (N, K, H, W)" 59 | assert target.dim() == 3 or target.dim() == 4, \ 60 | "targets must be of dimension (N, H, W) or (N, K, H, W)" 61 | 62 | # If the tensor is in categorical format convert it to integer format 63 | if predicted.dim() == 4: 64 | _, predicted = predicted.max(1) 65 | if target.dim() == 4: 66 | _, target = target.max(1) 67 | 68 | self.conf_metric.add(predicted.view(-1), target.view(-1)) 69 | 70 | def value(self): 71 | """Computes the IoU and mean IoU. 72 | 73 | The mean computation ignores NaN elements of the IoU array. 74 | 75 | Returns: 76 | Tuple: (IoU, mIoU). The first output is the per class IoU, 77 | for K classes it's numpy.ndarray with K elements. The second output, 78 | is the mean IoU. 79 | """ 80 | conf_matrix = self.conf_metric.value() 81 | if self.ignore_index is not None: 82 | conf_matrix[:, self.ignore_index] = 0 83 | conf_matrix[self.ignore_index, :] = 0 84 | true_positive = np.diag(conf_matrix) 85 | false_positive = np.sum(conf_matrix, 0) - true_positive 86 | false_negative = np.sum(conf_matrix, 1) - true_positive 87 | 88 | # Just in case we get a division by 0, ignore/hide the error 89 | with np.errstate(divide='ignore', invalid='ignore'): 90 | iou = true_positive / (true_positive + false_positive + false_negative) 91 | 92 | return iou, np.nanmean(iou) 93 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def get_arguments(): 5 | """Defines command-line arguments, and parses them. 6 | 7 | """ 8 | parser = ArgumentParser() 9 | 10 | # Execution mode 11 | parser.add_argument( 12 | "--mode", 13 | "-m", 14 | choices=['train', 'test', 'full'], 15 | default='train', 16 | help=("train: performs training and validation; test: tests the model " 17 | "found in \"--save-dir\" with name \"--name\" on \"--dataset\"; " 18 | "full: combines train and test modes. Default: train")) 19 | parser.add_argument( 20 | "--resume", 21 | action='store_true', 22 | help=("The model found in \"--save-dir/--name/\" and filename " 23 | "\"--name.h5\" is loaded.")) 24 | 25 | # Hyperparameters 26 | parser.add_argument( 27 | "--batch-size", 28 | "-b", 29 | type=int, 30 | default=10, 31 | help="The batch size. Default: 10") 32 | parser.add_argument( 33 | "--epochs", 34 | type=int, 35 | default=300, 36 | help="Number of training epochs. Default: 300") 37 | parser.add_argument( 38 | "--learning-rate", 39 | "-lr", 40 | type=float, 41 | default=5e-4, 42 | help="The learning rate. Default: 5e-4") 43 | parser.add_argument( 44 | "--lr-decay", 45 | type=float, 46 | default=0.1, 47 | help="The learning rate decay factor. Default: 0.5") 48 | parser.add_argument( 49 | "--lr-decay-epochs", 50 | type=int, 51 | default=100, 52 | help="The number of epochs before adjusting the learning rate. " 53 | "Default: 100") 54 | parser.add_argument( 55 | "--weight-decay", 56 | "-wd", 57 | type=float, 58 | default=2e-4, 59 | help="L2 regularization factor. Default: 2e-4") 60 | 61 | # Dataset 62 | parser.add_argument( 63 | "--dataset", 64 | choices=['camvid', 'cityscapes'], 65 | default='camvid', 66 | help="Dataset to use. Default: camvid") 67 | parser.add_argument( 68 | "--dataset-dir", 69 | type=str, 70 | default="data/CamVid", 71 | help="Path to the root directory of the selected dataset. " 72 | "Default: data/CamVid") 73 | parser.add_argument( 74 | "--height", 75 | type=int, 76 | default=360, 77 | help="The image height. Default: 360") 78 | parser.add_argument( 79 | "--width", 80 | type=int, 81 | default=480, 82 | help="The image width. Default: 480") 83 | parser.add_argument( 84 | "--weighing", 85 | choices=['enet', 'mfb', 'none'], 86 | default='ENet', 87 | help="The class weighing technique to apply to the dataset. " 88 | "Default: enet") 89 | parser.add_argument( 90 | "--with-unlabeled", 91 | dest='ignore_unlabeled', 92 | action='store_false', 93 | help="The unlabeled class is not ignored.") 94 | 95 | # Settings 96 | parser.add_argument( 97 | "--workers", 98 | type=int, 99 | default=4, 100 | help="Number of subprocesses to use for data loading. Default: 4") 101 | parser.add_argument( 102 | "--print-step", 103 | action='store_true', 104 | help="Print loss every step") 105 | parser.add_argument( 106 | "--imshow-batch", 107 | action='store_true', 108 | help=("Displays batch images when loading the dataset and making " 109 | "predictions.")) 110 | parser.add_argument( 111 | "--device", 112 | default='cuda', 113 | help="Device on which the network will be trained. Default: cuda") 114 | 115 | # Storage settings 116 | parser.add_argument( 117 | "--name", 118 | type=str, 119 | default='ENet', 120 | help="Name given to the model when saving. Default: ENet") 121 | parser.add_argument( 122 | "--save-dir", 123 | type=str, 124 | default='save', 125 | help="The directory where models are saved. Default: save") 126 | 127 | return parser.parse_args() 128 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import os 6 | 7 | 8 | def batch_transform(batch, transform): 9 | """Applies a transform to a batch of samples. 10 | 11 | Keyword arguments: 12 | - batch (): a batch os samples 13 | - transform (callable): A function/transform to apply to ``batch`` 14 | 15 | """ 16 | 17 | # Convert the single channel label to RGB in tensor form 18 | # 1. torch.unbind removes the 0-dimension of "labels" and returns a tuple of 19 | # all slices along that dimension 20 | # 2. the transform is applied to each slice 21 | transf_slices = [transform(tensor) for tensor in torch.unbind(batch)] 22 | 23 | return torch.stack(transf_slices) 24 | 25 | 26 | def imshow_batch(images, labels): 27 | """Displays two grids of images. The top grid displays ``images`` 28 | and the bottom grid ``labels`` 29 | 30 | Keyword arguments: 31 | - images (``Tensor``): a 4D mini-batch tensor of shape 32 | (B, C, H, W) 33 | - labels (``Tensor``): a 4D mini-batch tensor of shape 34 | (B, C, H, W) 35 | 36 | """ 37 | 38 | # Make a grid with the images and labels and convert it to numpy 39 | images = torchvision.utils.make_grid(images).numpy() 40 | labels = torchvision.utils.make_grid(labels).numpy() 41 | 42 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 7)) 43 | ax1.imshow(np.transpose(images, (1, 2, 0))) 44 | ax2.imshow(np.transpose(labels, (1, 2, 0))) 45 | 46 | plt.show() 47 | 48 | 49 | def save_checkpoint(model, optimizer, epoch, miou, args): 50 | """Saves the model in a specified directory with a specified name.save 51 | 52 | Keyword arguments: 53 | - model (``nn.Module``): The model to save. 54 | - optimizer (``torch.optim``): The optimizer state to save. 55 | - epoch (``int``): The current epoch for the model. 56 | - miou (``float``): The mean IoU obtained by the model. 57 | - args (``ArgumentParser``): An instance of ArgumentParser which contains 58 | the arguments used to train ``model``. The arguments are written to a text 59 | file in ``args.save_dir`` named "``args.name``_args.txt". 60 | 61 | """ 62 | name = args.name 63 | save_dir = args.save_dir 64 | 65 | assert os.path.isdir( 66 | save_dir), "The directory \"{0}\" doesn't exist.".format(save_dir) 67 | 68 | # Save model 69 | model_path = os.path.join(save_dir, name) 70 | checkpoint = { 71 | 'epoch': epoch, 72 | 'miou': miou, 73 | 'state_dict': model.state_dict(), 74 | 'optimizer': optimizer.state_dict() 75 | } 76 | torch.save(checkpoint, model_path) 77 | 78 | # Save arguments 79 | summary_filename = os.path.join(save_dir, name + '_summary.txt') 80 | with open(summary_filename, 'w') as summary_file: 81 | sorted_args = sorted(vars(args)) 82 | summary_file.write("ARGUMENTS\n") 83 | for arg in sorted_args: 84 | arg_str = "{0}: {1}\n".format(arg, getattr(args, arg)) 85 | summary_file.write(arg_str) 86 | 87 | summary_file.write("\nBEST VALIDATION\n") 88 | summary_file.write("Epoch: {0}\n". format(epoch)) 89 | summary_file.write("Mean IoU: {0}\n". format(miou)) 90 | 91 | 92 | def load_checkpoint(model, optimizer, folder_dir, filename): 93 | """Saves the model in a specified directory with a specified name.save 94 | 95 | Keyword arguments: 96 | - model (``nn.Module``): The stored model state is copied to this model 97 | instance. 98 | - optimizer (``torch.optim``): The stored optimizer state is copied to this 99 | optimizer instance. 100 | - folder_dir (``string``): The path to the folder where the saved model 101 | state is located. 102 | - filename (``string``): The model filename. 103 | 104 | Returns: 105 | The epoch, mean IoU, ``model``, and ``optimizer`` loaded from the 106 | checkpoint. 107 | 108 | """ 109 | assert os.path.isdir( 110 | folder_dir), "The directory \"{0}\" doesn't exist.".format(folder_dir) 111 | 112 | # Create folder to save model and information 113 | model_path = os.path.join(folder_dir, filename) 114 | assert os.path.isfile( 115 | model_path), "The model file \"{0}\" doesn't exist.".format(filename) 116 | 117 | # Load the stored model parameters to the model instance 118 | checkpoint = torch.load(model_path) 119 | model.load_state_dict(checkpoint['state_dict']) 120 | optimizer.load_state_dict(checkpoint['optimizer']) 121 | epoch = checkpoint['epoch'] 122 | miou = checkpoint['miou'] 123 | 124 | return model, optimizer, epoch, miou 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-ENet 2 | 3 | PyTorch (v1.1.0) implementation of [*ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation*](https://arxiv.org/abs/1606.02147), ported from the lua-torch implementation [ENet-training](https://github.com/e-lab/ENet-training) created by the authors. 4 | 5 | This implementation has been tested on the CamVid and Cityscapes datasets. Currently, a pre-trained version of the model trained in CamVid and Cityscapes is available [here](https://github.com/davidtvs/PyTorch-ENet/tree/master/save). 6 | 7 | | Dataset | Classes 1 | Input resolution | Batch size | Epochs | Mean IoU (%) | GPU memory (GiB) | Training time (hours)2 | 8 | | :------------------------------------------------------------------: | :------------------: | :--------------: | :--------: | :----: | :---------------: | :--------------: | :-------------------------------: | 9 | | [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) | 11 | 480x360 | 10 | 300 | 52.13 | 4.2 | 1 | 10 | | [Cityscapes](https://www.cityscapes-dataset.com/) | 19 | 1024x512 | 4 | 300 | 59.54 | 5.4 | 20 | 11 | 12 | 1 When referring to the number of classes, the void/unlabeled class is always excluded.
13 | 2 These are just for reference. Implementation, datasets, and hardware changes can lead to very different results. Reference hardware: Nvidia GTX 1070 and an AMD Ryzen 5 3600 3.6GHz. You can also train for 100 epochs or so and get similar mean IoU (± 2%).
14 | 3 Test set.
15 | 4 Validation set. 16 | 17 | ## Installation 18 | 19 | ### Local pip 20 | 21 | 1. Python 3 and pip 22 | 2. Set up a virtual environment (optional, but recommended) 23 | 3. Install dependencies using pip: `pip install -r requirements.txt` 24 | 25 | ### Docker image 26 | 27 | 1. Build the image: `docker build -t enet .` 28 | 2. Run: `docker run -it --gpus all --ipc host enet` 29 | 30 | ## Usage 31 | 32 | Run [``main.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/main.py), the main script file used for training and/or testing the model. The following options are supported: 33 | 34 | ``` 35 | python main.py [-h] [--mode {train,test,full}] [--resume] 36 | [--batch-size BATCH_SIZE] [--epochs EPOCHS] 37 | [--learning-rate LEARNING_RATE] [--lr-decay LR_DECAY] 38 | [--lr-decay-epochs LR_DECAY_EPOCHS] 39 | [--weight-decay WEIGHT_DECAY] [--dataset {camvid,cityscapes}] 40 | [--dataset-dir DATASET_DIR] [--height HEIGHT] [--width WIDTH] 41 | [--weighing {enet,mfb,none}] [--with-unlabeled] 42 | [--workers WORKERS] [--print-step] [--imshow-batch] 43 | [--device DEVICE] [--name NAME] [--save-dir SAVE_DIR] 44 | ``` 45 | 46 | For help on the optional arguments run: ``python main.py -h`` 47 | 48 | 49 | ### Examples: Training 50 | 51 | ``` 52 | python main.py -m train --save-dir save/folder/ --name model_name --dataset name --dataset-dir path/root_directory/ 53 | ``` 54 | 55 | 56 | ### Examples: Resuming training 57 | 58 | ``` 59 | python main.py -m train --resume True --save-dir save/folder/ --name model_name --dataset name --dataset-dir path/root_directory/ 60 | ``` 61 | 62 | 63 | ### Examples: Testing 64 | 65 | ``` 66 | python main.py -m test --save-dir save/folder/ --name model_name --dataset name --dataset-dir path/root_directory/ 67 | ``` 68 | 69 | 70 | ## Project structure 71 | 72 | ### Folders 73 | 74 | - [``data``](https://github.com/davidtvs/PyTorch-ENet/tree/master/data): Contains instructions on how to download the datasets and the code that handles data loading. 75 | - [``metric``](https://github.com/davidtvs/PyTorch-ENet/tree/master/metric): Evaluation-related metrics. 76 | - [``models``](https://github.com/davidtvs/PyTorch-ENet/tree/master/models): ENet model definition. 77 | - [``save``](https://github.com/davidtvs/PyTorch-ENet/tree/master/save): By default, ``main.py`` will save models in this folder. The pre-trained models can also be found here. 78 | 79 | ### Files 80 | 81 | - [``args.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/args.py): Contains all command-line options. 82 | - [``main.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/main.py): Main script file used for training and/or testing the model. 83 | - [``test.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/test.py): Defines the ``Test`` class which is responsible for testing the model. 84 | - [``train.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/train.py): Defines the ``Train`` class which is responsible for training the model. 85 | - [``transforms.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/transforms.py): Defines image transformations to convert an RGB image encoding classes to a ``torch.LongTensor`` and vice versa. 86 | -------------------------------------------------------------------------------- /data/camvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch.utils.data as data 4 | from . import utils 5 | 6 | 7 | class CamVid(data.Dataset): 8 | """CamVid dataset loader where the dataset is arranged as in 9 | https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid. 10 | 11 | 12 | Keyword arguments: 13 | - root_dir (``string``): Root directory path. 14 | - mode (``string``): The type of dataset: 'train' for training set, 'val' 15 | for validation set, and 'test' for test set. 16 | - transform (``callable``, optional): A function/transform that takes in 17 | an PIL image and returns a transformed version. Default: None. 18 | - label_transform (``callable``, optional): A function/transform that takes 19 | in the target and transforms it. Default: None. 20 | - loader (``callable``, optional): A function to load an image given its 21 | path. By default ``default_loader`` is used. 22 | 23 | """ 24 | # Training dataset root folders 25 | train_folder = 'train' 26 | train_lbl_folder = 'trainannot' 27 | 28 | # Validation dataset root folders 29 | val_folder = 'val' 30 | val_lbl_folder = 'valannot' 31 | 32 | # Test dataset root folders 33 | test_folder = 'test' 34 | test_lbl_folder = 'testannot' 35 | 36 | # Images extension 37 | img_extension = '.png' 38 | 39 | # Default encoding for pixel value, class name, and class color 40 | color_encoding = OrderedDict([ 41 | ('sky', (128, 128, 128)), 42 | ('building', (128, 0, 0)), 43 | ('pole', (192, 192, 128)), 44 | ('road_marking', (255, 69, 0)), 45 | ('road', (128, 64, 128)), 46 | ('pavement', (60, 40, 222)), 47 | ('tree', (128, 128, 0)), 48 | ('sign_symbol', (192, 128, 128)), 49 | ('fence', (64, 64, 128)), 50 | ('car', (64, 0, 128)), 51 | ('pedestrian', (64, 64, 0)), 52 | ('bicyclist', (0, 128, 192)), 53 | ('unlabeled', (0, 0, 0)) 54 | ]) 55 | 56 | def __init__(self, 57 | root_dir, 58 | mode='train', 59 | transform=None, 60 | label_transform=None, 61 | loader=utils.pil_loader): 62 | self.root_dir = root_dir 63 | self.mode = mode 64 | self.transform = transform 65 | self.label_transform = label_transform 66 | self.loader = loader 67 | 68 | if self.mode.lower() == 'train': 69 | # Get the training data and labels filepaths 70 | self.train_data = utils.get_files( 71 | os.path.join(root_dir, self.train_folder), 72 | extension_filter=self.img_extension) 73 | 74 | self.train_labels = utils.get_files( 75 | os.path.join(root_dir, self.train_lbl_folder), 76 | extension_filter=self.img_extension) 77 | elif self.mode.lower() == 'val': 78 | # Get the validation data and labels filepaths 79 | self.val_data = utils.get_files( 80 | os.path.join(root_dir, self.val_folder), 81 | extension_filter=self.img_extension) 82 | 83 | self.val_labels = utils.get_files( 84 | os.path.join(root_dir, self.val_lbl_folder), 85 | extension_filter=self.img_extension) 86 | elif self.mode.lower() == 'test': 87 | # Get the test data and labels filepaths 88 | self.test_data = utils.get_files( 89 | os.path.join(root_dir, self.test_folder), 90 | extension_filter=self.img_extension) 91 | 92 | self.test_labels = utils.get_files( 93 | os.path.join(root_dir, self.test_lbl_folder), 94 | extension_filter=self.img_extension) 95 | else: 96 | raise RuntimeError("Unexpected dataset mode. " 97 | "Supported modes are: train, val and test") 98 | 99 | def __getitem__(self, index): 100 | """ 101 | Args: 102 | - index (``int``): index of the item in the dataset 103 | 104 | Returns: 105 | A tuple of ``PIL.Image`` (image, label) where label is the ground-truth 106 | of the image. 107 | 108 | """ 109 | if self.mode.lower() == 'train': 110 | data_path, label_path = self.train_data[index], self.train_labels[ 111 | index] 112 | elif self.mode.lower() == 'val': 113 | data_path, label_path = self.val_data[index], self.val_labels[ 114 | index] 115 | elif self.mode.lower() == 'test': 116 | data_path, label_path = self.test_data[index], self.test_labels[ 117 | index] 118 | else: 119 | raise RuntimeError("Unexpected dataset mode. " 120 | "Supported modes are: train, val and test") 121 | 122 | img, label = self.loader(data_path, label_path) 123 | 124 | if self.transform is not None: 125 | img = self.transform(img) 126 | 127 | if self.label_transform is not None: 128 | label = self.label_transform(label) 129 | 130 | return img, label 131 | 132 | def __len__(self): 133 | """Returns the length of the dataset.""" 134 | if self.mode.lower() == 'train': 135 | return len(self.train_data) 136 | elif self.mode.lower() == 'val': 137 | return len(self.val_data) 138 | elif self.mode.lower() == 'test': 139 | return len(self.test_data) 140 | else: 141 | raise RuntimeError("Unexpected dataset mode. " 142 | "Supported modes are: train, val and test") 143 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | 5 | 6 | def get_files(folder, name_filter=None, extension_filter=None): 7 | """Helper function that returns the list of files in a specified folder 8 | with a specified extension. 9 | 10 | Keyword arguments: 11 | - folder (``string``): The path to a folder. 12 | - name_filter (```string``, optional): The returned files must contain 13 | this substring in their filename. Default: None; files are not filtered. 14 | - extension_filter (``string``, optional): The desired file extension. 15 | Default: None; files are not filtered 16 | 17 | """ 18 | if not os.path.isdir(folder): 19 | raise RuntimeError("\"{0}\" is not a folder.".format(folder)) 20 | 21 | # Filename filter: if not specified don't filter (condition always true); 22 | # otherwise, use a lambda expression to filter out files that do not 23 | # contain "name_filter" 24 | if name_filter is None: 25 | # This looks hackish...there is probably a better way 26 | name_cond = lambda filename: True 27 | else: 28 | name_cond = lambda filename: name_filter in filename 29 | 30 | # Extension filter: if not specified don't filter (condition always true); 31 | # otherwise, use a lambda expression to filter out files whose extension 32 | # is not "extension_filter" 33 | if extension_filter is None: 34 | # This looks hackish...there is probably a better way 35 | ext_cond = lambda filename: True 36 | else: 37 | ext_cond = lambda filename: filename.endswith(extension_filter) 38 | 39 | filtered_files = [] 40 | 41 | # Explore the directory tree to get files that contain "name_filter" and 42 | # with extension "extension_filter" 43 | for path, _, files in os.walk(folder): 44 | files.sort() 45 | for file in files: 46 | if name_cond(file) and ext_cond(file): 47 | full_path = os.path.join(path, file) 48 | filtered_files.append(full_path) 49 | 50 | return filtered_files 51 | 52 | 53 | def pil_loader(data_path, label_path): 54 | """Loads a sample and label image given their path as PIL images. 55 | 56 | Keyword arguments: 57 | - data_path (``string``): The filepath to the image. 58 | - label_path (``string``): The filepath to the ground-truth image. 59 | 60 | Returns the image and the label as PIL images. 61 | 62 | """ 63 | data = Image.open(data_path) 64 | label = Image.open(label_path) 65 | 66 | return data, label 67 | 68 | 69 | def remap(image, old_values, new_values): 70 | assert isinstance(image, Image.Image) or isinstance( 71 | image, np.ndarray), "image must be of type PIL.Image or numpy.ndarray" 72 | assert type(new_values) is tuple, "new_values must be of type tuple" 73 | assert type(old_values) is tuple, "old_values must be of type tuple" 74 | assert len(new_values) == len( 75 | old_values), "new_values and old_values must have the same length" 76 | 77 | # If image is a PIL.Image convert it to a numpy array 78 | if isinstance(image, Image.Image): 79 | image = np.array(image) 80 | 81 | # Replace old values by the new ones 82 | tmp = np.zeros_like(image) 83 | for old, new in zip(old_values, new_values): 84 | # Since tmp is already initialized as zeros we can skip new values 85 | # equal to 0 86 | if new != 0: 87 | tmp[image == old] = new 88 | 89 | return Image.fromarray(tmp) 90 | 91 | 92 | def enet_weighing(dataloader, num_classes, c=1.02): 93 | """Computes class weights as described in the ENet paper: 94 | 95 | w_class = 1 / (ln(c + p_class)), 96 | 97 | where c is usually 1.02 and p_class is the propensity score of that 98 | class: 99 | 100 | propensity_score = freq_class / total_pixels. 101 | 102 | References: https://arxiv.org/abs/1606.02147 103 | 104 | Keyword arguments: 105 | - dataloader (``data.Dataloader``): A data loader to iterate over the 106 | dataset. 107 | - num_classes (``int``): The number of classes. 108 | - c (``int``, optional): AN additional hyper-parameter which restricts 109 | the interval of values for the weights. Default: 1.02. 110 | 111 | """ 112 | class_count = 0 113 | total = 0 114 | for _, label in dataloader: 115 | label = label.cpu().numpy() 116 | 117 | # Flatten label 118 | flat_label = label.flatten() 119 | 120 | # Sum up the number of pixels of each class and the total pixel 121 | # counts for each label 122 | class_count += np.bincount(flat_label, minlength=num_classes) 123 | total += flat_label.size 124 | 125 | # Compute propensity score and then the weights for each class 126 | propensity_score = class_count / total 127 | class_weights = 1 / (np.log(c + propensity_score)) 128 | 129 | return class_weights 130 | 131 | 132 | def median_freq_balancing(dataloader, num_classes): 133 | """Computes class weights using median frequency balancing as described 134 | in https://arxiv.org/abs/1411.4734: 135 | 136 | w_class = median_freq / freq_class, 137 | 138 | where freq_class is the number of pixels of a given class divided by 139 | the total number of pixels in images where that class is present, and 140 | median_freq is the median of freq_class. 141 | 142 | Keyword arguments: 143 | - dataloader (``data.Dataloader``): A data loader to iterate over the 144 | dataset. 145 | whose weights are going to be computed. 146 | - num_classes (``int``): The number of classes 147 | 148 | """ 149 | class_count = 0 150 | total = 0 151 | for _, label in dataloader: 152 | label = label.cpu().numpy() 153 | 154 | # Flatten label 155 | flat_label = label.flatten() 156 | 157 | # Sum up the class frequencies 158 | bincount = np.bincount(flat_label, minlength=num_classes) 159 | 160 | # Create of mask of classes that exist in the label 161 | mask = bincount > 0 162 | # Multiply the mask by the pixel count. The resulting array has 163 | # one element for each class. The value is either 0 (if the class 164 | # does not exist in the label) or equal to the pixel count (if 165 | # the class exists in the label) 166 | total += mask * flat_label.size 167 | 168 | # Sum up the number of pixels found for each class 169 | class_count += bincount 170 | 171 | # Compute the frequency and its median 172 | freq = class_count / total 173 | med = np.median(freq) 174 | 175 | return med / freq 176 | -------------------------------------------------------------------------------- /data/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch.utils.data as data 4 | from . import utils 5 | 6 | 7 | class Cityscapes(data.Dataset): 8 | """Cityscapes dataset https://www.cityscapes-dataset.com/. 9 | 10 | Keyword arguments: 11 | - root_dir (``string``): Root directory path. 12 | - mode (``string``): The type of dataset: 'train' for training set, 'val' 13 | for validation set, and 'test' for test set. 14 | - transform (``callable``, optional): A function/transform that takes in 15 | an PIL image and returns a transformed version. Default: None. 16 | - label_transform (``callable``, optional): A function/transform that takes 17 | in the target and transforms it. Default: None. 18 | - loader (``callable``, optional): A function to load an image given its 19 | path. By default ``default_loader`` is used. 20 | 21 | """ 22 | # Training dataset root folders 23 | train_folder = "leftImg8bit_trainvaltest/leftImg8bit/train" 24 | train_lbl_folder = "gtFine_trainvaltest/gtFine/train" 25 | 26 | # Validation dataset root folders 27 | val_folder = "leftImg8bit_trainvaltest/leftImg8bit/val" 28 | val_lbl_folder = "gtFine_trainvaltest/gtFine/val" 29 | 30 | # Test dataset root folders 31 | test_folder = "leftImg8bit_trainvaltest/leftImg8bit/test" 32 | test_lbl_folder = "gtFine_trainvaltest/gtFine/test" 33 | 34 | # Filters to find the images 35 | img_extension = '.png' 36 | lbl_name_filter = 'labelIds' 37 | 38 | # The values associated with the 35 classes 39 | full_classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 40 | 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 41 | 32, 33, -1) 42 | # The values above are remapped to the following 43 | new_classes = (0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 5, 0, 0, 0, 6, 0, 7, 44 | 8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 17, 18, 19, 0) 45 | 46 | # Default encoding for pixel value, class name, and class color 47 | color_encoding = OrderedDict([ 48 | ('unlabeled', (0, 0, 0)), 49 | ('road', (128, 64, 128)), 50 | ('sidewalk', (244, 35, 232)), 51 | ('building', (70, 70, 70)), 52 | ('wall', (102, 102, 156)), 53 | ('fence', (190, 153, 153)), 54 | ('pole', (153, 153, 153)), 55 | ('traffic_light', (250, 170, 30)), 56 | ('traffic_sign', (220, 220, 0)), 57 | ('vegetation', (107, 142, 35)), 58 | ('terrain', (152, 251, 152)), 59 | ('sky', (70, 130, 180)), 60 | ('person', (220, 20, 60)), 61 | ('rider', (255, 0, 0)), 62 | ('car', (0, 0, 142)), 63 | ('truck', (0, 0, 70)), 64 | ('bus', (0, 60, 100)), 65 | ('train', (0, 80, 100)), 66 | ('motorcycle', (0, 0, 230)), 67 | ('bicycle', (119, 11, 32)) 68 | ]) 69 | 70 | def __init__(self, 71 | root_dir, 72 | mode='train', 73 | transform=None, 74 | label_transform=None, 75 | loader=utils.pil_loader): 76 | self.root_dir = root_dir 77 | self.mode = mode 78 | self.transform = transform 79 | self.label_transform = label_transform 80 | self.loader = loader 81 | 82 | if self.mode.lower() == 'train': 83 | # Get the training data and labels filepaths 84 | self.train_data = utils.get_files( 85 | os.path.join(root_dir, self.train_folder), 86 | extension_filter=self.img_extension) 87 | 88 | self.train_labels = utils.get_files( 89 | os.path.join(root_dir, self.train_lbl_folder), 90 | name_filter=self.lbl_name_filter, 91 | extension_filter=self.img_extension) 92 | elif self.mode.lower() == 'val': 93 | # Get the validation data and labels filepaths 94 | self.val_data = utils.get_files( 95 | os.path.join(root_dir, self.val_folder), 96 | extension_filter=self.img_extension) 97 | 98 | self.val_labels = utils.get_files( 99 | os.path.join(root_dir, self.val_lbl_folder), 100 | name_filter=self.lbl_name_filter, 101 | extension_filter=self.img_extension) 102 | elif self.mode.lower() == 'test': 103 | # Get the test data and labels filepaths 104 | self.test_data = utils.get_files( 105 | os.path.join(root_dir, self.test_folder), 106 | extension_filter=self.img_extension) 107 | 108 | self.test_labels = utils.get_files( 109 | os.path.join(root_dir, self.test_lbl_folder), 110 | name_filter=self.lbl_name_filter, 111 | extension_filter=self.img_extension) 112 | else: 113 | raise RuntimeError("Unexpected dataset mode. " 114 | "Supported modes are: train, val and test") 115 | 116 | def __getitem__(self, index): 117 | """ 118 | Args: 119 | - index (``int``): index of the item in the dataset 120 | 121 | Returns: 122 | A tuple of ``PIL.Image`` (image, label) where label is the ground-truth 123 | of the image. 124 | 125 | """ 126 | if self.mode.lower() == 'train': 127 | data_path, label_path = self.train_data[index], self.train_labels[ 128 | index] 129 | elif self.mode.lower() == 'val': 130 | data_path, label_path = self.val_data[index], self.val_labels[ 131 | index] 132 | elif self.mode.lower() == 'test': 133 | data_path, label_path = self.test_data[index], self.test_labels[ 134 | index] 135 | else: 136 | raise RuntimeError("Unexpected dataset mode. " 137 | "Supported modes are: train, val and test") 138 | 139 | img, label = self.loader(data_path, label_path) 140 | 141 | # Remap class labels 142 | label = utils.remap(label, self.full_classes, self.new_classes) 143 | 144 | if self.transform is not None: 145 | img = self.transform(img) 146 | 147 | if self.label_transform is not None: 148 | label = self.label_transform(label) 149 | 150 | return img, label 151 | 152 | def __len__(self): 153 | """Returns the length of the dataset.""" 154 | if self.mode.lower() == 'train': 155 | return len(self.train_data) 156 | elif self.mode.lower() == 'val': 157 | return len(self.val_data) 158 | elif self.mode.lower() == 'test': 159 | return len(self.test_data) 160 | else: 161 | raise RuntimeError("Unexpected dataset mode. " 162 | "Supported modes are: train, val and test") 163 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.optim.lr_scheduler as lr_scheduler 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | 10 | from PIL import Image 11 | 12 | import transforms as ext_transforms 13 | from models.enet import ENet 14 | from train import Train 15 | from test import Test 16 | from metric.iou import IoU 17 | from args import get_arguments 18 | from data.utils import enet_weighing, median_freq_balancing 19 | import utils 20 | 21 | # Get the arguments 22 | args = get_arguments() 23 | 24 | device = torch.device(args.device) 25 | 26 | 27 | def load_dataset(dataset): 28 | print("\nLoading dataset...\n") 29 | 30 | print("Selected dataset:", args.dataset) 31 | print("Dataset directory:", args.dataset_dir) 32 | print("Save directory:", args.save_dir) 33 | 34 | image_transform = transforms.Compose( 35 | [transforms.Resize((args.height, args.width)), 36 | transforms.ToTensor()]) 37 | 38 | label_transform = transforms.Compose([ 39 | transforms.Resize((args.height, args.width), Image.NEAREST), 40 | ext_transforms.PILToLongTensor() 41 | ]) 42 | 43 | # Get selected dataset 44 | # Load the training set as tensors 45 | train_set = dataset( 46 | args.dataset_dir, 47 | transform=image_transform, 48 | label_transform=label_transform) 49 | train_loader = data.DataLoader( 50 | train_set, 51 | batch_size=args.batch_size, 52 | shuffle=True, 53 | num_workers=args.workers) 54 | 55 | # Load the validation set as tensors 56 | val_set = dataset( 57 | args.dataset_dir, 58 | mode='val', 59 | transform=image_transform, 60 | label_transform=label_transform) 61 | val_loader = data.DataLoader( 62 | val_set, 63 | batch_size=args.batch_size, 64 | shuffle=False, 65 | num_workers=args.workers) 66 | 67 | # Load the test set as tensors 68 | test_set = dataset( 69 | args.dataset_dir, 70 | mode='test', 71 | transform=image_transform, 72 | label_transform=label_transform) 73 | test_loader = data.DataLoader( 74 | test_set, 75 | batch_size=args.batch_size, 76 | shuffle=False, 77 | num_workers=args.workers) 78 | 79 | # Get encoding between pixel valus in label images and RGB colors 80 | class_encoding = train_set.color_encoding 81 | 82 | # Remove the road_marking class from the CamVid dataset as it's merged 83 | # with the road class 84 | if args.dataset.lower() == 'camvid': 85 | del class_encoding['road_marking'] 86 | 87 | # Get number of classes to predict 88 | num_classes = len(class_encoding) 89 | 90 | # Print information for debugging 91 | print("Number of classes to predict:", num_classes) 92 | print("Train dataset size:", len(train_set)) 93 | print("Validation dataset size:", len(val_set)) 94 | 95 | # Get a batch of samples to display 96 | if args.mode.lower() == 'test': 97 | images, labels = iter(test_loader).next() 98 | else: 99 | images, labels = iter(train_loader).next() 100 | print("Image size:", images.size()) 101 | print("Label size:", labels.size()) 102 | print("Class-color encoding:", class_encoding) 103 | 104 | # Show a batch of samples and labels 105 | if args.imshow_batch: 106 | print("Close the figure window to continue...") 107 | label_to_rgb = transforms.Compose([ 108 | ext_transforms.LongTensorToRGBPIL(class_encoding), 109 | transforms.ToTensor() 110 | ]) 111 | color_labels = utils.batch_transform(labels, label_to_rgb) 112 | utils.imshow_batch(images, color_labels) 113 | 114 | # Get class weights from the selected weighing technique 115 | print("\nWeighing technique:", args.weighing) 116 | print("Computing class weights...") 117 | print("(this can take a while depending on the dataset size)") 118 | class_weights = 0 119 | if args.weighing.lower() == 'enet': 120 | class_weights = enet_weighing(train_loader, num_classes) 121 | elif args.weighing.lower() == 'mfb': 122 | class_weights = median_freq_balancing(train_loader, num_classes) 123 | else: 124 | class_weights = None 125 | 126 | if class_weights is not None: 127 | class_weights = torch.from_numpy(class_weights).float().to(device) 128 | # Set the weight of the unlabeled class to 0 129 | if args.ignore_unlabeled: 130 | ignore_index = list(class_encoding).index('unlabeled') 131 | class_weights[ignore_index] = 0 132 | 133 | print("Class weights:", class_weights) 134 | 135 | return (train_loader, val_loader, 136 | test_loader), class_weights, class_encoding 137 | 138 | 139 | def train(train_loader, val_loader, class_weights, class_encoding): 140 | print("\nTraining...\n") 141 | 142 | num_classes = len(class_encoding) 143 | 144 | # Intialize ENet 145 | model = ENet(num_classes).to(device) 146 | # Check if the network architecture is correct 147 | print(model) 148 | 149 | # We are going to use the CrossEntropyLoss loss function as it's most 150 | # frequentely used in classification problems with multiple classes which 151 | # fits the problem. This criterion combines LogSoftMax and NLLLoss. 152 | criterion = nn.CrossEntropyLoss(weight=class_weights) 153 | 154 | # ENet authors used Adam as the optimizer 155 | optimizer = optim.Adam( 156 | model.parameters(), 157 | lr=args.learning_rate, 158 | weight_decay=args.weight_decay) 159 | 160 | # Learning rate decay scheduler 161 | lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs, 162 | args.lr_decay) 163 | 164 | # Evaluation metric 165 | if args.ignore_unlabeled: 166 | ignore_index = list(class_encoding).index('unlabeled') 167 | else: 168 | ignore_index = None 169 | metric = IoU(num_classes, ignore_index=ignore_index) 170 | 171 | # Optionally resume from a checkpoint 172 | if args.resume: 173 | model, optimizer, start_epoch, best_miou = utils.load_checkpoint( 174 | model, optimizer, args.save_dir, args.name) 175 | print("Resuming from model: Start epoch = {0} " 176 | "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou)) 177 | else: 178 | start_epoch = 0 179 | best_miou = 0 180 | 181 | # Start Training 182 | print() 183 | train = Train(model, train_loader, optimizer, criterion, metric, device) 184 | val = Test(model, val_loader, criterion, metric, device) 185 | for epoch in range(start_epoch, args.epochs): 186 | print(">>>> [Epoch: {0:d}] Training".format(epoch)) 187 | 188 | epoch_loss, (iou, miou) = train.run_epoch(args.print_step) 189 | lr_updater.step() 190 | 191 | print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}". 192 | format(epoch, epoch_loss, miou)) 193 | 194 | if (epoch + 1) % 10 == 0 or epoch + 1 == args.epochs: 195 | print(">>>> [Epoch: {0:d}] Validation".format(epoch)) 196 | 197 | loss, (iou, miou) = val.run_epoch(args.print_step) 198 | 199 | print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}". 200 | format(epoch, loss, miou)) 201 | 202 | # Print per class IoU on last epoch or if best iou 203 | if epoch + 1 == args.epochs or miou > best_miou: 204 | for key, class_iou in zip(class_encoding.keys(), iou): 205 | print("{0}: {1:.4f}".format(key, class_iou)) 206 | 207 | # Save the model if it's the best thus far 208 | if miou > best_miou: 209 | print("\nBest model thus far. Saving...\n") 210 | best_miou = miou 211 | utils.save_checkpoint(model, optimizer, epoch + 1, best_miou, 212 | args) 213 | 214 | return model 215 | 216 | 217 | def test(model, test_loader, class_weights, class_encoding): 218 | print("\nTesting...\n") 219 | 220 | num_classes = len(class_encoding) 221 | 222 | # We are going to use the CrossEntropyLoss loss function as it's most 223 | # frequentely used in classification problems with multiple classes which 224 | # fits the problem. This criterion combines LogSoftMax and NLLLoss. 225 | criterion = nn.CrossEntropyLoss(weight=class_weights) 226 | 227 | # Evaluation metric 228 | if args.ignore_unlabeled: 229 | ignore_index = list(class_encoding).index('unlabeled') 230 | else: 231 | ignore_index = None 232 | metric = IoU(num_classes, ignore_index=ignore_index) 233 | 234 | # Test the trained model on the test set 235 | test = Test(model, test_loader, criterion, metric, device) 236 | 237 | print(">>>> Running test dataset") 238 | 239 | loss, (iou, miou) = test.run_epoch(args.print_step) 240 | class_iou = dict(zip(class_encoding.keys(), iou)) 241 | 242 | print(">>>> Avg. loss: {0:.4f} | Mean IoU: {1:.4f}".format(loss, miou)) 243 | 244 | # Print per class IoU 245 | for key, class_iou in zip(class_encoding.keys(), iou): 246 | print("{0}: {1:.4f}".format(key, class_iou)) 247 | 248 | # Show a batch of samples and labels 249 | if args.imshow_batch: 250 | print("A batch of predictions from the test set...") 251 | images, _ = iter(test_loader).next() 252 | predict(model, images, class_encoding) 253 | 254 | 255 | def predict(model, images, class_encoding): 256 | images = images.to(device) 257 | 258 | # Make predictions! 259 | model.eval() 260 | with torch.no_grad(): 261 | predictions = model(images) 262 | 263 | # Predictions is one-hot encoded with "num_classes" channels. 264 | # Convert it to a single int using the indices where the maximum (1) occurs 265 | _, predictions = torch.max(predictions.data, 1) 266 | 267 | label_to_rgb = transforms.Compose([ 268 | ext_transforms.LongTensorToRGBPIL(class_encoding), 269 | transforms.ToTensor() 270 | ]) 271 | color_predictions = utils.batch_transform(predictions.cpu(), label_to_rgb) 272 | utils.imshow_batch(images.data.cpu(), color_predictions) 273 | 274 | 275 | # Run only if this module is being run directly 276 | if __name__ == '__main__': 277 | 278 | # Fail fast if the dataset directory doesn't exist 279 | assert os.path.isdir( 280 | args.dataset_dir), "The directory \"{0}\" doesn't exist.".format( 281 | args.dataset_dir) 282 | 283 | # Fail fast if the saving directory doesn't exist 284 | assert os.path.isdir( 285 | args.save_dir), "The directory \"{0}\" doesn't exist.".format( 286 | args.save_dir) 287 | 288 | # Import the requested dataset 289 | if args.dataset.lower() == 'camvid': 290 | from data import CamVid as dataset 291 | elif args.dataset.lower() == 'cityscapes': 292 | from data import Cityscapes as dataset 293 | else: 294 | # Should never happen...but just in case it does 295 | raise RuntimeError("\"{0}\" is not a supported dataset.".format( 296 | args.dataset)) 297 | 298 | loaders, w_class, class_encoding = load_dataset(dataset) 299 | train_loader, val_loader, test_loader = loaders 300 | 301 | if args.mode.lower() in {'train', 'full'}: 302 | model = train(train_loader, val_loader, w_class, class_encoding) 303 | 304 | if args.mode.lower() in {'test', 'full'}: 305 | if args.mode.lower() == 'test': 306 | # Intialize a new ENet model 307 | num_classes = len(class_encoding) 308 | model = ENet(num_classes).to(device) 309 | 310 | # Initialize a optimizer just so we can retrieve the model from the 311 | # checkpoint 312 | optimizer = optim.Adam(model.parameters()) 313 | 314 | # Load the previoulsy saved model state to the ENet model 315 | model = utils.load_checkpoint(model, optimizer, args.save_dir, 316 | args.name)[0] 317 | 318 | if args.mode.lower() == 'test': 319 | print(model) 320 | 321 | test(model, test_loader, w_class, class_encoding) 322 | -------------------------------------------------------------------------------- /models/enet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class InitialBlock(nn.Module): 6 | """The initial block is composed of two branches: 7 | 1. a main branch which performs a regular convolution with stride 2; 8 | 2. an extension branch which performs max-pooling. 9 | 10 | Doing both operations in parallel and concatenating their results 11 | allows for efficient downsampling and expansion. The main branch 12 | outputs 13 feature maps while the extension branch outputs 3, for a 13 | total of 16 feature maps after concatenation. 14 | 15 | Keyword arguments: 16 | - in_channels (int): the number of input channels. 17 | - out_channels (int): the number output channels. 18 | - kernel_size (int, optional): the kernel size of the filters used in 19 | the convolution layer. Default: 3. 20 | - padding (int, optional): zero-padding added to both sides of the 21 | input. Default: 0. 22 | - bias (bool, optional): Adds a learnable bias to the output if 23 | ``True``. Default: False. 24 | - relu (bool, optional): When ``True`` ReLU is used as the activation 25 | function; otherwise, PReLU is used. Default: True. 26 | 27 | """ 28 | 29 | def __init__(self, 30 | in_channels, 31 | out_channels, 32 | bias=False, 33 | relu=True): 34 | super().__init__() 35 | 36 | if relu: 37 | activation = nn.ReLU 38 | else: 39 | activation = nn.PReLU 40 | 41 | # Main branch - As stated above the number of output channels for this 42 | # branch is the total minus 3, since the remaining channels come from 43 | # the extension branch 44 | self.main_branch = nn.Conv2d( 45 | in_channels, 46 | out_channels - 3, 47 | kernel_size=3, 48 | stride=2, 49 | padding=1, 50 | bias=bias) 51 | 52 | # Extension branch 53 | self.ext_branch = nn.MaxPool2d(3, stride=2, padding=1) 54 | 55 | # Initialize batch normalization to be used after concatenation 56 | self.batch_norm = nn.BatchNorm2d(out_channels) 57 | 58 | # PReLU layer to apply after concatenating the branches 59 | self.out_activation = activation() 60 | 61 | def forward(self, x): 62 | main = self.main_branch(x) 63 | ext = self.ext_branch(x) 64 | 65 | # Concatenate branches 66 | out = torch.cat((main, ext), 1) 67 | 68 | # Apply batch normalization 69 | out = self.batch_norm(out) 70 | 71 | return self.out_activation(out) 72 | 73 | 74 | class RegularBottleneck(nn.Module): 75 | """Regular bottlenecks are the main building block of ENet. 76 | Main branch: 77 | 1. Shortcut connection. 78 | 79 | Extension branch: 80 | 1. 1x1 convolution which decreases the number of channels by 81 | ``internal_ratio``, also called a projection; 82 | 2. regular, dilated or asymmetric convolution; 83 | 3. 1x1 convolution which increases the number of channels back to 84 | ``channels``, also called an expansion; 85 | 4. dropout as a regularizer. 86 | 87 | Keyword arguments: 88 | - channels (int): the number of input and output channels. 89 | - internal_ratio (int, optional): a scale factor applied to 90 | ``channels`` used to compute the number of 91 | channels after the projection. eg. given ``channels`` equal to 128 and 92 | internal_ratio equal to 2 the number of channels after the projection 93 | is 64. Default: 4. 94 | - kernel_size (int, optional): the kernel size of the filters used in 95 | the convolution layer described above in item 2 of the extension 96 | branch. Default: 3. 97 | - padding (int, optional): zero-padding added to both sides of the 98 | input. Default: 0. 99 | - dilation (int, optional): spacing between kernel elements for the 100 | convolution described in item 2 of the extension branch. Default: 1. 101 | asymmetric (bool, optional): flags if the convolution described in 102 | item 2 of the extension branch is asymmetric or not. Default: False. 103 | - dropout_prob (float, optional): probability of an element to be 104 | zeroed. Default: 0 (no dropout). 105 | - bias (bool, optional): Adds a learnable bias to the output if 106 | ``True``. Default: False. 107 | - relu (bool, optional): When ``True`` ReLU is used as the activation 108 | function; otherwise, PReLU is used. Default: True. 109 | 110 | """ 111 | 112 | def __init__(self, 113 | channels, 114 | internal_ratio=4, 115 | kernel_size=3, 116 | padding=0, 117 | dilation=1, 118 | asymmetric=False, 119 | dropout_prob=0, 120 | bias=False, 121 | relu=True): 122 | super().__init__() 123 | 124 | # Check in the internal_scale parameter is within the expected range 125 | # [1, channels] 126 | if internal_ratio <= 1 or internal_ratio > channels: 127 | raise RuntimeError("Value out of range. Expected value in the " 128 | "interval [1, {0}], got internal_scale={1}." 129 | .format(channels, internal_ratio)) 130 | 131 | internal_channels = channels // internal_ratio 132 | 133 | if relu: 134 | activation = nn.ReLU 135 | else: 136 | activation = nn.PReLU 137 | 138 | # Main branch - shortcut connection 139 | 140 | # Extension branch - 1x1 convolution, followed by a regular, dilated or 141 | # asymmetric convolution, followed by another 1x1 convolution, and, 142 | # finally, a regularizer (spatial dropout). Number of channels is constant. 143 | 144 | # 1x1 projection convolution 145 | self.ext_conv1 = nn.Sequential( 146 | nn.Conv2d( 147 | channels, 148 | internal_channels, 149 | kernel_size=1, 150 | stride=1, 151 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 152 | 153 | # If the convolution is asymmetric we split the main convolution in 154 | # two. Eg. for a 5x5 asymmetric convolution we have two convolution: 155 | # the first is 5x1 and the second is 1x5. 156 | if asymmetric: 157 | self.ext_conv2 = nn.Sequential( 158 | nn.Conv2d( 159 | internal_channels, 160 | internal_channels, 161 | kernel_size=(kernel_size, 1), 162 | stride=1, 163 | padding=(padding, 0), 164 | dilation=dilation, 165 | bias=bias), nn.BatchNorm2d(internal_channels), activation(), 166 | nn.Conv2d( 167 | internal_channels, 168 | internal_channels, 169 | kernel_size=(1, kernel_size), 170 | stride=1, 171 | padding=(0, padding), 172 | dilation=dilation, 173 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 174 | else: 175 | self.ext_conv2 = nn.Sequential( 176 | nn.Conv2d( 177 | internal_channels, 178 | internal_channels, 179 | kernel_size=kernel_size, 180 | stride=1, 181 | padding=padding, 182 | dilation=dilation, 183 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 184 | 185 | # 1x1 expansion convolution 186 | self.ext_conv3 = nn.Sequential( 187 | nn.Conv2d( 188 | internal_channels, 189 | channels, 190 | kernel_size=1, 191 | stride=1, 192 | bias=bias), nn.BatchNorm2d(channels), activation()) 193 | 194 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 195 | 196 | # PReLU layer to apply after adding the branches 197 | self.out_activation = activation() 198 | 199 | def forward(self, x): 200 | # Main branch shortcut 201 | main = x 202 | 203 | # Extension branch 204 | ext = self.ext_conv1(x) 205 | ext = self.ext_conv2(ext) 206 | ext = self.ext_conv3(ext) 207 | ext = self.ext_regul(ext) 208 | 209 | # Add main and extension branches 210 | out = main + ext 211 | 212 | return self.out_activation(out) 213 | 214 | 215 | class DownsamplingBottleneck(nn.Module): 216 | """Downsampling bottlenecks further downsample the feature map size. 217 | 218 | Main branch: 219 | 1. max pooling with stride 2; indices are saved to be used for 220 | unpooling later. 221 | 222 | Extension branch: 223 | 1. 2x2 convolution with stride 2 that decreases the number of channels 224 | by ``internal_ratio``, also called a projection; 225 | 2. regular convolution (by default, 3x3); 226 | 3. 1x1 convolution which increases the number of channels to 227 | ``out_channels``, also called an expansion; 228 | 4. dropout as a regularizer. 229 | 230 | Keyword arguments: 231 | - in_channels (int): the number of input channels. 232 | - out_channels (int): the number of output channels. 233 | - internal_ratio (int, optional): a scale factor applied to ``channels`` 234 | used to compute the number of channels after the projection. eg. given 235 | ``channels`` equal to 128 and internal_ratio equal to 2 the number of 236 | channels after the projection is 64. Default: 4. 237 | - return_indices (bool, optional): if ``True``, will return the max 238 | indices along with the outputs. Useful when unpooling later. 239 | - dropout_prob (float, optional): probability of an element to be 240 | zeroed. Default: 0 (no dropout). 241 | - bias (bool, optional): Adds a learnable bias to the output if 242 | ``True``. Default: False. 243 | - relu (bool, optional): When ``True`` ReLU is used as the activation 244 | function; otherwise, PReLU is used. Default: True. 245 | 246 | """ 247 | 248 | def __init__(self, 249 | in_channels, 250 | out_channels, 251 | internal_ratio=4, 252 | return_indices=False, 253 | dropout_prob=0, 254 | bias=False, 255 | relu=True): 256 | super().__init__() 257 | 258 | # Store parameters that are needed later 259 | self.return_indices = return_indices 260 | 261 | # Check in the internal_scale parameter is within the expected range 262 | # [1, channels] 263 | if internal_ratio <= 1 or internal_ratio > in_channels: 264 | raise RuntimeError("Value out of range. Expected value in the " 265 | "interval [1, {0}], got internal_scale={1}. " 266 | .format(in_channels, internal_ratio)) 267 | 268 | internal_channels = in_channels // internal_ratio 269 | 270 | if relu: 271 | activation = nn.ReLU 272 | else: 273 | activation = nn.PReLU 274 | 275 | # Main branch - max pooling followed by feature map (channels) padding 276 | self.main_max1 = nn.MaxPool2d( 277 | 2, 278 | stride=2, 279 | return_indices=return_indices) 280 | 281 | # Extension branch - 2x2 convolution, followed by a regular, dilated or 282 | # asymmetric convolution, followed by another 1x1 convolution. Number 283 | # of channels is doubled. 284 | 285 | # 2x2 projection convolution with stride 2 286 | self.ext_conv1 = nn.Sequential( 287 | nn.Conv2d( 288 | in_channels, 289 | internal_channels, 290 | kernel_size=2, 291 | stride=2, 292 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 293 | 294 | # Convolution 295 | self.ext_conv2 = nn.Sequential( 296 | nn.Conv2d( 297 | internal_channels, 298 | internal_channels, 299 | kernel_size=3, 300 | stride=1, 301 | padding=1, 302 | bias=bias), nn.BatchNorm2d(internal_channels), activation()) 303 | 304 | # 1x1 expansion convolution 305 | self.ext_conv3 = nn.Sequential( 306 | nn.Conv2d( 307 | internal_channels, 308 | out_channels, 309 | kernel_size=1, 310 | stride=1, 311 | bias=bias), nn.BatchNorm2d(out_channels), activation()) 312 | 313 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 314 | 315 | # PReLU layer to apply after concatenating the branches 316 | self.out_activation = activation() 317 | 318 | def forward(self, x): 319 | # Main branch shortcut 320 | if self.return_indices: 321 | main, max_indices = self.main_max1(x) 322 | else: 323 | main = self.main_max1(x) 324 | 325 | # Extension branch 326 | ext = self.ext_conv1(x) 327 | ext = self.ext_conv2(ext) 328 | ext = self.ext_conv3(ext) 329 | ext = self.ext_regul(ext) 330 | 331 | # Main branch channel padding 332 | n, ch_ext, h, w = ext.size() 333 | ch_main = main.size()[1] 334 | padding = torch.zeros(n, ch_ext - ch_main, h, w) 335 | 336 | # Before concatenating, check if main is on the CPU or GPU and 337 | # convert padding accordingly 338 | if main.is_cuda: 339 | padding = padding.cuda() 340 | 341 | # Concatenate 342 | main = torch.cat((main, padding), 1) 343 | 344 | # Add main and extension branches 345 | out = main + ext 346 | 347 | return self.out_activation(out), max_indices 348 | 349 | 350 | class UpsamplingBottleneck(nn.Module): 351 | """The upsampling bottlenecks upsample the feature map resolution using max 352 | pooling indices stored from the corresponding downsampling bottleneck. 353 | 354 | Main branch: 355 | 1. 1x1 convolution with stride 1 that decreases the number of channels by 356 | ``internal_ratio``, also called a projection; 357 | 2. max unpool layer using the max pool indices from the corresponding 358 | downsampling max pool layer. 359 | 360 | Extension branch: 361 | 1. 1x1 convolution with stride 1 that decreases the number of channels by 362 | ``internal_ratio``, also called a projection; 363 | 2. transposed convolution (by default, 3x3); 364 | 3. 1x1 convolution which increases the number of channels to 365 | ``out_channels``, also called an expansion; 366 | 4. dropout as a regularizer. 367 | 368 | Keyword arguments: 369 | - in_channels (int): the number of input channels. 370 | - out_channels (int): the number of output channels. 371 | - internal_ratio (int, optional): a scale factor applied to ``in_channels`` 372 | used to compute the number of channels after the projection. eg. given 373 | ``in_channels`` equal to 128 and ``internal_ratio`` equal to 2 the number 374 | of channels after the projection is 64. Default: 4. 375 | - dropout_prob (float, optional): probability of an element to be zeroed. 376 | Default: 0 (no dropout). 377 | - bias (bool, optional): Adds a learnable bias to the output if ``True``. 378 | Default: False. 379 | - relu (bool, optional): When ``True`` ReLU is used as the activation 380 | function; otherwise, PReLU is used. Default: True. 381 | 382 | """ 383 | 384 | def __init__(self, 385 | in_channels, 386 | out_channels, 387 | internal_ratio=4, 388 | dropout_prob=0, 389 | bias=False, 390 | relu=True): 391 | super().__init__() 392 | 393 | # Check in the internal_scale parameter is within the expected range 394 | # [1, channels] 395 | if internal_ratio <= 1 or internal_ratio > in_channels: 396 | raise RuntimeError("Value out of range. Expected value in the " 397 | "interval [1, {0}], got internal_scale={1}. " 398 | .format(in_channels, internal_ratio)) 399 | 400 | internal_channels = in_channels // internal_ratio 401 | 402 | if relu: 403 | activation = nn.ReLU 404 | else: 405 | activation = nn.PReLU 406 | 407 | # Main branch - max pooling followed by feature map (channels) padding 408 | self.main_conv1 = nn.Sequential( 409 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias), 410 | nn.BatchNorm2d(out_channels)) 411 | 412 | # Remember that the stride is the same as the kernel_size, just like 413 | # the max pooling layers 414 | self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2) 415 | 416 | # Extension branch - 1x1 convolution, followed by a regular, dilated or 417 | # asymmetric convolution, followed by another 1x1 convolution. Number 418 | # of channels is doubled. 419 | 420 | # 1x1 projection convolution with stride 1 421 | self.ext_conv1 = nn.Sequential( 422 | nn.Conv2d( 423 | in_channels, internal_channels, kernel_size=1, bias=bias), 424 | nn.BatchNorm2d(internal_channels), activation()) 425 | 426 | # Transposed convolution 427 | self.ext_tconv1 = nn.ConvTranspose2d( 428 | internal_channels, 429 | internal_channels, 430 | kernel_size=2, 431 | stride=2, 432 | bias=bias) 433 | self.ext_tconv1_bnorm = nn.BatchNorm2d(internal_channels) 434 | self.ext_tconv1_activation = activation() 435 | 436 | # 1x1 expansion convolution 437 | self.ext_conv2 = nn.Sequential( 438 | nn.Conv2d( 439 | internal_channels, out_channels, kernel_size=1, bias=bias), 440 | nn.BatchNorm2d(out_channels)) 441 | 442 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 443 | 444 | # PReLU layer to apply after concatenating the branches 445 | self.out_activation = activation() 446 | 447 | def forward(self, x, max_indices, output_size): 448 | # Main branch shortcut 449 | main = self.main_conv1(x) 450 | main = self.main_unpool1( 451 | main, max_indices, output_size=output_size) 452 | 453 | # Extension branch 454 | ext = self.ext_conv1(x) 455 | ext = self.ext_tconv1(ext, output_size=output_size) 456 | ext = self.ext_tconv1_bnorm(ext) 457 | ext = self.ext_tconv1_activation(ext) 458 | ext = self.ext_conv2(ext) 459 | ext = self.ext_regul(ext) 460 | 461 | # Add main and extension branches 462 | out = main + ext 463 | 464 | return self.out_activation(out) 465 | 466 | 467 | class ENet(nn.Module): 468 | """Generate the ENet model. 469 | 470 | Keyword arguments: 471 | - num_classes (int): the number of classes to segment. 472 | - encoder_relu (bool, optional): When ``True`` ReLU is used as the 473 | activation function in the encoder blocks/layers; otherwise, PReLU 474 | is used. Default: False. 475 | - decoder_relu (bool, optional): When ``True`` ReLU is used as the 476 | activation function in the decoder blocks/layers; otherwise, PReLU 477 | is used. Default: True. 478 | 479 | """ 480 | 481 | def __init__(self, num_classes, encoder_relu=False, decoder_relu=True): 482 | super().__init__() 483 | 484 | self.initial_block = InitialBlock(3, 16, relu=encoder_relu) 485 | 486 | # Stage 1 - Encoder 487 | self.downsample1_0 = DownsamplingBottleneck( 488 | 16, 489 | 64, 490 | return_indices=True, 491 | dropout_prob=0.01, 492 | relu=encoder_relu) 493 | self.regular1_1 = RegularBottleneck( 494 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 495 | self.regular1_2 = RegularBottleneck( 496 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 497 | self.regular1_3 = RegularBottleneck( 498 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 499 | self.regular1_4 = RegularBottleneck( 500 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 501 | 502 | # Stage 2 - Encoder 503 | self.downsample2_0 = DownsamplingBottleneck( 504 | 64, 505 | 128, 506 | return_indices=True, 507 | dropout_prob=0.1, 508 | relu=encoder_relu) 509 | self.regular2_1 = RegularBottleneck( 510 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 511 | self.dilated2_2 = RegularBottleneck( 512 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 513 | self.asymmetric2_3 = RegularBottleneck( 514 | 128, 515 | kernel_size=5, 516 | padding=2, 517 | asymmetric=True, 518 | dropout_prob=0.1, 519 | relu=encoder_relu) 520 | self.dilated2_4 = RegularBottleneck( 521 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 522 | self.regular2_5 = RegularBottleneck( 523 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 524 | self.dilated2_6 = RegularBottleneck( 525 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 526 | self.asymmetric2_7 = RegularBottleneck( 527 | 128, 528 | kernel_size=5, 529 | asymmetric=True, 530 | padding=2, 531 | dropout_prob=0.1, 532 | relu=encoder_relu) 533 | self.dilated2_8 = RegularBottleneck( 534 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 535 | 536 | # Stage 3 - Encoder 537 | self.regular3_0 = RegularBottleneck( 538 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 539 | self.dilated3_1 = RegularBottleneck( 540 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 541 | self.asymmetric3_2 = RegularBottleneck( 542 | 128, 543 | kernel_size=5, 544 | padding=2, 545 | asymmetric=True, 546 | dropout_prob=0.1, 547 | relu=encoder_relu) 548 | self.dilated3_3 = RegularBottleneck( 549 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 550 | self.regular3_4 = RegularBottleneck( 551 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 552 | self.dilated3_5 = RegularBottleneck( 553 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 554 | self.asymmetric3_6 = RegularBottleneck( 555 | 128, 556 | kernel_size=5, 557 | asymmetric=True, 558 | padding=2, 559 | dropout_prob=0.1, 560 | relu=encoder_relu) 561 | self.dilated3_7 = RegularBottleneck( 562 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 563 | 564 | # Stage 4 - Decoder 565 | self.upsample4_0 = UpsamplingBottleneck( 566 | 128, 64, dropout_prob=0.1, relu=decoder_relu) 567 | self.regular4_1 = RegularBottleneck( 568 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 569 | self.regular4_2 = RegularBottleneck( 570 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 571 | 572 | # Stage 5 - Decoder 573 | self.upsample5_0 = UpsamplingBottleneck( 574 | 64, 16, dropout_prob=0.1, relu=decoder_relu) 575 | self.regular5_1 = RegularBottleneck( 576 | 16, padding=1, dropout_prob=0.1, relu=decoder_relu) 577 | self.transposed_conv = nn.ConvTranspose2d( 578 | 16, 579 | num_classes, 580 | kernel_size=3, 581 | stride=2, 582 | padding=1, 583 | bias=False) 584 | 585 | def forward(self, x): 586 | # Initial block 587 | input_size = x.size() 588 | x = self.initial_block(x) 589 | 590 | # Stage 1 - Encoder 591 | stage1_input_size = x.size() 592 | x, max_indices1_0 = self.downsample1_0(x) 593 | x = self.regular1_1(x) 594 | x = self.regular1_2(x) 595 | x = self.regular1_3(x) 596 | x = self.regular1_4(x) 597 | 598 | # Stage 2 - Encoder 599 | stage2_input_size = x.size() 600 | x, max_indices2_0 = self.downsample2_0(x) 601 | x = self.regular2_1(x) 602 | x = self.dilated2_2(x) 603 | x = self.asymmetric2_3(x) 604 | x = self.dilated2_4(x) 605 | x = self.regular2_5(x) 606 | x = self.dilated2_6(x) 607 | x = self.asymmetric2_7(x) 608 | x = self.dilated2_8(x) 609 | 610 | # Stage 3 - Encoder 611 | x = self.regular3_0(x) 612 | x = self.dilated3_1(x) 613 | x = self.asymmetric3_2(x) 614 | x = self.dilated3_3(x) 615 | x = self.regular3_4(x) 616 | x = self.dilated3_5(x) 617 | x = self.asymmetric3_6(x) 618 | x = self.dilated3_7(x) 619 | 620 | # Stage 4 - Decoder 621 | x = self.upsample4_0(x, max_indices2_0, output_size=stage2_input_size) 622 | x = self.regular4_1(x) 623 | x = self.regular4_2(x) 624 | 625 | # Stage 5 - Decoder 626 | x = self.upsample5_0(x, max_indices1_0, output_size=stage1_input_size) 627 | x = self.regular5_1(x) 628 | x = self.transposed_conv(x, output_size=input_size) 629 | 630 | return x 631 | --------------------------------------------------------------------------------