├── best_model.pkl ├── starburst_black.png ├── test ├── mask │ ├── 000000337021.jpg │ ├── 000000337043.jpg │ ├── 000000337047.jpg │ ├── 000000337720.jpg │ ├── 000000337733.jpg │ ├── 000000337735.jpg │ ├── 000000338387.jpg │ ├── 000000339828.jpg │ ├── 000000339839.jpg │ ├── 000000339848.jpg │ ├── 000000339850.jpg │ ├── 000000340534.jpg │ ├── 000000340542.jpg │ ├── 000000341973.jpg │ ├── 000000342664.jpg │ ├── 000000343356.jpg │ ├── 000000344729.jpg │ ├── 000000344746.jpg │ ├── 000000345428.jpg │ ├── 000000346092.jpg │ ├── 000000346102.jpg │ ├── 000000346107.jpg │ ├── 000000346784.jpg │ ├── 000000347467.jpg │ ├── 000000347488.jpg │ ├── 000000348174.jpg │ ├── 000000348182.jpg │ ├── 000000348212.jpg │ ├── 000000348897.jpg │ ├── 000000349618.jpg │ ├── 000000350263.jpg │ ├── 000000350293.jpg │ ├── 000000352383.jpg │ ├── 000000352385.jpg │ ├── 000000352398.jpg │ ├── 000000352410.jpg │ ├── 000000353117.jpg │ ├── 000000353820.jpg │ ├── 000000355220.jpg │ └── 000000355243.jpg └── labels │ ├── 000000337021.npy │ ├── 000000337043.npy │ ├── 000000337047.npy │ ├── 000000337720.npy │ ├── 000000337733.npy │ ├── 000000337735.npy │ ├── 000000338387.npy │ ├── 000000339828.npy │ ├── 000000339839.npy │ ├── 000000339848.npy │ ├── 000000339850.npy │ ├── 000000340534.npy │ ├── 000000340542.npy │ ├── 000000341973.npy │ ├── 000000342664.npy │ ├── 000000343356.npy │ ├── 000000344729.npy │ ├── 000000344746.npy │ ├── 000000345428.npy │ ├── 000000346092.npy │ ├── 000000346102.npy │ ├── 000000346107.npy │ ├── 000000346784.npy │ ├── 000000347467.npy │ ├── 000000347488.npy │ ├── 000000348174.npy │ ├── 000000348182.npy │ ├── 000000348212.npy │ ├── 000000348897.npy │ ├── 000000349618.npy │ ├── 000000350263.npy │ ├── 000000350293.npy │ ├── 000000352383.npy │ ├── 000000352385.npy │ ├── 000000352398.npy │ ├── 000000352410.npy │ ├── 000000353117.npy │ ├── 000000353820.npy │ ├── 000000355220.npy │ └── 000000355243.npy ├── requirements.txt ├── models.py ├── opt.py ├── test.py ├── densenet.py ├── train.py ├── utils.py ├── dataset.py ├── environment.yml └── README.md /best_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/best_model.pkl -------------------------------------------------------------------------------- /starburst_black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/starburst_black.png -------------------------------------------------------------------------------- /test/mask/000000337021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000337021.jpg -------------------------------------------------------------------------------- /test/mask/000000337043.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000337043.jpg -------------------------------------------------------------------------------- /test/mask/000000337047.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000337047.jpg -------------------------------------------------------------------------------- /test/mask/000000337720.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000337720.jpg -------------------------------------------------------------------------------- /test/mask/000000337733.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000337733.jpg -------------------------------------------------------------------------------- /test/mask/000000337735.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000337735.jpg -------------------------------------------------------------------------------- /test/mask/000000338387.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000338387.jpg -------------------------------------------------------------------------------- /test/mask/000000339828.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000339828.jpg -------------------------------------------------------------------------------- /test/mask/000000339839.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000339839.jpg -------------------------------------------------------------------------------- /test/mask/000000339848.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000339848.jpg -------------------------------------------------------------------------------- /test/mask/000000339850.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000339850.jpg -------------------------------------------------------------------------------- /test/mask/000000340534.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000340534.jpg -------------------------------------------------------------------------------- /test/mask/000000340542.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000340542.jpg -------------------------------------------------------------------------------- /test/mask/000000341973.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000341973.jpg -------------------------------------------------------------------------------- /test/mask/000000342664.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000342664.jpg -------------------------------------------------------------------------------- /test/mask/000000343356.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000343356.jpg -------------------------------------------------------------------------------- /test/mask/000000344729.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000344729.jpg -------------------------------------------------------------------------------- /test/mask/000000344746.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000344746.jpg -------------------------------------------------------------------------------- /test/mask/000000345428.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000345428.jpg -------------------------------------------------------------------------------- /test/mask/000000346092.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000346092.jpg -------------------------------------------------------------------------------- /test/mask/000000346102.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000346102.jpg -------------------------------------------------------------------------------- /test/mask/000000346107.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000346107.jpg -------------------------------------------------------------------------------- /test/mask/000000346784.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000346784.jpg -------------------------------------------------------------------------------- /test/mask/000000347467.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000347467.jpg -------------------------------------------------------------------------------- /test/mask/000000347488.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000347488.jpg -------------------------------------------------------------------------------- /test/mask/000000348174.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000348174.jpg -------------------------------------------------------------------------------- /test/mask/000000348182.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000348182.jpg -------------------------------------------------------------------------------- /test/mask/000000348212.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000348212.jpg -------------------------------------------------------------------------------- /test/mask/000000348897.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000348897.jpg -------------------------------------------------------------------------------- /test/mask/000000349618.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000349618.jpg -------------------------------------------------------------------------------- /test/mask/000000350263.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000350263.jpg -------------------------------------------------------------------------------- /test/mask/000000350293.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000350293.jpg -------------------------------------------------------------------------------- /test/mask/000000352383.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000352383.jpg -------------------------------------------------------------------------------- /test/mask/000000352385.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000352385.jpg -------------------------------------------------------------------------------- /test/mask/000000352398.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000352398.jpg -------------------------------------------------------------------------------- /test/mask/000000352410.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000352410.jpg -------------------------------------------------------------------------------- /test/mask/000000353117.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000353117.jpg -------------------------------------------------------------------------------- /test/mask/000000353820.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000353820.jpg -------------------------------------------------------------------------------- /test/mask/000000355220.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000355220.jpg -------------------------------------------------------------------------------- /test/mask/000000355243.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/mask/000000355243.jpg -------------------------------------------------------------------------------- /test/labels/000000337021.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000337021.npy -------------------------------------------------------------------------------- /test/labels/000000337043.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000337043.npy -------------------------------------------------------------------------------- /test/labels/000000337047.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000337047.npy -------------------------------------------------------------------------------- /test/labels/000000337720.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000337720.npy -------------------------------------------------------------------------------- /test/labels/000000337733.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000337733.npy -------------------------------------------------------------------------------- /test/labels/000000337735.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000337735.npy -------------------------------------------------------------------------------- /test/labels/000000338387.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000338387.npy -------------------------------------------------------------------------------- /test/labels/000000339828.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000339828.npy -------------------------------------------------------------------------------- /test/labels/000000339839.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000339839.npy -------------------------------------------------------------------------------- /test/labels/000000339848.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000339848.npy -------------------------------------------------------------------------------- /test/labels/000000339850.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000339850.npy -------------------------------------------------------------------------------- /test/labels/000000340534.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000340534.npy -------------------------------------------------------------------------------- /test/labels/000000340542.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000340542.npy -------------------------------------------------------------------------------- /test/labels/000000341973.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000341973.npy -------------------------------------------------------------------------------- /test/labels/000000342664.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000342664.npy -------------------------------------------------------------------------------- /test/labels/000000343356.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000343356.npy -------------------------------------------------------------------------------- /test/labels/000000344729.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000344729.npy -------------------------------------------------------------------------------- /test/labels/000000344746.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000344746.npy -------------------------------------------------------------------------------- /test/labels/000000345428.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000345428.npy -------------------------------------------------------------------------------- /test/labels/000000346092.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000346092.npy -------------------------------------------------------------------------------- /test/labels/000000346102.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000346102.npy -------------------------------------------------------------------------------- /test/labels/000000346107.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000346107.npy -------------------------------------------------------------------------------- /test/labels/000000346784.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000346784.npy -------------------------------------------------------------------------------- /test/labels/000000347467.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000347467.npy -------------------------------------------------------------------------------- /test/labels/000000347488.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000347488.npy -------------------------------------------------------------------------------- /test/labels/000000348174.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000348174.npy -------------------------------------------------------------------------------- /test/labels/000000348182.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000348182.npy -------------------------------------------------------------------------------- /test/labels/000000348212.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000348212.npy -------------------------------------------------------------------------------- /test/labels/000000348897.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000348897.npy -------------------------------------------------------------------------------- /test/labels/000000349618.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000349618.npy -------------------------------------------------------------------------------- /test/labels/000000350263.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000350263.npy -------------------------------------------------------------------------------- /test/labels/000000350293.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000350293.npy -------------------------------------------------------------------------------- /test/labels/000000352383.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000352383.npy -------------------------------------------------------------------------------- /test/labels/000000352385.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000352385.npy -------------------------------------------------------------------------------- /test/labels/000000352398.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000352398.npy -------------------------------------------------------------------------------- /test/labels/000000352410.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000352410.npy -------------------------------------------------------------------------------- /test/labels/000000353117.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000353117.npy -------------------------------------------------------------------------------- /test/labels/000000353820.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000353820.npy -------------------------------------------------------------------------------- /test/labels/000000355220.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000355220.npy -------------------------------------------------------------------------------- /test/labels/000000355243.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofirkris/Eye-Segmentation/HEAD/test/labels/000000355243.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchsummary 2 | tqdm 3 | matplotlib 4 | numpy 5 | torch 6 | PIL 7 | cv2 8 | argparse 9 | pprint 10 | torchvision 11 | os 12 | 13 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Sep 8 18:50:11 2019 5 | 6 | @author: manoj 7 | """ 8 | 9 | 10 | from densenet import DenseNet2D 11 | model_dict = {} 12 | 13 | model_dict['densenet'] = DenseNet2D(dropout=True,prob=0.2) 14 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | import argparse 3 | 4 | def parse_args(): 5 | 6 | parser = argparse.ArgumentParser() 7 | # Data input settings 8 | parser.add_argument('--dataset', type=str, default='Semantic_Segmentation_Dataset/', help='name of dataset') 9 | # Optimization: General 10 | parser.add_argument('--bs', type=int, default = 8 ) 11 | parser.add_argument('--epochs', type=int,help='Number of epochs',default= 250) 12 | parser.add_argument('--workers', type=int,help='Number of workers',default=4) 13 | parser.add_argument('--model', help='model name',default='densenet') 14 | parser.add_argument('--evalsplit', help='eval spolit',default='val') 15 | parser.add_argument('--lr', type=float,default= 1e-3,help='Learning rate') 16 | parser.add_argument('--save', help='save folder name',default='0try') 17 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 18 | parser.add_argument('--load', type=str, default=None, help='load checkpoint file name') 19 | parser.add_argument('--resume', action='store_true', help='resume train from load chkpoint') 20 | parser.add_argument('--test', action='store_true', help='test only') 21 | parser.add_argument('--savemodel',action='store_true',help='checkpoint save the model') 22 | parser.add_argument('--testrun', action='store_true', help='test run with few dataset') 23 | parser.add_argument('--expname', type=str, default='info', help='extra explanation of the method') 24 | parser.add_argument('--useGPU', type=str, default=True, help='Set it as False if GPU is unavailable') 25 | 26 | # parse 27 | args = parser.parse_args() 28 | opt = vars(args) 29 | pprint('parsed input parameters:') 30 | pprint(opt) 31 | return args 32 | 33 | if __name__ == '__main__': 34 | 35 | opt = parse_args() 36 | print('opt[\'dataset\'] is ', opt.dataset) 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Sep 2 11:37:59 2019 5 | 6 | @author: aaa 7 | """ 8 | import torch 9 | from dataset import IrisDataset 10 | from torch.utils.data import DataLoader 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from dataset import transform 14 | import os 15 | from opt import parse_args 16 | from models import model_dict 17 | from tqdm import tqdm 18 | from utils import get_predictions 19 | #%% 20 | 21 | if __name__ == '__main__': 22 | 23 | args = parse_args() 24 | 25 | if args.model not in model_dict: 26 | print ("Model not found !!!") 27 | print ("valid models are:",list(model_dict.keys())) 28 | exit(1) 29 | 30 | if args.useGPU: 31 | device=torch.device("cuda") 32 | else: 33 | device=torch.device("cpu") 34 | 35 | model = model_dict[args.model] 36 | model = model.to(device) 37 | filename = args.load 38 | if not os.path.exists(filename): 39 | print("model path not found !!!") 40 | exit(1) 41 | 42 | model.load_state_dict(torch.load(filename)) 43 | model = model.to(device) 44 | model.eval() 45 | 46 | test_set = IrisDataset(filepath = 'Semantic_Segmentation_Dataset/',\ 47 | split = 'test',transform = transform) 48 | 49 | testloader = DataLoader(test_set, batch_size = args.bs, 50 | shuffle=False, num_workers=2) 51 | counter=0 52 | 53 | os.makedirs('test/labels/',exist_ok=True) 54 | os.makedirs('test/output/',exist_ok=True) 55 | os.makedirs('test/mask/',exist_ok=True) 56 | 57 | with torch.no_grad(): 58 | for i, batchdata in tqdm(enumerate(testloader),total=len(testloader)): 59 | img,labels,index,x,y= batchdata 60 | data = img.to(device) 61 | output = model(data) 62 | predict = get_predictions(output) 63 | for j in range (len(index)): 64 | np.save('test/labels/{}.npy'.format(index[j]),predict[j].cpu().numpy()) 65 | try: 66 | plt.imsave('test/output/{}.jpg'.format(index[j]),255*labels[j].cpu().numpy()) 67 | except: 68 | pass 69 | 70 | pred_img = predict[j].cpu().numpy()/3.0 71 | inp = img[j].squeeze() * 0.5 + 0.5 72 | img_orig = np.clip(inp,0,1) 73 | img_orig = np.array(img_orig) 74 | combine = np.hstack([img_orig,pred_img]) 75 | plt.imsave('test/mask/{}.jpg'.format(index[j]),combine) 76 | 77 | os.rename('test',args.save) 78 | -------------------------------------------------------------------------------- /densenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Sep 2 11:20:33 2019 5 | 6 | @author: Shusil Dangi 7 | 8 | References: 9 | https://github.com/ShusilDangi/DenseUNet-K 10 | It is a simplied version of DenseNet with U-NET architecture. 11 | 2D implementation 12 | """ 13 | import torch 14 | import math 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | class DenseNet2D_down_block(nn.Module): 19 | def __init__(self,input_channels,output_channels,down_size,dropout=False,prob=0): 20 | super(DenseNet2D_down_block, self).__init__() 21 | self.conv1 = nn.Conv2d(input_channels,output_channels,kernel_size=(3,3),padding=(1,1)) 22 | self.conv21 = nn.Conv2d(input_channels+output_channels,output_channels,kernel_size=(1,1),padding=(0,0)) 23 | self.conv22 = nn.Conv2d(output_channels,output_channels,kernel_size=(3,3),padding=(1,1)) 24 | self.conv31 = nn.Conv2d(input_channels+2*output_channels,output_channels,kernel_size=(1,1),padding=(0,0)) 25 | self.conv32 = nn.Conv2d(output_channels,output_channels,kernel_size=(3,3),padding=(1,1)) 26 | self.max_pool = nn.AvgPool2d(kernel_size=down_size) 27 | 28 | self.relu = nn.LeakyReLU() 29 | self.down_size = down_size 30 | self.dropout = dropout 31 | self.dropout1 = nn.Dropout(p=prob) 32 | self.dropout2 = nn.Dropout(p=prob) 33 | self.dropout3 = nn.Dropout(p=prob) 34 | self.bn = torch.nn.BatchNorm2d(num_features=output_channels) 35 | 36 | def forward(self, x): 37 | if self.down_size != None: 38 | x = self.max_pool(x) 39 | 40 | if self.dropout: 41 | x1 = self.relu(self.dropout1(self.conv1(x))) 42 | x21 = torch.cat((x,x1),dim=1) 43 | x22 = self.relu(self.dropout2(self.conv22(self.conv21(x21)))) 44 | x31 = torch.cat((x21,x22),dim=1) 45 | out = self.relu(self.dropout3(self.conv32(self.conv31(x31)))) 46 | else: 47 | x1 = self.relu(self.conv1(x)) 48 | x21 = torch.cat((x,x1),dim=1) 49 | x22 = self.relu(self.conv22(self.conv21(x21))) 50 | x31 = torch.cat((x21,x22),dim=1) 51 | out = self.relu(self.conv32(self.conv31(x31))) 52 | return self.bn(out) 53 | 54 | 55 | class DenseNet2D_up_block_concat(nn.Module): 56 | def __init__(self,skip_channels,input_channels,output_channels,up_stride,dropout=False,prob=0): 57 | super(DenseNet2D_up_block_concat, self).__init__() 58 | self.conv11 = nn.Conv2d(skip_channels+input_channels,output_channels,kernel_size=(1,1),padding=(0,0)) 59 | self.conv12 = nn.Conv2d(output_channels,output_channels,kernel_size=(3,3),padding=(1,1)) 60 | self.conv21 = nn.Conv2d(skip_channels+input_channels+output_channels,output_channels, 61 | kernel_size=(1,1),padding=(0,0)) 62 | self.conv22 = nn.Conv2d(output_channels,output_channels,kernel_size=(3,3),padding=(1,1)) 63 | self.relu = nn.LeakyReLU() 64 | self.up_stride = up_stride 65 | self.dropout = dropout 66 | self.dropout1 = nn.Dropout(p=prob) 67 | self.dropout2 = nn.Dropout(p=prob) 68 | 69 | def forward(self,prev_feature_map,x): 70 | x = nn.functional.interpolate(x,scale_factor=self.up_stride,mode='nearest') 71 | x = torch.cat((x,prev_feature_map),dim=1) 72 | if self.dropout: 73 | x1 = self.relu(self.dropout1(self.conv12(self.conv11(x)))) 74 | x21 = torch.cat((x,x1),dim=1) 75 | out = self.relu(self.dropout2(self.conv22(self.conv21(x21)))) 76 | else: 77 | x1 = self.relu(self.conv12(self.conv11(x))) 78 | x21 = torch.cat((x,x1),dim=1) 79 | out = self.relu(self.conv22(self.conv21(x21))) 80 | return out 81 | 82 | class DenseNet2D(nn.Module): 83 | def __init__(self,in_channels=1,out_channels=4,channel_size=32,concat=True,dropout=False,prob=0): 84 | super(DenseNet2D, self).__init__() 85 | 86 | self.down_block1 = DenseNet2D_down_block(input_channels=in_channels,output_channels=channel_size, 87 | down_size=None,dropout=dropout,prob=prob) 88 | self.down_block2 = DenseNet2D_down_block(input_channels=channel_size,output_channels=channel_size, 89 | down_size=(2,2),dropout=dropout,prob=prob) 90 | self.down_block3 = DenseNet2D_down_block(input_channels=channel_size,output_channels=channel_size, 91 | down_size=(2,2),dropout=dropout,prob=prob) 92 | self.down_block4 = DenseNet2D_down_block(input_channels=channel_size,output_channels=channel_size, 93 | down_size=(2,2),dropout=dropout,prob=prob) 94 | self.down_block5 = DenseNet2D_down_block(input_channels=channel_size,output_channels=channel_size, 95 | down_size=(2,2),dropout=dropout,prob=prob) 96 | 97 | self.up_block1 = DenseNet2D_up_block_concat(skip_channels=channel_size,input_channels=channel_size, 98 | output_channels=channel_size,up_stride=(2,2),dropout=dropout,prob=prob) 99 | self.up_block2 = DenseNet2D_up_block_concat(skip_channels=channel_size,input_channels=channel_size, 100 | output_channels=channel_size,up_stride=(2,2),dropout=dropout,prob=prob) 101 | self.up_block3 = DenseNet2D_up_block_concat(skip_channels=channel_size,input_channels=channel_size, 102 | output_channels=channel_size,up_stride=(2,2),dropout=dropout,prob=prob) 103 | self.up_block4 = DenseNet2D_up_block_concat(skip_channels=channel_size,input_channels=channel_size, 104 | output_channels=channel_size,up_stride=(2,2),dropout=dropout,prob=prob) 105 | 106 | self.out_conv1 = nn.Conv2d(in_channels=channel_size,out_channels=out_channels,kernel_size=1,padding=0) 107 | self.concat = concat 108 | self.dropout = dropout 109 | self.dropout1 = nn.Dropout(p=prob) 110 | 111 | self._initialize_weights() 112 | 113 | def _initialize_weights(self): 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | if m.bias is not None: 119 | m.bias.data.zero_() 120 | elif isinstance(m, nn.BatchNorm2d): 121 | m.weight.data.fill_(1) 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.Linear): 124 | n = m.weight.size(1) 125 | m.weight.data.normal_(0, 0.01) 126 | m.bias.data.zero_() 127 | 128 | def forward(self,x): 129 | self.x1 = self.down_block1(x) 130 | self.x2 = self.down_block2(self.x1) 131 | self.x3 = self.down_block3(self.x2) 132 | self.x4 = self.down_block4(self.x3) 133 | self.x5 = self.down_block5(self.x4) 134 | self.x6 = self.up_block1(self.x4,self.x5) 135 | self.x7 = self.up_block2(self.x3,self.x6) 136 | self.x8 = self.up_block3(self.x2,self.x7) 137 | self.x9 = self.up_block4(self.x1,self.x8) 138 | if self.dropout: 139 | out = self.out_conv1(self.dropout1(self.x9)) 140 | else: 141 | out = self.out_conv1(self.x9) 142 | 143 | return out 144 | 145 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Sep 2 11:22:32 2019 5 | 6 | @author: aayush 7 | """ 8 | 9 | from models import model_dict 10 | from torch.utils.data import DataLoader 11 | from dataset import IrisDataset 12 | import torch 13 | from utils import mIoU, CrossEntropyLoss2d,total_metric,get_nparams,Logger,GeneralizedDiceLoss,SurfaceLoss 14 | import numpy as np 15 | from dataset import transform 16 | from opt import parse_args 17 | import os 18 | from utils import get_predictions 19 | from tqdm import tqdm 20 | import matplotlib.pyplot as plt 21 | #%% 22 | 23 | def lossandaccuracy(loader,model,factor): 24 | epoch_loss = [] 25 | ious = [] 26 | model.eval() 27 | with torch.no_grad(): 28 | for i, batchdata in enumerate(loader): 29 | # print (len(batchdata)) 30 | img,labels,index,spatialWeights,maxDist=batchdata 31 | data = img.to(device) 32 | 33 | target = labels.to(device).long() 34 | output = model(data) 35 | 36 | ## loss from cross entropy is weighted sum of pixel wise loss and Canny edge loss *20 37 | CE_loss = criterion(output,target) 38 | loss = CE_loss*(torch.from_numpy(np.ones(spatialWeights.shape)).to(torch.float32).to(device)+(spatialWeights).to(torch.float32).to(device)) 39 | 40 | loss=torch.mean(loss).to(torch.float32).to(device) 41 | loss_dice = criterion_DICE(output,target) 42 | loss_sl = torch.mean(criterion_SL(output.to(device),(maxDist).to(device))) 43 | 44 | ##total loss is the weighted sum of suface loss and dice loss plus the boundary weighted cross entropy loss 45 | loss = (1-factor)*loss_sl+factor*(loss_dice)+loss 46 | 47 | epoch_loss.append(loss.item()) 48 | predict = get_predictions(output) 49 | iou = mIoU(predict,labels) 50 | ious.append(iou) 51 | return np.average(epoch_loss),np.average(ious) 52 | 53 | #%% 54 | if __name__ == '__main__': 55 | 56 | args = parse_args() 57 | kwargs = vars(args) 58 | 59 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 60 | 61 | if args.useGPU: 62 | device=torch.device("cuda") 63 | torch.cuda.manual_seed(12) 64 | else: 65 | device=torch.device("cpu") 66 | torch.manual_seed(12) 67 | 68 | torch.backends.cudnn.deterministic=False 69 | 70 | if args.model not in model_dict: 71 | print ("Model not found !!!") 72 | print ("valid models are:",list(model_dict.keys())) 73 | exit(1) 74 | 75 | LOGDIR = 'logs/{}'.format(args.expname) 76 | os.makedirs(LOGDIR,exist_ok=True) 77 | os.makedirs(LOGDIR+'/models',exist_ok=True) 78 | logger = Logger(os.path.join(LOGDIR,'logs.log')) 79 | 80 | model = model_dict[args.model] 81 | model = model.to(device) 82 | torch.save(model.state_dict(), '{}/models/dense_net{}.pkl'.format(LOGDIR,'_0')) 83 | model.train() 84 | nparams = get_nparams(model) 85 | 86 | try: 87 | from torchsummary import summary 88 | summary(model,input_size=(1,640,400)) 89 | print("Max params:", 1024*1024/4.0) 90 | logger.write_summary(str(model.parameters)) 91 | except: 92 | print ("Torch summary not found !!!") 93 | 94 | optimizer = torch.optim.Adam(model.parameters(), lr = args.lr) 95 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=5) 96 | 97 | criterion = CrossEntropyLoss2d() 98 | criterion_DICE = GeneralizedDiceLoss(softmax=True, reduction=True) 99 | criterion_SL = SurfaceLoss() 100 | 101 | Path2file = args.dataset 102 | train = IrisDataset(filepath = Path2file,split='train', 103 | transform = transform, **kwargs) 104 | 105 | valid = IrisDataset(filepath = Path2file , split='validation', 106 | transform = transform, **kwargs) 107 | 108 | trainloader = DataLoader(train, batch_size = args.bs, 109 | shuffle=True, num_workers = args.workers) 110 | 111 | validloader = DataLoader(valid, batch_size = args.bs, 112 | shuffle= False, num_workers = args.workers) 113 | 114 | test = IrisDataset(filepath = Path2file , split='test', 115 | transform = transform, **kwargs) 116 | 117 | testloader = DataLoader(test, batch_size = args.bs, 118 | shuffle=False, num_workers = args.workers) 119 | 120 | 121 | # alpha = 1 - np.arange(1,args.epochs)/args.epoch 122 | ##The weighing function for the dice loss and surface loss 123 | alpha=np.zeros(((args.epochs))) 124 | alpha[0:np.min([125,args.epochs])]=1 - np.arange(1,np.min([125,args.epochs])+1)/np.min([125,args.epochs]) 125 | if args.epochs>125: 126 | alpha[125:]=1 127 | ious = [] 128 | for epoch in range(args.epochs): 129 | for i, batchdata in enumerate(trainloader): 130 | # print (len(batchdata)) 131 | img,labels,index,spatialWeights,maxDist= batchdata 132 | data = img.to(device) 133 | target = labels.to(device).long() 134 | optimizer.zero_grad() 135 | output = model(data) 136 | ## loss from cross entropy is weighted sum of pixel wise loss and Canny edge loss *20 137 | CE_loss = criterion(output,target) 138 | loss = CE_loss*(torch.from_numpy(np.ones(spatialWeights.shape)).to(torch.float32).to(device)+(spatialWeights).to(torch.float32).to(device)) 139 | 140 | loss=torch.mean(loss).to(torch.float32).to(device) 141 | loss_dice = criterion_DICE(output,target) 142 | loss_sl = torch.mean(criterion_SL(output.to(device),(maxDist).to(device))) 143 | 144 | ##total loss is the weighted sum of suface loss and dice loss plus the boundary weighted cross entropy loss 145 | loss = (1-alpha[epoch])*loss_sl+alpha[epoch]*(loss_dice)+loss 146 | # 147 | predict = get_predictions(output) 148 | iou = mIoU(predict,labels) 149 | ious.append(iou) 150 | 151 | if i%10 == 0: 152 | logger.write('Epoch:{} [{}/{}], Loss: {:.3f}'.format(epoch,i,len(trainloader),loss.item())) 153 | 154 | loss.backward() 155 | optimizer.step() 156 | 157 | logger.write('Epoch:{}, Train mIoU: {}'.format(epoch,np.average(ious))) 158 | lossvalid , miou = lossandaccuracy(validloader,model,alpha[epoch]) 159 | totalperf = total_metric(nparams,miou) 160 | f = 'Epoch:{}, Valid Loss: {:.3f} mIoU: {} Complexity: {} total: {}' 161 | logger.write(f.format(epoch,lossvalid, miou,nparams,totalperf)) 162 | 163 | scheduler.step(lossvalid) 164 | 165 | ##save the model every epoch 166 | if epoch %1 == 0: 167 | torch.save(model.state_dict(), '{}/models/dense_net{}.pkl'.format(LOGDIR,epoch)) 168 | 169 | ##visualize the ouput every 5 epoch 170 | if epoch %5 ==0: 171 | os.makedirs('test/epoch/labels/',exist_ok=True) 172 | os.makedirs('test/epoch/output/',exist_ok=True) 173 | os.makedirs('test/epoch/mask/',exist_ok=True) 174 | 175 | with torch.no_grad(): 176 | for i, batchdata in tqdm(enumerate(testloader),total=len(testloader)): 177 | img,labels,index,x,maxDist= batchdata 178 | data = img.to(device) 179 | output = model(data) 180 | predict = get_predictions(output) 181 | for j in range (len(index)): 182 | np.save('test/epoch/labels/{}.npy'.format(index[j]),predict[j].cpu().numpy()) 183 | try: 184 | plt.imsave('test/epoch/output/{}.jpg'.format(index[j]),255*labels[j].cpu().numpy()) 185 | except: 186 | pass 187 | pred_img = predict[j].cpu().numpy()/3.0 188 | inp = img[j].squeeze() * 0.5 + 0.5 189 | img_orig = np.clip(inp,0,1) 190 | img_orig = np.array(img_orig) 191 | combine = np.hstack([img_orig,pred_img]) 192 | plt.imsave('test/epoch/mask/{}.jpg'.format(index[j]),combine) 193 | 194 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Aug 27 16:04:18 2019 5 | 6 | @author: Aayush Chaudhary 7 | 8 | References: 9 | https://evalai-forum.cloudcv.org/t/fyi-on-semantic-segmentation/180 10 | https://github.com/ycszen/pytorch-segmentation/blob/master/loss.py 11 | https://discuss.pytorch.org/t/using-cross-entropy-loss-with-semantic-segmentation-model/31988 12 | https://github.com/LIVIAETS/surface-loss 13 | """ 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torchvision 19 | import cv2 20 | import os 21 | 22 | from sklearn.metrics import precision_score , recall_score,f1_score 23 | from scipy.ndimage import distance_transform_edt as distance 24 | #%% 25 | class FocalLoss2d(nn.Module): 26 | def __init__(self, weight=None,gamma=2): 27 | super(FocalLoss2d,self).__init__() 28 | self.gamma = gamma 29 | self.loss = nn.NLLLoss(weight) 30 | def forward(self, outputs, targets): 31 | return self.loss((1 - nn.Softmax2d()(outputs)).pow(self.gamma) * torch.log(nn.Softmax2d()(outputs)), targets) 32 | 33 | ###https://github.com/ycszen/pytorch-segmentation/blob/master/loss.py 34 | # https://discuss.pytorch.org/t/using-cross-entropy-loss-with-semantic-segmentation-model/31988 35 | class CrossEntropyLoss2d(nn.Module): 36 | 37 | def __init__(self, weight=None): 38 | super(CrossEntropyLoss2d,self).__init__() 39 | self.loss = nn.NLLLoss(weight) 40 | 41 | def forward(self, outputs, targets): 42 | return self.loss(F.log_softmax(outputs,dim=1), targets) 43 | 44 | class SurfaceLoss(nn.Module): 45 | # Author: Rakshit Kothari 46 | def __init__(self, epsilon=1e-5, softmax=True): 47 | super(SurfaceLoss, self).__init__() 48 | self.weight_map = [] 49 | def forward(self, x, distmap): 50 | x = torch.softmax(x, dim=1) 51 | self.weight_map = distmap 52 | score = x.flatten(start_dim=2)*distmap.flatten(start_dim=2) 53 | score = torch.mean(score, dim=2) # Mean between pixels per channel 54 | score = torch.mean(score, dim=1) # Mean between channels 55 | return score 56 | 57 | 58 | class GeneralizedDiceLoss(nn.Module): 59 | # Author: Rakshit Kothari 60 | # Input: (B, C, ...) 61 | # Target: (B, C, ...) 62 | def __init__(self, epsilon=1e-5, weight=None, softmax=True, reduction=True): 63 | super(GeneralizedDiceLoss, self).__init__() 64 | self.epsilon = epsilon 65 | self.weight = [] 66 | self.reduction = reduction 67 | if softmax: 68 | self.norm = nn.Softmax(dim=1) 69 | else: 70 | self.norm = nn.Sigmoid() 71 | 72 | def forward(self, ip, target): 73 | 74 | # Rapid way to convert to one-hot. For future version, use functional 75 | Label = (np.arange(4) == target.cpu().numpy()[..., None]).astype(np.uint8) 76 | target = torch.from_numpy(np.rollaxis(Label, 3,start=1)).cuda() 77 | 78 | assert ip.shape == target.shape 79 | ip = self.norm(ip) 80 | 81 | # Flatten for multidimensional data 82 | ip = torch.flatten(ip, start_dim=2, end_dim=-1).cuda().to(torch.float32) 83 | target = torch.flatten(target, start_dim=2, end_dim=-1).cuda().to(torch.float32) 84 | 85 | numerator = ip*target 86 | denominator = ip + target 87 | 88 | class_weights = 1./(torch.sum(target, dim=2)**2).clamp(min=self.epsilon) 89 | 90 | A = class_weights*torch.sum(numerator, dim=2) 91 | B = class_weights*torch.sum(denominator, dim=2) 92 | 93 | dice_metric = 2.*torch.sum(A, dim=1)/torch.sum(B, dim=1) 94 | if self.reduction: 95 | return torch.mean(1. - dice_metric.clamp(min=self.epsilon)) 96 | else: 97 | return 1. - dice_metric.clamp(min=self.epsilon) 98 | 99 | #https://github.com/LIVIAETS/surface-loss 100 | def one_hot2dist(posmask): 101 | # Input: Mask. Will be converted to Bool. 102 | # Author: Rakshit Kothari 103 | assert len(posmask.shape) == 2 104 | h, w = posmask.shape 105 | res = np.zeros_like(posmask) 106 | posmask = posmask.astype(np.bool) 107 | mxDist = np.sqrt((h-1)**2 + (w-1)**2) 108 | if posmask.any(): 109 | negmask = ~posmask 110 | res = distance(negmask) * negmask - (distance(posmask) - 1) * posmask 111 | return res/mxDist 112 | 113 | def mIoU(predictions, targets,info=False): ###Mean per class accuracy 114 | unique_labels = np.unique(targets) 115 | num_unique_labels = len(unique_labels) 116 | ious = [] 117 | for index in range(num_unique_labels): 118 | pred_i = predictions == index 119 | label_i = targets == index 120 | intersection = np.logical_and(label_i, pred_i) 121 | union = np.logical_or(label_i, pred_i) 122 | iou_score = np.sum(intersection.numpy())/np.sum(union.numpy()) 123 | ious.append(iou_score) 124 | if info: 125 | print ("per-class mIOU: ", ious) 126 | return np.mean(ious) 127 | 128 | #https://evalai-forum.cloudcv.org/t/fyi-on-semantic-segmentation/180 129 | #GA: Global Pixel Accuracy 130 | #CA: Mean Class Accuracy for different classes 131 | # 132 | #Back: Background (non-eye part of peri-ocular region) 133 | #Sclera: Sclera 134 | #Iris: Iris 135 | #Pupil: Pupil 136 | #Precision: Computed using sklearn.metrics.precision_score(pred, gt, ‘weighted’) 137 | #Recall: Computed using sklearn.metrics.recall_score(pred, gt, ‘weighted’) 138 | #F1: Computed using sklearn.metrics.f1_score(pred, gt, ‘weighted’) 139 | #IoU: Computed using the function below 140 | def compute_mean_iou(flat_pred, flat_label,info=False): 141 | ''' 142 | compute mean intersection over union (IOU) over all classes 143 | :param flat_pred: flattened prediction matrix 144 | :param flat_label: flattened label matrix 145 | :return: mean IOU 146 | ''' 147 | unique_labels = np.unique(flat_label) 148 | num_unique_labels = len(unique_labels) 149 | 150 | Intersect = np.zeros(num_unique_labels) 151 | Union = np.zeros(num_unique_labels) 152 | precision = np.zeros(num_unique_labels) 153 | recall = np.zeros(num_unique_labels) 154 | f1 = np.zeros(num_unique_labels) 155 | 156 | for index, val in enumerate(unique_labels): 157 | pred_i = flat_pred == val 158 | label_i = flat_label == val 159 | 160 | if info: 161 | precision[index] = precision_score(pred_i, label_i, 'weighted') 162 | recall[index] = recall_score(pred_i, label_i, 'weighted') 163 | f1[index] = f1_score(pred_i, label_i, 'weighted') 164 | 165 | Intersect[index] = float(np.sum(np.logical_and(label_i, pred_i))) 166 | Union[index] = float(np.sum(np.logical_or(label_i, pred_i))) 167 | 168 | if info: 169 | print ("per-class mIOU: ", Intersect / Union) 170 | print ("per-class precision: ", precision) 171 | print ("per-class recall: ", recall) 172 | print ("per-class f1: ", f1) 173 | mean_iou = np.mean(Intersect / Union) 174 | return mean_iou 175 | 176 | def total_metric(nparams,miou): 177 | S = nparams * 4.0 / (1024 * 1024) 178 | total = min(1,1.0/S) + miou 179 | return total * 0.5 180 | 181 | 182 | def get_nparams(model): 183 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 184 | 185 | 186 | def get_predictions(output): 187 | bs,c,h,w = output.size() 188 | values, indices = output.cpu().max(1) 189 | indices = indices.view(bs,h,w) # bs x h x w 190 | return indices 191 | 192 | 193 | class Logger(): 194 | def __init__(self, output_name): 195 | dirname = os.path.dirname(output_name) 196 | if not os.path.exists(dirname): 197 | os.mkdir(dirname) 198 | self.dirname = dirname 199 | self.log_file = open(output_name, 'a+') 200 | self.infos = {} 201 | 202 | def append(self, key, val): 203 | vals = self.infos.setdefault(key, []) 204 | vals.append(val) 205 | 206 | def log(self, extra_msg=''): 207 | msgs = [extra_msg] 208 | for key, vals in self.infos.iteritems(): 209 | msgs.append('%s %.6f' % (key, np.mean(vals))) 210 | msg = '\n'.join(msgs) 211 | self.log_file.write(msg + '\n') 212 | self.log_file.flush() 213 | self.infos = {} 214 | return msg 215 | 216 | def write_silent(self, msg): 217 | self.log_file.write(msg + '\n') 218 | self.log_file.flush() 219 | 220 | def write(self, msg): 221 | self.log_file.write(msg + '\n') 222 | self.log_file.flush() 223 | print (msg) 224 | def write_summary(self,msg): 225 | self.log_file.write(msg) 226 | self.log_file.write('\n') 227 | self.log_file.flush() 228 | print (msg) 229 | 230 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Sep 2 11:47:44 2019 5 | 6 | @author: Aayush 7 | 8 | This file contains the dataloader and the augmentations and preprocessing done 9 | 10 | Required Preprocessing for all images (test, train and validation set): 11 | 1) Gamma correction by a factor of 0.8 12 | 2) local Contrast limited adaptive histogram equalization algorithm with clipLimit=1.5, tileGridSize=(8,8) 13 | 3) Normalization 14 | 15 | Train Image Augmentation Procedure Followed 16 | 1) Random horizontal flip with 50% probability. 17 | 2) Starburst pattern augmentation with 20% probability. 18 | 3) Random length lines augmentation around a random center with 20% probability. 19 | 4) Gaussian blur with kernel size (7,7) and random sigma with 20% probability. 20 | 5) Translation of image and labels in any direction with random factor less than 20. 21 | """ 22 | 23 | import numpy as np 24 | import torch 25 | from torch.utils.data import Dataset 26 | import os 27 | from PIL import Image 28 | from torchvision import transforms 29 | import cv2 30 | import random 31 | import os.path as osp 32 | from utils import one_hot2dist 33 | import copy 34 | 35 | transform = transforms.Compose( 36 | [transforms.ToTensor(), 37 | transforms.Normalize([0.5], [0.5])]) 38 | 39 | #%% 40 | class RandomHorizontalFlip(object): 41 | def __call__(self, img,label): 42 | if random.random() < 0.5: 43 | return img.transpose(Image.FLIP_LEFT_RIGHT),\ 44 | label.transpose(Image.FLIP_LEFT_RIGHT) 45 | return img,label 46 | 47 | class Starburst_augment(object): 48 | ## We have generated the starburst pattern from a train image 000000240768.png 49 | ## Please follow the file Starburst_generation_from_train_image_000000240768.pdf attached in the folder 50 | ## This procedure is used in order to handle people with multiple reflections for glasses 51 | ## a random translation of mask of starburst pattern 52 | def __call__(self, img): 53 | x=np.random.randint(1, 40) 54 | y=np.random.randint(1, 40) 55 | mode = np.random.randint(0, 2) 56 | starburst=Image.open('starburst_black.png').convert("L") 57 | if mode == 0: 58 | starburst = np.pad(starburst, pad_width=((0, 0), (x, 0)), mode='constant') 59 | starburst = starburst[:, :-x] 60 | if mode == 1: 61 | starburst = np.pad(starburst, pad_width=((0, 0), (0, x)), mode='constant') 62 | starburst = starburst[:, x:] 63 | 64 | img[92+y:549+y,0:400]=np.array(img)[92+y:549+y,0:400]*((255-np.array(starburst))/255)+np.array(starburst) 65 | return Image.fromarray(img) 66 | 67 | def getRandomLine(xc, yc, theta): 68 | x1 = xc - 50*np.random.rand(1)*(1 if np.random.rand(1) < 0.5 else -1) 69 | y1 = (x1 - xc)*np.tan(theta) + yc 70 | x2 = xc - (150*np.random.rand(1) + 50)*(1 if np.random.rand(1) < 0.5 else -1) 71 | y2 = (x2 - xc)*np.tan(theta) + yc 72 | return x1, y1, x2, y2 73 | 74 | class Gaussian_blur(object): 75 | def __call__(self, img): 76 | sigma_value=np.random.randint(2, 7) 77 | return Image.fromarray(cv2.GaussianBlur(img,(7,7),sigma_value)) 78 | 79 | class Translation(object): 80 | def __call__(self, base,mask): 81 | factor_h = 2*np.random.randint(1, 20) 82 | factor_v = 2*np.random.randint(1, 20) 83 | mode = np.random.randint(0, 4) 84 | # print (mode,factor_h,factor_v) 85 | if mode == 0: 86 | aug_base = np.pad(base, pad_width=((factor_v, 0), (0, 0)), mode='constant') 87 | aug_mask = np.pad(mask, pad_width=((factor_v, 0), (0, 0)), mode='constant') 88 | aug_base = aug_base[:-factor_v, :] 89 | aug_mask = aug_mask[:-factor_v, :] 90 | if mode == 1: 91 | aug_base = np.pad(base, pad_width=((0, factor_v), (0, 0)), mode='constant') 92 | aug_mask = np.pad(mask, pad_width=((0, factor_v), (0, 0)), mode='constant') 93 | aug_base = aug_base[factor_v:, :] 94 | aug_mask = aug_mask[factor_v:, :] 95 | if mode == 2: 96 | aug_base = np.pad(base, pad_width=((0, 0), (factor_h, 0)), mode='constant') 97 | aug_mask = np.pad(mask, pad_width=((0, 0), (factor_h, 0)), mode='constant') 98 | aug_base = aug_base[:, :-factor_h] 99 | aug_mask = aug_mask[:, :-factor_h] 100 | if mode == 3: 101 | aug_base = np.pad(base, pad_width=((0, 0), (0, factor_h)), mode='constant') 102 | aug_mask = np.pad(mask, pad_width=((0, 0), (0, factor_h)), mode='constant') 103 | aug_base = aug_base[:, factor_h:] 104 | aug_mask = aug_mask[:, factor_h:] 105 | return Image.fromarray(aug_base), Image.fromarray(aug_mask) 106 | 107 | class Line_augment(object): 108 | def __call__(self, base): 109 | yc, xc = (0.3 + 0.4*np.random.rand(1))*base.shape 110 | aug_base = copy.deepcopy(base) 111 | num_lines = np.random.randint(1, 10) 112 | for i in np.arange(0, num_lines): 113 | theta = np.pi*np.random.rand(1) 114 | x1, y1, x2, y2 = getRandomLine(xc, yc, theta) 115 | aug_base = cv2.line(aug_base, (x1, y1), (x2, y2), (255, 255, 255), 4) 116 | aug_base = aug_base.astype(np.uint8) 117 | return Image.fromarray(aug_base) 118 | 119 | class MaskToTensor(object): 120 | def __call__(self, img): 121 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 122 | 123 | 124 | class IrisDataset(Dataset): 125 | def __init__(self, filepath, split='train',transform=None,**args): 126 | self.transform = transform 127 | self.filepath= osp.join(filepath,split) 128 | self.split = split 129 | listall = [] 130 | 131 | for file in os.listdir(osp.join(self.filepath,'images')): 132 | if file.endswith(".png"): 133 | listall.append(file.strip(".png")) 134 | self.list_files=listall 135 | 136 | self.testrun = args.get('testrun') 137 | 138 | #PREPROCESSING STEP FOR ALL TRAIN, VALIDATION AND TEST INPUTS 139 | #local Contrast limited adaptive histogram equalization algorithm 140 | self.clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8,8)) 141 | 142 | def __len__(self): 143 | if self.testrun: 144 | return 10 145 | return len(self.list_files) 146 | 147 | def __getitem__(self, idx): 148 | imagepath = osp.join(self.filepath,'images',self.list_files[idx]+'.png') 149 | pilimg = Image.open(imagepath).convert("L") 150 | H, W = pilimg.width , pilimg.height 151 | 152 | #PREPROCESSING STEP FOR ALL TRAIN, VALIDATION AND TEST INPUTS 153 | #Fixed gamma value for 154 | table = 255.0*(np.linspace(0, 1, 256)**0.8) 155 | pilimg = cv2.LUT(np.array(pilimg), table) 156 | 157 | 158 | if self.split != 'test': 159 | labelpath = osp.join(self.filepath,'labels',self.list_files[idx]+'.npy') 160 | label = np.load(labelpath) 161 | label = np.resize(label,(W,H)) 162 | label = Image.fromarray(label) 163 | 164 | if self.transform is not None: 165 | if self.split == 'train': 166 | if random.random() < 0.2: 167 | pilimg = Starburst_augment()(np.array(pilimg)) 168 | if random.random() < 0.2: 169 | pilimg = Line_augment()(np.array(pilimg)) 170 | if random.random() < 0.2: 171 | pilimg = Gaussian_blur()(np.array(pilimg)) 172 | if random.random() < 0.4: 173 | pilimg, label = Translation()(np.array(pilimg),np.array(label)) 174 | 175 | img = self.clahe.apply(np.array(np.uint8(pilimg))) 176 | img = Image.fromarray(img) 177 | 178 | if self.transform is not None: 179 | if self.split == 'train': 180 | img, label = RandomHorizontalFlip()(img,label) 181 | img = self.transform(img) 182 | 183 | 184 | if self.split != 'test': 185 | ## This is for boundary aware cross entropy calculation 186 | spatialWeights = cv2.Canny(np.array(label),0,3)/255 187 | spatialWeights=cv2.dilate(spatialWeights,(3,3),iterations = 1)*20 188 | 189 | ##This is the implementation for the surface loss 190 | # Distance map for each class 191 | distMap = [] 192 | for i in range(0, 4): 193 | distMap.append(one_hot2dist(np.array(label)==i)) 194 | distMap = np.stack(distMap, 0) 195 | # spatialWeights=np.float32(distMap) 196 | 197 | 198 | if self.split == 'test': 199 | ##since label, spatialWeights and distMap is not needed for test images 200 | return img,0,self.list_files[idx],0,0 201 | 202 | label = MaskToTensor()(label) 203 | return img, label, self.list_files[idx],spatialWeights,np.float32(distMap) 204 | 205 | if __name__ == "__main__": 206 | import matplotlib.pyplot as plt 207 | ds = IrisDataset('Semantic_Segmentation_Dataset',split='train',transform=transform) 208 | # for i in range(1000): 209 | img, label, idx,x,y= ds[0] 210 | plt.subplot(121) 211 | plt.imshow(np.array(label)) 212 | plt.subplot(122) 213 | plt.imshow(np.array(img)[0,:,:],cmap='gray') -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - anaconda 4 | - menpo 5 | - pytorch 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _ipyw_jlab_nb_ext_conf=0.1.0=py36he11e457_0 10 | - _libgcc_mutex=0.1=main 11 | - alabaster=0.7.10=py36h306e16b_0 12 | - anaconda-client=1.6.14=py36_0 13 | - anaconda-navigator=1.8.7=py36_0 14 | - anaconda-project=0.8.2=py36h44fb852_0 15 | - asn1crypto=0.24.0=py36_0 16 | - astroid=1.6.3=py36_0 17 | - astropy=3.0.2=py36h3010b51_1 18 | - attrs=18.1.0=py36_0 19 | - av=6.0.0=py36h7273d18_0 20 | - babel=2.5.3=py36_0 21 | - backcall=0.1.0=py36_0 22 | - backports=1.0=py36hfa02d7e_1 23 | - backports.shutil_get_terminal_size=1.0.0=py36hfea85ff_2 24 | - beautifulsoup4=4.6.0=py36h49b8c8c_1 25 | - bitarray=0.8.1=py36h14c3975_1 26 | - bkcharts=0.2=py36h735825a_0 27 | - blas=1.0=mkl 28 | - blaze=0.11.3=py36h4e06776_0 29 | - bleach=2.1.3=py36_0 30 | - blosc=1.14.3=hdbcaa40_0 31 | - bokeh=0.12.16=py36_0 32 | - boto=2.48.0=py36h6e4cd66_1 33 | - bottleneck=1.2.1=py36haac1ea0_0 34 | - bzip2=1.0.6=h14c3975_5 35 | - ca-certificates=2019.5.15=0 36 | - cairo=1.14.12=h7636065_2 37 | - certifi=2019.6.16=py36_1 38 | - cffi=1.12.3=py36h2e261b9_0 39 | - chardet=3.0.4=py36h0f667ec_1 40 | - click=6.7=py36h5253387_0 41 | - cloudpickle=0.5.3=py36_0 42 | - clyent=1.2.2=py36h7e57e65_1 43 | - colorama=0.3.9=py36h489cec4_0 44 | - conda=4.6.11=py36_0 45 | - conda-build=3.10.9=py36_0 46 | - conda-env=2.6.0=h36134e3_1 47 | - conda-verify=2.0.0=py36h98955d8_0 48 | - contextlib2=0.5.5=py36h6c84a62_0 49 | - cryptography=2.2.2=py36h14c3975_0 50 | - cuda75=1.0=hf2493ae_0 51 | - cudatoolkit=8.0=3 52 | - curl=7.60.0=h84994c4_0 53 | - cycler=0.10.0=py36h93f1223_0 54 | - cython=0.28.2=py36h14c3975_0 55 | - cytoolz=0.9.0.1=py36h14c3975_0 56 | - dask=0.17.5=py36_0 57 | - dask-core=0.17.5=py36_0 58 | - datashape=0.5.4=py36h3ad6b5c_0 59 | - dbus=1.13.2=h714fa37_1 60 | - decorator=4.3.0=py36_0 61 | - distributed=1.21.8=py36_0 62 | - docutils=0.14=py36hb0f60f5_0 63 | - entrypoints=0.2.3=py36h1aec115_2 64 | - et_xmlfile=1.0.1=py36hd6bccc3_0 65 | - expat=2.2.5=he0dffb1_0 66 | - fastcache=1.0.2=py36h14c3975_2 67 | - ffmpeg=4.0.2=ha6a6e2b_0 68 | - filelock=3.0.4=py36_0 69 | - flask=1.0.2=py36_1 70 | - flask-cors=3.0.4=py36_0 71 | - fontconfig=2.12.6=h49f89f6_0 72 | - freeglut=2.8.1=0 73 | - freetype=2.8.1=hfa320df_1 74 | - get_terminal_size=1.0.0=haa9412d_0 75 | - gevent=1.3.0=py36h14c3975_0 76 | - glib=2.56.1=h000015b_0 77 | - glob2=0.6=py36he249c77_0 78 | - gmp=6.1.2=h6c8ec71_1 79 | - gmpy2=2.0.8=py36hc8893dd_2 80 | - gnutls=3.5.19=h2a4e5f8_1 81 | - graphite2=1.3.11=h16798f4_2 82 | - greenlet=0.4.13=py36h14c3975_0 83 | - gst-plugins-base=1.14.0=hbbd80ab_1 84 | - gstreamer=1.14.0=hb453b48_1 85 | - h5py=2.7.1=py36ha1f6525_2 86 | - harfbuzz=1.7.6=h5f0a787_1 87 | - hdf5=1.10.2=hba1933b_1 88 | - heapdict=1.0.0=py36_2 89 | - html5lib=1.0.1=py36h2f9c1c0_0 90 | - icu=58.2=h9c2bf20_1 91 | - idna=2.6=py36h82fb2a8_1 92 | - imageio=2.3.0=py36_0 93 | - imagesize=1.0.0=py36_0 94 | - intel-openmp=2018.0.0=8 95 | - ipykernel=4.8.2=py36_0 96 | - ipython=6.4.0=py36_0 97 | - ipython_genutils=0.2.0=py36hb52b0d5_0 98 | - ipywidgets=7.2.1=py36_0 99 | - isort=4.3.4=py36_0 100 | - itsdangerous=0.24=py36h93cc618_1 101 | - jasper=1.900.1=hd497a04_4 102 | - jbig=2.1=hdba287a_0 103 | - jdcal=1.4=py36_0 104 | - jedi=0.12.0=py36_1 105 | - jinja2=2.10=py36ha16c418_0 106 | - jpeg=9b=h024ee3a_2 107 | - jsonschema=2.6.0=py36h006f8b5_0 108 | - jupyter=1.0.0=py36_4 109 | - jupyter_client=5.2.3=py36_0 110 | - jupyter_console=5.2.0=py36he59e554_1 111 | - jupyter_core=4.4.0=py36h7c827e3_0 112 | - jupyterlab=0.32.1=py36_0 113 | - jupyterlab_launcher=0.10.5=py36_0 114 | - kiwisolver=1.0.1=py36h764f252_0 115 | - lazy-object-proxy=1.3.1=py36h10fcdad_0 116 | - libcurl=7.60.0=h1ad7b7a_0 117 | - libedit=3.1.20181209=hc058e9b_0 118 | - libffi=3.2.1=hd88cf55_4 119 | - libgcc-ng=9.1.0=hdf63c60_0 120 | - libgfortran-ng=7.2.0=hdf63c60_3 121 | - libiconv=1.15=h470a237_3 122 | - libopencv=3.4.1=h1a3b859_1 123 | - libopus=1.2.1=hb9ed12e_0 124 | - libpng=1.6.34=hb9fc6fc_0 125 | - libprotobuf=3.5.2=h6f1eeef_0 126 | - libsodium=1.0.16=h1bed415_0 127 | - libssh2=1.8.0=h9cfc8f7_4 128 | - libstdcxx-ng=9.1.0=hdf63c60_0 129 | - libtiff=4.0.9=he85c1e1_1 130 | - libtool=2.4.6=h544aabb_3 131 | - libvpx=1.7.0=h439df22_0 132 | - libxcb=1.13=h1bed415_1 133 | - libxml2=2.9.8=h26e45fe_1 134 | - libxslt=1.1.32=h1312cb7_0 135 | - llvmlite=0.23.1=py36hdbcaa40_0 136 | - locket=0.2.0=py36h787c0ad_1 137 | - lxml=4.2.1=py36h23eabaa_0 138 | - lzo=2.10=h49e0be7_2 139 | - markupsafe=1.0=py36hd9260cd_1 140 | - matplotlib=2.2.2=py36h0e671d2_1 141 | - mccabe=0.6.1=py36h5ad9710_1 142 | - mistune=0.8.3=py36h14c3975_1 143 | - mkl=2018.0.3=1 144 | - mkl-service=1.1.2=py36h17a0993_4 145 | - mkl_fft=1.0.6=py36h7dd41cf_0 146 | - mkl_random=1.0.1=py36h4414c95_1 147 | - more-itertools=4.1.0=py36_0 148 | - mpc=1.0.3=hec55b23_5 149 | - mpfr=3.1.5=h11a74b3_2 150 | - mpmath=1.0.0=py36hfeacd6b_2 151 | - msgpack-python=0.5.6=py36h6bb024c_0 152 | - multipledispatch=0.5.0=py36_0 153 | - navigator-updater=0.2.1=py36_0 154 | - nbconvert=5.3.1=py36hb41ffb7_0 155 | - nbformat=4.4.0=py36h31c9010_0 156 | - ncurses=6.1=he6710b0_1 157 | - nettle=3.3=0 158 | - networkx=2.1=py36_0 159 | - ninja=1.7.2=0 160 | - nltk=3.3.0=py36_0 161 | - nose=1.3.7=py36hcdf7029_2 162 | - notebook=5.5.0=py36_0 163 | - numba=0.38.0=py36h637b7d7_0 164 | - numexpr=2.6.5=py36h7bf3b9c_0 165 | - numpy=1.15.4=py36h1d66e8a_0 166 | - numpy-base=1.15.4=py36h81de0dd_0 167 | - numpydoc=0.8.0=py36_0 168 | - odo=0.5.1=py36h90ed295_0 169 | - olefile=0.45.1=py36_0 170 | - opencv3=3.1.0=py36_0 171 | - openh264=1.7.0=0 172 | - openpyxl=2.5.3=py36_0 173 | - openssl=1.1.1c=h7b6447c_1 174 | - packaging=17.1=py36_0 175 | - pandas=0.23.0=py36h637b7d7_0 176 | - pandoc=1.19.2.1=hea2e7c5_1 177 | - pandocfilters=1.4.2=py36ha6701b7_1 178 | - pango=1.41.0=hd475d92_0 179 | - parso=0.2.0=py36_0 180 | - partd=0.3.8=py36h36fd896_0 181 | - patchelf=0.9=hf79760b_2 182 | - path.py=11.0.1=py36_0 183 | - pathlib2=2.3.2=py36_0 184 | - patsy=0.5.0=py36_0 185 | - pcre=8.42=h439df22_0 186 | - pep8=1.7.1=py36_0 187 | - pexpect=4.5.0=py36_0 188 | - pickleshare=0.7.4=py36h63277f8_0 189 | - pillow=5.1.0=py36h3deb7b8_0 190 | - pixman=0.34.0=hceecf20_3 191 | - pkginfo=1.4.2=py36_1 192 | - pluggy=0.6.0=py36hb689045_0 193 | - ply=3.11=py36_0 194 | - prompt_toolkit=1.0.15=py36h17d85b1_0 195 | - psutil=5.4.5=py36h14c3975_0 196 | - ptyprocess=0.5.2=py36h69acd42_0 197 | - py=1.5.3=py36_0 198 | - py-opencv=3.4.1=py36h0676e08_1 199 | - pycodestyle=2.4.0=py36_0 200 | - pycosat=0.6.3=py36h0a5515d_0 201 | - pycparser=2.19=py36_0 202 | - pycrypto=2.6.1=py36h14c3975_8 203 | - pycurl=7.43.0.1=py36hb7f436b_0 204 | - pyflakes=1.6.0=py36h7bd6a15_0 205 | - pygments=2.2.0=py36h0d3125c_0 206 | - pylint=1.8.4=py36_0 207 | - pyodbc=4.0.23=py36hf484d3e_0 208 | - pyopengl=3.1.1a1=py36_0 209 | - pyopenssl=18.0.0=py36_0 210 | - pyparsing=2.2.0=py36hee85983_1 211 | - pyqt=5.9.2=py36h751905a_0 212 | - pyserial=3.4=py36_0 213 | - pysocks=1.6.8=py36_0 214 | - pytables=3.4.3=py36h02b9ad4_2 215 | - pytest=3.5.1=py36_0 216 | - pytest-arraydiff=0.2=py36_0 217 | - pytest-astropy=0.3.0=py36_0 218 | - pytest-doctestplus=0.1.3=py36_0 219 | - pytest-openfiles=0.3.0=py36_0 220 | - pytest-remotedata=0.2.1=py36_0 221 | - python=3.6.9=h265db76_0 222 | - python-dateutil=2.7.3=py36_0 223 | - pytz=2018.4=py36_0 224 | - pywavelets=0.5.2=py36he602eb0_0 225 | - pyyaml=3.12=py36hafb9ca4_1 226 | - pyzmq=17.0.0=py36h14c3975_0 227 | - qt=5.9.5=h7e424d6_0 228 | - qtawesome=0.4.4=py36h609ed8c_0 229 | - qtconsole=4.3.1=py36h8f73b5b_0 230 | - qtpy=1.4.1=py36_0 231 | - readline=7.0=h7b6447c_5 232 | - requests=2.18.4=py36he2e5f8d_1 233 | - rope=0.10.7=py36h147e2ec_0 234 | - ruamel_yaml=0.15.35=py36h14c3975_1 235 | - scikit-image=0.13.1=py36h14c3975_1 236 | - scikit-learn=0.19.1=py36h7aa7ec6_0 237 | - scipy=1.1.0=py36hfc37229_0 238 | - seaborn=0.8.1=py36hfad7ec4_0 239 | - send2trash=1.5.0=py36_0 240 | - setuptools=41.0.1=py36_0 241 | - simplegeneric=0.8.1=py36_2 242 | - singledispatch=3.4.0.3=py36h7a266c3_0 243 | - sip=4.19.8=py36hf484d3e_0 244 | - six=1.11.0=py36h372c433_1 245 | - snappy=1.1.7=hbae5bb6_3 246 | - snowballstemmer=1.2.1=py36h6febd40_0 247 | - sortedcollections=0.6.1=py36_0 248 | - sortedcontainers=1.5.10=py36_0 249 | - sphinx=1.7.4=py36_0 250 | - sphinxcontrib=1.0=py36h6d0f590_1 251 | - sphinxcontrib-websupport=1.0.1=py36hb5cb234_1 252 | - spyder=3.2.8=py36_0 253 | - sqlalchemy=1.2.7=py36h6b74fdf_0 254 | - sqlite=3.29.0=h7b6447c_0 255 | - statsmodels=0.9.0=py36h3010b51_0 256 | - sympy=1.1.1=py36hc6d1c1c_0 257 | - tblib=1.3.2=py36h34cf8b6_0 258 | - terminado=0.8.1=py36_1 259 | - testpath=0.3.1=py36h8cadb63_0 260 | - tk=8.6.8=hbc83047_0 261 | - toolz=0.9.0=py36_0 262 | - tornado=5.0.2=py36_0 263 | - traitlets=4.3.2=py36h674d592_0 264 | - typing=3.6.4=py36_0 265 | - unicodecsv=0.14.1=py36ha668878_0 266 | - unixodbc=2.3.6=h1bed415_0 267 | - urllib3=1.22=py36hbe7ace6_0 268 | - wcwidth=0.1.7=py36hdf4376a_0 269 | - webencodings=0.5.1=py36h800622e_1 270 | - werkzeug=0.14.1=py36_0 271 | - wheel=0.33.4=py36_0 272 | - widgetsnbextension=3.2.1=py36_0 273 | - wrapt=1.10.11=py36h28b7045_0 274 | - x264=1!152.20180717=h470a237_1 275 | - xlrd=1.1.0=py36h1db9f0c_1 276 | - xlsxwriter=1.0.4=py36_0 277 | - xlwt=1.3.0=py36h7b00a1f_0 278 | - xz=5.2.4=h14c3975_4 279 | - yaml=0.1.7=had09818_2 280 | - zeromq=4.2.5=h439df22_0 281 | - zict=0.1.3=py36h3a3bf81_0 282 | - zlib=1.2.11=h7b6447c_3 283 | - pip: 284 | - deepdish==0.3.6 285 | - dlib==19.16.0 286 | - enum34==1.1.6 287 | - ffprobe==0.5 288 | - future==0.17.1 289 | - imutils==0.5.1 290 | - iso8601==0.1.12 291 | - jupyter-http-over-ws==0.0.6 292 | - open-3d==0.3.0.0 293 | - open3d-official==0.3.0.0 294 | - pims==0.4.1 295 | - pip==19.0.3 296 | - pkg-config==0.0.1 297 | - plyfile==0.7 298 | - poppy==0.8.0 299 | - pptk==0.1.0 300 | - py-tvd==1.0 301 | - pybind11==2.2.4 302 | - python-pptx==0.6.17 303 | - scikit-video==1.1.10 304 | - simpleitk==1.2.2 305 | - slicerator==0.9.8 306 | - torch==1.0.1 307 | - torchsummary==1.5.1 308 | - torchvision==0.4.0 309 | - tqdm==4.35.0 310 | prefix: /home/aaa/anaconda3 311 | 312 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | If you use this code, please cite: 2 | ``` 3 | @INPROCEEDINGS{chaudhary2019ritnet, 4 | author={A. K. {Chaudhary} and R. {Kothari} and M. {Acharya} and S. {Dangi} and N. {Nair} and R. {Bailey} and C. {Kanan} and G. {Diaz} and J. B. {Pelz}}, 5 | booktitle={2019 IEEE/CVF International Conference on Computer Vision Workshop (ICCVW)}, 6 | title={RITnet: Real-time Semantic Segmentation of the Eye for Gaze Tracking}, 7 | year={2019}, 8 | volume={}, 9 | number={}, 10 | pages={3698-3702}, 11 | keywords={Semantics;Image segmentation;Computational modeling;Iris;Robustness;Convolution;Real-time systems}, 12 | doi={10.1109/ICCVW.2019.00568}, 13 | ISSN={2473-9936}, 14 | month={Oct},} 15 | ``` 16 | 17 | Instructions: 18 | 19 | ```python train.py --help``` 20 | 21 | To train the model with densenet model: 22 | 23 | ```python train.py --model densenet --expname FINAL --bs 8 --useGPU True --dataset Semantic_Segmentation_Dataset/``` 24 | 25 | To test the result: 26 | 27 | ```python test.py --model densenet --load best_model.pkl --bs 4 --dataset Semantic_Segmentation_Dataset/``` 28 | 29 | 30 | # Contents in the zip folder 31 | ``` 32 | best_model.pkl :: Our final model (potential winner model) which contains all the weights in Float32 format (Number of Parameters 248900). 33 | requirements.txt :: Includes all the necessary packages for the source code to run 34 | environment.yml :: List of all packages and version of one of our system in which the code was run successfully. 35 | dataset.py ::Data loader and augmentation 36 | train.py ::Train code 37 | test.py ::Test code 38 | densenet.py ::Model code 39 | utils.py ::List of utility files 40 | opt.py ::List of arguments for argparser 41 | models.py ::List of all models 42 | starburst_black.png:: A fixed structured pattern (with translation) used on train images to handle cases such as multiple reflections.(Train Image: 000000240768.png) 43 | Starburst generation from train image 000000240768.pdf ::Procedure how starburst pattern is generated 44 | ``` 45 | 46 | 47 | The requirements.txt file contains all the packages necessary for the code to run. We have also included an environment.yml file to 48 | recreate the conda environment we used. 49 | 50 | We have submitted two models from this version of code: 51 | 52 | 1. Epoch: 151 Validation accuracy: 95.7780 Test accuracy: 95.276 (Potential Winner Model: Last Submission) 53 | 2. Epoch: 117 Validation accuracy: 95.7023 Test accuracy: 95.159 (Our Second Last Submission) 54 | 55 | We could reach upto 56 | Epoch: 240 Validation accuracy: 95.7820 Test accuracy:NA (Not submitted: result after the deadline) 57 | 58 | 59 | The dataset.py contains data loader, preprocessing and post processing step 60 | Required Preprocessing for all images (test, train and validation set). 61 | 62 | 1. Gamma correction by a factor of 0.8 63 | 2. local Contrast limited adaptive histogram equalization algorithm with clipLimit=1.5, tileGridSize=(8,8) 64 | 3. Normalization [Mean 0.5, std=0.5] 65 | 66 | Train Image Augmentation Procedure Followed (Not Required during test) 67 | 68 | 1. Random horizontal flip with 50% probability. 69 | 2. Starburst pattern augmentation with 20% probability. 70 | 3. Random length lines (1 to 9) augmentation around a random center with 20% probability. 71 | 4. Gaussian blur with kernel size (7,7) and random sigma (2 to 7) with 20% probability. 72 | 5. Translation of image and labels in any direction with random factor less than 20 with 20% probability. 73 | 74 | ``` 75 | ---------------------------------------------------------------- 76 | Layer (type) Output Shape Param # 77 | ================================================================ 78 | Conv2d-1 [-1, 32, 640, 400] 320 79 | Dropout-2 [-1, 32, 640, 400] 0 80 | LeakyReLU-3 [-1, 32, 640, 400] 0 81 | Conv2d-4 [-1, 32, 640, 400] 1,088 82 | Conv2d-5 [-1, 32, 640, 400] 9,248 83 | Dropout-6 [-1, 32, 640, 400] 0 84 | LeakyReLU-7 [-1, 32, 640, 400] 0 85 | Conv2d-8 [-1, 32, 640, 400] 2,112 86 | Conv2d-9 [-1, 32, 640, 400] 9,248 87 | Dropout-10 [-1, 32, 640, 400] 0 88 | LeakyReLU-11 [-1, 32, 640, 400] 0 89 | BatchNorm2d-12 [-1, 32, 640, 400] 64 90 | DenseNet2D_down_block-13 [-1, 32, 640, 400] 0 91 | AvgPool2d-14 [-1, 32, 320, 200] 0 92 | Conv2d-15 [-1, 32, 320, 200] 9,248 93 | Dropout-16 [-1, 32, 320, 200] 0 94 | LeakyReLU-17 [-1, 32, 320, 200] 0 95 | Conv2d-18 [-1, 32, 320, 200] 2,080 96 | Conv2d-19 [-1, 32, 320, 200] 9,248 97 | Dropout-20 [-1, 32, 320, 200] 0 98 | LeakyReLU-21 [-1, 32, 320, 200] 0 99 | Conv2d-22 [-1, 32, 320, 200] 3,104 100 | Conv2d-23 [-1, 32, 320, 200] 9,248 101 | Dropout-24 [-1, 32, 320, 200] 0 102 | LeakyReLU-25 [-1, 32, 320, 200] 0 103 | BatchNorm2d-26 [-1, 32, 320, 200] 64 104 | DenseNet2D_down_block-27 [-1, 32, 320, 200] 0 105 | AvgPool2d-28 [-1, 32, 160, 100] 0 106 | Conv2d-29 [-1, 32, 160, 100] 9,248 107 | Dropout-30 [-1, 32, 160, 100] 0 108 | LeakyReLU-31 [-1, 32, 160, 100] 0 109 | Conv2d-32 [-1, 32, 160, 100] 2,080 110 | Conv2d-33 [-1, 32, 160, 100] 9,248 111 | Dropout-34 [-1, 32, 160, 100] 0 112 | LeakyReLU-35 [-1, 32, 160, 100] 0 113 | Conv2d-36 [-1, 32, 160, 100] 3,104 114 | Conv2d-37 [-1, 32, 160, 100] 9,248 115 | Dropout-38 [-1, 32, 160, 100] 0 116 | LeakyReLU-39 [-1, 32, 160, 100] 0 117 | BatchNorm2d-40 [-1, 32, 160, 100] 64 118 | DenseNet2D_down_block-41 [-1, 32, 160, 100] 0 119 | AvgPool2d-42 [-1, 32, 80, 50] 0 120 | Conv2d-43 [-1, 32, 80, 50] 9,248 121 | Dropout-44 [-1, 32, 80, 50] 0 122 | LeakyReLU-45 [-1, 32, 80, 50] 0 123 | Conv2d-46 [-1, 32, 80, 50] 2,080 124 | Conv2d-47 [-1, 32, 80, 50] 9,248 125 | Dropout-48 [-1, 32, 80, 50] 0 126 | LeakyReLU-49 [-1, 32, 80, 50] 0 127 | Conv2d-50 [-1, 32, 80, 50] 3,104 128 | Conv2d-51 [-1, 32, 80, 50] 9,248 129 | Dropout-52 [-1, 32, 80, 50] 0 130 | LeakyReLU-53 [-1, 32, 80, 50] 0 131 | BatchNorm2d-54 [-1, 32, 80, 50] 64 132 | DenseNet2D_down_block-55 [-1, 32, 80, 50] 0 133 | AvgPool2d-56 [-1, 32, 40, 25] 0 134 | Conv2d-57 [-1, 32, 40, 25] 9,248 135 | Dropout-58 [-1, 32, 40, 25] 0 136 | LeakyReLU-59 [-1, 32, 40, 25] 0 137 | Conv2d-60 [-1, 32, 40, 25] 2,080 138 | Conv2d-61 [-1, 32, 40, 25] 9,248 139 | Dropout-62 [-1, 32, 40, 25] 0 140 | LeakyReLU-63 [-1, 32, 40, 25] 0 141 | Conv2d-64 [-1, 32, 40, 25] 3,104 142 | Conv2d-65 [-1, 32, 40, 25] 9,248 143 | Dropout-66 [-1, 32, 40, 25] 0 144 | LeakyReLU-67 [-1, 32, 40, 25] 0 145 | BatchNorm2d-68 [-1, 32, 40, 25] 64 146 | DenseNet2D_down_block-69 [-1, 32, 40, 25] 0 147 | Conv2d-70 [-1, 32, 80, 50] 2,080 148 | Conv2d-71 [-1, 32, 80, 50] 9,248 149 | Dropout-72 [-1, 32, 80, 50] 0 150 | LeakyReLU-73 [-1, 32, 80, 50] 0 151 | Conv2d-74 [-1, 32, 80, 50] 3,104 152 | Conv2d-75 [-1, 32, 80, 50] 9,248 153 | Dropout-76 [-1, 32, 80, 50] 0 154 | LeakyReLU-77 [-1, 32, 80, 50] 0 155 | DenseNet2D_up_block_concat-78 [-1, 32, 80, 50] 0 156 | Conv2d-79 [-1, 32, 160, 100] 2,080 157 | Conv2d-80 [-1, 32, 160, 100] 9,248 158 | Dropout-81 [-1, 32, 160, 100] 0 159 | LeakyReLU-82 [-1, 32, 160, 100] 0 160 | Conv2d-83 [-1, 32, 160, 100] 3,104 161 | Conv2d-84 [-1, 32, 160, 100] 9,248 162 | Dropout-85 [-1, 32, 160, 100] 0 163 | LeakyReLU-86 [-1, 32, 160, 100] 0 164 | DenseNet2D_up_block_concat-87 [-1, 32, 160, 100] 0 165 | Conv2d-88 [-1, 32, 320, 200] 2,080 166 | Conv2d-89 [-1, 32, 320, 200] 9,248 167 | Dropout-90 [-1, 32, 320, 200] 0 168 | LeakyReLU-91 [-1, 32, 320, 200] 0 169 | Conv2d-92 [-1, 32, 320, 200] 3,104 170 | Conv2d-93 [-1, 32, 320, 200] 9,248 171 | Dropout-94 [-1, 32, 320, 200] 0 172 | LeakyReLU-95 [-1, 32, 320, 200] 0 173 | DenseNet2D_up_block_concat-96 [-1, 32, 320, 200] 0 174 | Conv2d-97 [-1, 32, 640, 400] 2,080 175 | Conv2d-98 [-1, 32, 640, 400] 9,248 176 | Dropout-99 [-1, 32, 640, 400] 0 177 | LeakyReLU-100 [-1, 32, 640, 400] 0 178 | Conv2d-101 [-1, 32, 640, 400] 3,104 179 | Conv2d-102 [-1, 32, 640, 400] 9,248 180 | Dropout-103 [-1, 32, 640, 400] 0 181 | LeakyReLU-104 [-1, 32, 640, 400] 0 182 | DenseNet2D_up_block_concat-105 [-1, 32, 640, 400] 0 183 | Dropout-106 [-1, 32, 640, 400] 0 184 | Conv2d-107 [-1, 4, 640, 400] 132 185 | ================================================================ 186 | Total params: 248,900 187 | Trainable params: 248,900 188 | Non-trainable params: 0 189 | ---------------------------------------------------------------- 190 | Input size (MB): 0.98 191 | Forward/backward pass size (MB): 1920.41 192 | Params size (MB): 0.95 193 | Estimated Total Size (MB): 1922.34 194 | ---------------------------------------------------------------- 195 | 196 | ``` 197 | --------------------------------------------------------------------------------