├── code ├── networks │ ├── net_factory_3d.py │ ├── net_factory_3dArgs.py │ ├── discriminator.py │ ├── attention.py │ ├── unet_3D.py │ ├── VoxResNet.py │ ├── vision_transformer.py │ ├── unet_3D_dv_semi.py │ ├── pnet.py │ ├── net_factory.py │ ├── net_factory_args.py │ ├── net_factory_args_HAR.py │ ├── attention_unet.py │ ├── encoder_tool.py │ ├── config.py │ ├── efficientunet.py │ ├── vnet.py │ └── vnetWithArgs.py ├── utils │ ├── metrics.py │ ├── ramps.py │ ├── deepcluster_vgg16.py │ ├── util.py │ └── losses.py ├── test_3D.py ├── train_3D.sh ├── train_2D.sh ├── dataloaders │ ├── dataset_synapse.py │ ├── utils.py │ └── la_heart.py ├── test_2D.py ├── loss.py ├── test.py ├── build_dataset.py ├── test_util.py ├── model.py └── augment_3d.py ├── README.md └── .gitignore /code/networks/net_factory_3d.py: -------------------------------------------------------------------------------- 1 | from networks.unet_3D import unet_3D 2 | from networks.vnet import VNet 3 | from networks.VoxResNet import VoxResNet 4 | from networks.attention_unet import Attention_UNet 5 | from networks.nnunet import initialize_network 6 | 7 | 8 | def net_factory_3d(net_type="unet_3D", in_chns=1, class_num=2): 9 | if net_type == "unet_3D": 10 | net = unet_3D(n_classes=class_num, in_channels=in_chns).cuda() 11 | elif net_type == "attention_unet": 12 | net = Attention_UNet(n_classes=class_num, in_channels=in_chns).cuda() 13 | elif net_type == "voxresnet": 14 | net = VoxResNet(in_chns=in_chns, feature_chns=64, 15 | class_num=class_num).cuda() 16 | elif net_type == "vnet": 17 | net = VNet(n_channels=in_chns, n_classes=class_num, 18 | normalization='batchnorm', has_dropout=True).cuda() 19 | elif net_type == "nnUNet": 20 | net = initialize_network(num_classes=class_num).cuda() 21 | else: 22 | net = None 23 | return net 24 | -------------------------------------------------------------------------------- /code/networks/net_factory_3dArgs.py: -------------------------------------------------------------------------------- 1 | from networks.unet_3D import unet_3D 2 | from networks.vnetWithArgs import VNet 3 | from networks.VoxResNet import VoxResNet 4 | from networks.attention_unet import Attention_UNet 5 | from networks.nnunet import initialize_network 6 | 7 | 8 | def net_factory_3d(net_type="unet_3D", in_chns=1, class_num=2): 9 | if net_type == "unet_3D": 10 | net = unet_3D(n_classes=class_num, in_channels=in_chns) # .cuda() 11 | elif net_type == "attention_unet": 12 | net = Attention_UNet(n_classes=class_num, in_channels=in_chns).cuda() 13 | elif net_type == "voxresnet": 14 | net = VoxResNet(in_chns=in_chns, feature_chns=64, 15 | class_num=class_num).cuda() 16 | elif net_type == "vnet": 17 | net = VNet(n_channels=in_chns, n_classes=class_num, 18 | normalization='batchnorm', has_dropout=True) # .cuda() 19 | elif net_type == "nnUNet": 20 | net = initialize_network(num_classes=class_num).cuda() 21 | else: 22 | net = None 23 | return net 24 | -------------------------------------------------------------------------------- /code/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from medpy import metric 3 | 4 | 5 | def cal_dice(prediction, label, num=2): 6 | total_dice = np.zeros(num-1) 7 | for i in range(1, num): 8 | prediction_tmp = (prediction == i) 9 | label_tmp = (label == i) 10 | prediction_tmp = prediction_tmp.astype(np.float) 11 | label_tmp = label_tmp.astype(np.float) 12 | 13 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 14 | total_dice[i - 1] += dice 15 | 16 | return total_dice 17 | 18 | 19 | def calculate_metric_percase(pred, gt): 20 | dc = metric.binary.dc(pred, gt) 21 | jc = metric.binary.jc(pred, gt) 22 | hd = metric.binary.hd95(pred, gt) 23 | asd = metric.binary.asd(pred, gt) 24 | 25 | return dc, jc, hd, asd 26 | 27 | 28 | def dice(input, target, ignore_index=None): 29 | smooth = 1. 30 | # using clone, so that it can do change to original target. 31 | iflat = input.clone().view(-1) 32 | tflat = target.clone().view(-1) 33 | if ignore_index is not None: 34 | mask = tflat == ignore_index 35 | tflat[mask] = 0 36 | iflat[mask] = 0 37 | intersection = (iflat * tflat).sum() 38 | 39 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) -------------------------------------------------------------------------------- /code/test_3D.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from networks.vnet import VNet 5 | from test_util import test_all_case 6 | # from model_ISD import * 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--root_path', type=str, default='/data/data/2018LA_Seg_Training/2018LA_Seg_Training Set', help='Name of Experiment') 10 | parser.add_argument('--model', type=str, default='LA/RECO_1_labeledfinal/vnet', help='model_name') 11 | parser.add_argument('--gpu', type=str, default='1', help='GPU to use') 12 | FLAGS = parser.parse_args() 13 | 14 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 15 | snapshot_path = "../model/"+FLAGS.model+"/" 16 | test_save_path = "../model/prediction/"+FLAGS.model+"_post/" 17 | if not os.path.exists(test_save_path): 18 | os.makedirs(test_save_path) 19 | 20 | num_classes = 2 21 | 22 | with open(FLAGS.root_path + '/../test.list', 'r') as f: 23 | image_list = f.readlines() 24 | image_list = [FLAGS.root_path +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list] 25 | 26 | 27 | def test_calculate_metric(epoch_num): 28 | with torch.no_grad(): 29 | net = VNet(n_channels=1, n_classes=2, normalization='batchnorm', has_dropout=False).cuda() 30 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(epoch_num) + '.pth') 31 | net.load_state_dict(torch.load(save_mode_path)) 32 | print("init weight from {}".format(save_mode_path)) 33 | 34 | avg_metric = test_all_case(net, image_list, num_classes=num_classes, 35 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 36 | save_result=True, test_save_path=test_save_path) 37 | 38 | return avg_metric 39 | 40 | 41 | if __name__ == '__main__': 42 | metric = test_calculate_metric(30000) 43 | print(metric) -------------------------------------------------------------------------------- /code/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | 43 | 44 | def exp_rampup(rampup_length): 45 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 46 | def warpper(epoch): 47 | if epoch < rampup_length: 48 | epoch = np.clip(epoch, 0.0, rampup_length) 49 | phase = 1.0 - epoch / rampup_length 50 | return float(np.exp(-5.0 * phase * phase)) 51 | else: 52 | return 1.0 53 | return warpper -------------------------------------------------------------------------------- /code/train_3D.sh: -------------------------------------------------------------------------------- 1 | # training action 2 | # for pretraining stage 1 3 | CUDA_VISIBLE_DEVICES=1 python train_3D_pretrain.py \ 4 | --train_encoder 1 \ 5 | --train_decoder 0 \ 6 | --K 36 \ 7 | --exp LA/pretrain \ 8 | --k1 1 \ 9 | --k2 1 \ 10 | --latent_pooling_size 1 \ 11 | --latent_feature_size 128 \ 12 | --output_pooling_size 4 \ 13 | --T_s 0.1 \ 14 | --T_t 0.6 \ 15 | --max_iterations 30000 \ 16 | --labeled_num 1 \ 17 | --num_classes 2 18 | 19 | 20 | # for pretraining stage 2 21 | CUDA_VISIBLE_DEVICES=1 python train_3D_pretrain.py \ 22 | --train_encoder 1 \ 23 | --train_decoder 1 \ 24 | --K 36 \ 25 | --exp LA/pretrain \ 26 | --resume LA/pretrain \ 27 | --k1 1 \ 28 | --k2 1 \ 29 | --latent_pooling_size 1 \ 30 | --latent_feature_size 128 \ 31 | --output_pooling_size 4 \ 32 | --T_s 0.1 \ 33 | --T_t 0.6 \ 34 | --max_iterations 30000 \ 35 | --labeled_num 1 \ 36 | --num_classes 2 37 | 38 | 39 | # for finetuning 40 | CUDA_VISIBLE_DEVICES=1 python train_3D_action.py \ 41 | --resume LA/pretrain \ 42 | --exp LA/action \ 43 | --max_iterations 30000 \ 44 | --labeled_num 1 \ 45 | --batch_size 1 \ 46 | --num_classes 2 47 | 48 | 49 | # for action++ 50 | # pretraining 51 | CUDA_VISIBLE_DEVICES=1 python train_3D_pretrain++.py \ 52 | --train_encoder 1 \ 53 | --train_decoder 1 \ 54 | --K 36 \ 55 | --exp LA/pretrain \ 56 | --k1 1 \ 57 | --k2 1 \ 58 | --latent_pooling_size 1 \ 59 | --latent_feature_size 128 \ 60 | --output_pooling_size 4 \ 61 | --T_s 0.1 \ 62 | --T_t 0.6 \ 63 | --max_iterations 30000 \ 64 | --labeled_num 1 \ 65 | --num_classes 2 66 | 67 | # for finetuning 68 | CUDA_VISIBLE_DEVICES=1 python train_3D_action.py \ 69 | --resume LA/pretrain++ \ 70 | --exp LA/action++ \ 71 | --max_iterations 30000 \ 72 | --labeled_num 1 \ 73 | --batch_size 1 \ 74 | --num_classes 2 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ACTION Family 2 | 3 | This is the official PyTorch implementation of our IPMI 2023 [![arXiv](https://img.shields.io/badge/arXiv-2206.02307-b31b1b.svg)](https://arxiv.org/abs/2206.02307) and MICCAI 2023 [![arXiv](https://img.shields.io/badge/arXiv-2304.02689-b31b1b.svg)](https://arxiv.org/abs/2304.02689) papers by [Chenyu You](http://chenyuyou.me/), [Weicheng Dai](https://weichengdai1.github.io/), [Yifei Min](https://scholar.google.com/citations?user=pFWnzL0AAAAJ&hl=en/), [Lawrence Staib](https://medicine.yale.edu/profile/lawrence-staib/), [Jasjeet S. Sekhon](https://www.jsekhon.com/), and [James S. Duncan](https://medicine.yale.edu/profile/james-duncan/): 4 | 5 | > [**Bootstrapping Semi-supervised Medical Image Segmentation with Anatomical-Aware Contrastive Distillation**](https://arxiv.org/abs/2206.02307) 6 | > Chenyu You, Weicheng Dai, Yifei Min, Lawrence Staib, James S. Duncan 7 | > *In International Conference on Information Processing in Medical Imaging (IPMI), 2023* 8 | 9 | We have another improved 2D/3D framework **ACTION++** accepted by MICCAI 2023: 10 | 11 | > **[ACTION++: Improving Semi-supervised Medical Image Segmentation with Adaptive Anatomical Contrast](https://arxiv.org/abs/2304.02689)**
12 | > Chenyu You, Weicheng Dai, Yifei Min, Lawrence Staib, Jasjeet S. Sekhon, James S. Duncan
13 | > *In International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI), 2023* [Early Accept] 14 | 15 | 16 | ## Citation 17 | 18 | If you find this project useful, please consider citing: 19 | 20 | ```bibtex 21 | @inproceedings{you2023bootstrapping, 22 | title={Bootstrapping semi-supervised medical image segmentation with anatomical-aware contrastive distillation}, 23 | author={You, Chenyu and Dai, Weicheng and Min, Yifei and Staib, Lawrence and Duncan, James S}, 24 | booktitle={IPMI}, 25 | year={2023} 26 | } 27 | 28 | @inproceedings{you2023actionplus, 29 | title={Action++: Improving semi-supervised medical image segmentation with adaptive anatomical contrast}, 30 | author={You, Chenyu and Dai, Weicheng and Min, Yifei and Staib, Lawrence and Sekhon, Jas and Duncan, James S}, 31 | booktitle={MICCAI}, 32 | year={2023} 33 | } 34 | ``` 35 | 36 | 37 | -------------------------------------------------------------------------------- /code/train_2D.sh: -------------------------------------------------------------------------------- 1 | 2 | # training action 3 | # for pretraining stage 1 4 | CUDA_VISIBLE_DEVICES=1 python train_2D_pretrain.py \ 5 | --train_encoder 1 \ 6 | --train_decoder 0 \ 7 | --K 36 \ 8 | --root_path /data/data/ACDC \ 9 | --exp ACDC/pretrain \ 10 | --k1 1 \ 11 | --k2 1 \ 12 | --latent_pooling_size 1 \ 13 | --latent_feature_size 512 \ 14 | --output_pooling_size 8 \ 15 | --T_s 0.1 \ 16 | --T_t 0.01 \ 17 | --max_iterations 30000 \ 18 | --labeled_num 1 \ 19 | --num_classes 4 20 | 21 | # for pretraining stage 2 22 | CUDA_VISIBLE_DEVICES=1 python train_2D_pretrain.py \ 23 | --train_encoder 1 \ 24 | --train_decoder 1 \ 25 | --K 36 \ 26 | --root_path /data/data/ACDC \ 27 | --exp ACDC/pretrain \ 28 | --resume ACDC/pretrain \ 29 | --k1 1 \ 30 | --k2 1 \ 31 | --latent_pooling_size 1 \ 32 | --latent_feature_size 512 \ 33 | --output_pooling_size 8 \ 34 | --T_s 0.1 \ 35 | --T_t 0.01 \ 36 | --max_iterations 30000 \ 37 | --labeled_num 1 \ 38 | --num_classes 4 39 | 40 | # for finetuning 41 | CUDA_VISIBLE_DEVICES=1 python train_2D_action.py \ 42 | --root_path /data/data/ACDC \ 43 | --exp ACDC/action \ 44 | --resume ACDC/pretrain \ 45 | --batch_size 4 \ 46 | --max_iterations 30000 \ 47 | --apply_aug cutmix \ 48 | --labeled_num 1 \ 49 | --num_classes 4 50 | 51 | 52 | # for action++ 53 | # pretraining 54 | CUDA_VISIBLE_DEVICES=1 python train_2D_pretrain++.py \ 55 | --train_encoder 1 \ 56 | --train_decoder 1 \ 57 | --K 36 \ 58 | --root_path /data/data/ACDC \ 59 | --exp ACDC/pretrain++ \ 60 | --k1 1 \ 61 | --k2 1 \ 62 | --latent_pooling_size 1 \ 63 | --latent_feature_size 512 \ 64 | --output_pooling_size 8 \ 65 | --T_s 0.1 \ 66 | --T_t 0.01 \ 67 | --max_iterations 30000 \ 68 | --labeled_num 1 \ 69 | --num_classes 4 70 | 71 | # for finetuning 72 | CUDA_VISIBLE_DEVICES=1 python train_2D_action++.py \ 73 | --root_path /data/data/ACDC \ 74 | --exp ACDC/action++ \ 75 | --resume ACDC/pretrain++ \ 76 | --batch_size 4 \ 77 | --max_iterations 30000 \ 78 | --apply_aug cutmix \ 79 | --labeled_num 1 \ 80 | --num_classes 4 -------------------------------------------------------------------------------- /code/dataloaders/dataset_synapse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import h5py 4 | import numpy as np 5 | import torch 6 | from scipy import ndimage 7 | from scipy.ndimage.interpolation import zoom 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def random_rot_flip(image, label): 12 | k = np.random.randint(0, 4) 13 | image = np.rot90(image, k) 14 | label = np.rot90(label, k) 15 | axis = np.random.randint(0, 2) 16 | image = np.flip(image, axis=axis).copy() 17 | label = np.flip(label, axis=axis).copy() 18 | return image, label 19 | 20 | 21 | def random_rotate(image, label): 22 | angle = np.random.randint(-20, 20) 23 | image = ndimage.rotate(image, angle, order=0, reshape=False) 24 | label = ndimage.rotate(label, angle, order=0, reshape=False) 25 | return image, label 26 | 27 | 28 | class RandomGenerator(object): 29 | def __init__(self, output_size): 30 | self.output_size = output_size 31 | 32 | def __call__(self, sample): 33 | image, label = sample['image'], sample['label'] 34 | 35 | if random.random() > 0.5: 36 | image, label = random_rot_flip(image, label) 37 | elif random.random() > 0.5: 38 | image, label = random_rotate(image, label) 39 | x, y = image.shape 40 | if x != self.output_size[0] or y != self.output_size[1]: 41 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3) # why not 3? 42 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 43 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 44 | label = torch.from_numpy(label.astype(np.float32)) 45 | sample = {'image': image, 'label': label.long()} 46 | return sample 47 | 48 | 49 | class Synapse_dataset(Dataset): 50 | def __init__(self, base_dir, list_dir, split, transform=None): 51 | self.transform = transform # using transform in torch! 52 | self.split = split 53 | if (split == "test" or split == 'val'): 54 | self.sample_list = open(os.path.join(list_dir, self.split+'_vol_40.txt')).readlines() 55 | else: 56 | self.sample_list = open(os.path.join(list_dir, self.split+'_40.txt')).readlines() 57 | self.data_dir = base_dir 58 | 59 | def __len__(self): 60 | return len(self.sample_list) 61 | 62 | def __getitem__(self, idx): 63 | if self.split == "train": 64 | slice_name = self.sample_list[idx].strip('\n') 65 | data_path = os.path.join(self.data_dir, slice_name+'.npz') 66 | # print(data_path) 67 | data = np.load(data_path) 68 | image, label = data['image'], data['label'] 69 | else: 70 | vol_name = self.sample_list[idx].strip('\n') 71 | filepath = self.data_dir + "_40/{}.h5".format(vol_name) 72 | data = h5py.File(filepath) 73 | image, label = data['image'][:], data['label'][:] 74 | 75 | sample = {'image': image, 'label': label} 76 | if self.transform: 77 | sample = self.transform(sample) 78 | sample['case_name'] = self.sample_list[idx].strip('\n') 79 | return sample -------------------------------------------------------------------------------- /code/utils/deepcluster_vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from random import random as rd 5 | 6 | __all__ = [ 'VGG', 'vgg16'] 7 | 8 | 9 | class VGG(nn.Module): 10 | 11 | def __init__(self, features, num_classes, sobel): 12 | super(VGG, self).__init__() 13 | self.features = features 14 | self.classifier = nn.Sequential( 15 | nn.Linear(512 * 7 * 7, 4096), 16 | nn.ReLU(True), 17 | nn.Dropout(0.5), 18 | nn.Linear(4096, 4096), 19 | nn.ReLU(True) 20 | ) 21 | self.top_layer = nn.Linear(4096, num_classes) 22 | self._initialize_weights() 23 | if sobel: 24 | grayscale = nn.Conv2d(3, 1, kernel_size=1, stride=1, padding=0) 25 | grayscale.weight.data.fill_(1.0 / 3.0) 26 | grayscale.bias.data.zero_() 27 | sobel_filter = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1) 28 | sobel_filter.weight.data[0,0].copy_( 29 | torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) 30 | ) 31 | sobel_filter.weight.data[1,0].copy_( 32 | torch.FloatTensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) 33 | ) 34 | sobel_filter.bias.data.zero_() 35 | self.sobel = nn.Sequential(grayscale, sobel_filter) 36 | for p in self.sobel.parameters(): 37 | p.requires_grad = False 38 | else: 39 | self.sobel = None 40 | 41 | def forward(self, x): 42 | if self.sobel: 43 | x = self.sobel(x) 44 | x = self.features(x) 45 | x = x.view(x.size(0), -1) 46 | x = self.classifier(x) 47 | if self.top_layer: 48 | x = self.top_layer(x) 49 | return x 50 | 51 | def _initialize_weights(self): 52 | for y,m in enumerate(self.modules()): 53 | if isinstance(m, nn.Conv2d): 54 | #print(y) 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | for i in range(m.out_channels): 57 | m.weight.data[i].normal_(0, math.sqrt(2. / n)) 58 | if m.bias is not None: 59 | m.bias.data.zero_() 60 | elif isinstance(m, nn.BatchNorm2d): 61 | m.weight.data.fill_(1) 62 | m.bias.data.zero_() 63 | elif isinstance(m, nn.Linear): 64 | m.weight.data.normal_(0, 0.01) 65 | m.bias.data.zero_() 66 | 67 | 68 | def make_layers(input_dim, batch_norm): 69 | layers = [] 70 | in_channels = input_dim 71 | cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 72 | for v in cfg: 73 | if v == 'M': 74 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 75 | else: 76 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 77 | if batch_norm: 78 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 79 | else: 80 | layers += [conv2d, nn.ReLU(inplace=True)] 81 | in_channels = v 82 | return nn.Sequential(*layers) 83 | 84 | 85 | def vgg16(sobel=False, bn=True, out=1000): 86 | dim = 2 + int(not sobel) 87 | model = VGG(make_layers(dim, bn), out, sobel) 88 | return model -------------------------------------------------------------------------------- /code/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FC3DDiscriminator(nn.Module): 7 | 8 | def __init__(self, num_classes, ndf=64, n_channel=1): 9 | super(FC3DDiscriminator, self).__init__() 10 | # downsample 16 11 | self.conv0 = nn.Conv3d( 12 | num_classes, ndf, kernel_size=4, stride=2, padding=1) 13 | self.conv1 = nn.Conv3d( 14 | n_channel, ndf, kernel_size=4, stride=2, padding=1) 15 | 16 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 17 | self.conv3 = nn.Conv3d( 18 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 19 | self.conv4 = nn.Conv3d( 20 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 21 | self.avgpool = nn.AvgPool3d((6, 6, 6)) # (D/16, W/16, H/16) 22 | self.classifier = nn.Linear(ndf*8, 2) 23 | 24 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 25 | self.dropout = nn.Dropout3d(0.5) 26 | self.Softmax = nn.Softmax() 27 | 28 | def forward(self, map, image): 29 | batch_size = map.shape[0] 30 | map_feature = self.conv0(map) 31 | image_feature = self.conv1(image) 32 | x = torch.add(map_feature, image_feature) 33 | x = self.leaky_relu(x) 34 | x = self.dropout(x) 35 | 36 | x = self.conv2(x) 37 | x = self.leaky_relu(x) 38 | x = self.dropout(x) 39 | 40 | x = self.conv3(x) 41 | x = self.leaky_relu(x) 42 | x = self.dropout(x) 43 | 44 | x = self.conv4(x) 45 | x = self.leaky_relu(x) 46 | 47 | x = self.avgpool(x) 48 | 49 | x = x.view(batch_size, -1) 50 | 51 | x = self.classifier(x) 52 | x = x.reshape((batch_size, 2)) 53 | # x = self.Softmax(x) 54 | 55 | return x 56 | 57 | 58 | class FCDiscriminator(nn.Module): 59 | 60 | def __init__(self, num_classes, ndf=64, n_channel=1): 61 | super(FCDiscriminator, self).__init__() 62 | self.conv0 = nn.Conv2d( 63 | num_classes, ndf, kernel_size=4, stride=2, padding=1) 64 | self.conv1 = nn.Conv2d( 65 | n_channel, ndf, kernel_size=4, stride=2, padding=1) 66 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 67 | self.conv3 = nn.Conv2d( 68 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 69 | self.conv4 = nn.Conv2d( 70 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 71 | self.classifier = nn.Linear(ndf*32, 2) 72 | self.avgpool = nn.AvgPool2d((7, 7)) 73 | 74 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 75 | self.dropout = nn.Dropout2d(0.5) 76 | # self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 77 | # self.sigmoid = nn.Sigmoid() 78 | 79 | def forward(self, map, feature): 80 | map_feature = self.conv0(map) 81 | image_feature = self.conv1(feature) 82 | x = torch.add(map_feature, image_feature) 83 | 84 | x = self.conv2(x) 85 | x = self.leaky_relu(x) 86 | x = self.dropout(x) 87 | 88 | x = self.conv3(x) 89 | x = self.leaky_relu(x) 90 | x = self.dropout(x) 91 | 92 | x = self.conv4(x) 93 | x = self.leaky_relu(x) 94 | x = self.avgpool(x) 95 | x = x.view(x.size(0), -1) 96 | x = self.classifier(x) 97 | # x = self.up_sample(x) 98 | # x = self.sigmoid(x) 99 | 100 | return x 101 | -------------------------------------------------------------------------------- /code/networks/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | try: 4 | from inplace_abn import InPlaceABN 5 | except ImportError: 6 | InPlaceABN = None 7 | 8 | 9 | class Conv2dReLU(nn.Sequential): 10 | def __init__( 11 | self, 12 | in_channels, 13 | out_channels, 14 | kernel_size, 15 | padding=0, 16 | stride=1, 17 | use_batchnorm=True, 18 | ): 19 | 20 | if use_batchnorm == "inplace" and InPlaceABN is None: 21 | raise RuntimeError( 22 | "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " 23 | + "To install see: https://github.com/mapillary/inplace_abn" 24 | ) 25 | 26 | super().__init__() 27 | 28 | conv = nn.Conv2d( 29 | in_channels, 30 | out_channels, 31 | kernel_size, 32 | stride=stride, 33 | padding=padding, 34 | bias=not (use_batchnorm), 35 | ) 36 | relu = nn.ReLU(inplace=True) 37 | 38 | if use_batchnorm == "inplace": 39 | bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) 40 | relu = nn.Identity() 41 | 42 | elif use_batchnorm and use_batchnorm != "inplace": 43 | bn = nn.BatchNorm2d(out_channels) 44 | 45 | else: 46 | bn = nn.Identity() 47 | 48 | super(Conv2dReLU, self).__init__(conv, bn, relu) 49 | 50 | 51 | class SCSEModule(nn.Module): 52 | def __init__(self, in_channels, reduction=16): 53 | super().__init__() 54 | self.cSE = nn.Sequential( 55 | nn.AdaptiveAvgPool2d(1), 56 | nn.Conv2d(in_channels, in_channels // reduction, 1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(in_channels // reduction, in_channels, 1), 59 | nn.Sigmoid(), 60 | ) 61 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) 62 | 63 | def forward(self, x): 64 | return x * self.cSE(x) + x * self.sSE(x) 65 | 66 | 67 | class Activation(nn.Module): 68 | 69 | def __init__(self, name, **params): 70 | 71 | super().__init__() 72 | 73 | if name is None or name == 'identity': 74 | self.activation = nn.Identity(**params) 75 | elif name == 'sigmoid': 76 | self.activation = nn.Sigmoid() 77 | elif name == 'softmax2d': 78 | self.activation = nn.Softmax(dim=1, **params) 79 | elif name == 'softmax': 80 | self.activation = nn.Softmax(**params) 81 | elif name == 'logsoftmax': 82 | self.activation = nn.LogSoftmax(**params) 83 | elif callable(name): 84 | self.activation = name(**params) 85 | else: 86 | raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/None; got {}'.format(name)) 87 | 88 | def forward(self, x): 89 | return self.activation(x) 90 | 91 | 92 | class Attention(nn.Module): 93 | 94 | def __init__(self, name, **params): 95 | super().__init__() 96 | 97 | if name is None: 98 | self.attention = nn.Identity(**params) 99 | elif name == 'scse': 100 | self.attention = SCSEModule(**params) 101 | else: 102 | raise ValueError("Attention {} is not implemented".format(name)) 103 | 104 | def forward(self, x): 105 | return self.attention(x) 106 | 107 | 108 | class Flatten(nn.Module): 109 | def forward(self, x): 110 | return x.view(x.shape[0], -1) 111 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /code/utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import pickle 9 | 10 | import numpy as np 11 | import torch 12 | from torch.utils.data.sampler import Sampler 13 | 14 | # import networks 15 | 16 | def load_model(path): 17 | """Loads model and return it without DataParallel table.""" 18 | if os.path.isfile(path): 19 | print("=> loading checkpoint '{}'".format(path)) 20 | checkpoint = torch.load(path) 21 | 22 | # size of the top layer 23 | N = checkpoint['state_dict']['top_layer.bias'].size() 24 | 25 | # build skeleton of the model 26 | sob = 'sobel.0.weight' in checkpoint['state_dict'].keys() 27 | model = models.__dict__[checkpoint['arch']](sobel=sob, out=int(N[0])) 28 | 29 | # deal with a dataparallel table 30 | def rename_key(key): 31 | if not 'module' in key: 32 | return key 33 | return ''.join(key.split('.module')) 34 | 35 | checkpoint['state_dict'] = {rename_key(key): val 36 | for key, val 37 | in checkpoint['state_dict'].items()} 38 | 39 | # load weights 40 | model.load_state_dict(checkpoint['state_dict']) 41 | print("Loaded") 42 | else: 43 | model = None 44 | print("=> no checkpoint found at '{}'".format(path)) 45 | return model 46 | 47 | 48 | class UnifLabelSampler(Sampler): 49 | """Samples elements uniformely accross pseudolabels. 50 | Args: 51 | N (int): size of returned iterator. 52 | images_lists: dict of key (target), value (list of data with this target) 53 | """ 54 | 55 | def __init__(self, N, images_lists): 56 | self.N = N 57 | self.images_lists = images_lists 58 | self.indexes = self.generate_indexes_epoch() 59 | 60 | def generate_indexes_epoch(self): 61 | size_per_pseudolabel = int(self.N / len(self.images_lists)) + 1 62 | res = np.zeros(size_per_pseudolabel * len(self.images_lists)) 63 | 64 | for i in range(len(self.images_lists)): 65 | indexes = np.random.choice( 66 | self.images_lists[i], 67 | size_per_pseudolabel, 68 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel) 69 | ) 70 | res[i * size_per_pseudolabel: (i + 1) * size_per_pseudolabel] = indexes 71 | 72 | np.random.shuffle(res) 73 | return res[:self.N].astype('int') 74 | 75 | def __iter__(self): 76 | return iter(self.indexes) 77 | 78 | def __len__(self): 79 | return self.N 80 | 81 | 82 | class AverageMeter(object): 83 | """Computes and stores the average and current value""" 84 | def __init__(self): 85 | self.reset() 86 | 87 | def reset(self): 88 | self.val = 0 89 | self.avg = 0 90 | self.sum = 0 91 | self.count = 0 92 | 93 | def update(self, val, n=1): 94 | self.val = val 95 | self.sum += val * n 96 | self.count += n 97 | self.avg = self.sum / self.count 98 | 99 | 100 | def learning_rate_decay(optimizer, t, lr_0): 101 | for param_group in optimizer.param_groups: 102 | lr = lr_0 / np.sqrt(1 + lr_0 * param_group['weight_decay'] * t) 103 | param_group['lr'] = lr 104 | 105 | 106 | class Logger(): 107 | """ Class to update every epoch to keep trace of the results 108 | Methods: 109 | - log() log and save 110 | """ 111 | 112 | def __init__(self, path): 113 | self.path = path 114 | self.data = [] 115 | 116 | def log(self, train_point): 117 | self.data.append(train_point) 118 | with open(os.path.join(self.path), 'wb') as fp: 119 | pickle.dump(self.data, fp, -1) 120 | -------------------------------------------------------------------------------- /code/networks/unet_3D.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the 3D U-Net paper: 4 | Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: 5 | 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. 6 | MICCAI (2) 2016: 424-432 7 | Note that there are some modifications from the original paper, such as 8 | the use of batch normalization, dropout, and leaky relu here. 9 | The implementation is borrowed from: https://github.com/ozan-oktay/Attention-Gated-Networks 10 | """ 11 | import math 12 | 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from networks.networks_other import init_weights 17 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT 18 | 19 | 20 | class unet_3D(nn.Module): 21 | 22 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 23 | super(unet_3D, self).__init__() 24 | self.is_deconv = is_deconv 25 | self.in_channels = in_channels 26 | self.is_batchnorm = is_batchnorm 27 | self.feature_scale = feature_scale 28 | 29 | filters = [64, 128, 256, 512, 1024] 30 | filters = [int(x / self.feature_scale) for x in filters] 31 | 32 | # downsampling 33 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 34 | 3, 3, 3), padding_size=(1, 1, 1)) 35 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 36 | 37 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 38 | 3, 3, 3), padding_size=(1, 1, 1)) 39 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 40 | 41 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 42 | 3, 3, 3), padding_size=(1, 1, 1)) 43 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 44 | 45 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 46 | 3, 3, 3), padding_size=(1, 1, 1)) 47 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 48 | 49 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 50 | 3, 3, 3), padding_size=(1, 1, 1)) 51 | 52 | # upsampling 53 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 54 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 55 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 56 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 57 | 58 | # final conv (without any concat) 59 | self.final = nn.Conv3d(filters[0], n_classes, 1) 60 | 61 | self.dropout1 = nn.Dropout(p=0.3) 62 | self.dropout2 = nn.Dropout(p=0.3) 63 | 64 | # initialise weights 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv3d): 67 | init_weights(m, init_type='kaiming') 68 | elif isinstance(m, nn.BatchNorm3d): 69 | init_weights(m, init_type='kaiming') 70 | 71 | def forward(self, inputs): 72 | conv1 = self.conv1(inputs) 73 | maxpool1 = self.maxpool1(conv1) 74 | 75 | conv2 = self.conv2(maxpool1) 76 | maxpool2 = self.maxpool2(conv2) 77 | 78 | conv3 = self.conv3(maxpool2) 79 | maxpool3 = self.maxpool3(conv3) 80 | 81 | conv4 = self.conv4(maxpool3) 82 | maxpool4 = self.maxpool4(conv4) 83 | 84 | center = self.center(maxpool4) 85 | center = self.dropout1(center) 86 | up4 = self.up_concat4(conv4, center) 87 | up3 = self.up_concat3(conv3, up4) 88 | up2 = self.up_concat2(conv2, up3) 89 | up1 = self.up_concat1(conv1, up2) 90 | up1 = self.dropout2(up1) 91 | 92 | final = self.final(up1) 93 | 94 | return final 95 | 96 | @staticmethod 97 | def apply_argmax_softmax(pred): 98 | log_p = F.softmax(pred, dim=1) 99 | 100 | return log_p 101 | -------------------------------------------------------------------------------- /code/networks/VoxResNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class SEBlock(nn.Module): 10 | def __init__(self, in_channels, r): 11 | super(SEBlock, self).__init__() 12 | 13 | redu_chns = int(in_channels / r) 14 | self.se_layers = nn.Sequential( 15 | nn.AdaptiveAvgPool3d(1), 16 | nn.Conv3d(in_channels, redu_chns, kernel_size=1, padding=0), 17 | nn.ReLU(), 18 | nn.Conv3d(redu_chns, in_channels, kernel_size=1, padding=0), 19 | nn.ReLU()) 20 | 21 | def forward(self, x): 22 | f = self.se_layers(x) 23 | return f * x + x 24 | 25 | 26 | class VoxRex(nn.Module): 27 | def __init__(self, in_channels): 28 | super(VoxRex, self).__init__() 29 | self.block = nn.Sequential( 30 | nn.InstanceNorm3d(in_channels), 31 | nn.ReLU(inplace=True), 32 | nn.Conv3d(in_channels, in_channels, 33 | kernel_size=3, padding=1, bias=False), 34 | nn.InstanceNorm3d(in_channels), 35 | nn.ReLU(inplace=True), 36 | nn.Conv3d(in_channels, in_channels, 37 | kernel_size=3, padding=1, bias=False) 38 | ) 39 | 40 | def forward(self, x): 41 | return self.block(x)+x 42 | 43 | 44 | class ConvBlock(nn.Module): 45 | """two convolution layers with batch norm and leaky relu""" 46 | 47 | def __init__(self, in_channels, out_channels): 48 | super(ConvBlock, self).__init__() 49 | self.conv_conv = nn.Sequential( 50 | nn.InstanceNorm3d(in_channels), 51 | nn.ReLU(inplace=True), 52 | nn.Conv3d(in_channels, out_channels, 53 | kernel_size=3, padding=1, bias=False), 54 | nn.InstanceNorm3d(out_channels), 55 | nn.ReLU(inplace=True), 56 | nn.Conv3d(out_channels, out_channels, 57 | kernel_size=3, padding=1, bias=False) 58 | ) 59 | 60 | def forward(self, x): 61 | return self.conv_conv(x) 62 | 63 | 64 | class UpBlock(nn.Module): 65 | """Upssampling followed by ConvBlock""" 66 | 67 | def __init__(self, in_channels, out_channels): 68 | super(UpBlock, self).__init__() 69 | self.up = nn.Upsample( 70 | scale_factor=2, mode='trilinear', align_corners=True) 71 | self.conv = ConvBlock(in_channels, out_channels) 72 | 73 | def forward(self, x1, x2): 74 | x1 = self.up(x1) 75 | x = torch.cat([x2, x1], dim=1) 76 | return self.conv(x) 77 | 78 | 79 | class VoxResNet(nn.Module): 80 | def __init__(self, in_chns=1, feature_chns=64, class_num=2): 81 | super(VoxResNet, self).__init__() 82 | self.in_chns = in_chns 83 | self.ft_chns = feature_chns 84 | self.n_class = class_num 85 | 86 | self.conv1 = nn.Conv3d(in_chns, feature_chns, kernel_size=3, padding=1) 87 | self.res1 = VoxRex(feature_chns) 88 | self.res2 = VoxRex(feature_chns) 89 | self.res3 = VoxRex(feature_chns) 90 | self.res4 = VoxRex(feature_chns) 91 | self.res5 = VoxRex(feature_chns) 92 | self.res6 = VoxRex(feature_chns) 93 | 94 | self.up1 = UpBlock(feature_chns * 2, feature_chns) 95 | self.up2 = UpBlock(feature_chns * 2, feature_chns) 96 | 97 | self.out = nn.Conv3d(feature_chns, self.n_class, kernel_size=1) 98 | 99 | self.maxpool = nn.MaxPool3d(2) 100 | self.upsample = nn.Upsample( 101 | scale_factor=2, mode='trilinear', align_corners=True) 102 | 103 | def forward(self, x): 104 | x = self.maxpool(self.conv1(x)) 105 | x1 = self.res1(x) 106 | x2 = self.res2(x1) 107 | x2_pool = self.maxpool(x2) 108 | x3 = self.res3(x2_pool) 109 | x4 = self.maxpool(self.res4(x3)) 110 | x5 = self.res5(x4) 111 | x6 = self.res6(x5) 112 | up1 = self.up1(x6, x2_pool) 113 | up2 = self.up2(up1, x) 114 | up = self.upsample(up2) 115 | out = self.out(up) 116 | return out 117 | -------------------------------------------------------------------------------- /code/networks/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file borrowed from Swin-UNet: https://github.com/HuCaoFighting/Swin-Unet 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import copy 8 | import logging 9 | import math 10 | 11 | from os.path import join as pjoin 12 | 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | 17 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 18 | from torch.nn.modules.utils import _pair 19 | from scipy import ndimage 20 | from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | class SwinUnet(nn.Module): 25 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 26 | super(SwinUnet, self).__init__() 27 | self.num_classes = num_classes 28 | self.zero_head = zero_head 29 | self.config = config 30 | 31 | self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, 32 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 33 | in_chans=config.MODEL.SWIN.IN_CHANS, 34 | num_classes=self.num_classes, 35 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 36 | depths=config.MODEL.SWIN.DEPTHS, 37 | num_heads=config.MODEL.SWIN.NUM_HEADS, 38 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 39 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 40 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 41 | qk_scale=config.MODEL.SWIN.QK_SCALE, 42 | drop_rate=config.MODEL.DROP_RATE, 43 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 44 | ape=config.MODEL.SWIN.APE, 45 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 46 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 47 | 48 | def forward(self, x): 49 | if x.size()[1] == 1: 50 | x = x.repeat(1,3,1,1) 51 | logits = self.swin_unet(x) 52 | return logits 53 | 54 | def load_from(self, config): 55 | pretrained_path = config.MODEL.PRETRAIN_CKPT 56 | if pretrained_path is not None: 57 | print("pretrained_path:{}".format(pretrained_path)) 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | pretrained_dict = torch.load(pretrained_path, map_location=device) 60 | if "model" not in pretrained_dict: 61 | print("---start load pretrained modle by splitting---") 62 | pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} 63 | for k in list(pretrained_dict.keys()): 64 | if "output" in k: 65 | print("delete key:{}".format(k)) 66 | del pretrained_dict[k] 67 | msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) 68 | # print(msg) 69 | return 70 | pretrained_dict = pretrained_dict['model'] 71 | print("---start load pretrained modle of swin encoder---") 72 | 73 | model_dict = self.swin_unet.state_dict() 74 | full_dict = copy.deepcopy(pretrained_dict) 75 | for k, v in pretrained_dict.items(): 76 | if "layers." in k: 77 | current_layer_num = 3-int(k[7:8]) 78 | current_k = "layers_up." + str(current_layer_num) + k[8:] 79 | full_dict.update({current_k:v}) 80 | for k in list(full_dict.keys()): 81 | if k in model_dict: 82 | if full_dict[k].shape != model_dict[k].shape: 83 | print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) 84 | del full_dict[k] 85 | 86 | msg = self.swin_unet.load_state_dict(full_dict, strict=False) 87 | # print(msg) 88 | else: 89 | print("none pretrain") 90 | -------------------------------------------------------------------------------- /code/networks/unet_3D_dv_semi.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is adapted from https://github.com/ozan-oktay/Attention-Gated-Networks 3 | """ 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT, UnetDsv3 9 | import torch.nn.functional as F 10 | from networks.networks_other import init_weights 11 | 12 | 13 | class unet_3D_dv_semi(nn.Module): 14 | 15 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 16 | super(unet_3D_dv_semi, self).__init__() 17 | self.is_deconv = is_deconv 18 | self.in_channels = in_channels 19 | self.is_batchnorm = is_batchnorm 20 | self.feature_scale = feature_scale 21 | 22 | filters = [64, 128, 256, 512, 1024] 23 | filters = [int(x / self.feature_scale) for x in filters] 24 | 25 | # downsampling 26 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 27 | 3, 3, 3), padding_size=(1, 1, 1)) 28 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 29 | 30 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 31 | 3, 3, 3), padding_size=(1, 1, 1)) 32 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 33 | 34 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 35 | 3, 3, 3), padding_size=(1, 1, 1)) 36 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 37 | 38 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 39 | 3, 3, 3), padding_size=(1, 1, 1)) 40 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 41 | 42 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 43 | 3, 3, 3), padding_size=(1, 1, 1)) 44 | 45 | # upsampling 46 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 47 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 48 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 49 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 50 | 51 | # deep supervision 52 | self.dsv4 = UnetDsv3( 53 | in_size=filters[3], out_size=n_classes, scale_factor=8) 54 | self.dsv3 = UnetDsv3( 55 | in_size=filters[2], out_size=n_classes, scale_factor=4) 56 | self.dsv2 = UnetDsv3( 57 | in_size=filters[1], out_size=n_classes, scale_factor=2) 58 | self.dsv1 = nn.Conv3d( 59 | in_channels=filters[0], out_channels=n_classes, kernel_size=1) 60 | 61 | self.dropout1 = nn.Dropout3d(p=0.5) 62 | self.dropout2 = nn.Dropout3d(p=0.3) 63 | self.dropout3 = nn.Dropout3d(p=0.2) 64 | self.dropout4 = nn.Dropout3d(p=0.1) 65 | 66 | # initialise weights 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv3d): 69 | init_weights(m, init_type='kaiming') 70 | elif isinstance(m, nn.BatchNorm3d): 71 | init_weights(m, init_type='kaiming') 72 | 73 | def forward(self, inputs): 74 | conv1 = self.conv1(inputs) 75 | maxpool1 = self.maxpool1(conv1) 76 | 77 | conv2 = self.conv2(maxpool1) 78 | maxpool2 = self.maxpool2(conv2) 79 | 80 | conv3 = self.conv3(maxpool2) 81 | maxpool3 = self.maxpool3(conv3) 82 | 83 | conv4 = self.conv4(maxpool3) 84 | maxpool4 = self.maxpool4(conv4) 85 | 86 | center = self.center(maxpool4) 87 | 88 | up4 = self.up_concat4(conv4, center) 89 | up4 = self.dropout1(up4) 90 | 91 | up3 = self.up_concat3(conv3, up4) 92 | up3 = self.dropout2(up3) 93 | 94 | up2 = self.up_concat2(conv2, up3) 95 | up2 = self.dropout3(up2) 96 | 97 | up1 = self.up_concat1(conv1, up2) 98 | up1 = self.dropout4(up1) 99 | 100 | # Deep Supervision 101 | dsv4 = self.dsv4(up4) 102 | dsv3 = self.dsv3(up3) 103 | dsv2 = self.dsv2(up2) 104 | dsv1 = self.dsv1(up1) 105 | 106 | return dsv1, dsv2, dsv3, dsv4 107 | 108 | @staticmethod 109 | def apply_argmax_softmax(pred): 110 | log_p = F.softmax(pred, dim=1) 111 | 112 | return log_p 113 | -------------------------------------------------------------------------------- /code/networks/pnet.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | """ 4 | An PyTorch implementation of the DeepIGeoS paper: 5 | Wang, Guotai and Zuluaga, Maria A and Li, Wenqi and Pratt, Rosalind and Patel, Premal A and Aertsen, Michael and Doel, Tom and David, Anna L and Deprest, Jan and Ourselin, S{\'e}bastien and others: 6 | DeepIGeoS: a deep interactive geodesic framework for medical image segmentation. 7 | TPAMI (7) 2018: 1559--1572 8 | Note that there are some modifications from the original paper, such as 9 | the use of leaky relu here. 10 | """ 11 | from __future__ import division, print_function 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | class PNetBlock(nn.Module): 18 | def __init__(self, in_channels, out_channels, dilation, padding): 19 | super(PNetBlock, self).__init__() 20 | 21 | self.in_chns = in_channels 22 | self.out_chns = out_channels 23 | self.dilation = dilation 24 | self.padding = padding 25 | 26 | self.conv1 = nn.Conv2d(self.in_chns, self.out_chns, kernel_size=3, 27 | padding=self.padding, dilation=self.dilation, groups=1, bias=True) 28 | self.conv2 = nn.Conv2d(self.out_chns, self.out_chns, kernel_size=3, 29 | padding=self.padding, dilation=self.dilation, groups=1, bias=True) 30 | self.in1 = nn.BatchNorm2d(self.out_chns) 31 | self.in2 = nn.BatchNorm2d(self.out_chns) 32 | self.ac1 = nn.LeakyReLU() 33 | self.ac2 = nn.LeakyReLU() 34 | 35 | def forward(self, x): 36 | x = self.conv1(x) 37 | x = self.in1(x) 38 | x = self.ac1(x) 39 | x = self.conv2(x) 40 | x = self.in2(x) 41 | x = self.ac2(x) 42 | return x 43 | 44 | 45 | class ConcatBlock(nn.Module): 46 | def __init__(self, in_channels, out_channels): 47 | super(ConcatBlock, self).__init__() 48 | self.in_chns = in_channels 49 | self.out_chns = out_channels 50 | self.conv1 = nn.Conv2d( 51 | self.in_chns, self.in_chns, kernel_size=1, padding=0) 52 | self.conv2 = nn.Conv2d( 53 | self.in_chns, self.out_chns, kernel_size=1, padding=0) 54 | self.ac1 = nn.LeakyReLU() 55 | self.ac2 = nn.LeakyReLU() 56 | 57 | def forward(self, x): 58 | x = self.conv1(x) 59 | x = self.ac1(x) 60 | x = self.conv2(x) 61 | x = self.ac2(x) 62 | return x 63 | 64 | 65 | class OutPutBlock(nn.Module): 66 | def __init__(self, in_channels, out_channels): 67 | super(OutPutBlock, self).__init__() 68 | self.in_chns = in_channels 69 | self.out_chns = out_channels 70 | self.conv1 = nn.Conv2d( 71 | self.in_chns, self.in_chns // 2, kernel_size=1, padding=0) 72 | self.conv2 = nn.Conv2d( 73 | self.in_chns // 2, self.out_chns, kernel_size=1, padding=0) 74 | self.drop1 = nn.Dropout2d(0.3) 75 | self.drop2 = nn.Dropout2d(0.3) 76 | self.ac1 = nn.LeakyReLU() 77 | 78 | def forward(self, x): 79 | x = self.drop1(x) 80 | x = self.conv1(x) 81 | x = self.ac1(x) 82 | x = self.drop2(x) 83 | x = self.conv2(x) 84 | return x 85 | 86 | 87 | class PNet2D(nn.Module): 88 | def __init__(self, in_chns, out_chns, num_filters, ratios): 89 | super(PNet2D, self).__init__() 90 | 91 | self.in_chns = in_chns 92 | self.out_chns = out_chns 93 | self.ratios = ratios 94 | self.num_filters = num_filters 95 | 96 | self.block1 = PNetBlock( 97 | self.in_chns, self.num_filters, self.ratios[0], padding=self.ratios[0]) 98 | 99 | self.block2 = PNetBlock( 100 | self.num_filters, self.num_filters, self.ratios[1], padding=self.ratios[1]) 101 | 102 | self.block3 = PNetBlock( 103 | self.num_filters, self.num_filters, self.ratios[2], padding=self.ratios[2]) 104 | 105 | self.block4 = PNetBlock( 106 | self.num_filters, self.num_filters, self.ratios[3], padding=self.ratios[3]) 107 | 108 | self.block5 = PNetBlock( 109 | self.num_filters, self.num_filters, self.ratios[4], padding=self.ratios[4]) 110 | self.catblock = ConcatBlock(self.num_filters * 5, self.num_filters * 2) 111 | self.out = OutPutBlock(self.num_filters * 2, self.out_chns) 112 | 113 | def forward(self, x): 114 | x1 = self.block1(x) 115 | x2 = self.block2(x1) 116 | x3 = self.block3(x2) 117 | x4 = self.block4(x3) 118 | x5 = self.block5(x4) 119 | conx = torch.cat([x1, x2, x3, x4, x5], dim=1) 120 | conx = self.catblock(conx) 121 | out = self.out(conx) 122 | return out 123 | -------------------------------------------------------------------------------- /code/networks/net_factory.py: -------------------------------------------------------------------------------- 1 | from networks.efficientunet import Effi_UNet 2 | from networks.enet import ENet 3 | from networks.pnet import PNet2D 4 | from networks.unet import UNet, UNet_DS, UNet_URPC, UNet_CCT 5 | import argparse 6 | from networks.vision_transformer import SwinUnet as ViT_seg 7 | from networks.config import get_config 8 | from networks.nnunet import initialize_network 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--root_path', type=str, 13 | default='../data/ACDC', help='Name of Experiment') 14 | parser.add_argument('--exp', type=str, 15 | default='ACDC/Cross_Supervision_CNN_Trans2D', help='experiment_name') 16 | parser.add_argument('--model', type=str, 17 | default='unet', help='model_name') 18 | parser.add_argument('--max_iterations', type=int, 19 | default=30000, help='maximum epoch number to train') 20 | parser.add_argument('--batch_size', type=int, default=8, 21 | help='batch_size per gpu') 22 | parser.add_argument('--deterministic', type=int, default=1, 23 | help='whether use deterministic training') 24 | parser.add_argument('--base_lr', type=float, default=0.01, 25 | help='segmentation network learning rate') 26 | parser.add_argument('--patch_size', type=list, default=[224, 224], 27 | help='patch size of network input') 28 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 29 | parser.add_argument('--num_classes', type=int, default=4, 30 | help='output channel of network') 31 | parser.add_argument( 32 | '--cfg', type=str, default="../code/configs/swin_tiny_patch4_window7_224_lite.yaml", help='path to config file', ) 33 | parser.add_argument( 34 | "--opts", 35 | help="Modify config options by adding 'KEY VALUE' pairs. ", 36 | default=None, 37 | nargs='+', 38 | ) 39 | parser.add_argument('--zip', action='store_true', 40 | help='use zipped dataset instead of folder dataset') 41 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 42 | help='no: no cache, ' 43 | 'full: cache all data, ' 44 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 45 | parser.add_argument('--resume', help='resume from checkpoint') 46 | parser.add_argument('--accumulation-steps', type=int, 47 | help="gradient accumulation steps") 48 | parser.add_argument('--use-checkpoint', action='store_true', 49 | help="whether to use gradient checkpointing to save memory") 50 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 51 | help='mixed precision opt level, if O0, no amp is used') 52 | parser.add_argument('--tag', help='tag of experiment') 53 | parser.add_argument('--eval', action='store_true', 54 | help='Perform evaluation only') 55 | parser.add_argument('--throughput', action='store_true', 56 | help='Test throughput only') 57 | 58 | # label and unlabel 59 | parser.add_argument('--labeled_bs', type=int, default=4, 60 | help='labeled_batch_size per gpu') 61 | parser.add_argument('--labeled_num', type=int, default=7, 62 | help='labeled data') 63 | # costs 64 | parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay') 65 | parser.add_argument('--consistency_type', type=str, 66 | default="mse", help='consistency_type') 67 | parser.add_argument('--consistency', type=float, 68 | default=0.1, help='consistency') 69 | parser.add_argument('--consistency_rampup', type=float, 70 | default=200.0, help='consistency_rampup') 71 | args = parser.parse_args() 72 | config = get_config(args) 73 | 74 | 75 | def net_factory(net_type="unet", in_chns=1, class_num=3): 76 | if net_type == "unet": 77 | net = UNet(in_chns=in_chns, class_num=class_num).cuda() 78 | elif net_type == "enet": 79 | net = ENet(in_channels=in_chns, num_classes=class_num).cuda() 80 | elif net_type == "unet_ds": 81 | net = UNet_DS(in_chns=in_chns, class_num=class_num).cuda() 82 | elif net_type == "unet_cct": 83 | net = UNet_CCT(in_chns=in_chns, class_num=class_num).cuda() 84 | elif net_type == "unet_urpc": 85 | net = UNet_URPC(in_chns=in_chns, class_num=class_num).cuda() 86 | elif net_type == "efficient_unet": 87 | net = Effi_UNet('efficientnet-b3', encoder_weights='imagenet', 88 | in_channels=in_chns, classes=class_num).cuda() 89 | elif net_type == "ViT_Seg": 90 | net = ViT_seg(config, img_size=args.patch_size, 91 | num_classes=args.num_classes).cuda() 92 | elif net_type == "pnet": 93 | net = PNet2D(in_chns, class_num, 64, [1, 2, 4, 8, 16]).cuda() 94 | elif net_type == "nnUNet": 95 | net = initialize_network(num_classes=class_num).cuda() 96 | else: 97 | net = None 98 | return net 99 | -------------------------------------------------------------------------------- /code/networks/net_factory_args.py: -------------------------------------------------------------------------------- 1 | from networks.efficientunet import Effi_UNet 2 | from networks.enet import ENet 3 | from networks.pnet import PNet2D 4 | from networks.unetWithArgs import UNet, UNet_DS, UNet_URPC, UNet_CCT 5 | # from networks.unetWithArgsSTEGO import UNet, UNet_DS, UNet_URPC, UNet_CCT 6 | import argparse 7 | from networks.vision_transformer import SwinUnet as ViT_seg 8 | from networks.config import get_config 9 | from networks.nnunet import initialize_network 10 | 11 | 12 | # parser = argparse.ArgumentParser() 13 | # parser.add_argument('--root_path', type=str, 14 | # default='../data/ACDC', help='Name of Experiment') 15 | # parser.add_argument('--exp', type=str, 16 | # default='ACDC/Cross_Supervision_CNN_Trans2D', help='experiment_name') 17 | # parser.add_argument('--model', type=str, 18 | # default='unet', help='model_name') 19 | # parser.add_argument('--max_iterations', type=int, 20 | # default=30000, help='maximum epoch number to train') 21 | # parser.add_argument('--batch_size', type=int, default=8, 22 | # help='batch_size per gpu') 23 | # parser.add_argument('--deterministic', type=int, default=1, 24 | # help='whether use deterministic training') 25 | # parser.add_argument('--base_lr', type=float, default=0.01, 26 | # help='segmentation network learning rate') 27 | # parser.add_argument('--patch_size', type=list, default=[224, 224], 28 | # help='patch size of network input') 29 | # parser.add_argument('--seed', type=int, default=1337, help='random seed') 30 | # parser.add_argument('--num_classes', type=int, default=4, 31 | # help='output channel of network') 32 | # parser.add_argument( 33 | # '--cfg', type=str, default="../code/configs/swin_tiny_patch4_window7_224_lite.yaml", help='path to config file', ) 34 | # parser.add_argument( 35 | # "--opts", 36 | # help="Modify config options by adding 'KEY VALUE' pairs. ", 37 | # default=None, 38 | # nargs='+', 39 | # ) 40 | # parser.add_argument('--zip', action='store_true', 41 | # help='use zipped dataset instead of folder dataset') 42 | # parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 43 | # help='no: no cache, ' 44 | # 'full: cache all data, ' 45 | # 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 46 | # parser.add_argument('--resume', help='resume from checkpoint') 47 | # parser.add_argument('--accumulation-steps', type=int, 48 | # help="gradient accumulation steps") 49 | # parser.add_argument('--use-checkpoint', action='store_true', 50 | # help="whether to use gradient checkpointing to save memory") 51 | # parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 52 | # help='mixed precision opt level, if O0, no amp is used') 53 | # parser.add_argument('--tag', help='tag of experiment') 54 | # parser.add_argument('--eval', action='store_true', 55 | # help='Perform evaluation only') 56 | # parser.add_argument('--throughput', action='store_true', 57 | # help='Test throughput only') 58 | 59 | # # label and unlabel 60 | # parser.add_argument('--labeled_bs', type=int, default=4, 61 | # help='labeled_batch_size per gpu') 62 | # parser.add_argument('--labeled_num', type=int, default=7, 63 | # help='labeled data') 64 | # # costs 65 | # parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay') 66 | # parser.add_argument('--consistency_type', type=str, 67 | # default="mse", help='consistency_type') 68 | # parser.add_argument('--consistency', type=float, 69 | # default=0.1, help='consistency') 70 | # parser.add_argument('--consistency_rampup', type=float, 71 | # default=200.0, help='consistency_rampup') 72 | # args = parser.parse_args() 73 | # config = get_config(args) 74 | 75 | 76 | def net_factory(net_type="unet", in_chns=1, class_num=3, train_encoder=True, train_decoder=True, unfreeze_seg=True): 77 | if net_type == "unet": 78 | net = UNet(in_chns=in_chns, class_num=class_num, \ 79 | train_encoder=train_encoder, train_decoder=train_decoder, unfreeze_seg=unfreeze_seg).cuda() 80 | elif net_type == "enet": 81 | net = ENet(in_channels=in_chns, num_classes=class_num).cuda() 82 | elif net_type == "unet_ds": 83 | net = UNet_DS(in_chns=in_chns, class_num=class_num).cuda() 84 | elif net_type == "unet_cct": 85 | net = UNet_CCT(in_chns=in_chns, class_num=class_num).cuda() 86 | elif net_type == "unet_urpc": 87 | net = UNet_URPC(in_chns=in_chns, class_num=class_num).cuda() 88 | elif net_type == "efficient_unet": 89 | net = Effi_UNet('efficientnet-b3', encoder_weights='imagenet', 90 | in_channels=in_chns, classes=class_num).cuda() 91 | # elif net_type == "ViT_Seg": 92 | # net = ViT_seg(config, img_size=args.patch_size, 93 | # num_classes=args.num_classes).cuda() 94 | elif net_type == "pnet": 95 | net = PNet2D(in_chns, class_num, 64, [1, 2, 4, 8, 16]).cuda() 96 | elif net_type == "nnUNet": 97 | net = initialize_network(num_classes=class_num).cuda() 98 | else: 99 | net = None 100 | return net 101 | -------------------------------------------------------------------------------- /code/networks/net_factory_args_HAR.py: -------------------------------------------------------------------------------- 1 | from networks.efficientunet import Effi_UNet 2 | from networks.enet import ENet 3 | from networks.pnet import PNet2D 4 | from networks.unetWithArgs_HAR import UNet, UNet_DS, UNet_URPC, UNet_CCT 5 | # from networks.unetWithArgsSTEGO import UNet, UNet_DS, UNet_URPC, UNet_CCT 6 | import argparse 7 | from networks.vision_transformer import SwinUnet as ViT_seg 8 | from networks.config import get_config 9 | from networks.nnunet import initialize_network 10 | 11 | 12 | # parser = argparse.ArgumentParser() 13 | # parser.add_argument('--root_path', type=str, 14 | # default='../data/ACDC', help='Name of Experiment') 15 | # parser.add_argument('--exp', type=str, 16 | # default='ACDC/Cross_Supervision_CNN_Trans2D', help='experiment_name') 17 | # parser.add_argument('--model', type=str, 18 | # default='unet', help='model_name') 19 | # parser.add_argument('--max_iterations', type=int, 20 | # default=30000, help='maximum epoch number to train') 21 | # parser.add_argument('--batch_size', type=int, default=8, 22 | # help='batch_size per gpu') 23 | # parser.add_argument('--deterministic', type=int, default=1, 24 | # help='whether use deterministic training') 25 | # parser.add_argument('--base_lr', type=float, default=0.01, 26 | # help='segmentation network learning rate') 27 | # parser.add_argument('--patch_size', type=list, default=[224, 224], 28 | # help='patch size of network input') 29 | # parser.add_argument('--seed', type=int, default=1337, help='random seed') 30 | # parser.add_argument('--num_classes', type=int, default=4, 31 | # help='output channel of network') 32 | # parser.add_argument( 33 | # '--cfg', type=str, default="../code/configs/swin_tiny_patch4_window7_224_lite.yaml", help='path to config file', ) 34 | # parser.add_argument( 35 | # "--opts", 36 | # help="Modify config options by adding 'KEY VALUE' pairs. ", 37 | # default=None, 38 | # nargs='+', 39 | # ) 40 | # parser.add_argument('--zip', action='store_true', 41 | # help='use zipped dataset instead of folder dataset') 42 | # parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 43 | # help='no: no cache, ' 44 | # 'full: cache all data, ' 45 | # 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 46 | # parser.add_argument('--resume', help='resume from checkpoint') 47 | # parser.add_argument('--accumulation-steps', type=int, 48 | # help="gradient accumulation steps") 49 | # parser.add_argument('--use-checkpoint', action='store_true', 50 | # help="whether to use gradient checkpointing to save memory") 51 | # parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 52 | # help='mixed precision opt level, if O0, no amp is used') 53 | # parser.add_argument('--tag', help='tag of experiment') 54 | # parser.add_argument('--eval', action='store_true', 55 | # help='Perform evaluation only') 56 | # parser.add_argument('--throughput', action='store_true', 57 | # help='Test throughput only') 58 | 59 | # # label and unlabel 60 | # parser.add_argument('--labeled_bs', type=int, default=4, 61 | # help='labeled_batch_size per gpu') 62 | # parser.add_argument('--labeled_num', type=int, default=7, 63 | # help='labeled data') 64 | # # costs 65 | # parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay') 66 | # parser.add_argument('--consistency_type', type=str, 67 | # default="mse", help='consistency_type') 68 | # parser.add_argument('--consistency', type=float, 69 | # default=0.1, help='consistency') 70 | # parser.add_argument('--consistency_rampup', type=float, 71 | # default=200.0, help='consistency_rampup') 72 | # args = parser.parse_args() 73 | # config = get_config(args) 74 | 75 | 76 | def net_factory(net_type="unet", in_chns=1, class_num=3, train_encoder=True, train_decoder=True, unfreeze_seg=True): 77 | if net_type == "unet": 78 | net = UNet(in_chns=in_chns, class_num=class_num, \ 79 | train_encoder=train_encoder, train_decoder=train_decoder, unfreeze_seg=unfreeze_seg).cuda() 80 | elif net_type == "enet": 81 | net = ENet(in_channels=in_chns, num_classes=class_num).cuda() 82 | elif net_type == "unet_ds": 83 | net = UNet_DS(in_chns=in_chns, class_num=class_num).cuda() 84 | elif net_type == "unet_cct": 85 | net = UNet_CCT(in_chns=in_chns, class_num=class_num).cuda() 86 | elif net_type == "unet_urpc": 87 | net = UNet_URPC(in_chns=in_chns, class_num=class_num).cuda() 88 | elif net_type == "efficient_unet": 89 | net = Effi_UNet('efficientnet-b3', encoder_weights='imagenet', 90 | in_channels=in_chns, classes=class_num).cuda() 91 | # elif net_type == "ViT_Seg": 92 | # net = ViT_seg(config, img_size=args.patch_size, 93 | # num_classes=args.num_classes).cuda() 94 | elif net_type == "pnet": 95 | net = PNet2D(in_chns, class_num, 64, [1, 2, 4, 8, 16]).cuda() 96 | elif net_type == "nnUNet": 97 | net = initialize_network(num_classes=class_num).cuda() 98 | else: 99 | net = None 100 | return net 101 | -------------------------------------------------------------------------------- /code/test_2D.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import h5py 6 | import nibabel as nib 7 | import numpy as np 8 | import SimpleITK as sitk 9 | import torch 10 | from medpy import metric 11 | from scipy.ndimage import zoom 12 | from tqdm import tqdm 13 | 14 | from model import * 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--root_path', type=str, 18 | default='/data/data/ACDC', help='Name of Experiment') 19 | parser.add_argument('--exp', type=str, 20 | default='ACDC/training_pool', help='experiment_name') 21 | parser.add_argument('--model', type=str, 22 | default='unet', help='model_name') 23 | parser.add_argument('--num_classes', type=int, default=4, 24 | help='output channel of network') 25 | parser.add_argument('--labeled_num', type=int, default=7, 26 | help='labeled data') 27 | parser.add_argument('--K', type=int, default=36, help='the size of cache') 28 | parser.add_argument('--latent_pooling_size', type=int, default=1, help='the pooling size of latent vector') 29 | parser.add_argument('--latent_feature_size', type=int, default=512, help='the feature size of latent vectors') 30 | parser.add_argument('--output_pooling_size', type=int, default=8, help='the pooling size of output head') 31 | parser.add_argument('--epoch', type=int, 32 | default=30000, help='testing epoch') 33 | FLAGS = parser.parse_args() 34 | 35 | def calculate_metric_percase(pred, gt): 36 | pred[pred > 0] = 1 37 | gt[gt > 0] = 1 38 | if pred.sum() > 0 and gt.sum() > 0: 39 | 40 | dice = metric.binary.dc(pred, gt) 41 | jc = metric.binary.jc(pred, gt) 42 | asd = metric.binary.asd(pred, gt) 43 | hd95 = metric.binary.hd95(pred, gt) 44 | return dice, jc, hd95, asd 45 | elif pred.sum() > 0 and gt.sum() == 0: 46 | return 1, 1, 0, 0 47 | else: 48 | return 0, 0, 0, 0 49 | 50 | 51 | 52 | def test_single_volume(case, net, classes, test_save_path, FLAGS): 53 | h5f = h5py.File(FLAGS.root_path + "/data/{}.h5".format(case), 'r') 54 | image = h5f['image'][:] 55 | label = h5f['label'][:] 56 | prediction = np.zeros_like(label) 57 | for ind in range(image.shape[0]): 58 | slice = image[ind, :, :] 59 | x, y = slice.shape[0], slice.shape[1] 60 | slice = zoom(slice, (256 / x, 256 / y), order=0) 61 | input = torch.from_numpy(slice).unsqueeze( 62 | 0).unsqueeze(0).float().cuda() 63 | net.eval() 64 | with torch.no_grad(): 65 | if FLAGS.model == "unet_urds": 66 | out_main, _, _, _ = net(input) 67 | else: 68 | out_main = net(input)[0] # , torch.zeros_like(input) 69 | out = torch.argmax(torch.softmax( 70 | out_main, dim=1), dim=1).squeeze(0) 71 | out = out.cpu().detach().numpy() 72 | pred = zoom(out, (x / 256, y / 256), order=0) 73 | prediction[ind] = pred 74 | 75 | # first_metric = calculate_metric_percase(prediction == 1, label == 1) 76 | # second_metric = calculate_metric_percase(prediction == 2, label == 2) 77 | # third_metric = calculate_metric_percase(prediction == 3, label == 3) 78 | 79 | metric_list = [] 80 | for i in range(1, classes): 81 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 82 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 83 | img_itk.SetSpacing((1, 1, 10)) 84 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 85 | prd_itk.SetSpacing((1, 1, 10)) 86 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 87 | lab_itk.SetSpacing((1, 1, 10)) 88 | sitk.WriteImage(prd_itk, test_save_path + case + "_pred.nii.gz") 89 | sitk.WriteImage(img_itk, test_save_path + case + "_img.nii.gz") 90 | sitk.WriteImage(lab_itk, test_save_path + case + "_gt.nii.gz") 91 | return metric_list 92 | 93 | 94 | def Inference(FLAGS): 95 | with open(FLAGS.root_path + '/test.list', 'r') as f: 96 | image_list = f.readlines() 97 | image_list = sorted([item.replace('\n', '').split(".")[0] 98 | for item in image_list]) 99 | test_save_path = "../model/{}_{}_labeledfinal/{}_predictions/".format(# 100 | FLAGS.exp, FLAGS.labeled_num, FLAGS.model) 101 | snapshot_path = "../model/{}_{}_labeledfinal/{}".format(# 102 | FLAGS.exp, FLAGS.labeled_num, FLAGS.model) 103 | if os.path.exists(test_save_path): 104 | shutil.rmtree(test_save_path) 105 | os.makedirs(test_save_path) 106 | net = create_model(ema=False, num_classes=FLAGS.num_classes, train_encoder=False, train_decoder=False) 107 | 108 | save_mode_path = os.path.join( 109 | snapshot_path, 'iter_{}.pth'.format(FLAGS.epoch)) 110 | net.load_state_dict(torch.load(save_mode_path, map_location=lambda storage, loc: storage)) 111 | print("init weight from {}".format(save_mode_path)) 112 | net.eval() 113 | 114 | metric_list = 0.0 115 | for case in tqdm(image_list): 116 | metric_i = test_single_volume( 117 | case, net, FLAGS.num_classes, test_save_path, FLAGS) 118 | metric_list += np.array(metric_i) 119 | avg_metric = metric_list / len(image_list) 120 | return avg_metric 121 | 122 | 123 | if __name__ == '__main__': 124 | metric = Inference(FLAGS) 125 | print(metric) 126 | cur = None 127 | for i in metric: 128 | try: 129 | if cur == None: 130 | cur = i 131 | else: 132 | cur += i 133 | except: 134 | if cur.all() == None: 135 | cur = i 136 | else: 137 | cur += i 138 | print(cur/len(metric)) 139 | -------------------------------------------------------------------------------- /code/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | __all__ = ['InfoNCE', 'info_nce'] 6 | 7 | class InfoNCE(nn.Module): 8 | """ 9 | Calculates the InfoNCE loss for self-supervised learning. 10 | This contrastive loss enforces the embeddings of similar (positive) samples to be close 11 | and those of different (negative) samples to be distant. 12 | A query embedding is compared with one positive key and with one or more negative keys. 13 | References: 14 | https://arxiv.org/abs/1807.03748v2 15 | https://arxiv.org/abs/2010.05113 16 | Args: 17 | temperature: Logits are divided by temperature before calculating the cross entropy. 18 | reduction: Reduction method applied to the output. 19 | Value must be one of ['none', 'sum', 'mean']. 20 | See torch.nn.functional.cross_entropy for more details about each option. 21 | negative_mode: Determines how the (optional) negative_keys are handled. 22 | Value must be one of ['paired', 'unpaired']. 23 | If 'paired', then each query sample is paired with a number of negative keys. 24 | Comparable to a triplet loss, but with multiple negatives per sample. 25 | If 'unpaired', then the set of negative keys are all unrelated to any positive key. 26 | Input shape: 27 | query: (N, D) Tensor with query samples (e.g. embeddings of the input). 28 | positive_key: (N, D) Tensor with positive samples (e.g. embeddings of augmented input). 29 | negative_keys (optional): Tensor with negative samples (e.g. embeddings of other inputs) 30 | If negative_mode = 'paired', then negative_keys is a (N, M, D) Tensor. 31 | If negative_mode = 'unpaired', then negative_keys is a (M, D) Tensor. 32 | If None, then the negative keys for a sample are the positive keys for the other samples. 33 | Returns: 34 | Value of the InfoNCE Loss. 35 | Examples: 36 | >>> loss = InfoNCE() 37 | >>> batch_size, num_negative, embedding_size = 32, 48, 128 38 | >>> query = torch.randn(batch_size, embedding_size) 39 | >>> positive_key = torch.randn(batch_size, embedding_size) 40 | >>> negative_keys = torch.randn(num_negative, embedding_size) 41 | >>> output = loss(query, positive_key, negative_keys) 42 | """ 43 | 44 | def __init__(self, temperature=0.1, reduction='mean', negative_mode='unpaired'): 45 | super().__init__() 46 | self.temperature = temperature 47 | self.reduction = reduction 48 | self.negative_mode = negative_mode 49 | 50 | def forward(self, query, positive_key, negative_keys=None): 51 | return info_nce(query, positive_key, negative_keys, 52 | temperature=self.temperature, 53 | reduction=self.reduction, 54 | negative_mode=self.negative_mode) 55 | 56 | 57 | def info_nce(query, positive_key, negative_keys=None, temperature=0.1, reduction='mean', negative_mode='unpaired'): 58 | # Check input dimensionality. 59 | if query.dim() != 2: 60 | raise ValueError(' must have 2 dimensions.') 61 | if positive_key.dim() != 2: 62 | raise ValueError(' must have 2 dimensions.') 63 | if negative_keys is not None: 64 | if negative_mode == 'unpaired' and negative_keys.dim() != 2: 65 | raise ValueError(" must have 2 dimensions if == 'unpaired'.") 66 | if negative_mode == 'paired' and negative_keys.dim() != 3: 67 | raise ValueError(" must have 3 dimensions if == 'paired'.") 68 | 69 | # Check matching number of samples. 70 | if len(query) != len(positive_key): 71 | raise ValueError(' and must must have the same number of samples.') 72 | if negative_keys is not None: 73 | if negative_mode == 'paired' and len(query) != len(negative_keys): 74 | raise ValueError("If negative_mode == 'paired', then must have the same number of samples as .") 75 | 76 | # Embedding vectors should have same number of components. 77 | if query.shape[-1] != positive_key.shape[-1]: 78 | raise ValueError('Vectors of and should have the same number of components.') 79 | if negative_keys is not None: 80 | if query.shape[-1] != negative_keys.shape[-1]: 81 | raise ValueError('Vectors of and should have the same number of components.') 82 | 83 | # Normalize to unit vectors 84 | query, positive_key, negative_keys = normalize(query, positive_key, negative_keys) 85 | if negative_keys is not None: 86 | # Explicit negative keys 87 | 88 | # Cosine between positive pairs 89 | positive_logit = torch.sum(query * positive_key, dim=1, keepdim=True) 90 | 91 | if negative_mode == 'unpaired': 92 | # Cosine between all query-negative combinations 93 | negative_logits = query @ transpose(negative_keys) 94 | 95 | elif negative_mode == 'paired': 96 | query = query.unsqueeze(1) 97 | negative_logits = query @ transpose(negative_keys) 98 | negative_logits = negative_logits.squeeze(1) 99 | 100 | # First index in last dimension are the positive samples 101 | logits = torch.cat([positive_logit, negative_logits], dim=1) 102 | labels = torch.zeros(len(logits), dtype=torch.long, device=query.device) 103 | else: 104 | # Negative keys are implicitly off-diagonal positive keys. 105 | 106 | # Cosine between all combinations 107 | logits = query @ transpose(positive_key) 108 | 109 | # Positive keys are the entries on the diagonal 110 | labels = torch.arange(len(query), device=query.device) 111 | 112 | return F.cross_entropy(logits / temperature, labels, reduction=reduction) 113 | 114 | 115 | def transpose(x): 116 | return x.transpose(-2, -1) 117 | 118 | 119 | def normalize(*xs): 120 | return [None if x is None else F.normalize(x, dim=-1) for x in xs] -------------------------------------------------------------------------------- /code/networks/attention_unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from networks.utils import UnetConv3, UnetUp3_CT, UnetGridGatingSignal3, UnetDsv3 4 | import torch.nn.functional as F 5 | from networks.networks_other import init_weights 6 | from networks.grid_attention_layer import GridAttentionBlock3D 7 | 8 | 9 | class Attention_UNet(nn.Module): 10 | 11 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, 12 | nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True): 13 | super(Attention_UNet, self).__init__() 14 | self.is_deconv = is_deconv 15 | self.in_channels = in_channels 16 | self.is_batchnorm = is_batchnorm 17 | self.feature_scale = feature_scale 18 | 19 | filters = [64, 128, 256, 512, 1024] 20 | filters = [int(x / self.feature_scale) for x in filters] 21 | 22 | # downsampling 23 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 24 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 25 | 26 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 27 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 28 | 29 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 30 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 31 | 32 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 33 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 34 | 35 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 36 | self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) 37 | 38 | # attention blocks 39 | self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], 40 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 41 | self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], 42 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 43 | self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3], 44 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 45 | 46 | # upsampling 47 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 48 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 49 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 50 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 51 | 52 | # deep supervision 53 | self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) 54 | self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) 55 | self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) 56 | self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) 57 | 58 | # final conv (without any concat) 59 | self.final = nn.Conv3d(n_classes*4, n_classes, 1) 60 | 61 | # initialise weights 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv3d): 64 | init_weights(m, init_type='kaiming') 65 | elif isinstance(m, nn.BatchNorm3d): 66 | init_weights(m, init_type='kaiming') 67 | 68 | def forward(self, inputs): 69 | # Feature Extraction 70 | conv1 = self.conv1(inputs) 71 | maxpool1 = self.maxpool1(conv1) 72 | 73 | conv2 = self.conv2(maxpool1) 74 | maxpool2 = self.maxpool2(conv2) 75 | 76 | conv3 = self.conv3(maxpool2) 77 | maxpool3 = self.maxpool3(conv3) 78 | 79 | conv4 = self.conv4(maxpool3) 80 | maxpool4 = self.maxpool4(conv4) 81 | 82 | # Gating Signal Generation 83 | center = self.center(maxpool4) 84 | gating = self.gating(center) 85 | 86 | # Attention Mechanism 87 | # Upscaling Part (Decoder) 88 | g_conv4, att4 = self.attentionblock4(conv4, gating) 89 | up4 = self.up_concat4(g_conv4, center) 90 | g_conv3, att3 = self.attentionblock3(conv3, up4) 91 | up3 = self.up_concat3(g_conv3, up4) 92 | g_conv2, att2 = self.attentionblock2(conv2, up3) 93 | up2 = self.up_concat2(g_conv2, up3) 94 | up1 = self.up_concat1(conv1, up2) 95 | 96 | # Deep Supervision 97 | dsv4 = self.dsv4(up4) 98 | dsv3 = self.dsv3(up3) 99 | dsv2 = self.dsv2(up2) 100 | dsv1 = self.dsv1(up1) 101 | final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1)) 102 | 103 | return final 104 | 105 | 106 | @staticmethod 107 | def apply_argmax_softmax(pred): 108 | log_p = F.softmax(pred, dim=1) 109 | 110 | return log_p 111 | 112 | 113 | class MultiAttentionBlock(nn.Module): 114 | def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): 115 | super(MultiAttentionBlock, self).__init__() 116 | self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, 117 | inter_channels=inter_size, mode=nonlocal_mode, 118 | sub_sample_factor= sub_sample_factor) 119 | self.gate_block_2 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, 120 | inter_channels=inter_size, mode=nonlocal_mode, 121 | sub_sample_factor=sub_sample_factor) 122 | self.combine_gates = nn.Sequential(nn.Conv3d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), 123 | nn.BatchNorm3d(in_size), 124 | nn.ReLU(inplace=True) 125 | ) 126 | 127 | # initialise the blocks 128 | for m in self.children(): 129 | if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue 130 | init_weights(m, init_type='kaiming') 131 | 132 | def forward(self, input, gating_signal): 133 | gate_1, attention_1 = self.gate_block_1(input, gating_signal) 134 | gate_2, attention_2 = self.gate_block_2(input, gating_signal) 135 | 136 | return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) -------------------------------------------------------------------------------- /code/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import sys 6 | import numpy as np 7 | import torch 8 | import shutil 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn as nn 11 | from torch.utils.data import DataLoader 12 | from networks.net_factory_args import net_factory 13 | from tqdm import tqdm 14 | from dataloaders.dataset_synapse import Synapse_dataset 15 | # from utils import test_single_volume 16 | import SimpleITK as sitk 17 | from scipy.ndimage import zoom 18 | from model import * 19 | from medpy import metric 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--root_path', type=str, 23 | default='/data/data/Lits/test_vol_h5', help='root dir for validation volume data') # for acdc volume_path=root_dir 24 | parser.add_argument('--exp', type=str, 25 | default='Lits/action', help='experiment_name') 26 | parser.add_argument('--num_classes', type=int, 27 | default=3, help='output channel of network') 28 | parser.add_argument('--model', type=str, 29 | default='unet', help='model_name') 30 | parser.add_argument('--list_dir', type=str, 31 | default='/data/data/Lits', help='list dir') 32 | parser.add_argument('--labeled_num', type=int, default=7, 33 | help='labeled data') 34 | parser.add_argument('--epoch', type=int, 35 | default=6000, help='testing epoch') 36 | parser.add_argument('--K', type=int, default=36, help='the size of cache') 37 | parser.add_argument('--latent_pooling_size', type=int, default=1, help='the pooling size of latent vector') 38 | parser.add_argument('--latent_feature_size', type=int, default=512, help='the feature size of latent vectors') 39 | parser.add_argument('--output_pooling_size', type=int, default=8, help='the pooling size of output head') 40 | args = parser.parse_args() 41 | 42 | 43 | def inference(args, model, test_save_path=None): 44 | db_test = Synapse_dataset(base_dir=args.root_path, split="test", list_dir=args.list_dir) 45 | testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 46 | logging.info("{} test iterations per epoch".format(len(testloader))) 47 | model.eval() 48 | metric_list = 0.0 49 | for i_batch, sampled_batch in tqdm(enumerate(testloader)): 50 | h, w = sampled_batch["image"].size()[2:] 51 | image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0] 52 | metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[256, 256], 53 | test_save_path=test_save_path, case=case_name, z_spacing=10) 54 | metric_list += np.array(metric_i) 55 | logging.info('idx %d case %s mean_dice %f jc %f mean_hd95 %f asd %f ' % (i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1], np.mean(metric_i, axis=0)[2], np.mean(metric_i, axis=0)[3])) 56 | metric_list = metric_list / len(db_test) 57 | for i in range(1, args.num_classes): 58 | logging.info('Mean class %d mean_dice %f jc %f mean_hd95 %f asd %f ' % (i, metric_list[i-1][0], metric_list[i-1][1], metric_list[i-1][2], metric_list[i-1][3])) 59 | performance = np.mean(metric_list, axis=0)[0] 60 | jc = np.mean(metric_list, axis=0)[1] 61 | mean_hd95 =np.mean(metric_list, axis=0)[2] 62 | asd = np.mean(metric_list, axis=0)[3] 63 | logging.info('Testing performance in best val model: mean_dice : %f jc %f mean_hd95 : %f asd %f' % (performance, jc, mean_hd95, asd)) 64 | return metric_list 65 | 66 | def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1): 67 | image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() 68 | if len(image.shape) == 3: 69 | prediction = np.zeros_like(label) 70 | for ind in range(image.shape[0]): 71 | slice = image[ind, :, :] 72 | x, y = slice.shape[0], slice.shape[1] 73 | if x != patch_size[0] or y != patch_size[1]: 74 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0 75 | input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() 76 | net.eval() 77 | with torch.no_grad(): 78 | outputs = net(input)[0] 79 | out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0) 80 | out = out.cpu().detach().numpy() 81 | if x != patch_size[0] or y != patch_size[1]: 82 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 83 | else: 84 | pred = out 85 | prediction[ind] = pred 86 | else: 87 | input = torch.from_numpy(image).unsqueeze( 88 | 0).unsqueeze(0).float().cuda() 89 | net.eval() 90 | with torch.no_grad(): 91 | out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0) 92 | prediction = out.cpu().detach().numpy() 93 | metric_list = [] 94 | for i in range(1, classes): 95 | metric_list.append(calculate_metric_percase(prediction == i, label == i)) 96 | 97 | if test_save_path is not None: 98 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 99 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 100 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 101 | img_itk.SetSpacing((1, 1, z_spacing)) 102 | prd_itk.SetSpacing((1, 1, z_spacing)) 103 | lab_itk.SetSpacing((1, 1, z_spacing)) 104 | sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz") 105 | sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz") 106 | sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz") 107 | return metric_list 108 | 109 | def calculate_metric_percase(pred, gt): 110 | pred[pred > 0] = 1 111 | gt[gt > 0] = 1 112 | if pred.sum() > 0 and gt.sum() > 0: 113 | 114 | dice = metric.binary.dc(pred, gt) 115 | jc = metric.binary.jc(pred, gt) 116 | asd = metric.binary.asd(pred, gt) 117 | hd95 = metric.binary.hd95(pred, gt) 118 | return dice, jc, hd95, asd 119 | elif pred.sum() > 0 and gt.sum() == 0: 120 | return 1, 1, 0, 0 121 | else: 122 | return 0, 0, 0, 0 123 | 124 | if __name__ == '__main__': 125 | FLAGS = parser.parse_args() 126 | test_save_path = "../model/{}_{}_labeledfinal/{}_predictions/".format(# 127 | FLAGS.exp, FLAGS.labeled_num, FLAGS.model) 128 | snapshot_path = "../model/{}_{}_labeledfinal/{}".format(# 129 | FLAGS.exp, FLAGS.labeled_num, FLAGS.model) 130 | if os.path.exists(test_save_path): 131 | shutil.rmtree(test_save_path) 132 | os.makedirs(test_save_path) 133 | net = create_model(ema=False, num_classes=FLAGS.num_classes, train_encoder=False, train_decoder=False) 134 | save_mode_path = os.path.join( 135 | snapshot_path, 'iter_{}.pth'.format(FLAGS.epoch)) 136 | if torch.cuda.device_count() > 1: 137 | net = torch.nn.DataParallel(net) 138 | net.load_state_dict(torch.load(save_mode_path)) 139 | print("init weight from {}".format(save_mode_path)) 140 | metric = inference(FLAGS, model=net, test_save_path=test_save_path) 141 | print(metric) 142 | print((metric[0]+metric[1])/2) 143 | -------------------------------------------------------------------------------- /code/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import matplotlib.pyplot as plt 6 | from skimage import measure 7 | import scipy.ndimage as nd 8 | 9 | 10 | def recursive_glob(rootdir='.', suffix=''): 11 | """Performs recursive glob with given suffix and rootdir 12 | :param rootdir is the root directory 13 | :param suffix is the suffix to be searched 14 | """ 15 | return [os.path.join(looproot, filename) 16 | for looproot, _, filenames in os.walk(rootdir) 17 | for filename in filenames if filename.endswith(suffix)] 18 | 19 | def get_cityscapes_labels(): 20 | return np.array([ 21 | # [ 0, 0, 0], 22 | [128, 64, 128], 23 | [244, 35, 232], 24 | [70, 70, 70], 25 | [102, 102, 156], 26 | [190, 153, 153], 27 | [153, 153, 153], 28 | [250, 170, 30], 29 | [220, 220, 0], 30 | [107, 142, 35], 31 | [152, 251, 152], 32 | [0, 130, 180], 33 | [220, 20, 60], 34 | [255, 0, 0], 35 | [0, 0, 142], 36 | [0, 0, 70], 37 | [0, 60, 100], 38 | [0, 80, 100], 39 | [0, 0, 230], 40 | [119, 11, 32]]) 41 | 42 | def get_pascal_labels(): 43 | """Load the mapping that associates pascal classes with label colors 44 | Returns: 45 | np.ndarray with dimensions (21, 3) 46 | """ 47 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 48 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 49 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 50 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 51 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 52 | [0, 64, 128]]) 53 | 54 | 55 | def encode_segmap(mask): 56 | """Encode segmentation label images as pascal classes 57 | Args: 58 | mask (np.ndarray): raw segmentation label image of dimension 59 | (M, N, 3), in which the Pascal classes are encoded as colours. 60 | Returns: 61 | (np.ndarray): class map with dimensions (M,N), where the value at 62 | a given location is the integer denoting the class index. 63 | """ 64 | mask = mask.astype(int) 65 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 66 | for ii, label in enumerate(get_pascal_labels()): 67 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 68 | label_mask = label_mask.astype(int) 69 | return label_mask 70 | 71 | 72 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 73 | rgb_masks = [] 74 | for label_mask in label_masks: 75 | rgb_mask = decode_segmap(label_mask, dataset) 76 | rgb_masks.append(rgb_mask) 77 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 78 | return rgb_masks 79 | 80 | def decode_segmap(label_mask, dataset, plot=False): 81 | """Decode segmentation class labels into a color image 82 | Args: 83 | label_mask (np.ndarray): an (M,N) array of integer values denoting 84 | the class label at each spatial location. 85 | plot (bool, optional): whether to show the resulting color image 86 | in a figure. 87 | Returns: 88 | (np.ndarray, optional): the resulting decoded color image. 89 | """ 90 | if dataset == 'pascal': 91 | n_classes = 21 92 | label_colours = get_pascal_labels() 93 | elif dataset == 'cityscapes': 94 | n_classes = 19 95 | label_colours = get_cityscapes_labels() 96 | else: 97 | raise NotImplementedError 98 | 99 | r = label_mask.copy() 100 | g = label_mask.copy() 101 | b = label_mask.copy() 102 | for ll in range(0, n_classes): 103 | r[label_mask == ll] = label_colours[ll, 0] 104 | g[label_mask == ll] = label_colours[ll, 1] 105 | b[label_mask == ll] = label_colours[ll, 2] 106 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 107 | rgb[:, :, 0] = r / 255.0 108 | rgb[:, :, 1] = g / 255.0 109 | rgb[:, :, 2] = b / 255.0 110 | if plot: 111 | plt.imshow(rgb) 112 | plt.show() 113 | else: 114 | return rgb 115 | 116 | def generate_param_report(logfile, param): 117 | log_file = open(logfile, 'w') 118 | # for key, val in param.items(): 119 | # log_file.write(key + ':' + str(val) + '\n') 120 | log_file.write(str(param)) 121 | log_file.close() 122 | 123 | def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): 124 | n, c, h, w = logit.size() 125 | # logit = logit.permute(0, 2, 3, 1) 126 | target = target.squeeze(1) 127 | if weight is None: 128 | criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) 129 | else: 130 | criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=False) 131 | loss = criterion(logit, target.long()) 132 | 133 | if size_average: 134 | loss /= (h * w) 135 | 136 | if batch_average: 137 | loss /= n 138 | 139 | return loss 140 | 141 | def lr_poly(base_lr, iter_, max_iter=100, power=0.9): 142 | return base_lr * ((1 - float(iter_) / max_iter) ** power) 143 | 144 | 145 | def get_iou(pred, gt, n_classes=21): 146 | total_iou = 0.0 147 | for i in range(len(pred)): 148 | pred_tmp = pred[i] 149 | gt_tmp = gt[i] 150 | 151 | intersect = [0] * n_classes 152 | union = [0] * n_classes 153 | for j in range(n_classes): 154 | match = (pred_tmp == j) + (gt_tmp == j) 155 | 156 | it = torch.sum(match == 2).item() 157 | un = torch.sum(match > 0).item() 158 | 159 | intersect[j] += it 160 | union[j] += un 161 | 162 | iou = [] 163 | for k in range(n_classes): 164 | if union[k] == 0: 165 | continue 166 | iou.append(intersect[k] / union[k]) 167 | 168 | img_iou = (sum(iou) / len(iou)) 169 | total_iou += img_iou 170 | 171 | return total_iou 172 | 173 | def get_dice(pred, gt): 174 | total_dice = 0.0 175 | pred = pred.long() 176 | gt = gt.long() 177 | for i in range(len(pred)): 178 | pred_tmp = pred[i] 179 | gt_tmp = gt[i] 180 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 181 | print(dice) 182 | total_dice += dice 183 | 184 | return total_dice 185 | 186 | def get_mc_dice(pred, gt, num=2): 187 | # num is the total number of classes, include the background 188 | total_dice = np.zeros(num-1) 189 | pred = pred.long() 190 | gt = gt.long() 191 | for i in range(len(pred)): 192 | for j in range(1, num): 193 | pred_tmp = (pred[i]==j) 194 | gt_tmp = (gt[i]==j) 195 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 196 | total_dice[j-1] +=dice 197 | return total_dice 198 | 199 | def post_processing(prediction): 200 | prediction = nd.binary_fill_holes(prediction) 201 | label_cc, num_cc = measure.label(prediction,return_num=True) 202 | total_cc = np.sum(prediction) 203 | measure.regionprops(label_cc) 204 | for cc in range(1,num_cc+1): 205 | single_cc = (label_cc==cc) 206 | single_vol = np.sum(single_cc) 207 | if single_vol/total_cc<0.2: 208 | prediction[single_cc]=0 209 | 210 | return prediction -------------------------------------------------------------------------------- /code/networks/encoder_tool.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from efficientnet_pytorch import EfficientNet 7 | from efficientnet_pytorch.utils import get_model_params, url_map 8 | 9 | 10 | class EncoderMixin: 11 | """Add encoder functionality such as: 12 | - output channels specification of feature tensors (produced by encoder) 13 | - patching first convolution for arbitrary input channels 14 | """ 15 | 16 | @property 17 | def out_channels(self) -> List: 18 | """Return channels dimensions for each tensor of forward output of encoder""" 19 | return self._out_channels[: self._depth + 1] 20 | 21 | def set_in_channels(self, in_channels): 22 | """Change first convolution chennels""" 23 | if in_channels == 3: 24 | return 25 | 26 | self._in_channels = in_channels 27 | if self._out_channels[0] == 3: 28 | self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) 29 | 30 | patch_first_conv(model=self, in_channels=in_channels) 31 | 32 | 33 | def patch_first_conv(model, in_channels): 34 | """Change first convolution layer input channels. 35 | In case: 36 | in_channels == 1 or in_channels == 2 -> reuse original weights 37 | in_channels > 3 -> make random kaiming normal initialization 38 | """ 39 | 40 | # get first conv 41 | for module in model.modules(): 42 | if isinstance(module, nn.Conv2d): 43 | break 44 | 45 | # change input channels for first conv 46 | module.in_channels = in_channels 47 | weight = module.weight.detach() 48 | reset = False 49 | 50 | if in_channels == 1: 51 | weight = weight.sum(1, keepdim=True) 52 | elif in_channels == 2: 53 | weight = weight[:, :2] * (3.0 / 2.0) 54 | else: 55 | reset = True 56 | weight = torch.Tensor( 57 | module.out_channels, 58 | module.in_channels // module.groups, 59 | *module.kernel_size 60 | ) 61 | 62 | module.weight = nn.parameter.Parameter(weight) 63 | if reset: 64 | module.reset_parameters() 65 | 66 | 67 | class EfficientNetEncoder(EfficientNet, EncoderMixin): 68 | def __init__(self, stage_idxs, out_channels, model_name, depth=5): 69 | 70 | blocks_args, global_params = get_model_params(model_name, override_params=None) 71 | super().__init__(blocks_args, global_params) 72 | 73 | self._stage_idxs = list(stage_idxs) + [len(self._blocks)] 74 | self._out_channels = out_channels 75 | self._depth = depth 76 | self._in_channels = 3 77 | 78 | del self._fc 79 | 80 | def forward(self, x): 81 | 82 | features = [x] 83 | 84 | if self._depth > 0: 85 | x = self._swish(self._bn0(self._conv_stem(x))) 86 | features.append(x) 87 | 88 | if self._depth > 1: 89 | skip_connection_idx = 0 90 | for idx, block in enumerate(self._blocks): 91 | drop_connect_rate = self._global_params.drop_connect_rate 92 | if drop_connect_rate: 93 | drop_connect_rate *= float(idx) / len(self._blocks) 94 | x = block(x, drop_connect_rate=drop_connect_rate) 95 | if idx == self._stage_idxs[skip_connection_idx] - 1: 96 | skip_connection_idx += 1 97 | features.append(x) 98 | if skip_connection_idx + 1 == self._depth: 99 | break 100 | return features 101 | 102 | def load_state_dict(self, state_dict, **kwargs): 103 | state_dict.pop("_fc.bias") 104 | state_dict.pop("_fc.weight") 105 | super().load_state_dict(state_dict, **kwargs) 106 | 107 | 108 | def _get_pretrained_settings(encoder): 109 | pretrained_settings = { 110 | "imagenet": { 111 | "mean": [0.485, 0.456, 0.406], 112 | "std": [0.229, 0.224, 0.225], 113 | "url": url_map[encoder], 114 | "input_space": "RGB", 115 | "input_range": [0, 1], 116 | } 117 | } 118 | return pretrained_settings 119 | 120 | 121 | efficient_net_encoders = { 122 | "efficientnet-b0": { 123 | "encoder": EfficientNetEncoder, 124 | "pretrained_settings": _get_pretrained_settings("efficientnet-b0"), 125 | "params": { 126 | "out_channels": (3, 32, 24, 40, 112, 320), 127 | "stage_idxs": (3, 5, 9), 128 | "model_name": "efficientnet-b0", 129 | }, 130 | }, 131 | "efficientnet-b1": { 132 | "encoder": EfficientNetEncoder, 133 | "pretrained_settings": _get_pretrained_settings("efficientnet-b1"), 134 | "params": { 135 | "out_channels": (3, 32, 24, 40, 112, 320), 136 | "stage_idxs": (5, 8, 16), 137 | "model_name": "efficientnet-b1", 138 | }, 139 | }, 140 | "efficientnet-b2": { 141 | "encoder": EfficientNetEncoder, 142 | "pretrained_settings": _get_pretrained_settings("efficientnet-b2"), 143 | "params": { 144 | "out_channels": (3, 32, 24, 48, 120, 352), 145 | "stage_idxs": (5, 8, 16), 146 | "model_name": "efficientnet-b2", 147 | }, 148 | }, 149 | "efficientnet-b3": { 150 | "encoder": EfficientNetEncoder, 151 | "pretrained_settings": _get_pretrained_settings("efficientnet-b3"), 152 | "params": { 153 | "out_channels": (3, 40, 32, 48, 136, 384), 154 | "stage_idxs": (5, 8, 18), 155 | "model_name": "efficientnet-b3", 156 | }, 157 | }, 158 | "efficientnet-b4": { 159 | "encoder": EfficientNetEncoder, 160 | "pretrained_settings": _get_pretrained_settings("efficientnet-b4"), 161 | "params": { 162 | "out_channels": (3, 48, 32, 56, 160, 448), 163 | "stage_idxs": (6, 10, 22), 164 | "model_name": "efficientnet-b4", 165 | }, 166 | }, 167 | "efficientnet-b5": { 168 | "encoder": EfficientNetEncoder, 169 | "pretrained_settings": _get_pretrained_settings("efficientnet-b5"), 170 | "params": { 171 | "out_channels": (3, 48, 40, 64, 176, 512), 172 | "stage_idxs": (8, 13, 27), 173 | "model_name": "efficientnet-b5", 174 | }, 175 | }, 176 | "efficientnet-b6": { 177 | "encoder": EfficientNetEncoder, 178 | "pretrained_settings": _get_pretrained_settings("efficientnet-b6"), 179 | "params": { 180 | "out_channels": (3, 56, 40, 72, 200, 576), 181 | "stage_idxs": (9, 15, 31), 182 | "model_name": "efficientnet-b6", 183 | }, 184 | }, 185 | "efficientnet-b7": { 186 | "encoder": EfficientNetEncoder, 187 | "pretrained_settings": _get_pretrained_settings("efficientnet-b7"), 188 | "params": { 189 | "out_channels": (3, 64, 48, 80, 224, 640), 190 | "stage_idxs": (11, 18, 38), 191 | "model_name": "efficientnet-b7", 192 | }, 193 | }, 194 | } 195 | 196 | encoders = {} 197 | encoders.update(efficient_net_encoders) 198 | 199 | 200 | def get_encoder(name, in_channels=3, depth=5, weights=None): 201 | Encoder = encoders[name]["encoder"] 202 | params = encoders[name]["params"] 203 | params.update(depth=depth) 204 | encoder = Encoder(**params) 205 | 206 | if weights is not None: 207 | settings = encoders[name]["pretrained_settings"][weights] 208 | encoder.load_state_dict(model_zoo.load_url(settings["url"])) 209 | 210 | encoder.set_in_channels(in_channels) 211 | 212 | return encoder 213 | -------------------------------------------------------------------------------- /code/build_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | from PIL import Image 3 | from PIL import ImageFilter 4 | import pandas as pd 5 | import numpy as np 6 | import torch 7 | import os 8 | import random 9 | import itertools 10 | import glob 11 | 12 | import torch.utils.data.sampler as sampler 13 | import torchvision.transforms as transforms 14 | import torchvision.transforms.functional as transforms_f 15 | from torch.utils.data.sampler import Sampler 16 | import h5py 17 | 18 | class BaseDataSetsWithIndex(Dataset): 19 | def __init__(self, base_dir=None, split='train', num=None, transform=None, index=16, label_type=0): 20 | self._base_dir = base_dir 21 | self.index = index 22 | self.sample_list = [] 23 | self.split = split 24 | self.transform = transform 25 | if self.split == 'train' and 'ACDC' in base_dir: 26 | with open(self._base_dir + '/train_slices.list', 'r') as f1: 27 | self.sample_list = f1.readlines() 28 | self.sample_list = [item.replace('\n', '') 29 | for item in self.sample_list] 30 | if(label_type==1): 31 | self.sample_list = self.sample_list[:index] 32 | else: 33 | self.sample_list = self.sample_list[index:] 34 | elif self.split == 'train' and 'MM' in base_dir: 35 | with open(self._base_dir + '/train_slices.txt', 'r') as f1: 36 | self.sample_list = f1.readlines() 37 | self.sample_list = [item.replace('.h5\n', '') 38 | for item in self.sample_list] 39 | if(label_type==1): 40 | self.sample_list = self.sample_list[:index] 41 | else: 42 | self.sample_list = self.sample_list[index:] 43 | 44 | elif self.split == 'val': 45 | with open(self._base_dir + '/val.list', 'r') as f: 46 | self.sample_list = f.readlines() 47 | self.sample_list = [item.replace('\n', '') 48 | for item in self.sample_list] 49 | if num is not None and self.split == "train": 50 | self.sample_list = self.sample_list[:num-index] 51 | print("total {} samples".format(len(self.sample_list))) 52 | 53 | def __len__(self): 54 | return len(self.sample_list) 55 | 56 | def __getitem__(self, idx): 57 | case = self.sample_list[idx] 58 | if self.split == "train": 59 | h5f = h5py.File(self._base_dir + 60 | "/data/slices/{}.h5".format(case), 'r') 61 | else: 62 | h5f = h5py.File(self._base_dir + "/data/{}.h5".format(case), 'r') 63 | image = h5f['image'][:] 64 | label = h5f['label'][:] 65 | sample = {'image': image, 'label': label} 66 | if self.split == "train" and self.transform!=None: 67 | sample = self.transform(sample) 68 | sample["idx"] = idx 69 | return sample 70 | 71 | # class TwoStreamBatchSampler(Sampler): 72 | # """Iterate two sets of indices 73 | # An 'epoch' is one iteration through the primary indices. 74 | # During the epoch, the secondary indices are iterated through 75 | # as many times as needed. 76 | # """ 77 | 78 | # def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 79 | # self.primary_indices = primary_indices 80 | # self.secondary_indices = secondary_indices 81 | # self.secondary_batch_size = secondary_batch_size 82 | # self.primary_batch_size = batch_size - secondary_batch_size 83 | 84 | # assert len(self.primary_indices) >= self.primary_batch_size > 0 85 | # assert len(self.secondary_indices) >= self.secondary_batch_size > 0 86 | 87 | # def __iter__(self): 88 | # primary_iter = iterate_once(self.primary_indices) 89 | # secondary_iter = self.iterate_eternally(self.secondary_indices) 90 | # return ( 91 | # primary_batch + secondary_batch 92 | # for (primary_batch, secondary_batch) 93 | # in zip(grouper(primary_iter, self.primary_batch_size), 94 | # grouper(secondary_iter, self.secondary_batch_size)) 95 | # ) 96 | 97 | # def __len__(self): 98 | # return len(self.primary_indices) // self.primary_batch_size 99 | 100 | 101 | # def iterate_eternally(self,indices): 102 | # n = len(self.data_source) 103 | # # def infinite_shuffles(): 104 | # # while True: 105 | # # yield np.random.permutation(indices) 106 | # # return itertools.chain.from_iterable(infinite_shuffles()) 107 | # for _ in range(self.num_samples // 32): 108 | # yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=torch.Generator()).tolist() 109 | # yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=torch.Generator()).tolist() 110 | 111 | 112 | # def iterate_once(iterable): 113 | # # return np.random.permutation(iterable) # changes here 114 | # def infinite_shuffles(): 115 | # while True: 116 | # yield np.random.permutation(iterable) 117 | # return itertools.chain.from_iterable(infinite_shuffles()) 118 | 119 | 120 | 121 | # def grouper(iterable, n): 122 | # "Collect data into fixed-length chunks or blocks" 123 | # # grouper('ABCDEFG', 3) --> ABC DEF" 124 | # args = [iter(iterable)] * n 125 | # return zip(*args) 126 | 127 | class Synapse_dataset(Dataset): 128 | def __init__(self, base_dir, list_dir, split, transform=None): 129 | self.transform = transform # using transform in torch! 130 | self.split = split 131 | if (split == 'test' or split == 'val'): 132 | self.sample_list = open(os.path.join(list_dir, self.split+'_vol_40.txt')).readlines() 133 | else: 134 | self.sample_list = open(os.path.join(list_dir, self.split+'_40.txt')).readlines() 135 | self.data_dir = base_dir 136 | 137 | def __len__(self): 138 | return len(self.sample_list) 139 | 140 | def __getitem__(self, idx): 141 | if self.split == "train": 142 | slice_name = self.sample_list[idx].strip('\n') 143 | data_path = os.path.join(self.data_dir, slice_name+'.npz') 144 | # print(data_path) 145 | data = np.load(data_path) 146 | image, label = data['image'], data['label'] 147 | else: 148 | vol_name = self.sample_list[idx].strip('\n') 149 | filepath = self.data_dir + "/{}.npy.h5".format(vol_name) 150 | data = h5py.File(filepath) 151 | image, label = data['image'][:], data['label'][:] 152 | 153 | sample = {'image': image, 'label': label} 154 | if self.transform: 155 | sample = self.transform(sample) 156 | sample['case_name'] = self.sample_list[idx].strip('\n') 157 | return sample 158 | 159 | class Synapse_datasetWithIndex(Dataset): 160 | def __init__(self, base_dir, list_dir, split, transform=None, index=221, label_type=1): 161 | self.transform = transform # using transform in torch! 162 | self.split = split 163 | self.sample_list = open(os.path.join(list_dir, self.split+'_40.txt')).readlines() 164 | self.data_dir = base_dir 165 | self.index = index 166 | self.label_type = label_type 167 | if(label_type==1): 168 | self.sample_list = self.sample_list[:index] 169 | else: 170 | self.sample_list = self.sample_list[index:] 171 | 172 | def __len__(self): 173 | return len(self.sample_list) 174 | 175 | def __getitem__(self, idx): 176 | if self.split == "train": 177 | slice_name = self.sample_list[idx].strip('\n') 178 | data_path = os.path.join(self.data_dir, slice_name+'.npz') 179 | # print(data_path) 180 | data = np.load(data_path) 181 | image, label = data['image'], data['label'] 182 | else: 183 | vol_name = self.sample_list[idx].strip('\n') 184 | filepath = self.data_dir + "/{}.npy.h5".format(vol_name) 185 | data = h5py.File(filepath) 186 | image, label = data['image'][:], data['label'][:] 187 | 188 | sample = {'image': image, 'label': label} 189 | if self.transform: 190 | sample = self.transform(sample) 191 | sample['case_name'] = self.sample_list[idx].strip('\n') 192 | return sample 193 | -------------------------------------------------------------------------------- /code/networks/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' 51 | _C.MODEL.RESUME = '' 52 | # Number of classes, overwritten in data preparation 53 | _C.MODEL.NUM_CLASSES = 1000 54 | # Dropout rate 55 | _C.MODEL.DROP_RATE = 0.0 56 | # Drop path rate 57 | _C.MODEL.DROP_PATH_RATE = 0.1 58 | # Label Smoothing 59 | _C.MODEL.LABEL_SMOOTHING = 0.1 60 | 61 | # Swin Transformer parameters 62 | _C.MODEL.SWIN = CN() 63 | _C.MODEL.SWIN.PATCH_SIZE = 4 64 | _C.MODEL.SWIN.IN_CHANS = 3 65 | _C.MODEL.SWIN.EMBED_DIM = 96 66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 67 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] 68 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 69 | _C.MODEL.SWIN.WINDOW_SIZE = 7 70 | _C.MODEL.SWIN.MLP_RATIO = 4. 71 | _C.MODEL.SWIN.QKV_BIAS = True 72 | _C.MODEL.SWIN.QK_SCALE = None 73 | _C.MODEL.SWIN.APE = False 74 | _C.MODEL.SWIN.PATCH_NORM = True 75 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" 76 | 77 | # ----------------------------------------------------------------------------- 78 | # Training settings 79 | # ----------------------------------------------------------------------------- 80 | _C.TRAIN = CN() 81 | _C.TRAIN.START_EPOCH = 0 82 | _C.TRAIN.EPOCHS = 300 83 | _C.TRAIN.WARMUP_EPOCHS = 20 84 | _C.TRAIN.WEIGHT_DECAY = 0.05 85 | _C.TRAIN.BASE_LR = 5e-4 86 | _C.TRAIN.WARMUP_LR = 5e-7 87 | _C.TRAIN.MIN_LR = 5e-6 88 | # Clip gradient norm 89 | _C.TRAIN.CLIP_GRAD = 5.0 90 | # Auto resume from latest checkpoint 91 | _C.TRAIN.AUTO_RESUME = True 92 | # Gradient accumulation steps 93 | # could be overwritten by command line argument 94 | _C.TRAIN.ACCUMULATION_STEPS = 0 95 | # Whether to use gradient checkpointing to save memory 96 | # could be overwritten by command line argument 97 | _C.TRAIN.USE_CHECKPOINT = False 98 | 99 | # LR scheduler 100 | _C.TRAIN.LR_SCHEDULER = CN() 101 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 102 | # Epoch interval to decay LR, used in StepLRScheduler 103 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 104 | # LR decay rate, used in StepLRScheduler 105 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 106 | 107 | # Optimizer 108 | _C.TRAIN.OPTIMIZER = CN() 109 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 110 | # Optimizer Epsilon 111 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 112 | # Optimizer Betas 113 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 114 | # SGD momentum 115 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 116 | 117 | # ----------------------------------------------------------------------------- 118 | # Augmentation settings 119 | # ----------------------------------------------------------------------------- 120 | _C.AUG = CN() 121 | # Color jitter factor 122 | _C.AUG.COLOR_JITTER = 0.4 123 | # Use AutoAugment policy. "v0" or "original" 124 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 125 | # Random erase prob 126 | _C.AUG.REPROB = 0.25 127 | # Random erase mode 128 | _C.AUG.REMODE = 'pixel' 129 | # Random erase count 130 | _C.AUG.RECOUNT = 1 131 | # Mixup alpha, mixup enabled if > 0 132 | _C.AUG.MIXUP = 0.8 133 | # Cutmix alpha, cutmix enabled if > 0 134 | _C.AUG.CUTMIX = 1.0 135 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 136 | _C.AUG.CUTMIX_MINMAX = None 137 | # Probability of performing mixup or cutmix when either/both is enabled 138 | _C.AUG.MIXUP_PROB = 1.0 139 | # Probability of switching to cutmix when both mixup and cutmix enabled 140 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 141 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 142 | _C.AUG.MIXUP_MODE = 'batch' 143 | 144 | # ----------------------------------------------------------------------------- 145 | # Testing settings 146 | # ----------------------------------------------------------------------------- 147 | _C.TEST = CN() 148 | # Whether to use center crop when testing 149 | _C.TEST.CROP = True 150 | 151 | # ----------------------------------------------------------------------------- 152 | # Misc 153 | # ----------------------------------------------------------------------------- 154 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 155 | # overwritten by command line argument 156 | _C.AMP_OPT_LEVEL = '' 157 | # Path to output folder, overwritten by command line argument 158 | _C.OUTPUT = '' 159 | # Tag of experiment, overwritten by command line argument 160 | _C.TAG = 'default' 161 | # Frequency to save checkpoint 162 | _C.SAVE_FREQ = 1 163 | # Frequency to logging info 164 | _C.PRINT_FREQ = 10 165 | # Fixed random seed 166 | _C.SEED = 0 167 | # Perform evaluation only, overwritten by command line argument 168 | _C.EVAL_MODE = False 169 | # Test throughput only, overwritten by command line argument 170 | _C.THROUGHPUT_MODE = False 171 | # local rank for DistributedDataParallel, given by command line argument 172 | _C.LOCAL_RANK = 0 173 | 174 | 175 | def _update_config_from_file(config, cfg_file): 176 | config.defrost() 177 | with open(cfg_file, 'r') as f: 178 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 179 | 180 | for cfg in yaml_cfg.setdefault('BASE', ['']): 181 | if cfg: 182 | _update_config_from_file( 183 | config, os.path.join(os.path.dirname(cfg_file), cfg) 184 | ) 185 | print('=> merge config from {}'.format(cfg_file)) 186 | config.merge_from_file(cfg_file) 187 | config.freeze() 188 | 189 | 190 | def update_config(config, args): 191 | _update_config_from_file(config, args.cfg) 192 | 193 | config.defrost() 194 | if args.opts: 195 | config.merge_from_list(args.opts) 196 | 197 | # merge from specific arguments 198 | if args.batch_size: 199 | config.DATA.BATCH_SIZE = args.batch_size 200 | if args.zip: 201 | config.DATA.ZIP_MODE = True 202 | if args.cache_mode: 203 | config.DATA.CACHE_MODE = args.cache_mode 204 | if args.resume: 205 | config.MODEL.RESUME = args.resume 206 | if args.accumulation_steps: 207 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 208 | if args.use_checkpoint: 209 | config.TRAIN.USE_CHECKPOINT = True 210 | if args.amp_opt_level: 211 | config.AMP_OPT_LEVEL = args.amp_opt_level 212 | if args.tag: 213 | config.TAG = args.tag 214 | if args.eval: 215 | config.EVAL_MODE = True 216 | if args.throughput: 217 | config.THROUGHPUT_MODE = True 218 | 219 | config.freeze() 220 | 221 | 222 | def get_config(args): 223 | """Get a yacs CfgNode object with default values.""" 224 | # Return a clone so that the defaults will not be altered 225 | # This is for the "local variable" use pattern 226 | config = _C.clone() 227 | update_config(config, args) 228 | 229 | return config 230 | -------------------------------------------------------------------------------- /code/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | def entropy(x, idx=None): 8 | """ 9 | Helper function to compute the entropy over the batch 10 | input: batch w/ shape [b, H, W, C] 11 | idx: batch w/ shape [b, H, W] 12 | output: entropy value [is ideally -log(num_classes)] 13 | """ 14 | EPS = 1e-8 15 | x_ = torch.clamp(x, min = EPS) 16 | b = x_ * torch.log(x_) 17 | 18 | # if(len(b.size()) == 4): 19 | # b = b.permute(0, 2, 3, 1) 20 | # if idx == None: 21 | # idx = (torch.randn(*x.shape[:3])>0.5).bool() 22 | # b = b[idx] 23 | # b = b.flatten(0, 2) 24 | 25 | if len(b.size()) == 4: # Sample-wise entropy 26 | print('use this one') 27 | return - b.sum(dim = -1) 28 | elif len(b.size()) == 1: # Distribution-wise entropy 29 | return - b.sum() 30 | else: 31 | raise ValueError('Input tensor is %d-Dimensional' %(len(b.size()))) 32 | 33 | 34 | def dice_loss(score, target): 35 | target = target.float() 36 | smooth = 1e-5 37 | intersect = torch.sum(score * target) 38 | y_sum = torch.sum(target * target) 39 | z_sum = torch.sum(score * score) 40 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 41 | loss = 1 - loss 42 | return loss 43 | 44 | 45 | def dice_loss1(score, target): 46 | target = target.float() 47 | smooth = 1e-5 48 | intersect = torch.sum(score * target) 49 | y_sum = torch.sum(target) 50 | z_sum = torch.sum(score) 51 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 52 | loss = 1 - loss 53 | return loss 54 | 55 | 56 | def entropy_loss(p, C=2): 57 | # p N*C*W*H*D 58 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) / \ 59 | torch.tensor(np.log(C)).cuda() 60 | ent = torch.mean(y1) 61 | 62 | return ent 63 | 64 | 65 | def softmax_dice_loss(input_logits, target_logits): 66 | """Takes softmax on both sides and returns MSE loss 67 | Note: 68 | - Returns the sum over all examples. Divide by the batch size afterwards 69 | if you want the mean. 70 | - Sends gradients to inputs but not the targets. 71 | """ 72 | assert input_logits.size() == target_logits.size() 73 | input_softmax = F.softmax(input_logits, dim=1) 74 | target_softmax = F.softmax(target_logits, dim=1) 75 | n = input_logits.shape[1] 76 | dice = 0 77 | for i in range(0, n): 78 | dice += dice_loss1(input_softmax[:, i], target_softmax[:, i]) 79 | mean_dice = dice / n 80 | 81 | return mean_dice 82 | 83 | 84 | def entropy_loss_map(p, C=2): 85 | ent = -1*torch.sum(p * torch.log(p + 1e-6), dim=1, 86 | keepdim=True)/torch.tensor(np.log(C)).cuda() 87 | return ent 88 | 89 | 90 | def softmax_mse_loss(input_logits, target_logits, sigmoid=False): 91 | """Takes softmax on both sides and returns MSE loss 92 | Note: 93 | - Returns the sum over all examples. Divide by the batch size afterwards 94 | if you want the mean. 95 | - Sends gradients to inputs but not the targets. 96 | """ 97 | assert input_logits.size() == target_logits.size() 98 | if sigmoid: 99 | input_softmax = torch.sigmoid(input_logits) 100 | target_softmax = torch.sigmoid(target_logits) 101 | else: 102 | input_softmax = F.softmax(input_logits, dim=1) 103 | target_softmax = F.softmax(target_logits, dim=1) 104 | 105 | mse_loss = (input_softmax-target_softmax)**2 106 | return mse_loss 107 | 108 | 109 | def softmax_kl_loss(input_logits, target_logits, sigmoid=False): 110 | """Takes softmax on both sides and returns KL divergence 111 | Note: 112 | - Returns the sum over all examples. Divide by the batch size afterwards 113 | if you want the mean. 114 | - Sends gradients to inputs but not the targets. 115 | """ 116 | assert input_logits.size() == target_logits.size() 117 | if sigmoid: 118 | input_log_softmax = torch.log(torch.sigmoid(input_logits)) 119 | target_softmax = torch.sigmoid(target_logits) 120 | else: 121 | input_log_softmax = F.log_softmax(input_logits, dim=1) 122 | target_softmax = F.softmax(target_logits, dim=1) 123 | 124 | # return F.kl_div(input_log_softmax, target_softmax) 125 | kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean') 126 | # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...]) 127 | return kl_div 128 | 129 | 130 | def symmetric_mse_loss(input1, input2): 131 | """Like F.mse_loss but sends gradients to both directions 132 | Note: 133 | - Returns the sum over all examples. Divide by the batch size afterwards 134 | if you want the mean. 135 | - Sends gradients to both input1 and input2. 136 | """ 137 | assert input1.size() == input2.size() 138 | return torch.mean((input1 - input2)**2) 139 | 140 | 141 | class FocalLoss(nn.Module): 142 | def __init__(self, gamma=2, alpha=None, size_average=True): 143 | super(FocalLoss, self).__init__() 144 | self.gamma = gamma 145 | self.alpha = alpha 146 | if isinstance(alpha, (float, int)): 147 | self.alpha = torch.Tensor([alpha, 1-alpha]) 148 | if isinstance(alpha, list): 149 | self.alpha = torch.Tensor(alpha) 150 | self.size_average = size_average 151 | 152 | def forward(self, input, target): 153 | if input.dim() > 2: 154 | # N,C,H,W => N,C,H*W 155 | input = input.view(input.size(0), input.size(1), -1) 156 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 157 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 158 | target = target.view(-1, 1) 159 | 160 | logpt = F.log_softmax(input, dim=1) 161 | logpt = logpt.gather(1, target) 162 | logpt = logpt.view(-1) 163 | pt = Variable(logpt.data.exp()) 164 | 165 | if self.alpha is not None: 166 | if self.alpha.type() != input.data.type(): 167 | self.alpha = self.alpha.type_as(input.data) 168 | at = self.alpha.gather(0, target.data.view(-1)) 169 | logpt = logpt * Variable(at) 170 | 171 | loss = -1 * (1-pt)**self.gamma * logpt 172 | if self.size_average: 173 | return loss.mean() 174 | else: 175 | return loss.sum() 176 | 177 | 178 | class DiceLoss(nn.Module): 179 | def __init__(self, n_classes): 180 | super(DiceLoss, self).__init__() 181 | self.n_classes = n_classes 182 | 183 | def _one_hot_encoder(self, input_tensor): 184 | tensor_list = [] 185 | for i in range(self.n_classes): 186 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 187 | tensor_list.append(temp_prob) 188 | output_tensor = torch.cat(tensor_list, dim=1) 189 | return output_tensor.float() 190 | 191 | def _dice_loss(self, score, target): 192 | target = target.float() 193 | smooth = 1e-5 194 | intersect = torch.sum(score * target) 195 | y_sum = torch.sum(target * target) 196 | z_sum = torch.sum(score * score) 197 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 198 | loss = 1 - loss 199 | return loss 200 | 201 | def forward(self, inputs, target, weight=None, softmax=False): 202 | if softmax: 203 | inputs = torch.softmax(inputs, dim=1) 204 | target = self._one_hot_encoder(target) 205 | if weight is None: 206 | weight = [1] * self.n_classes 207 | assert inputs.size() == target.size(), 'predict & target shape do not match' 208 | # class_wise_dice = [] 209 | loss = 0.0 210 | for i in range(0, self.n_classes): 211 | dice = self._dice_loss(inputs[:, i], target[:, i]) 212 | # class_wise_dice.append(1.0 - dice.item()) 213 | loss += dice * weight[i] 214 | return loss / self.n_classes 215 | 216 | 217 | def entropy_minmization(p): 218 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) 219 | ent = torch.mean(y1) 220 | 221 | return ent 222 | 223 | 224 | def entropy_map(p): 225 | ent_map = -1*torch.sum(p * torch.log(p + 1e-6), dim=1, 226 | keepdim=True) 227 | return ent_map 228 | 229 | 230 | if __name__ == '__main__': 231 | input = torch.randn(20, 81, 82, 4) 232 | output = entropy(input, idx=(torch.randn(20, 81, 82)>0.5).bool()) 233 | print(output.shape) 234 | 235 | sorted, indices = torch.sort(output, dim=0) 236 | print(indices.shape) -------------------------------------------------------------------------------- /code/networks/efficientunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from networks.attention import * 6 | from networks.efficient_encoder import get_encoder 7 | 8 | 9 | def initialize_decoder(module): 10 | for m in module.modules(): 11 | 12 | if isinstance(m, nn.Conv2d): 13 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 14 | if m.bias is not None: 15 | nn.init.constant_(m.bias, 0) 16 | 17 | elif isinstance(m, nn.BatchNorm2d): 18 | nn.init.constant_(m.weight, 1) 19 | nn.init.constant_(m.bias, 0) 20 | 21 | elif isinstance(m, nn.Linear): 22 | nn.init.xavier_uniform_(m.weight) 23 | if m.bias is not None: 24 | nn.init.constant_(m.bias, 0) 25 | 26 | 27 | class DecoderBlock(nn.Module): 28 | def __init__( 29 | self, 30 | in_channels, 31 | skip_channels, 32 | out_channels, 33 | use_batchnorm=True, 34 | attention_type=None, 35 | ): 36 | super().__init__() 37 | self.conv1 = Conv2dReLU( 38 | in_channels + skip_channels, 39 | out_channels, 40 | kernel_size=3, 41 | padding=1, 42 | use_batchnorm=use_batchnorm, 43 | ) 44 | self.attention1 = Attention(attention_type, in_channels=in_channels + skip_channels) 45 | self.conv2 = Conv2dReLU( 46 | out_channels, 47 | out_channels, 48 | kernel_size=3, 49 | padding=1, 50 | use_batchnorm=use_batchnorm, 51 | ) 52 | self.attention2 = Attention(attention_type, in_channels=out_channels) 53 | 54 | def forward(self, x, skip=None): 55 | x = F.interpolate(x, scale_factor=2, mode="nearest") 56 | if skip is not None: 57 | x = torch.cat([x, skip], dim=1) 58 | x = self.attention1(x) 59 | x = self.conv1(x) 60 | x = self.conv2(x) 61 | x = self.attention2(x) 62 | return x 63 | 64 | 65 | class CenterBlock(nn.Sequential): 66 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 67 | conv1 = Conv2dReLU( 68 | in_channels, 69 | out_channels, 70 | kernel_size=3, 71 | padding=1, 72 | use_batchnorm=use_batchnorm, 73 | ) 74 | conv2 = Conv2dReLU( 75 | out_channels, 76 | out_channels, 77 | kernel_size=3, 78 | padding=1, 79 | use_batchnorm=use_batchnorm, 80 | ) 81 | super().__init__(conv1, conv2) 82 | 83 | 84 | class UnetDecoder(nn.Module): 85 | def __init__( 86 | self, 87 | encoder_channels, 88 | decoder_channels, 89 | n_blocks=5, 90 | use_batchnorm=True, 91 | attention_type=None, 92 | center=False, 93 | ): 94 | super().__init__() 95 | 96 | if n_blocks != len(decoder_channels): 97 | raise ValueError( 98 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 99 | n_blocks, len(decoder_channels) 100 | ) 101 | ) 102 | 103 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution 104 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 105 | 106 | # computing blocks input and output channels 107 | head_channels = encoder_channels[0] 108 | in_channels = [head_channels] + list(decoder_channels[:-1]) 109 | skip_channels = list(encoder_channels[1:]) + [0] 110 | out_channels = decoder_channels 111 | 112 | if center: 113 | self.center = CenterBlock( 114 | head_channels, head_channels, use_batchnorm=use_batchnorm 115 | ) 116 | else: 117 | self.center = nn.Identity() 118 | 119 | # combine decoder keyword arguments 120 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 121 | blocks = [ 122 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 123 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 124 | ] 125 | self.blocks = nn.ModuleList(blocks) 126 | 127 | def forward(self, *features): 128 | 129 | features = features[1:] # remove first skip with same spatial resolution 130 | features = features[::-1] # reverse channels to start from head of encoder 131 | 132 | head = features[0] 133 | skips = features[1:] 134 | 135 | x = self.center(head) 136 | for i, decoder_block in enumerate(self.blocks): 137 | skip = skips[i] if i < len(skips) else None 138 | x = decoder_block(x, skip) 139 | 140 | return x 141 | 142 | 143 | class Effi_UNet(nn.Module): 144 | """Unet_ is a fully convolution neural network for image semantic segmentation 145 | 146 | Args: 147 | encoder_name: name of classification model (without last dense layers) used as feature 148 | extractor to build segmentation model. 149 | encoder_depth (int): number of stages used in decoder, larger depth - more features are generated. 150 | e.g. for depth=3 encoder will generate list of features with following spatial shapes 151 | [(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have 152 | spatial resolution (H/(2^depth), W/(2^depth)] 153 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 154 | decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks 155 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 156 | is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption. 157 | One of [True, False, 'inplace'] 158 | decoder_attention_type: attention module used in decoder of the model 159 | One of [``None``, ``scse``] 160 | in_channels: number of input channels for model, default is 3. 161 | classes: a number of classes for output (output shape - ``(batch, classes, h, w)``). 162 | activation: activation function to apply after final convolution; 163 | One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None] 164 | aux_params: if specified model will have additional classification auxiliary output 165 | build on top of encoder, supported params: 166 | - classes (int): number of classes 167 | - pooling (str): one of 'max', 'avg'. Default is 'avg'. 168 | - dropout (float): dropout factor in [0, 1) 169 | - activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits) 170 | 171 | Returns: 172 | ``torch.nn.Module``: **Unet** 173 | 174 | .. _Unet: 175 | https://arxiv.org/pdf/1505.04597 176 | 177 | """ 178 | 179 | def __init__( 180 | self, 181 | encoder_name: str = "resnet34", 182 | encoder_depth: int = 5, 183 | encoder_weights: str = "imagenet", 184 | decoder_use_batchnorm=True, 185 | decoder_channels=(256, 128, 64, 32, 16), 186 | decoder_attention_type=None, 187 | in_channels: int = 3, 188 | classes: int = 1): 189 | super().__init__() 190 | 191 | self.encoder = get_encoder( 192 | encoder_name, 193 | in_channels=in_channels, 194 | depth=encoder_depth, 195 | weights=encoder_weights, 196 | ) 197 | 198 | self.decoder = UnetDecoder( 199 | encoder_channels=self.encoder.out_channels, 200 | decoder_channels=decoder_channels, 201 | n_blocks=encoder_depth, 202 | use_batchnorm=decoder_use_batchnorm, 203 | center=True if encoder_name.startswith("vgg") else False, 204 | attention_type=decoder_attention_type, 205 | ) 206 | initialize_decoder(self.decoder) 207 | self.classifier = nn.Conv2d(decoder_channels[-1], classes, 1) 208 | 209 | def forward(self, x): 210 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 211 | features = self.encoder(x) 212 | decoder_output = self.decoder(*features) 213 | output = self.classifier(decoder_output) 214 | 215 | return output 216 | 217 | 218 | # unet = UNet('efficientnet-b3', encoder_weights='imagenet', in_channels=1, classes=1, decoder_attention_type="scse") 219 | # t = torch.rand(2, 1, 224, 224) 220 | # print(unet) 221 | # print(unet(t).shape) 222 | -------------------------------------------------------------------------------- /code/test_util.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import math 3 | import nibabel as nib 4 | import numpy as np 5 | from medpy import metric 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | from skimage.measure import label 10 | 11 | def getLargestCC(segmentation): 12 | labels = label(segmentation) 13 | assert( labels.max() != 0 ) # assume at least 1 CC 14 | largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 15 | return largestCC 16 | 17 | def var_all_case_LA(model, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4): 18 | 19 | with open('../data/LA/test.list', 'r') as f: 20 | image_list = f.readlines() 21 | image_list = ["../data/LA/2018LA_Seg_Training Set/" + item.replace('\n', '') + "/mri_norm2.h5" for item in image_list] 22 | loader = tqdm(image_list) 23 | total_dice = 0.0 24 | for image_path in loader: 25 | h5f = h5py.File(image_path, 'r') 26 | image = h5f['image'][:] 27 | label = h5f['label'][:] 28 | prediction, score_map = test_single_case(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 29 | if np.sum(prediction)==0: 30 | dice = 0 31 | else: 32 | dice = metric.binary.dc(prediction, label) 33 | total_dice += dice 34 | avg_dice = total_dice / len(image_list) 35 | print('average metric is {}'.format(avg_dice)) 36 | return avg_dice 37 | 38 | def test_all_case(model, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None, metric_detail=0, nms=0): 39 | 40 | loader = tqdm(image_list) if not metric_detail else image_list 41 | total_metric = 0.0 42 | ith = 0 43 | for image_path in loader: 44 | # id = image_path.split('/')[-2] 45 | h5f = h5py.File(image_path, 'r') 46 | image = h5f['image'][:] 47 | label = h5f['label'][:] 48 | if preproc_fn is not None: 49 | image = preproc_fn(image) 50 | prediction, score_map = test_single_case(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 51 | if nms: 52 | prediction = getLargestCC(prediction) 53 | 54 | if np.sum(prediction)==0: 55 | single_metric = (0,0,0,0) 56 | else: 57 | single_metric = calculate_metric_percase(prediction, label[:]) 58 | 59 | if metric_detail: 60 | print('%02d,\t%.5f, %.5f, %.5f, %.5f' % (ith, single_metric[0], single_metric[1], single_metric[2], single_metric[3])) 61 | 62 | total_metric += np.asarray(single_metric) 63 | 64 | if save_result: 65 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path + "%02d_pred.nii.gz" % ith) 66 | nib.save(nib.Nifti1Image(score_map[0].astype(np.float32), np.eye(4)), test_save_path + "%02d_scores.nii.gz" % ith) 67 | 68 | ith += 1 69 | 70 | avg_metric = total_metric / len(image_list) 71 | print('average metric is {}'.format(avg_metric)) 72 | 73 | with open(test_save_path+'../performance.txt', 'w') as f: 74 | f.writelines('average metric is {} \n'.format(avg_metric)) 75 | return avg_metric 76 | 77 | 78 | # def test_single_case(model, image, stride_xy, stride_z, patch_size, num_classes=1): 79 | # w, h, d = image.shape 80 | 81 | # # if the size of image is less than patch_size, then padding it 82 | # add_pad = False 83 | # if w < patch_size[0]: 84 | # w_pad = patch_size[0]-w 85 | # add_pad = True 86 | # else: 87 | # w_pad = 0 88 | # if h < patch_size[1]: 89 | # h_pad = patch_size[1]-h 90 | # add_pad = True 91 | # else: 92 | # h_pad = 0 93 | # if d < patch_size[2]: 94 | # d_pad = patch_size[2]-d 95 | # add_pad = True 96 | # else: 97 | # d_pad = 0 98 | # wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2 99 | # hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2 100 | # dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2 101 | # if add_pad: 102 | # image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0) 103 | # ww,hh,dd = image.shape 104 | 105 | # sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 106 | # sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 107 | # sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 108 | # # print("{}, {}, {}".format(sx, sy, sz)) 109 | # score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 110 | # cnt = np.zeros(image.shape).astype(np.float32) 111 | 112 | # for x in range(0, sx): 113 | # xs = min(stride_xy*x, ww-patch_size[0]) 114 | # for y in range(0, sy): 115 | # ys = min(stride_xy * y,hh-patch_size[1]) 116 | # for z in range(0, sz): 117 | # zs = min(stride_z * z, dd-patch_size[2]) 118 | # test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] 119 | # test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32) 120 | # test_patch = torch.from_numpy(test_patch).cuda() 121 | 122 | # with torch.no_grad(): 123 | # y1, features = model(test_patch) 124 | # y = F.softmax(y1, dim=1) 125 | 126 | # y = y.cpu().data.numpy() 127 | # y = y[0,:,:,:,:] 128 | # score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 129 | # = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 130 | # cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 131 | # = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 132 | # score_map = score_map/np.expand_dims(cnt,axis=0) 133 | # label_map = (score_map[0]>0.5).astype(np.int) 134 | # if add_pad: 135 | # label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 136 | # score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d] 137 | # return label_map, score_map 138 | 139 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 140 | w, h, d = image.shape 141 | 142 | # if the size of image is less than patch_size, then padding it 143 | add_pad = False 144 | if w < patch_size[0]: 145 | w_pad = patch_size[0]-w 146 | add_pad = True 147 | else: 148 | w_pad = 0 149 | if h < patch_size[1]: 150 | h_pad = patch_size[1]-h 151 | add_pad = True 152 | else: 153 | h_pad = 0 154 | if d < patch_size[2]: 155 | d_pad = patch_size[2]-d 156 | add_pad = True 157 | else: 158 | d_pad = 0 159 | wl_pad, wr_pad = w_pad//2, w_pad-w_pad//2 160 | hl_pad, hr_pad = h_pad//2, h_pad-h_pad//2 161 | dl_pad, dr_pad = d_pad//2, d_pad-d_pad//2 162 | if add_pad: 163 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), 164 | (dl_pad, dr_pad)], mode='constant', constant_values=0) 165 | ww, hh, dd = image.shape 166 | 167 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 168 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 169 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 170 | # print("{}, {}, {}".format(sx, sy, sz)) 171 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 172 | cnt = np.zeros(image.shape).astype(np.float32) 173 | 174 | for x in range(0, sx): 175 | xs = min(stride_xy*x, ww-patch_size[0]) 176 | for y in range(0, sy): 177 | ys = min(stride_xy * y, hh-patch_size[1]) 178 | for z in range(0, sz): 179 | zs = min(stride_z * z, dd-patch_size[2]) 180 | test_patch = image[xs:xs+patch_size[0], 181 | ys:ys+patch_size[1], zs:zs+patch_size[2]] 182 | test_patch = np.expand_dims(np.expand_dims( 183 | test_patch, axis=0), axis=0).astype(np.float32) 184 | test_patch = torch.from_numpy(test_patch).cuda() 185 | 186 | with torch.no_grad(): 187 | # test_patch = test_patch 188 | # print('each test', test_patch.shape) # each test torch.Size([1, 1, 256, 224, 20]) 189 | # test_patch = test_patch.permute(0, 1, 4, 2, 3) 190 | # exit() 191 | y1 = net(test_patch) 192 | # ensemble 193 | y = torch.softmax(y1, dim=1) 194 | # print('prediction', y.shape) # prediction torch.Size([1, 4, 20, 256, 224]) 195 | # y = y.permute(0, 1, 3, 4, 2) 196 | # exit() 197 | y = y.cpu().data.numpy() 198 | y = y[0, :, :, :, :] 199 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 200 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 201 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 202 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 203 | score_map = score_map/np.expand_dims(cnt, axis=0) 204 | label_map = np.argmax(score_map, axis=0) 205 | 206 | if add_pad: 207 | label_map = label_map[wl_pad:wl_pad+w, 208 | hl_pad:hl_pad+h, dl_pad:dl_pad+d] 209 | score_map = score_map[:, wl_pad:wl_pad + 210 | w, hl_pad:hl_pad+h, dl_pad:dl_pad+d] 211 | return label_map, score_map 212 | 213 | 214 | def calculate_metric_percase(pred, gt): 215 | dice = metric.binary.dc(pred, gt) 216 | jc = metric.binary.jc(pred, gt) 217 | hd = metric.binary.hd95(pred, gt) 218 | asd = metric.binary.asd(pred, gt) 219 | 220 | return dice, jc, hd, asd 221 | -------------------------------------------------------------------------------- /code/networks/vnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class ConvBlock(nn.Module): 6 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 7 | super(ConvBlock, self).__init__() 8 | 9 | ops = [] 10 | for i in range(n_stages): 11 | if i==0: 12 | input_channel = n_filters_in 13 | else: 14 | input_channel = n_filters_out 15 | 16 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 17 | if normalization == 'batchnorm': 18 | ops.append(nn.BatchNorm3d(n_filters_out)) 19 | elif normalization == 'groupnorm': 20 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 21 | elif normalization == 'instancenorm': 22 | ops.append(nn.InstanceNorm3d(n_filters_out)) 23 | elif normalization != 'none': 24 | assert False 25 | ops.append(nn.ReLU(inplace=True)) 26 | 27 | self.conv = nn.Sequential(*ops) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | 34 | class ResidualConvBlock(nn.Module): 35 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 36 | super(ResidualConvBlock, self).__init__() 37 | 38 | ops = [] 39 | for i in range(n_stages): 40 | if i == 0: 41 | input_channel = n_filters_in 42 | else: 43 | input_channel = n_filters_out 44 | 45 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 46 | if normalization == 'batchnorm': 47 | ops.append(nn.BatchNorm3d(n_filters_out)) 48 | elif normalization == 'groupnorm': 49 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 50 | elif normalization == 'instancenorm': 51 | ops.append(nn.InstanceNorm3d(n_filters_out)) 52 | elif normalization != 'none': 53 | assert False 54 | 55 | if i != n_stages-1: 56 | ops.append(nn.ReLU(inplace=True)) 57 | 58 | self.conv = nn.Sequential(*ops) 59 | self.relu = nn.ReLU(inplace=True) 60 | 61 | def forward(self, x): 62 | x = (self.conv(x) + x) 63 | x = self.relu(x) 64 | return x 65 | 66 | 67 | class DownsamplingConvBlock(nn.Module): 68 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 69 | super(DownsamplingConvBlock, self).__init__() 70 | 71 | ops = [] 72 | if normalization != 'none': 73 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 74 | if normalization == 'batchnorm': 75 | ops.append(nn.BatchNorm3d(n_filters_out)) 76 | elif normalization == 'groupnorm': 77 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 78 | elif normalization == 'instancenorm': 79 | ops.append(nn.InstanceNorm3d(n_filters_out)) 80 | else: 81 | assert False 82 | else: 83 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 84 | 85 | ops.append(nn.ReLU(inplace=True)) 86 | 87 | self.conv = nn.Sequential(*ops) 88 | 89 | def forward(self, x): 90 | x = self.conv(x) 91 | return x 92 | 93 | 94 | class UpsamplingDeconvBlock(nn.Module): 95 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 96 | super(UpsamplingDeconvBlock, self).__init__() 97 | 98 | ops = [] 99 | if normalization != 'none': 100 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 101 | if normalization == 'batchnorm': 102 | ops.append(nn.BatchNorm3d(n_filters_out)) 103 | elif normalization == 'groupnorm': 104 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 105 | elif normalization == 'instancenorm': 106 | ops.append(nn.InstanceNorm3d(n_filters_out)) 107 | else: 108 | assert False 109 | else: 110 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 111 | 112 | ops.append(nn.ReLU(inplace=True)) 113 | 114 | self.conv = nn.Sequential(*ops) 115 | 116 | def forward(self, x): 117 | x = self.conv(x) 118 | return x 119 | 120 | 121 | class Upsampling(nn.Module): 122 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 123 | super(Upsampling, self).__init__() 124 | 125 | ops = [] 126 | ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False)) 127 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 128 | if normalization == 'batchnorm': 129 | ops.append(nn.BatchNorm3d(n_filters_out)) 130 | elif normalization == 'groupnorm': 131 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 132 | elif normalization == 'instancenorm': 133 | ops.append(nn.InstanceNorm3d(n_filters_out)) 134 | elif normalization != 'none': 135 | assert False 136 | ops.append(nn.ReLU(inplace=True)) 137 | 138 | self.conv = nn.Sequential(*ops) 139 | 140 | def forward(self, x): 141 | x = self.conv(x) 142 | return x 143 | 144 | 145 | class VNet(nn.Module): 146 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False): 147 | super(VNet, self).__init__() 148 | self.has_dropout = has_dropout 149 | 150 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 151 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 152 | 153 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 154 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 155 | 156 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 157 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 158 | 159 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 160 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 161 | 162 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 163 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 164 | 165 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 166 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 167 | 168 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 169 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 170 | 171 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 172 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 173 | 174 | self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) 175 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 176 | 177 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 178 | # self.__init_weight() 179 | 180 | def encoder(self, input): 181 | x1 = self.block_one(input) 182 | x1_dw = self.block_one_dw(x1) 183 | 184 | x2 = self.block_two(x1_dw) 185 | x2_dw = self.block_two_dw(x2) 186 | 187 | x3 = self.block_three(x2_dw) 188 | x3_dw = self.block_three_dw(x3) 189 | 190 | x4 = self.block_four(x3_dw) 191 | x4_dw = self.block_four_dw(x4) 192 | 193 | x5 = self.block_five(x4_dw) 194 | # x5 = F.dropout3d(x5, p=0.5, training=True) 195 | if self.has_dropout: 196 | x5 = self.dropout(x5) 197 | 198 | res = [x1, x2, x3, x4, x5] 199 | 200 | return res 201 | 202 | def decoder(self, features): 203 | x1 = features[0] 204 | x2 = features[1] 205 | x3 = features[2] 206 | x4 = features[3] 207 | x5 = features[4] 208 | 209 | x5_up = self.block_five_up(x5) 210 | x5_up = x5_up + x4 211 | 212 | x6 = self.block_six(x5_up) 213 | x6_up = self.block_six_up(x6) 214 | x6_up = x6_up + x3 215 | 216 | x7 = self.block_seven(x6_up) 217 | x7_up = self.block_seven_up(x7) 218 | x7_up = x7_up + x2 219 | 220 | x8 = self.block_eight(x7_up) 221 | x8_up = self.block_eight_up(x8) 222 | x8_up = x8_up + x1 223 | x9 = self.block_nine(x8_up) 224 | # x9 = F.dropout3d(x9, p=0.5, training=True) 225 | if self.has_dropout: 226 | x9 = self.dropout(x9) 227 | out = self.out_conv(x9) 228 | return out 229 | 230 | 231 | def forward(self, input, turnoff_drop=False): 232 | if turnoff_drop: 233 | has_dropout = self.has_dropout 234 | self.has_dropout = False 235 | features = self.encoder(input) 236 | out = self.decoder(features) 237 | if turnoff_drop: 238 | self.has_dropout = has_dropout 239 | return out 240 | 241 | # def __init_weight(self): 242 | # for m in self.modules(): 243 | # if isinstance(m, nn.Conv3d): 244 | # torch.nn.init.kaiming_normal_(m.weight) 245 | # elif isinstance(m, nn.BatchNorm3d): 246 | # m.weight.data.fill_(1) 247 | # m.bias.data.zero_() 248 | if __name__ == '__main__': 249 | # compute FLOPS & PARAMETERS 250 | from thop import profile 251 | from thop import clever_format 252 | model = VNet(n_channels=1, n_classes=2) 253 | input = torch.randn(4, 1, 112, 112, 80) 254 | flops, params = profile(model, inputs=(input,)) 255 | print(flops, params) 256 | macs, params = clever_format([flops, params], "%.3f") 257 | print(macs, params) 258 | print("VNet have {} paramerters in total".format(sum(x.numel() for x in model.parameters()))) -------------------------------------------------------------------------------- /code/networks/vnetWithArgs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class ConvBlock(nn.Module): 6 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 7 | super(ConvBlock, self).__init__() 8 | 9 | ops = [] 10 | for i in range(n_stages): 11 | if i==0: 12 | input_channel = n_filters_in 13 | else: 14 | input_channel = n_filters_out 15 | 16 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 17 | if normalization == 'batchnorm': 18 | ops.append(nn.BatchNorm3d(n_filters_out)) 19 | elif normalization == 'groupnorm': 20 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 21 | elif normalization == 'instancenorm': 22 | ops.append(nn.InstanceNorm3d(n_filters_out)) 23 | elif normalization != 'none': 24 | assert False 25 | ops.append(nn.ReLU(inplace=True)) 26 | 27 | self.conv = nn.Sequential(*ops) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | 34 | class ResidualConvBlock(nn.Module): 35 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 36 | super(ResidualConvBlock, self).__init__() 37 | 38 | ops = [] 39 | for i in range(n_stages): 40 | if i == 0: 41 | input_channel = n_filters_in 42 | else: 43 | input_channel = n_filters_out 44 | 45 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 46 | if normalization == 'batchnorm': 47 | ops.append(nn.BatchNorm3d(n_filters_out)) 48 | elif normalization == 'groupnorm': 49 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 50 | elif normalization == 'instancenorm': 51 | ops.append(nn.InstanceNorm3d(n_filters_out)) 52 | elif normalization != 'none': 53 | assert False 54 | 55 | if i != n_stages-1: 56 | ops.append(nn.ReLU(inplace=True)) 57 | 58 | self.conv = nn.Sequential(*ops) 59 | self.relu = nn.ReLU(inplace=True) 60 | 61 | def forward(self, x): 62 | x = (self.conv(x) + x) 63 | x = self.relu(x) 64 | return x 65 | 66 | 67 | class DownsamplingConvBlock(nn.Module): 68 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 69 | super(DownsamplingConvBlock, self).__init__() 70 | 71 | ops = [] 72 | if normalization != 'none': 73 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 74 | if normalization == 'batchnorm': 75 | ops.append(nn.BatchNorm3d(n_filters_out)) 76 | elif normalization == 'groupnorm': 77 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 78 | elif normalization == 'instancenorm': 79 | ops.append(nn.InstanceNorm3d(n_filters_out)) 80 | else: 81 | assert False 82 | else: 83 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 84 | 85 | ops.append(nn.ReLU(inplace=True)) 86 | 87 | self.conv = nn.Sequential(*ops) 88 | 89 | def forward(self, x): 90 | x = self.conv(x) 91 | return x 92 | 93 | 94 | class UpsamplingDeconvBlock(nn.Module): 95 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 96 | super(UpsamplingDeconvBlock, self).__init__() 97 | 98 | ops = [] 99 | if normalization != 'none': 100 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 101 | if normalization == 'batchnorm': 102 | ops.append(nn.BatchNorm3d(n_filters_out)) 103 | elif normalization == 'groupnorm': 104 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 105 | elif normalization == 'instancenorm': 106 | ops.append(nn.InstanceNorm3d(n_filters_out)) 107 | else: 108 | assert False 109 | else: 110 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 111 | 112 | ops.append(nn.ReLU(inplace=True)) 113 | 114 | self.conv = nn.Sequential(*ops) 115 | 116 | def forward(self, x): 117 | x = self.conv(x) 118 | return x 119 | 120 | 121 | class Upsampling(nn.Module): 122 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 123 | super(Upsampling, self).__init__() 124 | 125 | ops = [] 126 | ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False)) 127 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 128 | if normalization == 'batchnorm': 129 | ops.append(nn.BatchNorm3d(n_filters_out)) 130 | elif normalization == 'groupnorm': 131 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 132 | elif normalization == 'instancenorm': 133 | ops.append(nn.InstanceNorm3d(n_filters_out)) 134 | elif normalization != 'none': 135 | assert False 136 | ops.append(nn.ReLU(inplace=True)) 137 | 138 | self.conv = nn.Sequential(*ops) 139 | 140 | def forward(self, x): 141 | x = self.conv(x) 142 | return x 143 | 144 | 145 | class VNet(nn.Module): 146 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False): 147 | super(VNet, self).__init__() 148 | self.has_dropout = has_dropout 149 | 150 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 151 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 152 | 153 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 154 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 155 | 156 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 157 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 158 | 159 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 160 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 161 | 162 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 163 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 164 | 165 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 166 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 167 | 168 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 169 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 170 | 171 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 172 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 173 | 174 | self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) 175 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 176 | 177 | self.dropout = nn.Dropout3d(p=0.5, inplace=False) 178 | # self.__init_weight() 179 | 180 | def encoder(self, input): 181 | x1 = self.block_one(input) 182 | x1_dw = self.block_one_dw(x1) 183 | 184 | x2 = self.block_two(x1_dw) 185 | x2_dw = self.block_two_dw(x2) 186 | 187 | x3 = self.block_three(x2_dw) 188 | x3_dw = self.block_three_dw(x3) 189 | 190 | x4 = self.block_four(x3_dw) 191 | x4_dw = self.block_four_dw(x4) 192 | 193 | x5 = self.block_five(x4_dw) 194 | # x5 = F.dropout3d(x5, p=0.5, training=True) 195 | if self.has_dropout: 196 | x5 = self.dropout(x5) 197 | 198 | res = [x1, x2, x3, x4, x5] 199 | 200 | return res 201 | 202 | def decoder(self, features): 203 | x1 = features[0] 204 | x2 = features[1] 205 | x3 = features[2] 206 | x4 = features[3] 207 | x5 = features[4] 208 | 209 | # feature_map = [x5] 210 | 211 | x5_up = self.block_five_up(x5) 212 | x5_up = x5_up + x4 213 | 214 | feature_map = [x5_up] 215 | 216 | x6 = self.block_six(x5_up) 217 | x6_up = self.block_six_up(x6) 218 | x6_up = x6_up + x3 219 | 220 | feature_map.append(x6_up) 221 | 222 | x7 = self.block_seven(x6_up) 223 | x7_up = self.block_seven_up(x7) 224 | x7_up = x7_up + x2 225 | 226 | feature_map.append(x7_up) 227 | 228 | x8 = self.block_eight(x7_up) 229 | x8_up = self.block_eight_up(x8) 230 | x8_up = x8_up + x1 231 | 232 | feature_map.append(x8_up) 233 | 234 | x9 = self.block_nine(x8_up) 235 | 236 | feature_map.append(x9) 237 | # x9 = F.dropout3d(x9, p=0.5, training=True) 238 | if self.has_dropout: 239 | x9 = self.dropout(x9) 240 | out = self.out_conv(x9) 241 | return out, feature_map 242 | 243 | 244 | def forward(self, input, turnoff_drop=False): 245 | if turnoff_drop: 246 | has_dropout = self.has_dropout 247 | self.has_dropout = False 248 | features = self.encoder(input) 249 | out, feature_map = self.decoder(features) 250 | if turnoff_drop: 251 | self.has_dropout = has_dropout 252 | return out, feature_map[0], feature_map 253 | 254 | # def __init_weight(self): 255 | # for m in self.modules(): 256 | # if isinstance(m, nn.Conv3d): 257 | # torch.nn.init.kaiming_normal_(m.weight) 258 | # elif isinstance(m, nn.BatchNorm3d): 259 | # m.weight.data.fill_(1) 260 | # m.bias.data.zero_() 261 | 262 | if __name__ == '__main__': 263 | # compute FLOPS & PARAMETERS 264 | # from thop import profile 265 | # from thop import clever_format 266 | model = VNet(n_channels=1, n_classes=2) 267 | input = torch.randn(4, 1, 112, 112, 80) 268 | 269 | res, latent, feat = model(input) 270 | print(res.shape) 271 | print(latent.shape) # torch.Size([4, 128, 14, 14, 10]) 272 | print([item.shape for item in feat]) 273 | # [torch.Size([4, 128, 14, 14, 10]), 274 | # torch.Size([4, 64, 28, 28, 20]), 275 | # torch.Size([4, 32, 56, 56, 40]), 276 | # torch.Size([4, 16, 112, 112, 80]), 277 | # torch.Size([4, 16, 112, 112, 80])] 278 | 279 | # flops, params = profile(model, inputs=(input,)) 280 | # print(flops, params) 281 | # macs, params = clever_format([flops, params], "%.3f") 282 | # print(macs, params) 283 | # print("VNet have {} paramerters in total".format(sum(x.numel() for x in model.parameters()))) -------------------------------------------------------------------------------- /code/dataloaders/la_heart.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | from torch.utils.data import Dataset 6 | import h5py 7 | import itertools 8 | from torch.utils.data.sampler import Sampler 9 | from PIL.ImageEnhance import * 10 | from torchvision.transforms import * 11 | from PIL import ImageFilter 12 | import random 13 | 14 | class LAHeart(Dataset): 15 | """ LA Dataset """ 16 | def __init__(self, base_dir=None, split='train', num=None, transform=None): 17 | self._base_dir = base_dir 18 | self.transform = transform 19 | self.sample_list = [] 20 | if split=='train': 21 | with open(self._base_dir+'/../train.list', 'r') as f: 22 | self.image_list = f.readlines() 23 | elif split == 'test': 24 | with open(self._base_dir+'/../test.list', 'r') as f: 25 | self.image_list = f.readlines() 26 | self.image_list = [item.replace('\n','') for item in self.image_list] 27 | if num is not None: 28 | self.image_list = self.image_list[:num] 29 | print("total {} samples".format(len(self.image_list))) 30 | 31 | def __len__(self): 32 | return len(self.image_list) 33 | 34 | def __getitem__(self, idx): 35 | image_name = self.image_list[idx] 36 | h5f = h5py.File(self._base_dir+"/"+image_name+"/mri_norm2.h5", 'r') 37 | image = h5f['image'][:] 38 | label = h5f['label'][:] 39 | sample = {'image': image, 'label': label} 40 | if self.transform: 41 | sample = self.transform(sample) 42 | sample['idx'] = idx 43 | return sample 44 | 45 | 46 | class LAHeartWithIndex(Dataset): 47 | """ LA Dataset """ 48 | def __init__(self, base_dir=None, split='train', num=None, transform=None, 49 | index=4, label_type=1): 50 | self._base_dir = base_dir 51 | self.transform = transform 52 | self.sample_list = [] 53 | if split=='train': 54 | with open(self._base_dir+'/../train.list', 'r') as f: 55 | self.image_list = f.readlines() 56 | elif split == 'test': 57 | with open(self._base_dir+'/../test.list', 'r') as f: 58 | self.image_list = f.readlines() 59 | self.image_list = [item.replace('\n','') for item in self.image_list] 60 | 61 | if(label_type==1): 62 | self.image_list = self.image_list[:index] 63 | else: 64 | self.image_list = self.image_list[index:] 65 | 66 | if num is not None: 67 | self.image_list = self.image_list[:num] 68 | print("total {} samples".format(len(self.image_list))) 69 | 70 | def __len__(self): 71 | return len(self.image_list) 72 | 73 | def __getitem__(self, idx): 74 | image_name = self.image_list[idx] 75 | h5f = h5py.File(self._base_dir+"/"+image_name+"/mri_norm2.h5", 'r') 76 | image = h5f['image'][:] 77 | label = h5f['label'][:] 78 | sample = {'image': image, 'label': label} 79 | if self.transform: 80 | sample = self.transform(sample) 81 | sample['idx'] = idx 82 | return sample 83 | 84 | 85 | class CenterCrop(object): 86 | def __init__(self, output_size): 87 | self.output_size = output_size 88 | 89 | def __call__(self, sample): 90 | image, label = sample['image'], sample['label'] 91 | 92 | # pad the sample if necessary 93 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 94 | self.output_size[2]: 95 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 96 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 97 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 98 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 99 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 100 | 101 | (w, h, d) = image.shape 102 | 103 | w1 = int(round((w - self.output_size[0]) / 2.)) 104 | h1 = int(round((h - self.output_size[1]) / 2.)) 105 | d1 = int(round((d - self.output_size[2]) / 2.)) 106 | 107 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 108 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 109 | 110 | return {'image': image, 'label': label} 111 | 112 | 113 | class RandomCrop(object): 114 | """ 115 | Crop randomly the image in a sample 116 | Args: 117 | output_size (int): Desired output size 118 | """ 119 | 120 | def __init__(self, output_size): 121 | self.output_size = output_size 122 | 123 | def __call__(self, sample): 124 | image, label = sample['image'], sample['label'] 125 | 126 | # pad the sample if necessary 127 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 128 | self.output_size[2]: 129 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 130 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 131 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 132 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 133 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 134 | 135 | (w, h, d) = image.shape 136 | # if np.random.uniform() > 0.33: 137 | # w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4) 138 | # h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4) 139 | # else: 140 | w1 = np.random.randint(0, w - self.output_size[0]) 141 | h1 = np.random.randint(0, h - self.output_size[1]) 142 | d1 = np.random.randint(0, d - self.output_size[2]) 143 | 144 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 145 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 146 | return {'image': image, 'label': label} 147 | 148 | 149 | class RandomRotFlip(object): 150 | """ 151 | Crop randomly flip the dataset in a sample 152 | Args: 153 | output_size (int): Desired output size 154 | """ 155 | 156 | def __call__(self, sample): 157 | image, label = sample['image'], sample['label'] 158 | k = np.random.randint(0, 4) 159 | image = np.rot90(image, k) 160 | label = np.rot90(label, k) 161 | axis = np.random.randint(0, 2) 162 | image = np.flip(image, axis=axis).copy() 163 | label = np.flip(label, axis=axis).copy() 164 | 165 | return {'image': image, 'label': label} 166 | 167 | 168 | # class RandomNoise(object): 169 | # def __init__(self, mu=0, sigma=0.1): 170 | # self.mu = mu 171 | # self.sigma = sigma 172 | 173 | # def __call__(self, sample): 174 | # image, label = sample['image'], sample['label'] 175 | # noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma) 176 | # noise = noise + self.mu 177 | # image = image + noise 178 | # return {'image': image, 'label': label} 179 | 180 | 181 | class CreateOnehotLabel(object): 182 | def __init__(self, num_classes): 183 | self.num_classes = num_classes 184 | 185 | def __call__(self, sample): 186 | image, label = sample['image'], sample['label'] 187 | onehot_label = np.zeros((self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32) 188 | for i in range(self.num_classes): 189 | onehot_label[i, :, :, :] = (label == i).astype(np.float32) 190 | return {'image': image, 'label': label,'onehot_label':onehot_label} 191 | 192 | 193 | class ToTensor(object): 194 | """Convert ndarrays in sample to Tensors.""" 195 | 196 | def __call__(self, sample): 197 | image = sample['image'] 198 | image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 199 | if 'onehot_label' in sample: 200 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long(), 201 | 'onehot_label': torch.from_numpy(sample['onehot_label']).long()} 202 | else: 203 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()} 204 | 205 | 206 | class TwoStreamBatchSampler(Sampler): 207 | """Iterate two sets of indices 208 | 209 | An 'epoch' is one iteration through the primary indices. 210 | During the epoch, the secondary indices are iterated through 211 | as many times as needed. 212 | """ 213 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 214 | self.primary_indices = primary_indices 215 | self.secondary_indices = secondary_indices 216 | self.secondary_batch_size = secondary_batch_size 217 | self.primary_batch_size = batch_size - secondary_batch_size 218 | 219 | assert len(self.primary_indices) >= self.primary_batch_size > 0 220 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 221 | 222 | def __iter__(self): 223 | primary_iter = iterate_once(self.primary_indices) 224 | secondary_iter = iterate_eternally(self.secondary_indices) 225 | return ( 226 | primary_batch + secondary_batch 227 | for (primary_batch, secondary_batch) 228 | in zip(grouper(primary_iter, self.primary_batch_size), 229 | grouper(secondary_iter, self.secondary_batch_size)) 230 | ) 231 | 232 | def __len__(self): 233 | return len(self.primary_indices) // self.primary_batch_size 234 | 235 | def iterate_once(iterable): 236 | return np.random.permutation(iterable) 237 | 238 | 239 | def iterate_eternally(indices): 240 | def infinite_shuffles(): 241 | while True: 242 | yield np.random.permutation(indices) 243 | return itertools.chain.from_iterable(infinite_shuffles()) 244 | 245 | 246 | def grouper(iterable, n): 247 | "Collect data into fixed-length chunks or blocks" 248 | # grouper('ABCDEFG', 3) --> ABC DEF" 249 | args = [iter(iterable)] * n 250 | return zip(*args) 251 | 252 | 253 | 254 | class RandomColorJitter(object): 255 | def __init__(self, color = (0.04, 0.04, 0.04, 0.01), p=0.1) -> None: 256 | self.color = color 257 | self.p = p 258 | 259 | def __call__(self, sample): 260 | if np.random.uniform(low=0, high=1, size=1) > self.p: 261 | return sample 262 | else: 263 | image, label = sample['image'], sample['label'] 264 | for j in range(image.shape[0]): 265 | for t in range(image.shape[-1]): 266 | image[j, :, :, :, t] = ColorJitter( 267 | brightness=self.color[0], 268 | contrast=self.color[1], 269 | saturation=self.color[2], 270 | hue=self.color[3])((image[j, :, :, :, t])) 271 | 272 | 273 | return {'image': image, 'label': label} 274 | 275 | 276 | class RandomNoise(object): 277 | def __init__(self, p=0.5): 278 | self.p = p 279 | def __call__(self, sample): 280 | if np.random.uniform(low=0, high=1, size=1) > self.p: 281 | return sample 282 | else: 283 | image, label = sample['image'], sample['label'] 284 | new_image = [] 285 | # noise = np.clip(self.sigma * np.random.randn(*image.shape), -2*self.sigma, 2*self.sigma) 286 | sigma = random.uniform(0.15, 1.15) 287 | for i in range(image.shape[0]): 288 | for j in range(image.shape[-1]): 289 | image[i, 0, :, :, j] = torch.FloatTensor(np.array(ToPILImage()(image[i, 0, :, :, j]).filter(ImageFilter.GaussianBlur(radius=sigma)))) 290 | # new_image.append(np.array(image_i)/255) 291 | # image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 292 | 293 | # image = torch.tensor(np.array(new_image), dtype=torch.float64) 294 | return {'image': image, 'label': label} 295 | 296 | if __name__ == '__main__': 297 | from torchvision import transforms 298 | train_data_path = '/home/weicheng/selfLearning/DTC/data/2018LA_Seg_Training Set' 299 | db_train = LAHeart(base_dir=train_data_path, 300 | split='train', 301 | transform = transforms.Compose([ 302 | RandomRotFlip(), 303 | RandomCrop([112, 112, 80]), 304 | ToTensor(), 305 | ])) -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import argparse 6 | from networks.net_factory_args import net_factory 7 | import numpy as np 8 | 9 | 10 | 11 | class FeatureExtractor(nn.Module): 12 | def __init__(self, fea_dim=[256, 128, 64, 32, 16], output_dim=256) -> None: 13 | super().__init__() 14 | assert len(fea_dim)==5, 'input_dim is not correct' 15 | cnt = fea_dim[0] 16 | self.fea0 = nn.Conv2d(in_channels=cnt, out_channels=cnt, kernel_size=1, bias=False) 17 | cnt += fea_dim[1] 18 | self.fea1 = nn.Conv2d(in_channels=cnt, out_channels=cnt, kernel_size=1, bias=False) 19 | cnt += fea_dim[2] 20 | self.fea2 = nn.Conv2d(in_channels=cnt, out_channels=cnt, kernel_size=1, bias=False) 21 | cnt += fea_dim[3] 22 | self.fea3 = nn.Conv2d(in_channels=cnt, out_channels=cnt, kernel_size=1, bias=False) 23 | cnt += fea_dim[4] 24 | self.fea4 = nn.Conv2d(in_channels=cnt, out_channels=output_dim, kernel_size=1, bias=False) 25 | 26 | def forward(self, fea_list): 27 | feature0 = fea_list[0] 28 | feature1 = fea_list[1] 29 | feature2 = fea_list[2] 30 | feature3 = fea_list[3] 31 | feature4 = fea_list[4] 32 | x = self.fea0(feature0) + feature0 33 | x = nn.Upsample(size = feature1.shape[-2:], mode='bilinear', align_corners=True)(x) 34 | x = torch.cat((x, feature1), dim=1) 35 | x = self.fea1(x) + x 36 | x = nn.Upsample(size = feature2.shape[-2:], mode='bilinear', align_corners=True)(x) 37 | x = torch.cat((x, feature2), dim=1) 38 | x = self.fea2(x) + x 39 | x = nn.Upsample(size = feature3.shape[-2:], mode='bilinear', align_corners=True)(x) 40 | x = torch.cat((x, feature3), dim=1) 41 | x = self.fea3(x) + x 42 | x = nn.Upsample(size = feature4.shape[-2:], mode='bilinear', align_corners=True)(x) 43 | x = torch.cat((x, feature4), dim=1) 44 | x = self.fea4(x) 45 | return x 46 | 47 | def create_model(ema=False, num_classes=4, train_encoder=True, train_decoder=True): 48 | # Network definition 49 | model = net_factory(net_type='unet', in_chns=1, 50 | class_num=num_classes, train_encoder=train_encoder, train_decoder=train_decoder) 51 | if ema: 52 | for param in model.parameters(): 53 | param.detach_() 54 | return model 55 | 56 | 57 | class ProjectionHead(nn.Module): 58 | def __init__(self, dim_in=4, proj_dim=4, output_pooling_size=16, proj='convmlp'): 59 | super(ProjectionHead, self).__init__() 60 | 61 | if proj == 'linear': 62 | self.proj = nn.Conv2d(dim_in, proj_dim, kernel_size=1) 63 | elif proj == 'convmlp': 64 | self.proj = nn.Sequential( 65 | nn.AdaptiveAvgPool2d(output_pooling_size), 66 | nn.Conv2d(dim_in, dim_in*2, kernel_size=1), 67 | nn.Conv2d(dim_in*2, proj_dim, kernel_size=1) 68 | ) 69 | 70 | def forward(self, x): 71 | return self.proj(x) 72 | 73 | 74 | class MLP(nn.Module): 75 | def __init__(self, input_channels=256, num_class=128, pooling_size=1): 76 | super().__init__() 77 | 78 | self.gap = nn.AdaptiveAvgPool2d(pooling_size) 79 | self.f1 = nn.Linear(input_channels*pooling_size**2, input_channels) 80 | self.f2 = nn.Linear(input_channels, num_class) 81 | 82 | def forward(self, x): 83 | x = self.gap(x) 84 | x = x.view(x.shape[0], -1) 85 | y = self.f1(x) 86 | y = self.f2(y) 87 | 88 | return y 89 | 90 | 91 | class RepresentationHead(nn.Module): 92 | def __init__(self, num_classes=256+128+64+32+16, output_channel=512): 93 | super(RepresentationHead, self).__init__() 94 | self.proj = nn.Sequential( 95 | nn.Conv2d(num_classes, output_channel, kernel_size=3, padding=1, bias=False), 96 | nn.Conv2d(output_channel, output_channel, kernel_size=1) 97 | ) 98 | def forward(self, x): 99 | return self.proj(x) 100 | 101 | 102 | class ISD(nn.Module): 103 | def __init__(self, K=48, m=0.99, Ts=0.1, Tt = 0.01, num_classes=4, train_encoder=True, train_decoder=True, 104 | latent_pooling_size=1, latent_feature_size=256, output_pooling_size=16, patch_size=64): # K=48 105 | super(ISD, self).__init__() 106 | 107 | self.K = K 108 | self.m = m 109 | self.Ts = Ts 110 | self.Tt = Tt 111 | self.num_classes = num_classes 112 | self.patch_size = patch_size 113 | self.latent_feature_size = latent_feature_size 114 | 115 | self.model = create_model(num_classes=num_classes, train_encoder=train_encoder, train_decoder=train_decoder) 116 | 117 | self.ema_model = create_model(ema=True, num_classes=num_classes, train_encoder=False, train_decoder=False) 118 | 119 | self.k_latent_head = MLP(input_channels=256, num_class=self.latent_feature_size, pooling_size=latent_pooling_size) 120 | 121 | self.q_latent_head = MLP(input_channels=256, num_class=self.latent_feature_size, pooling_size=latent_pooling_size) 122 | 123 | self.latent_predictor = nn.Sequential( 124 | nn.Linear(self.latent_feature_size, self.latent_feature_size), 125 | nn.Linear(self.latent_feature_size, self.latent_feature_size), 126 | ) 127 | 128 | 129 | self.k_outputs_head = ProjectionHead(dim_in=num_classes, proj_dim=num_classes, output_pooling_size=output_pooling_size) 130 | self.q_outputs_head = ProjectionHead(dim_in=num_classes, proj_dim=num_classes, output_pooling_size=output_pooling_size) 131 | 132 | 133 | self.outputs_predictor = nn.Sequential( 134 | nn.Conv2d(num_classes, num_classes, kernel_size=1), 135 | nn.Conv2d(num_classes, num_classes, kernel_size=1), 136 | ) 137 | 138 | # copy query encoder weights to key encoder 139 | for param_q, param_k in zip(self.model.parameters(), self.ema_model.parameters()): 140 | param_k.data.copy_(param_q.data) 141 | param_k.requires_grad = False 142 | 143 | # setup queue 144 | self.register_buffer('queue', torch.randn(self.K, self.latent_feature_size)) 145 | self.register_buffer('queue_mask', torch.randn(self.K, 36, num_classes*output_pooling_size**2)) 146 | 147 | # normalize the queue 148 | self.queue = nn.functional.normalize(self.queue, dim=0) 149 | self.queue_mask = nn.functional.normalize(self.queue_mask, dim=0) 150 | 151 | # setup the queue pointer 152 | self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) 153 | self.register_buffer('mask_queue_ptr', torch.zeros(1, dtype=torch.long)) 154 | 155 | 156 | @torch.no_grad() 157 | def _momentum_update_key_encoder(self): 158 | for param_q, param_k in zip(self.model.parameters(), self.ema_model.parameters()): 159 | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) 160 | 161 | 162 | @torch.no_grad() 163 | def data_parallel(self): 164 | self.model = torch.nn.DataParallel(self.model) 165 | self.ema_model = torch.nn.DataParallel(self.ema_model) 166 | self.k_latent_head = torch.nn.DataParallel(self.k_latent_head) 167 | self.q_latent_head = torch.nn.DataParallel(self.q_latent_head) 168 | self.latent_predictor = torch.nn.DataParallel(self.latent_predictor) 169 | self.k_outputs_head = torch.nn.DataParallel(self.k_outputs_head) 170 | self.q_outputs_head = torch.nn.DataParallel(self.q_outputs_head) 171 | self.outputs_predictor = torch.nn.DataParallel(self.outputs_predictor) 172 | 173 | 174 | @torch.no_grad() 175 | def _dequeue_and_enqueue(self, keys, queue, queue_ptr): 176 | batch_size = keys.shape[0] 177 | 178 | ptr = int(queue_ptr) 179 | assert self.K % batch_size == 0 180 | 181 | queue[ptr:ptr + batch_size] = keys 182 | ptr = (ptr + batch_size) % self.K # move pointer 183 | 184 | queue_ptr[0] = ptr 185 | 186 | 187 | def forward(self, im_q, im_k): 188 | 189 | batch_size = im_q.shape[0] 190 | 191 | if not self.training: 192 | outputs, latent_vector, _ = self.model(im_q) 193 | return outputs, latent_vector 194 | 195 | outputs, latent_vector, _ = self.model(im_q) 196 | outputs_tmp = outputs 197 | 198 | 199 | with torch.no_grad(): 200 | ema_inputs = im_k 201 | ema_output_tmp, _, _ = self.ema_model(ema_inputs) 202 | 203 | ############################################# calculate the KLD with anchors ################################# 204 | with torch.no_grad(): 205 | self._momentum_update_key_encoder() 206 | 207 | shuffle_ids, reverse_ids = get_shuffle_ids(im_k.shape[0]) 208 | im_k = im_k[shuffle_ids] 209 | 210 | ema_output, ema_latent_vector, _ = self.ema_model(im_k) 211 | 212 | ema_latent_vector = ema_latent_vector[reverse_ids] 213 | ema_output = ema_output[reverse_ids] 214 | 215 | queue = self.queue.clone().detach() 216 | queue_mask = self.queue_mask.clone().detach().transpose(0, 1).contiguous() 217 | patch_size = self.patch_size 218 | step = self.patch_size // 2 219 | output_stu_after_head = [] 220 | output_tea_after_head = [] 221 | for i in range(0, outputs.shape[2]-patch_size, step): 222 | for j in range(0, outputs.shape[3]-patch_size, step): 223 | output_stu_after_head.append(self.outputs_predictor(self.q_outputs_head(outputs[:, :, i:i+patch_size, j:j+patch_size]))) 224 | output_tea_after_head.append(self.k_outputs_head(ema_output[:, :, i:i+patch_size, j:j+patch_size])) 225 | 226 | 227 | output_stu_after_head = torch.cat(output_stu_after_head).reshape(batch_size, -1, 228 | output_stu_after_head[0].shape[1], 229 | output_stu_after_head[0].shape[2], 230 | output_stu_after_head[0].shape[3]).contiguous() 231 | 232 | 233 | output_tea_after_head = torch.cat(output_tea_after_head).reshape(batch_size, -1, 234 | output_tea_after_head[0].shape[1], 235 | output_tea_after_head[0].shape[2], 236 | output_tea_after_head[0].shape[3]).contiguous() 237 | 238 | desired_compressed_lat_k = self.k_latent_head(ema_latent_vector) 239 | 240 | desired_compressed_lat_q = self.latent_predictor(self.q_latent_head(latent_vector)) 241 | 242 | output_tea_after_head_tmp = output_tea_after_head.reshape(output_tea_after_head.shape[0], output_tea_after_head.shape[1], -1).contiguous() 243 | 244 | output_stu_after_head = output_stu_after_head.reshape((output_stu_after_head.shape[1], batch_size, -1)).contiguous() 245 | output_tea_after_head = output_tea_after_head.reshape((output_tea_after_head.shape[1], batch_size, -1)).contiguous() 246 | 247 | output_stu_after_head = output_stu_after_head.reshape(-1, output_stu_after_head.shape[0]).contiguous() # xxx*36 248 | output_tea_after_head = output_tea_after_head.reshape(-1, output_tea_after_head.shape[0]).contiguous() # xxx*36 249 | 250 | queue_mask = queue_mask.reshape(-1, queue_mask.shape[0]).contiguous() 251 | 252 | # compute the 4 logits 253 | ema_latent_logits = compute_logits(desired_compressed_lat_k, queue, self.Tt) 254 | latent_logits = compute_logits(desired_compressed_lat_q, queue, self.Ts) 255 | 256 | ema_output_logits = compute_logits(output_tea_after_head, queue_mask, self.Tt) 257 | output_logits = compute_logits(output_stu_after_head, queue_mask, self.Ts) 258 | 259 | self._dequeue_and_enqueue(desired_compressed_lat_k, self.queue, self.queue_ptr) 260 | self._dequeue_and_enqueue(output_tea_after_head_tmp, self.queue_mask, self.mask_queue_ptr) 261 | 262 | return outputs_tmp, ema_output_tmp, ema_latent_logits, latent_logits, ema_output_logits, output_logits 263 | 264 | 265 | 266 | def get_shuffle_ids(bsz): 267 | forward_inds = torch.randperm(bsz).long().cuda() 268 | backward_inds = torch.zeros(bsz).long().cuda() 269 | value = torch.arange(bsz).long().cuda() 270 | backward_inds.index_copy_(0, forward_inds, value) 271 | return forward_inds, backward_inds 272 | 273 | def compute_logits(z_anchor, z_positive, temp_fac): 274 | z_anchor = nn.functional.normalize(z_anchor, dim=1) 275 | z_positive = nn.functional.normalize(z_positive, dim=1) 276 | logits_out = torch.matmul(z_anchor.cuda(), z_positive.T.cuda())/temp_fac 277 | return logits_out 278 | 279 | 280 | if __name__ == '__main__': 281 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 282 | torch.cuda.empty_cache() 283 | model = ISD(K=24, m=0.99, Ts=0.07, Tt=0.01, num_classes=4, train_encoder=True, train_decoder=True, 284 | latent_pooling_size=1, latent_feature_size=512, output_pooling_size=8, patch_size=64).cuda() 285 | model.data_parallel() 286 | input = torch.ones([6, 1, 256, 256]).cuda() 287 | for i in range(5): 288 | outputs, ema_output, ema_latent_logits, latent_logits, ema_output_logits, output_logits= model(input, input) 289 | # exit() 290 | print(outputs.shape) -------------------------------------------------------------------------------- /code/augment_3d.py: -------------------------------------------------------------------------------- 1 | from torch import equal 2 | from torch.utils.data import DataLoader 3 | # from dataloaders.la_heart import * 4 | import random 5 | from PIL import ImageFilter 6 | 7 | from operator import index 8 | import os 9 | import torch 10 | import numpy as np 11 | from glob import glob 12 | from torch.utils.data import Dataset 13 | import h5py 14 | import itertools 15 | from torch.utils.data.sampler import Sampler 16 | import torchvision.transforms.functional as transforms_f 17 | from torchvision.transforms import * 18 | from PIL.ImageEnhance import * 19 | from PIL import Image 20 | import copy 21 | from scipy import ndimage 22 | from scipy.ndimage.interpolation import zoom 23 | try: # SciPy >= 0.19 24 | from scipy.special import comb 25 | except ImportError: 26 | from scipy.misc import comb 27 | from adv_morph import * 28 | 29 | 30 | def bezier_curve(points, nTimes=1000): 31 | """ 32 | Given a set of control points, return the 33 | bezier curve defined by the control points. 34 | Control points should be a list of lists, or list of tuples 35 | such as [ [1,1], 36 | [2,3], 37 | [4,5], ..[Xn, Yn] ] 38 | nTimes is the number of time steps, defaults to 1000 39 | See http://processingjs.nihongoresources.com/bezierinfo/ 40 | """ 41 | 42 | nPoints = len(points) 43 | xPoints = np.array([p[0] for p in points]) 44 | yPoints = np.array([p[1] for p in points]) 45 | 46 | t = np.linspace(0.0, 1.0, nTimes) 47 | 48 | polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints) ]) 49 | 50 | xvals = np.dot(xPoints, polynomial_array) 51 | yvals = np.dot(yPoints, polynomial_array) 52 | 53 | return xvals, yvals 54 | 55 | def bernstein_poly(i, n, t): 56 | """ 57 | The Bernstein polynomial of n, i as a function of t 58 | """ 59 | 60 | return comb(n, i) * ( t**(n-i) ) * (1 - t)**i 61 | 62 | 63 | def local_pixel_shuffling(x, prob=0.5): 64 | if random.random() >= prob: 65 | return x 66 | image_temp = copy.deepcopy(x) 67 | orig_image = copy.deepcopy(x) 68 | _, img_rows, img_cols = x.shape 69 | num_block = 10000 70 | for _ in range(num_block): 71 | block_noise_size_x = random.randint(1, img_rows//10) 72 | block_noise_size_y = random.randint(1, img_cols//10) 73 | noise_x = random.randint(0, img_rows-block_noise_size_x) 74 | noise_y = random.randint(0, img_cols-block_noise_size_y) 75 | window = orig_image[0, noise_x:noise_x+block_noise_size_x, 76 | noise_y:noise_y+block_noise_size_y 77 | ] 78 | window = window.flatten() 79 | np.random.shuffle(window) 80 | window = window.reshape((block_noise_size_x, 81 | block_noise_size_y)) 82 | image_temp[0, noise_x:noise_x+block_noise_size_x, 83 | noise_y:noise_y+block_noise_size_y] = window 84 | local_shuffling_x = image_temp 85 | 86 | return local_shuffling_x 87 | 88 | def nonlinear_transformation(x, prob=0.5): 89 | if random.random() >= prob: 90 | return x 91 | points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] 92 | xpoints = [p[0] for p in points] 93 | ypoints = [p[1] for p in points] 94 | xvals, yvals = bezier_curve(points, nTimes=100000) 95 | if random.random() < 0.5: 96 | # Half change to get flip 97 | xvals = np.sort(xvals) 98 | else: 99 | xvals, yvals = np.sort(xvals), np.sort(yvals) 100 | nonlinear_x = np.interp(x, xvals, yvals) 101 | return nonlinear_x 102 | 103 | def image_in_painting(x): 104 | _, img_rows, img_cols = x.shape 105 | cnt = 5 106 | while cnt > 0 and random.random() < 0.95: 107 | block_noise_size_x = random.randint(img_rows//6, img_rows//3) 108 | block_noise_size_y = random.randint(img_cols//6, img_cols//3) 109 | noise_x = random.randint(3, img_rows-block_noise_size_x-3) 110 | noise_y = random.randint(3, img_cols-block_noise_size_y-3) 111 | x[:, 112 | noise_x:noise_x+block_noise_size_x, 113 | noise_y:noise_y+block_noise_size_y] = np.random.rand(block_noise_size_x, 114 | block_noise_size_y) * 1.0 115 | cnt -= 1 116 | return x 117 | 118 | 119 | def image_out_painting(x): 120 | _, img_rows, img_cols = x.shape 121 | image_temp = copy.deepcopy(x) 122 | x = np.random.rand(x.shape[0], x.shape[1], x.shape[2], x.shape[3], ) * 1.0 123 | block_noise_size_x = img_rows - random.randint(3*img_rows//7, 4*img_rows//7) 124 | block_noise_size_y = img_cols - random.randint(3*img_cols//7, 4*img_cols//7) 125 | noise_x = random.randint(3, img_rows-block_noise_size_x-3) 126 | noise_y = random.randint(3, img_cols-block_noise_size_y-3) 127 | x[:, 128 | noise_x:noise_x+block_noise_size_x, 129 | noise_y:noise_y+block_noise_size_y] = image_temp[:, noise_x:noise_x+block_noise_size_x, 130 | noise_y:noise_y+block_noise_size_y] 131 | 132 | 133 | def transform(image, label, logits=None, crop_size=(256, 256), scale_size=(0.8, 1.0), augmentation=True): 134 | # Random rescale image 135 | # raw_w, raw_h, raw_d = image.shape 136 | if augmentation: 137 | # Random color jitter 138 | if torch.rand(1) > 0.5: 139 | # color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) For PyTorch 1.9/TorchVision 0.10 users 140 | # color_transform = transforms.ColorJitter.get_params((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 141 | color_transform = transforms.ColorJitter((0.98, 1.02), (0.98, 1.02), (0.98, 1.02), (-0.01, 0.01)) 142 | # print(type(color_transform)) # 143 | for d in range(image.shape[-1]): 144 | image[:, :, d] = color_transform(image[:, :, d]) 145 | 146 | # Random Gaussian filter 147 | if torch.rand(1) > 0.5: 148 | sigma = random.uniform(0.15, 1.15) 149 | for d in range(image.shape[-1]): 150 | image[:, :, d] = \ 151 | torch.FloatTensor(np.array(ToPILImage()(image[:, :, d])\ 152 | .filter(ImageFilter.GaussianBlur(radius=sigma)))) 153 | 154 | # Transform to tensor 155 | # image = transforms_f.to_tensor(image) 156 | if logits is not None: 157 | return image, label, logits 158 | else: 159 | return image, label 160 | 161 | 162 | # def denormalise(x, imagenet=True): 163 | # # if imagenet: 164 | # # x = transforms_f.normalize(x, mean=[0., 0., 0.], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]) 165 | # # x = transforms_f.normalize(x, mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]) 166 | # # return x 167 | # # else: 168 | # return (x + 1) / 2 169 | 170 | 171 | # def tensor_to_pil(im, label, logits): 172 | # # im = denormalise(im) 173 | # im = transforms_f.to_pil_image(im.cpu()) 174 | 175 | # label = label.float() / 255. 176 | # label = transforms_f.to_pil_image(label.unsqueeze(0).cpu()) 177 | 178 | # logits = transforms_f.to_pil_image(logits.unsqueeze(0).cpu()) 179 | # return im, label, logits 180 | 181 | 182 | def generate_cutout_mask_3d(img_size, ratio=2, dep=80): 183 | cutout_area = img_size[0] * img_size[1] / ratio 184 | 185 | w = np.random.randint(img_size[1] / ratio + 1, img_size[1]) 186 | h = np.round(cutout_area / w) 187 | 188 | x_start = np.random.randint(0, img_size[1] - w + 1) 189 | y_start = np.random.randint(0, img_size[0] - h + 1) 190 | z_start = np.random.randint(0, img_size[2] - 20 + 1) 191 | 192 | x_end = int(x_start + w) 193 | y_end = int(y_start + h) 194 | z_end = int(z_start + 20) 195 | 196 | mask = torch.ones(img_size) 197 | mask[y_start:y_end, x_start:x_end, z_start:z_end] = 0 198 | return mask.float() 199 | 200 | 201 | def generate_class_mask(pseudo_labels): 202 | labels = torch.unique(pseudo_labels) # all unique labels 203 | labels_select = labels[torch.randperm(len(labels))][:len(labels) // 2] # randomly select half of labels 204 | 205 | mask = (pseudo_labels.unsqueeze(-1) == labels_select).any(-1) 206 | return mask.float() 207 | 208 | 209 | def batch_transform(data, label, logits, scale_size, apply_augmentation): 210 | data_list, label_list, logits_list = [], [], [] 211 | data_size = data.shape 212 | 213 | for k in range(data.shape[0]): 214 | data_pil, label_pil, logits_pil = (data[k], label[k], logits[k]) 215 | aug_data, aug_label, aug_logits = transform(data_pil, label_pil, logits_pil, 216 | scale_size=scale_size, 217 | augmentation=apply_augmentation) 218 | data_list.append(aug_data.unsqueeze(0)) 219 | label_list.append(aug_label.unsqueeze(0)) 220 | logits_list.append(aug_logits.unsqueeze(0)) 221 | 222 | data_trans, label_trans, logits_trans = \ 223 | torch.cat(data_list).cuda(), torch.cat(label_list).cuda(), torch.cat(logits_list).cuda() 224 | 225 | return data_trans, label_trans, logits_trans 226 | 227 | 228 | def generate_unsup_data_3d(data, target, logits, mode='cutout'): 229 | batch_size, _, im_h, im_w, im_z = data.shape 230 | device = data.device 231 | 232 | new_data = [] 233 | new_target = [] 234 | new_logits = [] 235 | for i in range(batch_size): 236 | if mode == 'cutout': 237 | mix_mask = generate_cutout_mask_3d([im_h, im_w, im_z], ratio=2).to(device) 238 | target[i][(1 - mix_mask).bool()] = -1 239 | 240 | new_data.append((data[i] * mix_mask).unsqueeze(0)) 241 | new_target.append(target[i].unsqueeze(0)) 242 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 243 | continue 244 | 245 | if mode == 'cutmix': 246 | mix_mask = generate_cutout_mask_3d([im_h, im_w, im_z]).to(device) 247 | if mode == 'classmix': 248 | mix_mask = generate_class_mask(target[i]).to(device) 249 | 250 | new_data.append((data[i] * mix_mask + data[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 251 | new_target.append((target[i] * mix_mask + target[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 252 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 253 | 254 | new_data, new_target, new_logits = torch.cat(new_data), torch.cat(new_target), torch.cat(new_logits) 255 | return new_data, new_target.long(), new_logits 256 | 257 | 258 | 259 | def random_rot_flip(image, label, logit): 260 | k = np.random.randint(0, 4) 261 | image = np.rot90(image, k) 262 | label = np.rot90(label, k) 263 | logit = np.rot90(logit, k) 264 | 265 | axis = np.random.randint(0, 2) 266 | image = np.flip(image, axis=axis).copy() 267 | label = np.flip(label, axis=axis).copy() 268 | logit = np.flip(logit, axis=axis).copy() 269 | return image, label, logit 270 | 271 | 272 | def random_rotate(image, label, logit): 273 | angle = np.random.randint(-20, 20) 274 | image = ndimage.rotate(image, angle, order=0, reshape=False) 275 | label = ndimage.rotate(label, angle, order=0, reshape=False) 276 | logit = ndimage.rotate(logit, angle, order=0, reshape=False) 277 | return image, label, logit 278 | 279 | 280 | 281 | def randomGeneratorWithLogits(image, label, logit, output_size=[256, 256]): 282 | new_data = [] 283 | new_target = [] 284 | new_logits = [] 285 | _, _, x, y = image.shape 286 | for i in range(image.shape[0]): 287 | image_i = image[i, 0, :, :].data.cpu().numpy() 288 | label_i = label[i, :, :].data.cpu().numpy() 289 | logit_i = logit[i, :, :].data.cpu().numpy() 290 | 291 | image_i = zoom( 292 | image_i, (output_size[0] / x, output_size[1] / y), order=0) 293 | label_i = zoom( 294 | label_i, (output_size[0] / x, output_size[1] / y), order=0) 295 | logit_i = zoom( 296 | logit_i, (output_size[0] / x, output_size[1] / y), order=0) 297 | 298 | new_data.append(image_i) 299 | new_target.append(label_i) 300 | new_logits.append(logit_i) 301 | 302 | image = torch.from_numpy(np.array(new_data).astype(np.float32)).unsqueeze(1) 303 | label = torch.from_numpy(np.array(new_target).astype(np.uint8)).long() 304 | logit = torch.from_numpy(np.array(new_logits)) 305 | return image, label, logit 306 | 307 | 308 | # if __name__ == '__main__': 309 | # trainloader = DataLoader(db_train_teacher, batch_sampler=batch_sampler, num_workers=4, pin_memory=True,worker_init_fn=worker_init_fn) 310 | 311 | # for i_batch, batch in enumerate(trainloader): 312 | # # print(batch.shape) 313 | # teacher_batch = transform_teacher(batch) 314 | # student_batch = transform_student(batch) 315 | # # print(teacher) 316 | # teacher_batch, teacher_label = teacher_batch['image'], teacher_batch['label'] 317 | # student_batch, student_label = student_batch['image'], student_batch['label'] 318 | # # print(torch.max(teacher_batch)) 319 | # # teacher_batch = transform_teacher(teacher_batch) 320 | # # student_batch = transform_student(student_batch) 321 | # # print(teacher_batch.shape) 322 | # # print(teacher_label.shape) 323 | # print(torch.max(teacher_batch)) 324 | # print(torch.max(student_batch)) 325 | # # print(type(teacher_label)) 326 | # exit() 327 | # # isEqual = teacher_batch.eq(student_batch) 328 | # # if not (isEqual.all()): 329 | # # print('problem at index ', i_batch) 330 | # # exit() 331 | # print('no pb.') --------------------------------------------------------------------------------