├── code ├── utils │ ├── heheda │ ├── metrics.py │ ├── ramps.py │ ├── val_2d.py │ ├── util.py │ └── losses.py ├── dataloaders │ ├── heheda │ ├── acdc_data_processing.py │ ├── la_heart_processing.py │ └── dataset.py ├── networks │ ├── heheda │ ├── networks_other.py │ ├── net_factory.py │ ├── UCPCnetwork.py │ ├── discriminator.py │ ├── vnet_sdf.py │ ├── utils.py │ ├── VNet.py │ └── unet.py ├── test_2d.py └── SCPNettrain.py ├── README.md └── .gitignore /code/utils/heheda: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/dataloaders/heheda: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/networks/heheda: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCP-Net 2 | Early Accepted in MICCAI 2023 3 | -------------------------------------------------------------------------------- /code/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/14 下午4:41 4 | # @Author : chuyu zhang 5 | # @File : metrics.py 6 | # @Software: PyCharm 7 | 8 | 9 | import numpy as np 10 | from medpy import metric 11 | 12 | 13 | def cal_dice(prediction, label, num=2): 14 | total_dice = np.zeros(num-1) 15 | for i in range(1, num): 16 | prediction_tmp = (prediction == i) 17 | label_tmp = (label == i) 18 | prediction_tmp = prediction_tmp.astype(np.float) 19 | label_tmp = label_tmp.astype(np.float) 20 | 21 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 22 | total_dice[i - 1] += dice 23 | 24 | return total_dice 25 | 26 | 27 | def calculate_metric_percase(pred, gt): 28 | dc = metric.binary.dc(pred, gt) 29 | jc = metric.binary.jc(pred, gt) 30 | hd = metric.binary.hd95(pred, gt) 31 | asd = metric.binary.asd(pred, gt) 32 | 33 | return dc, jc, hd, asd 34 | 35 | 36 | def dice(input, target, ignore_index=None): 37 | smooth = 1. 38 | # using clone, so that it can do change to original target. 39 | iflat = input.clone().view(-1) 40 | tflat = target.clone().view(-1) 41 | if ignore_index is not None: 42 | mask = tflat == ignore_index 43 | tflat[mask] = 0 44 | iflat[mask] = 0 45 | intersection = (iflat * tflat).sum() 46 | 47 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) -------------------------------------------------------------------------------- /code/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /code/dataloaders/acdc_data_processing.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import h5py 5 | import numpy as np 6 | import SimpleITK as sitk 7 | 8 | slice_num = 0 9 | mask_path = sorted(glob.glob("/home/xdluo/data/ACDC/image/*.nii.gz")) 10 | for case in mask_path: 11 | img_itk = sitk.ReadImage(case) 12 | origin = img_itk.GetOrigin() 13 | spacing = img_itk.GetSpacing() 14 | direction = img_itk.GetDirection() 15 | image = sitk.GetArrayFromImage(img_itk) 16 | msk_path = case.replace("image", "label").replace(".nii.gz", "_gt.nii.gz") 17 | if os.path.exists(msk_path): 18 | print(msk_path) 19 | msk_itk = sitk.ReadImage(msk_path) 20 | mask = sitk.GetArrayFromImage(msk_itk) 21 | image = (image - image.min()) / (image.max() - image.min()) 22 | print(image.shape) 23 | image = image.astype(np.float32) 24 | item = case.split("/")[-1].split(".")[0] 25 | if image.shape != mask.shape: 26 | print("Error") 27 | print(item) 28 | for slice_ind in range(image.shape[0]): 29 | f = h5py.File( 30 | '/home/xdluo/data/ACDC/data/{}_slice_{}.h5'.format(item, slice_ind), 'w') 31 | f.create_dataset( 32 | 'image', data=image[slice_ind], compression="gzip") 33 | f.create_dataset('label', data=mask[slice_ind], compression="gzip") 34 | f.close() 35 | slice_num += 1 36 | print("Converted all ACDC volumes to 2D slices") 37 | print("Total {} slices".format(slice_num)) 38 | -------------------------------------------------------------------------------- /code/utils/val_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from medpy import metric 4 | from scipy.ndimage import zoom 5 | 6 | 7 | def calculate_metric_percase(pred, gt): 8 | pred[pred > 0] = 1 9 | gt[gt > 0] = 1 10 | if pred.sum() > 0: 11 | dice = metric.binary.dc(pred, gt) 12 | hd95 = metric.binary.hd95(pred, gt) 13 | return dice, hd95 14 | else: 15 | return 0, 0 16 | 17 | 18 | def test_single_volume(image, label, model, classes, patch_size=[256, 256]): 19 | image, label = image.squeeze(0).cpu().detach( 20 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 21 | prediction = np.zeros_like(label) 22 | for ind in range(image.shape[0]): 23 | slice = image[ind, :, :] 24 | x, y = slice.shape[0], slice.shape[1] 25 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 26 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 27 | model.eval() 28 | with torch.no_grad(): 29 | output = model(input) 30 | if len(output)>1: 31 | output = output[3]+output[1]+output[0]#.transpose(0,1).reshape(4,256,256).unsqueeze(0) 32 | out = torch.argmax(output, dim=1).squeeze(0)#torch.sigmoid(output).squeeze() 33 | #out = out>0.5 34 | out = out.cpu().detach().numpy() 35 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 36 | prediction[ind] = pred 37 | metric_list = [] 38 | for i in range(1, classes): 39 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 40 | return metric_list 41 | -------------------------------------------------------------------------------- /code/dataloaders/la_heart_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from glob import glob 3 | from tqdm import tqdm 4 | import h5py 5 | import nrrd 6 | 7 | output_size =[112, 112, 80] 8 | 9 | def covert_h5(): 10 | listt = glob('../data/LA/2018LA_Seg_Training Set/*/lgemri.nrrd') 11 | for item in tqdm(listt): 12 | image, img_header = nrrd.read(item) 13 | label, gt_header = nrrd.read(item.replace('lgemri.nrrd', 'laendo.nrrd')) 14 | label = (label == 255).astype(np.uint8) 15 | w, h, d = label.shape 16 | 17 | tempL = np.nonzero(label) 18 | minx, maxx = np.min(tempL[0]), np.max(tempL[0]) 19 | miny, maxy = np.min(tempL[1]), np.max(tempL[1]) 20 | minz, maxz = np.min(tempL[2]), np.max(tempL[2]) 21 | 22 | px = max(output_size[0] - (maxx - minx), 0) // 2 23 | py = max(output_size[1] - (maxy - miny), 0) // 2 24 | pz = max(output_size[2] - (maxz - minz), 0) // 2 25 | minx = max(minx - np.random.randint(10, 20) - px, 0) 26 | maxx = min(maxx + np.random.randint(10, 20) + px, w) 27 | miny = max(miny - np.random.randint(10, 20) - py, 0) 28 | maxy = min(maxy + np.random.randint(10, 20) + py, h) 29 | minz = max(minz - np.random.randint(5, 10) - pz, 0) 30 | maxz = min(maxz + np.random.randint(5, 10) + pz, d) 31 | 32 | image = (image - np.mean(image)) / np.std(image) 33 | image = image.astype(np.float32) 34 | image = image[minx:maxx, miny:maxy] 35 | label = label[minx:maxx, miny:maxy] 36 | print(label.shape) 37 | f = h5py.File(item.replace('lgemri.nrrd', 'mri_norm2.h5'), 'w') 38 | f.create_dataset('image', data=image, compression="gzip") 39 | f.create_dataset('label', data=label, compression="gzip") 40 | f.close() 41 | 42 | if __name__ == '__main__': 43 | covert_h5() -------------------------------------------------------------------------------- /code/networks/networks_other.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch.nn import init 9 | from torch.optim import lr_scheduler 10 | 11 | ############################################################################### 12 | # Functions 13 | ############################################################################### 14 | 15 | 16 | def weights_init_normal(m): 17 | classname = m.__class__.__name__ 18 | #print(classname) 19 | if classname.find('Conv') != -1: 20 | init.normal(m.weight.data, 0.0, 0.02) 21 | elif classname.find('Linear') != -1: 22 | init.normal(m.weight.data, 0.0, 0.02) 23 | elif classname.find('BatchNorm') != -1: 24 | init.normal(m.weight.data, 1.0, 0.02) 25 | init.constant(m.bias.data, 0.0) 26 | 27 | 28 | def weights_init_xavier(m): 29 | classname = m.__class__.__name__ 30 | #print(classname) 31 | if classname.find('Conv') != -1: 32 | init.xavier_normal(m.weight.data, gain=1) 33 | elif classname.find('Linear') != -1: 34 | init.xavier_normal(m.weight.data, gain=1) 35 | elif classname.find('BatchNorm') != -1: 36 | init.normal(m.weight.data, 1.0, 0.02) 37 | init.constant(m.bias.data, 0.0) 38 | 39 | 40 | def weights_init_kaiming(m): 41 | classname = m.__class__.__name__ 42 | #print(classname) 43 | if classname.find('Conv') != -1: 44 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 45 | elif classname.find('Linear') != -1: 46 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 47 | elif classname.find('BatchNorm') != -1: 48 | init.normal_(m.weight.data, 1.0, 0.02) 49 | init.constant_(m.bias.data, 0.0) 50 | 51 | 52 | def weights_init_orthogonal(m): 53 | classname = m.__class__.__name__ 54 | #print(classname) 55 | if classname.find('Conv') != -1: 56 | init.orthogonal(m.weight.data, gain=1) 57 | elif classname.find('Linear') != -1: 58 | init.orthogonal(m.weight.data, gain=1) 59 | elif classname.find('BatchNorm') != -1: 60 | init.normal(m.weight.data, 1.0, 0.02) 61 | init.constant(m.bias.data, 0.0) 62 | 63 | 64 | def init_weights(net, init_type='normal'): 65 | #print('initialization method [%s]' % init_type) 66 | if init_type == 'normal': 67 | net.apply(weights_init_normal) 68 | elif init_type == 'xavier': 69 | net.apply(weights_init_xavier) 70 | elif init_type == 'kaiming': 71 | net.apply(weights_init_kaiming) 72 | elif init_type == 'orthogonal': 73 | net.apply(weights_init_orthogonal) 74 | else: 75 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 76 | 77 | -------------------------------------------------------------------------------- /code/networks/net_factory.py: -------------------------------------------------------------------------------- 1 | from networks.unet import UNet, MCNet2d_v1, MCNet2d_v2, MCNet2d_v3, UNet_URPC, UNet_CCT, UNet_pro,BFDCNet2d_v1,DiceCENet2d_fuse,UNet_sdf 2 | from networks.VNet import VNet, MCNet3d_v1, MCNet3d_v2, ECNet3d, DiceCENet3d,DiceCENet3d_fuse,DiceCENet3d_fuse_2 3 | 4 | def net_factory(net_type="unet", in_chns=1, class_num=4, mode = "train"): 5 | if net_type == "unet": 6 | net = UNet(in_chns=in_chns, class_num=class_num).cuda() 7 | elif net_type == "mcnet2d_v1": 8 | net = MCNet2d_v1(in_chns=in_chns, class_num=class_num).cuda() 9 | elif net_type == "unetsdf": 10 | net = UNet_sdf(in_chns=in_chns, class_num=class_num).cuda() 11 | elif net_type == "mcnet2d_v2": 12 | net = MCNet2d_v2(in_chns=in_chns, class_num=class_num).cuda() 13 | elif net_type == "mcnet2d_v3": 14 | net = MCNet2d_v3(in_chns=in_chns, class_num=class_num).cuda() 15 | elif net_type == "vnet" and mode == "train": 16 | net = VNet(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=True).cuda() 17 | elif net_type == "mcnet3d_v1" and mode == "train": 18 | net = MCNet3d_v1(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=True).cuda() 19 | elif net_type == "mcnet3d_v2" and mode == "train": 20 | net = MCNet3d_v2(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=True).cuda() 21 | elif net_type == "ecnet3d" and mode == "train": 22 | net = ECNet3d(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=True).cuda() 23 | elif net_type == "dicecenet3d" and mode == "train": 24 | net = DiceCENet3d(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=True).cuda() 25 | elif net_type == "vnet" and mode == "test": 26 | net = VNet(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=False).cuda() 27 | elif net_type == "mcnet3d_v1" and mode == "test": 28 | net = MCNet3d_v1(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=False).cuda() 29 | elif net_type == "mcnet3d_v2" and mode == "test": 30 | net = MCNet3d_v2(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=False).cuda() 31 | elif net_type == "unet_cct" and mode == "train": 32 | net = UNet_CCT(in_chns=in_chns, class_num=class_num).cuda() 33 | elif net_type == "unet_urpc" and mode == "train": 34 | net = UNet_URPC(in_chns=in_chns, class_num=class_num).cuda() 35 | elif net_type == "unet_pro" and mode == "train": 36 | net = UNet_pro(in_chns=in_chns, class_num=class_num).cuda() 37 | elif net_type == "bfdcnet2d" and mode == "train": 38 | net = BFDCNet2d_v1(in_chns=in_chns, class_num=class_num).cuda() 39 | elif net_type == "dicecenetfuse" and mode == "train": 40 | net = DiceCENet3d_fuse(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=True).cuda() 41 | elif net_type == "dicecenetfuse_2" and mode == "train": 42 | net = DiceCENet3d_fuse_2(n_channels=in_chns, n_classes=class_num, normalization='batchnorm', has_dropout=True).cuda() 43 | if net_type == "dicecenetfuse2d": 44 | net = DiceCENet2d_fuse(in_chns=in_chns, class_num=class_num).cuda() 45 | return net 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /code/networks/UCPCnetwork.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT, UnetDsv3 5 | import torch.nn.functional as F 6 | from networks.networks_other import init_weights 7 | 8 | 9 | class unet_3D_dv_semi(nn.Module): 10 | 11 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 12 | super(unet_3D_dv_semi, self).__init__() 13 | self.is_deconv = is_deconv 14 | self.in_channels = in_channels 15 | self.is_batchnorm = is_batchnorm 16 | self.feature_scale = feature_scale 17 | 18 | filters = [64, 128, 256, 512, 1024] 19 | filters = [int(x / self.feature_scale) for x in filters] 20 | 21 | # downsampling 22 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 23 | 3, 3, 3), padding_size=(1, 1, 1)) 24 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 25 | 26 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 27 | 3, 3, 3), padding_size=(1, 1, 1)) 28 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 29 | 30 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 31 | 3, 3, 3), padding_size=(1, 1, 1)) 32 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 33 | 34 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 35 | 3, 3, 3), padding_size=(1, 1, 1)) 36 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 37 | 38 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 39 | 3, 3, 3), padding_size=(1, 1, 1)) 40 | 41 | # upsampling 42 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 43 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 44 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 45 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 46 | 47 | # deep supervision 48 | self.dsv4 = UnetDsv3( 49 | in_size=filters[3], out_size=n_classes, scale_factor=8) 50 | self.dsv3 = UnetDsv3( 51 | in_size=filters[2], out_size=n_classes, scale_factor=4) 52 | self.dsv2 = UnetDsv3( 53 | in_size=filters[1], out_size=n_classes, scale_factor=2) 54 | self.dsv1 = nn.Conv3d( 55 | in_channels=filters[0], out_channels=n_classes, kernel_size=1) 56 | 57 | self.dropout1 = nn.Dropout3d(p=0.5) 58 | self.dropout2 = nn.Dropout3d(p=0.3) 59 | self.dropout3 = nn.Dropout3d(p=0.2) 60 | self.dropout4 = nn.Dropout3d(p=0.1) 61 | 62 | # initialise weights 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv3d): 65 | init_weights(m, init_type='kaiming') 66 | elif isinstance(m, nn.BatchNorm3d): 67 | init_weights(m, init_type='kaiming') 68 | 69 | def forward(self, inputs): 70 | conv1 = self.conv1(inputs) 71 | maxpool1 = self.maxpool1(conv1) 72 | 73 | conv2 = self.conv2(maxpool1) 74 | maxpool2 = self.maxpool2(conv2) 75 | 76 | conv3 = self.conv3(maxpool2) 77 | maxpool3 = self.maxpool3(conv3) 78 | 79 | conv4 = self.conv4(maxpool3) 80 | maxpool4 = self.maxpool4(conv4) 81 | 82 | center = self.center(maxpool4) 83 | 84 | up4 = self.up_concat4(conv4, center) 85 | up4 = self.dropout1(up4) 86 | 87 | up3 = self.up_concat3(conv3, up4) 88 | up3 = self.dropout2(up3) 89 | 90 | up2 = self.up_concat2(conv2, up3) 91 | up2 = self.dropout3(up2) 92 | 93 | up1 = self.up_concat1(conv1, up2) 94 | up1 = self.dropout4(up1) 95 | 96 | # Deep Supervision 97 | dsv4 = self.dsv4(up4) 98 | dsv3 = self.dsv3(up3) 99 | dsv2 = self.dsv2(up2) 100 | dsv1 = self.dsv1(up1) 101 | 102 | return dsv1, dsv2, dsv3, dsv4 103 | 104 | @staticmethod 105 | def apply_argmax_softmax(pred): 106 | log_p = F.softmax(pred, dim=1) 107 | 108 | return log_p -------------------------------------------------------------------------------- /code/utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import pickle 9 | import numpy as np 10 | from scipy.ndimage import distance_transform_edt as distance 11 | from skimage import segmentation as skimage_seg 12 | import torch 13 | from torch.utils.data.sampler import Sampler 14 | 15 | import networks 16 | 17 | def load_model(path): 18 | """Loads model and return it without DataParallel table.""" 19 | if os.path.isfile(path): 20 | print("=> loading checkpoint '{}'".format(path)) 21 | checkpoint = torch.load(path) 22 | 23 | # size of the top layer 24 | N = checkpoint['state_dict']['top_layer.bias'].size() 25 | 26 | # build skeleton of the model 27 | sob = 'sobel.0.weight' in checkpoint['state_dict'].keys() 28 | model = models.__dict__[checkpoint['arch']](sobel=sob, out=int(N[0])) 29 | 30 | # deal with a dataparallel table 31 | def rename_key(key): 32 | if not 'module' in key: 33 | return key 34 | return ''.join(key.split('.module')) 35 | 36 | checkpoint['state_dict'] = {rename_key(key): val 37 | for key, val 38 | in checkpoint['state_dict'].items()} 39 | 40 | # load weights 41 | model.load_state_dict(checkpoint['state_dict']) 42 | print("Loaded") 43 | else: 44 | model = None 45 | print("=> no checkpoint found at '{}'".format(path)) 46 | return model 47 | 48 | 49 | class UnifLabelSampler(Sampler): 50 | """Samples elements uniformely accross pseudolabels. 51 | Args: 52 | N (int): size of returned iterator. 53 | images_lists: dict of key (target), value (list of data with this target) 54 | """ 55 | 56 | def __init__(self, N, images_lists): 57 | self.N = N 58 | self.images_lists = images_lists 59 | self.indexes = self.generate_indexes_epoch() 60 | 61 | def generate_indexes_epoch(self): 62 | size_per_pseudolabel = int(self.N / len(self.images_lists)) + 1 63 | res = np.zeros(size_per_pseudolabel * len(self.images_lists)) 64 | 65 | for i in range(len(self.images_lists)): 66 | indexes = np.random.choice( 67 | self.images_lists[i], 68 | size_per_pseudolabel, 69 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel) 70 | ) 71 | res[i * size_per_pseudolabel: (i + 1) * size_per_pseudolabel] = indexes 72 | 73 | np.random.shuffle(res) 74 | return res[:self.N].astype('int') 75 | 76 | def __iter__(self): 77 | return iter(self.indexes) 78 | 79 | def __len__(self): 80 | return self.N 81 | 82 | 83 | class AverageMeter(object): 84 | """Computes and stores the average and current value""" 85 | def __init__(self): 86 | self.reset() 87 | 88 | def reset(self): 89 | self.val = 0 90 | self.avg = 0 91 | self.sum = 0 92 | self.count = 0 93 | 94 | def update(self, val, n=1): 95 | self.val = val 96 | self.sum += val * n 97 | self.count += n 98 | self.avg = self.sum / self.count 99 | 100 | 101 | def learning_rate_decay(optimizer, t, lr_0): 102 | for param_group in optimizer.param_groups: 103 | lr = lr_0 / np.sqrt(1 + lr_0 * param_group['weight_decay'] * t) 104 | param_group['lr'] = lr 105 | 106 | 107 | class Logger(): 108 | """ Class to update every epoch to keep trace of the results 109 | Methods: 110 | - log() log and save 111 | """ 112 | 113 | def __init__(self, path): 114 | self.path = path 115 | self.data = [] 116 | 117 | def log(self, train_point): 118 | self.data.append(train_point) 119 | with open(os.path.join(self.path), 'wb') as fp: 120 | pickle.dump(self.data, fp, -1) 121 | 122 | 123 | def compute_sdf(img_gt, out_shape): 124 | """ 125 | compute the signed distance map of binary mask 126 | input: segmentation, shape = (batch_size, x, y, z) 127 | output: the Signed Distance Map (SDM) 128 | sdf(x) = 0; x in segmentation boundary 129 | -inf|x-y|; x in segmentation 130 | +inf|x-y|; x out of segmentation 131 | normalize sdf to [-1,1] 132 | """ 133 | 134 | img_gt = img_gt.astype(np.uint8) 135 | normalized_sdf = np.zeros(out_shape) 136 | 137 | for b in range(out_shape[0]): # batch size 138 | posmask = img_gt[b].astype(np.bool) 139 | if posmask.any(): 140 | negmask = ~posmask 141 | posdis = distance(posmask) 142 | negdis = distance(negmask) 143 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 144 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 145 | sdf[boundary==1] = 0 146 | normalized_sdf[b] = sdf 147 | # assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis)) 148 | # assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) 149 | 150 | return normalized_sdf -------------------------------------------------------------------------------- /code/test_2d.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import h5py 6 | import nibabel as nib 7 | import numpy as np 8 | import SimpleITK as sitk 9 | import torch 10 | from medpy import metric 11 | from scipy.ndimage import zoom 12 | from scipy.ndimage.interpolation import zoom 13 | from tqdm import tqdm 14 | from skimage.measure import label 15 | from networks.net_factory import net_factory 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--root_path', type=str, default='./data/ACDC', help='Name of Experiment') # if prostate num-class equals 2 19 | parser.add_argument('--exp', type=str, default='SCACDClabel', help='experiment_name') 20 | parser.add_argument('--model', type=str, default='unet_pro', help='model_name') 21 | parser.add_argument('--num_classes', type=int, default=4, help='output channel of network') 22 | parser.add_argument('--labelnum', type=int, default=7, help='labeled data') 23 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 24 | 25 | def calculate_metric_percase(pred, gt): 26 | pred[pred > 0] = 1 27 | gt[gt > 0] = 1 28 | dice = metric.binary.dc(pred, gt) 29 | jc = metric.binary.jc(pred, gt) 30 | asd = metric.binary.asd(pred, gt) 31 | hd95 = metric.binary.hd95(pred, gt) 32 | return dice, jc, hd95, asd 33 | 34 | def test_single_volume(case, net, test_save_path, FLAGS): 35 | h5f = h5py.File(FLAGS.root_path + "/data/{}.h5".format(case), 'r') 36 | image = h5f['image'][:] 37 | label = h5f['label'][:] 38 | #label[label==1]=0 39 | #label[label==3]=1 40 | #label[label==2]=0 41 | prediction = np.zeros_like(label) 42 | print(label.shape,'labelshape') 43 | for ind in range(image.shape[0]): 44 | slice = image[ind, :, :] 45 | x, y = slice.shape[0], slice.shape[1] 46 | slice = zoom(slice, (256 / x, 256 / y), order=0) 47 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 48 | net.eval() 49 | with torch.no_grad(): 50 | out_main = net(input) 51 | if len(out_main)>1: 52 | out_main=out_main[0]#torch.softmax(out_main[0], dim=1)+ torch.softmax(out_main[1], dim=1)+torch.softmax(out_main[3], dim=1) 53 | #out_main=out_main[4].transpose(0,1).reshape(4,256,256).unsqueeze(0) 54 | print(out_main.shape) 55 | #out_main = torch.reshape(out_main,(256,256,-1)) 56 | out = torch.argmax(out_main, dim=1).squeeze(0)#torch.softmax(out_main, dim=1) 57 | #out = torch.sigmoid(out_main).squeeze() 58 | #out = out>0.5 59 | out = out.cpu().detach().numpy() 60 | pred = zoom(out, (x / 256, y / 256), order=0) 61 | prediction[ind] = pred 62 | if np.sum(prediction == 1)==0: 63 | first_metric = 0,0,0,0 64 | else: 65 | first_metric = calculate_metric_percase(prediction == 1, label == 1) 66 | 67 | if np.sum(prediction == 2)==0: 68 | second_metric = 0,0,0,0 69 | else: 70 | second_metric = calculate_metric_percase(prediction == 2, label == 2) 71 | 72 | if np.sum(prediction == 3)==0: 73 | third_metric = 0,0,0,0 74 | else: 75 | third_metric = calculate_metric_percase(prediction == 3, label == 3) 76 | 77 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 78 | img_itk.SetSpacing((1, 1, 10)) 79 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 80 | prd_itk.SetSpacing((1, 1, 10)) 81 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 82 | lab_itk.SetSpacing((1, 1, 10)) 83 | sitk.WriteImage(prd_itk, test_save_path + case + "_pred.nii.gz") 84 | sitk.WriteImage(img_itk, test_save_path + case + "_img.nii.gz") 85 | sitk.WriteImage(lab_itk, test_save_path + case + "_gt.nii.gz") 86 | return first_metric, second_metric, third_metric 87 | 88 | 89 | def Inference(FLAGS): 90 | with open(FLAGS.root_path + '/test.list', 'r') as f: 91 | image_list = f.readlines() 92 | image_list = sorted([item.replace('\n', '').split(".")[0] for item in image_list]) 93 | snapshot_path = "/home/zxzhang/MC-Net-Main/model/ACDC_{}_{}_labeled/{}".format(FLAGS.exp, FLAGS.labelnum, FLAGS.model) 94 | test_save_path = "/home/zxzhang/MC-Net-Main/model/ACDC_{}_{}_labeled/ACDC{}_predictions/".format(FLAGS.exp, FLAGS.labelnum, FLAGS.model) 95 | if os.path.exists(test_save_path): 96 | shutil.rmtree(test_save_path) 97 | os.makedirs(test_save_path) 98 | net = net_factory(net_type=FLAGS.model, in_chns=1, class_num=FLAGS.num_classes) 99 | save_model_path = "/home/zxzhang/MC-Net-Main/model/ACDC_DiceCEACDCclassall_7_labeled/dicecenetfuse2d/iter_49700_dice_0.8593.pth"#os.path.join(snapshot_path, '{}_best_model.pth'.format(FLAGS.model))#"/home/zxzhang/MC-Net-Main/data/ACDCSASSNetACDCclass3/data/ACDC_SASSNetACDC_7_labeled/unetsdf//iter_6200_dice_0.8938.pth"#os.path.join(snapshot_path, '{}_best_model.pth'.format(FLAGS.model))#os.path.join(snapshot_path, '{}_best_model.pth'.format(FLAGS.model))#os.path.join(snapshot_path, '{}_best_model.pth'.format(FLAGS.model)) 100 | net.load_state_dict(torch.load(save_model_path), strict=False) 101 | print("init weight from {}".format(save_model_path)) 102 | net.eval() 103 | 104 | first_total = 0.0 105 | second_total = 0.0 106 | third_total = 0.0 107 | for case in tqdm(image_list): 108 | first_metric, second_metric, third_metric = test_single_volume(case, net, test_save_path, FLAGS) 109 | first_total += np.asarray(first_metric) 110 | second_total += np.asarray(second_metric) 111 | third_total += np.asarray(third_metric) 112 | avg_metric = [first_total / len(image_list), second_total / len(image_list), third_total / len(image_list)] 113 | return avg_metric, test_save_path 114 | 115 | 116 | if __name__ == '__main__': 117 | FLAGS = parser.parse_args() 118 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 119 | metric, test_save_path = Inference(FLAGS) 120 | print(metric) 121 | print((metric[0]+metric[1]+metric[2])/3) 122 | with open(test_save_path+'../performance.txt', 'w') as f: 123 | f.writelines('metric is {} \n'.format(metric)) 124 | f.writelines('average metric is {}\n'.format((metric[0]+metric[1]+metric[2])/3)) 125 | -------------------------------------------------------------------------------- /code/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class FCDiscriminator(nn.Module): 7 | 8 | def __init__(self, num_classes, ndf=64, n_channel=1): 9 | super(FCDiscriminator, self).__init__() 10 | self.conv0 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 11 | self.conv1 = nn.Conv2d(n_channel, ndf, kernel_size=4, stride=2, padding=1) 12 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 13 | self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 14 | self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 15 | self.classifier = nn.Linear(ndf*8, 2) 16 | self.avgpool = nn.AvgPool2d((16, 16)) 17 | 18 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 19 | self.dropout = nn.Dropout2d(0.5) 20 | # self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 21 | # self.sigmoid = nn.Sigmoid() 22 | 23 | def forward(self, map, feature): 24 | batch_size = map.shape[0] 25 | map_feature = self.conv0(map) 26 | image_feature = self.conv1(feature) 27 | x = torch.add(map_feature, image_feature) 28 | 29 | x = self.conv2(x) 30 | x = self.leaky_relu(x) 31 | x = self.dropout(x) 32 | 33 | x = self.conv3(x) 34 | x = self.leaky_relu(x) 35 | x = self.dropout(x) 36 | 37 | x = self.conv4(x) 38 | x = self.leaky_relu(x) 39 | #print(x.shape) 40 | x = self.avgpool(x) 41 | x = x.view(x.size(0), -1) 42 | # print(x.shape) 43 | x = self.classifier(x) 44 | x = x.reshape((batch_size, 2)) 45 | # x = self.up_sample(x) 46 | # x = self.sigmoid(x) 47 | 48 | return x 49 | 50 | 51 | class FC3DDiscriminator(nn.Module): 52 | 53 | def __init__(self, num_classes, ndf=64, n_channel=1): 54 | super(FC3DDiscriminator, self).__init__() 55 | # downsample 16 56 | self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 57 | self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1) 58 | 59 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 60 | self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 61 | self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 62 | self.avgpool = nn.AvgPool3d((7, 7, 5)) 63 | self.classifier = nn.Linear(ndf*8, 2) 64 | 65 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 66 | self.dropout = nn.Dropout3d(0.5) 67 | self.Softmax = nn.Softmax() 68 | 69 | def forward(self, map, image): 70 | batch_size = map.shape[0] 71 | map_feature = self.conv0(map) 72 | image_feature = self.conv1(image) 73 | x = torch.add(map_feature, image_feature) 74 | x = self.leaky_relu(x) 75 | x = self.dropout(x) 76 | 77 | x = self.conv2(x) 78 | x = self.leaky_relu(x) 79 | x = self.dropout(x) 80 | 81 | x = self.conv3(x) 82 | x = self.leaky_relu(x) 83 | x = self.dropout(x) 84 | 85 | x = self.conv4(x) 86 | x = self.leaky_relu(x) 87 | 88 | x = self.avgpool(x) 89 | 90 | x = x.view(batch_size, -1) 91 | x = self.classifier(x) 92 | x = x.reshape((batch_size, 2)) 93 | # x = self.Softmax(x) 94 | 95 | return x 96 | class FC3DDiscriminatorPanc(nn.Module): 97 | 98 | def __init__(self, num_classes, ndf=64, n_channel=1): 99 | super(FC3DDiscriminatorPanc, self).__init__() 100 | # downsample 16 101 | self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 102 | self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1) 103 | 104 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 105 | self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 106 | self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 107 | self.avgpool = nn.AvgPool3d((6, 6, 6)) 108 | self.classifier = nn.Linear(ndf*8, 2) 109 | 110 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 111 | self.dropout = nn.Dropout3d(0.5) 112 | self.Softmax = nn.Softmax() 113 | 114 | def forward(self, map, image): 115 | batch_size = map.shape[0] 116 | map_feature = self.conv0(map) 117 | image_feature = self.conv1(image) 118 | x = torch.add(map_feature, image_feature) 119 | x = self.leaky_relu(x) 120 | x = self.dropout(x) 121 | 122 | x = self.conv2(x) 123 | x = self.leaky_relu(x) 124 | x = self.dropout(x) 125 | 126 | x = self.conv3(x) 127 | x = self.leaky_relu(x) 128 | x = self.dropout(x) 129 | 130 | x = self.conv4(x) 131 | x = self.leaky_relu(x) 132 | 133 | x = self.avgpool(x) 134 | 135 | x = x.view(batch_size, -1) 136 | x = self.classifier(x) 137 | x = x.reshape((batch_size, 2)) 138 | # x = self.Softmax(x) 139 | 140 | return x 141 | 142 | class FC3DDiscriminatorNIH(nn.Module): 143 | def __init__(self, num_classes, ndf=64, n_channel=1): 144 | super(FC3DDiscriminatorNIH, self).__init__() 145 | # downsample 16 146 | self.conv0 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 147 | self.conv1 = nn.Conv3d(n_channel, ndf, kernel_size=4, stride=2, padding=1) 148 | 149 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 150 | self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 151 | self.conv4 = nn.Conv3d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 152 | self.avgpool = nn.AvgPool3d((13, 10, 9)) 153 | self.classifier = nn.Linear(ndf*8, 2) 154 | 155 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 156 | self.dropout = nn.Dropout3d(0.5) 157 | self.Softmax = nn.Softmax() 158 | 159 | def forward(self, map, image): 160 | batch_size = map.shape[0] 161 | map_feature = self.conv0(map) 162 | image_feature = self.conv1(image) 163 | x = torch.add(map_feature, image_feature) 164 | x = self.leaky_relu(x) 165 | x = self.dropout(x) 166 | 167 | x = self.conv2(x) 168 | x = self.leaky_relu(x) 169 | x = self.dropout(x) 170 | 171 | x = self.conv3(x) 172 | x = self.leaky_relu(x) 173 | x = self.dropout(x) 174 | 175 | x = self.conv4(x) 176 | x = self.leaky_relu(x) 177 | 178 | x = self.avgpool(x) 179 | 180 | x = x.view(batch_size, -1) 181 | x = self.classifier(x) 182 | x = x.reshape((batch_size, 2)) 183 | # x = self.Softmax(x) 184 | 185 | return x 186 | 187 | 188 | class FCDiscriminatorDAP(nn.Module): 189 | def __init__(self, num_classes, ndf = 64): 190 | super(FCDiscriminatorDAP, self).__init__() 191 | 192 | self.conv1 = nn.Conv3d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 193 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 194 | self.conv3 = nn.Conv3d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 195 | self.classifier = nn.Conv3d(ndf*4, 1, kernel_size=4, stride=2, padding=1) 196 | 197 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 198 | self.up_sample = nn.Upsample(scale_factor=16, mode='trilinear', align_corners=True) 199 | self.sigmoid = nn.Sigmoid() 200 | 201 | def forward(self, x): 202 | x = self.conv1(x) 203 | x = self.leaky_relu(x) 204 | x = self.conv2(x) 205 | x = self.leaky_relu(x) 206 | x = self.conv3(x) 207 | x = self.leaky_relu(x) 208 | x = self.classifier(x) 209 | x = self.up_sample(x) 210 | x = self.sigmoid(x) 211 | 212 | return x 213 | 214 | if __name__ == '__main__': 215 | # compute FLOPS & PARAMETERS 216 | from thop import profile 217 | from thop import clever_format 218 | model = FC3DDiscriminator(num_classes=1) 219 | input = torch.randn(4, 1, 112, 112, 80) 220 | flops, params = profile(model, inputs=(input,input)) 221 | macs, params = clever_format([flops, params], "%.3f") 222 | print(macs, params) 223 | 224 | model = FCDiscriminatorDAP(num_classes=2) 225 | input = torch.randn(4, 2, 112, 112, 80) 226 | flops, params = profile(model, inputs=(input,)) 227 | macs, params = clever_format([flops, params], "%.3f") 228 | print(macs, params) 229 | 230 | import ipdb; ipdb.set_trace() -------------------------------------------------------------------------------- /code/networks/vnet_sdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | Differences with V-Net 7 | Adding nn.Tanh in the end of the conv. to make the outputs in [-1, 1]. 8 | """ 9 | 10 | class ConvBlock(nn.Module): 11 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 12 | super(ConvBlock, self).__init__() 13 | 14 | ops = [] 15 | for i in range(n_stages): 16 | if i==0: 17 | input_channel = n_filters_in 18 | else: 19 | input_channel = n_filters_out 20 | 21 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 22 | if normalization == 'batchnorm': 23 | ops.append(nn.BatchNorm3d(n_filters_out)) 24 | elif normalization == 'groupnorm': 25 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 26 | elif normalization == 'instancenorm': 27 | ops.append(nn.InstanceNorm3d(n_filters_out)) 28 | elif normalization != 'none': 29 | assert False 30 | ops.append(nn.ReLU(inplace=True)) 31 | 32 | self.conv = nn.Sequential(*ops) 33 | 34 | def forward(self, x): 35 | x = self.conv(x) 36 | return x 37 | 38 | 39 | class ResidualConvBlock(nn.Module): 40 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 41 | super(ResidualConvBlock, self).__init__() 42 | 43 | ops = [] 44 | for i in range(n_stages): 45 | if i == 0: 46 | input_channel = n_filters_in 47 | else: 48 | input_channel = n_filters_out 49 | 50 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 51 | if normalization == 'batchnorm': 52 | ops.append(nn.BatchNorm3d(n_filters_out)) 53 | elif normalization == 'groupnorm': 54 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 55 | elif normalization == 'instancenorm': 56 | ops.append(nn.InstanceNorm3d(n_filters_out)) 57 | elif normalization != 'none': 58 | assert False 59 | 60 | if i != n_stages-1: 61 | ops.append(nn.ReLU(inplace=True)) 62 | 63 | self.conv = nn.Sequential(*ops) 64 | self.relu = nn.ReLU(inplace=True) 65 | 66 | def forward(self, x): 67 | x = (self.conv(x) + x) 68 | x = self.relu(x) 69 | return x 70 | 71 | 72 | class DownsamplingConvBlock(nn.Module): 73 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 74 | super(DownsamplingConvBlock, self).__init__() 75 | 76 | ops = [] 77 | if normalization != 'none': 78 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 79 | if normalization == 'batchnorm': 80 | ops.append(nn.BatchNorm3d(n_filters_out)) 81 | elif normalization == 'groupnorm': 82 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 83 | elif normalization == 'instancenorm': 84 | ops.append(nn.InstanceNorm3d(n_filters_out)) 85 | else: 86 | assert False 87 | else: 88 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 89 | 90 | ops.append(nn.ReLU(inplace=True)) 91 | 92 | self.conv = nn.Sequential(*ops) 93 | 94 | def forward(self, x): 95 | x = self.conv(x) 96 | return x 97 | 98 | 99 | class UpsamplingDeconvBlock(nn.Module): 100 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 101 | super(UpsamplingDeconvBlock, self).__init__() 102 | 103 | ops = [] 104 | if normalization != 'none': 105 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 106 | if normalization == 'batchnorm': 107 | ops.append(nn.BatchNorm3d(n_filters_out)) 108 | elif normalization == 'groupnorm': 109 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 110 | elif normalization == 'instancenorm': 111 | ops.append(nn.InstanceNorm3d(n_filters_out)) 112 | else: 113 | assert False 114 | else: 115 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 116 | 117 | ops.append(nn.ReLU(inplace=True)) 118 | 119 | self.conv = nn.Sequential(*ops) 120 | 121 | def forward(self, x): 122 | x = self.conv(x) 123 | return x 124 | 125 | 126 | class Upsampling(nn.Module): 127 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 128 | super(Upsampling, self).__init__() 129 | 130 | ops = [] 131 | ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False)) 132 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 133 | if normalization == 'batchnorm': 134 | ops.append(nn.BatchNorm3d(n_filters_out)) 135 | elif normalization == 'groupnorm': 136 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 137 | elif normalization == 'instancenorm': 138 | ops.append(nn.InstanceNorm3d(n_filters_out)) 139 | elif normalization != 'none': 140 | assert False 141 | ops.append(nn.ReLU(inplace=True)) 142 | 143 | self.conv = nn.Sequential(*ops) 144 | 145 | def forward(self, x): 146 | x = self.conv(x) 147 | return x 148 | 149 | 150 | class VNet(nn.Module): 151 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False): 152 | super(VNet, self).__init__() 153 | self.has_dropout = has_dropout 154 | convBlock = ConvBlock if not has_residual else ResidualConvBlock 155 | 156 | self.block_one = convBlock(1, n_channels, n_filters, normalization=normalization) 157 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 158 | 159 | self.block_two = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 160 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 161 | 162 | self.block_three = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 163 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 164 | 165 | self.block_four = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 166 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 167 | 168 | self.block_five = convBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 169 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 170 | 171 | self.block_six = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 172 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 173 | 174 | self.block_seven = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 175 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 176 | 177 | self.block_eight = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 178 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 179 | 180 | self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization) 181 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 182 | self.out_conv2 = nn.Conv3d(n_filters, n_classes, 1, padding=0) 183 | self.tanh = nn.Tanh() 184 | 185 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 186 | # self.__init_weight() 187 | 188 | def encoder(self, input): 189 | x1 = self.block_one(input) 190 | x1_dw = self.block_one_dw(x1) 191 | 192 | x2 = self.block_two(x1_dw) 193 | x2_dw = self.block_two_dw(x2) 194 | 195 | x3 = self.block_three(x2_dw) 196 | x3_dw = self.block_three_dw(x3) 197 | 198 | x4 = self.block_four(x3_dw) 199 | x4_dw = self.block_four_dw(x4) 200 | 201 | x5 = self.block_five(x4_dw) 202 | # x5 = F.dropout3d(x5, p=0.5, training=True) 203 | if self.has_dropout: 204 | x5 = self.dropout(x5) 205 | 206 | res = [x1, x2, x3, x4, x5] 207 | 208 | return res 209 | 210 | def decoder(self, features): 211 | x1 = features[0] 212 | x2 = features[1] 213 | x3 = features[2] 214 | x4 = features[3] 215 | x5 = features[4] 216 | 217 | x5_up = self.block_five_up(x5) 218 | x5_up = x5_up + x4 219 | 220 | x6 = self.block_six(x5_up) 221 | x6_up = self.block_six_up(x6) 222 | x6_up = x6_up + x3 223 | 224 | x7 = self.block_seven(x6_up) 225 | x7_up = self.block_seven_up(x7) 226 | x7_up = x7_up + x2 227 | 228 | x8 = self.block_eight(x7_up) 229 | x8_up = self.block_eight_up(x8) 230 | x8_up = x8_up + x1 231 | x9 = self.block_nine(x8_up) 232 | # x9 = F.dropout3d(x9, p=0.5, training=True) 233 | if self.has_dropout: 234 | x9 = self.dropout(x9) 235 | out = self.out_conv(x9) 236 | out_tanh = self.tanh(out) 237 | out_seg = self.out_conv2(x9) 238 | return out_tanh, out_seg 239 | 240 | 241 | def forward(self, input, turnoff_drop=False): 242 | if turnoff_drop: 243 | has_dropout = self.has_dropout 244 | self.has_dropout = False 245 | features = self.encoder(input) 246 | out_tanh, out_seg = self.decoder(features) 247 | if turnoff_drop: 248 | self.has_dropout = has_dropout 249 | return out_tanh, out_seg 250 | 251 | # def __init_weight(self): 252 | # for m in self.modules(): 253 | # if isinstance(m, nn.Conv3d): 254 | # torch.nn.init.kaiming_normal_(m.weight) 255 | # elif isinstance(m, nn.BatchNorm3d): 256 | # m.weight.data.fill_(1) 257 | 258 | if __name__ == '__main__': 259 | # compute FLOPS & PARAMETERS 260 | from thop import profile 261 | from thop import clever_format 262 | model = VNet(n_channels=1, n_classes=2) 263 | input = torch.randn(4, 1, 112, 112, 80) 264 | flops, params = profile(model, inputs=(input,)) 265 | macs, params = clever_format([flops, params], "%.3f") 266 | print(macs, params) 267 | print("VNet have {} paramerters in total".format(sum(x.numel() for x in model.parameters()))) 268 | 269 | # import ipdb; ipdb.set_trace() -------------------------------------------------------------------------------- /code/SCPNettrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from tensorboardX import SummaryWriter 16 | from torch.nn import BCEWithLogitsLoss 17 | from torch.nn.modules.loss import CrossEntropyLoss 18 | from torch.utils.data import DataLoader 19 | from torchvision import transforms 20 | from torchvision.utils import make_grid 21 | from tqdm import tqdm 22 | 23 | #from dataloaders import utils 24 | from dataloaders.dataset import BaseDataSets, RandomGenerator, TwoStreamBatchSampler 25 | from utils import losses, metrics, ramps, val_2d 26 | from networks.net_factory import net_factory 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--root_path', type=str, 30 | default='./data/ACDC', help='Name of Experiment') 31 | parser.add_argument('--exp', type=str, 32 | default='SCACDClabel2', help='experiment_name') 33 | parser.add_argument('--model', type=str, 34 | default='unet_pro', help='model_name') 35 | parser.add_argument('--max_iterations', type=int, 36 | default=40000, help='maximum epoch number to train') 37 | parser.add_argument('--batch_size', type=int, default=24, 38 | help='batch_size per gpu') 39 | parser.add_argument('--deterministic', type=int, default=1, 40 | help='whether use deterministic training') 41 | parser.add_argument('--base_lr', type=float, default=0.1, 42 | help='segmentation network learning rate') 43 | parser.add_argument('--patch_size', type=list, default=[256, 256], 44 | help='patch size of network input') 45 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 46 | parser.add_argument('--num_classes', type=int, default=4, 47 | help='output channel of network') 48 | 49 | # label and unlabel 50 | parser.add_argument('--labeled_bs', type=int, default=12, 51 | help='labeled_batch_size per gpu') 52 | parser.add_argument('--labeled_num', type=int, default=28, 53 | help='labeled data') 54 | # costs 55 | parser.add_argument('--consistency', type=float, 56 | default=0.1, help='consistency') 57 | parser.add_argument('--consistency_rampup', type=float, 58 | default=1000.0, help='consistency_rampup') 59 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 60 | args = parser.parse_args() 61 | 62 | 63 | def patients_to_slices(dataset, patiens_num): 64 | ref_dict = None 65 | if "ACDC" in dataset: 66 | ref_dict = {"3": 68, "7": 136, 67 | "14": 256, "21": 396, "28": 512, "35": 664, "140": 1312} 68 | elif "Prostate": 69 | ref_dict = {"2": 47, "4": 111, "7": 191, 70 | "11": 306, "14": 391, "18": 478, "35": 940} 71 | else: 72 | print("Error") 73 | return ref_dict[str(patiens_num)] 74 | 75 | 76 | def get_current_consistency_weight(epoch): 77 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 78 | return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup) 79 | 80 | 81 | def train(args, snapshot_path): 82 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 83 | base_lr = args.base_lr 84 | num_classes = args.num_classes 85 | batch_size = args.batch_size 86 | max_iterations = args.max_iterations 87 | 88 | model = net_factory(net_type=args.model, in_chns=1, 89 | class_num=num_classes) 90 | #model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('/home/zxzhang/MC-Net-Main/model/Prostate_EntropyProLearningAbL_cpcc_7_labeled/unet_pro/iter_16100_dice_0.7502.pth').items()}) 91 | 92 | def worker_init_fn(worker_id): 93 | random.seed(args.seed + worker_id) 94 | 95 | db_train = BaseDataSets(base_dir=args.root_path, split="train", num=None, transform=transforms.Compose([ 96 | RandomGenerator(args.patch_size) 97 | ])) 98 | db_val = BaseDataSets(base_dir=args.root_path, split="val") 99 | total_slices = len(db_train) 100 | labeled_slice = patients_to_slices(args.root_path, args.labeled_num) 101 | print("Total silices is: {}, labeled slices is: {}".format( 102 | total_slices, labeled_slice)) 103 | labeled_idxs = list(range(0, labeled_slice)) 104 | unlabeled_idxs = list(range(labeled_slice, total_slices)) 105 | batch_sampler = TwoStreamBatchSampler( 106 | labeled_idxs, unlabeled_idxs, batch_size, batch_size - args.labeled_bs) 107 | 108 | trainloader = DataLoader(db_train, batch_sampler=batch_sampler, 109 | num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 110 | 111 | model.train() 112 | 113 | valloader = DataLoader(db_val, batch_size=1, shuffle=False, 114 | num_workers=1) 115 | 116 | optimizer = optim.SGD(model.parameters(), lr=base_lr, 117 | momentum=0.9, weight_decay=0.0001) 118 | ce_loss = CrossEntropyLoss() 119 | dice_loss = losses.DiceLoss(num_classes) 120 | self_proloss = losses.weight_self_pro_softmax_mse_loss 121 | cross_proloss = losses.double_weight_cross_pro_softmax_mse_loss#(weight,input_logits, target_logits) 122 | writer = SummaryWriter(snapshot_path + '/log') 123 | logging.info("{} iterations per epoch".format(len(trainloader))) 124 | 125 | iter_num = 0 126 | max_epoch = max_iterations // len(trainloader) + 1 127 | best_performance = 0.0 128 | iterator = tqdm(range(max_epoch), ncols=70) 129 | for epoch_num in iterator: 130 | for i_batch, sampled_batch in enumerate(trainloader): 131 | print('hello') 132 | 133 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 134 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 135 | 136 | outputs, selfproout, crossproout, entropy = model(volume_batch) 137 | outputs_soft = torch.softmax(outputs, dim=1) 138 | 139 | loss_ce = ce_loss(outputs[:args.labeled_bs], 140 | label_batch[:args.labeled_bs][:].long()) 141 | 142 | loss_dice = dice_loss( 143 | outputs_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1)) 144 | 145 | supervised_loss = (loss_ce + loss_dice ) / 2 146 | 147 | consistency_weight = get_current_consistency_weight(iter_num // 150) 148 | consistency_self_pro = self_proloss(selfproout,outputs,entropy) 149 | consistency_cross_pro = cross_proloss(selfproout,crossproout,outputs,entropy) 150 | consistency_loss_aux1 = torch.mean( consistency_self_pro ) 151 | consistency_loss_aux2 = torch.mean(consistency_cross_pro) 152 | consistency_loss = (consistency_loss_aux1+ consistency_loss_aux2) / 2 153 | loss = supervised_loss + consistency_weight *consistency_loss #consistency_loss 154 | optimizer.zero_grad() 155 | loss.backward() 156 | optimizer.step() 157 | 158 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 159 | for param_group in optimizer.param_groups: 160 | param_group['lr'] = lr_ 161 | 162 | iter_num = iter_num + 1 163 | writer.add_scalar('info/lr', lr_, iter_num) 164 | writer.add_scalar('info/total_loss', loss, iter_num) 165 | writer.add_scalar('info/loss_ce', loss_ce, iter_num) 166 | writer.add_scalar('info/loss_dice', loss_dice, iter_num) 167 | writer.add_scalar('info/consistency_loss', 168 | consistency_loss, iter_num) 169 | writer.add_scalar('info/consistency_weight', 170 | consistency_weight, iter_num) 171 | logging.info( 172 | 'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' % 173 | (iter_num, loss.item(), loss_ce.item(), loss_dice.item())) 174 | 175 | 176 | if iter_num > 0 and iter_num % 100 == 0: 177 | model.eval() 178 | metric_list = 0.0 179 | for _, sampled_batch in enumerate(valloader): 180 | metric_i = val_2d.test_single_volume(sampled_batch["image"], sampled_batch["label"], model, 181 | classes=num_classes) 182 | metric_list += np.array(metric_i) 183 | metric_list = metric_list / len(db_val) 184 | for class_i in range(num_classes - 1): 185 | writer.add_scalar('info/val_{}_dice'.format(class_i + 1), metric_list[class_i, 0], iter_num) 186 | writer.add_scalar('info/val_{}_hd95'.format(class_i + 1), metric_list[class_i, 1], iter_num) 187 | 188 | performance = np.mean(metric_list, axis=0)[0] 189 | 190 | mean_hd95 = np.mean(metric_list, axis=0)[1] 191 | writer.add_scalar('info/val_mean_dice', performance, iter_num) 192 | writer.add_scalar('info/val_mean_hd95', mean_hd95, iter_num) 193 | 194 | if performance > best_performance: 195 | best_performance = performance 196 | save_mode_path = os.path.join(snapshot_path, 197 | 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance, 4))) 198 | save_best_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model)) 199 | torch.save(model.state_dict(), save_mode_path) 200 | torch.save(model.state_dict(), save_best_path) 201 | 202 | logging.info('iteration %d : mean_dice : %f mean_hd95 : %f' % (iter_num, performance, mean_hd95)) 203 | model.train() 204 | if iter_num >= max_iterations: 205 | break 206 | if iter_num >= max_iterations: 207 | iterator.close() 208 | break 209 | writer.close() 210 | return "Training Finished!" 211 | 212 | 213 | 214 | 215 | if __name__ == "__main__": 216 | if not args.deterministic: 217 | cudnn.benchmark = True 218 | cudnn.deterministic = False 219 | else: 220 | cudnn.benchmark = False 221 | cudnn.deterministic = True 222 | 223 | random.seed(args.seed) 224 | np.random.seed(args.seed) 225 | torch.manual_seed(args.seed) 226 | torch.cuda.manual_seed(args.seed) 227 | 228 | snapshot_path = "./model/Prostate_{}_{}_labeled/{}".format( 229 | args.exp, args.labeled_num, args.model) 230 | if not os.path.exists(snapshot_path): 231 | os.makedirs(snapshot_path) 232 | logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, 233 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 234 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 235 | logging.info(str(args)) 236 | train(args, snapshot_path) 237 | -------------------------------------------------------------------------------- /code/dataloaders/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import h5py 5 | import itertools 6 | from scipy import ndimage 7 | import random 8 | from torch.utils.data.sampler import Sampler 9 | from skimage import transform as sk_trans 10 | from scipy.ndimage import rotate, zoom 11 | 12 | class BaseDataSets(Dataset): 13 | def __init__(self, base_dir=None, split='train', num=None, transform=None): 14 | self._base_dir = base_dir 15 | self.sample_list = [] 16 | self.split = split 17 | self.transform = transform 18 | if self.split == 'train': 19 | with open(self._base_dir + '/train_slices.list', 'r') as f1: 20 | self.sample_list = f1.readlines() 21 | self.sample_list = [item.replace('\n', '') for item in self.sample_list] 22 | 23 | elif self.split == 'val': 24 | with open(self._base_dir + '/val.list', 'r') as f: 25 | self.sample_list = f.readlines() 26 | self.sample_list = [item.replace('\n', '') for item in self.sample_list] 27 | if num is not None and self.split == "train": 28 | self.sample_list = self.sample_list[:num] 29 | print("total {} samples".format(len(self.sample_list))) 30 | 31 | def __len__(self): 32 | return len(self.sample_list) 33 | 34 | def __getitem__(self, idx): 35 | case = self.sample_list[idx] 36 | if self.split == "train": 37 | h5f = h5py.File(self._base_dir + "/data/slices/{}.h5".format(case), 'r') 38 | else: 39 | h5f = h5py.File(self._base_dir + "/data/{}.h5".format(case), 'r') 40 | image = h5f['image'][:] 41 | label = h5f['label'][:] 42 | # label[label==1]=0 43 | # label[label==3]=0 44 | # label[label==2]=1we 45 | sample = {'image': image, 'label': label} 46 | if self.split == "train": 47 | sample = self.transform(sample) 48 | sample["idx"] = idx 49 | return sample 50 | 51 | def random_rot_flip(image, label): 52 | k = np.random.randint(0, 4) 53 | image = np.rot90(image, k) 54 | label = np.rot90(label, k) 55 | axis = np.random.randint(0, 2) 56 | image = np.flip(image, axis=axis).copy() 57 | label = np.flip(label, axis=axis).copy() 58 | return image, label 59 | 60 | 61 | def random_rotate(image, label): 62 | angle = np.random.randint(-20, 20) 63 | image = ndimage.rotate(image, angle, order=0, reshape=False) 64 | label = ndimage.rotate(label, angle, order=0, reshape=False) 65 | return image, label 66 | 67 | 68 | class RandomGenerator(object): 69 | def __init__(self, output_size): 70 | self.output_size = output_size 71 | 72 | def __call__(self, sample): 73 | image, label = sample['image'], sample['label'] 74 | # ind = random.randrange(0, img.shape[0]) 75 | # image = img[ind, ...] 76 | # label = lab[ind, ...] 77 | if random.random() > 0.5: 78 | image, label = random_rot_flip(image, label) 79 | elif random.random() > 0.5: 80 | image, label = random_rotate(image, label) 81 | x, y = image.shape 82 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=0) 83 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 84 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 85 | label = torch.from_numpy(label.astype(np.uint8)) 86 | sample = {'image': image, 'label': label} 87 | return sample 88 | 89 | 90 | class LAHeart(Dataset): 91 | """ LA Dataset """ 92 | def __init__(self, base_dir=None, split='train', num=None, transform=None): 93 | self._base_dir = base_dir 94 | self.transform = transform 95 | self.sample_list = [] 96 | 97 | train_path = self._base_dir+'/train.list' 98 | test_path = self._base_dir+'/test.list' 99 | 100 | if split=='train': 101 | with open(train_path, 'r') as f: 102 | self.image_list = f.readlines() 103 | elif split == 'test': 104 | with open(test_path, 'r') as f: 105 | self.image_list = f.readlines() 106 | 107 | self.image_list = [item.replace('\n','') for item in self.image_list] 108 | if num is not None: 109 | self.image_list = self.image_list[:num] 110 | print("total {} samples".format(len(self.image_list))) 111 | 112 | def __len__(self): 113 | return len(self.image_list) 114 | 115 | def __getitem__(self, idx): 116 | image_name = self.image_list[idx] 117 | h5f = h5py.File(self._base_dir + "/2018LA_Seg_TrainingSet/" + image_name + "/mri_norm2.h5", 'r') 118 | image = h5f['image'][:] 119 | label = h5f['label'][:] 120 | sample = {'image': image, 'label': label} 121 | if self.transform: 122 | sample = self.transform(sample) 123 | 124 | return sample 125 | 126 | class Pancreas(Dataset): 127 | """ Pancreas Dataset """ 128 | def __init__(self, base_dir=None, split='train', num=None, transform=None): 129 | self._base_dir = base_dir 130 | self.transform = transform 131 | self.sample_list = [] 132 | 133 | train_path = self._base_dir+'/train.list' 134 | test_path = self._base_dir+'/test.list' 135 | 136 | if split=='train': 137 | with open(train_path, 'r') as f: 138 | self.image_list = f.readlines() 139 | elif split == 'test': 140 | with open(test_path, 'r') as f: 141 | self.image_list = f.readlines() 142 | 143 | self.image_list = [item.replace('\n','') for item in self.image_list] 144 | if num is not None: 145 | self.image_list = self.image_list[:num] 146 | print("total {} samples".format(len(self.image_list))) 147 | 148 | def __len__(self): 149 | return len(self.image_list) 150 | 151 | def __getitem__(self, idx): 152 | image_name = self.image_list[idx] 153 | h5f = h5py.File(self._base_dir +'/' + image_name + "_norm.h5", 'r') 154 | image = h5f['image'][:] 155 | label = h5f['label'][:] 156 | sample = {'image': image, 'label': label} 157 | if self.transform: 158 | sample = self.transform(sample) 159 | 160 | return sample 161 | 162 | class Resize(object): 163 | 164 | def __init__(self, output_size): 165 | self.output_size = output_size 166 | 167 | def __call__(self, sample): 168 | image, label = sample['image'], sample['label'] 169 | (w, h, d) = image.shape 170 | label = label.astype(np.bool) 171 | image = sk_trans.resize(image, self.output_size, order = 1, mode = 'constant', cval = 0) 172 | label = sk_trans.resize(label, self.output_size, order = 0) 173 | assert(np.max(label) == 1 and np.min(label) == 0) 174 | assert(np.unique(label).shape[0] == 2) 175 | 176 | return {'image': image, 'label': label} 177 | 178 | 179 | class CenterCrop(object): 180 | def __init__(self, output_size): 181 | self.output_size = output_size 182 | 183 | def __call__(self, sample): 184 | image, label = sample['image'], sample['label'] 185 | 186 | # pad the sample if necessary 187 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 188 | self.output_size[2]: 189 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 190 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 191 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 192 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 193 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 194 | 195 | (w, h, d) = image.shape 196 | 197 | w1 = int(round((w - self.output_size[0]) / 2.)) 198 | h1 = int(round((h - self.output_size[1]) / 2.)) 199 | d1 = int(round((d - self.output_size[2]) / 2.)) 200 | 201 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 202 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 203 | 204 | return {'image': image, 'label': label} 205 | 206 | 207 | class RandomCrop(object): 208 | """ 209 | Crop randomly the image in a sample 210 | Args: 211 | output_size (int): Desired output size 212 | """ 213 | 214 | def __init__(self, output_size, with_sdf=False): 215 | self.output_size = output_size 216 | self.with_sdf = with_sdf 217 | 218 | def __call__(self, sample): 219 | image, label = sample['image'], sample['label'] 220 | if self.with_sdf: 221 | sdf = sample['sdf'] 222 | 223 | # pad the sample if necessary 224 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 225 | self.output_size[2]: 226 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 227 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 228 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 229 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 230 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 231 | if self.with_sdf: 232 | sdf = np.pad(sdf, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 233 | 234 | (w, h, d) = image.shape 235 | 236 | w1 = np.random.randint(0, w - self.output_size[0]) 237 | h1 = np.random.randint(0, h - self.output_size[1]) 238 | d1 = np.random.randint(0, d - self.output_size[2]) 239 | 240 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 241 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 242 | if self.with_sdf: 243 | sdf = sdf[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 244 | return {'image': image, 'label': label, 'sdf': sdf} 245 | else: 246 | return {'image': image, 'label': label} 247 | 248 | 249 | class RandomRotFlip(object): 250 | """ 251 | Crop randomly flip the dataset in a sample 252 | Args: 253 | output_size (int): Desired output size 254 | """ 255 | 256 | def __call__(self, sample): 257 | image, label = sample['image'], sample['label'] 258 | image, label = random_rot_flip(image, label) 259 | 260 | return {'image': image, 'label': label} 261 | 262 | class RandomRot(object): 263 | """ 264 | Crop randomly flip the dataset in a sample 265 | Args: 266 | output_size (int): Desired output size 267 | """ 268 | 269 | def __call__(self, sample): 270 | image, label = sample['image'], sample['label'] 271 | image, label = random_rotate(image, label) 272 | 273 | return {'image': image, 'label': label} 274 | 275 | 276 | class RandomNoise(object): 277 | def __init__(self, mu=0, sigma=0.1): 278 | self.mu = mu 279 | self.sigma = sigma 280 | 281 | def __call__(self, sample): 282 | image, label = sample['image'], sample['label'] 283 | noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma) 284 | noise = noise + self.mu 285 | image = image + noise 286 | return {'image': image, 'label': label} 287 | 288 | 289 | class CreateOnehotLabel(object): 290 | def __init__(self, num_classes): 291 | self.num_classes = num_classes 292 | 293 | def __call__(self, sample): 294 | image, label = sample['image'], sample['label'] 295 | onehot_label = np.zeros((self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32) 296 | for i in range(self.num_classes): 297 | onehot_label[i, :, :, :] = (label == i).astype(np.float32) 298 | return {'image': image, 'label': label,'onehot_label':onehot_label} 299 | 300 | 301 | class ToTensor(object): 302 | """Convert ndarrays in sample to Tensors.""" 303 | 304 | def __call__(self, sample): 305 | image = sample['image'] 306 | image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 307 | if 'onehot_label' in sample: 308 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long(), 309 | 'onehot_label': torch.from_numpy(sample['onehot_label']).long()} 310 | else: 311 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()} 312 | 313 | 314 | class TwoStreamBatchSampler(Sampler): 315 | """Iterate two sets of indices 316 | 317 | An 'epoch' is one iteration through the primary indices. 318 | During the epoch, the secondary indices are iterated through 319 | as many times as needed. 320 | """ 321 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 322 | self.primary_indices = primary_indices 323 | self.secondary_indices = secondary_indices 324 | self.secondary_batch_size = secondary_batch_size 325 | self.primary_batch_size = batch_size - secondary_batch_size 326 | 327 | assert len(self.primary_indices) >= self.primary_batch_size > 0 328 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 329 | 330 | def __iter__(self): 331 | primary_iter = iterate_once(self.primary_indices) 332 | secondary_iter = iterate_eternally(self.secondary_indices) 333 | return ( 334 | primary_batch + secondary_batch 335 | for (primary_batch, secondary_batch) 336 | in zip(grouper(primary_iter, self.primary_batch_size), 337 | grouper(secondary_iter, self.secondary_batch_size)) 338 | ) 339 | 340 | def __len__(self): 341 | return len(self.primary_indices) // self.primary_batch_size 342 | 343 | def iterate_once(iterable): 344 | return np.random.permutation(iterable) 345 | 346 | 347 | def iterate_eternally(indices): 348 | def infinite_shuffles(): 349 | while True: 350 | yield np.random.permutation(indices) 351 | return itertools.chain.from_iterable(infinite_shuffles()) 352 | 353 | 354 | def grouper(iterable, n): 355 | "Collect data into fixed-length chunks or blocks" 356 | # grouper('ABCDEFG', 3) --> ABC DEF" 357 | args = [iter(iterable)] * n 358 | return zip(*args) 359 | -------------------------------------------------------------------------------- /code/networks/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from networks.networks_other import init_weights 6 | 7 | 8 | class conv2DBatchNorm(nn.Module): 9 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 10 | super(conv2DBatchNorm, self).__init__() 11 | 12 | self.cb_unit = nn.Sequential(nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 13 | padding=padding, stride=stride, bias=bias), 14 | nn.BatchNorm2d(int(n_filters)),) 15 | 16 | def forward(self, inputs): 17 | outputs = self.cb_unit(inputs) 18 | return outputs 19 | 20 | 21 | class deconv2DBatchNorm(nn.Module): 22 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 23 | super(deconv2DBatchNorm, self).__init__() 24 | 25 | self.dcb_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, 26 | padding=padding, stride=stride, bias=bias), 27 | nn.BatchNorm2d(int(n_filters)),) 28 | 29 | def forward(self, inputs): 30 | outputs = self.dcb_unit(inputs) 31 | return outputs 32 | 33 | 34 | class conv2DBatchNormRelu(nn.Module): 35 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 36 | super(conv2DBatchNormRelu, self).__init__() 37 | 38 | self.cbr_unit = nn.Sequential(nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 39 | padding=padding, stride=stride, bias=bias), 40 | nn.BatchNorm2d(int(n_filters)), 41 | nn.ReLU(inplace=True),) 42 | 43 | def forward(self, inputs): 44 | outputs = self.cbr_unit(inputs) 45 | return outputs 46 | 47 | 48 | class deconv2DBatchNormRelu(nn.Module): 49 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 50 | super(deconv2DBatchNormRelu, self).__init__() 51 | 52 | self.dcbr_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, 53 | padding=padding, stride=stride, bias=bias), 54 | nn.BatchNorm2d(int(n_filters)), 55 | nn.ReLU(inplace=True),) 56 | 57 | def forward(self, inputs): 58 | outputs = self.dcbr_unit(inputs) 59 | return outputs 60 | 61 | 62 | class unetConv2(nn.Module): 63 | def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): 64 | super(unetConv2, self).__init__() 65 | self.n = n 66 | self.ks = ks 67 | self.stride = stride 68 | self.padding = padding 69 | s = stride 70 | p = padding 71 | if is_batchnorm: 72 | for i in range(1, n+1): 73 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 74 | nn.BatchNorm2d(out_size), 75 | nn.ReLU(inplace=True),) 76 | setattr(self, 'conv%d'%i, conv) 77 | in_size = out_size 78 | 79 | else: 80 | for i in range(1, n+1): 81 | conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), 82 | nn.ReLU(inplace=True),) 83 | setattr(self, 'conv%d'%i, conv) 84 | in_size = out_size 85 | 86 | # initialise the blocks 87 | for m in self.children(): 88 | init_weights(m, init_type='kaiming') 89 | 90 | def forward(self, inputs): 91 | x = inputs 92 | for i in range(1, self.n+1): 93 | conv = getattr(self, 'conv%d'%i) 94 | x = conv(x) 95 | 96 | return x 97 | 98 | 99 | class UnetConv3(nn.Module): 100 | def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)): 101 | super(UnetConv3, self).__init__() 102 | 103 | if is_batchnorm: 104 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), 105 | nn.InstanceNorm3d(out_size), 106 | nn.ReLU(inplace=True),) 107 | self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), 108 | nn.InstanceNorm3d(out_size), 109 | nn.ReLU(inplace=True),) 110 | else: 111 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), 112 | nn.ReLU(inplace=True),) 113 | self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), 114 | nn.ReLU(inplace=True),) 115 | 116 | # initialise the blocks 117 | for m in self.children(): 118 | init_weights(m, init_type='kaiming') 119 | 120 | def forward(self, inputs): 121 | outputs = self.conv1(inputs) 122 | outputs = self.conv2(outputs) 123 | return outputs 124 | 125 | 126 | class FCNConv3(nn.Module): 127 | def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)): 128 | super(FCNConv3, self).__init__() 129 | 130 | if is_batchnorm: 131 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), 132 | nn.InstanceNorm3d(out_size), 133 | nn.ReLU(inplace=True),) 134 | self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), 135 | nn.InstanceNorm3d(out_size), 136 | nn.ReLU(inplace=True),) 137 | self.conv3 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), 138 | nn.InstanceNorm3d(out_size), 139 | nn.ReLU(inplace=True),) 140 | else: 141 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), 142 | nn.ReLU(inplace=True),) 143 | self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), 144 | nn.ReLU(inplace=True),) 145 | self.conv3 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), 146 | nn.ReLU(inplace=True),) 147 | 148 | # initialise the blocks 149 | for m in self.children(): 150 | init_weights(m, init_type='kaiming') 151 | 152 | def forward(self, inputs): 153 | outputs = self.conv1(inputs) 154 | outputs = self.conv2(outputs) 155 | outputs = self.conv3(outputs) 156 | return outputs 157 | 158 | 159 | class UnetGatingSignal3(nn.Module): 160 | def __init__(self, in_size, out_size, is_batchnorm): 161 | super(UnetGatingSignal3, self).__init__() 162 | self.fmap_size = (4, 4, 4) 163 | 164 | if is_batchnorm: 165 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, in_size//2, (1,1,1), (1,1,1), (0,0,0)), 166 | nn.InstanceNorm3d(in_size//2), 167 | nn.ReLU(inplace=True), 168 | nn.AdaptiveAvgPool3d(output_size=self.fmap_size), 169 | ) 170 | self.fc1 = nn.Linear(in_features=(in_size//2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], 171 | out_features=out_size, bias=True) 172 | else: 173 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, in_size//2, (1,1,1), (1,1,1), (0,0,0)), 174 | nn.ReLU(inplace=True), 175 | nn.AdaptiveAvgPool3d(output_size=self.fmap_size), 176 | ) 177 | self.fc1 = nn.Linear(in_features=(in_size//2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], 178 | out_features=out_size, bias=True) 179 | 180 | # initialise the blocks 181 | for m in self.children(): 182 | init_weights(m, init_type='kaiming') 183 | 184 | def forward(self, inputs): 185 | batch_size = inputs.size(0) 186 | outputs = self.conv1(inputs) 187 | outputs = outputs.view(batch_size, -1) 188 | outputs = self.fc1(outputs) 189 | return outputs 190 | 191 | 192 | class UnetGridGatingSignal3(nn.Module): 193 | def __init__(self, in_size, out_size, kernel_size=(1,1,1), is_batchnorm=True): 194 | super(UnetGridGatingSignal3, self).__init__() 195 | 196 | if is_batchnorm: 197 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)), 198 | nn.InstanceNorm3d(out_size), 199 | nn.ReLU(inplace=True), 200 | ) 201 | else: 202 | self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)), 203 | nn.ReLU(inplace=True), 204 | ) 205 | 206 | # initialise the blocks 207 | for m in self.children(): 208 | init_weights(m, init_type='kaiming') 209 | 210 | def forward(self, inputs): 211 | outputs = self.conv1(inputs) 212 | return outputs 213 | 214 | 215 | class unetUp(nn.Module): 216 | def __init__(self, in_size, out_size, is_deconv): 217 | super(unetUp, self).__init__() 218 | self.conv = unetConv2(in_size, out_size, False) 219 | if is_deconv: 220 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) 221 | else: 222 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 223 | 224 | # initialise the blocks 225 | for m in self.children(): 226 | if m.__class__.__name__.find('unetConv2') != -1: continue 227 | init_weights(m, init_type='kaiming') 228 | 229 | def forward(self, inputs1, inputs2): 230 | outputs2 = self.up(inputs2) 231 | offset = outputs2.size()[2] - inputs1.size()[2] 232 | padding = 2 * [offset // 2, offset // 2] 233 | outputs1 = F.pad(inputs1, padding) 234 | return self.conv(torch.cat([outputs1, outputs2], 1)) 235 | 236 | 237 | class UnetUp3(nn.Module): 238 | def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True): 239 | super(UnetUp3, self).__init__() 240 | if is_deconv: 241 | self.conv = UnetConv3(in_size, out_size, is_batchnorm) 242 | self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0)) 243 | else: 244 | self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm) 245 | self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear') 246 | 247 | # initialise the blocks 248 | for m in self.children(): 249 | if m.__class__.__name__.find('UnetConv3') != -1: continue 250 | init_weights(m, init_type='kaiming') 251 | 252 | def forward(self, inputs1, inputs2): 253 | outputs2 = self.up(inputs2) 254 | offset = outputs2.size()[2] - inputs1.size()[2] 255 | padding = 2 * [offset // 2, offset // 2, 0] 256 | outputs1 = F.pad(inputs1, padding) 257 | return self.conv(torch.cat([outputs1, outputs2], 1)) 258 | 259 | 260 | class UnetUp3_CT(nn.Module): 261 | def __init__(self, in_size, out_size, is_batchnorm=True): 262 | super(UnetUp3_CT, self).__init__() 263 | self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 264 | self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear') 265 | 266 | # initialise the blocks 267 | for m in self.children(): 268 | if m.__class__.__name__.find('UnetConv3') != -1: continue 269 | init_weights(m, init_type='kaiming') 270 | 271 | def forward(self, inputs1, inputs2): 272 | outputs2 = self.up(inputs2) 273 | offset = outputs2.size()[2] - inputs1.size()[2] 274 | padding = 2 * [offset // 2, offset // 2, 0] 275 | outputs1 = F.pad(inputs1, padding) 276 | return self.conv(torch.cat([outputs1, outputs2], 1)) 277 | 278 | 279 | # Squeeze-and-Excitation Network 280 | class SqEx(nn.Module): 281 | 282 | def __init__(self, n_features, reduction=6): 283 | super(SqEx, self).__init__() 284 | 285 | if n_features % reduction != 0: 286 | raise ValueError('n_features must be divisible by reduction (default = 4)') 287 | 288 | self.linear1 = nn.Linear(n_features, n_features // reduction, bias=False) 289 | self.nonlin1 = nn.ReLU(inplace=True) 290 | self.linear2 = nn.Linear(n_features // reduction, n_features, bias=False) 291 | self.nonlin2 = nn.Sigmoid() 292 | 293 | def forward(self, x): 294 | 295 | y = F.avg_pool3d(x, kernel_size=x.size()[2:5]) 296 | y = y.permute(0, 2, 3, 4, 1) 297 | y = self.nonlin1(self.linear1(y)) 298 | y = self.nonlin2(self.linear2(y)) 299 | y = y.permute(0, 4, 1, 2, 3) 300 | y = x * y 301 | return y 302 | 303 | class UnetUp3_SqEx(nn.Module): 304 | def __init__(self, in_size, out_size, is_deconv, is_batchnorm): 305 | super(UnetUp3_SqEx, self).__init__() 306 | if is_deconv: 307 | self.sqex = SqEx(n_features=in_size+out_size) 308 | self.conv = UnetConv3(in_size, out_size, is_batchnorm) 309 | self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0)) 310 | else: 311 | self.sqex = SqEx(n_features=in_size+out_size) 312 | self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm) 313 | self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear') 314 | 315 | # initialise the blocks 316 | for m in self.children(): 317 | if m.__class__.__name__.find('UnetConv3') != -1: continue 318 | init_weights(m, init_type='kaiming') 319 | 320 | def forward(self, inputs1, inputs2): 321 | outputs2 = self.up(inputs2) 322 | offset = outputs2.size()[2] - inputs1.size()[2] 323 | padding = 2 * [offset // 2, offset // 2, 0] 324 | outputs1 = F.pad(inputs1, padding) 325 | concat = torch.cat([outputs1, outputs2], 1) 326 | gated = self.sqex(concat) 327 | return self.conv(gated) 328 | 329 | class residualBlock(nn.Module): 330 | expansion = 1 331 | 332 | def __init__(self, in_channels, n_filters, stride=1, downsample=None): 333 | super(residualBlock, self).__init__() 334 | 335 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False) 336 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False) 337 | self.downsample = downsample 338 | self.stride = stride 339 | self.relu = nn.ReLU(inplace=True) 340 | 341 | def forward(self, x): 342 | residual = x 343 | 344 | out = self.convbnrelu1(x) 345 | out = self.convbn2(out) 346 | 347 | if self.downsample is not None: 348 | residual = self.downsample(x) 349 | 350 | out += residual 351 | out = self.relu(out) 352 | return out 353 | 354 | 355 | class residualBottleneck(nn.Module): 356 | expansion = 4 357 | 358 | def __init__(self, in_channels, n_filters, stride=1, downsample=None): 359 | super(residualBottleneck, self).__init__() 360 | self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False) 361 | self.convbn2 = nn.Conv2DBatchNorm(n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False) 362 | self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False) 363 | self.relu = nn.ReLU(inplace=True) 364 | self.downsample = downsample 365 | self.stride = stride 366 | 367 | def forward(self, x): 368 | residual = x 369 | 370 | out = self.convbn1(x) 371 | out = self.convbn2(out) 372 | out = self.convbn3(out) 373 | 374 | if self.downsample is not None: 375 | residual = self.downsample(x) 376 | 377 | out += residual 378 | out = self.relu(out) 379 | 380 | return out 381 | 382 | 383 | 384 | 385 | class SeqModelFeatureExtractor(nn.Module): 386 | def __init__(self, submodule, extracted_layers): 387 | super(SeqModelFeatureExtractor, self).__init__() 388 | 389 | self.submodule = submodule 390 | self.extracted_layers = extracted_layers 391 | 392 | def forward(self, x): 393 | outputs = [] 394 | for name, module in self.submodule._modules.items(): 395 | x = module(x) 396 | if name in self.extracted_layers: 397 | outputs += [x] 398 | return outputs + [x] 399 | 400 | 401 | class HookBasedFeatureExtractor(nn.Module): 402 | def __init__(self, submodule, layername, upscale=False): 403 | super(HookBasedFeatureExtractor, self).__init__() 404 | 405 | self.submodule = submodule 406 | self.submodule.eval() 407 | self.layername = layername 408 | self.outputs_size = None 409 | self.outputs = None 410 | self.inputs = None 411 | self.inputs_size = None 412 | self.upscale = upscale 413 | 414 | def get_input_array(self, m, i, o): 415 | if isinstance(i, tuple): 416 | self.inputs = [i[index].data.clone() for index in range(len(i))] 417 | self.inputs_size = [input.size() for input in self.inputs] 418 | else: 419 | self.inputs = i.data.clone() 420 | self.inputs_size = self.input.size() 421 | print('Input Array Size: ', self.inputs_size) 422 | 423 | def get_output_array(self, m, i, o): 424 | if isinstance(o, tuple): 425 | self.outputs = [o[index].data.clone() for index in range(len(o))] 426 | self.outputs_size = [output.size() for output in self.outputs] 427 | else: 428 | self.outputs = o.data.clone() 429 | self.outputs_size = self.outputs.size() 430 | print('Output Array Size: ', self.outputs_size) 431 | 432 | def rescale_output_array(self, newsize): 433 | us = nn.Upsample(size=newsize[2:], mode='bilinear') 434 | if isinstance(self.outputs, list): 435 | for index in range(len(self.outputs)): self.outputs[index] = us(self.outputs[index]).data() 436 | else: 437 | self.outputs = us(self.outputs).data() 438 | 439 | def forward(self, x): 440 | target_layer = self.submodule._modules.get(self.layername) 441 | 442 | # Collect the output tensor 443 | h_inp = target_layer.register_forward_hook(self.get_input_array) 444 | h_out = target_layer.register_forward_hook(self.get_output_array) 445 | self.submodule(x) 446 | h_inp.remove() 447 | h_out.remove() 448 | 449 | # Rescale the feature-map if it's required 450 | if self.upscale: self.rescale_output_array(x.size()) 451 | 452 | return self.inputs, self.outputs 453 | 454 | 455 | class UnetDsv3(nn.Module): 456 | def __init__(self, in_size, out_size, scale_factor): 457 | super(UnetDsv3, self).__init__() 458 | self.dsv = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0), 459 | nn.Upsample(scale_factor=scale_factor, mode='trilinear'), ) 460 | 461 | def forward(self, input): 462 | return self.dsv(input) -------------------------------------------------------------------------------- /code/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from skimage.measure import label 5 | import numpy as np 6 | def weight_self_pro_softmax_mse_loss(input_logits, target_logits,entropy): 7 | """Takes softmax on both sides and returns MSE loss 8 | Note: 9 | - Returns the sum over all examples. Divide by the batch size afterwards 10 | if you want the mean. 11 | - Sends gradients to inputs but not the targets. 12 | """ 13 | target_logits = target_logits .view(target_logits .size(0), target_logits .size(1), -1) 14 | target_logits = target_logits.transpose(1, 2) # [N, HW, C] 15 | assert input_logits.size() == target_logits.size() 16 | input_softmax = F.softmax(input_logits, dim=2) 17 | target_softmax = F.softmax(target_logits, dim=2) 18 | mse_loss = (input_softmax.detach()-target_softmax)**2 19 | #entropy = 1-entropy.unsqueeze(-1).detach() 20 | #mse_loss =entropy*mse_loss 21 | return mse_loss 22 | def weight_cross_pro_softmax_mse_loss(weight,input_logits, target_logits): ##target_logits==classfier input_logit = cross_prototype 23 | """Takes softmax on both sides and returns MSE loss 24 | Note: 25 | - Returns the sum over all examples. Divide by the batch size afterwards 26 | if you want the mean. 27 | - Sends gradients to inputs but not the targets. 28 | """ 29 | weight = F.softmax(weight,dim=2) 30 | target_logits = target_logits .view(target_logits .size(0), target_logits .size(1), -1) 31 | target_logits = target_logits.transpose(1, 2) # [N, HW, C] 32 | assert input_logits.size() == target_logits.size() 33 | #input_softmax = F.softmax(input_logits, dim=2) 34 | target_softmax = F.softmax(target_logits, dim=2) 35 | mse_loss = (input_logits.detach()-target_softmax)**2 36 | mse_loss = weight.detach()*mse_loss 37 | return mse_loss 38 | def double_weight_cross_pro_softmax_mse_loss(weight,input_logits, target_logits,entropy): ##target_logits==classfier input_logit = cross_prototype 39 | """Takes softmax on both sides and returns MSE loss 40 | Note: 41 | - Returns the sum over all examples. Divide by the batch size afterwards 42 | if you want the mean. 43 | - Sends gradients to inputs but not the targets. 44 | """ 45 | weight = F.softmax(weight,dim=2) 46 | weight = torch.max(weight,dim=2,keepdim=True)[0] 47 | target_logits = target_logits.view(target_logits .size(0), target_logits .size(1), -1) 48 | target_logits = target_logits.transpose(1, 2) # [N, HW, C] 49 | assert input_logits.size() == target_logits.size() 50 | #input_softmax = F.softmax(input_logits, dim=2) 51 | target_softmax = F.softmax(target_logits, dim=2) 52 | mse_loss = (input_logits.detach()-target_softmax)**2 53 | mse_loss = weight.detach()*mse_loss 54 | entropy = 1 - entropy.unsqueeze(-1).detach() 55 | return entropy * mse_loss 56 | def softmax_kl_loss(input_logits, target_logits, sigmoid=False): 57 | """Takes softmax on both sides and returns KL divergence 58 | 59 | Note: 60 | - Returns the sum over all examples. Divide by the batch size afterwards 61 | if you want the mean. 62 | - Sends gradients to inputs but not the targets. 63 | """ 64 | assert input_logits.size() == target_logits.size() 65 | if sigmoid: 66 | input_log_softmax = torch.log(torch.sigmoid(input_logits)) 67 | target_softmax = torch.sigmoid(target_logits) 68 | else: 69 | input_log_softmax = F.log_softmax(input_logits, dim=1) 70 | target_softmax = F.softmax(target_logits, dim=1) 71 | 72 | # return F.kl_div(input_log_softmax, target_softmax) 73 | kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean') 74 | # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...]) 75 | return kl_div 76 | def dice_loss(score, target): 77 | target = target.float() 78 | smooth = 1e-5 79 | intersect = torch.sum(score * target) 80 | y_sum = torch.sum(target * target) 81 | z_sum = torch.sum(score * score) 82 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 83 | loss = 1 - loss 84 | return loss 85 | def softmax_mse_loss(input_logits, target_logits): 86 | """Takes softmax on both sides and returns MSE loss 87 | Note: 88 | - Returns the sum over all examples. Divide by the batch size afterwards 89 | if you want the mean. 90 | - Sends gradients to inputs but not the targets. 91 | """ 92 | assert input_logits.size() == target_logits.size() 93 | input_softmax = F.softmax(input_logits, dim=1) 94 | #target_softmax = F.softmax(target_logits, dim=1) 95 | 96 | mse_loss = (input_softmax-target_logits)**2 97 | return mse_loss 98 | def softmax_mae_loss_EGV(input_logits, target_logits): 99 | """Takes softmax on both sides and returns MSE loss 100 | Note: 101 | - Returns the sum over all examples. Divide by the batch size afterwards 102 | if you want the mean. 103 | - Sends gradients to inputs but not the targets. 104 | """ 105 | assert input_logits.size() == target_logits.size() 106 | input_softmax = F.softmax(input_logits, dim=1) 107 | #target_softmax = F.softmax(target_logits, dim=1) 108 | 109 | mae_loss = torch.abs(input_softmax-target_logits) 110 | return mae_loss 111 | def softmax_mse_loss_VGE(input_logits, target_logits): 112 | """Takes softmax on both sides and returns MSE loss 113 | Note: 114 | - Returns the sum over all examples. Divide by the batch size afterwards 115 | if you want the mean. 116 | - Sends gradients to inputs but not the targets. 117 | """ 118 | assert input_logits.size() == target_logits.size() 119 | target_softmax = F.softmax(target_logits, dim=1) 120 | #target_softmax = F.softmax(target_logits, dim=1) 121 | 122 | mse_loss = (input_logits-target_softmax.detach())**2 123 | return mse_loss 124 | def softmax_mae_loss_VGE(input_logits, target_logits): 125 | """Takes softmax on both sides and returns MSE loss 126 | Note: 127 | - Returns the sum over all examples. Divide by the batch size afterwards 128 | if you want the mean. 129 | - Sends gradients to inputs but not the targets. 130 | """ 131 | assert input_logits.size() == target_logits.size() 132 | target_softmax = F.softmax(target_logits, dim=1) 133 | #target_softmax = F.softmax(target_logits, dim=1) 134 | 135 | mae_loss = torch.abs(input_logits-target_softmax.detach()) 136 | return mae_loss 137 | def softmax_mae_loss_DiceCE(input_logits, target_logits): 138 | """Takes softmax on both sides and returns MSE loss 139 | Note: 140 | - Returns the sum over all examples. Divide by the batch size afterwards 141 | if you want the mean. 142 | - Sends gradients to inputs but not the targets. 143 | """ 144 | assert input_logits.size() == target_logits.size() 145 | #target_softmax = F.softmax(target_logits, dim=1) 146 | #target_softmax = F.softmax(target_logits, dim=1) 147 | 148 | mae_loss = torch.abs(input_logits-target_logits.detach()) 149 | return mae_loss 150 | def dce_eviloss(p, alpha, c, global_step, annealing_step): 151 | #evidence = F.softplus(prob) 152 | # L_dice = TDice(alpha,p,criterion_dl) 153 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 154 | alpha = alpha.transpose(1, 2) # [N, HW, C] 155 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 156 | S = torch.sum(alpha, dim=1, keepdim=True) 157 | E = alpha - 1 158 | label = F.one_hot(p, num_classes=c) 159 | label = label.view(-1, c) 160 | # digama loss 161 | a = torch.tensor(([1,10])) 162 | L_ace = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha))*a.cuda(), dim=1, keepdim=True) 163 | L_ace = torch.mean(L_ace.squeeze()) 164 | #print(L_ace.shape) 165 | # log loss 166 | # labelK = label * (torch.log(S) - torch.log(alpha)) 167 | # L_ace = torch.sum(label * (torch.log(S) - torch.log(alpha)), dim=1, keepdim=True) 168 | 169 | annealing_coef = min(1, global_step / annealing_step)*0.4 170 | alp = E * (1 - label) + 1 171 | L_KL = annealing_coef * KL(alp, c) 172 | L_KL = torch.mean(L_KL.squeeze()) 173 | return L_ace + L_KL 174 | def dce_eviloss_2d(p, alpha, c, global_step, annealing_step): 175 | #evidence = F.softplus(prob) 176 | # L_dice = TDice(alpha,p,criterion_dl) 177 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 178 | alpha = alpha.transpose(1, 2) # [N, HW, C] 179 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 180 | S = torch.sum(alpha, dim=1, keepdim=True) 181 | E = alpha - 1 182 | label = F.one_hot(p.long(), num_classes=c) 183 | label = label.view(-1, c) 184 | # digama loss 185 | a = torch.tensor(([1,5,10,5])) 186 | L_ace = torch.sum(label * (torch.digamma(S) - torch.digamma(alpha))*a.cuda(), dim=1, keepdim=True) 187 | L_ace = torch.mean(L_ace.squeeze()) 188 | #print(L_ace.shape) 189 | # log loss 190 | # labelK = label * (torch.log(S) - torch.log(alpha)) 191 | # L_ace = torch.sum(label * (torch.log(S) - torch.log(alpha)), dim=1, keepdim=True) 192 | 193 | annealing_coef = min(1, global_step / annealing_step) 194 | alp = E * (1 - label) + 1 195 | L_KL = annealing_coef * KL(alp, c) 196 | L_KL = torch.mean(L_KL.squeeze()) 197 | return (L_ace + L_KL) 198 | def L_KL(label, alpha, c, global_step, annealing_step): 199 | #evidence = F.softplus(prob) 200 | # L_dice = TDice(alpha,p,criterion_dl) 201 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 202 | alpha = alpha.transpose(1, 2) # [N, HW, C] 203 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 204 | S = torch.sum(alpha, dim=1, keepdim=True) 205 | E = alpha - 1 206 | label = F.one_hot(label, num_classes=c) 207 | label = label.view(-1, c) 208 | # digama loss 209 | # log loss 210 | # labelK = label * (torch.log(S) - torch.log(alpha)) 211 | # L_ace = torch.sum(label * (torch.log(S) - torch.log(alpha)), dim=1, keepdim=True) 212 | 213 | annealing_coef = min(1, global_step / annealing_step) 214 | alp = E * (1 - label) + 1 215 | L_KL = annealing_coef * KL(alp, c) 216 | L_KL = torch.mean(L_KL.squeeze()) 217 | return L_KL 218 | def unlabelplainloss(unlabelpred, alpha, c): 219 | #evidence = F.softplus(prob) 220 | # L_dice = TDice(alpha,p,criterion_dl) 221 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 222 | alpha = alpha.transpose(1, 2) # [N, HW, C] 223 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 224 | S = torch.sum(alpha, dim=1, keepdim=True) 225 | W = 1-c/S 226 | P = alpha/S 227 | unlabelpred = unlabelpred.view(unlabelpred.size(0), unlabelpred.size(1), -1) 228 | unlabelpred = unlabelpred.transpose(1, 2) # [N, HW, C] 229 | unlabelpred = unlabelpred.contiguous().view(-1, unlabelpred.size(2)) 230 | L_con_1 = torch.sum(softmax_mse_loss(unlabelpred,P.detach())*W,dim=1)/2 231 | L_con_1 = torch.mean(L_con_1) 232 | return L_con_1 233 | def unlabelplainmaeloss(unlabelpred, alpha, c): 234 | #evidence = F.softplus(prob) 235 | # L_dice = TDice(alpha,p,criterion_dl) 236 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 237 | alpha = alpha.transpose(1, 2) # [N, HW, C] 238 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 239 | S = torch.sum(alpha, dim=1, keepdim=True) 240 | W = 1-c/S 241 | P = alpha/S 242 | unlabelpred = unlabelpred.view(unlabelpred.size(0), unlabelpred.size(1), -1) 243 | unlabelpred = unlabelpred.transpose(1, 2) # [N, HW, C] 244 | unlabelpred = unlabelpred.contiguous().view(-1, unlabelpred.size(2)) 245 | L_con_1 = torch.sum(softmax_mae_loss_EGV(unlabelpred,P.detach())*W,dim=1)/2 246 | L_con_1 = torch.mean(L_con_1) 247 | return L_con_1 248 | def unlabelcrossclassiferloss(unlabelpred, alpha,c): 249 | #evidence = F.softplus(prob) 250 | # L_dice = TDice(alpha,p,criterion_dl) 251 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 252 | alpha = alpha.transpose(1, 2) # [N, HW, C] 253 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 254 | S = torch.sum(alpha, dim=1, keepdim=True) 255 | W = 1-c/S 256 | P = alpha/S 257 | unlabelpred = unlabelpred.view(unlabelpred.size(0), unlabelpred.size(1), -1) 258 | unlabelpred = unlabelpred.transpose(1, 2) # [N, HW, C] 259 | unlabelpred = unlabelpred.contiguous().view(-1, unlabelpred.size(2)) 260 | L_con_1 = torch.sum(softmax_mse_loss(unlabelpred,P.detach())*W.detach(),dim=1)/2 261 | L_con_2 = torch.sum(softmax_mse_loss_VGE(P,unlabelpred.detach())*W.detach(),dim=1)/2 262 | L_con = torch.mean(L_con_1+L_con_2) 263 | 264 | return L_con 265 | def unlabelcrossclassifermaeloss(unlabelpred, alpha,c): 266 | #evidence = F.softplus(prob) 267 | # L_dice = TDice(alpha,p,criterion_dl) 268 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 269 | alpha = alpha.transpose(1, 2) # [N, HW, C] 270 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 271 | S = torch.sum(alpha, dim=1, keepdim=True) 272 | W = 1-c/S 273 | P = alpha/S 274 | unlabelpred = unlabelpred.view(unlabelpred.size(0), unlabelpred.size(1), -1) 275 | unlabelpred = unlabelpred.transpose(1, 2) # [N, HW, C] 276 | unlabelpred = unlabelpred.contiguous().view(-1, unlabelpred.size(2)) 277 | L_con_1 = torch.sum(softmax_mae_loss_EGV(unlabelpred,P.detach())*W.detach(),dim=1)/2 278 | L_con_2 = torch.sum(softmax_mae_loss_VGE(P,unlabelpred.detach())*W.detach(),dim=1)/2 279 | L_con = torch.mean(L_con_1+L_con_2) 280 | return L_con 281 | def no_meanmae_loss(input1, input2): 282 | return(torch.abs(input1 - input2)) 283 | def DiceCEmaeconloss(unlabelpred, alpha,c): 284 | #evidence = F.softplus(prob) 285 | # L_dice = TDice(alpha,p,criterion_dl) 286 | alpha = alpha.view(alpha.size(0), alpha.size(1), -1) # [N, C, HW] 287 | alpha = alpha.transpose(1, 2) # [N, HW, C] 288 | alpha = alpha.contiguous().view(-1, alpha.size(2)) 289 | S1 = torch.sum(alpha, dim=1, keepdim=True) 290 | W1 = 1-c/S1 291 | P1 = alpha/S1 292 | unlabelpred = unlabelpred.view(unlabelpred.size(0), unlabelpred.size(1), -1) 293 | unlabelpred = unlabelpred.transpose(1, 2) # [N, HW, C] 294 | unlabelpred = unlabelpred.contiguous().view(-1, unlabelpred.size(2)) 295 | S2 = torch.sum(unlabelpred, dim=1, keepdim=True) 296 | W2 = 1-c/S2 297 | P2 = unlabelpred/S2 298 | L_con_1 = torch.sum(mse_loss(P2,P1.detach())*W1.detach(),dim=1)/2 299 | L_con_2 = torch.sum(mse_loss(P1,P2.detach())*W2.detach(),dim=1)/2 300 | L_con = torch.mean(L_con_1+L_con_2)#*5 301 | return L_con 302 | def Binary_dice_loss(predictive, target, ep=1e-8): 303 | intersection = 2 * torch.sum(predictive * target) + ep 304 | union = torch.sum(predictive) + torch.sum(target) + ep 305 | loss = 1 - intersection / union 306 | return loss 307 | 308 | def kl_loss(inputs, targets, ep=1e-8): 309 | kl_loss=nn.KLDivLoss(reduction='mean') 310 | consist_loss = kl_loss(torch.log(inputs+ep), targets) 311 | return consist_loss 312 | 313 | def soft_ce_loss(inputs, target, ep=1e-8): 314 | logprobs = torch.log(inputs+ep) 315 | return torch.mean(-(target[:,0,...]*logprobs[:,0,...]+target[:,1,...]*logprobs[:,1,...])) 316 | 317 | def mse_loss(input1, input2): 318 | return torch.mean((input1 - input2)**2) 319 | def mae_loss(input1, input2): 320 | return torch.mean(torch.abs(input1 - input2)) 321 | def get_cut_mask(probs, thres=0.5, nms=0): 322 | masks = (probs >= thres).type(torch.int64) 323 | masks = masks[:, 1, :, :].contiguous() 324 | if nms == 1: 325 | masks = LargestCC_pancreas(masks) 326 | return masks 327 | 328 | 329 | def LargestCC_pancreas(segmentation): 330 | N = segmentation.shape[0] 331 | batch_list = [] 332 | for n in range(N): 333 | n_prob = segmentation[n].detach().cpu().numpy() 334 | labels = label(n_prob) 335 | if labels.max() != 0: 336 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 337 | else: 338 | largestCC = n_prob 339 | batch_list.append(largestCC) 340 | 341 | return torch.Tensor(batch_list).cuda() 342 | 343 | def one_resize_mse_loss(input1, input2): 344 | input1 = input1.view(input1.size(0), input1.size(1), -1) # [N, C, HW] 345 | input1 = input1.transpose(1, 2) # [N, HW, C] 346 | input1 = input1.contiguous().view(-1, input1.size(2)) 347 | return torch.mean((input1 - input2.detach())**2) 348 | def uncertainty_mse_loss(input1, input2,uncertainty): 349 | return torch.mean(((input1 - input2)**2)*(1-uncertainty.unsqueeze(1).detach())) 350 | def dice_loss(score, target): 351 | target = target.float() 352 | smooth = 1e-5 353 | intersect = torch.sum(score * target) 354 | y_sum = torch.sum(target * target) 355 | z_sum = torch.sum(score * score) 356 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 357 | loss = 1 - loss 358 | return loss 359 | 360 | class DiceLoss(nn.Module): 361 | def __init__(self, n_classes): 362 | super(DiceLoss, self).__init__() 363 | self.n_classes = n_classes 364 | 365 | def _one_hot_encoder(self, input_tensor): 366 | tensor_list = [] 367 | for i in range(self.n_classes): 368 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 369 | tensor_list.append(temp_prob) 370 | output_tensor = torch.cat(tensor_list, dim=1) 371 | return output_tensor.float() 372 | 373 | def _dice_loss(self, score, target): 374 | target = target.float() 375 | smooth = 1e-10 376 | intersection = torch.sum(score * target) 377 | union = torch.sum(score * score) + torch.sum(target * target) + smooth 378 | loss = 1 - intersection / union 379 | return loss 380 | 381 | def forward(self, inputs, target, weight=None, softmax=False): 382 | if softmax: 383 | inputs = torch.softmax(inputs, dim=1) 384 | target = self._one_hot_encoder(target) 385 | if weight is None: 386 | weight = [1] * self.n_classes 387 | assert inputs.size() == target.size(), 'predict & target shape do not match' 388 | class_wise_dice = [] 389 | loss = 0.0 390 | for i in range(0, self.n_classes): 391 | dice = self._dice_loss(inputs[:, i], target[:, i]) 392 | class_wise_dice.append(1.0 - dice.item()) 393 | loss += dice * weight[i] 394 | return loss / self.n_classes 395 | CE = nn.CrossEntropyLoss() 396 | Dice= DiceLoss(2) 397 | def one_resize_CE_loss(input1, input2): 398 | patch_size = input1.size() 399 | input2 = input2.view(patch_size[0],-1,patch_size[1]) 400 | input2 = input2.transpose(1,2) 401 | input2 = input2.reshape(patch_size[0],patch_size[1],patch_size[2],patch_size[3],patch_size[4]) 402 | input2 = get_cut_mask(input2,0.5,1) 403 | return CE(input1, input2.long())#+Dice(input1/torch.sum(input1,dim=1,keepdim=True),input2.unsqueeze(1)) 404 | class DiceLoss_evi(nn.Module): 405 | def __init__(self, n_classes): 406 | super(DiceLoss_evi, self).__init__() 407 | self.n_classes = n_classes 408 | 409 | def _one_hot_encoder(self, input_tensor): 410 | tensor_list = [] 411 | for i in range(self.n_classes): 412 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 413 | tensor_list.append(temp_prob) 414 | output_tensor = torch.cat(tensor_list, dim=1) 415 | return output_tensor.float() 416 | 417 | def _dice_loss(self, score, target): 418 | target = target.float() 419 | smooth = 1e-10 420 | intersection = torch.sum(score * target) 421 | union = torch.sum(score) + torch.sum(target) + smooth 422 | loss = 1 - intersection / union 423 | return loss 424 | 425 | def forward(self, inputs, target, weight=None, softmax=False): 426 | if softmax: 427 | inputs = torch.softmax(inputs, dim=1) 428 | target = self._one_hot_encoder(target) 429 | if weight is None: 430 | weight = [1] * self.n_classes 431 | #print(inputs.shape,target.shape) 432 | assert inputs.size() == target.size(), 'predict & target shape do not match' 433 | class_wise_dice = [] 434 | loss = 0.0 435 | for i in range(0, self.n_classes): 436 | dice = self._dice_loss(inputs[:, i], target[:, i]) 437 | class_wise_dice.append(1.0 - dice.item()) 438 | loss += dice * weight[i] 439 | return loss / self.n_classes 440 | class evidentDiceLoss(nn.Module): 441 | def __init__(self, n_classes): 442 | super(evidentDiceLoss, self).__init__() 443 | self.n_classes = n_classes 444 | 445 | def _one_hot_encoder(self, input_tensor): 446 | tensor_list = [] 447 | for i in range(self.n_classes): 448 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 449 | tensor_list.append(temp_prob.unsqueeze(1)) 450 | output_tensor = torch.cat(tensor_list, dim=1) 451 | print(output_tensor.shape) 452 | return output_tensor.float() 453 | 454 | def _dice_loss(self, alpha, target): 455 | S = torch.sum(alpha, dim=1, keepdim=True) 456 | p = alpha/S 457 | target = target.float() 458 | smooth = 1e-10 459 | intersection = torch.sum(p * target) 460 | sDiceDen = torch.sum(p * p) + torch.sum(target * target) + smooth 461 | varfenzi = alpha*(S-alpha) 462 | varfenmu = S*S*(S+1) 463 | var = torch.sum(varfenzi/varfenmu) 464 | union = sDiceDen + var 465 | sumi = intersection / union 466 | return sumi 467 | 468 | def forward(self, inputs, target, weight=None, softmax=False): 469 | if softmax: 470 | inputs = torch.softmax(inputs, dim=1) 471 | target = self._one_hot_encoder(target) 472 | if weight is None: 473 | weight = [1] * self.n_classes 474 | assert inputs.size() == target.size(), 'predict & target shape do not match' 475 | #class_wise_dice = [] 476 | loss = 0.0 477 | for i in range(0, self.n_classes): 478 | dice = self._dice_loss(inputs[:, i], target[:, i]) 479 | #class_wise_dice.append(dice.item()) 480 | loss += dice * weight[i] 481 | loss = 1- (loss/self.n_classes *2) 482 | return loss 483 | def softmax_mae_loss(alpha1, alpha2,c): 484 | """Takes softmax on both sides and returns MSE loss 485 | Note: 486 | - Returns the sum over all examples. Divide by the batch size afterwards 487 | if you want the mean. 488 | - Sends gradients to inputs but not the targets. 489 | """ 490 | S1 = torch.sum(alpha1, dim=1, keepdim=True) 491 | W1 = 1 - c / S1 492 | P1= alpha1 / S1 493 | S2 = torch.sum(alpha2, dim=1, keepdim=True) 494 | W2 = 1 - c/S2 495 | P2 = alpha2 / S2 496 | loss1 = torch.sum(torch.abs(P1 - P2.detach())*W2) / (alpha1.shape[0] * alpha1.shape[1] * alpha1.shape[2] * alpha1.shape[3] * alpha1.shape[4]) 497 | loss2 = torch.sum(torch.abs(P2 - P1.detach())*W1) / (alpha1.shape[0] * alpha1.shape[1] * alpha1.shape[2] * alpha1.shape[3] * alpha1.shape[4]) 498 | loss = loss1 + loss2 499 | #input_softmax = F.softmax(input_logits, dim=1) 500 | #target_softmax = F.softmax(target_logits, dim=1) 501 | return loss -------------------------------------------------------------------------------- /code/networks/VNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class ConvBlock(nn.Module): 6 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 7 | super(ConvBlock, self).__init__() 8 | 9 | ops = [] 10 | for i in range(n_stages): 11 | if i==0: 12 | input_channel = n_filters_in 13 | else: 14 | input_channel = n_filters_out 15 | 16 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 17 | if normalization == 'batchnorm': 18 | ops.append(nn.BatchNorm3d(n_filters_out)) 19 | elif normalization == 'groupnorm': 20 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 21 | elif normalization == 'instancenorm': 22 | ops.append(nn.InstanceNorm3d(n_filters_out)) 23 | elif normalization != 'none': 24 | assert False 25 | ops.append(nn.ReLU(inplace=True)) 26 | 27 | self.conv = nn.Sequential(*ops) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | 34 | class ResidualConvBlock(nn.Module): 35 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 36 | super(ResidualConvBlock, self).__init__() 37 | 38 | ops = [] 39 | for i in range(n_stages): 40 | if i == 0: 41 | input_channel = n_filters_in 42 | else: 43 | input_channel = n_filters_out 44 | 45 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 46 | if normalization == 'batchnorm': 47 | ops.append(nn.BatchNorm3d(n_filters_out)) 48 | elif normalization == 'groupnorm': 49 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 50 | elif normalization == 'instancenorm': 51 | ops.append(nn.InstanceNorm3d(n_filters_out)) 52 | elif normalization != 'none': 53 | assert False 54 | 55 | if i != n_stages-1: 56 | ops.append(nn.ReLU(inplace=True)) 57 | 58 | self.conv = nn.Sequential(*ops) 59 | self.relu = nn.ReLU(inplace=True) 60 | 61 | def forward(self, x): 62 | x = (self.conv(x) + x) 63 | x = self.relu(x) 64 | return x 65 | 66 | 67 | class DownsamplingConvBlock(nn.Module): 68 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 69 | super(DownsamplingConvBlock, self).__init__() 70 | 71 | ops = [] 72 | if normalization != 'none': 73 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 74 | if normalization == 'batchnorm': 75 | ops.append(nn.BatchNorm3d(n_filters_out)) 76 | elif normalization == 'groupnorm': 77 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 78 | elif normalization == 'instancenorm': 79 | ops.append(nn.InstanceNorm3d(n_filters_out)) 80 | else: 81 | assert False 82 | else: 83 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 84 | 85 | ops.append(nn.ReLU(inplace=True)) 86 | 87 | self.conv = nn.Sequential(*ops) 88 | 89 | def forward(self, x): 90 | x = self.conv(x) 91 | return x 92 | 93 | 94 | class Upsampling_function(nn.Module): 95 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none', mode_upsampling = 1): 96 | super(Upsampling_function, self).__init__() 97 | 98 | ops = [] 99 | if mode_upsampling == 0: 100 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 101 | if mode_upsampling == 1: 102 | ops.append(nn.Upsample(scale_factor=stride, mode="trilinear", align_corners=True)) 103 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 104 | elif mode_upsampling == 2: 105 | ops.append(nn.Upsample(scale_factor=stride, mode="nearest")) 106 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 107 | 108 | if normalization == 'batchnorm': 109 | ops.append(nn.BatchNorm3d(n_filters_out)) 110 | elif normalization == 'groupnorm': 111 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 112 | elif normalization == 'instancenorm': 113 | ops.append(nn.InstanceNorm3d(n_filters_out)) 114 | elif normalization != 'none': 115 | assert False 116 | ops.append(nn.ReLU(inplace=True)) 117 | 118 | self.conv = nn.Sequential(*ops) 119 | 120 | def forward(self, x): 121 | x = self.conv(x) 122 | return x 123 | 124 | class Encoder(nn.Module): 125 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False): 126 | super(Encoder, self).__init__() 127 | self.has_dropout = has_dropout 128 | convBlock = ConvBlock if not has_residual else ResidualConvBlock 129 | 130 | self.block_one = convBlock(1, n_channels, n_filters, normalization=normalization) 131 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 132 | 133 | self.block_two = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 134 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 135 | 136 | self.block_three = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 137 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 138 | 139 | self.block_four = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 140 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 141 | 142 | self.block_five = convBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 143 | 144 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 145 | 146 | def forward(self, input): 147 | x1 = self.block_one(input) 148 | x1_dw = self.block_one_dw(x1) 149 | 150 | x2 = self.block_two(x1_dw) 151 | x2_dw = self.block_two_dw(x2) 152 | 153 | x3 = self.block_three(x2_dw) 154 | x3_dw = self.block_three_dw(x3) 155 | 156 | x4 = self.block_four(x3_dw) 157 | x4_dw = self.block_four_dw(x4) 158 | 159 | x5 = self.block_five(x4_dw) 160 | 161 | if self.has_dropout: 162 | x5 = self.dropout(x5) 163 | 164 | res = [x1, x2, x3, x4, x5] 165 | return res 166 | 167 | 168 | class Decoder(nn.Module): 169 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False, up_type=0): 170 | super(Decoder, self).__init__() 171 | self.has_dropout = has_dropout 172 | 173 | convBlock = ConvBlock if not has_residual else ResidualConvBlock 174 | 175 | self.block_five_up = Upsampling_function(n_filters * 16, n_filters * 8, normalization=normalization, mode_upsampling=up_type) 176 | 177 | self.block_six = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 178 | self.block_six_up = Upsampling_function(n_filters * 8, n_filters * 4, normalization=normalization, mode_upsampling=up_type) 179 | 180 | self.block_seven = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 181 | self.block_seven_up = Upsampling_function(n_filters * 4, n_filters * 2, normalization=normalization, mode_upsampling=up_type) 182 | 183 | self.block_eight = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 184 | self.block_eight_up = Upsampling_function(n_filters * 2, n_filters, normalization=normalization, mode_upsampling=up_type) 185 | 186 | self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization) 187 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 188 | 189 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 190 | 191 | def forward(self, features): 192 | x1 = features[0] 193 | x2 = features[1] 194 | x3 = features[2] 195 | x4 = features[3] 196 | x5 = features[4] 197 | 198 | x5_up = self.block_five_up(x5) 199 | x5_up = x5_up + x4 200 | 201 | x6 = self.block_six(x5_up) 202 | x6_up = self.block_six_up(x6) 203 | x6_up = x6_up + x3 204 | 205 | x7 = self.block_seven(x6_up) 206 | x7_up = self.block_seven_up(x7) 207 | x7_up = x7_up + x2 208 | 209 | x8 = self.block_eight(x7_up) 210 | x8_up = self.block_eight_up(x8) 211 | x8_up = x8_up + x1 212 | x9 = self.block_nine(x8_up) 213 | if self.has_dropout: 214 | x9 = self.dropout(x9) 215 | out_seg = self.out_conv(x9) 216 | 217 | return out_seg 218 | class DecoderDiceCE(nn.Module): 219 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False, up_type=0): 220 | super(DecoderDiceCE, self).__init__() 221 | self.has_dropout = has_dropout 222 | 223 | convBlock = ConvBlock if not has_residual else ResidualConvBlock 224 | 225 | self.block_five_up = Upsampling_function(n_filters * 16, n_filters * 8, normalization=normalization, mode_upsampling=up_type) 226 | 227 | self.block_six = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 228 | self.block_six_up = Upsampling_function(n_filters * 8, n_filters * 4, normalization=normalization, mode_upsampling=up_type) 229 | 230 | self.block_seven = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 231 | self.block_seven_up = Upsampling_function(n_filters * 4, n_filters * 2, normalization=normalization, mode_upsampling=up_type) 232 | 233 | self.block_eight = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 234 | self.block_eight_up = Upsampling_function(n_filters * 2, n_filters, normalization=normalization, mode_upsampling=up_type) 235 | 236 | self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization) 237 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 238 | 239 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 240 | 241 | def forward(self, features): 242 | x1 = features[0] 243 | x2 = features[1] 244 | x3 = features[2] 245 | x4 = features[3] 246 | x5 = features[4] 247 | 248 | x5_up = self.block_five_up(x5) 249 | x5_up = x5_up + x4 250 | 251 | x6 = self.block_six(x5_up) 252 | x6_up = self.block_six_up(x6) 253 | x6_up = x6_up + x3 254 | 255 | x7 = self.block_seven(x6_up) 256 | x7_up = self.block_seven_up(x7) 257 | x7_up = x7_up + x2 258 | 259 | x8 = self.block_eight(x7_up) 260 | x8_up = self.block_eight_up(x8) 261 | x8_up = x8_up + x1 262 | x9 = self.block_nine(x8_up) 263 | if self.has_dropout: 264 | x9 = self.dropout(x9) 265 | out_seg = self.out_conv(x9) 266 | 267 | return out_seg,x8_up 268 | class VNet(nn.Module): 269 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False): 270 | super(VNet, self).__init__() 271 | 272 | self.encoder = Encoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual) 273 | self.decoder1 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 274 | 275 | def forward(self, input): 276 | features = self.encoder(input) 277 | out_seg1 = self.decoder1(features) 278 | return out_seg1 279 | 280 | class MCNet3d_v1(nn.Module): 281 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False): 282 | super(MCNet3d_v1, self).__init__() 283 | 284 | self.encoder = Encoder(n_channels, n_classes, n_filters,normalization, has_dropout, has_residual) 285 | self.decoder1 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 286 | self.decoder2 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 1) 287 | 288 | def forward(self, input): 289 | features = self.encoder(input) 290 | out_seg1 = self.decoder1(features) 291 | out_seg2 = self.decoder2(features) 292 | return out_seg1, out_seg2 293 | 294 | class MCNet3d_v2(nn.Module): 295 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False): 296 | super(MCNet3d_v2, self).__init__() 297 | 298 | self.encoder = Encoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual) 299 | self.decoder1 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 300 | self.decoder2 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 1) 301 | self.decoder3 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 2) 302 | 303 | def forward(self, input): 304 | features = self.encoder(input) 305 | out_seg1 = self.decoder1(features) 306 | out_seg2 = self.decoder2(features) 307 | out_seg3 = self.decoder3(features) 308 | return out_seg1, out_seg2, out_seg3 309 | class ECNet3d(nn.Module): 310 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, 311 | has_residual=False): 312 | super(ECNet3d, self).__init__() 313 | 314 | self.encoder = Encoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual) 315 | self.decoder1 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 316 | self.decoder2 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 1) 317 | 318 | def forward(self, input): 319 | B = input.size(0) 320 | B = torch.tensor(B / 2, dtype=torch.int) 321 | features = self.encoder(input) 322 | out_seg1 = self.decoder1(features) 323 | out_seg2 = self.decoder2(features) 324 | evidence = F.softplus(out_seg2) 325 | alpha = evidence+1 326 | return out_seg1, evidence, alpha 327 | class DiceCENet3d(nn.Module): 328 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, 329 | has_residual=False): 330 | super(DiceCENet3d, self).__init__() 331 | 332 | self.encoder = Encoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual) 333 | self.decoder1 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 334 | self.decoder2 = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 335 | 336 | def forward(self, input): 337 | B = input.size(0) 338 | B = torch.tensor(B / 2, dtype=torch.int) 339 | features = self.encoder(input) 340 | out_seg1 = self.decoder1(features) 341 | out_seg2 = self.decoder2(features) 342 | evidence1 = F.softplus(out_seg1) 343 | alpha1 = evidence1+1 344 | evidence2 = F.softplus(out_seg2) 345 | alpha2 = evidence2 + 1 346 | return alpha1,alpha2 347 | def DS_Combin_two(alpha1, alpha2,class_num): 348 | """ 349 | :param alpha1: Dirichlet distribution parameters of view 1 350 | :param alpha2: Dirichlet distribution parameters of view 2 351 | :return: Combined Dirichlet distribution parameters 352 | """ 353 | alpha = dict() 354 | alpha[0], alpha[1] = alpha1,alpha2 355 | b, S, E, u = dict(), dict(), dict(), dict() 356 | for v in range(2): 357 | S[v] = torch.sum(alpha[v], dim=1, keepdim=True) 358 | E[v] = alpha[v] - 1 359 | b[v] = E[v] / (S[v].expand(E[v].shape)) 360 | #print(b[v].shape) 361 | u[v] = class_num / S[v]#B*C*1 362 | 363 | # b^0 @ b^(0+1) 364 | bb = torch.bmm(b[0].view(-1, class_num, 1), b[1].view(-1, 1, class_num )) 365 | # b^0 * u^1 366 | uv1_expand = u[1].expand(b[1].shape) #B*C*1 367 | #print(uv1_expand.shape,'uv1') 368 | bu = torch.mul(b[0], uv1_expand)#B*C*1 369 | 370 | # b^1 * u^0 371 | uv2_expand = u[0].expand(b[0].shape) 372 | ub = torch.mul(b[1], uv2_expand) 373 | # calculate C 374 | bb_sum = torch.sum(bb, dim=(1, 2), out=None)#B 375 | #print(bb.shape, 'bb_sum',torch.diagonal(bb, dim1=-2, dim2=-1)) 376 | bb_diag = torch.diagonal(bb, dim1=-2, dim2=-1).sum(-1) 377 | C = bb_sum - bb_diag 378 | 379 | # calculate b^a 380 | b_a = (torch.mul(b[0], b[1]) + bu + ub) / ((1 - C).view(-1,1).expand(b[0].shape)) 381 | # calculate u^a 382 | u_a = torch.mul(u[0], u[1]) / ((1 - C).view(-1,1).expand(u[0].shape)) 383 | 384 | # calculate new S 385 | S_a = class_num / u_a 386 | # calculate new e_k 387 | e_a = torch.mul(b_a, S_a.expand(b_a.shape)) 388 | alpha_a = e_a + 1 389 | return alpha_a 390 | class DiceCENet3d_fuse(nn.Module): 391 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, 392 | has_residual=False): 393 | super(DiceCENet3d_fuse, self).__init__() 394 | convBlock = ConvBlock if not has_residual else ResidualConvBlock 395 | self.encoder = Encoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual) 396 | self.decoder1 = DecoderDiceCE(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 397 | self.decoder2 = DecoderDiceCE(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 398 | self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization) 399 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 400 | self.dropout = nn.Dropout3d(p=0.1, inplace=False) 401 | def forward(self, input): 402 | B = input.size(0) 403 | B = torch.tensor(B / 2, dtype=torch.int) 404 | features = self.encoder(input) 405 | out_seg1,x_up1 = self.decoder1(features) 406 | out_seg2,x_up2 = self.decoder2(features) 407 | evidence1 = F.softplus(out_seg1) 408 | alpha1 = evidence1+1 409 | evidence2 = F.softplus(out_seg2) 410 | alpha2 = evidence2 + 1 411 | prob2 = alpha2/torch.sum(alpha2,dim=1,keepdim=True) 412 | resize_alpha1 = alpha1.view(alpha1.size(0), alpha1.size(1), -1) # [N, C, HW] 413 | resize_alpha1 = resize_alpha1.transpose(1, 2) # [N, HW, C] 414 | resize_alpha1 = resize_alpha1.contiguous().view(-1, resize_alpha1.size(2)) 415 | resize_alpha2 = alpha2.view(alpha2.size(0), alpha2.size(1), -1) # [N, C, HW] 416 | resize_alpha2 = resize_alpha2.transpose(1, 2) # [N, HW, C] 417 | resize_alpha2 = resize_alpha2.contiguous().view(-1, resize_alpha2.size(2)) 418 | fuse_out_sup = DS_Combin_two( resize_alpha1, resize_alpha2,2) 419 | fuse_out_sup = fuse_out_sup/torch.sum(fuse_out_sup,dim=1,keepdim=True) 420 | fuse_out =self.out_conv(self.dropout(self.block_nine(x_up1+x_up2))) 421 | fuse_out = F.softplus(fuse_out) 422 | #fuse_out = fuse_out.view(fuse_out.size(0), fuse_out.size(1), -1) # [N, C, HW] 423 | #fuse_out = fuse_out.transpose(1, 2) # [N, HW, C] 424 | #fuse_out = fuse_out.contiguous().view(-1, fuse_out.size(2)) 425 | fuse_out = fuse_out/torch.sum(fuse_out,dim=1,keepdim=True) 426 | return alpha1,alpha2,prob2,fuse_out,fuse_out_sup 427 | class DiceCENet3d_fuse_2(nn.Module): 428 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, 429 | has_residual=False): 430 | super(DiceCENet3d_fuse_2, self).__init__() 431 | convBlock = ConvBlock if not has_residual else ResidualConvBlock 432 | self.encoder = Encoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual) 433 | self.decoder1 = DecoderDiceCE(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 434 | self.decoder2 = DecoderDiceCE(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 435 | self.decoder3 = DecoderDiceCE(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual, 0) 436 | self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization) 437 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 438 | self.dropout = nn.Dropout3d(p=0.1, inplace=False) 439 | def forward(self, input): 440 | B = input.size(0) 441 | B = torch.tensor(B / 2, dtype=torch.int) 442 | features = self.encoder(input) 443 | out_seg1,x_up1 = self.decoder1(features) 444 | out_seg2,x_up2 = self.decoder2(features) 445 | out_seg3,x_up3 = self.decoder3(features) 446 | evidence1 = F.softplus(out_seg1) 447 | alpha1 = evidence1+1 448 | evidence2 = F.softplus(out_seg2) 449 | alpha2 = evidence2 + 1 450 | evidence3 = F.softplus(out_seg3) 451 | alpha3 = evidence3 + 1 452 | prob2 = alpha2/torch.sum(alpha2,dim=1,keepdim=True) 453 | resize_alpha1 = alpha1.view(alpha1.size(0), alpha1.size(1), -1) # [N, C, HW] 454 | resize_alpha1 = resize_alpha1.transpose(1, 2) # [N, HW, C] 455 | resize_alpha1 = resize_alpha1.contiguous().view(-1, resize_alpha1.size(2)) 456 | resize_alpha2 = alpha2.view(alpha2.size(0), alpha2.size(1), -1) # [N, C, HW] 457 | resize_alpha2 = resize_alpha2.transpose(1, 2) # [N, HW, C] 458 | resize_alpha2 = resize_alpha2.contiguous().view(-1, resize_alpha2.size(2)) 459 | fuse_out_sup = DS_Combin_two( resize_alpha1, resize_alpha2,2) 460 | fuse_out_sup = fuse_out_sup/torch.sum(fuse_out_sup,dim=1,keepdim=True) 461 | #fuse_out =self.out_conv(self.dropout(self.block_nine(x_up1+x_up2))) 462 | #fuse_out = F.softplus(fuse_out) 463 | #fuse_out = fuse_out.view(fuse_out.size(0), fuse_out.size(1), -1) # [N, C, HW] 464 | #fuse_out = fuse_out.transpose(1, 2) # [N, HW, C] 465 | #fuse_out = fuse_out.contiguous().view(-1, fuse_out.size(2)) 466 | fuse_out = alpha3/torch.sum(alpha3,dim=1,keepdim=True) 467 | return alpha1,alpha2,prob2,fuse_out,fuse_out_sup,alpha3 468 | if __name__ == '__main__': 469 | # compute FLOPS & PARAMETERS 470 | from ptflops import get_model_complexity_info 471 | model = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=False) 472 | with torch.cuda.device(0): 473 | macs, params = get_model_complexity_info(model, (1, 112, 112, 80), as_strings=True, 474 | print_per_layer_stat=True, verbose=True) 475 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 476 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 477 | with torch.cuda.device(0): 478 | macs, params = get_model_complexity_info(model, (1, 96, 96, 96), as_strings=True, 479 | print_per_layer_stat=True, verbose=True) 480 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 481 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 482 | import ipdb; ipdb.set_trace() 483 | -------------------------------------------------------------------------------- /code/networks/unet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | The implementation is borrowed from: https://github.com/HiLab-git/PyMIC 4 | """ 5 | from __future__ import division, print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.distributions.uniform import Uniform 10 | import torch.nn.functional as F 11 | import numpy as np 12 | class ConvBlock_2d(nn.Module): 13 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 14 | super(ConvBlock_2d, self).__init__() 15 | 16 | ops = [] 17 | for i in range(n_stages): 18 | if i==0: 19 | input_channel = n_filters_in 20 | else: 21 | input_channel = n_filters_out 22 | 23 | ops.append(nn.Conv2d(input_channel, n_filters_out, 3, padding=1)) 24 | if normalization == 'batchnorm': 25 | ops.append(nn.BatchNorm2d(n_filters_out)) 26 | elif normalization == 'groupnorm': 27 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 28 | elif normalization == 'instancenorm': 29 | ops.append(nn.InstanceNorm2d(n_filters_out)) 30 | elif normalization != 'none': 31 | assert False 32 | ops.append(nn.ReLU(inplace=True)) 33 | 34 | self.conv = nn.Sequential(*ops) 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | return x 39 | class ConvBlock(nn.Module): 40 | """two convolution layers with batch norm and leaky relu""" 41 | def __init__(self, in_channels, out_channels, dropout_p): 42 | super(ConvBlock, self).__init__() 43 | self.conv_conv = nn.Sequential( 44 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 45 | nn.BatchNorm2d(out_channels), 46 | nn.LeakyReLU(), 47 | nn.Dropout(dropout_p), 48 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 49 | nn.BatchNorm2d(out_channels), 50 | nn.LeakyReLU() 51 | ) 52 | 53 | def forward(self, x): 54 | return self.conv_conv(x) 55 | 56 | class DownBlock(nn.Module): 57 | """Downsampling followed by ConvBlock""" 58 | def __init__(self, in_channels, out_channels, dropout_p): 59 | super(DownBlock, self).__init__() 60 | self.maxpool_conv = nn.Sequential( 61 | nn.MaxPool2d(2), 62 | ConvBlock(in_channels, out_channels, dropout_p) 63 | ) 64 | 65 | def forward(self, x): 66 | return self.maxpool_conv(x) 67 | 68 | 69 | class UpBlock(nn.Module): 70 | """Upssampling followed by ConvBlock""" 71 | def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, mode_upsampling=1): 72 | super(UpBlock, self).__init__() 73 | self.mode_upsampling = mode_upsampling 74 | if mode_upsampling==0: 75 | self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) 76 | elif mode_upsampling==1: 77 | self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) 78 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 79 | elif mode_upsampling==2: 80 | self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) 81 | self.up = nn.Upsample(scale_factor=2, mode='nearest') 82 | elif mode_upsampling==3: 83 | self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) 84 | self.up = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True) 85 | self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) 86 | 87 | def forward(self, x1, x2): 88 | if self.mode_upsampling != 0: 89 | x1 = self.conv1x1(x1) 90 | x1 = self.up(x1) 91 | x = torch.cat([x2, x1], dim=1) 92 | x = self.conv(x) 93 | return x 94 | 95 | 96 | class Encoder(nn.Module): 97 | def __init__(self, params): 98 | super(Encoder, self).__init__() 99 | self.params = params 100 | self.in_chns = self.params['in_chns'] 101 | self.ft_chns = self.params['feature_chns'] 102 | self.n_class = self.params['class_num'] 103 | self.dropout = self.params['dropout'] 104 | assert (len(self.ft_chns) == 5) 105 | self.in_conv = ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) 106 | self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) 107 | self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) 108 | self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) 109 | self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) 110 | 111 | def forward(self, x): 112 | x0 = self.in_conv(x) 113 | x1 = self.down1(x0) 114 | x2 = self.down2(x1) 115 | x3 = self.down3(x2) 116 | x4 = self.down4(x3) 117 | return [x0, x1, x2, x3, x4] 118 | class DecoderCCT(nn.Module): 119 | def __init__(self, params): 120 | super(DecoderCCT, self).__init__() 121 | self.params = params 122 | self.in_chns = self.params['in_chns'] 123 | self.ft_chns = self.params['feature_chns'] 124 | self.n_class = self.params['class_num'] 125 | self.bilinear = self.params['bilinear'] 126 | assert (len(self.ft_chns) == 5) 127 | 128 | self.up1 = UpBlock( 129 | self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) 130 | self.up2 = UpBlock( 131 | self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) 132 | self.up3 = UpBlock( 133 | self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) 134 | self.up4 = UpBlock( 135 | self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) 136 | 137 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, 138 | kernel_size=3, padding=1) 139 | 140 | def forward(self, feature): 141 | x0 = feature[0] 142 | x1 = feature[1] 143 | x2 = feature[2] 144 | x3 = feature[3] 145 | x4 = feature[4] 146 | 147 | x = self.up1(x4, x3) 148 | x = self.up2(x, x2) 149 | x = self.up3(x, x1) 150 | x = self.up4(x, x0) 151 | output = self.out_conv(x) 152 | return output 153 | 154 | class Decoder(nn.Module): 155 | def __init__(self, params): 156 | super(Decoder, self).__init__() 157 | self.params = params 158 | self.in_chns = self.params['in_chns'] 159 | self.ft_chns = self.params['feature_chns'] 160 | self.n_class = self.params['class_num'] 161 | self.up_type = self.params['up_type'] 162 | assert (len(self.ft_chns) == 5) 163 | 164 | self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0, mode_upsampling=self.up_type) 165 | self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0, mode_upsampling=self.up_type) 166 | self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0, mode_upsampling=self.up_type) 167 | self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0, mode_upsampling=self.up_type) 168 | 169 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1) 170 | 171 | def forward(self, feature): 172 | x0 = feature[0] 173 | x1 = feature[1] 174 | x2 = feature[2] 175 | x3 = feature[3] 176 | x4 = feature[4] 177 | 178 | x = self.up1(x4, x3) 179 | x = self.up2(x, x2) 180 | x = self.up3(x, x1) 181 | x = self.up4(x, x0) 182 | output = self.out_conv(x) 183 | return output 184 | class Decoder_sdf(nn.Module): 185 | def __init__(self, params): 186 | super(Decoder_sdf, self).__init__() 187 | self.params = params 188 | self.in_chns = self.params['in_chns'] 189 | self.ft_chns = self.params['feature_chns'] 190 | self.n_class = self.params['class_num'] 191 | self.up_type = self.params['up_type'] 192 | assert (len(self.ft_chns) == 5) 193 | 194 | self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0, mode_upsampling=self.up_type) 195 | self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0, mode_upsampling=self.up_type) 196 | self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0, mode_upsampling=self.up_type) 197 | self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0, mode_upsampling=self.up_type) 198 | self.out_conv2 = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1) 199 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1) 200 | self.tanh = nn.Tanh() 201 | self.dropout = nn.Dropout2d(p=0.5, inplace=False) 202 | 203 | def forward(self, feature): 204 | x0 = feature[0] 205 | x1 = feature[1] 206 | x2 = feature[2] 207 | x3 = feature[3] 208 | x4 = feature[4] 209 | 210 | x = self.up1(x4, x3) 211 | x = self.up2(x, x2) 212 | x = self.up3(x, x1) 213 | x = self.up4(x, x0) 214 | output = self.out_conv(x) 215 | outputsdf = self.out_conv2(x) 216 | out_tanh = self.tanh(outputsdf) 217 | return out_tanh, output 218 | class Decoder_URPC(nn.Module): 219 | def __init__(self, params): 220 | super(Decoder_URPC, self).__init__() 221 | self.params = params 222 | self.in_chns = self.params['in_chns'] 223 | self.ft_chns = self.params['feature_chns'] 224 | self.n_class = self.params['class_num'] 225 | self.bilinear = self.params['bilinear'] 226 | assert (len(self.ft_chns) == 5) 227 | 228 | self.up1 = UpBlock( 229 | self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) 230 | self.up2 = UpBlock( 231 | self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) 232 | self.up3 = UpBlock( 233 | self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) 234 | self.up4 = UpBlock( 235 | self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) 236 | 237 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, 238 | kernel_size=3, padding=1) 239 | self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class, 240 | kernel_size=3, padding=1) 241 | self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class, 242 | kernel_size=3, padding=1) 243 | self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class, 244 | kernel_size=3, padding=1) 245 | self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class, 246 | kernel_size=3, padding=1) 247 | self.feature_noise = FeatureNoise() 248 | 249 | def forward(self, feature, shape): 250 | x0 = feature[0] 251 | x1 = feature[1] 252 | x2 = feature[2] 253 | x3 = feature[3] 254 | x4 = feature[4] 255 | x = self.up1(x4, x3) 256 | if self.training: 257 | dp3_out_seg = self.out_conv_dp3(Dropout(x, p=0.5)) 258 | else: 259 | dp3_out_seg = self.out_conv_dp3(x) 260 | dp3_out_seg = torch.nn.functional.interpolate(dp3_out_seg, shape) 261 | 262 | x = self.up2(x, x2) 263 | if self.training: 264 | dp2_out_seg = self.out_conv_dp2(FeatureDropout(x)) 265 | else: 266 | dp2_out_seg = self.out_conv_dp2(x) 267 | dp2_out_seg = torch.nn.functional.interpolate(dp2_out_seg, shape) 268 | 269 | x = self.up3(x, x1) 270 | if self.training: 271 | dp1_out_seg = self.out_conv_dp1(self.feature_noise(x)) 272 | else: 273 | dp1_out_seg = self.out_conv_dp1(x) 274 | dp1_out_seg = torch.nn.functional.interpolate(dp1_out_seg, shape) 275 | 276 | x = self.up4(x, x0) 277 | dp0_out_seg = self.out_conv(x) 278 | return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg 279 | 280 | def masked_average_pooling(feature, mask): 281 | #print(feature.shape[-2:]) 282 | mask = F.interpolate(mask, size=feature.shape[-2:], mode='bilinear', align_corners=True) 283 | #print((feature*mask).shape) 284 | masked_feature = torch.sum(feature * mask, dim=(2, 3)) \ 285 | / (mask.sum(dim=(2, 3)) + 1e-5) 286 | return masked_feature 287 | 288 | def batch_prototype(feature,mask): #return B*C*feature_size 289 | batch_pro = torch.zeros(mask.shape[0], mask.shape[1], feature.shape[1]) 290 | for i in range(mask.shape[1]): 291 | classmask = mask[:,i,:,:] 292 | proclass = masked_average_pooling(feature,classmask.unsqueeze(1)) 293 | batch_pro[:,i,:] = proclass 294 | return batch_pro 295 | def entropy_value(p, C): 296 | # p N*C*W*H*D 297 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=2) / \ 298 | torch.tensor(np.log(C))#.cuda() 299 | return y1 300 | def agreementmap(similarity_map): 301 | score_map = torch.argmax(similarity_map,dim=3) 302 | #score_map =score_map.transpose(1,2) 303 | ##print(score_map.shape, 'score',score_map[0,0,:]) 304 | gt_onthot = F.one_hot(score_map,6) 305 | avg_onehot = torch.sum(gt_onthot,dim=2).float() 306 | avg_onehot = F.normalize(avg_onehot,1.0,dim=2) 307 | ##print(gt_onthot[0,0,:,:],avg_onehot[0,0,:]) 308 | weight = 1-entropy_value(avg_onehot,similarity_map.shape[3]) 309 | ##print(weight[0,0]) 310 | #score_map = torch.sum(score_map,dim=2) 311 | return weight 312 | def similarity_calulation(feature,batchpro): #feature_size = B*C*H*W batchpro= B*C*dim 313 | B = feature.size(0) 314 | feature = feature.view(feature.size(0), feature.size(1), -1) # [N, C, HW] 315 | feature = feature.transpose(1, 2) # [N, HW, C] 316 | feature = feature.contiguous().view(-1, feature.size(2)) 317 | C = batchpro.size(1) 318 | batchpro = batchpro.contiguous().view(-1, batchpro.size(2)) 319 | feature = F.normalize(feature, p=2.0, dim=1) 320 | batchpro = F.normalize(batchpro, p=2.0, dim=1).cuda() 321 | similarity = torch.mm(feature, batchpro.T) 322 | similarity = similarity.reshape(-1, B, C) 323 | similarity = similarity.reshape(B, -1, B, C) 324 | return similarity 325 | def selfsimilaritygen(similarity): 326 | B = similarity.shape[0] 327 | mapsize = similarity.shape[1] 328 | C = similarity.shape[3] 329 | selfsimilarity = torch.zeros(B,mapsize,C) 330 | for i in range(similarity.shape[2]): 331 | selfsimilarity[i,:,:] = similarity[i,:,i,:] 332 | return selfsimilarity.cuda() 333 | def othersimilaritygen(similarity): 334 | similarity = torch.exp(similarity) 335 | for i in range(similarity.shape[2]): 336 | similarity[i,:,i,:] =0 337 | similaritysum = torch.sum(similarity,dim=2) 338 | similaritysum_union = torch.sum(similaritysum,dim=2).unsqueeze(-1) 339 | #print(similaritysum_union.shape) 340 | othersimilarity = similaritysum/similaritysum_union 341 | #print(othersimilarity[1,1,:].sum()) 342 | return othersimilarity 343 | def Dropout(x, p=0.3): 344 | x = torch.nn.functional.dropout(x, p) 345 | return x 346 | class Decoder_pro(nn.Module): 347 | def __init__(self, params): 348 | super(Decoder_pro, self).__init__() 349 | self.params = params 350 | self.in_chns = self.params['in_chns'] 351 | self.ft_chns = self.params['feature_chns'] 352 | self.n_class = self.params['class_num'] 353 | self.up_type = self.params['up_type'] 354 | assert (len(self.ft_chns) == 5) 355 | 356 | self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0, mode_upsampling=self.up_type) 357 | self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0, mode_upsampling=self.up_type) 358 | self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0, mode_upsampling=self.up_type) 359 | self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0, mode_upsampling=self.up_type) 360 | 361 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1) 362 | 363 | def forward(self, feature): 364 | x0 = feature[0] 365 | x1 = feature[1] 366 | x2 = feature[2] 367 | x3 = feature[3] 368 | x4 = feature[4] 369 | 370 | x = self.up1(x4, x3) 371 | x = self.up2(x, x2) 372 | x = self.up3(x, x1) 373 | x = self.up4(x, x0) 374 | print(x.shape,'feature_shape') 375 | output = self.out_conv(x) 376 | mask = torch.softmax(output,dim=1) 377 | batch_pro = batch_prototype(x,mask) 378 | similarity_map = similarity_calulation(x,batch_pro) 379 | entropy_weight = agreementmap(similarity_map) 380 | self_simi_map = selfsimilaritygen(similarity_map) #B*HW*C 381 | other_simi_map = othersimilaritygen(similarity_map)#B*HW*C 382 | return output, self_simi_map, other_simi_map, entropy_weight 383 | class DS1(nn.Module): 384 | def __init__(self, n_prototypes, n_feature_maps): 385 | super(DS1, self).__init__() 386 | self.n_prototypes = n_prototypes 387 | self.w = torch.nn.Linear(in_features=n_feature_maps, out_features=n_prototypes, bias=False).weight 388 | def forward(self, inputs): 389 | inputs = inputs.transpose(1, 2) 390 | inputs = inputs.transpose(2, 3) 391 | for i in range(self.n_prototypes): 392 | if i == 0: 393 | un_mass_i = (self.w[i, :] - inputs) ** 2 394 | un_mass_i = torch.sum(un_mass_i, dim=-1, keepdim=True) 395 | print(un_mass_i.shape) 396 | un_mass = un_mass_i 397 | #un_mass = torch.unsqueeze(un_mass_i, -1) 398 | if i >= 1: 399 | un_mass_i = (self.w[i, :] - inputs) ** 2 400 | un_mass_i = torch.sum(un_mass_i, dim=-1, keepdim=True) 401 | un_mass = torch.cat([un_mass, un_mass_i], -1) 402 | return un_mass 403 | class DistanceActivation_layer(torch.nn.Module): 404 | ''' 405 | verified 406 | ''' 407 | def __init__(self, n_prototypes,init_alpha=0,init_gamma=0.1): 408 | super(DistanceActivation_layer, self).__init__() 409 | self.eta = torch.nn.Linear(in_features=n_prototypes, out_features=1, bias=False)#.weight.data.fill_(torch.from_numpy(np.array(init_gamma)).to(device)) 410 | self.xi = torch.nn.Linear(in_features=n_prototypes, out_features=1, bias=False)#.weight.data.fill_(torch.from_numpy(np.array(init_alpha)).to(device)) 411 | #torch.nn.init.kaiming_uniform_(self.eta.weight) 412 | #torch.nn.init.kaiming_uniform_(self.xi.weight) 413 | torch.nn.init.constant_(self.eta.weight,init_gamma) 414 | torch.nn.init.constant_(self.xi.weight,init_alpha) 415 | #self.alpha_test = 1/(torch.exp(-self.xi.weight)+1) 416 | self.n_prototypes = n_prototypes 417 | self.alpha = None 418 | 419 | def forward(self, inputs): 420 | gamma=torch.square(self.eta.weight) 421 | alpha=torch.neg(self.xi.weight) 422 | alpha=torch.exp(alpha)+1 423 | alpha=torch.div(1, alpha) 424 | self.alpha=alpha 425 | si=torch.mul(gamma, inputs) 426 | si=torch.neg(si) 427 | si=torch.exp(si) 428 | si = torch.mul(si, alpha) 429 | max_val, max_idx = torch.max(si, dim=-1, keepdim=True) 430 | si /= (max_val + 0.0001) 431 | #print(si.shape,'si') 432 | 433 | return si,alpha 434 | class Belief_layer(torch.nn.Module): 435 | ''' 436 | verified 437 | ''' 438 | def __init__(self, n_prototypes, num_class): 439 | super(Belief_layer, self).__init__() 440 | self.beta = torch.nn.Linear(in_features=n_prototypes, out_features=num_class, bias=False).weight 441 | self.num_class = num_class 442 | def forward(self, inputs): 443 | beta = torch.square(self.beta) 444 | beta_sum = torch.sum(beta, dim=0, keepdim=True) 445 | u = torch.div(beta, beta_sum) 446 | #print(inputs.shape,u.shape) 447 | mass_prototype = torch.einsum('cp,b...p->b...pc',u, inputs) 448 | return mass_prototype 449 | class Omega_layer(torch.nn.Module): 450 | ''' 451 | verified, give same results 452 | ''' 453 | def __init__(self, n_prototypes, num_class): 454 | super(Omega_layer, self).__init__() 455 | self.n_prototypes = n_prototypes 456 | self.num_class = num_class 457 | 458 | def forward(self, inputs): 459 | mass_omega_sum = 1 - torch.sum(inputs, -1, keepdim=True) 460 | #mass_omega_sum = 1. - mass_omega_sum[..., 0] 461 | #mass_omega_sum = torch.unsqueeze(mass_omega_sum, -1) 462 | mass_with_omega = torch.cat([inputs, mass_omega_sum], -1) 463 | return mass_with_omega 464 | class Dempster_layer(torch.nn.Module): 465 | ''' 466 | verified give same results 467 | ''' 468 | def __init__(self, n_prototypes, num_class): 469 | super(Dempster_layer, self).__init__() 470 | self.n_prototypes = n_prototypes 471 | self.num_class = num_class 472 | 473 | def forward(self, inputs): 474 | m1 = inputs[..., 0, :] 475 | omega1 = torch.unsqueeze(inputs[..., 0, -1], -1) 476 | for i in range(self.n_prototypes - 1): 477 | m2 = inputs[..., (i + 1), :] 478 | omega2 = torch.unsqueeze(inputs[..., (i + 1), -1], -1) 479 | combine1 = torch.mul(m1, m2) 480 | combine2 = torch.mul(m1, omega2) 481 | combine3 = torch.mul(omega1, m2) 482 | combine1_2 = combine1 + combine2 483 | combine2_3 = combine1_2 + combine3 484 | combine2_3 = combine2_3 / torch.sum(combine2_3, dim=-1, keepdim=True) 485 | m1 = combine2_3 486 | omega1 = torch.unsqueeze(combine2_3[..., -1], -1) 487 | return m1 488 | class DempsterNormalize_layer(torch.nn.Module): 489 | ''' 490 | verified 491 | ''' 492 | def __init__(self): 493 | super(DempsterNormalize_layer, self).__init__() 494 | def forward(self, inputs): 495 | mass_combine_normalize = inputs / torch.sum(inputs, dim=-1, keepdim=True) 496 | return mass_combine_normalize 497 | 498 | class DS1_activate(nn.Module): 499 | def __init__(self, input_dim): 500 | super(DS1_activate, self).__init__() 501 | self.eta = torch.nn.Linear(input_dim, 1, bias=False)#.weight.data.fill_(torch.from_numpy(np.array(init_gamma)).to(device)) 502 | self.xi = torch.nn.Linear(input_dim, 1, bias=False)#.weight.data.fill_(torch.from_numpy(np.array(init_alpha)).to(device)) 503 | #torch.nn.init.kaiming_uniform_(self.eta.weight) 504 | #torch.nn.init.kaiming_uniform_(self.xi.weight) 505 | torch.nn.init.constant_(self.eta.weight,0.1) 506 | torch.nn.init.constant_(self.xi.weight,0) 507 | #self.xi = nn.Parameter(torch.randn(1, input_dim), requires_grad=True) 508 | #self.eta = nn.Parameter(torch.randn(1, input_dim), requires_grad=True) 509 | self.input_dim = input_dim 510 | 511 | def forward(self, inputs): 512 | gamma = torch.square(self.eta.weight) 513 | alpha = -self.xi.weight 514 | alpha = torch.exp(alpha) + 1 515 | alpha = torch.reciprocal(alpha) 516 | si = gamma * inputs 517 | si = -si 518 | si = torch.exp(si) 519 | si = si * alpha 520 | # si = si / (torch.max(si, dim=-1, keepdim=True)[0] + 0.0001) 521 | return si,alpha 522 | 523 | 524 | 525 | class DS2(nn.Module): 526 | def __init__(self, input_dim, num_class): 527 | super(DS2, self).__init__() 528 | #self.beta = nn.Parameter(torch.randn(input_dim, num_class),requires_grad=True) 529 | self.beta = torch.nn.Linear(num_class, input_dim, bias=False).weight 530 | self.input_dim = input_dim 531 | self.num_class = num_class 532 | 533 | def forward(self, inputs): 534 | beta = torch.square(self.beta) 535 | beta_sum = torch.sum(beta, dim=1, keepdim=True) 536 | u = beta/ beta_sum ##class probability 537 | print(u.shape,'uuuu',u.max(dim=1)[1]) 538 | inputs_new = torch.unsqueeze(inputs, -1) 539 | #print(inputs_new.shape) 540 | a = inputs_new.expand(-1, -1, -1, -1,self.num_class-1) 541 | #print(inputs_new.shape,a.shape) 542 | inputs_new = torch.cat([a, inputs_new], dim=-1) 543 | mass_prototype = None 544 | for i in range(self.input_dim): 545 | mass_prototype_i = torch.mul(u[i, :], inputs_new[:, :, :, i, :]) 546 | mass_prototype_i = torch.unsqueeze(mass_prototype_i, -2) 547 | if mass_prototype is None: 548 | mass_prototype = mass_prototype_i 549 | else: 550 | mass_prototype = torch.cat([mass_prototype, mass_prototype_i], dim=-2) 551 | return mass_prototype 552 | class DS2_omega(nn.Module): 553 | def __init__(self, input_dim, num_class): 554 | super(DS2_omega, self).__init__() 555 | self.input_dim = input_dim 556 | self.num_class = num_class 557 | def forward(self, inputs): 558 | mass_omega_sum = torch.sum(inputs, -1, keepdim=True) 559 | #print(mass_omega_sum.min(),mass_omega_sum.max()) 560 | mass_omega_sum = 1-mass_omega_sum[:, :, :, :, 0] 561 | mass_omega_sum = torch.unsqueeze(mass_omega_sum, -1) 562 | mass_with_omega = torch.cat([inputs, mass_omega_sum], -1) 563 | return mass_with_omega 564 | 565 | 566 | class DS3_Dempster(nn.Module): 567 | def __init__(self, input_dim, num_class): 568 | super(DS3_Dempster, self).__init__() 569 | self.input_dim = input_dim 570 | self.num_class = num_class 571 | 572 | def forward(self, inputs): 573 | m1 = inputs[:, :, :, 0, :] 574 | omega1 = torch.unsqueeze(inputs[:, :, :, 0, -1], -1) 575 | 576 | for i in range(self.input_dim - 1): 577 | m2 = inputs[:, :, :, i + 1, :] 578 | omega2 = torch.unsqueeze(inputs[:, :, :, i + 1, -1], -1) 579 | 580 | combine1 = torch.mul(m1, m2) 581 | combine2 = torch.mul(m1, omega2) 582 | combine3 = torch.mul(omega1, m2) 583 | combine1_2 = torch.add(combine1, combine2) 584 | combine2_3 = torch.add(combine1_2, combine3) 585 | m1 = combine2_3[:, :, :, :-1] 586 | omega1 = torch.mul(omega1, omega2) 587 | #omega1 = 1 - torch.sum(m1,dim=-1,keepdim=True) 588 | print(omega1.max(),'omega',omega1.min()) 589 | m1 = torch.cat([m1, omega1], -1) 590 | return m1 591 | class DM_pignistic(nn.Module): 592 | def __init__(self, num_class): 593 | super(DM_pignistic, self).__init__() 594 | self.num_class = num_class 595 | 596 | def forward(self, inputs): 597 | aveage_Pignistic = torch.div(inputs[:, :, :, -1], self.num_class) 598 | aveage_Pignistic = torch.unsqueeze(aveage_Pignistic, -1) 599 | mass_class = inputs[:, :, :, 0:-1] 600 | Pignistic_prob = torch.add(mass_class, aveage_Pignistic) 601 | 602 | return Pignistic_prob,inputs[:, :, :, -1] 603 | class DS3_normalize(nn.Module): 604 | def __init__(self): 605 | super(DS3_normalize, self).__init__() 606 | 607 | def forward(self, inputs): 608 | mass_combine_normalize = inputs / torch.sum(inputs, dim=-1, keepdim=True) 609 | print(mass_combine_normalize[0,1,1,:]) 610 | return mass_combine_normalize 611 | class Decoder_BF(nn.Module): 612 | def __init__(self, params): 613 | super(Decoder_BF, self).__init__() 614 | self.params = params 615 | self.in_chns = self.params['in_chns'] 616 | self.ft_chns = self.params['feature_chns'] 617 | self.n_class = self.params['class_num'] 618 | self.up_type = self.params['up_type'] 619 | self.pro_num = self.params['pro_num'] 620 | assert (len(self.ft_chns) == 5) 621 | 622 | self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0, mode_upsampling=self.up_type) 623 | self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0, mode_upsampling=self.up_type) 624 | self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0, mode_upsampling=self.up_type) 625 | self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0, mode_upsampling=self.up_type) 626 | 627 | #self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1) 628 | self.DS1 = DS1(self.pro_num, self.ft_chns[0]) 629 | self.DS1_1 = DS1_activate(self.pro_num) 630 | self.DS2 = DS2(self.pro_num,self.n_class) 631 | self.DS2_1 = DS2_omega(self.pro_num,self.n_class) 632 | self.DS3 = DS3_Dempster(self.pro_num,self.n_class) 633 | self.DS3_1 = DM_pignistic(self.n_class) 634 | self.DS_out = DS3_normalize() 635 | def forward(self, feature): 636 | x0 = feature[0] 637 | x1 = feature[1] 638 | x2 = feature[2] 639 | x3 = feature[3] 640 | x4 = feature[4] 641 | 642 | x = self.up1(x4, x3) 643 | x = self.up2(x, x2) 644 | x = self.up3(x, x1) 645 | x = self.up4(x, x0) 646 | x = self.DS1(x) 647 | x, alpha = self.DS1_1(x) 648 | x = self.DS2(x) 649 | x = self.DS2_1(x) 650 | x = self.DS3(x) 651 | x,uncertainty = self.DS3_1(x) 652 | x = self.DS_out(x) 653 | ##print(uncertainty.shape,'feature_shape') 654 | return x, alpha, uncertainty 655 | class UNet_pro(nn.Module): 656 | def __init__(self, in_chns, class_num): 657 | super(UNet_pro, self).__init__() 658 | 659 | params1 = {'in_chns': in_chns, 660 | 'feature_chns': [32, 32, 64, 128, 256], 661 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 662 | 'class_num': class_num, 663 | 'up_type': 1, 664 | 'acti_func': 'relu'} 665 | 666 | self.encoder = Encoder(params1) 667 | self.decoder1 = Decoder_pro(params1) 668 | 669 | def forward(self, x): 670 | feature = self.encoder(x) 671 | output1,self_simi_map, other_simi_map,entropy_weight = self.decoder1(feature) 672 | return output1,self_simi_map, other_simi_map,entropy_weight 673 | def FeatureDropout(x): 674 | attention = torch.mean(x, dim=1, keepdim=True) 675 | max_val, _ = torch.max(attention.view( 676 | x.size(0), -1), dim=1, keepdim=True) 677 | threshold = max_val * np.random.uniform(0.7, 0.9) 678 | threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) 679 | drop_mask = (attention < threshold).float() 680 | x = x.mul(drop_mask) 681 | return x 682 | 683 | 684 | class FeatureNoise(nn.Module): 685 | def __init__(self, uniform_range=0.3): 686 | super(FeatureNoise, self).__init__() 687 | self.uni_dist = Uniform(-uniform_range, uniform_range) 688 | 689 | def feature_based_noise(self, x): 690 | noise_vector = self.uni_dist.sample( 691 | x.shape[1:]).to(x.device).unsqueeze(0) 692 | x_noise = x.mul(noise_vector) + x 693 | return x_noise 694 | 695 | def forward(self, x): 696 | x = self.feature_based_noise(x) 697 | return x 698 | class UNet_CCT(nn.Module): 699 | def __init__(self, in_chns, class_num): 700 | super(UNet_CCT, self).__init__() 701 | 702 | params = {'in_chns': in_chns, 703 | 'feature_chns': [16, 32, 64, 128, 256], 704 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 705 | 'class_num': class_num, 706 | 'bilinear': False, 707 | 'acti_func': 'relu'} 708 | self.encoder = Encoder(params) 709 | self.main_decoder = DecoderCCT(params) 710 | self.aux_decoder1 = DecoderCCT(params) 711 | self.aux_decoder2 = DecoderCCT(params) 712 | self.aux_decoder3 = DecoderCCT(params) 713 | 714 | def forward(self, x): 715 | feature = self.encoder(x) 716 | main_seg = self.main_decoder(feature) 717 | aux1_feature = [FeatureNoise()(i) for i in feature] 718 | aux_seg1 = self.aux_decoder1(aux1_feature) 719 | aux2_feature = [Dropout(i) for i in feature] 720 | aux_seg2 = self.aux_decoder2(aux2_feature) 721 | aux3_feature = [FeatureDropout(i) for i in feature] 722 | aux_seg3 = self.aux_decoder3(aux3_feature) 723 | return main_seg, aux_seg1, aux_seg2, aux_seg3 724 | 725 | 726 | class UNet_URPC(nn.Module): 727 | def __init__(self, in_chns, class_num): 728 | super(UNet_URPC, self).__init__() 729 | 730 | params = {'in_chns': in_chns, 731 | 'feature_chns': [16, 32, 64, 128, 256], 732 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 733 | 'class_num': class_num, 734 | 'bilinear': False, 735 | 'acti_func': 'relu'} 736 | self.encoder = Encoder(params) 737 | self.decoder = Decoder_URPC(params) 738 | 739 | def forward(self, x): 740 | shape = x.shape[2:] 741 | feature = self.encoder(x) 742 | dp1_out_seg, dp2_out_seg, dp3_out_seg, dp4_out_seg = self.decoder( 743 | feature, shape) 744 | return dp1_out_seg, dp2_out_seg, dp3_out_seg, dp4_out_seg 745 | class UNet(nn.Module): 746 | def __init__(self, in_chns, class_num): 747 | super(UNet, self).__init__() 748 | 749 | params1 = {'in_chns': in_chns, 750 | 'feature_chns': [16, 32, 64, 128, 256], 751 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 752 | 'class_num': class_num, 753 | 'up_type': 1, 754 | 'acti_func': 'relu'} 755 | 756 | self.encoder = Encoder(params1) 757 | self.decoder1 = Decoder(params1) 758 | 759 | def forward(self, x): 760 | feature = self.encoder(x) 761 | output1 = self.decoder1(feature) 762 | return output1 763 | class UNet_sdf(nn.Module): 764 | def __init__(self, in_chns, class_num): 765 | super(UNet_sdf, self).__init__() 766 | 767 | params1 = {'in_chns': in_chns, 768 | 'feature_chns': [16, 32, 64, 128, 256], 769 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 770 | 'class_num': class_num, 771 | 'up_type': 1, 772 | 'acti_func': 'relu'} 773 | 774 | self.encoder = Encoder(params1) 775 | self.decoder1 = Decoder_sdf(params1) 776 | 777 | def forward(self, x): 778 | feature = self.encoder(x) 779 | outputsdf,output = self.decoder1(feature) 780 | return outputsdf,output 781 | class BFDCNet2d_v1(nn.Module): 782 | def __init__(self, in_chns, class_num): 783 | super(BFDCNet2d_v1, self).__init__() 784 | 785 | params1 = {'in_chns': in_chns, 786 | 'feature_chns': [16, 32, 64, 128, 256], 787 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 788 | 'class_num': class_num, 789 | 'up_type': 1, 790 | 'acti_func': 'relu'} 791 | params2 = {'in_chns': in_chns, 792 | 'feature_chns': [16, 32, 64, 128, 256], 793 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 794 | 'class_num': class_num, 795 | 'up_type': 0, 796 | 'acti_func': 'relu', 797 | 'pro_num':16} 798 | self.encoder = Encoder(params1) 799 | self.decoder1 = Decoder(params1) 800 | self.decoder2 = Decoder_BF(params2) 801 | 802 | def forward(self, x): 803 | feature = self.encoder(x) 804 | output1 = self.decoder1(feature) 805 | output2,alpha,uncertainty = self.decoder2(feature) 806 | output2 = output2.transpose(3, 2) 807 | output2 = output2.transpose(2,1) 808 | #uncertainty = uncertainty.transpose(3, 2) 809 | #uncertainty = uncertainty.transpose(2,1) 810 | print(uncertainty.max(),output2.shape) 811 | alpha = torch.mean(torch.abs(alpha)) 812 | uncertaintyavg = torch.mean(uncertainty*uncertainty) 813 | return output1, output2,alpha,uncertainty,uncertaintyavg 814 | class MCNet2d_v1(nn.Module): 815 | def __init__(self, in_chns, class_num): 816 | super(MCNet2d_v1, self).__init__() 817 | 818 | params1 = {'in_chns': in_chns, 819 | 'feature_chns': [16, 32, 64, 128, 256], 820 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 821 | 'class_num': class_num, 822 | 'up_type': 1, 823 | 'acti_func': 'relu'} 824 | params2 = {'in_chns': in_chns, 825 | 'feature_chns': [16, 32, 64, 128, 256], 826 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 827 | 'class_num': class_num, 828 | 'up_type': 0, 829 | 'acti_func': 'relu'} 830 | self.encoder = Encoder(params1) 831 | self.decoder1 = Decoder(params1) 832 | self.decoder2 = Decoder(params2) 833 | 834 | def forward(self, x): 835 | feature = self.encoder(x) 836 | output1 = self.decoder1(feature) 837 | output2 = self.decoder2(feature) 838 | return output1, output2 839 | 840 | class MCNet2d_v2(nn.Module): 841 | def __init__(self, in_chns, class_num): 842 | super(MCNet2d_v2, self).__init__() 843 | 844 | params1 = {'in_chns': in_chns, 845 | 'feature_chns': [16, 32, 64, 128, 256], 846 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 847 | 'class_num': class_num, 848 | 'up_type': 1, 849 | 'acti_func': 'relu'} 850 | params2 = {'in_chns': in_chns, 851 | 'feature_chns': [16, 32, 64, 128, 256], 852 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 853 | 'class_num': class_num, 854 | 'up_type': 0, 855 | 'acti_func': 'relu'} 856 | params3 = {'in_chns': in_chns, 857 | 'feature_chns': [16, 32, 64, 128, 256], 858 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 859 | 'class_num': class_num, 860 | 'up_type': 2, 861 | 'acti_func': 'relu'} 862 | self.encoder = Encoder(params1) 863 | self.decoder1 = Decoder(params1) 864 | self.decoder2 = Decoder(params2) 865 | self.decoder3 = Decoder(params3) 866 | 867 | def forward(self, x): 868 | feature = self.encoder(x) 869 | output1 = self.decoder1(feature) 870 | output2 = self.decoder2(feature) 871 | output3 = self.decoder3(feature) 872 | return output1, output2, output3 873 | class DecoderDiceCE(nn.Module): 874 | def __init__(self, params): 875 | super(DecoderDiceCE, self).__init__() 876 | self.params = params 877 | self.in_chns = self.params['in_chns'] 878 | self.ft_chns = self.params['feature_chns'] 879 | self.n_class = self.params['class_num'] 880 | self.up_type = self.params['up_type'] 881 | assert (len(self.ft_chns) == 5) 882 | 883 | self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0, mode_upsampling=self.up_type) 884 | self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0, mode_upsampling=self.up_type) 885 | self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0, mode_upsampling=self.up_type) 886 | self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0, mode_upsampling=self.up_type) 887 | 888 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1) 889 | 890 | def forward(self, feature): 891 | x0 = feature[0] 892 | x1 = feature[1] 893 | x2 = feature[2] 894 | x3 = feature[3] 895 | x4 = feature[4] 896 | 897 | x = self.up1(x4, x3) 898 | x = self.up2(x, x2) 899 | x = self.up3(x, x1) 900 | x = self.up4(x, x0) 901 | output = self.out_conv(x) 902 | return output,x 903 | def DS_Combin_two(alpha1, alpha2,class_num): 904 | """ 905 | :param alpha1: Dirichlet distribution parameters of view 1 906 | :param alpha2: Dirichlet distribution parameters of view 2 907 | :return: Combined Dirichlet distribution parameters 908 | """ 909 | alpha = dict() 910 | alpha[0], alpha[1] = alpha1,alpha2 911 | b, S, E, u = dict(), dict(), dict(), dict() 912 | for v in range(2): 913 | S[v] = torch.sum(alpha[v], dim=1, keepdim=True) 914 | E[v] = alpha[v] - 1 915 | b[v] = E[v] / (S[v].expand(E[v].shape)) 916 | #print(b[v].shape) 917 | u[v] = class_num / S[v]#B*C*1 918 | 919 | # b^0 @ b^(0+1) 920 | bb = torch.bmm(b[0].view(-1, class_num, 1), b[1].view(-1, 1, class_num )) 921 | # b^0 * u^1 922 | uv1_expand = u[1].expand(b[1].shape) #B*C*1 923 | #print(uv1_expand.shape,'uv1') 924 | bu = torch.mul(b[0], uv1_expand)#B*C*1 925 | 926 | # b^1 * u^0 927 | uv2_expand = u[0].expand(b[0].shape) 928 | ub = torch.mul(b[1], uv2_expand) 929 | # calculate C 930 | bb_sum = torch.sum(bb, dim=(1, 2), out=None)#B 931 | #print(bb.shape, 'bb_sum',torch.diagonal(bb, dim1=-2, dim2=-1)) 932 | bb_diag = torch.diagonal(bb, dim1=-2, dim2=-1).sum(-1) 933 | C = bb_sum - bb_diag 934 | 935 | # calculate b^a 936 | b_a = (torch.mul(b[0], b[1]) + bu + ub) / ((1 - C).view(-1,1).expand(b[0].shape)) 937 | # calculate u^a 938 | u_a = torch.mul(u[0], u[1]) / ((1 - C).view(-1,1).expand(u[0].shape)) 939 | 940 | # calculate new S 941 | S_a = class_num / u_a 942 | # calculate new e_k 943 | e_a = torch.mul(b_a, S_a.expand(b_a.shape)) 944 | alpha_a = e_a + 1 945 | return alpha_a 946 | class DiceCENet2d_fuse(nn.Module): 947 | def __init__(self, in_chns, class_num): 948 | super(DiceCENet2d_fuse, self).__init__() 949 | 950 | params1 = {'in_chns': in_chns, 951 | 'feature_chns': [16, 32, 64, 128, 256], 952 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 953 | 'class_num': class_num, 954 | 'up_type': 1, 955 | 'acti_func': 'relu'} 956 | self.params = params1 957 | self.ft_chns = self.params['feature_chns'] 958 | self.n_class = self.params['class_num'] 959 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=1, padding=0) 960 | self.encoder = Encoder(params1) 961 | self.decoder1 = DecoderDiceCE(params1) 962 | self.decoder2 = DecoderDiceCE(params1) 963 | self.decoder3 = DecoderDiceCE(params1) 964 | self.dropout = nn.Dropout2d(p=0.01, inplace=False) 965 | self.block_nine= ConvBlock_2d(2,self.ft_chns[0], self.ft_chns[0]) 966 | self.BN = nn.ReLU() 967 | def forward(self, x): 968 | feature = self.encoder(x) 969 | out_seg1,x_up1 = self.decoder1(feature) 970 | out_seg2,x_up2 = self.decoder2(feature) 971 | out_seg3,x_up3 = self.decoder3(feature) 972 | evidence1 = F.softplus(out_seg1) 973 | alpha1 = evidence1+1 974 | evidence2 = F.softplus(out_seg2) 975 | alpha2 = evidence2 + 1 976 | prob2 = alpha2/torch.sum(alpha2,dim=1,keepdim=True) 977 | resize_alpha1 = alpha1.view(alpha1.size(0), alpha1.size(1), -1) # [N, C, HW] 978 | resize_alpha1 = resize_alpha1.transpose(1, 2) # [N, HW, C] 979 | resize_alpha1 = resize_alpha1.contiguous().view(-1, resize_alpha1.size(2)) 980 | resize_alpha2 = alpha2.view(alpha2.size(0), alpha2.size(1), -1) # [N, C, HW] 981 | resize_alpha2 = resize_alpha2.transpose(1, 2) # [N, HW, C] 982 | resize_alpha2 = resize_alpha2.contiguous().view(-1, resize_alpha2.size(2)) 983 | fuse_out_sup = DS_Combin_two( resize_alpha1, resize_alpha2,4) 984 | fuse_out_sup = fuse_out_sup/torch.sum(fuse_out_sup,dim=1,keepdim=True) 985 | #fuse_out =self.out_conv(((self.block_nine(((x_up1))+((x_up2)))))) 986 | fuse_out = F.softplus(out_seg3) 987 | fuse_out = fuse_out+1 988 | #fuse_out = fuse_out.view(fuse_out.size(0), fuse_out.size(1), -1) # [N, C, HW] 989 | #fuse_out = fuse_out.transpose(1, 2) # [N, HW, C] 990 | #fuse_out = fuse_out.contiguous().view(-1, fuse_out.size(2)) 991 | alpha3_prob = fuse_out/torch.sum(fuse_out,dim=1,keepdim=True) 992 | #print(fuse_out.max(), fuse_out.min()) 993 | return alpha1,alpha2,prob2,alpha3_prob,fuse_out_sup,fuse_out 994 | class MCNet2d_v3(nn.Module): 995 | def __init__(self, in_chns, class_num): 996 | super(MCNet2d_v3, self).__init__() 997 | 998 | params1 = {'in_chns': in_chns, 999 | 'feature_chns': [16, 32, 64, 128, 256], 1000 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 1001 | 'class_num': class_num, 1002 | 'up_type': 1, 1003 | 'acti_func': 'relu'} 1004 | params2 = {'in_chns': in_chns, 1005 | 'feature_chns': [16, 32, 64, 128, 256], 1006 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 1007 | 'class_num': class_num, 1008 | 'up_type': 0, 1009 | 'acti_func': 'relu'} 1010 | params3 = {'in_chns': in_chns, 1011 | 'feature_chns': [16, 32, 64, 128, 256], 1012 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 1013 | 'class_num': class_num, 1014 | 'up_type': 2, 1015 | 'acti_func': 'relu'} 1016 | params4 = {'in_chns': in_chns, 1017 | 'feature_chns': [16, 32, 64, 128, 256], 1018 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 1019 | 'class_num': class_num, 1020 | 'up_type': 3, 1021 | 'acti_func': 'relu'} 1022 | self.encoder = Encoder(params1) 1023 | self.decoder1 = Decoder(params1) 1024 | self.decoder2 = Decoder(params2) 1025 | self.decoder3 = Decoder(params3) 1026 | self.decoder4 = Decoder(params4) 1027 | 1028 | def forward(self, x): 1029 | feature = self.encoder(x) 1030 | output1 = self.decoder1(feature) 1031 | output2 = self.decoder2(feature) 1032 | output3 = self.decoder3(feature) 1033 | output4 = self.decoder4(feature) 1034 | return output1, output2, output3, output4 1035 | 1036 | if __name__ == '__main__': 1037 | # compute FLOPS & PARAMETERS 1038 | from ptflops import get_model_complexity_info 1039 | model = UNet(in_chns=1, class_num=4).cuda() 1040 | with torch.cuda.device(0): 1041 | macs, params = get_model_complexity_info(model, (1, 256, 256), as_strings=True, 1042 | print_per_layer_stat=True, verbose=True) 1043 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 1044 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 1045 | import ipdb; ipdb.set_trace() 1046 | --------------------------------------------------------------------------------