├── .gitignore ├── README.md ├── util.py ├── model.py ├── dataset.py ├── train.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | __pycache__/* 3 | data/* 4 | saved_models/* 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IDRiD 2 | 3 | We use deep learning to solve the [IDRiD challenge](https://idrid.grand-challenge.org/) sub-challenge 1. 4 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -* 2 | import torch 3 | from sklearn.metrics import f1_score 4 | import numpy as np 5 | import os 6 | 7 | def weighted_BCELoss(output, target, weights=None): 8 | 9 | output = output.clamp(min=1e-5, max=1-1e-5) 10 | if weights is not None: 11 | assert len(weights) == 2 12 | 13 | loss = -weights[0] * (target * torch.log(output)) - weights[1] * ((1 - target) * torch.log(1 - output)) 14 | else: 15 | loss = -target * torch.log(output) - (1 - target) * torch.log(1 - output) 16 | 17 | return torch.mean(loss) 18 | 19 | def evaluate(y_true, y_pred): 20 | ''' 21 | Calculate statistic matrix. 22 | 23 | Args: 24 | y_true:the pytorch tensor of ground truth 25 | y_pred:the pytorch tensor of prediction 26 | return: 27 | The F1 score 28 | ''' 29 | y_true = y_true.numpy().flatten() 30 | y_pred = np.rint(y_pred.numpy().flatten()) 31 | f1 = f1_score(y_true, y_pred) 32 | return f1 33 | 34 | 35 | def save_model(model, save_dir, name): 36 | #save model 37 | if not os.path.exists(save_dir): 38 | os.makedirs(save_dir) 39 | path = os.path.join(save_dir, name) 40 | print('Saving model to directory "%s"'%(path)) 41 | torch.save(model.state_dict(), path) 42 | 43 | 44 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The code is modified from https://github.com/ZijunDeng/pytorch-semantic-segmentation 3 | ''' 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision import models 7 | 8 | # many are borrowed from https://github.com/ycszen/pytorch-ss/blob/master/gcn.py 9 | class _GlobalConvModule(nn.Module): 10 | def __init__(self, in_dim, out_dim, kernel_size): 11 | super(_GlobalConvModule, self).__init__() 12 | pad0 = int((kernel_size[0] - 1) / 2) 13 | pad1 = int((kernel_size[1] - 1) / 2) 14 | # kernel size had better be odd number so as to avoid alignment error 15 | super(_GlobalConvModule, self).__init__() 16 | self.conv_l1 = nn.Conv2d(in_dim, out_dim, kernel_size=(kernel_size[0], 1), 17 | padding=(pad0, 0)) 18 | self.conv_l2 = nn.Conv2d(out_dim, out_dim, kernel_size=(1, kernel_size[1]), 19 | padding=(0, pad1)) 20 | self.conv_r1 = nn.Conv2d(in_dim, out_dim, kernel_size=(1, kernel_size[1]), 21 | padding=(0, pad1)) 22 | self.conv_r2 = nn.Conv2d(out_dim, out_dim, kernel_size=(kernel_size[0], 1), 23 | padding=(pad0, 0)) 24 | 25 | def forward(self, x): 26 | x_l = self.conv_l1(x) 27 | x_l = self.conv_l2(x_l) 28 | x_r = self.conv_r1(x) 29 | x_r = self.conv_r2(x_r) 30 | x = x_l + x_r 31 | return x 32 | 33 | 34 | class _BoundaryRefineModule(nn.Module): 35 | def __init__(self, dim): 36 | super(_BoundaryRefineModule, self).__init__() 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) 39 | self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1) 40 | 41 | def forward(self, x): 42 | residual = self.conv1(x) 43 | residual = self.relu(residual) 44 | residual = self.conv2(residual) 45 | out = x + residual 46 | return out 47 | 48 | 49 | class GCN(nn.Module): 50 | def __init__(self, num_classes, input_size): 51 | super(GCN, self).__init__() 52 | self.input_size = input_size 53 | resnet = models.resnet152(pretrained=True) 54 | 55 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu) 56 | self.layer1 = nn.Sequential(resnet.maxpool, resnet.layer1) 57 | self.layer2 = resnet.layer2 58 | self.layer3 = resnet.layer3 59 | self.layer4 = resnet.layer4 60 | 61 | self.gcm1 = _GlobalConvModule(2048, num_classes, (7, 7)) 62 | self.gcm2 = _GlobalConvModule(1024, num_classes, (7, 7)) 63 | self.gcm3 = _GlobalConvModule(512, num_classes, (7, 7)) 64 | self.gcm4 = _GlobalConvModule(256, num_classes, (7, 7)) 65 | 66 | self.brm1 = _BoundaryRefineModule(num_classes) 67 | self.brm2 = _BoundaryRefineModule(num_classes) 68 | self.brm3 = _BoundaryRefineModule(num_classes) 69 | self.brm4 = _BoundaryRefineModule(num_classes) 70 | self.brm5 = _BoundaryRefineModule(num_classes) 71 | self.brm6 = _BoundaryRefineModule(num_classes) 72 | self.brm7 = _BoundaryRefineModule(num_classes) 73 | self.brm8 = _BoundaryRefineModule(num_classes) 74 | self.brm9 = _BoundaryRefineModule(num_classes) 75 | 76 | initialize_weights(self.gcm1, self.gcm2, self.gcm3, self.gcm4, self.brm1, self.brm2, self.brm3, 77 | self.brm4, self.brm5, self.brm6, self.brm7, self.brm8, self.brm9) 78 | 79 | def forward(self, x): 80 | # if x: 512 81 | fm0 = self.layer0(x) # 256 82 | fm1 = self.layer1(fm0) # 128 83 | fm2 = self.layer2(fm1) # 64 84 | fm3 = self.layer3(fm2) # 32 85 | fm4 = self.layer4(fm3) # 16 86 | 87 | gcfm1 = self.brm1(self.gcm1(fm4)) # 16 88 | gcfm2 = self.brm2(self.gcm2(fm3)) # 32 89 | gcfm3 = self.brm3(self.gcm3(fm2)) # 64 90 | gcfm4 = self.brm4(self.gcm4(fm1)) # 128 91 | 92 | fs1 = self.brm5(F.upsample(gcfm1, fm3.size()[2:], mode='bilinear') + gcfm2) # 32 93 | fs2 = self.brm6(F.upsample(fs1, fm2.size()[2:], mode='bilinear') + gcfm3) # 64 94 | fs3 = self.brm7(F.upsample(fs2, fm1.size()[2:], mode='bilinear') + gcfm4) # 128 95 | fs4 = self.brm8(F.upsample(fs3, fm0.size()[2:], mode='bilinear')) # 256 96 | out = self.brm9(F.upsample(fs4, self.input_size, mode='bilinear')) # 512 97 | 98 | return out 99 | 100 | 101 | def initialize_weights(*models): 102 | for model in models: 103 | for module in model.modules(): 104 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 105 | nn.init.kaiming_normal(module.weight) 106 | if module.bias is not None: 107 | module.bias.data.zero_() 108 | elif isinstance(module, nn.BatchNorm2d): 109 | module.weight.data.fill_(1) 110 | module.bias.data.zero_() -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -* 2 | import torch 3 | from torch.utils.data import Dataset 4 | from torchvision import transforms 5 | from PIL import Image 6 | import numpy as np 7 | import os 8 | import random 9 | 10 | 11 | class IDRiD_sub1_dataset(Dataset): 12 | """ 13 | Put the images into these directories respectivly: 14 | Apparent Retinopathy, No Apparent Retinopathy, MA, EX, HE, SE 15 | 16 | It will load in the original 4288x2848 image and crop it into 9x6 small images with 17 | size of 512x512 on the fly. 18 | 19 | Use shuffle=False for this dataset as it caches only one image. 20 | 21 | Data Argumentation and transforms has not been implemented yet. 22 | """ 23 | 24 | def __init__(self, root_dir): 25 | """ 26 | Args: 27 | root_dir:the root directory of the images 28 | """ 29 | self.task_type_list = ['MA', 'EX', 'HE', 'SE'] 30 | self.root_dir = root_dir 31 | self.data_idx = []#(image_dir, mask_dirs, name) mask_dirs is a list(None for NAR images) 32 | self.data_cache = {'image': None, 'mask': None, 'name': "", 'index': None}#cache the original size image 33 | 34 | image_root = os.path.join(self.root_dir, 'Apparent Retinopathy') 35 | image_NAR_root = os.path.join(self.root_dir, 'No Apparent Retinopathy') 36 | 37 | #Get the file index 38 | #AR images 39 | for filename in os.listdir(image_root): 40 | image_dir = os.path.join(image_root, filename) 41 | mask_dirs = {task_type:None for task_type in self.task_type_list} 42 | for task_type in self.task_type_list: 43 | m_dir = os.path.join(self.root_dir, task_type, filename[:-4]+'_'+task_type+'.tif') 44 | if os.path.isfile(m_dir): mask_dirs[task_type] = m_dir 45 | name = filename[:-4] 46 | self.data_idx.append((image_dir, mask_dirs, name)) 47 | #NAR images 48 | for filename in os.listdir(image_NAR_root): 49 | image_dir = os.path.join(image_NAR_root, filename) 50 | mask_dirs = {task_type:None for task_type in self.task_type_list} 51 | name = filename[:-4] 52 | self.data_idx.append((image_dir, mask_dirs, name)) 53 | #Shuffle 54 | random.shuffle(self.data_idx) 55 | 56 | def __len__(self): 57 | return len(self.data_idx)*9*6 58 | 59 | def __getitem__(self, idx): 60 | # crop the 4288x2848 image into 512x512 => 9x6 grid 61 | # 1 image => 9x6 = 54 small images 62 | n = int(idx/(6*9))#image index 63 | r = int((idx%(6*9))/9)#row 64 | c = (idx%(6*9))%9#column 65 | 66 | #Load the images if it's not in the cache 67 | if self.data_cache['index'] != n: 68 | image_dir, mask_dirs, name = self.data_idx[n] 69 | image = Image.open(image_dir) 70 | 71 | masks = [] 72 | for task_type in self.task_type_list: 73 | if mask_dirs[task_type] is not None: 74 | #AR images 75 | mask = Image.open(mask_dirs[task_type]) 76 | mask = np.array(mask, dtype='float32') 77 | else: 78 | #NAR images 79 | w, h = image.size 80 | mask = np.zeros((h, w), dtype='float32') 81 | masks.append(mask) 82 | masks = np.array(masks) 83 | masks = np.pad(masks, ((0, 0), (0, 224), (0, 320)), 'constant', constant_values=0)#padding 84 | self.data_cache = {'image': image, 'masks': masks, 'name': name, 'index': n} 85 | 86 | #crop the image 87 | image_crop = self.data_cache['image'].crop((c*512, r*512, c*512 + 512, r*512 + 512)) 88 | masks_crop = self.data_cache['masks'][:, r*512:r*512 + 512, c*512:c*512 + 512] 89 | image_crop = transforms.ToTensor()(image_crop) 90 | masks_crop = torch.from_numpy(masks_crop) 91 | name = self.data_cache['name']+'(%2d, %2d)'%(r, c) 92 | return image_crop, masks_crop, name 93 | 94 | if __name__ == '__main__': 95 | dataset = IDRiD_sub1_dataset('./data/sub1/train') 96 | print('dataset length: %d'%(len(dataset))) 97 | #data formate test 98 | 99 | print('dataset sample') 100 | image, mask, name = dataset[random.randint(0, len(dataset)-1)] 101 | print(image, mask, name) 102 | 103 | #show image test 104 | ''' 105 | import matplotlib.pyplot as plt 106 | for i in range(len(dataset)): 107 | image, mask, name = dataset[i] 108 | print(image, mask, name) 109 | plt.imshow(transforms.ToPILImage()(image)) 110 | plt.show() 111 | plt.imshow(mask.numpy()[1]) 112 | plt.show() 113 | ''' 114 | 115 | # dataloader test 116 | from torch.utils.data import DataLoader 117 | import time 118 | t = time.time() 119 | dataloader = DataLoader(dataset, batch_size=100, shuffle=False, num_workers=4) 120 | for data in dataloader: 121 | pass 122 | print('%ds'%(time.time()-t)) 123 | 124 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -* 2 | from dataset import IDRiD_sub1_dataset 3 | from util import evaluate, save_model, weighted_BCELoss 4 | from model import GCN 5 | import torch 6 | import torch.optim as optim 7 | from torch.optim.lr_scheduler import ReduceLROnPlateau 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader 11 | import time 12 | import copy 13 | import os 14 | 15 | #gcn_v3 with weighted loss, AR only, 256x256 16 | #gcn_v4 with random crop, 256x256 17 | #gcn_v3_2 fine tune with AR and NAR 18 | #gcn_v5 512x512 19 | 20 | use_gpu = torch.cuda.is_available 21 | save_dir = "./saved_models" 22 | model_name = "test.pth" 23 | data_train_dir = './data/sub1/train' 24 | data_val_dir = './data/sub1/val' 25 | batch_size = 6 26 | num_epochs = 100 27 | lr = 1e-4 28 | 29 | def make_dataloaders(batch_size=batch_size): 30 | dataset_train = IDRiD_sub1_dataset(data_train_dir) 31 | dataset_val = IDRiD_sub1_dataset(data_val_dir) 32 | dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=False, num_workers=4) 33 | dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=4) 34 | dataloaders = {'train': dataloader_train, 'val': dataloader_val} 35 | print('Training data: %d\nValidation data: %d'%((len(dataset_train)), len(dataset_val))) 36 | return dataloaders 37 | 38 | def train_model(model, num_epochs, dataloaders, optimizer, scheduler): 39 | since = time.time() 40 | best_model_wts = copy.deepcopy(model.state_dict()) 41 | best_f1 = 0.0 42 | 43 | for epoch in range(num_epochs): 44 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 45 | print('-' * 10) 46 | 47 | for phase in ['train', 'val']: 48 | if phase == 'train': 49 | model.train(True) # Set model to training mode 50 | else: 51 | model.train(False) # Set model to evaluate mode 52 | 53 | running_loss = 0.0 54 | running_f1 = 0.0 55 | data_num = 0 56 | 57 | for idx, data in enumerate(dataloaders[phase]): 58 | images, masks, names = data 59 | 60 | #weight for loss 61 | weights = [5, 1] 62 | if use_gpu: 63 | weights = torch.FloatTensor(weights).cuda() 64 | 65 | 66 | if use_gpu: 67 | images = images.cuda() 68 | masks = masks.cuda() 69 | if phase == 'train': 70 | images, masks = Variable(images, volatile=False), Variable(masks, volatile=False) 71 | else: 72 | images, masks = Variable(images, volatile=True), Variable(masks, volatile=True) 73 | 74 | optimizer.zero_grad() 75 | 76 | #forward 77 | 78 | outputs = model(images) 79 | outputs = F.sigmoid(outputs)#remenber to apply sigmoid befor usage 80 | loss = weighted_BCELoss(outputs, masks, weights) 81 | 82 | #backword 83 | 84 | if phase == 'train': 85 | loss.backward() 86 | optimizer.step() 87 | 88 | # statistics 89 | running_loss += loss.data[0]*images.size(0) 90 | data_num += images.size(0) 91 | outputs = outputs.cpu().data 92 | masks = masks.cpu().data 93 | running_f1 += evaluate(masks, outputs)*images.size(0) 94 | 95 | #verbose 96 | if idx%5==0 and idx!=0: 97 | print('\r{} {:.2f}%'.format(phase, 100*idx/len(dataloaders[phase])), end='\r') 98 | 99 | #print() 100 | epoch_loss = running_loss / data_num 101 | epoch_f1 = running_f1 / data_num 102 | if phase == 'val': 103 | scheduler.step(epoch_loss) 104 | print('{} Loss: {:.4f} F1 score: {:.4f}'.format(phase, epoch_loss, epoch_f1)) 105 | # deep copy the model 106 | if phase == 'val' and epoch_f1 > best_f1: 107 | best_f1 = epoch_f1 108 | best_model_wts = copy.deepcopy(model.state_dict()) 109 | save_model(model, save_dir, model_name) 110 | 111 | print() 112 | 113 | time_elapsed = time.time() - since 114 | print('Training complete in {:.0f}m {:.0f}s'.format( 115 | time_elapsed // 60, time_elapsed % 60)) 116 | 117 | print('Best F1 score: {:.4f}'.format(best_f1)) 118 | 119 | # load best model weights 120 | model.load_state_dict(best_model_wts) 121 | return model 122 | 123 | if __name__ == '__main__': 124 | # dataset 125 | dataloaders = make_dataloaders(batch_size=batch_size) 126 | 127 | #model 128 | model = GCN(4, 512) 129 | if use_gpu: 130 | model = model.cuda() 131 | #model = torch.nn.DataParallel(model).cuda() 132 | model.load_state_dict(torch.load(os.path.join(save_dir, 'gcn_v5.pth'))) 133 | #training 134 | optimizer = optim.Adam(model.parameters(), lr = lr) 135 | scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True) 136 | model = train_model(model, num_epochs, dataloaders, optimizer, scheduler) 137 | 138 | #save 139 | save_model(model, save_dir, model_name) 140 | 141 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -* 2 | from dataset import IDRiD_sub1_dataset 3 | from model import GCN 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader, Dataset 8 | from torchvision import transforms 9 | import numpy as np 10 | from sklearn.metrics import precision_recall_fscore_support, confusion_matrix 11 | from PIL import Image 12 | import matplotlib.pyplot as plt 13 | import os 14 | 15 | use_gpu = torch.cuda.is_available 16 | save_dir = "./saved_models" 17 | model_name = "gcn_v5.pth" 18 | data_dir = './data/sub1/val' 19 | batch_size = 4 20 | 21 | def show_image_sample(): 22 | # dataset 23 | dataset = IDRiD_sub1_dataset(data_dir) 24 | 25 | 26 | #model 27 | model = GCN(4, 512) 28 | if use_gpu: 29 | model = model.cuda() 30 | #model = torch.nn.DataParallel(model).cuda() 31 | model.load_state_dict(torch.load(os.path.join(save_dir, model_name))) 32 | model.train(False) 33 | for n in range(12): 34 | #test 35 | full_image = np.zeros((3, 2848, 4288), dtype='float32') 36 | full_mask = np.zeros((4, 2848, 4288), dtype='float32') 37 | full_output = np.zeros((4, 2848, 4288), dtype='float32')#(C, H, W) 38 | title = '' 39 | for idx in range(9*6*n, 9*6*(n+1)): 40 | image, mask, name = dataset[idx] 41 | n = int(idx/(6*9))#image index 42 | r = int((idx%(6*9))/9)#row 43 | c = (idx%(6*9))%9#column 44 | title = name[:-8] 45 | 46 | if use_gpu: 47 | image = image.cuda() 48 | mask = mask.cuda() 49 | image, mask = Variable(image, volatile=True), Variable(mask, volatile=True) 50 | 51 | #forward 52 | output = model(image.unsqueeze(0)) 53 | output = F.sigmoid(output) 54 | output = output[0] 55 | if c < 8: 56 | if r == 5: 57 | full_output[:, r*512:r*512+512-224, c*512:c*512+512] = output.cpu().data.numpy()[:, :-224, :] 58 | full_mask[:, r*512:r*512+512-224, c*512:c*512+512] = mask.cpu().data.numpy()[:, :-224, :] 59 | full_image[:, r*512:r*512+512-224, c*512:c*512+512] = image.cpu().data.numpy()[:, :-224, :] 60 | 61 | else: 62 | full_output[:, r*512:r*512+512, c*512:c*512+512] = output.cpu().data.numpy() 63 | full_mask[:, r*512:r*512+512, c*512:c*512+512] = mask.cpu().data.numpy() 64 | full_image[:, r*512:r*512+512, c*512:c*512+512] = image.cpu().data.numpy() 65 | 66 | 67 | full_image = full_image.transpose(1, 2, 0) 68 | MA = full_output[0] 69 | EX = full_output[1] 70 | HE = full_output[2] 71 | SE = full_output[3] 72 | 73 | 74 | plt.figure() 75 | plt.axis('off') 76 | plt.suptitle(title) 77 | plt.subplot(331) 78 | plt.title('image') 79 | fig = plt.imshow(full_image) 80 | fig.axes.get_xaxis().set_visible(False) 81 | fig.axes.get_yaxis().set_visible(False) 82 | plt.subplot(332) 83 | plt.title('ground truth MA') 84 | fig = plt.imshow(full_mask[0]) 85 | fig.axes.get_xaxis().set_visible(False) 86 | fig.axes.get_yaxis().set_visible(False) 87 | plt.subplot(333) 88 | plt.title('ground truth EX') 89 | fig = plt.imshow(full_mask[1]) 90 | fig.axes.get_xaxis().set_visible(False) 91 | fig.axes.get_yaxis().set_visible(False) 92 | plt.subplot(334) 93 | plt.title('ground truth HE') 94 | fig = plt.imshow(full_mask[2]) 95 | fig.axes.get_xaxis().set_visible(False) 96 | fig.axes.get_yaxis().set_visible(False) 97 | plt.subplot(335) 98 | plt.title('ground truth SE') 99 | fig = plt.imshow(full_mask[3]) 100 | fig.axes.get_xaxis().set_visible(False) 101 | fig.axes.get_yaxis().set_visible(False) 102 | plt.subplot(336) 103 | plt.title('predict MA') 104 | fig = plt.imshow(MA) 105 | fig.axes.get_xaxis().set_visible(False) 106 | fig.axes.get_yaxis().set_visible(False) 107 | plt.subplot(337) 108 | plt.title('predict EX') 109 | fig = plt.imshow(EX) 110 | fig.axes.get_xaxis().set_visible(False) 111 | fig.axes.get_yaxis().set_visible(False) 112 | plt.subplot(338) 113 | plt.title('predict HE') 114 | fig = plt.imshow(HE) 115 | fig.axes.get_xaxis().set_visible(False) 116 | fig.axes.get_yaxis().set_visible(False) 117 | plt.subplot(339) 118 | plt.title('predict SE') 119 | fig = plt.imshow(SE) 120 | fig.axes.get_xaxis().set_visible(False) 121 | fig.axes.get_yaxis().set_visible(False) 122 | 123 | 124 | plt.show() 125 | 126 | class save_predict_dataset(Dataset): 127 | 128 | 129 | def __init__(self, root_dir): 130 | self.root_dir = root_dir 131 | self.data_idx = []#(image_dir, mask_dirs, name) mask_dirs is a list(None for NAR images) 132 | self.data_cache = {'image': None, 'name': "", 'index': None}#cache the original size image 133 | 134 | 135 | #Get the file index 136 | for filename in os.listdir(root_dir): 137 | image_dir = os.path.join(root_dir, filename) 138 | name = filename[:-4] 139 | self.data_idx.append((image_dir, name)) 140 | 141 | def __len__(self): 142 | return len(self.data_idx)*6*9 143 | 144 | def __getitem__(self, idx): 145 | # crop the 4288x2848 image into 512x512 => 9x6 grid 146 | # 1 image => 9x6 = 54 small images 147 | n = int(idx/(6*9))#image index 148 | r = int((idx%(6*9))/9)#row 149 | c = (idx%(6*9))%9#column 150 | 151 | #Load the images if it's not in the cache 152 | if self.data_cache['index'] != n: 153 | image_dir, name = self.data_idx[n] 154 | image = Image.open(image_dir) 155 | 156 | self.data_cache = {'image': image, 'name': name, 'index': n} 157 | 158 | 159 | #crop the image 160 | 161 | image_crop = self.data_cache['image'].crop((c*512, r*512, c*512 + 512, r*512 + 512)) 162 | image_crop = transforms.ToTensor()(image_crop) 163 | name = self.data_cache['name'] 164 | 165 | return image_crop, name 166 | 167 | def save_output(root_dir, output_dir): 168 | # dataset 169 | dataset = save_predict_dataset(root_dir) 170 | 171 | 172 | #model 173 | model = GCN(4, 512) 174 | if use_gpu: 175 | model = model.cuda() 176 | #model = torch.nn.DataParallel(model).cuda() 177 | model.load_state_dict(torch.load(os.path.join(save_dir, model_name))) 178 | model.train(False) 179 | for n in range(int(len(dataset)/(6*9))): 180 | #test 181 | full_output = np.zeros((4, 2848, 4288), dtype='float32')#(C, H, W) 182 | title = '' 183 | for idx in range(6*9*n, 6*9*(n+1)): 184 | image, name = dataset[idx] 185 | r = int((idx%(6*9))/9)#row 186 | c = (idx%(6*9))%9#column 187 | title = name 188 | 189 | if use_gpu: 190 | image = image.cuda() 191 | image = Variable(image, volatile=True) 192 | 193 | #forward 194 | output = model(image.unsqueeze(0)) 195 | output = F.sigmoid(output) 196 | output = output[0] 197 | 198 | if c < 8: 199 | if r == 5: 200 | full_output[:, r*512:r*512+512-224, c*512:c*512+512] = output.cpu().data.numpy()[:, :-224, :] 201 | else: 202 | full_output[:, r*512:r*512+512, c*512:c*512+512] = output.cpu().data.numpy() 203 | 204 | for i, d in enumerate(['MA', 'EX', 'HE', 'SE']): 205 | if not os.path.exists(os.path.join(output_dir, d)): 206 | os.makedirs(os.path.join(output_dir, d)) 207 | im = np.expand_dims(full_output[i], axis=0).transpose(1, 2, 0) 208 | im = full_output[i]*255 209 | im = np.uint8(im) 210 | im = Image.fromarray(im) 211 | im.save(os.path.join(output_dir, d, title+'.jpg')) 212 | 213 | 214 | def run_statistic(threshold): 215 | ''' 216 | evaluate on small images result 217 | ''' 218 | # dataset 219 | dataset = IDRiD_sub1_dataset(data_dir) 220 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) 221 | #print('Data: %d'%(len(dataset))) 222 | 223 | #model 224 | model = GCN(4, 512) 225 | if use_gpu: 226 | model = model.cuda() 227 | #model = torch.nn.DataParallel(model).cuda() 228 | model.load_state_dict(torch.load(os.path.join(save_dir, model_name))) 229 | model.train(False) 230 | for i in range(4): 231 | y_pred_list = [] 232 | y_true_list = [] 233 | for idx, data in enumerate(dataloader): 234 | images, masks, names = data 235 | 236 | if use_gpu: 237 | images = images.cuda() 238 | masks = masks.cuda() 239 | images, masks = Variable(images, volatile=True), Variable(masks, volatile=True) 240 | 241 | #forward 242 | outputs = model(images) 243 | 244 | # statistics 245 | outputs = F.sigmoid(outputs).cpu().data#remenber to apply sigmoid befor usage 246 | masks = masks.cpu().data 247 | #for i in range(len(outputs)): 248 | y_pred = outputs[i] 249 | y_true = masks[i] 250 | y_pred = y_pred.numpy().flatten() 251 | y_pred = np.where(y_pred > threshold, 1, 0) 252 | y_true = y_true.numpy().flatten() 253 | y_pred_list.append(y_pred) 254 | y_true_list.append(y_true) 255 | 256 | #verbose 257 | if idx%5==0 and idx!=0: 258 | print('\r{:.2f}%'.format(100*idx/len(dataloader)), end='\r') 259 | #print() 260 | type_list = ['MA', 'EX', 'HE', 'SE'] 261 | precision, recall, f1, _ = precision_recall_fscore_support(np.array(y_true_list).flatten(), np.array(y_pred_list).flatten(), average='binary') 262 | print('{} \nThreshold: {:.2f}\nPrecision: {:.4f}\nRecall: {:.4f}\nF1: {:.4f}'.format(type_list[i], threshold, precision, recall, f1)) 263 | 264 | def evaluate(threshold): 265 | ''' 266 | evaluate results with original image size 267 | ''' 268 | task_type_list = ['MA', 'EX', 'HE', 'SE'] 269 | result = [] 270 | print('-------------------') 271 | for i in range(4): 272 | print('--------') 273 | mean = [0, 0, 0, 0] 274 | 275 | for filename in os.listdir('./data/sub1/val/Apparent Retinopathy'): 276 | gt_dirs = {task_type:None for task_type in task_type_list} 277 | for task_type in task_type_list: 278 | m_dir = os.path.join('./data/sub1/val', task_type, filename[:-4]+'_'+task_type+'.tif') 279 | if os.path.isfile(m_dir): gt_dirs[task_type] = m_dir 280 | pd_dirs = {task_type:None for task_type in task_type_list} 281 | for task_type in task_type_list: 282 | m_dir = os.path.join('./data/sub1/predict', task_type, filename[:-4]+'.jpg') 283 | if os.path.isfile(m_dir): pd_dirs[task_type] = m_dir 284 | gts = [] 285 | for task_type in task_type_list: 286 | mask = Image.open(gt_dirs[task_type]) 287 | mask = np.array(mask, dtype='float32') 288 | gts.append(mask) 289 | gts = np.array(gts[i]) 290 | pds = [] 291 | for task_type in task_type_list: 292 | mask = Image.open(pd_dirs[task_type]) 293 | mask = np.array(mask, dtype='float32') 294 | mask /= 255 295 | pds.append(mask) 296 | pds = np.array(pds[i]) 297 | pds = np.where(pds > threshold, 1, 0) 298 | tn, fp, fn, tp = confusion_matrix(gts.flatten(), pds.flatten()).ravel() 299 | ppv = tp/(tp+fp) 300 | sensitivity = tp/(tp+fn) 301 | specificity = tn/(tn+fp) 302 | f1 = (2*tp)/(2*tp+fp+fn) 303 | 304 | mean[0] += ppv 305 | mean[1] += sensitivity 306 | mean[2] += specificity 307 | mean[3] += f1 308 | 309 | for filename in os.listdir('./data/sub1/val/No Apparent Retinopathy'): 310 | gt_dirs = {task_type:None for task_type in task_type_list} 311 | for task_type in task_type_list: 312 | m_dir = os.path.join('./data/sub1/val', task_type, filename[:-4]+'_'+task_type+'.tif') 313 | if os.path.isfile(m_dir): gt_dirs[task_type] = m_dir 314 | pd_dirs = {task_type:None for task_type in task_type_list} 315 | for task_type in task_type_list: 316 | m_dir = os.path.join('./data/sub1/predict', task_type, filename[:-4]+'.jpg') 317 | if os.path.isfile(m_dir): pd_dirs[task_type] = m_dir 318 | 319 | gts = [] 320 | for task_type in task_type_list: 321 | mask = mask = np.zeros((2848, 4288), dtype='float32') 322 | gts.append(mask) 323 | gts = np.array(gts[i]) 324 | 325 | pds = [] 326 | for task_type in task_type_list: 327 | mask = Image.open(pd_dirs[task_type]) 328 | mask = np.array(mask, dtype='float32') 329 | mask /= 255 330 | pds.append(mask) 331 | pds = np.array(pds[i]) 332 | pds = np.where(pds > threshold, 1, 0) 333 | 334 | 335 | try: 336 | tn, fp, fn, tp = confusion_matrix(gts.flatten(), pds.flatten()).ravel() 337 | ppv = 0 338 | sensitivity = 0 339 | specificity = tn/(tn+fp) 340 | f1 = 0 341 | except: 342 | ppv = 0 343 | sensitivity = 0 344 | specificity = 0 345 | f1 = 0 346 | 347 | mean[0] += ppv 348 | mean[1] += sensitivity 349 | mean[2] += specificity 350 | mean[3] += f1 351 | 352 | print(task_type_list[i]) 353 | print('Threshold: {:.2f}\nPPV: {:.4f}\nSensitivity: {:.4f}\nSpecificity: {:.4f}\nF1: {:.4f}'.format(threshold, mean[0]/6, mean[1]/6, mean[2]/12, mean[3]/6)) 354 | result.append((mean[0]/6, mean[1]/6, mean[2]/12, mean[3]/6)) 355 | 356 | avg = [0, 0, 0, 0] 357 | for r in result: 358 | for i in range(4): 359 | avg[i]+=r[i] 360 | 361 | print("---------") 362 | print('Average') 363 | print('Threshold: {:.2f}\nPPV: {:.4f}\nSensitivity: {:.4f}\nSpecificity: {:.4f}\nF1: {:.4f}'.format(threshold, avg[0]/4, avg[1]/4, avg[2]/4, avg[3]/4)) 364 | 365 | if __name__ == '__main__': 366 | #save_output('./data/sub1/val/Apparent Retinopathy', '../data/sub1/predict') 367 | #print(model_name) 368 | #run_statistic(0.3) 369 | show_image_sample() 370 | ''' 371 | for th in [0.3]: 372 | evaluate(th) 373 | ''' 374 | 375 | --------------------------------------------------------------------------------