├── models ├── __init__.py └── logonet.py ├── seg_losses ├── __init__.py └── losses.py ├── graph ├── model4.png └── module3.png ├── samples ├── test_label.txt ├── train_label.txt ├── img │ ├── 1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg │ └── 1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg └── msk │ ├── 1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg │ └── 1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg ├── matrics.py ├── README.md ├── train.py └── utils.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .logonet import * -------------------------------------------------------------------------------- /seg_losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * -------------------------------------------------------------------------------- /graph/model4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/graph/model4.png -------------------------------------------------------------------------------- /graph/module3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/graph/module3.png -------------------------------------------------------------------------------- /samples/test_label.txt: -------------------------------------------------------------------------------- 1 | 1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg 1 2 | 1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg 0 -------------------------------------------------------------------------------- /samples/train_label.txt: -------------------------------------------------------------------------------- 1 | 1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg 1 2 | 1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg 0 -------------------------------------------------------------------------------- /samples/img/1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/samples/img/1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg -------------------------------------------------------------------------------- /samples/msk/1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/samples/msk/1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg -------------------------------------------------------------------------------- /samples/img/1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/samples/img/1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg -------------------------------------------------------------------------------- /samples/msk/1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/samples/msk/1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg -------------------------------------------------------------------------------- /matrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def iou_score(output, target): 7 | 8 | smooth = 1e-5 9 | if isinstance(output,list): 10 | output = output[-1] 11 | output = torch.sigmoid(output) 12 | if torch.is_tensor(output): 13 | output = output.view(-1).data.cpu().numpy() 14 | if torch.is_tensor(target): 15 | target = target.view(-1).data.cpu().numpy() 16 | output_ = output > 0.5 17 | target_ = target > 0.5 18 | intersection = (output_ & target_).sum() 19 | union = (output_ | target_).sum() 20 | 21 | return (intersection + smooth) / (union + smooth) 22 | 23 | 24 | def dice_coef(output, target): 25 | 26 | smooth = 1e-5 27 | if isinstance(output,list): 28 | output = output[-1] 29 | output = torch.sigmoid(output) 30 | output = output.view(-1).data.cpu().numpy() 31 | target = target.view(-1).data.cpu().numpy() 32 | output[output>=0.5] = 1 33 | output[output<0.5] = 0 34 | intersection = (output * target).sum() 35 | 36 | return (2. * intersection + smooth) / \ 37 | (output.sum() + target.sum() + smooth) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LoGo-Net 2 | The codes are for the work "A Local and Global Feature Disentangled Network: Toward Classification of Benign-malignant Thyroid Nodules from Ultrasound Image" and are also available for download on our homepage https://www.neuro.uestc.edu.cn/vccl/home.html. 3 | 4 | ## Introduction 5 | In this study, inspired by the domain knowledge of sonographers when diagnosing ultrasound images, a local and global feature disentangled network (LoGo-Net) is proposed to classify benign and malignant thyroid nodules. This model imitates the dual-pathway structure of human vision and establishes a new feature extraction method to improve the recognition performance of nodules. We use the tissue-anatomy disentangled (TAD) block to connect the dual pathways, which decouples the clues of local and global features based on the self-attention mechanism. 6 |

7 | 8 |

9 |

10 | 11 |

