├── .gitignore ├── README.md ├── confusion_matrix.py ├── data ├── augmentCityscapes.py ├── cityscapeLabels.txt ├── dataLoaderUtils.py ├── loadCityscapes.py └── segmented_data.py ├── main.py ├── media └── opt.txt ├── models ├── linknet.py ├── model.py └── nobypass.py ├── opts.py ├── test.py ├── train.py ├── transforms.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LinkNet 2 | 3 | This repository contains our PyTorch implementation of the network developed by us at e-Lab. 4 | You can go to our [blogpost](https://codeac29.github.io/projects/linknet/) or read the article [LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation](https://arxiv.org/abs/1707.03718) for further details. 5 | 6 | **The training script has issues and it is still a work in progress.** 7 | 8 | ## Dependencies: 9 | 10 | + Python 3.4 or greater 11 | + [PyTorch](https://pytorch.org) 12 | + [OpenCV](https://opencv.org/) 13 | 14 | Currently the network can be trained on two datasets: 15 | 16 | | Datasets | Input Resolution | # of classes | 17 | |:--------:|:----------------:|:------------:| 18 | | [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) (cv) | 768x576 | 11 | 19 | | [Cityscapes](https://www.cityscapes-dataset.com/) (cs) | 1024x512 | 19 | 20 | 21 | To download both datasets, follow the link provided above. 22 | Both the datasets are first of all resized by the training script and if you want then you can cache this resized data using `--cachepath` option. 23 | In case of CamVid dataset, the available video data is first split into train/validate/test set. 24 | This is done using [prepCamVid.lua](data/prepCamVid.lua) file. 25 | [dataDistributionCV.txt](misc/dataDistributionCV.txt) contains the detail about splitting of CamVid dataset. 26 | These things are automatically run before training of the network. 27 | 28 | LinkNet performance on both of the above dataset: 29 | 30 | | Datasets | Best IoU | Best iIoU | 31 | |:--------:|:----------------:|:------------:| 32 | | Cityscapes | 76.44 | 60.78 | 33 | | CamVid | 69.10 | 55.83 | 34 | 35 | ## Files/folders and their usage: 36 | 37 | * [main.py](main.py) : main file 38 | * [opts.py](opts.py) : contains all the input options used by the tranining script 39 | * [data](data) : data loaders for loading datasets 40 | * [models] : all the model architectures are defined here 41 | * [train.py](train.py) : loading of models and error calculation 42 | * [test.py](test.py) : calculate testing error and save confusion matrices 43 | * [ConfusionMatrix.py](ConfusionMatrix.py) : implements a confusion matrix 44 | There are three model files present in `models` folder: 45 | 46 | * [model.py](models/model.py) : our LinkNet architecture 47 | * [model-res-dec.py](models/model-res-dec.py) : LinkNet with residual connection in each of the decoder blocks. 48 | This slightly improves the result but we had to use `bilinear interpolation` in residual connection because of which we were not able to run our trained model on TX1. 49 | * [nobypass.py](models/nobypass.py) : this architecture does not use any link between encoder and decoder. 50 | You can use this model to verify if connecting encoder and decoder modules actually improve performance. 51 | 52 | A sample command to train network is given below: 53 | 54 | ``` 55 | th main.py --datapath /media/HDD1/Datasets/ --cachepath /dataCache/cityscapes/ --dataset cs --model models/model.py --save /Trained_models/cityscapes/ --saveTrainConf True --saveAll True --plot True 56 | ``` 57 | 58 | ### License 59 | 60 | This software is released under a creative commons license which allows for personal and research use only. 61 | For a commercial license please contact the authors. 62 | You can view a license summary here: http://creativecommons.org/licenses/by-nc/4.0/ 63 | -------------------------------------------------------------------------------- /confusion_matrix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from sklearn.metrics import confusion_matrix 5 | 6 | 7 | class ConfusionMatrix: 8 | def __init__(self, nclasses, classes, useUnlabeled=False): 9 | self.mat = np.zeros((nclasses, nclasses), dtype=np.float) 10 | self.valids = np.zeros((nclasses), dtype=np.float) 11 | self.IoU = np.zeros((nclasses), dtype=np.float) 12 | self.mIoU = 0 13 | 14 | self.nclasses = nclasses 15 | self.classes = classes 16 | self.list_classes = list(range(nclasses)) 17 | self.useUnlabeled = useUnlabeled 18 | self.matStartIdx = 1 if not self.useUnlabeled else 0 19 | 20 | def update_matrix(self, target, prediction): 21 | if not(isinstance(prediction, np.ndarray)) or not(isinstance(target, np.ndarray)): 22 | print("Expecting ndarray") 23 | elif len(target.shape) == 3: # batched spatial target 24 | if len(prediction.shape) == 4: # prediction is 1 hot encoded 25 | temp_prediction = np.argmax(prediction, axis=1).flatten() 26 | elif len(prediction.shape) == 3: 27 | temp_prediction = prediction.flatten() 28 | else: 29 | print("Make sure prediction and target dimension is correct") 30 | 31 | temp_target = target.flatten() 32 | elif len(target.shape) == 2: # spatial target 33 | if len(prediction.shape) == 3: # prediction is 1 hot encoded 34 | temp_prediction = np.argmax(prediction, axis=1).flatten() 35 | elif len(prediction.shape) == 2: 36 | temp_prediction = prediction.flatten() 37 | else: 38 | print("Make sure prediction and target dimension is correct") 39 | 40 | temp_target = target.flatten() 41 | elif len(target.shape) == 1: 42 | if len(prediction.shape) == 2: # prediction is 1 hot encoded 43 | temp_prediction = np.argmax(prediction, axis=1).flatten() 44 | elif len(prediction.shape) == 1: 45 | temp_prediction = prediction 46 | else: 47 | print("Make sure prediction and target dimension is correct") 48 | 49 | temp_target = target 50 | else: 51 | print("Data with this dimension cannot be handled") 52 | 53 | self.mat += confusion_matrix(temp_target, temp_prediction, labels=self.list_classes) 54 | 55 | def scores(self): 56 | tp = 0 57 | fp = 0 58 | tn = 0 59 | fn = 0 60 | total = 0 # Total true positives 61 | N = 0 # Total samples 62 | for i in range(self.matStartIdx, self.nclasses): 63 | N += sum(self.mat[:, i]) 64 | tp = self.mat[i][i] 65 | fp = sum(self.mat[self.matStartIdx:, i]) - tp 66 | fn = sum(self.mat[i,self.matStartIdx:]) - tp 67 | 68 | if (tp+fp) == 0: 69 | self.valids[i] = 0 70 | else: 71 | self.valids[i] = tp/(tp + fp) 72 | 73 | if (tp+fp+fn) == 0: 74 | self.IoU[i] = 0 75 | else: 76 | self.IoU[i] = tp/(tp + fp + fn) 77 | 78 | total += tp 79 | 80 | self.mIoU = sum(self.IoU[self.matStartIdx:])/(self.nclasses - self.matStartIdx) 81 | self.accuracy = total/(sum(sum(self.mat[self.matStartIdx:, self.matStartIdx:]))) 82 | 83 | return self.valids, self.accuracy, self.IoU, self.mIoU, self.mat 84 | 85 | def plot_confusion_matrix(self, filename): 86 | # Plot generated confusion matrix 87 | print(filename) 88 | 89 | 90 | def reset(self): 91 | self.mat = np.zeros((self.nclasses, self.nclasses), dtype=float) 92 | self.valids = np.zeros((self.nclasses), dtype=float) 93 | self.IoU = np.zeros((self.nclasses), dtype=float) 94 | self.mIoU = 0 95 | -------------------------------------------------------------------------------- /data/augmentCityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from tqdm import trange 4 | 5 | print('\033[0;0f\033[0J') 6 | root_dir = '/media/SSD1/cityscapes/' 7 | 8 | data_dir = os.path.join(root_dir, 'leftImg8bit/train/') 9 | label_dir = os.path.join(root_dir, 'gtFine/train/') 10 | 11 | def transform(img): 12 | center_y = 512 13 | center_x = 1024 14 | y1, y2 = center_y - 3*center_y//4, center_y + 3*center_y//4 15 | x1, x2 = center_x - 3*center_x//4, center_x + 3*center_x//4 16 | r_img = img[y1:y2, x1:x2] 17 | 18 | f_img = cv2.flip(img, 0) # Horizontal flip 19 | rf_img = cv2.flip(r_img, 0) # Horizontal flip cropped img 20 | 21 | return r_img, f_img, rf_img 22 | 23 | 24 | pbar1 = trange(len(os.listdir(data_dir)), position=0, desc='Overall progress ') 25 | for folder in os.listdir(data_dir): 26 | d = os.path.join(data_dir, folder) 27 | if not os.path.isdir(d): 28 | continue 29 | 30 | pbar2 = trange(len(os.listdir(d)), position=1, desc='Within folder progress ') 31 | for filename in os.listdir(d): 32 | if filename.endswith('.png'): 33 | data_path = '{0}/{1}/{2}'.format(data_dir, folder, filename) 34 | label_file = filename.replace('leftImg8bit', 'gtFine_labelIds') 35 | label_path = '{0}/{1}/{2}'.format(label_dir, folder, label_file) 36 | 37 | source_img = cv2.imread(data_path) 38 | r_img, f_img, rf_img = transform(source_img) 39 | dest_path = data_path[:-4] + '_r.png' 40 | cv2.imwrite(dest_path, r_img) 41 | dest_path = data_path[:-4] + '_f.png' 42 | cv2.imwrite(dest_path, f_img) 43 | dest_path = data_path[:-4] + '_rf.png' 44 | cv2.imwrite(dest_path, rf_img) 45 | 46 | source_img = cv2.imread(label_path, 0) 47 | r_img, f_img, rf_img = transform(source_img) 48 | dest_path = label_path[:-4] + '_r.png' 49 | cv2.imwrite(dest_path, r_img) 50 | dest_path = label_path[:-4] + '_f.png' 51 | cv2.imwrite(dest_path, f_img) 52 | dest_path = label_path[:-4] + '_rf.png' 53 | cv2.imwrite(dest_path, rf_img) 54 | 55 | pbar2.update(1) 56 | 57 | pbar2.close() 58 | pbar1.update(1) 59 | 60 | pbar1.close() 61 | print('\nData augmentation complete') 62 | -------------------------------------------------------------------------------- /data/cityscapeLabels.txt: -------------------------------------------------------------------------------- 1 | Unlabeled 2 | Road 3 | Sidewalk 4 | Building 5 | Wall 6 | Fence 7 | Pole 8 | TrafficLight 9 | TrafficSign 10 | Vegetation 11 | Terrain 12 | Sky 13 | Person 14 | Rider 15 | Car 16 | Truck 17 | Bus 18 | Train 19 | Motorcycle 20 | Bicycle 21 | -------------------------------------------------------------------------------- /data/dataLoaderUtils.py: -------------------------------------------------------------------------------- 1 | ############################### 2 | # Helper functions for Data Loaders 3 | ############################### 4 | 5 | import os 6 | 7 | class dataLoaderUtils: 8 | """ 9 | params : filename - name of the file to read 10 | returns : list of lines after striping '\n' 11 | (new line charcater) from the end. 12 | """ 13 | @staticmethod 14 | def readLines(filename): 15 | assert (os.path.isfile(os.path.join("data", filename))), " File:" + filename + " does not exists" 16 | if os.path.isfile(os.path.join("data", filename)): 17 | print("file exists") 18 | lines = open(os.path.join("data", filename)).read().splitlines() 19 | return lines 20 | 21 | 22 | """ 23 | params : path - location of the directory 24 | 25 | Create a directory at the path specified if it doesnot 26 | already exists. 27 | """ 28 | 29 | @staticmethod 30 | def mkdir(path): 31 | #assert (os.path.isfile(path)), " dir already exists" 32 | if not os.path.isdir(path): 33 | os.mkdir(path) 34 | -------------------------------------------------------------------------------- /data/loadCityscapes.py: -------------------------------------------------------------------------------- 1 | ####################### 2 | # Cityscape Data Loader 3 | 4 | # 28th August 5 | ######################## 6 | 7 | ##################### 8 | # Note : If a component is an absolute path, all previous components are thrown away and joining continues from the absolute path component. 9 | 10 | ##################### 11 | import torch 12 | import os 13 | import sys 14 | import gc 15 | 16 | ##test change 17 | sys.path.insert(0, '..') 18 | 19 | from data.dataLoaderUtils import dataLoaderUtils as utils 20 | from PIL import Image 21 | from torchvision import transforms 22 | #from progress.bar import Bar # for tracking progress 23 | 24 | 25 | class DataModel: 26 | def __init__(self, size, args): 27 | self.data = torch.FloatTensor(size, args['channels'], args['imHeight'], args['imWidth']) 28 | self.labels = torch.LongTensor(size, args['imHeight'], args['imWidth']) 29 | self.prev_error = 1e10 # a really huge value 30 | self.size = size 31 | 32 | 33 | class CityScapeDataLoader: 34 | def __init__(self, opts): 35 | # self.dataset_name = "cityscapes" 36 | self.train_size = 2975 # cityscape train images 37 | self.val_size = 500 # cityscape validation images 38 | self.labels_filename = "cityscapeLabels.txt" # cityscape labels file 39 | self.args = opts # command line arguments 40 | self.classes = utils.readLines(self.labels_filename) 41 | self.histClasses = torch.FloatTensor(len(self.classes)).zero_() 42 | self.loaded_from_cache = False 43 | self.dataset_name = "cityscapes" 44 | self.val_data = None 45 | self.train_data = None 46 | self.cacheFilePath = None 47 | self.conClasses = None 48 | # defining conClasses and classMap 49 | self.define_conClasses() 50 | self.define_classMap() 51 | 52 | # defining paths 53 | self.define_data_loader_paths() 54 | self.data_loader() 55 | print("\n\ncache file path: ", self.cacheFilePath) 56 | 57 | def define_data_loader_paths(self): 58 | dir_name = str(self.args['imHeight']) + "_" + str(self.args['imWidth']) 59 | dir_path = os.path.join(self.args['cachepath'], self.dataset_name, dir_name) 60 | self.cacheFilePath = os.path.join(dir_path, "data.pyt") 61 | 62 | def define_conClasses(self): 63 | self.conClasses = self.classes 64 | self.conClasses.remove("Unlabeled") 65 | 66 | def define_classMap(self): 67 | # Ignoring unnecessary classes 68 | self.classMap = {} 69 | self.classMap[-1] = 1 # licence plate 70 | self.classMap[0] = 1 # Unabeled 71 | self.classMap[1] = 1 # Ego vehicle 72 | self.classMap[2] = 1 # Rectification border 73 | self.classMap[3] = 1 # Out of roi 74 | self.classMap[4] = 1 # Static 75 | self.classMap[5] = 1 # Dynamic 76 | self.classMap[6] = 1 # Ground 77 | self.classMap[7] = 2 # Road 78 | self.classMap[8] = 3 # Sidewalk 79 | self.classMap[9] = 1 # Parking 80 | self.classMap[10] = 1 # Rail track 81 | self.classMap[11] = 4 # Building 82 | self.classMap[12] = 5 # Wall 83 | self.classMap[13] = 6 # Fence 84 | self.classMap[14] = 1 # Guard rail 85 | self.classMap[15] = 1 # Bridge 86 | self.classMap[16] = 1 # Tunnel 87 | self.classMap[17] = 7 # Pole 88 | self.classMap[18] = 1 # Polegroup 89 | self.classMap[19] = 8 # Traffic light 90 | self.classMap[20] = 9 # Traffic sign 91 | self.classMap[21] = 10 # Vegetation 92 | self.classMap[22] = 11 # Terrain 93 | self.classMap[23] = 12 # Sky 94 | self.classMap[24] = 13 # Person 95 | self.classMap[25] = 14 # Rider 96 | self.classMap[26] = 15 # Car 97 | self.classMap[27] = 16 # Truck 98 | self.classMap[28] = 17 # Bus 99 | self.classMap[29] = 1 # Caravan 100 | self.classMap[30] = 1 # Trailer 101 | self.classMap[31] = 18 # Train 102 | self.classMap[32] = 19 # Motorcycle 103 | self.classMap[33] = 20 # Bicycle 104 | 105 | def valid_file_extension(self, filename, extensions): 106 | ext = os.path.splitext(filename)[-1] 107 | return ext in extensions 108 | 109 | def data_loader(self): 110 | print('\n\27[31m\27[4mLoading cityscape dataset\27[0m') 111 | print('# of classes: ', len(self.classes)) 112 | 113 | #print("cacheFilePath: ", self.cacheFilePath) 114 | if self.args['cachepath'] != None and os.path.exists(self.cacheFilePath): 115 | #print('\27[32mData cache found at: \27[0m\27[4m', self.cacheFilePath, '\27[0m') 116 | data_cache = torch.load(self.cacheFilePath) 117 | self.train_data = data_cache['trainData'] 118 | self.val_data = data_cache['testData'] 119 | self.histClasses = data_cache['histClasses'] 120 | self.loaded_from_cache = True 121 | dataCache = None 122 | gc.collect() 123 | else: 124 | self.train_data = DataModel(self.train_size, self.args) 125 | self.val_data = DataModel(self.val_size, self.args) 126 | 127 | data_path_root_train = os.path.join(self.args['datapath'], self.dataset_name, 'leftImg8bit/train/') 128 | self.load_data(data_path_root_train, self.train_data) 129 | 130 | data_path_root_val = os.path.join(self.args['datapath'], self.dataset_name, 'leftImg8bit/val/') 131 | self.load_data(data_path_root_val, self.val_data) 132 | 133 | if self.args['cachepath'] != None and not self.loaded_from_cache: 134 | print('==> Saving data to cache:' + self.cacheFilePath) 135 | data_cache = {} 136 | data_cache["trainData"] = self.train_data 137 | data_cache["testData"] = self.val_data 138 | data_cache["histClasses"] = self.histClasses 139 | 140 | torch.save(data_cache,self.cacheFilePath ) 141 | # data_cache = None 142 | gc.collect() 143 | 144 | def load_data(self, data_path_root, data_model): 145 | extensions = {".jpeg", ".jpg", ".png", ".ppm", ".pgm"} 146 | assert (os.path.exists(data_path_root)), 'No training folder found at : ' + data_path_root 147 | count = 1 148 | dir_names = next(os.walk(data_path_root))[1] 149 | 150 | image_loader = transforms.Compose( 151 | [transforms.Scale((self.args['imWidth'], self.args['imHeight'])), transforms.ToTensor()]) 152 | 153 | # Initializinf the Progress Bar 154 | #bar = Bar("Processing", max=data_model.size) 155 | 156 | for dir in dir_names: 157 | dir_path = os.path.join(data_path_root, dir) 158 | file_names = next(os.walk(dir_path))[2] 159 | for file in file_names: 160 | # process each image 161 | if self.valid_file_extension(file, extensions) and count <= data_model.size: 162 | file_path = os.path.join(dir_path, file) 163 | print("attempting to load image" + file_path + "\n") 164 | # Load training images 165 | image = Image.open(file_path) 166 | data_model.data[count] = image_loader(image).float() 167 | # Get corresponding label filename 168 | label_filename = file_path.replace("leftImg8bit", "gtFine") 169 | label_filename = label_filename.replace(".png", "_labelIds.png") 170 | 171 | # Load training labels 172 | # Load labels with same filename as input image 173 | print("attempting to load file labels " + label_filename + "\n") 174 | label = Image.open(label_filename) 175 | label_file = image_loader(label).float() 176 | 177 | # TODO : aaply function 178 | self.histClasses = self.histClasses + torch.histc(label_file, bins=len(self.classes), min=1, 179 | max=len(self.classes)) 180 | print("data model size:", data_model.data.shape) 181 | data_model.data[count][0] = label_file[0] 182 | count = count + 1 183 | #bar.next() 184 | gc.collect() 185 | break 186 | break 187 | #bar.finish() 188 | 189 | 190 | @staticmethod 191 | def main(opts): 192 | print("inside the main") 193 | loader = CityScapeDataLoader(opts) 194 | print("leaving the main") 195 | loader.data_loader() 196 | return loader 197 | 198 | if __name__ == '__main__': 199 | opts = dict() 200 | -------------------------------------------------------------------------------- /data/segmented_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import cv2 4 | import torch 5 | import torch.utils.data as data 6 | 7 | def find_classes(root_dir): 8 | classes = ['Unlabeled', 'Road', 'Sidewalk', 'Building', 'Wall', 'Fence', 9 | 'Pole', 'TrafficLight', 'TrafficSign', 'Vegetation', 'Terrain', 'Sky', 'Person', 10 | 'Rider', 'Car', 'Truck', 'Bus', 'Train', 'Motorcycle', 'Bicycle'] 11 | #classes.sort() 12 | 13 | class_to_idx = {classes[i]: i for i in range(len(classes))} 14 | return classes, class_to_idx 15 | 16 | 17 | def make_dataset(root_dir, mode): 18 | tensors = [] 19 | data_dir = os.path.join(root_dir, 'leftImg8bit', mode) 20 | target_dir = os.path.join(root_dir, 'gtFine', mode) 21 | for folder in os.listdir(data_dir): 22 | d = os.path.join(data_dir, folder) 23 | if not os.path.isdir(d): 24 | continue 25 | 26 | for filename in os.listdir(d): 27 | if filename.endswith('.png'): 28 | data_path = '{0}/{1}/{2}'.format(data_dir, folder, filename) 29 | target_file = filename.replace('leftImg8bit', 'gtFine_labelIds') 30 | target_path = '{0}/{1}/{2}'.format(target_dir, folder, target_file) 31 | item = (data_path, target_path) 32 | tensors.append(item) 33 | 34 | return tensors 35 | 36 | 37 | def default_loader(input_path, target_path, img_transform, target_transform): 38 | raw_input_image = cv2.imread(input_path) 39 | # Get torch tensor 40 | input_image = img_transform(raw_input_image) 41 | 42 | raw_target_image = cv2.imread(target_path, 0) 43 | # Get torch tensor 44 | target_image = target_transform(raw_target_image) 45 | 46 | return input_image.float(), target_image.type(torch.LongTensor) 47 | 48 | 49 | def remap_class(): 50 | class_remap = {} 51 | class_remap[-1] = 0 #licence plate 52 | class_remap[0] = 0 #Unabeled 53 | class_remap[1] = 0 #Ego vehicle 54 | class_remap[2] = 0 #Rectification border 55 | class_remap[3] = 0 #Out of roi 56 | class_remap[4] = 0 #Static 57 | class_remap[5] = 0 #Dynamic 58 | class_remap[6] = 0 #Ground 59 | class_remap[7] = 1 #Road 60 | class_remap[8] = 2 #Sidewalk 61 | class_remap[9] = 0 #Parking 62 | class_remap[10] = 0 #Rail track 63 | class_remap[11] = 3 #Building 64 | class_remap[12] = 4 #Wall 65 | class_remap[13] = 5 #Fence 66 | class_remap[14] = 0 #Guard rail 67 | class_remap[15] = 0 #Bridge 68 | class_remap[16] = 0 #Tunnel 69 | class_remap[17] = 6 #Pole 70 | class_remap[18] = 0 #Polegroup 71 | class_remap[19] = 7 #Traffic light 72 | class_remap[20] = 8 #Traffic sign 73 | class_remap[21] = 9 #Vegetation 74 | class_remap[22] = 10 #Terrain 75 | class_remap[23] = 11 #Sky 76 | class_remap[24] = 12 #Person 77 | class_remap[25] = 13 #Rider 78 | class_remap[26] = 14 #Car 79 | class_remap[27] = 15 #Truck 80 | class_remap[28] = 16 #Bus 81 | class_remap[29] = 0 #Caravan 82 | class_remap[30] = 0 #Trailer 83 | class_remap[31] = 17 #Train 84 | class_remap[32] = 18 #Motorcycle 85 | class_remap[33] = 19 #Bicycle 86 | 87 | return class_remap 88 | 89 | 90 | class SegmentedData(data.Dataset): 91 | def __init__(self, root, mode, data_mode='small', transform=None, target_transform=None, loader=default_loader): 92 | """ 93 | Load data kept in folders ans their corresponding segmented data 94 | 95 | :param root: path to the root directory of data 96 | :type root: str 97 | :param mode: train/val mode 98 | :type mode: str 99 | :param transform: input transform 100 | :type transform: torch-vision transforms 101 | :param loader: type of data loader 102 | :type loader: function 103 | """ 104 | classes, class_to_idx = find_classes(root) 105 | tensors = make_dataset(root, mode) 106 | 107 | self.data_mode = data_mode 108 | self.tensors = tensors 109 | self.classes = classes 110 | self.class_to_idx = class_to_idx 111 | self.transform = transform 112 | self.target_transform = target_transform 113 | self.loader = loader 114 | 115 | self.class_map = remap_class() 116 | 117 | 118 | def __getitem__(self, index): 119 | # Get path of input image and ground truth 120 | input_path, target_path = self.tensors[index] 121 | # Acquire input image and ground truth 122 | input_tensor, target = self.loader(input_path, target_path, self.transform, self.target_transform) 123 | if self.data_mode == 'small': 124 | target.apply_(lambda x: self.class_map[x]) 125 | 126 | #if self.transform is not None: 127 | # input_tensor = self.transform(input_tensor) 128 | 129 | #if self.target_transform is not None: 130 | # target = self.target_transform(target) 131 | 132 | #if self.transform is not None: 133 | # for i in range(len(input_tensor)): 134 | # print(input_tensor[i].shape) 135 | # input_tensor[i] = self.transform(input_tensor[i]) 136 | # target[i] = self.transform(target[i]) 137 | 138 | return input_tensor, target 139 | 140 | 141 | def __len__(self): 142 | return len(self.tensors) 143 | 144 | 145 | def class_name(self): 146 | return(self.classes) 147 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from subprocess import call 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | from torch.utils.data import DataLoader 10 | from torch.autograd import Variable 11 | 12 | from opts import get_args # Get all the input arguments 13 | from test import Test 14 | from train import Train 15 | from confusion_matrix import ConfusionMatrix 16 | import data.segmented_data as segmented_data 17 | import transforms 18 | 19 | print('\033[0;0f\033[0J') 20 | # Color Palette 21 | CP_R = '\033[31m' 22 | CP_G = '\033[32m' 23 | CP_B = '\033[34m' 24 | CP_Y = '\033[33m' 25 | CP_C = '\033[0m' 26 | 27 | args = get_args() # Holds all the input arguments 28 | 29 | 30 | def cross_entropy2d(x, target, weight=None, size_average=True): 31 | # Taken from https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/loss.py 32 | n, c, h, w = x.size() 33 | log_p = F.log_softmax(x, dim=1) 34 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 35 | log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0] 36 | log_p = log_p.view(-1, c) 37 | 38 | mask = target >= 0 39 | target = target[mask] 40 | loss = F.nll_loss(log_p, target, ignore_index=250, 41 | weight=weight, size_average=False) 42 | if size_average: 43 | loss /= mask.data.sum() 44 | return loss 45 | 46 | 47 | def save_model(checkpoint, class_names, conf_matrix, test_error, prev_error, avg_accuracy, class_iou, save_dir, save_all): 48 | if test_error >= prev_error: 49 | prev_error = test_error 50 | 51 | print(CP_G + 'Saving model!!!' + CP_C) 52 | torch.save(checkpoint, save_dir + '/model_best.pth') 53 | 54 | np.savetxt(save_dir + '/confusion_matrix_best.txt', conf_matrix, fmt='%10s', delimiter=' ') 55 | 56 | conf_file = open(save_dir + '/confusion_matrix_best.txt', 'a') 57 | conf_file.write('{:-<80}\n'.format('')) 58 | first = True 59 | for value in class_iou: 60 | if first: 61 | conf_file.write("{:>10}".format("{:2.2f}".format(100*value))) 62 | first = False 63 | else: 64 | conf_file.write("{:>14}".format("{:2.2f}".format(100*value))) 65 | 66 | conf_file.write("\n") 67 | 68 | first = True 69 | for value in class_names: 70 | if first: 71 | conf_file.write("{:>10}".format(value)) 72 | first = False 73 | else: 74 | conf_file.write("{:>14}".format(value)) 75 | 76 | conf_file.write('\n{:-<80}\n\n'.format('')) 77 | conf_file.write('mIoU : ' + str(test_error) + '\n') 78 | conf_file.write('Average Accuracy : ' + str(avg_accuracy)) 79 | conf_file.close() 80 | 81 | if save_all: 82 | torch.save(checkpoint, save_dir + '/all/model_' + str(checkpoint['epoch']) + '.pth') 83 | 84 | conf_file_path = save_dir + '/all/confusion_matrix_' + str(checkpoint['epoch']) + '.txt' 85 | np.savetxt(conf_file_path, conf_matrix, fmt='%10s', delimiter=' ') 86 | 87 | conf_file = open(conf_file_path, 'a') 88 | conf_file.write('{:-<80}\n'.format('')) 89 | first = True 90 | for value in class_iou: 91 | if first: 92 | conf_file.write("{:>10}".format("{:2.2f}".format(100*value))) 93 | first = False 94 | else: 95 | conf_file.write("{:>14}".format("{:2.2f}".format(100*value))) 96 | 97 | conf_file.write("\n") 98 | 99 | first = True 100 | for value in class_names: 101 | if first: 102 | conf_file.write("{:>10}".format(value)) 103 | first = False 104 | else: 105 | conf_file.write("{:>14}".format(value)) 106 | 107 | conf_file.write('\n{:-<80}\n'.format('')) 108 | conf_file.write('mIoU : ' + str(test_error) + '\n') 109 | conf_file.write('Average Accuracy : ' + str(avg_accuracy)) 110 | conf_file.close() 111 | 112 | torch.save(checkpoint, save_dir + '/model_resume.pth') 113 | 114 | return prev_error 115 | 116 | 117 | def main(): 118 | print(CP_R + "e-Lab Segmentation Training Script" + CP_C) 119 | ################################################################# 120 | # Initialization step 121 | torch.manual_seed(args.seed) 122 | cudnn.benchmark = True 123 | torch.set_default_tensor_type('torch.FloatTensor') 124 | 125 | ################################################################# 126 | # Acquire dataset loader object 127 | # Normalization factor based on ResNet stats 128 | prep_data = transforms.Compose([ 129 | #transforms.Crop((512, 512)), 130 | transforms.Resize((1024, 512)), 131 | transforms.ToTensor(), 132 | transforms.Normalize([[0.406, 0.456, 0.485], [0.225, 0.224, 0.229]]) 133 | ]) 134 | 135 | prep_target = transforms.Compose([ 136 | #transforms.Crop((512, 512)), 137 | transforms.Resize((1024, 512)), 138 | transforms.ToTensor(basic=True), 139 | ]) 140 | 141 | if args.dataset == 'cs': 142 | import data.segmented_data as segmented_data 143 | print ("{}Cityscapes dataset in use{}!!!".format(CP_G, CP_C)) 144 | else: 145 | print ("{}Invalid data-loader{}".format(CP_R, CP_C)) 146 | 147 | # Training data loader 148 | data_obj_train = segmented_data.SegmentedData(root=args.datapath, mode='train', 149 | transform=prep_data, target_transform=prep_target) 150 | data_loader_train = DataLoader(data_obj_train, batch_size=args.bs, shuffle=True, 151 | num_workers=args.workers, pin_memory=True) 152 | data_len_train = len(data_obj_train) 153 | 154 | # Testing data loader 155 | data_obj_test = segmented_data.SegmentedData(root=args.datapath, mode='val', 156 | transform=prep_data, target_transform=prep_target) 157 | data_loader_test = DataLoader(data_obj_test, batch_size=args.bs, shuffle=False, 158 | num_workers=args.workers, pin_memory=True) 159 | data_len_test = len(data_obj_test) 160 | 161 | class_names = data_obj_train.class_name() 162 | n_classes = len(class_names) 163 | ################################################################# 164 | # Load model 165 | epoch = 0 166 | prev_iou = 0.0001 167 | # Load fresh model definition 168 | print('{}{:=<80}{}'.format(CP_R, '', CP_C)) 169 | print('{}Models will be saved in: {}{}'.format(CP_Y, CP_C, str(args.save))) 170 | if not os.path.exists(str(args.save)): 171 | os.mkdir(str(args.save)) 172 | 173 | if args.saveAll: 174 | if not os.path.exists(str(args.save)+'/all'): 175 | os.mkdir(str(args.save)+'/all') 176 | 177 | if args.model == 'linknet': 178 | # Save model definiton script 179 | call(["cp", "./models/linknet.py", args.save]) 180 | 181 | from models.linknet import LinkNet 182 | from torchvision.models import resnet18 183 | model = LinkNet(n_classes) 184 | 185 | # # Copy weights of resnet18 into encoder 186 | # pretrained_model = resnet18(pretrained=True) 187 | # for i, j in zip(model.modules(), pretrained_model.modules()): 188 | # if not list(i.children()): 189 | # if not isinstance(i, nn.Linear) and len(i.state_dict()) > 0: 190 | # i.weight.data = j.weight.data 191 | 192 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 193 | model.cuda() 194 | optimizer = torch.optim.Adam(model.parameters(), args.lr)#, 195 | #momentum=args.momentum, weight_decay=args.wd) 196 | 197 | if args.resume: 198 | # Load previous model state 199 | checkpoint = torch.load(args.save + '/model_resume.pth') 200 | epoch = checkpoint['epoch'] 201 | model.load_state_dict(checkpoint['state_dict']) 202 | 203 | optimizer.load_state_dict(checkpoint['optim_state']) 204 | prev_iou = checkpoint['min_error'] 205 | print('{}Loaded model from previous checkpoint epoch # {}({})'.format(CP_G, CP_C, epoch)) 206 | 207 | # Criterion 208 | print("Model initialized for training...") 209 | 210 | hist_path = os.path.join(args.save, 'hist') 211 | if os.path.isfile(hist_path + '.npy'): 212 | hist = np.load(hist_path + '.npy') 213 | print('{}Loaded cached dataset stats{}!!!'.format(CP_Y, CP_C)) 214 | else: 215 | # Get class weights based on training data 216 | hist = np.zeros((n_classes), dtype=np.float) 217 | for batch_idx, (x, yt) in enumerate(data_loader_train): 218 | h, bins = np.histogram(yt.numpy(), list(range(n_classes + 1))) 219 | hist += h 220 | 221 | hist = hist/(max(hist)) # Normalize histogram 222 | print('{}Saving dataset stats{}...'.format(CP_Y, CP_C)) 223 | np.save(hist_path, hist) 224 | 225 | criterion_weight = 1/np.log(1.02 + hist) 226 | criterion_weight[0] = 0 227 | criterion = nn.NLLLoss(Variable(torch.from_numpy(criterion_weight).float().cuda())) 228 | print('{}Using weighted criterion{}!!!'.format(CP_Y, CP_C)) 229 | #criterion = cross_entropy2d 230 | 231 | # Save arguements used for training 232 | args_log = open(args.save + '/args.log', 'w') 233 | for k in args.__dict__: 234 | args_log.write(k + ' : ' + str(args.__dict__[k]) + '\n') 235 | args_log.close() 236 | 237 | # Setup Metrics 238 | metrics = ConfusionMatrix(n_classes, class_names, useUnlabeled=args.use_unlabeled) 239 | 240 | train = Train(model, data_loader_train, optimizer, criterion, args.lr, args.wd, args.bs, args.visdom) 241 | test = Test(model, data_loader_test, criterion, metrics, args.bs, args.visdom) 242 | 243 | # Save error values in log file 244 | logger = open(args.save + '/error.log', 'w') 245 | logger.write('{:10} {:10}'.format('Train Error', 'Test Error')) 246 | logger.write('\n{:-<20}'.format('')) 247 | while epoch <= args.maxepoch: 248 | train_error = 0 249 | print('{}{:-<80}{}'.format(CP_R, '', CP_C)) 250 | print('{}Epoch #: {}{:03}'.format(CP_B, CP_C, epoch)) 251 | train_error = train.forward() 252 | test_error, accuracy, avg_accuracy, iou, miou, conf_mat= test.forward() 253 | 254 | logger.write('\n{:.6f} {:.6f} {:.6f}'.format(train_error, test_error, miou)) 255 | print('{}Training Error: {}{:.6f} | {}Testing Error: {}{:.6f} |{}Mean IoU: {}{:.6f}'.format( 256 | CP_B, CP_C, train_error, CP_B, CP_C, test_error, CP_G, CP_C, miou)) 257 | 258 | # Save weights and model definition 259 | prev_iou = save_model({ 260 | 'epoch': epoch, 261 | 'model_def': model, 262 | 'state_dict': model.state_dict(), 263 | 'optim_state': optimizer.state_dict(), 264 | 'min_error': prev_iou 265 | }, class_names, conf_mat, miou, prev_iou, avg_accuracy, iou, args.save, args.saveAll) 266 | 267 | epoch += 1 268 | 269 | logger.close() 270 | 271 | 272 | if __name__ == '__main__': 273 | main() 274 | -------------------------------------------------------------------------------- /media/opt.txt: -------------------------------------------------------------------------------- 1 | weightDecay:0.0002 2 | dataset:cs 3 | channels:3 4 | imHeight:512 5 | maxepoch:300 6 | printNorm:False 7 | plot:False 8 | batchSize:8 9 | devid:1 10 | datapath:/media/ 11 | nGPU:4 12 | learningRate:0.0005 13 | save:media 14 | lrDecayEvery:100 15 | momentum:0.9 16 | learningRateDecay:1e-07 17 | threads:8 18 | saveAll:False 19 | saveTrainConf:False 20 | cachepath:/media/ 21 | model:/models/model.py 22 | imWidth:1024 23 | pretrained:/media/HDD1/Models/pretrained/resnet-18.t7 24 | -------------------------------------------------------------------------------- /models/linknet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torchvision.models import resnet 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | 9 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, groups=1, bias=False): 10 | super(BasicBlock, self).__init__() 11 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=bias) 12 | self.bn1 = nn.BatchNorm2d(out_planes) 13 | self.relu = nn.ReLU(inplace=True) 14 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size, 1, padding, groups=groups, bias=bias) 15 | self.bn2 = nn.BatchNorm2d(out_planes) 16 | self.downsample = None 17 | if stride > 1: 18 | self.downsample = nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False), 19 | nn.BatchNorm2d(out_planes),) 20 | 21 | def forward(self, x): 22 | residual = x 23 | 24 | out = self.conv1(x) 25 | out = self.bn1(out) 26 | out = self.relu(out) 27 | 28 | out = self.conv2(out) 29 | out = self.bn2(out) 30 | 31 | if self.downsample is not None: 32 | residual = self.downsample(x) 33 | 34 | out += residual 35 | out = self.relu(out) 36 | 37 | return out 38 | 39 | 40 | class Encoder(nn.Module): 41 | 42 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, groups=1, bias=False): 43 | super(Encoder, self).__init__() 44 | self.block1 = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, bias) 45 | self.block2 = BasicBlock(out_planes, out_planes, kernel_size, 1, padding, groups, bias) 46 | 47 | def forward(self, x): 48 | x = self.block1(x) 49 | x = self.block2(x) 50 | 51 | return x 52 | 53 | 54 | class Decoder(nn.Module): 55 | 56 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=False): 57 | # TODO bias=True 58 | super(Decoder, self).__init__() 59 | self.conv1 = nn.Sequential(nn.Conv2d(in_planes, in_planes//4, 1, 1, 0, bias=bias), 60 | nn.BatchNorm2d(in_planes//4), 61 | nn.ReLU(inplace=True),) 62 | self.tp_conv = nn.Sequential(nn.ConvTranspose2d(in_planes//4, in_planes//4, kernel_size, stride, padding, output_padding, bias=bias), 63 | nn.BatchNorm2d(in_planes//4), 64 | nn.ReLU(inplace=True),) 65 | self.conv2 = nn.Sequential(nn.Conv2d(in_planes//4, out_planes, 1, 1, 0, bias=bias), 66 | nn.BatchNorm2d(out_planes), 67 | nn.ReLU(inplace=True),) 68 | 69 | def forward(self, x): 70 | x = self.conv1(x) 71 | x = self.tp_conv(x) 72 | x = self.conv2(x) 73 | 74 | return x 75 | 76 | 77 | class LinkNet(nn.Module): 78 | """ 79 | Generate Model Architecture 80 | """ 81 | 82 | def __init__(self, n_classes=21): 83 | """ 84 | Model initialization 85 | :param x_n: number of input neurons 86 | :type x_n: int 87 | """ 88 | super(LinkNet, self).__init__() 89 | 90 | base = resnet.resnet18(pretrained=True) 91 | 92 | self.in_block = nn.Sequential( 93 | base.conv1, 94 | base.bn1, 95 | base.relu, 96 | base.maxpool 97 | ) 98 | 99 | self.encoder1 = base.layer1 100 | self.encoder2 = base.layer2 101 | self.encoder3 = base.layer3 102 | self.encoder4 = base.layer4 103 | 104 | self.decoder1 = Decoder(64, 64, 3, 1, 1, 0) 105 | self.decoder2 = Decoder(128, 64, 3, 2, 1, 1) 106 | self.decoder3 = Decoder(256, 128, 3, 2, 1, 1) 107 | self.decoder4 = Decoder(512, 256, 3, 2, 1, 1) 108 | 109 | # Classifier 110 | self.tp_conv1 = nn.Sequential(nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), 111 | nn.BatchNorm2d(32), 112 | nn.ReLU(inplace=True),) 113 | self.conv2 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1), 114 | nn.BatchNorm2d(32), 115 | nn.ReLU(inplace=True),) 116 | self.tp_conv2 = nn.ConvTranspose2d(32, n_classes, 2, 2, 0) 117 | self.lsm = nn.LogSoftmax(dim=1) 118 | 119 | 120 | def forward(self, x): 121 | # Initial block 122 | x = self.in_block(x) 123 | 124 | # Encoder blocks 125 | e1 = self.encoder1(x) 126 | e2 = self.encoder2(e1) 127 | e3 = self.encoder3(e2) 128 | e4 = self.encoder4(e3) 129 | 130 | # Decoder blocks 131 | #d4 = e3 + self.decoder4(e4) 132 | d4 = e3 + self.decoder4(e4) 133 | d3 = e2 + self.decoder3(d4) 134 | d2 = e1 + self.decoder2(d3) 135 | d1 = x + self.decoder1(d2) 136 | 137 | # Classifier 138 | y = self.tp_conv1(d1) 139 | y = self.conv2(y) 140 | y = self.tp_conv2(y) 141 | 142 | y = self.lsm(y) 143 | 144 | return y 145 | 146 | class LinkNetBase(nn.Module): 147 | """ 148 | Generate model architecture 149 | """ 150 | 151 | def __init__(self, n_classes=21): 152 | """ 153 | Model initialization 154 | :param x_n: number of input neurons 155 | :type x_n: int 156 | """ 157 | super(LinkNetBase, self).__init__() 158 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 159 | self.bn1 = nn.BatchNorm2d(64) 160 | self.relu = nn.ReLU(inplace=True) 161 | self.maxpool = nn.MaxPool2d(3, 2, 1) 162 | 163 | self.encoder1 = Encoder(64, 64, 3, 1, 1) 164 | self.encoder2 = Encoder(64, 128, 3, 2, 1) 165 | self.encoder3 = Encoder(128, 256, 3, 2, 1) 166 | self.encoder4 = Encoder(256, 512, 3, 2, 1) 167 | 168 | self.decoder1 = Decoder(64, 64, 3, 1, 1, 0) 169 | self.decoder2 = Decoder(128, 64, 3, 2, 1, 1) 170 | self.decoder3 = Decoder(256, 128, 3, 2, 1, 1) 171 | self.decoder4 = Decoder(512, 256, 3, 2, 1, 1) 172 | 173 | # Classifier 174 | self.tp_conv1 = nn.Sequential(nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), 175 | nn.BatchNorm2d(32), 176 | nn.ReLU(inplace=True),) 177 | self.conv2 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1), 178 | nn.BatchNorm2d(32), 179 | nn.ReLU(inplace=True),) 180 | self.tp_conv2 = nn.ConvTranspose2d(32, n_classes, 2, 2, 0) 181 | self.lsm = nn.LogSoftmax(dim=1) 182 | 183 | def forward(self, x): 184 | # Initial block 185 | x = self.conv1(x) 186 | x = self.bn1(x) 187 | x = self.relu(x) 188 | x = self.maxpool(x) 189 | 190 | # Encoder blocks 191 | e1 = self.encoder1(x) 192 | e2 = self.encoder2(e1) 193 | e3 = self.encoder3(e2) 194 | e4 = self.encoder4(e3) 195 | 196 | # Decoder blocks 197 | #d4 = e3 + self.decoder4(e4) 198 | d4 = e3 + self.decoder4(e4) 199 | d3 = e2 + self.decoder3(d4) 200 | d2 = e1 + self.decoder2(d3) 201 | d1 = x + self.decoder1(d2) 202 | 203 | # Classifier 204 | y = self.tp_conv1(d1) 205 | y = self.conv2(y) 206 | y = self.tp_conv2(y) 207 | 208 | y = self.lsm(y) 209 | 210 | return y 211 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from torch import FloatTensor as tensor 2 | from torch import cuda 3 | from torch import nn 4 | from collections import OrderedDict as od 5 | import torch 6 | import math 7 | import os 8 | import torchvision.models as models 9 | 10 | class Model(object): 11 | 12 | def __init__(self, opt): 13 | self.opt = opt 14 | 15 | print ('\n\27[31m\27[4mConstructing Neural Network\27[0m') 16 | print ('Using pretrained ResNet-18') 17 | 18 | # loading model 19 | self.oModel = models.resnet18(True) 20 | #self.oModel = torch.load(opt['pretrained']) 21 | self.classes = opt['Classes'] 22 | self.histClasses = opt['histClasses'] 23 | 24 | 25 | # Getting rid of classifier 26 | self.oModels = nn.Sequential(*list(self.oModel.children())[:-3]) 27 | #self.oModel.remove(11) 28 | #self.oModel.remove(10) 29 | #self.oModel.remove(9) 30 | # Last layer is size 512x8x8 31 | 32 | # Function and variable definition 33 | self.iChannels = 64 34 | self.Convolution = nn.ConvTranspose2d 35 | self.Avg = nn.AvgPool2d 36 | self.ReLU = nn.ReLU 37 | self.Max = nn.MaxPool2d 38 | self.SBatchNorm = nn.BatchNorm2d 39 | self.arcum = [] 40 | self.model = None 41 | self.loss = None 42 | 43 | if os.path.isfile(self.opt['save'] + '/all/model-last.net'): 44 | model = torch.load(self.opt['save'] + '/all/model-last.net') 45 | else: 46 | layers = od([("oModel layer 1", list(self.oModel.children())[0]), ("oModel layer 2", list(self.oModel.children())[1]), 47 | ("oModel layer 3", list(self.oModel.children())[2]), ("oModel layer 4", list(self.oModel.children())[3]), 48 | ("bypass2dec layer", self.bypass2dec(64, 1, 1, 0)), 49 | ("spacial layer 1", nn.ConvTranspose2d(64, 32, (3, 3), padding=(1, 1), output_padding=(1, 1), stride=(2, 2))), 50 | ("batch norm layer 1", self.SBatchNorm(32)), ("ReLu layer 1", self.ReLU(True)), 51 | ("conv layer 1", self.Convolution(32, 32, 3, 3, 1, 1, 1, 1)), 52 | ("batch norm layer 2", self.SBatchNorm(32, eps=1e-3)), ("rectified layer 2", self.ReLU(True)), 53 | ("spacial layer 2", nn.ConvTranspose2d(32, len(self.classes), (2, 2), stride=(2, 2), 54 | padding=(0, 0), output_padding=(0, 0)))]) 55 | 56 | 57 | """ 58 | model.add_module("oModel layer 1", self.oModel.get(1)) 59 | model.add_module("oModel layer 2", self.oModel.get(2)) 60 | model.add_module("oModel layer 3", self.oModel.get(3)) 61 | 62 | model.add_module("oModel layer 4", self.oModel.get(4)) 63 | model.add_module("bypass2dec layer", self.bypass2dec(64, 1, 1, 0)) 64 | 65 | # -- Decoder section without bypassed information 66 | model.add_module("spacial layer 1", nn.ConvTranspose2d(64, 32, (3, 3), padding=(1, 1), output_padding=(1, 1) 67 | , stride=(2, 2))) 68 | model.add_module("batch norm layer 1", self.SBatchNorm(32)) 69 | model.add_module("ReLu layer 1", self.ReLU(True)) 70 | # -- 64x128x128 71 | model.add_module("conv layer 1", self.Convolution(32, 32, 3, 3, 1, 1, 1, 1)) 72 | model.add_module("batch norm layer 2", self.SBatchNorm(32, eps=1e-3)) 73 | model.add_module("rectified layer 2", self.ReLU(True)) 74 | # -- 32x128x128 75 | model.add_module("spacial layer 2", nn.ConvTranspose2d(32, len(self.classes), (2, 2), stride=(2, 2), 76 | padding=(0, 0), output_padding=(0, 0))) 77 | """ 78 | 79 | # -- Model definition ends here 80 | 81 | # -- Initialize convolutions and batch norm existing in later stage of decoder 82 | for i in range(1, 2): 83 | self.ConvInit(list(layers.items())[len(layers)-1][1]) 84 | self.ConvInit(list(layers.items())[len(layers)-1][1]) 85 | self.ConvInit(list(layers.items())[len(layers) - 4][1]) 86 | self.ConvInit(list(layers.items())[len(layers) - 4][1]) 87 | self.ConvInit(list(layers.items())[len(layers) - 7][1]) 88 | self.ConvInit(list(layers.items())[len(layers) - 7][1]) 89 | 90 | self.BNInit(list(layers.items())[len(layers) - 3][1]) 91 | self.BNInit(list(layers.items())[len(layers) - 3][1]) 92 | self.BNInit(list(layers.items())[len(layers) - 6][1]) 93 | self.BNInit(list(layers.items())[len(layers) - 6][1]) 94 | 95 | model = nn.Sequential(layers) 96 | 97 | """if torch.cuda.device_count() > 1: 98 | gpu_list = [] 99 | for i in range(0, torch.cuda.device_count()): 100 | gpu_list.append(i) 101 | model = nn.DataParallel(model) 102 | print('\27[32m' + str(self.opt['nGPU']) + " GPUs being used\27[0m") 103 | """ 104 | print('Defining loss function...') 105 | classWeights = torch.pow(torch.log(1.02 + self.histClasses / self.histClasses.max()), -1) 106 | #classWeights[0] = 0 107 | 108 | self.loss = torch.nn.CrossEntropyLoss("""weight=classWeights""") 109 | 110 | #model.cuda() 111 | #self.loss.cuda() 112 | 113 | self.model = model 114 | 115 | @staticmethod 116 | def ConvInit(vector): 117 | n = vector.kernel_size[0] * vector.kernel_size[1] * vector.out_channels 118 | vector.weight = torch.nn.Parameter(tensor(vector.in_channels, vector.out_channels // vector.groups, 119 | *vector.kernel_size).normal_(0, math.sqrt(2 / n))) 120 | # removed the weight:normal 121 | 122 | @staticmethod 123 | def BNInit(vector): 124 | vector.weight = torch.nn.Parameter(tensor(vector.num_features).fill_(1)) 125 | vector.bias = torch.nn.Parameter(tensor(vector.num_features).zero_()) 126 | 127 | def decode(self, iFeatures, oFeatures, stride, adjS): 128 | """ 129 | mainBlock = nn.Sequential() 130 | mainBlock.add_module("conv layer 1", self.Convolution(iFeatures, iFeatures / 4, 1, 1, 1, 1, 0, 0)) 131 | mainBlock.add_module("batch norm 1", self.SBatchNorm(iFeatures / 4, eps=1e-3)) 132 | mainBlock.add_module("rectifier layer 1", nn.ReLU(True)) 133 | mainBlock.add_module("spacial layer 1", nn.ConvTranspose2d(iFeatures / 4, iFeatures / 4, (3, 3), stride= 134 | (stride, stride), padding=(1, 1), output_padding=(adjS, adjS))) 135 | mainBlock.add_module("batch norm layer 2", self.SBatchNorm(iFeatures / 4, eps=1e-3)) 136 | mainBlock.add_module("rectifier layer 2", nn.ReLU(True)) 137 | mainBlock.add_module("conv layer 2", self.Convolution(iFeatures / 4, oFeatures, 1, 1, 1, 1, 0, 0)) 138 | mainBlock.add_module("batch norm layer 3", self.SBatchNorm(oFeatures, eps=1e-3)) 139 | mainBlock.add_module("rectifier layer 3", nn.ReLU(True)) 140 | """ 141 | 142 | layers = od([("conv layer 1", self.Convolution(int(iFeatures), int(iFeatures / 4), (1, 1), stride=(1, 1), padding=(0, 0))), 143 | ("batch norm 1", self.SBatchNorm(int(iFeatures / 4), eps=1e-3)), 144 | ("rectifier layer 1", nn.ReLU(True)), 145 | ("spacial layer 1", nn.ConvTranspose2d(int(iFeatures / 4), int(iFeatures / 4), (3, 3), 146 | stride=(stride, stride), padding=(1, 1), output_padding=(adjS, adjS))), 147 | ("batch norm layer 2", self.SBatchNorm(int(iFeatures / 4), eps=1e-3)), 148 | ("rectifier layer 2", nn.ReLU(True)), 149 | ("conv layer 2", self.Convolution(int(iFeatures / 4), oFeatures, (1, 1), stride=(1, 1), padding=(0, 0))), 150 | ("batch norm layer 3", self.SBatchNorm(oFeatures, eps=1e-3)), 151 | ("rectifier layer 3", nn.ReLU(True))]) 152 | 153 | for i in range(1, 2): 154 | self.ConvInit(list(layers.items())[0][1]) 155 | self.ConvInit(list(layers.items())[3][1]) 156 | self.ConvInit(list(layers.items())[6][1]) 157 | 158 | self.BNInit(list(layers.items())[1][1]) 159 | self.BNInit(list(layers.items())[4][1]) 160 | self.BNInit(list(layers.items())[7][1]) 161 | mainBlock = nn.Sequential(layers) 162 | 163 | return mainBlock 164 | 165 | def layer(self, layerN, features): 166 | self.iChannels = features 167 | s = nn.Sequential() 168 | for i in range(0, 2): 169 | s.add_module("Feature layer" + str(i), list(self.oModel.children())[i]) 170 | return s 171 | 172 | def bypass2dec(self, features, layers, stride, adjS): 173 | container = nn.Sequential() 174 | prim = nn.Sequential() # Container for encoder 175 | oFeatures = self.iChannels 176 | 177 | accum = [prim] #FIXME 178 | 179 | # -- Add the bottleneck modules 180 | prim.add_module("bypass_layer_"+str(layers), self.layer(layers, features)) 181 | if layers == 4: 182 | # --DECODER 183 | prim.add_module("decoder_layer_mod4"+str(layers), self.decode(features, oFeatures, 2, 1)) 184 | #container.add_module("arcum_decoder_" + str(layers), self.ConcatTable) 185 | #container.add_module("decoder_CAddTable_"+str(layers), self.CAddTable) 186 | container.add_module("rectifier_decoder_" + str(layers), nn.ReLU(True)) 187 | return container 188 | # -- Move on to next bottleneck 189 | prim.add_module("bypass2dec_layer_"+str(layers), self.bypass2dec(2 * features, layers + 1, 2, 1)) 190 | 191 | # -- Add decoder module 192 | prim.add_module("decoder_layer_"+str(layers), self.decode(features, oFeatures, stride, adjS)) 193 | #container.add_module("arcum_decoder_" + str(layers), self.ConcatTable) 194 | #container.add_module("decoder_CAddTable_"+str(layers), self.CAddTable) 195 | container.add_module("rectifier_decoder_" + str(layers), nn.ReLU(True)) 196 | 197 | return container 198 | 199 | @staticmethod 200 | def CAddTable(in1, in2): 201 | return in1 + in2 202 | 203 | @staticmethod 204 | def ConcatTable(new): 205 | return self.arcum.append(new) 206 | -------------------------------------------------------------------------------- /models/nobypass.py: -------------------------------------------------------------------------------- 1 | from torch import FloatTensor as tensor 2 | from torch import nn as nn 3 | import torch 4 | import math 5 | from torch import cuda 6 | import os 7 | from collections import OrderedDict as od 8 | 9 | 10 | class Model(object): 11 | 12 | def __init__(self, opt): 13 | self.opt = opt 14 | 15 | print '\n\27[31m\27[4mConstructing Neural Network\27[0m' 16 | print 'Using pretrained ResNet-18' 17 | 18 | # loading model 19 | self.oModel = torch.load(opt['pretrained']) 20 | self.classes = opt['Classes'] 21 | self.histClasses = opt['histClasses'] 22 | 23 | 24 | # Getting rid of classifier 25 | self.oModel.remove(11) 26 | self.oModel.remove(10) 27 | self.oModel.remove(9) 28 | # Last layer is size 512x8x8 29 | 30 | # Function and variable definition 31 | self.iChannels = 64 32 | self.Convolution = nn.ConvTranspose2d 33 | self.Avg = nn.AvgPool2d 34 | self.ReLU = nn.ReLU 35 | self.Max = nn.MaxPool2d 36 | self.SBatchNorm = nn.BatchNorm2d 37 | 38 | self.model = None 39 | self.loss = None 40 | 41 | if os.path.isfile(self.opt['save'] + '/all/model-last.net'): 42 | model = torch.load(self.opt['save'] + '/all/model-last.net') 43 | else: 44 | layers = od([("oModel layer 1", self.oModel.get(1)), ("oModel layer 2", self.oModel.get(2)), 45 | ("oModel layer 3", self.oModel.get(3)), ("oModel layer 4", self.oModel.get(4)), 46 | ("enc_Dec layer", self.model.add_module(self.enc_dec(64, 1, 1, 0))), 47 | ("spacial layer 1", nn.ConvTranspose2d(64, 32, (3, 3), padding=(1, 1), output_padding=(1, 1), 48 | stride=(2, 2))), 49 | ("batch norm layer 1", self.SBatchNorm(32)), ("ReLu layer 1", self.ReLU(True)), 50 | ("conv layer 1", self.Convolution(32, 32, 3, 3, 1, 1, 1, 1).type(cuda.FloatTensor)), 51 | ("batch norm layer 2", self.SBatchNorm(32, eps=1e-3)), ("rectified layer 2", self.ReLU(True)), 52 | ("spacial layer 2", nn.ConvTranspose2d(32, len(self.classes), (2, 2), stride=(2, 2), 53 | padding=(0, 0), output_padding=(0, 0)))]) 54 | 55 | 56 | """ 57 | model.add_module("oModel layer 1", self.oModel.get(1)) 58 | model.add_module("oModel layer 2", self.oModel.get(2)) 59 | model.add_module("oModel layer 3", self.oModel.get(3)) 60 | 61 | model.add_module("oModel layer 4", self.oModel.get(4)) 62 | model.add_module("bypass2dec layer", self.bypass2dec(64, 1, 1, 0)) 63 | 64 | # -- Decoder section without bypassed information 65 | model.add_module("spacial layer 1", nn.ConvTranspose2d(64, 32, (3, 3), padding=(1, 1), output_padding=(1, 1) 66 | , stride=(2, 2))) 67 | model.add_module("batch norm layer 1", self.SBatchNorm(32)) 68 | model.add_module("ReLu layer 1", self.ReLU(True)) 69 | # -- 64x128x128 70 | model.add_module("conv layer 1", self.Convolution(32, 32, 3, 3, 1, 1, 1, 1)) 71 | model.add_module("batch norm layer 2", self.SBatchNorm(32, eps=1e-3)) 72 | model.add_module("rectified layer 2", self.ReLU(True)) 73 | # -- 32x128x128 74 | model.add_module("spacial layer 2", nn.ConvTranspose2d(32, len(self.classes), (2, 2), stride=(2, 2), 75 | padding=(0, 0), output_padding=(0, 0))) 76 | """ 77 | 78 | # -- Model definition ends here 79 | 80 | # -- Initialize convolutions and batch norm existing in later stage of decoder 81 | for i in range(1, 2): 82 | self.ConvInit(layers.items()[len(layers)-1][1]) 83 | self.ConvInit(layers.items()[len(layers)-1][1]) 84 | self.ConvInit(layers.items()[len(layers) - 4][1]) 85 | self.ConvInit(layers.items()[len(layers) - 4][1]) 86 | self.ConvInit(layers.items()[len(layers) - 7][1]) 87 | self.ConvInit(layers.items()[len(layers) - 7][1]) 88 | 89 | self.BNInit(layers.items()[len(layers) - 3][1]) 90 | self.BNInit(layers.items()[len(layers) - 3][1]) 91 | self.BNInit(layers.items()[len(layers) - 6][1]) 92 | self.BNInit(layers.items()[len(layers) - 6][1]) 93 | 94 | model = nn.Sequential(layers) 95 | 96 | if torch.cuda.device_count() > 1: 97 | gpu_list = [] 98 | for i in range(0, torch.cuda.device_count()): 99 | gpu_list.append(i) 100 | model = nn.DataParallel(1, True, False).add(model.cuda(), gpu_list) # check this 101 | print('\27[32m' + str(self.opt['nGPU']) + " GPUs being used\27[0m") 102 | 103 | print('Defining loss function...') 104 | classWeights = torch.pow(torch.log(1.02 + self.histClasses / self.histClasses.max()), -1) 105 | -- classWeights[0] = 0 106 | 107 | self.loss = torch.nn.CrossEntropyLoss(weight=classWeights) 108 | 109 | model.cuda() 110 | self.loss.cuda() 111 | 112 | self.model = model 113 | 114 | self.model = model 115 | 116 | @staticmethod 117 | def ConvInit(vector): 118 | n = vector.kernel_size(0) * vector.kernel_size(1) * vector.out_channels 119 | vector.weight = torch.nn.Parameter(tensor(vector.in_channels, vector.out_channels // vector.groups, 120 | *vector.kernel_size).normal_(0, math.sqrt(2 / n))) 121 | # removed the weight:normal 122 | 123 | @staticmethod 124 | def BNInit(vector): 125 | vector.weight = torch.nn.Parameter(tensor(vector.in_channels, vector.out_channels // vector.groups, 126 | *vector.kernel_size).fill(1)) 127 | vector.bias = torch.nn.Parameter(tensor(vector.out_channels).zero_()) 128 | 129 | def decode(self, iFeatures, oFeatures, stride, adjS): 130 | """ 131 | mainBlock = nn.Sequential() 132 | mainBlock.add_module("conv layer 1", self.Convolution(iFeatures, iFeatures / 4, 1, 1, 1, 1, 0, 0)) 133 | mainBlock.add_module("batch norm 1", self.SBatchNorm(iFeatures / 4, eps=1e-3)) 134 | mainBlock.add_module("rectifier layer 1", nn.ReLU(True)) 135 | mainBlock.add_module("spacial layer 1", nn.ConvTranspose2d(iFeatures / 4, iFeatures / 4, (3, 3), stride= 136 | (stride, stride), padding=(1, 1), output_padding=(adjS, adjS))) 137 | mainBlock.add_module("batch norm layer 2", self.SBatchNorm(iFeatures / 4, eps=1e-3)) 138 | mainBlock.add_module("rectifier layer 2", nn.ReLU(True)) 139 | mainBlock.add_module("conv layer 2", self.Convolution(iFeatures / 4, oFeatures, 1, 1, 1, 1, 0, 0)) 140 | mainBlock.add_module("batch norm layer 3", self.SBatchNorm(oFeatures, eps=1e-3)) 141 | mainBlock.add_module("rectifier layer 3", nn.ReLU(True)) 142 | """ 143 | 144 | layers = od([("conv layer 1", self.Convolution(iFeatures, iFeatures / 4, 1, 1, 1, 1, 0, 0).type(cuda.FloatTensor)), 145 | ("batch norm 1", self.SBatchNorm(iFeatures / 4, eps=1e-3)), 146 | ("rectifier layer 1", nn.ReLU(True)), 147 | ("spacial layer 1", nn.ConvTranspose2d(iFeatures / 4, iFeatures / 4, (3, 3), 148 | stride=(stride, stride), padding=(1, 1), output_padding=(adjS, adjS))), 149 | ("batch norm layer 2", self.SBatchNorm(iFeatures / 4, eps=1e-3)), 150 | ("rectifier layer 2", nn.ReLU(True)), 151 | ("conv layer 2", self.Convolution(iFeatures / 4, oFeatures, 1, 1, 1, 1, 0, 0).type(cuda.FloatTensor)), 152 | ("batch norm layer 3", self.SBatchNorm(oFeatures, eps=1e-3)), 153 | ("rectifier layer 3", nn.ReLU(True))]) 154 | 155 | for i in xrange(1, 2): 156 | self.ConvInit(layers.items()[0][1]) 157 | self.ConvInit(layers.items()[3][1]) 158 | self.ConvInit(layers.items()[6][1]) 159 | 160 | self.BNInit(layers.items()[1][1]) 161 | self.BNInit(layers.items()[4][1]) 162 | self.BNInit(layers.items()[7][1]) 163 | mainBlock = nn.Sequential(layers) 164 | 165 | return mainBlock 166 | 167 | def layer(self, layerN, features): 168 | self.iChannels = features 169 | s = nn.Sequential() 170 | for i in xrange(1, 2): 171 | s.add_module("Feature layer" + str(i), list(list(self.oModel.children())[4+layerN].children)[i]) 172 | return s 173 | 174 | # -- Creates bypass modules for decoders 175 | def enc_dec(self, features, layers, stride, adjS): 176 | accum = nn.Sequential() 177 | oFeatures = self.iChannels 178 | 179 | # -- Add the bottleneck modules 180 | accum.add_module("Bottleneck_layer_"+str(layers), self.layer(layers, features)) 181 | if layers == 4: 182 | # --DECODER 183 | accum.add_module("bottleneck_Decoder_layer_"+str(layers), self.decode(features, oFeatures, 2, 1)) 184 | return accum 185 | 186 | # -- Move on to next bottleneck 187 | accum.add_module("enc_dec_layer_"+str(layers), self.enc_dec(2 * features, layers + 1, 2, 1)) 188 | 189 | # -- Add decoder module 190 | accum.add_module("bypass_decoder_layer_"+str(layers), self.decode(features, oFeatures, stride, adjS)) 191 | return accum 192 | 193 | 194 | 195 | 196 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def get_args(): 5 | # training related 6 | parser = ArgumentParser(description='e-Lab Segmentation Script') 7 | arg = parser.add_argument 8 | arg('--bs', type=float, default=8, help='batch size') 9 | arg('--lr', type=float, default=5e-4, help='learning rate, default is 5e-4') 10 | arg('--lrd', type=float, default=1e-7, help='learning rate decay (in # samples)') 11 | arg('--wd', type=float, default=2e-4, help='L2 penalty on the weights, default is 2e-4') 12 | arg('-m', '--momentum', type=float, default=.9, help='momentum, default: .9') 13 | 14 | # device related 15 | arg('--workers', type=int, default=8, help='# of cpu threads for data-loader') 16 | arg('--maxepoch', type=int, default=300, help='maximum number of training epochs') 17 | arg('--seed', type=int, default=0, help='seed value for random number generator') 18 | arg('--nGPU', type=int, default=4, help='number of GPUs you want to train on') 19 | arg('--save', type=str, default='media', help='save trained model here') 20 | 21 | # data set related: 22 | arg('--datapath', type=str, default='/media/HDD1/Datasets', help='dataset location') 23 | arg('--dataset', type=str, default='cs', choices=["cs", "cv"], 24 | help='dataset type: cs(cityscapes)/cv(CamVid)') 25 | arg('--img_size', type=int, default=512, help='image height (576 cv/512 cs)') 26 | arg('--use_unlabeled', action='store_true', help='use unlabeled class annotation') 27 | 28 | # model related 29 | arg('--model', type=str, default='linknet', help='linknet') 30 | 31 | # Saving/Displaying Information 32 | arg('--visdom', action='store_true', help='Plot using visdom') 33 | arg('--saveAll', action='store_true', help='Save all models and confusion matrices') 34 | arg('--resume', action='store_true', help='Resume from previous checkpoint') 35 | 36 | args = parser.parse_args() 37 | return args 38 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import visdom 3 | from tqdm import trange 4 | from torch.autograd import Variable 5 | 6 | class Test(object): 7 | def __init__(self, model, data_loader, criterion, metrics, batch_size, vis): 8 | super(Test, self).__init__() 9 | self.model = model 10 | self.data_loader = data_loader 11 | self.criterion = criterion 12 | self.metrics = metrics 13 | self.bs = batch_size 14 | self.vis = None 15 | 16 | if vis: 17 | self.vis = visdom.Visdom() 18 | 19 | self.loss_window = self.vis.line(X=torch.zeros((1,)).cpu(), 20 | Y=torch.zeros((1)).cpu(), 21 | opts=dict(xlabel='minibatches', 22 | ylabel='Loss', 23 | title='Validation Loss', 24 | legend=['Loss'])) 25 | 26 | self.iterations = 0 27 | 28 | def forward(self): 29 | self.model.eval() 30 | # TODO adjust learning rate 31 | 32 | total_loss = 0 33 | pbar = trange(len(self.data_loader.dataset), desc='Validation ') 34 | 35 | for batch_idx, (x, yt) in enumerate(self.data_loader): 36 | x = x.cuda(async=True) 37 | yt = yt.cuda(async=True) 38 | input_var = Variable(x, requires_grad=False) 39 | target_var = Variable(yt, requires_grad=False) 40 | 41 | # compute output 42 | y = self.model(input_var) 43 | loss = self.criterion(y, target_var) 44 | 45 | # measure accuracy and record loss 46 | total_loss += loss.item() 47 | 48 | # calculate mIoU 49 | pred = y.data.cpu().numpy() 50 | gt = yt.cpu().numpy() 51 | self.metrics.update_matrix(gt, pred) 52 | 53 | if batch_idx % 10 == 0: 54 | # Update tqdm bar 55 | if (batch_idx*self.bs + 10*len(x)) <= len(self.data_loader.dataset): 56 | pbar.update(10 * len(x)) 57 | else: 58 | pbar.update(len(self.data_loader.dataset) - int(batch_idx*self.bs)) 59 | 60 | # Display plot using visdom 61 | if self.vis: 62 | self.vis.line( 63 | X=torch.ones((1)).cpu() * self.iterations, 64 | Y=loss.data.cpu(), 65 | win=self.loss_window, 66 | update='append') 67 | 68 | self.iterations += 1 69 | 70 | accuracy, avg_accuracy, IoU, mIoU, conf_mat = self.metrics.scores() 71 | self.metrics.reset() 72 | pbar.close() 73 | 74 | return (total_loss*self.bs/len(self.data_loader.dataset), accuracy, avg_accuracy, IoU, mIoU, conf_mat) 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import visdom 3 | from tqdm import trange 4 | from torch.autograd import Variable 5 | 6 | class Train(object): 7 | def __init__(self, model, data_loader, optimizer, criterion, lr, wd, batch_size, vis): 8 | super(Train, self).__init__() 9 | self.model = model 10 | self.data_loader = data_loader 11 | self.optimizer = optimizer 12 | self.criterion = criterion 13 | self.lr = lr 14 | self.wd = wd 15 | self.bs = batch_size 16 | self.vis = None 17 | 18 | if vis: 19 | self.vis = visdom.Visdom() 20 | 21 | self.loss_window = self.vis.line(X=torch.zeros((1,)).cpu(), 22 | Y=torch.zeros((1)).cpu(), 23 | opts=dict(xlabel='minibatches', 24 | ylabel='Loss', 25 | title='Training Loss', 26 | legend=['Loss'])) 27 | 28 | self.iterations = 0 29 | 30 | def forward(self): 31 | self.model.train() 32 | # TODO adjust learning rate 33 | 34 | total_loss = 0 35 | pbar = trange(len(self.data_loader.dataset), desc='Training ') 36 | 37 | for batch_idx, (x, yt) in enumerate(self.data_loader): 38 | x = x.cuda(async=True) 39 | yt = yt.cuda(async=True) 40 | input_var = Variable(x) 41 | target_var = Variable(yt) 42 | 43 | # compute output 44 | y = self.model(input_var) 45 | loss = self.criterion(y, target_var) 46 | 47 | # measure accuracy and record loss 48 | total_loss += loss.item() 49 | 50 | # compute gradient and do SGD step 51 | self.optimizer.zero_grad() 52 | loss.backward() 53 | self.optimizer.step() 54 | 55 | if batch_idx % 10 == 0: 56 | # Update tqdm bar 57 | if (batch_idx*self.bs + 10*len(x)) <= len(self.data_loader.dataset): 58 | pbar.update(10 * len(x)) 59 | else: 60 | pbar.update(len(self.data_loader.dataset) - int(batch_idx*self.bs)) 61 | 62 | # Display plot using visdom 63 | if self.vis: 64 | self.vis.line( 65 | X=torch.ones((1)).cpu() * self.iterations, 66 | Y=loss.data.cpu(), 67 | win=self.loss_window, 68 | update='append') 69 | 70 | self.iterations += 1 71 | 72 | pbar.close() 73 | 74 | return total_loss*self.bs/len(self.data_loader.dataset) 75 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | class Resize: 5 | ''' 6 | Attributes 7 | ---------- 8 | factor : amount by which image needs to be resized 9 | 10 | Methods 11 | ------- 12 | forward(img=input_image) 13 | Resizes a numpy image of shape HWC 14 | ''' 15 | def __init__(self, factor): 16 | self.factor = factor 17 | 18 | def __call__(self, img): 19 | return self.forward(img) 20 | 21 | def forward(self, img): 22 | ''' 23 | Parameters 24 | ---------- 25 | img : opencv image 26 | 27 | Returns 28 | ------- 29 | numpy array 30 | Resize image 31 | ''' 32 | 33 | return cv2.resize(img, self.factor) 34 | 35 | 36 | class Normalize: 37 | ''' 38 | Attributes 39 | ---------- 40 | factor : list containing 2 lists with mean and standard deviation for each channel 41 | 42 | Methods 43 | ------- 44 | forward(img=input_image) 45 | Normalizes an input image based on mean and standard deviation 46 | ''' 47 | def __init__(self, factor): 48 | self.factor = factor 49 | 50 | def __call__(self, img): 51 | return self.forward(img) 52 | 53 | def forward(self, img): 54 | ''' 55 | Parameters 56 | ---------- 57 | img : image CHW 58 | 59 | Returns 60 | ------- 61 | array 62 | Normalized image 63 | ''' 64 | 65 | norm = self.factor[0] 66 | std = self.factor[1] 67 | 68 | assert (img.shape[0] == len(norm)), \ 69 | "{:d} channels in image but {:d} in normalization".format(img.shape[0], len(norm)) 70 | 71 | for i in range(len(norm)): 72 | img[i] = (img[i] - norm[i])/std[i] 73 | 74 | return img 75 | 76 | 77 | class Crop: 78 | ''' 79 | Attributes 80 | ---------- 81 | (h, w): center crop with this height and width value 82 | 83 | Methods 84 | ------- 85 | forward(img=input_image) 86 | Center crop of image 87 | ''' 88 | def __init__(self, dim): 89 | self.dim = dim 90 | 91 | def __call__(self, img): 92 | return self.forward(img) 93 | 94 | def forward(self, img): 95 | ''' 96 | Parameters 97 | ---------- 98 | img : image HW or HWC 99 | 100 | Returns 101 | ------- 102 | array 103 | Cropped image 104 | ''' 105 | 106 | h, w = self.dim 107 | img_h, img_w, _ = img.shape 108 | assert (img_h >= h and img_w >= w), \ 109 | "Cannot create a crop of {}x{} from image of resolution {}x{}".format(h, w, img_h, img_w) 110 | 111 | ch, cw = img_h//2, img_w//2 112 | y1, y2 = ch - h//2, ch + h//2 113 | x1, x2 = cw - w//2, cw + w//2 114 | 115 | return img[y1:y2, x1:x2] 116 | 117 | 118 | class ToTensor: 119 | ''' 120 | Attributes 121 | ---------- 122 | basic : convert numpy to PyTorch tensor 123 | 124 | Methods 125 | ------- 126 | forward(img=input_image) 127 | Convert HWC OpenCV image into CHW PyTorch Tensor 128 | ''' 129 | def __init__(self, basic=False): 130 | self.basic = basic 131 | 132 | def __call__(self, img): 133 | return self.forward(img) 134 | 135 | def forward(self, img): 136 | ''' 137 | Parameters 138 | ---------- 139 | img : opencv/numpy image 140 | 141 | Returns 142 | ------- 143 | Torch tensor 144 | BGR -> RGB, [0, 255] -> [0, 1] 145 | ''' 146 | 147 | if self.basic: 148 | return torch.from_numpy(img) 149 | else: 150 | img_RGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)/255 151 | return torch.from_numpy(img_RGB.transpose(2, 0, 1)) 152 | 153 | 154 | class Compose: 155 | def __init__(self, transforms): 156 | self.transforms = transforms 157 | 158 | def __call__(self, img): 159 | return self.forward(img) 160 | 161 | def forward(self, img): 162 | for t in self.transforms: 163 | img = t(img) 164 | 165 | return img 166 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import torch 5 | import numpy as np 6 | from argparse import ArgumentParser 7 | 8 | import transforms 9 | 10 | parser = ArgumentParser(description='e-Lab Segmentation Visualizer') 11 | _ = parser.add_argument 12 | _('--model_path', type=str, default='/media/', help='model to load') 13 | _('--data_path', type=str, default='/media/', help='image folder') 14 | _('--mode', type=int, default=0, help='mode 0, 1, 2') 15 | _('--fullscreen', action='store_true', help='Show output in full screen') 16 | 17 | args = parser.parse_args() 18 | 19 | # Clear screen 20 | print('\033[0;0f\033[0J') 21 | # Color Palette 22 | CP_R = '\033[31m' 23 | CP_G = '\033[32m' 24 | CP_B = '\033[34m' 25 | CP_Y = '\033[33m' 26 | CP_C = '\033[0m' 27 | 28 | # Define color scheme 29 | color_map = np.array([ 30 | [0, 0, 0], # Unlabled 31 | [128, 64, 128], # Road 32 | [244, 35, 232], # Sidewalk 33 | [70, 70, 70], # Building 34 | [102, 102, 156], # Wall 35 | [190, 153, 153], # Fence 36 | [153, 153, 153], # Pole 37 | [250, 170, 30], # Traffic light 38 | [220, 220, 0], # Traffic signal 39 | [107, 142, 35], # Vegetation 40 | [152, 251, 152], # Terrain 41 | [70, 130, 180], # Sky 42 | [220, 20, 60], # Person 43 | [255, 0, 0], # Rider 44 | [0, 0, 142], # Car 45 | [0, 0, 70], # Truck 46 | [0, 60, 100], # Bus 47 | [0, 80, 100], # Train 48 | [0, 0, 230], # Motorcycle 49 | [119, 11, 32] # Bicycle 50 | ], dtype=np.uint8) 51 | 52 | # Load model 53 | m = torch.load('/media/HDD2/Models/abhi/2/model_best.pth') 54 | model = torch.nn.DataParallel(m['model_def'](20)) 55 | model.load_state_dict(m['state_dict']) 56 | model.cuda() 57 | model.eval() 58 | 59 | root_dir = os.path.join(args.data_path, 'stuttgart_0' + str(args.mode)) 60 | first_idx = [1, 3500, 5100] 61 | last_idx = [599, 4599, 6299] 62 | idx = first_idx[args.mode] 63 | fps = 'NA' 64 | pred_map = np.zeros((256, 512, 3), dtype=np.uint8) 65 | pred_map = np.zeros((512, 1024, 3), dtype=np.uint8) 66 | 67 | win_title = 'Overlayed Image' 68 | if args.fullscreen: 69 | cv2.namedWindow(win_title, cv2.WND_PROP_FULLSCREEN) 70 | cv2.setWindowProperty(win_title, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN) 71 | 72 | while idx <= last_idx[args.mode]: 73 | # Load image, resize and convert into a 'batchified' cuda tensor 74 | start_time = time.time() 75 | filename = '{}/stuttgart_0{:d}_000000_{:06d}_leftImg8bit.png'.format(root_dir, args.mode, idx) 76 | 77 | if os.path.isfile(filename): 78 | x = cv2.imread(filename) 79 | read_time = time.time() - start_time 80 | 81 | resize = transforms.Resize(0.5) 82 | x = resize(x) 83 | prep_data = transforms.Compose([ 84 | #transforms.Crop((512, 512)), 85 | transforms.ToTensor(), 86 | transforms.Normalize([[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]) 87 | ]) 88 | input_image = prep_data(x) 89 | #input_image = torch.from_numpy(cv2.cvtColor(x, cv2.COLOR_BGR2RGB).transpose(2, 0, 1))/255 90 | input_image = input_image.unsqueeze(0).float().cuda() 91 | prep_time = time.time() - start_time - read_time 92 | 93 | # Get neural network output 94 | y = model(torch.autograd.Variable(input_image)) 95 | y = y.squeeze() 96 | pred = y.data.cpu().numpy() 97 | model_time = time.time() - start_time - read_time - prep_time 98 | 99 | # Calculate prediction and colorized segemented output 100 | prediction = np.argmax(pred, axis=0) 101 | num_classes = 20 102 | pred_map *= 0 103 | for i in range(num_classes): 104 | pred_map[prediction == i] = color_map[i] 105 | 106 | pred_map_BGR = cv2.cvtColor(pred_map, cv2.COLOR_RGB2BGR) 107 | overlay = cv2.addWeighted(x, 0.5, pred_map_BGR, 0.5, 0) 108 | pred_time = time.time() - start_time - read_time - prep_time - model_time 109 | 110 | #cv2.imshow('Original Image', x_rescaled) 111 | #cv2.imshow('Segmented Output', pred_map_BGR) 112 | cv2.imshow(win_title, overlay) 113 | disp_time = time.time() - start_time - read_time - prep_time - model_time - pred_time 114 | fps = 1/(time.time() - start_time) 115 | 116 | print("{}Read: {}{:4.2f} ms | {}Norm:: {}{:4.2f} ms | {}Model: {}{:4.2f} ms | {}Predict: {}{:4.2f} ms | {}Display: {}{:4.2f} ms".format( 117 | CP_Y, CP_C, read_time*1000, CP_G, CP_C, prep_time*1000, CP_G, CP_C, model_time*1000, 118 | CP_R, CP_C, pred_time*1000, CP_B, CP_C, disp_time*1000)) 119 | else: 120 | print("{}Warning{}!!! {}{} image unavailable{}.".format(CP_R, CP_C, filename, CP_R, CP_C)) 121 | 122 | idx += 1 123 | if cv2.waitKey(1) == 27: # ESC to stop 124 | break 125 | 126 | --------------------------------------------------------------------------------