├── misc.py
├── evaluation.py
├── data_loader.py
├── dataset.py
├── main.py
├── README.md
├── solver.py
└── network.py
/misc.py:
--------------------------------------------------------------------------------
1 |
2 | def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'):
3 | """
4 | Call in a loop to create terminal progress bar
5 | @params:
6 | iteration - Required : current iteration (Int)
7 | total - Required : total iterations (Int)
8 | prefix - Optional : prefix string (Str)
9 | suffix - Optional : suffix string (Str)
10 | decimals - Optional : positive number of decimals in percent complete (Int)
11 | length - Optional : character length of bar (Int)
12 | fill - Optional : bar fill character (Str)
13 | """
14 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
15 | filledLength = int(length * iteration // total)
16 | bar = fill * filledLength + '-' * (length - filledLength)
17 | print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r')
18 | # Print New Line on Complete
19 | if iteration == total:
20 | print()
21 |
22 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def binary_threshold(tensor, threshold=0.5):
4 | return torch.where(tensor > threshold, torch.tensor(1), torch.tensor(0))
5 |
6 |
7 | def get_accuracy(SR, GT, threshold=0.5):
8 | SR_binary = binary_threshold(SR, threshold)
9 | GT_binary = GT == torch.max(GT)
10 |
11 | correct_pixels = torch.sum(SR_binary == GT_binary)
12 | total_pixels = SR.numel()
13 | accuracy = float(correct_pixels) / float(total_pixels)
14 |
15 | return accuracy
16 |
17 | def get_sensitivity(SR, GT, threshold=0.5):
18 | SR_binary = binary_threshold(SR, threshold)
19 | GT_binary = GT == torch.max(GT)
20 |
21 | TP = ((SR_binary == 1) & (GT_binary == 1)).float()
22 | FN = ((SR_binary == 0) & (GT_binary == 1)).float()
23 |
24 | sensitivity = torch.sum(TP) / (torch.sum(TP + FN) + 1e-6)
25 | return sensitivity
26 |
27 | def get_specificity(SR, GT, threshold=0.5):
28 | SR_binary = binary_threshold(SR, threshold)
29 | GT_binary = GT == torch.max(GT)
30 |
31 | TN = ((SR_binary == 0) & (GT_binary == 0)).float()
32 | FP = ((SR_binary == 1) & (GT_binary == 0)).float()
33 |
34 | specificity = torch.sum(TN) / (torch.sum(TN + FP) + 1e-6)
35 | return specificity
36 |
37 | def get_precision(SR, GT, threshold=0.5):
38 | SR_binary = binary_threshold(SR, threshold)
39 | GT_binary = GT == torch.max(GT)
40 |
41 | TP = ((SR_binary == 1) & (GT_binary == 1)).float()
42 | FP = ((SR_binary == 1) & (GT_binary == 0)).float()
43 |
44 | precision = torch.sum(TP) / (torch.sum(TP + FP) + 1e-6)
45 | return precision
46 |
47 | def get_F1(SR, GT, threshold=0.5):
48 | sensitivity = get_sensitivity(SR, GT, threshold=threshold)
49 | precision = get_precision(SR, GT, threshold=threshold)
50 |
51 | F1 = 2 * sensitivity * precision / (sensitivity + precision + 1e-6)
52 | return F1
53 |
54 | def get_JS(SR, GT, threshold=0.5):
55 | SR_binary = binary_threshold(SR, threshold)
56 | GT_binary = GT == torch.max(GT)
57 |
58 | intersection = torch.sum((SR_binary + GT_binary) == 2)
59 | union = torch.sum((SR_binary + GT_binary) >= 1)
60 |
61 | JS = float(intersection) / (float(union) + 1e-6)
62 | return JS
63 |
64 | def get_DC(SR, GT, threshold=0.5):
65 | SR_binary = binary_threshold(SR, threshold)
66 | GT_binary = GT == torch.max(GT)
67 |
68 | intersection = torch.sum((SR_binary + GT_binary) == 2)
69 | dice_coefficient = float(2 * intersection) / (float(torch.sum(SR_binary) + torch.sum(GT_binary)) + 1e-6)
70 |
71 | return dice_coefficient
72 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from random import shuffle
4 | import numpy as np
5 | import torch
6 | from torch.utils import data
7 | from torchvision import transforms as T
8 | from torchvision.transforms import functional as F
9 | from PIL import Image
10 |
11 |
12 | class ImageFolder(data.Dataset):
13 | def __init__(self, root,image_size=224,mode='train',augmentation_prob=0.4):
14 | """Initializes image paths and preprocessing module."""
15 | self.root = root
16 |
17 | # GT : Ground Truth
18 | self.GT_paths = root[:-1]+'_GT/'
19 | self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))
20 | self.image_size = image_size
21 | self.mode = mode
22 | self.RotationDegree = [0,90,180,270]
23 | self.augmentation_prob = augmentation_prob
24 | print("image count in {} path :{}".format(self.mode,len(self.image_paths)))
25 |
26 | def __getitem__(self, index):
27 | """Reads an image from a file and preprocesses it and returns."""
28 | image_path = self.image_paths[index]
29 | filename = image_path.split('_')[-1][:-len(".jpg")]
30 | GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png'
31 |
32 | image = Image.open(image_path)
33 | GT = Image.open(GT_path)
34 |
35 | aspect_ratio = image.size[1]/image.size[0]
36 |
37 | Transform = []
38 |
39 | ResizeRange = random.randint(300,320)
40 | Transform.append(T.Resize((int(ResizeRange*aspect_ratio),ResizeRange)))
41 | p_transform = random.random()
42 |
43 | if (self.mode == 'train') and p_transform <= self.augmentation_prob:
44 | RotationDegree = random.randint(0,3)
45 | RotationDegree = self.RotationDegree[RotationDegree]
46 | if (RotationDegree == 90) or (RotationDegree == 270):
47 | aspect_ratio = 1/aspect_ratio
48 |
49 | Transform.append(T.RandomRotation((RotationDegree,RotationDegree)))
50 |
51 | RotationRange = random.randint(-10,10)
52 | Transform.append(T.RandomRotation((RotationRange,RotationRange)))
53 | CropRange = random.randint(250,270)
54 | Transform.append(T.CenterCrop((int(CropRange*aspect_ratio),CropRange)))
55 | Transform = T.Compose(Transform)
56 |
57 | image = Transform(image)
58 | GT = Transform(GT)
59 |
60 | ShiftRange_left = random.randint(0,20)
61 | ShiftRange_upper = random.randint(0,20)
62 | ShiftRange_right = image.size[0] - random.randint(0,20)
63 | ShiftRange_lower = image.size[1] - random.randint(0,20)
64 | image = image.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
65 | GT = GT.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
66 |
67 | if random.random() < 0.5:
68 | image = F.hflip(image)
69 | GT = F.hflip(GT)
70 |
71 | if random.random() < 0.5:
72 | image = F.vflip(image)
73 | GT = F.vflip(GT)
74 |
75 | Transform = T.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02)
76 |
77 | image = Transform(image)
78 |
79 | Transform =[]
80 |
81 |
82 | Transform.append(T.Resize((int(256*aspect_ratio)-int(256*aspect_ratio)%16,256)))
83 | Transform.append(T.ToTensor())
84 | Transform = T.Compose(Transform)
85 |
86 | image = Transform(image)
87 | GT = Transform(GT)
88 |
89 | Norm_ = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
90 | image = Norm_(image)
91 |
92 | return image, GT
93 |
94 | def __len__(self):
95 | """Returns the total number of font files."""
96 | return len(self.image_paths)
97 |
98 | def get_loader(image_path, image_size, batch_size, num_workers=2, mode='train',augmentation_prob=0.4):
99 | """Builds and returns Dataloader."""
100 |
101 | dataset = ImageFolder(root = image_path, image_size =image_size, mode=mode,augmentation_prob=augmentation_prob)
102 | data_loader = data.DataLoader(dataset=dataset,
103 | batch_size=batch_size,
104 | shuffle=True,
105 | num_workers=num_workers)
106 | return data_loader
107 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import random
4 | import shutil
5 | from shutil import copyfile
6 | from misc import printProgressBar
7 |
8 |
9 | def rm_mkdir(dir_path):
10 | if os.path.exists(dir_path):
11 | shutil.rmtree(dir_path)
12 | print('Remove path - %s'%dir_path)
13 | os.makedirs(dir_path)
14 | print('Create path - %s'%dir_path)
15 |
16 | def main(config):
17 |
18 | rm_mkdir(config.train_path)
19 | rm_mkdir(config.train_GT_path)
20 | rm_mkdir(config.valid_path)
21 | rm_mkdir(config.valid_GT_path)
22 | rm_mkdir(config.test_path)
23 | rm_mkdir(config.test_GT_path)
24 |
25 | filenames = os.listdir(config.origin_data_path)
26 | data_list = []
27 | GT_list = []
28 |
29 | for filename in filenames:
30 | ext = os.path.splitext(filename)[-1]
31 | if ext =='.jpg':
32 | filename = filename.split('_')[-1][:-len('.jpg')]
33 | data_list.append('ISIC_'+filename+'.jpg')
34 | GT_list.append('ISIC_'+filename+'_segmentation.png')
35 |
36 | num_total = len(data_list)
37 | num_train = int((config.train_ratio/(config.train_ratio+config.valid_ratio+config.test_ratio))*num_total)
38 | num_valid = int((config.valid_ratio/(config.train_ratio+config.valid_ratio+config.test_ratio))*num_total)
39 | num_test = num_total - num_train - num_valid
40 |
41 | print('\nNum of train set : ',num_train)
42 | print('\nNum of valid set : ',num_valid)
43 | print('\nNum of test set : ',num_test)
44 |
45 | Arange = list(range(num_total))
46 | random.shuffle(Arange)
47 |
48 | for i in range(num_train):
49 | idx = Arange.pop()
50 |
51 | src = os.path.join(config.origin_data_path, data_list[idx])
52 | dst = os.path.join(config.train_path,data_list[idx])
53 | copyfile(src, dst)
54 |
55 | src = os.path.join(config.origin_GT_path, GT_list[idx])
56 | dst = os.path.join(config.train_GT_path, GT_list[idx])
57 | copyfile(src, dst)
58 |
59 | printProgressBar(i + 1, num_train, prefix = 'Producing train set:', suffix = 'Complete', length = 50)
60 |
61 |
62 | for i in range(num_valid):
63 | idx = Arange.pop()
64 |
65 | src = os.path.join(config.origin_data_path, data_list[idx])
66 | dst = os.path.join(config.valid_path,data_list[idx])
67 | copyfile(src, dst)
68 |
69 | src = os.path.join(config.origin_GT_path, GT_list[idx])
70 | dst = os.path.join(config.valid_GT_path, GT_list[idx])
71 | copyfile(src, dst)
72 |
73 | printProgressBar(i + 1, num_valid, prefix = 'Producing valid set:', suffix = 'Complete', length = 50)
74 |
75 | for i in range(num_test):
76 | idx = Arange.pop()
77 |
78 | src = os.path.join(config.origin_data_path, data_list[idx])
79 | dst = os.path.join(config.test_path,data_list[idx])
80 | copyfile(src, dst)
81 |
82 | src = os.path.join(config.origin_GT_path, GT_list[idx])
83 | dst = os.path.join(config.test_GT_path, GT_list[idx])
84 | copyfile(src, dst)
85 |
86 |
87 | printProgressBar(i + 1, num_test, prefix = 'Producing test set:', suffix = 'Complete', length = 50)
88 |
89 | if __name__ == '__main__':
90 | parser = argparse.ArgumentParser()
91 |
92 | # model hyper-parameters
93 | parser.add_argument('--train_ratio', type=float, default=0.6)
94 | parser.add_argument('--valid_ratio', type=float, default=0.2)
95 | parser.add_argument('--test_ratio', type=float, default=0.2)
96 |
97 | # data path
98 | parser.add_argument('--origin_data_path', type=str, default='../ISIC/dataset/ISIC2018_Task1-2_Training_Input')
99 | parser.add_argument('--origin_GT_path', type=str, default='../ISIC/dataset/ISIC2018_Task1_Training_GroundTruth')
100 |
101 | parser.add_argument('--train_path', type=str, default='./dataset/train/')
102 | parser.add_argument('--train_GT_path', type=str, default='./dataset/train_GT/')
103 | parser.add_argument('--valid_path', type=str, default='./dataset/valid/')
104 | parser.add_argument('--valid_GT_path', type=str, default='./dataset/valid_GT/')
105 | parser.add_argument('--test_path', type=str, default='./dataset/test/')
106 | parser.add_argument('--test_GT_path', type=str, default='./dataset/test_GT/')
107 |
108 | config = parser.parse_args()
109 | print(config)
110 | main(config)
111 |
112 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from solver import Solver
4 | from data_loader import get_loader
5 | from torch.backends import cudnn
6 | import random
7 |
8 | def main(config):
9 | cudnn.benchmark = True
10 | # 修改类型
11 | config.model_type='R2U_Net'
12 |
13 | if config.model_type not in ['U_Net','R2U_Net','AttU_Net','R2AttU_Net']:
14 | print('ERROR!! model_type should be selected in U_Net/R2U_Net/AttU_Net/R2AttU_Net')
15 | print('Your input for model_type was %s'%config.model_type)
16 | return
17 |
18 | # Create directories if not exist
19 | if not os.path.exists(config.model_path):
20 | os.makedirs(config.model_path)
21 | if not os.path.exists(config.result_path):
22 | os.makedirs(config.result_path)
23 | config.result_path = os.path.join(config.result_path,config.model_type)
24 | if not os.path.exists(config.result_path):
25 | os.makedirs(config.result_path)
26 |
27 | # 修改参数
28 | lr = random.random()*0.0005 + 0.0000005
29 | augmentation_prob= random.random()*0.7
30 | epoch = random.choice([3,5,7])
31 | decay_ratio = random.random()*0.8
32 | decay_epoch = int(epoch*decay_ratio)
33 |
34 | config.augmentation_prob = augmentation_prob
35 | config.num_epochs = epoch
36 | config.lr = lr
37 | config.num_epochs_decay = decay_epoch
38 |
39 | print(config)
40 |
41 | train_loader = get_loader(image_path=config.train_path,
42 | image_size=config.image_size,
43 | batch_size=config.batch_size,
44 | num_workers=config.num_workers,
45 | mode='train',
46 | augmentation_prob=config.augmentation_prob)
47 | valid_loader = get_loader(image_path=config.valid_path,
48 | image_size=config.image_size,
49 | batch_size=config.batch_size,
50 | num_workers=config.num_workers,
51 | mode='valid',
52 | augmentation_prob=0.)
53 | test_loader = get_loader(image_path=config.test_path,
54 | image_size=config.image_size,
55 | batch_size=config.batch_size,
56 | num_workers=config.num_workers,
57 | mode='test',
58 | augmentation_prob=0.)
59 |
60 | solver = Solver(config, train_loader, valid_loader, test_loader)
61 |
62 | # Train and sample the images
63 | if config.mode == 'train':
64 | solver.train()
65 | elif config.mode == 'test':
66 | solver.test()
67 |
68 | if __name__ == '__main__':
69 | parser = argparse.ArgumentParser()
70 |
71 | # model hyper-parameters
72 | parser.add_argument('--image_size', type=int, default=224)
73 | parser.add_argument('--t', type=int, default=3, help='t for Recurrent step of R2U_Net or R2AttU_Net')
74 |
75 | # training hyper-parameters
76 | parser.add_argument('--img_ch', type=int, default=3)
77 | parser.add_argument('--output_ch', type=int, default=1)
78 | parser.add_argument('--num_epochs', type=int, default=100)
79 | parser.add_argument('--num_epochs_decay', type=int, default=70)
80 | parser.add_argument('--batch_size', type=int, default=1)
81 | parser.add_argument('--num_workers', type=int, default=8)
82 | parser.add_argument('--lr', type=float, default=0.0002)
83 | parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam
84 | parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam
85 | parser.add_argument('--augmentation_prob', type=float, default=0.4)
86 |
87 | parser.add_argument('--log_step', type=int, default=2)
88 | parser.add_argument('--val_step', type=int, default=2)
89 |
90 | # misc
91 | parser.add_argument('--mode', type=str, default='train')
92 | parser.add_argument('--model_type', type=str, default='U_Net', help='U_Net/R2U_Net/AttU_Net/R2AttU_Net')
93 | parser.add_argument('--model_path', type=str, default='./models')
94 | parser.add_argument('--train_path', type=str, default='./dataset/train/')
95 | parser.add_argument('--valid_path', type=str, default='./dataset/valid/')
96 | parser.add_argument('--test_path', type=str, default='./dataset/test/')
97 | parser.add_argument('--result_path', type=str, default='./result/')
98 | parser.add_argument('--cuda_idx', type=int, default=1)
99 |
100 | config = parser.parse_args()
101 | main(config)
102 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## 该项目使用PyTorch实现了U-Net、R2U-Net、Attention U-Net以及Attention R2U-Net模型的训练。同时,对这四个模型的关键参数进行了详细的分析和比较,旨在更全面地评估各个模型的优缺点。
3 | 注1:为了防止代码运行中出现路径检索错误,请将项目下载至新建的**ISIC**文件目录之下