12 | 13 | ## Pre-requirements 14 | The codebase is tested on the following setting. 15 | * Python>=3.7 16 | * PyTorch>=1.6.0 17 | * torchvision>=0.7 18 | 19 | ## Train 20 | * For easier use of LoGo-Net, this project provides a simple example framework. There are three scale models can choose from, which are logonet18, 34, and 50. 21 | ``` 22 | python train.py 23 | ``` 24 | 25 | ## Citation 26 | If you use this codes in your research, please cite the paper: 27 | ```BibTex 28 | @ARTICLE{logonet, 29 | author={Zhao, Shi-Xuan and Chen, Yang and Yang, Kai-Fu and Yang, Kai-Fu and Luo, Yan and Ma, Bu-Yun and Li, Yong-Jie}, 30 | journal={IEEE Transactions on Medical Imaging}, 31 | title={A Local and Global Feature Disentangled Network: Toward Classification of Benign-malignant Thyroid Nodules from Ultrasound Image}, 32 | year={2022}, 33 | volume={4}, 34 | number={6}, 35 | pages={1497--1509}, 36 | doi={10.1109/TMI.2022.3140797}} 37 | ``` 38 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.utils.data import Dataset, DataLoader 6 | from torchvision import transforms 7 | from PIL import Image 8 | import numpy as np 9 | from models import * 10 | from utils import Trainer 11 | import random 12 | 13 | def default_loader(path, is_img): 14 | if is_img: 15 | img = Image.open(path).convert('L') 16 | img = img.resize((224, 224), Image.ANTIALIAS) 17 | else: 18 | img = Image.open(path).convert('1') 19 | img = img.resize((56, 56), Image.ANTIALIAS) 20 | return img 21 | 22 | class MyDataset(Dataset): 23 | def __init__(self, mode, txt, transform=None, loader=default_loader): 24 | fh = open(txt, 'r') 25 | imgs = [] 26 | for line in fh: 27 | line = line.strip('\n') 28 | line = line.rstrip() 29 | words = line.split() 30 | imgs.append((words[0], int(words[1]))) 31 | self.imgs = imgs 32 | self.mode = mode 33 | self.transform = transform 34 | self.loader = loader 35 | 36 | def __getitem__(self, index): 37 | fn, label = self.imgs[index] 38 | img = self.loader(data_root + 'img/' + fn, True) 39 | tar = self.loader(data_root + 'msk/' + fn, False) 40 | if random.randint(0, 1) and self.mode == 'train': 41 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 42 | tar = tar.transpose(Image.FLIP_LEFT_RIGHT) 43 | img = self.transform(img) 44 | tar = torch.from_numpy(np.array(tar, np.float32, copy=False)) 45 | tar = torch.unsqueeze(tar, 0) 46 | return img, tar, label 47 | 48 | def __len__(self): 49 | return len(self.imgs) 50 | 51 | 52 | def main(model): 53 | train_data = MyDataset( 54 | mode='train', 55 | txt=label_root + 'train_label.txt', 56 | transform=transforms.Compose([ 57 | transforms.ToTensor(), 58 | transforms. 59 | transforms.Normalize(mean=[img_mean], 60 | std=[img_std]) 61 | ])) 62 | test_data = MyDataset( 63 | mode='test', 64 | txt=label_root + 'test_label.txt', 65 | transform=transforms.Compose([ 66 | transforms.ToTensor(), 67 | transforms.Normalize(mean=[img_mean], 68 | std=[img_std]) 69 | ])) 70 | 71 | train_loader = DataLoader( 72 | train_data, batch_size=batch_size, shuffle=True, drop_last=True) 73 | test_loader = DataLoader( 74 | test_data, batch_size=batch_size, shuffle=False) 75 | 76 | # model.cuda() 77 | # model = nn.DataParallel(model.cuda(), device_ids=[0]) 78 | optimizer = optim.SGD(params=model.parameters(), 79 | lr=base_lr, momentum=0.9, weight_decay=1e-5) 80 | # optimizer = optim.Adam(params=model.parameters(), lr=0.001) 81 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 50, 0) 82 | scheduler = optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.1) 83 | trainer = Trainer(model, optimizer, save_dir=save_root) 84 | trainer.loop(max_epoch, train_loader, test_loader, scheduler, save_freq=1) 85 | 86 | 87 | if __name__ == '__main__': 88 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 89 | torch.backends.cudnn.benchmark = True 90 | # argparse 91 | batch_size = 1 92 | data_root = './samples/' 93 | label_root = './samples/' 94 | save_root = './temp/' 95 | img_mean = 0.3309 96 | img_std = 0.1924 97 | model = logonet18() 98 | base_lr = 0.001 99 | max_epoch = 100 100 | 101 | if not os.path.exists(save_root): 102 | os.makedirs(save_root) 103 | main(model) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import scipy.io as sio 4 | import numpy as np 5 | from torch.autograd import Variable 6 | import torch.optim as optim 7 | import torch.nn as nn 8 | from tqdm import tqdm 9 | import torch.nn.functional as F 10 | import pdb 11 | from torchvision import transforms 12 | from PIL import Image 13 | import time 14 | from seg_losses import * 15 | from matrics import * 16 | 17 | 18 | class Trainer(object): 19 | def __init__(self, model, optimizer, save_dir=None, save_freq=1): 20 | self.model = model 21 | self.optimizer = optimizer 22 | self.save_dir = save_dir 23 | self.save_freq = save_freq 24 | 25 | def _loop(self, data_loader, ep, is_train=True): 26 | tensor2img = transforms.ToPILImage() 27 | loop_loss_class, correct, loop_loss_seg, loop_iou, loop_dice = [], [], [], [], [] 28 | mode = 'train' if is_train else 'test' 29 | for data, tar, label in tqdm(data_loader): 30 | # data, tar, label = data.cuda(), tar.cuda(), label.cuda() 31 | if is_train: 32 | torch.set_grad_enabled(True) 33 | else: 34 | torch.set_grad_enabled(False) 35 | out_class, out_seg = self.model(torch.cat([data, data, data], 1)) 36 | n = out_seg.size(0) 37 | loss_class = F.cross_entropy(out_class, label) 38 | loss_seg = F.binary_cross_entropy(torch.sigmoid(out_seg.view(n, -1)), tar.view(n, -1)) + \ 39 | IOULoss(out_seg, tar) 40 | loss = 0.5 * loss_class + loss_seg 41 | 42 | loop_loss_class.append(loss_class.detach() / len(data_loader)) 43 | loop_loss_seg.append(loss_seg.data / len(data_loader)) 44 | out = (out_class.data.max(1)[1] == label.data).sum() 45 | correct.append(float(out) / len(data_loader.dataset)) 46 | 47 | for j in range(n): 48 | loop_iou.append(iou_score(out_seg[j], tar[j])) 49 | loop_dice.append(dice_coef(out_seg[j], tar[j])) 50 | 51 | if is_train: 52 | self.optimizer.zero_grad() 53 | loss.backward() 54 | self.optimizer.step() 55 | 56 | print(mode + ': loss_class: {:.6f}, Acc: {:.6%}, loss_seg: {:.6f}, iou: {:.6f}, dice: {:.6f}'.format( 57 | sum(loop_loss_class), sum(correct), sum(loop_loss_seg), sum(loop_iou)/len(loop_iou), sum(loop_dice)/len(loop_dice))) 58 | return sum(loop_loss_class), sum(correct), sum(loop_loss_seg), sum(loop_iou)/len(loop_iou), sum(loop_dice)/len(loop_dice) 59 | 60 | def train(self, data_loader, ep): 61 | self.model.train() 62 | results = self._loop(data_loader, ep) 63 | return results 64 | 65 | def test(self, data_loader, ep): 66 | self.model.eval() 67 | results = self._loop(data_loader, ep, is_train=False) 68 | return results 69 | 70 | def loop(self, epochs, train_data, test_data, scheduler=None, save_freq=5): 71 | f = open(self.save_dir + 'log.txt', 'w') 72 | # f.write('train_loss_cls train_acc train_loss_seg train_iou train_dice ' + 73 | # 'test_loss_cls test_acc test_loss_seg test_iou test_dice\n') 74 | f.close() 75 | for ep in range(1, epochs + 1): 76 | if scheduler is not None: 77 | scheduler.step() 78 | print('epoch {}'.format(ep)) 79 | train_results = np.array(self.train(train_data, ep)) 80 | test_results = np.array(self.test(test_data, ep)) 81 | with open(self.save_dir + 'log.txt', 'a') as f: 82 | for i in np.append(train_results, test_results): 83 | f.write(str(round(i, 6)) + ' ') 84 | f.write('\n') 85 | if not ep % save_freq: 86 | self.save(ep) 87 | 88 | def save(self, epoch, **kwargs): 89 | if self.save_dir: 90 | name = self.save_dir + 'train' + str(epoch) + 'models.pth' 91 | torch.save(self.model.state_dict(), name) 92 | # torch.save(self.model, name) -------------------------------------------------------------------------------- /seg_losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | 8 | def IOULoss(input, target, weight=None): 9 | 10 | epsilon = 1e-5 11 | 12 | num = target.size(0) 13 | input = torch.sigmoid(input) 14 | input = input.view(num, -1) 15 | target = target.view(num, -1) 16 | target = target.float() 17 | 18 | intersect = (input * target).sum(1) 19 | if weight is not None: 20 | intersect = weight * intersect 21 | 22 | union = (input + target).sum(1) - intersect 23 | 24 | return 1. - torch.mean(intersect / union.clamp(min=epsilon)) 25 | 26 | 27 | def DICELoss(input, target, weight=None, dimention=2): 28 | 29 | epsilon = 1e-5 30 | 31 | input = torch.sigmoid(input) 32 | input = input.view(-1) 33 | target = target.view(-1) 34 | target = target.float() 35 | 36 | # compute per channel Dice Coefficient 37 | intersect = (input * target).sum() 38 | if weight is not None: 39 | intersect = weight * intersect 40 | 41 | # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) 42 | if dimention == 2: 43 | denominator = (input * input).sum() + (target * target).sum() 44 | else: 45 | denominator = (input + target).sum() 46 | 47 | return 1. - 2 * (intersect / denominator.clamp(min=epsilon)) 48 | 49 | 50 | def KLDivLoss(input, target): 51 | epsilon = 1e-5 52 | 53 | n = input.size(0) 54 | input = F.softmax(input.view(n, -1), -1) 55 | target = target.float() 56 | target = target.view(n, -1) 57 | add = target.sum(-1)[:,None] 58 | target = torch.div(target, add.clamp(min=epsilon)) 59 | 60 | return F.kl_div(torch.log(input), target, reduction='batchmean') 61 | 62 | 63 | class FocalLoss(nn.Module): 64 | """ 65 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 66 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 67 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 68 | :param num_class: 69 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 70 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 71 | focus on hard misclassified example 72 | :param reduction: `none`|`mean`|`sum` 73 | :param **kwargs 74 | balance_index: (int) balance class index, should be specific when alpha is float 75 | """ 76 | 77 | def __init__(self, alpha=[1.0, 1.0], gamma=2, ignore_index=None, reduction='mean'): 78 | super(FocalLoss, self).__init__() 79 | if alpha is None: 80 | alpha = [0.25, 0.75] 81 | self.alpha = alpha 82 | self.gamma = gamma 83 | self.smooth = 1e-6 84 | self.ignore_index = ignore_index 85 | self.reduction = reduction 86 | 87 | assert self.reduction in ['none', 'mean', 'sum'] 88 | 89 | if self.alpha is None: 90 | self.alpha = torch.ones(2) 91 | elif isinstance(self.alpha, (list, np.ndarray)): 92 | self.alpha = np.asarray(self.alpha) 93 | self.alpha = np.reshape(self.alpha, (2)) 94 | assert self.alpha.shape[0] == 2, \ 95 | 'the `alpha` shape is not match the number of class' 96 | elif isinstance(self.alpha, (float, int)): 97 | self.alpha = np.asarray([self.alpha, 1.0 - self.alpha], dtype=np.float).view(2) 98 | 99 | else: 100 | raise TypeError('{} not supported'.format(type(self.alpha))) 101 | 102 | def forward(self, output, target): 103 | prob = torch.sigmoid(output) 104 | prob = torch.clamp(prob, self.smooth, 1.0 - self.smooth) 105 | 106 | pos_mask = (target == 1).float() 107 | neg_mask = (target == 0).float() 108 | 109 | pos_loss = -self.alpha[0] * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob) * pos_mask 110 | neg_loss = -self.alpha[1] * torch.pow(prob, self.gamma) * \ 111 | torch.log(torch.sub(1.0, prob)) * neg_mask 112 | 113 | neg_loss = neg_loss.sum() 114 | pos_loss = pos_loss.sum() 115 | num_pos = pos_mask.view(pos_mask.size(0), -1).sum() 116 | num_neg = neg_mask.view(neg_mask.size(0), -1).sum() 117 | 118 | if num_pos == 0: 119 | loss = neg_loss 120 | else: 121 | loss = pos_loss / num_pos + neg_loss / num_neg 122 | return loss 123 | 124 | 125 | class CrossEntropyLabelSmooth(nn.Module): 126 | """Cross entropy loss with label smoothing regularizer. 127 | Reference: 128 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 129 | Equation: y = (1 - epsilon) * y + epsilon / K. 130 | Args: 131 | num_classes (int): number of classes. 132 | epsilon (float): weight. 133 | """ 134 | 135 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 136 | super(CrossEntropyLabelSmooth, self).__init__() 137 | self.num_classes = num_classes 138 | self.epsilon = epsilon 139 | self.use_gpu = use_gpu 140 | self.logsoftmax = nn.LogSoftmax(dim=1) 141 | 142 | def forward(self, inputs, targets): 143 | """ 144 | Args: 145 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 146 | targets: ground truth labels with shape (num_classes) 147 | """ 148 | log_probs = self.logsoftmax(inputs) 149 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 150 | if self.use_gpu: targets = targets.cuda() 151 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 152 | loss = (- targets * log_probs).mean(0).sum() 153 | return loss -------------------------------------------------------------------------------- /models/logonet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | import os 6 | import torchvision.models as models 7 | 8 | __all__ = ['logonet18', 'logonet34', 'logonet50'] 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=dilation, groups=groups, bias=True, dilation=dilation) 15 | 16 | def conv1x1(in_planes, out_planes, stride=1): 17 | """1x1 convolution""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True) 19 | 20 | 21 | class de_conv(nn.Module): 22 | def __init__(self, in_ch, out_ch): 23 | super(de_conv, self).__init__() 24 | self.conv = nn.Sequential( 25 | conv3x3(in_ch, out_ch), 26 | nn.BatchNorm2d(out_ch), 27 | nn.ReLU(inplace=True), 28 | # conv3x3(out_ch, out_ch), 29 | # nn.BatchNorm2d(out_ch), 30 | # nn.ReLU(inplace=True) 31 | ) 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | return x 36 | 37 | 38 | class up(nn.Module): 39 | def __init__(self, in_ch, out_ch, scale=2, bilinear=True): 40 | super(up, self).__init__() 41 | 42 | if bilinear: 43 | self.up = nn.Upsample(scale_factor=scale, mode='bilinear', align_corners=True) 44 | else: 45 | self.up = nn.ConvTranspose2d(in_ch//scale, in_ch//scale, kernel_size=3, stride=1, padding=1) 46 | 47 | self.conv = de_conv(in_ch, out_ch) 48 | self.dropout = nn.Dropout() 49 | 50 | def forward(self, x1, x2): 51 | x1 = self.up(x1) 52 | diffX = x2.size()[2] - x1.size()[2] 53 | diffY = x2.size()[3] - x1.size()[3] 54 | x1 = F.pad(x1, (diffY // 2, math.ceil(diffY / 2), 55 | diffX // 2, math.ceil(diffX / 2)), "constant", 0) 56 | x = torch.cat([x2, x1], dim=1) 57 | x = self.conv(x) 58 | x = self.dropout(x) 59 | return x 60 | 61 | 62 | class BasicBlock(nn.Module): 63 | expansion = 1 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 66 | base_width=64, dilation=1, norm_layer=None): 67 | super(BasicBlock, self).__init__() 68 | if norm_layer is None: 69 | norm_layer = nn.BatchNorm2d 70 | if groups != 1 or base_width != 64: 71 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 72 | if dilation > 1: 73 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 74 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 75 | self.conv1 = conv3x3(inplanes, planes, stride) 76 | self.bn1 = norm_layer(planes) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.conv2 = conv3x3(planes, planes) 79 | self.bn2 = norm_layer(planes) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | identity = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | 93 | if self.downsample is not None: 94 | identity = self.downsample(x) 95 | 96 | out += identity 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class TAD_block(nn.Module): 103 | 104 | def __init__(self, in_dim): 105 | super(TAD_block, self).__init__() 106 | self.chanel_in = in_dim 107 | 108 | self.gate_conv = nn.Sequential( 109 | conv1x1(in_dim, 1), 110 | nn.BatchNorm2d(1), 111 | nn.Sigmoid() 112 | ) 113 | self.query_conv = conv1x1(in_dim, in_dim // 8) 114 | self.key_conv = conv1x1(in_dim, in_dim // 8) 115 | self.value_conv = conv1x1(in_dim, in_dim) 116 | self.mask_conv = conv1x1(in_dim, 1) 117 | self.gamma = nn.Parameter(torch.zeros(1)) 118 | 119 | self.softmax = nn.Softmax(dim=-1) 120 | self.relu = nn.ReLU() 121 | 122 | def forward(self, x): 123 | """ 124 | inputs : 125 | x : input feature maps( B * C * H * W) 126 | returns : 127 | out : self attention value + input feature 128 | attention: B * N * N (N is H * W) 129 | """ 130 | B, C, H, W = x.size() 131 | proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1) # B * N * C 132 | proj_query -= proj_query.mean(1).unsqueeze(1) 133 | proj_key = self.key_conv(x).view(B, -1, H * W) # B * C * N 134 | proj_key -= proj_key.mean(2).unsqueeze(2) 135 | energy = torch.bmm(proj_query, proj_key) 136 | attention = self.softmax(energy) # B * N * N 137 | gate = self.gate_conv(x).view(B, -1, H * W) # B * 1 * N 138 | attention = attention.permute(0, 2, 1) * gate 139 | proj_value = self.value_conv(x).view(B, -1, H * W) # B * C * N 140 | proj_value = self.relu(proj_value) 141 | 142 | tissue = torch.bmm(proj_value, attention) 143 | tissue = tissue.view(B, C, H, W) 144 | 145 | proj_mask = self.mask_conv(x).view(B, -1, H * W) # B * 1 * N 146 | mask = self.softmax(proj_mask) 147 | anatomy = torch.bmm(proj_value, mask.permute(0, 2, 1)).unsqueeze(-1) 148 | 149 | out = tissue + anatomy 150 | out = self.gamma * out + x 151 | return out, tissue 152 | 153 | 154 | class logonet(nn.Module): 155 | 156 | def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=2): 157 | super(logonet, self).__init__() 158 | self.inplanes = 64 159 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 160 | bias=False) 161 | self.bn1 = nn.BatchNorm2d(64) 162 | self.relu = nn.ReLU(inplace=True) 163 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 164 | self.layer1 = self._make_layer(block, 64, layers[0]) 165 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 167 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 168 | self.avgpool = nn.AdaptiveAvgPool2d(1) 169 | self.fc = nn.Linear(512 * block.expansion, num_classes) 170 | 171 | self.tad1 = TAD_block(64) 172 | self.tad2 = TAD_block(64) 173 | self.tad3 = TAD_block(128) 174 | self.tad4 = TAD_block(256) 175 | 176 | self.up4 = up(512 + 256, 256) 177 | self.up3 = up(256 + 128, 128) 178 | self.up2 = up(128 + 64, 64) 179 | self.up1 = up(64 + 64, 64, 1) 180 | 181 | self.outconv = conv3x3(64, 1) 182 | 183 | for m in self.modules(): 184 | if isinstance(m, nn.Conv2d): 185 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 186 | elif isinstance(m, nn.BatchNorm2d): 187 | nn.init.constant_(m.weight, 1) 188 | nn.init.constant_(m.bias, 0) 189 | 190 | def _make_layer(self, block, planes, blocks, stride=1): 191 | downsample = None 192 | if stride != 1 or self.inplanes != planes * block.expansion: 193 | downsample = nn.Sequential( 194 | nn.Conv2d(self.inplanes, planes * block.expansion, 195 | kernel_size=1, stride=stride, bias=False), 196 | nn.BatchNorm2d(planes * block.expansion), 197 | ) 198 | 199 | layers = [] 200 | layers.append(block(self.inplanes, planes, stride, downsample=downsample)) 201 | self.inplanes = planes * block.expansion 202 | for i in range(1, blocks): 203 | layers.append(block(self.inplanes, planes)) 204 | 205 | return nn.Sequential(*layers) 206 | 207 | def forward(self, x): 208 | x = self.conv1(x) 209 | x = self.bn1(x) 210 | x = self.relu(x) 211 | x1 = self.maxpool(x) 212 | x, tissue1 = self.tad1(x1) 213 | 214 | x2 = self.layer1(x) 215 | x, tissue2 = self.tad2(x2) 216 | x3 = self.layer2(x) 217 | x, tissue3 = self.tad3(x3) 218 | x4 = self.layer3(x) 219 | x, tissue4 = self.tad4(x4) 220 | x5 = self.layer4(x) 221 | 222 | x = self.avgpool(x5) 223 | x = x.view(x.size(0), -1) 224 | x = self.fc(x) 225 | 226 | y = self.up4(x5, tissue4) 227 | y = self.up3(y, tissue3) 228 | y = self.up2(y, tissue2) 229 | y = self.up1(y, tissue1) 230 | y = self.outconv(y) 231 | 232 | return x, y 233 | 234 | 235 | def load_pretrained(model, premodel): 236 | pretrained_dict = premodel.state_dict() 237 | model_dict = model.state_dict() 238 | # 1. filter out unnecessary keys 239 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 240 | # 2. overwrite entries in the existing state dict 241 | model_dict.update(pretrained_dict) 242 | # 3. load the new state dict 243 | model.load_state_dict(model_dict) 244 | return model 245 | 246 | def logonet18(pretrained=False): 247 | model = logonet(BasicBlock, [2, 2, 2, 2]) 248 | if pretrained: 249 | premodel = models.resnet18(pretrained=True) 250 | premodel.fc = nn.Linear(512, 2) 251 | model = load_pretrained(model, premodel) 252 | return model 253 | 254 | def logonet34(pretrained=False): 255 | model = logonet(BasicBlock, [3, 4, 6, 3]) 256 | if pretrained: 257 | premodel = models.resnet34(pretrained=True) 258 | premodel.fc = nn.Linear(512, 2) 259 | model = load_pretrained(model, premodel) 260 | return model 261 | 262 | def logonet50(pretrained=False): 263 | model = logonet(BasicBlock, [4, 8, 8, 4]) 264 | return model 265 | 266 | if __name__ == '__main__': 267 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 268 | images = torch.randn(1, 3, 224, 224) 269 | model = logonet18(pretrained=True) 270 | out_class, out_seg = model(images) 271 | print(out_class.shape, out_seg.shape) --------------------------------------------------------------------------------