├── logs
└── README.md
├── checkpoints
└── README.md
├── data
└── pre_model
│ └── README.md
├── Figs
└── Framework.png
├── net
├── __pycache__
│ ├── ASPP.cpython-37.pyc
│ ├── convs.cpython-37.pyc
│ ├── gumbel.cpython-37.pyc
│ ├── loss.cpython-37.pyc
│ ├── models.cpython-37.pyc
│ ├── modules.cpython-37.pyc
│ └── xception.cpython-37.pyc
├── sync_batchnorm
│ ├── __pycache__
│ │ ├── comm.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── batchnorm.cpython-36.pyc
│ │ └── replicate.cpython-36.pyc
│ ├── __init__.py
│ ├── unittest.py
│ ├── batchnorm_reimpl.py
│ ├── replicate.py
│ ├── comm.py
│ └── batchnorm.py
├── convs.py
├── gumbel.py
├── ASPP.py
├── loss.py
├── models.py
├── xception.py
└── modules.py
├── utils
├── logger.py
├── meter.py
├── metrics.py
└── fp16util.py
├── training_script.sh
├── config.py
├── README.md
├── dataset
├── transform_customize.py
├── my_datasets.py
└── preprocess.py
├── visualization
└── utils.py
└── train_DSI_Net.py
/logs/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/checkpoints/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/pre_model/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Figs/Framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/Figs/Framework.png
--------------------------------------------------------------------------------
/net/__pycache__/ASPP.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/ASPP.cpython-37.pyc
--------------------------------------------------------------------------------
/net/__pycache__/convs.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/convs.cpython-37.pyc
--------------------------------------------------------------------------------
/net/__pycache__/gumbel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/gumbel.cpython-37.pyc
--------------------------------------------------------------------------------
/net/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/net/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/net/__pycache__/modules.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/modules.cpython-37.pyc
--------------------------------------------------------------------------------
/net/__pycache__/xception.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/xception.cpython-37.pyc
--------------------------------------------------------------------------------
/net/sync_batchnorm/__pycache__/comm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/sync_batchnorm/__pycache__/comm.cpython-36.pyc
--------------------------------------------------------------------------------
/net/sync_batchnorm/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/sync_batchnorm/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/net/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc
--------------------------------------------------------------------------------
/net/sync_batchnorm/__pycache__/replicate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/sync_batchnorm/__pycache__/replicate.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 |
2 | import random
3 |
4 | def print_f(str, f=None):
5 | if f is not None:
6 | print(str, file=f)
7 | if random.randint(0, 20) < 3:
8 | f.flush()
9 | print(str)
10 |
11 |
--------------------------------------------------------------------------------
/training_script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -J training
3 | #SBATCH --gres=gpu:1
4 | #SBATCH --partition=gpu_1d2g
5 | #SBATCH -c 2
6 | #SBATCH -N 1
7 |
8 | echo "Submitted from:"$SLURM_SUBMIT_DIR" on node:"$SLURM_SUBMIT_HOST
9 | echo "Running on node "$SLURM_JOB_NODELIST
10 | echo "Allocate Gpu Units:"$CUDA_VISIBLE_DEVICES
11 |
12 | nvidia-smi
13 |
14 | python train_DSI_Net.py --gpus 0 --K 100 --alpha 0.05 --image_list 'data/WCE/WCE_Dataset_image_list.pkl'
15 |
--------------------------------------------------------------------------------
/net/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/net/convs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class SeparableConv2d(nn.Module):
6 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
7 | super(SeparableConv2d,self).__init__()
8 |
9 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
10 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
11 |
12 | def forward(self,x):
13 | x = self.conv1(x)
14 | x = self.pointwise(x)
15 | return x
--------------------------------------------------------------------------------
/utils/meter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """Computes and stores the average and current value"""
3 | def __init__(self, avg_mom=0.5):
4 | self.avg_mom = avg_mom
5 | self.reset()
6 |
7 | def reset(self):
8 | self.val = 0
9 | self.avg = 0 # running average of whole epoch
10 | self.smooth_avg = 0
11 | self.sum = 0
12 | self.count = 0
13 |
14 | def update(self, val, n=1):
15 | self.val = val
16 | self.sum += val * n
17 | self.count += n
18 | self.smooth_avg = val if self.count == 0 else self.avg*self.avg_mom + val*(1-self.avg_mom)
19 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 |
2 | #data
3 | DATA_ROOT = '/home/meiluzhu2/data/WCE/WCE_larger'
4 | BATCH_SIZE = 8
5 | NUM_WORKERS = 8
6 | DROP_LAST = True
7 | SIZE = 256
8 |
9 | #training
10 | LEARNING_RATE = 0.0001
11 | MOMENTUM = 0.9
12 | POWER = 0.9
13 | WEIGHT_DECAY = 1e-5
14 | NUM_CLASSES_CLS = 3
15 | TRAIN_NUM = 2470
16 | EPOCH = 200
17 | STEPS = (TRAIN_NUM/BATCH_SIZE)*EPOCH
18 | FP16 = False
19 | VERBOSE = False
20 | SAVE_PATH = 'checkpoints/'
21 | LOG_PATH = 'logs/'
22 | COLOR = ['red', 'green', 'blue', 'yellow', 'black', 'orange', 'purple', 'pink','peru']
23 |
24 | #network
25 | INTERMIDEATE_NUM = 64
26 | OS = 8
27 | EM_STEP = 3
28 | ##gumbel
29 | GUMBEL_FACTOR = 1.0
30 | GUMBEL_NOISE = True
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/net/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y), message)
29 |
30 |
--------------------------------------------------------------------------------
/net/gumbel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Gumbel(nn.Module):
5 | '''
6 | Returns differentiable discrete outputs. Applies a Gumbel-Softmax trick on every element of x.
7 | '''
8 | def __init__(self, config):
9 | super(Gumbel, self).__init__()
10 | self.factor = config.GUMBEL_FACTOR
11 | self.gumbel_noise = config.GUMBEL_NOISE
12 |
13 | def forward(self, x):
14 | if not self.training: # no Gumbel noise during inference
15 | return (x >= 0).float()
16 |
17 | if self.gumbel_noise:
18 | U = torch.rand_like(x)
19 | g= -torch.log( - torch.log(U + 1e-8) + 1e-8)
20 | x = x + g
21 |
22 | soft = torch.sigmoid(x / self.factor)
23 | hard = ((soft >= 0.5).float() - soft).detach() + soft
24 | assert not torch.any(torch.isnan(hard))
25 |
26 | return hard
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn import metrics
3 |
4 |
5 | def Jaccard(pred_arg, mask):
6 | pred_arg = np.argmax(pred_arg.cpu().data.numpy(), axis=1)
7 | mask = mask.cpu().data.numpy()
8 |
9 | y_true_f = mask.reshape(mask.shape[0] * mask.shape[1] * mask.shape[2], order='F')
10 | y_pred_f = pred_arg.reshape(pred_arg.shape[0] * pred_arg.shape[1] * pred_arg.shape[2], order='F')
11 |
12 | intersection = np.float(np.sum(y_true_f * y_pred_f))
13 | jac_score = intersection / (np.sum(y_true_f) + np.sum(y_pred_f) - intersection)
14 |
15 | return jac_score
16 |
17 |
18 | def cla_evaluate(label, binary_score, pro_score):
19 |
20 | acc = metrics.accuracy_score(label, binary_score)
21 | AP = metrics.average_precision_score(label, pro_score)
22 | auc = metrics.roc_auc_score(label, pro_score)
23 | CM = metrics.confusion_matrix(label, binary_score)
24 | MCC = metrics.matthews_corrcoef(label,binary_score)
25 | F1 = metrics.f1_score(label,binary_score)
26 | sens = float(CM[1, 1]) / float(CM[1, 1] + CM[1, 0])
27 | spec = float(CM[0, 0]) / float(CM[0, 0] + CM[0, 1])
28 | return acc, auc, AP, sens, spec, MCC, F1
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DSI-Net
2 |
3 | This repository is an official PyTorch implementation of the paper [**"DSI-Net: Deep Synergistic Interaction Network for Joint Classification and Segmentation with Endoscope Images"**](https://ieeexplore.ieee.org/document/9440441), TMI 2021.
4 |
5 |

6 |
7 |
8 | ## Dependencies
9 | * Python 3.6
10 | * PyTorch >= 1.3.0
11 | * numpy
12 | * apex
13 | * sklearn
14 | * matplotlib
15 | * PIL
16 |
17 | ## Usage
18 | * Downloading [**processed dataset**](https://drive.google.com/file/d/1BBF21SVlH5685XpsvtKlWN7iepr7YQPU/view?usp=sharing)
19 | * Training DSI-Net
20 | ```python
21 | python train_DSI_Net.py --gpus 0 --K 100 --alpha 0.05 --image_list 'data/WCE/WCE_Dataset_image_list.pkl'
22 | ```
23 |
24 | ## Citation
25 | ```
26 | @ARTICLE{9440441,
27 | author={Zhu, Meilu and Chen, Zhen and Yuan, Yixuan},
28 | journal={IEEE Transactions on Medical Imaging},
29 | title={DSI-Net: Deep Synergistic Interaction Network for Joint Classification and Segmentation with Endoscope Images},
30 | year={2021},
31 | doi={10.1109/TMI.2021.3083586}}
32 | ```
33 | ## Contact
34 |
35 | Meilu Zhu (meiluzhu2-c@my.cityu.edu.hk)
36 |
--------------------------------------------------------------------------------
/dataset/transform_customize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.ndimage.interpolation import map_coordinates
3 | from scipy.ndimage.filters import gaussian_filter
4 |
5 |
6 | class RandomElasticTransform(object):
7 | """Randomly rotate image"""
8 | # https://gist.github.com/nasimrahaman/8ed04be1088e228c21d51291f47dd1e6
9 | def __init__(self, alpha =2000, sigma=50):
10 | self.alpha = alpha
11 | self.sigma = sigma
12 |
13 | def __call__(self, img):
14 |
15 | shape = img.shape[:2]
16 | random_state = np.random
17 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
18 | self.sigma, mode="constant", cval=0) * self.alpha
19 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
20 | self.sigma, mode="constant", cval=0) * self.alpha
21 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
22 | indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
23 |
24 | image = map_coordinates(img, indices, order=1, mode='nearest').reshape(shape)
25 |
26 | return image
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/visualization/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def show_seg_results(img, gt, pre, save_path = None, name = None):
8 |
9 | fig = plt.figure()
10 | ax = fig.add_subplot(131)
11 | ax.imshow(img)
12 | ax.axis('off')
13 | ax = fig.add_subplot(132)
14 | ax.imshow(gt)
15 | ax.axis('off')
16 | ax = fig.add_subplot(133)
17 | ax.imshow(pre)
18 | ax.axis('off')
19 | fig.suptitle('Img, GT, Prediction',fontsize=6)
20 | if save_path != None and name != None:
21 | fig.savefig(save_path + name + '.png', dpi=200, bbox_inches='tight')
22 | ax.cla()
23 | fig.clf()
24 | plt.close()
25 |
26 | def draw_curves(data_list, label_list, color_list, linestyle_list = None, filename = 'training_curve.png'):
27 |
28 | plt.figure()
29 | for i in range(len(data_list)):
30 | data = data_list[i]
31 | label = label_list[i]
32 | color = color_list[i]
33 | if linestyle_list == None:
34 | linestyle = '-'
35 | else:
36 | linestyle = linestyle_list[i]
37 | plt.plot(data, label = label, color = color, linestyle = linestyle)
38 | plt.legend(loc='best')
39 | plt.savefig(filename)
40 | plt.clf()
41 | plt.close()
42 | plt.show()
43 | plt.close('all')
--------------------------------------------------------------------------------
/utils/fp16util.py:
--------------------------------------------------------------------------------
1 | #https://github.com/cybertronai/imagenet18_old/blob/master/training/fp16util.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.autograd import Variable
6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7 |
8 | class tofp16(nn.Module):
9 | """
10 | Model wrapper that implements::
11 | def forward(self, input):
12 | return input.half()
13 | """
14 |
15 | def __init__(self):
16 | super(tofp16, self).__init__()
17 |
18 | def forward(self, input):
19 | return input.half()
20 |
21 |
22 | def BN_convert_float(module):
23 | '''
24 | Designed to work with network_to_half.
25 | BatchNorm layers need parameters in single precision.
26 | Find all layers and convert them back to float. This can't
27 | be done with built in .apply as that function will apply
28 | fn to all modules, parameters, and buffers. Thus we wouldn't
29 | be able to guard the float conversion based on the module type.
30 | '''
31 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
32 | module.float()
33 | for child in module.children():
34 | BN_convert_float(child)
35 | return module
36 |
37 |
38 | def network_to_half(network):
39 | """
40 | Convert model to half precision in a batchnorm-safe way.
41 | """
42 | # (AS) This is better as it does not change model structure
43 | return BN_convert_float(network.half())
44 | # return nn.Sequential(tofp16(), BN_convert_float(network.half()))
--------------------------------------------------------------------------------
/net/ASPP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class ASPP(nn.Module):
6 | def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
7 | super(ASPP, self).__init__()
8 | self.branch1 = nn.Sequential(
9 | nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
10 | nn.BatchNorm2d(dim_out),
11 | nn.ReLU(inplace=True),
12 | )
13 | self.branch2 = nn.Sequential(
14 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True),
15 | nn.BatchNorm2d(dim_out),
16 | nn.ReLU(inplace=True),
17 | )
18 | self.branch3 = nn.Sequential(
19 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True),
20 | nn.BatchNorm2d(dim_out),
21 | nn.ReLU(inplace=True),
22 | )
23 | self.branch4 = nn.Sequential(
24 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True),
25 | nn.BatchNorm2d(dim_out),
26 | nn.ReLU(inplace=True),
27 | )
28 | self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
29 | self.branch5_bn = nn.BatchNorm2d(dim_out)
30 | self.branch5_relu = nn.ReLU(inplace=True)
31 | self.conv_cat = nn.Sequential(
32 | nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True),
33 | nn.BatchNorm2d(dim_out),
34 | nn.ReLU(inplace=True),
35 | )
36 |
37 | def forward(self, x):
38 | [b, c, row, col] = x.size()
39 | conv1x1 = self.branch1(x)
40 | conv3x3_1 = self.branch2(x)
41 | conv3x3_2 = self.branch3(x)
42 | conv3x3_3 = self.branch4(x)
43 | global_feature = torch.mean(x, 2, True)
44 | global_feature = torch.mean(global_feature, 3, True)
45 | global_feature = self.branch5_conv(global_feature)
46 | global_feature = self.branch5_bn(global_feature)
47 | global_feature = self.branch5_relu(global_feature)
48 | global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
49 |
50 | feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
51 | result = self.conv_cat(feature_cat)
52 |
53 | return result
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/net/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 |
6 | def _dice_loss(predict, target):
7 |
8 | smooth = 1e-5
9 |
10 | y_true_f = target.contiguous().view(target.shape[0], -1)
11 | y_pred_f = predict.contiguous().view(predict.shape[0], -1)
12 | intersection = torch.sum(torch.mul(y_pred_f, y_true_f), dim=1)
13 | union = torch.sum(y_pred_f, dim=1) + torch.sum(y_true_f, dim=1) + smooth
14 | dice_score = (2.0 * intersection / union)
15 |
16 | dice_loss = 1 - dice_score
17 |
18 | return dice_loss
19 |
20 |
21 | class Dice_Loss(nn.Module):
22 | def __init__(self):
23 | super(Dice_Loss, self).__init__()
24 |
25 | def forward(self, predicts, target):
26 |
27 | preds = torch.softmax(predicts, dim=1)
28 | dice_loss0 = _dice_loss(preds[:, 0, :, :], 1 - target)
29 | dice_loss1 = _dice_loss(preds[:, 1, :, :], target)
30 | loss_D = (dice_loss0.mean() + dice_loss1.mean())/2.0
31 |
32 | return loss_D
33 |
34 |
35 | class Task_Interaction_Loss(nn.Module):
36 |
37 | def __init__(self):
38 | super(Task_Interaction_Loss, self).__init__()
39 |
40 | def forward(self, cls_predict, seg_predict, target):
41 |
42 | b,c = cls_predict.shape
43 | h, w = seg_predict.shape[2], seg_predict.shape[3]
44 |
45 | target = target.view(b,1)
46 | target = torch.zeros(b,c).cuda().scatter_(1,target,1)
47 | target_new = torch.zeros(b,c-1).cuda()
48 | cls_pred = Variable(torch.zeros(b,c-1)).cuda()
49 | seg_pred = Variable(torch.zeros(b,c-1)).cuda()
50 |
51 | target_new[:,0] = target[:,0]
52 | target_new[:,1] = target[:,1] + target[:,2]
53 |
54 | cls_pred[:,0] = cls_predict[:,0]
55 | cls_pred[:,1] = cls_predict[:,1] + cls_predict[:,2]
56 |
57 | # Log Sum Exp
58 | seg_pred = torch.logsumexp(seg_predict, dim=(2,3))/(h*w)
59 |
60 | #JS
61 | seg_cls_kl = F.kl_div(torch.log_softmax(cls_pred, dim=-1), torch.softmax(seg_pred, dim=-1), reduction='none')
62 | cls_seg_kl = F.kl_div(torch.log_softmax(seg_pred, dim=-1), torch.softmax(cls_pred, dim=-1), reduction='none')
63 |
64 | seg_cls_kl = seg_cls_kl.sum(-1)
65 | cls_seg_kl = cls_seg_kl.sum(-1)
66 |
67 | indicator1 = (cls_pred[:,0]>cls_pred[:,1]) == (target_new[:,0]>target_new[:,1])
68 | indicator2 = (seg_pred[:,0]>seg_pred[:,1]) == (target_new[:,0]>target_new[:,1])
69 |
70 | return (cls_seg_kl*indicator1 + seg_cls_kl*indicator2).sum()/2./b
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/net/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNormReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/dataset/my_datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | from torch.utils import data
4 | from torchvision import transforms
5 | from PIL import Image
6 | import os
7 | import pickle
8 |
9 |
10 | class CADCAPDataset(data.Dataset):
11 | def __init__(self, dataset_root, DATA_PKL, SIZE, data_type = 'train', mode = 'train'):
12 | self.dataset_root = dataset_root
13 | if os.path.exists(DATA_PKL):
14 | with open(DATA_PKL, 'rb') as f:
15 | info = pickle.load(f)
16 | if data_type == 'train':
17 | self.data = info['train']
18 | else:
19 | self.data = info['test']
20 | assert (data_type == 'train' and mode == 'train') or (data_type == 'train' and mode == 'test') or (data_type == 'test' and mode == 'test'), print('mode setting error in dataset....')
21 | self.mode = mode
22 | self.train_augmentation = transforms.Compose(
23 | [transforms.RandomAffine(degrees=90, shear=5.729578),
24 | transforms.RandomVerticalFlip(p=0.5),
25 | transforms.RandomHorizontalFlip(p=0.5),
26 | transforms.ToTensor(),
27 | transforms.ToPILImage(),
28 | transforms.Resize(SIZE)
29 | ])
30 | self.train_gt_augmentation = transforms.Compose(
31 | [transforms.RandomAffine(degrees=90, shear=5.729578),
32 | transforms.RandomVerticalFlip(p=0.5),
33 | transforms.RandomHorizontalFlip(p=0.5),
34 | transforms.ToTensor(),
35 | transforms.ToPILImage(),
36 | transforms.Resize(SIZE)
37 | ])
38 |
39 | self.test_augmentation = transforms.Compose(
40 | [transforms.ToTensor(),
41 | transforms.ToPILImage(),
42 | transforms.Resize(SIZE)
43 | ])
44 | self.test_gt_augmentation = transforms.Compose(
45 | [transforms.ToTensor(),
46 | transforms.ToPILImage(),
47 | transforms.Resize(SIZE)
48 | ])
49 |
50 | def __len__(self):
51 | return len(self.data)
52 |
53 | def __getitem__(self, idx):
54 | patient = self.data[idx]
55 | image = Image.open(os.path.join(self.dataset_root, patient['image']))
56 | mask = Image.open(os.path.join(self.dataset_root, patient['mask'])).convert('1')
57 | label = patient['label']
58 |
59 | if self.mode == 'train':
60 | seed = np.random.randint(123456)
61 | random.seed(seed)
62 | image = self.train_augmentation(image)
63 | random.seed(seed)
64 | mask = self.train_gt_augmentation(mask)
65 | else:
66 | image = self.test_augmentation(image)
67 | mask = self.test_gt_augmentation(mask)
68 |
69 | image = np.array(image) / 255.
70 | image = image.transpose((2, 0, 1))
71 | image = image.astype(np.float32)
72 | mask = np.array(mask)
73 | mask = np.float32(mask > 0)
74 | name = patient['image'].split('.')[0].replace('/','_' )
75 |
76 | return image.copy(), mask.copy(), label, name
77 |
78 |
79 |
80 |
81 |
82 |
83 |
--------------------------------------------------------------------------------
/net/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/net/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/net/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import net.xception as xception
5 |
6 | from .ASPP import ASPP
7 | from .convs import SeparableConv2d
8 | from .modules import Lesion_Location_Mining
9 | from .modules import Category_guided_Feature_Generation
10 | from .modules import Global_Prototypes_Generator
11 |
12 | class DSI_Net(nn.Module):
13 | def __init__(self, config, K=100):
14 | super(DSI_Net, self).__init__()
15 | self.backbone = None
16 | self.backbone_layers = None
17 | self.dropout = nn.Dropout(0.5)
18 | self.upsample_sub_x2 = nn.UpsamplingBilinear2d(scale_factor=2)
19 | self.upsample_sub_x4 = nn.UpsamplingBilinear2d(scale_factor=4)
20 | self.shortcut_conv = nn.Sequential(nn.Conv2d(256, 48, 1, 1, padding=1//2, bias=True),
21 | nn.BatchNorm2d(48),
22 | nn.ReLU(inplace=True),
23 | )
24 | self.aspp = ASPP(dim_in=2048, dim_out=256, rate=16//16, bn_mom = 0.99)
25 | self.coarse_head = nn.Sequential(
26 | nn.Conv2d(256+48, 256, 3, 1, padding=1, bias=True),
27 | nn.BatchNorm2d(256),
28 | nn.ReLU(inplace=True),
29 | nn.Dropout(0.5),
30 | nn.Conv2d(256, 256, 3, 1, padding=1, bias=True),
31 | nn.BatchNorm2d(256),
32 | nn.ReLU(inplace=True),
33 | nn.Dropout(0.1),
34 | nn.Conv2d(256, 2, kernel_size=1, stride=1, padding=0, bias=True)
35 | )
36 |
37 | self.fine_head = nn.Sequential(
38 | nn.Conv2d(256+64+48, 256, 3, 1, padding=1, bias=True),
39 | nn.BatchNorm2d(256),
40 | nn.ReLU(inplace=True),
41 | nn.Dropout(0.5),
42 | nn.Conv2d(256, 256, 3, 1, padding=1, bias=True),
43 | nn.BatchNorm2d(256),
44 | nn.ReLU(inplace=True),
45 | nn.Dropout(0.1),
46 | nn.Conv2d(256, 2, kernel_size=1, stride=1, padding=0, bias=True)
47 | )
48 |
49 | self.cls_head = nn.Sequential(
50 | SeparableConv2d(1024, 1536, 3, dilation=2, stride=1, padding=2, bias=False),
51 | nn.BatchNorm2d(1536),
52 | nn.ReLU(inplace=True),
53 | SeparableConv2d(1536, 1536, 3, dilation=2, stride=1, padding=2, bias=False),
54 | nn.BatchNorm2d(1536),
55 | nn.ReLU(inplace=True),
56 | SeparableConv2d(1536, 2048, 3, dilation=2, stride=1, padding=2, bias=False),
57 | nn.BatchNorm2d(2048),
58 | nn.ReLU(inplace=True))
59 | self.LLM = Lesion_Location_Mining(config, 1024, K)
60 | self.GPG = Global_Prototypes_Generator(2048, config.INTERMIDEATE_NUM)
61 | self.CFG = Category_guided_Feature_Generation(256, config.INTERMIDEATE_NUM, config.EM_STEP)
62 | self.avgpool = nn.AdaptiveAvgPool2d(1)
63 | self.cls_predict = nn.Linear(2048, config.NUM_CLASSES_CLS, bias = False)
64 |
65 | for m in self.modules():
66 | if isinstance(m, nn.Conv2d):
67 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
68 | elif isinstance(m, nn.BatchNorm2d):
69 | nn.init.constant_(m.weight, 1)
70 | nn.init.constant_(m.bias, 0)
71 |
72 | self.backbone = xception.Xception(os = config.OS)
73 | self.backbone_layers = self.backbone.get_layers()
74 |
75 |
76 | def forward(self, x):
77 | x = self.backbone(x)
78 |
79 | #shllow feature
80 | layers = self.backbone.get_layers()
81 | feature_shallow = self.shortcut_conv(layers[0])
82 | feature_aspp = self.aspp(layers[-1])
83 |
84 | #coarse seg
85 | feature_coarse= self.dropout(feature_aspp)
86 | feature_coarse = self.upsample_sub_x2(feature_coarse)
87 | feature_coarse = torch.cat([feature_coarse,feature_shallow],1)
88 | seg_coarse = self.coarse_head(feature_coarse)
89 |
90 | #####cls
91 | cls_feats = layers[-2]
92 | b, c, h, w = cls_feats.size()
93 | mask_coarse = torch.softmax(seg_coarse, dim = 1)
94 | mask_coarse = F.interpolate(mask_coarse, size=(h, w), mode="bilinear", align_corners=False)
95 |
96 | cls_feats = self.LLM(cls_feats, mask_coarse)
97 | cls_feats = self.cls_head(cls_feats)
98 | cls_out = self.avgpool(cls_feats)
99 | cls_out = cls_out.view(b, -1)
100 | cls_out = self.cls_predict(cls_out)
101 |
102 | #fine seg
103 | global_prototypes = self.GPG(self.cls_predict.weight.detach(), cls_out.detach())
104 | context= self.CFG(feature_aspp, mask_coarse, global_prototypes)
105 | context = self.upsample_sub_x2(context)
106 | feature_fine= self.dropout(feature_aspp)
107 | feature_fine = self.upsample_sub_x2(feature_fine)
108 | feature_fine = torch.cat([feature_fine,context,feature_shallow],1)
109 | seg_fine = self.fine_head(feature_fine)
110 |
111 | #final seg
112 | seg_coarse = self.upsample_sub_x4(seg_coarse)
113 | seg_fine = self.upsample_sub_x4(seg_fine)
114 |
115 | return seg_coarse, seg_fine, cls_out
116 |
117 |
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/net/xception.py:
--------------------------------------------------------------------------------
1 | """
2 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch)
3 | @author: tstandley
4 | Adapted by cadene
5 | Creates an Xception Model as defined in:
6 | Francois Chollet
7 | Xception: Deep Learning with Depthwise Separable Convolutions
8 | https://arxiv.org/pdf/1610.02357.pdf
9 | This weights ported from the Keras implementation. Achieves the following performance on the validation set:
10 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292
11 | REMEMBER to set your image size to 3x299x299 for both test and validation
12 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
13 | std=[0.5, 0.5, 0.5])
14 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
15 | """
16 | import math
17 | import torch
18 | import torch.nn as nn
19 |
20 | bn_mom = 0.0003
21 | __all__ = ['xception']
22 |
23 |
24 | class SeparableConv2d(nn.Module):
25 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False,activate_first=True,inplace=True):
26 | super(SeparableConv2d,self).__init__()
27 | self.relu0 = nn.ReLU(inplace=inplace)
28 | self.depthwise = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
29 | self.bn1 = nn.BatchNorm2d(in_channels)
30 | self.relu1 = nn.ReLU(inplace=True)
31 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
32 | self.bn2 = nn.BatchNorm2d(out_channels)
33 | self.relu2 = nn.ReLU(inplace=True)
34 | self.activate_first = activate_first
35 | def forward(self,x):
36 | if self.activate_first:
37 | x = self.relu0(x)
38 | x = self.depthwise(x)
39 | x = self.bn1(x)
40 | if not self.activate_first:
41 | x = self.relu1(x)
42 | x = self.pointwise(x)
43 | x = self.bn2(x)
44 | if not self.activate_first:
45 | x = self.relu2(x)
46 | return x
47 |
48 |
49 | class Block(nn.Module):
50 | def __init__(self,in_filters,out_filters,strides=1,atrous=None,grow_first=True,activate_first=True,inplace=True):
51 | super(Block, self).__init__()
52 | if atrous == None:
53 | atrous = [1]*3
54 | elif isinstance(atrous, int):
55 | atrous_list = [atrous]*3
56 | atrous = atrous_list
57 | idx = 0
58 | self.head_relu = True
59 | if out_filters != in_filters or strides!=1:
60 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
61 | self.skipbn = nn.BatchNorm2d(out_filters)
62 | self.head_relu = False
63 | else:
64 | self.skip=None
65 |
66 | self.hook_layer = None
67 | if grow_first:
68 | filters = out_filters
69 | else:
70 | filters = in_filters
71 | self.sepconv1 = SeparableConv2d(in_filters,filters,3,stride=1,padding=1*atrous[0],dilation=atrous[0],bias=False,activate_first=activate_first,inplace=self.head_relu)
72 | self.sepconv2 = SeparableConv2d(filters,out_filters,3,stride=1,padding=1*atrous[1],dilation=atrous[1],bias=False,activate_first=activate_first)
73 | self.sepconv3 = SeparableConv2d(out_filters,out_filters,3,stride=strides,padding=1*atrous[2],dilation=atrous[2],bias=False,activate_first=activate_first,inplace=inplace)
74 |
75 | def forward(self,inp):
76 |
77 | if self.skip is not None:
78 | skip = self.skip(inp)
79 | skip = self.skipbn(skip)
80 | else:
81 | skip = inp
82 |
83 | x = self.sepconv1(inp)
84 | x = self.sepconv2(x)
85 | self.hook_layer = x
86 | x = self.sepconv3(x)
87 |
88 | x+=skip
89 | return x
90 |
91 |
92 | class Xception(nn.Module):
93 | """
94 | Xception optimized for the ImageNet dataset, as specified in
95 | https://arxiv.org/pdf/1610.02357.pdf
96 | """
97 | def __init__(self, os):
98 | """ Constructor
99 | Args:
100 | num_classes: number of classes
101 | """
102 | super(Xception, self).__init__()
103 |
104 | stride_list = None
105 | if os == 8:
106 | stride_list = [2,1,1]
107 | elif os == 16:
108 | stride_list = [2,2,1]
109 | else:
110 | raise ValueError('xception.py: output stride=%d is not supported.'%os)
111 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
112 | self.bn1 = nn.BatchNorm2d(32)
113 | self.relu = nn.ReLU(inplace=True)
114 |
115 | self.conv2 = nn.Conv2d(32,64,3,1,1,bias=False)
116 | self.bn2 = nn.BatchNorm2d(64)
117 | #do relu here
118 |
119 | self.block1=Block(64,128,2)
120 | self.block2=Block(128,256,stride_list[0],inplace=False)
121 | self.block3=Block(256,728,stride_list[1])
122 |
123 | rate = 16//os
124 | self.block4=Block(728,728,1,atrous=rate)
125 | self.block5=Block(728,728,1,atrous=rate)
126 | self.block6=Block(728,728,1,atrous=rate)
127 | self.block7=Block(728,728,1,atrous=rate)
128 |
129 | self.block8=Block(728,728,1,atrous=rate)
130 | self.block9=Block(728,728,1,atrous=rate)
131 | self.block10=Block(728,728,1,atrous=rate)
132 | self.block11=Block(728,728,1,atrous=rate)
133 |
134 | self.block12=Block(728,728,1,atrous=rate)
135 | self.block13=Block(728,728,1,atrous=rate)
136 | self.block14=Block(728,728,1,atrous=rate)
137 | self.block15=Block(728,728,1,atrous=rate)
138 |
139 | self.block16=Block(728,728,1,atrous=[1*rate,1*rate,1*rate])
140 | self.block17=Block(728,728,1,atrous=[1*rate,1*rate,1*rate])
141 | self.block18=Block(728,728,1,atrous=[1*rate,1*rate,1*rate])
142 | self.block19=Block(728,728,1,atrous=[1*rate,1*rate,1*rate])
143 |
144 | self.block20=Block(728,1024,stride_list[2],atrous=rate,grow_first=False)
145 | #self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
146 |
147 | self.conv3 = SeparableConv2d(1024,1536,3,1,1*rate,dilation=rate,activate_first=False)
148 | # self.bn3 = nn.BatchNorm2d(1536)
149 |
150 | self.conv4 = SeparableConv2d(1536,1536,3,1,1*rate,dilation=rate,activate_first=False)
151 | # self.bn4 = nn.BatchNorm2d(1536)
152 |
153 | #do relu here
154 | self.conv5 = SeparableConv2d(1536,2048,3,1,1*rate,dilation=rate,activate_first=False)
155 | # self.bn5 = nn.BatchNorm2d(2048)
156 | self.layers = []
157 |
158 | #------- init weights --------
159 | for m in self.modules():
160 | if isinstance(m, nn.Conv2d):
161 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
162 | m.weight.data.normal_(0, math.sqrt(2. / n))
163 | elif isinstance(m, nn.BatchNorm2d):
164 | m.weight.data.fill_(1)
165 | m.bias.data.zero_()
166 | #-----------------------------
167 |
168 | def forward(self, input):
169 | self.layers = []
170 | x = self.conv1(input)
171 | x = self.bn1(x)
172 | x = self.relu(x)
173 | #self.layers.append(x)
174 | x = self.conv2(x)
175 | x = self.bn2(x)
176 | x = self.relu(x)
177 |
178 | x = self.block1(x)
179 | x = self.block2(x)
180 | self.layers.append(self.block2.hook_layer)
181 | x = self.block3(x)
182 | # self.layers.append(self.block3.hook_layer)
183 | x = self.block4(x)
184 | x = self.block5(x)
185 | x = self.block6(x)
186 | x = self.block7(x)
187 | x = self.block8(x)
188 | x = self.block9(x)
189 | x = self.block10(x)
190 | x = self.block11(x)
191 | x = self.block12(x)
192 | x = self.block13(x)
193 | x = self.block14(x)
194 | x = self.block15(x)
195 | x = self.block16(x)
196 | x = self.block17(x)
197 | x = self.block18(x)
198 | x = self.block19(x)
199 | x = self.block20(x)
200 | # self.layers.append(self.block20.hook_layer)
201 | self.layers.append(x)
202 | x = self.conv3(x)
203 | # x = self.bn3(x)
204 | # x = self.relu(x)
205 |
206 | x = self.conv4(x)
207 | # x = self.bn4(x)
208 | # x = self.relu(x)
209 |
210 | x = self.conv5(x)
211 | # x = self.bn5(x)
212 | # x = self.relu(x)
213 | self.layers.append(x)
214 |
215 | return x
216 |
217 | def get_layers(self):
218 | return self.layers
219 |
--------------------------------------------------------------------------------
/dataset/preprocess.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Oct 15 16:36:36 2020
4 |
5 | @author: meiluzhu
6 | """
7 |
8 | import os
9 | import numpy as np
10 | import pickle
11 | import cv2
12 |
13 | m_type = ['vascularlesions','inflammatory','normal']
14 | patients = []
15 | images = []
16 | labels = []
17 | res = 288
18 |
19 | #### CAD-CAP
20 | base = 'C:\\ZML\\Dataset\\WCE\\CAD-CAP'
21 | save_base = 'C:\\ZML\\Dataset\\WCE\\temp\\CAD-CAP'
22 |
23 | file = os.listdir(os.path.join(base, m_type[1]))
24 | for f in file:
25 | filename = f.split('.')
26 | if os.path.exists(os.path.join(base,m_type[1],filename[0])+'_a'+'.jpg'):
27 | img_dir = m_type[1]+'\\'+ f
28 | mask_dir = m_type[1]+'\\'+ filename[0]+'_a.jpg'
29 | print(img_dir)
30 | image = cv2.imread(os.path.join(base, img_dir))
31 | mask = cv2.imread(os.path.join(base, mask_dir),cv2.IMREAD_GRAYSCALE)
32 | h,w,c = image.shape
33 | h_top = 0
34 | h_bot = h
35 | w_lef = 0
36 | w_rig = w
37 | if h == 576 or h == 704:
38 | h_top = 32
39 | h_bot = h-32
40 | w_lef = 32
41 | w_rig = w-32
42 | post_img = image[h_top:h_bot,w_lef:w_rig,:]
43 | post_img[0:45,0:15,:] = 0
44 | post_mask = mask[h_top:h_bot,w_lef:w_rig]
45 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST)
46 | post_mask = cv2.resize(post_mask, (res,res), interpolation=cv2.INTER_NEAREST)
47 | new_dir = os.path.join(save_base,img_dir)
48 | cv2.imwrite(new_dir, post_img)
49 | new_dir = os.path.join(save_base,mask_dir)
50 | cv2.imwrite(new_dir, post_mask)
51 | post_mask[post_mask>0] = 1
52 | patient = {'image':'CAD-CAP/'+img_dir.replace('\\','/' ), 'mask': 'CAD-CAP/'+mask_dir.replace('\\','/' ),'label': 2}
53 | patients.append(patient)
54 |
55 | file = os.listdir(os.path.join(base, m_type[0]))
56 | for f in file:
57 | filename = f.split('.')
58 | if os.path.exists(os.path.join(base,m_type[0],filename[0])+'_a'+'.jpg'):
59 | img_dir = m_type[0]+'/'+ f
60 | mask_dir = m_type[0]+'/'+ filename[0]+'_a.jpg'
61 | print(img_dir)
62 | image = cv2.imread(os.path.join(base, img_dir))
63 | mask = cv2.imread(os.path.join(base, mask_dir),cv2.IMREAD_GRAYSCALE)
64 | ######
65 | h,w,c = image.shape
66 | h_top = 0
67 | h_bot = h
68 | w_lef = 0
69 | w_rig = w
70 | if h == 576 or h == 704:
71 | h_top = 32
72 | h_bot = h-32
73 | w_lef = 32
74 | w_rig = w-32
75 | post_img = image[h_top:h_bot,w_lef:w_rig,:]
76 | h,w,c = post_img.shape
77 | if h>600:
78 | post_img[0:10,0:139,:] = 0
79 | post_img[h-2:h,0:191,:] = 0
80 | post_img[h-5:h,0:150,:] = 0
81 | else:
82 | post_img[0:10,0:115,:] = 0
83 | post_img[h-2:h,0:133,:] = 0
84 | post_img[h-5:h,0:110,:] = 0
85 | post_img[0:60,0:21,:] = 0
86 | post_mask = mask[h_top:h_bot,w_lef:w_rig]
87 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST)
88 | post_mask = cv2.resize(post_mask, (res,res), interpolation=cv2.INTER_NEAREST)
89 | new_dir = os.path.join(save_base,img_dir)
90 | cv2.imwrite(new_dir, post_img)
91 | new_dir = os.path.join(save_base,mask_dir)
92 | cv2.imwrite(new_dir, post_mask)
93 | post_mask[post_mask>0] = 1
94 | patient = {'image':'CAD-CAP/'+img_dir.replace('\\','/' ), 'mask': 'CAD-CAP/'+mask_dir.replace('\\','/' ),'label': 1}
95 | patients.append(patient)
96 |
97 | file = os.listdir(os.path.join(base, m_type[2]))
98 | for f in file:
99 | img_dir = m_type[2]+'/'+ f
100 | print(img_dir)
101 | image = cv2.imread(os.path.join(base, img_dir))
102 | h,w,c = image.shape
103 | h_top = 0
104 | h_bot = h
105 | w_lef = 0
106 | w_rig = w
107 | if h == 576 or h == 704:
108 | h_top = 32
109 | h_bot = h-32
110 | w_lef = 32
111 | w_rig = w-32
112 | post_img = image[h_top:h_bot,w_lef:w_rig,:]
113 | h,w,c = post_img.shape
114 | post_img[0:45,0:15,:] = 0
115 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST)
116 | img_dir = img_dir.replace(' ', '_')
117 | new_dir = os.path.join(save_base,img_dir)
118 | cv2.imwrite(new_dir, post_img)
119 | post_mask = np.zeros((res,res), dtype = np.uint8)
120 | filename = img_dir.split('.')
121 | mask_dir = filename[0]+'_a.jpg'
122 | new_dir = os.path.join(save_base,mask_dir)
123 | cv2.imwrite(new_dir, post_mask)
124 | post_mask[post_mask>0] = 1
125 | patient = {'image':'CAD-CAP/'+img_dir.replace('\\','/' ), 'mask': 'CAD-CAP/'+mask_dir.replace('\\','/' ),'label': 0}
126 | patients.append(patient)
127 |
128 | ####KID
129 | base = 'C:\\ZML\\Dataset\\WCE\\KID'
130 | save_base = 'C:\\ZML\\Dataset\\WCE\\temp\\KID'
131 | era = 4
132 | file = os.listdir(os.path.join(base, m_type[1]))
133 | for f in file:
134 | filename = f.split('.')
135 | if os.path.exists(os.path.join(base,m_type[1],filename[0])+'m'+'.png'):
136 | img_dir = m_type[1]+'\\'+ f
137 | mask_dir = m_type[1]+'\\'+ filename[0]+'m.png'
138 | print(img_dir)
139 | image = cv2.imread(os.path.join(base, img_dir))
140 | mask = cv2.imread(os.path.join(base, mask_dir),cv2.IMREAD_GRAYSCALE)
141 | h,w,c = image.shape
142 | h_top = 20 + era
143 | h_bot = h - 20 - era
144 | w_lef = 20 + era
145 | w_rig = w - 20 - era
146 | post_img = image[h_top:h_bot,w_lef:w_rig,:]
147 | post_mask = mask[h_top:h_bot,w_lef:w_rig]
148 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST)
149 | post_mask = cv2.resize(post_mask, (res,res), interpolation=cv2.INTER_NEAREST)
150 | new_dir = os.path.join(save_base,img_dir)
151 | cv2.imwrite(new_dir, post_img)
152 | new_dir = os.path.join(save_base,mask_dir)
153 | cv2.imwrite(new_dir, post_mask)
154 | post_mask[post_mask>0] = 1
155 | patient = {'image':'KID/'+img_dir.replace('\\','/' ), 'mask': 'KID/'+mask_dir.replace('\\','/' ),'label': 2}
156 | patients.append(patient)
157 |
158 | file = os.listdir(os.path.join(base, m_type[0]))
159 | for f in file:
160 | filename = f.split('.')
161 | if os.path.exists(os.path.join(base,m_type[0],filename[0])+'m'+'.png'):
162 | img_dir = m_type[0]+'/'+ f
163 | mask_dir = m_type[0]+'/'+ filename[0]+'m.png'
164 | print(img_dir)
165 | image = cv2.imread(os.path.join(base, img_dir))
166 | mask = cv2.imread(os.path.join(base, mask_dir),cv2.IMREAD_GRAYSCALE)
167 | h,w,c = image.shape
168 | h_top = 20 + era
169 | h_bot = h - 20 - era
170 | w_lef = 20 + era
171 | w_rig = w - 20 - era
172 | post_img = image[h_top:h_bot,w_lef:w_rig,:]
173 | post_mask = mask[h_top:h_bot,w_lef:w_rig]
174 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST)
175 | post_mask = cv2.resize(post_mask, (res,res), interpolation=cv2.INTER_NEAREST)
176 | new_dir = os.path.join(save_base,img_dir)
177 | cv2.imwrite(new_dir, post_img)
178 | new_dir = os.path.join(save_base,mask_dir)
179 | cv2.imwrite(new_dir, post_mask)
180 | post_mask[post_mask>0] = 1
181 | patient = {'image':'KID/'+img_dir.replace('\\','/' ), 'mask': 'KID/'+mask_dir.replace('\\','/' ),'label': 1}
182 | patients.append(patient)
183 |
184 | m_type[2] = 'normal-small-bowel'
185 | file = os.listdir(os.path.join(base, m_type[2]))
186 | for f in file:
187 | img_dir = m_type[2]+'/'+ f
188 | print(img_dir)
189 | image = cv2.imread(os.path.join(base, img_dir))
190 | h,w,c = image.shape
191 | h_top = 20 + era
192 | h_bot = h - 20 - era
193 | w_lef = 20 + era
194 | w_rig = w - 20 - era
195 | post_img = image[h_top:h_bot,w_lef:w_rig,:]
196 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST)
197 | img_dir = img_dir.replace(' ', '_')
198 | img_dir = img_dir.replace('normal-small-bowel', 'normal')
199 | new_dir = os.path.join(save_base,img_dir)
200 | cv2.imwrite(new_dir, post_img)
201 | post_mask = np.zeros((res,res), dtype = np.uint8)
202 | filename = img_dir.split('.')
203 | mask_dir = filename[0]+'m.png'
204 | new_dir = os.path.join(save_base,mask_dir)
205 | cv2.imwrite(new_dir, post_mask)
206 | post_mask[post_mask>0] = 1
207 | patient = {'image':'KID/'+img_dir.replace('\\','/' ), 'mask': 'KID/'+mask_dir.replace('\\','/' ),'label': 0}
208 | patients.append(patient)
209 |
210 | np.random.shuffle(patients)
211 | trainset = patients[0:2470]
212 | testset = patients[2470:]
213 | dataset = {'train': trainset, 'test': testset}
214 | path = os.path.join('C:\\ZML\\Dataset\\WCE\\temp', 'WCE_Dataset_larger_Fold1.pkl')
215 | if os.path.exists(path):
216 | os.remove(path)
217 | with open(path,'wb') as f:
218 | pickle.dump(dataset, f)
219 |
220 |
221 |
--------------------------------------------------------------------------------
/net/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .gumbel import Gumbel
5 |
6 | class Category_guided_Feature_Generation(nn.Module):
7 |
8 | def __init__(self,
9 | in_channels = 256,
10 | out_channels = 64, EM_STEP = 3):
11 | super(Category_guided_Feature_Generation, self).__init__()
12 | self.out_channels = out_channels
13 | self.EM_STEP = EM_STEP
14 | self.conv0 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1),
15 | nn.BatchNorm2d(out_channels),
16 | nn.ReLU(True),
17 | nn.Dropout2d(0.2, False),
18 | nn.Conv2d(out_channels, out_channels, 1),
19 | nn.BatchNorm2d(out_channels),
20 | nn.ReLU(True),
21 | nn.Dropout2d(0.1, False),
22 | )
23 |
24 | self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1),
25 | nn.BatchNorm2d(out_channels),
26 | nn.ReLU(True),
27 | nn.Dropout2d(0.2, False),
28 | nn.Conv2d(out_channels, out_channels, 1),
29 | nn.BatchNorm2d(out_channels),
30 | nn.ReLU(True),
31 | nn.Dropout2d(0.1, False),
32 | )
33 |
34 | self.conv2 = nn.Sequential(nn.Conv2d(out_channels*2, out_channels, 1),
35 | nn.BatchNorm2d(out_channels),
36 | nn.ReLU(True),
37 | nn.Dropout2d(0.2, False)
38 | )
39 |
40 | def forward(self, x, coarse_mask, global_prototypes, regular = 0.5):
41 |
42 | b, h, w = x.size(0), x.size(2), x.size(3)
43 | classes_num = coarse_mask.size(1)
44 | feats = self.conv0(x)
45 | pseudo_mask = coarse_mask.view(b, classes_num, -1)
46 | feats = feats.view(b, self.out_channels, -1).permute(0, 2, 1)
47 | # EM
48 | T = self.EM_STEP
49 | for t in range(T):
50 | prototypes = torch.bmm(pseudo_mask, feats)
51 | prototypes = prototypes / (1e-8 + prototypes.norm(dim=1, keepdim=True))
52 | attention = torch.bmm(prototypes, feats.permute(0, 2, 1))
53 | attention = (self.out_channels**-regular) * attention
54 | pseudo_mask = torch.softmax(attention, dim=1)
55 | pseudo_mask = pseudo_mask / (1e-8 + pseudo_mask.sum(dim=1, keepdim=True))
56 | context_l = torch.bmm(prototypes.permute(0, 2, 1), pseudo_mask).view(b, self.out_channels, h, w)
57 |
58 | feats = self.conv1(x)
59 | feats = feats.view(b, self.out_channels, -1).permute(0, 2, 1)
60 | global_prototypes = global_prototypes / (1e-8 + global_prototypes.norm(dim=1, keepdim=True))
61 |
62 | #EM
63 | T = self.EM_STEP
64 | for t in range(T):
65 | attention = torch.bmm(global_prototypes, feats.permute(0, 2, 1))
66 | attention = (self.out_channels**-regular) * attention
67 | pseudo_mask = torch.softmax(attention, dim=1)
68 | pseudo_mask = pseudo_mask / (1e-8 + pseudo_mask.sum(dim=1, keepdim=True))
69 | global_prototypes = torch.bmm(pseudo_mask, feats)
70 | global_prototypes = global_prototypes / (1e-8 + global_prototypes.norm(dim=1, keepdim=True))
71 | context_g = torch.bmm(global_prototypes.permute(0, 2, 1), pseudo_mask).view(b, self.out_channels, h, w) # b, 64, 56*56
72 |
73 | context = torch.cat((context_l, context_g), dim = 1)
74 | context = self.conv2(context)
75 |
76 | return context
77 |
78 |
79 | class Global_Prototypes_Generator(nn.Module):
80 |
81 | def __init__(self,
82 | in_channels = 2048,
83 | out_channels = 64):
84 | super(Global_Prototypes_Generator, self).__init__()
85 | self.out_channels = out_channels
86 | self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1),
87 | nn.BatchNorm2d(out_channels),
88 | nn.ReLU(True),
89 | nn.Conv2d(out_channels, out_channels, 1)
90 | )
91 |
92 | def forward(self, prototypes, category):
93 | classes_num, c = prototypes.size(0), prototypes.size(1)
94 | prototypes = prototypes.view(classes_num,c, 1, 1)
95 | prototypes = self.conv1(prototypes).view(classes_num,self.out_channels)
96 | category = torch.softmax(category, dim = 1)
97 | b = category.size(0)
98 | bg_prototypes = prototypes[0]
99 | bg_prototypes = bg_prototypes.repeat(b, 1, 1)
100 | fg_prototypes = category[:,1:].view(b, classes_num-1, 1) * prototypes[1:]
101 | prototypes = torch.cat((bg_prototypes, fg_prototypes), dim = 1)
102 |
103 | return prototypes
104 |
105 |
106 |
107 |
108 | class Binary_Gate_Unit(nn.Module):
109 |
110 | def __init__(self, config, in_channels = 1024, k = 100):
111 | super(Binary_Gate_Unit, self).__init__()
112 | self.in_channels = in_channels
113 | self.k = k
114 | self.conv= nn.Sequential(
115 | nn.Conv2d(in_channels=self.in_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias = False),
116 | nn.ReLU(inplace=True)
117 | )
118 | self.fc1 = nn.Linear(k, int(torch.ceil(torch.tensor(k/2))))
119 | self.fc2 = nn.Linear(int(torch.ceil(torch.tensor(k/2))), k)
120 |
121 | self.gumbel = Gumbel(config)
122 |
123 | def forward(self,topk_prototypes):
124 |
125 | b = topk_prototypes.size(0)
126 | proto_weights = self.conv(topk_prototypes) # meta learner
127 | proto_weights = proto_weights.view(b, -1)
128 | proto_weights = self.fc1(torch.relu(proto_weights))
129 | proto_weights = self.fc2(proto_weights) # b, k
130 | proto_weights = self.gumbel(proto_weights)
131 | proto_weights = proto_weights.view(b, 1, self.k, 1)
132 |
133 | return proto_weights
134 |
135 |
136 | class Lesion_Location_Mining(nn.Module):
137 |
138 | def __init__(self, config, in_channels = 1024, k = 100):
139 | super(Lesion_Location_Mining, self).__init__()
140 | self.k = k
141 | self.BGU_fore =Binary_Gate_Unit(config, in_channels = in_channels, k = k)
142 | self.BGU_back =Binary_Gate_Unit(config, in_channels = in_channels, k = k)
143 |
144 | def forward(self, feats, soft_mask):
145 |
146 | b,c,h,w = feats.size()
147 | hard_mask = torch.max(soft_mask, dim = 1, keepdim = True)[1] # b, 1, h, w
148 | background_hard_mask = (hard_mask == 0).float()
149 | foreground_hard_mask = (hard_mask == 1).float()
150 | assert torch.sum(hard_mask == 2) == 0, 'Error in Lesion_Location_Mining_Module'
151 | background_soft_mask, foreground_soft_mask = soft_mask.split(1, dim = 1) #b, 1, h, w
152 | foreground_feats = feats * foreground_hard_mask # b, c, h, w
153 | background_feats = feats * background_hard_mask # b, c, h, w
154 | feats = feats.view(b, c, -1) # b, c, hw
155 |
156 | #****** foreground-->background **********#
157 | #key generator
158 | foreground_soft_mask = foreground_soft_mask.view(b, 1, -1)
159 | topk_idx = torch.topk(foreground_soft_mask, self.k, dim = -1, largest=True)[1]
160 | topk_prototypes = []
161 | for i in range(b):
162 | feats_temp = feats[i,:,topk_idx[i]] # c, k
163 | topk_prototypes.append(feats_temp)
164 | topk_prototypes = torch.stack(topk_prototypes) # b, c, k
165 | topk_prototypes = topk_prototypes.view(b, c, self.k, 1)
166 | proto_weights = self.BGU_fore(topk_prototypes)
167 | topk_prototypes = topk_prototypes * proto_weights # b, c, k, 1
168 |
169 | # b, c, h, w # b, c, k --->
170 | background_feats = background_feats.view(b, c, -1) # b, c, hw
171 | topk_prototypes = topk_prototypes.view(b, c, -1).permute(0, 2, 1) # b, k, c
172 | fore_attention_map = torch.matmul(topk_prototypes, background_feats) # b, k ,hw
173 |
174 | #norm + relu
175 | norm_prototypes = torch.norm(topk_prototypes, dim = -1, keepdim=True) # b, k, 1
176 | norm_background_feats = torch.norm(background_feats, dim = 1, keepdim=True) #b, 1, hw
177 | norm = torch.bmm(norm_prototypes, norm_background_feats) # b, k, hw
178 | fore_attention_map = fore_attention_map /(norm + 1e-8)
179 | fore_attention_map = torch.relu(fore_attention_map)
180 | fore_attention_map = fore_attention_map.view(b, self.k, h, w)
181 | fore_attention_map = torch.max(fore_attention_map, dim = 1, keepdim = True) [0]
182 |
183 | #****** background-->foreground**********#
184 | #key generator
185 | background_soft_mask = background_soft_mask.view(b, 1, -1)
186 | topk_idx = torch.topk(background_soft_mask, self.k, dim = -1, largest=True)[1]
187 | topk_prototypes = []
188 | for i in range(b):
189 | feats_temp = feats[i,:,topk_idx[i]] # c, k
190 | topk_prototypes.append(feats_temp)
191 | topk_prototypes = torch.stack(topk_prototypes) # b, c, k
192 | topk_prototypes = topk_prototypes.view(b, c, self.k, 1)
193 | proto_weights = self.BGU_back(topk_prototypes)
194 | topk_prototypes = topk_prototypes * proto_weights # b, c, k, 1
195 |
196 | # b, c, h, w # b, c, k --->
197 | foreground_feats = foreground_feats.view(b, c, -1) # b, c, hw
198 | topk_prototypes = topk_prototypes.view(b, c, -1).permute(0, 2, 1) # b, k, c
199 | back_attention_map = torch.matmul(topk_prototypes, foreground_feats) # b, k ,hw
200 | #norm + relu
201 | norm_prototypes = torch.norm(topk_prototypes, dim = -1, keepdim=True) # b, k, 1
202 | norm_foreground_feats = torch.norm(foreground_feats, dim = 1, keepdim=True) #b, 1, hw
203 | norm = torch.bmm(norm_prototypes, norm_foreground_feats) # b, k, hw
204 | back_attention_map = back_attention_map /(norm + 1e-8)
205 | back_attention_map = torch.relu(back_attention_map)
206 | back_attention_map = back_attention_map.view(b, self.k, h, w)
207 | back_attention_map = torch.max(back_attention_map, dim = 1, keepdim = True) [0]
208 |
209 | #merging
210 | feats = feats.view(b, c, h, w)
211 | foreground_soft_mask = foreground_soft_mask.view(b, 1, h, w)
212 | feats = feats + feats * (foreground_soft_mask - back_attention_map + fore_attention_map)
213 |
214 | return feats #b, c, h,w
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
--------------------------------------------------------------------------------
/train_DSI_Net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.backends.cudnn as cudnn
5 | import os
6 | from sklearn import metrics
7 | from sklearn.metrics import accuracy_score
8 | from net.models import DSI_Net
9 | from net.loss import Task_Interaction_Loss, Dice_Loss
10 | from dataset.my_datasets import CADCAPDataset
11 | from torch.utils import data
12 | from apex import amp
13 | from utils.logger import print_f
14 | import time
15 | import config
16 | import argparse
17 | from visualization.utils import show_seg_results, draw_curves
18 |
19 | #https://drive.google.com/file/d/12RjjEKM4nXtskHSJkWMdJ7S5PeaVFow3/view?usp=sharing
20 | model_urls = {'deeplabv3plus_xception': 'data/pre_model/deeplabv3plus_xception_VOC2012_epoch46_all.pth'}
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--image_list', default='/home/meiluzhu2/data/WCE/WCE_Dataset_larger_Fold1.pkl', type=str, help='image list pkl')
24 | parser.add_argument('--gpus', default='7', type=str, help='gpus')
25 | parser.add_argument('--K', default=100, type=int, help='seed number')
26 | parser.add_argument('--alpha', default=0.05, type=float, help='the weight of interaction loss')
27 | args = parser.parse_args()
28 |
29 | def lr_poly(base_lr, iter, max_iter, power):
30 | return base_lr * ((1 - float(iter) / max_iter) ** (power))
31 |
32 | def adjust_learning_rate(optimizer, i_iter):
33 | lr = lr_poly(config.LEARNING_RATE, i_iter, config.STEPS, config.POWER)
34 | optimizer.param_groups[0]['lr'] = lr
35 | return lr
36 |
37 | def test(valloader, model, epoch, path = None, verbose = False):
38 | # valiadation
39 | #cls
40 | pro_score_crop = []
41 | label_val_crop = []
42 |
43 | #refine seg
44 | seg_dice = []
45 | seg_sen = []
46 | seg_spe = []
47 | seg_acc = []
48 | seg_jac_score = []
49 |
50 | for index, batch in enumerate(valloader):
51 | data, masks, label, name = batch
52 | data = data.cuda()
53 | label = label.cuda()
54 | mask = masks[0].data.numpy()
55 | val_mask = np.int64(mask > 0)
56 |
57 | model.eval()
58 | with torch.no_grad():
59 | pred_seg_coarse, pred_seg_fine, pred_cls = model(data)
60 |
61 | #cls
62 | pro_score_crop.append(torch.softmax(pred_cls[0], dim=0).cpu().data.numpy())
63 | label_val_crop.append(label[0].cpu().data.numpy())
64 |
65 | #seg
66 | y_true_f = val_mask.reshape(val_mask.shape[0]*val_mask.shape[1], order='F')
67 | if np.sum(y_true_f) != 0 and label[0].cpu().data.numpy() != 0:
68 | pred_seg = torch.softmax(pred_seg_fine, dim=1).cpu().data.numpy()
69 | pred_arg = np.argmax(pred_seg[0], axis=0)
70 | y_pred_f = pred_arg.reshape(pred_arg.shape[0]*pred_arg.shape[1], order='F')
71 | intersection = np.float(np.sum(y_true_f * y_pred_f))
72 | seg_dice.append((2. * intersection) / (np.sum(y_true_f) + np.sum(y_pred_f)))
73 | seg_sen.append(intersection / np.sum(y_true_f))
74 | intersection0 = np.float(np.sum((1 - y_true_f) * (1 - y_pred_f)))
75 | seg_spe.append(intersection0 / np.sum(1 - y_true_f))
76 | seg_acc.append(accuracy_score(y_true_f, y_pred_f))
77 | seg_jac_score.append(intersection / (np.sum(y_true_f) + np.sum(y_pred_f) - intersection))
78 |
79 | if verbose == config.VERBOSE and epoch == config.EPOCH-1:
80 | show_seg_results(data[0].cpu().data.numpy().transpose(1, 2, 0), mask, pred_arg, path, name[0])
81 | #cls
82 | pro_score_crop = np.array(pro_score_crop)
83 | label_val_crop = np.array(label_val_crop)
84 | binary_score = np.eye(3)[np.argmax(np.array(pro_score_crop), axis=-1)]
85 | label_val = np.eye(3)[np.int64(np.array(label_val_crop))]
86 | preds = np.argmax(np.array(pro_score_crop), axis=-1)
87 | CK = metrics.cohen_kappa_score(label_val_crop, preds)
88 | OA = metrics.accuracy_score(label_val_crop, preds)
89 | EREC = metrics.recall_score(label_val, binary_score, average=None)
90 |
91 | result = {}
92 | result['seg'] = [np.array(seg_acc), np.array(seg_dice), np.array(seg_sen), np.array(seg_spe), np.array(seg_jac_score)]
93 | result['cls'] = [CK, OA, EREC]
94 | return result
95 |
96 |
97 | def main():
98 | """Create the network and start the training."""
99 |
100 | cudnn.enabled = True
101 | cudnn.benchmark = True
102 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
103 |
104 | ############# Create mask-guided classification network.
105 | model = DSI_Net(config, K = args.K)
106 | optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay =config.WEIGHT_DECAY)
107 | model.cuda()
108 | if config.FP16 is True:
109 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
110 | model = torch.nn.DataParallel(model)
111 |
112 | ############# Load pretrained weights
113 | pretrained_dict = torch.load(model_urls['deeplabv3plus_xception'])
114 | net_dict = model.state_dict()
115 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict) and (v.shape == net_dict[k].shape)}
116 | net_dict.update(pretrained_dict)
117 | model.load_state_dict(net_dict)
118 | print(len(net_dict))
119 | print(len(pretrained_dict))
120 | model.train()
121 | model.float()
122 | ce_loss = nn.CrossEntropyLoss()
123 | dice_loss = Dice_Loss()
124 | task_interaction_loss = Task_Interaction_Loss()
125 |
126 | ############# Load training and validation data
127 | trainloader = data.DataLoader(CADCAPDataset(config.DATA_ROOT, args.image_list, config.SIZE, data_type='train', mode = 'train'), batch_size=config.BATCH_SIZE, shuffle=True,
128 | num_workers=config.NUM_WORKERS, pin_memory=True, drop_last = config.DROP_LAST)
129 | testloader = data.DataLoader(CADCAPDataset(config.DATA_ROOT, args.image_list,config.SIZE, data_type = 'test', mode='test'),
130 | batch_size=1, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True)
131 | train_testloader = data.DataLoader(CADCAPDataset(config.DATA_ROOT, args.image_list,config.SIZE, data_type = 'train', mode = 'test'),
132 | batch_size=1, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True)
133 |
134 |
135 | if not os.path.isdir(config.SAVE_PATH):
136 | os.mkdir(config.SAVE_PATH)
137 | if not os.path.isdir(config.SAVE_PATH+'Seg_results/'):
138 | os.mkdir(config.SAVE_PATH+'Seg_results/')
139 | if not os.path.isdir(config.LOG_PATH):
140 | os.mkdir(config.LOG_PATH)
141 |
142 | f_path = config.LOG_PATH + 'training_output.log'
143 | logfile = open(f_path, 'a')
144 |
145 | print_f(os.getcwd(), f=logfile)
146 | print_f('Device: {}'.format(args.gpus), f=logfile)
147 | print_f('==={}==='.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())), f=logfile)
148 | print_f('===Setting===', f=logfile)
149 | print_f(' Data_list: {}'.format(args.image_list), f=logfile)
150 | print_f(' K: {}'.format(args.K), f=logfile)
151 | print_f(' Lost_weight: {}'.format(args.alpha), f=logfile)
152 | print_f(' LR: {}'.format(config.LEARNING_RATE), f=logfile)
153 |
154 | OA_bulk_train = []
155 | CK_bulk_train = []
156 | DI_bulk_train = []
157 | JA_bulk_train = []
158 | SE_bulk_train = []
159 |
160 | OA_bulk_test = []
161 | CK_bulk_test = []
162 | DI_bulk_test = []
163 | JA_bulk_test = []
164 | SE_bulk_test = []
165 |
166 |
167 | for epoch in range(config.EPOCH):
168 | #cls
169 | cls_train_loss = []
170 | seg_train_loss = []
171 | train_inter_loss = []
172 | ############# Start the training
173 | for i_iter, batch in enumerate(trainloader):
174 | step = (config.TRAIN_NUM/config.BATCH_SIZE)*epoch+i_iter
175 | images, masks, labels, name = batch
176 | images = images.cuda()
177 | labels = labels.cuda().long()
178 | masks = masks.cuda().squeeze(1)
179 | optimizer.zero_grad()
180 | lr = adjust_learning_rate(optimizer, step)
181 | model.train()
182 | preds_seg_coarse, preds_seg_fine, preds_cls = model(images)
183 | cls_loss = ce_loss(preds_cls, labels)
184 | seg_loss_fine = dice_loss(preds_seg_fine, masks)
185 | seg_loss_coarse = dice_loss(preds_seg_coarse, masks)
186 | inter_loss = task_interaction_loss(preds_cls, preds_seg_fine, labels)
187 | loss = cls_loss + seg_loss_fine + seg_loss_coarse + args.alpha * inter_loss
188 |
189 | if config.FP16 is True:
190 | with amp.scale_loss(loss, optimizer) as scaled_loss:
191 | scaled_loss.backward()
192 | else:
193 | loss.backward()
194 | optimizer.step()
195 | #cls
196 | cls_train_loss.append(cls_loss.cpu().data.numpy())
197 | seg_train_loss.append(seg_loss_fine.cpu().data.numpy())
198 | train_inter_loss.append(inter_loss.cpu().data.numpy())
199 |
200 | ############ train log
201 | line = "Train-Epoch [%d/%d] [All]: Seg_loss = %.6f, Class_loss = %.6f, Inter_loss = %.6f, LR = %0.9f\n" % (epoch, config.EPOCH, np.nanmean(seg_train_loss), np.nanmean(cls_train_loss), np.nanmean(train_inter_loss), lr)
202 | print_f(line, f=logfile)
203 |
204 | result = test(train_testloader, model, epoch, verbose=False)
205 | #cls
206 | [CK, OA, EREC] = result['cls']
207 | OA_bulk_train.append(OA)
208 | CK_bulk_train.append(CK)
209 |
210 | # seg
211 | [AC, DI, SE, SP, JA] = result['seg']
212 | JA_bulk_train.append(np.nanmean(JA))
213 | DI_bulk_train.append(np.nanmean(DI))
214 | SE_bulk_train.append(np.nanmean(SE))
215 |
216 | ############# Start the test
217 | result = test(testloader, model, epoch, config.SAVE_PATH+'Seg_results/' , verbose = config.VERBOSE)
218 | #cls
219 | [CK, OA, EREC] = result['cls']
220 | line = "Test -Epoch [%d/%d] [Cls]: CK-Score = %f, OA = %f, Rec-N = %f, Rec-V = %f, Rec-I=%f \n" % (epoch, config.EPOCH, CK, OA, EREC[0],EREC[1],EREC[2] )
221 | print_f(line, f=logfile)
222 | OA_bulk_test.append(OA)
223 | CK_bulk_test.append(CK)
224 |
225 | # seg
226 | [AC, DI, SE, SP, JA] = result['seg']
227 | line = "Test -Epoch [%d/%d] [Seg]: AC = %f, DI = %f, SE = %f, SP = %f, JA = %f \n" % (epoch, config.EPOCH, np.nanmean(AC), np.nanmean(DI), np.nanmean(SE), np.nanmean(SP), np.nanmean(JA))
228 | print_f(line, f=logfile)
229 |
230 | JA_bulk_test.append(np.nanmean(JA))
231 | DI_bulk_test.append(np.nanmean(DI))
232 | SE_bulk_test.append(np.nanmean(SE))
233 |
234 | ############# Plot val curve
235 | filename = os.path.join(config.LOG_PATH, 'cls_curves.png')
236 | data_list = [OA_bulk_train, OA_bulk_test, CK_bulk_train, CK_bulk_test]
237 | label_list = ['OA_train','OA_test','CK_train','CK_test']
238 | draw_curves(data_list = data_list, label_list = label_list, color_list = config.COLOR[0:4], filename = filename)
239 | filename = os.path.join(config.LOG_PATH, 'seg_curves.png')
240 | data_list = [JA_bulk_train, JA_bulk_test, DI_bulk_train, DI_bulk_test, SE_bulk_train, SE_bulk_test]
241 | label_list = ['JA_train','JA_test','DI_train','DI_test', 'SE_train','SE_test']
242 | draw_curves(data_list = data_list, label_list = label_list, color_list = config.COLOR[0:6], filename = filename)
243 |
244 | if __name__ == '__main__':
245 | main()
246 |
247 |
--------------------------------------------------------------------------------
/net/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 |
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | def forward(self, input):
49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50 | if not (self._is_parallel and self.training):
51 | return F.batch_norm(
52 | input, self.running_mean, self.running_var, self.weight, self.bias,
53 | self.training, self.momentum, self.eps)
54 |
55 | # Resize the input to (B, C, -1).
56 | input_shape = input.size()
57 | input = input.view(input.size(0), self.num_features, -1)
58 |
59 | # Compute the sum and square-sum.
60 | sum_size = input.size(0) * input.size(2)
61 | input_sum = _sum_ft(input)
62 | input_ssum = _sum_ft(input ** 2)
63 |
64 | # Reduce-and-broadcast the statistics.
65 | if self._parallel_id == 0:
66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67 | else:
68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69 |
70 | # Compute the output.
71 | if self.affine:
72 | # MJY:: Fuse the multiplication for speed.
73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74 | else:
75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76 |
77 | # Reshape it.
78 | return output.view(input_shape)
79 |
80 | def __data_parallel_replicate__(self, ctx, copy_id):
81 | self._is_parallel = True
82 | self._parallel_id = copy_id
83 |
84 | # parallel_id == 0 means master device.
85 | if self._parallel_id == 0:
86 | ctx.sync_master = self._sync_master
87 | else:
88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89 |
90 | def _data_parallel_master(self, intermediates):
91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92 |
93 | # Always using same "device order" makes the ReduceAdd operation faster.
94 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96 |
97 | to_reduce = [i[1][:2] for i in intermediates]
98 | to_reduce = [j for i in to_reduce for j in i] # flatten
99 | target_gpus = [i[1].sum.get_device() for i in intermediates]
100 |
101 | sum_size = sum([i[1].sum_size for i in intermediates])
102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104 |
105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106 |
107 | outputs = []
108 | for i, rec in enumerate(intermediates):
109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
110 |
111 | return outputs
112 |
113 | def _compute_mean_std(self, sum_, ssum, size):
114 | """Compute the mean and standard-deviation with sum and square-sum. This method
115 | also maintains the moving average on the master device."""
116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117 | mean = sum_ / size
118 | sumvar = ssum - sum_ * mean
119 | unbias_var = sumvar / (size - 1)
120 | bias_var = sumvar / size
121 |
122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124 |
125 | return mean, bias_var.clamp(self.eps) ** -0.5
126 |
127 |
128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130 | mini-batch.
131 |
132 | .. math::
133 |
134 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
135 |
136 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
137 | standard-deviation are reduced across all devices during training.
138 |
139 | For example, when one uses `nn.DataParallel` to wrap the network during
140 | training, PyTorch's implementation normalize the tensor on each device using
141 | the statistics only on that device, which accelerated the computation and
142 | is also easy to implement, but the statistics might be inaccurate.
143 | Instead, in this synchronized version, the statistics will be computed
144 | over all training samples distributed on multiple devices.
145 |
146 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
147 | as the built-in PyTorch implementation.
148 |
149 | The mean and standard-deviation are calculated per-dimension over
150 | the mini-batches and gamma and beta are learnable parameter vectors
151 | of size C (where C is the input size).
152 |
153 | During training, this layer keeps a running estimate of its computed mean
154 | and variance. The running sum is kept with a default momentum of 0.1.
155 |
156 | During evaluation, this running mean/variance is used for normalization.
157 |
158 | Because the BatchNorm is done over the `C` dimension, computing statistics
159 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
160 |
161 | Args:
162 | num_features: num_features from an expected input of size
163 | `batch_size x num_features [x width]`
164 | eps: a value added to the denominator for numerical stability.
165 | Default: 1e-5
166 | momentum: the value used for the running_mean and running_var
167 | computation. Default: 0.1
168 | affine: a boolean value that when set to ``True``, gives the layer learnable
169 | affine parameters. Default: ``True``
170 |
171 | Shape:
172 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
173 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
174 |
175 | Examples:
176 | >>> # With Learnable Parameters
177 | >>> m = SynchronizedBatchNorm1d(100)
178 | >>> # Without Learnable Parameters
179 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
180 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
181 | >>> output = m(input)
182 | """
183 |
184 | def _check_input_dim(self, input):
185 | if input.dim() != 2 and input.dim() != 3:
186 | raise ValueError('expected 2D or 3D input (got {}D input)'
187 | .format(input.dim()))
188 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
189 |
190 |
191 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
192 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
193 | of 3d inputs
194 |
195 | .. math::
196 |
197 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
198 |
199 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
200 | standard-deviation are reduced across all devices during training.
201 |
202 | For example, when one uses `nn.DataParallel` to wrap the network during
203 | training, PyTorch's implementation normalize the tensor on each device using
204 | the statistics only on that device, which accelerated the computation and
205 | is also easy to implement, but the statistics might be inaccurate.
206 | Instead, in this synchronized version, the statistics will be computed
207 | over all training samples distributed on multiple devices.
208 |
209 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
210 | as the built-in PyTorch implementation.
211 |
212 | The mean and standard-deviation are calculated per-dimension over
213 | the mini-batches and gamma and beta are learnable parameter vectors
214 | of size C (where C is the input size).
215 |
216 | During training, this layer keeps a running estimate of its computed mean
217 | and variance. The running sum is kept with a default momentum of 0.1.
218 |
219 | During evaluation, this running mean/variance is used for normalization.
220 |
221 | Because the BatchNorm is done over the `C` dimension, computing statistics
222 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
223 |
224 | Args:
225 | num_features: num_features from an expected input of
226 | size batch_size x num_features x height x width
227 | eps: a value added to the denominator for numerical stability.
228 | Default: 1e-5
229 | momentum: the value used for the running_mean and running_var
230 | computation. Default: 0.1
231 | affine: a boolean value that when set to ``True``, gives the layer learnable
232 | affine parameters. Default: ``True``
233 |
234 | Shape:
235 | - Input: :math:`(N, C, H, W)`
236 | - Output: :math:`(N, C, H, W)` (same shape as input)
237 |
238 | Examples:
239 | >>> # With Learnable Parameters
240 | >>> m = SynchronizedBatchNorm2d(100)
241 | >>> # Without Learnable Parameters
242 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
243 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
244 | >>> output = m(input)
245 | """
246 |
247 | def _check_input_dim(self, input):
248 | if input.dim() != 4:
249 | raise ValueError('expected 4D input (got {}D input)'
250 | .format(input.dim()))
251 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
252 |
253 |
254 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
255 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
256 | of 4d inputs
257 |
258 | .. math::
259 |
260 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
261 |
262 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
263 | standard-deviation are reduced across all devices during training.
264 |
265 | For example, when one uses `nn.DataParallel` to wrap the network during
266 | training, PyTorch's implementation normalize the tensor on each device using
267 | the statistics only on that device, which accelerated the computation and
268 | is also easy to implement, but the statistics might be inaccurate.
269 | Instead, in this synchronized version, the statistics will be computed
270 | over all training samples distributed on multiple devices.
271 |
272 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
273 | as the built-in PyTorch implementation.
274 |
275 | The mean and standard-deviation are calculated per-dimension over
276 | the mini-batches and gamma and beta are learnable parameter vectors
277 | of size C (where C is the input size).
278 |
279 | During training, this layer keeps a running estimate of its computed mean
280 | and variance. The running sum is kept with a default momentum of 0.1.
281 |
282 | During evaluation, this running mean/variance is used for normalization.
283 |
284 | Because the BatchNorm is done over the `C` dimension, computing statistics
285 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
286 | or Spatio-temporal BatchNorm
287 |
288 | Args:
289 | num_features: num_features from an expected input of
290 | size batch_size x num_features x depth x height x width
291 | eps: a value added to the denominator for numerical stability.
292 | Default: 1e-5
293 | momentum: the value used for the running_mean and running_var
294 | computation. Default: 0.1
295 | affine: a boolean value that when set to ``True``, gives the layer learnable
296 | affine parameters. Default: ``True``
297 |
298 | Shape:
299 | - Input: :math:`(N, C, D, H, W)`
300 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
301 |
302 | Examples:
303 | >>> # With Learnable Parameters
304 | >>> m = SynchronizedBatchNorm3d(100)
305 | >>> # Without Learnable Parameters
306 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
307 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
308 | >>> output = m(input)
309 | """
310 |
311 | def _check_input_dim(self, input):
312 | if input.dim() != 5:
313 | raise ValueError('expected 5D input (got {}D input)'
314 | .format(input.dim()))
315 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
316 |
--------------------------------------------------------------------------------