├── misc.py ├── evaluation.py ├── data_loader.py ├── dataset.py ├── main.py ├── README.md ├── solver.py └── network.py /misc.py: -------------------------------------------------------------------------------- 1 | 2 | def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'): 3 | """ 4 | Call in a loop to create terminal progress bar 5 | @params: 6 | iteration - Required : current iteration (Int) 7 | total - Required : total iterations (Int) 8 | prefix - Optional : prefix string (Str) 9 | suffix - Optional : suffix string (Str) 10 | decimals - Optional : positive number of decimals in percent complete (Int) 11 | length - Optional : character length of bar (Int) 12 | fill - Optional : bar fill character (Str) 13 | """ 14 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) 15 | filledLength = int(length * iteration // total) 16 | bar = fill * filledLength + '-' * (length - filledLength) 17 | print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r') 18 | # Print New Line on Complete 19 | if iteration == total: 20 | print() 21 | 22 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def binary_threshold(tensor, threshold=0.5): 4 | return torch.where(tensor > threshold, torch.tensor(1), torch.tensor(0)) 5 | 6 | 7 | def get_accuracy(SR, GT, threshold=0.5): 8 | SR_binary = binary_threshold(SR, threshold) 9 | GT_binary = GT == torch.max(GT) 10 | 11 | correct_pixels = torch.sum(SR_binary == GT_binary) 12 | total_pixels = SR.numel() 13 | accuracy = float(correct_pixels) / float(total_pixels) 14 | 15 | return accuracy 16 | 17 | def get_sensitivity(SR, GT, threshold=0.5): 18 | SR_binary = binary_threshold(SR, threshold) 19 | GT_binary = GT == torch.max(GT) 20 | 21 | TP = ((SR_binary == 1) & (GT_binary == 1)).float() 22 | FN = ((SR_binary == 0) & (GT_binary == 1)).float() 23 | 24 | sensitivity = torch.sum(TP) / (torch.sum(TP + FN) + 1e-6) 25 | return sensitivity 26 | 27 | def get_specificity(SR, GT, threshold=0.5): 28 | SR_binary = binary_threshold(SR, threshold) 29 | GT_binary = GT == torch.max(GT) 30 | 31 | TN = ((SR_binary == 0) & (GT_binary == 0)).float() 32 | FP = ((SR_binary == 1) & (GT_binary == 0)).float() 33 | 34 | specificity = torch.sum(TN) / (torch.sum(TN + FP) + 1e-6) 35 | return specificity 36 | 37 | def get_precision(SR, GT, threshold=0.5): 38 | SR_binary = binary_threshold(SR, threshold) 39 | GT_binary = GT == torch.max(GT) 40 | 41 | TP = ((SR_binary == 1) & (GT_binary == 1)).float() 42 | FP = ((SR_binary == 1) & (GT_binary == 0)).float() 43 | 44 | precision = torch.sum(TP) / (torch.sum(TP + FP) + 1e-6) 45 | return precision 46 | 47 | def get_F1(SR, GT, threshold=0.5): 48 | sensitivity = get_sensitivity(SR, GT, threshold=threshold) 49 | precision = get_precision(SR, GT, threshold=threshold) 50 | 51 | F1 = 2 * sensitivity * precision / (sensitivity + precision + 1e-6) 52 | return F1 53 | 54 | def get_JS(SR, GT, threshold=0.5): 55 | SR_binary = binary_threshold(SR, threshold) 56 | GT_binary = GT == torch.max(GT) 57 | 58 | intersection = torch.sum((SR_binary + GT_binary) == 2) 59 | union = torch.sum((SR_binary + GT_binary) >= 1) 60 | 61 | JS = float(intersection) / (float(union) + 1e-6) 62 | return JS 63 | 64 | def get_DC(SR, GT, threshold=0.5): 65 | SR_binary = binary_threshold(SR, threshold) 66 | GT_binary = GT == torch.max(GT) 67 | 68 | intersection = torch.sum((SR_binary + GT_binary) == 2) 69 | dice_coefficient = float(2 * intersection) / (float(torch.sum(SR_binary) + torch.sum(GT_binary)) + 1e-6) 70 | 71 | return dice_coefficient 72 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from random import shuffle 4 | import numpy as np 5 | import torch 6 | from torch.utils import data 7 | from torchvision import transforms as T 8 | from torchvision.transforms import functional as F 9 | from PIL import Image 10 | 11 | 12 | class ImageFolder(data.Dataset): 13 | def __init__(self, root,image_size=224,mode='train',augmentation_prob=0.4): 14 | """Initializes image paths and preprocessing module.""" 15 | self.root = root 16 | 17 | # GT : Ground Truth 18 | self.GT_paths = root[:-1]+'_GT/' 19 | self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root))) 20 | self.image_size = image_size 21 | self.mode = mode 22 | self.RotationDegree = [0,90,180,270] 23 | self.augmentation_prob = augmentation_prob 24 | print("image count in {} path :{}".format(self.mode,len(self.image_paths))) 25 | 26 | def __getitem__(self, index): 27 | """Reads an image from a file and preprocesses it and returns.""" 28 | image_path = self.image_paths[index] 29 | filename = image_path.split('_')[-1][:-len(".jpg")] 30 | GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png' 31 | 32 | image = Image.open(image_path) 33 | GT = Image.open(GT_path) 34 | 35 | aspect_ratio = image.size[1]/image.size[0] 36 | 37 | Transform = [] 38 | 39 | ResizeRange = random.randint(300,320) 40 | Transform.append(T.Resize((int(ResizeRange*aspect_ratio),ResizeRange))) 41 | p_transform = random.random() 42 | 43 | if (self.mode == 'train') and p_transform <= self.augmentation_prob: 44 | RotationDegree = random.randint(0,3) 45 | RotationDegree = self.RotationDegree[RotationDegree] 46 | if (RotationDegree == 90) or (RotationDegree == 270): 47 | aspect_ratio = 1/aspect_ratio 48 | 49 | Transform.append(T.RandomRotation((RotationDegree,RotationDegree))) 50 | 51 | RotationRange = random.randint(-10,10) 52 | Transform.append(T.RandomRotation((RotationRange,RotationRange))) 53 | CropRange = random.randint(250,270) 54 | Transform.append(T.CenterCrop((int(CropRange*aspect_ratio),CropRange))) 55 | Transform = T.Compose(Transform) 56 | 57 | image = Transform(image) 58 | GT = Transform(GT) 59 | 60 | ShiftRange_left = random.randint(0,20) 61 | ShiftRange_upper = random.randint(0,20) 62 | ShiftRange_right = image.size[0] - random.randint(0,20) 63 | ShiftRange_lower = image.size[1] - random.randint(0,20) 64 | image = image.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower)) 65 | GT = GT.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower)) 66 | 67 | if random.random() < 0.5: 68 | image = F.hflip(image) 69 | GT = F.hflip(GT) 70 | 71 | if random.random() < 0.5: 72 | image = F.vflip(image) 73 | GT = F.vflip(GT) 74 | 75 | Transform = T.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02) 76 | 77 | image = Transform(image) 78 | 79 | Transform =[] 80 | 81 | 82 | Transform.append(T.Resize((int(256*aspect_ratio)-int(256*aspect_ratio)%16,256))) 83 | Transform.append(T.ToTensor()) 84 | Transform = T.Compose(Transform) 85 | 86 | image = Transform(image) 87 | GT = Transform(GT) 88 | 89 | Norm_ = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 90 | image = Norm_(image) 91 | 92 | return image, GT 93 | 94 | def __len__(self): 95 | """Returns the total number of font files.""" 96 | return len(self.image_paths) 97 | 98 | def get_loader(image_path, image_size, batch_size, num_workers=2, mode='train',augmentation_prob=0.4): 99 | """Builds and returns Dataloader.""" 100 | 101 | dataset = ImageFolder(root = image_path, image_size =image_size, mode=mode,augmentation_prob=augmentation_prob) 102 | data_loader = data.DataLoader(dataset=dataset, 103 | batch_size=batch_size, 104 | shuffle=True, 105 | num_workers=num_workers) 106 | return data_loader 107 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import shutil 5 | from shutil import copyfile 6 | from misc import printProgressBar 7 | 8 | 9 | def rm_mkdir(dir_path): 10 | if os.path.exists(dir_path): 11 | shutil.rmtree(dir_path) 12 | print('Remove path - %s'%dir_path) 13 | os.makedirs(dir_path) 14 | print('Create path - %s'%dir_path) 15 | 16 | def main(config): 17 | 18 | rm_mkdir(config.train_path) 19 | rm_mkdir(config.train_GT_path) 20 | rm_mkdir(config.valid_path) 21 | rm_mkdir(config.valid_GT_path) 22 | rm_mkdir(config.test_path) 23 | rm_mkdir(config.test_GT_path) 24 | 25 | filenames = os.listdir(config.origin_data_path) 26 | data_list = [] 27 | GT_list = [] 28 | 29 | for filename in filenames: 30 | ext = os.path.splitext(filename)[-1] 31 | if ext =='.jpg': 32 | filename = filename.split('_')[-1][:-len('.jpg')] 33 | data_list.append('ISIC_'+filename+'.jpg') 34 | GT_list.append('ISIC_'+filename+'_segmentation.png') 35 | 36 | num_total = len(data_list) 37 | num_train = int((config.train_ratio/(config.train_ratio+config.valid_ratio+config.test_ratio))*num_total) 38 | num_valid = int((config.valid_ratio/(config.train_ratio+config.valid_ratio+config.test_ratio))*num_total) 39 | num_test = num_total - num_train - num_valid 40 | 41 | print('\nNum of train set : ',num_train) 42 | print('\nNum of valid set : ',num_valid) 43 | print('\nNum of test set : ',num_test) 44 | 45 | Arange = list(range(num_total)) 46 | random.shuffle(Arange) 47 | 48 | for i in range(num_train): 49 | idx = Arange.pop() 50 | 51 | src = os.path.join(config.origin_data_path, data_list[idx]) 52 | dst = os.path.join(config.train_path,data_list[idx]) 53 | copyfile(src, dst) 54 | 55 | src = os.path.join(config.origin_GT_path, GT_list[idx]) 56 | dst = os.path.join(config.train_GT_path, GT_list[idx]) 57 | copyfile(src, dst) 58 | 59 | printProgressBar(i + 1, num_train, prefix = 'Producing train set:', suffix = 'Complete', length = 50) 60 | 61 | 62 | for i in range(num_valid): 63 | idx = Arange.pop() 64 | 65 | src = os.path.join(config.origin_data_path, data_list[idx]) 66 | dst = os.path.join(config.valid_path,data_list[idx]) 67 | copyfile(src, dst) 68 | 69 | src = os.path.join(config.origin_GT_path, GT_list[idx]) 70 | dst = os.path.join(config.valid_GT_path, GT_list[idx]) 71 | copyfile(src, dst) 72 | 73 | printProgressBar(i + 1, num_valid, prefix = 'Producing valid set:', suffix = 'Complete', length = 50) 74 | 75 | for i in range(num_test): 76 | idx = Arange.pop() 77 | 78 | src = os.path.join(config.origin_data_path, data_list[idx]) 79 | dst = os.path.join(config.test_path,data_list[idx]) 80 | copyfile(src, dst) 81 | 82 | src = os.path.join(config.origin_GT_path, GT_list[idx]) 83 | dst = os.path.join(config.test_GT_path, GT_list[idx]) 84 | copyfile(src, dst) 85 | 86 | 87 | printProgressBar(i + 1, num_test, prefix = 'Producing test set:', suffix = 'Complete', length = 50) 88 | 89 | if __name__ == '__main__': 90 | parser = argparse.ArgumentParser() 91 | 92 | # model hyper-parameters 93 | parser.add_argument('--train_ratio', type=float, default=0.6) 94 | parser.add_argument('--valid_ratio', type=float, default=0.2) 95 | parser.add_argument('--test_ratio', type=float, default=0.2) 96 | 97 | # data path 98 | parser.add_argument('--origin_data_path', type=str, default='../ISIC/dataset/ISIC2018_Task1-2_Training_Input') 99 | parser.add_argument('--origin_GT_path', type=str, default='../ISIC/dataset/ISIC2018_Task1_Training_GroundTruth') 100 | 101 | parser.add_argument('--train_path', type=str, default='./dataset/train/') 102 | parser.add_argument('--train_GT_path', type=str, default='./dataset/train_GT/') 103 | parser.add_argument('--valid_path', type=str, default='./dataset/valid/') 104 | parser.add_argument('--valid_GT_path', type=str, default='./dataset/valid_GT/') 105 | parser.add_argument('--test_path', type=str, default='./dataset/test/') 106 | parser.add_argument('--test_GT_path', type=str, default='./dataset/test_GT/') 107 | 108 | config = parser.parse_args() 109 | print(config) 110 | main(config) 111 | 112 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from solver import Solver 4 | from data_loader import get_loader 5 | from torch.backends import cudnn 6 | import random 7 | 8 | def main(config): 9 | cudnn.benchmark = True 10 | # 修改类型 11 | config.model_type='R2U_Net' 12 | 13 | if config.model_type not in ['U_Net','R2U_Net','AttU_Net','R2AttU_Net']: 14 | print('ERROR!! model_type should be selected in U_Net/R2U_Net/AttU_Net/R2AttU_Net') 15 | print('Your input for model_type was %s'%config.model_type) 16 | return 17 | 18 | # Create directories if not exist 19 | if not os.path.exists(config.model_path): 20 | os.makedirs(config.model_path) 21 | if not os.path.exists(config.result_path): 22 | os.makedirs(config.result_path) 23 | config.result_path = os.path.join(config.result_path,config.model_type) 24 | if not os.path.exists(config.result_path): 25 | os.makedirs(config.result_path) 26 | 27 | # 修改参数 28 | lr = random.random()*0.0005 + 0.0000005 29 | augmentation_prob= random.random()*0.7 30 | epoch = random.choice([3,5,7]) 31 | decay_ratio = random.random()*0.8 32 | decay_epoch = int(epoch*decay_ratio) 33 | 34 | config.augmentation_prob = augmentation_prob 35 | config.num_epochs = epoch 36 | config.lr = lr 37 | config.num_epochs_decay = decay_epoch 38 | 39 | print(config) 40 | 41 | train_loader = get_loader(image_path=config.train_path, 42 | image_size=config.image_size, 43 | batch_size=config.batch_size, 44 | num_workers=config.num_workers, 45 | mode='train', 46 | augmentation_prob=config.augmentation_prob) 47 | valid_loader = get_loader(image_path=config.valid_path, 48 | image_size=config.image_size, 49 | batch_size=config.batch_size, 50 | num_workers=config.num_workers, 51 | mode='valid', 52 | augmentation_prob=0.) 53 | test_loader = get_loader(image_path=config.test_path, 54 | image_size=config.image_size, 55 | batch_size=config.batch_size, 56 | num_workers=config.num_workers, 57 | mode='test', 58 | augmentation_prob=0.) 59 | 60 | solver = Solver(config, train_loader, valid_loader, test_loader) 61 | 62 | # Train and sample the images 63 | if config.mode == 'train': 64 | solver.train() 65 | elif config.mode == 'test': 66 | solver.test() 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser() 70 | 71 | # model hyper-parameters 72 | parser.add_argument('--image_size', type=int, default=224) 73 | parser.add_argument('--t', type=int, default=3, help='t for Recurrent step of R2U_Net or R2AttU_Net') 74 | 75 | # training hyper-parameters 76 | parser.add_argument('--img_ch', type=int, default=3) 77 | parser.add_argument('--output_ch', type=int, default=1) 78 | parser.add_argument('--num_epochs', type=int, default=100) 79 | parser.add_argument('--num_epochs_decay', type=int, default=70) 80 | parser.add_argument('--batch_size', type=int, default=1) 81 | parser.add_argument('--num_workers', type=int, default=8) 82 | parser.add_argument('--lr', type=float, default=0.0002) 83 | parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam 84 | parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam 85 | parser.add_argument('--augmentation_prob', type=float, default=0.4) 86 | 87 | parser.add_argument('--log_step', type=int, default=2) 88 | parser.add_argument('--val_step', type=int, default=2) 89 | 90 | # misc 91 | parser.add_argument('--mode', type=str, default='train') 92 | parser.add_argument('--model_type', type=str, default='U_Net', help='U_Net/R2U_Net/AttU_Net/R2AttU_Net') 93 | parser.add_argument('--model_path', type=str, default='./models') 94 | parser.add_argument('--train_path', type=str, default='./dataset/train/') 95 | parser.add_argument('--valid_path', type=str, default='./dataset/valid/') 96 | parser.add_argument('--test_path', type=str, default='./dataset/test/') 97 | parser.add_argument('--result_path', type=str, default='./result/') 98 | parser.add_argument('--cuda_idx', type=int, default=1) 99 | 100 | config = parser.parse_args() 101 | main(config) 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## 该项目使用PyTorch实现了U-Net、R2U-Net、Attention U-Net以及Attention R2U-Net模型的训练。同时,对这四个模型的关键参数进行了详细的分析和比较,旨在更全面地评估各个模型的优缺点。 3 | 注1:为了防止代码运行中出现路径检索错误,请将项目下载至新建的**ISIC**文件目录之下
![image.png](https://cdn.nlark.com/yuque/0/2024/png/21820237/1706884484205-ea376c7b-8bab-4ca4-891d-96fc7500d427.png#averageHue=%23fbf9f8&clientId=uffba3bf4-4186-4&from=paste&height=375&id=ud933ff88&originHeight=540&originWidth=818&originalType=binary&ratio=1.25&rotation=0&showTitle=false&size=61230&status=done&style=none&taskId=uadb96c6c-0387-47f9-82f8-be3082e2b07&title=&width=567.4000244140625)
注2:自行创建**dataset**、**models**、**result**文件夹
注3:**训练结果**文件夹为本人实验所得数据,仅供参考
注4:运行环境**python3.9**
![image.png](https://cdn.nlark.com/yuque/0/2024/png/21820237/1706927257238-c0d5a2b6-b96f-4af3-80bd-e934523c1f92.png#averageHue=%23dbe2db&clientId=u25d34394-3259-4&from=paste&height=397&id=ub8baff11&originHeight=496&originWidth=1095&originalType=binary&ratio=1.25&rotation=0&showTitle=false&size=55770&status=done&style=none&taskId=u6d649412-01cd-498b-8c8f-925368299ef&title=&width=876) 4 | 5 | 6 | # 数据集下载 7 | 数据集使用**ISIC-2018**数据集:[https://challenge.isic-archive.com/data/#2018](https://challenge.isic-archive.com/data/#2018)
数据集被分为三个子集,分别为训练集、验证集和测试集,其比例分别占整个数据集的70%、10%和20%。整个数据集包含 2594 张图像,其中 1815 张图像用于训练,259 张用于验证,520 张用于测试模型。
**Step1**:下载如下框选文件
![image.png](https://cdn.nlark.com/yuque/0/2024/png/21820237/1706882998279-b095698c-6fad-4141-8872-ec8581e30fc2.png#averageHue=%23fbfaf9&clientId=uffba3bf4-4186-4&from=paste&height=205&id=u74c32ede&originHeight=537&originWidth=1620&originalType=binary&ratio=1.25&rotation=0&showTitle=false&size=113273&status=done&style=none&taskId=ua4ce2aa6-9acd-4e4a-9d26-198116592b1&title=&width=617.4000244140625)
**Step2**:解压缩后文件名如下所示,并将两个文件夹存放至dataset文件目录之下,无需其他操作
![image.png](https://cdn.nlark.com/yuque/0/2024/png/21820237/1706883240550-808fd460-9b0f-4044-9646-776c6acc7336.png#averageHue=%23fbf9f7&clientId=uffba3bf4-4186-4&from=paste&height=56&id=u84a35588&originHeight=70&originWidth=843&originalType=binary&ratio=1.25&rotation=0&showTitle=false&size=8135&status=done&style=none&taskId=u8c38396c-ef35-4408-88ee-71933b9dcd2&title=&width=674.4) 8 | 9 | 10 | # 代码运行步骤 11 | **Step1**:单独运行**dataset.py**文件,对数据集进行处理
**Step2**:配置**main.py**文件
![image.png](https://cdn.nlark.com/yuque/0/2024/png/21820237/1706885131681-b7b21e39-c310-4025-a19a-f422b40afac5.png#averageHue=%23fcf8f7&clientId=uffba3bf4-4186-4&from=paste&height=427&id=u62ea1e1e&originHeight=534&originWidth=1078&originalType=binary&ratio=1.25&rotation=0&showTitle=false&size=73617&status=done&style=none&taskId=uf597efb0-e5d1-4f1d-ad93-b1a258cf65a&title=&width=862.4)
![image.png](https://cdn.nlark.com/yuque/0/2024/png/21820237/1706885261905-27164510-7d16-46ff-85c1-8e240d3d000c.png#averageHue=%23f9f6f5&clientId=uffba3bf4-4186-4&from=paste&height=383&id=u4feee0e5&originHeight=479&originWidth=1146&originalType=binary&ratio=1.25&rotation=0&showTitle=false&size=88746&status=done&style=none&taskId=u89edc979-e4ee-460d-9252-d87374fbcad&title=&width=916.8)
**Step3**:每次运行完毕检查输出结果
![image.png](https://cdn.nlark.com/yuque/0/2024/png/21820237/1706885337722-733e8880-fd63-4070-b5b3-c21b5d15886c.png#averageHue=%23fbf4f3&clientId=uffba3bf4-4186-4&from=paste&height=418&id=Fl6EB&originHeight=522&originWidth=794&originalType=binary&ratio=1.25&rotation=0&showTitle=false&size=58280&status=done&style=none&taskId=uecc18f01-0aa3-4a1c-8b1a-c405c99b20f&title=&width=635.2) 12 | 13 | # 模型结构介绍 14 | 15 | ## U-Net 16 | [![](https://github.com/LeeJunHyun/Image_Segmentation/raw/master/img/U-Net.png#from=url&id=iZhyF&originHeight=278&originWidth=418&originalType=binary&ratio=1.25&rotation=0&showTitle=false&status=done&style=none&title=)](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/img/U-Net.png) 17 | 18 | ## R2U-Net 19 | [![](https://github.com/LeeJunHyun/Image_Segmentation/raw/master/img/R2U-Net.png#from=url&id=dbA4K&originHeight=335&originWidth=960&originalType=binary&ratio=1.25&rotation=0&showTitle=false&status=done&style=none&title=)](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/img/R2U-Net.png) 20 | 21 | ## Attention U-Net 22 | [![](https://github.com/LeeJunHyun/Image_Segmentation/raw/master/img/AttU-Net.png#from=url&id=nUv0K&originHeight=822&originWidth=1272&originalType=binary&ratio=1.25&rotation=0&showTitle=false&status=done&style=none&title=)](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/img/AttU-Net.png) 23 | 24 | ## Attention R2U-Net 25 | [![](https://github.com/LeeJunHyun/Image_Segmentation/raw/master/img/AttR2U-Net.png#from=url&id=lqcCP&originHeight=522&originWidth=1500&originalType=binary&ratio=1.25&rotation=0&showTitle=false&status=done&style=none&title=)](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/img/AttR2U-Net.png) 26 | 27 | # 模型评估结果 28 | [![](https://github.com/LeeJunHyun/Image_Segmentation/raw/master/img/Evaluation.png#from=url&id=toXKM&originHeight=673&originWidth=1670&originalType=binary&ratio=1.25&rotation=0&showTitle=false&status=done&style=none&title=)](https://github.com/LeeJunHyun/Image_Segmentation/blob/master/img/Evaluation.png) 29 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import time 4 | import datetime 5 | import torch 6 | import torchvision 7 | from torch import optim 8 | from torch.autograd import Variable 9 | import torch.nn.functional as F 10 | from evaluation import * 11 | from network import U_Net,R2U_Net,AttU_Net,R2AttU_Net 12 | import csv 13 | 14 | 15 | class Solver(object): 16 | best_unet = None 17 | best_unet_score = 0. 18 | best_epoch = 0 19 | def __init__(self, config, train_loader, valid_loader, test_loader): 20 | 21 | # Data loader 22 | self.train_loader = train_loader 23 | self.valid_loader = valid_loader 24 | self.test_loader = test_loader 25 | 26 | # Models 27 | self.unet = None 28 | self.optimizer = None 29 | self.img_ch = config.img_ch 30 | self.output_ch = config.output_ch 31 | self.criterion = torch.nn.BCELoss() 32 | self.augmentation_prob = config.augmentation_prob 33 | 34 | # Hyper-parameters 35 | self.lr = config.lr 36 | self.beta1 = config.beta1 37 | self.beta2 = config.beta2 38 | 39 | # Training settings 40 | self.num_epochs = config.num_epochs 41 | self.num_epochs_decay = config.num_epochs_decay 42 | self.batch_size = config.batch_size 43 | 44 | # Step size 45 | self.log_step = config.log_step 46 | self.val_step = config.val_step 47 | 48 | # Path 49 | self.model_path = config.model_path 50 | self.result_path = config.result_path 51 | self.mode = config.mode 52 | 53 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 54 | self.model_type = config.model_type 55 | self.t = config.t 56 | self.build_model() 57 | 58 | def build_model(self): 59 | """Build generator and discriminator.""" 60 | if self.model_type =='U_Net': 61 | self.unet = U_Net(img_ch=3,output_ch=1) 62 | elif self.model_type =='R2U_Net': 63 | self.unet = R2U_Net(img_ch=3,output_ch=1,t=self.t) 64 | elif self.model_type =='AttU_Net': 65 | self.unet = AttU_Net(img_ch=3,output_ch=1) 66 | elif self.model_type == 'R2AttU_Net': 67 | self.unet = R2AttU_Net(img_ch=3,output_ch=1,t=self.t) 68 | 69 | self.optimizer = optim.Adam(list(self.unet.parameters()), 70 | self.lr, [self.beta1, self.beta2]) 71 | self.unet.to(self.device) 72 | 73 | # self.print_network(self.unet, self.model_type) 74 | 75 | def print_network(self, model, name): 76 | """Print out the network information.""" 77 | num_params = 0 78 | for p in model.parameters(): 79 | num_params += p.numel() 80 | print(model) 81 | print(name) 82 | print("The number of parameters: {}".format(num_params)) 83 | 84 | def to_data(self, x): 85 | """Convert variable to tensor.""" 86 | if torch.cuda.is_available(): 87 | x = x.cpu() 88 | return x.data 89 | 90 | ''' 91 | 在您的代码中,update_lr 函数的目的是更新优化器的学习率。然而,在函数中有一个错误。让我为您解释一下: 92 | 93 | python 94 | Copy code 95 | def update_lr(self, g_lr, d_lr): 96 | for param_group in self.optimizer.param_groups: 97 | param_group['lr'] = lr 98 | 99 | 问题出在这一行: 100 | 101 | python 102 | Copy code 103 | param_group['lr'] = lr 104 | 您尝试将学习率设置为未定义的变量 lr。应该将其更正为: 105 | 106 | python 107 | Copy code 108 | param_group['lr'] = g_lr # or d_lr, depending on which one you want to use 109 | 这样,您可以选择将学习率更新为 g_lr 或 d_lr,具体取决于您的需求。如果您希望同时更新生成器和判别器的学习率,可以将 update_lr 函数修改为接受一个学习率参数: 110 | 111 | python 112 | Copy code 113 | def update_lr(self, lr): 114 | for param_group in self.optimizer.param_groups: 115 | param_group['lr'] = lr 116 | 然后,在调用此函数时,传递您希望使用的学习率值。 117 | ''' 118 | def update_lr(self,lr): 119 | for param_group in self.optimizer.param_groups: 120 | param_group['lr'] = lr 121 | 122 | def reset_grad(self): 123 | """Zero the gradient buffers.""" 124 | self.unet.zero_grad() 125 | 126 | def compute_accuracy(self,SR,GT): 127 | SR_flat = SR.view(-1) 128 | GT_flat = GT.view(-1) 129 | 130 | acc = GT_flat.data.cpu()==(SR_flat.data.cpu()>0.5) 131 | 132 | def tensor2img(self,x): 133 | img = (x[:,0,:,:]>x[:,1,:,:]).float() 134 | img = img*255 135 | return img 136 | 137 | 138 | 139 | def train(self): 140 | """Train encoder, generator and discriminator.""" 141 | 142 | #====================================== Training ===========================================# 143 | 144 | global best_epoch 145 | unet_path = os.path.join(self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %(self.model_type,self.num_epochs,self.lr,self.num_epochs_decay,self.augmentation_prob)) 146 | 147 | # U-Net Train 148 | if os.path.isfile(unet_path): 149 | # Load the pretrained Encoder 150 | self.unet.load_state_dict(torch.load(unet_path)) 151 | print('%s is Successfully Loaded from %s'%(self.model_type,unet_path)) 152 | else: 153 | # Train for Encoder 154 | lr = self.lr 155 | best_unet_score = 0. 156 | best_unet = None 157 | 158 | for epoch in range(self.num_epochs): 159 | 160 | self.unet.train(True) 161 | epoch_loss = 0 162 | 163 | acc = 0. # Accuracy 164 | SE = 0. # Sensitivity (Recall) 165 | SP = 0. # Specificity 166 | PC = 0. # Precision 167 | F1 = 0. # F1 Score 168 | JS = 0. # Jaccard Similarity 169 | DC = 0. # Dice Coefficient 170 | length = 0 171 | 172 | for i, (images, GT) in enumerate(self.train_loader): 173 | # GT : Ground Truth 174 | 175 | images = images.to(self.device) 176 | GT = GT.to(self.device) 177 | 178 | # SR : Segmentation Result 179 | SR = self.unet(images) 180 | SR_probs = F.sigmoid(SR) 181 | SR_flat = SR_probs.view(SR_probs.size(0),-1) 182 | 183 | GT_flat = GT.view(GT.size(0),-1) 184 | loss = self.criterion(SR_flat,GT_flat) 185 | epoch_loss += loss.item() 186 | 187 | # Backprop + optimize 188 | self.reset_grad() 189 | loss.backward() 190 | self.optimizer.step() 191 | 192 | acc += get_accuracy(SR,GT) 193 | SE += get_sensitivity(SR,GT) 194 | SP += get_specificity(SR,GT) 195 | PC += get_precision(SR,GT) 196 | F1 += get_F1(SR,GT) 197 | JS += get_JS(SR,GT) 198 | DC += get_DC(SR,GT) 199 | length += images.size(0) 200 | 201 | acc = acc/length 202 | SE = SE/length 203 | SP = SP/length 204 | PC = PC/length 205 | F1 = F1/length 206 | JS = JS/length 207 | DC = DC/length 208 | 209 | # Print the log info 210 | print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % ( 211 | epoch+1, self.num_epochs, 212 | epoch_loss, 213 | acc,SE,SP,PC,F1,JS,DC)) 214 | 215 | 216 | # Decay learning rate 217 | if (epoch+1) > (self.num_epochs - self.num_epochs_decay): 218 | lr -= (self.lr / float(self.num_epochs_decay)) 219 | for param_group in self.optimizer.param_groups: 220 | param_group['lr'] = lr 221 | print ('Decay learning rate to lr: {}.'.format(lr)) 222 | 223 | 224 | #===================================== Validation ====================================# 225 | self.unet.train(False) 226 | self.unet.eval() 227 | 228 | acc = 0. # Accuracy 229 | SE = 0. # Sensitivity (Recall) 230 | SP = 0. # Specificity 231 | PC = 0. # Precision 232 | F1 = 0. # F1 Score 233 | JS = 0. # Jaccard Similarity 234 | DC = 0. # Dice Coefficient 235 | length=0 236 | for i, (images, GT) in enumerate(self.valid_loader): 237 | 238 | images = images.to(self.device) 239 | GT = GT.to(self.device) 240 | SR = F.sigmoid(self.unet(images)) 241 | acc += get_accuracy(SR,GT) 242 | SE += get_sensitivity(SR,GT) 243 | SP += get_specificity(SR,GT) 244 | PC += get_precision(SR,GT) 245 | F1 += get_F1(SR,GT) 246 | JS += get_JS(SR,GT) 247 | DC += get_DC(SR,GT) 248 | 249 | length += images.size(0) 250 | 251 | acc = acc/length 252 | SE = SE/length 253 | SP = SP/length 254 | PC = PC/length 255 | F1 = F1/length 256 | JS = JS/length 257 | DC = DC/length 258 | unet_score = JS + DC 259 | 260 | print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'%(acc,SE,SP,PC,F1,JS,DC)) 261 | 262 | ''' 263 | torchvision.utils.save_image(images.data.cpu(), 264 | os.path.join(self.result_path, 265 | '%s_valid_%d_image.png'%(self.model_type,epoch+1))) 266 | torchvision.utils.save_image(SR.data.cpu(), 267 | os.path.join(self.result_path, 268 | '%s_valid_%d_SR.png'%(self.model_type,epoch+1))) 269 | torchvision.utils.save_image(GT.data.cpu(), 270 | os.path.join(self.result_path, 271 | '%s_valid_%d_GT.png'%(self.model_type,epoch+1))) 272 | ''' 273 | 274 | 275 | # Save Best U-Net model 276 | if unet_score > best_unet_score: 277 | best_unet_score = unet_score 278 | best_epoch = epoch 279 | best_unet = self.unet.state_dict() 280 | print('Best %s model score : %.4f'%(self.model_type,best_unet_score)) 281 | torch.save(best_unet,unet_path) 282 | 283 | #===================================== Test ====================================# 284 | del self.unet 285 | del best_unet 286 | self.build_model() 287 | self.unet.load_state_dict(torch.load(unet_path)) 288 | 289 | self.unet.train(False) 290 | self.unet.eval() 291 | 292 | acc = 0. # Accuracy 293 | SE = 0. # Sensitivity (Recall) 294 | SP = 0. # Specificity 295 | PC = 0. # Precision 296 | F1 = 0. # F1 Score 297 | JS = 0. # Jaccard Similarity 298 | DC = 0. # Dice Coefficient 299 | length=0 300 | for i, (images, GT) in enumerate(self.valid_loader): 301 | 302 | images = images.to(self.device) 303 | GT = GT.to(self.device) 304 | SR = F.sigmoid(self.unet(images)) 305 | acc += get_accuracy(SR,GT) 306 | SE += get_sensitivity(SR,GT) 307 | SP += get_specificity(SR,GT) 308 | PC += get_precision(SR,GT) 309 | F1 += get_F1(SR,GT) 310 | JS += get_JS(SR,GT) 311 | DC += get_DC(SR,GT) 312 | 313 | length += images.size(0) 314 | 315 | acc = acc/length 316 | SE = SE/length 317 | SP = SP/length 318 | PC = PC/length 319 | F1 = F1/length 320 | JS = JS/length 321 | DC = DC/length 322 | unet_score = JS + DC 323 | 324 | f = open(os.path.join(self.result_path,'result.csv'), 'a', encoding='utf-8', newline='') 325 | wr = csv.writer(f) 326 | wr.writerow([self.model_type,acc,SE,SP,PC,F1,JS,DC,self.lr,best_epoch,self.num_epochs,self.num_epochs_decay,self.augmentation_prob]) 327 | f.close() 328 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | 7 | def init_weights(net, init_type='normal', gain=0.02): 8 | def init_func(m): 9 | classname = m.__class__.__name__ 10 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 11 | if init_type == 'normal': 12 | init.normal_(m.weight.data, 0.0, gain) 13 | elif init_type == 'xavier': 14 | init.xavier_normal_(m.weight.data, gain=gain) 15 | elif init_type == 'kaiming': 16 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 17 | elif init_type == 'orthogonal': 18 | init.orthogonal_(m.weight.data, gain=gain) 19 | else: 20 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 21 | if hasattr(m, 'bias') and m.bias is not None: 22 | init.constant_(m.bias.data, 0.0) 23 | elif classname.find('BatchNorm2d') != -1: 24 | init.normal_(m.weight.data, 1.0, gain) 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | print('initialize network with %s' % init_type) 28 | net.apply(init_func) 29 | 30 | class conv_block(nn.Module): 31 | def __init__(self,ch_in,ch_out): 32 | super(conv_block,self).__init__() 33 | self.conv = nn.Sequential( 34 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 35 | nn.BatchNorm2d(ch_out), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 38 | nn.BatchNorm2d(ch_out), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | def forward(self,x): 43 | x = self.conv(x) 44 | return x 45 | 46 | class up_conv(nn.Module): 47 | def __init__(self,ch_in,ch_out): 48 | super(up_conv,self).__init__() 49 | self.up = nn.Sequential( 50 | nn.Upsample(scale_factor=2), 51 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 52 | nn.BatchNorm2d(ch_out), 53 | nn.ReLU(inplace=True) 54 | ) 55 | 56 | def forward(self,x): 57 | x = self.up(x) 58 | return x 59 | 60 | class Recurrent_block(nn.Module): 61 | def __init__(self,ch_out,t=2): 62 | super(Recurrent_block,self).__init__() 63 | self.t = t 64 | self.ch_out = ch_out 65 | self.conv = nn.Sequential( 66 | nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 67 | nn.BatchNorm2d(ch_out), 68 | nn.ReLU(inplace=True) 69 | ) 70 | 71 | def forward(self,x): 72 | for i in range(self.t): 73 | 74 | if i==0: 75 | x1 = self.conv(x) 76 | 77 | x1 = self.conv(x+x1) 78 | return x1 79 | 80 | class RRCNN_block(nn.Module): 81 | def __init__(self,ch_in,ch_out,t=2): 82 | super(RRCNN_block,self).__init__() 83 | self.RCNN = nn.Sequential( 84 | Recurrent_block(ch_out,t=t), 85 | Recurrent_block(ch_out,t=t) 86 | ) 87 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0) 88 | 89 | def forward(self,x): 90 | x = self.Conv_1x1(x) 91 | x1 = self.RCNN(x) 92 | return x+x1 93 | 94 | 95 | class single_conv(nn.Module): 96 | def __init__(self,ch_in,ch_out): 97 | super(single_conv,self).__init__() 98 | self.conv = nn.Sequential( 99 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 100 | nn.BatchNorm2d(ch_out), 101 | nn.ReLU(inplace=True) 102 | ) 103 | 104 | def forward(self,x): 105 | x = self.conv(x) 106 | return x 107 | 108 | class Attention_block(nn.Module): 109 | def __init__(self,F_g,F_l,F_int): 110 | super(Attention_block,self).__init__() 111 | self.W_g = nn.Sequential( 112 | nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), 113 | nn.BatchNorm2d(F_int) 114 | ) 115 | 116 | self.W_x = nn.Sequential( 117 | nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), 118 | nn.BatchNorm2d(F_int) 119 | ) 120 | 121 | self.psi = nn.Sequential( 122 | nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), 123 | nn.BatchNorm2d(1), 124 | nn.Sigmoid() 125 | ) 126 | 127 | self.relu = nn.ReLU(inplace=True) 128 | 129 | def forward(self,g,x): 130 | g1 = self.W_g(g) 131 | x1 = self.W_x(x) 132 | psi = self.relu(g1+x1) 133 | psi = self.psi(psi) 134 | 135 | return x*psi 136 | 137 | 138 | class U_Net(nn.Module): 139 | def __init__(self,img_ch=3,output_ch=1): 140 | super(U_Net,self).__init__() 141 | 142 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 143 | 144 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 145 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 146 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 147 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 148 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 149 | 150 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 151 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 152 | 153 | self.Up4 = up_conv(ch_in=512,ch_out=256) 154 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 155 | 156 | self.Up3 = up_conv(ch_in=256,ch_out=128) 157 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 158 | 159 | self.Up2 = up_conv(ch_in=128,ch_out=64) 160 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 161 | 162 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 163 | 164 | 165 | def forward(self,x): 166 | # encoding path 167 | x1 = self.Conv1(x) 168 | 169 | x2 = self.Maxpool(x1) 170 | x2 = self.Conv2(x2) 171 | 172 | x3 = self.Maxpool(x2) 173 | x3 = self.Conv3(x3) 174 | 175 | x4 = self.Maxpool(x3) 176 | x4 = self.Conv4(x4) 177 | 178 | x5 = self.Maxpool(x4) 179 | x5 = self.Conv5(x5) 180 | 181 | # decoding + concat path 182 | d5 = self.Up5(x5) 183 | d5 = torch.cat((x4,d5),dim=1) 184 | 185 | d5 = self.Up_conv5(d5) 186 | 187 | d4 = self.Up4(d5) 188 | d4 = torch.cat((x3,d4),dim=1) 189 | d4 = self.Up_conv4(d4) 190 | 191 | d3 = self.Up3(d4) 192 | d3 = torch.cat((x2,d3),dim=1) 193 | d3 = self.Up_conv3(d3) 194 | 195 | d2 = self.Up2(d3) 196 | d2 = torch.cat((x1,d2),dim=1) 197 | d2 = self.Up_conv2(d2) 198 | 199 | d1 = self.Conv_1x1(d2) 200 | 201 | return d1 202 | 203 | 204 | class R2U_Net(nn.Module): 205 | def __init__(self,img_ch=3,output_ch=1,t=2): 206 | super(R2U_Net,self).__init__() 207 | 208 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 209 | self.Upsample = nn.Upsample(scale_factor=2) 210 | 211 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t) 212 | 213 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t) 214 | 215 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t) 216 | 217 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t) 218 | 219 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t) 220 | 221 | 222 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 223 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t) 224 | 225 | self.Up4 = up_conv(ch_in=512,ch_out=256) 226 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t) 227 | 228 | self.Up3 = up_conv(ch_in=256,ch_out=128) 229 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t) 230 | 231 | self.Up2 = up_conv(ch_in=128,ch_out=64) 232 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t) 233 | 234 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 235 | 236 | 237 | def forward(self,x): 238 | # encoding path 239 | x1 = self.RRCNN1(x) 240 | 241 | x2 = self.Maxpool(x1) 242 | x2 = self.RRCNN2(x2) 243 | 244 | x3 = self.Maxpool(x2) 245 | x3 = self.RRCNN3(x3) 246 | 247 | x4 = self.Maxpool(x3) 248 | x4 = self.RRCNN4(x4) 249 | 250 | x5 = self.Maxpool(x4) 251 | x5 = self.RRCNN5(x5) 252 | 253 | # decoding + concat path 254 | d5 = self.Up5(x5) 255 | d5 = torch.cat((x4,d5),dim=1) 256 | d5 = self.Up_RRCNN5(d5) 257 | 258 | d4 = self.Up4(d5) 259 | d4 = torch.cat((x3,d4),dim=1) 260 | d4 = self.Up_RRCNN4(d4) 261 | 262 | d3 = self.Up3(d4) 263 | d3 = torch.cat((x2,d3),dim=1) 264 | d3 = self.Up_RRCNN3(d3) 265 | 266 | d2 = self.Up2(d3) 267 | d2 = torch.cat((x1,d2),dim=1) 268 | d2 = self.Up_RRCNN2(d2) 269 | 270 | d1 = self.Conv_1x1(d2) 271 | 272 | return d1 273 | 274 | 275 | 276 | class AttU_Net(nn.Module): 277 | def __init__(self,img_ch=3,output_ch=1): 278 | super(AttU_Net,self).__init__() 279 | 280 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 281 | 282 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=64) 283 | self.Conv2 = conv_block(ch_in=64,ch_out=128) 284 | self.Conv3 = conv_block(ch_in=128,ch_out=256) 285 | self.Conv4 = conv_block(ch_in=256,ch_out=512) 286 | self.Conv5 = conv_block(ch_in=512,ch_out=1024) 287 | 288 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 289 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256) 290 | self.Up_conv5 = conv_block(ch_in=1024, ch_out=512) 291 | 292 | self.Up4 = up_conv(ch_in=512,ch_out=256) 293 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128) 294 | self.Up_conv4 = conv_block(ch_in=512, ch_out=256) 295 | 296 | self.Up3 = up_conv(ch_in=256,ch_out=128) 297 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64) 298 | self.Up_conv3 = conv_block(ch_in=256, ch_out=128) 299 | 300 | self.Up2 = up_conv(ch_in=128,ch_out=64) 301 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32) 302 | self.Up_conv2 = conv_block(ch_in=128, ch_out=64) 303 | 304 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 305 | 306 | 307 | def forward(self,x): 308 | # encoding path 309 | x1 = self.Conv1(x) 310 | 311 | x2 = self.Maxpool(x1) 312 | x2 = self.Conv2(x2) 313 | 314 | x3 = self.Maxpool(x2) 315 | x3 = self.Conv3(x3) 316 | 317 | x4 = self.Maxpool(x3) 318 | x4 = self.Conv4(x4) 319 | 320 | x5 = self.Maxpool(x4) 321 | x5 = self.Conv5(x5) 322 | 323 | # decoding + concat path 324 | d5 = self.Up5(x5) 325 | x4 = self.Att5(g=d5,x=x4) 326 | d5 = torch.cat((x4,d5),dim=1) 327 | d5 = self.Up_conv5(d5) 328 | 329 | d4 = self.Up4(d5) 330 | x3 = self.Att4(g=d4,x=x3) 331 | d4 = torch.cat((x3,d4),dim=1) 332 | d4 = self.Up_conv4(d4) 333 | 334 | d3 = self.Up3(d4) 335 | x2 = self.Att3(g=d3,x=x2) 336 | d3 = torch.cat((x2,d3),dim=1) 337 | d3 = self.Up_conv3(d3) 338 | 339 | d2 = self.Up2(d3) 340 | x1 = self.Att2(g=d2,x=x1) 341 | d2 = torch.cat((x1,d2),dim=1) 342 | d2 = self.Up_conv2(d2) 343 | 344 | d1 = self.Conv_1x1(d2) 345 | 346 | return d1 347 | 348 | 349 | class R2AttU_Net(nn.Module): 350 | def __init__(self,img_ch=3,output_ch=1,t=2): 351 | super(R2AttU_Net,self).__init__() 352 | 353 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 354 | self.Upsample = nn.Upsample(scale_factor=2) 355 | 356 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t) 357 | 358 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t) 359 | 360 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t) 361 | 362 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t) 363 | 364 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t) 365 | 366 | 367 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 368 | self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256) 369 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t) 370 | 371 | self.Up4 = up_conv(ch_in=512,ch_out=256) 372 | self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128) 373 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t) 374 | 375 | self.Up3 = up_conv(ch_in=256,ch_out=128) 376 | self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64) 377 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t) 378 | 379 | self.Up2 = up_conv(ch_in=128,ch_out=64) 380 | self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32) 381 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t) 382 | 383 | self.Conv_1x1 = nn.Conv2d(64,output_ch,kernel_size=1,stride=1,padding=0) 384 | 385 | 386 | def forward(self,x): 387 | # encoding path 388 | x1 = self.RRCNN1(x) 389 | 390 | x2 = self.Maxpool(x1) 391 | x2 = self.RRCNN2(x2) 392 | 393 | x3 = self.Maxpool(x2) 394 | x3 = self.RRCNN3(x3) 395 | 396 | x4 = self.Maxpool(x3) 397 | x4 = self.RRCNN4(x4) 398 | 399 | x5 = self.Maxpool(x4) 400 | x5 = self.RRCNN5(x5) 401 | 402 | # decoding + concat path 403 | d5 = self.Up5(x5) 404 | x4 = self.Att5(g=d5,x=x4) 405 | d5 = torch.cat((x4,d5),dim=1) 406 | d5 = self.Up_RRCNN5(d5) 407 | 408 | d4 = self.Up4(d5) 409 | x3 = self.Att4(g=d4,x=x3) 410 | d4 = torch.cat((x3,d4),dim=1) 411 | d4 = self.Up_RRCNN4(d4) 412 | 413 | d3 = self.Up3(d4) 414 | x2 = self.Att3(g=d3,x=x2) 415 | d3 = torch.cat((x2,d3),dim=1) 416 | d3 = self.Up_RRCNN3(d3) 417 | 418 | d2 = self.Up2(d3) 419 | x1 = self.Att2(g=d2,x=x1) 420 | d2 = torch.cat((x1,d2),dim=1) 421 | d2 = self.Up_RRCNN2(d2) 422 | 423 | d1 = self.Conv_1x1(d2) 424 | 425 | return d1 426 | --------------------------------------------------------------------------------