├── README.md └── code_oa ├── dataloaders ├── bezier_curve.py ├── createlist.py ├── dataset.py ├── preprocess.py └── utils.py ├── networks ├── discriminator.py ├── unet.py └── unet3d.py ├── selection.py ├── test_2D_fully.py ├── tools.py ├── train_finetune.py ├── train_source.py ├── utils ├── bezier_curve.py ├── loss │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── compound_losses.cpython-310.pyc │ │ ├── deep_supervision.cpython-310.pyc │ │ ├── dice.cpython-310.pyc │ │ └── robust_ce_loss.cpython-310.pyc │ ├── compound_losses.py │ ├── deep_supervision.py │ ├── dice.py │ └── robust_ce_loss.py ├── losses.py ├── metrics.py ├── ramps.py └── util.py └── val_2D.py /README.md: -------------------------------------------------------------------------------- 1 | # UGTST 2 | An Uncertainty-guided Tiered Self-training Framework for Active Source-free Domain Adaptation in Prostate Segmentation (MICCAI 2024 Accept 🎉) 3 | 4 | This repository provides the implementation of the MICCAI 2024 paper "An Uncertainty-guided Tiered Self-training Framework for Active Source-free Domain Adaptation in Prostate Segmentation". 5 | 6 | # Contact 7 | For any inquiries, please contact: 8 | Zihao Luo, zihaoluoh@163.com 9 | -------------------------------------------------------------------------------- /code_oa/dataloaders/bezier_curve.py: -------------------------------------------------------------------------------- 1 | # [5] Zhou Z, Qi L, Yang X, et al. Generalizable cross-modality medical image segmentation via style augmentation and dual normalization[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022: 20856-20865. 2 | 3 | import numpy as np 4 | import random 5 | import matplotlib.pyplot as plt 6 | try: 7 | from scipy.special import comb 8 | except: 9 | from scipy.misc import comb 10 | """ 11 | this is for none linear transformation 12 | 13 | 14 | """ 15 | 16 | 17 | # bernstein_poly(i, n, t):计算伯恩斯坦多项式,其中 i 为多项式的次数,n 为多项式的阶数,t 为参数化值。该函数用于计算贝塞尔曲线中的权重系数。 18 | def bernstein_poly(i, n, t): 19 | """ 20 | The Bernstein polynomial of n, i as a function of t 21 | """ 22 | return comb(n, i) * ( t**(n-i) ) * (1 - t)**i 23 | 24 | 25 | def bezier_curve(points, nTimes=1000): 26 | """ 27 | Given a set of control points, return the 28 | bezier curve defined by the control points. 29 | Control points should be a list of lists, or list of tuples 30 | such as [ [1,1], 31 | [2,3], 32 | [4,5], ..[Xn, Yn] ] 33 | nTimes is the number of time steps, defaults to 1000 34 | See http://processingjs.nihongoresources.com/bezierinfo/ 35 | """ 36 | 37 | nPoints = len(points) 38 | xPoints = np.array([p[0] for p in points]) 39 | yPoints = np.array([p[1] for p in points]) 40 | 41 | t = np.linspace(0.0, 1.0, nTimes) 42 | 43 | polynomial_array = np.array([bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)]) 44 | 45 | xvals = np.dot(xPoints, polynomial_array) 46 | yvals = np.dot(yPoints, polynomial_array) 47 | 48 | return xvals, yvals 49 | 50 | 51 | def nonlinear_transformation(x): 52 | points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] 53 | xvals, yvals = bezier_curve(points, nTimes=1000) 54 | if random.random() < 0.5: 55 | # Half change to get flip 56 | xvals = np.sort(xvals) 57 | else: 58 | xvals, yvals = np.sort(xvals), np.sort(yvals) 59 | nonlinear_x = np.interp(x, xvals, yvals) 60 | return nonlinear_x -------------------------------------------------------------------------------- /code_oa/dataloaders/createlist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | data_folder = "/media/disk2t_solid/zihao/SFDA/data/data_preprocessed/3D_pancreas/MCF" 5 | 6 | file_list = [file for file in os.listdir(data_folder) if file.endswith(".h5")] 7 | 8 | random.shuffle(file_list) 9 | 10 | split_point1 = int(0.7 * len(file_list)) 11 | split_point2 = int(0.8 * len(file_list)) 12 | 13 | train_set = file_list[:split_point1] 14 | val_set = file_list[split_point1:split_point2] 15 | test_set = file_list[split_point2:] 16 | 17 | with open(os.path.join(data_folder, "trainlist.txt"), 'w') as train_file: 18 | train_file.write('\n'.join(train_set)) 19 | 20 | with open(os.path.join(data_folder, "vallist.txt"), 'w') as val_file: 21 | val_file.write('\n'.join(val_set)) 22 | 23 | with open(os.path.join(data_folder, "testlist.txt"), 'w') as test_file: 24 | test_file.write('\n'.join(test_set)) 25 | 26 | trainlist_path = os.path.join(data_folder, 'trainlist.txt') 27 | with open(trainlist_path, 'r') as f: 28 | file_names = f.read().splitlines() 29 | 30 | slice_file_names = [] 31 | slices_folder_path = os.path.join(data_folder,'slices') 32 | print(slices_folder_path) 33 | 34 | for file_name in file_names: 35 | case_name = file_name.split('_')[0] 36 | case_name = case_name.replace('.h5','') 37 | pattern = re.compile(f'{case_name}_slice(\d+).h5') 38 | print(case_name) 39 | slice_files = [f for f in os.listdir(slices_folder_path) if re.match(pattern, f)] 40 | slice_file_names.extend(slice_files) 41 | 42 | output_path = os.path.join(data_folder,'train_slices.txt') 43 | with open(output_path, 'w') as f: 44 | for slice_file_name in slice_file_names: 45 | f.write(slice_file_name + '\n') 46 | 47 | print("Finished!") -------------------------------------------------------------------------------- /code_oa/dataloaders/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from logging import root 4 | import os 5 | from scipy import ndimage 6 | import torch 7 | import random 8 | import h5py 9 | import pandas as pd 10 | from torch.utils.data import Dataset 11 | from torch.utils.data.sampler import Sampler 12 | import numpy as np 13 | import itertools 14 | from torchvision import transforms 15 | from scipy.ndimage import zoom 16 | import matplotlib.pyplot as plt 17 | from scipy.ndimage import gaussian_filter 18 | import cv2 19 | import skimage 20 | 21 | class h5DataSet(Dataset): 22 | def __init__( 23 | self, 24 | base_dir=None, 25 | split="train", 26 | num=None, 27 | transform=None, 28 | ops_weak=None, 29 | ops_strong=None, 30 | active_method=None 31 | ): 32 | self._base_dir = base_dir 33 | self.sample_list = [] 34 | self.split = split 35 | self.transform = transform 36 | self.ops_weak = ops_weak 37 | self.ops_strong = ops_strong 38 | self.active_method = active_method 39 | assert bool(ops_weak) == bool( 40 | ops_strong 41 | ), "For using CTAugment learned policies, provide both weak and strong batch augmentation policy" 42 | 43 | if self.split == "train": 44 | with open(self._base_dir + "/train_slices.txt", "r") as f1: 45 | self.sample_list = f1.readlines() 46 | self.sample_list = [item.replace("\n", "") for item in self.sample_list] 47 | elif self.split == "train_stage1" and self.active_method: 48 | with open(self._base_dir + f"/stage1_slice_{self.active_method}.txt", "r") as f1: 49 | self.sample_list = f1.readlines() 50 | self.sample_list = [item.replace("\n", "") for item in self.sample_list] 51 | elif self.split == "train_stage2" and self.active_method: 52 | with open(self._base_dir + f"/all_slice_{self.active_method}.txt", "r") as f1: 53 | self.sample_list = f1.readlines() 54 | self.sample_list = [item.replace("\n", "") for item in self.sample_list] 55 | elif self.split == "semi_train" and self.active_method: 56 | with open(self._base_dir + f"/all_slice_{self.active_method}.txt", "r") as f1: 57 | self.sample_list = f1.readlines() 58 | self.sample_list = [item.replace("\n", "") for item in self.sample_list] 59 | elif self.split == "semi_train": 60 | with open(self._base_dir + "/all_slice.txt", "r") as f1: 61 | self.sample_list = f1.readlines() 62 | self.sample_list = [item.replace("\n", "") for item in self.sample_list] 63 | elif self.split == "val": 64 | with open(self._base_dir + "/vallist.txt", "r") as f: 65 | self.sample_list = f.readlines() 66 | self.sample_list = [item.replace("\n", "") for item in self.sample_list] 67 | if num is not None and self.split == "train" or "train_stage1" or "train_stage2": 68 | self.sample_list = self.sample_list[:num] 69 | print("total {} samples".format(len(self.sample_list))) 70 | 71 | def __len__(self): 72 | return len(self.sample_list) 73 | 74 | def __getitem__(self, idx): 75 | case = self.sample_list[idx] 76 | if self.split == "val": 77 | h5f = h5py.File(self._base_dir + "/{}".format(case), "r") 78 | elif self.split == "train": 79 | h5f = h5py.File(self._base_dir + "/slices/{}".format(case), "r") 80 | elif self.split == "train_stage1": 81 | h5f = h5py.File(self._base_dir + "/slices/{}".format(case), "r") 82 | elif self.split == "train_stage2": 83 | h5f = h5py.File(self._base_dir + "/slices/{}".format(case), "r") 84 | elif self.split == "semi_train": 85 | h5f = h5py.File(self._base_dir + "/slices/{}".format(case), "r") 86 | 87 | image = h5f["image"][:] 88 | label = h5f["label"][:] 89 | image = image.astype(np.float32) 90 | 91 | label = label.astype(np.uint8) 92 | sample = {"image": image, "label": label} 93 | if (self.split == "train" or self.split == "train_stage1" or self.split == "train_stage2" or 94 | self.split == "semi_train"): 95 | if None not in (self.ops_weak, self.ops_strong): 96 | sample = self.transform(sample, self.ops_weak, self.ops_strong) 97 | else: 98 | sample = self.transform(sample) 99 | sample["idx"] = case 100 | return sample 101 | 102 | def random_rot_flip(image, label=None): 103 | k = np.random.randint(0, 4) 104 | image = np.rot90(image, k) 105 | axis = np.random.randint(0, 2) 106 | image = np.flip(image, axis=axis).copy() 107 | if label is not None: 108 | label = np.rot90(label, k) 109 | label = np.flip(label, axis=axis).copy() 110 | return image, label 111 | else: 112 | return image 113 | 114 | def random_rotate(image, label): 115 | angle = np.random.randint(-20, 20) 116 | image = ndimage.rotate(image, angle, order=0, reshape=False) 117 | label = ndimage.rotate(label, angle, order=0, reshape=False) 118 | return image, label 119 | 120 | def gaussian_noise(image, label): 121 | mean = 0 122 | std = 0.05 123 | noise = np.random.normal(mean, std, image.shape) 124 | image = image + noise 125 | return image, label 126 | 127 | def gaussian_blur(image, label): 128 | std_range = [0, 1] 129 | std = np.random.uniform(std_range[0], std_range[1]) 130 | image = gaussian_filter(image, std, order=0) 131 | return image, label 132 | 133 | def gammacorrection(image, label): 134 | gamma_min, gamma_max = 0.7, 1.5 135 | flip_prob = 0.5 136 | gamma_c = random.random() * (gamma_max - gamma_min) + gamma_min 137 | v_min = image.min() 138 | v_max = image.max() 139 | if (v_min < v_max): 140 | image = (image - v_min) / (v_max - v_min) 141 | if (np.random.uniform() < flip_prob): 142 | image = 1.0 - image 143 | image = np.power(image, gamma_c) * (v_max - v_min) + v_min 144 | image = image 145 | return image, label 146 | 147 | def contrastaug(image, label): 148 | contrast_range = [0.9, 1.1] 149 | preserve = True 150 | factor = np.random.uniform(contrast_range[0], contrast_range[1]) 151 | mean = image.mean() 152 | if preserve: 153 | minm = image.min() 154 | maxm = image.max() 155 | image = (image - mean) * factor + mean 156 | if preserve: 157 | image[image < minm] = minm 158 | image[image > maxm] = maxm 159 | return image, label 160 | 161 | def random_equalize_hist(image): 162 | image = skimage.exposure.equalize_hist(image) 163 | return image 164 | 165 | def random_sharpening(image): 166 | blurred = ndimage.gaussian_filter(image, 3) 167 | blurred_filter = ndimage.gaussian_filter(blurred, 1) 168 | alpha = random.randrange(1, 10) 169 | image = blurred + alpha * (blurred - blurred_filter) 170 | return image 171 | 172 | def min_max_normalize(image): 173 | min_val = image.min() 174 | max_val = image.max() 175 | normalized_image = (image - min_val) / (max_val - min_val) 176 | return normalized_image 177 | 178 | class RandomGenerator(object): 179 | def __init__(self, output_size, SpatialAug=True, IntensityAug=True, NonlinearAug=False): 180 | self.output_size = output_size 181 | self.SpatialAug = SpatialAug 182 | self.IntensityAug = IntensityAug 183 | self.NonlinearAug = NonlinearAug 184 | def __call__(self, sample): 185 | image, label = sample["image"], sample["label"] 186 | 187 | # ind = random.randrange(0, img.shape[0]) 188 | # image = img[ind, ...] 189 | # label = lab[ind, ...] 190 | if self.NonlinearAug: 191 | if random.random() > 0.5: 192 | image = nonlinear_transformation(image) 193 | if self.SpatialAug: 194 | if random.random() > 0.5: 195 | image, label = random_rot_flip(image, label) 196 | elif random.random() > 0.5: 197 | image, label = random_rotate(image, label) 198 | if self.IntensityAug: 199 | if random.random() > 0.7: 200 | image, label = gammacorrection(image, label) 201 | if random.random() > 0.7: 202 | image, label = contrastaug(image, label) 203 | if random.random() > 0.7: 204 | image = random_equalize_hist(image) 205 | if random.random() > 0.7: 206 | image = random_sharpening(image) 207 | if random.random() > 0.7: 208 | image, label = gaussian_blur(image, label) 209 | if random.random() > 0.5: 210 | image, label = gaussian_noise(image, label) 211 | 212 | x, y = image.shape 213 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=0) 214 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 215 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 216 | label = torch.from_numpy(label.astype(np.uint8)) 217 | sample = {"image": image, "label": label} 218 | return sample 219 | 220 | class TwoStreamBatchSampler(Sampler): 221 | """Iterate two sets of indices 222 | 223 | An 'epoch' is one iteration through the primary indices. 224 | During the epoch, the secondary indices are iterated through 225 | as many times as needed. 226 | """ 227 | 228 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 229 | self.primary_indices = primary_indices 230 | self.secondary_indices = secondary_indices 231 | self.secondary_batch_size = secondary_batch_size 232 | self.primary_batch_size = batch_size - secondary_batch_size 233 | 234 | assert len(self.primary_indices) >= self.primary_batch_size > 0 235 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 236 | 237 | def __iter__(self): 238 | primary_iter = iterate_once(self.primary_indices) 239 | secondary_iter = iterate_eternally(self.secondary_indices) 240 | return ( 241 | primary_batch + secondary_batch 242 | for (primary_batch, secondary_batch) 243 | in zip(grouper(primary_iter, self.primary_batch_size), 244 | grouper(secondary_iter, self.secondary_batch_size)) 245 | ) 246 | 247 | def __len__(self): 248 | return len(self.primary_indices) // self.primary_batch_size 249 | 250 | 251 | def iterate_once(iterable): 252 | return np.random.permutation(iterable) 253 | 254 | 255 | def iterate_eternally(indices): 256 | def infinite_shuffles(): 257 | while True: 258 | yield np.random.permutation(indices) 259 | 260 | return itertools.chain.from_iterable(infinite_shuffles()) 261 | 262 | 263 | def grouper(iterable, n): 264 | "Collect data into fixed-length chunks or blocks" 265 | # grouper('ABCDEFG', 3) --> ABC DEF" 266 | args = [iter(iterable)] * n 267 | return zip(*args) 268 | # [5] Zhou Z, Qi L, Yang X, et al. Generalizable cross-modality medical image segmentation via style augmentation and dual normalization[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022: 20856-20865. 269 | 270 | import numpy as np 271 | import random 272 | import matplotlib.pyplot as plt 273 | try: 274 | from scipy.special import comb 275 | except: 276 | from scipy.misc import comb 277 | """ 278 | this is for none linear transformation 279 | 280 | 281 | """ 282 | 283 | def bernstein_poly(i, n, t): 284 | """ 285 | The Bernstein polynomial of n, i as a function of t 286 | """ 287 | return comb(n, i) * ( t**(n-i) ) * (1 - t)**i 288 | 289 | 290 | def bezier_curve(points, nTimes=1000): 291 | """ 292 | Given a set of control points, return the 293 | bezier curve defined by the control points. 294 | Control points should be a list of lists, or list of tuples 295 | such as [ [1,1], 296 | [2,3], 297 | [4,5], ..[Xn, Yn] ] 298 | nTimes is the number of time steps, defaults to 1000 299 | See http://processingjs.nihongoresources.com/bezierinfo/ 300 | """ 301 | 302 | nPoints = len(points) 303 | xPoints = np.array([p[0] for p in points]) 304 | yPoints = np.array([p[1] for p in points]) 305 | 306 | t = np.linspace(0.0, 1.0, nTimes) 307 | 308 | polynomial_array = np.array([bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)]) 309 | 310 | xvals = np.dot(xPoints, polynomial_array) 311 | yvals = np.dot(yPoints, polynomial_array) 312 | 313 | return xvals, yvals 314 | 315 | 316 | def nonlinear_transformation(x): 317 | points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] 318 | xvals, yvals = bezier_curve(points, nTimes=1000) 319 | if random.random() < 0.5: 320 | # Half change to get flip 321 | xvals = np.sort(xvals) 322 | else: 323 | xvals, yvals = np.sort(xvals), np.sort(yvals) 324 | nonlinear_x = np.interp(x, xvals, yvals) 325 | return nonlinear_x 326 | 327 | def visualize_sample(sample): 328 | image = sample["image"][0].numpy() # Assuming the batch size is 1 329 | label = sample["label"][0].numpy() 330 | 331 | plt.figure(figsize=(10, 5)) 332 | 333 | plt.subplot(1, 2, 1) 334 | plt.imshow(image[0], cmap="gray") 335 | plt.title("Original Image") 336 | 337 | plt.subplot(1, 2, 2) 338 | if len(label.shape) == 2: 339 | plt.imshow(label, cmap="gray") 340 | elif len(label.shape) == 3: 341 | plt.imshow(label[0], cmap="gray") # Assuming label is 3D (batch_size, height, width) 342 | else: 343 | raise ValueError("Invalid shape for label data") 344 | 345 | plt.title("Original Label") 346 | 347 | plt.show() 348 | 349 | 350 | if __name__ == "__main__": 351 | root_dir = r"F:\SFDA\data\data_preprocessed\SAML\3T" 352 | dataset = h5DataSet(base_dir=root_dir, split="train") 353 | db_train = h5DataSet(base_dir=root_dir, split="train",transform=transforms.Compose([ 354 | RandomGenerator(output_size=(384, 384)), 355 | ])) 356 | train_loader = torch.utils.data.DataLoader(db_train, batch_size=24, shuffle=True, num_workers=1) 357 | for sample in train_loader: 358 | visualize_sample(sample) 359 | 360 | db_val = h5DataSet(base_dir=root_dir, split="val") 361 | 362 | valloader = torch.utils.data.DataLoader(db_val, batch_size=1, shuffle=False, 363 | num_workers=1) 364 | for sample in train_loader: 365 | image = sample['image'] 366 | label = sample['label'] 367 | print(image.shape, label.shape) 368 | print(image.min(), image.max(), label.max()) 369 | -------------------------------------------------------------------------------- /code_oa/dataloaders/preprocess.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import h5py 4 | import numpy as np 5 | import SimpleITK as sitk 6 | import re 7 | from scipy import ndimage 8 | image_pattern = re.compile(r'Case(\d+).nii.gz') 9 | 10 | data_raw_path = "/media/disk2t_solid/zihao/SFDA/data/data_raw/Prostate" 11 | data_preprocessed_path = "/media/disk2t_solid/zihao/SFDA/data/data_preprocessed/Prostate" 12 | 13 | subfolders = glob.glob(os.path.join(data_raw_path, "*")) 14 | 15 | for subfolder in subfolders: 16 | image_paths = glob.glob(os.path.join(subfolder, "*.nii.gz")) 17 | 18 | subfolder_name = os.path.basename(subfolder) 19 | 20 | for image_path in image_paths: 21 | 22 | image_match = image_pattern.search(os.path.basename(image_path)) 23 | if image_match: 24 | case_number = image_match.group(1) 25 | print("Processing Case:", case_number) 26 | 27 | label_path = os.path.join(subfolder, f'Case{case_number}_segmentation.nii.gz') 28 | 29 | if os.path.exists(label_path): 30 | print("Processing Label:", label_path) 31 | 32 | img_itk = sitk.ReadImage(image_path) 33 | origin = img_itk.GetOrigin() 34 | spacing = img_itk.GetSpacing() 35 | direction = img_itk.GetDirection() 36 | image = sitk.GetArrayFromImage(img_itk) 37 | 38 | label_itk = sitk.ReadImage(label_path) 39 | label = sitk.GetArrayFromImage(label_itk).astype(np.uint8) 40 | image = (image - image.min()) / (image.max() - image.min()) 41 | image = image.astype(np.float32) 42 | 43 | num_slices = image.shape[0] 44 | 45 | item = image_path.split("/")[-1].split(".")[0] 46 | if image.shape != label.shape: 47 | print("Error") 48 | print(item) 49 | file_path = os.path.join(data_preprocessed_path, subfolder_name) 50 | hdf5_file_path = os.path.join(file_path, 51 | f'Case{case_number}.h5') 52 | if not os.path.exists(file_path): 53 | os.makedirs(file_path) 54 | os.makedirs(os.path.join(file_path, 'slices')) 55 | with h5py.File(hdf5_file_path, 'w') as f: 56 | f.create_dataset('image', data=image, compression="gzip") 57 | f.create_dataset('label', data=label, compression="gzip") 58 | f.create_dataset('spacing',data=spacing, compression="gzip") 59 | 60 | for slice_ind in range(num_slices): 61 | hdf5_file_path = os.path.join(file_path, 62 | f'slices/Case{case_number}_slice{slice_ind}.h5') 63 | 64 | with h5py.File(hdf5_file_path, 'w') as f: 65 | f.create_dataset('image', data=image[slice_ind], compression="gzip") 66 | f.create_dataset('label', data=label[slice_ind], compression="gzip") 67 | print("Preprocess completed.") 68 | -------------------------------------------------------------------------------- /code_oa/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import matplotlib.pyplot as plt 6 | from skimage import measure 7 | import scipy.ndimage as nd 8 | 9 | 10 | def recursive_glob(rootdir='.', suffix=''): 11 | """Performs recursive glob with given suffix and rootdir 12 | :param rootdir is the root directory 13 | :param suffix is the suffix to be searched 14 | """ 15 | return [os.path.join(looproot, filename) 16 | for looproot, _, filenames in os.walk(rootdir) 17 | for filename in filenames if filename.endswith(suffix)] 18 | 19 | def get_cityscapes_labels(): 20 | return np.array([ 21 | # [ 0, 0, 0], 22 | [128, 64, 128], 23 | [244, 35, 232], 24 | [70, 70, 70], 25 | [102, 102, 156], 26 | [190, 153, 153], 27 | [153, 153, 153], 28 | [250, 170, 30], 29 | [220, 220, 0], 30 | [107, 142, 35], 31 | [152, 251, 152], 32 | [0, 130, 180], 33 | [220, 20, 60], 34 | [255, 0, 0], 35 | [0, 0, 142], 36 | [0, 0, 70], 37 | [0, 60, 100], 38 | [0, 80, 100], 39 | [0, 0, 230], 40 | [119, 11, 32]]) 41 | 42 | def get_pascal_labels(): 43 | """Load the mapping that associates pascal classes with label colors 44 | Returns: 45 | np.ndarray with dimensions (21, 3) 46 | """ 47 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 48 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 49 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 50 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 51 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 52 | [0, 64, 128]]) 53 | 54 | 55 | def encode_segmap(mask): 56 | """Encode segmentation label images as pascal classes 57 | Args: 58 | mask (np.ndarray): raw segmentation label image of dimension 59 | (M, N, 3), in which the Pascal classes are encoded as colours. 60 | Returns: 61 | (np.ndarray): class map with dimensions (M,N), where the value at 62 | a given location is the integer denoting the class index. 63 | """ 64 | mask = mask.astype(int) 65 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 66 | for ii, label in enumerate(get_pascal_labels()): 67 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 68 | label_mask = label_mask.astype(int) 69 | return label_mask 70 | 71 | 72 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 73 | rgb_masks = [] 74 | for label_mask in label_masks: 75 | rgb_mask = decode_segmap(label_mask, dataset) 76 | rgb_masks.append(rgb_mask) 77 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 78 | return rgb_masks 79 | 80 | def decode_segmap(label_mask, dataset, plot=False): 81 | """Decode segmentation class labels into a color image 82 | Args: 83 | label_mask (np.ndarray): an (M,N) array of integer values denoting 84 | the class label at each spatial location. 85 | plot (bool, optional): whether to show the resulting color image 86 | in a figure. 87 | Returns: 88 | (np.ndarray, optional): the resulting decoded color image. 89 | """ 90 | if dataset == 'pascal': 91 | n_classes = 21 92 | label_colours = get_pascal_labels() 93 | elif dataset == 'cityscapes': 94 | n_classes = 19 95 | label_colours = get_cityscapes_labels() 96 | else: 97 | raise NotImplementedError 98 | 99 | r = label_mask.copy() 100 | g = label_mask.copy() 101 | b = label_mask.copy() 102 | for ll in range(0, n_classes): 103 | r[label_mask == ll] = label_colours[ll, 0] 104 | g[label_mask == ll] = label_colours[ll, 1] 105 | b[label_mask == ll] = label_colours[ll, 2] 106 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 107 | rgb[:, :, 0] = r / 255.0 108 | rgb[:, :, 1] = g / 255.0 109 | rgb[:, :, 2] = b / 255.0 110 | if plot: 111 | plt.imshow(rgb) 112 | plt.show() 113 | else: 114 | return rgb 115 | 116 | def generate_param_report(logfile, param): 117 | log_file = open(logfile, 'w') 118 | # for key, val in param.items(): 119 | # log_file.write(key + ':' + str(val) + '\n') 120 | log_file.write(str(param)) 121 | log_file.close() 122 | 123 | def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): 124 | n, c, h, w = logit.size() 125 | # logit = logit.permute(0, 2, 3, 1) 126 | target = target.squeeze(1) 127 | if weight is None: 128 | criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) 129 | else: 130 | criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=False) 131 | loss = criterion(logit, target.long()) 132 | 133 | if size_average: 134 | loss /= (h * w) 135 | 136 | if batch_average: 137 | loss /= n 138 | 139 | return loss 140 | 141 | def lr_poly(base_lr, iter_, max_iter=100, power=0.9): 142 | return base_lr * ((1 - float(iter_) / max_iter) ** power) 143 | 144 | 145 | def get_iou(pred, gt, n_classes=21): 146 | total_iou = 0.0 147 | for i in range(len(pred)): 148 | pred_tmp = pred[i] 149 | gt_tmp = gt[i] 150 | 151 | intersect = [0] * n_classes 152 | union = [0] * n_classes 153 | for j in range(n_classes): 154 | match = (pred_tmp == j) + (gt_tmp == j) 155 | 156 | it = torch.sum(match == 2).item() 157 | un = torch.sum(match > 0).item() 158 | 159 | intersect[j] += it 160 | union[j] += un 161 | 162 | iou = [] 163 | for k in range(n_classes): 164 | if union[k] == 0: 165 | continue 166 | iou.append(intersect[k] / union[k]) 167 | 168 | img_iou = (sum(iou) / len(iou)) 169 | total_iou += img_iou 170 | 171 | return total_iou 172 | 173 | def get_dice(pred, gt): 174 | total_dice = 0.0 175 | pred = pred.long() 176 | gt = gt.long() 177 | for i in range(len(pred)): 178 | pred_tmp = pred[i] 179 | gt_tmp = gt[i] 180 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 181 | print(dice) 182 | total_dice += dice 183 | 184 | return total_dice 185 | 186 | def get_mc_dice(pred, gt, num=2): 187 | # num is the total number of classes, include the background 188 | total_dice = np.zeros(num-1) 189 | pred = pred.long() 190 | gt = gt.long() 191 | for i in range(len(pred)): 192 | for j in range(1, num): 193 | pred_tmp = (pred[i]==j) 194 | gt_tmp = (gt[i]==j) 195 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 196 | total_dice[j-1] +=dice 197 | return total_dice 198 | 199 | def post_processing(prediction): 200 | prediction = nd.binary_fill_holes(prediction) 201 | label_cc, num_cc = measure.label(prediction,return_num=True) 202 | total_cc = np.sum(prediction) 203 | measure.regionprops(label_cc) 204 | for cc in range(1,num_cc+1): 205 | single_cc = (label_cc==cc) 206 | single_vol = np.sum(single_cc) 207 | if single_vol/total_cc<0.2: 208 | prediction[single_cc]=0 209 | 210 | return prediction 211 | -------------------------------------------------------------------------------- /code_oa/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.spectral_norm import spectral_norm 4 | 5 | class MaskDiscriminator_SN(nn.Module): 6 | def __init__(self, num_classes=2, ndf=32): 7 | super(MaskDiscriminator_SN, self).__init__() 8 | self.conv0 = spectral_norm(nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)) 9 | self.conv1 = spectral_norm(nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)) 10 | self.conv2 = spectral_norm(nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)) 11 | self.conv3 = spectral_norm(nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)) 12 | self.conv4 = spectral_norm(nn.Conv2d(ndf*8, ndf*16, kernel_size=4, stride=2, padding=1)) 13 | self.classifier = spectral_norm(nn.Linear(ndf*16, 1)) 14 | self.avgpool = nn.AvgPool2d((7, 7)) 15 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 16 | self.dropout = nn.Dropout2d(0.5) 17 | # self.tanh = nn.Tanh() 18 | 19 | def forward(self, map): 20 | x = self.conv0(map) 21 | x = self.leaky_relu(x) 22 | x = self.dropout(x) 23 | x = self.conv1(x) 24 | x = self.leaky_relu(x) 25 | x = self.dropout(x) 26 | x = self.conv2(x) 27 | x = self.leaky_relu(x) 28 | x = self.dropout(x) 29 | x = self.conv3(x) 30 | x = self.leaky_relu(x) 31 | x = self.dropout(x) 32 | x = self.conv4(x) 33 | x = self.leaky_relu(x) 34 | x = self.avgpool(x) 35 | x = x.view(x.size(0), -1) 36 | x = self.classifier(x) 37 | # x = self.tanh(x) 38 | 39 | return x 40 | 41 | class MaskDiscriminator(nn.Module): 42 | def __init__(self, num_classes=2, ndf=16): 43 | super(MaskDiscriminator, self).__init__() 44 | self.conv0 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 45 | self.conv1 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 46 | self.conv2 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 47 | self.conv3 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 48 | self.conv4 = nn.Conv2d(ndf*8, ndf*16, kernel_size=4, stride=2, padding=1) 49 | self.classifier = nn.Linear(ndf*16, 1) 50 | self.avgpool = nn.AvgPool2d((7, 7)) 51 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 52 | self.dropout = nn.Dropout2d(0.5) 53 | self.sigmoid = nn.Sigmoid() 54 | 55 | def forward(self, map): 56 | x = self.conv0(map) 57 | x = self.leaky_relu(x) 58 | x = self.dropout(x) 59 | x = self.conv1(x) 60 | x = self.leaky_relu(x) 61 | x = self.dropout(x) 62 | x = self.conv2(x) 63 | x = self.leaky_relu(x) 64 | x = self.dropout(x) 65 | x = self.conv3(x) 66 | x = self.leaky_relu(x) 67 | x = self.dropout(x) 68 | x = self.conv4(x) 69 | x = self.leaky_relu(x) 70 | x = self.avgpool(x) 71 | x = x.view(x.size(0), -1) 72 | x = self.classifier(x) 73 | x = self.sigmoid(x) 74 | 75 | return x 76 | 77 | def create_circle_tensor(size): 78 | y, x = torch.meshgrid([torch.arange(size), torch.arange(size)]) 79 | distance = torch.sqrt((x - size // 2) ** 2 + (y - size // 2) ** 2) 80 | circle_tensor = (distance <= size // 4).float().unsqueeze(0) 81 | return circle_tensor 82 | 83 | def one_hot_encoder(n_classes, input_tensor): 84 | tensor_list = [] 85 | 86 | for i in range(n_classes): 87 | temp_prob = (input_tensor == i).unsqueeze(1) 88 | tensor_list.append(temp_prob) 89 | 90 | output_tensor = torch.cat(tensor_list, dim=1) 91 | 92 | return output_tensor.float() 93 | if __name__ == "__main__": 94 | # 示例 95 | model = MaskDiscriminator() 96 | print(model) 97 | map_input = torch.randn(1, 2, 384, 384) 98 | circle_tensor = create_circle_tensor(size=384) 99 | circle_tensor = one_hot_encoder(n_classes=2, input_tensor=circle_tensor) 100 | print(circle_tensor.shape) 101 | save_mode_path = (r'F:\SFDA\model\model_D\3T_al_labeled\Discriminator_best_model.pth') 102 | model.load_state_dict(torch.load(save_mode_path)) 103 | output = model(circle_tensor) 104 | # output = model(map_input) 105 | print("output:", output) 106 | -------------------------------------------------------------------------------- /code_oa/networks/unet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | The implementation is borrowed from: https://github.com/HiLab-git/PyMIC 4 | """ 5 | from __future__ import division, print_function 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.distributions.uniform import Uniform 11 | 12 | 13 | def kaiming_normal_init_weight(model): 14 | for m in model.modules(): 15 | if isinstance(m, nn.Conv3d): 16 | torch.nn.init.kaiming_normal_(m.weight) 17 | elif isinstance(m, nn.BatchNorm3d): 18 | m.weight.data.fill_(1) 19 | m.bias.data.zero_() 20 | return model 21 | 22 | 23 | def sparse_init_weight(model): 24 | for m in model.modules(): 25 | if isinstance(m, nn.Conv3d): 26 | torch.nn.init.sparse_(m.weight, sparsity=0.1) 27 | elif isinstance(m, nn.BatchNorm3d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | return model 31 | 32 | 33 | class ConvBlock(nn.Module): 34 | """two convolution layers with batch norm and leaky relu""" 35 | 36 | def __init__(self, in_channels, out_channels, dropout_p): 37 | super(ConvBlock, self).__init__() 38 | self.conv_conv = nn.Sequential( 39 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 40 | nn.BatchNorm2d(out_channels), 41 | nn.LeakyReLU(), 42 | nn.Dropout(dropout_p), 43 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 44 | nn.BatchNorm2d(out_channels), 45 | nn.LeakyReLU() 46 | ) 47 | 48 | def forward(self, x): 49 | return self.conv_conv(x) 50 | 51 | 52 | class DownBlock(nn.Module): 53 | """Downsampling followed by ConvBlock""" 54 | 55 | def __init__(self, in_channels, out_channels, dropout_p): 56 | super(DownBlock, self).__init__() 57 | self.maxpool_conv = nn.Sequential( 58 | nn.MaxPool2d(2), 59 | ConvBlock(in_channels, out_channels, dropout_p) 60 | 61 | ) 62 | 63 | def forward(self, x): 64 | return self.maxpool_conv(x) 65 | 66 | 67 | class UpBlock(nn.Module): 68 | """Upssampling followed by ConvBlock""" 69 | 70 | def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, 71 | bilinear=True): 72 | super(UpBlock, self).__init__() 73 | self.bilinear = bilinear 74 | if bilinear: 75 | self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1) 76 | self.up = nn.Upsample( 77 | scale_factor=2, mode='bilinear', align_corners=True) 78 | else: 79 | self.up = nn.ConvTranspose2d( 80 | in_channels1, in_channels2, kernel_size=2, stride=2) 81 | self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) 82 | 83 | def forward(self, x1, x2): 84 | if self.bilinear: 85 | x1 = self.conv1x1(x1) 86 | x1 = self.up(x1) 87 | x = torch.cat([x2, x1], dim=1) 88 | return self.conv(x) 89 | 90 | 91 | class Encoder(nn.Module): 92 | def __init__(self, params): 93 | super(Encoder, self).__init__() 94 | self.params = params 95 | self.in_chns = self.params['in_chns'] 96 | self.ft_chns = self.params['feature_chns'] 97 | self.n_class = self.params['class_num'] 98 | self.bilinear = self.params['bilinear'] 99 | self.dropout = self.params['dropout'] 100 | assert (len(self.ft_chns) == 5) 101 | self.in_conv = ConvBlock( 102 | self.in_chns, self.ft_chns[0], self.dropout[0]) 103 | self.down1 = DownBlock( 104 | self.ft_chns[0], self.ft_chns[1], self.dropout[1]) 105 | self.down2 = DownBlock( 106 | self.ft_chns[1], self.ft_chns[2], self.dropout[2]) 107 | self.down3 = DownBlock( 108 | self.ft_chns[2], self.ft_chns[3], self.dropout[3]) 109 | self.down4 = DownBlock( 110 | self.ft_chns[3], self.ft_chns[4], self.dropout[4]) 111 | 112 | def forward(self, x): 113 | x0 = self.in_conv(x) 114 | x1 = self.down1(x0) 115 | x2 = self.down2(x1) 116 | x3 = self.down3(x2) 117 | x4 = self.down4(x3) 118 | return [x0, x1, x2, x3, x4] 119 | 120 | 121 | class Decoder(nn.Module): 122 | def __init__(self, params): 123 | super(Decoder, self).__init__() 124 | self.params = params 125 | self.in_chns = self.params['in_chns'] 126 | self.ft_chns = self.params['feature_chns'] 127 | self.n_class = self.params['class_num'] 128 | self.bilinear = self.params['bilinear'] 129 | assert (len(self.ft_chns) == 5) 130 | 131 | self.up1 = UpBlock( 132 | self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) 133 | self.up2 = UpBlock( 134 | self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) 135 | self.up3 = UpBlock( 136 | self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) 137 | self.up4 = UpBlock( 138 | self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) 139 | 140 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, 141 | kernel_size=3, padding=1) 142 | 143 | def forward(self, feature): 144 | x0 = feature[0] 145 | x1 = feature[1] 146 | x2 = feature[2] 147 | x3 = feature[3] 148 | x4 = feature[4] 149 | 150 | x = self.up1(x4, x3) 151 | feature1 = x 152 | x = self.up2(x, x2) 153 | feature2 = x 154 | x = self.up3(x, x1) 155 | feature3 = x 156 | x = self.up4(x, x0) 157 | feature4 = x 158 | output = self.out_conv(x) 159 | return output, [feature1, feature2, feature3, feature4] 160 | 161 | 162 | class Decoder_DS(nn.Module): 163 | def __init__(self, params): 164 | super(Decoder_DS, self).__init__() 165 | self.params = params 166 | self.in_chns = self.params['in_chns'] 167 | self.ft_chns = self.params['feature_chns'] 168 | self.n_class = self.params['class_num'] 169 | self.bilinear = self.params['bilinear'] 170 | assert (len(self.ft_chns) == 5) 171 | 172 | self.up1 = UpBlock( 173 | self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) 174 | self.up2 = UpBlock( 175 | self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) 176 | self.up3 = UpBlock( 177 | self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) 178 | self.up4 = UpBlock( 179 | self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) 180 | 181 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, 182 | kernel_size=3, padding=1) 183 | self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class, 184 | kernel_size=3, padding=1) 185 | self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class, 186 | kernel_size=3, padding=1) 187 | self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class, 188 | kernel_size=3, padding=1) 189 | self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class, 190 | kernel_size=3, padding=1) 191 | 192 | def forward(self, feature, shape): 193 | x0 = feature[0] 194 | x1 = feature[1] 195 | x2 = feature[2] 196 | x3 = feature[3] 197 | x4 = feature[4] 198 | x = self.up1(x4, x3) 199 | dp3_out_seg = self.out_conv_dp3(x) 200 | dp3_out_seg = torch.nn.functional.interpolate(dp3_out_seg, shape) 201 | 202 | x = self.up2(x, x2) 203 | dp2_out_seg = self.out_conv_dp2(x) 204 | dp2_out_seg = torch.nn.functional.interpolate(dp2_out_seg, shape) 205 | 206 | x = self.up3(x, x1) 207 | dp1_out_seg = self.out_conv_dp1(x) 208 | dp1_out_seg = torch.nn.functional.interpolate(dp1_out_seg, shape) 209 | 210 | x = self.up4(x, x0) 211 | dp0_out_seg = self.out_conv(x) 212 | return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg 213 | 214 | 215 | class Decoder_URPC(nn.Module): 216 | def __init__(self, params): 217 | super(Decoder_URPC, self).__init__() 218 | self.params = params 219 | self.in_chns = self.params['in_chns'] 220 | self.ft_chns = self.params['feature_chns'] 221 | self.n_class = self.params['class_num'] 222 | self.bilinear = self.params['bilinear'] 223 | assert (len(self.ft_chns) == 5) 224 | 225 | self.up1 = UpBlock( 226 | self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0) 227 | self.up2 = UpBlock( 228 | self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0) 229 | self.up3 = UpBlock( 230 | self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0) 231 | self.up4 = UpBlock( 232 | self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0) 233 | 234 | self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, 235 | kernel_size=3, padding=1) 236 | self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class, 237 | kernel_size=3, padding=1) 238 | self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class, 239 | kernel_size=3, padding=1) 240 | self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class, 241 | kernel_size=3, padding=1) 242 | self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class, 243 | kernel_size=3, padding=1) 244 | self.feature_noise = FeatureNoise() 245 | 246 | def forward(self, feature, shape): 247 | x0 = feature[0] 248 | x1 = feature[1] 249 | x2 = feature[2] 250 | x3 = feature[3] 251 | x4 = feature[4] 252 | x = self.up1(x4, x3) 253 | if self.training: 254 | dp3_out_seg = self.out_conv_dp3(Dropout(x, p=0.5)) 255 | else: 256 | dp3_out_seg = self.out_conv_dp3(x) 257 | dp3_out_seg = torch.nn.functional.interpolate(dp3_out_seg, shape) 258 | 259 | x = self.up2(x, x2) 260 | if self.training: 261 | dp2_out_seg = self.out_conv_dp2(FeatureDropout(x)) 262 | else: 263 | dp2_out_seg = self.out_conv_dp2(x) 264 | dp2_out_seg = torch.nn.functional.interpolate(dp2_out_seg, shape) 265 | 266 | x = self.up3(x, x1) 267 | if self.training: 268 | dp1_out_seg = self.out_conv_dp1(self.feature_noise(x)) 269 | else: 270 | dp1_out_seg = self.out_conv_dp1(x) 271 | dp1_out_seg = torch.nn.functional.interpolate(dp1_out_seg, shape) 272 | 273 | x = self.up4(x, x0) 274 | dp0_out_seg = self.out_conv(x) 275 | return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg 276 | 277 | 278 | def Dropout(x, p=0.3): 279 | x = torch.nn.functional.dropout(x, p) 280 | return x 281 | 282 | 283 | def FeatureDropout(x): 284 | attention = torch.mean(x, dim=1, keepdim=True) 285 | max_val, _ = torch.max(attention.view( 286 | x.size(0), -1), dim=1, keepdim=True) 287 | threshold = max_val * np.random.uniform(0.7, 0.9) 288 | threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) 289 | drop_mask = (attention < threshold).float() 290 | x = x.mul(drop_mask) 291 | return x 292 | 293 | 294 | class FeatureNoise(nn.Module): 295 | def __init__(self, uniform_range=0.3): 296 | super(FeatureNoise, self).__init__() 297 | self.uni_dist = Uniform(-uniform_range, uniform_range) 298 | 299 | def feature_based_noise(self, x): 300 | noise_vector = self.uni_dist.sample( 301 | x.shape[1:]).to(x.device).unsqueeze(0) 302 | x_noise = x.mul(noise_vector) + x 303 | return x_noise 304 | 305 | def forward(self, x): 306 | x = self.feature_based_noise(x) 307 | return x 308 | 309 | 310 | 311 | 312 | class UNet(nn.Module): 313 | def __init__(self, in_chns, class_num): 314 | super(UNet, self).__init__() 315 | 316 | params = {'in_chns': in_chns, 317 | 'feature_chns': [16, 32, 64, 128, 256], 318 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 319 | 'class_num': class_num, 320 | 'bilinear': False, 321 | 'acti_func': 'relu'} 322 | 323 | self.encoder = Encoder(params) 324 | self.decoder = Decoder(params) 325 | 326 | def forward(self, x): 327 | feature = self.encoder(x) 328 | output, feature_list = self.decoder(feature) 329 | return output, feature 330 | 331 | 332 | class UNet_CCT(nn.Module): 333 | def __init__(self, in_chns, class_num): 334 | super(UNet_CCT, self).__init__() 335 | 336 | params = {'in_chns': in_chns, 337 | 'feature_chns': [16, 32, 64, 128, 256], 338 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 339 | 'class_num': class_num, 340 | 'bilinear': False, 341 | 'acti_func': 'relu'} 342 | self.encoder = Encoder(params) 343 | self.main_decoder = Decoder(params) 344 | self.aux_decoder1 = Decoder(params) 345 | self.aux_decoder2 = Decoder(params) 346 | self.aux_decoder3 = Decoder(params) 347 | 348 | def forward(self, x): 349 | feature = self.encoder(x) 350 | main_seg = self.main_decoder(feature) 351 | aux1_feature = [FeatureNoise()(i) for i in feature] 352 | aux_seg1 = self.aux_decoder1(aux1_feature) 353 | aux2_feature = [Dropout(i) for i in feature] 354 | aux_seg2 = self.aux_decoder2(aux2_feature) 355 | aux3_feature = [FeatureDropout(i) for i in feature] 356 | aux_seg3 = self.aux_decoder3(aux3_feature) 357 | return main_seg, aux_seg1, aux_seg2, aux_seg3 358 | 359 | class MCNet2d_v1(nn.Module): 360 | def __init__(self, in_chns, class_num): 361 | super(MCNet2d_v1, self).__init__() 362 | 363 | params1 = {'in_chns': in_chns, 364 | 'feature_chns': [16, 32, 64, 128, 256], 365 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 366 | 'class_num': class_num, 367 | 'bilinear': False, 368 | 'acti_func': 'relu'} 369 | params2 = {'in_chns': in_chns, 370 | 'feature_chns': [16, 32, 64, 128, 256], 371 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 372 | 'class_num': class_num, 373 | 'bilinear': True, 374 | 'acti_func': 'relu'} 375 | self.encoder = Encoder(params1) 376 | self.decoder1 = Decoder(params1) 377 | self.decoder2 = Decoder(params2) 378 | 379 | def forward(self, x): 380 | feature = self.encoder(x) 381 | output1, _ = self.decoder1(feature) 382 | output2, _ = self.decoder2(feature) 383 | return output1, output2 384 | 385 | class UNet_UPL(nn.Module): 386 | def __init__(self, in_chns, class_num): 387 | super(UNet_UPL, self).__init__() 388 | 389 | params = {'in_chns': in_chns, 390 | 'feature_chns': [16, 32, 64, 128, 256], 391 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 392 | 'class_num': class_num, 393 | 'bilinear': False, 394 | 'acti_func': 'relu'} 395 | self.encoder = Encoder(params) 396 | self.decoder = Decoder(params) 397 | self.aux_decoder1 = Decoder(params) 398 | self.aux_decoder2 = Decoder(params) 399 | self.aux_decoder3 = Decoder(params) 400 | 401 | 402 | def forward(self, x): 403 | 404 | A_1 = rotate_single_with_label(x, 1) 405 | A_2 = rotate_single_with_label(x, 2) 406 | A_3 = rotate_single_with_label(x, 3) 407 | 408 | feature = self.encoder(x) 409 | feature_1 = self.encoder(A_1) 410 | feature_2 = self.encoder(A_2) 411 | feature_3 = self.encoder(A_3) 412 | 413 | main_seg = self.decoder(feature)[0].softmax(1) 414 | aux_seg1 = self.aux_decoder1(feature_1)[0].softmax(1) 415 | aux_seg2 = self.aux_decoder2(feature_2)[0].softmax(1) 416 | aux_seg3 = self.aux_decoder3(feature_3)[0].softmax(1) 417 | 418 | aux_seg1 = rotate_single_with_label(aux_seg1, 3) 419 | aux_seg2 = rotate_single_with_label(aux_seg2, 2) 420 | aux_seg3 = rotate_single_with_label(aux_seg3, 1) 421 | return main_seg, aux_seg1, aux_seg2, aux_seg3 422 | 423 | def rotate_single_with_label(img, label): 424 | def tensor_rot_90(x): 425 | x_shape = list(x.shape) 426 | if (len(x_shape) == 4): 427 | return x.flip(3).transpose(2, 3) 428 | else: 429 | return x.flip(2).transpose(1, 2) 430 | 431 | def tensor_rot_180(x): 432 | x_shape = list(x.shape) 433 | if (len(x_shape) == 4): 434 | return x.flip(3).flip(2) 435 | else: 436 | return x.flip(2).flip(1) 437 | 438 | def tensor_rot_270(x): 439 | x_shape = list(x.shape) 440 | if (len(x_shape) == 4): 441 | return x.transpose(2, 3).flip(3) 442 | else: 443 | return x.transpose(1, 2).flip(2) 444 | if label == 1: 445 | img = tensor_rot_90(img) 446 | elif label == 2: 447 | img = tensor_rot_180(img) 448 | elif label == 3: 449 | img = tensor_rot_270(img) 450 | else: 451 | img = img 452 | return img 453 | 454 | class UNet_DST(nn.Module): 455 | def __init__(self, in_chns, class_num): 456 | super(UNet_DST, self).__init__() 457 | 458 | params1 = {'in_chns': in_chns, 459 | 'feature_chns': [16, 32, 64, 128, 256], 460 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 461 | 'class_num': class_num, 462 | 'bilinear': False, 463 | 'acti_func': 'relu'} 464 | params2 = {'in_chns': in_chns, 465 | 'feature_chns': [16, 32, 64, 128, 256], 466 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 467 | 'class_num': class_num, 468 | 'bilinear': False, 469 | 'acti_func': 'relu'} 470 | self.encoder = Encoder(params1) 471 | self.decoder = Decoder(params1) 472 | self.aux_decoder = Decoder(params2) 473 | 474 | def forward(self,x): 475 | feature = self.encoder(x) 476 | main_output = self.decoder(feature) 477 | aux_output = self.aux_decoder(feature) 478 | return main_output[0], aux_output[0] 479 | 480 | class UNet_Twodec(nn.Module): 481 | def __init__(self, in_chns, class_num): 482 | super(UNet_Twodec, self).__init__() 483 | 484 | params1 = {'in_chns': in_chns, 485 | 'feature_chns': [16, 32, 64, 128, 256], 486 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 487 | 'class_num': class_num, 488 | 'bilinear': False, 489 | 'acti_func': 'relu'} 490 | params2 = {'in_chns': in_chns, 491 | 'feature_chns': [16, 32, 64, 128, 256], 492 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 493 | 'class_num': class_num, 494 | 'bilinear': False, 495 | 'acti_func': 'relu'} 496 | self.encoder = Encoder(params1) 497 | self.decoder = Decoder(params1) 498 | self.aux_decoder = Decoder(params2) 499 | 500 | def forward(self,x): 501 | feature = self.encoder(x) 502 | main_output = self.decoder(feature) 503 | aux_output = self.aux_decoder(feature) 504 | return main_output[0], aux_output[0] 505 | 506 | class UNet_URPC(nn.Module): 507 | def __init__(self, in_chns, class_num): 508 | super(UNet_URPC, self).__init__() 509 | 510 | params = {'in_chns': in_chns, 511 | 'feature_chns': [16, 32, 64, 128, 256], 512 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 513 | 'class_num': class_num, 514 | 'bilinear': False, 515 | 'acti_func': 'relu'} 516 | self.encoder = Encoder(params) 517 | self.decoder = Decoder_URPC(params) 518 | 519 | def forward(self, x): 520 | shape = x.shape[2:] 521 | feature = self.encoder(x) 522 | dp1_out_seg, dp2_out_seg, dp3_out_seg, dp4_out_seg = self.decoder( 523 | feature, shape) 524 | return dp1_out_seg, dp2_out_seg, dp3_out_seg, dp4_out_seg 525 | 526 | 527 | class UNet_DS(nn.Module): 528 | def __init__(self, in_chns, class_num): 529 | super(UNet_DS, self).__init__() 530 | params = {'in_chns': in_chns, 531 | 'feature_chns': [16, 32, 64, 128, 256], 532 | 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 533 | 'class_num': class_num, 534 | 'bilinear': False, 535 | 'acti_func': 'relu'} 536 | self.encoder = Encoder(params) 537 | self.decoder = Decoder_DS(params) 538 | 539 | def forward(self, x): 540 | shape = x.shape[2:] 541 | feature = self.encoder(x) 542 | dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg = self.decoder( 543 | feature, shape) 544 | return dp0_out_seg, dp1_out_seg, dp2_out_seg, dp3_out_seg 545 | 546 | if __name__ == "__main__": 547 | model = UNet(in_chns=1, class_num=2) 548 | model.train() 549 | map_input = torch.randn(1, 1, 384, 384) 550 | pre, feature = model(map_input) 551 | print(feature[-1].shape) -------------------------------------------------------------------------------- /code_oa/networks/unet3d.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from dynamic_network_architectures.architectures.unet import PlainConvUNet 4 | 5 | model = PlainConvUNet( 6 | input_channels=1, 7 | n_stages=6, 8 | features_per_stage=[32, 64, 128, 256, 320, 320], 9 | conv_op=torch.nn.Conv3d, 10 | kernel_sizes=[[1, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], 11 | strides=[[1, 1, 1], [1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2]], 12 | n_conv_per_stage=[2, 2, 2, 2, 2, 2], 13 | num_classes=2, 14 | n_conv_per_stage_decoder=[2, 2, 2, 2, 2], 15 | conv_bias=True, 16 | norm_op=torch.nn.InstanceNorm3d, 17 | norm_op_kwargs={"eps": 1e-05, "affine": True}, 18 | nonlin=torch.nn.LeakyReLU, 19 | nonlin_kwargs={"inplace": True} 20 | ) 21 | 22 | print(model) 23 | -------------------------------------------------------------------------------- /code_oa/selection.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import h5py 5 | from tqdm import tqdm 6 | from networks.unet import UNet 7 | from sklearn.cluster import KMeans 8 | from sklearn.metrics import pairwise_distances_argmin_min 9 | from itertools import combinations 10 | import torch.nn as nn 11 | import torch 12 | import numpy as np 13 | from tools import * 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--root_path', type=str, 16 | default='../data/data_preprocessed/NPC_SMU/SMU', help='Name of Experiment') 17 | parser.add_argument('--exp', type=str, 18 | default='NPC/source_train', help='experiment_name') 19 | parser.add_argument('--model', type=str, 20 | default='UNet', help='model_name') 21 | parser.add_argument('--num_classes', type=int, default=2, 22 | help='output channel of network') 23 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 24 | parser.add_argument('--tta_num', type=int, default=8, 25 | help='test time augmentation') 26 | parser.add_argument('--spatial_aug', type=bool, default=False) 27 | parser.add_argument('--intensity_aug', type=bool, default=True) 28 | parser.add_argument('--C', type=int, default=4, help='Hyperparameter: capacity of Dtu') 29 | 30 | 31 | def cluster_and_select_samples(combined_data, k=3): 32 | features = [data[2] for data in combined_data] 33 | X = np.array(features) 34 | kmeans = KMeans(n_clusters=k, init='k-means++') 35 | kmeans.fit(X) 36 | centroids = kmeans.cluster_centers_ 37 | closest, _ = pairwise_distances_argmin_min(X, centroids) 38 | selected_samples = [] 39 | for cluster_idx in range(k): 40 | cluster_samples = [combined_data[idx] for idx, cl_idx in enumerate(closest) if cl_idx == cluster_idx] 41 | closest_sample_idx = np.argmin([np.linalg.norm(sample[2] - centroids[cluster_idx]) for sample in cluster_samples]) 42 | selected_samples.append(cluster_samples[closest_sample_idx]) 43 | return selected_samples 44 | 45 | def predict_with_tta_for_uncertainty_selection(image, net, output_path, parser, ratio=0.05): 46 | uncertainty = [] 47 | feature_list = [] 48 | img_name_list = [] 49 | image_list = [] 50 | pseudo_label = [] 51 | label_list = [] 52 | real_size = [] 53 | size_esti = [] 54 | for case in tqdm(image): 55 | h5f = h5py.File(parser.root_path + "/{}.h5".format(case), 'r') 56 | image = h5f['image'][:] 57 | label = h5f['label'][:] 58 | m = parser.tta_num 59 | image_copy = image 60 | image = np.expand_dims(image, axis=0) 61 | image = np.repeat(image, m, axis=0) 62 | for i in range(image.shape[0]): 63 | if parser.intensity_aug: 64 | image[i, :, :, :] = intensity_augmentor(image[i, :, :, :]).cpu().numpy() 65 | for ind in range(image.shape[1]): 66 | img_name = f'{case}_slice{ind}.h5' 67 | img_name_list.append(img_name) 68 | slice = image[:, ind, :, :] 69 | params = [] 70 | input = torch.from_numpy(slice).unsqueeze( 71 | 1).float().cpu() 72 | original_image = image_copy[ind, :, :] 73 | real_label = label[ind, :, :] 74 | pixel_count_real = np.count_nonzero(real_label) 75 | real_size.append(pixel_count_real) 76 | image_list.append(original_image) 77 | label_list.append(real_label) 78 | pixel_count_real = np.count_nonzero(real_label) 79 | real_size.append(pixel_count_real) 80 | volume_batch = input 81 | original_view = torch.from_numpy(original_image).unsqueeze( 82 | 0).unsqueeze(0).float().cpu() 83 | spatial_augmentor = SpatialAugmentation(rotation_range=15) 84 | if parser.spatial_aug: 85 | for i in range(volume_batch.shape[0]): 86 | volume_batch[i, :, :, :], params_i = spatial_augmentor.augment(volume_batch[i, :, :, :]) 87 | params.append(params_i) 88 | volume_batch = volume_batch.cuda() 89 | original_view = original_view.cuda() 90 | net.eval().cuda() 91 | with torch.no_grad(): 92 | out_put, _ = net(volume_batch) 93 | out_main, features = net(original_view) 94 | features = features[-1] 95 | 96 | features_np = features.squeeze().cpu().numpy() 97 | 98 | features = features_np.flatten() 99 | feature_list.append(features) 100 | if parser.spatial_aug: 101 | for i in range(out_put.shape[0]): 102 | out_put[i, :, :, :] = spatial_augmentor.reverse_augment(out_put[i, :, :, :], params[i]) 103 | output_prob = torch.softmax(out_main, dim=1) 104 | out_put = torch.softmax(out_put, dim=1) 105 | out_put = torch.mean(out_put, dim=0, keepdim=True) 106 | pse_label = torch.argmax(out_put, dim=1).squeeze() 107 | pse_label = pse_label.cpu().numpy() 108 | pseudo_label.append(pse_label) 109 | entropy = softmax_entropy(output_prob, softmax=False) 110 | entropy_np = entropy.cpu().detach().numpy() 111 | threshold = compute_entropy_density(entropy_np) 112 | selected_points = np.where(entropy_np > threshold) 113 | GAUA_uncertainty = np.mean(entropy_np[selected_points]) 114 | uncertainty.append(GAUA_uncertainty) 115 | size_esti.append(len(selected_points[0])) 116 | 117 | budget = int(len(image_list) * ratio) 118 | print('The number of labeled slices is:', budget) 119 | combined_data = list(zip(img_name_list, uncertainty, feature_list)) 120 | combined_data.sort(key=lambda x: x[1], reverse=True) 121 | uncertainty_selected_samples = combined_data[:budget * parser.C] 122 | method = 'UGTST' 123 | selected_samples = cluster_and_select_samples(uncertainty_selected_samples, k=budget) 124 | selected_img_names = [sample[0] for sample in selected_samples] 125 | 126 | all_data = list(zip(img_name_list, image_list, pseudo_label, label_list, uncertainty)) 127 | for i, data in enumerate(all_data): 128 | img_name, image, pseudo_label, true_label, uncertainty = data 129 | if img_name in selected_img_names: 130 | all_data[i] = (img_name, image, true_label, -1) 131 | else: 132 | all_data[i] = (img_name, image, pseudo_label, uncertainty) 133 | all_data = [(img_name, image, pseudo_label, uncertainty) for img_name, image, pseudo_label, uncertainty in all_data] 134 | all_data.sort(key=lambda x: x[3]) 135 | 136 | Dts_percent = int(len(all_data) * (1 - ratio * (parser.C - 1))) 137 | Dts_img_names = [data[0] for data in all_data[:Dts_percent]] 138 | with open(f"{output_path}/stage1_slice_{method}.txt", "w") as f: 139 | for name in Dts_img_names: 140 | f.write(name + '\n') 141 | 142 | all_img_names = [data[0] for data in all_data] 143 | with open(f"{output_path}/all_slice_{method}.txt", "w") as f: 144 | for name in all_img_names: 145 | f.write(name + '\n') 146 | os.makedirs(os.path.join(output_path, 'slice_pseudo'), exist_ok=True) 147 | for img_name, image, label, jsd in all_data: 148 | hdf5_file_path = os.path.join(output_path, 149 | f'slices_pseudo/{img_name}') 150 | with h5py.File(hdf5_file_path, "w") as f: 151 | f.create_dataset('image', data=image, compression="gzip") 152 | f.create_dataset('label', data=label, compression="gzip") 153 | with open(os.path.join(output_path, f'selection_{method}.txt'), 'w') as f: 154 | for item in selected_samples: 155 | f.write(f"{item[0]}: {item[1]}\n") 156 | 157 | 158 | 159 | if __name__ == '__main__': 160 | parser = parser.parse_args() 161 | random.seed(parser.seed) 162 | np.random.seed(parser.seed) 163 | torch.manual_seed(parser.seed) 164 | torch.cuda.manual_seed(parser.seed) 165 | with open(parser.root_path + f'/trainlist.txt', 'r') as f: 166 | image_list = f.readlines() 167 | image_list = sorted([item.replace('\n', '').split(".")[0] 168 | for item in image_list]) 169 | snapshot_path = "../model/{}/".format(parser.exp) 170 | net = UNet(in_chns=1, class_num=parser.num_classes) 171 | save_mode_path = os.path.join( 172 | snapshot_path, 'UNet_best_model.pth') 173 | net.load_state_dict(torch.load(save_mode_path)) 174 | print("init weight from {}".format(save_mode_path)) 175 | net.eval() 176 | predict_with_tta_for_uncertainty_selection(image_list, net, output_path=parser.root_path, parser=parser) 177 | -------------------------------------------------------------------------------- /code_oa/test_2D_fully.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import h5py 6 | import nibabel as nib 7 | import numpy as np 8 | import SimpleITK as sitk 9 | import torch 10 | from medpy import metric 11 | from scipy.ndimage import zoom 12 | from scipy.ndimage.interpolation import zoom 13 | from tqdm import tqdm 14 | from networks.unet import UNet, UNet_DST, UNet_UPL 15 | from scipy import ndimage 16 | # from networks.efficientunet import UNet, 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--root_path', type=str, 19 | default='../data/data_preprocessed/NPC_SMU/SMU', help='Name of Experiment') 20 | parser.add_argument('--exp', type=str, 21 | default='NPC_ablation/SMU_UGTST+_5%_tst_with_cons', help='experiment_name') 22 | parser.add_argument('--model', type=str, 23 | default='UNet', help='model_name') 24 | parser.add_argument('--num_classes', type=int, default=2, 25 | help='output channel of network') 26 | parser.add_argument('--largest_component', type=bool, default=True, 27 | help='get the largest component') 28 | parser.add_argument('--target_set', type=str, default='val', 29 | help='target_set') 30 | 31 | def get_largest_component(image): 32 | dim = len(image.shape) 33 | if(image.sum() == 0 ): 34 | # print('the largest component is null') 35 | return image 36 | if(dim == 2): 37 | s = ndimage.generate_binary_structure(2,1) 38 | elif(dim == 3): 39 | s = ndimage.generate_binary_structure(3,1) 40 | else: 41 | raise ValueError("the dimension number should be 2 or 3") 42 | labeled_array, numpatches = ndimage.label(image, s) 43 | sizes = ndimage.sum(image, labeled_array, range(1, numpatches + 1)) 44 | max_label = np.where(sizes == sizes.max())[0] + 1 45 | output = np.asarray(labeled_array == max_label, np.uint8) 46 | return output 47 | 48 | def calculate_metric_percase(pred, gt, spacing): 49 | pred[pred > 0] = 1 50 | gt[gt > 0] = 1 51 | dice = metric.binary.dc(pred, gt) 52 | asd = metric.binary.assd(pred, gt, spacing) 53 | hd95 = metric.binary.hd95(pred, gt, spacing) 54 | return dice, asd, hd95 55 | 56 | 57 | def test_single_volume(case, net, test_save_path, FLAGS): 58 | h5f = h5py.File(FLAGS.root_path + "/{}.h5".format(case), 'r') 59 | image = h5f['image'][:] 60 | label = h5f['label'][:] 61 | org_spacing = h5f['spacing'][:] 62 | 63 | spacing = [org_spacing[2], org_spacing[0], org_spacing[1]] 64 | prediction = np.zeros_like(label) 65 | for ind in range(image.shape[0]): 66 | slice = image[ind, :, :] 67 | x, y = slice.shape[0], slice.shape[1] 68 | slice = zoom(slice, (320 / x, 320 / y), order=0) 69 | input = torch.from_numpy(slice).unsqueeze( 70 | 0).unsqueeze(0).float().cuda() 71 | net.eval().cuda() 72 | with torch.no_grad(): 73 | if FLAGS.model == "unet_urds": 74 | out_main, _, _, _ = net(input) 75 | elif FLAGS.model == "UNet_UPL": 76 | outputs1, outputs2, outputs3, outputs4 = net(input) 77 | out_main = (outputs1 + outputs2 + outputs3 + outputs4) / 4.0 78 | else: 79 | out_main, _ = net(input) 80 | out = torch.argmax(torch.softmax( 81 | out_main, dim=1), dim=1).squeeze(0) 82 | out = out.cpu().detach().numpy() 83 | pred = zoom(out, (x / 320, y / 320), order=0) 84 | prediction[ind] = pred 85 | if FLAGS.largest_component: 86 | prediction = get_largest_component(prediction) 87 | first_metric = calculate_metric_percase(prediction == 1, label == 1, spacing=spacing) 88 | # second_metric = calculate_metric_percase(prediction == 2, label == 2) 89 | # third_metric = calculate_metric_percase(prediction == 3, label == 3) 90 | 91 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 92 | 93 | img_itk.SetSpacing(org_spacing) 94 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 95 | prd_itk.SetSpacing(org_spacing) 96 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 97 | lab_itk.SetSpacing(org_spacing) 98 | sitk.WriteImage(prd_itk, test_save_path + case + "_pred.nii.gz") 99 | sitk.WriteImage(img_itk, test_save_path + case + "_img.nii.gz") 100 | sitk.WriteImage(lab_itk, test_save_path + case + "_gt.nii.gz") 101 | return first_metric 102 | 103 | 104 | def Inference(FLAGS): 105 | 106 | with open(FLAGS.root_path + f'/{FLAGS.target_set}list.txt', 'r') as f: 107 | image_list = f.readlines() 108 | image_list = sorted([item.replace('\n', '').split(".")[0] 109 | for item in image_list]) 110 | snapshot_path = "../model/{}/".format(FLAGS.exp) 111 | pre = os.path.basename(FLAGS.root_path) 112 | test_save_path = "../model/{}/{}_{}_{}_predictions/".format( 113 | FLAGS.exp, FLAGS.model, pre, FLAGS.target_set) 114 | if os.path.exists(test_save_path): 115 | shutil.rmtree(test_save_path) 116 | os.makedirs(test_save_path) 117 | if FLAGS.model == 'UNet_UPL': 118 | net = UNet_UPL(in_chns=1, class_num=FLAGS.num_classes) 119 | elif FLAGS.model == 'UNet_DST': 120 | net = UNet_DST(in_chns=1, class_num=FLAGS.num_classes) 121 | else: 122 | net = UNet(in_chns=1, class_num=FLAGS.num_classes) 123 | save_mode_path = os.path.join( 124 | snapshot_path, f'UNet_best_model.pth') 125 | net.load_state_dict(torch.load(save_mode_path)) 126 | print("init weight from {}".format(save_mode_path)) 127 | net.eval() 128 | 129 | dice_scores = [] 130 | assd_values = [] 131 | hd95_values = [] 132 | 133 | with open(os.path.join(test_save_path, 'metrics.txt'), 'w') as f: 134 | f.write("Case\tDice\tHD95\tASSD\n") 135 | for case in tqdm(image_list): 136 | metrics = test_single_volume(case, net, test_save_path, FLAGS) 137 | dice_scores.append(metrics[0]) 138 | hd95_values.append(metrics[2]) 139 | assd_values.append(metrics[1]) 140 | f.write(f"{case}\t{metrics[0]}\t{metrics[2]}\t{metrics[1]}\n") 141 | 142 | avg_dice = np.mean(dice_scores) 143 | std_dice = np.std(dice_scores) 144 | 145 | avg_hd95 = np.mean(hd95_values) 146 | std_hd95 = np.std(hd95_values) 147 | 148 | avg_asd = np.mean(assd_values) 149 | std_asd = np.std(assd_values) 150 | 151 | avg_metrics = { 152 | 'avg_dice': avg_dice, 153 | 'std_dice': std_dice, 154 | 'avg_hd95': avg_hd95, 155 | 'std_hd95': std_hd95, 156 | 'avg_assd': avg_asd, 157 | 'std_assd': std_asd, 158 | } 159 | formatted_metrics = { 160 | 'dice': f'{avg_metrics["avg_dice"] * 100:.2f}' + f'±{avg_metrics["std_dice"] * 100:.2f}', 161 | 'hd95': f'{avg_metrics["avg_hd95"]:.2f}±{avg_metrics["std_hd95"]:.2f}', 162 | 'assd': f'{avg_metrics["avg_assd"]:.2f}±{avg_metrics["std_assd"]:.2f}', 163 | } 164 | print("Formatted Metrics:") 165 | for key, value in formatted_metrics.items(): 166 | print(f'{key}: {value}') 167 | 168 | with open(os.path.join(test_save_path, 'overall_metrics.txt'), 'w') as f: 169 | for key, value in avg_metrics.items(): 170 | f.write(f'{key}: {value}\n') 171 | for key, value in formatted_metrics.items(): 172 | f.write(f'{key}: {value}\n') 173 | 174 | return avg_metrics 175 | 176 | if __name__ == '__main__': 177 | FLAGS = parser.parse_args() 178 | metric = Inference(FLAGS) 179 | print(metric) 180 | # print((metric[0]+metric[1]+metric[2])/3) 181 | -------------------------------------------------------------------------------- /code_oa/tools.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torchvision.transforms.functional 3 | import numpy as np 4 | import torch 5 | import random 6 | from scipy.ndimage import gaussian_filter 7 | from sklearn.cluster import KMeans 8 | from sklearn.metrics import pairwise_distances_argmin_min 9 | from scipy.signal import find_peaks 10 | 11 | def embedding(feature, mask): 12 | last_feature = feature[-1] 13 | mask = torch.nn.AdaptiveAvgPool2d(mask, output_size=(128, 128)).argmax(1) 14 | embedding = torch.nn.MaxPool2d() 15 | 16 | 17 | def gaussian_noise(image): 18 | mean = 0 19 | std = 0.05 20 | noise = np.random.normal(mean, std, image.shape) 21 | image = image + noise 22 | return image 23 | 24 | 25 | def gaussian_blur(image): 26 | std_range = [0, 1] 27 | std = np.random.uniform(std_range[0], std_range[1]) 28 | image = gaussian_filter(image, std, order=0) 29 | return image 30 | 31 | 32 | def clip_filter(image, lower_percentile=0.5, upper_percentile=99.5): 33 | lower_limit = np.percentile(image, lower_percentile) 34 | upper_limit = np.percentile(image, upper_percentile) 35 | 36 | clipped_image = np.clip(image, lower_limit, upper_limit) 37 | v_min = clipped_image.min() 38 | v_max = clipped_image.max() 39 | clipped_image = (clipped_image - v_min) / (v_max - v_min) 40 | return clipped_image 41 | 42 | 43 | def gammacorrection(image): 44 | gamma_min, gamma_max = 0.7, 1.5 45 | flip_prob = 0.5 46 | gamma_c = random.random() * (gamma_max - gamma_min) + gamma_min 47 | v_min = image.min() 48 | v_max = image.max() 49 | if (v_min < v_max): 50 | image = (image - v_min) / (v_max - v_min) 51 | if (np.random.uniform() < flip_prob): 52 | image = 1.0 - image 53 | image = np.power(image, gamma_c) * (v_max - v_min) + v_min 54 | image = image 55 | return image 56 | 57 | 58 | def contrastaug(image): 59 | contrast_range = [0.8, 1.2] 60 | preserve = True 61 | factor = np.random.uniform(contrast_range[0], contrast_range[1]) 62 | mean = image.mean() 63 | if preserve: 64 | minm = image.min() 65 | maxm = image.max() 66 | image = (image - mean) * factor + mean 67 | image[image < minm] = minm 68 | image[image > maxm] = maxm 69 | 70 | return image 71 | 72 | 73 | import numpy as np 74 | import random 75 | import matplotlib.pyplot as plt 76 | try: 77 | from scipy.special import comb 78 | except: 79 | from scipy.misc import comb 80 | """ 81 | this is for none linear transformation 82 | 83 | 84 | """ 85 | 86 | 87 | # bernstein_poly(i, n, t):计算伯恩斯坦多项式,其中 i 为多项式的次数,n 为多项式的阶数,t 为参数化值。该函数用于计算贝塞尔曲线中的权重系数。 88 | def bernstein_poly(i, n, t): 89 | """ 90 | The Bernstein polynomial of n, i as a function of t 91 | """ 92 | return comb(n, i) * ( t**(n-i) ) * (1 - t)**i 93 | 94 | 95 | def bezier_curve(points, nTimes=1000): 96 | """ 97 | Given a set of control points, return the 98 | bezier curve defined by the control points. 99 | Control points should be a list of lists, or list of tuples 100 | such as [ [1,1], 101 | [2,3], 102 | [4,5], ..[Xn, Yn] ] 103 | nTimes is the number of time steps, defaults to 1000 104 | See http://processingjs.nihongoresources.com/bezierinfo/ 105 | """ 106 | 107 | nPoints = len(points) 108 | xPoints = np.array([p[0] for p in points]) 109 | yPoints = np.array([p[1] for p in points]) 110 | 111 | t = np.linspace(0.0, 1.0, nTimes) 112 | 113 | polynomial_array = np.array([bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)]) 114 | 115 | xvals = np.dot(xPoints, polynomial_array) 116 | yvals = np.dot(yPoints, polynomial_array) 117 | 118 | return xvals, yvals 119 | 120 | 121 | def nonlinear_transformation(x): 122 | points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] 123 | xvals, yvals = bezier_curve(points, nTimes=1000) 124 | if random.random() < 0.5: 125 | # Half change to get flip 126 | xvals = np.sort(xvals) 127 | else: 128 | xvals, yvals = np.sort(xvals), np.sort(yvals) 129 | nonlinear_x = np.interp(x, xvals, yvals) 130 | return nonlinear_x 131 | 132 | def intensity_augmentor(sample): 133 | # Check if input is a tensor, if so, convert to NumPy (on CPU) for augmentation 134 | if isinstance(sample, torch.Tensor): 135 | sample = sample.cpu().numpy() # Move tensor to CPU and convert to NumPy 136 | 137 | image = sample 138 | if random.random() > 0.5: 139 | image = nonlinear_transformation(image) 140 | if random.random() > 0.5: 141 | image = contrastaug(image) 142 | if random.random() > 0.5: 143 | image = clip_filter(image) 144 | if random.random() > 0.5: 145 | image = gammacorrection(image) 146 | if random.random() > 0.5: 147 | image = gaussian_blur(image) 148 | if random.random() > 0.5: 149 | image = gaussian_noise(image) 150 | 151 | # After augmentation, convert back to Tensor and move to CUDA if necessary 152 | sample = torch.from_numpy(image).float() # Convert back to Tensor 153 | if torch.cuda.is_available(): 154 | sample = sample.cuda() # Move Tensor back to CUDA if needed 155 | 156 | return sample 157 | 158 | 159 | def softmax_confidence(preds, softmax): 160 | if softmax: 161 | preds = torch.nn.functional.softmax(preds, dim=1) 162 | CONF = torch.max(preds, 1)[0] 163 | CONF *= -1 # The small the better --> Reverse it makes it the large the better 164 | return CONF 165 | 166 | def softmax_entropy(preds, softmax=True): 167 | # Softmax Entropy 168 | if softmax: 169 | preds = torch.nn.functional.softmax(preds, dim=1) 170 | ENT = torch.sum(-preds * torch.log2(preds + 1e-12), dim=1) # The large the better 171 | return ENT 172 | 173 | 174 | class SpatialAugmentation: 175 | def __init__(self, rotation_range=0): 176 | self.rotation_range = rotation_range 177 | 178 | def augment(self, img): 179 | rotation_90 = int(np.random.choice([0])) 180 | flip = int(np.random.choice([2, 3])) 181 | rotation_angle = np.random.uniform(-self.rotation_range, self.rotation_range) 182 | # img = torch.flip(img, [flip]) 183 | img = torchvision.transforms.functional.rotate(img, rotation_90) 184 | # img = torchvision.transforms.functional.rotate(img, rotation_angle) 185 | 186 | return img, {'rotation_90': rotation_90, 'rotation_angle': rotation_angle} 187 | 188 | def reverse_augment(self, augmented_img, params): 189 | rotation_90 = -params['rotation_90'] 190 | rotation_angle = -params['rotation_angle'] 191 | # flip = params['flip'] 192 | # reversed_img = torchvision.transforms.functional.rotate(augmented_img, rotation_angle) 193 | reversed_img = torchvision.transforms.functional.rotate(augmented_img, rotation_90) 194 | # reversed_img = torch.flip(reversed_img, [flip]) 195 | 196 | return reversed_img 197 | 198 | 199 | def compute_entropy_density(entropy_np): 200 | entropy_flat = entropy_np.flatten() 201 | hist, bins = np.histogram(entropy_flat, bins=100, density=True) 202 | peaks, _ = find_peaks(hist) 203 | threshold = bins[peaks[0]] 204 | return threshold 205 | 206 | 207 | 208 | def cluster_and_select_samples(combined_data, k=3): 209 | 210 | features = [data[2] for data in combined_data] 211 | X = np.array(features) 212 | kmeans = KMeans(n_clusters=k, init='k-means++', n_init=10) 213 | kmeans.fit(X) 214 | centroids = kmeans.cluster_centers_ 215 | closest, _ = pairwise_distances_argmin_min(X, centroids) 216 | 217 | selected_samples = [] 218 | 219 | for cluster_idx in range(k): 220 | cluster_samples = [combined_data[idx] for idx, cl_idx in enumerate(closest) if cl_idx == cluster_idx] 221 | closest_sample_idx = np.argmin([np.linalg.norm(sample[2] - centroids[cluster_idx]) for sample in cluster_samples]) 222 | selected_samples.append(cluster_samples[closest_sample_idx]) 223 | 224 | return selected_samples 225 | 226 | 227 | def js_divergence(p, q): 228 | p = p.cpu().numpy() 229 | q = q.cpu().numpy() 230 | 231 | def kl_divergence(p, q): 232 | return np.sum(np.where(p != 0, p * np.log(p / q), 0)) 233 | 234 | m = 0.5 * (p + q) 235 | return 0.5 * (kl_divergence(p, m) + kl_divergence(q, m)) 236 | 237 | def create_circle_tensor(size): 238 | y, x = torch.meshgrid([torch.arange(size), torch.arange(size)]) 239 | distance = torch.sqrt((x - size // 2) ** 2 + (y - size // 2) ** 2) 240 | circle_tensor = (distance <= size // 4).float().unsqueeze(0).unsqueeze(0) 241 | return circle_tensor 242 | 243 | 244 | def visualize_tensor(tensor): 245 | plt.imshow(tensor[0, 0, :, :], cmap='gray') 246 | plt.show() 247 | 248 | 249 | def one_hot_encoder(n_classes, input_tensor): 250 | tensor_list = [] 251 | 252 | for i in range(n_classes): 253 | temp_prob = (input_tensor == i).unsqueeze(1) 254 | tensor_list.append(temp_prob) 255 | 256 | output_tensor = torch.cat(tensor_list, dim=1) 257 | 258 | return output_tensor.float() 259 | 260 | 261 | def smooth_segmentation_labels(n_classes, target_tensor): 262 | 263 | tensor_list = [] 264 | for i in range(n_classes): 265 | class_mask = (target_tensor == i).unsqueeze(1) 266 | temp_prob = class_mask.float() 267 | smoothing_factor = torch.rand_like(temp_prob) * 0.1 268 | smoothing = (1.0 - smoothing_factor) * temp_prob + smoothing_factor / n_classes 269 | tensor_list.append(smoothing) 270 | 271 | output_tensor = torch.cat(tensor_list, dim=1) 272 | 273 | return output_tensor.float() 274 | 275 | def add_noise_boxes(incoming_mask, n_classes, image_size, mask_type, n_boxes=3, probability=None, real_mask=False): 276 | if probability is None: 277 | probability = {'random': 1.0, 'jigsaw': 1.0, 'zeros': 1.0} 278 | for p in probability.values(): 279 | assert 0.0 <= p <= 1.0 280 | 281 | if type(mask_type) is not list: 282 | assert type(mask_type) == str 283 | mask_type = [mask_type] 284 | 285 | def _py_corrupt(mask): 286 | mask = mask.numpy() 287 | mask = mask.astype(np.float32) # Ensure the array is in float32 288 | jigsaw_op = np.random.choice([True, False], p=[probability['jigsaw'], 1.0 - probability['jigsaw']]) 289 | zeros_op = np.random.choice([True, False], p=[probability['zeros'], 1.0 - probability['zeros']]) 290 | random_op = np.random.choice([True, False], p=[probability['random'], 1.0 - probability['random']]) 291 | if not (jigsaw_op or zeros_op): 292 | random_op = True 293 | 294 | for _ in range(n_boxes): 295 | 296 | def get_box_params(low, high): 297 | r = np.random.randint(low=low, high=high) 298 | mcx = np.random.randint(r + 1, image_size[0] - r - 1) 299 | mcy = np.random.randint(r + 1, image_size[1] - r - 1) 300 | return r, mcx, mcy 301 | 302 | if 'random' in mask_type and random_op: 303 | r, mcx, mcy = get_box_params(low=1, high=5) 304 | mask[:, mcx - r:mcx + r, mcy - r:mcy + r] = 0 305 | mask[:, mcx - r:mcx + r, mcy - r:mcy + r] = 1 306 | if 'jigsaw' in mask_type and jigsaw_op: 307 | ll = np.min([image_size[0], image_size[1]]) // 10 308 | hh = np.min([image_size[0], image_size[1]]) // 5 309 | r, mcx, mcy = get_box_params(low=ll, high=hh) 310 | mask[mcx - r:mcx + r, mcy - r:mcy + r] = 0 311 | mcx_src = np.random.randint(r + 1, image_size[0] - r - 1) 312 | mcy_src = np.random.randint(r + 1, image_size[1] - r - 1) 313 | mask_copy = mask.copy() 314 | mask[:, mcx - r:mcx + r, mcy - r:mcy + r] = mask_copy[:, mcx_src - r:mcx_src + r, 315 | mcy_src - r:mcy_src + r] 316 | if 'zeros' in mask_type and zeros_op: 317 | r, mcx, mcy = get_box_params(low=1, high=10) 318 | mask[:, mcx - r:mcx + r, mcy - r:mcy + r] = 0 319 | mask[:, mcx - r:mcx + r, mcy - r:mcy + r] = 1 320 | return mask 321 | 322 | incoming_mask = incoming_mask.cpu() 323 | if real_mask: 324 | incoming_mask = one_hot_encoder(n_classes=n_classes, input_tensor=incoming_mask) 325 | noisy_masks = [_py_corrupt(m) for m in incoming_mask] 326 | 327 | mask = torch.from_numpy(np.array(noisy_masks)) 328 | mask = mask.cuda() 329 | return mask 330 | 331 | 332 | # binary_tensor = torch.randint(2, size=(1, 384, 384), dtype=torch.uint8) 333 | # predict_tensor = torch.randn(1, 2, 384, 384) 334 | # noised_tensor = add_noise_boxes(binary_tensor, 2, image_size=[384, 384], mask_type=['jigsaw', 'random', 'zeros'], 335 | # real_mask=True) 336 | # noised_pre_tensor = add_noise_boxes(predict_tensor, 2, image_size=[384, 384], mask_type=['jigsaw', 'random', 'zeros']) 337 | # circle_tensor = create_circle_tensor(size=384) 338 | # circle_noise_tensor = add_noise_boxes(circle_tensor, 2, image_size=[384, 384], n_boxes=20, 339 | # probability={'random': 0.9, 'jigsaw': 0.5, 'zeros': 0.5}, mask_type=['random']) 340 | # print('noised shape:', noised_tensor.shape) 341 | # visualize_tensor(circle_tensor) 342 | # visualize_tensor(circle_noise_tensor) 343 | # print('noised shape:', noised_tensor.shape) 344 | # print('noised pre shape:', noised_pre_tensor.shape) 345 | -------------------------------------------------------------------------------- /code_oa/train_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from tensorboardX import SummaryWriter 16 | from torch.nn import BCEWithLogitsLoss 17 | from torch.nn.modules.loss import CrossEntropyLoss 18 | from torch.utils.data import DataLoader 19 | from torchvision import transforms 20 | from torchvision.utils import make_grid 21 | from tqdm import tqdm 22 | 23 | from dataloaders import utils 24 | from dataloaders.dataset import H5DataSet, RandomGenerator, TwoStreamBatchSampler 25 | from networks.unet import UNet 26 | from utils import losses, metrics, ramps 27 | from val_2D import test_single_volume 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--root_path', type=str, 31 | default='../data/data_preprocessed/NPC/WCH', help='Name of Experiment') 32 | parser.add_argument('--exp', type=str, 33 | default='NPC_WCH/WCH_self_training_UGTST+_5%', help='experiment_name') 34 | parser.add_argument('--model', type=str, 35 | default='UNet', help='model_name') 36 | parser.add_argument('--max_iterations', type=int, 37 | default=20000, help='maximum epoch number to train') 38 | parser.add_argument('--batch_size', type=int, default=24, 39 | help='batch_size per gpu') 40 | parser.add_argument('--deterministic', type=int, default=1, 41 | help='whether use deterministic training') 42 | parser.add_argument('--base_lr', type=float, default=0.001, 43 | help='segmentation network learning rate') 44 | parser.add_argument('--patch_size', type=list, default=[320, 320], 45 | help='patch size of network input') 46 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 47 | parser.add_argument('--num_classes', type=int, default=2, 48 | help='output channel of network') 49 | # label and unlabel 50 | parser.add_argument('--labeled_bs', type=int, default=12, 51 | help='labeled_batch_size per gpu') 52 | # costs 53 | parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay') 54 | parser.add_argument('--consistency_type', type=str, 55 | default="mse", help='consistency_type') 56 | parser.add_argument('--consistency', type=float, 57 | default=0.1, help='consistency') 58 | parser.add_argument('--consistency_rampup', type=float, 59 | default=200.0, help='consistency_rampup') 60 | parser.add_argument('--early_stop_patient', type=float, default=5000, 61 | help='num for early stop patient') 62 | parser.add_argument('--pretrained_path', type=str, 63 | default='../model/NPC/source_train/UNet_best_model.pth', help='Path to the pretrained model') 64 | parser.add_argument('--labeled_num', type=int, default=56, help='labeled slices') 65 | parser.add_argument('--active_method', type=str, 66 | default='UGTST', help='active learning method') 67 | args = parser.parse_args() 68 | 69 | def get_current_consistency_weight(epoch): 70 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 71 | return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup) 72 | 73 | def update_ema_variables(model, ema_model, alpha, global_step): 74 | # Use the true average until the exponential average is more correct 75 | alpha = min(1 - 1 / (global_step + 1), alpha) 76 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 77 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 78 | 79 | def worker_init_fn(worker_id): 80 | random.seed(args.seed + worker_id) 81 | 82 | def train(args, snapshot_path): 83 | base_lr = args.base_lr 84 | num_classes = args.num_classes 85 | batch_size = args.batch_size 86 | max_iterations = args.max_iterations 87 | early_stop_patient = args.early_stop_patient 88 | def create_model(ema=False): 89 | # Network definition 90 | model = UNet(in_chns=1, 91 | class_num=num_classes) 92 | if args.pretrained_path is not None: 93 | model.load_state_dict(torch.load(args.pretrained_path)) 94 | logging.info(f"Loaded pretrained model from {args.pretrained_path}") 95 | model = model.cuda() 96 | if ema: 97 | for param in model.parameters(): 98 | param.detach_() 99 | return model 100 | 101 | model = create_model() 102 | 103 | db_train = H5DataSet(base_dir=args.root_path, split="semi_train", active_method=f'{args.active_method}', 104 | num=None, transform=transforms.Compose([ 105 | RandomGenerator(args.patch_size, IntensityAug=True, SpatialAug=True, NonlinearAug=False) 106 | ])) 107 | db_val = H5DataSet(base_dir=args.root_path, split="val") 108 | 109 | total_slices = len(db_train) 110 | labeled_slice = args.labeled_num 111 | print("Total silices is: {}, labeled slices is: {}".format( 112 | total_slices, labeled_slice)) 113 | labeled_idxs = list(range(0, labeled_slice)) 114 | unlabeled_idxs = list(range(labeled_slice, total_slices)) 115 | batch_sampler = TwoStreamBatchSampler( 116 | labeled_idxs, unlabeled_idxs, batch_size, batch_size-args.labeled_bs) 117 | 118 | trainloader = DataLoader(db_train, batch_sampler=batch_sampler, 119 | num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn) 120 | 121 | model.train() 122 | 123 | valloader = DataLoader(db_val, batch_size=1, shuffle=False, 124 | num_workers=0) 125 | 126 | optimizer = optim.SGD(model.parameters(), lr=base_lr, 127 | momentum=0.9, weight_decay=0.0001) 128 | 129 | ce_loss = CrossEntropyLoss() 130 | dice_loss = losses.DiceLoss(num_classes) 131 | 132 | writer = SummaryWriter(snapshot_path + '/log') 133 | logging.info("{} iterations per epoch".format(len(trainloader))) 134 | 135 | iter_num = 0 136 | max_epoch = max_iterations // len(trainloader) + 1 137 | best_performance = 0.0 138 | no_improvement_counter = 0.0 139 | iterator = tqdm(range(max_epoch), ncols=70) 140 | for epoch_num in iterator: 141 | for i_batch, sampled_batch in enumerate(trainloader): 142 | 143 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 144 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 145 | 146 | output, _ = model(volume_batch) 147 | outputs_soft = torch.softmax(output, dim=1) 148 | 149 | consistency_weight = get_current_consistency_weight(iter_num // 150) 150 | 151 | loss_lab = 0.5 * (ce_loss(output[:args.labeled_bs], label_batch[:][:args.labeled_bs].long()) + dice_loss( 152 | outputs_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))) 153 | pseudo_outputs = torch.argmax(outputs_soft[args.labeled_bs:].detach(), dim=1, keepdim=False) 154 | pseudo_supervision = ce_loss(output[args.labeled_bs:], pseudo_outputs) 155 | 156 | # loss = loss_lab 157 | loss = loss_lab + pseudo_supervision 158 | optimizer.zero_grad() 159 | 160 | loss.backward() 161 | optimizer.step() 162 | 163 | iter_num = iter_num + 1 164 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 165 | for param_group in optimizer.param_groups: 166 | param_group['lr'] = lr_ 167 | 168 | writer.add_scalar('lr', lr_, iter_num) 169 | writer.add_scalar( 170 | 'consistency_weight/consistency_weight', consistency_weight, iter_num) 171 | writer.add_scalar('loss', 172 | loss, iter_num) 173 | 174 | logging.info('iteration %d : lab_loss : %f ulab_loss : %f' % (iter_num, loss_lab.item(), pseudo_supervision.item())) 175 | if iter_num % 200 == 0: 176 | image = volume_batch[1, 0:1, :, :] 177 | writer.add_image('train/Image', image, iter_num) 178 | outputs = torch.argmax(torch.softmax( 179 | output, dim=1), dim=1, keepdim=True) 180 | writer.add_image('train/Prediction', 181 | outputs[1, ...] * 50, iter_num) 182 | 183 | if iter_num > 0 and iter_num % 50 == 0: 184 | model.eval() 185 | metric_list = 0.0 186 | for i_batch, sampled_batch in enumerate(valloader): 187 | metric_i = test_single_volume( 188 | sampled_batch["image"], sampled_batch["label"], model, 189 | classes=num_classes, patch_size=args.patch_size) 190 | metric_list += np.array(metric_i) 191 | metric_list = metric_list / len(db_val) 192 | for class_i in range(num_classes-1): 193 | writer.add_scalar('info/model1_val_{}_dice'.format(class_i+1), 194 | metric_list[class_i, 0], iter_num) 195 | writer.add_scalar('info/model1_val_{}_hd95'.format(class_i+1), 196 | metric_list[class_i, 1], iter_num) 197 | 198 | performance = np.mean(metric_list, axis=0)[0] 199 | 200 | mean_hd95 = np.mean(metric_list, axis=0)[1] 201 | writer.add_scalar('info/model1_val_mean_dice', performance, iter_num) 202 | writer.add_scalar('info/model1_val_mean_hd95', mean_hd95, iter_num) 203 | 204 | if performance > best_performance: 205 | best_performance = performance 206 | save_mode_path = os.path.join(snapshot_path, 207 | 'model1_iter_{}_dice_{}.pth'.format( 208 | iter_num, round(best_performance, 4))) 209 | save_best = os.path.join(snapshot_path, 210 | '{}_best_model.pth'.format(args.model)) 211 | torch.save(model.state_dict(), save_mode_path) 212 | torch.save(model.state_dict(), save_best) 213 | no_improvement_counter = 0 214 | else: 215 | no_improvement_counter += 50 216 | logging.info( 217 | 'iteration %d : model1_mean_dice : %f model1_mean_hd95 : %f' % (iter_num, performance, mean_hd95)) 218 | model.train() 219 | 220 | 221 | if iter_num % 3000 == 0: 222 | save_mode_path = os.path.join( 223 | snapshot_path, 'iter_' + str(iter_num) + '.pth') 224 | torch.save(model.state_dict(), save_mode_path) 225 | logging.info("save model to {}".format(save_mode_path)) 226 | 227 | if iter_num >= max_iterations: 228 | time1 = time.time() 229 | break 230 | if no_improvement_counter >= early_stop_patient: 231 | logging.info('No improvement in Validation mean_dice for {} iterations. Early stopping...'.format(early_stop_patient)) 232 | iterator.close() 233 | break 234 | if iter_num >= max_iterations: 235 | iterator.close() 236 | break 237 | writer.close() 238 | 239 | if __name__ == "__main__": 240 | if not args.deterministic: 241 | cudnn.benchmark = True 242 | cudnn.deterministic = False 243 | else: 244 | cudnn.benchmark = False 245 | cudnn.deterministic = True 246 | 247 | random.seed(args.seed) 248 | np.random.seed(args.seed) 249 | torch.manual_seed(args.seed) 250 | torch.cuda.manual_seed(args.seed) 251 | 252 | snapshot_path = "../model/{}".format( 253 | args.exp) 254 | if not os.path.exists(snapshot_path): 255 | os.makedirs(snapshot_path) 256 | if os.path.exists(snapshot_path + '/code'): 257 | shutil.rmtree(snapshot_path + '/code') 258 | shutil.copytree('.', snapshot_path + '/code', 259 | shutil.ignore_patterns(['.git', '__pycache__'])) 260 | 261 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 262 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 263 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 264 | logging.info(str(args)) 265 | train(args, snapshot_path) 266 | -------------------------------------------------------------------------------- /code_oa/train_source.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from tensorboardX import SummaryWriter 16 | from torch.nn import BCEWithLogitsLoss 17 | from torch.nn.modules.loss import CrossEntropyLoss 18 | from torch.utils.data import DataLoader 19 | from torchvision import transforms 20 | from torchvision.utils import make_grid 21 | from tqdm import tqdm 22 | from networks.unet import UNet 23 | from dataloaders import utils 24 | from utils import losses, metrics, ramps 25 | from dataloaders.dataset import h5DataSet, RandomGenerator, TwoStreamBatchSampler 26 | from val_2D import test_single_volume, test_single_volume_ds, test_single_volume_fast 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--root_path', type=str, 29 | default='../data/data_preprocessed/NPC/source', help='Name of Experiment') 30 | parser.add_argument('--exp', type=str, 31 | default='NPC_new/source', help='experiment_name') 32 | parser.add_argument('--num_classes', type=int, default=2, 33 | help='output channel of network') 34 | parser.add_argument('--max_iterations', type=int, 35 | default=60000, help='maximum epoch number to train') 36 | parser.add_argument('--batch_size', type=int, default=24, 37 | help='batch_size per gpu') 38 | parser.add_argument('--deterministic', type=int, default=1, 39 | help='whether use deterministic training') 40 | parser.add_argument('--base_lr', type=float, default=0.01, 41 | help='segmentation network learning rate') 42 | parser.add_argument('--patch_size', type=list, default=[320, 320], 43 | help='patch size of network input') 44 | parser.add_argument('--pretrained_path', type=str, default=None, help='Path to the pretrained model') 45 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 46 | parser.add_argument('--f', type=str, default='train') 47 | args = parser.parse_args() 48 | 49 | def worker_init_fn(worker_id): 50 | random.seed(args.seed + worker_id) 51 | pass 52 | 53 | def train(args, snapshot_path): 54 | 55 | base_lr = args.base_lr 56 | num_classes = args.num_classes 57 | batch_size = args.batch_size 58 | max_iterations = args.max_iterations 59 | 60 | model = UNet(in_chns=1, class_num=num_classes) 61 | model = model.cuda() 62 | split = args.f 63 | if args.f == '1': 64 | split = '1' 65 | if args.f == '2': 66 | split = '2' 67 | if args.f == '3': 68 | split = '3' 69 | if args.f == '4': 70 | split = '4' 71 | if args.f == 'train': 72 | split = 'train' 73 | if args.pretrained_path is not None: 74 | model.load_state_dict(torch.load(args.pretrained_path)) 75 | logging.info(f"Loaded pretrained model from {args.pretrained_path}") 76 | db_train = h5DataSet(base_dir=args.root_path, split=split, transform=transforms.Compose([ 77 | RandomGenerator(args.patch_size, IntensityAug=True, SpatialAug=True, NonlinearAug=True) 78 | ])) 79 | db_val = h5DataSet(base_dir=args.root_path, split="val") 80 | 81 | 82 | 83 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, 84 | num_workers=8, pin_memory=True, worker_init_fn=worker_init_fn) 85 | valloader = DataLoader(db_val, batch_size=1, shuffle=False, 86 | num_workers=1) 87 | 88 | model.train() 89 | 90 | optimizer = optim.SGD(model.parameters(), lr=base_lr, 91 | momentum=0.9, weight_decay=0.0001) 92 | ce_loss = CrossEntropyLoss() 93 | dice_loss = losses.DiceLoss(num_classes) 94 | 95 | writer = SummaryWriter(snapshot_path + '/log') 96 | logging.info("{} iterations per epoch".format(len(trainloader))) 97 | 98 | iter_num = 0 99 | max_epoch = max_iterations // len(trainloader) + 1 100 | best_performance = 0.0 101 | 102 | iterator = tqdm(range(max_epoch), ncols=70) 103 | for epoch_num in iterator: 104 | for i_batch, sampled_batch in enumerate(trainloader): 105 | 106 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 107 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 108 | 109 | outputs, _ = model(volume_batch) 110 | outputs_soft = torch.softmax(outputs, dim=1) 111 | loss_ce = ce_loss(outputs, label_batch[:].long()) 112 | loss_dice = dice_loss(outputs_soft, label_batch.unsqueeze(1)) 113 | loss = 0.5 * (loss_dice + loss_ce) 114 | optimizer.zero_grad() 115 | loss.backward() 116 | optimizer.step() 117 | 118 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 119 | for param_group in optimizer.param_groups: 120 | param_group['lr'] = lr_ 121 | 122 | iter_num = iter_num + 1 123 | writer.add_scalar('info/lr', lr_, iter_num) 124 | writer.add_scalar('info/total_loss', loss, iter_num) 125 | writer.add_scalar('info/loss_ce', loss_ce, iter_num) 126 | writer.add_scalar('info/loss_dice', loss_dice, iter_num) 127 | 128 | logging.info( 129 | 'iteration %d :learning rate %f: loss : %f, loss_ce: %f, loss_dice: %f' % 130 | (iter_num, lr_, loss.item(), loss_ce.item(), loss_dice.item())) 131 | 132 | if iter_num % 20 == 0: 133 | image = volume_batch[0, 0:1, :, :] 134 | writer.add_image('train/Image', image, iter_num) 135 | outputs = torch.argmax(torch.softmax( 136 | outputs, dim=1), dim=1, keepdim=True) 137 | writer.add_image('train/Prediction', 138 | outputs[0, ...] * 50, iter_num) 139 | labs = label_batch[0, ...].unsqueeze(0) * 50 140 | writer.add_image('train/GroundTruth', labs, iter_num) 141 | 142 | if iter_num > 0 and iter_num % 200 == 0: 143 | model.eval() 144 | metric_list = 0.0 145 | for i_batch, sampled_batch in enumerate(valloader): 146 | metric_i = test_single_volume( 147 | sampled_batch["image"], sampled_batch["label"], model, 148 | classes=num_classes, patch_size=args.patch_size) 149 | metric_list += np.array(metric_i) 150 | metric_list = metric_list / len(db_val) 151 | for class_i in range(num_classes-1): 152 | writer.add_scalar('info/val_{}_dice'.format(class_i+1), 153 | metric_list[class_i, 0], iter_num) 154 | writer.add_scalar('info/val_{}_hd95'.format(class_i+1), 155 | metric_list[class_i, 1], iter_num) 156 | 157 | performance = np.mean(metric_list, axis=0)[0] 158 | 159 | mean_hd95 = np.mean(metric_list, axis=0)[1] 160 | writer.add_scalar('info/val_mean_dice', performance, iter_num) 161 | writer.add_scalar('info/val_mean_hd95', mean_hd95, iter_num) 162 | 163 | if performance > best_performance: 164 | best_performance = performance 165 | save_mode_path = os.path.join(snapshot_path, 166 | 'iter_{}_dice_{}.pth'.format( 167 | iter_num, round(best_performance, 4))) 168 | save_best = os.path.join(snapshot_path, 169 | '{}_best_model.pth'.format('UNet')) 170 | torch.save(model.state_dict(), save_mode_path) 171 | torch.save(model.state_dict(), save_best) 172 | 173 | logging.info( 174 | 'iteration %d : mean_dice : %f mean_hd95 : %f' % (iter_num, performance, mean_hd95)) 175 | model.train() 176 | 177 | if iter_num % 3000 == 0: 178 | save_mode_path = os.path.join( 179 | snapshot_path, 'iter_' + str(iter_num) + '.pth') 180 | torch.save(model.state_dict(), save_mode_path) 181 | logging.info("save model to {}".format(save_mode_path)) 182 | if iter_num >= max_iterations: 183 | break 184 | if iter_num >= max_iterations: 185 | iterator.close() 186 | break 187 | writer.close() 188 | return "Training Finished!" 189 | 190 | 191 | if __name__ == "__main__": 192 | if not args.deterministic: 193 | cudnn.benchmark = True 194 | cudnn.deterministic = False 195 | else: 196 | cudnn.benchmark = False 197 | cudnn.deterministic = True 198 | 199 | random.seed(args.seed) 200 | np.random.seed(args.seed) 201 | torch.manual_seed(args.seed) 202 | torch.cuda.manual_seed(args.seed) 203 | 204 | snapshot_path = "../model/{}_{}/".format( 205 | args.exp, args.f) 206 | if not os.path.exists(snapshot_path): 207 | os.makedirs(snapshot_path) 208 | if os.path.exists(snapshot_path + '/code'): 209 | shutil.rmtree(snapshot_path + '/code') 210 | shutil.copytree('.', snapshot_path + '/code', 211 | shutil.ignore_patterns(['.git', '__pycache__'])) 212 | 213 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 214 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 215 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 216 | logging.info(str(args)) 217 | train(args, snapshot_path) -------------------------------------------------------------------------------- /code_oa/utils/bezier_curve.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import matplotlib.pyplot as plt 5 | try: 6 | from scipy.special import comb 7 | except: 8 | from scipy.misc import comb 9 | 10 | 11 | def bernstein_poly(i, n, t): 12 | """ 13 | The Bernstein polynomial of n, i as a function of t 14 | """ 15 | return comb(n, i) * ( t**(n-i) ) * (1 - t)**i 16 | 17 | 18 | def bezier_curve(points, nTimes=1000): 19 | """ 20 | Given a set of control points, return the 21 | bezier curve defined by the control points. 22 | Control points should be a list of lists, or list of tuples 23 | such as [ [1,1], 24 | [2,3], 25 | [4,5], ..[Xn, Yn] ] 26 | nTimes is the number of time steps, defaults to 1000 27 | See http://processingjs.nihongoresources.com/bezierinfo/ 28 | """ 29 | 30 | nPoints = len(points) 31 | xPoints = np.array([p[0] for p in points]) 32 | yPoints = np.array([p[1] for p in points]) 33 | 34 | t = np.linspace(0.0, 1.0, nTimes) 35 | 36 | polynomial_array = np.array([bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)]) 37 | 38 | xvals = np.dot(xPoints, polynomial_array) 39 | yvals = np.dot(yPoints, polynomial_array) 40 | 41 | return xvals, yvals 42 | 43 | 44 | def nonlinear_transformation(x, prob=0.5): 45 | # Note that this function will not help you to do normalization. 46 | # Once it normalizes the image, it would transform back after the nonlinear transformation. 47 | if random.random() >= prob: 48 | return x 49 | 50 | maxvalue = x.max() 51 | minvalue = x.min() 52 | 53 | # Normalize to [0, 1] using max-min normalization 54 | if maxvalue > 1 or minvalue < 0: 55 | # Normalize x to [0, 1] 56 | x_maxmin = (x - minvalue) / torch.clamp((maxvalue - minvalue), min=1e-5) 57 | 58 | # Do nonlinear transformation (Bezier curve) 59 | points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] 60 | xvals, yvals = bezier_curve(points, nTimes=100000) 61 | 62 | if random.random() < 0.5: 63 | # Half chance to get flip 64 | xvals = np.sort(xvals) 65 | else: 66 | xvals, yvals = np.sort(xvals), np.sort(yvals) 67 | 68 | # Convert x_maxmin to numpy (it should be on the CPU at this point) 69 | x_maxmin_cpu = x_maxmin.cpu().numpy() # Move the tensor to CPU and convert to numpy 70 | 71 | # Perform interpolation using numpy 72 | nonlinear_x = np.interp(x_maxmin_cpu, xvals, yvals) 73 | 74 | # Convert nonlinear_x back to a tensor, and ensure it's on the same device as x 75 | nonlinear_x = torch.tensor(nonlinear_x, dtype=torch.float32, 76 | device=x.device) # Ensure it is on the same device as x 77 | 78 | # Restore the original intensity from the MAXMIN normalization 79 | nonlinear_x = nonlinear_x * torch.clamp((maxvalue - minvalue), min=1e-5) + minvalue 80 | 81 | return nonlinear_x -------------------------------------------------------------------------------- /code_oa/utils/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGTST/5b01660e043a653c5f8703a653f38d1bdc898ab1/code_oa/utils/loss/__init__.py -------------------------------------------------------------------------------- /code_oa/utils/loss/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGTST/5b01660e043a653c5f8703a653f38d1bdc898ab1/code_oa/utils/loss/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /code_oa/utils/loss/__pycache__/compound_losses.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGTST/5b01660e043a653c5f8703a653f38d1bdc898ab1/code_oa/utils/loss/__pycache__/compound_losses.cpython-310.pyc -------------------------------------------------------------------------------- /code_oa/utils/loss/__pycache__/deep_supervision.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGTST/5b01660e043a653c5f8703a653f38d1bdc898ab1/code_oa/utils/loss/__pycache__/deep_supervision.cpython-310.pyc -------------------------------------------------------------------------------- /code_oa/utils/loss/__pycache__/dice.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGTST/5b01660e043a653c5f8703a653f38d1bdc898ab1/code_oa/utils/loss/__pycache__/dice.cpython-310.pyc -------------------------------------------------------------------------------- /code_oa/utils/loss/__pycache__/robust_ce_loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/UGTST/5b01660e043a653c5f8703a653f38d1bdc898ab1/code_oa/utils/loss/__pycache__/robust_ce_loss.cpython-310.pyc -------------------------------------------------------------------------------- /code_oa/utils/loss/compound_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nnunetv2.training.loss.dice import SoftDiceLoss, MemoryEfficientSoftDiceLoss 3 | from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss, TopKLoss 4 | from nnunetv2.utilities.helpers import softmax_helper_dim1 5 | from torch import nn 6 | 7 | 8 | class DC_and_CE_loss(nn.Module): 9 | def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None, 10 | dice_class=SoftDiceLoss): 11 | """ 12 | Weights for CE and Dice do not need to sum to one. You can set whatever you want. 13 | :param soft_dice_kwargs: 14 | :param ce_kwargs: 15 | :param aggregate: 16 | :param square_dice: 17 | :param weight_ce: 18 | :param weight_dice: 19 | """ 20 | super(DC_and_CE_loss, self).__init__() 21 | if ignore_label is not None: 22 | ce_kwargs['ignore_index'] = ignore_label 23 | 24 | self.weight_dice = weight_dice 25 | self.weight_ce = weight_ce 26 | self.ignore_label = ignore_label 27 | 28 | self.ce = RobustCrossEntropyLoss(**ce_kwargs) 29 | self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs) 30 | 31 | def forward(self, net_output: torch.Tensor, target: torch.Tensor): 32 | """ 33 | target must be b, c, x, y(, z) with c=1 34 | :param net_output: 35 | :param target: 36 | :return: 37 | """ 38 | if self.ignore_label is not None: 39 | assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ 40 | '(DC_and_CE_loss)' 41 | mask = target != self.ignore_label 42 | # remove ignore label from target, replace with one of the known labels. It doesn't matter because we 43 | # ignore gradients in those areas anyway 44 | target_dice = torch.where(mask, target, 0) 45 | num_fg = mask.sum() 46 | else: 47 | target_dice = target 48 | mask = None 49 | 50 | dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ 51 | if self.weight_dice != 0 else 0 52 | ce_loss = self.ce(net_output, target[:, 0]) \ 53 | if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 54 | 55 | result = self.weight_ce * ce_loss + self.weight_dice * dc_loss 56 | return result 57 | 58 | 59 | class DC_and_BCE_loss(nn.Module): 60 | def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use_ignore_label: bool = False, 61 | dice_class=MemoryEfficientSoftDiceLoss): 62 | """ 63 | DO NOT APPLY NONLINEARITY IN YOUR NETWORK! 64 | 65 | target mut be one hot encoded 66 | IMPORTANT: We assume use_ignore_label is located in target[:, -1]!!! 67 | 68 | :param soft_dice_kwargs: 69 | :param bce_kwargs: 70 | :param aggregate: 71 | """ 72 | super(DC_and_BCE_loss, self).__init__() 73 | if use_ignore_label: 74 | bce_kwargs['reduction'] = 'none' 75 | 76 | self.weight_dice = weight_dice 77 | self.weight_ce = weight_ce 78 | self.use_ignore_label = use_ignore_label 79 | 80 | self.ce = nn.BCEWithLogitsLoss(**bce_kwargs) 81 | self.dc = dice_class(apply_nonlin=torch.sigmoid, **soft_dice_kwargs) 82 | 83 | def forward(self, net_output: torch.Tensor, target: torch.Tensor): 84 | if self.use_ignore_label: 85 | # target is one hot encoded here. invert it so that it is True wherever we can compute the loss 86 | if target.dtype == torch.bool: 87 | mask = ~target[:, -1:] 88 | else: 89 | mask = (1 - target[:, -1:]).bool() 90 | # remove ignore channel now that we have the mask 91 | # why did we use clone in the past? Should have documented that... 92 | # target_regions = torch.clone(target[:, :-1]) 93 | target_regions = target[:, :-1] 94 | else: 95 | target_regions = target 96 | mask = None 97 | 98 | dc_loss = self.dc(net_output, target_regions, loss_mask=mask) 99 | target_regions = target_regions.float() 100 | if mask is not None: 101 | ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8) 102 | else: 103 | ce_loss = self.ce(net_output, target_regions) 104 | result = self.weight_ce * ce_loss + self.weight_dice * dc_loss 105 | return result 106 | 107 | 108 | class DC_and_topk_loss(nn.Module): 109 | def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None): 110 | """ 111 | Weights for CE and Dice do not need to sum to one. You can set whatever you want. 112 | :param soft_dice_kwargs: 113 | :param ce_kwargs: 114 | :param aggregate: 115 | :param square_dice: 116 | :param weight_ce: 117 | :param weight_dice: 118 | """ 119 | super().__init__() 120 | if ignore_label is not None: 121 | ce_kwargs['ignore_index'] = ignore_label 122 | 123 | self.weight_dice = weight_dice 124 | self.weight_ce = weight_ce 125 | self.ignore_label = ignore_label 126 | 127 | self.ce = TopKLoss(**ce_kwargs) 128 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs) 129 | 130 | def forward(self, net_output: torch.Tensor, target: torch.Tensor): 131 | """ 132 | target must be b, c, x, y(, z) with c=1 133 | :param net_output: 134 | :param target: 135 | :return: 136 | """ 137 | if self.ignore_label is not None: 138 | assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ 139 | '(DC_and_CE_loss)' 140 | mask = (target != self.ignore_label).bool() 141 | # remove ignore label from target, replace with one of the known labels. It doesn't matter because we 142 | # ignore gradients in those areas anyway 143 | target_dice = torch.clone(target) 144 | target_dice[target == self.ignore_label] = 0 145 | num_fg = mask.sum() 146 | else: 147 | target_dice = target 148 | mask = None 149 | 150 | dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ 151 | if self.weight_dice != 0 else 0 152 | ce_loss = self.ce(net_output, target) \ 153 | if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 154 | 155 | result = self.weight_ce * ce_loss + self.weight_dice * dc_loss 156 | return result 157 | -------------------------------------------------------------------------------- /code_oa/utils/loss/deep_supervision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class DeepSupervisionWrapper(nn.Module): 6 | def __init__(self, loss, weight_factors=None): 7 | """ 8 | Wraps a loss function so that it can be applied to multiple outputs. Forward accepts an arbitrary number of 9 | inputs. Each input is expected to be a tuple/list. Each tuple/list must have the same length. The loss is then 10 | applied to each entry like this: 11 | l = w0 * loss(input0[0], input1[0], ...) + w1 * loss(input0[1], input1[1], ...) + ... 12 | If weights are None, all w will be 1. 13 | """ 14 | super(DeepSupervisionWrapper, self).__init__() 15 | assert any([x != 0 for x in weight_factors]), "At least one weight factor should be != 0.0" 16 | self.weight_factors = tuple(weight_factors) 17 | self.loss = loss 18 | 19 | def forward(self, *args): 20 | assert all([isinstance(i, (tuple, list)) for i in args]), \ 21 | f"all args must be either tuple or list, got {[type(i) for i in args]}" 22 | # we could check for equal lengths here as well, but we really shouldn't overdo it with checks because 23 | # this code is executed a lot of times! 24 | 25 | if self.weight_factors is None: 26 | weights = (1, ) * len(args[0]) 27 | else: 28 | weights = self.weight_factors 29 | 30 | return sum([weights[i] * self.loss(*inputs) for i, inputs in enumerate(zip(*args)) if weights[i] != 0.0]) 31 | -------------------------------------------------------------------------------- /code_oa/utils/loss/dice.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from nnunetv2.utilities.ddp_allgather import AllGatherGrad 5 | from torch import nn 6 | 7 | 8 | class SoftDiceLoss(nn.Module): 9 | def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1., 10 | ddp: bool = True, clip_tp: float = None): 11 | """ 12 | """ 13 | super(SoftDiceLoss, self).__init__() 14 | 15 | self.do_bg = do_bg 16 | self.batch_dice = batch_dice 17 | self.apply_nonlin = apply_nonlin 18 | self.smooth = smooth 19 | self.clip_tp = clip_tp 20 | self.ddp = ddp 21 | 22 | def forward(self, x, y, loss_mask=None): 23 | shp_x = x.shape 24 | 25 | if self.batch_dice: 26 | axes = [0] + list(range(2, len(shp_x))) 27 | else: 28 | axes = list(range(2, len(shp_x))) 29 | 30 | if self.apply_nonlin is not None: 31 | x = self.apply_nonlin(x) 32 | 33 | tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) 34 | 35 | if self.ddp and self.batch_dice: 36 | tp = AllGatherGrad.apply(tp).sum(0) 37 | fp = AllGatherGrad.apply(fp).sum(0) 38 | fn = AllGatherGrad.apply(fn).sum(0) 39 | 40 | if self.clip_tp is not None: 41 | tp = torch.clip(tp, min=self.clip_tp , max=None) 42 | 43 | nominator = 2 * tp 44 | denominator = 2 * tp + fp + fn 45 | 46 | dc = (nominator + self.smooth) / (torch.clip(denominator + self.smooth, 1e-8)) 47 | 48 | if not self.do_bg: 49 | if self.batch_dice: 50 | dc = dc[1:] 51 | else: 52 | dc = dc[:, 1:] 53 | dc = dc.mean() 54 | 55 | return -dc 56 | 57 | 58 | class MemoryEfficientSoftDiceLoss(nn.Module): 59 | def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1., 60 | ddp: bool = True): 61 | """ 62 | saves 1.6 GB on Dataset017 3d_lowres 63 | """ 64 | super(MemoryEfficientSoftDiceLoss, self).__init__() 65 | 66 | self.do_bg = do_bg 67 | self.batch_dice = batch_dice 68 | self.apply_nonlin = apply_nonlin 69 | self.smooth = smooth 70 | self.ddp = ddp 71 | 72 | def forward(self, x, y, loss_mask=None): 73 | if self.apply_nonlin is not None: 74 | x = self.apply_nonlin(x) 75 | 76 | # make everything shape (b, c) 77 | axes = tuple(range(2, x.ndim)) 78 | 79 | with torch.no_grad(): 80 | if x.ndim != y.ndim: 81 | y = y.view((y.shape[0], 1, *y.shape[1:])) 82 | 83 | if x.shape == y.shape: 84 | # if this is the case then gt is probably already a one hot encoding 85 | y_onehot = y 86 | else: 87 | y_onehot = torch.zeros(x.shape, device=x.device, dtype=torch.bool) 88 | y_onehot.scatter_(1, y.long(), 1) 89 | 90 | if not self.do_bg: 91 | y_onehot = y_onehot[:, 1:] 92 | 93 | sum_gt = y_onehot.sum(axes) if loss_mask is None else (y_onehot * loss_mask).sum(axes) 94 | 95 | # this one MUST be outside the with torch.no_grad(): context. Otherwise no gradients for you 96 | if not self.do_bg: 97 | x = x[:, 1:] 98 | 99 | if loss_mask is None: 100 | intersect = (x * y_onehot).sum(axes) 101 | sum_pred = x.sum(axes) 102 | else: 103 | intersect = (x * y_onehot * loss_mask).sum(axes) 104 | sum_pred = (x * loss_mask).sum(axes) 105 | 106 | if self.batch_dice: 107 | if self.ddp: 108 | intersect = AllGatherGrad.apply(intersect).sum(0) 109 | sum_pred = AllGatherGrad.apply(sum_pred).sum(0) 110 | sum_gt = AllGatherGrad.apply(sum_gt).sum(0) 111 | 112 | intersect = intersect.sum(0) 113 | sum_pred = sum_pred.sum(0) 114 | sum_gt = sum_gt.sum(0) 115 | 116 | dc = (2 * intersect + self.smooth) / (torch.clip(sum_gt + sum_pred + self.smooth, 1e-8)) 117 | 118 | dc = dc.mean() 119 | return -dc 120 | 121 | 122 | def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): 123 | """ 124 | net_output must be (b, c, x, y(, z))) 125 | gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) 126 | if mask is provided it must have shape (b, 1, x, y(, z))) 127 | :param net_output: 128 | :param gt: 129 | :param axes: can be (, ) = no summation 130 | :param mask: mask must be 1 for valid pixels and 0 for invalid pixels 131 | :param square: if True then fp, tp and fn will be squared before summation 132 | :return: 133 | """ 134 | if axes is None: 135 | axes = tuple(range(2, net_output.ndim)) 136 | 137 | with torch.no_grad(): 138 | if net_output.ndim != gt.ndim: 139 | gt = gt.view((gt.shape[0], 1, *gt.shape[1:])) 140 | 141 | if net_output.shape == gt.shape: 142 | # if this is the case then gt is probably already a one hot encoding 143 | y_onehot = gt 144 | else: 145 | y_onehot = torch.zeros(net_output.shape, device=net_output.device, dtype=torch.bool) 146 | y_onehot.scatter_(1, gt.long(), 1) 147 | 148 | tp = net_output * y_onehot 149 | fp = net_output * (~y_onehot) 150 | fn = (1 - net_output) * y_onehot 151 | tn = (1 - net_output) * (~y_onehot) 152 | 153 | if mask is not None: 154 | with torch.no_grad(): 155 | mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for _ in range(2, tp.ndim)])) 156 | tp *= mask_here 157 | fp *= mask_here 158 | fn *= mask_here 159 | tn *= mask_here 160 | # benchmark whether tiling the mask would be faster (torch.tile). It probably is for large batch sizes 161 | # OK it barely makes a difference but the implementation above is a tiny bit faster + uses less vram 162 | # (using nnUNetv2_train 998 3d_fullres 0) 163 | # tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) 164 | # fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) 165 | # fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) 166 | # tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1) 167 | 168 | if square: 169 | tp = tp ** 2 170 | fp = fp ** 2 171 | fn = fn ** 2 172 | tn = tn ** 2 173 | 174 | if len(axes) > 0: 175 | tp = tp.sum(dim=axes, keepdim=False) 176 | fp = fp.sum(dim=axes, keepdim=False) 177 | fn = fn.sum(dim=axes, keepdim=False) 178 | tn = tn.sum(dim=axes, keepdim=False) 179 | 180 | return tp, fp, fn, tn 181 | 182 | 183 | if __name__ == '__main__': 184 | from nnunetv2.utilities.helpers import softmax_helper_dim1 185 | pred = torch.rand((2, 3, 32, 32, 32)) 186 | ref = torch.randint(0, 3, (2, 32, 32, 32)) 187 | 188 | dl_old = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False) 189 | dl_new = MemoryEfficientSoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False) 190 | res_old = dl_old(pred, ref) 191 | res_new = dl_new(pred, ref) 192 | print(res_old, res_new) 193 | -------------------------------------------------------------------------------- /code_oa/utils/loss/robust_ce_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | import numpy as np 4 | 5 | 6 | class RobustCrossEntropyLoss(nn.CrossEntropyLoss): 7 | """ 8 | this is just a compatibility layer because my target tensor is float and has an extra dimension 9 | 10 | input must be logits, not probabilities! 11 | """ 12 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 13 | if target.ndim == input.ndim: 14 | assert target.shape[1] == 1 15 | target = target[:, 0] 16 | return super().forward(input, target.long()) 17 | 18 | 19 | class TopKLoss(RobustCrossEntropyLoss): 20 | """ 21 | input must be logits, not probabilities! 22 | """ 23 | def __init__(self, weight=None, ignore_index: int = -100, k: float = 10, label_smoothing: float = 0): 24 | self.k = k 25 | super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False, label_smoothing=label_smoothing) 26 | 27 | def forward(self, inp, target): 28 | target = target[:, 0].long() 29 | res = super(TopKLoss, self).forward(inp, target) 30 | num_voxels = np.prod(res.shape, dtype=np.int64) 31 | res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) 32 | return res.mean() 33 | -------------------------------------------------------------------------------- /code_oa/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | def dice_loss(score, target): 9 | target = target.float() 10 | smooth = 1e-5 11 | intersect = torch.sum(score * target) 12 | y_sum = torch.sum(target * target) 13 | z_sum = torch.sum(score * score) 14 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 15 | loss = 1 - loss 16 | return loss 17 | 18 | 19 | def dice_loss1(score, target): 20 | target = target.float() 21 | smooth = 1e-5 22 | intersect = torch.sum(score * target) 23 | y_sum = torch.sum(target) 24 | z_sum = torch.sum(score) 25 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 26 | loss = 1 - loss 27 | return loss 28 | 29 | 30 | def entropy_loss(p, C=2): 31 | # p N*C*W*H*D 32 | y1 = -1 * torch.sum(p * torch.log(p + 1e-6), dim=1) / \ 33 | torch.tensor(np.log(C)).cuda() 34 | ent = torch.mean(y1) 35 | 36 | return ent 37 | 38 | 39 | def softmax_dice_loss(input_logits, target_logits): 40 | """Takes softmax on both sides and returns MSE loss 41 | 42 | Note: 43 | - Returns the sum over all examples. Divide by the batch size afterwards 44 | if you want the mean. 45 | - Sends gradients to inputs but not the targets. 46 | """ 47 | assert input_logits.size() == target_logits.size() 48 | input_softmax = F.softmax(input_logits, dim=1) 49 | target_softmax = F.softmax(target_logits, dim=1) 50 | n = input_logits.shape[1] 51 | dice = 0 52 | for i in range(0, n): 53 | dice += dice_loss1(input_softmax[:, i], target_softmax[:, i]) 54 | mean_dice = dice / n 55 | 56 | return mean_dice 57 | 58 | 59 | def entropy_loss_map(p, C=2): 60 | ent = -1 * torch.sum(p * torch.log(p + 1e-6), dim=1, 61 | keepdim=True) / torch.tensor(np.log(C)).cuda() 62 | return ent 63 | 64 | 65 | def softmax_mse_loss(input_logits, target_logits, sigmoid=False): 66 | """Takes softmax on both sides and returns MSE loss 67 | 68 | Note: 69 | - Returns the sum over all examples. Divide by the batch size afterwards 70 | if you want the mean. 71 | - Sends gradients to inputs but not the targets. 72 | """ 73 | assert input_logits.size() == target_logits.size() 74 | if sigmoid: 75 | input_softmax = torch.sigmoid(input_logits) 76 | target_softmax = torch.sigmoid(target_logits) 77 | else: 78 | input_softmax = F.softmax(input_logits, dim=1) 79 | target_softmax = F.softmax(target_logits, dim=1) 80 | 81 | mse_loss = (input_softmax - target_softmax) ** 2 82 | return mse_loss 83 | 84 | 85 | def softmax_kl_loss(input_logits, target_logits, sigmoid=False): 86 | """Takes softmax on both sides and returns KL divergence 87 | 88 | Note: 89 | - Returns the sum over all examples. Divide by the batch size afterwards 90 | if you want the mean. 91 | - Sends gradients to inputs but not the targets. 92 | """ 93 | assert input_logits.size() == target_logits.size() 94 | if sigmoid: 95 | input_log_softmax = torch.log(torch.sigmoid(input_logits)) 96 | target_softmax = torch.sigmoid(target_logits) 97 | else: 98 | input_log_softmax = F.log_softmax(input_logits, dim=1) 99 | target_softmax = F.softmax(target_logits, dim=1) 100 | 101 | # return F.kl_div(input_log_softmax, target_softmax) 102 | kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean') 103 | # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...]) 104 | return kl_div 105 | 106 | 107 | def symmetric_mse_loss(input1, input2): 108 | """Like F.mse_loss but sends gradients to both directions 109 | 110 | Note: 111 | - Returns the sum over all examples. Divide by the batch size afterwards 112 | if you want the mean. 113 | - Sends gradients to both input1 and input2. 114 | """ 115 | assert input1.size() == input2.size() 116 | return torch.mean((input1 - input2) ** 2) 117 | 118 | 119 | class FocalLoss(nn.Module): 120 | def __init__(self, gamma=2, alpha=None, size_average=True): 121 | super(FocalLoss, self).__init__() 122 | self.gamma = gamma 123 | self.alpha = alpha 124 | if isinstance(alpha, (float, int)): 125 | self.alpha = torch.Tensor([alpha, 1 - alpha]) 126 | if isinstance(alpha, list): 127 | self.alpha = torch.Tensor(alpha) 128 | self.size_average = size_average 129 | 130 | def forward(self, input, target): 131 | if input.dim() > 2: 132 | # N,C,H,W => N,C,H*W 133 | input = input.view(input.size(0), input.size(1), -1) 134 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 135 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 136 | target = target.view(-1, 1) 137 | 138 | logpt = F.log_softmax(input, dim=1) 139 | logpt = logpt.gather(1, target) 140 | logpt = logpt.view(-1) 141 | pt = Variable(logpt.data.exp()) 142 | 143 | if self.alpha is not None: 144 | if self.alpha.type() != input.data.type(): 145 | self.alpha = self.alpha.type_as(input.data) 146 | at = self.alpha.gather(0, target.data.view(-1)) 147 | logpt = logpt * Variable(at) 148 | 149 | loss = -1 * (1 - pt) ** self.gamma * logpt 150 | if self.size_average: 151 | return loss.mean() 152 | else: 153 | return loss.sum() 154 | 155 | def mse_loss(input1, input2): 156 | return torch.mean((input1 - input2)**2) 157 | 158 | 159 | class DiceLoss(nn.Module): 160 | def __init__(self, n_classes, onehot=True): 161 | super(DiceLoss, self).__init__() 162 | self.n_classes = n_classes 163 | self.onehot = onehot 164 | 165 | def _one_hot_encoder(self, input_tensor): 166 | tensor_list = [] 167 | for i in range(self.n_classes): 168 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 169 | tensor_list.append(temp_prob) 170 | output_tensor = torch.cat(tensor_list, dim=1) 171 | return output_tensor.float() 172 | 173 | def _dice_loss(self, score, target): 174 | target = target.float() 175 | smooth = 1e-5 176 | intersect = torch.sum(score * target) 177 | y_sum = torch.sum(target * target) 178 | z_sum = torch.sum(score * score) 179 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 180 | loss = 1 - loss 181 | return loss 182 | 183 | def forward(self, inputs, target, weight=None, softmax=False): 184 | if softmax: 185 | inputs = torch.softmax(inputs, dim=1) 186 | if self.onehot: 187 | target = self._one_hot_encoder(target) 188 | if weight is None: 189 | weight = [1] * self.n_classes 190 | assert inputs.size() == target.size(), 'predict & target shape do not match' 191 | class_wise_dice = [] 192 | loss = 0.0 193 | for i in range(0, self.n_classes): 194 | dice = self._dice_loss(inputs[:, i], target[:, i]) 195 | class_wise_dice.append(1.0 - dice.item()) 196 | loss += dice * weight[i] 197 | return loss / self.n_classes 198 | 199 | 200 | def entropy_minmization(p): 201 | y1 = -1 * torch.sum(p * torch.log(p + 1e-6), dim=1) 202 | ent = torch.mean(y1) 203 | 204 | return ent 205 | 206 | 207 | def entropy_map(p): 208 | ent_map = -1 * torch.sum(p * torch.log(p + 1e-6), dim=1, 209 | keepdim=True) 210 | return ent_map 211 | 212 | 213 | def compute_kl_loss(p, q): 214 | p_loss = F.kl_div(F.log_softmax(p, dim=-1), 215 | F.softmax(q, dim=-1), reduction='none') 216 | q_loss = F.kl_div(F.log_softmax(q, dim=-1), 217 | F.softmax(p, dim=-1), reduction='none') 218 | 219 | # Using function "sum" and "mean" are depending on your task 220 | p_loss = p_loss.mean() 221 | q_loss = q_loss.mean() 222 | 223 | loss = (p_loss + q_loss) / 2 224 | return loss 225 | 226 | class LeastSquareGAN(object): 227 | def __init__(self): 228 | super(LeastSquareGAN, self).__init__() 229 | self.real_label = 1.0 230 | self.fake_label = -1.0 231 | 232 | @staticmethod 233 | def generator_loss(disc_pred_fake, real_label=1.0): 234 | loss = 0.5 * torch.mean((disc_pred_fake - real_label) ** 2) 235 | return loss 236 | 237 | @staticmethod 238 | def discriminator_loss(disc_pred_real, disc_pred_fake, real_label=1.0, fake_label=-1.0): 239 | loss = (0.5 * torch.mean((disc_pred_real - real_label) ** 2) + 240 | 0.5 * torch.mean((disc_pred_fake - fake_label) ** 2)) 241 | return loss 242 | 243 | @staticmethod 244 | def discriminator_fake_loss(disc_pred_fake, fake_label=-1.0): 245 | loss = 0.5 * torch.mean((disc_pred_fake - fake_label) ** 2) 246 | return loss 247 | 248 | @staticmethod 249 | def discriminator_real_loss(disc_pred_real, real_label=1.0): 250 | loss = 0.5 * torch.mean((disc_pred_real - real_label) ** 2) 251 | return loss 252 | 253 | 254 | class VanillaGAN(object): 255 | def __init__(self): 256 | super(VanillaGAN, self).__init__() 257 | 258 | @staticmethod 259 | def generator_loss(disc_pred_fake): 260 | labels = torch.ones_like(disc_pred_fake) 261 | loss = F.binary_cross_entropy_with_logits(disc_pred_fake, labels) 262 | return loss 263 | 264 | @staticmethod 265 | def discriminator_loss(disc_pred_real, disc_pred_fake): 266 | real_labels = torch.ones_like(disc_pred_real) 267 | loss_real = F.binary_cross_entropy_with_logits(disc_pred_real, real_labels) 268 | fake_labels = torch.zeros_like(disc_pred_fake) 269 | loss_fake = F.binary_cross_entropy_with_logits(disc_pred_fake, fake_labels) 270 | loss = loss_real + loss_fake 271 | return loss 272 | 273 | @staticmethod 274 | def discriminator_fake_loss(disc_pred_fake): 275 | labels = torch.ones_like(disc_pred_fake) 276 | loss = F.binary_cross_entropy_with_logits(disc_pred_fake, labels) 277 | return loss 278 | 279 | @staticmethod 280 | def discriminator_real_loss(disc_pred_real): 281 | labels = torch.ones_like(disc_pred_real) 282 | loss = F.binary_cross_entropy_with_logits(disc_pred_real, labels) 283 | return loss 284 | 285 | def gradient_penalty(discriminator, x_real, x_fake, gp_weight=10.0): 286 | 287 | if x_fake is None: 288 | # For real samples, calculate gradients 289 | x_interpolated = x_real.clone().detach().requires_grad_(True) 290 | else: 291 | # For interpolated samples, calculate gradients 292 | epsilon = torch.rand(x_real.shape[0], 1, 1, 1).to(x_real.device) 293 | x_interpolated = epsilon * x_real + (1 - epsilon) * x_fake 294 | x_interpolated.requires_grad_(True) 295 | # Get discriminator predictions on interpolated samples 296 | disc_pred_interpolated = discriminator(x_interpolated) 297 | 298 | # Calculate gradients of disc_pred_interpolated with respect to x_interpolated 299 | gradients = torch.autograd.grad(outputs=disc_pred_interpolated, 300 | inputs=x_interpolated, 301 | grad_outputs=torch.ones_like(disc_pred_interpolated), 302 | create_graph=True, 303 | retain_graph=True)[0] 304 | 305 | # Calculate gradient penalty 306 | slopes = torch.sqrt(torch.sum(gradients.pow(2), dim=[1, 2, 3]) + 1e-8) 307 | penalty = torch.mean((slopes - 1.) ** 2) 308 | 309 | # Apply penalty weight 310 | gp = gp_weight * penalty 311 | 312 | return gp 313 | -------------------------------------------------------------------------------- /code_oa/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/14 下午4:41 4 | # @Author : chuyu zhang 5 | # @File : metrics.py 6 | # @Software: PyCharm 7 | 8 | 9 | import numpy as np 10 | from medpy import metric 11 | 12 | 13 | def cal_dice(prediction, label, num=2): 14 | total_dice = np.zeros(num-1) 15 | for i in range(1, num): 16 | prediction_tmp = (prediction == i) 17 | label_tmp = (label == i) 18 | prediction_tmp = prediction_tmp.astype(np.float) 19 | label_tmp = label_tmp.astype(np.float) 20 | 21 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 22 | total_dice[i - 1] += dice 23 | 24 | return total_dice 25 | 26 | 27 | def calculate_metric_percase(pred, gt): 28 | dc = metric.binary.dc(pred, gt) 29 | jc = metric.binary.jc(pred, gt) 30 | hd = metric.binary.hd95(pred, gt) 31 | asd = metric.binary.asd(pred, gt) 32 | 33 | return dc, jc, hd, asd 34 | 35 | 36 | def dice(input, target, ignore_index=None): 37 | smooth = 1. 38 | # using clone, so that it can do change to original target. 39 | iflat = input.clone().view(-1) 40 | tflat = target.clone().view(-1) 41 | if ignore_index is not None: 42 | mask = tflat == ignore_index 43 | tflat[mask] = 0 44 | iflat[mask] = 0 45 | intersection = (iflat * tflat).sum() 46 | 47 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) -------------------------------------------------------------------------------- /code_oa/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /code_oa/utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import pickle 9 | import numpy as np 10 | import re 11 | from scipy.ndimage import distance_transform_edt as distance 12 | from skimage import segmentation as skimage_seg 13 | import torch 14 | from torch.utils.data.sampler import Sampler 15 | import torch.distributed as dist 16 | 17 | # many issues with this function 18 | def load_model(path): 19 | """Loads model and return it without DataParallel table.""" 20 | if os.path.isfile(path): 21 | print("=> loading checkpoint '{}'".format(path)) 22 | checkpoint = torch.load(path) 23 | 24 | for key in checkpoint["state_dict"]: 25 | print(key) 26 | 27 | # size of the top layer 28 | N = checkpoint["state_dict"]["decoder.out_conv.bias"].size() 29 | 30 | # build skeleton of the model 31 | sob = "sobel.0.weight" in checkpoint["state_dict"].keys() 32 | model = models.__dict__[checkpoint["arch"]](sobel=sob, out=int(N[0])) 33 | 34 | # deal with a dataparallel table 35 | def rename_key(key): 36 | if not "module" in key: 37 | return key 38 | return "".join(key.split(".module")) 39 | 40 | checkpoint["state_dict"] = { 41 | rename_key(key): val for key, val in checkpoint["state_dict"].items() 42 | } 43 | 44 | # load weights 45 | model.load_state_dict(checkpoint["state_dict"]) 46 | print("Loaded") 47 | else: 48 | model = None 49 | print("=> no checkpoint found at '{}'".format(path)) 50 | return model 51 | 52 | 53 | def load_checkpoint(path, model, optimizer, from_ddp=False): 54 | """loads previous checkpoint 55 | 56 | Args: 57 | path (str): path to checkpoint 58 | model (model): model to restore checkpoint to 59 | optimizer (optimizer): torch optimizer to load optimizer state_dict to 60 | from_ddp (bool, optional): load DistributedDataParallel checkpoint to regular model. Defaults to False. 61 | 62 | Returns: 63 | model, optimizer, epoch_num, loss 64 | """ 65 | # load checkpoint 66 | checkpoint = torch.load(path) 67 | # transfer state_dict from checkpoint to model 68 | model.load_state_dict(checkpoint["state_dict"]) 69 | # transfer optimizer state_dict from checkpoint to model 70 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 71 | # track loss 72 | loss = checkpoint["loss"] 73 | return model, optimizer, checkpoint["epoch"], loss.item() 74 | 75 | 76 | def restore_model(logger, snapshot_path, model_num=None): 77 | """wrapper function to read log dir and load restore a previous checkpoint 78 | 79 | Args: 80 | logger (Logger): logger object (for info output to console) 81 | snapshot_path (str): path to checkpoint directory 82 | 83 | Returns: 84 | model, optimizer, start_epoch, performance 85 | """ 86 | try: 87 | # check if there is previous progress to be restored: 88 | logger.info(f"Snapshot path: {snapshot_path}") 89 | iter_num = [] 90 | name = "model_iter" 91 | if model_num: 92 | name = model_num 93 | for filename in os.listdir(snapshot_path): 94 | if name in filename: 95 | basename, extension = os.path.splitext(filename) 96 | iter_num.append(int(basename.split("_")[2])) 97 | iter_num = max(iter_num) 98 | for filename in os.listdir(snapshot_path): 99 | if name in filename and str(iter_num) in filename: 100 | model_checkpoint = filename 101 | except Exception as e: 102 | logger.warning(f"Error finding previous checkpoints: {e}") 103 | 104 | try: 105 | logger.info(f"Restoring model checkpoint: {model_checkpoint}") 106 | model, optimizer, start_epoch, performance = load_checkpoint( 107 | snapshot_path + "/" + model_checkpoint, model, optimizer 108 | ) 109 | logger.info(f"Models restored from iteration {iter_num}") 110 | return model, optimizer, start_epoch, performance 111 | except Exception as e: 112 | logger.warning(f"Unable to restore model checkpoint: {e}, using new model") 113 | 114 | 115 | def save_checkpoint(epoch, model, optimizer, loss, path): 116 | """Saves model as checkpoint""" 117 | torch.save( 118 | { 119 | "epoch": epoch, 120 | "state_dict": model.state_dict(), 121 | "optimizer_state_dict": optimizer.state_dict(), 122 | "loss": loss, 123 | }, 124 | path, 125 | ) 126 | 127 | 128 | class UnifLabelSampler(Sampler): 129 | """Samples elements uniformely accross pseudolabels. 130 | Args: 131 | N (int): size of returned iterator. 132 | images_lists: dict of key (target), value (list of data with this target) 133 | """ 134 | 135 | def __init__(self, N, images_lists): 136 | self.N = N 137 | self.images_lists = images_lists 138 | self.indexes = self.generate_indexes_epoch() 139 | 140 | def generate_indexes_epoch(self): 141 | size_per_pseudolabel = int(self.N / len(self.images_lists)) + 1 142 | res = np.zeros(size_per_pseudolabel * len(self.images_lists)) 143 | 144 | for i in range(len(self.images_lists)): 145 | indexes = np.random.choice( 146 | self.images_lists[i], 147 | size_per_pseudolabel, 148 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel), 149 | ) 150 | res[i * size_per_pseudolabel : (i + 1) * size_per_pseudolabel] = indexes 151 | 152 | np.random.shuffle(res) 153 | return res[: self.N].astype("int") 154 | 155 | def __iter__(self): 156 | return iter(self.indexes) 157 | 158 | def __len__(self): 159 | return self.N 160 | 161 | 162 | class AverageMeter(object): 163 | """Computes and stores the average and current value""" 164 | 165 | def __init__(self): 166 | self.reset() 167 | 168 | def reset(self): 169 | self.val = 0 170 | self.avg = 0 171 | self.sum = 0 172 | self.count = 0 173 | 174 | def update(self, val, n=1): 175 | self.val = val 176 | self.sum += val * n 177 | self.count += n 178 | self.avg = self.sum / self.count 179 | 180 | 181 | def learning_rate_decay(optimizer, t, lr_0): 182 | for param_group in optimizer.param_groups: 183 | lr = lr_0 / np.sqrt(1 + lr_0 * param_group["weight_decay"] * t) 184 | param_group["lr"] = lr 185 | 186 | 187 | class Logger: 188 | """Class to update every epoch to keep trace of the results 189 | Methods: 190 | - log() log and save 191 | """ 192 | 193 | def __init__(self, path): 194 | self.path = path 195 | self.data = [] 196 | 197 | def log(self, train_point): 198 | self.data.append(train_point) 199 | with open(os.path.join(self.path), "wb") as fp: 200 | pickle.dump(self.data, fp, -1) 201 | 202 | 203 | def compute_sdf(img_gt, out_shape): 204 | """ 205 | compute the signed distance map of binary mask 206 | input: segmentation, shape = (batch_size, x, y, z) 207 | output: the Signed Distance Map (SDM) 208 | sdf(x) = 0; x in segmentation boundary 209 | -inf|x-y|; x in segmentation 210 | +inf|x-y|; x out of segmentation 211 | normalize sdf to [-1,1] 212 | """ 213 | 214 | img_gt = img_gt.astype(np.uint8) 215 | normalized_sdf = np.zeros(out_shape) 216 | 217 | for b in range(out_shape[0]): # batch size 218 | posmask = img_gt[b].astype(np.bool) 219 | if posmask.any(): 220 | negmask = ~posmask 221 | posdis = distance(posmask) 222 | negdis = distance(negmask) 223 | boundary = skimage_seg.find_boundaries(posmask, mode="inner").astype( 224 | np.uint8 225 | ) 226 | sdf = (negdis - np.min(negdis)) / (np.max(negdis) - np.min(negdis)) - ( 227 | posdis - np.min(posdis) 228 | ) / (np.max(posdis) - np.min(posdis)) 229 | sdf[boundary == 1] = 0 230 | normalized_sdf[b] = sdf 231 | # assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis)) 232 | # assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) 233 | 234 | return normalized_sdf 235 | 236 | 237 | # set up process group for distributed computing 238 | def distributed_setup(rank, world_size): 239 | os.environ["MASTER_ADDR"] = "localhost" 240 | os.environ["MASTER_PORT"] = "12355" 241 | print("setting up dist process group now") 242 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 243 | 244 | 245 | def load_ddp_to_nddp(state_dict): 246 | pattern = re.compile("module") 247 | for k, v in state_dict.items(): 248 | if re.search("module", k): 249 | model_dict[re.sub(pattern, "", k)] = v 250 | else: 251 | model_dict = state_dict 252 | return model_dict 253 | -------------------------------------------------------------------------------- /code_oa/val_2D.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from medpy import metric 4 | from scipy.ndimage import zoom 5 | from scipy import ndimage 6 | 7 | def calculate_metric_percase(pred, gt): 8 | pred[pred > 0] = 1 9 | gt[gt > 0] = 1 10 | if pred.sum() > 0: 11 | dice = metric.binary.dc(pred, gt) 12 | hd95 = metric.binary.hd95(pred, gt) 13 | return dice, hd95 14 | else: 15 | return 0, 0 16 | 17 | def get_largest_component(image): 18 | dim = len(image.shape) 19 | if(image.sum() == 0 ): 20 | # print('the largest component is null') 21 | return image 22 | if(dim == 2): 23 | s = ndimage.generate_binary_structure(2,1) 24 | elif(dim == 3): 25 | s = ndimage.generate_binary_structure(3,1) 26 | else: 27 | raise ValueError("the dimension number should be 2 or 3") 28 | labeled_array, numpatches = ndimage.label(image, s) 29 | sizes = ndimage.sum(image, labeled_array, range(1, numpatches + 1)) 30 | max_label = np.where(sizes == sizes.max())[0] + 1 31 | output = np.asarray(labeled_array == max_label, np.uint8) 32 | return output 33 | 34 | def test_single_volume(image, label, net, classes, patch_size=[384, 384]): 35 | image, label = image.squeeze(0).cpu().detach( 36 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 37 | prediction = np.zeros_like(label) 38 | for ind in range(image.shape[0]): 39 | slice = image[ind, :, :] 40 | x, y = slice.shape[0], slice.shape[1] 41 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 42 | input = torch.from_numpy(slice).unsqueeze( 43 | 0).unsqueeze(0).float().cuda() 44 | net.eval() 45 | with torch.no_grad(): 46 | out = torch.argmax(torch.softmax( 47 | net(input)[0], dim=1), dim=1).squeeze(0) 48 | out = out.cpu().detach().numpy() 49 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 50 | prediction[ind] = pred 51 | # prediction = get_largest_component(prediction) 52 | metric_list = [] 53 | for i in range(1, classes): 54 | metric_list.append(calculate_metric_percase( 55 | prediction == i, label == i)) 56 | return metric_list 57 | 58 | def test_single_volume_fast(image, label, net, classes): 59 | image, label = image.squeeze(0).cpu().detach( 60 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 61 | input = torch.from_numpy(image).float().cuda().unsqueeze(1) 62 | net.eval() 63 | with torch.no_grad(): 64 | out = torch.argmax(torch.softmax( 65 | net(input)[0], dim=1), dim=1).squeeze(0) 66 | out = out.cpu().detach().numpy() 67 | prediction = out 68 | metric_list = [] 69 | for i in range(1, classes): 70 | metric_list.append(calculate_metric_percase( 71 | prediction == i, label == i)) 72 | return metric_list 73 | 74 | 75 | def test_single_volume_ds(image, label, net, classes, patch_size=[384, 384]): 76 | image, label = image.squeeze(0).cpu().detach( 77 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 78 | prediction = np.zeros_like(label) 79 | for ind in range(image.shape[0]): 80 | slice = image[ind, :, :] 81 | x, y = slice.shape[0], slice.shape[1] 82 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 83 | input = torch.from_numpy(slice).unsqueeze( 84 | 0).unsqueeze(0).float().cuda() 85 | net.eval() 86 | with torch.no_grad(): 87 | output_main, _, _, _ = net(input) 88 | out = torch.argmax(torch.softmax( 89 | output_main, dim=1), dim=1).squeeze(0) 90 | out = out.cpu().detach().numpy() 91 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 92 | prediction[ind] = pred 93 | metric_list = [] 94 | for i in range(1, classes): 95 | metric_list.append(calculate_metric_percase( 96 | prediction == i, label == i)) 97 | return metric_list 98 | --------------------------------------------------------------------------------