├── __init__.py ├── dataset ├── __init__.py ├── utils.py ├── generate_hvsmr.py ├── generate_chd.py ├── hvsmr.py ├── mmwhs.py ├── generate_mmwhs.py ├── acdc.py ├── chd.py ├── generate_acdc.py └── augmentation.py ├── network ├── __init__.py └── unet2d.py ├── figures └── overview.jpg ├── .gitignore ├── requirements.txt ├── experiment_log.py ├── run_script.sh ├── lr_scheduler.py ├── myconfig.py ├── README.md ├── loss └── contrast_loss.py ├── train_contrast.py ├── train_supervised.py ├── utils.py └── metrics.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dewenzeng/positional_cl/HEAD/figures/overview.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # results 2 | results/ 3 | 4 | # logs 5 | runs/ 6 | 7 | runs_*/ 8 | 9 | # vscode 10 | .vscode/ 11 | 12 | train_job.sh 13 | 14 | test.py 15 | 16 | __pycache__ 17 | 18 | run.sh -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | batchgenerators==0.23 2 | numpy==1.21.5 3 | pandas==1.4.1 4 | Pillow==9.0.1 5 | tensorboard==2.9.1 6 | torch==1.12.0 7 | torchvision==0.13.0 8 | opencv-python 9 | SimpleITK 10 | scikit-image 11 | matplotlib 12 | scikit-learn 13 | scipy 14 | 15 | -------------------------------------------------------------------------------- /experiment_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class PytorchExperimentLogger(object): 4 | """ 5 | A single class for logging your pytorch experiments to file. 6 | Extends the ExperimentLogger also also creates a experiment folder with a file structure: 7 | """ 8 | 9 | def __init__(self, saveDir, fileName,ShowTerminal=False): 10 | 11 | self.saveFile = os.path.join(saveDir, fileName+".txt") 12 | self.ShowTerminal = ShowTerminal 13 | 14 | def print(self, strT): 15 | # 16 | if self.ShowTerminal: 17 | print(strT) 18 | f = open(self.saveFile, 'a') 19 | f.writelines(strT+'\n') 20 | f.close() 21 | -------------------------------------------------------------------------------- /run_script.sh: -------------------------------------------------------------------------------- 1 | # some examples of training command 2 | 3 | # contrastive learning 4 | # train pcl on chd dataset 5 | CUDA_VISIBLE_DEVICES=0,1 python train_contrast.py --device cuda:0 --batch_size 32 --epochs 300 --data_dir your_data_dir --lr 0.1 --do_contrast --dataset chd --patch_size 512 512 \ 6 | --experiment_name contrast_chd_pcl_temp01_thresh01_ --slice_threshold 0.1 --temp 0.1 --initial_filter_size 32 --classes 512 --contrastive_method pcl 7 | 8 | # train pcl on acdc dataset 9 | CUDA_VISIBLE_DEVICES=0,1 python train_contrast.py --device cuda:0 --batch_size 32 --epochs 300 --data_dir your_data_dir --lr 0.1 --do_contrast --dataset acdc --patch_size 352 352 \ 10 | --experiment_name contrast_acdc_pcl_temp01_thresh035_ --slice_threshold 0.35 --temp 0.1 --initial_filter_size 48 --classes 512 --contrastive_method pcl 11 | 12 | # train simclr on chd dataset 13 | CUDA_VISIBLE_DEVICES=0,1 python train_contrast.py --device cuda:0 --batch_size 32 --epochs 300 --data_dir your_data_dir --lr 0.1 --do_contrast --dataset chd --patch_size 512 512 \ 14 | --experiment_name contrast_chd_pcl_temp01_thresh01_ --slice_threshold 0.1 --temp 0.1 --initial_filter_size 32 --classes 512 --contrastive_method simclr 15 | 16 | # supervised finetuning 17 | # train from scratch on chd dataset using 40 samples. 18 | CUDA_VISIBLE_DEVICES=0 python train_supervised.py --device cuda:0 --batch_size 10 --epochs 100 --data_dir your_data_dir --lr 5e-5 --min_lr 1e-6 --dataset chd --patch_size 512 512 \ 19 | --experiment_name supervised_chd_scratch_sample_40_ --initial_filter_size 32 --classes 8 --enable_few_data --sampling_k 40 20 | 21 | # train from pcl pretrained on chd dataset using 40 samples. 22 | CUDA_VISIBLE_DEVICES=0 python train_supervised.py --device cuda:0 --batch_size 10 --epochs 100 --data_dir your_data_dir --lr 5e-5 --min_lr 1e-6 --dataset chd --patch_size 512 512 \ 23 | --experiment_name supervised_chd_pcl_sample_40_ --initial_filter_size 32 --classes 8 --enable_few_data --sampling_k 40 \ 24 | --restart --pretrained_model_path your_pretrained_model_path 25 | 26 | # transfer from acdc pretrained to HVSMR using 6 samples. 27 | CUDA_VISIBLE_DEVICES=0 python train_supervised.py --device cuda:0 --batch_size 10 --epochs 100 --data_dir your_data_dir --lr 5e-5 --min_lr 5e-6 --dataset hvsmr --patch_size 352 352 \ 28 | --experiment_name supervised_hvsmr_pcl_sample_6_ --initial_filter_size 48 --classes 3 --enable_few_data --sampling_k 6 --restart --pretrained_model_path your_pretrained_model_path 29 | 30 | # Note: For ACDC and HVSMR, we use initial_filter_size 48 to align with the work one sota https://github.com/MIC-DKFZ/ACDC2017. For CHD and MMWHS, we use initial_filter_size 32 31 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | 14 | class LR_Scheduler(object): 15 | """Learning Rate Scheduler 16 | 17 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 18 | 19 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 20 | 21 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 22 | 23 | Args: 24 | args: 25 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 26 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 27 | :attr:`args.lr_step` 28 | 29 | iters_per_epoch: number of iterations per epoch 30 | """ 31 | 32 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 33 | lr_step=0, warmup_epochs=0, min_lr=None): 34 | self.mode = mode 35 | print('Using {} LR Scheduler!'.format(self.mode)) 36 | self.lr = base_lr 37 | if mode == 'step': 38 | assert lr_step 39 | self.lr_step = lr_step 40 | self.iters_per_epoch = iters_per_epoch 41 | self.N = num_epochs * iters_per_epoch 42 | self.epoch = -1 43 | self.warmup_iters = warmup_epochs * iters_per_epoch 44 | self.min_lr = min_lr 45 | 46 | def __call__(self, optimizer, i, epoch): 47 | T = epoch * self.iters_per_epoch + i 48 | if self.mode == 'cos': 49 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 50 | elif self.mode == 'poly': 51 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 52 | elif self.mode == 'step': 53 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 54 | else: 55 | raise NotImplemented 56 | # warm up lr schedule 57 | if self.min_lr is not None: 58 | if lr < self.min_lr: 59 | lr = self.min_lr 60 | if self.warmup_iters > 0 and T < self.warmup_iters: 61 | lr = lr * 1.0 * T / self.warmup_iters 62 | if epoch > self.epoch: 63 | # print('=>Epoches %i, learning rate = %.4f' % (epoch, lr)) 64 | self.epoch = epoch 65 | 66 | assert lr >= 0 67 | self._adjust_learning_rate(optimizer, lr) 68 | 69 | def _adjust_learning_rate(self, optimizer, lr): 70 | if len(optimizer.param_groups) == 1: 71 | optimizer.param_groups[0]['lr'] = lr 72 | else: 73 | # enlarge the lr at the head 74 | optimizer.param_groups[0]['lr'] = lr 75 | for i in range(1, len(optimizer.param_groups)): 76 | optimizer.param_groups[i]['lr'] = lr * 10 77 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def pad_if_too_small(image, new_shape, pad_value=None): 5 | """Padding a image according to the new shape. 6 | 7 | The result shape will be [max(image[0], new_shape[0]), max(image[1], new_shape[1])]. 8 | e.g., 9 | 1. image:[10,20], new_shape:(30,30), the res shape is [30,30]. 10 | 2. image:[10,20], new_shape:(10,10), the res shape is [10,20]. 11 | 3. image:[3,10,20], new_shape:(3,20,20), the res shape is [3,20,20]. 12 | 13 | Args: 14 | image: a numpy array. 15 | new_shape: a tuple, # elements should be the same as the image. 16 | pad_value: padding value, default to 0. 17 | 18 | Returns: 19 | res: a numpy array. 20 | """ 21 | shape = tuple(list(image.shape)) 22 | new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0)) 23 | if pad_value is None: 24 | if len(shape) == 2: 25 | pad_value = image[0, 0] 26 | elif len(shape) == 3: 27 | pad_value = image[0, 0, 0] 28 | else: 29 | raise ValueError("Image must be either 2 or 3 dimensional") 30 | res = np.ones(list(new_shape), dtype=image.dtype) * pad_value 31 | start = np.array(new_shape) / 2. - np.array(shape) / 2. 32 | if len(shape) == 2: 33 | res[int(start[0]):int(start[0]) + int(shape[0]), int(start[1]):int(start[1]) + int(shape[1])] = image 34 | elif len(shape) == 3: 35 | res[int(start[0]):int(start[0]) + int(shape[0]), int(start[1]):int(start[1]) + int(shape[1]), 36 | int(start[2]):int(start[2]) + int(shape[2])] = image 37 | return res 38 | 39 | def pad_and_or_crop(orig_data, new_shape, mode=None, coords=None): 40 | 41 | data = pad_if_too_small(orig_data, new_shape, pad_value=0) 42 | 43 | h, w = data.shape 44 | if mode == "centre": 45 | h_c = int(h / 2.) 46 | w_c = int(w / 2.) 47 | elif mode == "fixed": 48 | assert (coords is not None) 49 | h_c, w_c = coords 50 | elif mode == "random": 51 | h_c_min = int(new_shape[0] / 2.) 52 | w_c_min = int(new_shape[1] / 2.) 53 | 54 | if new_shape[0] % 2 == 1: 55 | h_c_max = h - 1 - int(new_shape[0] / 2.) 56 | w_c_max = w - 1 - int(new_shape[1] / 2.) 57 | else: 58 | h_c_max = h - int(new_shape[0] / 2.) 59 | w_c_max = w - int(new_shape[1] / 2.) 60 | 61 | h_c = np.random.randint(low=h_c_min, high=(h_c_max + 1)) 62 | w_c = np.random.randint(low=w_c_min, high=(w_c_max + 1)) 63 | 64 | h_start = h_c - int(new_shape[0] / 2.) 65 | w_start = w_c - int(new_shape[1] / 2.) 66 | data = data[h_start:(h_start + new_shape[0]), w_start:(w_start + new_shape[1])] 67 | 68 | return data, (h_c, w_c) 69 | 70 | def matplotlib_imshow(img, one_channel=False): 71 | if one_channel: 72 | img = img.mean(dim=0) 73 | # img = img / 2 + 0.5 # unnormalize 74 | npimg = img.numpy() 75 | if one_channel: 76 | plt.imshow(npimg) 77 | else: 78 | # use this function if image is grayscale 79 | plt.imshow(npimg[0,:,:],'gray') 80 | # use this function if image is RGB 81 | # plt.imshow(np.transpose(npimg, (1, 2, 0))) 82 | -------------------------------------------------------------------------------- /myconfig.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | parser = argparse.ArgumentParser() 5 | 6 | # Environment 7 | parser.add_argument("--device", type=str, default='cuda:0') 8 | parser.add_argument("--num_works", type=int, default=8) 9 | parser.add_argument("--exp_load", type=str, default=None) 10 | parser.add_argument('--save', metavar='SAVE', default='', help='saved folder') 11 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results', help='results dir') 12 | parser.add_argument('--runs_dir', default='./runs', help='runs dir') 13 | 14 | # Data 15 | parser.add_argument("--dataset", type=str, default="chd", help='can be chd, acdc, mmwhs, hvsmr') 16 | parser.add_argument("--data_dir", type=str, default="/afs/crc.nd.edu/user/d/dzeng2/data/acdc/preprocessed_data/2D/") 17 | parser.add_argument('--batch_size', type=int, default=5) 18 | parser.add_argument('--seed', type=int, default=1234) 19 | parser.add_argument("--enable_few_data", default=False, action='store_true') 20 | parser.add_argument('--sampling_k', type=int, default=10) 21 | parser.add_argument('--cross_vali_num', type=int, default=5) 22 | 23 | # Model 24 | parser.add_argument("--initial_filter_size", type=int, default=48) 25 | parser.add_argument("--patch_size", nargs='+', type=int) 26 | parser.add_argument("--classes", type=int, default=4) 27 | 28 | # Train 29 | parser.add_argument("--experiment_name", type=str, default="contrast_chd_simclr_") 30 | parser.add_argument("--restart", default=False, action='store_true') 31 | parser.add_argument("--pretrained_model_path", type=str, default='/afs/crc.nd.edu/user/d/dzeng2/UnsupervisedSegmentation/results/supervised_v3_train_2020-10-26_18-41-29/model/latest.pth') 32 | parser.add_argument("--epochs", type=int, default=100) 33 | parser.add_argument("--lr", type=float, default=1e-4) 34 | parser.add_argument("--min_lr", type=float, default=1e-6) 35 | parser.add_argument("--decay", type=str, default='50-100-150-200') 36 | parser.add_argument("--gamma", type=float, default=0.5) 37 | parser.add_argument("--optimizer", type=str, default='rmsprop', 38 | choices=('sgd', 'adam', 'rmsprop')) 39 | parser.add_argument("--weight_decay", type=float, default=1e-4) 40 | parser.add_argument("--momentum", type=float, default=0.9) 41 | parser.add_argument("--betas", type=tuple, default=(0.9, 0.999)) 42 | parser.add_argument("--epsilon", type=float, default=1e-8) 43 | parser.add_argument("--do_contrast", default=False, action='store_true') 44 | parser.add_argument("--lr_scheduler", type=str, default='cos') 45 | parser.add_argument("--contrastive_method", type=str, default='simclr', help='simclr, gcl(global contrastive learning), pcl(positional contrastive learning)') 46 | 47 | # Loss 48 | parser.add_argument("--temp", type=float, default=0.1) 49 | parser.add_argument("--slice_threshold", type=float, default=0.05) 50 | 51 | def save_args(obj, defaults, kwargs): 52 | for k,v in defaults.iteritems(): 53 | if k in kwargs: v = kwargs[k] 54 | setattr(obj, k, v) 55 | 56 | def get_config(): 57 | config = parser.parse_args() 58 | config.data_dir = os.path.expanduser(config.data_dir) 59 | config.patch_size = tuple(config.patch_size) 60 | return config 61 | -------------------------------------------------------------------------------- /dataset/generate_hvsmr.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import os 3 | import numpy as np 4 | from skimage.transform import resize 5 | 6 | def resize_image(image, old_spacing, new_spacing, order=3): 7 | new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), 8 | int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), 9 | int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) 10 | return resize(image, new_shape, order=order, mode='edge') 11 | 12 | 13 | def convert_to_one_hot(seg): 14 | vals = np.unique(seg) 15 | res = np.zeros([len(vals)] + list(seg.shape), seg.dtype) 16 | for c in range(len(vals)): 17 | res[c][seg == c] = 1 18 | return res 19 | 20 | def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5), keep_z_spacing=False): 21 | spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]] 22 | image = sitk.GetArrayFromImage(itk_image).astype(float) 23 | if keep_z_spacing: 24 | spacing_target = list(spacing_target) 25 | spacing_target[0] = spacing[0] 26 | if not is_seg: 27 | order_img = 3 28 | if not keep_z_spacing: 29 | order_img = 1 30 | image = resize_image(image, spacing, spacing_target, order=order_img).astype(np.float32) 31 | min_val_1p=np.percentile(image,1) 32 | max_val_99p=np.percentile(image,99) 33 | image[imagemax_val_99p]=max_val_99p 35 | image -= image.mean() 36 | image /= image.std() 37 | else: 38 | tmp = convert_to_one_hot(image) 39 | vals = np.unique(image) 40 | results = [] 41 | for i in range(len(tmp)): 42 | results.append(resize_image(tmp[i].astype(float), spacing, spacing_target, 1)[None]) 43 | image = vals[np.vstack(results).argmax(0)] 44 | return image 45 | 46 | def generate_hvsmr_dataset(data_dir, save_dir): 47 | if not os.path.exists(save_dir): 48 | os.mkdir(save_dir) 49 | for i in range(10): 50 | image_path = os.path.join(data_dir, 'Training_dataset_sx_cropped', 'training_sa_crop_pat'+str(i)+'.nii.gz') 51 | label_path = os.path.join(data_dir, 'Ground_truth', 'training_sa_crop_pat'+str(i)+'-label.nii.gz') 52 | itk_image = sitk.ReadImage(image_path) 53 | itk_label = sitk.ReadImage(label_path) 54 | print(f'image spacing:{itk_image.GetSpacing()}') 55 | print(f'original image size:{itk_image.GetSize()}') 56 | image = preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.7, 0.7), keep_z_spacing=True) 57 | label = preprocess_image(itk_label, is_seg=True, spacing_target=(1, 0.7, 0.7), keep_z_spacing=True) 58 | print(f'resized image size:{image.shape}') 59 | if not os.path.exists(os.path.join(save_dir, 'patient_'+str(i))): 60 | os.mkdir(os.path.join(save_dir, 'patient_'+str(i))) 61 | for j in range(image.shape[0]): 62 | tmp_image = image[j,:,:] 63 | tmp_label = label[j,:,:] 64 | save_path_image = os.path.join(save_dir, 'patient_'+str(i), 'frame_%03d'%j) 65 | all_data = np.stack([tmp_image, tmp_label],axis=0) 66 | np.save(save_path_image, all_data) 67 | 68 | if __name__ == "__main__": 69 | import argparse 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("-indir", help="folder where the extracted training data is", type=str, default='d:/data/hvsmr/') 72 | parser.add_argument("-labeled_outdir", help="folder where to save the data for the 2d network", type=str, default='d:/data/hvsmr/test') 73 | args = parser.parse_args() 74 | generate_hvsmr_dataset(args.indir, args.labeled_outdir) 75 | -------------------------------------------------------------------------------- /dataset/generate_chd.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import os 3 | import numpy as np 4 | from batchgenerators.utilities.file_and_folder_operations import * 5 | import pickle 6 | from collections import OrderedDict 7 | 8 | def generate_chd_dataset(data_dir, labeled_save_dir, unlabeled_save_dir): 9 | results = OrderedDict() 10 | for i in range(1001,1129): 11 | if os.path.exists(os.path.join(data_dir,'ct_'+str(i)+'_image.nii.gz')): 12 | print(f'processing i={i}') 13 | image_path = os.path.join(data_dir,'ct_'+str(i)+'_image.nii.gz') 14 | label_path = os.path.join(data_dir,'ct_'+str(i)+'_label.nii.gz') 15 | image = sitk.ReadImage(image_path) 16 | label = sitk.ReadImage(label_path) 17 | # print(f'image spacing:{image.GetSpacing()}') 18 | image_npy = sitk.GetArrayViewFromImage(image) 19 | label_npy = sitk.GetArrayViewFromImage(label) 20 | # convert label 21 | image_npy_copy = image_npy.copy() 22 | min_val_1p=np.percentile(image_npy_copy,1) 23 | max_val_99p=np.percentile(image_npy_copy,99) 24 | image_npy_copy[image_npy_copymax_val_99p]=max_val_99p 26 | mean = image_npy_copy.astype(float).mean() 27 | std = image_npy_copy.astype(float).std() 28 | label_npy_copy = label_npy.copy() 29 | label_npy_copy[label_npy_copy>7]=0 30 | results['ct_'+str(i)] = {} 31 | results['ct_'+str(i)]['mean'] = mean 32 | results['ct_'+str(i)]['std'] = std 33 | # print(f'mean:{mean}, std:{std}') 34 | # we save the integer version instead of float version to save space. normalization is done on-the-fly. 35 | # we save one labeled version and one unlabeled version, maybe not the best solution, but ok. 36 | # you can also add new unlabeled CT data into the unlabeled dataset for contrastive learning. 37 | for j in range(image_npy.shape[0]): 38 | tmp_image = image_npy_copy[j,:,:] 39 | tmp_label = label_npy_copy[j,:,:] 40 | all_data = np.stack([tmp_image, tmp_label],axis=0) 41 | maybe_mkdir_p(os.path.join(labeled_save_dir, 'train', 'ct_'+str(i))) 42 | save_path_image = os.path.join(labeled_save_dir, 'train', 'ct_'+str(i), 'frame'+str.format('%03d'%j)) 43 | np.savez_compressed(save_path_image, data=all_data) 44 | maybe_mkdir_p(os.path.join(unlabeled_save_dir, 'train', 'ct_'+str(i))) 45 | save_path_image = os.path.join(unlabeled_save_dir, 'train', 'ct_'+str(i), 'frame'+str.format('%03d'%j)) 46 | np.save(save_path_image, tmp_image) 47 | with open(os.path.join(labeled_save_dir, "mean_std.pkl"), 'wb') as f: 48 | pickle.dump(results, f) 49 | 50 | with open(os.path.join(unlabeled_save_dir, "mean_std.pkl"), 'wb') as f: 51 | pickle.dump(results, f) 52 | 53 | if __name__ == "__main__": 54 | import argparse 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("-indir", help="folder where the extracted training data is", type=str) 57 | parser.add_argument("-labeled_outdir", help="folder where to save the data for the 2d network", type=str) 58 | parser.add_argument("-unlabeled_outdir", help="folder where to save the data for the 2d network", type=str) 59 | args = parser.parse_args() 60 | generate_chd_dataset(args.indir, args.labeled_outdir, args.unlabeled_outdir) 61 | 62 | # python generate_chd.py -indir /afs/crc.nd.edu/user/d/dzeng2/data/chd/raw_image -labeled_outdir /afs/crc.nd.edu/user/d/dzeng2/data/chd/test/supervised -unlabeled_outdir /afs/crc.nd.edu/user/d/dzeng2/data/chd/test/contrastive -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Positional Contrastive Learning 2 | 3 | Implementation of paper in 'Positional Contrastive Learning for Volumetric Medical Image Segmentation' ([paper](https://arxiv.org/pdf/2106.09157v3.pdf) @ MICCAI'21) 4 | 5 |

