├── README.md ├── UNet.sh ├── data_loader.py ├── dataset.py ├── evaluation.py ├── img ├── AttR2U-Net.png ├── AttU-Net.png ├── Evaluation.png ├── R2U-Net.png └── U-Net.png ├── main.py ├── misc.py ├── network.py └── solver.py /README.md: -------------------------------------------------------------------------------- 1 | ### pytorch Implementation of U-Net, R2U-Net, Attention U-Net, Attention R2U-Net 2 | 3 | **(This repository is no longer being updated)** 4 | 5 | **U-Net: Convolutional Networks for Biomedical Image Segmentation** 6 | 7 | https://arxiv.org/abs/1505.04597 8 | 9 | **Recurrent Residual Convolutional Neural Network based on U-Net (R2U-Net) for Medical Image Segmentation** 10 | 11 | https://arxiv.org/abs/1802.06955 12 | 13 | **Attention U-Net: Learning Where to Look for the Pancreas** 14 | 15 | https://arxiv.org/abs/1804.03999 16 | 17 | **Attention R2U-Net : Just integration of two recent advanced works (R2U-Net + Attention U-Net)** 18 | 19 | 20 | ## U-Net 21 | ![U-Net](/img/U-Net.png) 22 | 23 | 24 | ## R2U-Net 25 | ![R2U-Net](/img/R2U-Net.png) 26 | 27 | ## Attention U-Net 28 | ![AttU-Net](/img/AttU-Net.png) 29 | 30 | ## Attention R2U-Net 31 | ![AttR2U-Net](/img/AttR2U-Net.png) 32 | 33 | ## Evaluation 34 | we just test the models with [ISIC 2018 dataset](https://challenge2018.isic-archive.com/task1/training/). The dataset was split into three subsets, training set, validation set, and test set, which the proportion is 70%, 10% and 20% of the whole dataset, respectively. The entire dataset contains 2594 images where 1815 images were used 35 | for training, 259 for validation and 520 for testing models. 36 | 37 | ![evaluation](/img/Evaluation.png) 38 | -------------------------------------------------------------------------------- /UNet.sh: -------------------------------------------------------------------------------- 1 | 2 | for ((i=0;i<100;i++));do 3 | a='U_Net' 4 | python3 main.py --model_type=$a 5 | a='R2U_Net' 6 | python3 main.py --model_type=$a 7 | a='AttU_Net' 8 | python3 main.py --model_type=$a 9 | a='R2AttU_Net' 10 | python3 main.py --model_type=$a 11 | 12 | done 13 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from random import shuffle 4 | import numpy as np 5 | import torch 6 | from torch.utils import data 7 | from torchvision import transforms as T 8 | from torchvision.transforms import functional as F 9 | from PIL import Image 10 | 11 | class ImageFolder(data.Dataset): 12 | def __init__(self, root,image_size=224,mode='train',augmentation_prob=0.4): 13 | """Initializes image paths and preprocessing module.""" 14 | self.root = root 15 | 16 | # GT : Ground Truth 17 | self.GT_paths = root[:-1]+'_GT/' 18 | self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root))) 19 | self.image_size = image_size 20 | self.mode = mode 21 | self.RotationDegree = [0,90,180,270] 22 | self.augmentation_prob = augmentation_prob 23 | print("image count in {} path :{}".format(self.mode,len(self.image_paths))) 24 | 25 | def __getitem__(self, index): 26 | """Reads an image from a file and preprocesses it and returns.""" 27 | image_path = self.image_paths[index] 28 | filename = image_path.split('_')[-1][:-len(".jpg")] 29 | GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png' 30 | 31 | image = Image.open(image_path) 32 | GT = Image.open(GT_path) 33 | 34 | aspect_ratio = image.size[1]/image.size[0] 35 | 36 | Transform = [] 37 | 38 | ResizeRange = random.randint(300,320) 39 | Transform.append(T.Resize((int(ResizeRange*aspect_ratio),ResizeRange))) 40 | p_transform = random.random() 41 | 42 | if (self.mode == 'train') and p_transform <= self.augmentation_prob: 43 | RotationDegree = random.randint(0,3) 44 | RotationDegree = self.RotationDegree[RotationDegree] 45 | if (RotationDegree == 90) or (RotationDegree == 270): 46 | aspect_ratio = 1/aspect_ratio 47 | 48 | Transform.append(T.RandomRotation((RotationDegree,RotationDegree))) 49 | 50 | RotationRange = random.randint(-10,10) 51 | Transform.append(T.RandomRotation((RotationRange,RotationRange))) 52 | CropRange = random.randint(250,270) 53 | Transform.append(T.CenterCrop((int(CropRange*aspect_ratio),CropRange))) 54 | Transform = T.Compose(Transform) 55 | 56 | image = Transform(image) 57 | GT = Transform(GT) 58 | 59 | ShiftRange_left = random.randint(0,20) 60 | ShiftRange_upper = random.randint(0,20) 61 | ShiftRange_right = image.size[0] - random.randint(0,20) 62 | ShiftRange_lower = image.size[1] - random.randint(0,20) 63 | image = image.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower)) 64 | GT = GT.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower)) 65 | 66 | if random.random() < 0.5: 67 | image = F.hflip(image) 68 | GT = F.hflip(GT) 69 | 70 | if random.random() < 0.5: 71 | image = F.vflip(image) 72 | GT = F.vflip(GT) 73 | 74 | Transform = T.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02) 75 | 76 | image = Transform(image) 77 | 78 | Transform =[] 79 | 80 | 81 | Transform.append(T.Resize((int(256*aspect_ratio)-int(256*aspect_ratio)%16,256))) 82 | Transform.append(T.ToTensor()) 83 | Transform = T.Compose(Transform) 84 | 85 | image = Transform(image) 86 | GT = Transform(GT) 87 | 88 | Norm_ = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 89 | image = Norm_(image) 90 | 91 | return image, GT 92 | 93 | def __len__(self): 94 | """Returns the total number of font files.""" 95 | return len(self.image_paths) 96 | 97 | def get_loader(image_path, image_size, batch_size, num_workers=2, mode='train',augmentation_prob=0.4): 98 | """Builds and returns Dataloader.""" 99 | 100 | dataset = ImageFolder(root = image_path, image_size =image_size, mode=mode,augmentation_prob=augmentation_prob) 101 | data_loader = data.DataLoader(dataset=dataset, 102 | batch_size=batch_size, 103 | shuffle=True, 104 | num_workers=num_workers) 105 | return data_loader 106 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import shutil 5 | from shutil import copyfile 6 | from misc import printProgressBar 7 | 8 | 9 | def rm_mkdir(dir_path): 10 | if os.path.exists(dir_path): 11 | shutil.rmtree(dir_path) 12 | print('Remove path - %s'%dir_path) 13 | os.makedirs(dir_path) 14 | print('Create path - %s'%dir_path) 15 | 16 | def main(config): 17 | 18 | rm_mkdir(config.train_path) 19 | rm_mkdir(config.train_GT_path) 20 | rm_mkdir(config.valid_path) 21 | rm_mkdir(config.valid_GT_path) 22 | rm_mkdir(config.test_path) 23 | rm_mkdir(config.test_GT_path) 24 | 25 | filenames = os.listdir(config.origin_data_path) 26 | data_list = [] 27 | GT_list = [] 28 | 29 | for filename in filenames: 30 | ext = os.path.splitext(filename)[-1] 31 | if ext =='.jpg': 32 | filename = filename.split('_')[-1][:-len('.jpg')] 33 | data_list.append('ISIC_'+filename+'.jpg') 34 | GT_list.append('ISIC_'+filename+'_segmentation.png') 35 | 36 | num_total = len(data_list) 37 | num_train = int((config.train_ratio/(config.train_ratio+config.valid_ratio+config.test_ratio))*num_total) 38 | num_valid = int((config.valid_ratio/(config.train_ratio+config.valid_ratio+config.test_ratio))*num_total) 39 | num_test = num_total - num_train - num_valid 40 | 41 | print('\nNum of train set : ',num_train) 42 | print('\nNum of valid set : ',num_valid) 43 | print('\nNum of test set : ',num_test) 44 | 45 | Arange = list(range(num_total)) 46 | random.shuffle(Arange) 47 | 48 | for i in range(num_train): 49 | idx = Arange.pop() 50 | 51 | src = os.path.join(config.origin_data_path, data_list[idx]) 52 | dst = os.path.join(config.train_path,data_list[idx]) 53 | copyfile(src, dst) 54 | 55 | src = os.path.join(config.origin_GT_path, GT_list[idx]) 56 | dst = os.path.join(config.train_GT_path, GT_list[idx]) 57 | copyfile(src, dst) 58 | 59 | printProgressBar(i + 1, num_train, prefix = 'Producing train set:', suffix = 'Complete', length = 50) 60 | 61 | 62 | for i in range(num_valid): 63 | idx = Arange.pop() 64 | 65 | src = os.path.join(config.origin_data_path, data_list[idx]) 66 | dst = os.path.join(config.valid_path,data_list[idx]) 67 | copyfile(src, dst) 68 | 69 | src = os.path.join(config.origin_GT_path, GT_list[idx]) 70 | dst = os.path.join(config.valid_GT_path, GT_list[idx]) 71 | copyfile(src, dst) 72 | 73 | printProgressBar(i + 1, num_valid, prefix = 'Producing valid set:', suffix = 'Complete', length = 50) 74 | 75 | for i in range(num_test): 76 | idx = Arange.pop() 77 | 78 | src = os.path.join(config.origin_data_path, data_list[idx]) 79 | dst = os.path.join(config.test_path,data_list[idx]) 80 | copyfile(src, dst) 81 | 82 | src = os.path.join(config.origin_GT_path, GT_list[idx]) 83 | dst = os.path.join(config.test_GT_path, GT_list[idx]) 84 | copyfile(src, dst) 85 | 86 | 87 | printProgressBar(i + 1, num_test, prefix = 'Producing test set:', suffix = 'Complete', length = 50) 88 | 89 | if __name__ == '__main__': 90 | parser = argparse.ArgumentParser() 91 | 92 | 93 | # model hyper-parameters 94 | parser.add_argument('--train_ratio', type=float, default=0.6) 95 | parser.add_argument('--valid_ratio', type=float, default=0.2) 96 | parser.add_argument('--test_ratio', type=float, default=0.2) 97 | 98 | # data path 99 | parser.add_argument('--origin_data_path', type=str, default='../ISIC/dataset/ISIC2018_Task1-2_Training_Input') 100 | parser.add_argument('--origin_GT_path', type=str, default='../ISIC/dataset/ISIC2018_Task1_Training_GroundTruth') 101 | 102 | parser.add_argument('--train_path', type=str, default='./dataset/train/') 103 | parser.add_argument('--train_GT_path', type=str, default='./dataset/train_GT/') 104 | parser.add_argument('--valid_path', type=str, default='./dataset/valid/') 105 | parser.add_argument('--valid_GT_path', type=str, default='./dataset/valid_GT/') 106 | parser.add_argument('--test_path', type=str, default='./dataset/test/') 107 | parser.add_argument('--test_GT_path', type=str, default='./dataset/test_GT/') 108 | 109 | config = parser.parse_args() 110 | print(config) 111 | main(config) -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # SR : Segmentation Result 4 | # GT : Ground Truth 5 | 6 | def get_accuracy(SR,GT,threshold=0.5): 7 | SR = SR > threshold 8 | GT = GT == torch.max(GT) 9 | corr = torch.sum(SR==GT) 10 | tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3) 11 | acc = float(corr)/float(tensor_size) 12 | 13 | return acc 14 | 15 | def get_sensitivity(SR,GT,threshold=0.5): 16 | # Sensitivity == Recall 17 | SR = SR > threshold 18 | GT = GT == torch.max(GT) 19 | 20 | # TP : True Positive 21 | # FN : False Negative 22 | TP = ((SR==1)+(GT==1))==2 23 | FN = ((SR==0)+(GT==1))==2 24 | 25 | SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6) 26 | 27 | return SE 28 | 29 | def get_specificity(SR,GT,threshold=0.5): 30 | SR = SR > threshold 31 | GT = GT == torch.max(GT) 32 | 33 | # TN : True Negative 34 | # FP : False Positive 35 | TN = ((SR==0)+(GT==0))==2 36 | FP = ((SR==1)+(GT==0))==2 37 | 38 | SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6) 39 | 40 | return SP 41 | 42 | def get_precision(SR,GT,threshold=0.5): 43 | SR = SR > threshold 44 | GT = GT == torch.max(GT) 45 | 46 | # TP : True Positive 47 | # FP : False Positive 48 | TP = ((SR==1)+(GT==1))==2 49 | FP = ((SR==1)+(GT==0))==2 50 | 51 | PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6) 52 | 53 | return PC 54 | 55 | def get_F1(SR,GT,threshold=0.5): 56 | # Sensitivity == Recall 57 | SE = get_sensitivity(SR,GT,threshold=threshold) 58 | PC = get_precision(SR,GT,threshold=threshold) 59 | 60 | F1 = 2*SE*PC/(SE+PC + 1e-6) 61 | 62 | return F1 63 | 64 | def get_JS(SR,GT,threshold=0.5): 65 | # JS : Jaccard similarity 66 | SR = SR > threshold 67 | GT = GT == torch.max(GT) 68 | 69 | Inter = torch.sum((SR+GT)==2) 70 | Union = torch.sum((SR+GT)>=1) 71 | 72 | JS = float(Inter)/(float(Union) + 1e-6) 73 | 74 | return JS 75 | 76 | def get_DC(SR,GT,threshold=0.5): 77 | # DC : Dice Coefficient 78 | SR = SR > threshold 79 | GT = GT == torch.max(GT) 80 | 81 | Inter = torch.sum((SR+GT)==2) 82 | DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6) 83 | 84 | return DC 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /img/AttR2U-Net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeeJunHyun/Image_Segmentation/5e9da9395c52b119d55dfc6532c34ac0e88f446e/img/AttR2U-Net.png -------------------------------------------------------------------------------- /img/AttU-Net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeeJunHyun/Image_Segmentation/5e9da9395c52b119d55dfc6532c34ac0e88f446e/img/AttU-Net.png -------------------------------------------------------------------------------- /img/Evaluation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeeJunHyun/Image_Segmentation/5e9da9395c52b119d55dfc6532c34ac0e88f446e/img/Evaluation.png -------------------------------------------------------------------------------- /img/R2U-Net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeeJunHyun/Image_Segmentation/5e9da9395c52b119d55dfc6532c34ac0e88f446e/img/R2U-Net.png -------------------------------------------------------------------------------- /img/U-Net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeeJunHyun/Image_Segmentation/5e9da9395c52b119d55dfc6532c34ac0e88f446e/img/U-Net.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from solver import Solver 4 | from data_loader import get_loader 5 | from torch.backends import cudnn 6 | import random 7 | 8 | def main(config): 9 | cudnn.benchmark = True 10 | if config.model_type not in ['U_Net','R2U_Net','AttU_Net','R2AttU_Net']: 11 | print('ERROR!! model_type should be selected in U_Net/R2U_Net/AttU_Net/R2AttU_Net') 12 | print('Your input for model_type was %s'%config.model_type) 13 | return 14 | 15 | # Create directories if not exist 16 | if not os.path.exists(config.model_path): 17 | os.makedirs(config.model_path) 18 | if not os.path.exists(config.result_path): 19 | os.makedirs(config.result_path) 20 | config.result_path = os.path.join(config.result_path,config.model_type) 21 | if not os.path.exists(config.result_path): 22 | os.makedirs(config.result_path) 23 | 24 | lr = random.random()*0.0005 + 0.0000005 25 | augmentation_prob= random.random()*0.7 26 | epoch = random.choice([100,150,200,250]) 27 | decay_ratio = random.random()*0.8 28 | decay_epoch = int(epoch*decay_ratio) 29 | 30 | config.augmentation_prob = augmentation_prob 31 | config.num_epochs = epoch 32 | config.lr = lr 33 | config.num_epochs_decay = decay_epoch 34 | 35 | print(config) 36 | 37 | train_loader = get_loader(image_path=config.train_path, 38 | image_size=config.image_size, 39 | batch_size=config.batch_size, 40 | num_workers=config.num_workers, 41 | mode='train', 42 | augmentation_prob=config.augmentation_prob) 43 | valid_loader = get_loader(image_path=config.valid_path, 44 | image_size=config.image_size, 45 | batch_size=config.batch_size, 46 | num_workers=config.num_workers, 47 | mode='valid', 48 | augmentation_prob=0.) 49 | test_loader = get_loader(image_path=config.test_path, 50 | image_size=config.image_size, 51 | batch_size=config.batch_size, 52 | num_workers=config.num_workers, 53 | mode='test', 54 | augmentation_prob=0.) 55 | 56 | solver = Solver(config, train_loader, valid_loader, test_loader) 57 | 58 | 59 | # Train and sample the images 60 | if config.mode == 'train': 61 | solver.train() 62 | elif config.mode == 'test': 63 | solver.test() 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | 69 | 70 | # model hyper-parameters 71 | parser.add_argument('--image_size', type=int, default=224) 72 | parser.add_argument('--t', type=int, default=3, help='t for Recurrent step of R2U_Net or R2AttU_Net') 73 | 74 | # training hyper-parameters 75 | parser.add_argument('--img_ch', type=int, default=3) 76 | parser.add_argument('--output_ch', type=int, default=1) 77 | parser.add_argument('--num_epochs', type=int, default=100) 78 | parser.add_argument('--num_epochs_decay', type=int, default=70) 79 | parser.add_argument('--batch_size', type=int, default=1) 80 | parser.add_argument('--num_workers', type=int, default=8) 81 | parser.add_argument('--lr', type=float, default=0.0002) 82 | parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam 83 | parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam 84 | parser.add_argument('--augmentation_prob', type=float, default=0.4) 85 | 86 | parser.add_argument('--log_step', type=int, default=2) 87 | parser.add_argument('--val_step', type=int, default=2) 88 | 89 | # misc 90 | parser.add_argument('--mode', type=str, default='train') 91 | parser.add_argument('--model_type', type=str, default='U_Net', help='U_Net/R2U_Net/AttU_Net/R2AttU_Net') 92 | parser.add_argument('--model_path', type=str, default='./models') 93 | parser.add_argument('--train_path', type=str, default='./dataset/train/') 94 | parser.add_argument('--valid_path', type=str, default='./dataset/valid/') 95 | parser.add_argument('--test_path', type=str, default='./dataset/test/') 96 | parser.add_argument('--result_path', type=str, default='./result/') 97 | 98 | parser.add_argument('--cuda_idx', type=int, default=1) 99 | 100 | config = parser.parse_args() 101 | main(config) 102 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | 2 | def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'): 3 | """ 4 | Call in a loop to create terminal progress bar 5 | @params: 6 | iteration - Required : current iteration (Int) 7 | total - Required : total iterations (Int) 8 | prefix - Optional : prefix string (Str) 9 | suffix - Optional : suffix string (Str) 10 | decimals - Optional : positive number of decimals in percent complete (Int) 11 | length - Optional : character length of bar (Int) 12 | fill - Optional : bar fill character (Str) 13 | """ 14 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) 15 | filledLength = int(length * iteration // total) 16 | bar = fill * filledLength + '-' * (length - filledLength) 17 | print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r') 18 | # Print New Line on Complete 19 | if iteration == total: 20 | print() -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | def init_weights(net, init_type='normal', gain=0.02): 7 | def init_func(m): 8 | classname = m.__class__.__name__ 9 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 10 | if init_type == 'normal': 11 | init.normal_(m.weight.data, 0.0, gain) 12 | elif init_type == 'xavier': 13 | init.xavier_normal_(m.weight.data, gain=gain) 14 | elif init_type == 'kaiming': 15 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 16 | elif init_type == 'orthogonal': 17 | init.orthogonal_(m.weight.data, gain=gain) 18 | else: 19 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 20 | if hasattr(m, 'bias') and m.bias is not None: 21 | init.constant_(m.bias.data, 0.0) 22 | elif classname.find('BatchNorm2d') != -1: 23 | init.normal_(m.weight.data, 1.0, gain) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | print('initialize network with %s' % init_type) 27 | net.apply(init_func) 28 | 29 | class conv_block(nn.Module): 30 | def __init__(self,ch_in,ch_out): 31 | super(conv_block,self).__init__() 32 | self.conv = nn.Sequential( 33 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 34 | nn.BatchNorm2d(ch_out), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 37 | nn.BatchNorm2d(ch_out), 38 | nn.ReLU(inplace=True) 39 | ) 40 | 41 | 42 | def forward(self,x): 43 | x = self.conv(x) 44 | return x 45 | 46 | class up_conv(nn.Module): 47 | def __init__(self,ch_in,ch_out): 48 | super(up_conv,self).__init__() 49 | self.up = nn.Sequential( 50 | nn.Upsample(scale_factor=2), 51 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 52 | nn.BatchNorm2d(ch_out), 53 | nn.ReLU(inplace=True) 54 | ) 55 | 56 | def forward(self,x): 57 | x = self.up(x) 58 | return x 59 | 60 | class Recurrent_block(nn.Module): 61 | def __init__(self,ch_out,t=2): 62 | super(Recurrent_block,self).__init__() 63 | self.t = t 64 | self.ch_out = ch_out 65 | self.conv = nn.Sequential( 66 | nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 67 | nn.BatchNorm2d(ch_out), 68 | nn.ReLU(inplace=True) 69 | ) 70 | 71 | def forward(self,x): 72 | for i in range(self.t): 73 | 74 | if i==0: 75 | x1 = self.conv(x) 76 | 77 | x1 = self.conv(x+x1) 78 | return x1 79 | 80 | class RRCNN_block(nn.Module): 81 | def __init__(self,ch_in,ch_out,t=2): 82 | super(RRCNN_block,self).__init__() 83 | self.RCNN = nn.Sequential( 84 | Recurrent_block(ch_out,t=t), 85 | Recurrent_block(ch_out,t=t) 86 | ) 87 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0) 88 | 89 | def forward(self,x): 90 | x = self.Conv_1x1(x) 91 | x1 = self.RCNN(x) 92 | return x+x1 93 | 94 | 95 | class single_conv(nn.Module): 96 | def __init__(self,ch_in,ch_out): 97 | super(single_conv,self).__init__() 98 | self.conv = nn.Sequential( 99 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 100 | nn.BatchNorm2d(ch_out), 101 | nn.ReLU(inplace=True) 102 | ) 103 | 104 | def forward(self,x): 105 | x = self.conv(x) 106 | return x 107 | 108 | class Attention_block(nn.Module): 109 | def __init__(self,F_g,F_l,F_int): 110 | super(Attention_block,self).__init__() 111 | self.W_g = nn.Sequential( 112 | nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), 113 | nn.BatchNorm2d(F_int) 114 | ) 115 | 116 | self.W_x = nn.Sequential( 117 | nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), 118 | nn.BatchNorm2d(F_int) 119 | ) 120 | 121 | self.psi = nn.Sequential( 122 | nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), 123 | nn.BatchNorm2d(1), 124 | nn.Sigmoid() 125 | ) 126 | 127 | self.relu = nn.ReLU(inplace=True) 128 | 129 | def forward(self,g,x): 130 | g1 = self.W_g(g) 131 | x1 = self.W_x(x) 132 | psi = self.relu(g1+x1) 133 | psi = self.psi(psi) 134 | 135 | return x*psi 136 | 137 | 138 | class U_Net(nn.Module): 139 | def __init__(self,img_ch=3,output_ch=1): 140 | super(U_Net,self).__init__() 141 | 142 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 143 | 144 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 145 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 146 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 147 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 148 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 149 | 150 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 151 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 152 | 153 | self.Up4 = up_conv(ch_in=512,ch_out=256) 154 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 155 | 156 | self.Up3 = up_conv(ch_in=256,ch_out=128) 157 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 158 | 159 | self.Up2 = up_conv(ch_in=128,ch_out=64) 160 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 161 | 162 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 163 | 164 | 165 | def forward(self,x): 166 | # encoding path 167 | x1 = self.Conv1(x) 168 | 169 | x2 = self.Maxpool(x1) 170 | x2 = self.Conv2(x2) 171 | 172 | x3 = self.Maxpool(x2) 173 | x3 = self.Conv3(x3) 174 | 175 | x4 = self.Maxpool(x3) 176 | x4 = self.Conv4(x4) 177 | 178 | x5 = self.Maxpool(x4) 179 | x5 = self.Conv5(x5) 180 | 181 | # decoding + concat path 182 | d5 = self.Up5(x5) 183 | d5 = torch.cat((x4,d5),dim=1) 184 | 185 | d5 = self.Up_conv5(d5) 186 | 187 | d4 = self.Up4(d5) 188 | d4 = torch.cat((x3,d4),dim=1) 189 | d4 = self.Up_conv4(d4) 190 | 191 | d3 = self.Up3(d4) 192 | d3 = torch.cat((x2,d3),dim=1) 193 | d3 = self.Up_conv3(d3) 194 | 195 | d2 = self.Up2(d3) 196 | d2 = torch.cat((x1,d2),dim=1) 197 | d2 = self.Up_conv2(d2) 198 | 199 | d1 = self.Conv_1x1(d2) 200 | 201 | return d1 202 | 203 | 204 | class R2U_Net(nn.Module): 205 | def __init__(self,img_ch=3,output_ch=1,t=2): 206 | super(R2U_Net,self).__init__() 207 | 208 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 209 | self.Upsample = nn.Upsample(scale_factor=2) 210 | 211 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t) 212 | 213 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t) 214 | 215 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t) 216 | 217 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t) 218 | 219 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t) 220 | 221 | 222 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 223 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t) 224 | 225 | self.Up4 = up_conv(ch_in=512,ch_out=256) 226 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t) 227 | 228 | self.Up3 = up_conv(ch_in=256,ch_out=128) 229 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t) 230 | 231 | self.Up2 = up_conv(ch_in=128,ch_out=64) 232 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t) 233 | 234 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 235 | 236 | 237 | def forward(self,x): 238 | # encoding path 239 | x1 = self.RRCNN1(x) 240 | 241 | x2 = self.Maxpool(x1) 242 | x2 = self.RRCNN2(x2) 243 | 244 | x3 = self.Maxpool(x2) 245 | x3 = self.RRCNN3(x3) 246 | 247 | x4 = self.Maxpool(x3) 248 | x4 = self.RRCNN4(x4) 249 | 250 | x5 = self.Maxpool(x4) 251 | x5 = self.RRCNN5(x5) 252 | 253 | # decoding + concat path 254 | d5 = self.Up5(x5) 255 | d5 = torch.cat((x4,d5),dim=1) 256 | d5 = self.Up_RRCNN5(d5) 257 | 258 | d4 = self.Up4(d5) 259 | d4 = torch.cat((x3,d4),dim=1) 260 | d4 = self.Up_RRCNN4(d4) 261 | 262 | d3 = self.Up3(d4) 263 | d3 = torch.cat((x2,d3),dim=1) 264 | d3 = self.Up_RRCNN3(d3) 265 | 266 | d2 = self.Up2(d3) 267 | d2 = torch.cat((x1,d2),dim=1) 268 | d2 = self.Up_RRCNN2(d2) 269 | 270 | d1 = self.Conv_1x1(d2) 271 | 272 | return d1 273 | 274 | 275 | 276 | class AttU_Net(nn.Module): 277 | def __init__(self,img_ch=3,output_ch=1): 278 | super(AttU_Net,self).__init__() 279 | 280 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 281 | 282 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 283 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 284 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 285 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 286 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 287 | 288 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 289 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256) 290 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 291 | 292 | self.Up4 = up_conv(ch_in=512,ch_out=256) 293 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128) 294 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 295 | 296 | self.Up3 = up_conv(ch_in=256,ch_out=128) 297 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64) 298 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 299 | 300 | self.Up2 = up_conv(ch_in=128,ch_out=64) 301 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32) 302 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 303 | 304 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 305 | 306 | 307 | def forward(self,x): 308 | # encoding path 309 | x1 = self.Conv1(x) 310 | 311 | x2 = self.Maxpool(x1) 312 | x2 = self.Conv2(x2) 313 | 314 | x3 = self.Maxpool(x2) 315 | x3 = self.Conv3(x3) 316 | 317 | x4 = self.Maxpool(x3) 318 | x4 = self.Conv4(x4) 319 | 320 | x5 = self.Maxpool(x4) 321 | x5 = self.Conv5(x5) 322 | 323 | # decoding + concat path 324 | d5 = self.Up5(x5) 325 | x4 = self.Att5(g=d5,x=x4) 326 | d5 = torch.cat((x4,d5),dim=1) 327 | d5 = self.Up_conv5(d5) 328 | 329 | d4 = self.Up4(d5) 330 | x3 = self.Att4(g=d4,x=x3) 331 | d4 = torch.cat((x3,d4),dim=1) 332 | d4 = self.Up_conv4(d4) 333 | 334 | d3 = self.Up3(d4) 335 | x2 = self.Att3(g=d3,x=x2) 336 | d3 = torch.cat((x2,d3),dim=1) 337 | d3 = self.Up_conv3(d3) 338 | 339 | d2 = self.Up2(d3) 340 | x1 = self.Att2(g=d2,x=x1) 341 | d2 = torch.cat((x1,d2),dim=1) 342 | d2 = self.Up_conv2(d2) 343 | 344 | d1 = self.Conv_1x1(d2) 345 | 346 | return d1 347 | 348 | 349 | class R2AttU_Net(nn.Module): 350 | def __init__(self,img_ch=3,output_ch=1,t=2): 351 | super(R2AttU_Net,self).__init__() 352 | 353 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 354 | self.Upsample = nn.Upsample(scale_factor=2) 355 | 356 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t) 357 | 358 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t) 359 | 360 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t) 361 | 362 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t) 363 | 364 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t) 365 | 366 | 367 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 368 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256) 369 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t) 370 | 371 | self.Up4 = up_conv(ch_in=512,ch_out=256) 372 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128) 373 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t) 374 | 375 | self.Up3 = up_conv(ch_in=256,ch_out=128) 376 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64) 377 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t) 378 | 379 | self.Up2 = up_conv(ch_in=128,ch_out=64) 380 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32) 381 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t) 382 | 383 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 384 | 385 | 386 | def forward(self,x): 387 | # encoding path 388 | x1 = self.RRCNN1(x) 389 | 390 | x2 = self.Maxpool(x1) 391 | x2 = self.RRCNN2(x2) 392 | 393 | x3 = self.Maxpool(x2) 394 | x3 = self.RRCNN3(x3) 395 | 396 | x4 = self.Maxpool(x3) 397 | x4 = self.RRCNN4(x4) 398 | 399 | x5 = self.Maxpool(x4) 400 | x5 = self.RRCNN5(x5) 401 | 402 | # decoding + concat path 403 | d5 = self.Up5(x5) 404 | x4 = self.Att5(g=d5,x=x4) 405 | d5 = torch.cat((x4,d5),dim=1) 406 | d5 = self.Up_RRCNN5(d5) 407 | 408 | d4 = self.Up4(d5) 409 | x3 = self.Att4(g=d4,x=x3) 410 | d4 = torch.cat((x3,d4),dim=1) 411 | d4 = self.Up_RRCNN4(d4) 412 | 413 | d3 = self.Up3(d4) 414 | x2 = self.Att3(g=d3,x=x2) 415 | d3 = torch.cat((x2,d3),dim=1) 416 | d3 = self.Up_RRCNN3(d3) 417 | 418 | d2 = self.Up2(d3) 419 | x1 = self.Att2(g=d2,x=x1) 420 | d2 = torch.cat((x1,d2),dim=1) 421 | d2 = self.Up_RRCNN2(d2) 422 | 423 | d1 = self.Conv_1x1(d2) 424 | 425 | return d1 426 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import datetime 5 | import torch 6 | import torchvision 7 | from torch import optim 8 | from torch.autograd import Variable 9 | import torch.nn.functional as F 10 | from evaluation import * 11 | from network import U_Net,R2U_Net,AttU_Net,R2AttU_Net 12 | import csv 13 | 14 | 15 | class Solver(object): 16 | def __init__(self, config, train_loader, valid_loader, test_loader): 17 | 18 | # Data loader 19 | self.train_loader = train_loader 20 | self.valid_loader = valid_loader 21 | self.test_loader = test_loader 22 | 23 | # Models 24 | self.unet = None 25 | self.optimizer = None 26 | self.img_ch = config.img_ch 27 | self.output_ch = config.output_ch 28 | self.criterion = torch.nn.BCELoss() 29 | self.augmentation_prob = config.augmentation_prob 30 | 31 | # Hyper-parameters 32 | self.lr = config.lr 33 | self.beta1 = config.beta1 34 | self.beta2 = config.beta2 35 | 36 | # Training settings 37 | self.num_epochs = config.num_epochs 38 | self.num_epochs_decay = config.num_epochs_decay 39 | self.batch_size = config.batch_size 40 | 41 | # Step size 42 | self.log_step = config.log_step 43 | self.val_step = config.val_step 44 | 45 | # Path 46 | self.model_path = config.model_path 47 | self.result_path = config.result_path 48 | self.mode = config.mode 49 | 50 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 51 | self.model_type = config.model_type 52 | self.t = config.t 53 | self.build_model() 54 | 55 | def build_model(self): 56 | """Build generator and discriminator.""" 57 | if self.model_type =='U_Net': 58 | self.unet = U_Net(img_ch=3,output_ch=1) 59 | elif self.model_type =='R2U_Net': 60 | self.unet = R2U_Net(img_ch=3,output_ch=1,t=self.t) 61 | elif self.model_type =='AttU_Net': 62 | self.unet = AttU_Net(img_ch=3,output_ch=1) 63 | elif self.model_type == 'R2AttU_Net': 64 | self.unet = R2AttU_Net(img_ch=3,output_ch=1,t=self.t) 65 | 66 | 67 | self.optimizer = optim.Adam(list(self.unet.parameters()), 68 | self.lr, [self.beta1, self.beta2]) 69 | self.unet.to(self.device) 70 | 71 | # self.print_network(self.unet, self.model_type) 72 | 73 | def print_network(self, model, name): 74 | """Print out the network information.""" 75 | num_params = 0 76 | for p in model.parameters(): 77 | num_params += p.numel() 78 | print(model) 79 | print(name) 80 | print("The number of parameters: {}".format(num_params)) 81 | 82 | def to_data(self, x): 83 | """Convert variable to tensor.""" 84 | if torch.cuda.is_available(): 85 | x = x.cpu() 86 | return x.data 87 | 88 | def update_lr(self, g_lr, d_lr): 89 | for param_group in self.optimizer.param_groups: 90 | param_group['lr'] = lr 91 | 92 | def reset_grad(self): 93 | """Zero the gradient buffers.""" 94 | self.unet.zero_grad() 95 | 96 | def compute_accuracy(self,SR,GT): 97 | SR_flat = SR.view(-1) 98 | GT_flat = GT.view(-1) 99 | 100 | acc = GT_flat.data.cpu()==(SR_flat.data.cpu()>0.5) 101 | 102 | def tensor2img(self,x): 103 | img = (x[:,0,:,:]>x[:,1,:,:]).float() 104 | img = img*255 105 | return img 106 | 107 | 108 | def train(self): 109 | """Train encoder, generator and discriminator.""" 110 | 111 | #====================================== Training ===========================================# 112 | #===========================================================================================# 113 | 114 | unet_path = os.path.join(self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %(self.model_type,self.num_epochs,self.lr,self.num_epochs_decay,self.augmentation_prob)) 115 | 116 | # U-Net Train 117 | if os.path.isfile(unet_path): 118 | # Load the pretrained Encoder 119 | self.unet.load_state_dict(torch.load(unet_path)) 120 | print('%s is Successfully Loaded from %s'%(self.model_type,unet_path)) 121 | else: 122 | # Train for Encoder 123 | lr = self.lr 124 | best_unet_score = 0. 125 | 126 | for epoch in range(self.num_epochs): 127 | 128 | self.unet.train(True) 129 | epoch_loss = 0 130 | 131 | acc = 0. # Accuracy 132 | SE = 0. # Sensitivity (Recall) 133 | SP = 0. # Specificity 134 | PC = 0. # Precision 135 | F1 = 0. # F1 Score 136 | JS = 0. # Jaccard Similarity 137 | DC = 0. # Dice Coefficient 138 | length = 0 139 | 140 | for i, (images, GT) in enumerate(self.train_loader): 141 | # GT : Ground Truth 142 | 143 | images = images.to(self.device) 144 | GT = GT.to(self.device) 145 | 146 | # SR : Segmentation Result 147 | SR = self.unet(images) 148 | SR_probs = F.sigmoid(SR) 149 | SR_flat = SR_probs.view(SR_probs.size(0),-1) 150 | 151 | GT_flat = GT.view(GT.size(0),-1) 152 | loss = self.criterion(SR_flat,GT_flat) 153 | epoch_loss += loss.item() 154 | 155 | # Backprop + optimize 156 | self.reset_grad() 157 | loss.backward() 158 | self.optimizer.step() 159 | 160 | acc += get_accuracy(SR,GT) 161 | SE += get_sensitivity(SR,GT) 162 | SP += get_specificity(SR,GT) 163 | PC += get_precision(SR,GT) 164 | F1 += get_F1(SR,GT) 165 | JS += get_JS(SR,GT) 166 | DC += get_DC(SR,GT) 167 | length += images.size(0) 168 | 169 | acc = acc/length 170 | SE = SE/length 171 | SP = SP/length 172 | PC = PC/length 173 | F1 = F1/length 174 | JS = JS/length 175 | DC = DC/length 176 | 177 | # Print the log info 178 | print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % ( 179 | epoch+1, self.num_epochs, \ 180 | epoch_loss,\ 181 | acc,SE,SP,PC,F1,JS,DC)) 182 | 183 | 184 | 185 | # Decay learning rate 186 | if (epoch+1) > (self.num_epochs - self.num_epochs_decay): 187 | lr -= (self.lr / float(self.num_epochs_decay)) 188 | for param_group in self.optimizer.param_groups: 189 | param_group['lr'] = lr 190 | print ('Decay learning rate to lr: {}.'.format(lr)) 191 | 192 | 193 | #===================================== Validation ====================================# 194 | self.unet.train(False) 195 | self.unet.eval() 196 | 197 | acc = 0. # Accuracy 198 | SE = 0. # Sensitivity (Recall) 199 | SP = 0. # Specificity 200 | PC = 0. # Precision 201 | F1 = 0. # F1 Score 202 | JS = 0. # Jaccard Similarity 203 | DC = 0. # Dice Coefficient 204 | length=0 205 | for i, (images, GT) in enumerate(self.valid_loader): 206 | 207 | images = images.to(self.device) 208 | GT = GT.to(self.device) 209 | SR = F.sigmoid(self.unet(images)) 210 | acc += get_accuracy(SR,GT) 211 | SE += get_sensitivity(SR,GT) 212 | SP += get_specificity(SR,GT) 213 | PC += get_precision(SR,GT) 214 | F1 += get_F1(SR,GT) 215 | JS += get_JS(SR,GT) 216 | DC += get_DC(SR,GT) 217 | 218 | length += images.size(0) 219 | 220 | acc = acc/length 221 | SE = SE/length 222 | SP = SP/length 223 | PC = PC/length 224 | F1 = F1/length 225 | JS = JS/length 226 | DC = DC/length 227 | unet_score = JS + DC 228 | 229 | print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'%(acc,SE,SP,PC,F1,JS,DC)) 230 | 231 | ''' 232 | torchvision.utils.save_image(images.data.cpu(), 233 | os.path.join(self.result_path, 234 | '%s_valid_%d_image.png'%(self.model_type,epoch+1))) 235 | torchvision.utils.save_image(SR.data.cpu(), 236 | os.path.join(self.result_path, 237 | '%s_valid_%d_SR.png'%(self.model_type,epoch+1))) 238 | torchvision.utils.save_image(GT.data.cpu(), 239 | os.path.join(self.result_path, 240 | '%s_valid_%d_GT.png'%(self.model_type,epoch+1))) 241 | ''' 242 | 243 | 244 | # Save Best U-Net model 245 | if unet_score > best_unet_score: 246 | best_unet_score = unet_score 247 | best_epoch = epoch 248 | best_unet = self.unet.state_dict() 249 | print('Best %s model score : %.4f'%(self.model_type,best_unet_score)) 250 | torch.save(best_unet,unet_path) 251 | 252 | #===================================== Test ====================================# 253 | del self.unet 254 | del best_unet 255 | self.build_model() 256 | self.unet.load_state_dict(torch.load(unet_path)) 257 | 258 | self.unet.train(False) 259 | self.unet.eval() 260 | 261 | acc = 0. # Accuracy 262 | SE = 0. # Sensitivity (Recall) 263 | SP = 0. # Specificity 264 | PC = 0. # Precision 265 | F1 = 0. # F1 Score 266 | JS = 0. # Jaccard Similarity 267 | DC = 0. # Dice Coefficient 268 | length=0 269 | for i, (images, GT) in enumerate(self.valid_loader): 270 | 271 | images = images.to(self.device) 272 | GT = GT.to(self.device) 273 | SR = F.sigmoid(self.unet(images)) 274 | acc += get_accuracy(SR,GT) 275 | SE += get_sensitivity(SR,GT) 276 | SP += get_specificity(SR,GT) 277 | PC += get_precision(SR,GT) 278 | F1 += get_F1(SR,GT) 279 | JS += get_JS(SR,GT) 280 | DC += get_DC(SR,GT) 281 | 282 | length += images.size(0) 283 | 284 | acc = acc/length 285 | SE = SE/length 286 | SP = SP/length 287 | PC = PC/length 288 | F1 = F1/length 289 | JS = JS/length 290 | DC = DC/length 291 | unet_score = JS + DC 292 | 293 | 294 | f = open(os.path.join(self.result_path,'result.csv'), 'a', encoding='utf-8', newline='') 295 | wr = csv.writer(f) 296 | wr.writerow([self.model_type,acc,SE,SP,PC,F1,JS,DC,self.lr,best_epoch,self.num_epochs,self.num_epochs_decay,self.augmentation_prob]) 297 | f.close() 298 | 299 | 300 | 301 | --------------------------------------------------------------------------------