├── 1.png ├── README.md ├── datasets └── cityscapes │ └── instructions.txt └── scripts ├── __pycache__ ├── datasets.cpython-37.pyc ├── model.cpython-37.pyc └── utils.cpython-37.pyc ├── datasets.py ├── evaluate.py ├── model.py ├── train.py └── utils.py /1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdaan19/UNet-Multiclass/00cb35fb81e051e302dc720123832cbe6b40335f/1.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## About 2 | This repository contains code used to train U-Net on a multi-class segmentation dataset. The code has been written in python. Inside the scripts folder, you can find all the different python files used to train, evaluate and prepare the data. The dataset that was used is the [cityscapes dataset](https://www.cityscapes-dataset.com/). If you want to know how to perform image segmentation tasks in PyTorch, then visit my [text-based tutorial](https://medium.com/@mhamdaan/multi-class-semantic-segmentation-with-u-net-pytorch-ee81a66bba89) on medium. 3 | 4 | ![image](1.png) 5 | 6 | ### Dependencies 7 | | Sl. no. | Dependency | Install | 8 | | --- | --- | --- | 9 | | 1. | [PyTorch](https://pytorch.org/) | Head on to PyTorch's official website and choose your OS, cuda version and other specifications and install using the generated terminal command. | 10 | | 2. | [torchvision](https://pytorch.org/vision/stable/index.html) | ```pip install torchvision``` | 11 | | 3. | [Pillow](https://pillow.readthedocs.io/en/stable/) | ```python3 -m pip install --upgrade Pillow``` | 12 | | 4. | [NumPy](https://numpy.org/) | ```pip install numpy``` | 13 | | 5. | [CityscapesScripts](https://github.com/mcordts/cityscapesScripts) | ```python -m pip install cityscapesscripts``` | 14 | 15 | There are some other libraries that were also used such as Matplotlib (for visualization) and TQDM (to show the progress bar), but because they were only supplementary, I did not include them in the table above. 16 | -------------------------------------------------------------------------------- /datasets/cityscapes/instructions.txt: -------------------------------------------------------------------------------- 1 | 1. Note that 'cityscapes' is your root directory. 2 | 2. Keep 'gtFine', 'gtCoarse' and 'leftImg8bit' folders here. 3 | 3. Cityscapes dataset files can be downloaded from https://www.cityscapes-dataset.com/downloads/ 4 | -------------------------------------------------------------------------------- /scripts/__pycache__/datasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdaan19/UNet-Multiclass/00cb35fb81e051e302dc720123832cbe6b40335f/scripts/__pycache__/datasets.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdaan19/UNet-Multiclass/00cb35fb81e051e302dc720123832cbe6b40335f/scripts/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdaan19/UNet-Multiclass/00cb35fb81e051e302dc720123832cbe6b40335f/scripts/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms, datasets 6 | import numpy as np 7 | import os 8 | 9 | class CityscapesDataset(Dataset): 10 | def __init__(self, split, root_dir, target_type='semantic', mode='fine', transform=None, eval=False): 11 | self.transform = transform 12 | if mode == 'fine': 13 | self.mode = 'gtFine' 14 | elif mode == 'coarse': 15 | self.mode = 'gtCoarse' 16 | self.split = split 17 | self.yLabel_list = [] 18 | self.XImg_list = [] 19 | self.eval = eval 20 | 21 | # Preparing a list of all labelTrainIds rgb and 22 | # ground truth images. Setting relabbelled=True is recommended. 23 | 24 | self.label_path = os.path.join(os.getcwd(), root_dir+'/'+self.mode+'/'+self.split) 25 | self.rgb_path = os.path.join(os.getcwd(), root_dir+'/leftImg8bit/'+self.split) 26 | city_list = os.listdir(self.label_path) 27 | for city in city_list: 28 | temp = os.listdir(self.label_path+'/'+city) 29 | list_items = temp.copy() 30 | 31 | # 19-class label items being filtered 32 | for item in temp: 33 | if not item.endswith('labelTrainIds.png', 0, len(item)): 34 | list_items.remove(item) 35 | 36 | # defining paths 37 | list_items = ['/'+city+'/'+path for path in list_items] 38 | 39 | self.yLabel_list.extend(list_items) 40 | self.XImg_list.extend( 41 | ['/'+city+'/'+path for path in os.listdir(self.rgb_path+'/'+city)] 42 | ) 43 | 44 | def __len__(self): 45 | length = len(self.XImg_list) 46 | return length 47 | 48 | 49 | def __getitem__(self, index): 50 | image = Image.open(self.rgb_path+self.XImg_list[index]) 51 | y = Image.open(self.label_path+self.yLabel_list[index]) 52 | 53 | if self.transform is not None: 54 | image = self.transform(image) 55 | y = self.transform(y) 56 | 57 | image = transforms.ToTensor()(image) 58 | y = np.array(y) 59 | y = torch.from_numpy(y) 60 | 61 | y = y.type(torch.LongTensor) 62 | if self.eval: 63 | return image, y, self.XImg_list[index] 64 | else: 65 | return image, y 66 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model import UNET 3 | from utils import * 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | from torchvision import transforms 7 | from cityscapesscripts.helpers.labels import trainId2label as t2l 8 | 9 | if torch.cuda.is_available(): 10 | device = 'cuda:0' 11 | print('Running on the GPU') 12 | else: 13 | device = 'cpu' 14 | print('Running on the CPU') 15 | 16 | ROOT_DIR_CITYSCAPES = 'datasets/cityscapes' 17 | IMAGE_HEIGHT = 300 18 | IMAGE_WIDTH = 600 19 | 20 | MODEL_PATH = "YOUR-MODEL-PATH-WHICH-NEEDS-TO-BE-EVALUATED" 21 | 22 | EVAL = False 23 | PLOT_LOSS = False 24 | 25 | def save_predictions(data, model): 26 | model.eval() 27 | with torch.no_grad(): 28 | for idx, batch in enumerate(tqdm(data)): 29 | 30 | X, y, s = batch # here 's' is the name of the file stored in the root directory 31 | X, y = X.to(device), y.to(device) 32 | predictions = model(X) 33 | 34 | predictions = torch.nn.functional.softmax(predictions, dim=1) 35 | pred_labels = torch.argmax(predictions, dim=1) 36 | pred_labels = pred_labels.float() 37 | 38 | # Remapping the labels 39 | pred_labels = pred_labels.to('cpu') 40 | pred_labels.apply_(lambda x: t2l[x].id) 41 | pred_labels = pred_labels.to(device) 42 | 43 | # Resizing predicted images too original size 44 | pred_labels = transforms.Resize((1024, 2048))(pred_labels) 45 | 46 | # Configure filename & location to save predictions as images 47 | s = str(s) 48 | pos = s.rfind('/', 0, len(s)) 49 | name = s[pos+1:-18] 50 | global location 51 | location = 'saved_images\multiclass_1' 52 | 53 | utils.save_as_images(pred_labels, location, name, multiclass=True) 54 | 55 | def evaluate(path): 56 | T = transforms.Compose([ 57 | transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH), interpolation=Image.NEAREST) 58 | ]) 59 | 60 | val_set = get_cityscapes_data( 61 | root_dir=ROOT_DIR_CITYSCAPES, 62 | split='val', 63 | mode='fine', 64 | relabelled=True, 65 | transforms=T, 66 | shuffle=True, 67 | eval=True 68 | ) 69 | 70 | print('Data has been loaded!') 71 | 72 | net = UNET(in_channels=3, classes=19).to(device) 73 | checkpoint = torch.load(path) 74 | net.load_state_dict(checkpoint['model_state_dict']) 75 | net.eval() 76 | print(f'{path} has been loaded and initialized') 77 | save_predictions(val_set, net) 78 | 79 | def plot_losses(path): 80 | checkpoint = torch.load(path) 81 | losses = checkpoint['loss_values'] 82 | epoch = checkpoint['epoch'] 83 | epoch_list = list(range(epoch)) 84 | 85 | plt.plot(epoch_list, losses) 86 | plt.xlabel('Epochs') 87 | plt.ylabel('Loss') 88 | plt.title(f"Loss over {epoch+1} epoch/s") 89 | plt.show() 90 | 91 | if __name__ == '__main__': 92 | if EVAL: 93 | evaluate(MODEL_PATH) 94 | if PLOT_LOSS: 95 | plot_losses(MODEL_PATH) -------------------------------------------------------------------------------- /scripts/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms.functional as TF 4 | 5 | class UNET(nn.Module): 6 | 7 | def __init__(self, in_channels=3, classes=1): 8 | super(UNET, self).__init__() 9 | self.layers = [in_channels, 64, 128, 256, 512, 1024] 10 | 11 | self.double_conv_downs = nn.ModuleList( 12 | [self.__double_conv(layer, layer_n) for layer, layer_n in zip(self.layers[:-1], self.layers[1:])]) 13 | 14 | self.up_trans = nn.ModuleList( 15 | [nn.ConvTranspose2d(layer, layer_n, kernel_size=2, stride=2) 16 | for layer, layer_n in zip(self.layers[::-1][:-2], self.layers[::-1][1:-1])]) 17 | 18 | self.double_conv_ups = nn.ModuleList( 19 | [self.__double_conv(layer, layer//2) for layer in self.layers[::-1][:-2]]) 20 | 21 | self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2) 22 | 23 | self.final_conv = nn.Conv2d(64, classes, kernel_size=1) 24 | 25 | 26 | def __double_conv(self, in_channels, out_channels): 27 | conv = nn.Sequential( 28 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), 29 | nn.BatchNorm2d(out_channels), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 32 | nn.ReLU(inplace=True) 33 | ) 34 | return conv 35 | 36 | def forward(self, x): 37 | # down layers 38 | concat_layers = [] 39 | 40 | for down in self.double_conv_downs: 41 | x = down(x) 42 | if down != self.double_conv_downs[-1]: 43 | concat_layers.append(x) 44 | x = self.max_pool_2x2(x) 45 | 46 | concat_layers = concat_layers[::-1] 47 | 48 | # up layers 49 | for up_trans, double_conv_up, concat_layer in zip(self.up_trans, self.double_conv_ups, concat_layers): 50 | x = up_trans(x) 51 | if x.shape != concat_layer.shape: 52 | x = TF.resize(x, concat_layer.shape[2:]) 53 | 54 | concatenated = torch.cat((concat_layer, x), dim=1) 55 | x = double_conv_up(concatenated) 56 | 57 | x = self.final_conv(x) 58 | 59 | return x 60 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from tqdm import tqdm 7 | from utils import * 8 | from model import UNET 9 | 10 | if torch.cuda.is_available(): 11 | DEVICE = 'cuda:0' 12 | print('Running on the GPU') 13 | else: 14 | DEVICE = "cpu" 15 | print('Running on the CPU') 16 | 17 | MODEL_PATH = 'YOUR-MODEL-PATH' 18 | LOAD_MODEL = False 19 | ROOT_DIR = '../datasets/cityscapes' 20 | IMG_HEIGHT = 110 21 | IMG_WIDTH = 220 22 | BATCH_SIZE = 16 23 | LEARNING_RATE = 0.0005 24 | EPOCHS = 5 25 | 26 | def train_function(data, model, optimizer, loss_fn, device): 27 | print('Entering into train function') 28 | loss_values = [] 29 | data = tqdm(data) 30 | for index, batch in enumerate(data): 31 | X, y = batch 32 | X, y = X.to(device), y.to(device) 33 | preds = model(X) 34 | 35 | loss = loss_fn(preds, y) 36 | optimizer.zero_grad() 37 | loss.backward() 38 | optimizer.step() 39 | 40 | return loss.item() 41 | 42 | 43 | def main(): 44 | global epoch 45 | epoch = 0 # epoch is initially assigned to 0. If LOAD_MODEL is true then 46 | # epoch is set to the last value + 1. 47 | LOSS_VALS = [] # Defining a list to store loss values after every epoch 48 | 49 | transform = transforms.Compose([ 50 | transforms.Resize((IMG_HEIGHT, IMG_WIDTH), interpolation=Image.NEAREST), 51 | ]) 52 | 53 | train_set = get_cityscapes_data( 54 | split='train', 55 | mode='fine', 56 | relabelled=True, 57 | root_dir=ROOT_DIR, 58 | transforms=transform, 59 | batch_size=BATCH_SIZE, 60 | ) 61 | 62 | print('Data Loaded Successfully!') 63 | 64 | # Defining the model, optimizer and loss function 65 | unet = UNET(in_channels=3, classes=19).to(DEVICE).train() 66 | optimizer = optim.Adam(unet.parameters(), lr=LEARNING_RATE) 67 | loss_function = nn.CrossEntropyLoss(ignore_index=255) 68 | 69 | # Loading a previous stored model from MODEL_PATH variable 70 | if LOAD_MODEL == True: 71 | checkpoint = torch.load(MODEL_PATH) 72 | unet.load_state_dict(checkpoint['model_state_dict']) 73 | optimizer.load_state_dict(checkpoint['optim_state_dict']) 74 | epoch = checkpoint['epoch']+1 75 | LOSS_VALS = checkpoint['loss_values'] 76 | print("Model successfully loaded!") 77 | 78 | #Training the model for every epoch. 79 | for e in range(epoch, EPOCHS): 80 | print(f'Epoch: {e}') 81 | loss_val = train_function(train_set, unet, optimizer, loss_function, DEVICE) 82 | LOSS_VALS.append(loss_val) 83 | torch.save({ 84 | 'model_state_dict': unet.state_dict(), 85 | 'optim_state_dict': optimizer.state_dict(), 86 | 'epoch': e, 87 | 'loss_values': LOSS_VALS 88 | }, MODEL_PATH) 89 | print("Epoch completed and model successfully saved!") 90 | 91 | 92 | if __name__ == '__main__': 93 | main() -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import datasets, utils, transforms 4 | from datasets import CityscapesDataset 5 | from PIL import Image 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | def get_cityscapes_data( 10 | mode, 11 | split, 12 | relabelled, 13 | root_dir='datasets/cityscapes', 14 | target_type="semantic", 15 | transforms=None, 16 | batch_size=1, 17 | eval=False, 18 | shuffle=True, 19 | pin_memory=True, 20 | 21 | ): 22 | data = CityscapesDataset( 23 | mode=mode, split=split, target_type=target_type, relabelled=relabelled, transform=transforms, root_dir=root_dir, eval=eval) 24 | 25 | data_loaded = torch.utils.data.DataLoader( 26 | data, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory) 27 | 28 | return data_loaded 29 | 30 | # Functions to save predictions as images 31 | def save_as_images(tensor_pred, folder, image_name): 32 | tensor_pred = transforms.ToPILImage()(tensor_pred.byte()) 33 | filename = f"{folder}\{image_name}.png" 34 | tensor_pred.save(filename) 35 | 36 | 37 | 38 | 39 | 40 | --------------------------------------------------------------------------------