6 | 7 |

8 | 9 | ### Dataset 10 | - Congenital Heart Disease (CHD) dataset, CT, [link](https://www.kaggle.com/datasets/xiaoweixumedicalai/chd68-segmentation-dataset-miccai19) 11 | - MMWHS dataset, CT, [link](https://zmiclab.github.io/zxh/0/mmwhs/) 12 | - ACDC dataset, MRI, [link](https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html) 13 | - HVSMR dataset, MRI, [link](http://segchd.csail.mit.edu/) 14 | 15 | ### Preprocessing 16 | Use the scripts in the `dataset` folder to preprocess the dataset, convert the original data into .npy for training and testing. 17 | ``` 18 | # convert the CHD dataset 19 | python generate_chd.py -indir raw_image_dir -labeled_outdir save_dir_for_unlabeled_data -unlabeled_outdir save_dir_for_unlabeled_data 20 | # convert ACDC dataset 21 | python generate_acdc.py -i raw_image_dir -out_labeled save_dir_for_unlabeled_data -out_unlabeled save_dir_for_unlabeled_data 22 | # convert MMWHS dataset 23 | python generate_mmwhs.py -indir raw_image_dir -labeled_outdir save_dir_for_unlabeled_data 24 | # convert HVSMR dataset 25 | python generate_hvsmr.py -indir raw_image_dir -labeled_outdir save_dir_for_unlabeled_data 26 | ``` 27 | 28 | ### Running 29 | 30 | (1) PCL on CHD dataset 31 | ``` 32 | CUDA_VISIBLE_DEVICES=0,1 python train_contrast.py --device cuda:0 --batch_size 32 --epochs 300 --data_dir chd_dataset --lr 0.1 --do_contrast --dataset chd --patch_size 512 512 \ 33 | --experiment_name contrast_chd_pcl_temp01_thresh01_ --slice_threshold 0.1 --temp 0.1 --initial_filter_size 32 --classes 512 --contrastive_method pcl 34 | ``` 35 | 36 | (2) PCL on ACDC dataset 37 | ``` 38 | CUDA_VISIBLE_DEVICES=0,1 python train_contrast.py --device cuda:0 --batch_size 32 --epochs 300 --data_dir acdc_dataset --lr 0.1 --do_contrast --dataset acdc --patch_size 352 352 \ 39 | --experiment_name contrast_acdc_pcl_temp01_thresh035_ --slice_threshold 0.35 --temp 0.1 --initial_filter_size 48 --classes 512 --contrastive_method pcl 40 | ``` 41 | 42 | (3) Semi-supervised finetuning on CHD dataset on 40 samples using 5-fold cross validation 43 | ``` 44 | CUDA_VISIBLE_DEVICES=0 python train_supervised.py --device cuda:0 --batch_size 10 --epochs 100 --data_dir chd_dataset --lr 5e-5 --min_lr 5e-6 --dataset chd --patch_size 512 512 \ 45 | --experiment_name supervised_chd_pcl_sample_40_ --initial_filter_size 32 --classes 8 --enable_few_data --sampling_k 40 \ 46 | --restart --pretrained_model_path /afs/crc.nd.edu/user/d/dzeng2/positional_cl/results/contrast_pcl_2020-11-27_17-08-52/model/latest.pth 47 | ``` 48 | 49 | (4) Transfer learning finetuning on MMWHS dataset on 10 samples using 5-fold cross validation 50 | ``` 51 | CUDA_VISIBLE_DEVICES=0 python train_supervised.py --device cuda:0 --batch_size 10 --epochs 100 --data_dir mmwhs dataset --lr 5e-5 --min_lr 5e-6 --dataset MMWHS --patch_size 256 256 \ 52 | --experiment_name supervised_chd_pcl_sample_40_ --initial_filter_size 48 --classes 8 --enable_few_data --sampling_k 10 \ 53 | --restart --pretrained_model_path /afs/crc.nd.edu/user/d/dzeng2/positional_cl/results/contrast_pcl_2020-11-27_17-08-52/model/latest.pth 54 | ``` 55 | 56 | Please refer to [run_script.sh](run_script.sh) for more experimental running commands 57 | 58 | ### Pretrained model 59 | 60 | The pretrained model using PCL can be found [here](https://drive.google.com/drive/folders/16vnZj9c5Mp-9lazmHtR-01AxHGUe0q_6?usp=sharing). Note that for CHD, the initial_filter_size is 32. For ACDC, the initial_filter_size is 48 to align with a SOTA solution [link](https://github.com/MIC-DKFZ/ACDC2017). 61 | 62 | ### How to cite this code 63 | 64 | Please cite the original publication: 65 | ``` 66 | @inproceedings{zeng2021positional, 67 | title={Positional contrastive learning for volumetric medical image segmentation}, 68 | author={Zeng, Dewen and Wu, Yawen and Hu, Xinrong and Xu, Xiaowei and Yuan, Haiyun and Huang, Meiping and Zhuang, Jian and Hu, Jingtong and Shi, Yiyu}, 69 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 70 | pages={221--230}, 71 | year={2021}, 72 | organization={Springer} 73 | } 74 | ``` -------------------------------------------------------------------------------- /loss/contrast_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Yonglong Tian (yonglong@mit.edu) 3 | Date: May 07, 2020 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | class SupConLoss(nn.Module): 9 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 10 | It also supports the unsupervised contrastive loss in SimCLR""" 11 | def __init__(self, threshold=0.1, temperature=0.07, contrast_mode='all', 12 | base_temperature=0.07, contrastive_method='simclr'): 13 | super(SupConLoss, self).__init__() 14 | self.temperature = temperature 15 | self.contrast_mode = contrast_mode 16 | self.base_temperature = base_temperature 17 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 18 | self.threshold = threshold 19 | self.contrastive_method = contrastive_method 20 | 21 | def _cosine_simililarity(self, x, y): 22 | # x shape: (N, 1, C) 23 | # y shape: (1, N, C) 24 | # v shape: (N, N) 25 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 26 | return v 27 | 28 | def forward(self, features, labels=None, mask=None): 29 | """Compute loss for model. If both `labels` and `mask` are None, 30 | it degenerates to SimCLR unsupervised loss: 31 | https://arxiv.org/pdf/2002.05709.pdf 32 | 33 | Args: 34 | features: hidden vector of shape [bsz, n_views, ...]. 35 | labels: ground truth of shape [bsz]. 36 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 37 | has the same class as sample i. Can be asymmetric. 38 | Returns: 39 | A loss scalar. 40 | """ 41 | device = features.device 42 | 43 | if len(features.shape) < 3: 44 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 45 | 'at least 3 dimensions are required') 46 | if len(features.shape) > 3: 47 | features = features.view(features.shape[0], features.shape[1], -1) 48 | 49 | batch_size = features.shape[0] 50 | if labels is not None and mask is not None: 51 | raise ValueError('Cannot define both `labels` and `mask`') 52 | elif labels is None and mask is None: 53 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 54 | elif labels is not None: 55 | labels = labels.contiguous().view(-1, 1) 56 | if labels.shape[0] != batch_size: 57 | raise ValueError('Num of labels does not match num of features') 58 | if self.contrastive_method == 'gcl': 59 | mask = torch.eq(labels, labels.T).float().to(device) 60 | elif self.contrastive_method == 'pcl': 61 | mask = (torch.abs(labels.T.repeat(batch_size,1) - labels.repeat(1,batch_size)) < self.threshold).float().to(device) 62 | else: 63 | mask = mask.float().to(device) 64 | 65 | contrast_count = features.shape[1] 66 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 67 | if self.contrast_mode == 'one': 68 | anchor_feature = features[:, 0] 69 | anchor_count = 1 70 | elif self.contrast_mode == 'all': 71 | anchor_feature = contrast_feature 72 | anchor_count = contrast_count 73 | else: 74 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 75 | 76 | # compute logits 77 | logits = torch.div( 78 | self._cosine_simililarity(anchor_feature, contrast_feature), 79 | self.temperature) 80 | # tile mask 81 | mask = mask.repeat(anchor_count, contrast_count) 82 | # mask-out self-contrast cases 83 | logits_mask = torch.scatter( 84 | torch.ones_like(mask), 85 | 1, 86 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 87 | 0 88 | ) 89 | mask = mask * logits_mask 90 | 91 | # compute log_prob 92 | exp_logits = torch.exp(logits) * logits_mask 93 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 94 | 95 | # compute mean of log-likelihood over positive 96 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 97 | 98 | # loss 99 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 100 | loss = loss.view(anchor_count, batch_size).mean() 101 | 102 | return loss 103 | -------------------------------------------------------------------------------- /dataset/hvsmr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import matplotlib.pyplot as plt 5 | from batchgenerators.utilities.file_and_folder_operations import * 6 | from batchgenerators.transforms.abstract_transforms import Compose, RndTransform 7 | from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform 8 | from batchgenerators.transforms.crop_and_pad_transforms import RandomCropTransform 9 | from torch.utils.data.dataset import Dataset 10 | from random import choice 11 | from .utils import * 12 | 13 | class HVSMR(Dataset): 14 | 15 | def __init__(self, keys, purpose, args): 16 | self.data_dir = args.data_dir 17 | self.patch_size = args.patch_size 18 | self.purpose = purpose 19 | self.classes = args.classes 20 | self.files = [] 21 | for key in keys: 22 | frames = subfiles(os.path.join(self.data_dir, 'patient_%d'%key), False, None, ".npy", True) 23 | frames.sort() 24 | for frame in frames: 25 | self.files.append(os.path.join(self.data_dir, 'patient_%d'%key, frame)) 26 | print(f'dataset length: {len(self.files)}') 27 | 28 | def __getitem__(self, index): 29 | img = np.load(self.files[index])[0].astype(np.float32) 30 | label = np.load(self.files[index])[1].astype(np.float32) 31 | img, label = self.prepare_supervised(img, label) 32 | # print(f'finish transform {self.files[index]}') 33 | return img, label 34 | 35 | # this function for normal supervised training 36 | def prepare_supervised(self, img, label): 37 | if self.purpose == 'train': 38 | # resize image 39 | img, coord = pad_and_or_crop(img, self.patch_size, mode='random') 40 | label, _ = pad_and_or_crop(label, self.patch_size, mode='fixed', coords=coord) 41 | # the image and label should be [batch, c, x, y, z], this is the adapatation for using batchgenerators :) 42 | data_dict = {'data':img[None, None], 'seg':label[None, None]} 43 | tr_transforms = [] 44 | tr_transforms.append(MirrorTransform((0, 1))) 45 | tr_transforms.append(RndTransform(SpatialTransform(self.patch_size, list(np.array(self.patch_size)//2), 46 | True, (100., 350.), (14., 17.), 47 | True, (0, 2.*np.pi), (-0.000001, 0.00001), (-0.000001, 0.00001), 48 | True, (0.7, 1.3), 'constant', 0, 3, 'constant', 0, 0, 49 | random_crop=False), prob=0.67, alternative_transform=RandomCropTransform(self.patch_size))) 50 | 51 | train_transform = Compose(tr_transforms) 52 | data_dict = train_transform(**data_dict) 53 | img = data_dict.get('data')[0] 54 | label = data_dict.get('seg')[0] 55 | return img, label 56 | else: 57 | # resize image 58 | img, coord = pad_and_or_crop(img, self.patch_size, mode='centre') 59 | label, _ = pad_and_or_crop(label, self.patch_size, mode='fixed', coords=coord) 60 | return img[None], label[None] 61 | 62 | def __len__(self): 63 | return len(self.files) 64 | 65 | if __name__ == "__main__": 66 | import argparse 67 | from torch.autograd import Variable 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("--data_dir", type=str, default="d:/data/hvsmr/preprocessed") 70 | # parser.add_argument("--data_dir", type=str, default="/afs/crc.nd.edu/user/d/dzeng2/data/hvsmr/preprocessed") 71 | parser.add_argument("--patch_size", type=tuple, default=(320, 320)) 72 | parser.add_argument("--classes", type=int, default=4) 73 | args = parser.parse_args() 74 | 75 | all_keys = np.arange(0, 10) 76 | train_dataset = HVSMR(keys=all_keys, purpose='train', args=args) 77 | train_dataloader = torch.utils.data.DataLoader(train_dataset, 78 | batch_size=5, 79 | shuffle=True, 80 | num_workers=8, 81 | drop_last=False) 82 | 83 | for batch_idx, tup in enumerate(train_dataloader): 84 | image, label = tup 85 | print(f'image shape:{image.shape}') 86 | plt.figure(1) 87 | img_grid = torchvision.utils.make_grid(image) 88 | matplotlib_imshow(img_grid, one_channel=False) 89 | plt.figure(2) 90 | img_grid = torchvision.utils.make_grid(label) 91 | matplotlib_imshow(img_grid, one_channel=False) 92 | plt.show() 93 | break 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /train_contrast.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from utils import * 4 | from torch.autograd import Variable 5 | from loss.contrast_loss import SupConLoss 6 | from network.unet2d import UNet2D_classification 7 | from dataset.chd import CHD 8 | from dataset.acdc import ACDC 9 | from myconfig import get_config 10 | from batchgenerators.utilities.file_and_folder_operations import * 11 | from lr_scheduler import LR_Scheduler 12 | from torch.utils.tensorboard import SummaryWriter 13 | from experiment_log import PytorchExperimentLogger 14 | 15 | def main(): 16 | # initialize config 17 | args = get_config() 18 | 19 | if args.save == '': 20 | args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 21 | save_path = os.path.join(args.results_dir, args.experiment_name + args.save) 22 | if not os.path.exists(save_path): 23 | os.makedirs(save_path) 24 | 25 | logger = PytorchExperimentLogger(save_path, "elog", ShowTerminal=True) 26 | model_result_dir = join(save_path, 'model') 27 | maybe_mkdir_p(model_result_dir) 28 | args.model_result_dir = model_result_dir 29 | 30 | logger.print(f"saving to {save_path}") 31 | writer = SummaryWriter('runs/' + args.experiment_name + args.save) 32 | 33 | # setup cuda 34 | args.device = torch.device(args.device if torch.cuda.is_available() else "cpu") 35 | # logger.print(f"the model will run on device {args.device}") 36 | 37 | # create model 38 | logger.print("creating model ...") 39 | model = UNet2D_classification(in_channels=1, initial_filter_size=args.initial_filter_size, kernel_size=3, classes=args.classes, do_instancenorm=True) 40 | 41 | if args.restart: 42 | logger.print('loading from saved model'+args.pretrained_model_path) 43 | dict = torch.load(args.pretrained_model_path, 44 | map_location=lambda storage, loc: storage) 45 | save_model = dict["net"] 46 | model.load_state_dict(save_model) 47 | 48 | model.to(args.device) 49 | model = torch.nn.DataParallel(model) 50 | 51 | num_parameters = sum([l.nelement() for l in model.module.parameters()]) 52 | logger.print(f"number of parameters: {num_parameters}") 53 | 54 | if args.dataset == 'chd': 55 | training_keys = os.listdir(os.path.join(args.data_dir,'train')) 56 | training_keys.sort() 57 | train_dataset = CHD(keys=training_keys, purpose='train', args=args) 58 | elif args.dataset == 'acdc': 59 | train_dataset = ACDC(keys=list(range(1,101)), purpose='train', args=args) 60 | 61 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_works, drop_last=True) 62 | 63 | # define loss function (criterion) and optimizer 64 | criterion = SupConLoss(threshold=args.slice_threshold, temperature=args.temp, contrastive_method=args.contrastive_method).to(args.device) 65 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-5) 66 | scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(train_loader)) 67 | 68 | for epoch in range(args.epochs): 69 | # train for one epoch 70 | train_loss = train(train_loader, model, criterion, epoch, optimizer, scheduler, logger, args) 71 | 72 | logger.print('\n Epoch: {0}\t' 73 | 'Training Loss {train_loss:.4f} \t' 74 | .format(epoch + 1, train_loss=train_loss)) 75 | 76 | writer.add_scalar('training_loss', train_loss, epoch) 77 | writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) 78 | 79 | # save model 80 | save_dict = {"net": model.module.state_dict()} 81 | torch.save(save_dict, os.path.join(args.model_result_dir, "latest.pth")) 82 | 83 | def train(data_loader, model, criterion, epoch, optimizer, scheduler, logger, args): 84 | model.train() 85 | losses = AverageMeter() 86 | for batch_idx, tup in enumerate(data_loader): 87 | scheduler(optimizer, batch_idx, epoch) 88 | img1, img2, slice_position, partition = tup 89 | image1_var = Variable(img1.float(), requires_grad=False).to(args.device) 90 | image2_var = Variable(img2.float(), requires_grad=False).to(args.device) 91 | f1_1 = model(image1_var) 92 | f2_1 = model(image2_var) 93 | bsz = img1.shape[0] 94 | features = torch.cat([f1_1.unsqueeze(1), f2_1.unsqueeze(1)], dim=1) 95 | if args.contrastive_method == 'pcl': 96 | loss = criterion(features, labels=slice_position) 97 | elif args.contrastive_method == 'gcl': 98 | loss = criterion(features, labels=partition) 99 | else: # simclr 100 | loss = criterion(features) 101 | losses.update(loss.item(), bsz) 102 | optimizer.zero_grad() 103 | loss.backward() 104 | optimizer.step() 105 | logger.print(f"epoch:{epoch}, batch:{batch_idx}/{len(data_loader)}, lr:{optimizer.param_groups[0]['lr']:.6f}, loss:{losses.avg:.4f}") 106 | return losses.avg 107 | 108 | if __name__ == '__main__': 109 | main() -------------------------------------------------------------------------------- /dataset/mmwhs.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import matplotlib.pyplot as plt 6 | from batchgenerators.utilities.file_and_folder_operations import * 7 | from batchgenerators.transforms.abstract_transforms import Compose, RndTransform 8 | from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform 9 | from batchgenerators.transforms.crop_and_pad_transforms import RandomCropTransform 10 | from torch.utils.data.dataset import Dataset 11 | from random import choice 12 | from .utils import * 13 | 14 | class MMWHS(Dataset): 15 | 16 | def __init__(self, keys, purpose, args): 17 | self.data_dir = args.data_dir 18 | self.patch_size = args.patch_size 19 | self.purpose = purpose 20 | self.classes = args.classes 21 | self.files = [] 22 | with open(os.path.join(self.data_dir, "mean_std.pkl"), 'rb') as f: 23 | mean_std = pickle.load(f) 24 | self.means = [] 25 | self.stds = [] 26 | for key in keys: 27 | frames = subfiles(join(self.data_dir, 'supervised', 'ct_train_'+str(key)), False, None, ".npz", True) 28 | frames.sort() 29 | for frame in frames: 30 | self.means.append(mean_std['ct_train_'+str(key)]['mean']) 31 | self.stds.append(mean_std['ct_train_'+str(key)]['std']) 32 | self.files.append(join(self.data_dir, 'supervised', 'ct_train_'+str(key), frame)) 33 | print(f'dataset length: {len(self.files)}') 34 | 35 | def __getitem__(self, index): 36 | all_data = np.load(self.files[index])['data'] 37 | img = all_data[0].astype(np.float32) 38 | img -= self.means[index] 39 | img /= self.stds[index] 40 | label = all_data[1].astype(np.float32) 41 | img, label = self.prepare_supervised(img, label) 42 | return img, label 43 | 44 | # this function for normal supervised training 45 | def prepare_supervised(self, img, label): 46 | if self.purpose == 'train': 47 | # pad image 48 | img, coord = pad_and_or_crop(img, self.patch_size, mode='random') 49 | label, _ = pad_and_or_crop(label, self.patch_size, mode='fixed', coords=coord) 50 | # the image and label should be [batch, c, x, y, z], this is the adapatation for using batchgenerators :) 51 | data_dict = {'data':img[None, None], 'seg':label[None, None]} 52 | tr_transforms = [] 53 | tr_transforms.append(MirrorTransform((0, 1))) 54 | tr_transforms.append(RndTransform(SpatialTransform(self.patch_size, list(np.array(self.patch_size)//2), 55 | True, (100., 350.), (14., 17.), 56 | True, (0, 2.*np.pi), (-0.000001, 0.00001), (-0.000001, 0.00001), 57 | True, (0.7, 1.3), 'constant', 0, 3, 'constant', 0, 0, 58 | random_crop=False), prob=0.67, alternative_transform=RandomCropTransform(self.patch_size))) 59 | 60 | train_transform = Compose(tr_transforms) 61 | data_dict = train_transform(**data_dict) 62 | img = data_dict.get('data')[0] 63 | label = data_dict.get('seg')[0] 64 | return img, label 65 | else: 66 | # pad image 67 | img, coord = pad_and_or_crop(img, self.patch_size, mode='centre') 68 | label, _ = pad_and_or_crop(label, self.patch_size, mode='fixed', coords=coord) 69 | return img[None], label[None] 70 | 71 | def __len__(self): 72 | return len(self.files) 73 | 74 | if __name__ == "__main__": 75 | import argparse 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument("--data_dir", type=str, default="d:/data/mmwhs/test/") 78 | parser.add_argument("--patch_size", type=tuple, default=(256, 256)) 79 | parser.add_argument("--device", type=str, default='cpu') 80 | parser.add_argument("--classes", type=int, default=8) 81 | args = parser.parse_args() 82 | 83 | all_keys = np.arange(1001, 1011) 84 | train_dataset = MMWHS(keys=all_keys, purpose='val', args=args) 85 | train_dataloader = torch.utils.data.DataLoader(train_dataset, 86 | batch_size=20, 87 | shuffle=True, 88 | num_workers=8, 89 | drop_last=False) 90 | for batch_idx, tup in enumerate(train_dataloader): 91 | img, label = tup 92 | print(f'img shape:{img.shape}') 93 | print(f'label unique:{np.unique(label.numpy())}') 94 | plt.figure(1) 95 | img_grid = torchvision.utils.make_grid(img) 96 | matplotlib_imshow(img_grid, one_channel=False) 97 | plt.figure(2) 98 | img_grid = torchvision.utils.make_grid(label) 99 | matplotlib_imshow(img_grid, one_channel=False) 100 | plt.show() 101 | break -------------------------------------------------------------------------------- /dataset/generate_mmwhs.py: -------------------------------------------------------------------------------- 1 | import SimpleITK as sitk 2 | import os 3 | import numpy as np 4 | from utils import * 5 | import pickle 6 | from collections import OrderedDict 7 | from skimage.transform import resize 8 | 9 | def resize_image(image, old_spacing, new_spacing, order=3): 10 | new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), 11 | int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), 12 | int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) 13 | return resize(image, new_shape, order=order, mode='edge') 14 | 15 | 16 | def convert_to_one_hot(seg): 17 | vals = np.unique(seg) 18 | res = np.zeros([len(vals)] + list(seg.shape), seg.dtype) 19 | for c in range(len(vals)): 20 | res[c][seg == c] = 1 21 | return res 22 | 23 | def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5), keep_z_spacing=False): 24 | spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]] 25 | image = sitk.GetArrayFromImage(itk_image).astype(float) 26 | if keep_z_spacing: 27 | spacing_target = list(spacing_target) 28 | spacing_target[0] = spacing[0] 29 | if not is_seg: 30 | order_img = 3 31 | if not keep_z_spacing: 32 | order_img = 1 33 | image = resize_image(image, spacing, spacing_target, order=order_img).astype(np.float32) 34 | # image -= image.mean() 35 | # image /= image.std() 36 | else: 37 | tmp = convert_to_one_hot(image) 38 | vals = np.unique(image) 39 | results = [] 40 | for i in range(len(tmp)): 41 | results.append(resize_image(tmp[i].astype(float), spacing, spacing_target, 1)[None]) 42 | image = vals[np.vstack(results).argmax(0)] 43 | return image 44 | 45 | def generate_mmwhs_dataset(data_dir, save_dir): 46 | if not os.path.exists(save_dir): 47 | os.mkdir(save_dir) 48 | if not os.path.exists(os.path.join(save_dir, 'supervised')): 49 | os.mkdir(os.path.join(save_dir, 'supervised')) 50 | i = 1001 51 | results = OrderedDict() 52 | for j in range(20): 53 | print(f'processing ct_train_{i}...') 54 | image_path = os.path.join(data_dir,'ct_train_'+str(i)+'_image.nii.gz') 55 | label_path = os.path.join(data_dir,'ct_train_'+str(i)+'_label.nii.gz') 56 | itk_image = sitk.ReadImage(image_path) 57 | itk_label = sitk.ReadImage(label_path) 58 | label_npy = sitk.GetArrayViewFromImage(itk_label) 59 | label_npy_copy = label_npy.copy() 60 | label_npy_copy[label_npy_copy==205]=1 61 | label_npy_copy[label_npy_copy==420]=2 62 | label_npy_copy[label_npy_copy==500]=3 63 | label_npy_copy[label_npy_copy==550]=4 64 | label_npy_copy[label_npy_copy==600]=5 65 | label_npy_copy[label_npy_copy==820]=6 66 | label_npy_copy[label_npy_copy==850]=7 67 | label_npy_copy[label_npy_copy>100]=0 68 | itk_label_copy = sitk.GetImageFromArray(label_npy_copy) 69 | itk_label_copy.SetSpacing(itk_label.GetSpacing()) 70 | itk_label_copy.SetDirection(itk_label.GetDirection()) 71 | print(f'image spacing:{itk_image.GetSpacing()}') 72 | print(f'original image size:{itk_image.GetSize()}') 73 | image_npy = preprocess_image(itk_image, is_seg=False, spacing_target=(1, 1.0, 1.0), keep_z_spacing=True) 74 | label_npy = preprocess_image(itk_label_copy, is_seg=True, spacing_target=(1, 1.0, 1.0), keep_z_spacing=True) 75 | print(f'resized image size:{image_npy.shape}') 76 | # convert label 77 | image_npy_copy = image_npy.copy().astype(np.int16) 78 | # remove the pixels that are too smaller can increase the image contrast 79 | min_val_1p=np.percentile(image_npy_copy,1) 80 | max_val_99p=np.percentile(image_npy_copy,99) 81 | image_npy_copy[image_npy_copymax_val_99p]=max_val_99p 83 | image_npy_copy[image_npy_copy<-2000]=-1110 84 | mean = image_npy_copy.astype(float).mean() 85 | std = image_npy_copy.astype(float).std() 86 | results['ct_train_'+str(i)] = {} 87 | results['ct_train_'+str(i)]['mean'] = mean 88 | results['ct_train_'+str(i)]['std'] = std 89 | print(f'after label_npy unique:{np.unique(label_npy)}') 90 | for n in range(image_npy.shape[0]): 91 | tmp_image = image_npy_copy[n,:,:] 92 | tmp_label = label_npy[n,:,:] 93 | all_data = np.stack([tmp_image, tmp_label],axis=0) 94 | if not os.path.exists(os.path.join(save_dir, 'supervised', 'ct_train_'+str(i))): 95 | os.mkdir(os.path.join(save_dir, 'supervised', 'ct_train_'+str(i))) 96 | save_path_image = os.path.join(save_dir, 'supervised', 'ct_train_'+str(i), 'frame'+str.format('%03d'%n)) 97 | np.savez_compressed(save_path_image, data=all_data) 98 | i = i + 1 99 | 100 | with open(os.path.join(save_dir, "mean_std.pkl"), 'wb') as f: 101 | pickle.dump(results, f) 102 | 103 | if __name__ == "__main__": 104 | import argparse 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument("-indir", help="folder where the extracted training data is", type=str, default='d:/data/mmwhs/ct_train1') 107 | parser.add_argument("-labeled_outdir", help="folder where to save the data for the 2d network", type=str, default='d:/data/mmwhs/test') 108 | args = parser.parse_args() 109 | generate_mmwhs_dataset(args.indir, args.labeled_outdir) 110 | 111 | # example 112 | # python generate_mmwhs.py -indir /afs/crc.nd.edu/user/d/dzeng2/data/mmwhs/ct/raw_data -labeled_outdir /afs/crc.nd.edu/user/d/dzeng2/data/mmwhs/ct/test 113 | -------------------------------------------------------------------------------- /dataset/acdc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from batchgenerators.utilities.file_and_folder_operations import * 4 | from batchgenerators.transforms.abstract_transforms import Compose, RndTransform 5 | from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform 6 | from batchgenerators.transforms.crop_and_pad_transforms import RandomCropTransform 7 | from torch.utils.data.dataset import Dataset 8 | from random import choice 9 | from .utils import * 10 | 11 | class ACDC(Dataset): 12 | 13 | def __init__(self, keys, purpose, args): 14 | self.data_dir = args.data_dir 15 | self.patch_size = args.patch_size 16 | self.purpose = purpose 17 | self.classes = args.classes 18 | self.do_contrast = args.do_contrast 19 | self.files = [] 20 | if self.do_contrast: 21 | # we do not pre-load all data, instead, load data in the get item function 22 | self.slice_position = [] 23 | self.partition = [] 24 | self.slices = [] 25 | for key in keys: 26 | frames = subfiles(join(self.data_dir, 'patient_%03d'%key), False, None, ".npy", True) 27 | for frame in frames: 28 | image = np.load(join(self.data_dir, 'patient_%03d'%key, frame)) 29 | for i in range(image.shape[0]): 30 | self.files.append(join(self.data_dir, 'patient_%03d'%key, frame)) 31 | self.slices.append(i) 32 | self.slice_position.append(float(i+1)/image.shape[0]) 33 | part = image.shape[0] / 4.0 34 | if part - int(part) >= 0.5: 35 | part = int(part + 1) 36 | else: 37 | part = int(part) 38 | self.partition.append(max(0,min(int(i//part),3)+1)) 39 | else: 40 | for key in keys: 41 | frames = subfiles(join(self.data_dir, 'patient_%03d'%key), False, None, ".npy", True) 42 | for frame in frames: 43 | image = np.load(join(self.data_dir, 'patient_%03d'%key, frame)) 44 | for i in range(image.shape[1]): 45 | self.files.append(image[:,i]) 46 | print(f'dataset length: {len(self.files)}') 47 | 48 | def __getitem__(self, index): 49 | if not self.do_contrast: 50 | img = self.files[index][0].astype(np.float32) 51 | label = self.files[index][1] 52 | img, label = self.prepare_supervised(img, label) 53 | return img, label 54 | else: 55 | img = np.load(self.files[index]).astype(np.float32)[self.slices[index]] 56 | img1, img2 = self.prepare_contrast(img) 57 | return img1, img2, self.slice_position[index], self.partition[index] 58 | 59 | # this function for normal supervised training 60 | def prepare_supervised(self, img, label): 61 | if self.purpose == 'train': 62 | # resize image 63 | img, coord = pad_and_or_crop(img, self.patch_size, mode='random') 64 | label, _ = pad_and_or_crop(label, self.patch_size, mode='fixed', coords=coord) 65 | # the image and label should be [batch, c, x, y, z], this is the adapatation for using batchgenerators :) 66 | data_dict = {'data':img[None, None], 'seg':label[None, None]} 67 | tr_transforms = [] 68 | tr_transforms.append(MirrorTransform((0, 1))) 69 | tr_transforms.append(RndTransform(SpatialTransform(self.patch_size, list(np.array(self.patch_size)//2), 70 | True, (100., 350.), (14., 17.), 71 | True, (0, 2.*np.pi), (-0.000001, 0.00001), (-0.000001, 0.00001), 72 | True, (0.7, 1.3), 'constant', 0, 3, 'constant', 0, 0, 73 | random_crop=False), prob=0.67, alternative_transform=RandomCropTransform(self.patch_size))) 74 | 75 | train_transform = Compose(tr_transforms) 76 | data_dict = train_transform(**data_dict) 77 | img = data_dict.get('data')[0] 78 | label = data_dict.get('seg')[0] 79 | return img, label 80 | else: 81 | # resize image 82 | img, coord = pad_and_or_crop(img, self.patch_size, mode='centre') 83 | label, _ = pad_and_or_crop(label, self.patch_size, mode='fixed', coords=coord) 84 | return img[None], label[None] 85 | 86 | # use this function for contrastive learning 87 | def prepare_contrast(self, img): 88 | # resize image 89 | img, coord = pad_and_or_crop(img, self.patch_size, mode='random') 90 | # the image and label should be [batch, c, x, y, z], this is the adapatation for using batchgenerators :) 91 | data_dict = {'data':img[None, None]} 92 | tr_transforms = [] 93 | tr_transforms.append(MirrorTransform((0, 1))) 94 | tr_transforms.append(RndTransform(SpatialTransform(self.patch_size, list(np.array(self.patch_size)//2), 95 | True, (100., 350.), (14., 17.), 96 | True, (0, 2.*np.pi), (-0.000001, 0.00001), (-0.000001, 0.00001), 97 | True, (0.7, 1.3), 'constant', 0, 3, 'constant', 0, 0, 98 | random_crop=False), prob=0.67, alternative_transform=RandomCropTransform(self.patch_size))) 99 | 100 | train_transform = Compose(tr_transforms) 101 | data_dict1 = train_transform(**data_dict) 102 | img1 = data_dict1.get('data')[0] 103 | data_dict2 = train_transform(**data_dict) 104 | img2 = data_dict2.get('data')[0] 105 | return img1, img2 106 | 107 | def __len__(self): 108 | return len(self.files) 109 | 110 | if __name__ == "__main__": 111 | import argparse 112 | parser = argparse.ArgumentParser() 113 | # parser.add_argument("--data_dir", type=str, default="d:/data/acdc/acdc_contrastive/contrastive/2d/") 114 | parser.add_argument("--data_dir", type=str, default="/afs/crc.nd.edu/user/d/dzeng2/data/acdc/acdc_contrastive/contrastive/2d/") 115 | parser.add_argument("--patch_size", type=tuple, default=(352, 352)) 116 | parser.add_argument("--classes", type=int, default=4) 117 | parser.add_argument("--do_contrast", default=True, action='store_true') 118 | parser.add_argument("--slice_threshold", type=float, default=0.5) 119 | args = parser.parse_args() 120 | 121 | 122 | train_dataset = ACDC(keys=list(range(1,101)), purpose='train', args=args) 123 | train_dataloader = torch.utils.data.DataLoader(train_dataset, 124 | batch_size=32, 125 | shuffle=True, 126 | num_workers=8, 127 | drop_last=False) 128 | 129 | pp = [] 130 | for batch_idx, tup in enumerate(train_dataloader): 131 | print(f'the {batch_idx}th/{len(train_dataloader)} minibatch...') 132 | img1, img2, slice_position, partition = tup 133 | batch_size = img1.shape[0] 134 | # print(f'batch_size:{batch_size}, slice_position:{slice_position}') 135 | slice_position = slice_position.contiguous().view(-1, 1) 136 | mask = (torch.abs(slice_position.T.repeat(batch_size,1) - slice_position.repeat(1,batch_size)) < args.slice_threshold).float() 137 | # count how many positive pair in each batch 138 | for i in range(batch_size): 139 | pp.append(2*mask[i].sum()-1) 140 | pp = np.asarray(pp) 141 | pp_mean = np.mean(pp) 142 | pp_std = np.std(pp) 143 | print(f'average number of positive pairs mean:{pp_mean}, std:{pp_std}') 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /dataset/chd.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import torch 4 | import os 5 | from batchgenerators.utilities.file_and_folder_operations import * 6 | from batchgenerators.transforms.abstract_transforms import Compose, RndTransform 7 | from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform 8 | from batchgenerators.transforms.crop_and_pad_transforms import RandomCropTransform 9 | from torch.utils.data.dataset import Dataset 10 | from random import choice 11 | from .utils import * 12 | 13 | class CHD(Dataset): 14 | 15 | def __init__(self, keys, purpose, args): 16 | self.data_dir = args.data_dir 17 | self.patch_size = args.patch_size 18 | self.purpose = purpose 19 | self.classes = args.classes 20 | self.do_contrast = args.do_contrast 21 | self.files = [] 22 | with open(os.path.join(self.data_dir, "mean_std.pkl"), 'rb') as f: 23 | mean_std = pickle.load(f) 24 | if self.do_contrast: 25 | # we do not pre-load all data, instead, load data in the get item function 26 | self.slice_position = [] 27 | self.partition = [] 28 | self.means = [] 29 | self.stds = [] 30 | for key in keys: 31 | frames = subfiles(join(self.data_dir, 'train', key), False, None, ".npy", True) 32 | frames.sort() 33 | i = 0 34 | for frame in frames: 35 | self.files.append(join(self.data_dir, 'train', key, frame)) 36 | self.means.append(mean_std[key]['mean']) 37 | self.stds.append(mean_std[key]['std']) 38 | self.slice_position.append(float(i+1)/len(frames)) 39 | part = len(frames) / 4.0 40 | if part - int(part) >= 0.5: 41 | part = int(part + 1) 42 | else: 43 | part = int(part) 44 | self.partition.append(max(0,min(int(i//part),3)+1)) 45 | i = i + 1 46 | else: 47 | self.means = [] 48 | self.stds = [] 49 | for key in keys: 50 | frames = subfiles(join(self.data_dir, 'train', 'ct_'+str(key)), False, None, ".npz", True) 51 | frames.sort() 52 | for frame in frames: 53 | self.means.append(mean_std['ct_'+str(key)]['mean']) 54 | self.stds.append(mean_std['ct_'+str(key)]['std']) 55 | self.files.append(join(self.data_dir, 'train', 'ct_'+str(key), frame)) 56 | print(f'dataset length: {len(self.files)}') 57 | 58 | def __getitem__(self, index): 59 | if self.do_contrast: 60 | image = np.load(self.files[index]).astype(np.float32) 61 | # do preprocessing 62 | image -= self.means[index] 63 | image /= self.stds[index] 64 | img1, img2 = self.prepare_contrast(image) 65 | return img1, img2, self.slice_position[index], self.partition[index] 66 | else: 67 | all_data = np.load(self.files[index])['data'] 68 | img = all_data[0].astype(np.float32) 69 | img -= self.means[index] 70 | img /= self.stds[index] 71 | label = all_data[1].astype(np.float32) 72 | img, label = self.prepare_supervised(img, label) 73 | return img, label 74 | 75 | # this function for normal supervised training 76 | def prepare_supervised(self, img, label): 77 | if self.purpose == 'train': 78 | # pad image 79 | img, coord = pad_and_or_crop(img, self.patch_size, mode='random') 80 | label, _ = pad_and_or_crop(label, self.patch_size, mode='fixed', coords=coord) 81 | # No augmentation is used in the finetuning because augmention could hurt the performance. 82 | return img[None], label[None] 83 | 84 | else: 85 | # resize image 86 | img, coord = pad_and_or_crop(img, self.patch_size, mode='centre') 87 | label, _ = pad_and_or_crop(label, self.patch_size, mode='fixed', coords=coord) 88 | return img[None], label[None] 89 | 90 | # use this function for contrastive learning 91 | def prepare_contrast(self, img): 92 | # resize image 93 | img, coord = pad_and_or_crop(img, self.patch_size, mode='random') 94 | # the image and label should be [batch, c, x, y, z], this is the adapatation for using batchgenerators :) 95 | data_dict = {'data':img[None, None]} 96 | tr_transforms = [] 97 | tr_transforms.append(MirrorTransform((0, 1))) 98 | tr_transforms.append(RndTransform(SpatialTransform(self.patch_size, list(np.array(self.patch_size)//2), 99 | True, (100., 350.), (14., 17.), 100 | True, (0, 2.*np.pi), (-0.000001, 0.00001), (-0.000001, 0.00001), 101 | True, (0.7, 1.3), 'constant', 0, 3, 'constant', 0, 0, 102 | random_crop=False), prob=0.67, alternative_transform=RandomCropTransform(self.patch_size))) 103 | 104 | train_transform = Compose(tr_transforms) 105 | data_dict1 = train_transform(**data_dict) 106 | img1 = data_dict1.get('data')[0] 107 | data_dict2 = train_transform(**data_dict) 108 | img2 = data_dict2.get('data')[0] 109 | return img1, img2 110 | 111 | def __len__(self): 112 | return len(self.files) 113 | 114 | def get_split_chd(data_dir, fold, seed=12345): 115 | # this is seeded, will be identical each time 116 | all_keys = np.arange(0, 50) 117 | cases = os.listdir(data_dir) 118 | cases.sort() 119 | i = 0 120 | for case in cases: 121 | all_keys[i] = int(case[-4:]) 122 | i = i + 1 123 | kf = KFold(n_splits=5, shuffle=True, random_state=seed) 124 | splits = kf.split(all_keys) 125 | for i, (train_idx, test_idx) in enumerate(splits): 126 | train_keys = all_keys[train_idx] 127 | test_keys = all_keys[test_idx] 128 | if i == fold: 129 | break 130 | return train_keys, test_keys 131 | 132 | if __name__ == "__main__": 133 | import argparse 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument("--data_dir", type=str, default="/afs/crc.nd.edu/user/d/dzeng2/data/chd/preprocessed_without_label/") 136 | parser.add_argument("--patch_size", type=tuple, default=(512, 512)) 137 | parser.add_argument("--classes", type=int, default=8) 138 | parser.add_argument("--do_contrast", default=True, action='store_true') 139 | parser.add_argument("--slice_threshold", type=float, default=0.05) 140 | args = parser.parse_args() 141 | 142 | train_keys = os.listdir(os.path.join(args.data_dir,'train')) 143 | train_keys.sort() 144 | train_dataset = CHD(keys=train_keys, purpose='train', args=args) 145 | train_dataloader = torch.utils.data.DataLoader(train_dataset, 146 | batch_size=30, 147 | shuffle=True, 148 | num_workers=8, 149 | drop_last=False) 150 | 151 | pp = [] 152 | n = 0 153 | for batch_idx, tup in enumerate(train_dataloader): 154 | print(f'the {n}th minibatch...') 155 | img1, img2, slice_position, partition = tup 156 | batch_size = img1.shape[0] 157 | # print(f'batch_size:{batch_size}, slice_position:{slice_position}') 158 | slice_position = slice_position.contiguous().view(-1, 1) 159 | mask = (torch.abs(slice_position.T.repeat(batch_size,1) - slice_position.repeat(1,batch_size)) < args.slice_threshold).float() 160 | # count how many positive pair in each batch 161 | for i in range(mask.shape[0]): 162 | pp.append(mask[i].sum()-1) 163 | n = n + 1 164 | if n > 100: 165 | break 166 | pp = np.asarray(pp) 167 | pp_mean = np.mean(pp) 168 | pp_std = np.std(pp) 169 | print(f'mean:{pp_mean}, std:{pp_std}') -------------------------------------------------------------------------------- /train_supervised.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from utils import * 4 | import random 5 | from network.unet2d import UNet2D 6 | from dataset.chd import CHD 7 | from dataset.acdc import ACDC 8 | from dataset.mmwhs import MMWHS 9 | from dataset.hvsmr import HVSMR 10 | import torch.nn.functional as F 11 | from metrics import SegmentationMetric 12 | from myconfig import get_config 13 | from batchgenerators.utilities.file_and_folder_operations import * 14 | from lr_scheduler import LR_Scheduler 15 | from torch.utils.tensorboard import SummaryWriter 16 | from experiment_log import PytorchExperimentLogger 17 | 18 | def run(fold, writer, args): 19 | 20 | maybe_mkdir_p(os.path.join(args.save_path, 'cross_val_'+str(fold))) 21 | logger = PytorchExperimentLogger(os.path.join(args.save_path, 'cross_val_'+str(fold)), "elog", ShowTerminal=True) 22 | # setup cuda 23 | args.device = torch.device(args.device if torch.cuda.is_available() else "cpu") 24 | logger.print(f"the model will run on device:{args.device}") 25 | torch.manual_seed(args.seed) 26 | if 'cuda' in str(args.device): 27 | torch.cuda.manual_seed_all(args.seed) 28 | logger.print(f"starting training for cross validation fold {fold} ...") 29 | model_result_dir = join(args.save_path, 'cross_val_'+str(fold), 'model') 30 | maybe_mkdir_p(model_result_dir) 31 | args.model_result_dir = model_result_dir 32 | # create model 33 | logger.print("creating model ...") 34 | model = UNet2D(in_channels=1, initial_filter_size=args.initial_filter_size, kernel_size=3, classes=args.classes, do_instancenorm=True) 35 | if args.restart: 36 | logger.print('loading from saved model ' + args.pretrained_model_path) 37 | dict = torch.load(args.pretrained_model_path, 38 | map_location=lambda storage, loc: storage) 39 | save_model = dict["net"] 40 | model_dict = model.state_dict() 41 | # we only need to load the parameters of the encoder 42 | state_dict = {k: v for k, v in save_model.items() if "encoder" in k} 43 | model_dict.update(state_dict) 44 | model.load_state_dict(model_dict) 45 | model.to(args.device) 46 | 47 | num_parameters = sum([l.nelement() for l in model.parameters()]) 48 | logger.print(f"number of parameters: {num_parameters}") 49 | 50 | if args.dataset == 'chd': 51 | train_keys, val_keys = get_split_chd(os.path.join(args.data_dir,'train'), fold, args.cross_vali_num) 52 | # now random sample train_keys 53 | if args.enable_few_data: 54 | random.seed(args.seed) 55 | train_keys = random.sample(list(train_keys), k=args.sampling_k) 56 | logger.print(f'train_keys:{train_keys}') 57 | logger.print(f'val_keys:{val_keys}') 58 | train_dataset = CHD(keys=train_keys, purpose='train', args=args) 59 | validate_dataset = CHD(keys=val_keys, purpose='val', args=args) 60 | elif args.dataset == 'mmwhs': 61 | train_keys, val_keys = get_split_mmwhs(fold, args.cross_vali_num) 62 | if args.enable_few_data: 63 | random.seed(args.seed) 64 | train_keys = random.sample(list(train_keys), k=args.sampling_k) 65 | logger.print(f'train_keys:{train_keys}') 66 | train_dataset = MMWHS(keys=train_keys, purpose='val', args=args) 67 | logger.print('training data dir '+train_dataset.data_dir) 68 | validate_dataset = MMWHS(keys=val_keys, purpose='val', args=args) 69 | elif args.dataset == 'acdc': 70 | train_keys, val_keys = get_split_acdc(fold, args.cross_vali_num) 71 | if args.enable_few_data: 72 | random.seed(args.seed) 73 | train_keys = random.sample(list(train_keys), k=args.sampling_k) 74 | logger.print(f'train_keys:{train_keys}') 75 | logger.print(f'val_keys:{val_keys}') 76 | train_dataset = ACDC(keys=train_keys, purpose='train', args=args) 77 | validate_dataset = ACDC(keys=val_keys, purpose='val', args=args) 78 | elif args.dataset == 'hvsmr': 79 | train_keys, val_keys = get_split_hvsmr(fold, args.cross_vali_num) 80 | if args.enable_few_data: 81 | random.seed(args.seed) 82 | train_keys = random.sample(list(train_keys), k=args.sampling_k) 83 | logger.print(f'train_keys:{train_keys}') 84 | logger.print(f'val_keys:{val_keys}') 85 | train_dataset = HVSMR(keys=train_keys, purpose='train', args=args) 86 | validate_dataset = HVSMR(keys=val_keys, purpose='val', args=args) 87 | 88 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_works, drop_last=False) 89 | validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_works, drop_last=False) 90 | 91 | criterion = torch.nn.CrossEntropyLoss() 92 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=1e-5) 93 | scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(train_loader), min_lr=args.min_lr) 94 | best_dice = 0 95 | for epoch in range(args.epochs): 96 | # train for one epoch 97 | train_loss, train_dice = train(train_loader, model, criterion, epoch, optimizer, scheduler, logger, args) 98 | writer.add_scalar('training_loss_fold'+str(fold), train_loss, epoch) 99 | writer.add_scalar('training_dice_fold'+str(fold), train_dice, epoch) 100 | writer.add_scalar('learning_rate_fold'+str(fold), optimizer.param_groups[0]['lr'], epoch) 101 | if (epoch % 2 == 0): 102 | # evaluate for one epoch 103 | val_dice = validate(validate_loader, model, epoch, logger, args) 104 | 105 | logger.print('Epoch: {0}\t' 106 | 'Training Loss {train_loss:.4f} \t' 107 | 'Validation Dice {val_dice:.4f} \t' 108 | .format(epoch, train_loss=train_loss, val_dice=val_dice)) 109 | 110 | if best_dice < val_dice: 111 | best_dice = val_dice 112 | save_dict = {"net": model.state_dict()} 113 | torch.save(save_dict, os.path.join(args.model_result_dir, "best.pth")) 114 | writer.add_scalar('validate_dice_fold'+str(fold), val_dice, epoch) 115 | writer.add_scalar('best_dice_fold'+str(fold), best_dice, epoch) 116 | # save model 117 | save_dict = {"net": model.state_dict()} 118 | torch.save(save_dict, os.path.join(args.model_result_dir, "latest.pth")) 119 | 120 | def train(data_loader, model, criterion, epoch, optimizer, scheduler, logger, args): 121 | model.train() 122 | metric_val = SegmentationMetric(args.classes) 123 | metric_val.reset() 124 | losses = AverageMeter() 125 | for batch_idx, tup in enumerate(data_loader): 126 | img, label = tup 127 | image_var = img.float().to(args.device) 128 | label = label.long().to(args.device) 129 | scheduler(optimizer, batch_idx, epoch) 130 | x_out = model(image_var) 131 | loss = criterion(x_out, label.squeeze(dim=1)) 132 | losses.update(loss.item(), image_var.size(0)) 133 | optimizer.zero_grad() 134 | loss.backward() 135 | optimizer.step() 136 | # Do softmax 137 | x_out = F.softmax(x_out, dim=1) 138 | metric_val.update(label.long().squeeze(dim=1), x_out) 139 | _, _, Dice = metric_val.get() 140 | logger.print(f"Training epoch:{epoch}, batch:{batch_idx}/{len(data_loader)}, lr:{optimizer.param_groups[0]['lr']:.6f}, loss:{losses.avg:.4f}, mean Dice:{Dice:.4f}") 141 | pixAcc, mIoU, mDice = metric_val.get() 142 | return losses.avg, mDice 143 | 144 | def validate(data_loader, model, epoch, logger, args): 145 | model.eval() 146 | metric_val = SegmentationMetric(args.classes) 147 | metric_val.reset() 148 | with torch.no_grad(): 149 | for batch_idx, tup in enumerate(data_loader): 150 | img, label = tup 151 | image_var = img.float().to(args.device) 152 | label = label.long().to(args.device) 153 | x_out = model(image_var) 154 | x_out = F.softmax(x_out, dim=1) 155 | metric_val.update(label.long().squeeze(dim=1), x_out) 156 | pixAcc, mIoU, Dice = metric_val.get() 157 | logger.print(f"Validation epoch:{epoch}, batch:{batch_idx}/{len(data_loader)}, mean Dice:{Dice}") 158 | pixAcc, mIoU, Dice = metric_val.get() 159 | return Dice 160 | 161 | if __name__ == '__main__': 162 | # initialize config 163 | args = get_config() 164 | if args.save == '': 165 | args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 166 | args.save_path = os.path.join(args.results_dir, args.experiment_name + args.save) 167 | if not os.path.exists(args.save_path): 168 | os.makedirs(args.save_path) 169 | writer = SummaryWriter(os.path.join(args.runs_dir, args.experiment_name + args.save)) 170 | for i in range(0, args.cross_vali_num): 171 | run(i, writer, args) -------------------------------------------------------------------------------- /network/unet2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ["UNet2D"] 6 | 7 | class InitWeights_He(object): 8 | def __init__(self, neg_slope=1e-2): 9 | self.neg_slope = neg_slope 10 | 11 | def __call__(self, module): 12 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, 13 | nn.ConvTranspose2d) or isinstance( 14 | module, nn.ConvTranspose3d): 15 | module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) 16 | if module.bias is not None: 17 | module.bias = nn.init.constant_(module.bias, 0) 18 | 19 | class encoder(nn.Module): 20 | def __init__(self, in_channels, initial_filter_size, kernel_size, do_instancenorm): 21 | super().__init__() 22 | self.contr_1_1 = self.contract(in_channels, initial_filter_size, kernel_size, instancenorm=do_instancenorm) 23 | self.contr_1_2 = self.contract(initial_filter_size, initial_filter_size, kernel_size, 24 | instancenorm=do_instancenorm) 25 | self.pool = nn.MaxPool2d(2, stride=2) 26 | 27 | self.contr_2_1 = self.contract(initial_filter_size, initial_filter_size * 2, kernel_size, 28 | instancenorm=do_instancenorm) 29 | self.contr_2_2 = self.contract(initial_filter_size * 2, initial_filter_size * 2, kernel_size, 30 | instancenorm=do_instancenorm) 31 | 32 | self.contr_3_1 = self.contract(initial_filter_size * 2, initial_filter_size * 2 ** 2, kernel_size, 33 | instancenorm=do_instancenorm) 34 | self.contr_3_2 = self.contract(initial_filter_size * 2 ** 2, initial_filter_size * 2 ** 2, kernel_size, 35 | instancenorm=do_instancenorm) 36 | 37 | self.contr_4_1 = self.contract(initial_filter_size * 2 ** 2, initial_filter_size * 2 ** 3, kernel_size, 38 | instancenorm=do_instancenorm) 39 | self.contr_4_2 = self.contract(initial_filter_size * 2 ** 3, initial_filter_size * 2 ** 3, kernel_size, 40 | instancenorm=do_instancenorm) 41 | self.center = nn.Sequential( 42 | nn.Conv2d(initial_filter_size * 2 ** 3, initial_filter_size * 2 ** 4, 3, padding=1), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(initial_filter_size * 2 ** 4, initial_filter_size * 2 ** 4, 3, padding=1), 45 | nn.ReLU(inplace=True) 46 | ) 47 | 48 | def forward(self, x): 49 | contr_1 = self.contr_1_2(self.contr_1_1(x)) 50 | pool = self.pool(contr_1) 51 | 52 | contr_2 = self.contr_2_2(self.contr_2_1(pool)) 53 | pool = self.pool(contr_2) 54 | 55 | contr_3 = self.contr_3_2(self.contr_3_1(pool)) 56 | pool = self.pool(contr_3) 57 | 58 | contr_4 = self.contr_4_2(self.contr_4_1(pool)) 59 | pool = self.pool(contr_4) 60 | 61 | out = self.center(pool) 62 | return out, contr_4, contr_3, contr_2, contr_1 63 | 64 | @staticmethod 65 | def contract(in_channels, out_channels, kernel_size=3, instancenorm=True): 66 | if instancenorm: 67 | layer = nn.Sequential( 68 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), 69 | nn.BatchNorm2d(out_channels), 70 | nn.LeakyReLU(inplace=True)) 71 | else: 72 | layer = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), 74 | nn.LeakyReLU(inplace=True)) 75 | return layer 76 | 77 | class decoder(nn.Module): 78 | def __init__(self, initial_filter_size, classes): 79 | super().__init__() 80 | # self.concat_weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True) 81 | self.upscale5 = nn.ConvTranspose2d(initial_filter_size * 2 ** 4, initial_filter_size * 2 ** 3, kernel_size=2, 82 | stride=2) 83 | self.expand_4_1 = self.expand(initial_filter_size * 2 ** 4, initial_filter_size * 2 ** 3) 84 | self.expand_4_2 = self.expand(initial_filter_size * 2 ** 3, initial_filter_size * 2 ** 3) 85 | self.upscale4 = nn.ConvTranspose2d(initial_filter_size * 2 ** 3, initial_filter_size * 2 ** 2, kernel_size=2, 86 | stride=2) 87 | 88 | self.expand_3_1 = self.expand(initial_filter_size * 2 ** 3, initial_filter_size * 2 ** 2) 89 | self.expand_3_2 = self.expand(initial_filter_size * 2 ** 2, initial_filter_size * 2 ** 2) 90 | self.upscale3 = nn.ConvTranspose2d(initial_filter_size * 2 ** 2, initial_filter_size * 2, 2, stride=2) 91 | 92 | self.expand_2_1 = self.expand(initial_filter_size * 2 ** 2, initial_filter_size * 2) 93 | self.expand_2_2 = self.expand(initial_filter_size * 2, initial_filter_size * 2) 94 | self.upscale2 = nn.ConvTranspose2d(initial_filter_size * 2, initial_filter_size, 2, stride=2) 95 | 96 | self.expand_1_1 = self.expand(initial_filter_size * 2, initial_filter_size) 97 | self.expand_1_2 = self.expand(initial_filter_size, initial_filter_size) 98 | self.head = nn.Sequential( 99 | nn.Conv2d(initial_filter_size, classes, kernel_size=1, 100 | stride=1, bias=False)) 101 | 102 | def forward(self, x, contr_4, contr_3, contr_2, contr_1): 103 | 104 | concat_weight = 1 105 | upscale = self.upscale5(x) 106 | crop = self.center_crop(contr_4, upscale.size()[2], upscale.size()[3]) 107 | concat = torch.cat([upscale, crop * concat_weight], 1) 108 | 109 | expand = self.expand_4_2(self.expand_4_1(concat)) 110 | upscale = self.upscale4(expand) 111 | 112 | crop = self.center_crop(contr_3, upscale.size()[2], upscale.size()[3]) 113 | concat = torch.cat([upscale, crop * concat_weight], 1) 114 | 115 | expand = self.expand_3_2(self.expand_3_1(concat)) 116 | upscale = self.upscale3(expand) 117 | 118 | crop = self.center_crop(contr_2, upscale.size()[2], upscale.size()[3]) 119 | concat = torch.cat([upscale, crop * concat_weight], 1) 120 | 121 | expand = self.expand_2_2(self.expand_2_1(concat)) 122 | upscale = self.upscale2(expand) 123 | 124 | crop = self.center_crop(contr_1, upscale.size()[2], upscale.size()[3]) 125 | concat = torch.cat([upscale, crop * concat_weight], 1) 126 | 127 | expand = self.expand_1_2(self.expand_1_1(concat)) 128 | 129 | out = self.head(expand) 130 | return out 131 | 132 | @staticmethod 133 | def center_crop(layer, target_width, target_height): 134 | batch_size, n_channels, layer_width, layer_height = layer.size() 135 | xy1 = (layer_width - target_width) // 2 136 | xy2 = (layer_height - target_height) // 2 137 | return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)] 138 | 139 | @staticmethod 140 | def expand(in_channels, out_channels, kernel_size=3): 141 | layer = nn.Sequential( 142 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), 143 | nn.BatchNorm2d(out_channels), 144 | nn.LeakyReLU(inplace=True), 145 | ) 146 | return layer 147 | 148 | class UNet2D(nn.Module): 149 | def __init__(self, in_channels=1, initial_filter_size=32, kernel_size=3, classes=4, do_instancenorm=True): 150 | super().__init__() 151 | 152 | self.encoder = encoder(in_channels, initial_filter_size, kernel_size, do_instancenorm) 153 | self.decoder = decoder(initial_filter_size, classes) 154 | 155 | self.apply(InitWeights_He(1e-2)) 156 | 157 | def forward(self, x): 158 | 159 | x_1, contr_4, contr_3, contr_2, contr_1 = self.encoder(x) 160 | out = self.decoder(x_1, contr_4, contr_3, contr_2, contr_1) 161 | 162 | return out 163 | 164 | class UNet2D_classification(nn.Module): 165 | def __init__(self, in_channels=1, initial_filter_size=32, kernel_size=3, classes=3, do_instancenorm=True): 166 | super().__init__() 167 | 168 | self.encoder = encoder(in_channels, initial_filter_size, kernel_size, do_instancenorm) 169 | 170 | self.head = nn.Sequential( 171 | nn.AdaptiveAvgPool2d((1, 1)), 172 | nn.Flatten(), 173 | nn.Linear(initial_filter_size * 2 ** 4, initial_filter_size * 2 ** 4), 174 | nn.ReLU(inplace=True), 175 | nn.Linear(initial_filter_size * 2 ** 4, classes) 176 | ) 177 | 178 | self.apply(InitWeights_He(1e-2)) 179 | 180 | def forward(self, x): 181 | 182 | x_1, _, _, _, _ = self.encoder(x) 183 | out = self.head(x_1) 184 | 185 | return out 186 | 187 | 188 | if __name__ == '__main__': 189 | model = UNet2D(in_channels=1, initial_filter_size=32, kernel_size=3, classes=3, do_instancenorm=True) 190 | input = torch.randn(5,1,160,160) 191 | out = model(input) 192 | print(f'out shape:{out.shape}') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from batchgenerators.utilities.file_and_folder_operations import * 5 | from typing import Union, Tuple, List 6 | from scipy.ndimage.filters import gaussian_filter 7 | from sklearn.model_selection import KFold 8 | 9 | def get_split_chd(data_dir, fold, cross_vali_num, seed=12345): 10 | # this is seeded, will be identical each time 11 | all_keys = np.arange(0, 68) 12 | cases = os.listdir(data_dir) 13 | cases.sort() 14 | i = 0 15 | for case in cases: 16 | all_keys[i] = int(case[-4:]) 17 | i = i + 1 18 | kf = KFold(n_splits=cross_vali_num, shuffle=True, random_state=seed) 19 | splits = kf.split(all_keys) 20 | for i, (train_idx, test_idx) in enumerate(splits): 21 | train_keys = all_keys[train_idx] 22 | test_keys = all_keys[test_idx] 23 | if i == fold: 24 | break 25 | return train_keys, test_keys 26 | 27 | def get_split_mmwhs(fold, cross_vali_num, seed=12345): 28 | # this is seeded, will be identical each time 29 | all_keys = np.arange(1001, 1021) 30 | kf = KFold(n_splits=cross_vali_num, shuffle=True, random_state=seed) 31 | splits = kf.split(all_keys) 32 | for i, (train_idx, test_idx) in enumerate(splits): 33 | train_keys = all_keys[train_idx] 34 | test_keys = all_keys[test_idx] 35 | if i == fold: 36 | break 37 | return train_keys, test_keys 38 | 39 | def get_split_acdc(fold, cross_vali_num, seed=12345): 40 | # this is seeded, will be identical each time 41 | kf = KFold(n_splits=cross_vali_num, shuffle=True, random_state=seed) 42 | all_keys = np.arange(1, 101) 43 | splits = kf.split(all_keys) 44 | for i, (train_idx, test_idx) in enumerate(splits): 45 | train_keys = all_keys[train_idx] 46 | test_keys = all_keys[test_idx] 47 | if i == fold: 48 | break 49 | return train_keys, test_keys 50 | 51 | def get_split_hvsmr(fold, cross_vali_num, seed=12345): 52 | # this is seeded, will be identical each time 53 | kf = KFold(n_splits=cross_vali_num, shuffle=True, random_state=seed) 54 | all_keys = np.arange(0, 10) 55 | splits = kf.split(all_keys) 56 | for i, (train_idx, test_idx) in enumerate(splits): 57 | train_keys = all_keys[train_idx] 58 | test_keys = all_keys[test_idx] 59 | if i == fold: 60 | break 61 | return train_keys, test_keys 62 | 63 | def soft_dice(y_pred, y_true): 64 | # sum over axes 65 | axes = tuple([0] + list(range(2, len(y_pred.size())))) 66 | # y_pred is softmax output of shape (num_samples, num_classes) 67 | # y_true is the label that should be converted to one hot encoding of target (shape= (num_samples, num_classes)) 68 | y_onehot = torch.zeros(y_pred.shape).to(y_pred.device) 69 | y_onehot.scatter_(1, y_true, 1) 70 | intersect = (y_pred * y_onehot).sum(dim=axes) 71 | denominator = y_pred.sum(dim=axes) + y_onehot.sum(dim=axes) 72 | dice_scores = 2 * intersect / (denominator + 1e-6) 73 | # we do not count for background dice though 74 | return -1 * dice_scores[1:].mean() 75 | 76 | class AverageMeter(object): 77 | """Computes and stores the average and current value""" 78 | 79 | def __init__(self): 80 | self.reset() 81 | 82 | def reset(self): 83 | self.val = 0 84 | self.avg = 0 85 | self.sum = 0 86 | self.count = 0 87 | 88 | def update(self, val, n=1): 89 | self.val = val 90 | self.sum += val * n 91 | self.count += n 92 | self.avg = self.sum / self.count 93 | 94 | def pad_nd_image(image, new_shape=None, mode="constant", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None): 95 | """ 96 | one padder to pad them all. Documentation? Well okay. A little bit 97 | 98 | :param image: nd image. can be anything 99 | :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If 100 | len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of 101 | the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape) 102 | Example: 103 | image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh? 104 | image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768). 105 | 106 | :param mode: see np.pad for documentation 107 | :param return_slicer: if True then this function will also return what coords you will need to use when cropping back 108 | to original shape 109 | :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is 110 | divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will 111 | be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None) 112 | :param kwargs: see np.pad for documentation 113 | """ 114 | if kwargs is None: 115 | kwargs = {'constant_values': 0} 116 | 117 | if new_shape is not None: 118 | old_shape = np.array(image.shape[-len(new_shape):]) 119 | else: 120 | assert shape_must_be_divisible_by is not None 121 | assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)) 122 | new_shape = image.shape[-len(shape_must_be_divisible_by):] 123 | old_shape = new_shape 124 | 125 | num_axes_nopad = len(image.shape) - len(new_shape) 126 | 127 | new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))] 128 | 129 | if not isinstance(new_shape, np.ndarray): 130 | new_shape = np.array(new_shape) 131 | 132 | if shape_must_be_divisible_by is not None: 133 | if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)): 134 | shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape) 135 | else: 136 | assert len(shape_must_be_divisible_by) == len(new_shape) 137 | 138 | for i in range(len(new_shape)): 139 | if new_shape[i] % shape_must_be_divisible_by[i] == 0: 140 | new_shape[i] -= shape_must_be_divisible_by[i] 141 | 142 | new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))]) 143 | 144 | difference = new_shape - old_shape 145 | pad_below = difference // 2 146 | pad_above = difference // 2 + difference % 2 147 | pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)]) 148 | 149 | if not ((all([i == 0 for i in pad_below])) and (all([i == 0 for i in pad_above]))): 150 | res = np.pad(image, pad_list, mode, **kwargs) 151 | else: 152 | res = image 153 | 154 | if not return_slicer: 155 | return res 156 | else: 157 | pad_list = np.array(pad_list) 158 | pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1] 159 | slicer = list(slice(*i) for i in pad_list) 160 | return res, slicer 161 | 162 | def compute_steps_for_sliding_window(patch_size: Tuple[int, ...], image_size: Tuple[int, ...], step_size: float) -> \ 163 | List[List[int]]: 164 | assert [i >= j for i, j in zip(image_size, patch_size)], "image size must be as large or larger than patch_size" 165 | assert 0 < step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1' 166 | 167 | # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of 168 | # 110, patch size of 32 and step_size of 0.5, then we want to make 4 steps starting at coordinate 0, 27, 55, 78 169 | target_step_sizes_in_voxels = [i * step_size for i in patch_size] 170 | 171 | num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, patch_size)] 172 | 173 | steps = [] 174 | for dim in range(len(patch_size)): 175 | # the highest step value for this dimension is 176 | max_step_value = image_size[dim] - patch_size[dim] 177 | if num_steps[dim] > 1: 178 | actual_step_size = max_step_value / (num_steps[dim] - 1) 179 | else: 180 | actual_step_size = 99999999999 # does not matter because there is only one step at 0 181 | 182 | steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])] 183 | 184 | steps.append(steps_here) 185 | 186 | return steps 187 | 188 | def get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray: 189 | tmp = np.zeros(patch_size) 190 | center_coords = [i // 2 for i in patch_size] 191 | sigmas = [i * sigma_scale for i in patch_size] 192 | tmp[tuple(center_coords)] = 1 193 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) 194 | gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1 195 | gaussian_importance_map = gaussian_importance_map.astype(np.float32) 196 | 197 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans! 198 | gaussian_importance_map[gaussian_importance_map == 0] = np.min( 199 | gaussian_importance_map[gaussian_importance_map != 0]) 200 | 201 | return gaussian_importance_map 202 | 203 | __optimizers = { 204 | 'SGD': torch.optim.SGD, 205 | 'ASGD': torch.optim.ASGD, 206 | 'Adam': torch.optim.Adam, 207 | 'Adamax': torch.optim.Adamax, 208 | 'Adagrad': torch.optim.Adagrad, 209 | 'Adadelta': torch.optim.Adadelta, 210 | 'Rprop': torch.optim.Rprop, 211 | 'RMSprop': torch.optim.RMSprop 212 | } -------------------------------------------------------------------------------- /dataset/generate_acdc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This process code is build on nnUNet:https://github.com/MIC-DKFZ/nnUNet 16 | # Images are saved as np.float 17 | 18 | import SimpleITK as sitk 19 | import os 20 | from multiprocessing import pool 21 | import pickle 22 | import numpy as np 23 | from skimage.transform import resize 24 | 25 | def resize_image(image, old_spacing, new_spacing, order=3): 26 | new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), 27 | int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), 28 | int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) 29 | return resize(image, new_shape, order=order, mode='edge') 30 | 31 | str_to_ind = {'DCM':0, 'HCM':1, 'MINF':2, 'NOR':3, 'RV':4} 32 | ind_to_str = {} 33 | for k in str_to_ind.keys(): 34 | ind_to_str[str_to_ind[k]] = k 35 | 36 | # def view_patient_raw_data(patient, width=400, height=400): 37 | # import batchviewer 38 | # a = [] 39 | # a.append(patient['ed_data'][None]) 40 | # a.append(patient['ed_gt'][None]) 41 | # a.append(patient['es_data'][None]) 42 | # a.append(patient['es_gt'][None]) 43 | # batchviewer.view_batch(np.vstack(a), width, height) 44 | 45 | def convert_to_one_hot(seg): 46 | vals = np.unique(seg) 47 | res = np.zeros([len(vals)] + list(seg.shape), seg.dtype) 48 | for c in range(len(vals)): 49 | res[c][seg == c] = 1 50 | return res 51 | 52 | 53 | def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5), keep_z_spacing=False): 54 | spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]] 55 | image = sitk.GetArrayFromImage(itk_image).astype(float) 56 | if keep_z_spacing: 57 | spacing_target = list(spacing_target) 58 | spacing_target[0] = spacing[0] 59 | if not is_seg: 60 | order_img = 3 61 | if not keep_z_spacing: 62 | order_img = 1 63 | image = resize_image(image, spacing, spacing_target, order=order_img).astype(np.float32) 64 | image -= image.mean() 65 | image /= image.std() 66 | else: 67 | tmp = convert_to_one_hot(image) 68 | vals = np.unique(image) 69 | results = [] 70 | for i in range(len(tmp)): 71 | results.append(resize_image(tmp[i].astype(float), spacing, spacing_target, 1)[None]) 72 | image = vals[np.vstack(results).argmax(0)] 73 | return image 74 | 75 | 76 | def load_dataset(ids=range(101), root_dir="/home/fabian/drives/E132-Projekte/ACDC/new_dataset_preprocessed_for_2D_v2/"): 77 | with open(os.path.join(root_dir, "patient_info.pkl"), 'rb') as f: 78 | patient_info = pickle.load(f) 79 | 80 | data = {} 81 | for i in ids: 82 | if os.path.isfile(os.path.join(root_dir, "pat_%03.0d.npy"%i)): 83 | a = np.load(os.path.join(root_dir, "pat_%03.0d.npy"%i), mmap_mode='r') 84 | data[i] = {} 85 | data[i]['height'] = patient_info[i]['height'] 86 | data[i]['weight'] = patient_info[i]['weight'] 87 | data[i]['pathology'] = patient_info[i]['pathology'] 88 | data[i]['ed_data'] = a[0, :] 89 | data[i]['ed_gt'] = a[1, :] 90 | data[i]['es_data'] = a[2, :] 91 | data[i]['es_gt'] = a[3, :] 92 | return data 93 | 94 | def process_patient(args): 95 | id, patient_info, folder, folder_out, keep_z_spc = args 96 | #print id 97 | # if id in [286, 288]: 98 | # return 99 | patient_folder = os.path.join(folder, "patient%03.0d"%id) 100 | if not os.path.isdir(patient_folder): 101 | return 102 | images = {} 103 | 104 | fname = os.path.join(patient_folder, "patient%03.0d_frame%02.0d.nii.gz" % (id, patient_info[id]['ed'])) 105 | if os.path.isfile(fname): 106 | images["ed"] = sitk.ReadImage(fname) 107 | fname = os.path.join(patient_folder, "patient%03.0d_frame%02.0d_gt.nii.gz" % (id, patient_info[id]['ed'])) 108 | if os.path.isfile(fname): 109 | images["ed_seg"] = sitk.ReadImage(fname) 110 | fname = os.path.join(patient_folder, "patient%03.0d_frame%02.0d.nii.gz" % (id, patient_info[id]['es'])) 111 | if os.path.isfile(fname): 112 | images["es"] = sitk.ReadImage(fname) 113 | fname = os.path.join(patient_folder, "patient%03.0d_frame%02.0d_gt.nii.gz" % (id, patient_info[id]['es'])) 114 | if os.path.isfile(fname): 115 | images["es_seg"] = sitk.ReadImage(fname) 116 | 117 | print(f'{id}, {images["es_seg"].GetSpacing()}') 118 | 119 | for k in images.keys(): 120 | #print k 121 | images[k] = preprocess_image(images[k], is_seg=(k == "ed_seg" or k == "es_seg"), 122 | spacing_target=(10, 1.25, 1.25), keep_z_spacing=keep_z_spc) 123 | 124 | img_as_list = [] 125 | for k in ['ed', 'ed_seg', 'es', 'es_seg']: 126 | if k not in images.keys(): 127 | print(f'{id}, has missing key: {k}') 128 | img_as_list.append(images[k][None]) 129 | try: 130 | all_img = np.vstack(img_as_list) 131 | except: 132 | print(f'{id}, has a problem with spacings') 133 | os.mkdir(os.path.join(folder_out, "patient_%03.0d" % id)) 134 | all_img = np.vstack(img_as_list[:2]) 135 | np.save(os.path.join(folder_out, "patient_%03.0d" % id, 'frame_%02.0d'%patient_info[id]['ed']), all_img.astype(np.float32)) 136 | all_img = np.vstack(img_as_list[-2:]) 137 | np.save(os.path.join(folder_out, "patient_%03.0d" % id, 'frame_%02.0d'%patient_info[id]['es']), all_img.astype(np.float32)) 138 | 139 | def process_patient_video(args): 140 | id, patient_info, folder, folder_out, keep_z_spc = args 141 | #print id 142 | # if id in [286, 288]: 143 | # return 144 | patient_folder = os.path.join(folder, "patient%03.0d"%id) 145 | if not os.path.isdir(patient_folder): 146 | return 147 | 148 | fname = os.path.join(patient_folder, "patient%03.0d_4d.nii.gz" % (id)) 149 | if os.path.isfile(fname): 150 | images = sitk.ReadImage(fname) 151 | 152 | print(f'{id}, {images.GetSpacing()}') 153 | 154 | os.mkdir(os.path.join(folder_out, "patient_%03.0d" % id)) 155 | slices = images.GetSize()[3] 156 | for k in range(slices): 157 | image = preprocess_image(images[:,:,:,k], is_seg=False, spacing_target=(10, 1.25, 1.25), keep_z_spacing=keep_z_spc) 158 | np.save(os.path.join(folder_out, "patient_%03.0d" % id, 'frame_%02.0d'%k), image.astype(np.float32)) 159 | 160 | def generate_patient_info(folder): 161 | patient_info={} 162 | for id in range(101): 163 | fldr = os.path.join(folder, 'patient%03.0d'%id) 164 | if not os.path.isdir(fldr): 165 | print(f'could not find dir of patient, {id}') 166 | continue 167 | nfo = np.loadtxt(os.path.join(fldr, "Info.cfg"), dtype=str, delimiter=': ') 168 | patient_info[id] = {} 169 | patient_info[id]['ed'] = int(nfo[0, 1]) 170 | patient_info[id]['es'] = int(nfo[1, 1]) 171 | patient_info[id]['height'] = float(nfo[3, 1]) 172 | patient_info[id]['pathology'] = nfo[2, 1] 173 | patient_info[id]['weight'] = float(nfo[5, 1]) 174 | return patient_info 175 | 176 | 177 | def run_preprocessing_labeled(folder="/media/fabian/My Book/datasets/ACDC/training/", 178 | folder_out = "/media/fabian/DeepLearningData/datasets/ACDC_forReal_orig_Z/", keep_z_spacing=True): 179 | 180 | print('start processing labeled data...') 181 | patient_info = generate_patient_info(folder) 182 | 183 | if not os.path.isdir(folder_out): 184 | os.mkdir(folder_out) 185 | with open(os.path.join(folder_out, "patient_info.pkl"), 'wb') as f: 186 | pickle.dump(patient_info, f) 187 | 188 | # beware of z spacing!!! see process_patient for more info! 189 | ids = range(101) 190 | p = pool.Pool(8) 191 | p.map(process_patient, zip(ids, [patient_info]*101, [folder]*101, [folder_out]*101, [keep_z_spacing]*101)) 192 | p.close() 193 | p.join() 194 | 195 | def run_preprocessing_unlabeled(folder="/media/fabian/My Book/datasets/ACDC/training/", 196 | folder_out = "/media/fabian/DeepLearningData/datasets/ACDC_forReal_orig_Z/", keep_z_spacing=True): 197 | 198 | print('start processing unlabeled data...') 199 | patient_info = generate_patient_info(folder) 200 | 201 | if not os.path.isdir(folder_out): 202 | os.mkdir(folder_out) 203 | with open(os.path.join(folder_out, "patient_info.pkl"), 'wb') as f: 204 | pickle.dump(patient_info, f) 205 | 206 | # beware of z spacing!!! see process_patient for more info! 207 | ids = range(101) 208 | p = pool.Pool(8) 209 | p.map(process_patient_video, zip(ids, [patient_info]*101, [folder]*101, [folder_out]*101, [keep_z_spacing]*101)) 210 | p.close() 211 | p.join() 212 | 213 | if __name__ == "__main__": 214 | import argparse 215 | parser = argparse.ArgumentParser() 216 | parser.add_argument("-i", help="folder where the extracted training data is", type=str) 217 | parser.add_argument("-out_labeled", help="folder where to save the data for the 2d network", type=str) 218 | parser.add_argument("-out_unlabeled", help="folder where to save the data for the 2d network", type=str) 219 | args = parser.parse_args() 220 | run_preprocessing_labeled(args.i, args.out_labeled, True) 221 | run_preprocessing_unlabeled(args.i, args.out_unlabeled, True) 222 | # run_preprocessing(args.i, args.out3d, False) -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | class SegmentationMetric(object): 7 | """Computes pixAcc and mIoU metric scroes""" 8 | def __init__(self, nclass): 9 | self.nclass = nclass 10 | self.lock = threading.Lock() 11 | self.reset() 12 | 13 | def update(self, labels, preds): 14 | def evaluate_worker(self, label, pred): 15 | correct, labeled = batch_pix_accuracy( 16 | pred, label) 17 | inter, union = batch_intersection_union( 18 | pred, label, self.nclass) 19 | with self.lock: 20 | self.total_correct += correct 21 | self.total_label += labeled 22 | self.total_inter += inter 23 | self.total_union += union 24 | return 25 | 26 | if isinstance(preds, torch.Tensor): 27 | evaluate_worker(self, labels, preds) 28 | elif isinstance(preds, (list, tuple)): 29 | threads = [threading.Thread(target=evaluate_worker, 30 | args=(self, label, pred), 31 | ) 32 | for (label, pred) in zip(labels, preds)] 33 | for thread in threads: 34 | thread.start() 35 | for thread in threads: 36 | thread.join() 37 | else: 38 | raise NotImplemented 39 | 40 | def get(self, mode='mean'): 41 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 42 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 43 | Dice = 2.0 * self.total_inter / (np.spacing(1) + self.total_union + self.total_inter) 44 | if mode=='mean': 45 | mIoU = IoU.mean() 46 | Dice = Dice.mean() 47 | return pixAcc, mIoU, Dice 48 | else: 49 | return pixAcc, IoU, Dice 50 | 51 | def reset(self): 52 | self.total_inter = 0 53 | self.total_union = 0 54 | self.total_correct = 0 55 | self.total_label = 0 56 | return 57 | 58 | def batch_pix_accuracy(output, target): 59 | """Batch Pixel Accuracy 60 | Args: 61 | predict: input 4D tensor 62 | target: label 3D tensor 63 | """ 64 | # predict = torch.max(output, 1)[1] 65 | predict = torch.argmax(output, dim=1) 66 | # predict = output 67 | 68 | # label: 0, 1, ..., nclass - 1 69 | # Note: 0 is background 70 | predict = predict.cpu().numpy().astype('int64') + 1 71 | target = target.cpu().numpy().astype('int64') + 1 72 | 73 | pixel_labeled = np.sum(target > 0) 74 | pixel_correct = np.sum((predict == target)*(target > 0)) 75 | assert pixel_correct <= pixel_labeled, \ 76 | "Correct area should be smaller than Labeled" 77 | return pixel_correct, pixel_labeled 78 | 79 | def batch_intersection_union(output, target, nclass): #只区分背景和器官: nclass = 2 80 | """Batch Intersection of Union 81 | Args: 82 | predict: input 4D tensor #model的输出 83 | target: label 3D Tensor #label 84 | nclass: number of categories (int) #只区分背景和器官: nclass = 2 85 | """ 86 | predict = torch.max(output, dim=1)[1] #获得了预测结果 87 | # predict = output 88 | mini = 1 89 | maxi = nclass-1 #nclass = 2, maxi=1 90 | nbins = nclass-1 #nclass = 2, nbins=1 91 | 92 | # label is: 0, 1, 2, ..., nclass-1 93 | # Note: 0 is background 94 | predict = predict.cpu().numpy().astype('int64') 95 | target = target.cpu().numpy().astype('int64') 96 | 97 | predict = predict * (target >= 0).astype(predict.dtype) 98 | intersection = predict * (predict == target) # 得到TP和TN 99 | 100 | # areas of intersection and union 101 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) #统计(TP、TN)值为1的像素个数,获得TN 102 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) #统计predict中值为1的像素个数,获得TN+FN 103 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) #统计target中值为1的像素个数,获得TN+FP 104 | area_union = area_pred + area_lab - area_inter #area_union:TN+FN+FP 105 | assert (area_inter <= area_union).all(), \ 106 | "Intersection area should be smaller than Union area" 107 | return area_inter, area_union 108 | 109 | # ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py 110 | def pixel_accuracy(im_pred, im_lab): 111 | 112 | im_pred = np.asarray(im_pred) 113 | im_lab = np.asarray(im_lab) 114 | 115 | # Remove classes from unlabeled pixels in gt image. 116 | # We should not penalize detections in unlabeled portions of the image. 117 | pixel_labeled = np.sum(im_lab > 0) 118 | pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0)) 119 | 120 | return pixel_correct, pixel_labeled 121 | 122 | def intersection_and_union(im_pred, im_lab, num_class): 123 | im_pred = np.asarray(im_pred) 124 | im_lab = np.asarray(im_lab) 125 | # Remove classes from unlabeled pixels in gt image. 126 | im_pred = im_pred * (im_lab > 0) 127 | # Compute area intersection: 128 | intersection = im_pred * (im_pred == im_lab) 129 | area_inter, _ = np.histogram(intersection, bins=num_class-1, 130 | range=(1, num_class - 1)) 131 | # Compute area union: 132 | area_pred, _ = np.histogram(im_pred, bins=num_class-1, 133 | range=(1, num_class - 1)) 134 | area_lab, _ = np.histogram(im_lab, bins=num_class-1, 135 | range=(1, num_class - 1)) 136 | area_union = area_pred + area_lab - area_inter 137 | return area_inter, area_union 138 | 139 | def _fast_hist(label_true, label_pred, n_class): 140 | mask = (label_true >= 0) & (label_true < n_class) 141 | hist = np.bincount( 142 | n_class * label_true[mask].astype(int) + 143 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) 144 | return hist 145 | 146 | def label_accuracy_score(label_trues, label_preds, n_class): 147 | """Returns accuracy score evaluation result. 148 | - overall accuracy 149 | - mean accuracy 150 | - mean IU 151 | - fwavacc 152 | """ 153 | hist = np.zeros((n_class, n_class)) 154 | for lt, lp in zip(label_trues, label_preds): 155 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 156 | acc = np.diag(hist).sum() / hist.sum() 157 | acc_cls = np.diag(hist) / hist.sum(axis=1) 158 | acc_cls = np.nanmean(acc_cls) 159 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 160 | mean_iu = np.nanmean(iu) 161 | freq = hist.sum(axis=1) / hist.sum() 162 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 163 | return acc, acc_cls, mean_iu, fwavacc 164 | 165 | def rel_abs_vol_diff(y_true, y_pred): 166 | 167 | return np.abs( (y_pred.sum()/y_true.sum() - 1)*100) 168 | 169 | def get_boundary(data, img_dim=2, shift = -1): 170 | data = data>0 171 | edge = np.zeros_like(data) 172 | for nn in range(img_dim): 173 | edge += ~(data ^ np.roll(~data,shift=shift,axis=nn)) 174 | return edge.astype(int) 175 | 176 | def numpy_dice(y_true, y_pred, axis=None, smooth=1.0): 177 | intersection = y_true*y_pred 178 | return ( 2. * intersection.sum(axis=axis) +smooth )/ (np.sum(y_true, axis=axis) + np.sum(y_pred, axis=axis) + smooth ) 179 | 180 | def dice_coefficient(input, target, smooth=1.0): 181 | assert smooth > 0, 'Smooth must be greater than 0.' 182 | 183 | probs = F.softmax(input, dim=1) 184 | 185 | encoded_target = probs.detach() * 0 186 | encoded_target.scatter_(1, target.unsqueeze(1), 1) 187 | encoded_target = encoded_target.float() 188 | 189 | num = probs * encoded_target # b, c, h, w -- p*g 190 | num = torch.sum(num, dim=3) # b, c, h 191 | num = torch.sum(num, dim=2) # b, c 192 | 193 | den1 = probs * probs # b, c, h, w -- p^2 194 | den1 = torch.sum(den1, dim=3) # b, c, h 195 | den1 = torch.sum(den1, dim=2) # b, c 196 | 197 | den2 = encoded_target * encoded_target # b, c, h, w -- g^2 198 | den2 = torch.sum(den2, dim=3) # b, c, h 199 | den2 = torch.sum(den2, dim=2) # b, c 200 | 201 | dice = (2 * num + smooth) / (den1 + den2 + smooth) # b, c 202 | 203 | return dice.mean().mean() 204 | 205 | def dice_iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor, N_class): 206 | SMOOTH = 1e-5 207 | # You can comment out this line if you are passing tensors of equal shape 208 | # But if you are passing output from UNet or something it will most probably 209 | # be with the BATCH x 1 x H x W shape 210 | outputs = outputs.squeeze(dim=1).float() 211 | labels = labels.squeeze(dim=1).float() 212 | dice = torch.ones(N_class-1).float() 213 | iou = torch.ones(N_class-1).float() 214 | ## for test 215 | #outputs = torch.tensor([[1,1],[3,3]]).float() 216 | #labels = torch.tensor([[0, 1], [2, 3]]).float() 217 | 218 | for iter in range(1,N_class): ## ignore the background 219 | predict_temp = torch.eq(outputs, iter) 220 | label_temp = torch.eq(labels, iter) 221 | intersection = predict_temp & label_temp 222 | intersection = intersection.float().sum((1,2)) 223 | union_dice = (predict_temp.float().sum((1,2)) + label_temp.float().sum((1,2))) 224 | union_iou = (predict_temp | label_temp).float().sum((1,2)) 225 | # if intersection>0 and union>0: 226 | # dice_temp = (2*intersection)/(union) 227 | # else: 228 | # dice_temp = 0 229 | dice[iter-1] = ((2 * intersection + SMOOTH) / (union_dice + SMOOTH)).mean() 230 | iou[iter-1] = ((intersection + SMOOTH) / (union_iou + SMOOTH)).mean() 231 | return dice, iou # Or thresholded.mean() 232 | 233 | if __name__ == '__main__': 234 | outputs = torch.zeros(5, 256, 256) 235 | labels = torch.LongTensor(5, 1, 256, 256).random_(0, 5) 236 | dice, iou = dice_iou_pytorch(outputs=outputs, labels=labels, N_class=5) 237 | print(f'dice:{dice}, iou:{iou}') 238 | 239 | -------------------------------------------------------------------------------- /dataset/augmentation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | import torch 5 | import torchvision.transforms.functional as tf 6 | from PIL import Image, ImageOps 7 | import cv2 8 | 9 | # zbabby(2019/2/21) 10 | # All of the augmentation for PIL image 11 | 12 | class Compose(object): 13 | def __init__(self, augmentations): 14 | self.augmentations = augmentations 15 | self.PIL2Numpy = False 16 | 17 | def __call__(self, img, mask): 18 | if isinstance(img, np.ndarray): 19 | img = Image.fromarray(img, mode="L") 20 | mask = Image.fromarray(mask, mode="L") 21 | self.PIL2Numpy = True 22 | 23 | assert img.size == mask.size 24 | for a in self.augmentations: 25 | img, mask = a(img, mask) 26 | 27 | if self.PIL2Numpy: 28 | img, mask = np.array(img), np.array(mask, dtype=np.uint8) 29 | 30 | return img, mask 31 | 32 | class ToTensor(object): 33 | def __call__(self, img, mask): 34 | return tf.to_tensor(img), torch.from_numpy(np.array(mask)).long() 35 | 36 | 37 | class Resize(object): 38 | def __init__(self, size): 39 | self.size = size 40 | 41 | def __call__(self, img, mask): 42 | return tf.resize(img,self.size), tf.resize(mask,self.size) 43 | 44 | 45 | class AdjustGamma(object): 46 | def __init__(self, gamma): 47 | self.gamma = gamma 48 | 49 | def __call__(self, img, mask): 50 | assert img.size == mask.size 51 | return tf.adjust_gamma(img, random.uniform(1, 1 + self.gamma)), mask 52 | 53 | 54 | class AdjustSaturation(object): 55 | def __init__(self, saturation): 56 | self.saturation = saturation 57 | 58 | def __call__(self, img, mask): 59 | assert img.size == mask.size 60 | return tf.adjust_saturation(img, 61 | random.uniform(1 - self.saturation, 62 | 1 + self.saturation)), mask 63 | class AdjustHue(object): 64 | def __init__(self, hue): 65 | self.hue = hue 66 | 67 | def __call__(self, img, mask): 68 | assert img.size == mask.size 69 | return tf.adjust_hue(img, random.uniform(-self.hue, 70 | self.hue)), mask 71 | 72 | class AdjustBrightness(object): 73 | def __init__(self, bf): 74 | self.bf = bf 75 | 76 | def __call__(self, img, mask): 77 | assert img.size == mask.size 78 | return tf.adjust_brightness(img, 79 | random.uniform(1 - self.bf, 80 | 1 + self.bf)), mask 81 | 82 | class AdjustContrast(object): 83 | def __init__(self, cf): 84 | self.cf = cf 85 | 86 | def __call__(self, img, mask): 87 | assert img.size == mask.size 88 | return tf.adjust_contrast(img, 89 | random.uniform(1 - self.cf, 90 | 1 + self.cf)), mask 91 | 92 | 93 | class RandomHorizontallyFlip(object): 94 | def __init__(self, p=0.5): 95 | self.p = p 96 | 97 | def __call__(self, img, mask): 98 | if random.random() < self.p: 99 | return ( 100 | img.transpose(Image.FLIP_LEFT_RIGHT), 101 | mask.transpose(Image.FLIP_LEFT_RIGHT), 102 | ) 103 | return (img, mask) 104 | 105 | 106 | class RandomVerticallyFlip(object): 107 | def __init__(self, p=0.5): 108 | self.p = p 109 | 110 | def __call__(self, img, mask): 111 | if random.random() < self.p: 112 | return ( 113 | img.transpose(Image.FLIP_TOP_BOTTOM), 114 | mask.transpose(Image.FLIP_TOP_BOTTOM), 115 | ) 116 | return (img, mask) 117 | 118 | 119 | class FreeScale(object): 120 | def __init__(self, size): 121 | self.size = tuple(reversed(size)) # size: (h, w) 122 | 123 | def __call__(self, img, mask): 124 | assert img.size == mask.size 125 | return ( 126 | img.resize(self.size, Image.BILINEAR), 127 | mask.resize(self.size, Image.NEAREST), 128 | ) 129 | 130 | class RandomZoom(object): 131 | def __init__(self, size): 132 | self.size = tuple(reversed(size)) # size: (h, w) 133 | 134 | def __call__(self, img, mask): 135 | assert img.size == mask.size 136 | if random.random() < 0.5: 137 | new_size = (int(img.size[0]*self.size[0]), int(img.size[1]*self.size[1])) 138 | return ( 139 | img.resize(new_size, Image.BILINEAR), 140 | mask.resize(new_size, Image.NEAREST), 141 | ) 142 | return (img, mask) 143 | 144 | 145 | class RandomTranslate(object): 146 | def __init__(self, offset): 147 | self.offset = offset # tuple (delta_x, delta_y), 0~1 148 | 149 | def __call__(self, img, mask): 150 | assert img.size == mask.size 151 | x_offset = int((2 * (random.random() - 0.5) * self.offset[0])*img.size[0]) 152 | y_offset = int((2 * (random.random() - 0.5) * self.offset[1])*img.size[1]) 153 | 154 | x_crop_offset = x_offset 155 | y_crop_offset = y_offset 156 | if x_offset < 0: 157 | x_crop_offset = 0 158 | if y_offset < 0: 159 | y_crop_offset = 0 160 | 161 | cropped_img = tf.crop(img, 162 | y_crop_offset, 163 | x_crop_offset, 164 | img.size[1] - abs(y_offset), 165 | img.size[0] - abs(x_offset)) 166 | 167 | if x_offset >= 0 and y_offset >= 0: 168 | padding_tuple = (0, 0, x_offset, y_offset) 169 | 170 | elif x_offset >= 0 and y_offset < 0: 171 | padding_tuple = (0, abs(y_offset), x_offset, 0) 172 | 173 | elif x_offset < 0 and y_offset >= 0: 174 | padding_tuple = (abs(x_offset), 0, 0, y_offset) 175 | 176 | elif x_offset < 0 and y_offset < 0: 177 | padding_tuple = (abs(x_offset), abs(y_offset), 0, 0) 178 | 179 | return ( 180 | tf.pad(cropped_img, 181 | padding_tuple, 182 | padding_mode='reflect'), 183 | tf.affine(mask, 184 | translate=(-x_offset, -y_offset), 185 | scale=1.0, 186 | angle=0.0, 187 | shear=0.0, 188 | fillcolor=0)) 189 | 190 | class RandomRotate(object): 191 | def __init__(self, degree): 192 | self.degree = degree # -180 and 180 193 | 194 | def __call__(self, img, mask): 195 | rotate_degree = random.random() * 2 * self.degree - self.degree 196 | return ( 197 | tf.affine(img, 198 | translate=(0, 0), 199 | scale=1.0, 200 | angle=rotate_degree, 201 | resample=Image.NEAREST, 202 | fillcolor=(0, 0, 0) if len(img.size) == 3 else 0, 203 | shear=0.0), 204 | tf.affine(mask, 205 | translate=(0, 0), 206 | scale=1.0, 207 | angle=rotate_degree, 208 | resample=Image.NEAREST, 209 | fillcolor=0, 210 | shear=0.0)) 211 | 212 | class Scale(object): 213 | def __init__(self, size): 214 | self.size = size 215 | 216 | def __call__(self, img, mask): 217 | assert img.size == mask.size 218 | w, h = img.size 219 | if (w >= h and w == self.size) or (h >= w and h == self.size): 220 | return img, mask 221 | if w > h: 222 | ow = self.size 223 | oh = int(self.size * h / w) 224 | return ( 225 | img.resize((ow, oh), Image.BILINEAR), 226 | mask.resize((ow, oh), Image.NEAREST), 227 | ) 228 | else: 229 | oh = self.size 230 | ow = int(self.size * w / h) 231 | return ( 232 | img.resize((ow, oh), Image.BILINEAR), 233 | mask.resize((ow, oh), Image.NEAREST), 234 | ) 235 | 236 | 237 | class RandomCrop(object): 238 | def __init__(self, size, padding=0): 239 | if isinstance(size, numbers.Number): 240 | self.size = (int(size), int(size)) 241 | else: 242 | self.size = size 243 | self.padding = padding 244 | 245 | def __call__(self, img, mask): 246 | if self.padding > 0: 247 | img = ImageOps.expand(img, border=self.padding, fill=0) 248 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 249 | 250 | assert img.size == mask.size 251 | w, h = img.size 252 | th, tw = self.size 253 | if w == tw and h == th: 254 | return img, mask 255 | if w < tw or h < th: 256 | return ( 257 | img.resize((tw, th), Image.BILINEAR), 258 | mask.resize((tw, th), Image.NEAREST), 259 | ) 260 | 261 | x1 = random.randint(0, w - tw) 262 | y1 = random.randint(0, h - th) 263 | return ( 264 | img.crop((x1, y1, x1 + tw, y1 + th)), 265 | mask.crop((x1, y1, x1 + tw, y1 + th)), 266 | ) 267 | 268 | 269 | class RandomSizedCrop(object): 270 | def __init__(self, size): 271 | self.size = size 272 | 273 | def __call__(self, img, mask): 274 | assert img.size == mask.size 275 | for attempt in range(10): 276 | area = img.size[0] * img.size[1] 277 | target_area = random.uniform(0.45, 1.0) * area 278 | aspect_ratio = random.uniform(0.5, 2) 279 | 280 | w = int(round(math.sqrt(target_area * aspect_ratio))) 281 | h = int(round(math.sqrt(target_area / aspect_ratio))) 282 | 283 | if random.random() < 0.5: 284 | w, h = h, w 285 | 286 | if w <= img.size[0] and h <= img.size[1]: 287 | x1 = random.randint(0, img.size[0] - w) 288 | y1 = random.randint(0, img.size[1] - h) 289 | 290 | img = img.crop((x1, y1, x1 + w, y1 + h)) 291 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 292 | assert img.size == (w, h) 293 | 294 | return ( 295 | img.resize((self.size, self.size), Image.BILINEAR), 296 | mask.resize((self.size, self.size), Image.NEAREST), 297 | ) 298 | 299 | # Notice, we must guarantee crop to the expected size 300 | scale = Scale(self.size) 301 | crop = CenterCrop(self.size) 302 | return crop(*scale(img, mask)) 303 | 304 | class CenterCrop(object): 305 | def __init__(self, size): 306 | if isinstance(size, numbers.Number): 307 | self.size = (int(size), int(size)) 308 | else: 309 | self.size = size 310 | 311 | def __call__(self, img, mask): 312 | assert img.size == mask.size 313 | w, h = img.size 314 | th, tw = self.size 315 | x1 = int(round((w - tw) / 2.)) 316 | y1 = int(round((h - th) / 2.)) 317 | return ( 318 | img.crop((x1, y1, x1 + tw, y1 + th)), 319 | mask.crop((x1, y1, x1 + tw, y1 + th)), 320 | ) 321 | 322 | class RandomSized(object): 323 | def __init__(self, size): 324 | self.size = size 325 | self.scale = Scale(self.size) 326 | self.crop = RandomCrop(self.size) 327 | 328 | def __call__(self, img, mask): 329 | assert img.size == mask.size 330 | 331 | w = int(random.uniform(0.5, 2) * img.size[0]) 332 | h = int(random.uniform(0.5, 2) * img.size[1]) 333 | 334 | img, mask = ( 335 | img.resize((w, h), Image.BILINEAR), 336 | mask.resize((w, h), Image.NEAREST), 337 | ) 338 | 339 | return self.crop(*self.scale(img, mask)) 340 | 341 | class Pad(object): 342 | """Pads the given PIL.Image on all sides with the given "pad" value""" 343 | 344 | def __init__(self, padding, fill=0): 345 | assert isinstance(padding, numbers.Number) 346 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 347 | self.padding = padding 348 | self.fill = fill 349 | 350 | def __call__(self, img, mask): 351 | return (ImageOps.expand(img, border=self.padding, fill=self.fill), 352 | ImageOps.expand(mask, border=self.padding, fill=self.fill)) 353 | 354 | class GaussianBlur(object): 355 | # Implements Gaussian blur as described in the SimCLR paper 356 | def __init__(self, kernel_size, min=0.1, max=2.0): 357 | self.min = min 358 | self.max = max 359 | # kernel size is set to be 10% of the image height/width 360 | self.kernel_size = kernel_size 361 | 362 | def __call__(self, img, mask): 363 | img = np.array(img) 364 | mask = np.array(mask) 365 | 366 | # blur the image with a 50% chance 367 | # prob = np.random.random_sample() 368 | # 369 | # if prob < 0.5: 370 | sigma = (self.max - self.min) * np.random.random_sample() + self.min 371 | img = cv2.GaussianBlur(img, (self.kernel_size, self.kernel_size), sigma) 372 | mask = cv2.GaussianBlur(mask, (self.kernel_size, self.kernel_size), sigma) 373 | 374 | return (Image.fromarray(img,mode='L'), Image.fromarray(mask, mode='L')) 375 | 376 | class SobelFilter(object): 377 | # Implements Gaussian blur as described in the SimCLR paper 378 | def __init__(self, kernel_size): 379 | # kernel size is set to be 10% of the image height/width 380 | self.kernel_size = kernel_size 381 | 382 | def __call__(self, img, mask): 383 | img = np.array(img) 384 | mask = np.array(mask) 385 | 386 | img_x = np.absolute(cv2.Sobel(img, cv2.CV_64F, 1, 0, self.kernel_size)) 387 | img_y = np.absolute(cv2.Sobel(img, cv2.CV_64F, 0, 1, self.kernel_size)) 388 | img = img_x + img_y 389 | img = (255 * (img - img.min()) / (img.max() - img.min())).astype(np.uint8) 390 | mask_x = np.absolute(cv2.Sobel(mask, cv2.CV_64F, 1, 0, self.kernel_size)) 391 | mask_y = np.absolute(cv2.Sobel(mask, cv2.CV_64F, 0, 1, self.kernel_size)) 392 | mask = mask_x + mask_y 393 | 394 | return (Image.fromarray(img,mode='L'), Image.fromarray(mask, mode='L')) 395 | 396 | 397 | class RandomElasticTransform(object): 398 | def __init__(self, alpha = 3, sigma=0.07, img_type='L'): 399 | self.alpha = alpha 400 | self.sigma = sigma 401 | self.img_type = img_type 402 | 403 | def _elastic_transform(self, img, mask): 404 | 405 | # convert to numpy 406 | img = np.array(img) # hxwxc 407 | mask = np.array(mask) 408 | 409 | shape1=img.shape 410 | 411 | alpha = self.alpha*shape1[0] 412 | sigma = self.sigma*shape1[0] 413 | 414 | x, y = np.meshgrid(np.arange(shape1[0]), np.arange(shape1[1]), indexing='ij') 415 | blur_size = int(4 * sigma) | 1 416 | dx = cv2.GaussianBlur((np.random.rand(shape1[0], shape1[1]) * 2 - 1), ksize=(blur_size, blur_size),sigmaX=sigma) * alpha 417 | dy = cv2.GaussianBlur((np.random.rand(shape1[0], shape1[1]) * 2 - 1), ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 418 | 419 | if (x is None) or (y is None): 420 | x, y = np.meshgrid(np.arange(shape1[0]), np.arange(shape1[1]), indexing='ij') 421 | 422 | map_x = (x + dx).astype(np.float32) 423 | map_y = (y + dy).astype(np.float32) 424 | # convert map 425 | map_x, map_y = cv2.convertMaps(map_x, map_y, cv2.CV_16SC2) 426 | 427 | img = cv2.remap(img, map_y, map_x, interpolation=cv2.INTER_LINEAR, borderMode = cv2.BORDER_CONSTANT).reshape(shape1) 428 | mask = cv2.remap(mask, map_y, map_x, interpolation=cv2.INTER_NEAREST, borderMode = cv2.BORDER_CONSTANT).reshape(shape1) 429 | 430 | return (Image.fromarray(img,mode=self.img_type), Image.fromarray(mask, mode='L')) 431 | 432 | def __call__(self, img, mask): 433 | """Elastic deformation of images as described in [Simard2003]_. 434 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 435 | Convolutional Neural Networks applied to Visual Document Analysis", in 436 | Proc. of the International Conference on Document Analysis and 437 | Recognition, 2003. 438 | """ 439 | if random.random() < 0.5: 440 | return self._elastic_transform(img, mask) 441 | else: 442 | return (img, mask) 443 | 444 | 445 | 446 | import cv2 447 | import numpy as np 448 | import SimpleITK as sitk 449 | 450 | def smooth_images(imgs, t_step=0.125, n_iter=5): 451 | """ 452 | Curvature driven image denoising. 453 | In my experience helps significantly with segmentation. 454 | """ 455 | 456 | for mm in range(len(imgs)): 457 | img = sitk.GetImageFromArray(imgs[mm]) 458 | img = sitk.CurvatureFlow(image1=img, 459 | timeStep=t_step, 460 | numberOfIterations=n_iter) 461 | 462 | imgs[mm] = sitk.GetArrayFromImage(img) 463 | 464 | return imgs 465 | --------------------------------------------------------------------------------