├── requirements_dev.txt
├── save
├── ENet_CamVid
│ ├── ENet
│ └── ENet_summary.txt
├── ENet_Cityscapes
│ ├── ENet
│ └── ENet_summary.txt
└── README.md
├── data
├── __init__.py
├── README.md
├── camvid.py
├── utils.py
└── cityscapes.py
├── metric
├── __init__.py
├── metric.py
├── confusionmatrix.py
└── iou.py
├── requirements.txt
├── Dockerfile
├── LICENSE
├── .gitignore
├── test.py
├── train.py
├── transforms.py
├── args.py
├── utils.py
├── README.md
├── main.py
└── models
└── enet.py
/requirements_dev.txt:
--------------------------------------------------------------------------------
1 | -r requirements.txt
2 | flake8
3 | black
4 |
--------------------------------------------------------------------------------
/save/ENet_CamVid/ENet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtvs/PyTorch-ENet/HEAD/save/ENet_CamVid/ENet
--------------------------------------------------------------------------------
/save/ENet_Cityscapes/ENet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidtvs/PyTorch-ENet/HEAD/save/ENet_Cityscapes/ENet
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .camvid import CamVid
2 | from .cityscapes import Cityscapes
3 |
4 | __all__ = ['CamVid', 'Cityscapes']
5 |
--------------------------------------------------------------------------------
/metric/__init__.py:
--------------------------------------------------------------------------------
1 | from .confusionmatrix import ConfusionMatrix
2 | from .iou import IoU
3 | from .metric import Metric
4 |
5 | __all__ = ['ConfusionMatrix', 'IoU', 'Metric']
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler>=0.10.0
2 | kiwisolver>=1.0.1
3 | matplotlib>=3.0.2
4 | numpy>=1.16.0
5 | Pillow>=6.2.0
6 | pyparsing>=2.3.1
7 | python-dateutil>=2.7.5
8 | pytz>=2018.9
9 | six>=1.12.0
10 | torch>=1.1.0
11 | torchvision>=0.2.2
12 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Based on a PyTorch docker image that matches the minimum requirements: PyTorch 1.1.0
2 | FROM pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-runtime
3 |
4 | RUN python -m pip install --upgrade pip
5 |
6 | COPY . /enet
7 | WORKDIR /enet
8 | RUN pip install -r requirements.txt
9 |
--------------------------------------------------------------------------------
/metric/metric.py:
--------------------------------------------------------------------------------
1 | class Metric(object):
2 | """Base class for all metrics.
3 |
4 | From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
5 | """
6 | def reset(self):
7 | pass
8 |
9 | def add(self):
10 | pass
11 |
12 | def value(self):
13 | pass
14 |
--------------------------------------------------------------------------------
/save/ENet_CamVid/ENet_summary.txt:
--------------------------------------------------------------------------------
1 | ARGUMENTS
2 | batch_size: 10
3 | dataset: camvid
4 | dataset_dir: ../CamVid/
5 | device: cuda
6 | epochs: 300
7 | height: 360
8 | ignore_unlabeled: True
9 | imshow_batch: False
10 | learning_rate: 0.0005
11 | lr_decay: 0.1
12 | lr_decay_epochs: 100
13 | mode: full
14 | name: ENet
15 | print_step: False
16 | resume: False
17 | save_dir: save/new_camvid/
18 | weighing: ENet
19 | weight_decay: 0.0002
20 | width: 480
21 | workers: 4
22 |
23 | BEST VALIDATION
24 | Epoch: 280
25 | Mean IoU: 0.6518655444842216
26 |
--------------------------------------------------------------------------------
/save/ENet_Cityscapes/ENet_summary.txt:
--------------------------------------------------------------------------------
1 | ARGUMENTS
2 | batch_size: 4
3 | dataset: cityscapes
4 | dataset_dir: ../Cityscapes/1024x512/
5 | device: cuda
6 | epochs: 300
7 | height: 512
8 | ignore_unlabeled: True
9 | imshow_batch: False
10 | learning_rate: 0.0005
11 | lr_decay: 0.1
12 | lr_decay_epochs: 100
13 | mode: train
14 | name: ENet
15 | print_step: False
16 | resume: False
17 | save_dir: save/cityscapes_2/
18 | weighing: ENet
19 | weight_decay: 0.0002
20 | width: 1024
21 | workers: 4
22 |
23 | BEST VALIDATION
24 | Epoch: 250
25 | Mean IoU: 0.5949690267526815
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 davidtvs
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Supported datasets
2 |
3 | - CamVid
4 | - CityScapes
5 |
6 | Note: When referring to the number of classes, the void/unlabeled class is excluded.
7 |
8 | ## CamVid Dataset
9 |
10 | The Cambridge-driving Labeled Video Database (CamVid) is a collection of over ten minutes of high-quality 30Hz footage with object class semantic labels at 1Hz and in part, 15Hz. Each pixel is associated with one of 32 classes.
11 |
12 | The CamVid dataset supported here is a 12 class version developed by the authors of SegNet. [Download link here](https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid). For actual training, an 11 class version is used - the "road marking" class is combined with the "road" class.
13 |
14 | More detailed information about the CamVid dataset can be found [here](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) and on the [SegNet GitHub repository](https://github.com/alexgkendall/SegNet-Tutorial).
15 |
16 | ## Cityscapes
17 |
18 | Cityscapes is a set of stereo video sequences recorded in streets from 50 different cities with 34 different classes. There are 5000 images with fine annotations and 20000 images coarsely annotated.
19 |
20 | The version supported here is the finely annotated one with 19 classes.
21 |
22 | For more detailed information see the official [website](https://www.cityscapes-dataset.com/) and [repository](https://github.com/mcordts/cityscapesScripts).
23 |
24 | The dataset can be downloaded from https://www.cityscapes-dataset.com/downloads/. At this time, a registration is required to download the data.
25 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | .static_storage/
57 | .media/
58 | local_settings.py
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | # Sublime text project
108 | *.sublime-workspace
109 | *.sublime-project
110 |
111 | # VSCode
112 | .vscode/
113 | *.code-workspace
114 | .devcontainer
115 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Test:
5 | """Tests the ``model`` on the specified test dataset using the
6 | data loader, and loss criterion.
7 |
8 | Keyword arguments:
9 | - model (``nn.Module``): the model instance to test.
10 | - data_loader (``Dataloader``): Provides single or multi-process
11 | iterators over the dataset.
12 | - criterion (``Optimizer``): The loss criterion.
13 | - metric (```Metric``): An instance specifying the metric to return.
14 | - device (``torch.device``): An object representing the device on which
15 | tensors are allocated.
16 |
17 | """
18 |
19 | def __init__(self, model, data_loader, criterion, metric, device):
20 | self.model = model
21 | self.data_loader = data_loader
22 | self.criterion = criterion
23 | self.metric = metric
24 | self.device = device
25 |
26 | def run_epoch(self, iteration_loss=False):
27 | """Runs an epoch of validation.
28 |
29 | Keyword arguments:
30 | - iteration_loss (``bool``, optional): Prints loss at every step.
31 |
32 | Returns:
33 | - The epoch loss (float), and the values of the specified metrics
34 |
35 | """
36 | self.model.eval()
37 | epoch_loss = 0.0
38 | self.metric.reset()
39 | for step, batch_data in enumerate(self.data_loader):
40 | # Get the inputs and labels
41 | inputs = batch_data[0].to(self.device)
42 | labels = batch_data[1].to(self.device)
43 |
44 | with torch.no_grad():
45 | # Forward propagation
46 | outputs = self.model(inputs)
47 |
48 | # Loss computation
49 | loss = self.criterion(outputs, labels)
50 |
51 | # Keep track of loss for current epoch
52 | epoch_loss += loss.item()
53 |
54 | # Keep track of evaluation the metric
55 | self.metric.add(outputs.detach(), labels.detach())
56 |
57 | if iteration_loss:
58 | print("[Step: %d] Iteration loss: %.4f" % (step, loss.item()))
59 |
60 | return epoch_loss / len(self.data_loader), self.metric.value()
61 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | class Train:
2 | """Performs the training of ``model`` given a training dataset data
3 | loader, the optimizer, and the loss criterion.
4 |
5 | Keyword arguments:
6 | - model (``nn.Module``): the model instance to train.
7 | - data_loader (``Dataloader``): Provides single or multi-process
8 | iterators over the dataset.
9 | - optim (``Optimizer``): The optimization algorithm.
10 | - criterion (``Optimizer``): The loss criterion.
11 | - metric (```Metric``): An instance specifying the metric to return.
12 | - device (``torch.device``): An object representing the device on which
13 | tensors are allocated.
14 |
15 | """
16 |
17 | def __init__(self, model, data_loader, optim, criterion, metric, device):
18 | self.model = model
19 | self.data_loader = data_loader
20 | self.optim = optim
21 | self.criterion = criterion
22 | self.metric = metric
23 | self.device = device
24 |
25 | def run_epoch(self, iteration_loss=False):
26 | """Runs an epoch of training.
27 |
28 | Keyword arguments:
29 | - iteration_loss (``bool``, optional): Prints loss at every step.
30 |
31 | Returns:
32 | - The epoch loss (float).
33 |
34 | """
35 | self.model.train()
36 | epoch_loss = 0.0
37 | self.metric.reset()
38 | for step, batch_data in enumerate(self.data_loader):
39 | # Get the inputs and labels
40 | inputs = batch_data[0].to(self.device)
41 | labels = batch_data[1].to(self.device)
42 |
43 | # Forward propagation
44 | outputs = self.model(inputs)
45 |
46 | # Loss computation
47 | loss = self.criterion(outputs, labels)
48 |
49 | # Backpropagation
50 | self.optim.zero_grad()
51 | loss.backward()
52 | self.optim.step()
53 |
54 | # Keep track of loss for current epoch
55 | epoch_loss += loss.item()
56 |
57 | # Keep track of the evaluation metric
58 | self.metric.add(outputs.detach(), labels.detach())
59 |
60 | if iteration_loss:
61 | print("[Step: %d] Iteration loss: %.4f" % (step, loss.item()))
62 |
63 | return epoch_loss / len(self.data_loader), self.metric.value()
64 |
--------------------------------------------------------------------------------
/save/README.md:
--------------------------------------------------------------------------------
1 | # Pre-trained models
2 |
3 | | Dataset | Classes 1 | Input resolution | Batch size | Epochs | Mean IoU (%) | GPU memory (GiB) | Training time (hours)2 |
4 | | :------------------------------------------------------------------: | :------------------: | :--------------: | :--------: | :----: | :---------------: | :--------------: | :-------------------------------: |
5 | | [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) | 11 | 480x360 | 10 | 300 | 52.13 | 4.2 | 1 |
6 | | [Cityscapes](https://www.cityscapes-dataset.com/) | 19 | 1024x512 | 4 | 300 | 59.54 | 5.4 | 20 |
7 |
8 | ## Per-class IoU: CamVid3
9 |
10 | | | Sky | Building | Pole | Road | Pavement | Tree | Sign Symbol | Fence | Car | Pedestrian | Bicyclist |
11 | | :-----: | :---: | :------: | :---: | :---: | :------: | :---: | :---------: | :---: | :---: | :--------: | :-------: |
12 | | IoU (%) | 90.2 | 68.6 | 22.6 | 91.5 | 73.2 | 63.6 | 19.3 | 16.7 | 65.1 | 27.2 | 35.0 |
13 |
14 | ## Per-class IoU: Cityscapes4
15 |
16 | | | Road | Sidewalk | Building | Wall | Fence | Pole | Traffic light | Traffic Sign | Vegetation | Terrain | Sky | Person | Rider | Car | Truck | Bus | Train | Motorcycle | Bicycle |
17 | | :-----: | :---: | :------: | :------: | :---: | :---: | :---: | :-----------: | :----------: | :--------: | :-----: | :---: | :----: | :---: | :---: | :---: | :---: | :---: | :--------: | :-----: |
18 | | IoU (%) | 96.1 | 73.3 | 85.8 | 44.1 | 40.5 | 45.3 | 42.5 | 53.9 | 87.9 | 53.5 | 90.1 | 62.3 | 44.3 | 87.6 | 46.6 | 58.2 | 34.8 | 25.8 | 57.9 |
19 |
20 | 1 When referring to the number of classes, the void/unlabeled class is always excluded.
21 | 2 These are just for reference. Implementation, datasets, and hardware changes can lead to very different results. Reference hardware: Nvidia GTX 1070 and an AMD Ryzen 5 3600 3.6GHz. You can also train for 100 epochs or so and get similar mean IoU (± 2%).
22 | 3 Test set.
23 | 4 Validation set.
24 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | from collections import OrderedDict
5 | from torchvision.transforms import ToPILImage
6 |
7 |
8 | class PILToLongTensor(object):
9 | """Converts a ``PIL Image`` to a ``torch.LongTensor``.
10 |
11 | Code adapted from: http://pytorch.org/docs/master/torchvision/transforms.html?highlight=totensor
12 |
13 | """
14 |
15 | def __call__(self, pic):
16 | """Performs the conversion from a ``PIL Image`` to a ``torch.LongTensor``.
17 |
18 | Keyword arguments:
19 | - pic (``PIL.Image``): the image to convert to ``torch.LongTensor``
20 |
21 | Returns:
22 | A ``torch.LongTensor``.
23 |
24 | """
25 | if not isinstance(pic, Image.Image):
26 | raise TypeError("pic should be PIL Image. Got {}".format(
27 | type(pic)))
28 |
29 | # handle numpy array
30 | if isinstance(pic, np.ndarray):
31 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
32 | # backward compatibility
33 | return img.long()
34 |
35 | # Convert PIL image to ByteTensor
36 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
37 |
38 | # Reshape tensor
39 | nchannel = len(pic.mode)
40 | img = img.view(pic.size[1], pic.size[0], nchannel)
41 |
42 | # Convert to long and squeeze the channels
43 | return img.transpose(0, 1).transpose(0,
44 | 2).contiguous().long().squeeze_()
45 |
46 |
47 | class LongTensorToRGBPIL(object):
48 | """Converts a ``torch.LongTensor`` to a ``PIL image``.
49 |
50 | The input is a ``torch.LongTensor`` where each pixel's value identifies the
51 | class.
52 |
53 | Keyword arguments:
54 | - rgb_encoding (``OrderedDict``): An ``OrderedDict`` that relates pixel
55 | values, class names, and class colors.
56 |
57 | """
58 | def __init__(self, rgb_encoding):
59 | self.rgb_encoding = rgb_encoding
60 |
61 | def __call__(self, tensor):
62 | """Performs the conversion from ``torch.LongTensor`` to a ``PIL image``
63 |
64 | Keyword arguments:
65 | - tensor (``torch.LongTensor``): the tensor to convert
66 |
67 | Returns:
68 | A ``PIL.Image``.
69 |
70 | """
71 | # Check if label_tensor is a LongTensor
72 | if not isinstance(tensor, torch.LongTensor):
73 | raise TypeError("label_tensor should be torch.LongTensor. Got {}"
74 | .format(type(tensor)))
75 | # Check if encoding is a ordered dictionary
76 | if not isinstance(self.rgb_encoding, OrderedDict):
77 | raise TypeError("encoding should be an OrderedDict. Got {}".format(
78 | type(self.rgb_encoding)))
79 |
80 | # label_tensor might be an image without a channel dimension, in this
81 | # case unsqueeze it
82 | if len(tensor.size()) == 2:
83 | tensor.unsqueeze_(0)
84 |
85 | color_tensor = torch.ByteTensor(3, tensor.size(1), tensor.size(2))
86 |
87 | for index, (class_name, color) in enumerate(self.rgb_encoding.items()):
88 | # Get a mask of elements equal to index
89 | mask = torch.eq(tensor, index).squeeze_()
90 | # Fill color_tensor with corresponding colors
91 | for channel, color_value in enumerate(color):
92 | color_tensor[channel].masked_fill_(mask, color_value)
93 |
94 | return ToPILImage()(color_tensor)
95 |
--------------------------------------------------------------------------------
/metric/confusionmatrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from metric import metric
4 |
5 |
6 | class ConfusionMatrix(metric.Metric):
7 | """Constructs a confusion matrix for a multi-class classification problems.
8 |
9 | Does not support multi-label, multi-class problems.
10 |
11 | Keyword arguments:
12 | - num_classes (int): number of classes in the classification problem.
13 | - normalized (boolean, optional): Determines whether or not the confusion
14 | matrix is normalized or not. Default: False.
15 |
16 | Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
17 | """
18 |
19 | def __init__(self, num_classes, normalized=False):
20 | super().__init__()
21 |
22 | self.conf = np.ndarray((num_classes, num_classes), dtype=np.int64)
23 | self.normalized = normalized
24 | self.num_classes = num_classes
25 | self.reset()
26 |
27 | def reset(self):
28 | self.conf.fill(0)
29 |
30 | def add(self, predicted, target):
31 | """Computes the confusion matrix
32 |
33 | The shape of the confusion matrix is K x K, where K is the number
34 | of classes.
35 |
36 | Keyword arguments:
37 | - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
38 | predicted scores obtained from the model for N examples and K classes,
39 | or an N-tensor/array of integer values between 0 and K-1.
40 | - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
41 | ground-truth classes for N examples and K classes, or an N-tensor/array
42 | of integer values between 0 and K-1.
43 |
44 | """
45 | # If target and/or predicted are tensors, convert them to numpy arrays
46 | if torch.is_tensor(predicted):
47 | predicted = predicted.cpu().numpy()
48 | if torch.is_tensor(target):
49 | target = target.cpu().numpy()
50 |
51 | assert predicted.shape[0] == target.shape[0], \
52 | 'number of targets and predicted outputs do not match'
53 |
54 | if np.ndim(predicted) != 1:
55 | assert predicted.shape[1] == self.num_classes, \
56 | 'number of predictions does not match size of confusion matrix'
57 | predicted = np.argmax(predicted, 1)
58 | else:
59 | assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \
60 | 'predicted values are not between 0 and k-1'
61 |
62 | if np.ndim(target) != 1:
63 | assert target.shape[1] == self.num_classes, \
64 | 'Onehot target does not match size of confusion matrix'
65 | assert (target >= 0).all() and (target <= 1).all(), \
66 | 'in one-hot encoding, target values should be 0 or 1'
67 | assert (target.sum(1) == 1).all(), \
68 | 'multi-label setting is not supported'
69 | target = np.argmax(target, 1)
70 | else:
71 | assert (target.max() < self.num_classes) and (target.min() >= 0), \
72 | 'target values are not between 0 and k-1'
73 |
74 | # hack for bincounting 2 arrays together
75 | x = predicted + self.num_classes * target
76 | bincount_2d = np.bincount(
77 | x.astype(np.int64), minlength=self.num_classes**2)
78 | assert bincount_2d.size == self.num_classes**2
79 | conf = bincount_2d.reshape((self.num_classes, self.num_classes))
80 |
81 | self.conf += conf
82 |
83 | def value(self):
84 | """
85 | Returns:
86 | Confustion matrix of K rows and K columns, where rows corresponds
87 | to ground-truth targets and columns corresponds to predicted
88 | targets.
89 | """
90 | if self.normalized:
91 | conf = self.conf.astype(np.float32)
92 | return conf / conf.sum(1).clip(min=1e-12)[:, None]
93 | else:
94 | return self.conf
95 |
--------------------------------------------------------------------------------
/metric/iou.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from metric import metric
4 | from metric.confusionmatrix import ConfusionMatrix
5 |
6 |
7 | class IoU(metric.Metric):
8 | """Computes the intersection over union (IoU) per class and corresponding
9 | mean (mIoU).
10 |
11 | Intersection over union (IoU) is a common evaluation metric for semantic
12 | segmentation. The predictions are first accumulated in a confusion matrix
13 | and the IoU is computed from it as follows:
14 |
15 | IoU = true_positive / (true_positive + false_positive + false_negative).
16 |
17 | Keyword arguments:
18 | - num_classes (int): number of classes in the classification problem
19 | - normalized (boolean, optional): Determines whether or not the confusion
20 | matrix is normalized or not. Default: False.
21 | - ignore_index (int or iterable, optional): Index of the classes to ignore
22 | when computing the IoU. Can be an int, or any iterable of ints.
23 | """
24 |
25 | def __init__(self, num_classes, normalized=False, ignore_index=None):
26 | super().__init__()
27 | self.conf_metric = ConfusionMatrix(num_classes, normalized)
28 |
29 | if ignore_index is None:
30 | self.ignore_index = None
31 | elif isinstance(ignore_index, int):
32 | self.ignore_index = (ignore_index,)
33 | else:
34 | try:
35 | self.ignore_index = tuple(ignore_index)
36 | except TypeError:
37 | raise ValueError("'ignore_index' must be an int or iterable")
38 |
39 | def reset(self):
40 | self.conf_metric.reset()
41 |
42 | def add(self, predicted, target):
43 | """Adds the predicted and target pair to the IoU metric.
44 |
45 | Keyword arguments:
46 | - predicted (Tensor): Can be a (N, K, H, W) tensor of
47 | predicted scores obtained from the model for N examples and K classes,
48 | or (N, H, W) tensor of integer values between 0 and K-1.
49 | - target (Tensor): Can be a (N, K, H, W) tensor of
50 | target scores for N examples and K classes, or (N, H, W) tensor of
51 | integer values between 0 and K-1.
52 |
53 | """
54 | # Dimensions check
55 | assert predicted.size(0) == target.size(0), \
56 | 'number of targets and predicted outputs do not match'
57 | assert predicted.dim() == 3 or predicted.dim() == 4, \
58 | "predictions must be of dimension (N, H, W) or (N, K, H, W)"
59 | assert target.dim() == 3 or target.dim() == 4, \
60 | "targets must be of dimension (N, H, W) or (N, K, H, W)"
61 |
62 | # If the tensor is in categorical format convert it to integer format
63 | if predicted.dim() == 4:
64 | _, predicted = predicted.max(1)
65 | if target.dim() == 4:
66 | _, target = target.max(1)
67 |
68 | self.conf_metric.add(predicted.view(-1), target.view(-1))
69 |
70 | def value(self):
71 | """Computes the IoU and mean IoU.
72 |
73 | The mean computation ignores NaN elements of the IoU array.
74 |
75 | Returns:
76 | Tuple: (IoU, mIoU). The first output is the per class IoU,
77 | for K classes it's numpy.ndarray with K elements. The second output,
78 | is the mean IoU.
79 | """
80 | conf_matrix = self.conf_metric.value()
81 | if self.ignore_index is not None:
82 | conf_matrix[:, self.ignore_index] = 0
83 | conf_matrix[self.ignore_index, :] = 0
84 | true_positive = np.diag(conf_matrix)
85 | false_positive = np.sum(conf_matrix, 0) - true_positive
86 | false_negative = np.sum(conf_matrix, 1) - true_positive
87 |
88 | # Just in case we get a division by 0, ignore/hide the error
89 | with np.errstate(divide='ignore', invalid='ignore'):
90 | iou = true_positive / (true_positive + false_positive + false_negative)
91 |
92 | return iou, np.nanmean(iou)
93 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 |
4 | def get_arguments():
5 | """Defines command-line arguments, and parses them.
6 |
7 | """
8 | parser = ArgumentParser()
9 |
10 | # Execution mode
11 | parser.add_argument(
12 | "--mode",
13 | "-m",
14 | choices=['train', 'test', 'full'],
15 | default='train',
16 | help=("train: performs training and validation; test: tests the model "
17 | "found in \"--save-dir\" with name \"--name\" on \"--dataset\"; "
18 | "full: combines train and test modes. Default: train"))
19 | parser.add_argument(
20 | "--resume",
21 | action='store_true',
22 | help=("The model found in \"--save-dir/--name/\" and filename "
23 | "\"--name.h5\" is loaded."))
24 |
25 | # Hyperparameters
26 | parser.add_argument(
27 | "--batch-size",
28 | "-b",
29 | type=int,
30 | default=10,
31 | help="The batch size. Default: 10")
32 | parser.add_argument(
33 | "--epochs",
34 | type=int,
35 | default=300,
36 | help="Number of training epochs. Default: 300")
37 | parser.add_argument(
38 | "--learning-rate",
39 | "-lr",
40 | type=float,
41 | default=5e-4,
42 | help="The learning rate. Default: 5e-4")
43 | parser.add_argument(
44 | "--lr-decay",
45 | type=float,
46 | default=0.1,
47 | help="The learning rate decay factor. Default: 0.5")
48 | parser.add_argument(
49 | "--lr-decay-epochs",
50 | type=int,
51 | default=100,
52 | help="The number of epochs before adjusting the learning rate. "
53 | "Default: 100")
54 | parser.add_argument(
55 | "--weight-decay",
56 | "-wd",
57 | type=float,
58 | default=2e-4,
59 | help="L2 regularization factor. Default: 2e-4")
60 |
61 | # Dataset
62 | parser.add_argument(
63 | "--dataset",
64 | choices=['camvid', 'cityscapes'],
65 | default='camvid',
66 | help="Dataset to use. Default: camvid")
67 | parser.add_argument(
68 | "--dataset-dir",
69 | type=str,
70 | default="data/CamVid",
71 | help="Path to the root directory of the selected dataset. "
72 | "Default: data/CamVid")
73 | parser.add_argument(
74 | "--height",
75 | type=int,
76 | default=360,
77 | help="The image height. Default: 360")
78 | parser.add_argument(
79 | "--width",
80 | type=int,
81 | default=480,
82 | help="The image width. Default: 480")
83 | parser.add_argument(
84 | "--weighing",
85 | choices=['enet', 'mfb', 'none'],
86 | default='ENet',
87 | help="The class weighing technique to apply to the dataset. "
88 | "Default: enet")
89 | parser.add_argument(
90 | "--with-unlabeled",
91 | dest='ignore_unlabeled',
92 | action='store_false',
93 | help="The unlabeled class is not ignored.")
94 |
95 | # Settings
96 | parser.add_argument(
97 | "--workers",
98 | type=int,
99 | default=4,
100 | help="Number of subprocesses to use for data loading. Default: 4")
101 | parser.add_argument(
102 | "--print-step",
103 | action='store_true',
104 | help="Print loss every step")
105 | parser.add_argument(
106 | "--imshow-batch",
107 | action='store_true',
108 | help=("Displays batch images when loading the dataset and making "
109 | "predictions."))
110 | parser.add_argument(
111 | "--device",
112 | default='cuda',
113 | help="Device on which the network will be trained. Default: cuda")
114 |
115 | # Storage settings
116 | parser.add_argument(
117 | "--name",
118 | type=str,
119 | default='ENet',
120 | help="Name given to the model when saving. Default: ENet")
121 | parser.add_argument(
122 | "--save-dir",
123 | type=str,
124 | default='save',
125 | help="The directory where models are saved. Default: save")
126 |
127 | return parser.parse_args()
128 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import os
6 |
7 |
8 | def batch_transform(batch, transform):
9 | """Applies a transform to a batch of samples.
10 |
11 | Keyword arguments:
12 | - batch (): a batch os samples
13 | - transform (callable): A function/transform to apply to ``batch``
14 |
15 | """
16 |
17 | # Convert the single channel label to RGB in tensor form
18 | # 1. torch.unbind removes the 0-dimension of "labels" and returns a tuple of
19 | # all slices along that dimension
20 | # 2. the transform is applied to each slice
21 | transf_slices = [transform(tensor) for tensor in torch.unbind(batch)]
22 |
23 | return torch.stack(transf_slices)
24 |
25 |
26 | def imshow_batch(images, labels):
27 | """Displays two grids of images. The top grid displays ``images``
28 | and the bottom grid ``labels``
29 |
30 | Keyword arguments:
31 | - images (``Tensor``): a 4D mini-batch tensor of shape
32 | (B, C, H, W)
33 | - labels (``Tensor``): a 4D mini-batch tensor of shape
34 | (B, C, H, W)
35 |
36 | """
37 |
38 | # Make a grid with the images and labels and convert it to numpy
39 | images = torchvision.utils.make_grid(images).numpy()
40 | labels = torchvision.utils.make_grid(labels).numpy()
41 |
42 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 7))
43 | ax1.imshow(np.transpose(images, (1, 2, 0)))
44 | ax2.imshow(np.transpose(labels, (1, 2, 0)))
45 |
46 | plt.show()
47 |
48 |
49 | def save_checkpoint(model, optimizer, epoch, miou, args):
50 | """Saves the model in a specified directory with a specified name.save
51 |
52 | Keyword arguments:
53 | - model (``nn.Module``): The model to save.
54 | - optimizer (``torch.optim``): The optimizer state to save.
55 | - epoch (``int``): The current epoch for the model.
56 | - miou (``float``): The mean IoU obtained by the model.
57 | - args (``ArgumentParser``): An instance of ArgumentParser which contains
58 | the arguments used to train ``model``. The arguments are written to a text
59 | file in ``args.save_dir`` named "``args.name``_args.txt".
60 |
61 | """
62 | name = args.name
63 | save_dir = args.save_dir
64 |
65 | assert os.path.isdir(
66 | save_dir), "The directory \"{0}\" doesn't exist.".format(save_dir)
67 |
68 | # Save model
69 | model_path = os.path.join(save_dir, name)
70 | checkpoint = {
71 | 'epoch': epoch,
72 | 'miou': miou,
73 | 'state_dict': model.state_dict(),
74 | 'optimizer': optimizer.state_dict()
75 | }
76 | torch.save(checkpoint, model_path)
77 |
78 | # Save arguments
79 | summary_filename = os.path.join(save_dir, name + '_summary.txt')
80 | with open(summary_filename, 'w') as summary_file:
81 | sorted_args = sorted(vars(args))
82 | summary_file.write("ARGUMENTS\n")
83 | for arg in sorted_args:
84 | arg_str = "{0}: {1}\n".format(arg, getattr(args, arg))
85 | summary_file.write(arg_str)
86 |
87 | summary_file.write("\nBEST VALIDATION\n")
88 | summary_file.write("Epoch: {0}\n". format(epoch))
89 | summary_file.write("Mean IoU: {0}\n". format(miou))
90 |
91 |
92 | def load_checkpoint(model, optimizer, folder_dir, filename):
93 | """Saves the model in a specified directory with a specified name.save
94 |
95 | Keyword arguments:
96 | - model (``nn.Module``): The stored model state is copied to this model
97 | instance.
98 | - optimizer (``torch.optim``): The stored optimizer state is copied to this
99 | optimizer instance.
100 | - folder_dir (``string``): The path to the folder where the saved model
101 | state is located.
102 | - filename (``string``): The model filename.
103 |
104 | Returns:
105 | The epoch, mean IoU, ``model``, and ``optimizer`` loaded from the
106 | checkpoint.
107 |
108 | """
109 | assert os.path.isdir(
110 | folder_dir), "The directory \"{0}\" doesn't exist.".format(folder_dir)
111 |
112 | # Create folder to save model and information
113 | model_path = os.path.join(folder_dir, filename)
114 | assert os.path.isfile(
115 | model_path), "The model file \"{0}\" doesn't exist.".format(filename)
116 |
117 | # Load the stored model parameters to the model instance
118 | checkpoint = torch.load(model_path)
119 | model.load_state_dict(checkpoint['state_dict'])
120 | optimizer.load_state_dict(checkpoint['optimizer'])
121 | epoch = checkpoint['epoch']
122 | miou = checkpoint['miou']
123 |
124 | return model, optimizer, epoch, miou
125 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch-ENet
2 |
3 | PyTorch (v1.1.0) implementation of [*ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation*](https://arxiv.org/abs/1606.02147), ported from the lua-torch implementation [ENet-training](https://github.com/e-lab/ENet-training) created by the authors.
4 |
5 | This implementation has been tested on the CamVid and Cityscapes datasets. Currently, a pre-trained version of the model trained in CamVid and Cityscapes is available [here](https://github.com/davidtvs/PyTorch-ENet/tree/master/save).
6 |
7 | | Dataset | Classes 1 | Input resolution | Batch size | Epochs | Mean IoU (%) | GPU memory (GiB) | Training time (hours)2 |
8 | | :------------------------------------------------------------------: | :------------------: | :--------------: | :--------: | :----: | :---------------: | :--------------: | :-------------------------------: |
9 | | [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) | 11 | 480x360 | 10 | 300 | 52.13 | 4.2 | 1 |
10 | | [Cityscapes](https://www.cityscapes-dataset.com/) | 19 | 1024x512 | 4 | 300 | 59.54 | 5.4 | 20 |
11 |
12 | 1 When referring to the number of classes, the void/unlabeled class is always excluded.
13 | 2 These are just for reference. Implementation, datasets, and hardware changes can lead to very different results. Reference hardware: Nvidia GTX 1070 and an AMD Ryzen 5 3600 3.6GHz. You can also train for 100 epochs or so and get similar mean IoU (± 2%).
14 | 3 Test set.
15 | 4 Validation set.
16 |
17 | ## Installation
18 |
19 | ### Local pip
20 |
21 | 1. Python 3 and pip
22 | 2. Set up a virtual environment (optional, but recommended)
23 | 3. Install dependencies using pip: `pip install -r requirements.txt`
24 |
25 | ### Docker image
26 |
27 | 1. Build the image: `docker build -t enet .`
28 | 2. Run: `docker run -it --gpus all --ipc host enet`
29 |
30 | ## Usage
31 |
32 | Run [``main.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/main.py), the main script file used for training and/or testing the model. The following options are supported:
33 |
34 | ```
35 | python main.py [-h] [--mode {train,test,full}] [--resume]
36 | [--batch-size BATCH_SIZE] [--epochs EPOCHS]
37 | [--learning-rate LEARNING_RATE] [--lr-decay LR_DECAY]
38 | [--lr-decay-epochs LR_DECAY_EPOCHS]
39 | [--weight-decay WEIGHT_DECAY] [--dataset {camvid,cityscapes}]
40 | [--dataset-dir DATASET_DIR] [--height HEIGHT] [--width WIDTH]
41 | [--weighing {enet,mfb,none}] [--with-unlabeled]
42 | [--workers WORKERS] [--print-step] [--imshow-batch]
43 | [--device DEVICE] [--name NAME] [--save-dir SAVE_DIR]
44 | ```
45 |
46 | For help on the optional arguments run: ``python main.py -h``
47 |
48 |
49 | ### Examples: Training
50 |
51 | ```
52 | python main.py -m train --save-dir save/folder/ --name model_name --dataset name --dataset-dir path/root_directory/
53 | ```
54 |
55 |
56 | ### Examples: Resuming training
57 |
58 | ```
59 | python main.py -m train --resume True --save-dir save/folder/ --name model_name --dataset name --dataset-dir path/root_directory/
60 | ```
61 |
62 |
63 | ### Examples: Testing
64 |
65 | ```
66 | python main.py -m test --save-dir save/folder/ --name model_name --dataset name --dataset-dir path/root_directory/
67 | ```
68 |
69 |
70 | ## Project structure
71 |
72 | ### Folders
73 |
74 | - [``data``](https://github.com/davidtvs/PyTorch-ENet/tree/master/data): Contains instructions on how to download the datasets and the code that handles data loading.
75 | - [``metric``](https://github.com/davidtvs/PyTorch-ENet/tree/master/metric): Evaluation-related metrics.
76 | - [``models``](https://github.com/davidtvs/PyTorch-ENet/tree/master/models): ENet model definition.
77 | - [``save``](https://github.com/davidtvs/PyTorch-ENet/tree/master/save): By default, ``main.py`` will save models in this folder. The pre-trained models can also be found here.
78 |
79 | ### Files
80 |
81 | - [``args.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/args.py): Contains all command-line options.
82 | - [``main.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/main.py): Main script file used for training and/or testing the model.
83 | - [``test.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/test.py): Defines the ``Test`` class which is responsible for testing the model.
84 | - [``train.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/train.py): Defines the ``Train`` class which is responsible for training the model.
85 | - [``transforms.py``](https://github.com/davidtvs/PyTorch-ENet/blob/master/transforms.py): Defines image transformations to convert an RGB image encoding classes to a ``torch.LongTensor`` and vice versa.
86 |
--------------------------------------------------------------------------------
/data/camvid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch.utils.data as data
4 | from . import utils
5 |
6 |
7 | class CamVid(data.Dataset):
8 | """CamVid dataset loader where the dataset is arranged as in
9 | https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid.
10 |
11 |
12 | Keyword arguments:
13 | - root_dir (``string``): Root directory path.
14 | - mode (``string``): The type of dataset: 'train' for training set, 'val'
15 | for validation set, and 'test' for test set.
16 | - transform (``callable``, optional): A function/transform that takes in
17 | an PIL image and returns a transformed version. Default: None.
18 | - label_transform (``callable``, optional): A function/transform that takes
19 | in the target and transforms it. Default: None.
20 | - loader (``callable``, optional): A function to load an image given its
21 | path. By default ``default_loader`` is used.
22 |
23 | """
24 | # Training dataset root folders
25 | train_folder = 'train'
26 | train_lbl_folder = 'trainannot'
27 |
28 | # Validation dataset root folders
29 | val_folder = 'val'
30 | val_lbl_folder = 'valannot'
31 |
32 | # Test dataset root folders
33 | test_folder = 'test'
34 | test_lbl_folder = 'testannot'
35 |
36 | # Images extension
37 | img_extension = '.png'
38 |
39 | # Default encoding for pixel value, class name, and class color
40 | color_encoding = OrderedDict([
41 | ('sky', (128, 128, 128)),
42 | ('building', (128, 0, 0)),
43 | ('pole', (192, 192, 128)),
44 | ('road_marking', (255, 69, 0)),
45 | ('road', (128, 64, 128)),
46 | ('pavement', (60, 40, 222)),
47 | ('tree', (128, 128, 0)),
48 | ('sign_symbol', (192, 128, 128)),
49 | ('fence', (64, 64, 128)),
50 | ('car', (64, 0, 128)),
51 | ('pedestrian', (64, 64, 0)),
52 | ('bicyclist', (0, 128, 192)),
53 | ('unlabeled', (0, 0, 0))
54 | ])
55 |
56 | def __init__(self,
57 | root_dir,
58 | mode='train',
59 | transform=None,
60 | label_transform=None,
61 | loader=utils.pil_loader):
62 | self.root_dir = root_dir
63 | self.mode = mode
64 | self.transform = transform
65 | self.label_transform = label_transform
66 | self.loader = loader
67 |
68 | if self.mode.lower() == 'train':
69 | # Get the training data and labels filepaths
70 | self.train_data = utils.get_files(
71 | os.path.join(root_dir, self.train_folder),
72 | extension_filter=self.img_extension)
73 |
74 | self.train_labels = utils.get_files(
75 | os.path.join(root_dir, self.train_lbl_folder),
76 | extension_filter=self.img_extension)
77 | elif self.mode.lower() == 'val':
78 | # Get the validation data and labels filepaths
79 | self.val_data = utils.get_files(
80 | os.path.join(root_dir, self.val_folder),
81 | extension_filter=self.img_extension)
82 |
83 | self.val_labels = utils.get_files(
84 | os.path.join(root_dir, self.val_lbl_folder),
85 | extension_filter=self.img_extension)
86 | elif self.mode.lower() == 'test':
87 | # Get the test data and labels filepaths
88 | self.test_data = utils.get_files(
89 | os.path.join(root_dir, self.test_folder),
90 | extension_filter=self.img_extension)
91 |
92 | self.test_labels = utils.get_files(
93 | os.path.join(root_dir, self.test_lbl_folder),
94 | extension_filter=self.img_extension)
95 | else:
96 | raise RuntimeError("Unexpected dataset mode. "
97 | "Supported modes are: train, val and test")
98 |
99 | def __getitem__(self, index):
100 | """
101 | Args:
102 | - index (``int``): index of the item in the dataset
103 |
104 | Returns:
105 | A tuple of ``PIL.Image`` (image, label) where label is the ground-truth
106 | of the image.
107 |
108 | """
109 | if self.mode.lower() == 'train':
110 | data_path, label_path = self.train_data[index], self.train_labels[
111 | index]
112 | elif self.mode.lower() == 'val':
113 | data_path, label_path = self.val_data[index], self.val_labels[
114 | index]
115 | elif self.mode.lower() == 'test':
116 | data_path, label_path = self.test_data[index], self.test_labels[
117 | index]
118 | else:
119 | raise RuntimeError("Unexpected dataset mode. "
120 | "Supported modes are: train, val and test")
121 |
122 | img, label = self.loader(data_path, label_path)
123 |
124 | if self.transform is not None:
125 | img = self.transform(img)
126 |
127 | if self.label_transform is not None:
128 | label = self.label_transform(label)
129 |
130 | return img, label
131 |
132 | def __len__(self):
133 | """Returns the length of the dataset."""
134 | if self.mode.lower() == 'train':
135 | return len(self.train_data)
136 | elif self.mode.lower() == 'val':
137 | return len(self.val_data)
138 | elif self.mode.lower() == 'test':
139 | return len(self.test_data)
140 | else:
141 | raise RuntimeError("Unexpected dataset mode. "
142 | "Supported modes are: train, val and test")
143 |
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 |
5 |
6 | def get_files(folder, name_filter=None, extension_filter=None):
7 | """Helper function that returns the list of files in a specified folder
8 | with a specified extension.
9 |
10 | Keyword arguments:
11 | - folder (``string``): The path to a folder.
12 | - name_filter (```string``, optional): The returned files must contain
13 | this substring in their filename. Default: None; files are not filtered.
14 | - extension_filter (``string``, optional): The desired file extension.
15 | Default: None; files are not filtered
16 |
17 | """
18 | if not os.path.isdir(folder):
19 | raise RuntimeError("\"{0}\" is not a folder.".format(folder))
20 |
21 | # Filename filter: if not specified don't filter (condition always true);
22 | # otherwise, use a lambda expression to filter out files that do not
23 | # contain "name_filter"
24 | if name_filter is None:
25 | # This looks hackish...there is probably a better way
26 | name_cond = lambda filename: True
27 | else:
28 | name_cond = lambda filename: name_filter in filename
29 |
30 | # Extension filter: if not specified don't filter (condition always true);
31 | # otherwise, use a lambda expression to filter out files whose extension
32 | # is not "extension_filter"
33 | if extension_filter is None:
34 | # This looks hackish...there is probably a better way
35 | ext_cond = lambda filename: True
36 | else:
37 | ext_cond = lambda filename: filename.endswith(extension_filter)
38 |
39 | filtered_files = []
40 |
41 | # Explore the directory tree to get files that contain "name_filter" and
42 | # with extension "extension_filter"
43 | for path, _, files in os.walk(folder):
44 | files.sort()
45 | for file in files:
46 | if name_cond(file) and ext_cond(file):
47 | full_path = os.path.join(path, file)
48 | filtered_files.append(full_path)
49 |
50 | return filtered_files
51 |
52 |
53 | def pil_loader(data_path, label_path):
54 | """Loads a sample and label image given their path as PIL images.
55 |
56 | Keyword arguments:
57 | - data_path (``string``): The filepath to the image.
58 | - label_path (``string``): The filepath to the ground-truth image.
59 |
60 | Returns the image and the label as PIL images.
61 |
62 | """
63 | data = Image.open(data_path)
64 | label = Image.open(label_path)
65 |
66 | return data, label
67 |
68 |
69 | def remap(image, old_values, new_values):
70 | assert isinstance(image, Image.Image) or isinstance(
71 | image, np.ndarray), "image must be of type PIL.Image or numpy.ndarray"
72 | assert type(new_values) is tuple, "new_values must be of type tuple"
73 | assert type(old_values) is tuple, "old_values must be of type tuple"
74 | assert len(new_values) == len(
75 | old_values), "new_values and old_values must have the same length"
76 |
77 | # If image is a PIL.Image convert it to a numpy array
78 | if isinstance(image, Image.Image):
79 | image = np.array(image)
80 |
81 | # Replace old values by the new ones
82 | tmp = np.zeros_like(image)
83 | for old, new in zip(old_values, new_values):
84 | # Since tmp is already initialized as zeros we can skip new values
85 | # equal to 0
86 | if new != 0:
87 | tmp[image == old] = new
88 |
89 | return Image.fromarray(tmp)
90 |
91 |
92 | def enet_weighing(dataloader, num_classes, c=1.02):
93 | """Computes class weights as described in the ENet paper:
94 |
95 | w_class = 1 / (ln(c + p_class)),
96 |
97 | where c is usually 1.02 and p_class is the propensity score of that
98 | class:
99 |
100 | propensity_score = freq_class / total_pixels.
101 |
102 | References: https://arxiv.org/abs/1606.02147
103 |
104 | Keyword arguments:
105 | - dataloader (``data.Dataloader``): A data loader to iterate over the
106 | dataset.
107 | - num_classes (``int``): The number of classes.
108 | - c (``int``, optional): AN additional hyper-parameter which restricts
109 | the interval of values for the weights. Default: 1.02.
110 |
111 | """
112 | class_count = 0
113 | total = 0
114 | for _, label in dataloader:
115 | label = label.cpu().numpy()
116 |
117 | # Flatten label
118 | flat_label = label.flatten()
119 |
120 | # Sum up the number of pixels of each class and the total pixel
121 | # counts for each label
122 | class_count += np.bincount(flat_label, minlength=num_classes)
123 | total += flat_label.size
124 |
125 | # Compute propensity score and then the weights for each class
126 | propensity_score = class_count / total
127 | class_weights = 1 / (np.log(c + propensity_score))
128 |
129 | return class_weights
130 |
131 |
132 | def median_freq_balancing(dataloader, num_classes):
133 | """Computes class weights using median frequency balancing as described
134 | in https://arxiv.org/abs/1411.4734:
135 |
136 | w_class = median_freq / freq_class,
137 |
138 | where freq_class is the number of pixels of a given class divided by
139 | the total number of pixels in images where that class is present, and
140 | median_freq is the median of freq_class.
141 |
142 | Keyword arguments:
143 | - dataloader (``data.Dataloader``): A data loader to iterate over the
144 | dataset.
145 | whose weights are going to be computed.
146 | - num_classes (``int``): The number of classes
147 |
148 | """
149 | class_count = 0
150 | total = 0
151 | for _, label in dataloader:
152 | label = label.cpu().numpy()
153 |
154 | # Flatten label
155 | flat_label = label.flatten()
156 |
157 | # Sum up the class frequencies
158 | bincount = np.bincount(flat_label, minlength=num_classes)
159 |
160 | # Create of mask of classes that exist in the label
161 | mask = bincount > 0
162 | # Multiply the mask by the pixel count. The resulting array has
163 | # one element for each class. The value is either 0 (if the class
164 | # does not exist in the label) or equal to the pixel count (if
165 | # the class exists in the label)
166 | total += mask * flat_label.size
167 |
168 | # Sum up the number of pixels found for each class
169 | class_count += bincount
170 |
171 | # Compute the frequency and its median
172 | freq = class_count / total
173 | med = np.median(freq)
174 |
175 | return med / freq
176 |
--------------------------------------------------------------------------------
/data/cityscapes.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch.utils.data as data
4 | from . import utils
5 |
6 |
7 | class Cityscapes(data.Dataset):
8 | """Cityscapes dataset https://www.cityscapes-dataset.com/.
9 |
10 | Keyword arguments:
11 | - root_dir (``string``): Root directory path.
12 | - mode (``string``): The type of dataset: 'train' for training set, 'val'
13 | for validation set, and 'test' for test set.
14 | - transform (``callable``, optional): A function/transform that takes in
15 | an PIL image and returns a transformed version. Default: None.
16 | - label_transform (``callable``, optional): A function/transform that takes
17 | in the target and transforms it. Default: None.
18 | - loader (``callable``, optional): A function to load an image given its
19 | path. By default ``default_loader`` is used.
20 |
21 | """
22 | # Training dataset root folders
23 | train_folder = "leftImg8bit_trainvaltest/leftImg8bit/train"
24 | train_lbl_folder = "gtFine_trainvaltest/gtFine/train"
25 |
26 | # Validation dataset root folders
27 | val_folder = "leftImg8bit_trainvaltest/leftImg8bit/val"
28 | val_lbl_folder = "gtFine_trainvaltest/gtFine/val"
29 |
30 | # Test dataset root folders
31 | test_folder = "leftImg8bit_trainvaltest/leftImg8bit/test"
32 | test_lbl_folder = "gtFine_trainvaltest/gtFine/test"
33 |
34 | # Filters to find the images
35 | img_extension = '.png'
36 | lbl_name_filter = 'labelIds'
37 |
38 | # The values associated with the 35 classes
39 | full_classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
40 | 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
41 | 32, 33, -1)
42 | # The values above are remapped to the following
43 | new_classes = (0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 5, 0, 0, 0, 6, 0, 7,
44 | 8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 17, 18, 19, 0)
45 |
46 | # Default encoding for pixel value, class name, and class color
47 | color_encoding = OrderedDict([
48 | ('unlabeled', (0, 0, 0)),
49 | ('road', (128, 64, 128)),
50 | ('sidewalk', (244, 35, 232)),
51 | ('building', (70, 70, 70)),
52 | ('wall', (102, 102, 156)),
53 | ('fence', (190, 153, 153)),
54 | ('pole', (153, 153, 153)),
55 | ('traffic_light', (250, 170, 30)),
56 | ('traffic_sign', (220, 220, 0)),
57 | ('vegetation', (107, 142, 35)),
58 | ('terrain', (152, 251, 152)),
59 | ('sky', (70, 130, 180)),
60 | ('person', (220, 20, 60)),
61 | ('rider', (255, 0, 0)),
62 | ('car', (0, 0, 142)),
63 | ('truck', (0, 0, 70)),
64 | ('bus', (0, 60, 100)),
65 | ('train', (0, 80, 100)),
66 | ('motorcycle', (0, 0, 230)),
67 | ('bicycle', (119, 11, 32))
68 | ])
69 |
70 | def __init__(self,
71 | root_dir,
72 | mode='train',
73 | transform=None,
74 | label_transform=None,
75 | loader=utils.pil_loader):
76 | self.root_dir = root_dir
77 | self.mode = mode
78 | self.transform = transform
79 | self.label_transform = label_transform
80 | self.loader = loader
81 |
82 | if self.mode.lower() == 'train':
83 | # Get the training data and labels filepaths
84 | self.train_data = utils.get_files(
85 | os.path.join(root_dir, self.train_folder),
86 | extension_filter=self.img_extension)
87 |
88 | self.train_labels = utils.get_files(
89 | os.path.join(root_dir, self.train_lbl_folder),
90 | name_filter=self.lbl_name_filter,
91 | extension_filter=self.img_extension)
92 | elif self.mode.lower() == 'val':
93 | # Get the validation data and labels filepaths
94 | self.val_data = utils.get_files(
95 | os.path.join(root_dir, self.val_folder),
96 | extension_filter=self.img_extension)
97 |
98 | self.val_labels = utils.get_files(
99 | os.path.join(root_dir, self.val_lbl_folder),
100 | name_filter=self.lbl_name_filter,
101 | extension_filter=self.img_extension)
102 | elif self.mode.lower() == 'test':
103 | # Get the test data and labels filepaths
104 | self.test_data = utils.get_files(
105 | os.path.join(root_dir, self.test_folder),
106 | extension_filter=self.img_extension)
107 |
108 | self.test_labels = utils.get_files(
109 | os.path.join(root_dir, self.test_lbl_folder),
110 | name_filter=self.lbl_name_filter,
111 | extension_filter=self.img_extension)
112 | else:
113 | raise RuntimeError("Unexpected dataset mode. "
114 | "Supported modes are: train, val and test")
115 |
116 | def __getitem__(self, index):
117 | """
118 | Args:
119 | - index (``int``): index of the item in the dataset
120 |
121 | Returns:
122 | A tuple of ``PIL.Image`` (image, label) where label is the ground-truth
123 | of the image.
124 |
125 | """
126 | if self.mode.lower() == 'train':
127 | data_path, label_path = self.train_data[index], self.train_labels[
128 | index]
129 | elif self.mode.lower() == 'val':
130 | data_path, label_path = self.val_data[index], self.val_labels[
131 | index]
132 | elif self.mode.lower() == 'test':
133 | data_path, label_path = self.test_data[index], self.test_labels[
134 | index]
135 | else:
136 | raise RuntimeError("Unexpected dataset mode. "
137 | "Supported modes are: train, val and test")
138 |
139 | img, label = self.loader(data_path, label_path)
140 |
141 | # Remap class labels
142 | label = utils.remap(label, self.full_classes, self.new_classes)
143 |
144 | if self.transform is not None:
145 | img = self.transform(img)
146 |
147 | if self.label_transform is not None:
148 | label = self.label_transform(label)
149 |
150 | return img, label
151 |
152 | def __len__(self):
153 | """Returns the length of the dataset."""
154 | if self.mode.lower() == 'train':
155 | return len(self.train_data)
156 | elif self.mode.lower() == 'val':
157 | return len(self.val_data)
158 | elif self.mode.lower() == 'test':
159 | return len(self.test_data)
160 | else:
161 | raise RuntimeError("Unexpected dataset mode. "
162 | "Supported modes are: train, val and test")
163 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.optim.lr_scheduler as lr_scheduler
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 |
10 | from PIL import Image
11 |
12 | import transforms as ext_transforms
13 | from models.enet import ENet
14 | from train import Train
15 | from test import Test
16 | from metric.iou import IoU
17 | from args import get_arguments
18 | from data.utils import enet_weighing, median_freq_balancing
19 | import utils
20 |
21 | # Get the arguments
22 | args = get_arguments()
23 |
24 | device = torch.device(args.device)
25 |
26 |
27 | def load_dataset(dataset):
28 | print("\nLoading dataset...\n")
29 |
30 | print("Selected dataset:", args.dataset)
31 | print("Dataset directory:", args.dataset_dir)
32 | print("Save directory:", args.save_dir)
33 |
34 | image_transform = transforms.Compose(
35 | [transforms.Resize((args.height, args.width)),
36 | transforms.ToTensor()])
37 |
38 | label_transform = transforms.Compose([
39 | transforms.Resize((args.height, args.width), Image.NEAREST),
40 | ext_transforms.PILToLongTensor()
41 | ])
42 |
43 | # Get selected dataset
44 | # Load the training set as tensors
45 | train_set = dataset(
46 | args.dataset_dir,
47 | transform=image_transform,
48 | label_transform=label_transform)
49 | train_loader = data.DataLoader(
50 | train_set,
51 | batch_size=args.batch_size,
52 | shuffle=True,
53 | num_workers=args.workers)
54 |
55 | # Load the validation set as tensors
56 | val_set = dataset(
57 | args.dataset_dir,
58 | mode='val',
59 | transform=image_transform,
60 | label_transform=label_transform)
61 | val_loader = data.DataLoader(
62 | val_set,
63 | batch_size=args.batch_size,
64 | shuffle=False,
65 | num_workers=args.workers)
66 |
67 | # Load the test set as tensors
68 | test_set = dataset(
69 | args.dataset_dir,
70 | mode='test',
71 | transform=image_transform,
72 | label_transform=label_transform)
73 | test_loader = data.DataLoader(
74 | test_set,
75 | batch_size=args.batch_size,
76 | shuffle=False,
77 | num_workers=args.workers)
78 |
79 | # Get encoding between pixel valus in label images and RGB colors
80 | class_encoding = train_set.color_encoding
81 |
82 | # Remove the road_marking class from the CamVid dataset as it's merged
83 | # with the road class
84 | if args.dataset.lower() == 'camvid':
85 | del class_encoding['road_marking']
86 |
87 | # Get number of classes to predict
88 | num_classes = len(class_encoding)
89 |
90 | # Print information for debugging
91 | print("Number of classes to predict:", num_classes)
92 | print("Train dataset size:", len(train_set))
93 | print("Validation dataset size:", len(val_set))
94 |
95 | # Get a batch of samples to display
96 | if args.mode.lower() == 'test':
97 | images, labels = iter(test_loader).next()
98 | else:
99 | images, labels = iter(train_loader).next()
100 | print("Image size:", images.size())
101 | print("Label size:", labels.size())
102 | print("Class-color encoding:", class_encoding)
103 |
104 | # Show a batch of samples and labels
105 | if args.imshow_batch:
106 | print("Close the figure window to continue...")
107 | label_to_rgb = transforms.Compose([
108 | ext_transforms.LongTensorToRGBPIL(class_encoding),
109 | transforms.ToTensor()
110 | ])
111 | color_labels = utils.batch_transform(labels, label_to_rgb)
112 | utils.imshow_batch(images, color_labels)
113 |
114 | # Get class weights from the selected weighing technique
115 | print("\nWeighing technique:", args.weighing)
116 | print("Computing class weights...")
117 | print("(this can take a while depending on the dataset size)")
118 | class_weights = 0
119 | if args.weighing.lower() == 'enet':
120 | class_weights = enet_weighing(train_loader, num_classes)
121 | elif args.weighing.lower() == 'mfb':
122 | class_weights = median_freq_balancing(train_loader, num_classes)
123 | else:
124 | class_weights = None
125 |
126 | if class_weights is not None:
127 | class_weights = torch.from_numpy(class_weights).float().to(device)
128 | # Set the weight of the unlabeled class to 0
129 | if args.ignore_unlabeled:
130 | ignore_index = list(class_encoding).index('unlabeled')
131 | class_weights[ignore_index] = 0
132 |
133 | print("Class weights:", class_weights)
134 |
135 | return (train_loader, val_loader,
136 | test_loader), class_weights, class_encoding
137 |
138 |
139 | def train(train_loader, val_loader, class_weights, class_encoding):
140 | print("\nTraining...\n")
141 |
142 | num_classes = len(class_encoding)
143 |
144 | # Intialize ENet
145 | model = ENet(num_classes).to(device)
146 | # Check if the network architecture is correct
147 | print(model)
148 |
149 | # We are going to use the CrossEntropyLoss loss function as it's most
150 | # frequentely used in classification problems with multiple classes which
151 | # fits the problem. This criterion combines LogSoftMax and NLLLoss.
152 | criterion = nn.CrossEntropyLoss(weight=class_weights)
153 |
154 | # ENet authors used Adam as the optimizer
155 | optimizer = optim.Adam(
156 | model.parameters(),
157 | lr=args.learning_rate,
158 | weight_decay=args.weight_decay)
159 |
160 | # Learning rate decay scheduler
161 | lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs,
162 | args.lr_decay)
163 |
164 | # Evaluation metric
165 | if args.ignore_unlabeled:
166 | ignore_index = list(class_encoding).index('unlabeled')
167 | else:
168 | ignore_index = None
169 | metric = IoU(num_classes, ignore_index=ignore_index)
170 |
171 | # Optionally resume from a checkpoint
172 | if args.resume:
173 | model, optimizer, start_epoch, best_miou = utils.load_checkpoint(
174 | model, optimizer, args.save_dir, args.name)
175 | print("Resuming from model: Start epoch = {0} "
176 | "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
177 | else:
178 | start_epoch = 0
179 | best_miou = 0
180 |
181 | # Start Training
182 | print()
183 | train = Train(model, train_loader, optimizer, criterion, metric, device)
184 | val = Test(model, val_loader, criterion, metric, device)
185 | for epoch in range(start_epoch, args.epochs):
186 | print(">>>> [Epoch: {0:d}] Training".format(epoch))
187 |
188 | epoch_loss, (iou, miou) = train.run_epoch(args.print_step)
189 | lr_updater.step()
190 |
191 | print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
192 | format(epoch, epoch_loss, miou))
193 |
194 | if (epoch + 1) % 10 == 0 or epoch + 1 == args.epochs:
195 | print(">>>> [Epoch: {0:d}] Validation".format(epoch))
196 |
197 | loss, (iou, miou) = val.run_epoch(args.print_step)
198 |
199 | print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
200 | format(epoch, loss, miou))
201 |
202 | # Print per class IoU on last epoch or if best iou
203 | if epoch + 1 == args.epochs or miou > best_miou:
204 | for key, class_iou in zip(class_encoding.keys(), iou):
205 | print("{0}: {1:.4f}".format(key, class_iou))
206 |
207 | # Save the model if it's the best thus far
208 | if miou > best_miou:
209 | print("\nBest model thus far. Saving...\n")
210 | best_miou = miou
211 | utils.save_checkpoint(model, optimizer, epoch + 1, best_miou,
212 | args)
213 |
214 | return model
215 |
216 |
217 | def test(model, test_loader, class_weights, class_encoding):
218 | print("\nTesting...\n")
219 |
220 | num_classes = len(class_encoding)
221 |
222 | # We are going to use the CrossEntropyLoss loss function as it's most
223 | # frequentely used in classification problems with multiple classes which
224 | # fits the problem. This criterion combines LogSoftMax and NLLLoss.
225 | criterion = nn.CrossEntropyLoss(weight=class_weights)
226 |
227 | # Evaluation metric
228 | if args.ignore_unlabeled:
229 | ignore_index = list(class_encoding).index('unlabeled')
230 | else:
231 | ignore_index = None
232 | metric = IoU(num_classes, ignore_index=ignore_index)
233 |
234 | # Test the trained model on the test set
235 | test = Test(model, test_loader, criterion, metric, device)
236 |
237 | print(">>>> Running test dataset")
238 |
239 | loss, (iou, miou) = test.run_epoch(args.print_step)
240 | class_iou = dict(zip(class_encoding.keys(), iou))
241 |
242 | print(">>>> Avg. loss: {0:.4f} | Mean IoU: {1:.4f}".format(loss, miou))
243 |
244 | # Print per class IoU
245 | for key, class_iou in zip(class_encoding.keys(), iou):
246 | print("{0}: {1:.4f}".format(key, class_iou))
247 |
248 | # Show a batch of samples and labels
249 | if args.imshow_batch:
250 | print("A batch of predictions from the test set...")
251 | images, _ = iter(test_loader).next()
252 | predict(model, images, class_encoding)
253 |
254 |
255 | def predict(model, images, class_encoding):
256 | images = images.to(device)
257 |
258 | # Make predictions!
259 | model.eval()
260 | with torch.no_grad():
261 | predictions = model(images)
262 |
263 | # Predictions is one-hot encoded with "num_classes" channels.
264 | # Convert it to a single int using the indices where the maximum (1) occurs
265 | _, predictions = torch.max(predictions.data, 1)
266 |
267 | label_to_rgb = transforms.Compose([
268 | ext_transforms.LongTensorToRGBPIL(class_encoding),
269 | transforms.ToTensor()
270 | ])
271 | color_predictions = utils.batch_transform(predictions.cpu(), label_to_rgb)
272 | utils.imshow_batch(images.data.cpu(), color_predictions)
273 |
274 |
275 | # Run only if this module is being run directly
276 | if __name__ == '__main__':
277 |
278 | # Fail fast if the dataset directory doesn't exist
279 | assert os.path.isdir(
280 | args.dataset_dir), "The directory \"{0}\" doesn't exist.".format(
281 | args.dataset_dir)
282 |
283 | # Fail fast if the saving directory doesn't exist
284 | assert os.path.isdir(
285 | args.save_dir), "The directory \"{0}\" doesn't exist.".format(
286 | args.save_dir)
287 |
288 | # Import the requested dataset
289 | if args.dataset.lower() == 'camvid':
290 | from data import CamVid as dataset
291 | elif args.dataset.lower() == 'cityscapes':
292 | from data import Cityscapes as dataset
293 | else:
294 | # Should never happen...but just in case it does
295 | raise RuntimeError("\"{0}\" is not a supported dataset.".format(
296 | args.dataset))
297 |
298 | loaders, w_class, class_encoding = load_dataset(dataset)
299 | train_loader, val_loader, test_loader = loaders
300 |
301 | if args.mode.lower() in {'train', 'full'}:
302 | model = train(train_loader, val_loader, w_class, class_encoding)
303 |
304 | if args.mode.lower() in {'test', 'full'}:
305 | if args.mode.lower() == 'test':
306 | # Intialize a new ENet model
307 | num_classes = len(class_encoding)
308 | model = ENet(num_classes).to(device)
309 |
310 | # Initialize a optimizer just so we can retrieve the model from the
311 | # checkpoint
312 | optimizer = optim.Adam(model.parameters())
313 |
314 | # Load the previoulsy saved model state to the ENet model
315 | model = utils.load_checkpoint(model, optimizer, args.save_dir,
316 | args.name)[0]
317 |
318 | if args.mode.lower() == 'test':
319 | print(model)
320 |
321 | test(model, test_loader, w_class, class_encoding)
322 |
--------------------------------------------------------------------------------
/models/enet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class InitialBlock(nn.Module):
6 | """The initial block is composed of two branches:
7 | 1. a main branch which performs a regular convolution with stride 2;
8 | 2. an extension branch which performs max-pooling.
9 |
10 | Doing both operations in parallel and concatenating their results
11 | allows for efficient downsampling and expansion. The main branch
12 | outputs 13 feature maps while the extension branch outputs 3, for a
13 | total of 16 feature maps after concatenation.
14 |
15 | Keyword arguments:
16 | - in_channels (int): the number of input channels.
17 | - out_channels (int): the number output channels.
18 | - kernel_size (int, optional): the kernel size of the filters used in
19 | the convolution layer. Default: 3.
20 | - padding (int, optional): zero-padding added to both sides of the
21 | input. Default: 0.
22 | - bias (bool, optional): Adds a learnable bias to the output if
23 | ``True``. Default: False.
24 | - relu (bool, optional): When ``True`` ReLU is used as the activation
25 | function; otherwise, PReLU is used. Default: True.
26 |
27 | """
28 |
29 | def __init__(self,
30 | in_channels,
31 | out_channels,
32 | bias=False,
33 | relu=True):
34 | super().__init__()
35 |
36 | if relu:
37 | activation = nn.ReLU
38 | else:
39 | activation = nn.PReLU
40 |
41 | # Main branch - As stated above the number of output channels for this
42 | # branch is the total minus 3, since the remaining channels come from
43 | # the extension branch
44 | self.main_branch = nn.Conv2d(
45 | in_channels,
46 | out_channels - 3,
47 | kernel_size=3,
48 | stride=2,
49 | padding=1,
50 | bias=bias)
51 |
52 | # Extension branch
53 | self.ext_branch = nn.MaxPool2d(3, stride=2, padding=1)
54 |
55 | # Initialize batch normalization to be used after concatenation
56 | self.batch_norm = nn.BatchNorm2d(out_channels)
57 |
58 | # PReLU layer to apply after concatenating the branches
59 | self.out_activation = activation()
60 |
61 | def forward(self, x):
62 | main = self.main_branch(x)
63 | ext = self.ext_branch(x)
64 |
65 | # Concatenate branches
66 | out = torch.cat((main, ext), 1)
67 |
68 | # Apply batch normalization
69 | out = self.batch_norm(out)
70 |
71 | return self.out_activation(out)
72 |
73 |
74 | class RegularBottleneck(nn.Module):
75 | """Regular bottlenecks are the main building block of ENet.
76 | Main branch:
77 | 1. Shortcut connection.
78 |
79 | Extension branch:
80 | 1. 1x1 convolution which decreases the number of channels by
81 | ``internal_ratio``, also called a projection;
82 | 2. regular, dilated or asymmetric convolution;
83 | 3. 1x1 convolution which increases the number of channels back to
84 | ``channels``, also called an expansion;
85 | 4. dropout as a regularizer.
86 |
87 | Keyword arguments:
88 | - channels (int): the number of input and output channels.
89 | - internal_ratio (int, optional): a scale factor applied to
90 | ``channels`` used to compute the number of
91 | channels after the projection. eg. given ``channels`` equal to 128 and
92 | internal_ratio equal to 2 the number of channels after the projection
93 | is 64. Default: 4.
94 | - kernel_size (int, optional): the kernel size of the filters used in
95 | the convolution layer described above in item 2 of the extension
96 | branch. Default: 3.
97 | - padding (int, optional): zero-padding added to both sides of the
98 | input. Default: 0.
99 | - dilation (int, optional): spacing between kernel elements for the
100 | convolution described in item 2 of the extension branch. Default: 1.
101 | asymmetric (bool, optional): flags if the convolution described in
102 | item 2 of the extension branch is asymmetric or not. Default: False.
103 | - dropout_prob (float, optional): probability of an element to be
104 | zeroed. Default: 0 (no dropout).
105 | - bias (bool, optional): Adds a learnable bias to the output if
106 | ``True``. Default: False.
107 | - relu (bool, optional): When ``True`` ReLU is used as the activation
108 | function; otherwise, PReLU is used. Default: True.
109 |
110 | """
111 |
112 | def __init__(self,
113 | channels,
114 | internal_ratio=4,
115 | kernel_size=3,
116 | padding=0,
117 | dilation=1,
118 | asymmetric=False,
119 | dropout_prob=0,
120 | bias=False,
121 | relu=True):
122 | super().__init__()
123 |
124 | # Check in the internal_scale parameter is within the expected range
125 | # [1, channels]
126 | if internal_ratio <= 1 or internal_ratio > channels:
127 | raise RuntimeError("Value out of range. Expected value in the "
128 | "interval [1, {0}], got internal_scale={1}."
129 | .format(channels, internal_ratio))
130 |
131 | internal_channels = channels // internal_ratio
132 |
133 | if relu:
134 | activation = nn.ReLU
135 | else:
136 | activation = nn.PReLU
137 |
138 | # Main branch - shortcut connection
139 |
140 | # Extension branch - 1x1 convolution, followed by a regular, dilated or
141 | # asymmetric convolution, followed by another 1x1 convolution, and,
142 | # finally, a regularizer (spatial dropout). Number of channels is constant.
143 |
144 | # 1x1 projection convolution
145 | self.ext_conv1 = nn.Sequential(
146 | nn.Conv2d(
147 | channels,
148 | internal_channels,
149 | kernel_size=1,
150 | stride=1,
151 | bias=bias), nn.BatchNorm2d(internal_channels), activation())
152 |
153 | # If the convolution is asymmetric we split the main convolution in
154 | # two. Eg. for a 5x5 asymmetric convolution we have two convolution:
155 | # the first is 5x1 and the second is 1x5.
156 | if asymmetric:
157 | self.ext_conv2 = nn.Sequential(
158 | nn.Conv2d(
159 | internal_channels,
160 | internal_channels,
161 | kernel_size=(kernel_size, 1),
162 | stride=1,
163 | padding=(padding, 0),
164 | dilation=dilation,
165 | bias=bias), nn.BatchNorm2d(internal_channels), activation(),
166 | nn.Conv2d(
167 | internal_channels,
168 | internal_channels,
169 | kernel_size=(1, kernel_size),
170 | stride=1,
171 | padding=(0, padding),
172 | dilation=dilation,
173 | bias=bias), nn.BatchNorm2d(internal_channels), activation())
174 | else:
175 | self.ext_conv2 = nn.Sequential(
176 | nn.Conv2d(
177 | internal_channels,
178 | internal_channels,
179 | kernel_size=kernel_size,
180 | stride=1,
181 | padding=padding,
182 | dilation=dilation,
183 | bias=bias), nn.BatchNorm2d(internal_channels), activation())
184 |
185 | # 1x1 expansion convolution
186 | self.ext_conv3 = nn.Sequential(
187 | nn.Conv2d(
188 | internal_channels,
189 | channels,
190 | kernel_size=1,
191 | stride=1,
192 | bias=bias), nn.BatchNorm2d(channels), activation())
193 |
194 | self.ext_regul = nn.Dropout2d(p=dropout_prob)
195 |
196 | # PReLU layer to apply after adding the branches
197 | self.out_activation = activation()
198 |
199 | def forward(self, x):
200 | # Main branch shortcut
201 | main = x
202 |
203 | # Extension branch
204 | ext = self.ext_conv1(x)
205 | ext = self.ext_conv2(ext)
206 | ext = self.ext_conv3(ext)
207 | ext = self.ext_regul(ext)
208 |
209 | # Add main and extension branches
210 | out = main + ext
211 |
212 | return self.out_activation(out)
213 |
214 |
215 | class DownsamplingBottleneck(nn.Module):
216 | """Downsampling bottlenecks further downsample the feature map size.
217 |
218 | Main branch:
219 | 1. max pooling with stride 2; indices are saved to be used for
220 | unpooling later.
221 |
222 | Extension branch:
223 | 1. 2x2 convolution with stride 2 that decreases the number of channels
224 | by ``internal_ratio``, also called a projection;
225 | 2. regular convolution (by default, 3x3);
226 | 3. 1x1 convolution which increases the number of channels to
227 | ``out_channels``, also called an expansion;
228 | 4. dropout as a regularizer.
229 |
230 | Keyword arguments:
231 | - in_channels (int): the number of input channels.
232 | - out_channels (int): the number of output channels.
233 | - internal_ratio (int, optional): a scale factor applied to ``channels``
234 | used to compute the number of channels after the projection. eg. given
235 | ``channels`` equal to 128 and internal_ratio equal to 2 the number of
236 | channels after the projection is 64. Default: 4.
237 | - return_indices (bool, optional): if ``True``, will return the max
238 | indices along with the outputs. Useful when unpooling later.
239 | - dropout_prob (float, optional): probability of an element to be
240 | zeroed. Default: 0 (no dropout).
241 | - bias (bool, optional): Adds a learnable bias to the output if
242 | ``True``. Default: False.
243 | - relu (bool, optional): When ``True`` ReLU is used as the activation
244 | function; otherwise, PReLU is used. Default: True.
245 |
246 | """
247 |
248 | def __init__(self,
249 | in_channels,
250 | out_channels,
251 | internal_ratio=4,
252 | return_indices=False,
253 | dropout_prob=0,
254 | bias=False,
255 | relu=True):
256 | super().__init__()
257 |
258 | # Store parameters that are needed later
259 | self.return_indices = return_indices
260 |
261 | # Check in the internal_scale parameter is within the expected range
262 | # [1, channels]
263 | if internal_ratio <= 1 or internal_ratio > in_channels:
264 | raise RuntimeError("Value out of range. Expected value in the "
265 | "interval [1, {0}], got internal_scale={1}. "
266 | .format(in_channels, internal_ratio))
267 |
268 | internal_channels = in_channels // internal_ratio
269 |
270 | if relu:
271 | activation = nn.ReLU
272 | else:
273 | activation = nn.PReLU
274 |
275 | # Main branch - max pooling followed by feature map (channels) padding
276 | self.main_max1 = nn.MaxPool2d(
277 | 2,
278 | stride=2,
279 | return_indices=return_indices)
280 |
281 | # Extension branch - 2x2 convolution, followed by a regular, dilated or
282 | # asymmetric convolution, followed by another 1x1 convolution. Number
283 | # of channels is doubled.
284 |
285 | # 2x2 projection convolution with stride 2
286 | self.ext_conv1 = nn.Sequential(
287 | nn.Conv2d(
288 | in_channels,
289 | internal_channels,
290 | kernel_size=2,
291 | stride=2,
292 | bias=bias), nn.BatchNorm2d(internal_channels), activation())
293 |
294 | # Convolution
295 | self.ext_conv2 = nn.Sequential(
296 | nn.Conv2d(
297 | internal_channels,
298 | internal_channels,
299 | kernel_size=3,
300 | stride=1,
301 | padding=1,
302 | bias=bias), nn.BatchNorm2d(internal_channels), activation())
303 |
304 | # 1x1 expansion convolution
305 | self.ext_conv3 = nn.Sequential(
306 | nn.Conv2d(
307 | internal_channels,
308 | out_channels,
309 | kernel_size=1,
310 | stride=1,
311 | bias=bias), nn.BatchNorm2d(out_channels), activation())
312 |
313 | self.ext_regul = nn.Dropout2d(p=dropout_prob)
314 |
315 | # PReLU layer to apply after concatenating the branches
316 | self.out_activation = activation()
317 |
318 | def forward(self, x):
319 | # Main branch shortcut
320 | if self.return_indices:
321 | main, max_indices = self.main_max1(x)
322 | else:
323 | main = self.main_max1(x)
324 |
325 | # Extension branch
326 | ext = self.ext_conv1(x)
327 | ext = self.ext_conv2(ext)
328 | ext = self.ext_conv3(ext)
329 | ext = self.ext_regul(ext)
330 |
331 | # Main branch channel padding
332 | n, ch_ext, h, w = ext.size()
333 | ch_main = main.size()[1]
334 | padding = torch.zeros(n, ch_ext - ch_main, h, w)
335 |
336 | # Before concatenating, check if main is on the CPU or GPU and
337 | # convert padding accordingly
338 | if main.is_cuda:
339 | padding = padding.cuda()
340 |
341 | # Concatenate
342 | main = torch.cat((main, padding), 1)
343 |
344 | # Add main and extension branches
345 | out = main + ext
346 |
347 | return self.out_activation(out), max_indices
348 |
349 |
350 | class UpsamplingBottleneck(nn.Module):
351 | """The upsampling bottlenecks upsample the feature map resolution using max
352 | pooling indices stored from the corresponding downsampling bottleneck.
353 |
354 | Main branch:
355 | 1. 1x1 convolution with stride 1 that decreases the number of channels by
356 | ``internal_ratio``, also called a projection;
357 | 2. max unpool layer using the max pool indices from the corresponding
358 | downsampling max pool layer.
359 |
360 | Extension branch:
361 | 1. 1x1 convolution with stride 1 that decreases the number of channels by
362 | ``internal_ratio``, also called a projection;
363 | 2. transposed convolution (by default, 3x3);
364 | 3. 1x1 convolution which increases the number of channels to
365 | ``out_channels``, also called an expansion;
366 | 4. dropout as a regularizer.
367 |
368 | Keyword arguments:
369 | - in_channels (int): the number of input channels.
370 | - out_channels (int): the number of output channels.
371 | - internal_ratio (int, optional): a scale factor applied to ``in_channels``
372 | used to compute the number of channels after the projection. eg. given
373 | ``in_channels`` equal to 128 and ``internal_ratio`` equal to 2 the number
374 | of channels after the projection is 64. Default: 4.
375 | - dropout_prob (float, optional): probability of an element to be zeroed.
376 | Default: 0 (no dropout).
377 | - bias (bool, optional): Adds a learnable bias to the output if ``True``.
378 | Default: False.
379 | - relu (bool, optional): When ``True`` ReLU is used as the activation
380 | function; otherwise, PReLU is used. Default: True.
381 |
382 | """
383 |
384 | def __init__(self,
385 | in_channels,
386 | out_channels,
387 | internal_ratio=4,
388 | dropout_prob=0,
389 | bias=False,
390 | relu=True):
391 | super().__init__()
392 |
393 | # Check in the internal_scale parameter is within the expected range
394 | # [1, channels]
395 | if internal_ratio <= 1 or internal_ratio > in_channels:
396 | raise RuntimeError("Value out of range. Expected value in the "
397 | "interval [1, {0}], got internal_scale={1}. "
398 | .format(in_channels, internal_ratio))
399 |
400 | internal_channels = in_channels // internal_ratio
401 |
402 | if relu:
403 | activation = nn.ReLU
404 | else:
405 | activation = nn.PReLU
406 |
407 | # Main branch - max pooling followed by feature map (channels) padding
408 | self.main_conv1 = nn.Sequential(
409 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias),
410 | nn.BatchNorm2d(out_channels))
411 |
412 | # Remember that the stride is the same as the kernel_size, just like
413 | # the max pooling layers
414 | self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2)
415 |
416 | # Extension branch - 1x1 convolution, followed by a regular, dilated or
417 | # asymmetric convolution, followed by another 1x1 convolution. Number
418 | # of channels is doubled.
419 |
420 | # 1x1 projection convolution with stride 1
421 | self.ext_conv1 = nn.Sequential(
422 | nn.Conv2d(
423 | in_channels, internal_channels, kernel_size=1, bias=bias),
424 | nn.BatchNorm2d(internal_channels), activation())
425 |
426 | # Transposed convolution
427 | self.ext_tconv1 = nn.ConvTranspose2d(
428 | internal_channels,
429 | internal_channels,
430 | kernel_size=2,
431 | stride=2,
432 | bias=bias)
433 | self.ext_tconv1_bnorm = nn.BatchNorm2d(internal_channels)
434 | self.ext_tconv1_activation = activation()
435 |
436 | # 1x1 expansion convolution
437 | self.ext_conv2 = nn.Sequential(
438 | nn.Conv2d(
439 | internal_channels, out_channels, kernel_size=1, bias=bias),
440 | nn.BatchNorm2d(out_channels))
441 |
442 | self.ext_regul = nn.Dropout2d(p=dropout_prob)
443 |
444 | # PReLU layer to apply after concatenating the branches
445 | self.out_activation = activation()
446 |
447 | def forward(self, x, max_indices, output_size):
448 | # Main branch shortcut
449 | main = self.main_conv1(x)
450 | main = self.main_unpool1(
451 | main, max_indices, output_size=output_size)
452 |
453 | # Extension branch
454 | ext = self.ext_conv1(x)
455 | ext = self.ext_tconv1(ext, output_size=output_size)
456 | ext = self.ext_tconv1_bnorm(ext)
457 | ext = self.ext_tconv1_activation(ext)
458 | ext = self.ext_conv2(ext)
459 | ext = self.ext_regul(ext)
460 |
461 | # Add main and extension branches
462 | out = main + ext
463 |
464 | return self.out_activation(out)
465 |
466 |
467 | class ENet(nn.Module):
468 | """Generate the ENet model.
469 |
470 | Keyword arguments:
471 | - num_classes (int): the number of classes to segment.
472 | - encoder_relu (bool, optional): When ``True`` ReLU is used as the
473 | activation function in the encoder blocks/layers; otherwise, PReLU
474 | is used. Default: False.
475 | - decoder_relu (bool, optional): When ``True`` ReLU is used as the
476 | activation function in the decoder blocks/layers; otherwise, PReLU
477 | is used. Default: True.
478 |
479 | """
480 |
481 | def __init__(self, num_classes, encoder_relu=False, decoder_relu=True):
482 | super().__init__()
483 |
484 | self.initial_block = InitialBlock(3, 16, relu=encoder_relu)
485 |
486 | # Stage 1 - Encoder
487 | self.downsample1_0 = DownsamplingBottleneck(
488 | 16,
489 | 64,
490 | return_indices=True,
491 | dropout_prob=0.01,
492 | relu=encoder_relu)
493 | self.regular1_1 = RegularBottleneck(
494 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu)
495 | self.regular1_2 = RegularBottleneck(
496 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu)
497 | self.regular1_3 = RegularBottleneck(
498 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu)
499 | self.regular1_4 = RegularBottleneck(
500 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu)
501 |
502 | # Stage 2 - Encoder
503 | self.downsample2_0 = DownsamplingBottleneck(
504 | 64,
505 | 128,
506 | return_indices=True,
507 | dropout_prob=0.1,
508 | relu=encoder_relu)
509 | self.regular2_1 = RegularBottleneck(
510 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu)
511 | self.dilated2_2 = RegularBottleneck(
512 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
513 | self.asymmetric2_3 = RegularBottleneck(
514 | 128,
515 | kernel_size=5,
516 | padding=2,
517 | asymmetric=True,
518 | dropout_prob=0.1,
519 | relu=encoder_relu)
520 | self.dilated2_4 = RegularBottleneck(
521 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
522 | self.regular2_5 = RegularBottleneck(
523 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu)
524 | self.dilated2_6 = RegularBottleneck(
525 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
526 | self.asymmetric2_7 = RegularBottleneck(
527 | 128,
528 | kernel_size=5,
529 | asymmetric=True,
530 | padding=2,
531 | dropout_prob=0.1,
532 | relu=encoder_relu)
533 | self.dilated2_8 = RegularBottleneck(
534 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
535 |
536 | # Stage 3 - Encoder
537 | self.regular3_0 = RegularBottleneck(
538 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu)
539 | self.dilated3_1 = RegularBottleneck(
540 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu)
541 | self.asymmetric3_2 = RegularBottleneck(
542 | 128,
543 | kernel_size=5,
544 | padding=2,
545 | asymmetric=True,
546 | dropout_prob=0.1,
547 | relu=encoder_relu)
548 | self.dilated3_3 = RegularBottleneck(
549 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu)
550 | self.regular3_4 = RegularBottleneck(
551 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu)
552 | self.dilated3_5 = RegularBottleneck(
553 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu)
554 | self.asymmetric3_6 = RegularBottleneck(
555 | 128,
556 | kernel_size=5,
557 | asymmetric=True,
558 | padding=2,
559 | dropout_prob=0.1,
560 | relu=encoder_relu)
561 | self.dilated3_7 = RegularBottleneck(
562 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu)
563 |
564 | # Stage 4 - Decoder
565 | self.upsample4_0 = UpsamplingBottleneck(
566 | 128, 64, dropout_prob=0.1, relu=decoder_relu)
567 | self.regular4_1 = RegularBottleneck(
568 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu)
569 | self.regular4_2 = RegularBottleneck(
570 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu)
571 |
572 | # Stage 5 - Decoder
573 | self.upsample5_0 = UpsamplingBottleneck(
574 | 64, 16, dropout_prob=0.1, relu=decoder_relu)
575 | self.regular5_1 = RegularBottleneck(
576 | 16, padding=1, dropout_prob=0.1, relu=decoder_relu)
577 | self.transposed_conv = nn.ConvTranspose2d(
578 | 16,
579 | num_classes,
580 | kernel_size=3,
581 | stride=2,
582 | padding=1,
583 | bias=False)
584 |
585 | def forward(self, x):
586 | # Initial block
587 | input_size = x.size()
588 | x = self.initial_block(x)
589 |
590 | # Stage 1 - Encoder
591 | stage1_input_size = x.size()
592 | x, max_indices1_0 = self.downsample1_0(x)
593 | x = self.regular1_1(x)
594 | x = self.regular1_2(x)
595 | x = self.regular1_3(x)
596 | x = self.regular1_4(x)
597 |
598 | # Stage 2 - Encoder
599 | stage2_input_size = x.size()
600 | x, max_indices2_0 = self.downsample2_0(x)
601 | x = self.regular2_1(x)
602 | x = self.dilated2_2(x)
603 | x = self.asymmetric2_3(x)
604 | x = self.dilated2_4(x)
605 | x = self.regular2_5(x)
606 | x = self.dilated2_6(x)
607 | x = self.asymmetric2_7(x)
608 | x = self.dilated2_8(x)
609 |
610 | # Stage 3 - Encoder
611 | x = self.regular3_0(x)
612 | x = self.dilated3_1(x)
613 | x = self.asymmetric3_2(x)
614 | x = self.dilated3_3(x)
615 | x = self.regular3_4(x)
616 | x = self.dilated3_5(x)
617 | x = self.asymmetric3_6(x)
618 | x = self.dilated3_7(x)
619 |
620 | # Stage 4 - Decoder
621 | x = self.upsample4_0(x, max_indices2_0, output_size=stage2_input_size)
622 | x = self.regular4_1(x)
623 | x = self.regular4_2(x)
624 |
625 | # Stage 5 - Decoder
626 | x = self.upsample5_0(x, max_indices1_0, output_size=stage1_input_size)
627 | x = self.regular5_1(x)
628 | x = self.transposed_conv(x, output_size=input_size)
629 |
630 | return x
631 |
--------------------------------------------------------------------------------