├── models
├── __init__.py
└── logonet.py
├── seg_losses
├── __init__.py
└── losses.py
├── graph
├── model4.png
└── module3.png
├── samples
├── test_label.txt
├── train_label.txt
├── img
│ ├── 1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg
│ └── 1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg
└── msk
│ ├── 1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg
│ └── 1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg
├── matrics.py
├── README.md
├── train.py
└── utils.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .logonet import *
--------------------------------------------------------------------------------
/seg_losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import *
--------------------------------------------------------------------------------
/graph/model4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/graph/model4.png
--------------------------------------------------------------------------------
/graph/module3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/graph/module3.png
--------------------------------------------------------------------------------
/samples/test_label.txt:
--------------------------------------------------------------------------------
1 | 1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg 1
2 | 1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg 0
--------------------------------------------------------------------------------
/samples/train_label.txt:
--------------------------------------------------------------------------------
1 | 1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg 1
2 | 1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg 0
--------------------------------------------------------------------------------
/samples/img/1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/samples/img/1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg
--------------------------------------------------------------------------------
/samples/msk/1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/samples/msk/1.2.250.1.204.5.8373722513.201612141038584163.75.dcm.jpg
--------------------------------------------------------------------------------
/samples/img/1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/samples/img/1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg
--------------------------------------------------------------------------------
/samples/msk/1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phanzsx/LoGo-Net/HEAD/samples/msk/1.2.250.1.204.5.8373722513.20170907131030778276.75.dcm.jpg
--------------------------------------------------------------------------------
/matrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def iou_score(output, target):
7 |
8 | smooth = 1e-5
9 | if isinstance(output,list):
10 | output = output[-1]
11 | output = torch.sigmoid(output)
12 | if torch.is_tensor(output):
13 | output = output.view(-1).data.cpu().numpy()
14 | if torch.is_tensor(target):
15 | target = target.view(-1).data.cpu().numpy()
16 | output_ = output > 0.5
17 | target_ = target > 0.5
18 | intersection = (output_ & target_).sum()
19 | union = (output_ | target_).sum()
20 |
21 | return (intersection + smooth) / (union + smooth)
22 |
23 |
24 | def dice_coef(output, target):
25 |
26 | smooth = 1e-5
27 | if isinstance(output,list):
28 | output = output[-1]
29 | output = torch.sigmoid(output)
30 | output = output.view(-1).data.cpu().numpy()
31 | target = target.view(-1).data.cpu().numpy()
32 | output[output>=0.5] = 1
33 | output[output<0.5] = 0
34 | intersection = (output * target).sum()
35 |
36 | return (2. * intersection + smooth) / \
37 | (output.sum() + target.sum() + smooth)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LoGo-Net
2 | The codes are for the work "A Local and Global Feature Disentangled Network: Toward Classification of Benign-malignant Thyroid Nodules from Ultrasound Image" and are also available for download on our homepage https://www.neuro.uestc.edu.cn/vccl/home.html.
3 |
4 | ## Introduction
5 | In this study, inspired by the domain knowledge of sonographers when diagnosing ultrasound images, a local and global feature disentangled network (LoGo-Net) is proposed to classify benign and malignant thyroid nodules. This model imitates the dual-pathway structure of human vision and establishes a new feature extraction method to improve the recognition performance of nodules. We use the tissue-anatomy disentangled (TAD) block to connect the dual pathways, which decouples the clues of local and global features based on the self-attention mechanism.
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | ## Pre-requirements
14 | The codebase is tested on the following setting.
15 | * Python>=3.7
16 | * PyTorch>=1.6.0
17 | * torchvision>=0.7
18 |
19 | ## Train
20 | * For easier use of LoGo-Net, this project provides a simple example framework. There are three scale models can choose from, which are logonet18, 34, and 50.
21 | ```
22 | python train.py
23 | ```
24 |
25 | ## Citation
26 | If you use this codes in your research, please cite the paper:
27 | ```BibTex
28 | @ARTICLE{logonet,
29 | author={Zhao, Shi-Xuan and Chen, Yang and Yang, Kai-Fu and Yang, Kai-Fu and Luo, Yan and Ma, Bu-Yun and Li, Yong-Jie},
30 | journal={IEEE Transactions on Medical Imaging},
31 | title={A Local and Global Feature Disentangled Network: Toward Classification of Benign-malignant Thyroid Nodules from Ultrasound Image},
32 | year={2022},
33 | volume={4},
34 | number={6},
35 | pages={1497--1509},
36 | doi={10.1109/TMI.2022.3140797}}
37 | ```
38 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from torch.utils.data import Dataset, DataLoader
6 | from torchvision import transforms
7 | from PIL import Image
8 | import numpy as np
9 | from models import *
10 | from utils import Trainer
11 | import random
12 |
13 | def default_loader(path, is_img):
14 | if is_img:
15 | img = Image.open(path).convert('L')
16 | img = img.resize((224, 224), Image.ANTIALIAS)
17 | else:
18 | img = Image.open(path).convert('1')
19 | img = img.resize((56, 56), Image.ANTIALIAS)
20 | return img
21 |
22 | class MyDataset(Dataset):
23 | def __init__(self, mode, txt, transform=None, loader=default_loader):
24 | fh = open(txt, 'r')
25 | imgs = []
26 | for line in fh:
27 | line = line.strip('\n')
28 | line = line.rstrip()
29 | words = line.split()
30 | imgs.append((words[0], int(words[1])))
31 | self.imgs = imgs
32 | self.mode = mode
33 | self.transform = transform
34 | self.loader = loader
35 |
36 | def __getitem__(self, index):
37 | fn, label = self.imgs[index]
38 | img = self.loader(data_root + 'img/' + fn, True)
39 | tar = self.loader(data_root + 'msk/' + fn, False)
40 | if random.randint(0, 1) and self.mode == 'train':
41 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
42 | tar = tar.transpose(Image.FLIP_LEFT_RIGHT)
43 | img = self.transform(img)
44 | tar = torch.from_numpy(np.array(tar, np.float32, copy=False))
45 | tar = torch.unsqueeze(tar, 0)
46 | return img, tar, label
47 |
48 | def __len__(self):
49 | return len(self.imgs)
50 |
51 |
52 | def main(model):
53 | train_data = MyDataset(
54 | mode='train',
55 | txt=label_root + 'train_label.txt',
56 | transform=transforms.Compose([
57 | transforms.ToTensor(),
58 | transforms.
59 | transforms.Normalize(mean=[img_mean],
60 | std=[img_std])
61 | ]))
62 | test_data = MyDataset(
63 | mode='test',
64 | txt=label_root + 'test_label.txt',
65 | transform=transforms.Compose([
66 | transforms.ToTensor(),
67 | transforms.Normalize(mean=[img_mean],
68 | std=[img_std])
69 | ]))
70 |
71 | train_loader = DataLoader(
72 | train_data, batch_size=batch_size, shuffle=True, drop_last=True)
73 | test_loader = DataLoader(
74 | test_data, batch_size=batch_size, shuffle=False)
75 |
76 | # model.cuda()
77 | # model = nn.DataParallel(model.cuda(), device_ids=[0])
78 | optimizer = optim.SGD(params=model.parameters(),
79 | lr=base_lr, momentum=0.9, weight_decay=1e-5)
80 | # optimizer = optim.Adam(params=model.parameters(), lr=0.001)
81 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 50, 0)
82 | scheduler = optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.1)
83 | trainer = Trainer(model, optimizer, save_dir=save_root)
84 | trainer.loop(max_epoch, train_loader, test_loader, scheduler, save_freq=1)
85 |
86 |
87 | if __name__ == '__main__':
88 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
89 | torch.backends.cudnn.benchmark = True
90 | # argparse
91 | batch_size = 1
92 | data_root = './samples/'
93 | label_root = './samples/'
94 | save_root = './temp/'
95 | img_mean = 0.3309
96 | img_std = 0.1924
97 | model = logonet18()
98 | base_lr = 0.001
99 | max_epoch = 100
100 |
101 | if not os.path.exists(save_root):
102 | os.makedirs(save_root)
103 | main(model)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import scipy.io as sio
4 | import numpy as np
5 | from torch.autograd import Variable
6 | import torch.optim as optim
7 | import torch.nn as nn
8 | from tqdm import tqdm
9 | import torch.nn.functional as F
10 | import pdb
11 | from torchvision import transforms
12 | from PIL import Image
13 | import time
14 | from seg_losses import *
15 | from matrics import *
16 |
17 |
18 | class Trainer(object):
19 | def __init__(self, model, optimizer, save_dir=None, save_freq=1):
20 | self.model = model
21 | self.optimizer = optimizer
22 | self.save_dir = save_dir
23 | self.save_freq = save_freq
24 |
25 | def _loop(self, data_loader, ep, is_train=True):
26 | tensor2img = transforms.ToPILImage()
27 | loop_loss_class, correct, loop_loss_seg, loop_iou, loop_dice = [], [], [], [], []
28 | mode = 'train' if is_train else 'test'
29 | for data, tar, label in tqdm(data_loader):
30 | # data, tar, label = data.cuda(), tar.cuda(), label.cuda()
31 | if is_train:
32 | torch.set_grad_enabled(True)
33 | else:
34 | torch.set_grad_enabled(False)
35 | out_class, out_seg = self.model(torch.cat([data, data, data], 1))
36 | n = out_seg.size(0)
37 | loss_class = F.cross_entropy(out_class, label)
38 | loss_seg = F.binary_cross_entropy(torch.sigmoid(out_seg.view(n, -1)), tar.view(n, -1)) + \
39 | IOULoss(out_seg, tar)
40 | loss = 0.5 * loss_class + loss_seg
41 |
42 | loop_loss_class.append(loss_class.detach() / len(data_loader))
43 | loop_loss_seg.append(loss_seg.data / len(data_loader))
44 | out = (out_class.data.max(1)[1] == label.data).sum()
45 | correct.append(float(out) / len(data_loader.dataset))
46 |
47 | for j in range(n):
48 | loop_iou.append(iou_score(out_seg[j], tar[j]))
49 | loop_dice.append(dice_coef(out_seg[j], tar[j]))
50 |
51 | if is_train:
52 | self.optimizer.zero_grad()
53 | loss.backward()
54 | self.optimizer.step()
55 |
56 | print(mode + ': loss_class: {:.6f}, Acc: {:.6%}, loss_seg: {:.6f}, iou: {:.6f}, dice: {:.6f}'.format(
57 | sum(loop_loss_class), sum(correct), sum(loop_loss_seg), sum(loop_iou)/len(loop_iou), sum(loop_dice)/len(loop_dice)))
58 | return sum(loop_loss_class), sum(correct), sum(loop_loss_seg), sum(loop_iou)/len(loop_iou), sum(loop_dice)/len(loop_dice)
59 |
60 | def train(self, data_loader, ep):
61 | self.model.train()
62 | results = self._loop(data_loader, ep)
63 | return results
64 |
65 | def test(self, data_loader, ep):
66 | self.model.eval()
67 | results = self._loop(data_loader, ep, is_train=False)
68 | return results
69 |
70 | def loop(self, epochs, train_data, test_data, scheduler=None, save_freq=5):
71 | f = open(self.save_dir + 'log.txt', 'w')
72 | # f.write('train_loss_cls train_acc train_loss_seg train_iou train_dice ' +
73 | # 'test_loss_cls test_acc test_loss_seg test_iou test_dice\n')
74 | f.close()
75 | for ep in range(1, epochs + 1):
76 | if scheduler is not None:
77 | scheduler.step()
78 | print('epoch {}'.format(ep))
79 | train_results = np.array(self.train(train_data, ep))
80 | test_results = np.array(self.test(test_data, ep))
81 | with open(self.save_dir + 'log.txt', 'a') as f:
82 | for i in np.append(train_results, test_results):
83 | f.write(str(round(i, 6)) + ' ')
84 | f.write('\n')
85 | if not ep % save_freq:
86 | self.save(ep)
87 |
88 | def save(self, epoch, **kwargs):
89 | if self.save_dir:
90 | name = self.save_dir + 'train' + str(epoch) + 'models.pth'
91 | torch.save(self.model.state_dict(), name)
92 | # torch.save(self.model, name)
--------------------------------------------------------------------------------
/seg_losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import numpy as np
6 |
7 |
8 | def IOULoss(input, target, weight=None):
9 |
10 | epsilon = 1e-5
11 |
12 | num = target.size(0)
13 | input = torch.sigmoid(input)
14 | input = input.view(num, -1)
15 | target = target.view(num, -1)
16 | target = target.float()
17 |
18 | intersect = (input * target).sum(1)
19 | if weight is not None:
20 | intersect = weight * intersect
21 |
22 | union = (input + target).sum(1) - intersect
23 |
24 | return 1. - torch.mean(intersect / union.clamp(min=epsilon))
25 |
26 |
27 | def DICELoss(input, target, weight=None, dimention=2):
28 |
29 | epsilon = 1e-5
30 |
31 | input = torch.sigmoid(input)
32 | input = input.view(-1)
33 | target = target.view(-1)
34 | target = target.float()
35 |
36 | # compute per channel Dice Coefficient
37 | intersect = (input * target).sum()
38 | if weight is not None:
39 | intersect = weight * intersect
40 |
41 | # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
42 | if dimention == 2:
43 | denominator = (input * input).sum() + (target * target).sum()
44 | else:
45 | denominator = (input + target).sum()
46 |
47 | return 1. - 2 * (intersect / denominator.clamp(min=epsilon))
48 |
49 |
50 | def KLDivLoss(input, target):
51 | epsilon = 1e-5
52 |
53 | n = input.size(0)
54 | input = F.softmax(input.view(n, -1), -1)
55 | target = target.float()
56 | target = target.view(n, -1)
57 | add = target.sum(-1)[:,None]
58 | target = torch.div(target, add.clamp(min=epsilon))
59 |
60 | return F.kl_div(torch.log(input), target, reduction='batchmean')
61 |
62 |
63 | class FocalLoss(nn.Module):
64 | """
65 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
66 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
67 | Focal_Loss= -1*alpha*(1-pt)*log(pt)
68 | :param num_class:
69 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
70 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
71 | focus on hard misclassified example
72 | :param reduction: `none`|`mean`|`sum`
73 | :param **kwargs
74 | balance_index: (int) balance class index, should be specific when alpha is float
75 | """
76 |
77 | def __init__(self, alpha=[1.0, 1.0], gamma=2, ignore_index=None, reduction='mean'):
78 | super(FocalLoss, self).__init__()
79 | if alpha is None:
80 | alpha = [0.25, 0.75]
81 | self.alpha = alpha
82 | self.gamma = gamma
83 | self.smooth = 1e-6
84 | self.ignore_index = ignore_index
85 | self.reduction = reduction
86 |
87 | assert self.reduction in ['none', 'mean', 'sum']
88 |
89 | if self.alpha is None:
90 | self.alpha = torch.ones(2)
91 | elif isinstance(self.alpha, (list, np.ndarray)):
92 | self.alpha = np.asarray(self.alpha)
93 | self.alpha = np.reshape(self.alpha, (2))
94 | assert self.alpha.shape[0] == 2, \
95 | 'the `alpha` shape is not match the number of class'
96 | elif isinstance(self.alpha, (float, int)):
97 | self.alpha = np.asarray([self.alpha, 1.0 - self.alpha], dtype=np.float).view(2)
98 |
99 | else:
100 | raise TypeError('{} not supported'.format(type(self.alpha)))
101 |
102 | def forward(self, output, target):
103 | prob = torch.sigmoid(output)
104 | prob = torch.clamp(prob, self.smooth, 1.0 - self.smooth)
105 |
106 | pos_mask = (target == 1).float()
107 | neg_mask = (target == 0).float()
108 |
109 | pos_loss = -self.alpha[0] * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob) * pos_mask
110 | neg_loss = -self.alpha[1] * torch.pow(prob, self.gamma) * \
111 | torch.log(torch.sub(1.0, prob)) * neg_mask
112 |
113 | neg_loss = neg_loss.sum()
114 | pos_loss = pos_loss.sum()
115 | num_pos = pos_mask.view(pos_mask.size(0), -1).sum()
116 | num_neg = neg_mask.view(neg_mask.size(0), -1).sum()
117 |
118 | if num_pos == 0:
119 | loss = neg_loss
120 | else:
121 | loss = pos_loss / num_pos + neg_loss / num_neg
122 | return loss
123 |
124 |
125 | class CrossEntropyLabelSmooth(nn.Module):
126 | """Cross entropy loss with label smoothing regularizer.
127 | Reference:
128 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
129 | Equation: y = (1 - epsilon) * y + epsilon / K.
130 | Args:
131 | num_classes (int): number of classes.
132 | epsilon (float): weight.
133 | """
134 |
135 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
136 | super(CrossEntropyLabelSmooth, self).__init__()
137 | self.num_classes = num_classes
138 | self.epsilon = epsilon
139 | self.use_gpu = use_gpu
140 | self.logsoftmax = nn.LogSoftmax(dim=1)
141 |
142 | def forward(self, inputs, targets):
143 | """
144 | Args:
145 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
146 | targets: ground truth labels with shape (num_classes)
147 | """
148 | log_probs = self.logsoftmax(inputs)
149 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
150 | if self.use_gpu: targets = targets.cuda()
151 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
152 | loss = (- targets * log_probs).mean(0).sum()
153 | return loss
--------------------------------------------------------------------------------
/models/logonet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | import os
6 | import torchvision.models as models
7 |
8 | __all__ = ['logonet18', 'logonet34', 'logonet50']
9 |
10 |
11 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
12 | """3x3 convolution with padding"""
13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14 | padding=dilation, groups=groups, bias=True, dilation=dilation)
15 |
16 | def conv1x1(in_planes, out_planes, stride=1):
17 | """1x1 convolution"""
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True)
19 |
20 |
21 | class de_conv(nn.Module):
22 | def __init__(self, in_ch, out_ch):
23 | super(de_conv, self).__init__()
24 | self.conv = nn.Sequential(
25 | conv3x3(in_ch, out_ch),
26 | nn.BatchNorm2d(out_ch),
27 | nn.ReLU(inplace=True),
28 | # conv3x3(out_ch, out_ch),
29 | # nn.BatchNorm2d(out_ch),
30 | # nn.ReLU(inplace=True)
31 | )
32 |
33 | def forward(self, x):
34 | x = self.conv(x)
35 | return x
36 |
37 |
38 | class up(nn.Module):
39 | def __init__(self, in_ch, out_ch, scale=2, bilinear=True):
40 | super(up, self).__init__()
41 |
42 | if bilinear:
43 | self.up = nn.Upsample(scale_factor=scale, mode='bilinear', align_corners=True)
44 | else:
45 | self.up = nn.ConvTranspose2d(in_ch//scale, in_ch//scale, kernel_size=3, stride=1, padding=1)
46 |
47 | self.conv = de_conv(in_ch, out_ch)
48 | self.dropout = nn.Dropout()
49 |
50 | def forward(self, x1, x2):
51 | x1 = self.up(x1)
52 | diffX = x2.size()[2] - x1.size()[2]
53 | diffY = x2.size()[3] - x1.size()[3]
54 | x1 = F.pad(x1, (diffY // 2, math.ceil(diffY / 2),
55 | diffX // 2, math.ceil(diffX / 2)), "constant", 0)
56 | x = torch.cat([x2, x1], dim=1)
57 | x = self.conv(x)
58 | x = self.dropout(x)
59 | return x
60 |
61 |
62 | class BasicBlock(nn.Module):
63 | expansion = 1
64 |
65 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
66 | base_width=64, dilation=1, norm_layer=None):
67 | super(BasicBlock, self).__init__()
68 | if norm_layer is None:
69 | norm_layer = nn.BatchNorm2d
70 | if groups != 1 or base_width != 64:
71 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
72 | if dilation > 1:
73 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
74 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
75 | self.conv1 = conv3x3(inplanes, planes, stride)
76 | self.bn1 = norm_layer(planes)
77 | self.relu = nn.ReLU(inplace=True)
78 | self.conv2 = conv3x3(planes, planes)
79 | self.bn2 = norm_layer(planes)
80 | self.downsample = downsample
81 | self.stride = stride
82 |
83 | def forward(self, x):
84 | identity = x
85 |
86 | out = self.conv1(x)
87 | out = self.bn1(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv2(out)
91 | out = self.bn2(out)
92 |
93 | if self.downsample is not None:
94 | identity = self.downsample(x)
95 |
96 | out += identity
97 | out = self.relu(out)
98 |
99 | return out
100 |
101 |
102 | class TAD_block(nn.Module):
103 |
104 | def __init__(self, in_dim):
105 | super(TAD_block, self).__init__()
106 | self.chanel_in = in_dim
107 |
108 | self.gate_conv = nn.Sequential(
109 | conv1x1(in_dim, 1),
110 | nn.BatchNorm2d(1),
111 | nn.Sigmoid()
112 | )
113 | self.query_conv = conv1x1(in_dim, in_dim // 8)
114 | self.key_conv = conv1x1(in_dim, in_dim // 8)
115 | self.value_conv = conv1x1(in_dim, in_dim)
116 | self.mask_conv = conv1x1(in_dim, 1)
117 | self.gamma = nn.Parameter(torch.zeros(1))
118 |
119 | self.softmax = nn.Softmax(dim=-1)
120 | self.relu = nn.ReLU()
121 |
122 | def forward(self, x):
123 | """
124 | inputs :
125 | x : input feature maps( B * C * H * W)
126 | returns :
127 | out : self attention value + input feature
128 | attention: B * N * N (N is H * W)
129 | """
130 | B, C, H, W = x.size()
131 | proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1) # B * N * C
132 | proj_query -= proj_query.mean(1).unsqueeze(1)
133 | proj_key = self.key_conv(x).view(B, -1, H * W) # B * C * N
134 | proj_key -= proj_key.mean(2).unsqueeze(2)
135 | energy = torch.bmm(proj_query, proj_key)
136 | attention = self.softmax(energy) # B * N * N
137 | gate = self.gate_conv(x).view(B, -1, H * W) # B * 1 * N
138 | attention = attention.permute(0, 2, 1) * gate
139 | proj_value = self.value_conv(x).view(B, -1, H * W) # B * C * N
140 | proj_value = self.relu(proj_value)
141 |
142 | tissue = torch.bmm(proj_value, attention)
143 | tissue = tissue.view(B, C, H, W)
144 |
145 | proj_mask = self.mask_conv(x).view(B, -1, H * W) # B * 1 * N
146 | mask = self.softmax(proj_mask)
147 | anatomy = torch.bmm(proj_value, mask.permute(0, 2, 1)).unsqueeze(-1)
148 |
149 | out = tissue + anatomy
150 | out = self.gamma * out + x
151 | return out, tissue
152 |
153 |
154 | class logonet(nn.Module):
155 |
156 | def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=2):
157 | super(logonet, self).__init__()
158 | self.inplanes = 64
159 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
160 | bias=False)
161 | self.bn1 = nn.BatchNorm2d(64)
162 | self.relu = nn.ReLU(inplace=True)
163 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
164 | self.layer1 = self._make_layer(block, 64, layers[0])
165 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
167 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
168 | self.avgpool = nn.AdaptiveAvgPool2d(1)
169 | self.fc = nn.Linear(512 * block.expansion, num_classes)
170 |
171 | self.tad1 = TAD_block(64)
172 | self.tad2 = TAD_block(64)
173 | self.tad3 = TAD_block(128)
174 | self.tad4 = TAD_block(256)
175 |
176 | self.up4 = up(512 + 256, 256)
177 | self.up3 = up(256 + 128, 128)
178 | self.up2 = up(128 + 64, 64)
179 | self.up1 = up(64 + 64, 64, 1)
180 |
181 | self.outconv = conv3x3(64, 1)
182 |
183 | for m in self.modules():
184 | if isinstance(m, nn.Conv2d):
185 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
186 | elif isinstance(m, nn.BatchNorm2d):
187 | nn.init.constant_(m.weight, 1)
188 | nn.init.constant_(m.bias, 0)
189 |
190 | def _make_layer(self, block, planes, blocks, stride=1):
191 | downsample = None
192 | if stride != 1 or self.inplanes != planes * block.expansion:
193 | downsample = nn.Sequential(
194 | nn.Conv2d(self.inplanes, planes * block.expansion,
195 | kernel_size=1, stride=stride, bias=False),
196 | nn.BatchNorm2d(planes * block.expansion),
197 | )
198 |
199 | layers = []
200 | layers.append(block(self.inplanes, planes, stride, downsample=downsample))
201 | self.inplanes = planes * block.expansion
202 | for i in range(1, blocks):
203 | layers.append(block(self.inplanes, planes))
204 |
205 | return nn.Sequential(*layers)
206 |
207 | def forward(self, x):
208 | x = self.conv1(x)
209 | x = self.bn1(x)
210 | x = self.relu(x)
211 | x1 = self.maxpool(x)
212 | x, tissue1 = self.tad1(x1)
213 |
214 | x2 = self.layer1(x)
215 | x, tissue2 = self.tad2(x2)
216 | x3 = self.layer2(x)
217 | x, tissue3 = self.tad3(x3)
218 | x4 = self.layer3(x)
219 | x, tissue4 = self.tad4(x4)
220 | x5 = self.layer4(x)
221 |
222 | x = self.avgpool(x5)
223 | x = x.view(x.size(0), -1)
224 | x = self.fc(x)
225 |
226 | y = self.up4(x5, tissue4)
227 | y = self.up3(y, tissue3)
228 | y = self.up2(y, tissue2)
229 | y = self.up1(y, tissue1)
230 | y = self.outconv(y)
231 |
232 | return x, y
233 |
234 |
235 | def load_pretrained(model, premodel):
236 | pretrained_dict = premodel.state_dict()
237 | model_dict = model.state_dict()
238 | # 1. filter out unnecessary keys
239 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
240 | # 2. overwrite entries in the existing state dict
241 | model_dict.update(pretrained_dict)
242 | # 3. load the new state dict
243 | model.load_state_dict(model_dict)
244 | return model
245 |
246 | def logonet18(pretrained=False):
247 | model = logonet(BasicBlock, [2, 2, 2, 2])
248 | if pretrained:
249 | premodel = models.resnet18(pretrained=True)
250 | premodel.fc = nn.Linear(512, 2)
251 | model = load_pretrained(model, premodel)
252 | return model
253 |
254 | def logonet34(pretrained=False):
255 | model = logonet(BasicBlock, [3, 4, 6, 3])
256 | if pretrained:
257 | premodel = models.resnet34(pretrained=True)
258 | premodel.fc = nn.Linear(512, 2)
259 | model = load_pretrained(model, premodel)
260 | return model
261 |
262 | def logonet50(pretrained=False):
263 | model = logonet(BasicBlock, [4, 8, 8, 4])
264 | return model
265 |
266 | if __name__ == '__main__':
267 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
268 | images = torch.randn(1, 3, 224, 224)
269 | model = logonet18(pretrained=True)
270 | out_class, out_seg = model(images)
271 | print(out_class.shape, out_seg.shape)
--------------------------------------------------------------------------------