注2:自行创建**dataset**、**models**、**result**文件夹
注3:**训练结果**文件夹为本人实验所得数据,仅供参考
注4:运行环境**python3.9**

4 |
5 |
6 | # 数据集下载
7 | 数据集使用**ISIC-2018**数据集:[https://challenge.isic-archive.com/data/#2018](https://challenge.isic-archive.com/data/#2018)
数据集被分为三个子集,分别为训练集、验证集和测试集,其比例分别占整个数据集的70%、10%和20%。整个数据集包含 2594 张图像,其中 1815 张图像用于训练,259 张用于验证,520 张用于测试模型。
**Step1**:下载如下框选文件

**Step2**:解压缩后文件名如下所示,并将两个文件夹存放至dataset文件目录之下,无需其他操作

8 |
9 |
10 | # 代码运行步骤
11 | **Step1**:单独运行**dataset.py**文件,对数据集进行处理
**Step2**:配置**main.py**文件


**Step3**:每次运行完毕检查输出结果

12 |
13 | # 模型结构介绍
14 |
15 | ## U-Net
16 | [](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/img/U-Net.png)
17 |
18 | ## R2U-Net
19 | [](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/img/R2U-Net.png)
20 |
21 | ## Attention U-Net
22 | [](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/img/AttU-Net.png)
23 |
24 | ## Attention R2U-Net
25 | [](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/img/AttR2U-Net.png)
26 |
27 | # 模型评估结果
28 | [](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/img/Evaluation.png)
29 |
--------------------------------------------------------------------------------
/solver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import time
4 | import datetime
5 | import torch
6 | import torchvision
7 | from torch import optim
8 | from torch.autograd import Variable
9 | import torch.nn.functional as F
10 | from evaluation import *
11 | from network import U_Net,R2U_Net,AttU_Net,R2AttU_Net
12 | import csv
13 |
14 |
15 | class Solver(object):
16 | best_unet = None
17 | best_unet_score = 0.
18 | best_epoch = 0
19 | def __init__(self, config, train_loader, valid_loader, test_loader):
20 |
21 | # Data loader
22 | self.train_loader = train_loader
23 | self.valid_loader = valid_loader
24 | self.test_loader = test_loader
25 |
26 | # Models
27 | self.unet = None
28 | self.optimizer = None
29 | self.img_ch = config.img_ch
30 | self.output_ch = config.output_ch
31 | self.criterion = torch.nn.BCELoss()
32 | self.augmentation_prob = config.augmentation_prob
33 |
34 | # Hyper-parameters
35 | self.lr = config.lr
36 | self.beta1 = config.beta1
37 | self.beta2 = config.beta2
38 |
39 | # Training settings
40 | self.num_epochs = config.num_epochs
41 | self.num_epochs_decay = config.num_epochs_decay
42 | self.batch_size = config.batch_size
43 |
44 | # Step size
45 | self.log_step = config.log_step
46 | self.val_step = config.val_step
47 |
48 | # Path
49 | self.model_path = config.model_path
50 | self.result_path = config.result_path
51 | self.mode = config.mode
52 |
53 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54 | self.model_type = config.model_type
55 | self.t = config.t
56 | self.build_model()
57 |
58 | def build_model(self):
59 | """Build generator and discriminator."""
60 | if self.model_type =='U_Net':
61 | self.unet = U_Net(img_ch=3,output_ch=1)
62 | elif self.model_type =='R2U_Net':
63 | self.unet = R2U_Net(img_ch=3,output_ch=1,t=self.t)
64 | elif self.model_type =='AttU_Net':
65 | self.unet = AttU_Net(img_ch=3,output_ch=1)
66 | elif self.model_type == 'R2AttU_Net':
67 | self.unet = R2AttU_Net(img_ch=3,output_ch=1,t=self.t)
68 |
69 | self.optimizer = optim.Adam(list(self.unet.parameters()),
70 | self.lr, [self.beta1, self.beta2])
71 | self.unet.to(self.device)
72 |
73 | # self.print_network(self.unet, self.model_type)
74 |
75 | def print_network(self, model, name):
76 | """Print out the network information."""
77 | num_params = 0
78 | for p in model.parameters():
79 | num_params += p.numel()
80 | print(model)
81 | print(name)
82 | print("The number of parameters: {}".format(num_params))
83 |
84 | def to_data(self, x):
85 | """Convert variable to tensor."""
86 | if torch.cuda.is_available():
87 | x = x.cpu()
88 | return x.data
89 |
90 | '''
91 | 在您的代码中,update_lr 函数的目的是更新优化器的学习率。然而,在函数中有一个错误。让我为您解释一下:
92 |
93 | python
94 | Copy code
95 | def update_lr(self, g_lr, d_lr):
96 | for param_group in self.optimizer.param_groups:
97 | param_group['lr'] = lr
98 |
99 | 问题出在这一行:
100 |
101 | python
102 | Copy code
103 | param_group['lr'] = lr
104 | 您尝试将学习率设置为未定义的变量 lr。应该将其更正为:
105 |
106 | python
107 | Copy code
108 | param_group['lr'] = g_lr # or d_lr, depending on which one you want to use
109 | 这样,您可以选择将学习率更新为 g_lr 或 d_lr,具体取决于您的需求。如果您希望同时更新生成器和判别器的学习率,可以将 update_lr 函数修改为接受一个学习率参数:
110 |
111 | python
112 | Copy code
113 | def update_lr(self, lr):
114 | for param_group in self.optimizer.param_groups:
115 | param_group['lr'] = lr
116 | 然后,在调用此函数时,传递您希望使用的学习率值。
117 | '''
118 | def update_lr(self,lr):
119 | for param_group in self.optimizer.param_groups:
120 | param_group['lr'] = lr
121 |
122 | def reset_grad(self):
123 | """Zero the gradient buffers."""
124 | self.unet.zero_grad()
125 |
126 | def compute_accuracy(self,SR,GT):
127 | SR_flat = SR.view(-1)
128 | GT_flat = GT.view(-1)
129 |
130 | acc = GT_flat.data.cpu()==(SR_flat.data.cpu()>0.5)
131 |
132 | def tensor2img(self,x):
133 | img = (x[:,0,:,:]>x[:,1,:,:]).float()
134 | img = img*255
135 | return img
136 |
137 |
138 |
139 | def train(self):
140 | """Train encoder, generator and discriminator."""
141 |
142 | #====================================== Training ===========================================#
143 |
144 | global best_epoch
145 | unet_path = os.path.join(self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %(self.model_type,self.num_epochs,self.lr,self.num_epochs_decay,self.augmentation_prob))
146 |
147 | # U-Net Train
148 | if os.path.isfile(unet_path):
149 | # Load the pretrained Encoder
150 | self.unet.load_state_dict(torch.load(unet_path))
151 | print('%s is Successfully Loaded from %s'%(self.model_type,unet_path))
152 | else:
153 | # Train for Encoder
154 | lr = self.lr
155 | best_unet_score = 0.
156 | best_unet = None
157 |
158 | for epoch in range(self.num_epochs):
159 |
160 | self.unet.train(True)
161 | epoch_loss = 0
162 |
163 | acc = 0. # Accuracy
164 | SE = 0. # Sensitivity (Recall)
165 | SP = 0. # Specificity
166 | PC = 0. # Precision
167 | F1 = 0. # F1 Score
168 | JS = 0. # Jaccard Similarity
169 | DC = 0. # Dice Coefficient
170 | length = 0
171 |
172 | for i, (images, GT) in enumerate(self.train_loader):
173 | # GT : Ground Truth
174 |
175 | images = images.to(self.device)
176 | GT = GT.to(self.device)
177 |
178 | # SR : Segmentation Result
179 | SR = self.unet(images)
180 | SR_probs = F.sigmoid(SR)
181 | SR_flat = SR_probs.view(SR_probs.size(0),-1)
182 |
183 | GT_flat = GT.view(GT.size(0),-1)
184 | loss = self.criterion(SR_flat,GT_flat)
185 | epoch_loss += loss.item()
186 |
187 | # Backprop + optimize
188 | self.reset_grad()
189 | loss.backward()
190 | self.optimizer.step()
191 |
192 | acc += get_accuracy(SR,GT)
193 | SE += get_sensitivity(SR,GT)
194 | SP += get_specificity(SR,GT)
195 | PC += get_precision(SR,GT)
196 | F1 += get_F1(SR,GT)
197 | JS += get_JS(SR,GT)
198 | DC += get_DC(SR,GT)
199 | length += images.size(0)
200 |
201 | acc = acc/length
202 | SE = SE/length
203 | SP = SP/length
204 | PC = PC/length
205 | F1 = F1/length
206 | JS = JS/length
207 | DC = DC/length
208 |
209 | # Print the log info
210 | print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (
211 | epoch+1, self.num_epochs,
212 | epoch_loss,
213 | acc,SE,SP,PC,F1,JS,DC))
214 |
215 |
216 | # Decay learning rate
217 | if (epoch+1) > (self.num_epochs - self.num_epochs_decay):
218 | lr -= (self.lr / float(self.num_epochs_decay))
219 | for param_group in self.optimizer.param_groups:
220 | param_group['lr'] = lr
221 | print ('Decay learning rate to lr: {}.'.format(lr))
222 |
223 |
224 | #===================================== Validation ====================================#
225 | self.unet.train(False)
226 | self.unet.eval()
227 |
228 | acc = 0. # Accuracy
229 | SE = 0. # Sensitivity (Recall)
230 | SP = 0. # Specificity
231 | PC = 0. # Precision
232 | F1 = 0. # F1 Score
233 | JS = 0. # Jaccard Similarity
234 | DC = 0. # Dice Coefficient
235 | length=0
236 | for i, (images, GT) in enumerate(self.valid_loader):
237 |
238 | images = images.to(self.device)
239 | GT = GT.to(self.device)
240 | SR = F.sigmoid(self.unet(images))
241 | acc += get_accuracy(SR,GT)
242 | SE += get_sensitivity(SR,GT)
243 | SP += get_specificity(SR,GT)
244 | PC += get_precision(SR,GT)
245 | F1 += get_F1(SR,GT)
246 | JS += get_JS(SR,GT)
247 | DC += get_DC(SR,GT)
248 |
249 | length += images.size(0)
250 |
251 | acc = acc/length
252 | SE = SE/length
253 | SP = SP/length
254 | PC = PC/length
255 | F1 = F1/length
256 | JS = JS/length
257 | DC = DC/length
258 | unet_score = JS + DC
259 |
260 | print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'%(acc,SE,SP,PC,F1,JS,DC))
261 |
262 | '''
263 | torchvision.utils.save_image(images.data.cpu(),
264 | os.path.join(self.result_path,
265 | '%s_valid_%d_image.png'%(self.model_type,epoch+1)))
266 | torchvision.utils.save_image(SR.data.cpu(),
267 | os.path.join(self.result_path,
268 | '%s_valid_%d_SR.png'%(self.model_type,epoch+1)))
269 | torchvision.utils.save_image(GT.data.cpu(),
270 | os.path.join(self.result_path,
271 | '%s_valid_%d_GT.png'%(self.model_type,epoch+1)))
272 | '''
273 |
274 |
275 | # Save Best U-Net model
276 | if unet_score > best_unet_score:
277 | best_unet_score = unet_score
278 | best_epoch = epoch
279 | best_unet = self.unet.state_dict()
280 | print('Best %s model score : %.4f'%(self.model_type,best_unet_score))
281 | torch.save(best_unet,unet_path)
282 |
283 | #===================================== Test ====================================#
284 | del self.unet
285 | del best_unet
286 | self.build_model()
287 | self.unet.load_state_dict(torch.load(unet_path))
288 |
289 | self.unet.train(False)
290 | self.unet.eval()
291 |
292 | acc = 0. # Accuracy
293 | SE = 0. # Sensitivity (Recall)
294 | SP = 0. # Specificity
295 | PC = 0. # Precision
296 | F1 = 0. # F1 Score
297 | JS = 0. # Jaccard Similarity
298 | DC = 0. # Dice Coefficient
299 | length=0
300 | for i, (images, GT) in enumerate(self.valid_loader):
301 |
302 | images = images.to(self.device)
303 | GT = GT.to(self.device)
304 | SR = F.sigmoid(self.unet(images))
305 | acc += get_accuracy(SR,GT)
306 | SE += get_sensitivity(SR,GT)
307 | SP += get_specificity(SR,GT)
308 | PC += get_precision(SR,GT)
309 | F1 += get_F1(SR,GT)
310 | JS += get_JS(SR,GT)
311 | DC += get_DC(SR,GT)
312 |
313 | length += images.size(0)
314 |
315 | acc = acc/length
316 | SE = SE/length
317 | SP = SP/length
318 | PC = PC/length
319 | F1 = F1/length
320 | JS = JS/length
321 | DC = DC/length
322 | unet_score = JS + DC
323 |
324 | f = open(os.path.join(self.result_path,'result.csv'), 'a', encoding='utf-8', newline='')
325 | wr = csv.writer(f)
326 | wr.writerow([self.model_type,acc,SE,SP,PC,F1,JS,DC,self.lr,best_epoch,self.num_epochs,self.num_epochs_decay,self.augmentation_prob])
327 | f.close()
328 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 |
6 |
7 | def init_weights(net, init_type='normal', gain=0.02):
8 | def init_func(m):
9 | classname = m.__class__.__name__
10 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
11 | if init_type == 'normal':
12 | init.normal_(m.weight.data, 0.0, gain)
13 | elif init_type == 'xavier':
14 | init.xavier_normal_(m.weight.data, gain=gain)
15 | elif init_type == 'kaiming':
16 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
17 | elif init_type == 'orthogonal':
18 | init.orthogonal_(m.weight.data, gain=gain)
19 | else:
20 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
21 | if hasattr(m, 'bias') and m.bias is not None:
22 | init.constant_(m.bias.data, 0.0)
23 | elif classname.find('BatchNorm2d') != -1:
24 | init.normal_(m.weight.data, 1.0, gain)
25 | init.constant_(m.bias.data, 0.0)
26 |
27 | print('initialize network with %s' % init_type)
28 | net.apply(init_func)
29 |
30 | class conv_block(nn.Module):
31 | def __init__(self,ch_in,ch_out):
32 | super(conv_block,self).__init__()
33 | self.conv = nn.Sequential(
34 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
35 | nn.BatchNorm2d(ch_out),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
38 | nn.BatchNorm2d(ch_out),
39 | nn.ReLU(inplace=True)
40 | )
41 |
42 | def forward(self,x):
43 | x = self.conv(x)
44 | return x
45 |
46 | class up_conv(nn.Module):
47 | def __init__(self,ch_in,ch_out):
48 | super(up_conv,self).__init__()
49 | self.up = nn.Sequential(
50 | nn.Upsample(scale_factor=2),
51 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
52 | nn.BatchNorm2d(ch_out),
53 | nn.ReLU(inplace=True)
54 | )
55 |
56 | def forward(self,x):
57 | x = self.up(x)
58 | return x
59 |
60 | class Recurrent_block(nn.Module):
61 | def __init__(self,ch_out,t=2):
62 | super(Recurrent_block,self).__init__()
63 | self.t = t
64 | self.ch_out = ch_out
65 | self.conv = nn.Sequential(
66 | nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
67 | nn.BatchNorm2d(ch_out),
68 | nn.ReLU(inplace=True)
69 | )
70 |
71 | def forward(self,x):
72 | for i in range(self.t):
73 |
74 | if i==0:
75 | x1 = self.conv(x)
76 |
77 | x1 = self.conv(x+x1)
78 | return x1
79 |
80 | class RRCNN_block(nn.Module):
81 | def __init__(self,ch_in,ch_out,t=2):
82 | super(RRCNN_block,self).__init__()
83 | self.RCNN = nn.Sequential(
84 | Recurrent_block(ch_out,t=t),
85 | Recurrent_block(ch_out,t=t)
86 | )
87 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
88 |
89 | def forward(self,x):
90 | x = self.Conv_1x1(x)
91 | x1 = self.RCNN(x)
92 | return x+x1
93 |
94 |
95 | class single_conv(nn.Module):
96 | def __init__(self,ch_in,ch_out):
97 | super(single_conv,self).__init__()
98 | self.conv = nn.Sequential(
99 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
100 | nn.BatchNorm2d(ch_out),
101 | nn.ReLU(inplace=True)
102 | )
103 |
104 | def forward(self,x):
105 | x = self.conv(x)
106 | return x
107 |
108 | class Attention_block(nn.Module):
109 | def __init__(self,F_g,F_l,F_int):
110 | super(Attention_block,self).__init__()
111 | self.W_g = nn.Sequential(
112 | nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
113 | nn.BatchNorm2d(F_int)
114 | )
115 |
116 | self.W_x = nn.Sequential(
117 | nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
118 | nn.BatchNorm2d(F_int)
119 | )
120 |
121 | self.psi = nn.Sequential(
122 | nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
123 | nn.BatchNorm2d(1),
124 | nn.Sigmoid()
125 | )
126 |
127 | self.relu = nn.ReLU(inplace=True)
128 |
129 | def forward(self,g,x):
130 | g1 = self.W_g(g)
131 | x1 = self.W_x(x)
132 | psi = self.relu(g1+x1)
133 | psi = self.psi(psi)
134 |
135 | return x*psi
136 |
137 |
138 | class U_Net(nn.Module):
139 | def __init__(self,img_ch=3,output_ch=1):
140 | super(U_Net,self).__init__()
141 |
142 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
143 |
144 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
145 | self.Conv2 = conv_block(ch_in=64,ch_out=128)
146 | self.Conv3 = conv_block(ch_in=128,ch_out=256)
147 | self.Conv4 = conv_block(ch_in=256,ch_out=512)
148 | self.Conv5 = conv_block(ch_in=512,ch_out=1024)
149 |
150 | self.Up5 = up_conv(ch_in=1024,ch_out=512)
151 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
152 |
153 | self.Up4 = up_conv(ch_in=512,ch_out=256)
154 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
155 |
156 | self.Up3 = up_conv(ch_in=256,ch_out=128)
157 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
158 |
159 | self.Up2 = up_conv(ch_in=128,ch_out=64)
160 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
161 |
162 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
163 |
164 |
165 | def forward(self,x):
166 | # encoding path
167 | x1 = self.Conv1(x)
168 |
169 | x2 = self.Maxpool(x1)
170 | x2 = self.Conv2(x2)
171 |
172 | x3 = self.Maxpool(x2)
173 | x3 = self.Conv3(x3)
174 |
175 | x4 = self.Maxpool(x3)
176 | x4 = self.Conv4(x4)
177 |
178 | x5 = self.Maxpool(x4)
179 | x5 = self.Conv5(x5)
180 |
181 | # decoding + concat path
182 | d5 = self.Up5(x5)
183 | d5 = torch.cat((x4,d5),dim=1)
184 |
185 | d5 = self.Up_conv5(d5)
186 |
187 | d4 = self.Up4(d5)
188 | d4 = torch.cat((x3,d4),dim=1)
189 | d4 = self.Up_conv4(d4)
190 |
191 | d3 = self.Up3(d4)
192 | d3 = torch.cat((x2,d3),dim=1)
193 | d3 = self.Up_conv3(d3)
194 |
195 | d2 = self.Up2(d3)
196 | d2 = torch.cat((x1,d2),dim=1)
197 | d2 = self.Up_conv2(d2)
198 |
199 | d1 = self.Conv_1x1(d2)
200 |
201 | return d1
202 |
203 |
204 | class R2U_Net(nn.Module):
205 | def __init__(self,img_ch=3,output_ch=1,t=2):
206 | super(R2U_Net,self).__init__()
207 |
208 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
209 | self.Upsample = nn.Upsample(scale_factor=2)
210 |
211 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
212 |
213 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
214 |
215 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
216 |
217 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
218 |
219 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
220 |
221 |
222 | self.Up5 = up_conv(ch_in=1024,ch_out=512)
223 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
224 |
225 | self.Up4 = up_conv(ch_in=512,ch_out=256)
226 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
227 |
228 | self.Up3 = up_conv(ch_in=256,ch_out=128)
229 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
230 |
231 | self.Up2 = up_conv(ch_in=128,ch_out=64)
232 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
233 |
234 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
235 |
236 |
237 | def forward(self,x):
238 | # encoding path
239 | x1 = self.RRCNN1(x)
240 |
241 | x2 = self.Maxpool(x1)
242 | x2 = self.RRCNN2(x2)
243 |
244 | x3 = self.Maxpool(x2)
245 | x3 = self.RRCNN3(x3)
246 |
247 | x4 = self.Maxpool(x3)
248 | x4 = self.RRCNN4(x4)
249 |
250 | x5 = self.Maxpool(x4)
251 | x5 = self.RRCNN5(x5)
252 |
253 | # decoding + concat path
254 | d5 = self.Up5(x5)
255 | d5 = torch.cat((x4,d5),dim=1)
256 | d5 = self.Up_RRCNN5(d5)
257 |
258 | d4 = self.Up4(d5)
259 | d4 = torch.cat((x3,d4),dim=1)
260 | d4 = self.Up_RRCNN4(d4)
261 |
262 | d3 = self.Up3(d4)
263 | d3 = torch.cat((x2,d3),dim=1)
264 | d3 = self.Up_RRCNN3(d3)
265 |
266 | d2 = self.Up2(d3)
267 | d2 = torch.cat((x1,d2),dim=1)
268 | d2 = self.Up_RRCNN2(d2)
269 |
270 | d1 = self.Conv_1x1(d2)
271 |
272 | return d1
273 |
274 |
275 |
276 | class AttU_Net(nn.Module):
277 | def __init__(self,img_ch=3,output_ch=1):
278 | super(AttU_Net,self).__init__()
279 |
280 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
281 |
282 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
283 | self.Conv2 = conv_block(ch_in=64,ch_out=128)
284 | self.Conv3 = conv_block(ch_in=128,ch_out=256)
285 | self.Conv4 = conv_block(ch_in=256,ch_out=512)
286 | self.Conv5 = conv_block(ch_in=512,ch_out=1024)
287 |
288 | self.Up5 = up_conv(ch_in=1024,ch_out=512)
289 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
290 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
291 |
292 | self.Up4 = up_conv(ch_in=512,ch_out=256)
293 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
294 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
295 |
296 | self.Up3 = up_conv(ch_in=256,ch_out=128)
297 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
298 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
299 |
300 | self.Up2 = up_conv(ch_in=128,ch_out=64)
301 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
302 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
303 |
304 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
305 |
306 |
307 | def forward(self,x):
308 | # encoding path
309 | x1 = self.Conv1(x)
310 |
311 | x2 = self.Maxpool(x1)
312 | x2 = self.Conv2(x2)
313 |
314 | x3 = self.Maxpool(x2)
315 | x3 = self.Conv3(x3)
316 |
317 | x4 = self.Maxpool(x3)
318 | x4 = self.Conv4(x4)
319 |
320 | x5 = self.Maxpool(x4)
321 | x5 = self.Conv5(x5)
322 |
323 | # decoding + concat path
324 | d5 = self.Up5(x5)
325 | x4 = self.Att5(g=d5,x=x4)
326 | d5 = torch.cat((x4,d5),dim=1)
327 | d5 = self.Up_conv5(d5)
328 |
329 | d4 = self.Up4(d5)
330 | x3 = self.Att4(g=d4,x=x3)
331 | d4 = torch.cat((x3,d4),dim=1)
332 | d4 = self.Up_conv4(d4)
333 |
334 | d3 = self.Up3(d4)
335 | x2 = self.Att3(g=d3,x=x2)
336 | d3 = torch.cat((x2,d3),dim=1)
337 | d3 = self.Up_conv3(d3)
338 |
339 | d2 = self.Up2(d3)
340 | x1 = self.Att2(g=d2,x=x1)
341 | d2 = torch.cat((x1,d2),dim=1)
342 | d2 = self.Up_conv2(d2)
343 |
344 | d1 = self.Conv_1x1(d2)
345 |
346 | return d1
347 |
348 |
349 | class R2AttU_Net(nn.Module):
350 | def __init__(self,img_ch=3,output_ch=1,t=2):
351 | super(R2AttU_Net,self).__init__()
352 |
353 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
354 | self.Upsample = nn.Upsample(scale_factor=2)
355 |
356 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
357 |
358 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
359 |
360 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
361 |
362 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
363 |
364 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
365 |
366 |
367 | self.Up5 = up_conv(ch_in=1024,ch_out=512)
368 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
369 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
370 |
371 | self.Up4 = up_conv(ch_in=512,ch_out=256)
372 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
373 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
374 |
375 | self.Up3 = up_conv(ch_in=256,ch_out=128)
376 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
377 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
378 |
379 | self.Up2 = up_conv(ch_in=128,ch_out=64)
380 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
381 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
382 |
383 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0)
384 |
385 |
386 | def forward(self,x):
387 | # encoding path
388 | x1 = self.RRCNN1(x)
389 |
390 | x2 = self.Maxpool(x1)
391 | x2 = self.RRCNN2(x2)
392 |
393 | x3 = self.Maxpool(x2)
394 | x3 = self.RRCNN3(x3)
395 |
396 | x4 = self.Maxpool(x3)
397 | x4 = self.RRCNN4(x4)
398 |
399 | x5 = self.Maxpool(x4)
400 | x5 = self.RRCNN5(x5)
401 |
402 | # decoding + concat path
403 | d5 = self.Up5(x5)
404 | x4 = self.Att5(g=d5,x=x4)
405 | d5 = torch.cat((x4,d5),dim=1)
406 | d5 = self.Up_RRCNN5(d5)
407 |
408 | d4 = self.Up4(d5)
409 | x3 = self.Att4(g=d4,x=x3)
410 | d4 = torch.cat((x3,d4),dim=1)
411 | d4 = self.Up_RRCNN4(d4)
412 |
413 | d3 = self.Up3(d4)
414 | x2 = self.Att3(g=d3,x=x2)
415 | d3 = torch.cat((x2,d3),dim=1)
416 | d3 = self.Up_RRCNN3(d3)
417 |
418 | d2 = self.Up2(d3)
419 | x1 = self.Att2(g=d2,x=x1)
420 | d2 = torch.cat((x1,d2),dim=1)
421 | d2 = self.Up_RRCNN2(d2)
422 |
423 | d1 = self.Conv_1x1(d2)
424 |
425 | return d1
426 |
--------------------------------------------------------------------------------