├── .gitignore ├── README.md ├── eval ├── README.md ├── dataset.py ├── erfnet.py ├── erfnet_nobn.py ├── eval_cityscapes_color.py ├── eval_cityscapes_server.py ├── eval_forwardTime.py ├── eval_iou.py ├── iouEval.py └── transform.py ├── example_segmentation.png ├── imagenet ├── README.md ├── erfnet_imagenet.py └── main.py ├── license.txt ├── train ├── README.md ├── dataset.py ├── erfnet.py ├── erfnet_imagenet.py ├── iouEval.py ├── main.py ├── transform.py └── visualize.py └── trained_models ├── erfnet_encoder_pretrained.pth.tar └── erfnet_pretrained.pth /.gitignore: -------------------------------------------------------------------------------- 1 | #Files 2 | *.pyc 3 | *.pyo 4 | */__pycache__/ 5 | */*/__pycache__/ 6 | */*/*/__pycache__/ 7 | eval/save_results/ 8 | eval/save_color/ 9 | save/ 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ERFNet (PyTorch version) 2 | 3 | This code is a toolbox that uses **PyTorch** for training and evaluating the **ERFNet** architecture for semantic segmentation. 4 | 5 | **For the Original Torch version please go [HERE](https://github.com/Eromera/erfnet)** 6 | 7 | NOTE: This PyTorch version has a slightly better result than the ones in the Torch version (used in the paper): 72.1 IoU in Val set and 69.8 IoU in test set. 8 | 9 | ![Example segmentation](example_segmentation.png?raw=true "Example segmentation") 10 | 11 | ## Publications 12 | 13 | If you use this software in your research, please cite our publications: 14 | 15 | **"Efficient ConvNet for Real-time Semantic Segmentation"**, E. Romera, J. M. Alvarez, L. M. Bergasa and R. Arroyo, IEEE Intelligent Vehicles Symposium (IV), pp. 1789-1794, Redondo Beach (California, USA), June 2017. 16 | **[Best Student Paper Award]**, [[pdf]](http://www.robesafe.uah.es/personal/eduardo.romera/pdfs/Romera17iv.pdf) 17 | 18 | **"ERFNet: Efficient Residual Factorized ConvNet for Real-time Semantic Segmentation"**, E. Romera, J. M. Alvarez, L. M. Bergasa and R. Arroyo, Transactions on Intelligent Transportation Systems (T-ITS), December 2017. [[pdf]](http://www.robesafe.uah.es/personal/eduardo.romera/pdfs/Romera17tits.pdf) 19 | 20 | ## Packages 21 | For instructions please refer to the README on each folder: 22 | 23 | * [train](train) contains tools for training the network for semantic segmentation. 24 | * [eval](eval) contains tools for evaluating/visualizing the network's output. 25 | * [imagenet](imagenet) Contains script and model for pretraining ERFNet's encoder in Imagenet. 26 | * [trained_models](trained_models) Contains the trained models used in the papers. NOTE: the pytorch version is slightly different from the torch models. 27 | 28 | ## Requirements: 29 | 30 | * [**The Cityscapes dataset**](https://www.cityscapes-dataset.com/): Download the "leftImg8bit" for the RGB images and the "gtFine" for the labels. **Please note that for training you should use the "_labelTrainIds" and not the "_labelIds", you can download the [cityscapes scripts](https://github.com/mcordts/cityscapesScripts) and use the [conversor](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createTrainIdLabelImgs.py) to generate trainIds from labelIds** 31 | * [**Python 3.6**](https://www.python.org/): If you don't have Python3.6 in your system, I recommend installing it with [Anaconda](https://www.anaconda.com/download/#linux) 32 | * [**PyTorch**](http://pytorch.org/): Make sure to install the Pytorch version for Python 3.6 with CUDA support (code only tested for CUDA 8.0). 33 | * **Additional Python packages**: numpy, matplotlib, Pillow, torchvision and visdom (optional for --visualize flag) 34 | 35 | In Anaconda you can install with: 36 | ``` 37 | conda install numpy matplotlib torchvision Pillow 38 | conda install -c conda-forge visdom 39 | ``` 40 | 41 | If you use Pip (make sure to have it configured for Python3.6) you can install with: 42 | 43 | ``` 44 | pip install numpy matplotlib torchvision Pillow visdom 45 | ``` 46 | 47 | ## License 48 | 49 | This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License, which allows for personal and research use only. For a commercial license please contact the authors. You can view a license summary here: http://creativecommons.org/licenses/by-nc/4.0/ 50 | -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # Functions for evaluating/visualizing the network's output 2 | 3 | Currently there are 4 usable functions to evaluate stuff: 4 | - eval_cityscapes_color 5 | - eval_cityscapes_server 6 | - eval_iou 7 | - eval_forwardTime 8 | 9 | ## eval_cityscapes_color.py 10 | 11 | This code can be used to produce segmentation of the Cityscapes images in color for visualization purposes. By default it saves images in eval/save_color/ folder. You can also visualize results in visdom with --visualize flag. 12 | 13 | **Options:** Specify the Cityscapes folder path with '--datadir' option. Select the cityscapes subset with '--subset' ('val', 'test', 'train' or 'demoSequence'). For other options check the bottom side of the file. 14 | 15 | **Examples:** 16 | ``` 17 | python eval_cityscapes_color.py --datadir /home/datasets/cityscapes/ --subset val 18 | ``` 19 | 20 | ## eval_cityscapes_server.py 21 | 22 | This code can be used to produce segmentation of the Cityscapes images and convert the output indices to the original 'labelIds' so it can be evaluated using the scripts from Cityscapes dataset (evalPixelLevelSemanticLabeling.py) or uploaded to Cityscapes test server. By default it saves images in eval/save_results/ folder. 23 | 24 | **Options:** Specify the Cityscapes folder path with '--datadir' option. Select the cityscapes subset with '--subset' ('val', 'test', 'train' or 'demoSequence'). For other options check the bottom side of the file. 25 | 26 | **Examples:** 27 | ``` 28 | python eval_cityscapes_server.py --datadir /home/datasets/cityscapes/ --subset val 29 | ``` 30 | 31 | ## eval_iou.py 32 | 33 | This code can be used to calculate the IoU (mean and per-class) in a subset of images with labels available, like Cityscapes val/train sets. 34 | 35 | **Options:** Specify the Cityscapes folder path with '--datadir' option. Select the cityscapes subset with '--subset' ('val' or 'train'). For other options check the bottom side of the file. 36 | 37 | **Examples:** 38 | ``` 39 | python eval_iou.py --datadir /home/datasets/cityscapes/ --subset val 40 | ``` 41 | 42 | ## eval_forwardTime.py 43 | This function loads a model specified by '-m' and enters a loop to continuously estimate forward pass time (fwt) in the specified resolution. 44 | 45 | **Options:** Option '--width' specifies the width (default: 1024). Option '--height' specifies the height (default: 512). For other options check the bottom side of the file. 46 | 47 | **Examples:** 48 | ``` 49 | python eval_forwardTime.py 50 | ``` 51 | 52 | **NOTE**: Paper values were obtained with a single Titan X (Maxwell) and a Jetson TX1 using the original Torch code. The pytorch code is a bit faster, but cudahalf (FP16) seems to give problems at the moment for some pytorch versions so this code only runs at FP32 (a bit slower). 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /eval/dataset.py: -------------------------------------------------------------------------------- 1 | # Code with dataset loader for VOC12 and Cityscapes (adapted from bodokaiser/piwise code) 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import numpy as np 7 | import os 8 | 9 | from PIL import Image 10 | 11 | from torch.utils.data import Dataset 12 | 13 | EXTENSIONS = ['.jpg', '.png'] 14 | 15 | def load_image(file): 16 | return Image.open(file) 17 | 18 | def is_image(filename): 19 | return any(filename.endswith(ext) for ext in EXTENSIONS) 20 | 21 | def is_label(filename): 22 | return filename.endswith("_labelTrainIds.png") 23 | 24 | def image_path(root, basename, extension): 25 | return os.path.join(root, f'{basename}{extension}') 26 | 27 | def image_path_city(root, name): 28 | return os.path.join(root, f'{name}') 29 | 30 | def image_basename(filename): 31 | return os.path.basename(os.path.splitext(filename)[0]) 32 | 33 | class VOC12(Dataset): 34 | 35 | def __init__(self, root, input_transform=None, target_transform=None): 36 | self.images_root = os.path.join(root, 'images') 37 | self.labels_root = os.path.join(root, 'labels') 38 | 39 | self.filenames = [image_basename(f) 40 | for f in os.listdir(self.labels_root) if is_image(f)] 41 | self.filenames.sort() 42 | 43 | self.input_transform = input_transform 44 | self.target_transform = target_transform 45 | 46 | def __getitem__(self, index): 47 | filename = self.filenames[index] 48 | 49 | with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f: 50 | image = load_image(f).convert('RGB') 51 | with open(image_path(self.labels_root, filename, '.png'), 'rb') as f: 52 | label = load_image(f).convert('P') 53 | 54 | if self.input_transform is not None: 55 | image = self.input_transform(image) 56 | if self.target_transform is not None: 57 | label = self.target_transform(label) 58 | 59 | return image, label 60 | 61 | def __len__(self): 62 | return len(self.filenames) 63 | 64 | 65 | class cityscapes(Dataset): 66 | 67 | def __init__(self, root, input_transform=None, target_transform=None, subset='val'): 68 | self.images_root = os.path.join(root, 'leftImg8bit/' + subset) 69 | self.labels_root = os.path.join(root, 'gtFine/' + subset) 70 | 71 | self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.images_root)) for f in fn if is_image(f)] 72 | self.filenames.sort() 73 | 74 | self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.labels_root)) for f in fn if is_label(f)] 75 | self.filenamesGt.sort() 76 | 77 | self.input_transform = input_transform 78 | self.target_transform = target_transform 79 | 80 | def __getitem__(self, index): 81 | filename = self.filenames[index] 82 | filenameGt = self.filenamesGt[index] 83 | 84 | #print(filename) 85 | 86 | with open(image_path_city(self.images_root, filename), 'rb') as f: 87 | image = load_image(f).convert('RGB') 88 | with open(image_path_city(self.labels_root, filenameGt), 'rb') as f: 89 | label = load_image(f).convert('P') 90 | 91 | if self.input_transform is not None: 92 | image = self.input_transform(image) 93 | if self.target_transform is not None: 94 | label = self.target_transform(label) 95 | 96 | return image, label, filename, filenameGt 97 | 98 | def __len__(self): 99 | return len(self.filenames) 100 | 101 | -------------------------------------------------------------------------------- /eval/erfnet.py: -------------------------------------------------------------------------------- 1 | # ERFNET full network definition for Pytorch 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | 12 | class DownsamplerBlock (nn.Module): 13 | def __init__(self, ninput, noutput): 14 | super().__init__() 15 | 16 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True) 17 | self.pool = nn.MaxPool2d(2, stride=2) 18 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 19 | 20 | def forward(self, input): 21 | output = torch.cat([self.conv(input), self.pool(input)], 1) 22 | output = self.bn(output) 23 | return F.relu(output) 24 | 25 | 26 | class non_bottleneck_1d (nn.Module): 27 | def __init__(self, chann, dropprob, dilated): 28 | super().__init__() 29 | 30 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True) 31 | 32 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True) 33 | 34 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) 35 | 36 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 37 | 38 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated)) 39 | 40 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) 41 | 42 | self.dropout = nn.Dropout2d(dropprob) 43 | 44 | 45 | def forward(self, input): 46 | 47 | output = self.conv3x1_1(input) 48 | output = F.relu(output) 49 | output = self.conv1x3_1(output) 50 | output = self.bn1(output) 51 | output = F.relu(output) 52 | 53 | output = self.conv3x1_2(output) 54 | output = F.relu(output) 55 | output = self.conv1x3_2(output) 56 | output = self.bn2(output) 57 | 58 | if (self.dropout.p != 0): 59 | output = self.dropout(output) 60 | 61 | return F.relu(output+input) #+input = identity (residual connection) 62 | 63 | 64 | class Encoder(nn.Module): 65 | def __init__(self, num_classes): 66 | super().__init__() 67 | self.initial_block = DownsamplerBlock(3,16) 68 | 69 | self.layers = nn.ModuleList() 70 | 71 | self.layers.append(DownsamplerBlock(16,64)) 72 | 73 | for x in range(0, 5): #5 times 74 | self.layers.append(non_bottleneck_1d(64, 0.1, 1)) 75 | 76 | self.layers.append(DownsamplerBlock(64,128)) 77 | 78 | for x in range(0, 2): #2 times 79 | self.layers.append(non_bottleneck_1d(128, 0.1, 2)) 80 | self.layers.append(non_bottleneck_1d(128, 0.1, 4)) 81 | self.layers.append(non_bottleneck_1d(128, 0.1, 8)) 82 | self.layers.append(non_bottleneck_1d(128, 0.1, 16)) 83 | 84 | #only for encoder mode: 85 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True) 86 | 87 | def forward(self, input, predict=False): 88 | output = self.initial_block(input) 89 | 90 | for layer in self.layers: 91 | output = layer(output) 92 | 93 | if predict: 94 | output = self.output_conv(output) 95 | 96 | return output 97 | 98 | 99 | class UpsamplerBlock (nn.Module): 100 | def __init__(self, ninput, noutput): 101 | super().__init__() 102 | self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True) 103 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 104 | 105 | def forward(self, input): 106 | output = self.conv(input) 107 | output = self.bn(output) 108 | return F.relu(output) 109 | 110 | class Decoder (nn.Module): 111 | def __init__(self, num_classes): 112 | super().__init__() 113 | 114 | self.layers = nn.ModuleList() 115 | 116 | self.layers.append(UpsamplerBlock(128,64)) 117 | self.layers.append(non_bottleneck_1d(64, 0, 1)) 118 | self.layers.append(non_bottleneck_1d(64, 0, 1)) 119 | 120 | self.layers.append(UpsamplerBlock(64,16)) 121 | self.layers.append(non_bottleneck_1d(16, 0, 1)) 122 | self.layers.append(non_bottleneck_1d(16, 0, 1)) 123 | 124 | self.output_conv = nn.ConvTranspose2d( 16, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True) 125 | 126 | def forward(self, input): 127 | output = input 128 | 129 | for layer in self.layers: 130 | output = layer(output) 131 | 132 | output = self.output_conv(output) 133 | 134 | return output 135 | 136 | 137 | class ERFNet(nn.Module): 138 | def __init__(self, num_classes, encoder=None): #use encoder to pass pretrained encoder 139 | super().__init__() 140 | 141 | if (encoder == None): 142 | self.encoder = Encoder(num_classes) 143 | else: 144 | self.encoder = encoder 145 | self.decoder = Decoder(num_classes) 146 | 147 | def forward(self, input, only_encode=False): 148 | if only_encode: 149 | return self.encoder.forward(input, predict=True) 150 | else: 151 | output = self.encoder(input) #predict=False by default 152 | return self.decoder.forward(output) 153 | 154 | -------------------------------------------------------------------------------- /eval/erfnet_nobn.py: -------------------------------------------------------------------------------- 1 | # ERFNET full network definition for Pytorch - without batch normalization layers nor dropout 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | #ERFNET definition 12 | 13 | class DownsamplerBlock (nn.Module): 14 | def __init__(self, ninput, noutput): 15 | super().__init__() 16 | 17 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True) 18 | self.pool = nn.MaxPool2d(2, stride=2) 19 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 20 | 21 | def forward(self, input): 22 | output = torch.cat([self.conv(input), self.pool(input)], 1) 23 | #output = self.bn(output) 24 | return F.relu(output, inplace=True) 25 | 26 | 27 | class non_bottleneck_1d (nn.Module): 28 | def __init__(self, chann, dropprob, dilated): 29 | super().__init__() 30 | 31 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True) 32 | 33 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True) 34 | 35 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) 36 | 37 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 38 | 39 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated)) 40 | 41 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) 42 | 43 | self.dropout = nn.Dropout2d(dropprob) 44 | 45 | 46 | def forward(self, input): 47 | 48 | output = self.conv3x1_1(input) 49 | output = F.relu(output, inplace=True) 50 | output = self.conv1x3_1(output) 51 | #output = self.bn1(output) 52 | output = F.relu(output, inplace=True) 53 | 54 | output = self.conv3x1_2(output) 55 | output = F.relu(output) 56 | output = self.conv1x3_2(output) 57 | #output = self.bn2(output) 58 | #output = F.relu(output) #ESTO ESTABA MAL 59 | 60 | #if (self.dropout.p != 0): 61 | # output = self.dropout(output) 62 | 63 | return F.relu(output+input, inplace=True) #+input = identity (residual connection) 64 | 65 | 66 | class Encoder(nn.Module): 67 | def __init__(self, num_classes): 68 | super().__init__() 69 | self.initial_block = DownsamplerBlock(3,16) 70 | 71 | self.layers = nn.ModuleList() 72 | 73 | self.layers.append(DownsamplerBlock(16,64)) 74 | 75 | for x in range(0, 5): #5 times 76 | self.layers.append(non_bottleneck_1d(64, 0.03, 1)) #Dropout here was wrong in prev trainings 77 | 78 | self.layers.append(DownsamplerBlock(64,128)) 79 | 80 | for x in range(0, 2): #2 times 81 | self.layers.append(non_bottleneck_1d(128, 0.3, 2)) 82 | self.layers.append(non_bottleneck_1d(128, 0.3, 4)) 83 | self.layers.append(non_bottleneck_1d(128, 0.3, 8)) 84 | self.layers.append(non_bottleneck_1d(128, 0.3, 16)) 85 | 86 | #only for encoder mode: 87 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True) 88 | 89 | def forward(self, input, predict=False): 90 | output = self.initial_block(input) 91 | 92 | for layer in self.layers: 93 | output = layer(output) 94 | 95 | if predict: 96 | output = self.output_conv(output) 97 | 98 | return output 99 | 100 | 101 | class UpsamplerBlock (nn.Module): 102 | def __init__(self, ninput, noutput): 103 | super().__init__() 104 | self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True) 105 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 106 | 107 | def forward(self, input): 108 | output = self.conv(input) 109 | #output = self.bn(output) 110 | return F.relu(output, inplace=True) 111 | 112 | class Decoder (nn.Module): 113 | def __init__(self, num_classes): 114 | super().__init__() 115 | 116 | self.layers = nn.ModuleList() 117 | 118 | self.layers.append(UpsamplerBlock(128,64)) 119 | self.layers.append(non_bottleneck_1d(64, 0, 1)) 120 | self.layers.append(non_bottleneck_1d(64, 0, 1)) 121 | 122 | self.layers.append(UpsamplerBlock(64,16)) 123 | self.layers.append(non_bottleneck_1d(16, 0, 1)) 124 | self.layers.append(non_bottleneck_1d(16, 0, 1)) 125 | 126 | self.output_conv = nn.ConvTranspose2d( 16, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True) 127 | 128 | def forward(self, input): 129 | output = input 130 | 131 | for layer in self.layers: 132 | output = layer(output) 133 | 134 | output = self.output_conv(output) 135 | 136 | return output 137 | 138 | 139 | class ERFNet(nn.Module): 140 | def __init__(self, num_classes, encoder=None): #use encoder to pass pretrained encoder 141 | super().__init__() 142 | 143 | if (encoder == None): 144 | self.encoder = Encoder(num_classes) 145 | else: 146 | self.encoder = encoder 147 | self.decoder = Decoder(num_classes) 148 | 149 | def forward(self, input, only_encode=False): 150 | if only_encode: 151 | return self.encoder.forward(input, predict=True) 152 | else: 153 | output = self.encoder(input) #predict=False by default 154 | return self.decoder.forward(output) 155 | 156 | -------------------------------------------------------------------------------- /eval/eval_cityscapes_color.py: -------------------------------------------------------------------------------- 1 | # Code to produce colored segmentation output in Pytorch for all cityscapes subsets 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import numpy as np 7 | import torch 8 | import os 9 | import importlib 10 | 11 | from PIL import Image 12 | from argparse import ArgumentParser 13 | 14 | from torch.autograd import Variable 15 | from torch.utils.data import DataLoader 16 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize 17 | from torchvision.transforms import ToTensor, ToPILImage 18 | 19 | from dataset import cityscapes 20 | from erfnet import ERFNet 21 | from transform import Relabel, ToLabel, Colorize 22 | 23 | import visdom 24 | 25 | 26 | NUM_CHANNELS = 3 27 | NUM_CLASSES = 20 28 | 29 | image_transform = ToPILImage() 30 | input_transform_cityscapes = Compose([ 31 | Resize((512,1024),Image.BILINEAR), 32 | ToTensor(), 33 | #Normalize([.485, .456, .406], [.229, .224, .225]), 34 | ]) 35 | target_transform_cityscapes = Compose([ 36 | Resize((512,1024),Image.NEAREST), 37 | ToLabel(), 38 | Relabel(255, 19), #ignore label to 19 39 | ]) 40 | 41 | cityscapes_trainIds2labelIds = Compose([ 42 | Relabel(19, 255), 43 | Relabel(18, 33), 44 | Relabel(17, 32), 45 | Relabel(16, 31), 46 | Relabel(15, 28), 47 | Relabel(14, 27), 48 | Relabel(13, 26), 49 | Relabel(12, 25), 50 | Relabel(11, 24), 51 | Relabel(10, 23), 52 | Relabel(9, 22), 53 | Relabel(8, 21), 54 | Relabel(7, 20), 55 | Relabel(6, 19), 56 | Relabel(5, 17), 57 | Relabel(4, 13), 58 | Relabel(3, 12), 59 | Relabel(2, 11), 60 | Relabel(1, 8), 61 | Relabel(0, 7), 62 | Relabel(255, 0), 63 | ToPILImage(), 64 | ]) 65 | 66 | def main(args): 67 | 68 | modelpath = args.loadDir + args.loadModel 69 | weightspath = args.loadDir + args.loadWeights 70 | 71 | print ("Loading model: " + modelpath) 72 | print ("Loading weights: " + weightspath) 73 | 74 | #Import ERFNet model from the folder 75 | #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") 76 | model = ERFNet(NUM_CLASSES) 77 | 78 | model = torch.nn.DataParallel(model) 79 | if (not args.cpu): 80 | model = model.cuda() 81 | 82 | #model.load_state_dict(torch.load(args.state)) 83 | #model.load_state_dict(torch.load(weightspath)) #not working if missing key 84 | 85 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements 86 | own_state = model.state_dict() 87 | for name, param in state_dict.items(): 88 | if name not in own_state: 89 | continue 90 | own_state[name].copy_(param) 91 | return model 92 | 93 | model = load_my_state_dict(model, torch.load(weightspath)) 94 | print ("Model and weights LOADED successfully") 95 | 96 | model.eval() 97 | 98 | if(not os.path.exists(args.datadir)): 99 | print ("Error: datadir could not be loaded") 100 | 101 | 102 | loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), 103 | num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 104 | 105 | # For visualizer: 106 | # must launch in other window "python3.6 -m visdom.server -port 8097" 107 | # and access localhost:8097 to see it 108 | if (args.visualize): 109 | vis = visdom.Visdom() 110 | 111 | for step, (images, labels, filename, filenameGt) in enumerate(loader): 112 | if (not args.cpu): 113 | images = images.cuda() 114 | #labels = labels.cuda() 115 | 116 | inputs = Variable(images) 117 | #targets = Variable(labels) 118 | with torch.no_grad(): 119 | outputs = model(inputs) 120 | 121 | label = outputs[0].max(0)[1].byte().cpu().data 122 | #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) 123 | label_color = Colorize()(label.unsqueeze(0)) 124 | 125 | filenameSave = "./save_color/" + filename[0].split("leftImg8bit/")[1] 126 | os.makedirs(os.path.dirname(filenameSave), exist_ok=True) 127 | #image_transform(label.byte()).save(filenameSave) 128 | label_save = ToPILImage()(label_color) 129 | label_save.save(filenameSave) 130 | 131 | if (args.visualize): 132 | vis.image(label_color.numpy()) 133 | print (step, filenameSave) 134 | 135 | 136 | 137 | if __name__ == '__main__': 138 | parser = ArgumentParser() 139 | 140 | parser.add_argument('--state') 141 | 142 | parser.add_argument('--loadDir',default="../trained_models/") 143 | parser.add_argument('--loadWeights', default="erfnet_pretrained.pth") 144 | parser.add_argument('--loadModel', default="erfnet.py") 145 | parser.add_argument('--subset', default="val") #can be val, test, train, demoSequence 146 | 147 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/") 148 | parser.add_argument('--num-workers', type=int, default=4) 149 | parser.add_argument('--batch-size', type=int, default=1) 150 | parser.add_argument('--cpu', action='store_true') 151 | 152 | parser.add_argument('--visualize', action='store_true') 153 | main(parser.parse_args()) 154 | -------------------------------------------------------------------------------- /eval/eval_cityscapes_server.py: -------------------------------------------------------------------------------- 1 | # Code to produce segmentation output in Pytorch for all cityscapes subset 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import numpy as np 7 | import torch 8 | import os 9 | import importlib 10 | 11 | from PIL import Image 12 | from argparse import ArgumentParser 13 | 14 | from torch.autograd import Variable 15 | from torch.utils.data import DataLoader 16 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize 17 | from torchvision.transforms import ToTensor, ToPILImage 18 | 19 | from dataset import cityscapes 20 | from erfnet import ERFNet 21 | from transform import Relabel, ToLabel, Colorize 22 | 23 | 24 | NUM_CHANNELS = 3 25 | NUM_CLASSES = 20 26 | 27 | image_transform = ToPILImage() 28 | input_transform_cityscapes = Compose([ 29 | Resize(512), 30 | ToTensor(), 31 | #Normalize([.485, .456, .406], [.229, .224, .225]), 32 | ]) 33 | target_transform_cityscapes = Compose([ 34 | Resize(512), 35 | ToLabel(), 36 | Relabel(255, 19), #ignore label to 19 37 | ]) 38 | 39 | cityscapes_trainIds2labelIds = Compose([ 40 | Relabel(19, 255), 41 | Relabel(18, 33), 42 | Relabel(17, 32), 43 | Relabel(16, 31), 44 | Relabel(15, 28), 45 | Relabel(14, 27), 46 | Relabel(13, 26), 47 | Relabel(12, 25), 48 | Relabel(11, 24), 49 | Relabel(10, 23), 50 | Relabel(9, 22), 51 | Relabel(8, 21), 52 | Relabel(7, 20), 53 | Relabel(6, 19), 54 | Relabel(5, 17), 55 | Relabel(4, 13), 56 | Relabel(3, 12), 57 | Relabel(2, 11), 58 | Relabel(1, 8), 59 | Relabel(0, 7), 60 | Relabel(255, 0), 61 | ToPILImage(), 62 | Resize(1024, Image.NEAREST), 63 | ]) 64 | 65 | def main(args): 66 | 67 | modelpath = args.loadDir + args.loadModel 68 | weightspath = args.loadDir + args.loadWeights 69 | 70 | print ("Loading model: " + modelpath) 71 | print ("Loading weights: " + weightspath) 72 | 73 | #Import ERFNet model from the folder 74 | #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet") 75 | model = ERFNet(NUM_CLASSES) 76 | 77 | model = torch.nn.DataParallel(model) 78 | if (not args.cpu): 79 | model = model.cuda() 80 | 81 | #model.load_state_dict(torch.load(args.state)) 82 | #model.load_state_dict(torch.load(weightspath)) #not working if missing key 83 | 84 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements 85 | own_state = model.state_dict() 86 | for name, param in state_dict.items(): 87 | if name not in own_state: 88 | continue 89 | own_state[name].copy_(param) 90 | return model 91 | 92 | model = load_my_state_dict(model, torch.load(weightspath)) 93 | print ("Model and weights LOADED successfully") 94 | 95 | model.eval() 96 | 97 | if(not os.path.exists(args.datadir)): 98 | print ("Error: datadir could not be loaded") 99 | 100 | 101 | loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), 102 | num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 103 | 104 | for step, (images, labels, filename, filenameGt) in enumerate(loader): 105 | if (not args.cpu): 106 | images = images.cuda() 107 | #labels = labels.cuda() 108 | 109 | inputs = Variable(images) 110 | #targets = Variable(labels) 111 | with torch.no_grad(): 112 | outputs = model(inputs) 113 | 114 | label = outputs[0].max(0)[1].byte().cpu().data 115 | label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0)) 116 | #print (numpy.unique(label.numpy())) #debug 117 | 118 | filenameSave = "./save_results/" + filename[0].split("leftImg8bit/")[1] 119 | os.makedirs(os.path.dirname(filenameSave), exist_ok=True) 120 | #image_transform(label.byte()).save(filenameSave) 121 | label_cityscapes.save(filenameSave) 122 | 123 | print (step, filenameSave) 124 | 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = ArgumentParser() 129 | 130 | parser.add_argument('--state') 131 | 132 | parser.add_argument('--loadDir',default="../trained_models/") 133 | parser.add_argument('--loadWeights', default="erfnet_pretrained.pth") 134 | parser.add_argument('--loadModel', default="erfnet.py") 135 | parser.add_argument('--subset', default="val") #can be val, test, train, demoSequence 136 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/") 137 | parser.add_argument('--num-workers', type=int, default=4) 138 | parser.add_argument('--batch-size', type=int, default=1) 139 | parser.add_argument('--cpu', action='store_true') 140 | 141 | main(parser.parse_args()) 142 | -------------------------------------------------------------------------------- /eval/eval_forwardTime.py: -------------------------------------------------------------------------------- 1 | # Code to evaluate forward pass time in Pytorch 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | import time 10 | 11 | from PIL import Image 12 | from argparse import ArgumentParser 13 | 14 | from torch.autograd import Variable 15 | 16 | from erfnet_nobn import ERFNet 17 | from transform import Relabel, ToLabel, Colorize 18 | 19 | import torch.backends.cudnn as cudnn 20 | cudnn.benchmark = True 21 | 22 | def main(args): 23 | model = ERFNet(19) 24 | if (not args.cpu): 25 | model = model.cuda()#.half() #HALF seems to be doing slower for some reason 26 | #model = torch.nn.DataParallel(model).cuda() 27 | 28 | model.eval() 29 | 30 | 31 | images = torch.randn(args.batch_size, args.num_channels, args.height, args.width) 32 | 33 | if (not args.cpu): 34 | images = images.cuda()#.half() 35 | 36 | time_train = [] 37 | 38 | i=0 39 | 40 | while(True): 41 | #for step, (images, labels, filename, filenameGt) in enumerate(loader): 42 | 43 | start_time = time.time() 44 | 45 | inputs = Variable(images) 46 | with torch.no_grad(): 47 | outputs = model(inputs) 48 | 49 | #preds = outputs.cpu() 50 | if (not args.cpu): 51 | torch.cuda.synchronize() #wait for cuda to finish (cuda is asynchronous!) 52 | 53 | if i!=0: #first run always takes some time for setup 54 | fwt = time.time() - start_time 55 | time_train.append(fwt) 56 | print ("Forward time per img (b=%d): %.3f (Mean: %.3f)" % (args.batch_size, fwt/args.batch_size, sum(time_train) / len(time_train) / args.batch_size)) 57 | 58 | time.sleep(1) #to avoid overheating the GPU too much 59 | i+=1 60 | 61 | if __name__ == '__main__': 62 | parser = ArgumentParser() 63 | 64 | parser.add_argument('--width', type=int, default=1024) 65 | parser.add_argument('--height', type=int, default=512) 66 | parser.add_argument('--num-channels', type=int, default=3) 67 | parser.add_argument('--batch-size', type=int, default=1) 68 | parser.add_argument('--cpu', action='store_true') 69 | 70 | main(parser.parse_args()) 71 | -------------------------------------------------------------------------------- /eval/eval_iou.py: -------------------------------------------------------------------------------- 1 | # Code to calculate IoU (mean and per-class) in a dataset 2 | # Nov 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import os 10 | import importlib 11 | import time 12 | 13 | from PIL import Image 14 | from argparse import ArgumentParser 15 | 16 | from torch.autograd import Variable 17 | from torch.utils.data import DataLoader 18 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize 19 | from torchvision.transforms import ToTensor, ToPILImage 20 | 21 | from dataset import cityscapes 22 | from erfnet import ERFNet 23 | from transform import Relabel, ToLabel, Colorize 24 | from iouEval import iouEval, getColorEntry 25 | 26 | NUM_CHANNELS = 3 27 | NUM_CLASSES = 20 28 | 29 | image_transform = ToPILImage() 30 | input_transform_cityscapes = Compose([ 31 | Resize(512, Image.BILINEAR), 32 | ToTensor(), 33 | ]) 34 | target_transform_cityscapes = Compose([ 35 | Resize(512, Image.NEAREST), 36 | ToLabel(), 37 | Relabel(255, 19), #ignore label to 19 38 | ]) 39 | 40 | def main(args): 41 | 42 | modelpath = args.loadDir + args.loadModel 43 | weightspath = args.loadDir + args.loadWeights 44 | 45 | print ("Loading model: " + modelpath) 46 | print ("Loading weights: " + weightspath) 47 | 48 | model = ERFNet(NUM_CLASSES) 49 | 50 | #model = torch.nn.DataParallel(model) 51 | if (not args.cpu): 52 | model = torch.nn.DataParallel(model).cuda() 53 | 54 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements 55 | own_state = model.state_dict() 56 | for name, param in state_dict.items(): 57 | if name not in own_state: 58 | if name.startswith("module."): 59 | own_state[name.split("module.")[-1]].copy_(param) 60 | else: 61 | print(name, " not loaded") 62 | continue 63 | else: 64 | own_state[name].copy_(param) 65 | return model 66 | 67 | model = load_my_state_dict(model, torch.load(weightspath, map_location=lambda storage, loc: storage)) 68 | print ("Model and weights LOADED successfully") 69 | 70 | 71 | model.eval() 72 | 73 | if(not os.path.exists(args.datadir)): 74 | print ("Error: datadir could not be loaded") 75 | 76 | 77 | loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset), num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 78 | 79 | 80 | iouEvalVal = iouEval(NUM_CLASSES) 81 | 82 | start = time.time() 83 | 84 | for step, (images, labels, filename, filenameGt) in enumerate(loader): 85 | if (not args.cpu): 86 | images = images.cuda() 87 | labels = labels.cuda() 88 | 89 | inputs = Variable(images) 90 | with torch.no_grad(): 91 | outputs = model(inputs) 92 | 93 | iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, labels) 94 | 95 | filenameSave = filename[0].split("leftImg8bit/")[1] 96 | 97 | print (step, filenameSave) 98 | 99 | 100 | iouVal, iou_classes = iouEvalVal.getIoU() 101 | 102 | iou_classes_str = [] 103 | for i in range(iou_classes.size(0)): 104 | iouStr = getColorEntry(iou_classes[i])+'{:0.2f}'.format(iou_classes[i]*100) + '\033[0m' 105 | iou_classes_str.append(iouStr) 106 | 107 | print("---------------------------------------") 108 | print("Took ", time.time()-start, "seconds") 109 | print("=======================================") 110 | #print("TOTAL IOU: ", iou * 100, "%") 111 | print("Per-Class IoU:") 112 | print(iou_classes_str[0], "Road") 113 | print(iou_classes_str[1], "sidewalk") 114 | print(iou_classes_str[2], "building") 115 | print(iou_classes_str[3], "wall") 116 | print(iou_classes_str[4], "fence") 117 | print(iou_classes_str[5], "pole") 118 | print(iou_classes_str[6], "traffic light") 119 | print(iou_classes_str[7], "traffic sign") 120 | print(iou_classes_str[8], "vegetation") 121 | print(iou_classes_str[9], "terrain") 122 | print(iou_classes_str[10], "sky") 123 | print(iou_classes_str[11], "person") 124 | print(iou_classes_str[12], "rider") 125 | print(iou_classes_str[13], "car") 126 | print(iou_classes_str[14], "truck") 127 | print(iou_classes_str[15], "bus") 128 | print(iou_classes_str[16], "train") 129 | print(iou_classes_str[17], "motorcycle") 130 | print(iou_classes_str[18], "bicycle") 131 | print("=======================================") 132 | iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m' 133 | print ("MEAN IoU: ", iouStr, "%") 134 | 135 | if __name__ == '__main__': 136 | parser = ArgumentParser() 137 | 138 | parser.add_argument('--state') 139 | 140 | parser.add_argument('--loadDir',default="../trained_models/") 141 | parser.add_argument('--loadWeights', default="erfnet_pretrained.pth") 142 | parser.add_argument('--loadModel', default="erfnet.py") 143 | parser.add_argument('--subset', default="val") #can be val or train (must have labels) 144 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/") 145 | parser.add_argument('--num-workers', type=int, default=4) 146 | parser.add_argument('--batch-size', type=int, default=1) 147 | parser.add_argument('--cpu', action='store_true') 148 | 149 | main(parser.parse_args()) 150 | -------------------------------------------------------------------------------- /eval/iouEval.py: -------------------------------------------------------------------------------- 1 | # Code for evaluating IoU 2 | # Nov 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | 8 | class iouEval: 9 | 10 | def __init__(self, nClasses, ignoreIndex=19): 11 | self.nClasses = nClasses 12 | self.ignoreIndex = ignoreIndex if nClasses>ignoreIndex else -1 #if ignoreIndex is larger than nClasses, consider no ignoreIndex 13 | self.reset() 14 | 15 | def reset (self): 16 | classes = self.nClasses if self.ignoreIndex==-1 else self.nClasses-1 17 | self.tp = torch.zeros(classes).double() 18 | self.fp = torch.zeros(classes).double() 19 | self.fn = torch.zeros(classes).double() 20 | 21 | def addBatch(self, x, y): #x=preds, y=targets 22 | #sizes should be "batch_size x nClasses x H x W" 23 | 24 | #print ("X is cuda: ", x.is_cuda) 25 | #print ("Y is cuda: ", y.is_cuda) 26 | 27 | if (x.is_cuda or y.is_cuda): 28 | x = x.cuda() 29 | y = y.cuda() 30 | 31 | #if size is "batch_size x 1 x H x W" scatter to onehot 32 | if (x.size(1) == 1): 33 | x_onehot = torch.zeros(x.size(0), self.nClasses, x.size(2), x.size(3)) 34 | if x.is_cuda: 35 | x_onehot = x_onehot.cuda() 36 | x_onehot.scatter_(1, x, 1).float() 37 | else: 38 | x_onehot = x.float() 39 | 40 | if (y.size(1) == 1): 41 | y_onehot = torch.zeros(y.size(0), self.nClasses, y.size(2), y.size(3)) 42 | if y.is_cuda: 43 | y_onehot = y_onehot.cuda() 44 | y_onehot.scatter_(1, y, 1).float() 45 | else: 46 | y_onehot = y.float() 47 | 48 | if (self.ignoreIndex != -1): 49 | ignores = y_onehot[:,self.ignoreIndex].unsqueeze(1) 50 | x_onehot = x_onehot[:, :self.ignoreIndex] 51 | y_onehot = y_onehot[:, :self.ignoreIndex] 52 | else: 53 | ignores=0 54 | 55 | #print(type(x_onehot)) 56 | #print(type(y_onehot)) 57 | #print(x_onehot.size()) 58 | #print(y_onehot.size()) 59 | 60 | tpmult = x_onehot * y_onehot #times prediction and gt coincide is 1 61 | tp = torch.sum(torch.sum(torch.sum(tpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 62 | fpmult = x_onehot * (1-y_onehot-ignores) #times prediction says its that class and gt says its not (subtracting cases when its ignore label!) 63 | fp = torch.sum(torch.sum(torch.sum(fpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 64 | fnmult = (1-x_onehot) * (y_onehot) #times prediction says its not that class and gt says it is 65 | fn = torch.sum(torch.sum(torch.sum(fnmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 66 | 67 | self.tp += tp.double().cpu() 68 | self.fp += fp.double().cpu() 69 | self.fn += fn.double().cpu() 70 | 71 | def getIoU(self): 72 | num = self.tp 73 | den = self.tp + self.fp + self.fn + 1e-15 74 | iou = num / den 75 | return torch.mean(iou), iou #returns "iou mean", "iou per class" 76 | 77 | # Class for colors 78 | class colors: 79 | RED = '\033[31;1m' 80 | GREEN = '\033[32;1m' 81 | YELLOW = '\033[33;1m' 82 | BLUE = '\033[34;1m' 83 | MAGENTA = '\033[35;1m' 84 | CYAN = '\033[36;1m' 85 | BOLD = '\033[1m' 86 | UNDERLINE = '\033[4m' 87 | ENDC = '\033[0m' 88 | 89 | # Colored value output if colorized flag is activated. 90 | def getColorEntry(val): 91 | if not isinstance(val, float): 92 | return colors.ENDC 93 | if (val < .20): 94 | return colors.RED 95 | elif (val < .40): 96 | return colors.YELLOW 97 | elif (val < .60): 98 | return colors.BLUE 99 | elif (val < .80): 100 | return colors.CYAN 101 | else: 102 | return colors.GREEN 103 | 104 | -------------------------------------------------------------------------------- /eval/transform.py: -------------------------------------------------------------------------------- 1 | # Code with transformations for Cityscapes (adapted from bodokaiser/piwise code) 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from PIL import Image 10 | 11 | def colormap_cityscapes(n): 12 | cmap=np.zeros([n, 3]).astype(np.uint8) 13 | cmap[0,:] = np.array([128, 64,128]) 14 | cmap[1,:] = np.array([244, 35,232]) 15 | cmap[2,:] = np.array([ 70, 70, 70]) 16 | cmap[3,:] = np.array([ 102,102,156]) 17 | cmap[4,:] = np.array([ 190,153,153]) 18 | cmap[5,:] = np.array([ 153,153,153]) 19 | 20 | cmap[6,:] = np.array([ 250,170, 30]) 21 | cmap[7,:] = np.array([ 220,220, 0]) 22 | cmap[8,:] = np.array([ 107,142, 35]) 23 | cmap[9,:] = np.array([ 152,251,152]) 24 | cmap[10,:] = np.array([ 70,130,180]) 25 | 26 | cmap[11,:] = np.array([ 220, 20, 60]) 27 | cmap[12,:] = np.array([ 255, 0, 0]) 28 | cmap[13,:] = np.array([ 0, 0,142]) 29 | cmap[14,:] = np.array([ 0, 0, 70]) 30 | cmap[15,:] = np.array([ 0, 60,100]) 31 | 32 | cmap[16,:] = np.array([ 0, 80,100]) 33 | cmap[17,:] = np.array([ 0, 0,230]) 34 | cmap[18,:] = np.array([ 119, 11, 32]) 35 | cmap[19,:] = np.array([ 0, 0, 0]) 36 | 37 | return cmap 38 | 39 | 40 | def colormap(n): 41 | cmap=np.zeros([n, 3]).astype(np.uint8) 42 | 43 | for i in np.arange(n): 44 | r, g, b = np.zeros(3) 45 | 46 | for j in np.arange(8): 47 | r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j)) 48 | g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1)) 49 | b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2)) 50 | 51 | cmap[i,:] = np.array([r, g, b]) 52 | 53 | return cmap 54 | 55 | class Relabel: 56 | 57 | def __init__(self, olabel, nlabel): 58 | self.olabel = olabel 59 | self.nlabel = nlabel 60 | 61 | def __call__(self, tensor): 62 | assert isinstance(tensor, torch.LongTensor) or isinstance(tensor, torch.ByteTensor) , 'tensor needs to be LongTensor' 63 | tensor[tensor == self.olabel] = self.nlabel 64 | return tensor 65 | 66 | 67 | class ToLabel: 68 | 69 | def __call__(self, image): 70 | return torch.from_numpy(np.array(image)).long().unsqueeze(0) 71 | 72 | 73 | class Colorize: 74 | 75 | def __init__(self, n=22): 76 | #self.cmap = colormap(256) 77 | self.cmap = colormap_cityscapes(256) 78 | self.cmap[n] = self.cmap[-1] 79 | self.cmap = torch.from_numpy(self.cmap[:n]) 80 | 81 | def __call__(self, gray_image): 82 | size = gray_image.size() 83 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 84 | 85 | #for label in range(1, len(self.cmap)): 86 | for label in range(0, len(self.cmap)): 87 | mask = gray_image[0] == label 88 | 89 | color_image[0][mask] = self.cmap[label][0] 90 | color_image[1][mask] = self.cmap[label][1] 91 | color_image[2][mask] = self.cmap[label][2] 92 | 93 | return color_image 94 | -------------------------------------------------------------------------------- /example_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eromera/erfnet_pytorch/d4a46faf9e465286c89ebd9c44bc929b2d213fb3/example_segmentation.png -------------------------------------------------------------------------------- /imagenet/README.md: -------------------------------------------------------------------------------- 1 | # Imagenet pretraining script and model 2 | 3 | This folder contains the script and model definition to pretrain ERFNet's encoder in Imagenet Data. 4 | 5 | The script is an adaptation from the code in [Pytorch Imagenet example](https://github.com/pytorch/examples/tree/master/imagenet). Please make sure that you have Imagenet dataset split in train and val folders before launching the script. Refer to that repository for instructions about usage and main.py options. Basic command: 6 | 7 | ``` 8 | python main.py 9 | ``` 10 | -------------------------------------------------------------------------------- /imagenet/erfnet_imagenet.py: -------------------------------------------------------------------------------- 1 | # ERFNet encoder model definition used for pretraining in ImageNet 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | class DownsamplerBlock (nn.Module): 12 | def __init__(self, ninput, noutput): 13 | super().__init__() 14 | 15 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True) 16 | self.pool = nn.MaxPool2d(2, stride=2) 17 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 18 | 19 | def forward(self, input): 20 | output = torch.cat([self.conv(input), self.pool(input)], 1) 21 | output = self.bn(output) 22 | return F.relu(output) 23 | 24 | class non_bottleneck_1d (nn.Module): 25 | def __init__(self, chann, dropprob, dilated): 26 | super().__init__() 27 | 28 | 29 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True) 30 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True) 31 | 32 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 33 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated)) 34 | 35 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) 36 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) 37 | 38 | self.dropout = nn.Dropout2d(dropprob) 39 | 40 | 41 | def forward(self, input): 42 | 43 | output = self.conv3x1_1(input) 44 | output = F.relu(output) 45 | output = self.conv1x3_1(output) 46 | output = self.bn1(output) 47 | output = F.relu(output) 48 | 49 | output = self.conv3x1_2(output) 50 | output = F.relu(output) 51 | output = self.conv1x3_2(output) 52 | output = self.bn2(output) 53 | 54 | if (self.dropout.p != 0): 55 | output = self.dropout(output) 56 | 57 | return F.relu(output+input) 58 | 59 | 60 | class Encoder(nn.Module): 61 | def __init__(self): 62 | super().__init__() 63 | self.initial_block = DownsamplerBlock(3,16) 64 | 65 | self.layers = nn.ModuleList() 66 | 67 | self.layers.append(DownsamplerBlock(16,64)) 68 | 69 | for x in range(0, 5): #5 times 70 | self.layers.append(non_bottleneck_1d(64, 0.1, 1)) 71 | self.layers.append(DownsamplerBlock(64,128)) 72 | 73 | for x in range(0, 2): #2 times 74 | self.layers.append(non_bottleneck_1d(128, 0.1, 2)) 75 | self.layers.append(non_bottleneck_1d(128, 0.1, 4)) 76 | self.layers.append(non_bottleneck_1d(128, 0.1, 8)) 77 | self.layers.append(non_bottleneck_1d(128, 0.1, 16)) 78 | 79 | 80 | def forward(self, input): 81 | output = self.initial_block(input) 82 | 83 | for layer in self.layers: 84 | output = layer(output) 85 | 86 | return output 87 | 88 | 89 | class Features(nn.Module): 90 | def __init__(self): 91 | super().__init__() 92 | self.encoder = Encoder() 93 | self.extralayer1 = nn.MaxPool2d(2, stride=2) 94 | self.extralayer2 = nn.AvgPool2d(14,1,0) 95 | 96 | def forward(self, input): 97 | output = self.encoder(input) 98 | output = self.extralayer1(output) 99 | output = self.extralayer2(output) 100 | return output 101 | 102 | class Classifier(nn.Module): 103 | def __init__(self, num_classes): 104 | super().__init__() 105 | self.linear = nn.Linear(128, num_classes) 106 | 107 | def forward(self, input): 108 | output = input.view(input.size(0), 128) #first is batch_size 109 | output = self.linear(output) 110 | return output 111 | 112 | class ERFNet(nn.Module): 113 | def __init__(self, num_classes): #use encoder to pass pretrained encoder 114 | super().__init__() 115 | 116 | self.features = Features() 117 | self.classifier = Classifier(num_classes) 118 | 119 | def forward(self, input): 120 | output = self.features(input) 121 | output = self.classifier(output) 122 | return output 123 | 124 | -------------------------------------------------------------------------------- /imagenet/main.py: -------------------------------------------------------------------------------- 1 | # Script for training ERFNet encoder in ImageNet 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import argparse 7 | import os 8 | import shutil 9 | import time 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.backends.cudnn as cudnn 15 | import torch.optim 16 | import torch.utils.data 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | import torchvision.models as models 20 | 21 | from torch.optim import lr_scheduler 22 | 23 | from erfnet_imagenet import ERFNet 24 | 25 | model_names = sorted(name for name in models.__dict__ 26 | if name.islower() and not name.startswith("__") 27 | and callable(models.__dict__[name])) 28 | 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 31 | parser.add_argument('data', metavar='DIR', 32 | help='path to dataset') 33 | parser.add_argument('--arch', '-a', metavar='ARCH', default='erfnet', 34 | #choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: resnet18)') 38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 39 | help='number of data loading workers (default: 4)') 40 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=256, type=int, 45 | metavar='N', help='mini-batch size (default: 256)') 46 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 47 | metavar='LR', help='initial learning rate') 48 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 49 | help='momentum') 50 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 51 | metavar='W', help='weight decay (default: 1e-4)') 52 | parser.add_argument('--print-freq', '-p', default=10, type=int, 53 | metavar='N', help='print frequency (default: 10)') 54 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 55 | help='path to latest checkpoint (default: none)') 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='evaluate model on validation set') 58 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 59 | help='use pre-trained model') 60 | 61 | best_prec1 = 0 62 | 63 | def main(): 64 | global args, best_prec1 65 | args = parser.parse_args() 66 | 67 | # create model 68 | if (args.arch == 'erfnet'): 69 | model = ERFNet(1000) 70 | else: 71 | if args.pretrained: 72 | print("=> using pre-trained model '{}'".format(args.arch)) 73 | model = models.__dict__[args.arch](pretrained=True) 74 | else: 75 | print("=> creating model '{}'".format(args.arch)) 76 | model = models.__dict__[args.arch]() 77 | 78 | model = torch.nn.DataParallel(model).cuda() 79 | 80 | # define loss function (criterion) and optimizer 81 | criterion = nn.CrossEntropyLoss().cuda() 82 | 83 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 84 | momentum=args.momentum, 85 | weight_decay=args.weight_decay) 86 | 87 | # optionally resume from a checkpoint 88 | if args.resume: 89 | if os.path.isfile(args.resume): 90 | print("=> loading checkpoint '{}'".format(args.resume)) 91 | checkpoint = torch.load(args.resume) 92 | args.start_epoch = checkpoint['epoch'] 93 | best_prec1 = checkpoint['best_prec1'] 94 | model.load_state_dict(checkpoint['state_dict']) 95 | optimizer.load_state_dict(checkpoint['optimizer']) 96 | print("=> loaded checkpoint '{}' (epoch {})" 97 | .format(args.resume, checkpoint['epoch'])) 98 | else: 99 | print("=> no checkpoint found at '{}'".format(args.resume)) 100 | 101 | cudnn.benchmark = True 102 | 103 | # Data loading code 104 | traindir = os.path.join(args.data, 'train') 105 | valdir = os.path.join(args.data, 'val') 106 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 107 | std=[0.229, 0.224, 0.225]) 108 | 109 | train_loader = torch.utils.data.DataLoader( 110 | datasets.ImageFolder(traindir, transforms.Compose([ 111 | #RemoveExif(), 112 | transforms.RandomResizedCrop(224), 113 | transforms.RandomHorizontalFlip(), 114 | transforms.ToTensor(), 115 | normalize, 116 | ])), 117 | batch_size=args.batch_size, shuffle=True, 118 | num_workers=args.workers, pin_memory=True) 119 | 120 | val_loader = torch.utils.data.DataLoader( 121 | datasets.ImageFolder(valdir, transforms.Compose([ 122 | transforms.Resize(256), 123 | transforms.CenterCrop(224), 124 | transforms.ToTensor(), 125 | normalize, 126 | ])), 127 | batch_size=args.batch_size, shuffle=False, 128 | num_workers=args.workers, pin_memory=True) 129 | 130 | if args.evaluate: 131 | validate(val_loader, model, criterion) 132 | return 133 | 134 | for epoch in range(args.start_epoch, args.epochs): 135 | adjust_learning_rate(optimizer, epoch) 136 | 137 | # train for one epoch 138 | train(train_loader, model, criterion, optimizer, epoch) 139 | 140 | # evaluate on validation set 141 | prec1 = validate(val_loader, model, criterion) 142 | 143 | # remember best prec@1 and save checkpoint 144 | is_best = prec1 > best_prec1 145 | best_prec1 = max(prec1, best_prec1) 146 | save_checkpoint({ 147 | 'epoch': epoch + 1, 148 | 'arch': args.arch, 149 | 'state_dict': model.state_dict(), 150 | 'best_prec1': best_prec1, 151 | 'optimizer' : optimizer.state_dict(), 152 | }, is_best) 153 | 154 | #scheduler.step(prec1, epoch) #decreases learning rate if prec1 plateaus 155 | 156 | 157 | def train(train_loader, model, criterion, optimizer, epoch): 158 | batch_time = AverageMeter() 159 | data_time = AverageMeter() 160 | losses = AverageMeter() 161 | top1 = AverageMeter() 162 | top5 = AverageMeter() 163 | 164 | # switch to train mode 165 | model.train() 166 | 167 | end = time.time() 168 | for i, (input, target) in enumerate(train_loader): 169 | # measure data loading time 170 | data_time.update(time.time() - end) 171 | 172 | target = target.cuda(async=True) 173 | input_var = torch.autograd.Variable(input) 174 | target_var = torch.autograd.Variable(target) 175 | 176 | # compute output 177 | output = model(input_var) 178 | loss = criterion(output, target_var) 179 | 180 | # measure accuracy and record loss 181 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 182 | losses.update(loss.data[0], input.size(0)) 183 | top1.update(prec1[0], input.size(0)) 184 | top5.update(prec5[0], input.size(0)) 185 | 186 | # compute gradient and do SGD step 187 | optimizer.zero_grad() 188 | loss.backward() 189 | optimizer.step() 190 | 191 | # measure elapsed time 192 | batch_time.update(time.time() - end) 193 | end = time.time() 194 | 195 | if i % args.print_freq == 0: 196 | for param_group in optimizer.param_groups: 197 | lr = param_group['lr'] 198 | print('Epoch: [{0}][{1}/{2}][lr:{lr:.6g}]\t' 199 | 'Time {batch_time.val:.3f} ({batch_time.avg:.2f}) / ' 200 | 'Data {data_time.val:.3f} ({data_time.avg:.2f})\t' 201 | 'Loss {loss.val:.3f} ({loss.avg:.3f})\t' 202 | 'Prec@1 {top1.val:.2f} ({top1.avg:.2f})\t' 203 | 'Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format( 204 | epoch, i, len(train_loader), batch_time=batch_time, 205 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=lr)) 206 | 207 | 208 | def validate(val_loader, model, criterion): 209 | batch_time = AverageMeter() 210 | losses = AverageMeter() 211 | top1 = AverageMeter() 212 | top5 = AverageMeter() 213 | 214 | # switch to evaluate mode 215 | model.eval() 216 | 217 | end = time.time() 218 | for i, (input, target) in enumerate(val_loader): 219 | target = target.cuda(async=True) 220 | input_var = torch.autograd.Variable(input, volatile=True) 221 | target_var = torch.autograd.Variable(target, volatile=True) 222 | 223 | # compute output 224 | output = model(input_var) 225 | loss = criterion(output, target_var) 226 | 227 | # measure accuracy and record loss 228 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 229 | losses.update(loss.data[0], input.size(0)) 230 | top1.update(prec1[0], input.size(0)) 231 | top5.update(prec5[0], input.size(0)) 232 | 233 | # measure elapsed time 234 | batch_time.update(time.time() - end) 235 | end = time.time() 236 | 237 | if i % args.print_freq == 0: 238 | print('Test: [{0}/{1}]\t' 239 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 240 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 241 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 242 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 243 | i, len(val_loader), batch_time=batch_time, loss=losses, 244 | top1=top1, top5=top5)) 245 | 246 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 247 | .format(top1=top1, top5=top5)) 248 | 249 | return top1.avg 250 | 251 | 252 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 253 | torch.save(state, filename) 254 | if is_best: 255 | shutil.copyfile(filename, 'model_best.pth.tar') 256 | 257 | 258 | class AverageMeter(object): 259 | """Computes and stores the average and current value""" 260 | def __init__(self): 261 | self.reset() 262 | 263 | def reset(self): 264 | self.val = 0 265 | self.avg = 0 266 | self.sum = 0 267 | self.count = 0 268 | 269 | def update(self, val, n=1): 270 | self.val = val 271 | self.sum += val * n 272 | self.count += n 273 | self.avg = self.sum / self.count 274 | 275 | 276 | def adjust_learning_rate(optimizer, epoch): 277 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 278 | lr = args.lr 279 | wd = 1e-4 280 | milestone = 15 #after epoch milestone, lr is reduced exponentially 281 | if epoch > milestone: 282 | lr = args.lr * (0.95 ** (epoch-milestone)) 283 | wd = 0 284 | for param_group in optimizer.param_groups: 285 | param_group['lr'] = lr 286 | param_group['weight_decay'] = wd 287 | 288 | def accuracy(output, target, topk=(1,)): 289 | """Computes the precision@k for the specified values of k""" 290 | maxk = max(topk) 291 | batch_size = target.size(0) 292 | 293 | _, pred = output.topk(maxk, 1, True, True) 294 | pred = pred.t() 295 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 296 | 297 | res = [] 298 | for k in topk: 299 | correct_k = correct[:k].view(-1).float().sum(0) 300 | res.append(correct_k.mul_(100.0 / batch_size)) 301 | return res 302 | 303 | 304 | if __name__ == '__main__': 305 | main() 306 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | 409 | -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # Training ERFNet in Pytorch 2 | 3 | PyTorch code for training ERFNet model on Cityscapes. The code was based initially on the code from [bodokaiser/piwise](https://github.com/bodokaiser/piwise), adapted with several custom added modifications and tweaks. Some of them are: 4 | - Load cityscapes dataset 5 | - ERFNet model definition 6 | - Calculate IoU on each epoch during training 7 | - Save snapshots and best model during training 8 | - Save additional output files useful for checking results (see below "Output files...") 9 | - Resume training from checkpoint (use "--resume" flag in the command) 10 | 11 | ## Options 12 | For all options and defaults please see the bottom of the "main.py" file. Required ones are --savedir (name for creating a new folder with all the outputs of the training) and --datadir (path to cityscapes directory). 13 | 14 | ## Example commands 15 | Train encoder with 150 epochs and batch=6 and then train decoder (decoder training starts after encoder training): 16 | ``` 17 | python main.py --savedir erfnet_training1 --datadir /home/datasets/cityscapes/ --num-epochs 150 --batch-size 6 18 | ``` 19 | 20 | Train decoder using encoder's pretrained weights with ImageNet: 21 | ``` 22 | python main.py --savedir erfnet_training1 --datadir /home/datasets/cityscapes/ --num-epochs 150 --batch-size 6 --decoder --pretrainedEncoder "../trained_models/erfnet_encoder_pretrained.pth.tar" 23 | ``` 24 | 25 | ## Output files generated for each training: 26 | Each training will create a new folder in the "erfnet_pytorch/save/" directory named with the parameter --savedir and the following files: 27 | * **automated_log.txt**: Plain text file that contains in columns the following info of each epoch {Epoch, Train-loss,Test-loss,Train-IoU,Test-IoU, learningRate}. Can be used to plot using Gnuplot or Excel. 28 | * **best.txt**: Plain text file containing a line with the best IoU achieved during training and its epoch. 29 | * **checkpoint.pth.tar**: bundle file that contains the checkpoint of the last trained epoch, contains the following elements: 'epoch' (epoch number as int), 'arch' (net definition as a string), 'state_dict' (saved weights dictionary loadable by pytorch), 'best_acc' (best achieved accuracy as float), 'optimizer' (saved optimizer parameters). 30 | * **{model}.py**: copy of the model file used (default erfnet.py). 31 | * **model.txt**: Plain text that displays the model's layers 32 | * **model_best.pth**: saved weights of the epoch that achieved best val accuracy. 33 | * **model_best.pth.tar**: Same parameters as "checkpoint.pth.tar" but for the epoch with best val accuracy. 34 | * **opts.txt**: Plain text file containing the options used for this training 35 | 36 | NOTE: Encoder trainings have an added "_encoder" tag to each file's name. 37 | 38 | ## IoU display during training 39 | 40 | NEW: In previous code, IoU was calculated using a port of the cityscapes scripts, but new code has been added in "iouEval.py" to make it class-general, non-dependable on other code, and much faster (using cuda) 41 | 42 | By default, only Validation IoU is calculated for faster training (can be changed in options) 43 | 44 | ## Visualization 45 | If you want to visualize the outputs during training add the "--visualize" flag and open an extra tab with: 46 | ``` 47 | python -m visdom.server -port 8097 48 | ``` 49 | The plots will be available using the browser in http://localhost.com:8097 50 | 51 | ## Multi-GPU 52 | If you wish to specify which GPUs to use, use the CUDA_VISIBLE_DEVICES command: 53 | ``` 54 | CUDA_VISIBLE_DEVICES=0 python main.py ... 55 | CUDA_VISIBLE_DEVICES=0,1 python main.py ... 56 | ``` 57 | 58 | 59 | -------------------------------------------------------------------------------- /train/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | from PIL import Image 5 | 6 | from torch.utils.data import Dataset 7 | 8 | EXTENSIONS = ['.jpg', '.png'] 9 | 10 | def load_image(file): 11 | return Image.open(file) 12 | 13 | def is_image(filename): 14 | return any(filename.endswith(ext) for ext in EXTENSIONS) 15 | 16 | def is_label(filename): 17 | return filename.endswith("_labelTrainIds.png") 18 | 19 | def image_path(root, basename, extension): 20 | return os.path.join(root, f'{basename}{extension}') 21 | 22 | def image_path_city(root, name): 23 | return os.path.join(root, f'{name}') 24 | 25 | def image_basename(filename): 26 | return os.path.basename(os.path.splitext(filename)[0]) 27 | 28 | class VOC12(Dataset): 29 | 30 | def __init__(self, root, input_transform=None, target_transform=None): 31 | self.images_root = os.path.join(root, 'images') 32 | self.labels_root = os.path.join(root, 'labels') 33 | 34 | self.filenames = [image_basename(f) 35 | for f in os.listdir(self.labels_root) if is_image(f)] 36 | self.filenames.sort() 37 | 38 | self.input_transform = input_transform 39 | self.target_transform = target_transform 40 | 41 | def __getitem__(self, index): 42 | filename = self.filenames[index] 43 | 44 | with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f: 45 | image = load_image(f).convert('RGB') 46 | with open(image_path(self.labels_root, filename, '.png'), 'rb') as f: 47 | label = load_image(f).convert('P') 48 | 49 | if self.input_transform is not None: 50 | image = self.input_transform(image) 51 | if self.target_transform is not None: 52 | label = self.target_transform(label) 53 | 54 | return image, label 55 | 56 | def __len__(self): 57 | return len(self.filenames) 58 | 59 | 60 | 61 | 62 | class cityscapes(Dataset): 63 | 64 | def __init__(self, root, co_transform=None, subset='train'): 65 | self.images_root = os.path.join(root, 'leftImg8bit/') 66 | self.labels_root = os.path.join(root, 'gtFine/') 67 | 68 | self.images_root += subset 69 | self.labels_root += subset 70 | 71 | print (self.images_root) 72 | #self.filenames = [image_basename(f) for f in os.listdir(self.images_root) if is_image(f)] 73 | self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.images_root)) for f in fn if is_image(f)] 74 | self.filenames.sort() 75 | 76 | #[os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(".")) for f in fn] 77 | #self.filenamesGt = [image_basename(f) for f in os.listdir(self.labels_root) if is_image(f)] 78 | self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.labels_root)) for f in fn if is_label(f)] 79 | self.filenamesGt.sort() 80 | 81 | self.co_transform = co_transform # ADDED THIS 82 | 83 | 84 | def __getitem__(self, index): 85 | filename = self.filenames[index] 86 | filenameGt = self.filenamesGt[index] 87 | 88 | with open(image_path_city(self.images_root, filename), 'rb') as f: 89 | image = load_image(f).convert('RGB') 90 | with open(image_path_city(self.labels_root, filenameGt), 'rb') as f: 91 | label = load_image(f).convert('P') 92 | 93 | if self.co_transform is not None: 94 | image, label = self.co_transform(image, label) 95 | 96 | return image, label 97 | 98 | def __len__(self): 99 | return len(self.filenames) 100 | 101 | -------------------------------------------------------------------------------- /train/erfnet.py: -------------------------------------------------------------------------------- 1 | # ERFNet full model definition for Pytorch 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | class DownsamplerBlock (nn.Module): 12 | def __init__(self, ninput, noutput): 13 | super().__init__() 14 | 15 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True) 16 | self.pool = nn.MaxPool2d(2, stride=2) 17 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 18 | 19 | def forward(self, input): 20 | output = torch.cat([self.conv(input), self.pool(input)], 1) 21 | output = self.bn(output) 22 | return F.relu(output) 23 | 24 | 25 | class non_bottleneck_1d (nn.Module): 26 | def __init__(self, chann, dropprob, dilated): 27 | super().__init__() 28 | 29 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True) 30 | 31 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True) 32 | 33 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) 34 | 35 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 36 | 37 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated)) 38 | 39 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) 40 | 41 | self.dropout = nn.Dropout2d(dropprob) 42 | 43 | 44 | def forward(self, input): 45 | 46 | output = self.conv3x1_1(input) 47 | output = F.relu(output) 48 | output = self.conv1x3_1(output) 49 | output = self.bn1(output) 50 | output = F.relu(output) 51 | 52 | output = self.conv3x1_2(output) 53 | output = F.relu(output) 54 | output = self.conv1x3_2(output) 55 | output = self.bn2(output) 56 | 57 | if (self.dropout.p != 0): 58 | output = self.dropout(output) 59 | 60 | return F.relu(output+input) #+input = identity (residual connection) 61 | 62 | 63 | class Encoder(nn.Module): 64 | def __init__(self, num_classes): 65 | super().__init__() 66 | self.initial_block = DownsamplerBlock(3,16) 67 | 68 | self.layers = nn.ModuleList() 69 | 70 | self.layers.append(DownsamplerBlock(16,64)) 71 | 72 | for x in range(0, 5): #5 times 73 | self.layers.append(non_bottleneck_1d(64, 0.03, 1)) 74 | 75 | self.layers.append(DownsamplerBlock(64,128)) 76 | 77 | for x in range(0, 2): #2 times 78 | self.layers.append(non_bottleneck_1d(128, 0.3, 2)) 79 | self.layers.append(non_bottleneck_1d(128, 0.3, 4)) 80 | self.layers.append(non_bottleneck_1d(128, 0.3, 8)) 81 | self.layers.append(non_bottleneck_1d(128, 0.3, 16)) 82 | 83 | #Only in encoder mode: 84 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True) 85 | 86 | def forward(self, input, predict=False): 87 | output = self.initial_block(input) 88 | 89 | for layer in self.layers: 90 | output = layer(output) 91 | 92 | if predict: 93 | output = self.output_conv(output) 94 | 95 | return output 96 | 97 | 98 | class UpsamplerBlock (nn.Module): 99 | def __init__(self, ninput, noutput): 100 | super().__init__() 101 | self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True) 102 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 103 | 104 | def forward(self, input): 105 | output = self.conv(input) 106 | output = self.bn(output) 107 | return F.relu(output) 108 | 109 | class Decoder (nn.Module): 110 | def __init__(self, num_classes): 111 | super().__init__() 112 | 113 | self.layers = nn.ModuleList() 114 | 115 | self.layers.append(UpsamplerBlock(128,64)) 116 | self.layers.append(non_bottleneck_1d(64, 0, 1)) 117 | self.layers.append(non_bottleneck_1d(64, 0, 1)) 118 | 119 | self.layers.append(UpsamplerBlock(64,16)) 120 | self.layers.append(non_bottleneck_1d(16, 0, 1)) 121 | self.layers.append(non_bottleneck_1d(16, 0, 1)) 122 | 123 | self.output_conv = nn.ConvTranspose2d( 16, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True) 124 | 125 | def forward(self, input): 126 | output = input 127 | 128 | for layer in self.layers: 129 | output = layer(output) 130 | 131 | output = self.output_conv(output) 132 | 133 | return output 134 | 135 | #ERFNet 136 | class Net(nn.Module): 137 | def __init__(self, num_classes, encoder=None): #use encoder to pass pretrained encoder 138 | super().__init__() 139 | 140 | if (encoder == None): 141 | self.encoder = Encoder(num_classes) 142 | else: 143 | self.encoder = encoder 144 | self.decoder = Decoder(num_classes) 145 | 146 | def forward(self, input, only_encode=False): 147 | if only_encode: 148 | return self.encoder.forward(input, predict=True) 149 | else: 150 | output = self.encoder(input) #predict=False by default 151 | return self.decoder.forward(output) 152 | -------------------------------------------------------------------------------- /train/erfnet_imagenet.py: -------------------------------------------------------------------------------- 1 | # ERFNet encoder model definition used for pretraining in ImageNet 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | class DownsamplerBlock (nn.Module): 12 | def __init__(self, ninput, noutput): 13 | super().__init__() 14 | 15 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True) 16 | self.pool = nn.MaxPool2d(2, stride=2) 17 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 18 | 19 | def forward(self, input): 20 | output = torch.cat([self.conv(input), self.pool(input)], 1) 21 | output = self.bn(output) 22 | return F.relu(output) 23 | 24 | 25 | class non_bottleneck_1d (nn.Module): 26 | def __init__(self, chann, dropprob, dilated): 27 | super().__init__() 28 | 29 | 30 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True) 31 | 32 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True) 33 | 34 | 35 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 36 | 37 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated)) 38 | 39 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) 40 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) 41 | 42 | self.dropout = nn.Dropout2d(dropprob) 43 | 44 | 45 | def forward(self, input): 46 | 47 | output = self.conv3x1_1(input) 48 | output = F.relu(output) 49 | output = self.conv1x3_1(output) 50 | output = self.bn1(output) 51 | output = F.relu(output) 52 | 53 | output = self.conv3x1_2(output) 54 | output = F.relu(output) 55 | output = self.conv1x3_2(output) 56 | output = self.bn2(output) 57 | 58 | if (self.dropout.p != 0): 59 | output = self.dropout(output) 60 | 61 | return F.relu(output+input) #+input = identity (residual connection) 62 | 63 | 64 | class Encoder(nn.Module): 65 | def __init__(self): 66 | super().__init__() 67 | self.initial_block = DownsamplerBlock(3,16) 68 | 69 | self.layers = nn.ModuleList() 70 | 71 | self.layers.append(DownsamplerBlock(16,64)) 72 | 73 | for x in range(0, 5): #5 times 74 | self.layers.append(non_bottleneck_1d(64, 0.1, 1)) 75 | 76 | self.layers.append(DownsamplerBlock(64,128)) 77 | 78 | for x in range(0, 2): #2 times 79 | self.layers.append(non_bottleneck_1d(128, 0.1, 2)) 80 | self.layers.append(non_bottleneck_1d(128, 0.1, 4)) 81 | self.layers.append(non_bottleneck_1d(128, 0.1, 8)) 82 | self.layers.append(non_bottleneck_1d(128, 0.1, 16)) 83 | 84 | 85 | def forward(self, input): 86 | output = self.initial_block(input) 87 | 88 | for layer in self.layers: 89 | output = layer(output) 90 | 91 | return output 92 | 93 | 94 | class Features(nn.Module): 95 | def __init__(self): 96 | super().__init__() 97 | self.encoder = Encoder() 98 | self.extralayer1 = nn.MaxPool2d(2, stride=2) 99 | self.extralayer2 = nn.AvgPool2d(14,1,0) 100 | 101 | def forward(self, input): 102 | #print("Feat input: ", input.size()) 103 | output = self.encoder(input) 104 | output = self.extralayer1(output) 105 | output = self.extralayer2(output) 106 | #print("Feat output: ", output.size()) 107 | return output 108 | 109 | class Classifier(nn.Module): 110 | def __init__(self, num_classes): 111 | super().__init__() 112 | self.linear = nn.Linear(128, num_classes) 113 | 114 | def forward(self, input): 115 | output = input.view(input.size(0), 128) #first is batch_size 116 | output = self.linear(output) 117 | return output 118 | 119 | class ERFNet(nn.Module): 120 | def __init__(self, num_classes): #use encoder to pass pretrained encoder 121 | super().__init__() 122 | 123 | self.features = Features() 124 | self.classifier = Classifier(num_classes) 125 | 126 | def forward(self, input): 127 | output = self.features(input) 128 | output = self.classifier(output) 129 | return output 130 | 131 | 132 | -------------------------------------------------------------------------------- /train/iouEval.py: -------------------------------------------------------------------------------- 1 | # Code for evaluating IoU 2 | # Nov 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | 8 | class iouEval: 9 | 10 | def __init__(self, nClasses, ignoreIndex=19): 11 | self.nClasses = nClasses 12 | self.ignoreIndex = ignoreIndex if nClasses>ignoreIndex else -1 #if ignoreIndex is larger than nClasses, consider no ignoreIndex 13 | self.reset() 14 | 15 | def reset (self): 16 | classes = self.nClasses if self.ignoreIndex==-1 else self.nClasses-1 17 | self.tp = torch.zeros(classes).double() 18 | self.fp = torch.zeros(classes).double() 19 | self.fn = torch.zeros(classes).double() 20 | 21 | def addBatch(self, x, y): #x=preds, y=targets 22 | #sizes should be "batch_size x nClasses x H x W" 23 | 24 | #print ("X is cuda: ", x.is_cuda) 25 | #print ("Y is cuda: ", y.is_cuda) 26 | 27 | if (x.is_cuda or y.is_cuda): 28 | x = x.cuda() 29 | y = y.cuda() 30 | 31 | #if size is "batch_size x 1 x H x W" scatter to onehot 32 | if (x.size(1) == 1): 33 | x_onehot = torch.zeros(x.size(0), self.nClasses, x.size(2), x.size(3)) 34 | if x.is_cuda: 35 | x_onehot = x_onehot.cuda() 36 | x_onehot.scatter_(1, x, 1).float() 37 | else: 38 | x_onehot = x.float() 39 | 40 | if (y.size(1) == 1): 41 | y_onehot = torch.zeros(y.size(0), self.nClasses, y.size(2), y.size(3)) 42 | if y.is_cuda: 43 | y_onehot = y_onehot.cuda() 44 | y_onehot.scatter_(1, y, 1).float() 45 | else: 46 | y_onehot = y.float() 47 | 48 | if (self.ignoreIndex != -1): 49 | ignores = y_onehot[:,self.ignoreIndex].unsqueeze(1) 50 | x_onehot = x_onehot[:, :self.ignoreIndex] 51 | y_onehot = y_onehot[:, :self.ignoreIndex] 52 | else: 53 | ignores=0 54 | 55 | #print(type(x_onehot)) 56 | #print(type(y_onehot)) 57 | #print(x_onehot.size()) 58 | #print(y_onehot.size()) 59 | 60 | tpmult = x_onehot * y_onehot #times prediction and gt coincide is 1 61 | tp = torch.sum(torch.sum(torch.sum(tpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 62 | fpmult = x_onehot * (1-y_onehot-ignores) #times prediction says its that class and gt says its not (subtracting cases when its ignore label!) 63 | fp = torch.sum(torch.sum(torch.sum(fpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 64 | fnmult = (1-x_onehot) * (y_onehot) #times prediction says its not that class and gt says it is 65 | fn = torch.sum(torch.sum(torch.sum(fnmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, keepdim=True).squeeze() 66 | 67 | self.tp += tp.double().cpu() 68 | self.fp += fp.double().cpu() 69 | self.fn += fn.double().cpu() 70 | 71 | def getIoU(self): 72 | num = self.tp 73 | den = self.tp + self.fp + self.fn + 1e-15 74 | iou = num / den 75 | return torch.mean(iou), iou #returns "iou mean", "iou per class" 76 | 77 | # Class for colors 78 | class colors: 79 | RED = '\033[31;1m' 80 | GREEN = '\033[32;1m' 81 | YELLOW = '\033[33;1m' 82 | BLUE = '\033[34;1m' 83 | MAGENTA = '\033[35;1m' 84 | CYAN = '\033[36;1m' 85 | BOLD = '\033[1m' 86 | UNDERLINE = '\033[4m' 87 | ENDC = '\033[0m' 88 | 89 | # Colored value output if colorized flag is activated. 90 | def getColorEntry(val): 91 | if not isinstance(val, float): 92 | return colors.ENDC 93 | if (val < .20): 94 | return colors.RED 95 | elif (val < .40): 96 | return colors.YELLOW 97 | elif (val < .60): 98 | return colors.BLUE 99 | elif (val < .80): 100 | return colors.CYAN 101 | else: 102 | return colors.GREEN 103 | 104 | -------------------------------------------------------------------------------- /train/main.py: -------------------------------------------------------------------------------- 1 | # Main code for training ERFNet model in Cityscapes dataset 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import os 7 | import random 8 | import time 9 | import numpy as np 10 | import torch 11 | import math 12 | 13 | from PIL import Image, ImageOps 14 | from argparse import ArgumentParser 15 | 16 | from torch.optim import SGD, Adam, lr_scheduler 17 | from torch.autograd import Variable 18 | from torch.utils.data import DataLoader 19 | from torchvision.transforms import Compose, CenterCrop, Normalize, Resize, Pad 20 | from torchvision.transforms import ToTensor, ToPILImage 21 | 22 | from dataset import VOC12,cityscapes 23 | from transform import Relabel, ToLabel, Colorize 24 | from visualize import Dashboard 25 | 26 | import importlib 27 | from iouEval import iouEval, getColorEntry 28 | 29 | from shutil import copyfile 30 | 31 | NUM_CHANNELS = 3 32 | NUM_CLASSES = 20 #pascal=22, cityscapes=20 33 | 34 | color_transform = Colorize(NUM_CLASSES) 35 | image_transform = ToPILImage() 36 | 37 | #Augmentations - different function implemented to perform random augments on both image and target 38 | class MyCoTransform(object): 39 | def __init__(self, enc, augment=True, height=512): 40 | self.enc=enc 41 | self.augment = augment 42 | self.height = height 43 | pass 44 | def __call__(self, input, target): 45 | # do something to both images 46 | input = Resize(self.height, Image.BILINEAR)(input) 47 | target = Resize(self.height, Image.NEAREST)(target) 48 | 49 | if(self.augment): 50 | # Random hflip 51 | hflip = random.random() 52 | if (hflip < 0.5): 53 | input = input.transpose(Image.FLIP_LEFT_RIGHT) 54 | target = target.transpose(Image.FLIP_LEFT_RIGHT) 55 | 56 | #Random translation 0-2 pixels (fill rest with padding 57 | transX = random.randint(-2, 2) 58 | transY = random.randint(-2, 2) 59 | 60 | input = ImageOps.expand(input, border=(transX,transY,0,0), fill=0) 61 | target = ImageOps.expand(target, border=(transX,transY,0,0), fill=255) #pad label filling with 255 62 | input = input.crop((0, 0, input.size[0]-transX, input.size[1]-transY)) 63 | target = target.crop((0, 0, target.size[0]-transX, target.size[1]-transY)) 64 | 65 | input = ToTensor()(input) 66 | if (self.enc): 67 | target = Resize(int(self.height/8), Image.NEAREST)(target) 68 | target = ToLabel()(target) 69 | target = Relabel(255, 19)(target) 70 | 71 | return input, target 72 | 73 | 74 | class CrossEntropyLoss2d(torch.nn.Module): 75 | 76 | def __init__(self, weight=None): 77 | super().__init__() 78 | 79 | self.loss = torch.nn.NLLLoss2d(weight) 80 | 81 | def forward(self, outputs, targets): 82 | return self.loss(torch.nn.functional.log_softmax(outputs, dim=1), targets) 83 | 84 | 85 | def train(args, model, enc=False): 86 | best_acc = 0 87 | 88 | #TODO: calculate weights by processing dataset histogram (now its being set by hand from the torch values) 89 | #create a loder to run all images and calculate histogram of labels, then create weight array using class balancing 90 | 91 | weight = torch.ones(NUM_CLASSES) 92 | if (enc): 93 | weight[0] = 2.3653597831726 94 | weight[1] = 4.4237880706787 95 | weight[2] = 2.9691488742828 96 | weight[3] = 5.3442072868347 97 | weight[4] = 5.2983593940735 98 | weight[5] = 5.2275490760803 99 | weight[6] = 5.4394111633301 100 | weight[7] = 5.3659925460815 101 | weight[8] = 3.4170460700989 102 | weight[9] = 5.2414722442627 103 | weight[10] = 4.7376127243042 104 | weight[11] = 5.2286224365234 105 | weight[12] = 5.455126285553 106 | weight[13] = 4.3019247055054 107 | weight[14] = 5.4264230728149 108 | weight[15] = 5.4331531524658 109 | weight[16] = 5.433765411377 110 | weight[17] = 5.4631009101868 111 | weight[18] = 5.3947434425354 112 | else: 113 | weight[0] = 2.8149201869965 114 | weight[1] = 6.9850029945374 115 | weight[2] = 3.7890393733978 116 | weight[3] = 9.9428062438965 117 | weight[4] = 9.7702074050903 118 | weight[5] = 9.5110931396484 119 | weight[6] = 10.311357498169 120 | weight[7] = 10.026463508606 121 | weight[8] = 4.6323022842407 122 | weight[9] = 9.5608062744141 123 | weight[10] = 7.8698215484619 124 | weight[11] = 9.5168733596802 125 | weight[12] = 10.373730659485 126 | weight[13] = 6.6616044044495 127 | weight[14] = 10.260489463806 128 | weight[15] = 10.287888526917 129 | weight[16] = 10.289801597595 130 | weight[17] = 10.405355453491 131 | weight[18] = 10.138095855713 132 | 133 | weight[19] = 0 134 | 135 | assert os.path.exists(args.datadir), "Error: datadir (dataset directory) could not be loaded" 136 | 137 | co_transform = MyCoTransform(enc, augment=True, height=args.height)#1024) 138 | co_transform_val = MyCoTransform(enc, augment=False, height=args.height)#1024) 139 | dataset_train = cityscapes(args.datadir, co_transform, 'train') 140 | dataset_val = cityscapes(args.datadir, co_transform_val, 'val') 141 | 142 | loader = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True) 143 | loader_val = DataLoader(dataset_val, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 144 | 145 | if args.cuda: 146 | weight = weight.cuda() 147 | criterion = CrossEntropyLoss2d(weight) 148 | print(type(criterion)) 149 | 150 | savedir = f'../save/{args.savedir}' 151 | 152 | if (enc): 153 | automated_log_path = savedir + "/automated_log_encoder.txt" 154 | modeltxtpath = savedir + "/model_encoder.txt" 155 | else: 156 | automated_log_path = savedir + "/automated_log.txt" 157 | modeltxtpath = savedir + "/model.txt" 158 | 159 | if (not os.path.exists(automated_log_path)): #dont add first line if it exists 160 | with open(automated_log_path, "a") as myfile: 161 | myfile.write("Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate") 162 | 163 | with open(modeltxtpath, "w") as myfile: 164 | myfile.write(str(model)) 165 | 166 | 167 | #TODO: reduce memory in first gpu: https://discuss.pytorch.org/t/multi-gpu-training-memory-usage-in-balance/4163/4 #https://github.com/pytorch/pytorch/issues/1893 168 | 169 | #optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999), eps=1e-08, weight_decay=2e-4) ## scheduler 1 170 | optimizer = Adam(model.parameters(), 5e-4, (0.9, 0.999), eps=1e-08, weight_decay=1e-4) ## scheduler 2 171 | 172 | start_epoch = 1 173 | if args.resume: 174 | #Must load weights, optimizer, epoch and best value. 175 | if enc: 176 | filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar' 177 | else: 178 | filenameCheckpoint = savedir + '/checkpoint.pth.tar' 179 | 180 | assert os.path.exists(filenameCheckpoint), "Error: resume option was used but checkpoint was not found in folder" 181 | checkpoint = torch.load(filenameCheckpoint) 182 | start_epoch = checkpoint['epoch'] 183 | model.load_state_dict(checkpoint['state_dict']) 184 | optimizer.load_state_dict(checkpoint['optimizer']) 185 | best_acc = checkpoint['best_acc'] 186 | print("=> Loaded checkpoint at epoch {})".format(checkpoint['epoch'])) 187 | 188 | #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5) # set up scheduler ## scheduler 1 189 | lambda1 = lambda epoch: pow((1-((epoch-1)/args.num_epochs)),0.9) ## scheduler 2 190 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) ## scheduler 2 191 | 192 | if args.visualize and args.steps_plot > 0: 193 | board = Dashboard(args.port) 194 | 195 | for epoch in range(start_epoch, args.num_epochs+1): 196 | print("----- TRAINING - EPOCH", epoch, "-----") 197 | 198 | scheduler.step(epoch) ## scheduler 2 199 | 200 | epoch_loss = [] 201 | time_train = [] 202 | 203 | doIouTrain = args.iouTrain 204 | doIouVal = args.iouVal 205 | 206 | if (doIouTrain): 207 | iouEvalTrain = iouEval(NUM_CLASSES) 208 | 209 | usedLr = 0 210 | for param_group in optimizer.param_groups: 211 | print("LEARNING RATE: ", param_group['lr']) 212 | usedLr = float(param_group['lr']) 213 | 214 | model.train() 215 | for step, (images, labels) in enumerate(loader): 216 | 217 | start_time = time.time() 218 | #print (labels.size()) 219 | #print (np.unique(labels.numpy())) 220 | #print("labels: ", np.unique(labels[0].numpy())) 221 | #labels = torch.ones(4, 1, 512, 1024).long() 222 | if args.cuda: 223 | images = images.cuda() 224 | labels = labels.cuda() 225 | 226 | inputs = Variable(images) 227 | targets = Variable(labels) 228 | outputs = model(inputs, only_encode=enc) 229 | 230 | #print("targets", np.unique(targets[:, 0].cpu().data.numpy())) 231 | 232 | optimizer.zero_grad() 233 | loss = criterion(outputs, targets[:, 0]) 234 | loss.backward() 235 | optimizer.step() 236 | 237 | epoch_loss.append(loss.data[0]) 238 | time_train.append(time.time() - start_time) 239 | 240 | if (doIouTrain): 241 | #start_time_iou = time.time() 242 | iouEvalTrain.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data) 243 | #print ("Time to add confusion matrix: ", time.time() - start_time_iou) 244 | 245 | #print(outputs.size()) 246 | if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0: 247 | start_time_plot = time.time() 248 | image = inputs[0].cpu().data 249 | #image[0] = image[0] * .229 + .485 250 | #image[1] = image[1] * .224 + .456 251 | #image[2] = image[2] * .225 + .406 252 | #print("output", np.unique(outputs[0].cpu().max(0)[1].data.numpy())) 253 | board.image(image, f'input (epoch: {epoch}, step: {step})') 254 | if isinstance(outputs, list): #merge gpu tensors 255 | board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)), 256 | f'output (epoch: {epoch}, step: {step})') 257 | else: 258 | board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)), 259 | f'output (epoch: {epoch}, step: {step})') 260 | board.image(color_transform(targets[0].cpu().data), 261 | f'target (epoch: {epoch}, step: {step})') 262 | print ("Time to paint images: ", time.time() - start_time_plot) 263 | if args.steps_loss > 0 and step % args.steps_loss == 0: 264 | average = sum(epoch_loss) / len(epoch_loss) 265 | print(f'loss: {average:0.4} (epoch: {epoch}, step: {step})', 266 | "// Avg time/img: %.4f s" % (sum(time_train) / len(time_train) / args.batch_size)) 267 | 268 | 269 | average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss) 270 | 271 | iouTrain = 0 272 | if (doIouTrain): 273 | iouTrain, iou_classes = iouEvalTrain.getIoU() 274 | iouStr = getColorEntry(iouTrain)+'{:0.2f}'.format(iouTrain*100) + '\033[0m' 275 | print ("EPOCH IoU on TRAIN set: ", iouStr, "%") 276 | 277 | #Validate on 500 val images after each epoch of training 278 | print("----- VALIDATING - EPOCH", epoch, "-----") 279 | model.eval() 280 | epoch_loss_val = [] 281 | time_val = [] 282 | 283 | if (doIouVal): 284 | iouEvalVal = iouEval(NUM_CLASSES) 285 | 286 | for step, (images, labels) in enumerate(loader_val): 287 | start_time = time.time() 288 | if args.cuda: 289 | images = images.cuda() 290 | labels = labels.cuda() 291 | 292 | inputs = Variable(images, volatile=True) #volatile flag makes it free backward or outputs for eval 293 | targets = Variable(labels, volatile=True) 294 | outputs = model(inputs, only_encode=enc) 295 | 296 | loss = criterion(outputs, targets[:, 0]) 297 | epoch_loss_val.append(loss.data[0]) 298 | time_val.append(time.time() - start_time) 299 | 300 | 301 | #Add batch to calculate TP, FP and FN for iou estimation 302 | if (doIouVal): 303 | #start_time_iou = time.time() 304 | iouEvalVal.addBatch(outputs.max(1)[1].unsqueeze(1).data, targets.data) 305 | #print ("Time to add confusion matrix: ", time.time() - start_time_iou) 306 | 307 | if args.visualize and args.steps_plot > 0 and step % args.steps_plot == 0: 308 | start_time_plot = time.time() 309 | image = inputs[0].cpu().data 310 | board.image(image, f'VAL input (epoch: {epoch}, step: {step})') 311 | if isinstance(outputs, list): #merge gpu tensors 312 | board.image(color_transform(outputs[0][0].cpu().max(0)[1].data.unsqueeze(0)), 313 | f'VAL output (epoch: {epoch}, step: {step})') 314 | else: 315 | board.image(color_transform(outputs[0].cpu().max(0)[1].data.unsqueeze(0)), 316 | f'VAL output (epoch: {epoch}, step: {step})') 317 | board.image(color_transform(targets[0].cpu().data), 318 | f'VAL target (epoch: {epoch}, step: {step})') 319 | print ("Time to paint images: ", time.time() - start_time_plot) 320 | if args.steps_loss > 0 and step % args.steps_loss == 0: 321 | average = sum(epoch_loss_val) / len(epoch_loss_val) 322 | print(f'VAL loss: {average:0.4} (epoch: {epoch}, step: {step})', 323 | "// Avg time/img: %.4f s" % (sum(time_val) / len(time_val) / args.batch_size)) 324 | 325 | 326 | average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val) 327 | #scheduler.step(average_epoch_loss_val, epoch) ## scheduler 1 # update lr if needed 328 | 329 | iouVal = 0 330 | if (doIouVal): 331 | iouVal, iou_classes = iouEvalVal.getIoU() 332 | iouStr = getColorEntry(iouVal)+'{:0.2f}'.format(iouVal*100) + '\033[0m' 333 | print ("EPOCH IoU on VAL set: ", iouStr, "%") 334 | 335 | 336 | # remember best valIoU and save checkpoint 337 | if iouVal == 0: 338 | current_acc = -average_epoch_loss_val 339 | else: 340 | current_acc = iouVal 341 | is_best = current_acc > best_acc 342 | best_acc = max(current_acc, best_acc) 343 | if enc: 344 | filenameCheckpoint = savedir + '/checkpoint_enc.pth.tar' 345 | filenameBest = savedir + '/model_best_enc.pth.tar' 346 | else: 347 | filenameCheckpoint = savedir + '/checkpoint.pth.tar' 348 | filenameBest = savedir + '/model_best.pth.tar' 349 | save_checkpoint({ 350 | 'epoch': epoch + 1, 351 | 'arch': str(model), 352 | 'state_dict': model.state_dict(), 353 | 'best_acc': best_acc, 354 | 'optimizer' : optimizer.state_dict(), 355 | }, is_best, filenameCheckpoint, filenameBest) 356 | 357 | #SAVE MODEL AFTER EPOCH 358 | if (enc): 359 | filename = f'{savedir}/model_encoder-{epoch:03}.pth' 360 | filenamebest = f'{savedir}/model_encoder_best.pth' 361 | else: 362 | filename = f'{savedir}/model-{epoch:03}.pth' 363 | filenamebest = f'{savedir}/model_best.pth' 364 | if args.epochs_save > 0 and step > 0 and step % args.epochs_save == 0: 365 | torch.save(model.state_dict(), filename) 366 | print(f'save: {filename} (epoch: {epoch})') 367 | if (is_best): 368 | torch.save(model.state_dict(), filenamebest) 369 | print(f'save: {filenamebest} (epoch: {epoch})') 370 | if (not enc): 371 | with open(savedir + "/best.txt", "w") as myfile: 372 | myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal)) 373 | else: 374 | with open(savedir + "/best_encoder.txt", "w") as myfile: 375 | myfile.write("Best epoch is %d, with Val-IoU= %.4f" % (epoch, iouVal)) 376 | 377 | #SAVE TO FILE A ROW WITH THE EPOCH RESULT (train loss, val loss, train IoU, val IoU) 378 | #Epoch Train-loss Test-loss Train-IoU Test-IoU learningRate 379 | with open(automated_log_path, "a") as myfile: 380 | myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" % (epoch, average_epoch_loss_train, average_epoch_loss_val, iouTrain, iouVal, usedLr )) 381 | 382 | return(model) #return model (convenience for encoder-decoder training) 383 | 384 | def save_checkpoint(state, is_best, filenameCheckpoint, filenameBest): 385 | torch.save(state, filenameCheckpoint) 386 | if is_best: 387 | print ("Saving model as best") 388 | torch.save(state, filenameBest) 389 | 390 | 391 | def main(args): 392 | savedir = f'../save/{args.savedir}' 393 | 394 | if not os.path.exists(savedir): 395 | os.makedirs(savedir) 396 | 397 | with open(savedir + '/opts.txt', "w") as myfile: 398 | myfile.write(str(args)) 399 | 400 | #Load Model 401 | assert os.path.exists(args.model + ".py"), "Error: model definition not found" 402 | model_file = importlib.import_module(args.model) 403 | model = model_file.Net(NUM_CLASSES) 404 | copyfile(args.model + ".py", savedir + '/' + args.model + ".py") 405 | 406 | if args.cuda: 407 | model = torch.nn.DataParallel(model).cuda() 408 | 409 | if args.state: 410 | #if args.state is provided then load this state for training 411 | #Note: this only loads initialized weights. If you want to resume a training use "--resume" option!! 412 | """ 413 | try: 414 | model.load_state_dict(torch.load(args.state)) 415 | except AssertionError: 416 | model.load_state_dict(torch.load(args.state, 417 | map_location=lambda storage, loc: storage)) 418 | #When model is saved as DataParallel it adds a model. to each key. To remove: 419 | #state_dict = {k.partition('model.')[2]: v for k,v in state_dict} 420 | #https://discuss.pytorch.org/t/prefix-parameter-names-in-saved-model-if-trained-by-multi-gpu/494 421 | """ 422 | def load_my_state_dict(model, state_dict): #custom function to load model when not all dict keys are there 423 | own_state = model.state_dict() 424 | for name, param in state_dict.items(): 425 | if name not in own_state: 426 | continue 427 | own_state[name].copy_(param) 428 | return model 429 | 430 | #print(torch.load(args.state)) 431 | model = load_my_state_dict(model, torch.load(args.state)) 432 | 433 | """ 434 | def weights_init(m): 435 | classname = m.__class__.__name__ 436 | if classname.find('Conv') != -1: 437 | #m.weight.data.normal_(0.0, 0.02) 438 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 439 | m.weight.data.normal_(0, math.sqrt(2. / n)) 440 | elif classname.find('BatchNorm') != -1: 441 | #m.weight.data.normal_(1.0, 0.02) 442 | m.weight.data.fill_(1) 443 | m.bias.data.fill_(0) 444 | 445 | #TO ACCESS MODEL IN DataParallel: next(model.children()) 446 | #next(model.children()).decoder.apply(weights_init) 447 | #Reinitialize weights for decoder 448 | 449 | next(model.children()).decoder.layers.apply(weights_init) 450 | next(model.children()).decoder.output_conv.apply(weights_init) 451 | 452 | #print(model.state_dict()) 453 | f = open('weights5.txt', 'w') 454 | f.write(str(model.state_dict())) 455 | f.close() 456 | """ 457 | 458 | #train(args, model) 459 | if (not args.decoder): 460 | print("========== ENCODER TRAINING ===========") 461 | model = train(args, model, True) #Train encoder 462 | #CAREFUL: for some reason, after training encoder alone, the decoder gets weights=0. 463 | #We must reinit decoder weights or reload network passing only encoder in order to train decoder 464 | print("========== DECODER TRAINING ===========") 465 | if (not args.state): 466 | if args.pretrainedEncoder: 467 | print("Loading encoder pretrained in imagenet") 468 | from erfnet_imagenet import ERFNet as ERFNet_imagenet 469 | pretrainedEnc = torch.nn.DataParallel(ERFNet_imagenet(1000)) 470 | pretrainedEnc.load_state_dict(torch.load(args.pretrainedEncoder)['state_dict']) 471 | pretrainedEnc = next(pretrainedEnc.children()).features.encoder 472 | if (not args.cuda): 473 | pretrainedEnc = pretrainedEnc.cpu() #because loaded encoder is probably saved in cuda 474 | else: 475 | pretrainedEnc = next(model.children()).encoder 476 | model = model_file.Net(NUM_CLASSES, encoder=pretrainedEnc) #Add decoder to encoder 477 | if args.cuda: 478 | model = torch.nn.DataParallel(model).cuda() 479 | #When loading encoder reinitialize weights for decoder because they are set to 0 when training dec 480 | model = train(args, model, False) #Train decoder 481 | print("========== TRAINING FINISHED ===========") 482 | 483 | if __name__ == '__main__': 484 | parser = ArgumentParser() 485 | parser.add_argument('--cuda', action='store_true', default=True) #NOTE: cpu-only has not been tested so you might have to change code if you deactivate this flag 486 | parser.add_argument('--model', default="erfnet") 487 | parser.add_argument('--state') 488 | 489 | parser.add_argument('--port', type=int, default=8097) 490 | parser.add_argument('--datadir', default=os.getenv("HOME") + "/datasets/cityscapes/") 491 | parser.add_argument('--height', type=int, default=512) 492 | parser.add_argument('--num-epochs', type=int, default=150) 493 | parser.add_argument('--num-workers', type=int, default=4) 494 | parser.add_argument('--batch-size', type=int, default=6) 495 | parser.add_argument('--steps-loss', type=int, default=50) 496 | parser.add_argument('--steps-plot', type=int, default=50) 497 | parser.add_argument('--epochs-save', type=int, default=0) #You can use this value to save model every X epochs 498 | parser.add_argument('--savedir', required=True) 499 | parser.add_argument('--decoder', action='store_true') 500 | parser.add_argument('--pretrainedEncoder') #, default="../trained_models/erfnet_encoder_pretrained.pth.tar") 501 | parser.add_argument('--visualize', action='store_true') 502 | 503 | parser.add_argument('--iouTrain', action='store_true', default=False) #recommended: False (takes more time to train otherwise) 504 | parser.add_argument('--iouVal', action='store_true', default=True) 505 | parser.add_argument('--resume', action='store_true') #Use this flag to load last checkpoint for training 506 | 507 | main(parser.parse_args()) 508 | -------------------------------------------------------------------------------- /train/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | 6 | def colormap_cityscapes(n): 7 | cmap=np.zeros([n, 3]).astype(np.uint8) 8 | cmap[0,:] = np.array([128, 64,128]) 9 | cmap[1,:] = np.array([244, 35,232]) 10 | cmap[2,:] = np.array([ 70, 70, 70]) 11 | cmap[3,:] = np.array([ 102,102,156]) 12 | cmap[4,:] = np.array([ 190,153,153]) 13 | cmap[5,:] = np.array([ 153,153,153]) 14 | 15 | cmap[6,:] = np.array([ 250,170, 30]) 16 | cmap[7,:] = np.array([ 220,220, 0]) 17 | cmap[8,:] = np.array([ 107,142, 35]) 18 | cmap[9,:] = np.array([ 152,251,152]) 19 | cmap[10,:] = np.array([ 70,130,180]) 20 | 21 | cmap[11,:] = np.array([ 220, 20, 60]) 22 | cmap[12,:] = np.array([ 255, 0, 0]) 23 | cmap[13,:] = np.array([ 0, 0,142]) 24 | cmap[14,:] = np.array([ 0, 0, 70]) 25 | cmap[15,:] = np.array([ 0, 60,100]) 26 | 27 | cmap[16,:] = np.array([ 0, 80,100]) 28 | cmap[17,:] = np.array([ 0, 0,230]) 29 | cmap[18,:] = np.array([ 119, 11, 32]) 30 | cmap[19,:] = np.array([ 0, 0, 0]) 31 | 32 | return cmap 33 | 34 | 35 | def colormap(n): 36 | cmap=np.zeros([n, 3]).astype(np.uint8) 37 | 38 | for i in np.arange(n): 39 | r, g, b = np.zeros(3) 40 | 41 | for j in np.arange(8): 42 | r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j)) 43 | g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1)) 44 | b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2)) 45 | 46 | cmap[i,:] = np.array([r, g, b]) 47 | 48 | return cmap 49 | 50 | class Relabel: 51 | 52 | def __init__(self, olabel, nlabel): 53 | self.olabel = olabel 54 | self.nlabel = nlabel 55 | 56 | def __call__(self, tensor): 57 | assert (isinstance(tensor, torch.LongTensor) or isinstance(tensor, torch.ByteTensor)) , 'tensor needs to be LongTensor' 58 | tensor[tensor == self.olabel] = self.nlabel 59 | return tensor 60 | 61 | 62 | class ToLabel: 63 | 64 | def __call__(self, image): 65 | return torch.from_numpy(np.array(image)).long().unsqueeze(0) 66 | 67 | 68 | class Colorize: 69 | 70 | def __init__(self, n=22): 71 | #self.cmap = colormap(256) 72 | self.cmap = colormap_cityscapes(256) 73 | self.cmap[n] = self.cmap[-1] 74 | self.cmap = torch.from_numpy(self.cmap[:n]) 75 | 76 | def __call__(self, gray_image): 77 | size = gray_image.size() 78 | #print(size) 79 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 80 | #color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0) 81 | 82 | #for label in range(1, len(self.cmap)): 83 | for label in range(0, len(self.cmap)): 84 | mask = gray_image[0] == label 85 | #mask = gray_image == label 86 | 87 | color_image[0][mask] = self.cmap[label][0] 88 | color_image[1][mask] = self.cmap[label][1] 89 | color_image[2][mask] = self.cmap[label][2] 90 | 91 | return color_image 92 | -------------------------------------------------------------------------------- /train/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torch.autograd import Variable 4 | 5 | from visdom import Visdom 6 | 7 | class Dashboard: 8 | 9 | def __init__(self, port): 10 | self.vis = Visdom(port=port) 11 | 12 | def loss(self, losses, title): 13 | x = np.arange(1, len(losses)+1, 1) 14 | 15 | self.vis.line(losses, x, env='loss', opts=dict(title=title)) 16 | 17 | def image(self, image, title): 18 | if image.is_cuda: 19 | image = image.cpu() 20 | if isinstance(image, Variable): 21 | image = image.data 22 | image = image.numpy() 23 | 24 | self.vis.image(image, env='images', opts=dict(title=title)) -------------------------------------------------------------------------------- /trained_models/erfnet_encoder_pretrained.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eromera/erfnet_pytorch/d4a46faf9e465286c89ebd9c44bc929b2d213fb3/trained_models/erfnet_encoder_pretrained.pth.tar -------------------------------------------------------------------------------- /trained_models/erfnet_pretrained.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eromera/erfnet_pytorch/d4a46faf9e465286c89ebd9c44bc929b2d213fb3/trained_models/erfnet_pretrained.pth --------------------------------------------------------------------------------