├── logs └── README.md ├── checkpoints └── README.md ├── data └── pre_model │ └── README.md ├── Figs └── Framework.png ├── net ├── __pycache__ │ ├── ASPP.cpython-37.pyc │ ├── convs.cpython-37.pyc │ ├── gumbel.cpython-37.pyc │ ├── loss.cpython-37.pyc │ ├── models.cpython-37.pyc │ ├── modules.cpython-37.pyc │ └── xception.cpython-37.pyc ├── sync_batchnorm │ ├── __pycache__ │ │ ├── comm.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── batchnorm.cpython-36.pyc │ │ └── replicate.cpython-36.pyc │ ├── __init__.py │ ├── unittest.py │ ├── batchnorm_reimpl.py │ ├── replicate.py │ ├── comm.py │ └── batchnorm.py ├── convs.py ├── gumbel.py ├── ASPP.py ├── loss.py ├── models.py ├── xception.py └── modules.py ├── utils ├── logger.py ├── meter.py ├── metrics.py └── fp16util.py ├── training_script.sh ├── config.py ├── README.md ├── dataset ├── transform_customize.py ├── my_datasets.py └── preprocess.py ├── visualization └── utils.py └── train_DSI_Net.py /logs/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/pre_model/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Figs/Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/Figs/Framework.png -------------------------------------------------------------------------------- /net/__pycache__/ASPP.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/ASPP.cpython-37.pyc -------------------------------------------------------------------------------- /net/__pycache__/convs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/convs.cpython-37.pyc -------------------------------------------------------------------------------- /net/__pycache__/gumbel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/gumbel.cpython-37.pyc -------------------------------------------------------------------------------- /net/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /net/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /net/__pycache__/modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/modules.cpython-37.pyc -------------------------------------------------------------------------------- /net/__pycache__/xception.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/__pycache__/xception.cpython-37.pyc -------------------------------------------------------------------------------- /net/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /net/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /net/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /net/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CityU-AIM-Group/DSI-Net/HEAD/net/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | 4 | def print_f(str, f=None): 5 | if f is not None: 6 | print(str, file=f) 7 | if random.randint(0, 20) < 3: 8 | f.flush() 9 | print(str) 10 | 11 | -------------------------------------------------------------------------------- /training_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -J training 3 | #SBATCH --gres=gpu:1 4 | #SBATCH --partition=gpu_1d2g 5 | #SBATCH -c 2 6 | #SBATCH -N 1 7 | 8 | echo "Submitted from:"$SLURM_SUBMIT_DIR" on node:"$SLURM_SUBMIT_HOST 9 | echo "Running on node "$SLURM_JOB_NODELIST 10 | echo "Allocate Gpu Units:"$CUDA_VISIBLE_DEVICES 11 | 12 | nvidia-smi 13 | 14 | python train_DSI_Net.py --gpus 0 --K 100 --alpha 0.05 --image_list 'data/WCE/WCE_Dataset_image_list.pkl' 15 | -------------------------------------------------------------------------------- /net/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /net/convs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class SeparableConv2d(nn.Module): 6 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): 7 | super(SeparableConv2d,self).__init__() 8 | 9 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) 10 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 11 | 12 | def forward(self,x): 13 | x = self.conv1(x) 14 | x = self.pointwise(x) 15 | return x -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | def __init__(self, avg_mom=0.5): 4 | self.avg_mom = avg_mom 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 # running average of whole epoch 10 | self.smooth_avg = 0 11 | self.sum = 0 12 | self.count = 0 13 | 14 | def update(self, val, n=1): 15 | self.val = val 16 | self.sum += val * n 17 | self.count += n 18 | self.smooth_avg = val if self.count == 0 else self.avg*self.avg_mom + val*(1-self.avg_mom) 19 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | #data 3 | DATA_ROOT = '/home/meiluzhu2/data/WCE/WCE_larger' 4 | BATCH_SIZE = 8 5 | NUM_WORKERS = 8 6 | DROP_LAST = True 7 | SIZE = 256 8 | 9 | #training 10 | LEARNING_RATE = 0.0001 11 | MOMENTUM = 0.9 12 | POWER = 0.9 13 | WEIGHT_DECAY = 1e-5 14 | NUM_CLASSES_CLS = 3 15 | TRAIN_NUM = 2470 16 | EPOCH = 200 17 | STEPS = (TRAIN_NUM/BATCH_SIZE)*EPOCH 18 | FP16 = False 19 | VERBOSE = False 20 | SAVE_PATH = 'checkpoints/' 21 | LOG_PATH = 'logs/' 22 | COLOR = ['red', 'green', 'blue', 'yellow', 'black', 'orange', 'purple', 'pink','peru'] 23 | 24 | #network 25 | INTERMIDEATE_NUM = 64 26 | OS = 8 27 | EM_STEP = 3 28 | ##gumbel 29 | GUMBEL_FACTOR = 1.0 30 | GUMBEL_NOISE = True 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /net/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /net/gumbel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Gumbel(nn.Module): 5 | ''' 6 | Returns differentiable discrete outputs. Applies a Gumbel-Softmax trick on every element of x. 7 | ''' 8 | def __init__(self, config): 9 | super(Gumbel, self).__init__() 10 | self.factor = config.GUMBEL_FACTOR 11 | self.gumbel_noise = config.GUMBEL_NOISE 12 | 13 | def forward(self, x): 14 | if not self.training: # no Gumbel noise during inference 15 | return (x >= 0).float() 16 | 17 | if self.gumbel_noise: 18 | U = torch.rand_like(x) 19 | g= -torch.log( - torch.log(U + 1e-8) + 1e-8) 20 | x = x + g 21 | 22 | soft = torch.sigmoid(x / self.factor) 23 | hard = ((soft >= 0.5).float() - soft).detach() + soft 24 | assert not torch.any(torch.isnan(hard)) 25 | 26 | return hard -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import metrics 3 | 4 | 5 | def Jaccard(pred_arg, mask): 6 | pred_arg = np.argmax(pred_arg.cpu().data.numpy(), axis=1) 7 | mask = mask.cpu().data.numpy() 8 | 9 | y_true_f = mask.reshape(mask.shape[0] * mask.shape[1] * mask.shape[2], order='F') 10 | y_pred_f = pred_arg.reshape(pred_arg.shape[0] * pred_arg.shape[1] * pred_arg.shape[2], order='F') 11 | 12 | intersection = np.float(np.sum(y_true_f * y_pred_f)) 13 | jac_score = intersection / (np.sum(y_true_f) + np.sum(y_pred_f) - intersection) 14 | 15 | return jac_score 16 | 17 | 18 | def cla_evaluate(label, binary_score, pro_score): 19 | 20 | acc = metrics.accuracy_score(label, binary_score) 21 | AP = metrics.average_precision_score(label, pro_score) 22 | auc = metrics.roc_auc_score(label, pro_score) 23 | CM = metrics.confusion_matrix(label, binary_score) 24 | MCC = metrics.matthews_corrcoef(label,binary_score) 25 | F1 = metrics.f1_score(label,binary_score) 26 | sens = float(CM[1, 1]) / float(CM[1, 1] + CM[1, 0]) 27 | spec = float(CM[0, 0]) / float(CM[0, 0] + CM[0, 1]) 28 | return acc, auc, AP, sens, spec, MCC, F1 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DSI-Net 2 | 3 | This repository is an official PyTorch implementation of the paper [**"DSI-Net: Deep Synergistic Interaction Network for Joint Classification and Segmentation with Endoscope Images"**](https://ieeexplore.ieee.org/document/9440441), TMI 2021. 4 | 5 |
6 | 7 | 8 | ## Dependencies 9 | * Python 3.6 10 | * PyTorch >= 1.3.0 11 | * numpy 12 | * apex 13 | * sklearn 14 | * matplotlib 15 | * PIL 16 | 17 | ## Usage 18 | * Downloading [**processed dataset**](https://drive.google.com/file/d/1BBF21SVlH5685XpsvtKlWN7iepr7YQPU/view?usp=sharing) 19 | * Training DSI-Net 20 | ```python 21 | python train_DSI_Net.py --gpus 0 --K 100 --alpha 0.05 --image_list 'data/WCE/WCE_Dataset_image_list.pkl' 22 | ``` 23 | 24 | ## Citation 25 | ``` 26 | @ARTICLE{9440441, 27 | author={Zhu, Meilu and Chen, Zhen and Yuan, Yixuan}, 28 | journal={IEEE Transactions on Medical Imaging}, 29 | title={DSI-Net: Deep Synergistic Interaction Network for Joint Classification and Segmentation with Endoscope Images}, 30 | year={2021}, 31 | doi={10.1109/TMI.2021.3083586}} 32 | ``` 33 | ## Contact 34 | 35 | Meilu Zhu (meiluzhu2-c@my.cityu.edu.hk) 36 | -------------------------------------------------------------------------------- /dataset/transform_customize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage.interpolation import map_coordinates 3 | from scipy.ndimage.filters import gaussian_filter 4 | 5 | 6 | class RandomElasticTransform(object): 7 | """Randomly rotate image""" 8 | # https://gist.github.com/nasimrahaman/8ed04be1088e228c21d51291f47dd1e6 9 | def __init__(self, alpha =2000, sigma=50): 10 | self.alpha = alpha 11 | self.sigma = sigma 12 | 13 | def __call__(self, img): 14 | 15 | shape = img.shape[:2] 16 | random_state = np.random 17 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), 18 | self.sigma, mode="constant", cval=0) * self.alpha 19 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), 20 | self.sigma, mode="constant", cval=0) * self.alpha 21 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') 22 | indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))] 23 | 24 | image = map_coordinates(img, indices, order=1, mode='nearest').reshape(shape) 25 | 26 | return image 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /visualization/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def show_seg_results(img, gt, pre, save_path = None, name = None): 8 | 9 | fig = plt.figure() 10 | ax = fig.add_subplot(131) 11 | ax.imshow(img) 12 | ax.axis('off') 13 | ax = fig.add_subplot(132) 14 | ax.imshow(gt) 15 | ax.axis('off') 16 | ax = fig.add_subplot(133) 17 | ax.imshow(pre) 18 | ax.axis('off') 19 | fig.suptitle('Img, GT, Prediction',fontsize=6) 20 | if save_path != None and name != None: 21 | fig.savefig(save_path + name + '.png', dpi=200, bbox_inches='tight') 22 | ax.cla() 23 | fig.clf() 24 | plt.close() 25 | 26 | def draw_curves(data_list, label_list, color_list, linestyle_list = None, filename = 'training_curve.png'): 27 | 28 | plt.figure() 29 | for i in range(len(data_list)): 30 | data = data_list[i] 31 | label = label_list[i] 32 | color = color_list[i] 33 | if linestyle_list == None: 34 | linestyle = '-' 35 | else: 36 | linestyle = linestyle_list[i] 37 | plt.plot(data, label = label, color = color, linestyle = linestyle) 38 | plt.legend(loc='best') 39 | plt.savefig(filename) 40 | plt.clf() 41 | plt.close() 42 | plt.show() 43 | plt.close('all') -------------------------------------------------------------------------------- /utils/fp16util.py: -------------------------------------------------------------------------------- 1 | #https://github.com/cybertronai/imagenet18_old/blob/master/training/fp16util.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | class tofp16(nn.Module): 9 | """ 10 | Model wrapper that implements:: 11 | def forward(self, input): 12 | return input.half() 13 | """ 14 | 15 | def __init__(self): 16 | super(tofp16, self).__init__() 17 | 18 | def forward(self, input): 19 | return input.half() 20 | 21 | 22 | def BN_convert_float(module): 23 | ''' 24 | Designed to work with network_to_half. 25 | BatchNorm layers need parameters in single precision. 26 | Find all layers and convert them back to float. This can't 27 | be done with built in .apply as that function will apply 28 | fn to all modules, parameters, and buffers. Thus we wouldn't 29 | be able to guard the float conversion based on the module type. 30 | ''' 31 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 32 | module.float() 33 | for child in module.children(): 34 | BN_convert_float(child) 35 | return module 36 | 37 | 38 | def network_to_half(network): 39 | """ 40 | Convert model to half precision in a batchnorm-safe way. 41 | """ 42 | # (AS) This is better as it does not change model structure 43 | return BN_convert_float(network.half()) 44 | # return nn.Sequential(tofp16(), BN_convert_float(network.half())) -------------------------------------------------------------------------------- /net/ASPP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ASPP(nn.Module): 6 | def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1): 7 | super(ASPP, self).__init__() 8 | self.branch1 = nn.Sequential( 9 | nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True), 10 | nn.BatchNorm2d(dim_out), 11 | nn.ReLU(inplace=True), 12 | ) 13 | self.branch2 = nn.Sequential( 14 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True), 15 | nn.BatchNorm2d(dim_out), 16 | nn.ReLU(inplace=True), 17 | ) 18 | self.branch3 = nn.Sequential( 19 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True), 20 | nn.BatchNorm2d(dim_out), 21 | nn.ReLU(inplace=True), 22 | ) 23 | self.branch4 = nn.Sequential( 24 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True), 25 | nn.BatchNorm2d(dim_out), 26 | nn.ReLU(inplace=True), 27 | ) 28 | self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True) 29 | self.branch5_bn = nn.BatchNorm2d(dim_out) 30 | self.branch5_relu = nn.ReLU(inplace=True) 31 | self.conv_cat = nn.Sequential( 32 | nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True), 33 | nn.BatchNorm2d(dim_out), 34 | nn.ReLU(inplace=True), 35 | ) 36 | 37 | def forward(self, x): 38 | [b, c, row, col] = x.size() 39 | conv1x1 = self.branch1(x) 40 | conv3x3_1 = self.branch2(x) 41 | conv3x3_2 = self.branch3(x) 42 | conv3x3_3 = self.branch4(x) 43 | global_feature = torch.mean(x, 2, True) 44 | global_feature = torch.mean(global_feature, 3, True) 45 | global_feature = self.branch5_conv(global_feature) 46 | global_feature = self.branch5_bn(global_feature) 47 | global_feature = self.branch5_relu(global_feature) 48 | global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) 49 | 50 | feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) 51 | result = self.conv_cat(feature_cat) 52 | 53 | return result 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /net/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | def _dice_loss(predict, target): 7 | 8 | smooth = 1e-5 9 | 10 | y_true_f = target.contiguous().view(target.shape[0], -1) 11 | y_pred_f = predict.contiguous().view(predict.shape[0], -1) 12 | intersection = torch.sum(torch.mul(y_pred_f, y_true_f), dim=1) 13 | union = torch.sum(y_pred_f, dim=1) + torch.sum(y_true_f, dim=1) + smooth 14 | dice_score = (2.0 * intersection / union) 15 | 16 | dice_loss = 1 - dice_score 17 | 18 | return dice_loss 19 | 20 | 21 | class Dice_Loss(nn.Module): 22 | def __init__(self): 23 | super(Dice_Loss, self).__init__() 24 | 25 | def forward(self, predicts, target): 26 | 27 | preds = torch.softmax(predicts, dim=1) 28 | dice_loss0 = _dice_loss(preds[:, 0, :, :], 1 - target) 29 | dice_loss1 = _dice_loss(preds[:, 1, :, :], target) 30 | loss_D = (dice_loss0.mean() + dice_loss1.mean())/2.0 31 | 32 | return loss_D 33 | 34 | 35 | class Task_Interaction_Loss(nn.Module): 36 | 37 | def __init__(self): 38 | super(Task_Interaction_Loss, self).__init__() 39 | 40 | def forward(self, cls_predict, seg_predict, target): 41 | 42 | b,c = cls_predict.shape 43 | h, w = seg_predict.shape[2], seg_predict.shape[3] 44 | 45 | target = target.view(b,1) 46 | target = torch.zeros(b,c).cuda().scatter_(1,target,1) 47 | target_new = torch.zeros(b,c-1).cuda() 48 | cls_pred = Variable(torch.zeros(b,c-1)).cuda() 49 | seg_pred = Variable(torch.zeros(b,c-1)).cuda() 50 | 51 | target_new[:,0] = target[:,0] 52 | target_new[:,1] = target[:,1] + target[:,2] 53 | 54 | cls_pred[:,0] = cls_predict[:,0] 55 | cls_pred[:,1] = cls_predict[:,1] + cls_predict[:,2] 56 | 57 | # Log Sum Exp 58 | seg_pred = torch.logsumexp(seg_predict, dim=(2,3))/(h*w) 59 | 60 | #JS 61 | seg_cls_kl = F.kl_div(torch.log_softmax(cls_pred, dim=-1), torch.softmax(seg_pred, dim=-1), reduction='none') 62 | cls_seg_kl = F.kl_div(torch.log_softmax(seg_pred, dim=-1), torch.softmax(cls_pred, dim=-1), reduction='none') 63 | 64 | seg_cls_kl = seg_cls_kl.sum(-1) 65 | cls_seg_kl = cls_seg_kl.sum(-1) 66 | 67 | indicator1 = (cls_pred[:,0]>cls_pred[:,1]) == (target_new[:,0]>target_new[:,1]) 68 | indicator2 = (seg_pred[:,0]>seg_pred[:,1]) == (target_new[:,0]>target_new[:,1]) 69 | 70 | return (cls_seg_kl*indicator1 + seg_cls_kl*indicator2).sum()/2./b 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /net/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /dataset/my_datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from torch.utils import data 4 | from torchvision import transforms 5 | from PIL import Image 6 | import os 7 | import pickle 8 | 9 | 10 | class CADCAPDataset(data.Dataset): 11 | def __init__(self, dataset_root, DATA_PKL, SIZE, data_type = 'train', mode = 'train'): 12 | self.dataset_root = dataset_root 13 | if os.path.exists(DATA_PKL): 14 | with open(DATA_PKL, 'rb') as f: 15 | info = pickle.load(f) 16 | if data_type == 'train': 17 | self.data = info['train'] 18 | else: 19 | self.data = info['test'] 20 | assert (data_type == 'train' and mode == 'train') or (data_type == 'train' and mode == 'test') or (data_type == 'test' and mode == 'test'), print('mode setting error in dataset....') 21 | self.mode = mode 22 | self.train_augmentation = transforms.Compose( 23 | [transforms.RandomAffine(degrees=90, shear=5.729578), 24 | transforms.RandomVerticalFlip(p=0.5), 25 | transforms.RandomHorizontalFlip(p=0.5), 26 | transforms.ToTensor(), 27 | transforms.ToPILImage(), 28 | transforms.Resize(SIZE) 29 | ]) 30 | self.train_gt_augmentation = transforms.Compose( 31 | [transforms.RandomAffine(degrees=90, shear=5.729578), 32 | transforms.RandomVerticalFlip(p=0.5), 33 | transforms.RandomHorizontalFlip(p=0.5), 34 | transforms.ToTensor(), 35 | transforms.ToPILImage(), 36 | transforms.Resize(SIZE) 37 | ]) 38 | 39 | self.test_augmentation = transforms.Compose( 40 | [transforms.ToTensor(), 41 | transforms.ToPILImage(), 42 | transforms.Resize(SIZE) 43 | ]) 44 | self.test_gt_augmentation = transforms.Compose( 45 | [transforms.ToTensor(), 46 | transforms.ToPILImage(), 47 | transforms.Resize(SIZE) 48 | ]) 49 | 50 | def __len__(self): 51 | return len(self.data) 52 | 53 | def __getitem__(self, idx): 54 | patient = self.data[idx] 55 | image = Image.open(os.path.join(self.dataset_root, patient['image'])) 56 | mask = Image.open(os.path.join(self.dataset_root, patient['mask'])).convert('1') 57 | label = patient['label'] 58 | 59 | if self.mode == 'train': 60 | seed = np.random.randint(123456) 61 | random.seed(seed) 62 | image = self.train_augmentation(image) 63 | random.seed(seed) 64 | mask = self.train_gt_augmentation(mask) 65 | else: 66 | image = self.test_augmentation(image) 67 | mask = self.test_gt_augmentation(mask) 68 | 69 | image = np.array(image) / 255. 70 | image = image.transpose((2, 0, 1)) 71 | image = image.astype(np.float32) 72 | mask = np.array(mask) 73 | mask = np.float32(mask > 0) 74 | name = patient['image'].split('.')[0].replace('/','_' ) 75 | 76 | return image.copy(), mask.copy(), label, name 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /net/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /net/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /net/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import net.xception as xception 5 | 6 | from .ASPP import ASPP 7 | from .convs import SeparableConv2d 8 | from .modules import Lesion_Location_Mining 9 | from .modules import Category_guided_Feature_Generation 10 | from .modules import Global_Prototypes_Generator 11 | 12 | class DSI_Net(nn.Module): 13 | def __init__(self, config, K=100): 14 | super(DSI_Net, self).__init__() 15 | self.backbone = None 16 | self.backbone_layers = None 17 | self.dropout = nn.Dropout(0.5) 18 | self.upsample_sub_x2 = nn.UpsamplingBilinear2d(scale_factor=2) 19 | self.upsample_sub_x4 = nn.UpsamplingBilinear2d(scale_factor=4) 20 | self.shortcut_conv = nn.Sequential(nn.Conv2d(256, 48, 1, 1, padding=1//2, bias=True), 21 | nn.BatchNorm2d(48), 22 | nn.ReLU(inplace=True), 23 | ) 24 | self.aspp = ASPP(dim_in=2048, dim_out=256, rate=16//16, bn_mom = 0.99) 25 | self.coarse_head = nn.Sequential( 26 | nn.Conv2d(256+48, 256, 3, 1, padding=1, bias=True), 27 | nn.BatchNorm2d(256), 28 | nn.ReLU(inplace=True), 29 | nn.Dropout(0.5), 30 | nn.Conv2d(256, 256, 3, 1, padding=1, bias=True), 31 | nn.BatchNorm2d(256), 32 | nn.ReLU(inplace=True), 33 | nn.Dropout(0.1), 34 | nn.Conv2d(256, 2, kernel_size=1, stride=1, padding=0, bias=True) 35 | ) 36 | 37 | self.fine_head = nn.Sequential( 38 | nn.Conv2d(256+64+48, 256, 3, 1, padding=1, bias=True), 39 | nn.BatchNorm2d(256), 40 | nn.ReLU(inplace=True), 41 | nn.Dropout(0.5), 42 | nn.Conv2d(256, 256, 3, 1, padding=1, bias=True), 43 | nn.BatchNorm2d(256), 44 | nn.ReLU(inplace=True), 45 | nn.Dropout(0.1), 46 | nn.Conv2d(256, 2, kernel_size=1, stride=1, padding=0, bias=True) 47 | ) 48 | 49 | self.cls_head = nn.Sequential( 50 | SeparableConv2d(1024, 1536, 3, dilation=2, stride=1, padding=2, bias=False), 51 | nn.BatchNorm2d(1536), 52 | nn.ReLU(inplace=True), 53 | SeparableConv2d(1536, 1536, 3, dilation=2, stride=1, padding=2, bias=False), 54 | nn.BatchNorm2d(1536), 55 | nn.ReLU(inplace=True), 56 | SeparableConv2d(1536, 2048, 3, dilation=2, stride=1, padding=2, bias=False), 57 | nn.BatchNorm2d(2048), 58 | nn.ReLU(inplace=True)) 59 | self.LLM = Lesion_Location_Mining(config, 1024, K) 60 | self.GPG = Global_Prototypes_Generator(2048, config.INTERMIDEATE_NUM) 61 | self.CFG = Category_guided_Feature_Generation(256, config.INTERMIDEATE_NUM, config.EM_STEP) 62 | self.avgpool = nn.AdaptiveAvgPool2d(1) 63 | self.cls_predict = nn.Linear(2048, config.NUM_CLASSES_CLS, bias = False) 64 | 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 68 | elif isinstance(m, nn.BatchNorm2d): 69 | nn.init.constant_(m.weight, 1) 70 | nn.init.constant_(m.bias, 0) 71 | 72 | self.backbone = xception.Xception(os = config.OS) 73 | self.backbone_layers = self.backbone.get_layers() 74 | 75 | 76 | def forward(self, x): 77 | x = self.backbone(x) 78 | 79 | #shllow feature 80 | layers = self.backbone.get_layers() 81 | feature_shallow = self.shortcut_conv(layers[0]) 82 | feature_aspp = self.aspp(layers[-1]) 83 | 84 | #coarse seg 85 | feature_coarse= self.dropout(feature_aspp) 86 | feature_coarse = self.upsample_sub_x2(feature_coarse) 87 | feature_coarse = torch.cat([feature_coarse,feature_shallow],1) 88 | seg_coarse = self.coarse_head(feature_coarse) 89 | 90 | #####cls 91 | cls_feats = layers[-2] 92 | b, c, h, w = cls_feats.size() 93 | mask_coarse = torch.softmax(seg_coarse, dim = 1) 94 | mask_coarse = F.interpolate(mask_coarse, size=(h, w), mode="bilinear", align_corners=False) 95 | 96 | cls_feats = self.LLM(cls_feats, mask_coarse) 97 | cls_feats = self.cls_head(cls_feats) 98 | cls_out = self.avgpool(cls_feats) 99 | cls_out = cls_out.view(b, -1) 100 | cls_out = self.cls_predict(cls_out) 101 | 102 | #fine seg 103 | global_prototypes = self.GPG(self.cls_predict.weight.detach(), cls_out.detach()) 104 | context= self.CFG(feature_aspp, mask_coarse, global_prototypes) 105 | context = self.upsample_sub_x2(context) 106 | feature_fine= self.dropout(feature_aspp) 107 | feature_fine = self.upsample_sub_x2(feature_fine) 108 | feature_fine = torch.cat([feature_fine,context,feature_shallow],1) 109 | seg_fine = self.fine_head(feature_fine) 110 | 111 | #final seg 112 | seg_coarse = self.upsample_sub_x4(seg_coarse) 113 | seg_fine = self.upsample_sub_x4(seg_fine) 114 | 115 | return seg_coarse, seg_fine, cls_out 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /net/xception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) 3 | @author: tstandley 4 | Adapted by cadene 5 | Creates an Xception Model as defined in: 6 | Francois Chollet 7 | Xception: Deep Learning with Depthwise Separable Convolutions 8 | https://arxiv.org/pdf/1610.02357.pdf 9 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 10 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 11 | REMEMBER to set your image size to 3x299x299 for both test and validation 12 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 13 | std=[0.5, 0.5, 0.5]) 14 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 15 | """ 16 | import math 17 | import torch 18 | import torch.nn as nn 19 | 20 | bn_mom = 0.0003 21 | __all__ = ['xception'] 22 | 23 | 24 | class SeparableConv2d(nn.Module): 25 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False,activate_first=True,inplace=True): 26 | super(SeparableConv2d,self).__init__() 27 | self.relu0 = nn.ReLU(inplace=inplace) 28 | self.depthwise = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) 29 | self.bn1 = nn.BatchNorm2d(in_channels) 30 | self.relu1 = nn.ReLU(inplace=True) 31 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 32 | self.bn2 = nn.BatchNorm2d(out_channels) 33 | self.relu2 = nn.ReLU(inplace=True) 34 | self.activate_first = activate_first 35 | def forward(self,x): 36 | if self.activate_first: 37 | x = self.relu0(x) 38 | x = self.depthwise(x) 39 | x = self.bn1(x) 40 | if not self.activate_first: 41 | x = self.relu1(x) 42 | x = self.pointwise(x) 43 | x = self.bn2(x) 44 | if not self.activate_first: 45 | x = self.relu2(x) 46 | return x 47 | 48 | 49 | class Block(nn.Module): 50 | def __init__(self,in_filters,out_filters,strides=1,atrous=None,grow_first=True,activate_first=True,inplace=True): 51 | super(Block, self).__init__() 52 | if atrous == None: 53 | atrous = [1]*3 54 | elif isinstance(atrous, int): 55 | atrous_list = [atrous]*3 56 | atrous = atrous_list 57 | idx = 0 58 | self.head_relu = True 59 | if out_filters != in_filters or strides!=1: 60 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) 61 | self.skipbn = nn.BatchNorm2d(out_filters) 62 | self.head_relu = False 63 | else: 64 | self.skip=None 65 | 66 | self.hook_layer = None 67 | if grow_first: 68 | filters = out_filters 69 | else: 70 | filters = in_filters 71 | self.sepconv1 = SeparableConv2d(in_filters,filters,3,stride=1,padding=1*atrous[0],dilation=atrous[0],bias=False,activate_first=activate_first,inplace=self.head_relu) 72 | self.sepconv2 = SeparableConv2d(filters,out_filters,3,stride=1,padding=1*atrous[1],dilation=atrous[1],bias=False,activate_first=activate_first) 73 | self.sepconv3 = SeparableConv2d(out_filters,out_filters,3,stride=strides,padding=1*atrous[2],dilation=atrous[2],bias=False,activate_first=activate_first,inplace=inplace) 74 | 75 | def forward(self,inp): 76 | 77 | if self.skip is not None: 78 | skip = self.skip(inp) 79 | skip = self.skipbn(skip) 80 | else: 81 | skip = inp 82 | 83 | x = self.sepconv1(inp) 84 | x = self.sepconv2(x) 85 | self.hook_layer = x 86 | x = self.sepconv3(x) 87 | 88 | x+=skip 89 | return x 90 | 91 | 92 | class Xception(nn.Module): 93 | """ 94 | Xception optimized for the ImageNet dataset, as specified in 95 | https://arxiv.org/pdf/1610.02357.pdf 96 | """ 97 | def __init__(self, os): 98 | """ Constructor 99 | Args: 100 | num_classes: number of classes 101 | """ 102 | super(Xception, self).__init__() 103 | 104 | stride_list = None 105 | if os == 8: 106 | stride_list = [2,1,1] 107 | elif os == 16: 108 | stride_list = [2,2,1] 109 | else: 110 | raise ValueError('xception.py: output stride=%d is not supported.'%os) 111 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False) 112 | self.bn1 = nn.BatchNorm2d(32) 113 | self.relu = nn.ReLU(inplace=True) 114 | 115 | self.conv2 = nn.Conv2d(32,64,3,1,1,bias=False) 116 | self.bn2 = nn.BatchNorm2d(64) 117 | #do relu here 118 | 119 | self.block1=Block(64,128,2) 120 | self.block2=Block(128,256,stride_list[0],inplace=False) 121 | self.block3=Block(256,728,stride_list[1]) 122 | 123 | rate = 16//os 124 | self.block4=Block(728,728,1,atrous=rate) 125 | self.block5=Block(728,728,1,atrous=rate) 126 | self.block6=Block(728,728,1,atrous=rate) 127 | self.block7=Block(728,728,1,atrous=rate) 128 | 129 | self.block8=Block(728,728,1,atrous=rate) 130 | self.block9=Block(728,728,1,atrous=rate) 131 | self.block10=Block(728,728,1,atrous=rate) 132 | self.block11=Block(728,728,1,atrous=rate) 133 | 134 | self.block12=Block(728,728,1,atrous=rate) 135 | self.block13=Block(728,728,1,atrous=rate) 136 | self.block14=Block(728,728,1,atrous=rate) 137 | self.block15=Block(728,728,1,atrous=rate) 138 | 139 | self.block16=Block(728,728,1,atrous=[1*rate,1*rate,1*rate]) 140 | self.block17=Block(728,728,1,atrous=[1*rate,1*rate,1*rate]) 141 | self.block18=Block(728,728,1,atrous=[1*rate,1*rate,1*rate]) 142 | self.block19=Block(728,728,1,atrous=[1*rate,1*rate,1*rate]) 143 | 144 | self.block20=Block(728,1024,stride_list[2],atrous=rate,grow_first=False) 145 | #self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) 146 | 147 | self.conv3 = SeparableConv2d(1024,1536,3,1,1*rate,dilation=rate,activate_first=False) 148 | # self.bn3 = nn.BatchNorm2d(1536) 149 | 150 | self.conv4 = SeparableConv2d(1536,1536,3,1,1*rate,dilation=rate,activate_first=False) 151 | # self.bn4 = nn.BatchNorm2d(1536) 152 | 153 | #do relu here 154 | self.conv5 = SeparableConv2d(1536,2048,3,1,1*rate,dilation=rate,activate_first=False) 155 | # self.bn5 = nn.BatchNorm2d(2048) 156 | self.layers = [] 157 | 158 | #------- init weights -------- 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d): 161 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 162 | m.weight.data.normal_(0, math.sqrt(2. / n)) 163 | elif isinstance(m, nn.BatchNorm2d): 164 | m.weight.data.fill_(1) 165 | m.bias.data.zero_() 166 | #----------------------------- 167 | 168 | def forward(self, input): 169 | self.layers = [] 170 | x = self.conv1(input) 171 | x = self.bn1(x) 172 | x = self.relu(x) 173 | #self.layers.append(x) 174 | x = self.conv2(x) 175 | x = self.bn2(x) 176 | x = self.relu(x) 177 | 178 | x = self.block1(x) 179 | x = self.block2(x) 180 | self.layers.append(self.block2.hook_layer) 181 | x = self.block3(x) 182 | # self.layers.append(self.block3.hook_layer) 183 | x = self.block4(x) 184 | x = self.block5(x) 185 | x = self.block6(x) 186 | x = self.block7(x) 187 | x = self.block8(x) 188 | x = self.block9(x) 189 | x = self.block10(x) 190 | x = self.block11(x) 191 | x = self.block12(x) 192 | x = self.block13(x) 193 | x = self.block14(x) 194 | x = self.block15(x) 195 | x = self.block16(x) 196 | x = self.block17(x) 197 | x = self.block18(x) 198 | x = self.block19(x) 199 | x = self.block20(x) 200 | # self.layers.append(self.block20.hook_layer) 201 | self.layers.append(x) 202 | x = self.conv3(x) 203 | # x = self.bn3(x) 204 | # x = self.relu(x) 205 | 206 | x = self.conv4(x) 207 | # x = self.bn4(x) 208 | # x = self.relu(x) 209 | 210 | x = self.conv5(x) 211 | # x = self.bn5(x) 212 | # x = self.relu(x) 213 | self.layers.append(x) 214 | 215 | return x 216 | 217 | def get_layers(self): 218 | return self.layers 219 | -------------------------------------------------------------------------------- /dataset/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Oct 15 16:36:36 2020 4 | 5 | @author: meiluzhu 6 | """ 7 | 8 | import os 9 | import numpy as np 10 | import pickle 11 | import cv2 12 | 13 | m_type = ['vascularlesions','inflammatory','normal'] 14 | patients = [] 15 | images = [] 16 | labels = [] 17 | res = 288 18 | 19 | #### CAD-CAP 20 | base = 'C:\\ZML\\Dataset\\WCE\\CAD-CAP' 21 | save_base = 'C:\\ZML\\Dataset\\WCE\\temp\\CAD-CAP' 22 | 23 | file = os.listdir(os.path.join(base, m_type[1])) 24 | for f in file: 25 | filename = f.split('.') 26 | if os.path.exists(os.path.join(base,m_type[1],filename[0])+'_a'+'.jpg'): 27 | img_dir = m_type[1]+'\\'+ f 28 | mask_dir = m_type[1]+'\\'+ filename[0]+'_a.jpg' 29 | print(img_dir) 30 | image = cv2.imread(os.path.join(base, img_dir)) 31 | mask = cv2.imread(os.path.join(base, mask_dir),cv2.IMREAD_GRAYSCALE) 32 | h,w,c = image.shape 33 | h_top = 0 34 | h_bot = h 35 | w_lef = 0 36 | w_rig = w 37 | if h == 576 or h == 704: 38 | h_top = 32 39 | h_bot = h-32 40 | w_lef = 32 41 | w_rig = w-32 42 | post_img = image[h_top:h_bot,w_lef:w_rig,:] 43 | post_img[0:45,0:15,:] = 0 44 | post_mask = mask[h_top:h_bot,w_lef:w_rig] 45 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST) 46 | post_mask = cv2.resize(post_mask, (res,res), interpolation=cv2.INTER_NEAREST) 47 | new_dir = os.path.join(save_base,img_dir) 48 | cv2.imwrite(new_dir, post_img) 49 | new_dir = os.path.join(save_base,mask_dir) 50 | cv2.imwrite(new_dir, post_mask) 51 | post_mask[post_mask>0] = 1 52 | patient = {'image':'CAD-CAP/'+img_dir.replace('\\','/' ), 'mask': 'CAD-CAP/'+mask_dir.replace('\\','/' ),'label': 2} 53 | patients.append(patient) 54 | 55 | file = os.listdir(os.path.join(base, m_type[0])) 56 | for f in file: 57 | filename = f.split('.') 58 | if os.path.exists(os.path.join(base,m_type[0],filename[0])+'_a'+'.jpg'): 59 | img_dir = m_type[0]+'/'+ f 60 | mask_dir = m_type[0]+'/'+ filename[0]+'_a.jpg' 61 | print(img_dir) 62 | image = cv2.imread(os.path.join(base, img_dir)) 63 | mask = cv2.imread(os.path.join(base, mask_dir),cv2.IMREAD_GRAYSCALE) 64 | ###### 65 | h,w,c = image.shape 66 | h_top = 0 67 | h_bot = h 68 | w_lef = 0 69 | w_rig = w 70 | if h == 576 or h == 704: 71 | h_top = 32 72 | h_bot = h-32 73 | w_lef = 32 74 | w_rig = w-32 75 | post_img = image[h_top:h_bot,w_lef:w_rig,:] 76 | h,w,c = post_img.shape 77 | if h>600: 78 | post_img[0:10,0:139,:] = 0 79 | post_img[h-2:h,0:191,:] = 0 80 | post_img[h-5:h,0:150,:] = 0 81 | else: 82 | post_img[0:10,0:115,:] = 0 83 | post_img[h-2:h,0:133,:] = 0 84 | post_img[h-5:h,0:110,:] = 0 85 | post_img[0:60,0:21,:] = 0 86 | post_mask = mask[h_top:h_bot,w_lef:w_rig] 87 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST) 88 | post_mask = cv2.resize(post_mask, (res,res), interpolation=cv2.INTER_NEAREST) 89 | new_dir = os.path.join(save_base,img_dir) 90 | cv2.imwrite(new_dir, post_img) 91 | new_dir = os.path.join(save_base,mask_dir) 92 | cv2.imwrite(new_dir, post_mask) 93 | post_mask[post_mask>0] = 1 94 | patient = {'image':'CAD-CAP/'+img_dir.replace('\\','/' ), 'mask': 'CAD-CAP/'+mask_dir.replace('\\','/' ),'label': 1} 95 | patients.append(patient) 96 | 97 | file = os.listdir(os.path.join(base, m_type[2])) 98 | for f in file: 99 | img_dir = m_type[2]+'/'+ f 100 | print(img_dir) 101 | image = cv2.imread(os.path.join(base, img_dir)) 102 | h,w,c = image.shape 103 | h_top = 0 104 | h_bot = h 105 | w_lef = 0 106 | w_rig = w 107 | if h == 576 or h == 704: 108 | h_top = 32 109 | h_bot = h-32 110 | w_lef = 32 111 | w_rig = w-32 112 | post_img = image[h_top:h_bot,w_lef:w_rig,:] 113 | h,w,c = post_img.shape 114 | post_img[0:45,0:15,:] = 0 115 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST) 116 | img_dir = img_dir.replace(' ', '_') 117 | new_dir = os.path.join(save_base,img_dir) 118 | cv2.imwrite(new_dir, post_img) 119 | post_mask = np.zeros((res,res), dtype = np.uint8) 120 | filename = img_dir.split('.') 121 | mask_dir = filename[0]+'_a.jpg' 122 | new_dir = os.path.join(save_base,mask_dir) 123 | cv2.imwrite(new_dir, post_mask) 124 | post_mask[post_mask>0] = 1 125 | patient = {'image':'CAD-CAP/'+img_dir.replace('\\','/' ), 'mask': 'CAD-CAP/'+mask_dir.replace('\\','/' ),'label': 0} 126 | patients.append(patient) 127 | 128 | ####KID 129 | base = 'C:\\ZML\\Dataset\\WCE\\KID' 130 | save_base = 'C:\\ZML\\Dataset\\WCE\\temp\\KID' 131 | era = 4 132 | file = os.listdir(os.path.join(base, m_type[1])) 133 | for f in file: 134 | filename = f.split('.') 135 | if os.path.exists(os.path.join(base,m_type[1],filename[0])+'m'+'.png'): 136 | img_dir = m_type[1]+'\\'+ f 137 | mask_dir = m_type[1]+'\\'+ filename[0]+'m.png' 138 | print(img_dir) 139 | image = cv2.imread(os.path.join(base, img_dir)) 140 | mask = cv2.imread(os.path.join(base, mask_dir),cv2.IMREAD_GRAYSCALE) 141 | h,w,c = image.shape 142 | h_top = 20 + era 143 | h_bot = h - 20 - era 144 | w_lef = 20 + era 145 | w_rig = w - 20 - era 146 | post_img = image[h_top:h_bot,w_lef:w_rig,:] 147 | post_mask = mask[h_top:h_bot,w_lef:w_rig] 148 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST) 149 | post_mask = cv2.resize(post_mask, (res,res), interpolation=cv2.INTER_NEAREST) 150 | new_dir = os.path.join(save_base,img_dir) 151 | cv2.imwrite(new_dir, post_img) 152 | new_dir = os.path.join(save_base,mask_dir) 153 | cv2.imwrite(new_dir, post_mask) 154 | post_mask[post_mask>0] = 1 155 | patient = {'image':'KID/'+img_dir.replace('\\','/' ), 'mask': 'KID/'+mask_dir.replace('\\','/' ),'label': 2} 156 | patients.append(patient) 157 | 158 | file = os.listdir(os.path.join(base, m_type[0])) 159 | for f in file: 160 | filename = f.split('.') 161 | if os.path.exists(os.path.join(base,m_type[0],filename[0])+'m'+'.png'): 162 | img_dir = m_type[0]+'/'+ f 163 | mask_dir = m_type[0]+'/'+ filename[0]+'m.png' 164 | print(img_dir) 165 | image = cv2.imread(os.path.join(base, img_dir)) 166 | mask = cv2.imread(os.path.join(base, mask_dir),cv2.IMREAD_GRAYSCALE) 167 | h,w,c = image.shape 168 | h_top = 20 + era 169 | h_bot = h - 20 - era 170 | w_lef = 20 + era 171 | w_rig = w - 20 - era 172 | post_img = image[h_top:h_bot,w_lef:w_rig,:] 173 | post_mask = mask[h_top:h_bot,w_lef:w_rig] 174 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST) 175 | post_mask = cv2.resize(post_mask, (res,res), interpolation=cv2.INTER_NEAREST) 176 | new_dir = os.path.join(save_base,img_dir) 177 | cv2.imwrite(new_dir, post_img) 178 | new_dir = os.path.join(save_base,mask_dir) 179 | cv2.imwrite(new_dir, post_mask) 180 | post_mask[post_mask>0] = 1 181 | patient = {'image':'KID/'+img_dir.replace('\\','/' ), 'mask': 'KID/'+mask_dir.replace('\\','/' ),'label': 1} 182 | patients.append(patient) 183 | 184 | m_type[2] = 'normal-small-bowel' 185 | file = os.listdir(os.path.join(base, m_type[2])) 186 | for f in file: 187 | img_dir = m_type[2]+'/'+ f 188 | print(img_dir) 189 | image = cv2.imread(os.path.join(base, img_dir)) 190 | h,w,c = image.shape 191 | h_top = 20 + era 192 | h_bot = h - 20 - era 193 | w_lef = 20 + era 194 | w_rig = w - 20 - era 195 | post_img = image[h_top:h_bot,w_lef:w_rig,:] 196 | post_img = cv2.resize(post_img, (res,res), interpolation=cv2.INTER_NEAREST) 197 | img_dir = img_dir.replace(' ', '_') 198 | img_dir = img_dir.replace('normal-small-bowel', 'normal') 199 | new_dir = os.path.join(save_base,img_dir) 200 | cv2.imwrite(new_dir, post_img) 201 | post_mask = np.zeros((res,res), dtype = np.uint8) 202 | filename = img_dir.split('.') 203 | mask_dir = filename[0]+'m.png' 204 | new_dir = os.path.join(save_base,mask_dir) 205 | cv2.imwrite(new_dir, post_mask) 206 | post_mask[post_mask>0] = 1 207 | patient = {'image':'KID/'+img_dir.replace('\\','/' ), 'mask': 'KID/'+mask_dir.replace('\\','/' ),'label': 0} 208 | patients.append(patient) 209 | 210 | np.random.shuffle(patients) 211 | trainset = patients[0:2470] 212 | testset = patients[2470:] 213 | dataset = {'train': trainset, 'test': testset} 214 | path = os.path.join('C:\\ZML\\Dataset\\WCE\\temp', 'WCE_Dataset_larger_Fold1.pkl') 215 | if os.path.exists(path): 216 | os.remove(path) 217 | with open(path,'wb') as f: 218 | pickle.dump(dataset, f) 219 | 220 | 221 | -------------------------------------------------------------------------------- /net/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .gumbel import Gumbel 5 | 6 | class Category_guided_Feature_Generation(nn.Module): 7 | 8 | def __init__(self, 9 | in_channels = 256, 10 | out_channels = 64, EM_STEP = 3): 11 | super(Category_guided_Feature_Generation, self).__init__() 12 | self.out_channels = out_channels 13 | self.EM_STEP = EM_STEP 14 | self.conv0 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1), 15 | nn.BatchNorm2d(out_channels), 16 | nn.ReLU(True), 17 | nn.Dropout2d(0.2, False), 18 | nn.Conv2d(out_channels, out_channels, 1), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(True), 21 | nn.Dropout2d(0.1, False), 22 | ) 23 | 24 | self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1), 25 | nn.BatchNorm2d(out_channels), 26 | nn.ReLU(True), 27 | nn.Dropout2d(0.2, False), 28 | nn.Conv2d(out_channels, out_channels, 1), 29 | nn.BatchNorm2d(out_channels), 30 | nn.ReLU(True), 31 | nn.Dropout2d(0.1, False), 32 | ) 33 | 34 | self.conv2 = nn.Sequential(nn.Conv2d(out_channels*2, out_channels, 1), 35 | nn.BatchNorm2d(out_channels), 36 | nn.ReLU(True), 37 | nn.Dropout2d(0.2, False) 38 | ) 39 | 40 | def forward(self, x, coarse_mask, global_prototypes, regular = 0.5): 41 | 42 | b, h, w = x.size(0), x.size(2), x.size(3) 43 | classes_num = coarse_mask.size(1) 44 | feats = self.conv0(x) 45 | pseudo_mask = coarse_mask.view(b, classes_num, -1) 46 | feats = feats.view(b, self.out_channels, -1).permute(0, 2, 1) 47 | # EM 48 | T = self.EM_STEP 49 | for t in range(T): 50 | prototypes = torch.bmm(pseudo_mask, feats) 51 | prototypes = prototypes / (1e-8 + prototypes.norm(dim=1, keepdim=True)) 52 | attention = torch.bmm(prototypes, feats.permute(0, 2, 1)) 53 | attention = (self.out_channels**-regular) * attention 54 | pseudo_mask = torch.softmax(attention, dim=1) 55 | pseudo_mask = pseudo_mask / (1e-8 + pseudo_mask.sum(dim=1, keepdim=True)) 56 | context_l = torch.bmm(prototypes.permute(0, 2, 1), pseudo_mask).view(b, self.out_channels, h, w) 57 | 58 | feats = self.conv1(x) 59 | feats = feats.view(b, self.out_channels, -1).permute(0, 2, 1) 60 | global_prototypes = global_prototypes / (1e-8 + global_prototypes.norm(dim=1, keepdim=True)) 61 | 62 | #EM 63 | T = self.EM_STEP 64 | for t in range(T): 65 | attention = torch.bmm(global_prototypes, feats.permute(0, 2, 1)) 66 | attention = (self.out_channels**-regular) * attention 67 | pseudo_mask = torch.softmax(attention, dim=1) 68 | pseudo_mask = pseudo_mask / (1e-8 + pseudo_mask.sum(dim=1, keepdim=True)) 69 | global_prototypes = torch.bmm(pseudo_mask, feats) 70 | global_prototypes = global_prototypes / (1e-8 + global_prototypes.norm(dim=1, keepdim=True)) 71 | context_g = torch.bmm(global_prototypes.permute(0, 2, 1), pseudo_mask).view(b, self.out_channels, h, w) # b, 64, 56*56 72 | 73 | context = torch.cat((context_l, context_g), dim = 1) 74 | context = self.conv2(context) 75 | 76 | return context 77 | 78 | 79 | class Global_Prototypes_Generator(nn.Module): 80 | 81 | def __init__(self, 82 | in_channels = 2048, 83 | out_channels = 64): 84 | super(Global_Prototypes_Generator, self).__init__() 85 | self.out_channels = out_channels 86 | self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1), 87 | nn.BatchNorm2d(out_channels), 88 | nn.ReLU(True), 89 | nn.Conv2d(out_channels, out_channels, 1) 90 | ) 91 | 92 | def forward(self, prototypes, category): 93 | classes_num, c = prototypes.size(0), prototypes.size(1) 94 | prototypes = prototypes.view(classes_num,c, 1, 1) 95 | prototypes = self.conv1(prototypes).view(classes_num,self.out_channels) 96 | category = torch.softmax(category, dim = 1) 97 | b = category.size(0) 98 | bg_prototypes = prototypes[0] 99 | bg_prototypes = bg_prototypes.repeat(b, 1, 1) 100 | fg_prototypes = category[:,1:].view(b, classes_num-1, 1) * prototypes[1:] 101 | prototypes = torch.cat((bg_prototypes, fg_prototypes), dim = 1) 102 | 103 | return prototypes 104 | 105 | 106 | 107 | 108 | class Binary_Gate_Unit(nn.Module): 109 | 110 | def __init__(self, config, in_channels = 1024, k = 100): 111 | super(Binary_Gate_Unit, self).__init__() 112 | self.in_channels = in_channels 113 | self.k = k 114 | self.conv= nn.Sequential( 115 | nn.Conv2d(in_channels=self.in_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias = False), 116 | nn.ReLU(inplace=True) 117 | ) 118 | self.fc1 = nn.Linear(k, int(torch.ceil(torch.tensor(k/2)))) 119 | self.fc2 = nn.Linear(int(torch.ceil(torch.tensor(k/2))), k) 120 | 121 | self.gumbel = Gumbel(config) 122 | 123 | def forward(self,topk_prototypes): 124 | 125 | b = topk_prototypes.size(0) 126 | proto_weights = self.conv(topk_prototypes) # meta learner 127 | proto_weights = proto_weights.view(b, -1) 128 | proto_weights = self.fc1(torch.relu(proto_weights)) 129 | proto_weights = self.fc2(proto_weights) # b, k 130 | proto_weights = self.gumbel(proto_weights) 131 | proto_weights = proto_weights.view(b, 1, self.k, 1) 132 | 133 | return proto_weights 134 | 135 | 136 | class Lesion_Location_Mining(nn.Module): 137 | 138 | def __init__(self, config, in_channels = 1024, k = 100): 139 | super(Lesion_Location_Mining, self).__init__() 140 | self.k = k 141 | self.BGU_fore =Binary_Gate_Unit(config, in_channels = in_channels, k = k) 142 | self.BGU_back =Binary_Gate_Unit(config, in_channels = in_channels, k = k) 143 | 144 | def forward(self, feats, soft_mask): 145 | 146 | b,c,h,w = feats.size() 147 | hard_mask = torch.max(soft_mask, dim = 1, keepdim = True)[1] # b, 1, h, w 148 | background_hard_mask = (hard_mask == 0).float() 149 | foreground_hard_mask = (hard_mask == 1).float() 150 | assert torch.sum(hard_mask == 2) == 0, 'Error in Lesion_Location_Mining_Module' 151 | background_soft_mask, foreground_soft_mask = soft_mask.split(1, dim = 1) #b, 1, h, w 152 | foreground_feats = feats * foreground_hard_mask # b, c, h, w 153 | background_feats = feats * background_hard_mask # b, c, h, w 154 | feats = feats.view(b, c, -1) # b, c, hw 155 | 156 | #****** foreground-->background **********# 157 | #key generator 158 | foreground_soft_mask = foreground_soft_mask.view(b, 1, -1) 159 | topk_idx = torch.topk(foreground_soft_mask, self.k, dim = -1, largest=True)[1] 160 | topk_prototypes = [] 161 | for i in range(b): 162 | feats_temp = feats[i,:,topk_idx[i]] # c, k 163 | topk_prototypes.append(feats_temp) 164 | topk_prototypes = torch.stack(topk_prototypes) # b, c, k 165 | topk_prototypes = topk_prototypes.view(b, c, self.k, 1) 166 | proto_weights = self.BGU_fore(topk_prototypes) 167 | topk_prototypes = topk_prototypes * proto_weights # b, c, k, 1 168 | 169 | # b, c, h, w # b, c, k ---> 170 | background_feats = background_feats.view(b, c, -1) # b, c, hw 171 | topk_prototypes = topk_prototypes.view(b, c, -1).permute(0, 2, 1) # b, k, c 172 | fore_attention_map = torch.matmul(topk_prototypes, background_feats) # b, k ,hw 173 | 174 | #norm + relu 175 | norm_prototypes = torch.norm(topk_prototypes, dim = -1, keepdim=True) # b, k, 1 176 | norm_background_feats = torch.norm(background_feats, dim = 1, keepdim=True) #b, 1, hw 177 | norm = torch.bmm(norm_prototypes, norm_background_feats) # b, k, hw 178 | fore_attention_map = fore_attention_map /(norm + 1e-8) 179 | fore_attention_map = torch.relu(fore_attention_map) 180 | fore_attention_map = fore_attention_map.view(b, self.k, h, w) 181 | fore_attention_map = torch.max(fore_attention_map, dim = 1, keepdim = True) [0] 182 | 183 | #****** background-->foreground**********# 184 | #key generator 185 | background_soft_mask = background_soft_mask.view(b, 1, -1) 186 | topk_idx = torch.topk(background_soft_mask, self.k, dim = -1, largest=True)[1] 187 | topk_prototypes = [] 188 | for i in range(b): 189 | feats_temp = feats[i,:,topk_idx[i]] # c, k 190 | topk_prototypes.append(feats_temp) 191 | topk_prototypes = torch.stack(topk_prototypes) # b, c, k 192 | topk_prototypes = topk_prototypes.view(b, c, self.k, 1) 193 | proto_weights = self.BGU_back(topk_prototypes) 194 | topk_prototypes = topk_prototypes * proto_weights # b, c, k, 1 195 | 196 | # b, c, h, w # b, c, k ---> 197 | foreground_feats = foreground_feats.view(b, c, -1) # b, c, hw 198 | topk_prototypes = topk_prototypes.view(b, c, -1).permute(0, 2, 1) # b, k, c 199 | back_attention_map = torch.matmul(topk_prototypes, foreground_feats) # b, k ,hw 200 | #norm + relu 201 | norm_prototypes = torch.norm(topk_prototypes, dim = -1, keepdim=True) # b, k, 1 202 | norm_foreground_feats = torch.norm(foreground_feats, dim = 1, keepdim=True) #b, 1, hw 203 | norm = torch.bmm(norm_prototypes, norm_foreground_feats) # b, k, hw 204 | back_attention_map = back_attention_map /(norm + 1e-8) 205 | back_attention_map = torch.relu(back_attention_map) 206 | back_attention_map = back_attention_map.view(b, self.k, h, w) 207 | back_attention_map = torch.max(back_attention_map, dim = 1, keepdim = True) [0] 208 | 209 | #merging 210 | feats = feats.view(b, c, h, w) 211 | foreground_soft_mask = foreground_soft_mask.view(b, 1, h, w) 212 | feats = feats + feats * (foreground_soft_mask - back_attention_map + fore_attention_map) 213 | 214 | return feats #b, c, h,w 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | -------------------------------------------------------------------------------- /train_DSI_Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.backends.cudnn as cudnn 5 | import os 6 | from sklearn import metrics 7 | from sklearn.metrics import accuracy_score 8 | from net.models import DSI_Net 9 | from net.loss import Task_Interaction_Loss, Dice_Loss 10 | from dataset.my_datasets import CADCAPDataset 11 | from torch.utils import data 12 | from apex import amp 13 | from utils.logger import print_f 14 | import time 15 | import config 16 | import argparse 17 | from visualization.utils import show_seg_results, draw_curves 18 | 19 | #https://drive.google.com/file/d/12RjjEKM4nXtskHSJkWMdJ7S5PeaVFow3/view?usp=sharing 20 | model_urls = {'deeplabv3plus_xception': 'data/pre_model/deeplabv3plus_xception_VOC2012_epoch46_all.pth'} 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--image_list', default='/home/meiluzhu2/data/WCE/WCE_Dataset_larger_Fold1.pkl', type=str, help='image list pkl') 24 | parser.add_argument('--gpus', default='7', type=str, help='gpus') 25 | parser.add_argument('--K', default=100, type=int, help='seed number') 26 | parser.add_argument('--alpha', default=0.05, type=float, help='the weight of interaction loss') 27 | args = parser.parse_args() 28 | 29 | def lr_poly(base_lr, iter, max_iter, power): 30 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 31 | 32 | def adjust_learning_rate(optimizer, i_iter): 33 | lr = lr_poly(config.LEARNING_RATE, i_iter, config.STEPS, config.POWER) 34 | optimizer.param_groups[0]['lr'] = lr 35 | return lr 36 | 37 | def test(valloader, model, epoch, path = None, verbose = False): 38 | # valiadation 39 | #cls 40 | pro_score_crop = [] 41 | label_val_crop = [] 42 | 43 | #refine seg 44 | seg_dice = [] 45 | seg_sen = [] 46 | seg_spe = [] 47 | seg_acc = [] 48 | seg_jac_score = [] 49 | 50 | for index, batch in enumerate(valloader): 51 | data, masks, label, name = batch 52 | data = data.cuda() 53 | label = label.cuda() 54 | mask = masks[0].data.numpy() 55 | val_mask = np.int64(mask > 0) 56 | 57 | model.eval() 58 | with torch.no_grad(): 59 | pred_seg_coarse, pred_seg_fine, pred_cls = model(data) 60 | 61 | #cls 62 | pro_score_crop.append(torch.softmax(pred_cls[0], dim=0).cpu().data.numpy()) 63 | label_val_crop.append(label[0].cpu().data.numpy()) 64 | 65 | #seg 66 | y_true_f = val_mask.reshape(val_mask.shape[0]*val_mask.shape[1], order='F') 67 | if np.sum(y_true_f) != 0 and label[0].cpu().data.numpy() != 0: 68 | pred_seg = torch.softmax(pred_seg_fine, dim=1).cpu().data.numpy() 69 | pred_arg = np.argmax(pred_seg[0], axis=0) 70 | y_pred_f = pred_arg.reshape(pred_arg.shape[0]*pred_arg.shape[1], order='F') 71 | intersection = np.float(np.sum(y_true_f * y_pred_f)) 72 | seg_dice.append((2. * intersection) / (np.sum(y_true_f) + np.sum(y_pred_f))) 73 | seg_sen.append(intersection / np.sum(y_true_f)) 74 | intersection0 = np.float(np.sum((1 - y_true_f) * (1 - y_pred_f))) 75 | seg_spe.append(intersection0 / np.sum(1 - y_true_f)) 76 | seg_acc.append(accuracy_score(y_true_f, y_pred_f)) 77 | seg_jac_score.append(intersection / (np.sum(y_true_f) + np.sum(y_pred_f) - intersection)) 78 | 79 | if verbose == config.VERBOSE and epoch == config.EPOCH-1: 80 | show_seg_results(data[0].cpu().data.numpy().transpose(1, 2, 0), mask, pred_arg, path, name[0]) 81 | #cls 82 | pro_score_crop = np.array(pro_score_crop) 83 | label_val_crop = np.array(label_val_crop) 84 | binary_score = np.eye(3)[np.argmax(np.array(pro_score_crop), axis=-1)] 85 | label_val = np.eye(3)[np.int64(np.array(label_val_crop))] 86 | preds = np.argmax(np.array(pro_score_crop), axis=-1) 87 | CK = metrics.cohen_kappa_score(label_val_crop, preds) 88 | OA = metrics.accuracy_score(label_val_crop, preds) 89 | EREC = metrics.recall_score(label_val, binary_score, average=None) 90 | 91 | result = {} 92 | result['seg'] = [np.array(seg_acc), np.array(seg_dice), np.array(seg_sen), np.array(seg_spe), np.array(seg_jac_score)] 93 | result['cls'] = [CK, OA, EREC] 94 | return result 95 | 96 | 97 | def main(): 98 | """Create the network and start the training.""" 99 | 100 | cudnn.enabled = True 101 | cudnn.benchmark = True 102 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 103 | 104 | ############# Create mask-guided classification network. 105 | model = DSI_Net(config, K = args.K) 106 | optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE, weight_decay =config.WEIGHT_DECAY) 107 | model.cuda() 108 | if config.FP16 is True: 109 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1") 110 | model = torch.nn.DataParallel(model) 111 | 112 | ############# Load pretrained weights 113 | pretrained_dict = torch.load(model_urls['deeplabv3plus_xception']) 114 | net_dict = model.state_dict() 115 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in net_dict) and (v.shape == net_dict[k].shape)} 116 | net_dict.update(pretrained_dict) 117 | model.load_state_dict(net_dict) 118 | print(len(net_dict)) 119 | print(len(pretrained_dict)) 120 | model.train() 121 | model.float() 122 | ce_loss = nn.CrossEntropyLoss() 123 | dice_loss = Dice_Loss() 124 | task_interaction_loss = Task_Interaction_Loss() 125 | 126 | ############# Load training and validation data 127 | trainloader = data.DataLoader(CADCAPDataset(config.DATA_ROOT, args.image_list, config.SIZE, data_type='train', mode = 'train'), batch_size=config.BATCH_SIZE, shuffle=True, 128 | num_workers=config.NUM_WORKERS, pin_memory=True, drop_last = config.DROP_LAST) 129 | testloader = data.DataLoader(CADCAPDataset(config.DATA_ROOT, args.image_list,config.SIZE, data_type = 'test', mode='test'), 130 | batch_size=1, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True) 131 | train_testloader = data.DataLoader(CADCAPDataset(config.DATA_ROOT, args.image_list,config.SIZE, data_type = 'train', mode = 'test'), 132 | batch_size=1, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True) 133 | 134 | 135 | if not os.path.isdir(config.SAVE_PATH): 136 | os.mkdir(config.SAVE_PATH) 137 | if not os.path.isdir(config.SAVE_PATH+'Seg_results/'): 138 | os.mkdir(config.SAVE_PATH+'Seg_results/') 139 | if not os.path.isdir(config.LOG_PATH): 140 | os.mkdir(config.LOG_PATH) 141 | 142 | f_path = config.LOG_PATH + 'training_output.log' 143 | logfile = open(f_path, 'a') 144 | 145 | print_f(os.getcwd(), f=logfile) 146 | print_f('Device: {}'.format(args.gpus), f=logfile) 147 | print_f('==={}==='.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())), f=logfile) 148 | print_f('===Setting===', f=logfile) 149 | print_f(' Data_list: {}'.format(args.image_list), f=logfile) 150 | print_f(' K: {}'.format(args.K), f=logfile) 151 | print_f(' Lost_weight: {}'.format(args.alpha), f=logfile) 152 | print_f(' LR: {}'.format(config.LEARNING_RATE), f=logfile) 153 | 154 | OA_bulk_train = [] 155 | CK_bulk_train = [] 156 | DI_bulk_train = [] 157 | JA_bulk_train = [] 158 | SE_bulk_train = [] 159 | 160 | OA_bulk_test = [] 161 | CK_bulk_test = [] 162 | DI_bulk_test = [] 163 | JA_bulk_test = [] 164 | SE_bulk_test = [] 165 | 166 | 167 | for epoch in range(config.EPOCH): 168 | #cls 169 | cls_train_loss = [] 170 | seg_train_loss = [] 171 | train_inter_loss = [] 172 | ############# Start the training 173 | for i_iter, batch in enumerate(trainloader): 174 | step = (config.TRAIN_NUM/config.BATCH_SIZE)*epoch+i_iter 175 | images, masks, labels, name = batch 176 | images = images.cuda() 177 | labels = labels.cuda().long() 178 | masks = masks.cuda().squeeze(1) 179 | optimizer.zero_grad() 180 | lr = adjust_learning_rate(optimizer, step) 181 | model.train() 182 | preds_seg_coarse, preds_seg_fine, preds_cls = model(images) 183 | cls_loss = ce_loss(preds_cls, labels) 184 | seg_loss_fine = dice_loss(preds_seg_fine, masks) 185 | seg_loss_coarse = dice_loss(preds_seg_coarse, masks) 186 | inter_loss = task_interaction_loss(preds_cls, preds_seg_fine, labels) 187 | loss = cls_loss + seg_loss_fine + seg_loss_coarse + args.alpha * inter_loss 188 | 189 | if config.FP16 is True: 190 | with amp.scale_loss(loss, optimizer) as scaled_loss: 191 | scaled_loss.backward() 192 | else: 193 | loss.backward() 194 | optimizer.step() 195 | #cls 196 | cls_train_loss.append(cls_loss.cpu().data.numpy()) 197 | seg_train_loss.append(seg_loss_fine.cpu().data.numpy()) 198 | train_inter_loss.append(inter_loss.cpu().data.numpy()) 199 | 200 | ############ train log 201 | line = "Train-Epoch [%d/%d] [All]: Seg_loss = %.6f, Class_loss = %.6f, Inter_loss = %.6f, LR = %0.9f\n" % (epoch, config.EPOCH, np.nanmean(seg_train_loss), np.nanmean(cls_train_loss), np.nanmean(train_inter_loss), lr) 202 | print_f(line, f=logfile) 203 | 204 | result = test(train_testloader, model, epoch, verbose=False) 205 | #cls 206 | [CK, OA, EREC] = result['cls'] 207 | OA_bulk_train.append(OA) 208 | CK_bulk_train.append(CK) 209 | 210 | # seg 211 | [AC, DI, SE, SP, JA] = result['seg'] 212 | JA_bulk_train.append(np.nanmean(JA)) 213 | DI_bulk_train.append(np.nanmean(DI)) 214 | SE_bulk_train.append(np.nanmean(SE)) 215 | 216 | ############# Start the test 217 | result = test(testloader, model, epoch, config.SAVE_PATH+'Seg_results/' , verbose = config.VERBOSE) 218 | #cls 219 | [CK, OA, EREC] = result['cls'] 220 | line = "Test -Epoch [%d/%d] [Cls]: CK-Score = %f, OA = %f, Rec-N = %f, Rec-V = %f, Rec-I=%f \n" % (epoch, config.EPOCH, CK, OA, EREC[0],EREC[1],EREC[2] ) 221 | print_f(line, f=logfile) 222 | OA_bulk_test.append(OA) 223 | CK_bulk_test.append(CK) 224 | 225 | # seg 226 | [AC, DI, SE, SP, JA] = result['seg'] 227 | line = "Test -Epoch [%d/%d] [Seg]: AC = %f, DI = %f, SE = %f, SP = %f, JA = %f \n" % (epoch, config.EPOCH, np.nanmean(AC), np.nanmean(DI), np.nanmean(SE), np.nanmean(SP), np.nanmean(JA)) 228 | print_f(line, f=logfile) 229 | 230 | JA_bulk_test.append(np.nanmean(JA)) 231 | DI_bulk_test.append(np.nanmean(DI)) 232 | SE_bulk_test.append(np.nanmean(SE)) 233 | 234 | ############# Plot val curve 235 | filename = os.path.join(config.LOG_PATH, 'cls_curves.png') 236 | data_list = [OA_bulk_train, OA_bulk_test, CK_bulk_train, CK_bulk_test] 237 | label_list = ['OA_train','OA_test','CK_train','CK_test'] 238 | draw_curves(data_list = data_list, label_list = label_list, color_list = config.COLOR[0:4], filename = filename) 239 | filename = os.path.join(config.LOG_PATH, 'seg_curves.png') 240 | data_list = [JA_bulk_train, JA_bulk_test, DI_bulk_train, DI_bulk_test, SE_bulk_train, SE_bulk_test] 241 | label_list = ['JA_train','JA_test','DI_train','DI_test', 'SE_train','SE_test'] 242 | draw_curves(data_list = data_list, label_list = label_list, color_list = config.COLOR[0:6], filename = filename) 243 | 244 | if __name__ == '__main__': 245 | main() 246 | 247 | -------------------------------------------------------------------------------- /net/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | 132 | .. math:: 133 | 134 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 135 | 136 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 137 | standard-deviation are reduced across all devices during training. 138 | 139 | For example, when one uses `nn.DataParallel` to wrap the network during 140 | training, PyTorch's implementation normalize the tensor on each device using 141 | the statistics only on that device, which accelerated the computation and 142 | is also easy to implement, but the statistics might be inaccurate. 143 | Instead, in this synchronized version, the statistics will be computed 144 | over all training samples distributed on multiple devices. 145 | 146 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 147 | as the built-in PyTorch implementation. 148 | 149 | The mean and standard-deviation are calculated per-dimension over 150 | the mini-batches and gamma and beta are learnable parameter vectors 151 | of size C (where C is the input size). 152 | 153 | During training, this layer keeps a running estimate of its computed mean 154 | and variance. The running sum is kept with a default momentum of 0.1. 155 | 156 | During evaluation, this running mean/variance is used for normalization. 157 | 158 | Because the BatchNorm is done over the `C` dimension, computing statistics 159 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 160 | 161 | Args: 162 | num_features: num_features from an expected input of size 163 | `batch_size x num_features [x width]` 164 | eps: a value added to the denominator for numerical stability. 165 | Default: 1e-5 166 | momentum: the value used for the running_mean and running_var 167 | computation. Default: 0.1 168 | affine: a boolean value that when set to ``True``, gives the layer learnable 169 | affine parameters. Default: ``True`` 170 | 171 | Shape: 172 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 173 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 174 | 175 | Examples: 176 | >>> # With Learnable Parameters 177 | >>> m = SynchronizedBatchNorm1d(100) 178 | >>> # Without Learnable Parameters 179 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 180 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 181 | >>> output = m(input) 182 | """ 183 | 184 | def _check_input_dim(self, input): 185 | if input.dim() != 2 and input.dim() != 3: 186 | raise ValueError('expected 2D or 3D input (got {}D input)' 187 | .format(input.dim())) 188 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 189 | 190 | 191 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 192 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 193 | of 3d inputs 194 | 195 | .. math:: 196 | 197 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 198 | 199 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 200 | standard-deviation are reduced across all devices during training. 201 | 202 | For example, when one uses `nn.DataParallel` to wrap the network during 203 | training, PyTorch's implementation normalize the tensor on each device using 204 | the statistics only on that device, which accelerated the computation and 205 | is also easy to implement, but the statistics might be inaccurate. 206 | Instead, in this synchronized version, the statistics will be computed 207 | over all training samples distributed on multiple devices. 208 | 209 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 210 | as the built-in PyTorch implementation. 211 | 212 | The mean and standard-deviation are calculated per-dimension over 213 | the mini-batches and gamma and beta are learnable parameter vectors 214 | of size C (where C is the input size). 215 | 216 | During training, this layer keeps a running estimate of its computed mean 217 | and variance. The running sum is kept with a default momentum of 0.1. 218 | 219 | During evaluation, this running mean/variance is used for normalization. 220 | 221 | Because the BatchNorm is done over the `C` dimension, computing statistics 222 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 223 | 224 | Args: 225 | num_features: num_features from an expected input of 226 | size batch_size x num_features x height x width 227 | eps: a value added to the denominator for numerical stability. 228 | Default: 1e-5 229 | momentum: the value used for the running_mean and running_var 230 | computation. Default: 0.1 231 | affine: a boolean value that when set to ``True``, gives the layer learnable 232 | affine parameters. Default: ``True`` 233 | 234 | Shape: 235 | - Input: :math:`(N, C, H, W)` 236 | - Output: :math:`(N, C, H, W)` (same shape as input) 237 | 238 | Examples: 239 | >>> # With Learnable Parameters 240 | >>> m = SynchronizedBatchNorm2d(100) 241 | >>> # Without Learnable Parameters 242 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 243 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 244 | >>> output = m(input) 245 | """ 246 | 247 | def _check_input_dim(self, input): 248 | if input.dim() != 4: 249 | raise ValueError('expected 4D input (got {}D input)' 250 | .format(input.dim())) 251 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 252 | 253 | 254 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 255 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 256 | of 4d inputs 257 | 258 | .. math:: 259 | 260 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 261 | 262 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 263 | standard-deviation are reduced across all devices during training. 264 | 265 | For example, when one uses `nn.DataParallel` to wrap the network during 266 | training, PyTorch's implementation normalize the tensor on each device using 267 | the statistics only on that device, which accelerated the computation and 268 | is also easy to implement, but the statistics might be inaccurate. 269 | Instead, in this synchronized version, the statistics will be computed 270 | over all training samples distributed on multiple devices. 271 | 272 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 273 | as the built-in PyTorch implementation. 274 | 275 | The mean and standard-deviation are calculated per-dimension over 276 | the mini-batches and gamma and beta are learnable parameter vectors 277 | of size C (where C is the input size). 278 | 279 | During training, this layer keeps a running estimate of its computed mean 280 | and variance. The running sum is kept with a default momentum of 0.1. 281 | 282 | During evaluation, this running mean/variance is used for normalization. 283 | 284 | Because the BatchNorm is done over the `C` dimension, computing statistics 285 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 286 | or Spatio-temporal BatchNorm 287 | 288 | Args: 289 | num_features: num_features from an expected input of 290 | size batch_size x num_features x depth x height x width 291 | eps: a value added to the denominator for numerical stability. 292 | Default: 1e-5 293 | momentum: the value used for the running_mean and running_var 294 | computation. Default: 0.1 295 | affine: a boolean value that when set to ``True``, gives the layer learnable 296 | affine parameters. Default: ``True`` 297 | 298 | Shape: 299 | - Input: :math:`(N, C, D, H, W)` 300 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 301 | 302 | Examples: 303 | >>> # With Learnable Parameters 304 | >>> m = SynchronizedBatchNorm3d(100) 305 | >>> # Without Learnable Parameters 306 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 307 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 308 | >>> output = m(input) 309 | """ 310 | 311 | def _check_input_dim(self, input): 312 | if input.dim() != 5: 313 | raise ValueError('expected 5D input (got {}D input)' 314 | .format(input.dim())) 315 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 316 | --------------------------------------------------------------------------------