├── 1-s2.0-S0010482524004037-main.pdf ├── LICENSE ├── MyModel.py ├── README.md ├── Train_Seg_Cls_5.py ├── datasets.py ├── joint_transforms.py └── utils.py /1-s2.0-S0010482524004037-main.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqhe-frank/BUS-segmentation-and-classification/b611cee686b4c2bb753f0308c2716761df56d1c7/1-s2.0-S0010482524004037-main.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 贺琪琪 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MyModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.resnet import resnet18 4 | import torch.nn.functional as F 5 | from Multi_Scale_Module import PAFEM, GPM, FoldConv_aspp, HMU, SIM 6 | from DCN import DeformConv2d 7 | 8 | 9 | class CA_Module(nn.Module): 10 | def __init__(self, in_channel): 11 | super(CA_Module, self).__init__() 12 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 13 | self.maxpool = nn.AdaptiveMaxPool2d((1, 1)) 14 | 15 | self.linear = nn.Sequential(nn.Linear(2 * in_channel, in_channel // 16), 16 | nn.ReLU(), 17 | nn.Linear(in_channel // 16, in_channel), 18 | nn.Sigmoid()) 19 | 20 | def forward(self, x): 21 | b, c, _, _ = x.size() 22 | p1 = self.avgpool(x) 23 | p2 = self.maxpool(x) 24 | p = torch.flatten(torch.cat([p1, p2], dim=1), 1) 25 | po = self.linear(p).view(b, c, 1, 1) 26 | out = nn.ReLU()(x * po) 27 | return out 28 | 29 | 30 | class classfiler_1(nn.Module): 31 | def __init__(self): 32 | super(classfiler_1, self).__init__() 33 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1) 34 | self.fc = nn.Sequential(nn.Linear(1280, 640), nn.ReLU(), nn.Linear(640, 2)) 35 | self.ca1 = CA_Module(256) 36 | self.ca2 = CA_Module(512) 37 | self.ca3 = CA_Module(512) 38 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 39 | 40 | def forward(self, input1, input2, input3): 41 | c1 = self.ca1(input1) 42 | c2 = self.up(self.ca2(input2)) 43 | c3 = self.ca3(input3) 44 | 45 | all = self.avgpool(torch.cat([c1, c2, c3], 1)) 46 | c5 = torch.flatten(all, 1) 47 | out = self.fc(c5) 48 | return out 49 | 50 | 51 | class GSA_Module(nn.Module): 52 | def __init__(self, in_channel): 53 | super(GSA_Module, self).__init__() 54 | 55 | self.output = nn.Sequential(nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1), 56 | nn.BatchNorm2d(in_channel), nn.ReLU(inplace=True)) 57 | 58 | self.gate = nn.Sequential(nn.Conv2d(in_channel, in_channel // 8, 3, 1), 59 | nn.BatchNorm2d(in_channel // 8), nn.ReLU(inplace=True), 60 | nn.Conv2d(in_channel // 8, 1, kernel_size=3, padding=1)) 61 | 62 | self.sa = SpatialAttentionModule() 63 | 64 | def forward(self, x): 65 | b, c, _, _ = x.size() 66 | 67 | sal = self.sa(x)*x 68 | 69 | g1 = self.gate(x) 70 | g2 = F.adaptive_avg_pool2d(torch.sigmoid(g1), 1) 71 | g3 = self.output(g2.repeat(1, c, 1, 1) * sal) #.unsqueeze(1) 72 | return g3 73 | 74 | 75 | class SpatialAttentionModule(nn.Module): 76 | def __init__(self): 77 | super(SpatialAttentionModule, self).__init__() 78 | self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, stride=1, padding=1) 79 | self.sigmoid = nn.Sigmoid() 80 | 81 | def forward(self, x): 82 | avgout = torch.mean(x, dim=1, keepdim=True) 83 | maxout, _ = torch.max(x, dim=1, keepdim=True) 84 | out = torch.cat([avgout, maxout], dim=1) 85 | out = nn.ReLU(self.sigmoid(self.conv2d(out)) * x) 86 | return out 87 | 88 | 89 | class GSA(nn.Module): 90 | def __init__(self, in_channel): 91 | super(GSA, self).__init__() 92 | 93 | self.output = nn.Sequential(nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1), 94 | nn.GroupNorm(in_channel // 4, in_channel), nn.ReLU()) 95 | 96 | self.gate = nn.Sequential(nn.Conv2d(in_channel, in_channel // 8, kernel_size=3, padding=1), 97 | nn.GroupNorm(in_channel // 16, in_channel // 8), nn.ReLU(), 98 | nn.Conv2d(in_channel // 8, 2, kernel_size=3, padding=1)) 99 | 100 | def forward(self, x): 101 | b, c, _, _ = x.size() 102 | g1 = self.gate(x) 103 | g2 = F.adaptive_avg_pool2d(torch.sigmoid(g1), 1) 104 | g3 = F.adaptive_max_pool2d(torch.sigmoid(g1), 1) 105 | output = self.output(g2[:, 0, :, :].unsqueeze(1).repeat(1, c, 1, 1) * x + \ 106 | g3[:, 1, :, :].unsqueeze(1).repeat(1, c, 1, 1) * x) 107 | 108 | return output 109 | 110 | 111 | class DSAModule(nn.Module): 112 | def __init__(self, in_ch): 113 | super(DSAModule, self).__init__() 114 | 115 | self.dcn1 = nn.Sequential(DeformConv2d(in_ch, in_ch, 3, padding=1, modulation=True), 116 | nn.GroupNorm(in_ch // 4, in_ch), nn.ReLU()) 117 | 118 | self.dcn2 = nn.Sequential(DeformConv2d(2 * in_ch, in_ch, 3, padding=1, modulation=True), 119 | nn.GroupNorm(in_ch // 4, in_ch), nn.ReLU()) 120 | 121 | self.conv1 = nn.Sequential(nn.Conv2d(1, 1, 3, 1, 1), nn.Sigmoid()) 122 | self.conv2 = nn.Sequential(nn.Conv2d(1, 1, 3, 1, 1), nn.Sigmoid()) 123 | 124 | def forward(self, x): 125 | c1 = self.dcn1(x) 126 | # torch.mean(x,1).unsqueeze(1) 127 | avgout = self.conv1(torch.mean(c1, 1).unsqueeze(1)) 128 | o1 = avgout * c1 129 | # torch.max(x, dim=1, keepdim=True) 130 | maxout = self.conv2(torch.max(c1, 1)[0].unsqueeze(1)) 131 | o2 = maxout * c1 132 | 133 | out = torch.cat([o1, o2], dim=1) 134 | out = self.dcn2(out) 135 | return out 136 | 137 | 138 | 139 | class MyModel(nn.Module): 140 | def __init__(self): 141 | super(MyModel, self).__init__() 142 | self.base_model = resnet18(pretrained=True) 143 | self.base_layers = list(self.base_model.children()) 144 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # 64*128*128 145 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # 64*64*64 146 | self.layer2 = self.base_layers[5] # 128 *32*32 147 | self.layer3 = self.base_layers[6] # 256 *16*16 148 | self.layer4 = self.base_layers[7] # 512 *8*8 149 | 150 | self.conv2 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.GroupNorm(16, 512), 151 | nn.ReLU(inplace=True), 152 | nn.Conv2d(512, 512, 3, 1, 1), nn.GroupNorm(16, 512), 153 | nn.ReLU(inplace=True)) 154 | 155 | self.conv3 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.GroupNorm(8, 256), 156 | nn.ReLU(inplace=True), 157 | nn.Conv2d(256, 256, 3, 1, 1), nn.GroupNorm(8, 256), 158 | nn.ReLU(inplace=True)) 159 | 160 | self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.GroupNorm(4, 128), 161 | nn.ReLU(inplace=True), 162 | nn.Conv2d(128, 128, 3, 1, 1), nn.GroupNorm(4, 128), 163 | nn.ReLU(inplace=True)) 164 | 165 | self.conv5 = nn.Sequential(nn.Conv2d(128, 128, 1), nn.ReLU(inplace=True), 166 | nn.Conv2d(128, 128, 3, 1, 1), nn.GroupNorm(4, 128), 167 | nn.ReLU(inplace=True)) 168 | 169 | self.up1 = nn.Sequential(nn.Conv2d(512, 256, 1), nn.ReLU(inplace=True), 170 | nn.UpsamplingBilinear2d(scale_factor=2)) 171 | 172 | self.up2 = nn.Sequential(nn.Conv2d(512, 128, 1), nn.ReLU(inplace=True), 173 | nn.UpsamplingBilinear2d(scale_factor=2)) 174 | 175 | self.up3 = nn.Sequential(nn.Conv2d(256, 64, 1), nn.ReLU(inplace=True), 176 | nn.UpsamplingBilinear2d(scale_factor=2)) 177 | 178 | self.up4 = nn.Sequential(nn.Conv2d(128, 64, 1), nn.ReLU(inplace=True), 179 | nn.UpsamplingBilinear2d(scale_factor=2)) 180 | 181 | self.out = nn.Sequential(nn.UpsamplingBilinear2d(scale_factor=2), 182 | nn.Conv2d(128, 128, 3, 1, 1), nn.GroupNorm(4, 128), 183 | nn.ReLU(inplace=True), 184 | nn.Conv2d(128, 1, 1)) 185 | 186 | self.classfer = classfiler_1() 187 | 188 | self.g1 = GSA(512) 189 | self.g2 = GSA(256) 190 | self.g3 = GSA(128) 191 | self.g4 = GSA(64) 192 | self.g5 = GSA(64) 193 | 194 | self.sa1 = DSAModule(512) 195 | self.sa2 = DSAModule(256) 196 | 197 | 198 | def forward(self, x): 199 | layer0 = self.layer0(x) # 64 200 | layer1 = self.layer1(layer0) # 64 201 | layer2 = self.layer2(layer1) # 128 202 | layer3 = self.layer3(layer2) # 256 203 | layer4 = self.layer4(layer3) # 512 204 | 205 | up1 = self.up1(self.g1(layer4)) 206 | ffm1 = torch.cat([self.g2(layer3), up1], dim=1) 207 | ffm1 = self.sa1(self.conv2(ffm1)) 208 | 209 | classfiler = self.classfer(layer3, layer4, ffm1) 210 | 211 | up2 = self.up2(ffm1) 212 | ffm2 = torch.cat([self.g3(layer2), up2], dim=1) 213 | ffm2 = self.sa2(self.conv3(ffm2)) 214 | 215 | up3 = self.up3(ffm2) 216 | ffm3 = torch.cat([self.g4(layer1), up3], dim=1) 217 | ffm3 = self.conv4(ffm3) 218 | 219 | up4 = self.up4(ffm3) 220 | ffm4 = torch.cat([self.g5(layer0), up4], dim=1) 221 | out = self.conv5(ffm4) 222 | out = self.out(out) 223 | 224 | return classfiler, out 225 | 226 | 227 | 228 | if __name__ == '__main__': 229 | from torch.autograd import Variable 230 | x = Variable(torch.rand(2, 3, 256, 256)).cuda() 231 | model = MyModel().cuda() 232 | c, s = model(x) 233 | print('Output s shape:', s.shape) 234 | print('Output c shape:', c.shape) 235 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BUS-segmentation-and-classification 2 | Multi-task learning for segmentation and classification of breast tumors from ultrasound images 3 | 4 | # Abstract 5 | Segmentation and classification of breast tumors are critical components of breast ultrasound (BUS) computer-aided diagnosis (CAD), which significantly improves the diagnostic accuracy of breast cancer. However, the characteristics of tumor regions in BUS images, such as non-uniform intensity distributions, ambiguous or missing boundaries, and varying tumor shapes and sizes, pose significant challenges to automated segmentation and classification solutions. Many previous studies have proposed multi-task learning methods to jointly tackle tumor segmentation and classification by sharing the features extracted by the encoder. Unfortunately, this often introduces redundant or misleading information, which hinders effective feature exploitation and adversely affects performance. To address this issue, we present ACSNet, a novel multi-task learning network designed to optimize tumor segmentation and classification in BUS images. The segmentation network incorporates a novel gate unit to allow optimal transfer of valuable contextual information from the encoder to the decoder. In addition, we develop the Deformable Spatial Attention Module (DSAModule) to improve segmentation accuracy by overcoming the limitations of conventional convolution in dealing with morphological variations of tumors. In the classification branch, multi-scale feature extraction and channel attention mechanisms are integrated to discriminate between benign and malignant breast tumors. Experiments on two publicly available BUS datasets demonstrate that ACSNet not only outperforms mainstream multi-task learning methods for both breast tumor segmentation and classification tasks, but also achieves state-of-the-art results for BUS tumor segmentation. 6 | # framwork: 7 | ![network](https://github.com/user-attachments/assets/8e422aad-c244-488c-b099-645c3a60aab7) 8 | 9 | Thanks to the following projects open source: 10 | 1. https://github.com/msracver/Deformable-ConvNets (DCN) 11 | 2. https://github.com/xorangecheng/GlobalGuidance-Net 12 | 3. https://github.com/mroussak/BUS_Deep_Learning、https://github.com/SimonVandenhende/Multi-Task-Learning-PyTorch 13 | 14 | # Note 15 | DCN:Deformable Convolution (DCN) code from https://github.com/msracver/Deformable-ConvNets 16 | 17 | If you have any questions, please contact: 18 | 20210217h@gmail.com 19 | -------------------------------------------------------------------------------- /Train_Seg_Cls_5.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | from sklearn.model_selection import KFold 7 | from torch import optim 8 | from torch.autograd import Variable 9 | import all_transfroms 10 | from torchvision import transforms 11 | from datasets import ImageFolder 12 | from torch.utils.data import DataLoader,SubsetRandomSampler 13 | from model_mt import UNet 14 | from Model_Our import BaseLine, BaseLine_1 15 | # from SModel.MTLNet import BaseLine 16 | from SModel.networks.multi_task_unet import MT_Net 17 | # from Model_Last import BaseLine38 18 | import torch.nn as nn 19 | from utils import DiceLoss, dice_coef, AutomaticWeightedLoss, \ 20 | AutoWeightedLoss, AutoLoss, metric_seg, cmp_3, metric_seg_1 21 | import torch.backends.cudnn as cudnn 22 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 23 | import matplotlib.pyplot as plt 24 | from scipy.ndimage import gaussian_filter1d 25 | 26 | 27 | train_path = "./Dataset/" 28 | 29 | joint_transforms=all_transfroms.Compose([ 30 | all_transfroms.Resize((256,256)), 31 | all_transfroms.RandomHorizontallyFlip(0.6), 32 | all_transfroms.RandomRotate(30), 33 | all_transfroms.RandomVerticalFlip(0.6) 34 | ]) 35 | 36 | val_transform = transforms.Compose([ 37 | transforms.Resize((256,256)), 38 | transforms.ToTensor(), 39 | transforms.Normalize([0.330, 0.330, 0.330], [0.204, 0.204, 0.204]) 40 | ]) 41 | 42 | # , 43 | # transforms.Normalize([0.248, 0.248, 0.248], [0.151, 0.151, 0.151]) 44 | 45 | transform = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize([0.330, 0.330, 0.330], [0.204, 0.204, 0.204]) 48 | ]) 49 | 50 | # , 51 | # transforms.Normalize([0.248, 0.248, 0.248], [0.151, 0.151, 0.151]) 52 | 53 | #transforms.Normalize([0.330, 0.330, 0.330], [0.204, 0.204, 0.204]) 54 | 55 | target_transform = transforms.ToTensor() 56 | val_target_transform = transforms.Compose([transforms.Resize((256,256)), transforms.ToTensor()]) #transforms.Resize((256,256)), 57 | 58 | 59 | 60 | def main(): 61 | 62 | train_set = ImageFolder(train_path, joint_transforms, transform, target_transform) 63 | test_set = ImageFolder(train_path, None, val_transform, val_target_transform) 64 | 65 | cv = KFold(n_splits=5, random_state=42, shuffle=True) 66 | fold = 1 67 | 68 | num_epochs = 90 69 | 70 | tr_loss = [] 71 | val_loss = [] 72 | test_hd95 = [] 73 | test_asd = [] 74 | test_ji = [] 75 | test_dice = [] 76 | test_acc = [] 77 | test_pre = [] 78 | test_recall = [] 79 | test_f1 = [] 80 | 81 | for train_idx,test_idx in cv.split(train_set): 82 | 83 | print("\nCross validation fold %d" % fold) 84 | 85 | train_sampler = SubsetRandomSampler(train_idx) 86 | test_sampler = SubsetRandomSampler(test_idx) 87 | 88 | train_loader = DataLoader(train_set, batch_size=8, 89 | num_workers=1, shuffle=False, sampler=train_sampler) 90 | 91 | val_loader = DataLoader(test_set, batch_size=2, 92 | num_workers=1, shuffle=False, sampler=test_sampler) 93 | 94 | test_loader = DataLoader(test_set, batch_size=1, num_workers=1, 95 | shuffle=False, sampler=test_sampler) 96 | 97 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 98 | net = base2() 99 | net = net.to(device) 100 | 101 | train_loss, avg_val_loss, pre_ji, pre_dice, \ 102 | hd95, asd, pre_class_acc, pre, recall, f1 = train(net, train_loader, val_loader, test_loader, fold, num_epochs) 103 | 104 | tr_loss.append(train_loss) 105 | val_loss.append(avg_val_loss) 106 | test_ji.append(pre_ji) 107 | test_dice.append(pre_dice) 108 | test_hd95.append(hd95) 109 | test_asd.append(asd) 110 | test_acc.append(pre_class_acc) 111 | test_pre.append(pre) 112 | test_recall.append(recall) 113 | test_f1.append(f1) 114 | 115 | fold += 1 116 | torch.cuda.empty_cache() 117 | 118 | print('\n', '#' * 10, '最终5折交叉验证结果', '#' * 10) 119 | print('Average Train Loss:{:.4f}'.format(np.mean(tr_loss))) 120 | print('Average Val Loss:{:.4f}'.format(np.mean(val_loss))) 121 | print('\n', '#' * 10, 'Segmentation Results', '#' * 10) 122 | print('Average Test Jaccard:{:.2%}±{:.4} '.format(np.mean(test_ji), np.std(test_ji))) 123 | print('Average Test Dice:{:.2%}±{:.4}'.format(np.mean(test_dice), np.std(test_dice))) 124 | print('Average Test HD95:{:.2f}±{:.4}'.format(np.mean(test_hd95), np.std(test_hd95))) 125 | print('Average Test ASD:{:.2f}±{:.4}'.format(np.mean(test_asd), np.std(test_asd))) 126 | print('\n', '#' * 10, 'Classification Results', '#' * 10) 127 | print('Average Test Accuracy:{:.2%}±{:.4}'.format(np.mean(test_acc), np.std(test_acc))) 128 | print('Average Test Precision:{:.2%}±{:.4}'.format(np.mean(test_pre), np.std(test_pre))) 129 | print('Average Test Recall:{:.2%}±{:.4}'.format(np.mean(test_recall), np.std(test_recall))) 130 | print('Average Test F1 Score:{:.2%}±{:.4}'.format(np.mean(test_f1), np.std(test_f1))) 131 | 132 | 133 | 134 | class AvgMeter(object): 135 | def __init__(self): 136 | self.reset() 137 | 138 | def reset(self): 139 | self.val = 0 140 | self.avg = 0 141 | self.sum = 0 142 | self.count = 0 143 | 144 | def update(self, val, n=1): 145 | self.val = val 146 | self.sum += val * n 147 | self.count += n 148 | self.avg = self.sum / self.count 149 | 150 | def train(model, train_loader, val_loader, test_loader, fold, num_epochs): 151 | 152 | global train_loss, class_acc, avg_val_loss, seg_dice 153 | avg_train_loss = AvgMeter() 154 | avg_train_dice = AvgMeter() 155 | avg_train_ji = AvgMeter() 156 | 157 | bce_logit = DiceLoss().cuda() 158 | #bc_class = nn.BCEWithLogitsLoss().cuda() 159 | bc_class = nn.CrossEntropyLoss().cuda() 160 | #awl = AutomaticWeightedLoss(2).cuda() 161 | awl = AutomaticWeightedLoss().cuda() 162 | 163 | train_loss_pic = [] 164 | val_loss_pic = [] 165 | val_dice_pic = [] 166 | val_ji_pic = [] 167 | train_ji_pic = [] 168 | train_dice_pic = [] 169 | 170 | params = [p for p in model.parameters() if p.requires_grad] 171 | params.append(awl.params) 172 | 173 | optimizer = optim.Adam(params, lr=0.0001, weight_decay=1e-4) 174 | lf = lambda x: ((1 + math.cos(x * math.pi / num_epochs)) / 2) * (1 - 0.1) + 0.1 175 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) 176 | 177 | all_result = 0.0 178 | 179 | save_path = './Result/pth/Train_{}_folder_model.pth'.format(fold) 180 | 181 | for epoch in range(num_epochs): 182 | model.train() 183 | 184 | for step, data in enumerate(train_loader): 185 | optimizer.zero_grad() 186 | 187 | images, masks, labels = data 188 | 189 | images = Variable(images).cuda() 190 | masks = Variable(masks).cuda() 191 | labels = Variable(labels).cuda() 192 | 193 | class_logits, seg_logits = model(images) 194 | 195 | seg_loss = bce_logit(seg_logits, masks) 196 | class_loss = bc_class(class_logits, labels) 197 | #seg_loss.backward(retain_graph=True) 198 | loss = awl(seg_loss, class_loss) 199 | #loss = 0.7*seg_loss+0.3*class_loss 200 | dc, jc = metric_seg_1(seg_logits, masks) 201 | 202 | loss.backward() 203 | optimizer.step() 204 | 205 | avg_train_loss.update(loss.item(), images.size(0)) 206 | avg_train_dice.update(dc, images.size(0)) 207 | avg_train_ji.update(jc, images.size(0)) 208 | 209 | 210 | train_loss = avg_train_loss.avg 211 | train_dc = avg_train_dice.avg 212 | train_jc = avg_train_ji.avg 213 | 214 | train_loss_pic.append(train_loss) 215 | train_dice_pic.append(train_dc) 216 | train_ji_pic.append(train_jc) 217 | 218 | avg_val_loss, seg_dice, ji, hd95, \ 219 | asd, class_acc, pre, recall, f1score = validate(model, val_loader) 220 | 221 | scheduler.step() 222 | val_loss_pic.append(avg_val_loss) 223 | val_dice_pic.append(seg_dice) 224 | val_ji_pic.append(ji) 225 | 226 | all_re = 0.7*seg_dice + 0.3*class_acc 227 | if all_re > all_result: 228 | all_result = all_re 229 | torch.save(model.state_dict(), save_path) 230 | 231 | print('Epoch:{}| TrainLoss:{:.4f} ValidLoss:{:.4f}| ' 232 | 'Acc:{:.2%} Pre:{:.2%} Recall:{:.2%} F1:{:.2%}|' 233 | 'Dice:{:.2%} JA:{:.2%} HD95:{:.2f} ASD:{:.2f}.' 234 | .format(epoch + 1, train_loss, avg_val_loss, class_acc, 235 | pre, recall, f1score, seg_dice,ji, hd95, asd)) 236 | 237 | pre_ji, pre_dice, hd95, asd, pre_class_acc, pre, recall, f1 = test(save_path, model, test_loader, fold) 238 | 239 | 240 | #np.around(np.mean(train_ji_pic), 3) 241 | 242 | #======================================================================= 243 | epochsn = np.arange(1, len(train_loss_pic) + 1, 1) 244 | plt.figure(figsize=(18, 5)) 245 | # figsize:指定figure的宽和高,单位为英寸; 246 | plt.subplot(131) 247 | y_sm = gaussian_filter1d(val_loss_pic, sigma=1) 248 | # 一个figure对象包含了多个子图,可以使用subplot()函数来绘制子图: 249 | plt.plot(epochsn, train_loss_pic, 'b', label='Training Loss') 250 | plt.plot(epochsn, y_sm, 'r', label='Validation Loss') 251 | 252 | plt.grid(color='gray', linestyle='--') 253 | plt.legend() 254 | # plt.legend()函数主要的作用就是给图加上图例 255 | plt.title('Loss, Epochs={}, Batch={}'.format(num_epochs, 5)) 256 | plt.xlabel('Epochs') 257 | plt.ylabel('Loss') 258 | 259 | plt.subplot(132) 260 | plt.plot(epochsn, train_dice_pic, 'g', label='Train Dice') 261 | y2_sm = gaussian_filter1d(val_dice_pic, sigma=1) 262 | plt.plot(epochsn, y2_sm, 'cyan', label='Validation Dice') 263 | plt.grid(color='gray', linestyle='--') 264 | plt.legend() 265 | plt.title('Dice coefficient score') 266 | plt.xlabel('Epochs') 267 | plt.ylabel('CSC') 268 | 269 | plt.subplot(133) 270 | plt.plot(epochsn, train_ji_pic, 'chocolate', label='Train Jaccard') 271 | y3_sm = gaussian_filter1d(val_ji_pic, sigma=1) 272 | plt.plot(epochsn, y3_sm, 'm', label='Validation Jaccard') 273 | plt.grid(color='gray', linestyle='--') 274 | plt.legend() 275 | plt.title('Jaccard coefficient score') 276 | plt.xlabel('Epochs') 277 | plt.ylabel('CSC') 278 | plt.savefig('./Result/pic/savefig_{}.png'.format(fold)) 279 | plt.show() 280 | 281 | 282 | return train_loss, avg_val_loss, pre_ji, pre_dice, hd95, asd, pre_class_acc, pre, recall, f1 283 | 284 | 285 | def validate(model, val_loader): 286 | losses = AvgMeter() 287 | avg_val_dice = AvgMeter() 288 | avg_val_jc = AvgMeter() 289 | avg_val_hd = AvgMeter() 290 | avg_val_asd = AvgMeter() 291 | 292 | 293 | bce_logit = DiceLoss().cuda() 294 | #class_logit = nn.BCEWithLogitsLoss().cuda() 295 | class_logit = nn.CrossEntropyLoss().cuda() 296 | #awl = AutomaticWeightedLoss(2).cuda() 297 | awl = AutomaticWeightedLoss().cuda() 298 | 299 | val_preds = [] 300 | val_trues = [] 301 | 302 | model.eval() 303 | 304 | with torch.no_grad(): 305 | torch.cuda.empty_cache() 306 | for i, (input, target, label) in enumerate(val_loader): 307 | input = input.cuda() 308 | target = target.cuda() 309 | label = label.cuda() 310 | 311 | pre_class, output = model(input) 312 | 313 | seg_loss = bce_logit(output, target) 314 | class_loss = class_logit(pre_class, label) 315 | loss = awl(seg_loss, class_loss) 316 | #loss = 0.6*seg_loss+0.4*class_loss 317 | dc, jc, hdc, asdc = metric_seg(output, target) 318 | 319 | avg_val_dice.update(dc, input.size(0)) 320 | avg_val_jc.update(jc, input.size(0)) 321 | avg_val_hd.update(hdc, input.size(0)) 322 | avg_val_asd.update(asdc, input.size(0)) 323 | 324 | pre_class = torch.sigmoid(pre_class) 325 | predict_class = torch.max(pre_class, dim=1)[1] 326 | val_preds.extend(predict_class.detach().cpu().numpy()) 327 | val_trues.extend(label.detach().cpu().numpy()) 328 | 329 | losses.update(loss.item(), input.size(0)) 330 | val_loss = losses.avg 331 | 332 | sklearn_accuracy = accuracy_score(val_trues, val_preds) 333 | sklearn_precision = precision_score(val_trues, val_preds, average='weighted') 334 | sklearn_recall = recall_score(val_trues, val_preds, average='macro') 335 | sklearn_f1 = f1_score(val_trues, val_preds, average='weighted') 336 | 337 | 338 | return val_loss, np.around(avg_val_dice.avg, 3), np.around(avg_val_jc.avg, 3), \ 339 | np.around(avg_val_hd.avg, 3), np.around(avg_val_asd.avg, 3),\ 340 | sklearn_accuracy, sklearn_precision, sklearn_recall, sklearn_f1 341 | 342 | 343 | def test(save_path, model, test_loader, fold): 344 | 345 | JI=[] 346 | Dices=[] 347 | test_preds = [] 348 | test_trues = [] 349 | HD95_1 = [] 350 | ASD_1 = [] 351 | weights_path = save_path 352 | assert os.path.exists(weights_path), f"file: '{weights_path}' dose not exist." 353 | model.load_state_dict(torch.load(weights_path)) 354 | to_pil = transforms.ToPILImage() 355 | 356 | model.eval() 357 | with torch.no_grad(): 358 | for i, (input, target, label) in enumerate(test_loader): 359 | 360 | image = Variable(input).cuda() 361 | target = Variable(target).cuda() 362 | label = Variable(label).cuda() 363 | 364 | pro_class, pro_seg = model(image) 365 | pro_class = torch.sigmoid(pro_class) 366 | predict_class = torch.max(pro_class, dim=1)[1] 367 | test_preds.extend(predict_class.detach().cpu().numpy()) 368 | test_trues.extend(label.detach().cpu().numpy()) 369 | 370 | a = target.squeeze(0) 371 | b = image.squeeze(0) 372 | a = to_pil(a) 373 | b = to_pil(b) 374 | a.save('./Result/TestResult/{}/mask{}.png'.format(fold, i)) 375 | b.save('./Result/TestResult/{}/img{}.png'.format(fold, i)) 376 | 377 | pro=torch.sigmoid(pro_seg).data.squeeze(0).cpu() 378 | c = to_pil(pro) 379 | c.save('./Result/TestResult/{}/pre{}.png'.format(fold, i)) 380 | target = target.squeeze(0).cpu() 381 | pro = np.array(pro) 382 | target = np.array(target) 383 | pro[pro>=0.5]=1 384 | pro[pro<0.5]=0 385 | TP=float(np.sum(np.logical_and(pro==1,target==1))) 386 | TN=float(np.sum(np.logical_and(pro==0,target==0))) 387 | FP=float(np.sum(np.logical_and(pro==1,target==0))) 388 | FN=float(np.sum(np.logical_and(pro==0,target==1))) 389 | JA=TP/((TP+FN+FP)+1e-5) 390 | DI=2*TP/((2*TP+FN+FP+1e-5)) 391 | 392 | Dices.append(DI) 393 | JI.append(JA) 394 | hd95, asd = cmp_3(pro, target) 395 | HD95_1.append(hd95) 396 | ASD_1.append(asd) 397 | 398 | sklearn_accuracy = accuracy_score(test_trues, test_preds) 399 | sklearn_precision = precision_score(test_trues, test_preds, average='weighted') 400 | sklearn_recall = recall_score(test_trues, test_preds, average='macro') 401 | sklearn_f1 = f1_score(test_trues, test_preds, average='weighted') 402 | 403 | print('Test Result:\n''Segmentation:\n Jaccard:{:.2%} ' 404 | 'Dice:{:.2%} HD95:{:.2f} ASD:{:.2f}'.format(np.around(np.mean(JI), 3), 405 | np.around(np.mean(Dices), 3), 406 | np.around(np.mean(HD95_1), 3), 407 | np.around(np.mean(ASD_1), 3))) 408 | print('Classification:\n Accuary:{:.2%} ' 409 | 'Precision:{:.2%} Recall:{:.2%} Score:{:.2%}'.format(sklearn_accuracy, 410 | sklearn_precision, 411 | sklearn_recall, 412 | sklearn_f1)) 413 | 414 | 415 | return np.around(np.mean(JI), 3), np.around(np.mean(Dices), 3), np.around(np.mean(HD95_1), 3),\ 416 | np.around(np.mean(ASD_1), 3),sklearn_accuracy, sklearn_precision, sklearn_recall, sklearn_f1 417 | 418 | 419 | 420 | 421 | 422 | if __name__ == "__main__": 423 | seed = 2 424 | #seed = 2 425 | random.seed(seed) 426 | np.random.seed(seed) 427 | torch.manual_seed(seed) 428 | torch.cuda.manual_seed(seed) 429 | torch.cuda.manual_seed_all(seed) 430 | 431 | torch.backends.cudnn.deterministic = True 432 | torch.backends.cudnn.benchmark = True 433 | main() 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch.utils.data as data 4 | from PIL import Image 5 | 6 | 7 | def make_dataset(root): 8 | 9 | img_num_class = [cla for cla in os.listdir(os.path.join(root, 'imgs')) 10 | if os.path.isdir(os.path.join(root, 'imgs', cla))] 11 | 12 | img_num_class.sort() 13 | 14 | class_indices = dict((k, v) for v, k in enumerate(img_num_class)) 15 | json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=1) 16 | with open('class_indices.json', 'w') as json_file: 17 | json_file.write(json_str) 18 | 19 | image_class1 = class_indices['benign'] 20 | 21 | a = [(os.path.join(root, 'imgs', 'benign', img_name), 22 | os.path.join(root, 'masks', 'benign', img_name.split('.')[0] + '_mask.png'), 23 | image_class1) for img_name in os.listdir(os.path.join(root, 'imgs', 'benign'))] 24 | 25 | image_class2 = class_indices['malignant'] 26 | 27 | b = [(os.path.join(root, 'imgs', 'malignant', img_name), 28 | os.path.join(root, 'masks', 'malignant', img_name.split('.')[0] + '_mask.png'), 29 | image_class2) for img_name in os.listdir(os.path.join(root, 'imgs', 'malignant'))] 30 | c = a+b 31 | return c 32 | 33 | class ImageFolder(data.Dataset): 34 | def __init__(self, root, joint_transform=None, transform=None, target_transform=None): 35 | self.root = root 36 | self.imgs = make_dataset(root) 37 | self.joint_transform = joint_transform 38 | self.transform = transform 39 | self.target_transform = target_transform 40 | 41 | 42 | def __getitem__(self, index): 43 | img_path, gt_path, label = self.imgs[index] 44 | img = Image.open(img_path).convert('RGB') 45 | target = Image.open(gt_path).convert('L') 46 | cla = label 47 | if self.joint_transform is not None: 48 | img, target, cla = self.joint_transform(img, target, cla) 49 | 50 | if self.transform is not None: 51 | img = self.transform(img) 52 | if self.target_transform is not None: 53 | target = self.target_transform(target) 54 | return img, target, cla 55 | 56 | def __len__(self): 57 | return len(self.imgs) 58 | 59 | class ImageFolder2(data.Dataset): 60 | def __init__(self, root, joint_transform=None, transform=None, target_transform=None): 61 | self.root = root 62 | self.imgs = make_dataset(root) 63 | self.joint_transform = joint_transform 64 | self.transform = transform 65 | self.target_transform = target_transform 66 | 67 | def __getitem__(self, index): 68 | img_path, gt_path = self.imgs[index] 69 | img = Image.open(img_path).convert('RGB') 70 | target = Image.open(gt_path).convert('L') 71 | if self.joint_transform is not None: 72 | img, target = self.joint_transform(img, target) 73 | 74 | if self.transform is not None: 75 | img = self.transform(img) 76 | if self.target_transform is not None: 77 | target = self.target_transform(target) 78 | 79 | return img, target 80 | 81 | def __len__(self): 82 | return len(self.imgs) 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /joint_transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | 5 | import torch 6 | from PIL import Image, ImageOps 7 | import numpy as np 8 | import torchvision.transforms.functional as F 9 | 10 | 11 | class Compose(object): 12 | def __init__(self, transforms): 13 | self.transforms = transforms 14 | 15 | def __call__(self, img, mask): 16 | assert img.size == mask.size 17 | for t in self.transforms: 18 | img, mask = t(img, mask) 19 | return img, mask 20 | 21 | 22 | class RandomCrop(object): 23 | def __init__(self, size, padding=0): 24 | if isinstance(size, numbers.Number): 25 | self.size = (int(size), int(size)) 26 | else: 27 | self.size = size 28 | self.padding = padding 29 | 30 | def __call__(self, img, mask): 31 | if self.padding > 0: 32 | img = ImageOps.expand(img, border=self.padding, fill=0) 33 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 34 | 35 | assert img.size == mask.size 36 | w, h = img.size 37 | th, tw = self.size 38 | if w == tw and h == th: 39 | return img, mask 40 | if w < tw or h < th: 41 | return img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST) 42 | 43 | x1 = random.randint(0, w - tw) 44 | y1 = random.randint(0, h - th) 45 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 46 | 47 | 48 | class CenterCrop(object): 49 | def __init__(self, size): 50 | if isinstance(size, numbers.Number): 51 | self.size = (int(size), int(size)) 52 | else: 53 | self.size = size 54 | 55 | def __call__(self, img, mask): 56 | assert img.size == mask.size 57 | w, h = img.size 58 | th, tw = self.size 59 | x1 = int(round((w - tw) / 2.)) 60 | y1 = int(round((h - th) / 2.)) 61 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 62 | 63 | class RandomHorizontallyFlip(object): 64 | def __init__(self, p=0.5): 65 | self.p = p 66 | 67 | def __call__(self, img, mask): 68 | 69 | if torch.rand(1) < self.p: 70 | return F.hflip(img), F.hflip(mask) 71 | return img, mask 72 | 73 | 74 | 75 | class FreeScale(object): 76 | def __init__(self, size): 77 | self.size = tuple(reversed(size)) # size: (h, w) 78 | 79 | def __call__(self, img, mask): 80 | assert img.size == mask.size 81 | return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST) 82 | 83 | 84 | class Scale(object): 85 | def __init__(self, size): 86 | self.size = size 87 | 88 | def __call__(self, img, mask): 89 | assert img.size == mask.size 90 | w, h = img.size 91 | if (w >= h and w == self.size) or (h >= w and h == self.size): 92 | return img, mask 93 | if w > h: 94 | ow = self.size 95 | oh = int(self.size * h / w) 96 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 97 | else: 98 | oh = self.size 99 | ow = int(self.size * w / h) 100 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 101 | 102 | 103 | class RandomSizedCrop(object): 104 | def __init__(self, size): 105 | self.size = size 106 | 107 | def __call__(self, img, mask): 108 | assert img.size == mask.size 109 | for attempt in range(10): 110 | area = img.size[0] * img.size[1] 111 | target_area = random.uniform(0.45, 1.0) * area 112 | aspect_ratio = random.uniform(0.5, 2) 113 | 114 | w = int(round(math.sqrt(target_area * aspect_ratio))) 115 | h = int(round(math.sqrt(target_area / aspect_ratio))) 116 | 117 | if random.random() < 0.5: 118 | w, h = h, w 119 | 120 | if w <= img.size[0] and h <= img.size[1]: 121 | x1 = random.randint(0, img.size[0] - w) 122 | y1 = random.randint(0, img.size[1] - h) 123 | 124 | img = img.crop((x1, y1, x1 + w, y1 + h)) 125 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 126 | assert (img.size == (w, h)) 127 | 128 | return img.resize((self.size, self.size), Image.BILINEAR), mask.resize((self.size, self.size), 129 | Image.NEAREST) 130 | 131 | # Fallback 132 | scale = Scale(self.size) 133 | crop = CenterCrop(self.size) 134 | return crop(*scale(img, mask)) 135 | 136 | 137 | class RandomRotate(object): 138 | def __init__(self, degree): 139 | self.degree = degree 140 | 141 | def __call__(self, img, mask): 142 | rotate_degree = random.random() * 2 * self.degree - self.degree 143 | return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) 144 | 145 | 146 | class RandomSized(object): 147 | def __init__(self, size): 148 | self.size = size 149 | self.scale = Scale(self.size) 150 | self.crop = RandomCrop(self.size) 151 | 152 | def __call__(self, img, mask): 153 | assert img.size == mask.size 154 | 155 | w = int(random.uniform(0.5, 2) * img.size[0]) 156 | h = int(random.uniform(0.5, 2) * img.size[1]) 157 | 158 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 159 | 160 | return self.crop(*self.scale(img, mask)) 161 | 162 | 163 | class SlidingCropOld(object): 164 | def __init__(self, crop_size, stride_rate, ignore_label): 165 | self.crop_size = crop_size 166 | self.stride_rate = stride_rate 167 | self.ignore_label = ignore_label 168 | 169 | def _pad(self, img, mask): 170 | h, w = img.shape[: 2] 171 | pad_h = max(self.crop_size - h, 0) 172 | pad_w = max(self.crop_size - w, 0) 173 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 174 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) 175 | return img, mask 176 | 177 | def __call__(self, img, mask): 178 | assert img.size == mask.size 179 | 180 | w, h = img.size 181 | long_size = max(h, w) 182 | 183 | img = np.array(img) 184 | mask = np.array(mask) 185 | 186 | if long_size > self.crop_size: 187 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 188 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 189 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 190 | img_sublist, mask_sublist = [], [] 191 | for yy in range(h_step_num): 192 | for xx in range(w_step_num): 193 | sy, sx = yy * stride, xx * stride 194 | ey, ex = sy + self.crop_size, sx + self.crop_size 195 | img_sub = img[sy: ey, sx: ex, :] 196 | mask_sub = mask[sy: ey, sx: ex] 197 | img_sub, mask_sub = self._pad(img_sub, mask_sub) 198 | img_sublist.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) 199 | mask_sublist.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) 200 | return img_sublist, mask_sublist 201 | else: 202 | img, mask = self._pad(img, mask) 203 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 204 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 205 | return img, mask 206 | 207 | 208 | class SlidingCrop(object): 209 | def __init__(self, crop_size, stride_rate, ignore_label): 210 | self.crop_size = crop_size 211 | self.stride_rate = stride_rate 212 | self.ignore_label = ignore_label 213 | 214 | def _pad(self, img, mask): 215 | h, w = img.shape[: 2] 216 | pad_h = max(self.crop_size - h, 0) 217 | pad_w = max(self.crop_size - w, 0) 218 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 219 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) 220 | return img, mask, h, w 221 | 222 | def __call__(self, img, mask): 223 | assert img.size == mask.size 224 | 225 | w, h = img.size 226 | long_size = max(h, w) 227 | 228 | img = np.array(img) 229 | mask = np.array(mask) 230 | 231 | if long_size > self.crop_size: 232 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 233 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 234 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 235 | img_slices, mask_slices, slices_info = [], [], [] 236 | for yy in range(h_step_num): 237 | for xx in range(w_step_num): 238 | sy, sx = yy * stride, xx * stride 239 | ey, ex = sy + self.crop_size, sx + self.crop_size 240 | img_sub = img[sy: ey, sx: ex, :] 241 | mask_sub = mask[sy: ey, sx: ex] 242 | img_sub, mask_sub, sub_h, sub_w = self._pad(img_sub, mask_sub) 243 | img_slices.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) 244 | mask_slices.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) 245 | slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) 246 | return img_slices, mask_slices, slices_info 247 | else: 248 | img, mask, sub_h, sub_w = self._pad(img, mask) 249 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 250 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 251 | return [img], [mask], [[0, sub_h, 0, sub_w, sub_h, sub_w]] 252 | 253 | class Resize(object): 254 | 255 | def __init__(self, size): 256 | self.size = size 257 | 258 | def __call__(self, img, mask): 259 | 260 | img = F.resize(img, self.size, F.InterpolationMode.BILINEAR) 261 | mask = F.resize(mask, self.size, F.InterpolationMode.NEAREST) 262 | 263 | return img, mask 264 | 265 | class RandomVerticalFlip(object): 266 | def __init__(self, p=0.5): 267 | self.p = p 268 | 269 | 270 | def __call__(self, img, mask): 271 | if torch.rand(1) < self.p: 272 | return F.vflip(img), F.vflip(mask) 273 | return img, mask 274 | 275 | 276 | 277 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | import pydensecrf.densecrf as dcrf 5 | from torch.optim import lr_scheduler 6 | from medpy import metric 7 | 8 | class DiceLoss(nn.Module): 9 | def __init__(self): 10 | super(DiceLoss, self).__init__() 11 | self.sigmoid = nn.Sigmoid() 12 | 13 | def forward(self, input, target): 14 | N = target.size(0) 15 | smooth = 1e-5 16 | input = self.sigmoid(input) 17 | 18 | input_flat = input.view(N, -1) 19 | target_flat = target.view(N, -1) 20 | 21 | intersection = input_flat * target_flat 22 | 23 | loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth) 24 | loss = 1 - loss.sum() / N 25 | 26 | return loss 27 | 28 | 29 | def dice_coef(output, target): 30 | smooth = 1e-5 31 | 32 | output = torch.sigmoid(output).view(-1).data.cpu().numpy() 33 | target = target.view(-1).data.cpu().numpy() 34 | intersection = (output * target).sum() 35 | 36 | return (2. * intersection + smooth) / \ 37 | (output.sum() + target.sum() + smooth) 38 | 39 | class AutomaticWeightedLoss(nn.Module): 40 | """automatically weighted multitask loss 41 | Params: 42 | num: int,the number of loss 43 | x: multitask loss 44 | Examples: 45 | loss1=1 46 | loss2=2 47 | awl = AutomaticWeightedLoss(2) 48 | loss_sum = awl(loss1, loss2) 49 | """ 50 | def __init__(self, num=2): 51 | super(AutomaticWeightedLoss, self).__init__() 52 | params = torch.ones(num, requires_grad=True) 53 | self.params = torch.nn.Parameter(params) 54 | 55 | def forward(self, *x): 56 | loss_sum = 0 57 | for i, loss in enumerate(x): 58 | loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2) 59 | return loss_sum 60 | 61 | 62 | class AutoWeightedLoss(nn.Module): 63 | 64 | def __init__(self, num=2): 65 | super(AutoWeightedLoss, self).__init__() 66 | params = torch.ones(num, requires_grad=True) 67 | self.params = torch.nn.Parameter(params) 68 | 69 | def forward(self, *x): 70 | loss_sum = 0 71 | for i, loss in enumerate(x): 72 | loss_sum += (self.params[i]) * loss 73 | return loss_sum 74 | 75 | 76 | 77 | class AutoLoss(nn.Module): 78 | 79 | def __init__(self, ): 80 | super(AutoLoss, self).__init__() 81 | params = torch.tensor(0.3, requires_grad=True) 82 | self.params = torch.nn.Parameter(params) 83 | 84 | def forward(self, x1, x2): 85 | loss_sum = self.params * x2 + (1-self.params) * x1 86 | return loss_sum 87 | 88 | 89 | def metric_seg(pred,gt): 90 | pred = torch.sigmoid(pred) 91 | pre = np.array((pred>0.5).to(torch.int).detach().cpu()) 92 | target = np.array(gt.to(torch.int).detach().cpu()) 93 | 94 | di = metric.binary.dc(pre, target) 95 | ji = metric.binary.jc(pre, target) 96 | 97 | if pre.sum() > 0 and target.sum() > 0: 98 | hd95 = metric.binary.hd95(pre, target) 99 | asd = metric.binary.asd(pre, target) 100 | return di, ji, hd95, asd 101 | else: 102 | return di, ji, 0, 0 103 | 104 | 105 | def metric_seg_1(pred,gt): 106 | pre = torch.sigmoid(pred) 107 | pre = np.array((pre>0.5).to(torch.int).detach().cpu()) 108 | target = np.array(gt.to(torch.int).detach().cpu()) 109 | 110 | di = metric.binary.dc(pre, target) 111 | ji = metric.binary.jc(pre, target) 112 | 113 | return di, ji 114 | 115 | 116 | 117 | 118 | 119 | 120 | def cmp_3(pred, gt): 121 | 122 | if pred.sum() > 0 and gt.sum() > 0: 123 | hd95 = metric.binary.hd95(pred, gt) 124 | asd = metric.binary.asd(pred, gt) 125 | return hd95, asd 126 | else: 127 | return 0, 0 128 | 129 | 130 | def crf_refine(img, annos): 131 | assert img.dtype == np.uint8 132 | assert annos.dtype == np.uint8 133 | assert img.shape[:2] == annos.shape 134 | 135 | def _sigmoid(x): 136 | return 1 / (1 + np.exp(-x)) 137 | 138 | # img and annos should be np array with data type uint8 139 | 140 | EPSILON = 1e-8 141 | 142 | M = 2 # salient or not 143 | tau = 1.05 144 | # Setup the CRF model 145 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) 146 | 147 | anno_norm = annos / 255. 148 | 149 | n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) 150 | p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) 151 | 152 | U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') 153 | U[0, :] = n_energy.flatten() 154 | U[1, :] = p_energy.flatten() 155 | 156 | d.setUnaryEnergy(U) 157 | 158 | d.addPairwiseGaussian(sxy=3, compat=3) 159 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) 160 | 161 | # Do the inference 162 | infer = np.array(d.inference(1)).astype('float32') 163 | res = infer[1, :] 164 | 165 | res = res * 255 166 | res = res.reshape(img.shape[:2]) 167 | return res.astype('uint8') 168 | 169 | 170 | 171 | def create_lr_scheduler(optimizer, 172 | num_step: int, 173 | epochs: int, 174 | warmup=True, 175 | warmup_epochs=1, 176 | warmup_factor=0.01): 177 | assert num_step > 0 and epochs > 0 178 | if warmup is False: 179 | warmup_epochs = 0 180 | 181 | def f(x): 182 | """ 183 | 根据step数返回一个学习率倍率因子, 184 | 注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法 185 | """ 186 | if warmup is True and x <= (warmup_epochs * num_step): 187 | alpha = float(x) / (warmup_epochs * num_step) 188 | # warmup过程中lr倍率因子从warmup_factor -> 1 189 | return warmup_factor * (1 - alpha) + alpha 190 | else: 191 | # warmup后lr倍率因子从1 -> 0 192 | # 参考deeplab_v2: Learning rate policy 193 | return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9 194 | 195 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f) 196 | 197 | 198 | def get_scheduler(optimizer, opt): 199 | """Return a learning rate scheduler 200 | Parameters: 201 | optimizer -- the optimizer of the network 202 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  203 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 204 | For 'linear', we keep the same learning rate for the first epochs 205 | and linearly decay the rate to zero over the next epochs. 206 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 207 | See https://pytorch.org/docs/stable/optim.html for more details. 208 | """ 209 | if opt == 'linear': 210 | def lambda_rule(epoch): 211 | lr_l = 1.0 - max(0, epoch + 1 - 90) / float(90 + 1) 212 | return lr_l 213 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 214 | elif opt == 'step': 215 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 216 | elif opt == 'plateau': 217 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 218 | elif opt == 'cosine': 219 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 220 | else: 221 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 222 | return scheduler 223 | 224 | 225 | import torch.nn.functional as F 226 | import torchvision 227 | 228 | 229 | def structure_loss(pred, mask): 230 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask) 231 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') 232 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) 233 | 234 | pred = torch.sigmoid(pred) 235 | inter = ((pred*mask)*weit).sum(dim=(2,3)) 236 | union = ((pred+mask)*weit).sum(dim=(2,3)) 237 | wiou = 1-(inter+1)/(union-inter+1) 238 | return (wbce+wiou).mean() 239 | 240 | 241 | #loss2u = LossNet(F.sigmoid(output), mask) 242 | 243 | class LossNet(torch.nn.Module): 244 | def __init__(self, resize=True): 245 | super(LossNet, self).__init__() 246 | blocks = [] 247 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) 248 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) 249 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) 250 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) 251 | for bl in blocks: 252 | for p in bl: 253 | p.requires_grad = False 254 | self.blocks = torch.nn.ModuleList(blocks) 255 | self.transform = torch.nn.functional.interpolate 256 | self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) 257 | self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) 258 | self.resize = resize 259 | 260 | 261 | def forward(self, input, target): 262 | if input.shape[1] != 3: 263 | input = input.repeat(1, 3, 1, 1) 264 | target = target.repeat(1, 3, 1, 1) 265 | input = (input-self.mean) / self.std 266 | target = (target-self.mean) / self.std 267 | if self.resize: 268 | input = self.transform(input, mode='bilinear', size=(256, 256), align_corners=False) 269 | target = self.transform(target, mode='bilinear', size=(256, 256), align_corners=False) 270 | loss = 0.0 271 | x = input 272 | y = target 273 | 274 | for block in self.blocks: 275 | x = block(x) 276 | y = block(y) 277 | loss += torch.nn.functional.mse_loss(x, y) 278 | return loss 279 | 280 | 281 | --------------------------------------------------------------------------------