├── README.md ├── code ├── dataloaders │ ├── la_heart.py │ └── la_heart_processing.py ├── networks │ └── vnet.py ├── test_LA.py ├── test_util.py ├── train_LA.py ├── train_LA_UPC.py └── utils │ ├── losses.py │ └── ramps.py └── data ├── test.list └── train.list /README.md: -------------------------------------------------------------------------------- 1 | # UPC: Uncertainty-aware Pseudo-label and Consistency for Semi-supervised Medical Image Segmentationh 2 | 3 | by [Liyun Lu](https://github.com/liyun-lu), Mengxiao Yin, Liyao Fu, Feng Yang. 4 | 5 | ## Introduction 6 | This repository is the Pytorch implementation of "Uncertainty-aware Pseudo-label and Consistency for Semi-supervised Medical Image Segmentationh" 7 | 8 | ## Requirements 9 | We implemented our experiment on the super parallel computer system of Guangxi University. The specific configuration is as follows: 10 | * Centos 7.4 11 | * NVIDIA Tesla V100 32G 12 | * Intel Xeon gold 6230 2.1G 20C processor 13 | 14 | Some important required packages include: 15 | * CUDA 10.1 16 | * Pytorch == 1.6.0 17 | * Python == 3.8 18 | * Some basic python packages such as Numpy, Scikit-image, SimpleITK, Scipy ...... 19 | 20 | # Usage 21 | 22 | 1. Clone the repo: 23 | ``` 24 | git clone https://github.com/GXU-GMU-MICCAI/UPC-Pytorch.git 25 | cd UPC-Pytorch 26 | ``` 27 | 28 | 2. Download the Left Atrium dataset in [Google drive](https://drive.google.com/file/d/1CKEtfOGRQhjySYf4MnTgrdEOcuYbBC2t/view?usp=sharing). 29 | Put the data in './data/' folder 30 | ``` 31 | cd code/dataloaders 32 | python la_heart_processing.py 33 | ``` 34 | 35 | 3. Train the model 36 | ``` 37 | cd code 38 | python train_LA_UPC.py 39 | ``` 40 | 41 | 4. Test the model 42 | ``` 43 | python test_LA.py 44 | ``` 45 | 46 | ## Citation 47 | 48 | ## Acknowledgement 49 | Part of the code is revised from the [UA-MT](https://github.com/yulequan/UA-MT). 50 | 51 | We thank Dr. Lequan Yu for their elegant and efficient code base. 52 | 53 | ## Note 54 | * The repository is being updated. 55 | * Contact: Liyun Lu (luly1061@163.com) 56 | -------------------------------------------------------------------------------- /code/dataloaders/la_heart.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | from torch.utils.data import Dataset 6 | import h5py 7 | import itertools 8 | from torch.utils.data.sampler import Sampler 9 | 10 | class LAHeart(Dataset): 11 | """ LA Dataset """ 12 | def __init__(self, base_dir=None, split='train', num=None, transform=None): 13 | self._base_dir = base_dir # dataset path 14 | self.transform = transform # 数据增强类型 15 | self.sample_list = [] 16 | if split=='train': # 训练80个数据 17 | with open(self._base_dir+'/../train.list', 'r') as f: 18 | self.image_list = f.readlines() 19 | elif split == 'test': # 测试20个数据 20 | with open(self._base_dir+'/../test.list', 'r') as f: 21 | self.image_list = f.readlines() 22 | self.image_list = [item.replace('\n','') for item in self.image_list] 23 | if num is not None: # num是加载数据的个数 24 | self.image_list = self.image_list[:num] 25 | print("total {} samples".format(len(self.image_list))) 26 | 27 | def __len__(self): 28 | return len(self.image_list) 29 | 30 | def __getitem__(self, idx): 31 | image_name = self.image_list[idx] 32 | h5f = h5py.File(self._base_dir+"/"+image_name+"/mri_norm2.h5", 'r') 33 | image = h5f['image'][:] 34 | label = h5f['label'][:] 35 | sample = {'image': image, 'label': label} 36 | if self.transform: # 图片进行增强 37 | sample = self.transform(sample) 38 | 39 | return sample 40 | 41 | # 对图片进行中心裁剪 42 | class CenterCrop(object): 43 | def __init__(self, output_size): 44 | self.output_size = output_size 45 | 46 | def __call__(self, sample): 47 | image, label = sample['image'], sample['label'] 48 | 49 | # pad the sample if necessary 50 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 51 | self.output_size[2]: 52 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 53 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 54 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 55 | # 填充padding方法:ndarray = numpy.pad(array, pad_width, mode, **kwargs) 56 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 57 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 58 | 59 | (w, h, d) = image.shape 60 | 61 | w1 = int(round((w - self.output_size[0]) / 2.)) 62 | h1 = int(round((h - self.output_size[1]) / 2.)) 63 | d1 = int(round((d - self.output_size[2]) / 2.)) 64 | 65 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 66 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 67 | 68 | return {'image': image, 'label': label} 69 | 70 | # 对图片进行随机裁剪 71 | class RandomCrop(object): 72 | """ 73 | Crop randomly the image in a sample 74 | Args: 75 | output_size (int): Desired output size 76 | """ 77 | 78 | def __init__(self, output_size): 79 | self.output_size = output_size 80 | 81 | def __call__(self, sample): 82 | image, label = sample['image'], sample['label'] 83 | 84 | # pad the sample if necessary 85 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 86 | self.output_size[2]: 87 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 88 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 89 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 90 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 91 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0) 92 | 93 | (w, h, d) = image.shape 94 | # if np.random.uniform() > 0.33: 95 | # w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4) 96 | # h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4) 97 | # else: 98 | w1 = np.random.randint(0, w - self.output_size[0]) # 在特定范围内随机取一个数 99 | h1 = np.random.randint(0, h - self.output_size[1]) 100 | d1 = np.random.randint(0, d - self.output_size[2]) 101 | 102 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 103 | image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]] 104 | return {'image': image, 'label': label} 105 | 106 | # 随机翻转 107 | class RandomRotFlip(object): 108 | """ 109 | Crop randomly flip the dataset in a sample 110 | Args: 111 | output_size (int): Desired output size 112 | """ 113 | 114 | def __call__(self, sample): 115 | image, label = sample['image'], sample['label'] 116 | k = np.random.randint(0, 4) 117 | image = np.rot90(image, k) # 随机旋转矩阵90、180、270 118 | label = np.rot90(label, k) 119 | axis = np.random.randint(0, 2) # 代表某个维度,x轴或y或z 120 | image = np.flip(image, axis=axis).copy() # 绕着某个轴翻转, np.copy()是深拷贝的意思 121 | label = np.flip(label, axis=axis).copy() 122 | 123 | return {'image': image, 'label': label} 124 | 125 | # 随机噪声 126 | class RandomNoise(object): 127 | def __init__(self, mu=0, sigma=0.1): 128 | self.mu = mu 129 | self.sigma = sigma 130 | 131 | def __call__(self, sample): 132 | image, label = sample['image'], sample['label'] 133 | # np.clip是一个截取函数,用于截取数组中小于或者大于某值的部分,并使得被截取部分等于固定值 134 | noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma) 135 | noise = noise + self.mu 136 | image = image + noise 137 | return {'image': image, 'label': label} 138 | 139 | 140 | class CreateOnehotLabel(object): 141 | def __init__(self, num_classes): 142 | self.num_classes = num_classes 143 | 144 | def __call__(self, sample): 145 | image, label = sample['image'], sample['label'] 146 | onehot_label = np.zeros((self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32) 147 | for i in range(self.num_classes): 148 | onehot_label[i, :, :, :] = (label == i).astype(np.float32) 149 | return {'image': image, 'label': label, 'onehot_label': onehot_label} 150 | 151 | 152 | class ToTensor(object): 153 | """Convert ndarrays in sample to Tensors.""" 154 | 155 | def __call__(self, sample): 156 | image = sample['image'] 157 | image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 158 | if 'onehot_label' in sample: 159 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long(), 160 | 'onehot_label': torch.from_numpy(sample['onehot_label']).long()} 161 | else: 162 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()} 163 | 164 | 165 | class TwoStreamBatchSampler(Sampler): 166 | """Iterate two sets of indices 167 | 168 | An 'epoch' is one iteration through the primary indices. 通过主要指标,一个epoch就是一次迭代 169 | During the epoch, the secondary indices are iterated through 在一个epoch过程中,根据多次需要,二级索引被迭代 170 | as many times as needed. 171 | """ 172 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 173 | self.primary_indices = primary_indices 174 | self.secondary_indices = secondary_indices 175 | self.secondary_batch_size = secondary_batch_size 176 | self.primary_batch_size = batch_size - secondary_batch_size 177 | 178 | assert len(self.primary_indices) >= self.primary_batch_size > 0 179 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 180 | 181 | def __iter__(self): 182 | primary_iter = iterate_once(self.primary_indices) 183 | secondary_iter = iterate_eternally(self.secondary_indices) 184 | return ( 185 | primary_batch + secondary_batch 186 | for (primary_batch, secondary_batch) 187 | in zip(grouper(primary_iter, self.primary_batch_size), 188 | grouper(secondary_iter, self.secondary_batch_size)) 189 | ) 190 | 191 | def __len__(self): 192 | return len(self.primary_indices) // self.primary_batch_size 193 | 194 | def iterate_once(iterable): 195 | return np.random.permutation(iterable) 196 | 197 | 198 | def iterate_eternally(indices): 199 | def infinite_shuffles(): 200 | while True: 201 | yield np.random.permutation(indices) 202 | return itertools.chain.from_iterable(infinite_shuffles()) 203 | 204 | 205 | def grouper(iterable, n): 206 | "Collect data into fixed-length chunks or blocks" 207 | # grouper('ABCDEFG', 3) --> ABC DEF" 208 | args = [iter(iterable)] * n 209 | return zip(*args) 210 | -------------------------------------------------------------------------------- /code/dataloaders/la_heart_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from glob import glob 3 | from tqdm import tqdm 4 | import h5py 5 | import nrrd 6 | 7 | output_size =[112, 112, 80] # 裁剪图片成为112*112*80 8 | # 160,192,224 9 | # nnrd包读取后,预处理后转变为h5py文件 10 | def covert_h5(): 11 | listt = glob('E:/project/Master-1/code/segmetation-UAMT/UA-MT/data/Training_Set/*/lgemri.nrrd') 12 | print(listt) 13 | 14 | for item in tqdm(listt): 15 | image, img_header = nrrd.read(item) 16 | label, gt_header = nrrd.read(item.replace('lgemri.nrrd', 'laendo.nrrd')) 17 | label = (label == 255).astype(np.uint8) 18 | w, h, d = label.shape 19 | 20 | tempL = np.nonzero(label) # 得到数组array中非零元素的位置(下标) 21 | minx, maxx = np.min(tempL[0]), np.max(tempL[0]) 22 | miny, maxy = np.min(tempL[1]), np.max(tempL[1]) 23 | minz, maxz = np.min(tempL[2]), np.max(tempL[2]) 24 | 25 | px = max(output_size[0] - (maxx - minx), 0) // 2 26 | py = max(output_size[1] - (maxy - miny), 0) // 2 27 | pz = max(output_size[2] - (maxz - minz), 0) // 2 28 | minx = max(minx - np.random.randint(10, 20) - px, 0) 29 | maxx = min(maxx + np.random.randint(10, 20) + px, w) 30 | miny = max(miny - np.random.randint(10, 20) - py, 0) 31 | maxy = min(maxy + np.random.randint(10, 20) + py, h) 32 | minz = max(minz - np.random.randint(5, 10) - pz, 0) 33 | maxz = min(maxz + np.random.randint(5, 10) + pz, d) 34 | 35 | image = (image - np.mean(image)) / np.std(image) # 标准化/0均值化,将数据中心转移到原点处,np.std计算标准差 36 | image = image.astype(np.float32) # 转为浮点数 37 | image = image[minx:maxx, miny:maxy] 38 | label = label[minx:maxx, miny:maxy] 39 | print(label.shape) 40 | f = h5py.File(item.replace('lgemri.nrrd', 'mri_norm2.h5'), 'w') 41 | f.create_dataset('image', data=image, compression="gzip") 42 | f.create_dataset('label', data=label, compression="gzip") 43 | f.close() 44 | 45 | 46 | if __name__ == '__main__': 47 | covert_h5() -------------------------------------------------------------------------------- /code/networks/vnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class ConvBlock(nn.Module): 6 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 7 | super(ConvBlock, self).__init__() 8 | 9 | ops = [] 10 | for i in range(n_stages): 11 | if i==0: 12 | input_channel = n_filters_in 13 | else: 14 | input_channel = n_filters_out 15 | 16 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 17 | if normalization == 'batchnorm': 18 | ops.append(nn.BatchNorm3d(n_filters_out)) 19 | elif normalization == 'groupnorm': 20 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 21 | elif normalization == 'instancenorm': 22 | ops.append(nn.InstanceNorm3d(n_filters_out)) 23 | elif normalization != 'none': 24 | assert False 25 | ops.append(nn.ReLU(inplace=True)) 26 | 27 | self.conv = nn.Sequential(*ops) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | 34 | class ResidualConvBlock(nn.Module): 35 | def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'): 36 | super(ResidualConvBlock, self).__init__() 37 | 38 | ops = [] 39 | for i in range(n_stages): 40 | if i == 0: 41 | input_channel = n_filters_in 42 | else: 43 | input_channel = n_filters_out 44 | 45 | ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1)) 46 | if normalization == 'batchnorm': 47 | ops.append(nn.BatchNorm3d(n_filters_out)) 48 | elif normalization == 'groupnorm': 49 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 50 | elif normalization == 'instancenorm': 51 | ops.append(nn.InstanceNorm3d(n_filters_out)) 52 | elif normalization != 'none': 53 | assert False 54 | 55 | if i != n_stages-1: 56 | ops.append(nn.ReLU(inplace=True)) 57 | 58 | self.conv = nn.Sequential(*ops) 59 | self.relu = nn.ReLU(inplace=True) 60 | 61 | def forward(self, x): 62 | x = (self.conv(x) + x) 63 | x = self.relu(x) 64 | return x 65 | 66 | 67 | class DownsamplingConvBlock(nn.Module): 68 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 69 | super(DownsamplingConvBlock, self).__init__() 70 | 71 | ops = [] 72 | if normalization != 'none': 73 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 74 | if normalization == 'batchnorm': 75 | ops.append(nn.BatchNorm3d(n_filters_out)) 76 | elif normalization == 'groupnorm': 77 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 78 | elif normalization == 'instancenorm': 79 | ops.append(nn.InstanceNorm3d(n_filters_out)) 80 | else: 81 | assert False 82 | else: 83 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 84 | 85 | ops.append(nn.ReLU(inplace=True)) 86 | 87 | self.conv = nn.Sequential(*ops) 88 | 89 | def forward(self, x): 90 | x = self.conv(x) 91 | return x 92 | 93 | 94 | class UpsamplingDeconvBlock(nn.Module): 95 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 96 | super(UpsamplingDeconvBlock, self).__init__() 97 | 98 | ops = [] 99 | if normalization != 'none': 100 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 101 | if normalization == 'batchnorm': 102 | ops.append(nn.BatchNorm3d(n_filters_out)) 103 | elif normalization == 'groupnorm': 104 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 105 | elif normalization == 'instancenorm': 106 | ops.append(nn.InstanceNorm3d(n_filters_out)) 107 | else: 108 | assert False 109 | else: 110 | ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) 111 | 112 | ops.append(nn.ReLU(inplace=True)) 113 | 114 | self.conv = nn.Sequential(*ops) 115 | 116 | def forward(self, x): 117 | x = self.conv(x) 118 | return x 119 | 120 | 121 | class Upsampling(nn.Module): 122 | def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'): 123 | super(Upsampling, self).__init__() 124 | 125 | ops = [] 126 | ops.append(nn.Upsample(scale_factor=stride, mode='trilinear',align_corners=False)) 127 | ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1)) 128 | if normalization == 'batchnorm': 129 | ops.append(nn.BatchNorm3d(n_filters_out)) 130 | elif normalization == 'groupnorm': 131 | ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out)) 132 | elif normalization == 'instancenorm': 133 | ops.append(nn.InstanceNorm3d(n_filters_out)) 134 | elif normalization != 'none': 135 | assert False 136 | ops.append(nn.ReLU(inplace=True)) 137 | 138 | self.conv = nn.Sequential(*ops) 139 | 140 | def forward(self, x): 141 | x = self.conv(x) 142 | return x 143 | 144 | 145 | class VNet(nn.Module): 146 | def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, 147 | dropout_rate=0.5): 148 | super(VNet, self).__init__() 149 | self.has_dropout = has_dropout 150 | 151 | self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization) 152 | self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization) 153 | 154 | self.block_two = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 155 | self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization) 156 | 157 | self.block_three = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 158 | self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization) 159 | 160 | self.block_four = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 161 | self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization) 162 | 163 | self.block_five = ConvBlock(3, n_filters * 16, n_filters * 16, normalization=normalization) 164 | self.block_five_up = UpsamplingDeconvBlock(n_filters * 16, n_filters * 8, normalization=normalization) 165 | 166 | self.block_six = ConvBlock(3, n_filters * 8, n_filters * 8, normalization=normalization) 167 | self.block_six_up = UpsamplingDeconvBlock(n_filters * 8, n_filters * 4, normalization=normalization) 168 | 169 | self.block_seven = ConvBlock(3, n_filters * 4, n_filters * 4, normalization=normalization) 170 | self.block_seven_up = UpsamplingDeconvBlock(n_filters * 4, n_filters * 2, normalization=normalization) 171 | 172 | self.block_eight = ConvBlock(2, n_filters * 2, n_filters * 2, normalization=normalization) 173 | self.block_eight_up = UpsamplingDeconvBlock(n_filters * 2, n_filters, normalization=normalization) 174 | 175 | self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization) 176 | self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0) 177 | 178 | self.dropout = nn.Dropout3d(p=dropout_rate, inplace=False) 179 | # self.__init_weight() 180 | 181 | def encoder(self, input): 182 | x1 = self.block_one(input) 183 | x1_dw = self.block_one_dw(x1) 184 | 185 | x2 = self.block_two(x1_dw) 186 | x2_dw = self.block_two_dw(x2) 187 | 188 | x3 = self.block_three(x2_dw) 189 | x3_dw = self.block_three_dw(x3) 190 | 191 | x4 = self.block_four(x3_dw) 192 | x4_dw = self.block_four_dw(x4) 193 | 194 | x5 = self.block_five(x4_dw) 195 | if self.has_dropout: 196 | x5 = self.dropout(x5) 197 | # x5 = F.dropout3d(x5, p=0.5, training=True) 198 | 199 | res = [x1, x2, x3, x4, x5] 200 | 201 | return res 202 | 203 | def decoder(self, features): 204 | x1 = features[0] 205 | x2 = features[1] 206 | x3 = features[2] 207 | x4 = features[3] 208 | x5 = features[4] 209 | 210 | x5_up = self.block_five_up(x5) 211 | x5_up = x5_up + x4 212 | 213 | x6 = self.block_six(x5_up) 214 | x6_up = self.block_six_up(x6) 215 | x6_up = x6_up + x3 216 | 217 | x7 = self.block_seven(x6_up) 218 | x7_up = self.block_seven_up(x7) 219 | x7_up = x7_up + x2 220 | 221 | x8 = self.block_eight(x7_up) 222 | x8_up = self.block_eight_up(x8) 223 | x8_up = x8_up + x1 224 | x9 = self.block_nine(x8_up) 225 | if self.has_dropout: 226 | x9 = self.dropout(x9) 227 | # x9 = F.dropout3d(x9, p=0.5, training=True) 228 | out = self.out_conv(x9) 229 | return out 230 | 231 | 232 | def forward(self, input, turnoff_drop=False): 233 | if turnoff_drop: 234 | has_dropout = self.has_dropout 235 | self.has_dropout = False 236 | features = self.encoder(input) 237 | out = self.decoder(features) 238 | if turnoff_drop: 239 | self.has_dropout = has_dropout 240 | return out 241 | 242 | # def __init_weight(self): 243 | # for m in self.modules(): 244 | # if isinstance(m, nn.Conv3d): 245 | # torch.nn.init.kaiming_normal_(m.weight) 246 | # elif isinstance(m, nn.BatchNorm3d): 247 | # m.weight.data.fill_(1) 248 | # m.bias.data.zero_() 249 | -------------------------------------------------------------------------------- /code/test_LA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from networks.vnet import VNet 5 | from test_util import test_all_case 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--root_path', type=str, default='../data/Training_Set/', help='Name of Experiment') 9 | parser.add_argument('--model', type=str, default='PL', help='model_name') 10 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 11 | FLAGS = parser.parse_args() 12 | 13 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu 14 | # snapshot_path = "../model/"+FLAGS.model+"/" 15 | test_save_path = "../model/prediction/"+FLAGS.model+"_post/" 16 | if not os.path.exists(test_save_path): 17 | os.makedirs(test_save_path) 18 | 19 | num_classes = 2 20 | 21 | with open(FLAGS.root_path + '/../test.list', 'r') as f: 22 | image_list = f.readlines() 23 | image_list = [FLAGS.root_path +item.replace('\n', '')+"/mri_norm2.h5" for item in image_list] 24 | 25 | 26 | def test_calculate_metric(save_model_path): 27 | net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False).cuda() 28 | net.load_state_dict(torch.load(save_model_path)) 29 | print("init weight from {}".format(save_model_path)) 30 | net.eval() 31 | 32 | avg_metric = test_all_case(net, image_list, num_classes=num_classes, 33 | patch_size=(112, 112, 80), stride_xy=18, stride_z=4, 34 | save_result=True, test_save_path=test_save_path) 35 | 36 | return avg_metric 37 | 38 | if __name__ == '__main__': 39 | model_save_path = '../model/iter_6000.pth' 40 | metric = test_calculate_metric(model_save_path) 41 | print(metric) -------------------------------------------------------------------------------- /code/test_util.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import math 3 | import nibabel as nib 4 | import numpy as np 5 | from medpy import metric 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | 10 | 11 | def test_all_case(net, image_list, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4, save_result=True, 12 | test_save_path=None, preproc_fn=None): 13 | total_metric = 0.0 14 | for image_path in image_list: 15 | id = image_path.split('/')[-1] 16 | h5f = h5py.File(image_path, 'r') 17 | image = h5f['image'][:] 18 | label = h5f['label'][:] 19 | if preproc_fn is not None: 20 | image = preproc_fn(image) 21 | prediction, score_map = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 22 | 23 | if np.sum(prediction) == 0: 24 | single_metric = (0, 0, 0, 0, 0, 0, 0) 25 | else: 26 | single_metric = calculate_metric_percase(prediction, label[:]) 27 | total_metric += np.asarray(single_metric) 28 | 29 | if save_result: 30 | nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path + id + "_pred.nii.gz") 31 | nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path + id + "_img.nii.gz") 32 | nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path + id + "_gt.nii.gz") 33 | avg_metric = total_metric / len(image_list) 34 | # print('average metric is {}'.format(avg_metric)) 35 | 36 | return avg_metric 37 | 38 | 39 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 40 | w, h, d = image.shape 41 | 42 | # if the size of image is less than patch_size, then padding it 43 | add_pad = False 44 | if w < patch_size[0]: 45 | w_pad = patch_size[0] - w 46 | add_pad = True 47 | else: 48 | w_pad = 0 49 | if h < patch_size[1]: 50 | h_pad = patch_size[1] - h 51 | add_pad = True 52 | else: 53 | h_pad = 0 54 | if d < patch_size[2]: 55 | d_pad = patch_size[2] - d 56 | add_pad = True 57 | else: 58 | d_pad = 0 59 | wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2 60 | hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2 61 | dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2 62 | if add_pad: 63 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)], mode='constant', 64 | constant_values=0) 65 | ww, hh, dd = image.shape 66 | 67 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 68 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 69 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 70 | # print("{}, {}, {}".format(sx, sy, sz)) 71 | score_map = np.zeros((num_classes,) + image.shape).astype(np.float32) 72 | cnt = np.zeros(image.shape).astype(np.float32) 73 | 74 | for x in range(0, sx): 75 | xs = min(stride_xy * x, ww - patch_size[0]) 76 | for y in range(0, sy): 77 | ys = min(stride_xy * y, hh - patch_size[1]) 78 | for z in range(0, sz): 79 | zs = min(stride_z * z, dd - patch_size[2]) 80 | test_patch = image[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] 81 | test_patch = np.expand_dims(np.expand_dims(test_patch, axis=0), axis=0).astype(np.float32) 82 | test_patch = torch.from_numpy(test_patch).cuda() 83 | y1 = net(test_patch) 84 | y = F.softmax(y1, dim=1) 85 | y = y.cpu().data.numpy() 86 | y = y[0, :, :, :, :] 87 | score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 88 | = score_map[:, xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + y 89 | cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] \ 90 | = cnt[xs:xs + patch_size[0], ys:ys + patch_size[1], zs:zs + patch_size[2]] + 1 91 | score_map = score_map / np.expand_dims(cnt, axis=0) 92 | label_map = np.argmax(score_map, axis=0) 93 | if add_pad: 94 | label_map = label_map[wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 95 | score_map = score_map[:, wl_pad:wl_pad + w, hl_pad:hl_pad + h, dl_pad:dl_pad + d] 96 | return label_map, score_map 97 | 98 | 99 | def cal_dice(prediction, label, num=2): 100 | total_dice = np.zeros(num - 1) 101 | for i in range(1, num): 102 | prediction_tmp = (prediction == i) 103 | label_tmp = (label == i) 104 | prediction_tmp = prediction_tmp.astype(np.float) 105 | label_tmp = label_tmp.astype(np.float) 106 | 107 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 108 | total_dice[i - 1] += dice 109 | 110 | return total_dice 111 | 112 | 113 | def calculate_metric_percase(pred, gt): 114 | dice = metric.binary.dc(pred, gt) 115 | jc = metric.binary.jc(pred, gt) 116 | hd = metric.binary.hd95(pred, gt) 117 | asd = metric.binary.asd(pred, gt) 118 | precision = metric.binary.precision(pred, gt) 119 | sensitivity = metric.binary.sensitivity(pred, gt) 120 | specificity = metric.binary.specificity(pred, gt) 121 | 122 | return dice, jc, asd, hd, precision, sensitivity, specificity 123 | -------------------------------------------------------------------------------- /code/train_LA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | from torchvision import transforms 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | 20 | from networks.vnet import VNet 21 | from utils.losses import dice_loss 22 | from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 23 | from test_LA import val_calculate_metric 24 | 25 | ######## lable num ######### 26 | num_labeled = 80 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--root_path', type=str, default='../data/Training_Set/', help='Name of Experiment') 29 | parser.add_argument('--exp', type=str, default='vnet_supervisedonly_dp/labeled_80', help='model_name') 30 | parser.add_argument('--max_iterations', type=int, default=6000, help='maximum epoch number to train') # max_epoch = max_iterations//len(trainloader)+1 31 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 32 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 33 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 34 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 35 | parser.add_argument('--gpu', type=str, default='7', help='GPU to use') 36 | args = parser.parse_args() 37 | 38 | train_data_path = args.root_path 39 | snapshot_path = "../model/" + args.exp + "/" 40 | 41 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 42 | batch_size = args.batch_size * len(args.gpu.split(',')) 43 | max_iterations = args.max_iterations 44 | base_lr = args.base_lr 45 | 46 | if args.deterministic: 47 | cudnn.benchmark = False 48 | cudnn.deterministic = True 49 | random.seed(args.seed) 50 | np.random.seed(args.seed) 51 | torch.manual_seed(args.seed) 52 | torch.cuda.manual_seed(args.seed) 53 | 54 | patch_size = (112, 112, 80) # 裁剪的大小 55 | num_classes = 2 56 | 57 | if __name__ == "__main__": 58 | ## make logger file 59 | if not os.path.exists(snapshot_path): 60 | os.makedirs(snapshot_path) 61 | if os.path.exists(snapshot_path + '/code'): 62 | shutil.rmtree(snapshot_path + '/code') 63 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 64 | 65 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 66 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 67 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 68 | logging.info(str(args)) 69 | 70 | ################ 定义神经网络 ############## 71 | net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) 72 | net = net.cuda() 73 | 74 | ############## 定义数据,并且进行数据增强:翻转、裁剪,使用图像增广训练模型 ############ 75 | db_train = LAHeart(base_dir=train_data_path, 76 | split='train', 77 | num=num_labeled, 78 | transform = transforms.Compose([ 79 | RandomRotFlip(), 80 | RandomCrop(patch_size), 81 | ToTensor(), 82 | ])) 83 | db_test = LAHeart(base_dir=train_data_path, 84 | split='test', 85 | transform = transforms.Compose([ 86 | CenterCrop(patch_size), 87 | ToTensor() 88 | ])) 89 | def worker_init_fn(worker_id): 90 | random.seed(args.seed+worker_id) 91 | ############### 加载数据 ################# 92 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 93 | print(len(trainloader)) 94 | 95 | 96 | net.train() 97 | # 定义优化器 98 | optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 99 | 100 | writer = SummaryWriter(snapshot_path+'/log') 101 | logging.info("{} itertations per epoch".format(len(trainloader))) 102 | 103 | ################# 开始训练 ################## 104 | iter_num = 0 105 | max_epoch = max_iterations//len(trainloader)+1 106 | lr_ = base_lr 107 | net.train() 108 | for epoch_num in tqdm(range(max_epoch), ncols=70): # tqdm用来显示进度条 109 | time1 = time.time() 110 | for i_batch, sampled_batch in enumerate(trainloader): 111 | time2 = time.time() 112 | # print('fetch data cost {}'.format(time2-time1)) 113 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 114 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 115 | outputs = net(volume_batch) 116 | 117 | # 定义损失函数 118 | loss_seg = F.cross_entropy(outputs, label_batch) 119 | outputs_soft = F.softmax(outputs, dim=1) 120 | loss_seg_dice = dice_loss(outputs_soft[:, 1, :, :, :], label_batch == 1) 121 | loss = 0.5*(loss_seg+loss_seg_dice) 122 | 123 | optimizer.zero_grad() 124 | loss.backward() 125 | optimizer.step() 126 | 127 | iter_num = iter_num + 1 128 | writer.add_scalar('lr', lr_, iter_num) 129 | writer.add_scalar('loss/loss_seg', loss_seg, iter_num) 130 | writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num) 131 | writer.add_scalar('loss/loss', loss, iter_num) 132 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 133 | if iter_num % 50 == 0: 134 | image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1) 135 | # make_grid的作用是将若干幅图像拼成一幅图像,padding=5 136 | grid_image = make_grid(image, 5, normalize=True) 137 | writer.add_image('train/Image', grid_image, iter_num) 138 | 139 | outputs_soft = F.softmax(outputs, 1) 140 | image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 141 | grid_image = make_grid(image, 5, normalize=False) 142 | writer.add_image('train/Predicted_label', grid_image, iter_num) 143 | 144 | image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 145 | grid_image = make_grid(image, 5, normalize=False) 146 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 147 | 148 | ## change lr 149 | if iter_num % 2500 == 0: 150 | lr_ = base_lr * 0.1 ** (iter_num // 2500) 151 | for param_group in optimizer.param_groups: 152 | param_group['lr'] = lr_ 153 | # if iter_num % 1000 == 0: 154 | # save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth') 155 | # torch.save(net.state_dict(), save_mode_path) 156 | # logging.info("save model to {}".format(save_mode_path)) 157 | 158 | if iter_num > max_iterations: 159 | break 160 | time1 = time.time() 161 | if iter_num > max_iterations: 162 | break 163 | 164 | metric = val_calculate_metric(net.state_dict()) 165 | ff_path = os.path.join(snapshot_path, 'test_metric.txt') 166 | ff = open(ff_path, 'w') 167 | ff.write(str(metric)) 168 | ff.close() 169 | print(metric) 170 | 171 | save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations)+'.pth') 172 | torch.save(net.state_dict(), save_mode_path) 173 | logging.info("save model to {}".format(save_mode_path)) 174 | writer.close() 175 | 176 | -------------------------------------------------------------------------------- /code/train_LA_UPC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from tensorboardX import SummaryWriter 5 | import shutil 6 | import argparse 7 | import logging 8 | import time 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.optim as optim 14 | import torch.nn as nn 15 | from torchvision import transforms 16 | import torch.nn.functional as F 17 | import torch.backends.cudnn as cudnn 18 | from torch.utils.data import DataLoader 19 | from torchvision.utils import make_grid 20 | 21 | from networks.vnet import VNet 22 | from dataloaders import utils 23 | from utils import ramps, losses 24 | from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor, TwoStreamBatchSampler 25 | 26 | 27 | ######## lable num ######### 28 | num_labeled = 16 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--root_path', type=str, default='../data/Training_Set/', help='Name of Experiment') 31 | parser.add_argument('--exp', type=str, default='UAMT-ada4-ema/decay-0.9999', help='model_name') 32 | parser.add_argument('--max_iterations', type=int, default=6000, help='maximum epoch number to train') 33 | parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') 34 | parser.add_argument('--labeled_bs', type=int, default=2, help='labeled_batch_size per gpu') 35 | parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') 36 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 37 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 38 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use') 39 | 40 | parser.add_argument('--pseudo', action='store_true', default=True, help='generate the pseudo label') 41 | parser.add_argument('--pseudo_rect', action='store_true', default=False, help='Rectify the pseudo label') 42 | parser.add_argument('--threshold', type=float, default=0.90, help='pseudo label threshold') 43 | parser.add_argument('--T', type=float, default=1) 44 | parser.add_argument('--ratio', type=float, default=0.20, help='model noise ratio') 45 | parser.add_argument('--dropout_rate', type=float, default=0.9) 46 | ### costs 47 | parser.add_argument('--ema_decay', type=float, default=0.9999, help='ema_decay') 48 | parser.add_argument('--consistency', type=float, default=1.0, help='consistency') 49 | parser.add_argument('--consistency_rampup', type=float, default=40.0, help='consistency_rampup') 50 | args = parser.parse_args() 51 | 52 | train_data_path = args.root_path 53 | snapshot_path = "../model/" + args.exp + "/" 54 | 55 | 56 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 57 | batch_size = args.batch_size * len(args.gpu.split(',')) 58 | max_iterations = args.max_iterations 59 | base_lr = args.base_lr 60 | labeled_bs = args.labeled_bs 61 | 62 | if args.deterministic: 63 | cudnn.benchmark = False 64 | cudnn.deterministic = True 65 | random.seed(args.seed) 66 | np.random.seed(args.seed) 67 | torch.manual_seed(args.seed) 68 | torch.cuda.manual_seed(args.seed) 69 | 70 | num_classes = 2 71 | patch_size = (112, 112, 80) 72 | 73 | def get_current_consistency_weight(epoch): 74 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 75 | return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup) 76 | 77 | def update_ema_variables(model, ema_model, alpha, global_step): 78 | # Use the true average until the exponential average is more correct 79 | alpha = min(1 - 1 / (global_step + 1), alpha) 80 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 81 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 82 | 83 | def update_variance(pred1, pred2, loss_origin): 84 | sm = nn.Softmax(dim=1) 85 | log_sm = nn.LogSoftmax(dim=1) 86 | kl_distance = nn.KLDivLoss(reduction='none') 87 | 88 | # 用loss_kl 近似等于 variance 89 | loss_kl = torch.sum(kl_distance(log_sm(pred1), sm(pred2)), dim=1) # pred1 是student model, 被指导 90 | exp_loss_kl = torch.exp(-loss_kl) 91 | # print(variance.shape) 92 | # print('variance mean: %.4f' % torch.mean(exp_variance[:])) 93 | # print('variance min: %.4f' % torch.min(exp_variance[:])) 94 | # print('variance max: %.4f' % torch.max(exp_variance[:])) 95 | loss_rect = torch.mean(loss_origin * exp_loss_kl) + torch.mean(loss_kl) 96 | return loss_rect 97 | 98 | def update_consistency_loss(pred1, pred2): 99 | if args.pseudo: 100 | criterion = nn.CrossEntropyLoss(reduction='none') 101 | # 用pred2生成伪标签 102 | pseudo_label = torch.softmax(pred2.detach() / args.T, dim=1) # T:向前传播次数 103 | max_probs, targets = torch.max(pseudo_label, dim=1) # 概率和标签下标 104 | # print(targets.shape) 105 | if args.pseudo_rect: # 利用两个预测值的方差,对伪标签进行修正 106 | # Crossentropyloss作为损失函数时,iutput应该是[batchsize, n_class, h, w, d],target是[batchsize, h, w, d] 107 | loss_ce = criterion(pred1, targets) # 输出shape [batch, h, w, d] 108 | # print(pred1.shape, targets.shape) 109 | loss = update_variance(pred1, pred2, loss_ce) 110 | else: 111 | mask = max_probs.ge(args.threshold).float() # 大于等于阈值 112 | loss_ce = criterion(pred1, targets) 113 | loss = torch.mean(loss_ce * mask) 114 | # print(loss) 115 | else: 116 | criterion = nn.MSELoss(reduction='none') 117 | loss_mse = criterion(pred1, pred2) 118 | loss = torch.mean(loss_mse) 119 | 120 | return loss 121 | 122 | 123 | if __name__ == "__main__": 124 | ## make logger file 125 | if not os.path.exists(snapshot_path): 126 | os.makedirs(snapshot_path) 127 | if os.path.exists(snapshot_path + '/code'): 128 | shutil.rmtree(snapshot_path + '/code') 129 | shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__'])) 130 | 131 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 132 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 133 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 134 | logging.info(str(args)) 135 | 136 | def create_model(ema=False, has_dropout=False): 137 | # Network definition 138 | if has_dropout: 139 | net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True, 140 | dropout_rate=args.dropout_rate) 141 | else: 142 | net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=False) 143 | model = net.cuda() 144 | if ema: 145 | for param in model.parameters(): 146 | param.detach_() 147 | return model 148 | 149 | model = create_model(has_dropout=True) # student model 150 | ema_model = create_model(ema=True, has_dropout=True) # teacher model 151 | 152 | db_train = LAHeart(base_dir=train_data_path, 153 | split='train', 154 | transform = transforms.Compose([ 155 | RandomRotFlip(), 156 | RandomCrop(patch_size), 157 | ToTensor(), 158 | ])) 159 | db_test = LAHeart(base_dir=train_data_path, 160 | split='test', 161 | transform = transforms.Compose([ 162 | CenterCrop(patch_size), 163 | ToTensor() 164 | ])) 165 | labeled_idxs = list(range(num_labeled)) 166 | unlabeled_idxs = list(range(num_labeled, 80)) 167 | batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-labeled_bs) 168 | def worker_init_fn(worker_id): 169 | random.seed(args.seed+worker_id) 170 | trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) 171 | 172 | 173 | model.train() 174 | ema_model.train() 175 | optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 176 | 177 | writer = SummaryWriter(snapshot_path+'/log') 178 | logging.info("{} itertations per epoch".format(len(trainloader))) 179 | 180 | iter_num = 0 181 | max_epoch = max_iterations//len(trainloader)+1 182 | lr_ = base_lr 183 | # 只有student model在向前传播训练 184 | model.train() 185 | time1 = time.time() 186 | for epoch_num in tqdm(range(max_epoch), ncols=70): 187 | for i_batch, sampled_batch in enumerate(trainloader): 188 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 189 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 190 | 191 | ratio = args.ratio 192 | noise = torch.clamp(torch.randn_like(volume_batch) * ratio, -(2 * ratio), (2 * ratio)) 193 | # student model + noise, teacher不加noise 194 | student_inputs = volume_batch + noise 195 | ema_inputs = volume_batch 196 | outputs = model(student_inputs) # student model 在训练 197 | with torch.no_grad(): 198 | ema_output = ema_model(ema_inputs) 199 | 200 | ## calculate the loss 201 | # the labled data loss 202 | loss_seg = F.cross_entropy(outputs[:labeled_bs], label_batch[:labeled_bs]) # 只取labeld的output 203 | outputs_soft = F.softmax(outputs, dim=1) 204 | loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1) 205 | supervised_loss = 0.5 * (loss_seg + loss_seg_dice) # only on labeled data 206 | # print('************ supervised loss:{}'.format(supervised_loss)) 207 | 208 | # 计算consisitency loss 209 | consistency_loss = update_consistency_loss(outputs, ema_output) 210 | # print('************ consisitncy loss:{}'.format(consistency_loss)) 211 | 212 | consistency_weight = get_current_consistency_weight(iter_num // 150) 213 | loss = supervised_loss + consistency_weight * consistency_loss 214 | # print('************ Total loss:{}'.format(loss)) 215 | 216 | optimizer.zero_grad() 217 | loss.backward() 218 | optimizer.step() 219 | # 将studnet model 的参数更新到 teacher model 220 | update_ema_variables(model, ema_model, args.ema_decay, iter_num) 221 | 222 | iter_num = iter_num + 1 223 | writer.add_scalar('lr', lr_, iter_num) 224 | writer.add_scalar('loss/loss', loss, iter_num) 225 | writer.add_scalar('loss/loss_seg', loss_seg, iter_num) 226 | writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num) 227 | writer.add_scalar('loss/supervised_loss', supervised_loss, iter_num) 228 | writer.add_scalar('train/consistency_loss', consistency_loss, iter_num) 229 | writer.add_scalar('train/consistency_weight', consistency_weight, iter_num) 230 | 231 | 232 | logging.info('iteration %d : loss : %f supervised_loss: %f consistency_loss: %f consistency_weight: %f' % 233 | (iter_num, loss.item(), supervised_loss.item(), consistency_loss.item(), consistency_weight)) 234 | 235 | if iter_num % 50 == 0: 236 | image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 237 | grid_image = make_grid(image, 5, normalize=True) 238 | writer.add_image('train/Image', grid_image, iter_num) 239 | 240 | # image = outputs_soft[0, 3:4, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 241 | image = torch.max(outputs_soft[0, :, :, :, 20:61:10], 0)[1].permute(2, 0, 1).data.cpu().numpy() 242 | image = utils.decode_seg_map_sequence(image) 243 | grid_image = make_grid(image, 5, normalize=False) 244 | writer.add_image('train/Predicted_label', grid_image, iter_num) 245 | 246 | image = label_batch[0, :, :, 20:61:10].permute(2, 0, 1) 247 | grid_image = make_grid(utils.decode_seg_map_sequence(image.data.cpu().numpy()), 5, normalize=False) 248 | writer.add_image('train/Groundtruth_label', grid_image, iter_num) 249 | 250 | image = volume_batch[-1, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 251 | grid_image = make_grid(image, 5, normalize=True) 252 | writer.add_image('unlabel/Image', grid_image, iter_num) 253 | 254 | # image = outputs_soft[-1, 3:4, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 255 | image = torch.max(outputs_soft[-1, :, :, :, 20:61:10], 0)[1].permute(2, 0, 1).data.cpu().numpy() 256 | image = utils.decode_seg_map_sequence(image) 257 | grid_image = make_grid(image, 5, normalize=False) 258 | writer.add_image('unlabel/Predicted_label', grid_image, iter_num) 259 | 260 | image = label_batch[-1, :, :, 20:61:10].permute(2, 0, 1) 261 | grid_image = make_grid(utils.decode_seg_map_sequence(image.data.cpu().numpy()), 5, normalize=False) 262 | writer.add_image('unlabel/Groundtruth_label', grid_image, iter_num) 263 | 264 | ## change lr 265 | if iter_num % 2500 == 0: 266 | lr_ = base_lr * 0.1 ** (iter_num // 2500) 267 | for param_group in optimizer.param_groups: 268 | param_group['lr'] = lr_ 269 | 270 | if (iter_num % 1000 == 0) & (iter_num >= 6000): 271 | save_mode_path = os.path.join(snapshot_path, 'ada_iter_' + str(iter_num) + '.pth') 272 | torch.save(model.state_dict(), save_mode_path) 273 | logging.info("save model to {}".format(save_mode_path)) 274 | 275 | if iter_num >= max_iterations: 276 | break 277 | if iter_num >= max_iterations: 278 | break 279 | 280 | time2 = time.time() 281 | total_time = (time2 - time1) / 3600 282 | print('total train time:', total_time) 283 | 284 | save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth') 285 | torch.save(model.state_dict(), save_mode_path) 286 | 287 | logging.info("save model to {}".format(save_mode_path)) 288 | writer.close() 289 | -------------------------------------------------------------------------------- /code/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import numpy as np 4 | 5 | def dice_loss(score, target): 6 | target = target.float() 7 | smooth = 1e-5 8 | intersect = torch.sum(score * target) 9 | y_sum = torch.sum(target * target) 10 | z_sum = torch.sum(score * score) 11 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 12 | loss = 1 - loss 13 | return loss 14 | 15 | def dice_loss1(score, target): 16 | target = target.float() 17 | smooth = 1e-5 18 | intersect = torch.sum(score * target) 19 | y_sum = torch.sum(target) 20 | z_sum = torch.sum(score) 21 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 22 | loss = 1 - loss 23 | return loss 24 | 25 | def entropy_loss(p,C=2): 26 | ## p N*C*W*H*D 27 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1)/torch.tensor(np.log(C)).cuda() 28 | ent = torch.mean(y1) 29 | 30 | return ent 31 | 32 | def softmax_dice_loss(input_logits, target_logits): 33 | """Takes softmax on both sides and returns MSE loss 34 | 35 | Note: 36 | - Returns the sum over all examples. Divide by the batch size afterwards 37 | if you want the mean. 38 | - Sends gradients to inputs but not the targets. 39 | """ 40 | assert input_logits.size() == target_logits.size() 41 | input_softmax = F.softmax(input_logits, dim=1) 42 | target_softmax = F.softmax(target_logits, dim=1) 43 | n = input_logits.shape[1] 44 | dice = 0 45 | for i in range(0, n): 46 | dice += dice_loss1(input_softmax[:, i], target_softmax[:, i]) 47 | mean_dice = dice / n 48 | 49 | return mean_dice 50 | 51 | 52 | def entropy_loss_map(p, C=2): 53 | ent = -1*torch.sum(p * torch.log(p + 1e-6), dim=1, keepdim=True)/torch.tensor(np.log(C)).cuda() 54 | return ent 55 | 56 | def softmax_mse_loss(input_logits, target_logits): 57 | """Takes softmax on both sides and returns MSE loss 58 | 59 | Note: 60 | - Returns the sum over all examples. Divide by the batch size afterwards 61 | if you want the mean. 62 | - Sends gradients to inputs but not the targets. 63 | """ 64 | assert input_logits.size() == target_logits.size() 65 | input_softmax = F.softmax(input_logits, dim=1) 66 | target_softmax = F.softmax(target_logits, dim=1) 67 | 68 | mse_loss = (input_softmax-target_softmax)**2 69 | return mse_loss 70 | 71 | def softmax_kl_loss(input_logits, target_logits): 72 | """Takes softmax on both sides and returns KL divergence 73 | 74 | Note: 75 | - Returns the sum over all examples. Divide by the batch size afterwards 76 | if you want the mean. 77 | - Sends gradients to inputs but not the targets. 78 | """ 79 | assert input_logits.size() == target_logits.size() 80 | input_log_softmax = F.log_softmax(input_logits, dim=1) 81 | target_softmax = F.softmax(target_logits, dim=1) 82 | 83 | # return F.kl_div(input_log_softmax, target_softmax) 84 | kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='none') 85 | # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...]) 86 | return kl_div 87 | 88 | def symmetric_mse_loss(input1, input2): 89 | """Like F.mse_loss but sends gradients to both directions 90 | 91 | Note: 92 | - Returns the sum over all examples. Divide by the batch size afterwards 93 | if you want the mean. 94 | - Sends gradients to both input1 and input2. 95 | """ 96 | assert input1.size() == input2.size() 97 | return torch.mean((input1 - input2)**2) 98 | -------------------------------------------------------------------------------- /code/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /data/test.list: -------------------------------------------------------------------------------- 1 | UPT6DX9IQY9JAZ7HJKA7 2 | UTBUJIWZMKP64E3N73YC 3 | ULHWPWKKLTE921LQLH1P 4 | V0MZOWJ6MU3RMRCV9EXR 5 | VDOF02M8ZHEAADFMS6NP 6 | VG4C826RAAKVMV9BQLVD 7 | VIXBEFTNVHZWKAKURJBN 8 | VQ2L3WM8KEVF6L44E6G9 9 | WBG9WYZ1B25WDT5WAT8T 10 | WMDG2EFA6L2SNDZXIRU0 11 | WNPKE0W404QE9AELX1LR 12 | WSJB9P4JCXUVHBOYFVWL 13 | WW8F5CO4S4K5IM5Z7EXX 14 | X18LU5AOBNNDMLTA0JZL 15 | XYDLYJ5CS19FDBVLJIPI 16 | Y7ZU0B2APPF54WG6PDMF 17 | YDKD1HVHSME6NVMA8I39 18 | Z9GMG63CJLL0VW893BB1 19 | ZIJLJAVQV3FJ6JSQOH1E 20 | ZQPMJ4XEC5A4BISD45P1 21 | -------------------------------------------------------------------------------- /data/train.list: -------------------------------------------------------------------------------- 1 | 06SR5RBREL16DQ6M8LWS 2 | 0RZDK210BSMWAA6467LU 3 | 1D7CUD1955YZPGK8XHJX 4 | 1GU15S0GJ6PFNARO469W 5 | 1MHBF3G6DCPWHSKG7XCP 6 | 23X6SY44VT9KFHR7S7OC 7 | 2XL5HSFSE93RMOJDRGR4 8 | 38CWS74285MFGZZXR09Z 9 | 3C2QTUNI0852XV7ZH4Q1 10 | 3DA0T2V6JJ2NLUAV6FWM 11 | 4498CA6DZWELOXCBRYRF 12 | 45C45I6IXAFGNRO067W9 13 | 4CHFJGF6ZUM7CMZTNFQF 14 | 4EPVTT1HPA8U60CDUKXE 15 | 57SGAJMLCTCH92QUA0EE 16 | 5BHTH9RHH3PQT913I59W 17 | 5FKQL4K14KCB72Y8YMC2 18 | 5HH0WPWIY06DLAFOBQ4M 19 | 5QFK2PMHNX7UALK52NNA 20 | 5UB5KFD2PK38Z4LS6W80 21 | 6799D6LEBH3NSRV1KH27 22 | 78NJ5YFQF72BGC8RO51C 23 | 7FUCNXB39F78WTOP5K71 24 | 8GYK8A9MBRC9TV0FVSRA 25 | 8M99G0JLAXG9GLPV0O8G 26 | 8RE90C8H5DKF4V6HO8UU 27 | 8ZG2TRZ81MAWHZPN9KKG 28 | 9DCM2IB45SK6YKQNYUQY 29 | 9DHWWP5Y66VDMPXISZ13 30 | 9DQYTIU00I4JC0OEOKQQ 31 | A11O45O3NAXWM7T2H8CH 32 | A4R1S23KR0KU2WSYHK2X 33 | A5RNNK0A891WUSC2V624 34 | AT5CRO5JUDBWD4RUPXSQ 35 | BNK95S2SJXEGSW7VAKYU 36 | BXJWOUYP2J3EN4U92517 37 | BYSRSI3H4YTWKMM3MADP 38 | BZUFJX66T0W6ZPVTL9DU 39 | CB5P5W7X310NIIVU7UZV 40 | CBIJFVZ5L9BS0LKWE8YL 41 | CCGAKN4EDT72KC8TTJ76 42 | CLXFYOBQDCVXQ9P7YC07 43 | CMPXO4J23G58J53Q98SZ 44 | CZPMV6KWZ4I7IJJP9FOK 45 | DLKXBV73A55ZTSZ0QQI2 46 | DQ5UYBGR5QP6L692QSG6 47 | DYXSCIWHLSUOZIDDSZ40 48 | E2ZMO66WGS74UKXTZPPQ 49 | EJ5V7SPR4961JWD6SS8V 50 | FGM5NIWN3URY4HF4WNUW 51 | GSC9KNY0VEZXFSGWNF25 52 | HVE7DR3CUA2IM3RC6OMA 53 | HZZ4O0BRKF8S0YX3NNF7 54 | I2VZ7N8H9QYNYT7ZZF1Y 55 | IDWWHGWJ5STOQXSDT6GU 56 | IIY6TYJMTJIZRIZLB9YW 57 | IJJY51YW3W4YJJ7DTVTK 58 | IQYKPTWXVV9H0IHB8YXC 59 | JEC6HJ7SQJXBKVREX03F 60 | JGFOLWJF7YCYD8DPHQNH 61 | K32FD6LRSUSSXGS1YUOX 62 | KM5RYAMP4P4ZP6XWP3Q2 63 | KSNYHUBHHUJTYJ14UQZR 64 | LH4FVU3TQDEC87YGN6FL 65 | LJSDNMND9SHKM7Q4IRHJ 66 | MFTDVMBWFNQ3F5KHBRDR 67 | MJHV7F65TB2A76CQLOC3 68 | MVKIPGBKTNSENNP1S4HB 69 | O5TSIKRD4AIB8K84WIR9 70 | OIRDLE32TXZX942FVZMM 71 | P1OTI3IWJUIB5NRLULLH 72 | PVNXUK681N9BY14K4Z86 73 | Q0MEX9ZIKAGJORSPLQ3Y 74 | Q7J0WYM695R9MA285ZW0 75 | QZC1W0FNR19KJFLOCFLH 76 | R8ER97O9UUN77C02VE2J 77 | RSZY41MT2FGDKHWWL5L2 78 | SN4LF8SGBSRQUPTDSX78 79 | TDDI6L3Y0L9VVFP9MNFS 80 | UZUZZT2W9IUSHL6ASOX3 81 | --------------------------------------------------------------------------------